backward compatible with to_task_dataset function in adaseq repo

1. backward compatible with to_task_dataset function for DefaultTrainer in adaseq repo
2. fix registry issue for RedsImageDeblurringDataset and GoproImageDeblurringDataset
3. add ut TestCustomDatasetsCompatibility
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11981956
This commit is contained in:
xingjun.wxj
2023-03-14 18:34:44 +08:00
committed by yuze.zyz
parent 4a0cb67e66
commit 4e77f654f5
13 changed files with 153 additions and 25 deletions

View File

@@ -1153,3 +1153,6 @@ class CustomDatasets(object):
DetImagesMixDataset = 'DetImagesMixDataset'
PanopticDataset = 'PanopticDataset'
PairedDataset = 'PairedDataset'
SiddDataset = 'SiddDataset'
GoproDataset = 'GoproDataset'
RedsDataset = 'RedsDataset'

View File

@@ -18,7 +18,7 @@ def default_loader(path):
@CUSTOM_DATASETS.register_module(
Tasks.image_deblurring, module_name=CustomDatasets.PairedDataset)
Tasks.image_deblurring, module_name=CustomDatasets.GoproDataset)
class GoproImageDeblurringDataset(TorchCustomDataset):
"""Paired image dataset for image restoration.
"""

View File

@@ -18,7 +18,7 @@ def default_loader(path):
@CUSTOM_DATASETS.register_module(
Tasks.image_deblurring, module_name=CustomDatasets.PairedDataset)
Tasks.image_deblurring, module_name=CustomDatasets.RedsDataset)
class RedsImageDeblurringDataset(TorchCustomDataset):
"""Paired image dataset for image restoration.
"""

View File

@@ -3,7 +3,7 @@
import cv2
import numpy as np
from modelscope.metainfo import Models
from modelscope.metainfo import CustomDatasets
from modelscope.msdatasets.dataset_cls.custom_datasets import (
CUSTOM_DATASETS, TorchCustomDataset)
from modelscope.utils.constant import Tasks
@@ -16,7 +16,7 @@ def default_loader(path):
@CUSTOM_DATASETS.register_module(
Tasks.image_denoising, module_name=Models.nafnet)
Tasks.image_denoising, module_name=CustomDatasets.SiddDataset)
class SiddImageDenoisingDataset(TorchCustomDataset):
"""Paired image dataset for image restoration.
"""

View File

@@ -1,6 +1,29 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from .gopro_image_deblurring_dataset import GoproImageDeblurringDataset
from .reds_image_deblurring_dataset import RedsImageDeblurringDataset
from .sidd_image_denoising import SiddImageDenoisingDataset
from .video_summarization_dataset import VideoSummarizationDataset
from typing import TYPE_CHECKING
from modelscope.utils.import_utils import LazyImportModule
if TYPE_CHECKING:
from .torch_base_dataset import TorchTaskDataset
from .gopro_image_deblurring_dataset import GoproImageDeblurringDataset
from .reds_image_deblurring_dataset import RedsImageDeblurringDataset
from .sidd_image_denoising import SiddImageDenoisingDataset
from .video_summarization_dataset import VideoSummarizationDataset
else:
_import_structure = {
'torch_base_dataset': ['TorchTaskDataset'],
'gopro_image_deblurring_dataset': ['GoproImageDeblurringDataset'],
'reds_image_deblurring_dataset': ['RedsImageDeblurringDataset'],
'sidd_image_denoising': ['SiddImageDenoisingDataset'],
'video_summarization_dataset': ['VideoSummarizationDataset'],
}
import sys
sys.modules[__name__] = LazyImportModule(
__name__,
globals()['__file__'],
_import_structure,
module_spec=__spec__,
extra_objects={},
)

View File

@@ -1,12 +1,11 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from modelscope.msdatasets.dataset_cls.custom_datasets.gopro_image_deblurring_dataset import \
from modelscope.msdatasets.dataset_cls.custom_datasets import \
GoproImageDeblurringDataset
from modelscope.utils.logger import get_logger
logger = get_logger()
logger.warning(
'The reference has been Deprecated, '
'please use `from modelscope.msdatasets.dataset_cls.'
'custom_datasets.gopro_image_deblurring_dataset import GoproImageDeblurringDataset`'
'The reference has been Deprecated in modelscope v1.4.0+, '
'please use `from modelscope.msdatasets.dataset_cls.custom_datasets import GoproImageDeblurringDataset`'
)

View File

@@ -1,11 +1,11 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from modelscope.msdatasets.dataset_cls.custom_datasets.reds_image_deblurring_dataset import \
from modelscope.msdatasets.dataset_cls.custom_datasets import \
RedsImageDeblurringDataset
from modelscope.utils.logger import get_logger
logger = get_logger()
logger.warning(
'The reference has been Deprecated, '
'please use `modelscope.msdatasets.dataset_cls.custom_datasets.'
'reds_image_deblurring_dataset import RedsImageDeblurringDataset`')
'The reference has been Deprecated in modelscope v1.4.0+, '
'please use `modelscope.msdatasets.dataset_cls.custom_datasets import RedsImageDeblurringDataset`'
)

View File

@@ -1,11 +1,11 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from modelscope.msdatasets.dataset_cls.custom_datasets.sidd_image_denoising import \
from modelscope.msdatasets.dataset_cls.custom_datasets import \
SiddImageDenoisingDataset
from modelscope.utils.logger import get_logger
logger = get_logger()
logger.warning(
'The reference has been Deprecated, '
'please use `from modelscope.msdatasets.dataset_cls.'
'custom_datasets.sidd_image_denoising import SiddImageDenoisingDataset`')
'The reference has been Deprecated in modelscope v1.4.0+, '
'please use `from modelscope.msdatasets.dataset_cls.custom_datasets import SiddImageDenoisingDataset`'
)

View File

@@ -0,0 +1,11 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from modelscope.msdatasets.dataset_cls.custom_datasets import \
TorchCustomDataset as TorchTaskDataset
from modelscope.utils.logger import get_logger
logger = get_logger()
logger.warning(
'The reference has been Deprecated in modelscope v1.4.0+, '
'please use `from modelscope.msdatasets.dataset_cls.custom_datasets import TorchCustomDataset`'
)

View File

@@ -1,12 +1,11 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from modelscope.msdatasets.dataset_cls.custom_datasets.video_summarization_dataset import \
from modelscope.msdatasets.dataset_cls.custom_datasets import \
VideoSummarizationDataset
from modelscope.utils.logger import get_logger
logger = get_logger()
logger.warning(
'The reference has been Deprecated, '
'please use `from modelscope.msdatasets.dataset_cls.'
'custom_datasets.video_summarization_dataset import VideoSummarizationDataset`'
'The reference has been Deprecated in modelscope v1.4.0+, '
'please use `from modelscope.msdatasets.dataset_cls.custom_datasets import VideoSummarizationDataset`'
)

View File

@@ -554,6 +554,23 @@ class EpochBasedTrainer(BaseTrainer):
else:
return datasets
def to_task_dataset(self, dataset: Dataset, mode: str,
preprocessor: Preprocessor,
**kwargs) -> TorchCustomDataset:
r"""
@deprecated
This method is deprecated and may be removed in future releases, please use `build_dataset()` instead. Could be
compatible with methods that override the to_task_dataset in other classes.
"""
self.logger.warning(
'This to_task_dataset method is deprecated, please use build_dataset instead.'
)
task_dataset = TorchCustomDataset(
dataset, mode=mode, preprocessor=preprocessor, **kwargs)
task_dataset.trainer = self
return task_dataset
@staticmethod
def build_dataset_from_cfg(model_cfg: Config,
mode: str,

View File

@@ -0,0 +1,76 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import unittest
from datasets import Dataset
from modelscope.hub.snapshot_download import snapshot_download
from modelscope.msdatasets.dataset_cls.custom_datasets import \
TorchCustomDataset
from modelscope.preprocessors import Preprocessor
from modelscope.trainers.trainer import EpochBasedTrainer
from modelscope.utils import logger as logging
from modelscope.utils.config import Config
from modelscope.utils.constant import ModeKeys, ModelFile, Tasks
from modelscope.utils.test_utils import test_level
logger = logging.get_logger()
class TestDummyEpochBasedTrainer(EpochBasedTrainer):
def __init__(self,
dataset: Dataset = None,
mode: str = ModeKeys.TRAIN,
preprocessor: Preprocessor = None,
**kwargs):
super(TestDummyEpochBasedTrainer, self).__init__(**kwargs)
self.train_dataset = self.to_task_dataset(dataset, mode, preprocessor)
def to_task_dataset(self, dataset: Dataset, mode: str,
preprocessor: Preprocessor,
**kwargs) -> TorchCustomDataset:
src_dataset_dict = {
'src_txt': [
'This is test sentence1-1', 'This is test sentence2-1',
'This is test sentence3-1'
]
}
dataset = Dataset.from_dict(src_dataset_dict)
dataset_res = TorchCustomDataset(
datasets=dataset, mode=mode, preprocessor=preprocessor)
dataset_res.trainer = self
return dataset_res
class TestCustomDatasetsCompatibility(unittest.TestCase):
def setUp(self):
self.task = Tasks.movie_scene_segmentation
self.model_id = 'damo/cv_resnet50-bert_video-scene-segmentation_movienet'
cache_path = snapshot_download(self.model_id)
self.config_path = os.path.join(cache_path, ModelFile.CONFIGURATION)
self.cfg = Config.from_file(self.config_path)
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_adaseq_import_task_datasets(self):
from modelscope.msdatasets.task_datasets.torch_base_dataset import TorchTaskDataset
from modelscope.msdatasets.task_datasets import GoproImageDeblurringDataset
from modelscope.msdatasets.task_datasets import RedsImageDeblurringDataset
from modelscope.msdatasets.task_datasets import SiddImageDenoisingDataset
from modelscope.msdatasets.task_datasets import VideoSummarizationDataset
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_adaseq_trainer_overwrite(self):
test_trainer = TestDummyEpochBasedTrainer(cfg_file=self.config_path)
assert isinstance(test_trainer.train_dataset.trainer,
TestDummyEpochBasedTrainer)
assert test_trainer.train_dataset.mode == ModeKeys.TRAIN
assert isinstance(test_trainer.train_dataset._inner_dataset, Dataset)
if __name__ == '__main__':
unittest.main()

View File

@@ -52,7 +52,7 @@ class ImageColorizationTrainerTest(unittest.TestCase):
shutil.rmtree(self.tmp_dir, ignore_errors=True)
super().tearDown()
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_trainer(self):
kwargs = dict(
model=self.model_id,