mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 16:27:45 +01:00
38 lines
1.3 KiB
Python
38 lines
1.3 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
import argparse
|
|
import os
|
|
import shutil
|
|
|
|
import torch
|
|
|
|
|
|
def convert_single_pth(fullname):
|
|
filename, ext = os.path.splitext(fullname)
|
|
checkpoint = torch.load(fullname, map_location='cpu')
|
|
only_module = 'state_dict' not in checkpoint
|
|
state_dict = checkpoint if only_module else checkpoint['state_dict']
|
|
torch.save(state_dict, fullname)
|
|
|
|
if not only_module:
|
|
checkpoint.pop('state_dict')
|
|
fullname_trainer = filename + '_trainer_state' + ext
|
|
torch.save(checkpoint, fullname_trainer)
|
|
|
|
|
|
# This script is used to split pth files which generated before version 1.3.1 into two files.
|
|
# there is only one argument: --dir, fill the dir contains the pth files inside.
|
|
# NOTE: If you are using this script to convert the checkpoints of GPT3 or other sharding models,
|
|
# please rename the checkpoint filenames after the conversion manually.
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--dir', help='The dir contains the *.pth files.')
|
|
args = parser.parse_args()
|
|
folder = args.dir
|
|
assert folder
|
|
|
|
all_files = os.listdir(folder)
|
|
all_files = [file for file in all_files if file.endswith('.pth')]
|
|
for file in all_files:
|
|
shutil.copy(
|
|
os.path.join(folder, file), os.path.join(folder, file + '.legacy'))
|
|
convert_single_pth(os.path.join(folder, file))
|