diff --git a/modelscope/models/multi_modal/ofa_for_all_tasks.py b/modelscope/models/multi_modal/ofa_for_all_tasks.py index 38d1538d..dc2db59c 100644 --- a/modelscope/models/multi_modal/ofa_for_all_tasks.py +++ b/modelscope/models/multi_modal/ofa_for_all_tasks.py @@ -1,8 +1,10 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import math +import os import string +from functools import partial from os import path as osp -from typing import Any, Dict +from typing import Any, Callable, Dict, List, Optional, Union import json import torch.cuda @@ -295,3 +297,16 @@ class OfaForAllTasks(TorchModel): self.cfg.model.answer2label) with open(ans2label_file, 'r') as reader: self.ans2label_dict = json.load(reader) + + def save_pretrained(self, + target_folder: Union[str, os.PathLike], + save_checkpoint_names: Union[str, List[str]] = None, + save_function: Callable = None, + config: Optional[dict] = None, + **kwargs): + super(OfaForAllTasks, self). \ + save_pretrained(target_folder=target_folder, + save_checkpoint_names=save_checkpoint_names, + save_function=partial(save_function, with_meta=False), + config=config, + **kwargs) diff --git a/modelscope/trainers/multi_modal/ofa/ofa_trainer_utils.py b/modelscope/trainers/multi_modal/ofa/ofa_trainer_utils.py index cdae21c6..ecd8cd1d 100644 --- a/modelscope/trainers/multi_modal/ofa/ofa_trainer_utils.py +++ b/modelscope/trainers/multi_modal/ofa/ofa_trainer_utils.py @@ -12,7 +12,6 @@ from torch.nn.modules.loss import _Loss from torch.utils.data import Dataset from modelscope.preprocessors.multi_modal import OfaPreprocessor -from .ofa_file_dataset import OFAFileDataset class OFADataset(Dataset):