override pipeline by tasks name after finetune done, avoid case like fill mask pipeline with a text cls task

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10554512
This commit is contained in:
zhangzhicheng.zzc
2022-10-27 22:52:29 +08:00
committed by yingda.chen
parent 78f29cf999
commit 9df3f5c41f

View File

@@ -172,12 +172,17 @@ class CheckpointHook(Hook):
else:
model = trainer.model
config = trainer.cfg.to_dict()
# override pipeline by tasks name after finetune done,
# avoid case like fill mask pipeline with a text cls task
config['pipeline'] = {'type': config['task']}
if hasattr(model, 'save_pretrained'):
model.save_pretrained(
output_dir,
ModelFile.TORCH_MODEL_BIN_FILE,
save_function=save_checkpoint,
config=trainer.cfg.to_dict(),
config=config,
with_meta=False)
def after_train_iter(self, trainer):