mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-25 12:39:25 +01:00
remove ofa_file_dataset
This commit is contained in:
@@ -1,8 +1,10 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import math
|
||||
import os
|
||||
import string
|
||||
from functools import partial
|
||||
from os import path as osp
|
||||
from typing import Any, Dict
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import json
|
||||
import torch.cuda
|
||||
@@ -295,3 +297,16 @@ class OfaForAllTasks(TorchModel):
|
||||
self.cfg.model.answer2label)
|
||||
with open(ans2label_file, 'r') as reader:
|
||||
self.ans2label_dict = json.load(reader)
|
||||
|
||||
def save_pretrained(self,
|
||||
target_folder: Union[str, os.PathLike],
|
||||
save_checkpoint_names: Union[str, List[str]] = None,
|
||||
save_function: Callable = None,
|
||||
config: Optional[dict] = None,
|
||||
**kwargs):
|
||||
super(OfaForAllTasks, self). \
|
||||
save_pretrained(target_folder=target_folder,
|
||||
save_checkpoint_names=save_checkpoint_names,
|
||||
save_function=partial(save_function, with_meta=False),
|
||||
config=config,
|
||||
**kwargs)
|
||||
|
||||
@@ -12,7 +12,6 @@ from torch.nn.modules.loss import _Loss
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from modelscope.preprocessors.multi_modal import OfaPreprocessor
|
||||
from .ofa_file_dataset import OFAFileDataset
|
||||
|
||||
|
||||
class OFADataset(Dataset):
|
||||
|
||||
Reference in New Issue
Block a user