Refactor the task_datasets module

Refactor the task_datasets module:

1. Add new module modelscope.msdatasets.dataset_cls.custom_datasets.
2. Add new function: modelscope.msdatasets.ms_dataset.MsDataset.to_custom_dataset().
2. Add calling to_custom_dataset() func in MsDataset.load() to adapt new custom_datasets module.
3. Refactor the pipeline for loading custom dataset: 
	1) Only use MsDataset.load() function to load the custom datasets.
	2) Combine MsDataset.load() with class EpochBasedTrainer.
4. Add new entry func for building datasets in EpochBasedTrainer: see modelscope.trainers.trainer.EpochBasedTrainer.build_dataset()
5. Add new func to build the custom dataset from model configuration, see: modelscope.trainers.trainer.EpochBasedTrainer.build_dataset_from_cfg()
6. Add new registry function for building custom datasets, see: modelscope.msdatasets.dataset_cls.custom_datasets.builder.build_custom_dataset()
7. Refine the class SiameseUIETrainer to adapt the new custom_datasets module.
8. Add class TorchCustomDataset as a superclass for custom datasets classes.
9. To move modules/classes/functions:
	1) Move module msdatasets.audio to custom_datasets
	2) Move module msdatasets.cv to custom_datasets
	3) Move module bad_image_detecting to custom_datasets
	4) Move module damoyolo to custom_datasets
	5) Move module face_2d_keypoints to custom_datasets
	6) Move module hand_2d_keypoints to custom_datasets
	7) Move module human_wholebody_keypoint to custom_datasets
	8) Move module image_classification to custom_datasets
	9) Move module image_inpainting to custom_datasets
	10) Move module image_portrait_enhancement to custom_datasets
	11) Move module image_quality_assessment_degradation to custom_datasets
	12) Move module image_quality_assmessment_mos to custom_datasets
	13) Move class LanguageGuidedVideoSummarizationDataset to custom_datasets
	14) Move class MGeoRankingDataset to custom_datasets
	15) Move module movie_scene_segmentation custom_datasets
	16) Move module object_detection to custom_datasets
	17) Move module referring_video_object_segmentation to custom_datasets
	18) Move module sidd_image_denoising to custom_datasets
	19) Move module video_frame_interpolation to custom_datasets
	20) Move module video_stabilization to custom_datasets
	21) Move module video_super_resolution to custom_datasets
	22) Move class GoproImageDeblurringDataset to custom_datasets
	23) Move class EasyCVBaseDataset to custom_datasets
	24) Move class ImageInstanceSegmentationCocoDataset to custom_datasets
	25) Move class RedsImageDeblurringDataset to custom_datasets
	26) Move class TextRankingDataset to custom_datasets
	27) Move class VecoDataset to custom_datasets
	28) Move class VideoSummarizationDataset to custom_datasets
10. To delete modules/functions/classes:
	1) Del module task_datasets
	2) Del to_task_dataset() in EpochBasedTrainer
	3) Del build_dataset() in EpochBasedTrainer and renew a same name function.
11. Rename class Datasets to CustomDatasets in metainfo.py

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11872747
This commit is contained in:
xingjun.wxj
2023-03-10 09:03:32 +08:00
committed by wenmeng.zwm
parent fc7daea9c2
commit e02a260c93
135 changed files with 1158 additions and 867 deletions

View File

@@ -1,14 +0,0 @@
modelscope.msdatasets.cv
================================
.. automodule:: modelscope.msdatasets.cv
.. currentmodule:: modelscope.msdatasets.cv
.. autosummary::
:toctree: generated
:nosignatures:
:template: classtemplate.rst
easycv_base.EasyCVBaseDataset
image_classification.ClsDataset

View File

@@ -0,0 +1,41 @@
modelscope.msdatasets.dataset_cls.custom_datasets
====================
.. automodule:: modelscope.msdatasets.dataset_cls.custom_datasets
.. currentmodule:: modelscope.msdatasets.dataset_cls.custom_datasets
.. autosummary::
:toctree: generated
:nosignatures:
:template: classtemplate.rst
EasyCVBaseDataset
TorchCustomDataset
MovieSceneSegmentationDataset
ImageInstanceSegmentationCocoDataset
GoproImageDeblurringDataset
LanguageGuidedVideoSummarizationDataset
MGeoRankingDataset
RedsImageDeblurringDataset
TextRankingDataset
VecoDataset
VideoSummarizationDataset
BadImageDetectingDataset
ImageInpaintingDataset
ImagePortraitEnhancementDataset
ImageQualityAssessmentDegradationDataset
ImageQualityAssessmentMosDataset
ReferringVideoObjectSegmentationDataset
SiddImageDenoisingDataset
VideoFrameInterpolationDataset
VideoStabilizationDataset
VideoSuperResolutionDataset
SegDataset
FaceKeypointDataset
HandCocoWholeBodyDataset
WholeBodyCocoTopDownDataset
ClsDataset
DetImagesMixDataset
DetDataset

View File

@@ -0,0 +1,15 @@
modelscope.msdatasets.dataset_cls
====================
.. automodule:: modelscope.msdatasets.dataset_cls
.. currentmodule:: modelscope.msdatasets.dataset_cls
.. autosummary::
:toctree: generated
:nosignatures:
:template: classtemplate.rst
ExternalDataset
NativeIterableDataset

View File

@@ -10,5 +10,4 @@ modelscope.msdatasets.ms_dataset
:nosignatures: :nosignatures:
:template: classtemplate.rst :template: classtemplate.rst
MsMapDataset
MsDataset MsDataset

View File

@@ -1137,7 +1137,7 @@ class LR_Schedulers(object):
ExponentialWarmup = 'ExponentialWarmup' ExponentialWarmup = 'ExponentialWarmup'
class Datasets(object): class CustomDatasets(object):
""" Names for different datasets. """ Names for different datasets.
""" """
ClsDataset = 'ClsDataset' ClsDataset = 'ClsDataset'

View File

@@ -8,8 +8,8 @@ from modelscope.models.cv.tinynas_detection.damo.apis.detector_inference import
inference inference
from modelscope.models.cv.tinynas_detection.damo.detectors.detector import \ from modelscope.models.cv.tinynas_detection.damo.detectors.detector import \
build_local_model build_local_model
from modelscope.msdatasets.task_datasets.damoyolo import (build_dataloader, from modelscope.msdatasets.dataset_cls.custom_datasets.damoyolo import (
build_dataset) build_dataloader, build_dataset)
def mkdir(path): def mkdir(path):

View File

@@ -5,7 +5,7 @@ import os
import torch import torch
from tqdm import tqdm from tqdm import tqdm
from modelscope.msdatasets.task_datasets.damoyolo.evaluation import evaluate from modelscope.msdatasets.dataset_cls.custom_datasets.damoyolo import evaluate
from modelscope.utils.logger import get_logger from modelscope.utils.logger import get_logger
from modelscope.utils.timer import Timer, get_time_str from modelscope.utils.timer import Timer, get_time_str
from modelscope.utils.torch_utils import (all_gather, get_world_size, from modelscope.utils.torch_utils import (all_gather, get_world_size,

View File

@@ -1,3 +1,2 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.
from . import cv
from .ms_dataset import MsDataset from .ms_dataset import MsDataset

View File

@@ -1,3 +0,0 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from . import (image_classification, image_semantic_segmentation,
object_detection)

View File

@@ -13,6 +13,7 @@ from modelscope.msdatasets.context.dataset_context_config import \
DatasetContextConfig DatasetContextConfig
from modelscope.msdatasets.data_files.data_files_manager import \ from modelscope.msdatasets.data_files.data_files_manager import \
DataFilesManager DataFilesManager
from modelscope.msdatasets.dataset_cls.dataset import ExternalDataset
from modelscope.msdatasets.meta.data_meta_manager import DataMetaManager from modelscope.msdatasets.meta.data_meta_manager import DataMetaManager
from modelscope.utils.constant import DatasetFormations from modelscope.utils.constant import DatasetFormations
@@ -62,7 +63,8 @@ class OssDataLoader(BaseDataLoader):
self.data_files_builder: Optional[DataFilesManager] = None self.data_files_builder: Optional[DataFilesManager] = None
self.dataset: Optional[Union[Dataset, IterableDataset, DatasetDict, self.dataset: Optional[Union[Dataset, IterableDataset, DatasetDict,
IterableDatasetDict]] = None IterableDatasetDict,
ExternalDataset]] = None
self.builder: Optional[DatasetBuilder] = None self.builder: Optional[DatasetBuilder] = None
self.data_files_manager: Optional[DataFilesManager] = None self.data_files_manager: Optional[DataFilesManager] = None
@@ -141,7 +143,8 @@ class OssDataLoader(BaseDataLoader):
self.builder) self.builder)
def _post_process(self) -> None: def _post_process(self) -> None:
... if isinstance(self.dataset, ExternalDataset):
self.dataset.custom_map = self.dataset_context_config.data_meta_config.meta_type_map
class MaxComputeDataLoader(BaseDataLoader): class MaxComputeDataLoader(BaseDataLoader):

View File

@@ -1 +1,3 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.
from .dataset import ExternalDataset, NativeIterableDataset

View File

@@ -0,0 +1,84 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import TYPE_CHECKING
from modelscope.utils.import_utils import LazyImportModule
if TYPE_CHECKING:
from .easycv_base import EasyCVBaseDataset
from .builder import CUSTOM_DATASETS, build_custom_dataset
from .torch_custom_dataset import TorchCustomDataset
from .movie_scene_segmentation.movie_scene_segmentation_dataset import MovieSceneSegmentationDataset
from .image_instance_segmentation_coco_dataset import ImageInstanceSegmentationCocoDataset
from .gopro_image_deblurring_dataset import GoproImageDeblurringDataset
from .language_guided_video_summarization_dataset import LanguageGuidedVideoSummarizationDataset
from .mgeo_ranking_dataset import MGeoRankingDataset
from .reds_image_deblurring_dataset import RedsImageDeblurringDataset
from .text_ranking_dataset import TextRankingDataset
from .veco_dataset import VecoDataset
from .video_summarization_dataset import VideoSummarizationDataset
from .audio import KWSDataset, KWSDataLoader, kws_nearfield_dataset
from .bad_image_detecting import BadImageDetectingDataset
from .image_inpainting import ImageInpaintingDataset
from .image_portrait_enhancement import ImagePortraitEnhancementDataset
from .image_quality_assessment_degradation import ImageQualityAssessmentDegradationDataset
from .image_quality_assmessment_mos import ImageQualityAssessmentMosDataset
from .referring_video_object_segmentation import ReferringVideoObjectSegmentationDataset
from .sidd_image_denoising import SiddImageDenoisingDataset
from .video_frame_interpolation import VideoFrameInterpolationDataset
from .video_stabilization import VideoStabilizationDataset
from .video_super_resolution import VideoSuperResolutionDataset
from .image_semantic_segmentation import SegDataset
from .face_2d_keypoins import FaceKeypointDataset
from .hand_2d_keypoints import HandCocoWholeBodyDataset
from .human_wholebody_keypoint import WholeBodyCocoTopDownDataset
from .image_classification import ClsDataset
from .object_detection import DetDataset, DetImagesMixDataset
from .ocr_detection import DataLoader, ImageDataset, QuadMeasurer
from .ocr_recognition_dataset import OCRRecognitionDataset
else:
_import_structure = {
'easycv_base': ['EasyCVBaseDataset'],
'builder': ['CUSTOM_DATASETS', 'build_custom_dataset'],
'torch_custom_dataset': ['TorchCustomDataset'],
'movie_scene_segmentation_dataset': ['MovieSceneSegmentationDataset'],
'image_instance_segmentation_coco_dataset':
['ImageInstanceSegmentationCocoDataset'],
'gopro_image_deblurring_dataset': ['GoproImageDeblurringDataset'],
'language_guided_video_summarization_dataset':
['LanguageGuidedVideoSummarizationDataset'],
'mgeo_ranking_dataset': ['MGeoRankingDataset'],
'reds_image_deblurring_dataset': ['RedsImageDeblurringDataset'],
'text_ranking_dataset': ['TextRankingDataset'],
'veco_dataset': ['VecoDataset'],
'video_summarization_dataset': ['VideoSummarizationDataset'],
'audio': ['KWSDataset', 'KWSDataLoader', 'kws_nearfield_dataset'],
'bad_image_detecting': ['BadImageDetectingDataset'],
'image_inpainting': ['ImageInpaintingDataset'],
'image_portrait_enhancement': ['ImagePortraitEnhancementDataset'],
'image_quality_assessment_degradation':
['ImageQualityAssessmentDegradationDataset'],
'image_quality_assmessment_mos': ['ImageQualityAssessmentMosDataset'],
'referring_video_object_segmentation':
['ReferringVideoObjectSegmentationDataset'],
'sidd_image_denoising': ['SiddImageDenoisingDataset'],
'video_frame_interpolation': ['VideoFrameInterpolationDataset'],
'video_stabilization': ['VideoStabilizationDataset'],
'video_super_resolution': ['VideoSuperResolutionDataset'],
'image_semantic_segmentation': ['SegDataset'],
'face_2d_keypoins': ['FaceKeypointDataset'],
'hand_2d_keypoints': ['HandCocoWholeBodyDataset'],
'human_wholebody_keypoint': ['WholeBodyCocoTopDownDataset'],
'image_classification': ['ClsDataset'],
'object_detection': ['DetDataset', 'DetImagesMixDataset'],
'ocr_detection': ['DataLoader', 'ImageDataset', 'QuadMeasurer'],
'ocr_recognition_dataset': ['OCRRecognitionDataset'],
}
import sys
sys.modules[__name__] = LazyImportModule(
__name__,
globals()['__file__'],
_import_structure,
module_spec=__spec__,
extra_objects={},
)

View File

@@ -5,7 +5,6 @@ import math
import os.path import os.path
import queue import queue
import threading import threading
import time
import numpy as np import numpy as np
import torch import torch

View File

@@ -18,7 +18,7 @@ import torch
import torch.distributed as dist import torch.distributed as dist
from torch.utils.data import IterableDataset from torch.utils.data import IterableDataset
import modelscope.msdatasets.task_datasets.audio.kws_nearfield_processor as processor import modelscope.msdatasets.dataset_cls.custom_datasets.audio.kws_nearfield_processor as processor
from modelscope.trainers.audio.kws_utils.file_utils import (make_pair, from modelscope.trainers.audio.kws_utils.file_utils import (make_pair,
read_lists) read_lists)
from modelscope.utils.logger import get_logger from modelscope.utils.logger import get_logger

View File

@@ -1,12 +1,8 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.
import cv2
import numpy as np
from modelscope.metainfo import Models from modelscope.metainfo import Models
from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS from modelscope.msdatasets.dataset_cls.custom_datasets import (
from modelscope.msdatasets.task_datasets.torch_base_dataset import \ CUSTOM_DATASETS, TorchCustomDataset)
TorchTaskDataset
from modelscope.outputs import OutputKeys from modelscope.outputs import OutputKeys
from modelscope.preprocessors import LoadImage from modelscope.preprocessors import LoadImage
from modelscope.preprocessors.cv.bad_image_detecting_preprocessor import \ from modelscope.preprocessors.cv.bad_image_detecting_preprocessor import \
@@ -14,9 +10,9 @@ from modelscope.preprocessors.cv.bad_image_detecting_preprocessor import \
from modelscope.utils.constant import Tasks from modelscope.utils.constant import Tasks
@TASK_DATASETS.register_module( @CUSTOM_DATASETS.register_module(
Tasks.bad_image_detecting, module_name=Models.bad_image_detecting) Tasks.bad_image_detecting, module_name=Models.bad_image_detecting)
class BadImageDetectingDataset(TorchTaskDataset): class BadImageDetectingDataset(TorchCustomDataset):
"""Paired image dataset for bad image detecting. """Paired image dataset for bad image detecting.
""" """

View File

@@ -3,13 +3,13 @@
from modelscope.utils.config import ConfigDict from modelscope.utils.config import ConfigDict
from modelscope.utils.registry import Registry, build_from_cfg from modelscope.utils.registry import Registry, build_from_cfg
TASK_DATASETS = Registry('task_datasets') CUSTOM_DATASETS = Registry('custom_datasets')
def build_task_dataset(cfg: ConfigDict, def build_custom_dataset(cfg: ConfigDict,
task_name: str = None, task_name: str,
default_args: dict = None): default_args: dict = None):
""" Build task specific dataset processor given model config dict and the task name. """ Build custom dataset for user-define dataset given model config and task name.
Args: Args:
cfg (:obj:`ConfigDict`): config dict for model object. cfg (:obj:`ConfigDict`): config dict for model object.
@@ -18,4 +18,4 @@ def build_task_dataset(cfg: ConfigDict,
default_args (dict, optional): Default initialization arguments. default_args (dict, optional): Default initialization arguments.
""" """
return build_from_cfg( return build_from_cfg(
cfg, TASK_DATASETS, group_key=task_name, default_args=default_args) cfg, CUSTOM_DATASETS, group_key=task_name, default_args=default_args)

View File

@@ -1,2 +1,3 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.
from .build import build_dataloader, build_dataset from .build import build_dataloader, build_dataset
from .evaluation import evaluate

View File

@@ -1,6 +1,6 @@
# Copyright © Alibaba, Inc. and its affiliates. # Copyright © Alibaba, Inc. and its affiliates.
from modelscope.msdatasets.task_datasets.damoyolo import datasets from .. import datasets
from .coco import coco_evaluation from .coco import coco_evaluation

View File

@@ -1,15 +1,16 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.
from easycv.datasets.face import FaceKeypointDataset as _FaceKeypointDataset from easycv.datasets.face import FaceKeypointDataset as _FaceKeypointDataset
from modelscope.metainfo import Datasets from modelscope.metainfo import CustomDatasets
from modelscope.msdatasets.cv.easycv_base import EasyCVBaseDataset from modelscope.msdatasets.dataset_cls.custom_datasets import CUSTOM_DATASETS
from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS from modelscope.msdatasets.dataset_cls.custom_datasets.easycv_base import \
EasyCVBaseDataset
from modelscope.utils.constant import Tasks from modelscope.utils.constant import Tasks
@TASK_DATASETS.register_module( @CUSTOM_DATASETS.register_module(
group_key=Tasks.face_2d_keypoints, group_key=Tasks.face_2d_keypoints,
module_name=Datasets.Face2dKeypointsDataset) module_name=CustomDatasets.Face2dKeypointsDataset)
class FaceKeypointDataset(EasyCVBaseDataset, _FaceKeypointDataset): class FaceKeypointDataset(EasyCVBaseDataset, _FaceKeypointDataset):
"""EasyCV dataset for face 2d keypoints. """EasyCV dataset for face 2d keypoints.

View File

@@ -3,14 +3,13 @@
import cv2 import cv2
import numpy as np import numpy as np
from modelscope.metainfo import Datasets from modelscope.metainfo import CustomDatasets
from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS from modelscope.msdatasets.dataset_cls.custom_datasets import (
from modelscope.msdatasets.task_datasets.sidd_image_denoising.data_utils import ( CUSTOM_DATASETS, TorchCustomDataset)
from modelscope.msdatasets.dataset_cls.custom_datasets.sidd_image_denoising.data_utils import (
img2tensor, padding) img2tensor, padding)
from modelscope.msdatasets.task_datasets.sidd_image_denoising.transforms import ( from modelscope.msdatasets.dataset_cls.custom_datasets.sidd_image_denoising.transforms import (
augment, paired_random_crop) augment, paired_random_crop)
from modelscope.msdatasets.task_datasets.torch_base_dataset import \
TorchTaskDataset
from modelscope.utils.constant import Tasks from modelscope.utils.constant import Tasks
@@ -18,9 +17,9 @@ def default_loader(path):
return cv2.imread(path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.0 return cv2.imread(path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.0
@TASK_DATASETS.register_module( @CUSTOM_DATASETS.register_module(
Tasks.image_deblurring, module_name=Datasets.PairedDataset) Tasks.image_deblurring, module_name=CustomDatasets.PairedDataset)
class GoproImageDeblurringDataset(TorchTaskDataset): class GoproImageDeblurringDataset(TorchCustomDataset):
"""Paired image dataset for image restoration. """Paired image dataset for image restoration.
""" """

View File

@@ -0,0 +1,22 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import TYPE_CHECKING
from modelscope.utils.import_utils import LazyImportModule
if TYPE_CHECKING:
from .hand_2d_keypoints_dataset import HandCocoWholeBodyDataset
else:
_import_structure = {
'hand_2d_keypoints_dataset': ['HandCocoWholeBodyDataset']
}
import sys
sys.modules[__name__] = LazyImportModule(
__name__,
globals()['__file__'],
_import_structure,
module_spec=__spec__,
extra_objects={},
)

View File

@@ -2,15 +2,16 @@
from easycv.datasets.pose import \ from easycv.datasets.pose import \
HandCocoWholeBodyDataset as _HandCocoWholeBodyDataset HandCocoWholeBodyDataset as _HandCocoWholeBodyDataset
from modelscope.metainfo import Datasets from modelscope.metainfo import CustomDatasets
from modelscope.msdatasets.cv.easycv_base import EasyCVBaseDataset from modelscope.msdatasets.dataset_cls.custom_datasets import CUSTOM_DATASETS
from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS from modelscope.msdatasets.dataset_cls.custom_datasets.easycv_base import \
EasyCVBaseDataset
from modelscope.utils.constant import Tasks from modelscope.utils.constant import Tasks
@TASK_DATASETS.register_module( @CUSTOM_DATASETS.register_module(
group_key=Tasks.hand_2d_keypoints, group_key=Tasks.hand_2d_keypoints,
module_name=Datasets.HandCocoWholeBodyDataset) module_name=CustomDatasets.HandCocoWholeBodyDataset)
class HandCocoWholeBodyDataset(EasyCVBaseDataset, _HandCocoWholeBodyDataset): class HandCocoWholeBodyDataset(EasyCVBaseDataset, _HandCocoWholeBodyDataset):
"""EasyCV dataset for human hand 2d keypoints. """EasyCV dataset for human hand 2d keypoints.

View File

@@ -2,15 +2,16 @@
from easycv.datasets.pose import \ from easycv.datasets.pose import \
WholeBodyCocoTopDownDataset as _WholeBodyCocoTopDownDataset WholeBodyCocoTopDownDataset as _WholeBodyCocoTopDownDataset
from modelscope.metainfo import Datasets from modelscope.metainfo import CustomDatasets
from modelscope.msdatasets.cv.easycv_base import EasyCVBaseDataset from modelscope.msdatasets.dataset_cls.custom_datasets import CUSTOM_DATASETS
from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS from modelscope.msdatasets.dataset_cls.custom_datasets.easycv_base import \
EasyCVBaseDataset
from modelscope.utils.constant import Tasks from modelscope.utils.constant import Tasks
@TASK_DATASETS.register_module( @CUSTOM_DATASETS.register_module(
group_key=Tasks.human_wholebody_keypoint, group_key=Tasks.human_wholebody_keypoint,
module_name=Datasets.HumanWholeBodyKeypointDataset) module_name=CustomDatasets.HumanWholeBodyKeypointDataset)
class WholeBodyCocoTopDownDataset(EasyCVBaseDataset, class WholeBodyCocoTopDownDataset(EasyCVBaseDataset,
_WholeBodyCocoTopDownDataset): _WholeBodyCocoTopDownDataset):
"""EasyCV dataset for human whole body 2d keypoints. """EasyCV dataset for human whole body 2d keypoints.

View File

@@ -1,14 +1,16 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.
from easycv.datasets.classification import ClsDataset as _ClsDataset from easycv.datasets.classification import ClsDataset as _ClsDataset
from modelscope.metainfo import Datasets from modelscope.metainfo import CustomDatasets
from modelscope.msdatasets.cv.easycv_base import EasyCVBaseDataset from modelscope.msdatasets.dataset_cls.custom_datasets import CUSTOM_DATASETS
from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS from modelscope.msdatasets.dataset_cls.custom_datasets.easycv_base import \
EasyCVBaseDataset
from modelscope.utils.constant import Tasks from modelscope.utils.constant import Tasks
@TASK_DATASETS.register_module( @CUSTOM_DATASETS.register_module(
group_key=Tasks.image_classification, module_name=Datasets.ClsDataset) group_key=Tasks.image_classification,
module_name=CustomDatasets.ClsDataset)
class ClsDataset(_ClsDataset): class ClsDataset(_ClsDataset):
"""EasyCV dataset for classification. """EasyCV dataset for classification.

View File

@@ -4,13 +4,11 @@ from typing import TYPE_CHECKING
from modelscope.utils.import_utils import LazyImportModule from modelscope.utils.import_utils import LazyImportModule
if TYPE_CHECKING: if TYPE_CHECKING:
from .hand_2d_keypoints_dataset import Hand2DKeypointDataset from .image_inpainting_dataset import ImageInpaintingDataset
else: else:
_import_structure = { _import_structure = {
'hand_2d_keypoints_dataset': ['Hand2DKeypointDataset'] 'image_inpainting_dataset': ['ImageInpaintingDataset'],
} }
import sys import sys
sys.modules[__name__] = LazyImportModule( sys.modules[__name__] = LazyImportModule(

View File

@@ -3,20 +3,16 @@ Part of the implementation is borrowed and modified from LaMa,
publicly available at https://github.com/saic-mdal/lama publicly available at https://github.com/saic-mdal/lama
""" """
import glob import glob
import os
import os.path as osp import os.path as osp
from enum import Enum from enum import Enum
import albumentations as A import albumentations as A
import cv2 import cv2
import json
import numpy as np import numpy as np
import torch
from modelscope.metainfo import Models from modelscope.metainfo import Models
from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS from modelscope.msdatasets.dataset_cls.custom_datasets import (
from modelscope.msdatasets.task_datasets.torch_base_dataset import \ CUSTOM_DATASETS, TorchCustomDataset)
TorchTaskDataset
from modelscope.utils.constant import Tasks from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger from modelscope.utils.logger import get_logger
from .aug import IAAAffine2, IAAPerspective2 from .aug import IAAAffine2, IAAPerspective2
@@ -296,9 +292,9 @@ def get_transforms(test_mode, out_size):
return transform return transform
@TASK_DATASETS.register_module( @CUSTOM_DATASETS.register_module(
Tasks.image_inpainting, module_name=Models.image_inpainting) Tasks.image_inpainting, module_name=Models.image_inpainting)
class ImageInpaintingDataset(TorchTaskDataset): class ImageInpaintingDataset(TorchCustomDataset):
def __init__(self, **kwargs): def __init__(self, **kwargs):
split_config = kwargs['split_config'] split_config = kwargs['split_config']

View File

@@ -6,9 +6,9 @@ import numpy as np
from pycocotools.coco import COCO from pycocotools.coco import COCO
from modelscope.metainfo import Models from modelscope.metainfo import Models
from modelscope.msdatasets.dataset_cls.custom_datasets import (
CUSTOM_DATASETS, TorchCustomDataset)
from modelscope.utils.constant import Tasks from modelscope.utils.constant import Tasks
from .builder import TASK_DATASETS
from .torch_base_dataset import TorchTaskDataset
DATASET_STRUCTURE = { DATASET_STRUCTURE = {
'train': { 'train': {
@@ -22,10 +22,10 @@ DATASET_STRUCTURE = {
} }
@TASK_DATASETS.register_module( @CUSTOM_DATASETS.register_module(
module_name=Models.cascade_mask_rcnn_swin, module_name=Models.cascade_mask_rcnn_swin,
group_key=Tasks.image_segmentation) group_key=Tasks.image_segmentation)
class ImageInstanceSegmentationCocoDataset(TorchTaskDataset): class ImageInstanceSegmentationCocoDataset(TorchCustomDataset):
"""Coco-style dataset for image instance segmentation. """Coco-style dataset for image instance segmentation.
Args: Args:

View File

@@ -3,10 +3,9 @@
import cv2 import cv2
import numpy as np import numpy as np
from modelscope.metainfo import Datasets, Models from modelscope.metainfo import CustomDatasets
from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS from modelscope.msdatasets.dataset_cls.custom_datasets import (
from modelscope.msdatasets.task_datasets.torch_base_dataset import \ CUSTOM_DATASETS, TorchCustomDataset)
TorchTaskDataset
from modelscope.utils.constant import Tasks from modelscope.utils.constant import Tasks
from .data_utils import img2tensor from .data_utils import img2tensor
@@ -15,9 +14,9 @@ def default_loader(path):
return cv2.imread(path, cv2.IMREAD_COLOR).astype(np.float32) / 255.0 return cv2.imread(path, cv2.IMREAD_COLOR).astype(np.float32) / 255.0
@TASK_DATASETS.register_module( @CUSTOM_DATASETS.register_module(
Tasks.image_portrait_enhancement, module_name=Datasets.PairedDataset) Tasks.image_portrait_enhancement, module_name=CustomDatasets.PairedDataset)
class ImagePortraitEnhancementDataset(TorchTaskDataset): class ImagePortraitEnhancementDataset(TorchCustomDataset):
"""Paired image dataset for image portrait enhancement. """Paired image dataset for image portrait enhancement.
""" """

View File

@@ -1,21 +1,18 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.
import cv2
import numpy as np
from torchvision import transforms from torchvision import transforms
from modelscope.metainfo import Models from modelscope.metainfo import Models
from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS from modelscope.msdatasets.dataset_cls.custom_datasets import (
from modelscope.msdatasets.task_datasets.torch_base_dataset import \ CUSTOM_DATASETS, TorchCustomDataset)
TorchTaskDataset
from modelscope.preprocessors import LoadImage from modelscope.preprocessors import LoadImage
from modelscope.utils.constant import Tasks from modelscope.utils.constant import Tasks
@TASK_DATASETS.register_module( @CUSTOM_DATASETS.register_module(
Tasks.image_quality_assessment_degradation, Tasks.image_quality_assessment_degradation,
module_name=Models.image_quality_assessment_degradation) module_name=Models.image_quality_assessment_degradation)
class ImageQualityAssessmentDegradationDataset(TorchTaskDataset): class ImageQualityAssessmentDegradationDataset(TorchCustomDataset):
"""Paired image dataset for image quality assessment degradation. """Paired image dataset for image quality assessment degradation.
""" """

View File

@@ -1,20 +1,16 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.
import cv2
import numpy as np
from modelscope.metainfo import Models from modelscope.metainfo import Models
from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS from modelscope.msdatasets.dataset_cls.custom_datasets import (
from modelscope.msdatasets.task_datasets.torch_base_dataset import \ CUSTOM_DATASETS, TorchCustomDataset)
TorchTaskDataset
from modelscope.preprocessors.cv import ImageQualityAssessmentMosPreprocessor from modelscope.preprocessors.cv import ImageQualityAssessmentMosPreprocessor
from modelscope.utils.constant import Tasks from modelscope.utils.constant import Tasks
@TASK_DATASETS.register_module( @CUSTOM_DATASETS.register_module(
Tasks.image_quality_assessment_mos, Tasks.image_quality_assessment_mos,
module_name=Models.image_quality_assessment_mos) module_name=Models.image_quality_assessment_mos)
class ImageQualityAssessmentMosDataset(TorchTaskDataset): class ImageQualityAssessmentMosDataset(TorchCustomDataset):
"""Paired image dataset for image quality assessment mos. """Paired image dataset for image quality assessment mos.
""" """

View File

@@ -1,14 +1,15 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.
from easycv.datasets.segmentation import SegDataset as _SegDataset from easycv.datasets.segmentation import SegDataset as _SegDataset
from modelscope.metainfo import Datasets from modelscope.metainfo import CustomDatasets
from modelscope.msdatasets.cv.easycv_base import EasyCVBaseDataset from modelscope.msdatasets.dataset_cls.custom_datasets import CUSTOM_DATASETS
from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS from modelscope.msdatasets.dataset_cls.custom_datasets.easycv_base import \
EasyCVBaseDataset
from modelscope.utils.constant import Tasks from modelscope.utils.constant import Tasks
@TASK_DATASETS.register_module( @CUSTOM_DATASETS.register_module(
group_key=Tasks.image_segmentation, module_name=Datasets.SegDataset) group_key=Tasks.image_segmentation, module_name=CustomDatasets.SegDataset)
class SegDataset(EasyCVBaseDataset, _SegDataset): class SegDataset(EasyCVBaseDataset, _SegDataset):
"""EasyCV dataset for Sementic segmentation. """EasyCV dataset for Sementic segmentation.
For more details, please refer to : For more details, please refer to :

View File

@@ -25,16 +25,15 @@ import numpy as np
import torch import torch
from modelscope.metainfo import Models from modelscope.metainfo import Models
from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS from modelscope.msdatasets.dataset_cls.custom_datasets import (
from modelscope.msdatasets.task_datasets.torch_base_dataset import \ CUSTOM_DATASETS, TorchCustomDataset)
TorchTaskDataset
from modelscope.utils.constant import Tasks from modelscope.utils.constant import Tasks
@TASK_DATASETS.register_module( @CUSTOM_DATASETS.register_module(
Tasks.language_guided_video_summarization, Tasks.language_guided_video_summarization,
module_name=Models.language_guided_video_summarization) module_name=Models.language_guided_video_summarization)
class LanguageGuidedVideoSummarizationDataset(TorchTaskDataset): class LanguageGuidedVideoSummarizationDataset(TorchCustomDataset):
def __init__(self, mode, opt, root_dir): def __init__(self, mode, opt, root_dir):
self.mode = mode self.mode = mode

View File

@@ -1,24 +1,20 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.
import random import random
from dataclasses import dataclass from typing import Any, List, Union
from typing import Any, Dict, List, Tuple, Union
import json import json
import torch import torch
from datasets import Dataset, IterableDataset, concatenate_datasets
from torch.utils.data import ConcatDataset from torch.utils.data import ConcatDataset
from transformers import DataCollatorWithPadding
from modelscope.metainfo import Models from modelscope.metainfo import Models
from modelscope.msdatasets.dataset_cls.custom_datasets import (
CUSTOM_DATASETS, TorchCustomDataset)
from modelscope.utils.constant import ModeKeys, Tasks from modelscope.utils.constant import ModeKeys, Tasks
from .base import TaskDataset
from .builder import TASK_DATASETS
from .torch_base_dataset import TorchTaskDataset
@TASK_DATASETS.register_module( @CUSTOM_DATASETS.register_module(
group_key=Tasks.text_ranking, module_name=Models.mgeo) group_key=Tasks.text_ranking, module_name=Models.mgeo)
class MGeoRankingDataset(TorchTaskDataset): class MGeoRankingDataset(TorchCustomDataset):
def __init__(self, def __init__(self,
datasets: Union[Any, List[Any]], datasets: Union[Any, List[Any]],

View File

@@ -0,0 +1,20 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import TYPE_CHECKING
from modelscope.utils.import_utils import LazyImportModule
if TYPE_CHECKING:
from .movie_scene_segmentation_dataset import MovieSceneSegmentationDataset
else:
_import_structure = {
'movie_scene_segmentation_dataset': ['MovieSceneSegmentationDataset'],
}
import sys
sys.modules[__name__] = LazyImportModule(
__name__,
globals()['__file__'],
_import_structure,
module_spec=__spec__,
extra_objects={},
)

View File

@@ -10,9 +10,8 @@ import torch
from torchvision.datasets.folder import pil_loader from torchvision.datasets.folder import pil_loader
from modelscope.metainfo import Models from modelscope.metainfo import Models
from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS from modelscope.msdatasets.dataset_cls.custom_datasets.builder import \
from modelscope.msdatasets.task_datasets.torch_base_dataset import \ CUSTOM_DATASETS
TorchTaskDataset
from modelscope.utils.constant import Tasks from modelscope.utils.constant import Tasks
from . import sampler from . import sampler
@@ -30,9 +29,9 @@ DATASET_STRUCTURE = {
} }
@TASK_DATASETS.register_module( @CUSTOM_DATASETS.register_module(
Tasks.movie_scene_segmentation, module_name=Models.resnet50_bert) group_key=Tasks.movie_scene_segmentation, module_name=Models.resnet50_bert)
class MovieSceneSegmentationDataset(TorchTaskDataset): class MovieSceneSegmentationDataset(torch.utils.data.Dataset):
"""dataset for movie scene segmentation. """dataset for movie scene segmentation.
Args: Args:

View File

@@ -1,20 +1,21 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.
import os.path as osp
from easycv.datasets.detection import DetDataset as _DetDataset from easycv.datasets.detection import DetDataset as _DetDataset
from easycv.datasets.detection import \ from easycv.datasets.detection import \
DetImagesMixDataset as _DetImagesMixDataset DetImagesMixDataset as _DetImagesMixDataset
from modelscope.metainfo import Datasets from modelscope.metainfo import CustomDatasets
from modelscope.msdatasets.cv.easycv_base import EasyCVBaseDataset from modelscope.msdatasets.dataset_cls.custom_datasets import CUSTOM_DATASETS
from modelscope.msdatasets.task_datasets import TASK_DATASETS from modelscope.msdatasets.dataset_cls.custom_datasets.easycv_base import \
EasyCVBaseDataset
from modelscope.utils.constant import Tasks from modelscope.utils.constant import Tasks
@TASK_DATASETS.register_module( @CUSTOM_DATASETS.register_module(
group_key=Tasks.image_object_detection, module_name=Datasets.DetDataset) group_key=Tasks.image_object_detection,
@TASK_DATASETS.register_module( module_name=CustomDatasets.DetDataset)
group_key=Tasks.image_segmentation, module_name=Datasets.DetDataset) @CUSTOM_DATASETS.register_module(
group_key=Tasks.image_segmentation, module_name=CustomDatasets.DetDataset)
class DetDataset(EasyCVBaseDataset, _DetDataset): class DetDataset(EasyCVBaseDataset, _DetDataset):
"""EasyCV dataset for object detection. """EasyCV dataset for object detection.
For more details, please refer to https://github.com/alibaba/EasyCV/blob/master/easycv/datasets/detection/raw.py . For more details, please refer to https://github.com/alibaba/EasyCV/blob/master/easycv/datasets/detection/raw.py .
@@ -47,12 +48,12 @@ class DetDataset(EasyCVBaseDataset, _DetDataset):
_DetDataset.__init__(self, *args, **kwargs) _DetDataset.__init__(self, *args, **kwargs)
@TASK_DATASETS.register_module( @CUSTOM_DATASETS.register_module(
group_key=Tasks.image_object_detection, group_key=Tasks.image_object_detection,
module_name=Datasets.DetImagesMixDataset) module_name=CustomDatasets.DetImagesMixDataset)
@TASK_DATASETS.register_module( @CUSTOM_DATASETS.register_module(
group_key=Tasks.domain_specific_object_detection, group_key=Tasks.domain_specific_object_detection,
module_name=Datasets.DetImagesMixDataset) module_name=CustomDatasets.DetImagesMixDataset)
class DetImagesMixDataset(EasyCVBaseDataset, _DetImagesMixDataset): class DetImagesMixDataset(EasyCVBaseDataset, _DetImagesMixDataset):
"""EasyCV dataset for object detection, a wrapper of multiple images mixed dataset. """EasyCV dataset for object detection, a wrapper of multiple images mixed dataset.
Suitable for training on multiple images mixed data augmentation like Suitable for training on multiple images mixed data augmentation like

View File

@@ -1,3 +1,4 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.
from .data_loader import DataLoader from .data_loader import DataLoader
from .image_dataset import ImageDataset from .image_dataset import ImageDataset
from .measures import QuadMeasurer

View File

@@ -9,9 +9,10 @@ import torch
from PIL import Image from PIL import Image
from modelscope.metainfo import Models from modelscope.metainfo import Models
from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS from modelscope.msdatasets.dataset_cls.custom_datasets.builder import \
from modelscope.msdatasets.task_datasets.torch_base_dataset import \ CUSTOM_DATASETS
TorchTaskDataset from modelscope.msdatasets.dataset_cls.custom_datasets.torch_custom_dataset import \
TorchCustomDataset
from modelscope.utils.constant import Tasks from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger from modelscope.utils.logger import get_logger
@@ -29,9 +30,9 @@ def Q2B(uchar):
return chr(inside_code) return chr(inside_code)
@TASK_DATASETS.register_module( @CUSTOM_DATASETS.register_module(
Tasks.ocr_recognition, module_name=Models.ocr_recognition) Tasks.ocr_recognition, module_name=Models.ocr_recognition)
class OCRRecognitionDataset(TorchTaskDataset): class OCRRecognitionDataset(TorchCustomDataset):
def __init__(self, **kwargs): def __init__(self, **kwargs):
split_config = kwargs['split_config'] split_config = kwargs['split_config']

View File

@@ -3,14 +3,13 @@
import cv2 import cv2
import numpy as np import numpy as np
from modelscope.metainfo import Datasets from modelscope.metainfo import CustomDatasets
from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS from modelscope.msdatasets.dataset_cls.custom_datasets import (
from modelscope.msdatasets.task_datasets.sidd_image_denoising.data_utils import ( CUSTOM_DATASETS, TorchCustomDataset)
from modelscope.msdatasets.dataset_cls.custom_datasets.sidd_image_denoising.data_utils import (
img2tensor, padding) img2tensor, padding)
from modelscope.msdatasets.task_datasets.sidd_image_denoising.transforms import ( from modelscope.msdatasets.dataset_cls.custom_datasets.sidd_image_denoising.transforms import (
augment, paired_random_crop) augment, paired_random_crop)
from modelscope.msdatasets.task_datasets.torch_base_dataset import \
TorchTaskDataset
from modelscope.utils.constant import Tasks from modelscope.utils.constant import Tasks
@@ -18,9 +17,9 @@ def default_loader(path):
return cv2.imread(path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.0 return cv2.imread(path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.0
@TASK_DATASETS.register_module( @CUSTOM_DATASETS.register_module(
Tasks.image_deblurring, module_name=Datasets.PairedDataset) Tasks.image_deblurring, module_name=CustomDatasets.PairedDataset)
class RedsImageDeblurringDataset(TorchTaskDataset): class RedsImageDeblurringDataset(TorchCustomDataset):
"""Paired image dataset for image restoration. """Paired image dataset for image restoration.
""" """

View File

@@ -0,0 +1,21 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import TYPE_CHECKING
from modelscope.utils.import_utils import LazyImportModule
if TYPE_CHECKING:
from .referring_video_object_segmentation_dataset import ReferringVideoObjectSegmentationDataset
else:
_import_structure = {
'referring_video_object_segmentation_dataset':
['MovieSceneSegmentationDataset'],
}
import sys
sys.modules[__name__] = LazyImportModule(
__name__,
globals()['__file__'],
_import_structure,
module_spec=__spec__,
extra_objects={},
)

View File

@@ -18,9 +18,8 @@ from tqdm import tqdm
from modelscope.metainfo import Models from modelscope.metainfo import Models
from modelscope.models.cv.referring_video_object_segmentation.utils import \ from modelscope.models.cv.referring_video_object_segmentation.utils import \
nested_tensor_from_videos_list nested_tensor_from_videos_list
from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS from modelscope.msdatasets.dataset_cls.custom_datasets import (
from modelscope.msdatasets.task_datasets.torch_base_dataset import \ CUSTOM_DATASETS, TorchCustomDataset)
TorchTaskDataset
from modelscope.utils.constant import Tasks from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger from modelscope.utils.logger import get_logger
from . import transformers as T from . import transformers as T
@@ -33,10 +32,10 @@ def get_image_id(video_id, frame_idx, ref_instance_a2d_id):
return image_id return image_id
@TASK_DATASETS.register_module( @CUSTOM_DATASETS.register_module(
Tasks.referring_video_object_segmentation, Tasks.referring_video_object_segmentation,
module_name=Models.referring_video_object_segmentation) module_name=Models.referring_video_object_segmentation)
class ReferringVideoObjectSegmentationDataset(TorchTaskDataset): class ReferringVideoObjectSegmentationDataset(TorchCustomDataset):
def __init__(self, **kwargs): def __init__(self, **kwargs):
split_config = kwargs['split_config'] split_config = kwargs['split_config']

View File

@@ -4,9 +4,8 @@ import cv2
import numpy as np import numpy as np
from modelscope.metainfo import Models from modelscope.metainfo import Models
from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS from modelscope.msdatasets.dataset_cls.custom_datasets import (
from modelscope.msdatasets.task_datasets.torch_base_dataset import \ CUSTOM_DATASETS, TorchCustomDataset)
TorchTaskDataset
from modelscope.utils.constant import Tasks from modelscope.utils.constant import Tasks
from .data_utils import img2tensor, padding from .data_utils import img2tensor, padding
from .transforms import augment, paired_random_crop from .transforms import augment, paired_random_crop
@@ -16,9 +15,9 @@ def default_loader(path):
return cv2.imread(path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.0 return cv2.imread(path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.0
@TASK_DATASETS.register_module( @CUSTOM_DATASETS.register_module(
Tasks.image_denoising, module_name=Models.nafnet) Tasks.image_denoising, module_name=Models.nafnet)
class SiddImageDenoisingDataset(TorchTaskDataset): class SiddImageDenoisingDataset(TorchCustomDataset):
"""Paired image dataset for image restoration. """Paired image dataset for image restoration.
""" """

View File

@@ -1,25 +1,21 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.
import random import random
from dataclasses import dataclass from typing import Any, List, Union
from typing import Any, Dict, List, Tuple, Union
import torch import torch
from datasets import Dataset, IterableDataset, concatenate_datasets
from torch.utils.data import ConcatDataset from torch.utils.data import ConcatDataset
from transformers import DataCollatorWithPadding
from modelscope.metainfo import Models from modelscope.metainfo import Models
from modelscope.msdatasets.dataset_cls.custom_datasets import (
CUSTOM_DATASETS, TorchCustomDataset)
from modelscope.utils.constant import ModeKeys, Tasks from modelscope.utils.constant import ModeKeys, Tasks
from .base import TaskDataset
from .builder import TASK_DATASETS
from .torch_base_dataset import TorchTaskDataset
@TASK_DATASETS.register_module( @CUSTOM_DATASETS.register_module(
group_key=Tasks.text_ranking, module_name=Models.bert) group_key=Tasks.text_ranking, module_name=Models.bert)
@TASK_DATASETS.register_module( @CUSTOM_DATASETS.register_module(
group_key=Tasks.sentence_embedding, module_name=Models.bert) group_key=Tasks.sentence_embedding, module_name=Models.bert)
class TextRankingDataset(TorchTaskDataset): class TextRankingDataset(TorchCustomDataset):
def __init__(self, def __init__(self,
datasets: Union[Any, List[Any]], datasets: Union[Any, List[Any]],

View File

@@ -0,0 +1,51 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Any, List, Union
import torch.utils.data
from torch.utils.data import ConcatDataset as TorchConcatDataset
from modelscope.utils.constant import ModeKeys
class TorchCustomDataset(torch.utils.data.Dataset):
"""The custom dataset base class for all the torch-based task processors.
"""
def __init__(self,
datasets: Union[Any, List[Any]],
mode=ModeKeys.TRAIN,
preprocessor=None,
**kwargs):
self.trainer = None
self.mode = mode
self.preprocessor = preprocessor
self._inner_dataset = self.prepare_dataset(datasets)
def __getitem__(self, index) -> Any:
return self.preprocessor(
self._inner_dataset[index]
) if self.preprocessor else self._inner_dataset[index]
def __len__(self):
return len(self._inner_dataset)
def prepare_dataset(self, datasets: Union[Any, List[Any]]) -> Any:
"""Prepare a dataset.
User can process the input datasets in a whole dataset perspective.
This method gives a default implementation of datasets merging, user can override this
method to write custom logics.
Args:
datasets: The original dataset(s)
Returns: A single dataset, which may be created after merging.
"""
if isinstance(datasets, List):
if len(datasets) == 1:
return datasets[0]
elif len(datasets) > 1:
return TorchConcatDataset(datasets)
else:
return datasets

View File

@@ -5,13 +5,13 @@ import numpy as np
from datasets import Dataset, IterableDataset, concatenate_datasets from datasets import Dataset, IterableDataset, concatenate_datasets
from modelscope.metainfo import Models from modelscope.metainfo import Models
from modelscope.msdatasets.dataset_cls.custom_datasets import (
CUSTOM_DATASETS, TorchCustomDataset)
from modelscope.utils.constant import Tasks from modelscope.utils.constant import Tasks
from .builder import TASK_DATASETS
from .torch_base_dataset import TorchTaskDataset
@TASK_DATASETS.register_module(module_name=Models.veco, group_key=Tasks.nli) @CUSTOM_DATASETS.register_module(module_name=Models.veco, group_key=Tasks.nli)
class VecoDataset(TorchTaskDataset): class VecoDataset(TorchCustomDataset):
def __init__(self, def __init__(self,
datasets: Union[Any, List[Any]], datasets: Union[Any, List[Any]],

View File

@@ -1,16 +1,13 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.
from collections import defaultdict
import cv2 import cv2
import numpy as np import numpy as np
import torch import torch
from modelscope.metainfo import Models from modelscope.metainfo import Models
from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS from modelscope.msdatasets.dataset_cls.custom_datasets import (
from modelscope.msdatasets.task_datasets.torch_base_dataset import \ CUSTOM_DATASETS, TorchCustomDataset)
TorchTaskDataset from modelscope.msdatasets.dataset_cls.custom_datasets.video_frame_interpolation.data_utils import (
from modelscope.msdatasets.task_datasets.video_frame_interpolation.data_utils import (
img2tensor, img_padding) img2tensor, img_padding)
from modelscope.utils.constant import Tasks from modelscope.utils.constant import Tasks
@@ -19,10 +16,10 @@ def default_loader(path):
return cv2.imread(path, cv2.IMREAD_UNCHANGED).astype(np.float32) return cv2.imread(path, cv2.IMREAD_UNCHANGED).astype(np.float32)
@TASK_DATASETS.register_module( @CUSTOM_DATASETS.register_module(
Tasks.video_frame_interpolation, Tasks.video_frame_interpolation,
module_name=Models.video_frame_interpolation) module_name=Models.video_frame_interpolation)
class VideoFrameInterpolationDataset(TorchTaskDataset): class VideoFrameInterpolationDataset(TorchCustomDataset):
"""Dataset for video frame-interpolation. """Dataset for video frame-interpolation.
""" """

View File

@@ -1,15 +1,14 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.
from modelscope.metainfo import Models from modelscope.metainfo import Models
from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS from modelscope.msdatasets.dataset_cls.custom_datasets import (
from modelscope.msdatasets.task_datasets.torch_base_dataset import \ CUSTOM_DATASETS, TorchCustomDataset)
TorchTaskDataset
from modelscope.utils.constant import Tasks from modelscope.utils.constant import Tasks
@TASK_DATASETS.register_module( @CUSTOM_DATASETS.register_module(
Tasks.video_stabilization, module_name=Models.video_stabilization) Tasks.video_stabilization, module_name=Models.video_stabilization)
class VideoStabilizationDataset(TorchTaskDataset): class VideoStabilizationDataset(TorchCustomDataset):
"""Paired video dataset for video stabilization. """Paired video dataset for video stabilization.
""" """

View File

@@ -8,11 +8,11 @@ import json
import numpy as np import numpy as np
import torch import torch
from modelscope.msdatasets.task_datasets.torch_base_dataset import \ from modelscope.msdatasets.dataset_cls.custom_datasets import \
TorchTaskDataset TorchCustomDataset
class VideoSummarizationDataset(TorchTaskDataset): class VideoSummarizationDataset(TorchCustomDataset):
def __init__(self, mode, opt, root_dir): def __init__(self, mode, opt, root_dir):
self.mode = mode self.mode = mode

Some files were not shown because too many files have changed in this diff Show More