mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 08:17:45 +01:00
Refactor the task_datasets module
Refactor the task_datasets module: 1. Add new module modelscope.msdatasets.dataset_cls.custom_datasets. 2. Add new function: modelscope.msdatasets.ms_dataset.MsDataset.to_custom_dataset(). 2. Add calling to_custom_dataset() func in MsDataset.load() to adapt new custom_datasets module. 3. Refactor the pipeline for loading custom dataset: 1) Only use MsDataset.load() function to load the custom datasets. 2) Combine MsDataset.load() with class EpochBasedTrainer. 4. Add new entry func for building datasets in EpochBasedTrainer: see modelscope.trainers.trainer.EpochBasedTrainer.build_dataset() 5. Add new func to build the custom dataset from model configuration, see: modelscope.trainers.trainer.EpochBasedTrainer.build_dataset_from_cfg() 6. Add new registry function for building custom datasets, see: modelscope.msdatasets.dataset_cls.custom_datasets.builder.build_custom_dataset() 7. Refine the class SiameseUIETrainer to adapt the new custom_datasets module. 8. Add class TorchCustomDataset as a superclass for custom datasets classes. 9. To move modules/classes/functions: 1) Move module msdatasets.audio to custom_datasets 2) Move module msdatasets.cv to custom_datasets 3) Move module bad_image_detecting to custom_datasets 4) Move module damoyolo to custom_datasets 5) Move module face_2d_keypoints to custom_datasets 6) Move module hand_2d_keypoints to custom_datasets 7) Move module human_wholebody_keypoint to custom_datasets 8) Move module image_classification to custom_datasets 9) Move module image_inpainting to custom_datasets 10) Move module image_portrait_enhancement to custom_datasets 11) Move module image_quality_assessment_degradation to custom_datasets 12) Move module image_quality_assmessment_mos to custom_datasets 13) Move class LanguageGuidedVideoSummarizationDataset to custom_datasets 14) Move class MGeoRankingDataset to custom_datasets 15) Move module movie_scene_segmentation custom_datasets 16) Move module object_detection to custom_datasets 17) Move module referring_video_object_segmentation to custom_datasets 18) Move module sidd_image_denoising to custom_datasets 19) Move module video_frame_interpolation to custom_datasets 20) Move module video_stabilization to custom_datasets 21) Move module video_super_resolution to custom_datasets 22) Move class GoproImageDeblurringDataset to custom_datasets 23) Move class EasyCVBaseDataset to custom_datasets 24) Move class ImageInstanceSegmentationCocoDataset to custom_datasets 25) Move class RedsImageDeblurringDataset to custom_datasets 26) Move class TextRankingDataset to custom_datasets 27) Move class VecoDataset to custom_datasets 28) Move class VideoSummarizationDataset to custom_datasets 10. To delete modules/functions/classes: 1) Del module task_datasets 2) Del to_task_dataset() in EpochBasedTrainer 3) Del build_dataset() in EpochBasedTrainer and renew a same name function. 11. Rename class Datasets to CustomDatasets in metainfo.py Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11872747
This commit is contained in:
@@ -1,14 +0,0 @@
|
||||
modelscope.msdatasets.cv
|
||||
================================
|
||||
|
||||
.. automodule:: modelscope.msdatasets.cv
|
||||
|
||||
.. currentmodule:: modelscope.msdatasets.cv
|
||||
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
:template: classtemplate.rst
|
||||
|
||||
easycv_base.EasyCVBaseDataset
|
||||
image_classification.ClsDataset
|
||||
@@ -0,0 +1,41 @@
|
||||
modelscope.msdatasets.dataset_cls.custom_datasets
|
||||
====================
|
||||
|
||||
.. automodule:: modelscope.msdatasets.dataset_cls.custom_datasets
|
||||
|
||||
.. currentmodule:: modelscope.msdatasets.dataset_cls.custom_datasets
|
||||
|
||||
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
:template: classtemplate.rst
|
||||
|
||||
EasyCVBaseDataset
|
||||
TorchCustomDataset
|
||||
MovieSceneSegmentationDataset
|
||||
ImageInstanceSegmentationCocoDataset
|
||||
GoproImageDeblurringDataset
|
||||
LanguageGuidedVideoSummarizationDataset
|
||||
MGeoRankingDataset
|
||||
RedsImageDeblurringDataset
|
||||
TextRankingDataset
|
||||
VecoDataset
|
||||
VideoSummarizationDataset
|
||||
BadImageDetectingDataset
|
||||
ImageInpaintingDataset
|
||||
ImagePortraitEnhancementDataset
|
||||
ImageQualityAssessmentDegradationDataset
|
||||
ImageQualityAssessmentMosDataset
|
||||
ReferringVideoObjectSegmentationDataset
|
||||
SiddImageDenoisingDataset
|
||||
VideoFrameInterpolationDataset
|
||||
VideoStabilizationDataset
|
||||
VideoSuperResolutionDataset
|
||||
SegDataset
|
||||
FaceKeypointDataset
|
||||
HandCocoWholeBodyDataset
|
||||
WholeBodyCocoTopDownDataset
|
||||
ClsDataset
|
||||
DetImagesMixDataset
|
||||
DetDataset
|
||||
15
docs/source/api/modelscope.msdatasets.dataset_cls.rst
Normal file
15
docs/source/api/modelscope.msdatasets.dataset_cls.rst
Normal file
@@ -0,0 +1,15 @@
|
||||
modelscope.msdatasets.dataset_cls
|
||||
====================
|
||||
|
||||
.. automodule:: modelscope.msdatasets.dataset_cls
|
||||
|
||||
.. currentmodule:: modelscope.msdatasets.dataset_cls
|
||||
|
||||
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
:template: classtemplate.rst
|
||||
|
||||
ExternalDataset
|
||||
NativeIterableDataset
|
||||
@@ -10,5 +10,4 @@ modelscope.msdatasets.ms_dataset
|
||||
:nosignatures:
|
||||
:template: classtemplate.rst
|
||||
|
||||
MsMapDataset
|
||||
MsDataset
|
||||
|
||||
@@ -1137,7 +1137,7 @@ class LR_Schedulers(object):
|
||||
ExponentialWarmup = 'ExponentialWarmup'
|
||||
|
||||
|
||||
class Datasets(object):
|
||||
class CustomDatasets(object):
|
||||
""" Names for different datasets.
|
||||
"""
|
||||
ClsDataset = 'ClsDataset'
|
||||
|
||||
@@ -8,8 +8,8 @@ from modelscope.models.cv.tinynas_detection.damo.apis.detector_inference import
|
||||
inference
|
||||
from modelscope.models.cv.tinynas_detection.damo.detectors.detector import \
|
||||
build_local_model
|
||||
from modelscope.msdatasets.task_datasets.damoyolo import (build_dataloader,
|
||||
build_dataset)
|
||||
from modelscope.msdatasets.dataset_cls.custom_datasets.damoyolo import (
|
||||
build_dataloader, build_dataset)
|
||||
|
||||
|
||||
def mkdir(path):
|
||||
|
||||
@@ -5,7 +5,7 @@ import os
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from modelscope.msdatasets.task_datasets.damoyolo.evaluation import evaluate
|
||||
from modelscope.msdatasets.dataset_cls.custom_datasets.damoyolo import evaluate
|
||||
from modelscope.utils.logger import get_logger
|
||||
from modelscope.utils.timer import Timer, get_time_str
|
||||
from modelscope.utils.torch_utils import (all_gather, get_world_size,
|
||||
|
||||
@@ -1,3 +1,2 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from . import cv
|
||||
from .ms_dataset import MsDataset
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from . import (image_classification, image_semantic_segmentation,
|
||||
object_detection)
|
||||
@@ -13,6 +13,7 @@ from modelscope.msdatasets.context.dataset_context_config import \
|
||||
DatasetContextConfig
|
||||
from modelscope.msdatasets.data_files.data_files_manager import \
|
||||
DataFilesManager
|
||||
from modelscope.msdatasets.dataset_cls.dataset import ExternalDataset
|
||||
from modelscope.msdatasets.meta.data_meta_manager import DataMetaManager
|
||||
from modelscope.utils.constant import DatasetFormations
|
||||
|
||||
@@ -62,7 +63,8 @@ class OssDataLoader(BaseDataLoader):
|
||||
|
||||
self.data_files_builder: Optional[DataFilesManager] = None
|
||||
self.dataset: Optional[Union[Dataset, IterableDataset, DatasetDict,
|
||||
IterableDatasetDict]] = None
|
||||
IterableDatasetDict,
|
||||
ExternalDataset]] = None
|
||||
self.builder: Optional[DatasetBuilder] = None
|
||||
self.data_files_manager: Optional[DataFilesManager] = None
|
||||
|
||||
@@ -141,7 +143,8 @@ class OssDataLoader(BaseDataLoader):
|
||||
self.builder)
|
||||
|
||||
def _post_process(self) -> None:
|
||||
...
|
||||
if isinstance(self.dataset, ExternalDataset):
|
||||
self.dataset.custom_map = self.dataset_context_config.data_meta_config.meta_type_map
|
||||
|
||||
|
||||
class MaxComputeDataLoader(BaseDataLoader):
|
||||
|
||||
@@ -1 +1,3 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from .dataset import ExternalDataset, NativeIterableDataset
|
||||
|
||||
@@ -0,0 +1,84 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .easycv_base import EasyCVBaseDataset
|
||||
from .builder import CUSTOM_DATASETS, build_custom_dataset
|
||||
from .torch_custom_dataset import TorchCustomDataset
|
||||
from .movie_scene_segmentation.movie_scene_segmentation_dataset import MovieSceneSegmentationDataset
|
||||
from .image_instance_segmentation_coco_dataset import ImageInstanceSegmentationCocoDataset
|
||||
from .gopro_image_deblurring_dataset import GoproImageDeblurringDataset
|
||||
from .language_guided_video_summarization_dataset import LanguageGuidedVideoSummarizationDataset
|
||||
from .mgeo_ranking_dataset import MGeoRankingDataset
|
||||
from .reds_image_deblurring_dataset import RedsImageDeblurringDataset
|
||||
from .text_ranking_dataset import TextRankingDataset
|
||||
from .veco_dataset import VecoDataset
|
||||
from .video_summarization_dataset import VideoSummarizationDataset
|
||||
from .audio import KWSDataset, KWSDataLoader, kws_nearfield_dataset
|
||||
from .bad_image_detecting import BadImageDetectingDataset
|
||||
from .image_inpainting import ImageInpaintingDataset
|
||||
from .image_portrait_enhancement import ImagePortraitEnhancementDataset
|
||||
from .image_quality_assessment_degradation import ImageQualityAssessmentDegradationDataset
|
||||
from .image_quality_assmessment_mos import ImageQualityAssessmentMosDataset
|
||||
from .referring_video_object_segmentation import ReferringVideoObjectSegmentationDataset
|
||||
from .sidd_image_denoising import SiddImageDenoisingDataset
|
||||
from .video_frame_interpolation import VideoFrameInterpolationDataset
|
||||
from .video_stabilization import VideoStabilizationDataset
|
||||
from .video_super_resolution import VideoSuperResolutionDataset
|
||||
from .image_semantic_segmentation import SegDataset
|
||||
from .face_2d_keypoins import FaceKeypointDataset
|
||||
from .hand_2d_keypoints import HandCocoWholeBodyDataset
|
||||
from .human_wholebody_keypoint import WholeBodyCocoTopDownDataset
|
||||
from .image_classification import ClsDataset
|
||||
from .object_detection import DetDataset, DetImagesMixDataset
|
||||
from .ocr_detection import DataLoader, ImageDataset, QuadMeasurer
|
||||
from .ocr_recognition_dataset import OCRRecognitionDataset
|
||||
else:
|
||||
_import_structure = {
|
||||
'easycv_base': ['EasyCVBaseDataset'],
|
||||
'builder': ['CUSTOM_DATASETS', 'build_custom_dataset'],
|
||||
'torch_custom_dataset': ['TorchCustomDataset'],
|
||||
'movie_scene_segmentation_dataset': ['MovieSceneSegmentationDataset'],
|
||||
'image_instance_segmentation_coco_dataset':
|
||||
['ImageInstanceSegmentationCocoDataset'],
|
||||
'gopro_image_deblurring_dataset': ['GoproImageDeblurringDataset'],
|
||||
'language_guided_video_summarization_dataset':
|
||||
['LanguageGuidedVideoSummarizationDataset'],
|
||||
'mgeo_ranking_dataset': ['MGeoRankingDataset'],
|
||||
'reds_image_deblurring_dataset': ['RedsImageDeblurringDataset'],
|
||||
'text_ranking_dataset': ['TextRankingDataset'],
|
||||
'veco_dataset': ['VecoDataset'],
|
||||
'video_summarization_dataset': ['VideoSummarizationDataset'],
|
||||
'audio': ['KWSDataset', 'KWSDataLoader', 'kws_nearfield_dataset'],
|
||||
'bad_image_detecting': ['BadImageDetectingDataset'],
|
||||
'image_inpainting': ['ImageInpaintingDataset'],
|
||||
'image_portrait_enhancement': ['ImagePortraitEnhancementDataset'],
|
||||
'image_quality_assessment_degradation':
|
||||
['ImageQualityAssessmentDegradationDataset'],
|
||||
'image_quality_assmessment_mos': ['ImageQualityAssessmentMosDataset'],
|
||||
'referring_video_object_segmentation':
|
||||
['ReferringVideoObjectSegmentationDataset'],
|
||||
'sidd_image_denoising': ['SiddImageDenoisingDataset'],
|
||||
'video_frame_interpolation': ['VideoFrameInterpolationDataset'],
|
||||
'video_stabilization': ['VideoStabilizationDataset'],
|
||||
'video_super_resolution': ['VideoSuperResolutionDataset'],
|
||||
'image_semantic_segmentation': ['SegDataset'],
|
||||
'face_2d_keypoins': ['FaceKeypointDataset'],
|
||||
'hand_2d_keypoints': ['HandCocoWholeBodyDataset'],
|
||||
'human_wholebody_keypoint': ['WholeBodyCocoTopDownDataset'],
|
||||
'image_classification': ['ClsDataset'],
|
||||
'object_detection': ['DetDataset', 'DetImagesMixDataset'],
|
||||
'ocr_detection': ['DataLoader', 'ImageDataset', 'QuadMeasurer'],
|
||||
'ocr_recognition_dataset': ['OCRRecognitionDataset'],
|
||||
}
|
||||
|
||||
import sys
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
@@ -5,7 +5,6 @@ import math
|
||||
import os.path
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -18,7 +18,7 @@ import torch
|
||||
import torch.distributed as dist
|
||||
from torch.utils.data import IterableDataset
|
||||
|
||||
import modelscope.msdatasets.task_datasets.audio.kws_nearfield_processor as processor
|
||||
import modelscope.msdatasets.dataset_cls.custom_datasets.audio.kws_nearfield_processor as processor
|
||||
from modelscope.trainers.audio.kws_utils.file_utils import (make_pair,
|
||||
read_lists)
|
||||
from modelscope.utils.logger import get_logger
|
||||
@@ -1,12 +1,8 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS
|
||||
from modelscope.msdatasets.task_datasets.torch_base_dataset import \
|
||||
TorchTaskDataset
|
||||
from modelscope.msdatasets.dataset_cls.custom_datasets import (
|
||||
CUSTOM_DATASETS, TorchCustomDataset)
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.preprocessors import LoadImage
|
||||
from modelscope.preprocessors.cv.bad_image_detecting_preprocessor import \
|
||||
@@ -14,9 +10,9 @@ from modelscope.preprocessors.cv.bad_image_detecting_preprocessor import \
|
||||
from modelscope.utils.constant import Tasks
|
||||
|
||||
|
||||
@TASK_DATASETS.register_module(
|
||||
@CUSTOM_DATASETS.register_module(
|
||||
Tasks.bad_image_detecting, module_name=Models.bad_image_detecting)
|
||||
class BadImageDetectingDataset(TorchTaskDataset):
|
||||
class BadImageDetectingDataset(TorchCustomDataset):
|
||||
"""Paired image dataset for bad image detecting.
|
||||
"""
|
||||
|
||||
@@ -3,13 +3,13 @@
|
||||
from modelscope.utils.config import ConfigDict
|
||||
from modelscope.utils.registry import Registry, build_from_cfg
|
||||
|
||||
TASK_DATASETS = Registry('task_datasets')
|
||||
CUSTOM_DATASETS = Registry('custom_datasets')
|
||||
|
||||
|
||||
def build_task_dataset(cfg: ConfigDict,
|
||||
task_name: str = None,
|
||||
default_args: dict = None):
|
||||
""" Build task specific dataset processor given model config dict and the task name.
|
||||
def build_custom_dataset(cfg: ConfigDict,
|
||||
task_name: str,
|
||||
default_args: dict = None):
|
||||
""" Build custom dataset for user-define dataset given model config and task name.
|
||||
|
||||
Args:
|
||||
cfg (:obj:`ConfigDict`): config dict for model object.
|
||||
@@ -18,4 +18,4 @@ def build_task_dataset(cfg: ConfigDict,
|
||||
default_args (dict, optional): Default initialization arguments.
|
||||
"""
|
||||
return build_from_cfg(
|
||||
cfg, TASK_DATASETS, group_key=task_name, default_args=default_args)
|
||||
cfg, CUSTOM_DATASETS, group_key=task_name, default_args=default_args)
|
||||
@@ -1,2 +1,3 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from .build import build_dataloader, build_dataset
|
||||
from .evaluation import evaluate
|
||||
@@ -1,6 +1,6 @@
|
||||
# Copyright © Alibaba, Inc. and its affiliates.
|
||||
|
||||
from modelscope.msdatasets.task_datasets.damoyolo import datasets
|
||||
from .. import datasets
|
||||
from .coco import coco_evaluation
|
||||
|
||||
|
||||
@@ -1,15 +1,16 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from easycv.datasets.face import FaceKeypointDataset as _FaceKeypointDataset
|
||||
|
||||
from modelscope.metainfo import Datasets
|
||||
from modelscope.msdatasets.cv.easycv_base import EasyCVBaseDataset
|
||||
from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS
|
||||
from modelscope.metainfo import CustomDatasets
|
||||
from modelscope.msdatasets.dataset_cls.custom_datasets import CUSTOM_DATASETS
|
||||
from modelscope.msdatasets.dataset_cls.custom_datasets.easycv_base import \
|
||||
EasyCVBaseDataset
|
||||
from modelscope.utils.constant import Tasks
|
||||
|
||||
|
||||
@TASK_DATASETS.register_module(
|
||||
@CUSTOM_DATASETS.register_module(
|
||||
group_key=Tasks.face_2d_keypoints,
|
||||
module_name=Datasets.Face2dKeypointsDataset)
|
||||
module_name=CustomDatasets.Face2dKeypointsDataset)
|
||||
class FaceKeypointDataset(EasyCVBaseDataset, _FaceKeypointDataset):
|
||||
"""EasyCV dataset for face 2d keypoints.
|
||||
|
||||
@@ -3,14 +3,13 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from modelscope.metainfo import Datasets
|
||||
from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS
|
||||
from modelscope.msdatasets.task_datasets.sidd_image_denoising.data_utils import (
|
||||
from modelscope.metainfo import CustomDatasets
|
||||
from modelscope.msdatasets.dataset_cls.custom_datasets import (
|
||||
CUSTOM_DATASETS, TorchCustomDataset)
|
||||
from modelscope.msdatasets.dataset_cls.custom_datasets.sidd_image_denoising.data_utils import (
|
||||
img2tensor, padding)
|
||||
from modelscope.msdatasets.task_datasets.sidd_image_denoising.transforms import (
|
||||
from modelscope.msdatasets.dataset_cls.custom_datasets.sidd_image_denoising.transforms import (
|
||||
augment, paired_random_crop)
|
||||
from modelscope.msdatasets.task_datasets.torch_base_dataset import \
|
||||
TorchTaskDataset
|
||||
from modelscope.utils.constant import Tasks
|
||||
|
||||
|
||||
@@ -18,9 +17,9 @@ def default_loader(path):
|
||||
return cv2.imread(path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.0
|
||||
|
||||
|
||||
@TASK_DATASETS.register_module(
|
||||
Tasks.image_deblurring, module_name=Datasets.PairedDataset)
|
||||
class GoproImageDeblurringDataset(TorchTaskDataset):
|
||||
@CUSTOM_DATASETS.register_module(
|
||||
Tasks.image_deblurring, module_name=CustomDatasets.PairedDataset)
|
||||
class GoproImageDeblurringDataset(TorchCustomDataset):
|
||||
"""Paired image dataset for image restoration.
|
||||
"""
|
||||
|
||||
@@ -0,0 +1,22 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .hand_2d_keypoints_dataset import HandCocoWholeBodyDataset
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
'hand_2d_keypoints_dataset': ['HandCocoWholeBodyDataset']
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
@@ -2,15 +2,16 @@
|
||||
from easycv.datasets.pose import \
|
||||
HandCocoWholeBodyDataset as _HandCocoWholeBodyDataset
|
||||
|
||||
from modelscope.metainfo import Datasets
|
||||
from modelscope.msdatasets.cv.easycv_base import EasyCVBaseDataset
|
||||
from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS
|
||||
from modelscope.metainfo import CustomDatasets
|
||||
from modelscope.msdatasets.dataset_cls.custom_datasets import CUSTOM_DATASETS
|
||||
from modelscope.msdatasets.dataset_cls.custom_datasets.easycv_base import \
|
||||
EasyCVBaseDataset
|
||||
from modelscope.utils.constant import Tasks
|
||||
|
||||
|
||||
@TASK_DATASETS.register_module(
|
||||
@CUSTOM_DATASETS.register_module(
|
||||
group_key=Tasks.hand_2d_keypoints,
|
||||
module_name=Datasets.HandCocoWholeBodyDataset)
|
||||
module_name=CustomDatasets.HandCocoWholeBodyDataset)
|
||||
class HandCocoWholeBodyDataset(EasyCVBaseDataset, _HandCocoWholeBodyDataset):
|
||||
"""EasyCV dataset for human hand 2d keypoints.
|
||||
|
||||
@@ -2,15 +2,16 @@
|
||||
from easycv.datasets.pose import \
|
||||
WholeBodyCocoTopDownDataset as _WholeBodyCocoTopDownDataset
|
||||
|
||||
from modelscope.metainfo import Datasets
|
||||
from modelscope.msdatasets.cv.easycv_base import EasyCVBaseDataset
|
||||
from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS
|
||||
from modelscope.metainfo import CustomDatasets
|
||||
from modelscope.msdatasets.dataset_cls.custom_datasets import CUSTOM_DATASETS
|
||||
from modelscope.msdatasets.dataset_cls.custom_datasets.easycv_base import \
|
||||
EasyCVBaseDataset
|
||||
from modelscope.utils.constant import Tasks
|
||||
|
||||
|
||||
@TASK_DATASETS.register_module(
|
||||
@CUSTOM_DATASETS.register_module(
|
||||
group_key=Tasks.human_wholebody_keypoint,
|
||||
module_name=Datasets.HumanWholeBodyKeypointDataset)
|
||||
module_name=CustomDatasets.HumanWholeBodyKeypointDataset)
|
||||
class WholeBodyCocoTopDownDataset(EasyCVBaseDataset,
|
||||
_WholeBodyCocoTopDownDataset):
|
||||
"""EasyCV dataset for human whole body 2d keypoints.
|
||||
@@ -1,14 +1,16 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from easycv.datasets.classification import ClsDataset as _ClsDataset
|
||||
|
||||
from modelscope.metainfo import Datasets
|
||||
from modelscope.msdatasets.cv.easycv_base import EasyCVBaseDataset
|
||||
from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS
|
||||
from modelscope.metainfo import CustomDatasets
|
||||
from modelscope.msdatasets.dataset_cls.custom_datasets import CUSTOM_DATASETS
|
||||
from modelscope.msdatasets.dataset_cls.custom_datasets.easycv_base import \
|
||||
EasyCVBaseDataset
|
||||
from modelscope.utils.constant import Tasks
|
||||
|
||||
|
||||
@TASK_DATASETS.register_module(
|
||||
group_key=Tasks.image_classification, module_name=Datasets.ClsDataset)
|
||||
@CUSTOM_DATASETS.register_module(
|
||||
group_key=Tasks.image_classification,
|
||||
module_name=CustomDatasets.ClsDataset)
|
||||
class ClsDataset(_ClsDataset):
|
||||
"""EasyCV dataset for classification.
|
||||
|
||||
@@ -4,13 +4,11 @@ from typing import TYPE_CHECKING
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .hand_2d_keypoints_dataset import Hand2DKeypointDataset
|
||||
|
||||
from .image_inpainting_dataset import ImageInpaintingDataset
|
||||
else:
|
||||
_import_structure = {
|
||||
'hand_2d_keypoints_dataset': ['Hand2DKeypointDataset']
|
||||
'image_inpainting_dataset': ['ImageInpaintingDataset'],
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
@@ -3,20 +3,16 @@ Part of the implementation is borrowed and modified from LaMa,
|
||||
publicly available at https://github.com/saic-mdal/lama
|
||||
"""
|
||||
import glob
|
||||
import os
|
||||
import os.path as osp
|
||||
from enum import Enum
|
||||
|
||||
import albumentations as A
|
||||
import cv2
|
||||
import json
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS
|
||||
from modelscope.msdatasets.task_datasets.torch_base_dataset import \
|
||||
TorchTaskDataset
|
||||
from modelscope.msdatasets.dataset_cls.custom_datasets import (
|
||||
CUSTOM_DATASETS, TorchCustomDataset)
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
from .aug import IAAAffine2, IAAPerspective2
|
||||
@@ -296,9 +292,9 @@ def get_transforms(test_mode, out_size):
|
||||
return transform
|
||||
|
||||
|
||||
@TASK_DATASETS.register_module(
|
||||
@CUSTOM_DATASETS.register_module(
|
||||
Tasks.image_inpainting, module_name=Models.image_inpainting)
|
||||
class ImageInpaintingDataset(TorchTaskDataset):
|
||||
class ImageInpaintingDataset(TorchCustomDataset):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
split_config = kwargs['split_config']
|
||||
@@ -6,9 +6,9 @@ import numpy as np
|
||||
from pycocotools.coco import COCO
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.msdatasets.dataset_cls.custom_datasets import (
|
||||
CUSTOM_DATASETS, TorchCustomDataset)
|
||||
from modelscope.utils.constant import Tasks
|
||||
from .builder import TASK_DATASETS
|
||||
from .torch_base_dataset import TorchTaskDataset
|
||||
|
||||
DATASET_STRUCTURE = {
|
||||
'train': {
|
||||
@@ -22,10 +22,10 @@ DATASET_STRUCTURE = {
|
||||
}
|
||||
|
||||
|
||||
@TASK_DATASETS.register_module(
|
||||
@CUSTOM_DATASETS.register_module(
|
||||
module_name=Models.cascade_mask_rcnn_swin,
|
||||
group_key=Tasks.image_segmentation)
|
||||
class ImageInstanceSegmentationCocoDataset(TorchTaskDataset):
|
||||
class ImageInstanceSegmentationCocoDataset(TorchCustomDataset):
|
||||
"""Coco-style dataset for image instance segmentation.
|
||||
|
||||
Args:
|
||||
@@ -3,10 +3,9 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from modelscope.metainfo import Datasets, Models
|
||||
from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS
|
||||
from modelscope.msdatasets.task_datasets.torch_base_dataset import \
|
||||
TorchTaskDataset
|
||||
from modelscope.metainfo import CustomDatasets
|
||||
from modelscope.msdatasets.dataset_cls.custom_datasets import (
|
||||
CUSTOM_DATASETS, TorchCustomDataset)
|
||||
from modelscope.utils.constant import Tasks
|
||||
from .data_utils import img2tensor
|
||||
|
||||
@@ -15,9 +14,9 @@ def default_loader(path):
|
||||
return cv2.imread(path, cv2.IMREAD_COLOR).astype(np.float32) / 255.0
|
||||
|
||||
|
||||
@TASK_DATASETS.register_module(
|
||||
Tasks.image_portrait_enhancement, module_name=Datasets.PairedDataset)
|
||||
class ImagePortraitEnhancementDataset(TorchTaskDataset):
|
||||
@CUSTOM_DATASETS.register_module(
|
||||
Tasks.image_portrait_enhancement, module_name=CustomDatasets.PairedDataset)
|
||||
class ImagePortraitEnhancementDataset(TorchCustomDataset):
|
||||
"""Paired image dataset for image portrait enhancement.
|
||||
"""
|
||||
|
||||
@@ -1,21 +1,18 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from torchvision import transforms
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS
|
||||
from modelscope.msdatasets.task_datasets.torch_base_dataset import \
|
||||
TorchTaskDataset
|
||||
from modelscope.msdatasets.dataset_cls.custom_datasets import (
|
||||
CUSTOM_DATASETS, TorchCustomDataset)
|
||||
from modelscope.preprocessors import LoadImage
|
||||
from modelscope.utils.constant import Tasks
|
||||
|
||||
|
||||
@TASK_DATASETS.register_module(
|
||||
@CUSTOM_DATASETS.register_module(
|
||||
Tasks.image_quality_assessment_degradation,
|
||||
module_name=Models.image_quality_assessment_degradation)
|
||||
class ImageQualityAssessmentDegradationDataset(TorchTaskDataset):
|
||||
class ImageQualityAssessmentDegradationDataset(TorchCustomDataset):
|
||||
"""Paired image dataset for image quality assessment degradation.
|
||||
"""
|
||||
|
||||
@@ -1,20 +1,16 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS
|
||||
from modelscope.msdatasets.task_datasets.torch_base_dataset import \
|
||||
TorchTaskDataset
|
||||
from modelscope.msdatasets.dataset_cls.custom_datasets import (
|
||||
CUSTOM_DATASETS, TorchCustomDataset)
|
||||
from modelscope.preprocessors.cv import ImageQualityAssessmentMosPreprocessor
|
||||
from modelscope.utils.constant import Tasks
|
||||
|
||||
|
||||
@TASK_DATASETS.register_module(
|
||||
@CUSTOM_DATASETS.register_module(
|
||||
Tasks.image_quality_assessment_mos,
|
||||
module_name=Models.image_quality_assessment_mos)
|
||||
class ImageQualityAssessmentMosDataset(TorchTaskDataset):
|
||||
class ImageQualityAssessmentMosDataset(TorchCustomDataset):
|
||||
"""Paired image dataset for image quality assessment mos.
|
||||
"""
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from easycv.datasets.segmentation import SegDataset as _SegDataset
|
||||
|
||||
from modelscope.metainfo import Datasets
|
||||
from modelscope.msdatasets.cv.easycv_base import EasyCVBaseDataset
|
||||
from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS
|
||||
from modelscope.metainfo import CustomDatasets
|
||||
from modelscope.msdatasets.dataset_cls.custom_datasets import CUSTOM_DATASETS
|
||||
from modelscope.msdatasets.dataset_cls.custom_datasets.easycv_base import \
|
||||
EasyCVBaseDataset
|
||||
from modelscope.utils.constant import Tasks
|
||||
|
||||
|
||||
@TASK_DATASETS.register_module(
|
||||
group_key=Tasks.image_segmentation, module_name=Datasets.SegDataset)
|
||||
@CUSTOM_DATASETS.register_module(
|
||||
group_key=Tasks.image_segmentation, module_name=CustomDatasets.SegDataset)
|
||||
class SegDataset(EasyCVBaseDataset, _SegDataset):
|
||||
"""EasyCV dataset for Sementic segmentation.
|
||||
For more details, please refer to :
|
||||
@@ -25,16 +25,15 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS
|
||||
from modelscope.msdatasets.task_datasets.torch_base_dataset import \
|
||||
TorchTaskDataset
|
||||
from modelscope.msdatasets.dataset_cls.custom_datasets import (
|
||||
CUSTOM_DATASETS, TorchCustomDataset)
|
||||
from modelscope.utils.constant import Tasks
|
||||
|
||||
|
||||
@TASK_DATASETS.register_module(
|
||||
@CUSTOM_DATASETS.register_module(
|
||||
Tasks.language_guided_video_summarization,
|
||||
module_name=Models.language_guided_video_summarization)
|
||||
class LanguageGuidedVideoSummarizationDataset(TorchTaskDataset):
|
||||
class LanguageGuidedVideoSummarizationDataset(TorchCustomDataset):
|
||||
|
||||
def __init__(self, mode, opt, root_dir):
|
||||
self.mode = mode
|
||||
@@ -1,24 +1,20 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
from typing import Any, List, Union
|
||||
|
||||
import json
|
||||
import torch
|
||||
from datasets import Dataset, IterableDataset, concatenate_datasets
|
||||
from torch.utils.data import ConcatDataset
|
||||
from transformers import DataCollatorWithPadding
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.msdatasets.dataset_cls.custom_datasets import (
|
||||
CUSTOM_DATASETS, TorchCustomDataset)
|
||||
from modelscope.utils.constant import ModeKeys, Tasks
|
||||
from .base import TaskDataset
|
||||
from .builder import TASK_DATASETS
|
||||
from .torch_base_dataset import TorchTaskDataset
|
||||
|
||||
|
||||
@TASK_DATASETS.register_module(
|
||||
@CUSTOM_DATASETS.register_module(
|
||||
group_key=Tasks.text_ranking, module_name=Models.mgeo)
|
||||
class MGeoRankingDataset(TorchTaskDataset):
|
||||
class MGeoRankingDataset(TorchCustomDataset):
|
||||
|
||||
def __init__(self,
|
||||
datasets: Union[Any, List[Any]],
|
||||
@@ -0,0 +1,20 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .movie_scene_segmentation_dataset import MovieSceneSegmentationDataset
|
||||
else:
|
||||
_import_structure = {
|
||||
'movie_scene_segmentation_dataset': ['MovieSceneSegmentationDataset'],
|
||||
}
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
@@ -10,9 +10,8 @@ import torch
|
||||
from torchvision.datasets.folder import pil_loader
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS
|
||||
from modelscope.msdatasets.task_datasets.torch_base_dataset import \
|
||||
TorchTaskDataset
|
||||
from modelscope.msdatasets.dataset_cls.custom_datasets.builder import \
|
||||
CUSTOM_DATASETS
|
||||
from modelscope.utils.constant import Tasks
|
||||
from . import sampler
|
||||
|
||||
@@ -30,9 +29,9 @@ DATASET_STRUCTURE = {
|
||||
}
|
||||
|
||||
|
||||
@TASK_DATASETS.register_module(
|
||||
Tasks.movie_scene_segmentation, module_name=Models.resnet50_bert)
|
||||
class MovieSceneSegmentationDataset(TorchTaskDataset):
|
||||
@CUSTOM_DATASETS.register_module(
|
||||
group_key=Tasks.movie_scene_segmentation, module_name=Models.resnet50_bert)
|
||||
class MovieSceneSegmentationDataset(torch.utils.data.Dataset):
|
||||
"""dataset for movie scene segmentation.
|
||||
|
||||
Args:
|
||||
@@ -1,20 +1,21 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os.path as osp
|
||||
|
||||
from easycv.datasets.detection import DetDataset as _DetDataset
|
||||
from easycv.datasets.detection import \
|
||||
DetImagesMixDataset as _DetImagesMixDataset
|
||||
|
||||
from modelscope.metainfo import Datasets
|
||||
from modelscope.msdatasets.cv.easycv_base import EasyCVBaseDataset
|
||||
from modelscope.msdatasets.task_datasets import TASK_DATASETS
|
||||
from modelscope.metainfo import CustomDatasets
|
||||
from modelscope.msdatasets.dataset_cls.custom_datasets import CUSTOM_DATASETS
|
||||
from modelscope.msdatasets.dataset_cls.custom_datasets.easycv_base import \
|
||||
EasyCVBaseDataset
|
||||
from modelscope.utils.constant import Tasks
|
||||
|
||||
|
||||
@TASK_DATASETS.register_module(
|
||||
group_key=Tasks.image_object_detection, module_name=Datasets.DetDataset)
|
||||
@TASK_DATASETS.register_module(
|
||||
group_key=Tasks.image_segmentation, module_name=Datasets.DetDataset)
|
||||
@CUSTOM_DATASETS.register_module(
|
||||
group_key=Tasks.image_object_detection,
|
||||
module_name=CustomDatasets.DetDataset)
|
||||
@CUSTOM_DATASETS.register_module(
|
||||
group_key=Tasks.image_segmentation, module_name=CustomDatasets.DetDataset)
|
||||
class DetDataset(EasyCVBaseDataset, _DetDataset):
|
||||
"""EasyCV dataset for object detection.
|
||||
For more details, please refer to https://github.com/alibaba/EasyCV/blob/master/easycv/datasets/detection/raw.py .
|
||||
@@ -47,12 +48,12 @@ class DetDataset(EasyCVBaseDataset, _DetDataset):
|
||||
_DetDataset.__init__(self, *args, **kwargs)
|
||||
|
||||
|
||||
@TASK_DATASETS.register_module(
|
||||
@CUSTOM_DATASETS.register_module(
|
||||
group_key=Tasks.image_object_detection,
|
||||
module_name=Datasets.DetImagesMixDataset)
|
||||
@TASK_DATASETS.register_module(
|
||||
module_name=CustomDatasets.DetImagesMixDataset)
|
||||
@CUSTOM_DATASETS.register_module(
|
||||
group_key=Tasks.domain_specific_object_detection,
|
||||
module_name=Datasets.DetImagesMixDataset)
|
||||
module_name=CustomDatasets.DetImagesMixDataset)
|
||||
class DetImagesMixDataset(EasyCVBaseDataset, _DetImagesMixDataset):
|
||||
"""EasyCV dataset for object detection, a wrapper of multiple images mixed dataset.
|
||||
Suitable for training on multiple images mixed data augmentation like
|
||||
@@ -1,3 +1,4 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from .data_loader import DataLoader
|
||||
from .image_dataset import ImageDataset
|
||||
from .measures import QuadMeasurer
|
||||
@@ -9,9 +9,10 @@ import torch
|
||||
from PIL import Image
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS
|
||||
from modelscope.msdatasets.task_datasets.torch_base_dataset import \
|
||||
TorchTaskDataset
|
||||
from modelscope.msdatasets.dataset_cls.custom_datasets.builder import \
|
||||
CUSTOM_DATASETS
|
||||
from modelscope.msdatasets.dataset_cls.custom_datasets.torch_custom_dataset import \
|
||||
TorchCustomDataset
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
@@ -29,9 +30,9 @@ def Q2B(uchar):
|
||||
return chr(inside_code)
|
||||
|
||||
|
||||
@TASK_DATASETS.register_module(
|
||||
@CUSTOM_DATASETS.register_module(
|
||||
Tasks.ocr_recognition, module_name=Models.ocr_recognition)
|
||||
class OCRRecognitionDataset(TorchTaskDataset):
|
||||
class OCRRecognitionDataset(TorchCustomDataset):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
split_config = kwargs['split_config']
|
||||
@@ -3,14 +3,13 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from modelscope.metainfo import Datasets
|
||||
from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS
|
||||
from modelscope.msdatasets.task_datasets.sidd_image_denoising.data_utils import (
|
||||
from modelscope.metainfo import CustomDatasets
|
||||
from modelscope.msdatasets.dataset_cls.custom_datasets import (
|
||||
CUSTOM_DATASETS, TorchCustomDataset)
|
||||
from modelscope.msdatasets.dataset_cls.custom_datasets.sidd_image_denoising.data_utils import (
|
||||
img2tensor, padding)
|
||||
from modelscope.msdatasets.task_datasets.sidd_image_denoising.transforms import (
|
||||
from modelscope.msdatasets.dataset_cls.custom_datasets.sidd_image_denoising.transforms import (
|
||||
augment, paired_random_crop)
|
||||
from modelscope.msdatasets.task_datasets.torch_base_dataset import \
|
||||
TorchTaskDataset
|
||||
from modelscope.utils.constant import Tasks
|
||||
|
||||
|
||||
@@ -18,9 +17,9 @@ def default_loader(path):
|
||||
return cv2.imread(path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.0
|
||||
|
||||
|
||||
@TASK_DATASETS.register_module(
|
||||
Tasks.image_deblurring, module_name=Datasets.PairedDataset)
|
||||
class RedsImageDeblurringDataset(TorchTaskDataset):
|
||||
@CUSTOM_DATASETS.register_module(
|
||||
Tasks.image_deblurring, module_name=CustomDatasets.PairedDataset)
|
||||
class RedsImageDeblurringDataset(TorchCustomDataset):
|
||||
"""Paired image dataset for image restoration.
|
||||
"""
|
||||
|
||||
@@ -0,0 +1,21 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .referring_video_object_segmentation_dataset import ReferringVideoObjectSegmentationDataset
|
||||
else:
|
||||
_import_structure = {
|
||||
'referring_video_object_segmentation_dataset':
|
||||
['MovieSceneSegmentationDataset'],
|
||||
}
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
@@ -18,9 +18,8 @@ from tqdm import tqdm
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models.cv.referring_video_object_segmentation.utils import \
|
||||
nested_tensor_from_videos_list
|
||||
from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS
|
||||
from modelscope.msdatasets.task_datasets.torch_base_dataset import \
|
||||
TorchTaskDataset
|
||||
from modelscope.msdatasets.dataset_cls.custom_datasets import (
|
||||
CUSTOM_DATASETS, TorchCustomDataset)
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
from . import transformers as T
|
||||
@@ -33,10 +32,10 @@ def get_image_id(video_id, frame_idx, ref_instance_a2d_id):
|
||||
return image_id
|
||||
|
||||
|
||||
@TASK_DATASETS.register_module(
|
||||
@CUSTOM_DATASETS.register_module(
|
||||
Tasks.referring_video_object_segmentation,
|
||||
module_name=Models.referring_video_object_segmentation)
|
||||
class ReferringVideoObjectSegmentationDataset(TorchTaskDataset):
|
||||
class ReferringVideoObjectSegmentationDataset(TorchCustomDataset):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
split_config = kwargs['split_config']
|
||||
@@ -4,9 +4,8 @@ import cv2
|
||||
import numpy as np
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS
|
||||
from modelscope.msdatasets.task_datasets.torch_base_dataset import \
|
||||
TorchTaskDataset
|
||||
from modelscope.msdatasets.dataset_cls.custom_datasets import (
|
||||
CUSTOM_DATASETS, TorchCustomDataset)
|
||||
from modelscope.utils.constant import Tasks
|
||||
from .data_utils import img2tensor, padding
|
||||
from .transforms import augment, paired_random_crop
|
||||
@@ -16,9 +15,9 @@ def default_loader(path):
|
||||
return cv2.imread(path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.0
|
||||
|
||||
|
||||
@TASK_DATASETS.register_module(
|
||||
@CUSTOM_DATASETS.register_module(
|
||||
Tasks.image_denoising, module_name=Models.nafnet)
|
||||
class SiddImageDenoisingDataset(TorchTaskDataset):
|
||||
class SiddImageDenoisingDataset(TorchCustomDataset):
|
||||
"""Paired image dataset for image restoration.
|
||||
"""
|
||||
|
||||
@@ -1,25 +1,21 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
from typing import Any, List, Union
|
||||
|
||||
import torch
|
||||
from datasets import Dataset, IterableDataset, concatenate_datasets
|
||||
from torch.utils.data import ConcatDataset
|
||||
from transformers import DataCollatorWithPadding
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.msdatasets.dataset_cls.custom_datasets import (
|
||||
CUSTOM_DATASETS, TorchCustomDataset)
|
||||
from modelscope.utils.constant import ModeKeys, Tasks
|
||||
from .base import TaskDataset
|
||||
from .builder import TASK_DATASETS
|
||||
from .torch_base_dataset import TorchTaskDataset
|
||||
|
||||
|
||||
@TASK_DATASETS.register_module(
|
||||
@CUSTOM_DATASETS.register_module(
|
||||
group_key=Tasks.text_ranking, module_name=Models.bert)
|
||||
@TASK_DATASETS.register_module(
|
||||
@CUSTOM_DATASETS.register_module(
|
||||
group_key=Tasks.sentence_embedding, module_name=Models.bert)
|
||||
class TextRankingDataset(TorchTaskDataset):
|
||||
class TextRankingDataset(TorchCustomDataset):
|
||||
|
||||
def __init__(self,
|
||||
datasets: Union[Any, List[Any]],
|
||||
@@ -0,0 +1,51 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import Any, List, Union
|
||||
|
||||
import torch.utils.data
|
||||
from torch.utils.data import ConcatDataset as TorchConcatDataset
|
||||
|
||||
from modelscope.utils.constant import ModeKeys
|
||||
|
||||
|
||||
class TorchCustomDataset(torch.utils.data.Dataset):
|
||||
"""The custom dataset base class for all the torch-based task processors.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
datasets: Union[Any, List[Any]],
|
||||
mode=ModeKeys.TRAIN,
|
||||
preprocessor=None,
|
||||
**kwargs):
|
||||
self.trainer = None
|
||||
self.mode = mode
|
||||
self.preprocessor = preprocessor
|
||||
self._inner_dataset = self.prepare_dataset(datasets)
|
||||
|
||||
def __getitem__(self, index) -> Any:
|
||||
return self.preprocessor(
|
||||
self._inner_dataset[index]
|
||||
) if self.preprocessor else self._inner_dataset[index]
|
||||
|
||||
def __len__(self):
|
||||
return len(self._inner_dataset)
|
||||
|
||||
def prepare_dataset(self, datasets: Union[Any, List[Any]]) -> Any:
|
||||
"""Prepare a dataset.
|
||||
|
||||
User can process the input datasets in a whole dataset perspective.
|
||||
This method gives a default implementation of datasets merging, user can override this
|
||||
method to write custom logics.
|
||||
|
||||
Args:
|
||||
datasets: The original dataset(s)
|
||||
|
||||
Returns: A single dataset, which may be created after merging.
|
||||
|
||||
"""
|
||||
if isinstance(datasets, List):
|
||||
if len(datasets) == 1:
|
||||
return datasets[0]
|
||||
elif len(datasets) > 1:
|
||||
return TorchConcatDataset(datasets)
|
||||
else:
|
||||
return datasets
|
||||
@@ -5,13 +5,13 @@ import numpy as np
|
||||
from datasets import Dataset, IterableDataset, concatenate_datasets
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.msdatasets.dataset_cls.custom_datasets import (
|
||||
CUSTOM_DATASETS, TorchCustomDataset)
|
||||
from modelscope.utils.constant import Tasks
|
||||
from .builder import TASK_DATASETS
|
||||
from .torch_base_dataset import TorchTaskDataset
|
||||
|
||||
|
||||
@TASK_DATASETS.register_module(module_name=Models.veco, group_key=Tasks.nli)
|
||||
class VecoDataset(TorchTaskDataset):
|
||||
@CUSTOM_DATASETS.register_module(module_name=Models.veco, group_key=Tasks.nli)
|
||||
class VecoDataset(TorchCustomDataset):
|
||||
|
||||
def __init__(self,
|
||||
datasets: Union[Any, List[Any]],
|
||||
@@ -1,16 +1,13 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from collections import defaultdict
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS
|
||||
from modelscope.msdatasets.task_datasets.torch_base_dataset import \
|
||||
TorchTaskDataset
|
||||
from modelscope.msdatasets.task_datasets.video_frame_interpolation.data_utils import (
|
||||
from modelscope.msdatasets.dataset_cls.custom_datasets import (
|
||||
CUSTOM_DATASETS, TorchCustomDataset)
|
||||
from modelscope.msdatasets.dataset_cls.custom_datasets.video_frame_interpolation.data_utils import (
|
||||
img2tensor, img_padding)
|
||||
from modelscope.utils.constant import Tasks
|
||||
|
||||
@@ -19,10 +16,10 @@ def default_loader(path):
|
||||
return cv2.imread(path, cv2.IMREAD_UNCHANGED).astype(np.float32)
|
||||
|
||||
|
||||
@TASK_DATASETS.register_module(
|
||||
@CUSTOM_DATASETS.register_module(
|
||||
Tasks.video_frame_interpolation,
|
||||
module_name=Models.video_frame_interpolation)
|
||||
class VideoFrameInterpolationDataset(TorchTaskDataset):
|
||||
class VideoFrameInterpolationDataset(TorchCustomDataset):
|
||||
"""Dataset for video frame-interpolation.
|
||||
"""
|
||||
|
||||
@@ -1,15 +1,14 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS
|
||||
from modelscope.msdatasets.task_datasets.torch_base_dataset import \
|
||||
TorchTaskDataset
|
||||
from modelscope.msdatasets.dataset_cls.custom_datasets import (
|
||||
CUSTOM_DATASETS, TorchCustomDataset)
|
||||
from modelscope.utils.constant import Tasks
|
||||
|
||||
|
||||
@TASK_DATASETS.register_module(
|
||||
@CUSTOM_DATASETS.register_module(
|
||||
Tasks.video_stabilization, module_name=Models.video_stabilization)
|
||||
class VideoStabilizationDataset(TorchTaskDataset):
|
||||
class VideoStabilizationDataset(TorchCustomDataset):
|
||||
"""Paired video dataset for video stabilization.
|
||||
"""
|
||||
|
||||
@@ -8,11 +8,11 @@ import json
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from modelscope.msdatasets.task_datasets.torch_base_dataset import \
|
||||
TorchTaskDataset
|
||||
from modelscope.msdatasets.dataset_cls.custom_datasets import \
|
||||
TorchCustomDataset
|
||||
|
||||
|
||||
class VideoSummarizationDataset(TorchTaskDataset):
|
||||
class VideoSummarizationDataset(TorchCustomDataset):
|
||||
|
||||
def __init__(self, mode, opt, root_dir):
|
||||
self.mode = mode
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user