mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-17 00:37:43 +01:00
32 lines
962 B
Python
32 lines
962 B
Python
|
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||
|
|
|
||
|
|
import argparse
|
||
|
|
import os
|
||
|
|
|
||
|
|
from modelscope.models import Model
|
||
|
|
from modelscope.utils.megatron_utils import convert_megatron_checkpoint
|
||
|
|
|
||
|
|
|
||
|
|
def unwrap_model(model):
|
||
|
|
for name in ('model', 'module', 'dist_model'):
|
||
|
|
while hasattr(model, name):
|
||
|
|
model = getattr(model, name)
|
||
|
|
return model
|
||
|
|
|
||
|
|
|
||
|
|
parser = argparse.ArgumentParser(
|
||
|
|
description='Split or merge your megatron_based checkpoint.')
|
||
|
|
parser.add_argument(
|
||
|
|
'--model_dir', type=str, required=True, help='Checkpoint to be converted.')
|
||
|
|
parser.add_argument(
|
||
|
|
'--target_dir', type=str, required=True, help='Target save path.')
|
||
|
|
args = parser.parse_args()
|
||
|
|
|
||
|
|
model = Model.from_pretrained(
|
||
|
|
args.model_dir,
|
||
|
|
rank=int(os.getenv('RANK')),
|
||
|
|
megatron_cfg={'tensor_model_parallel_size': int(os.getenv('WORLD_SIZE'))})
|
||
|
|
unwrapped_model = unwrap_model(model)
|
||
|
|
|
||
|
|
convert_megatron_checkpoint(unwrapped_model, model.model_dir, args.target_dir)
|