mirror of
https://github.com/modelscope/modelscope.git
synced 2026-05-18 05:05:00 +02:00
backward compatible with to_task_dataset function in adaseq repo
1. backward compatible with to_task_dataset function for DefaultTrainer in adaseq repo
2. fix registry issue for RedsImageDeblurringDataset and GoproImageDeblurringDataset
3. add ut TestCustomDatasetsCompatibility
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11981956
This commit is contained in:
@@ -1153,3 +1153,6 @@ class CustomDatasets(object):
|
||||
DetImagesMixDataset = 'DetImagesMixDataset'
|
||||
PanopticDataset = 'PanopticDataset'
|
||||
PairedDataset = 'PairedDataset'
|
||||
SiddDataset = 'SiddDataset'
|
||||
GoproDataset = 'GoproDataset'
|
||||
RedsDataset = 'RedsDataset'
|
||||
|
||||
@@ -18,7 +18,7 @@ def default_loader(path):
|
||||
|
||||
|
||||
@CUSTOM_DATASETS.register_module(
|
||||
Tasks.image_deblurring, module_name=CustomDatasets.PairedDataset)
|
||||
Tasks.image_deblurring, module_name=CustomDatasets.GoproDataset)
|
||||
class GoproImageDeblurringDataset(TorchCustomDataset):
|
||||
"""Paired image dataset for image restoration.
|
||||
"""
|
||||
|
||||
@@ -18,7 +18,7 @@ def default_loader(path):
|
||||
|
||||
|
||||
@CUSTOM_DATASETS.register_module(
|
||||
Tasks.image_deblurring, module_name=CustomDatasets.PairedDataset)
|
||||
Tasks.image_deblurring, module_name=CustomDatasets.RedsDataset)
|
||||
class RedsImageDeblurringDataset(TorchCustomDataset):
|
||||
"""Paired image dataset for image restoration.
|
||||
"""
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.metainfo import CustomDatasets
|
||||
from modelscope.msdatasets.dataset_cls.custom_datasets import (
|
||||
CUSTOM_DATASETS, TorchCustomDataset)
|
||||
from modelscope.utils.constant import Tasks
|
||||
@@ -16,7 +16,7 @@ def default_loader(path):
|
||||
|
||||
|
||||
@CUSTOM_DATASETS.register_module(
|
||||
Tasks.image_denoising, module_name=Models.nafnet)
|
||||
Tasks.image_denoising, module_name=CustomDatasets.SiddDataset)
|
||||
class SiddImageDenoisingDataset(TorchCustomDataset):
|
||||
"""Paired image dataset for image restoration.
|
||||
"""
|
||||
|
||||
@@ -1,6 +1,29 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from .gopro_image_deblurring_dataset import GoproImageDeblurringDataset
|
||||
from .reds_image_deblurring_dataset import RedsImageDeblurringDataset
|
||||
from .sidd_image_denoising import SiddImageDenoisingDataset
|
||||
from .video_summarization_dataset import VideoSummarizationDataset
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .torch_base_dataset import TorchTaskDataset
|
||||
from .gopro_image_deblurring_dataset import GoproImageDeblurringDataset
|
||||
from .reds_image_deblurring_dataset import RedsImageDeblurringDataset
|
||||
from .sidd_image_denoising import SiddImageDenoisingDataset
|
||||
from .video_summarization_dataset import VideoSummarizationDataset
|
||||
else:
|
||||
_import_structure = {
|
||||
'torch_base_dataset': ['TorchTaskDataset'],
|
||||
'gopro_image_deblurring_dataset': ['GoproImageDeblurringDataset'],
|
||||
'reds_image_deblurring_dataset': ['RedsImageDeblurringDataset'],
|
||||
'sidd_image_denoising': ['SiddImageDenoisingDataset'],
|
||||
'video_summarization_dataset': ['VideoSummarizationDataset'],
|
||||
}
|
||||
|
||||
import sys
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from modelscope.msdatasets.dataset_cls.custom_datasets.gopro_image_deblurring_dataset import \
|
||||
from modelscope.msdatasets.dataset_cls.custom_datasets import \
|
||||
GoproImageDeblurringDataset
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
logger.warning(
|
||||
'The reference has been Deprecated, '
|
||||
'please use `from modelscope.msdatasets.dataset_cls.'
|
||||
'custom_datasets.gopro_image_deblurring_dataset import GoproImageDeblurringDataset`'
|
||||
'The reference has been Deprecated in modelscope v1.4.0+, '
|
||||
'please use `from modelscope.msdatasets.dataset_cls.custom_datasets import GoproImageDeblurringDataset`'
|
||||
)
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from modelscope.msdatasets.dataset_cls.custom_datasets.reds_image_deblurring_dataset import \
|
||||
from modelscope.msdatasets.dataset_cls.custom_datasets import \
|
||||
RedsImageDeblurringDataset
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
logger.warning(
|
||||
'The reference has been Deprecated, '
|
||||
'please use `modelscope.msdatasets.dataset_cls.custom_datasets.'
|
||||
'reds_image_deblurring_dataset import RedsImageDeblurringDataset`')
|
||||
'The reference has been Deprecated in modelscope v1.4.0+, '
|
||||
'please use `modelscope.msdatasets.dataset_cls.custom_datasets import RedsImageDeblurringDataset`'
|
||||
)
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from modelscope.msdatasets.dataset_cls.custom_datasets.sidd_image_denoising import \
|
||||
from modelscope.msdatasets.dataset_cls.custom_datasets import \
|
||||
SiddImageDenoisingDataset
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
logger.warning(
|
||||
'The reference has been Deprecated, '
|
||||
'please use `from modelscope.msdatasets.dataset_cls.'
|
||||
'custom_datasets.sidd_image_denoising import SiddImageDenoisingDataset`')
|
||||
'The reference has been Deprecated in modelscope v1.4.0+, '
|
||||
'please use `from modelscope.msdatasets.dataset_cls.custom_datasets import SiddImageDenoisingDataset`'
|
||||
)
|
||||
|
||||
11
modelscope/msdatasets/task_datasets/torch_base_dataset.py
Normal file
11
modelscope/msdatasets/task_datasets/torch_base_dataset.py
Normal file
@@ -0,0 +1,11 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from modelscope.msdatasets.dataset_cls.custom_datasets import \
|
||||
TorchCustomDataset as TorchTaskDataset
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
logger.warning(
|
||||
'The reference has been Deprecated in modelscope v1.4.0+, '
|
||||
'please use `from modelscope.msdatasets.dataset_cls.custom_datasets import TorchCustomDataset`'
|
||||
)
|
||||
@@ -1,12 +1,11 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from modelscope.msdatasets.dataset_cls.custom_datasets.video_summarization_dataset import \
|
||||
from modelscope.msdatasets.dataset_cls.custom_datasets import \
|
||||
VideoSummarizationDataset
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
logger.warning(
|
||||
'The reference has been Deprecated, '
|
||||
'please use `from modelscope.msdatasets.dataset_cls.'
|
||||
'custom_datasets.video_summarization_dataset import VideoSummarizationDataset`'
|
||||
'The reference has been Deprecated in modelscope v1.4.0+, '
|
||||
'please use `from modelscope.msdatasets.dataset_cls.custom_datasets import VideoSummarizationDataset`'
|
||||
)
|
||||
|
||||
@@ -554,6 +554,23 @@ class EpochBasedTrainer(BaseTrainer):
|
||||
else:
|
||||
return datasets
|
||||
|
||||
def to_task_dataset(self, dataset: Dataset, mode: str,
|
||||
preprocessor: Preprocessor,
|
||||
**kwargs) -> TorchCustomDataset:
|
||||
r"""
|
||||
@deprecated
|
||||
This method is deprecated and may be removed in future releases, please use `build_dataset()` instead. Could be
|
||||
compatible with methods that override the to_task_dataset in other classes.
|
||||
"""
|
||||
self.logger.warning(
|
||||
'This to_task_dataset method is deprecated, please use build_dataset instead.'
|
||||
)
|
||||
|
||||
task_dataset = TorchCustomDataset(
|
||||
dataset, mode=mode, preprocessor=preprocessor, **kwargs)
|
||||
task_dataset.trainer = self
|
||||
return task_dataset
|
||||
|
||||
@staticmethod
|
||||
def build_dataset_from_cfg(model_cfg: Config,
|
||||
mode: str,
|
||||
|
||||
76
tests/msdatasets/test_custom_datasets_compatibility.py
Normal file
76
tests/msdatasets/test_custom_datasets_compatibility.py
Normal file
@@ -0,0 +1,76 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from datasets import Dataset
|
||||
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.msdatasets.dataset_cls.custom_datasets import \
|
||||
TorchCustomDataset
|
||||
from modelscope.preprocessors import Preprocessor
|
||||
from modelscope.trainers.trainer import EpochBasedTrainer
|
||||
from modelscope.utils import logger as logging
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import ModeKeys, ModelFile, Tasks
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
logger = logging.get_logger()
|
||||
|
||||
|
||||
class TestDummyEpochBasedTrainer(EpochBasedTrainer):
|
||||
|
||||
def __init__(self,
|
||||
dataset: Dataset = None,
|
||||
mode: str = ModeKeys.TRAIN,
|
||||
preprocessor: Preprocessor = None,
|
||||
**kwargs):
|
||||
super(TestDummyEpochBasedTrainer, self).__init__(**kwargs)
|
||||
self.train_dataset = self.to_task_dataset(dataset, mode, preprocessor)
|
||||
|
||||
def to_task_dataset(self, dataset: Dataset, mode: str,
|
||||
preprocessor: Preprocessor,
|
||||
**kwargs) -> TorchCustomDataset:
|
||||
src_dataset_dict = {
|
||||
'src_txt': [
|
||||
'This is test sentence1-1', 'This is test sentence2-1',
|
||||
'This is test sentence3-1'
|
||||
]
|
||||
}
|
||||
dataset = Dataset.from_dict(src_dataset_dict)
|
||||
dataset_res = TorchCustomDataset(
|
||||
datasets=dataset, mode=mode, preprocessor=preprocessor)
|
||||
dataset_res.trainer = self
|
||||
return dataset_res
|
||||
|
||||
|
||||
class TestCustomDatasetsCompatibility(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.task = Tasks.movie_scene_segmentation
|
||||
self.model_id = 'damo/cv_resnet50-bert_video-scene-segmentation_movienet'
|
||||
|
||||
cache_path = snapshot_download(self.model_id)
|
||||
self.config_path = os.path.join(cache_path, ModelFile.CONFIGURATION)
|
||||
self.cfg = Config.from_file(self.config_path)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_adaseq_import_task_datasets(self):
|
||||
from modelscope.msdatasets.task_datasets.torch_base_dataset import TorchTaskDataset
|
||||
from modelscope.msdatasets.task_datasets import GoproImageDeblurringDataset
|
||||
from modelscope.msdatasets.task_datasets import RedsImageDeblurringDataset
|
||||
from modelscope.msdatasets.task_datasets import SiddImageDenoisingDataset
|
||||
from modelscope.msdatasets.task_datasets import VideoSummarizationDataset
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_adaseq_trainer_overwrite(self):
|
||||
test_trainer = TestDummyEpochBasedTrainer(cfg_file=self.config_path)
|
||||
|
||||
assert isinstance(test_trainer.train_dataset.trainer,
|
||||
TestDummyEpochBasedTrainer)
|
||||
assert test_trainer.train_dataset.mode == ModeKeys.TRAIN
|
||||
assert isinstance(test_trainer.train_dataset._inner_dataset, Dataset)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -52,7 +52,7 @@ class ImageColorizationTrainerTest(unittest.TestCase):
|
||||
shutil.rmtree(self.tmp_dir, ignore_errors=True)
|
||||
super().tearDown()
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_trainer(self):
|
||||
kwargs = dict(
|
||||
model=self.model_id,
|
||||
|
||||
Reference in New Issue
Block a user