From a799dd237d807eceef80bf3361f5bd2a0db9ce1d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=A1=8C=E5=97=94?= Date: Thu, 29 Sep 2022 15:26:20 +0800 Subject: [PATCH] remove ofa_file_dataset --- .../models/multi_modal/ofa_for_all_tasks.py | 17 ++++++++++++++++- .../multi_modal/ofa/ofa_trainer_utils.py | 1 - 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/modelscope/models/multi_modal/ofa_for_all_tasks.py b/modelscope/models/multi_modal/ofa_for_all_tasks.py index 38d1538d..dc2db59c 100644 --- a/modelscope/models/multi_modal/ofa_for_all_tasks.py +++ b/modelscope/models/multi_modal/ofa_for_all_tasks.py @@ -1,8 +1,10 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import math +import os import string +from functools import partial from os import path as osp -from typing import Any, Dict +from typing import Any, Callable, Dict, List, Optional, Union import json import torch.cuda @@ -295,3 +297,16 @@ class OfaForAllTasks(TorchModel): self.cfg.model.answer2label) with open(ans2label_file, 'r') as reader: self.ans2label_dict = json.load(reader) + + def save_pretrained(self, + target_folder: Union[str, os.PathLike], + save_checkpoint_names: Union[str, List[str]] = None, + save_function: Callable = None, + config: Optional[dict] = None, + **kwargs): + super(OfaForAllTasks, self). \ + save_pretrained(target_folder=target_folder, + save_checkpoint_names=save_checkpoint_names, + save_function=partial(save_function, with_meta=False), + config=config, + **kwargs) diff --git a/modelscope/trainers/multi_modal/ofa/ofa_trainer_utils.py b/modelscope/trainers/multi_modal/ofa/ofa_trainer_utils.py index cdae21c6..ecd8cd1d 100644 --- a/modelscope/trainers/multi_modal/ofa/ofa_trainer_utils.py +++ b/modelscope/trainers/multi_modal/ofa/ofa_trainer_utils.py @@ -12,7 +12,6 @@ from torch.nn.modules.loss import _Loss from torch.utils.data import Dataset from modelscope.preprocessors.multi_modal import OfaPreprocessor -from .ofa_file_dataset import OFAFileDataset class OFADataset(Dataset):