diff --git a/modelscope/trainers/hooks/checkpoint_hook.py b/modelscope/trainers/hooks/checkpoint_hook.py index 47bd84c4..89aa39ba 100644 --- a/modelscope/trainers/hooks/checkpoint_hook.py +++ b/modelscope/trainers/hooks/checkpoint_hook.py @@ -172,12 +172,17 @@ class CheckpointHook(Hook): else: model = trainer.model + config = trainer.cfg.to_dict() + # override pipeline by tasks name after finetune done, + # avoid case like fill mask pipeline with a text cls task + config['pipeline'] = {'type': config['task']} + if hasattr(model, 'save_pretrained'): model.save_pretrained( output_dir, ModelFile.TORCH_MODEL_BIN_FILE, save_function=save_checkpoint, - config=trainer.cfg.to_dict(), + config=config, with_meta=False) def after_train_iter(self, trainer):