remove ofa_file_dataset

This commit is contained in:
行嗔
2022-09-29 15:26:20 +08:00
parent 993b944b65
commit a799dd237d
2 changed files with 16 additions and 2 deletions

View File

@@ -1,8 +1,10 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import math
import os
import string
from functools import partial
from os import path as osp
from typing import Any, Dict
from typing import Any, Callable, Dict, List, Optional, Union
import json
import torch.cuda
@@ -295,3 +297,16 @@ class OfaForAllTasks(TorchModel):
self.cfg.model.answer2label)
with open(ans2label_file, 'r') as reader:
self.ans2label_dict = json.load(reader)
def save_pretrained(self,
target_folder: Union[str, os.PathLike],
save_checkpoint_names: Union[str, List[str]] = None,
save_function: Callable = None,
config: Optional[dict] = None,
**kwargs):
super(OfaForAllTasks, self). \
save_pretrained(target_folder=target_folder,
save_checkpoint_names=save_checkpoint_names,
save_function=partial(save_function, with_meta=False),
config=config,
**kwargs)

View File

@@ -12,7 +12,6 @@ from torch.nn.modules.loss import _Loss
from torch.utils.data import Dataset
from modelscope.preprocessors.multi_modal import OfaPreprocessor
from .ofa_file_dataset import OFAFileDataset
class OFADataset(Dataset):