mirror of
https://github.com/modelscope/modelscope.git
synced 2026-05-18 13:15:06 +02:00
[to #50334474] llama tuned model -> pipeline
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/13071602 * prepare for inference
This commit is contained in:
@@ -261,3 +261,19 @@ if __name__ == '__main__':
|
||||
trainer = build_trainer(
|
||||
name=Trainers.text_generation_trainer, default_args=kwargs)
|
||||
trainer.train()
|
||||
|
||||
# prepare for inference
|
||||
if int(os.environ.get('LOCAL_RANK', 0)) == 0:
|
||||
tokenizer.save_pretrained(os.path.join(args.work_dir, 'output'))
|
||||
os.system(f'rm {args.work_dir}/output/pytorch_model*')
|
||||
os.system(
|
||||
f'python3 {args.work_dir}/zero_to_fp32.py {args.work_dir} {args.work_dir}/output/pytorch_model.bin'
|
||||
)
|
||||
os.system(
|
||||
f'cp {model_path}/configuration.json {args.work_dir}/output/configuration.json'
|
||||
)
|
||||
with open(f'{model_path}/config.json', 'r') as f:
|
||||
config = json.load(f)
|
||||
config['vocab_size'] = len(tokenizer)
|
||||
with open(f'{args.work_dir}/output/config.json', 'w') as f:
|
||||
json.dump(config, f)
|
||||
|
||||
Reference in New Issue
Block a user