Files
AudioGPT/audio_to_text/captioning/utils/remove_optimizer.py

19 lines
471 B
Python
Raw Normal View History

2023-03-28 23:30:18 +08:00
import argparse
import torch
def main(checkpoint):
state_dict = torch.load(checkpoint, map_location="cpu")
if "optimizer" in state_dict:
del state_dict["optimizer"]
if "lr_scheduler" in state_dict:
del state_dict["lr_scheduler"]
torch.save(state_dict, checkpoint)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("checkpoint", type=str)
args = parser.parse_args()
main(args.checkpoint)