Files
modelscope/tests/msdatasets/test_custom_datasets_compatibility.py
xingjun.wxj 4e77f654f5 backward compatible with to_task_dataset function in adaseq repo
1. backward compatible with to_task_dataset function for DefaultTrainer in adaseq repo
2. fix registry issue for RedsImageDeblurringDataset and GoproImageDeblurringDataset
3. add ut TestCustomDatasetsCompatibility
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11981956
2023-03-14 18:34:44 +08:00

77 lines
3.0 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import unittest
from datasets import Dataset
from modelscope.hub.snapshot_download import snapshot_download
from modelscope.msdatasets.dataset_cls.custom_datasets import \
TorchCustomDataset
from modelscope.preprocessors import Preprocessor
from modelscope.trainers.trainer import EpochBasedTrainer
from modelscope.utils import logger as logging
from modelscope.utils.config import Config
from modelscope.utils.constant import ModeKeys, ModelFile, Tasks
from modelscope.utils.test_utils import test_level
logger = logging.get_logger()
class TestDummyEpochBasedTrainer(EpochBasedTrainer):
def __init__(self,
dataset: Dataset = None,
mode: str = ModeKeys.TRAIN,
preprocessor: Preprocessor = None,
**kwargs):
super(TestDummyEpochBasedTrainer, self).__init__(**kwargs)
self.train_dataset = self.to_task_dataset(dataset, mode, preprocessor)
def to_task_dataset(self, dataset: Dataset, mode: str,
preprocessor: Preprocessor,
**kwargs) -> TorchCustomDataset:
src_dataset_dict = {
'src_txt': [
'This is test sentence1-1', 'This is test sentence2-1',
'This is test sentence3-1'
]
}
dataset = Dataset.from_dict(src_dataset_dict)
dataset_res = TorchCustomDataset(
datasets=dataset, mode=mode, preprocessor=preprocessor)
dataset_res.trainer = self
return dataset_res
class TestCustomDatasetsCompatibility(unittest.TestCase):
def setUp(self):
self.task = Tasks.movie_scene_segmentation
self.model_id = 'damo/cv_resnet50-bert_video-scene-segmentation_movienet'
cache_path = snapshot_download(self.model_id)
self.config_path = os.path.join(cache_path, ModelFile.CONFIGURATION)
self.cfg = Config.from_file(self.config_path)
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_adaseq_import_task_datasets(self):
from modelscope.msdatasets.task_datasets.torch_base_dataset import TorchTaskDataset
from modelscope.msdatasets.task_datasets import GoproImageDeblurringDataset
from modelscope.msdatasets.task_datasets import RedsImageDeblurringDataset
from modelscope.msdatasets.task_datasets import SiddImageDenoisingDataset
from modelscope.msdatasets.task_datasets import VideoSummarizationDataset
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_adaseq_trainer_overwrite(self):
test_trainer = TestDummyEpochBasedTrainer(cfg_file=self.config_path)
assert isinstance(test_trainer.train_dataset.trainer,
TestDummyEpochBasedTrainer)
assert test_trainer.train_dataset.mode == ModeKeys.TRAIN
assert isinstance(test_trainer.train_dataset._inner_dataset, Dataset)
if __name__ == '__main__':
unittest.main()