mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 08:17:45 +01:00
Refactor the task_datasets module
Refactor the task_datasets module: 1. Add new module modelscope.msdatasets.dataset_cls.custom_datasets. 2. Add new function: modelscope.msdatasets.ms_dataset.MsDataset.to_custom_dataset(). 2. Add calling to_custom_dataset() func in MsDataset.load() to adapt new custom_datasets module. 3. Refactor the pipeline for loading custom dataset: 1) Only use MsDataset.load() function to load the custom datasets. 2) Combine MsDataset.load() with class EpochBasedTrainer. 4. Add new entry func for building datasets in EpochBasedTrainer: see modelscope.trainers.trainer.EpochBasedTrainer.build_dataset() 5. Add new func to build the custom dataset from model configuration, see: modelscope.trainers.trainer.EpochBasedTrainer.build_dataset_from_cfg() 6. Add new registry function for building custom datasets, see: modelscope.msdatasets.dataset_cls.custom_datasets.builder.build_custom_dataset() 7. Refine the class SiameseUIETrainer to adapt the new custom_datasets module. 8. Add class TorchCustomDataset as a superclass for custom datasets classes. 9. To move modules/classes/functions: 1) Move module msdatasets.audio to custom_datasets 2) Move module msdatasets.cv to custom_datasets 3) Move module bad_image_detecting to custom_datasets 4) Move module damoyolo to custom_datasets 5) Move module face_2d_keypoints to custom_datasets 6) Move module hand_2d_keypoints to custom_datasets 7) Move module human_wholebody_keypoint to custom_datasets 8) Move module image_classification to custom_datasets 9) Move module image_inpainting to custom_datasets 10) Move module image_portrait_enhancement to custom_datasets 11) Move module image_quality_assessment_degradation to custom_datasets 12) Move module image_quality_assmessment_mos to custom_datasets 13) Move class LanguageGuidedVideoSummarizationDataset to custom_datasets 14) Move class MGeoRankingDataset to custom_datasets 15) Move module movie_scene_segmentation custom_datasets 16) Move module object_detection to custom_datasets 17) Move module referring_video_object_segmentation to custom_datasets 18) Move module sidd_image_denoising to custom_datasets 19) Move module video_frame_interpolation to custom_datasets 20) Move module video_stabilization to custom_datasets 21) Move module video_super_resolution to custom_datasets 22) Move class GoproImageDeblurringDataset to custom_datasets 23) Move class EasyCVBaseDataset to custom_datasets 24) Move class ImageInstanceSegmentationCocoDataset to custom_datasets 25) Move class RedsImageDeblurringDataset to custom_datasets 26) Move class TextRankingDataset to custom_datasets 27) Move class VecoDataset to custom_datasets 28) Move class VideoSummarizationDataset to custom_datasets 10. To delete modules/functions/classes: 1) Del module task_datasets 2) Del to_task_dataset() in EpochBasedTrainer 3) Del build_dataset() in EpochBasedTrainer and renew a same name function. 11. Rename class Datasets to CustomDatasets in metainfo.py Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11872747
This commit is contained in:
@@ -1,14 +0,0 @@
|
|||||||
modelscope.msdatasets.cv
|
|
||||||
================================
|
|
||||||
|
|
||||||
.. automodule:: modelscope.msdatasets.cv
|
|
||||||
|
|
||||||
.. currentmodule:: modelscope.msdatasets.cv
|
|
||||||
|
|
||||||
.. autosummary::
|
|
||||||
:toctree: generated
|
|
||||||
:nosignatures:
|
|
||||||
:template: classtemplate.rst
|
|
||||||
|
|
||||||
easycv_base.EasyCVBaseDataset
|
|
||||||
image_classification.ClsDataset
|
|
||||||
@@ -0,0 +1,41 @@
|
|||||||
|
modelscope.msdatasets.dataset_cls.custom_datasets
|
||||||
|
====================
|
||||||
|
|
||||||
|
.. automodule:: modelscope.msdatasets.dataset_cls.custom_datasets
|
||||||
|
|
||||||
|
.. currentmodule:: modelscope.msdatasets.dataset_cls.custom_datasets
|
||||||
|
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: generated
|
||||||
|
:nosignatures:
|
||||||
|
:template: classtemplate.rst
|
||||||
|
|
||||||
|
EasyCVBaseDataset
|
||||||
|
TorchCustomDataset
|
||||||
|
MovieSceneSegmentationDataset
|
||||||
|
ImageInstanceSegmentationCocoDataset
|
||||||
|
GoproImageDeblurringDataset
|
||||||
|
LanguageGuidedVideoSummarizationDataset
|
||||||
|
MGeoRankingDataset
|
||||||
|
RedsImageDeblurringDataset
|
||||||
|
TextRankingDataset
|
||||||
|
VecoDataset
|
||||||
|
VideoSummarizationDataset
|
||||||
|
BadImageDetectingDataset
|
||||||
|
ImageInpaintingDataset
|
||||||
|
ImagePortraitEnhancementDataset
|
||||||
|
ImageQualityAssessmentDegradationDataset
|
||||||
|
ImageQualityAssessmentMosDataset
|
||||||
|
ReferringVideoObjectSegmentationDataset
|
||||||
|
SiddImageDenoisingDataset
|
||||||
|
VideoFrameInterpolationDataset
|
||||||
|
VideoStabilizationDataset
|
||||||
|
VideoSuperResolutionDataset
|
||||||
|
SegDataset
|
||||||
|
FaceKeypointDataset
|
||||||
|
HandCocoWholeBodyDataset
|
||||||
|
WholeBodyCocoTopDownDataset
|
||||||
|
ClsDataset
|
||||||
|
DetImagesMixDataset
|
||||||
|
DetDataset
|
||||||
15
docs/source/api/modelscope.msdatasets.dataset_cls.rst
Normal file
15
docs/source/api/modelscope.msdatasets.dataset_cls.rst
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
modelscope.msdatasets.dataset_cls
|
||||||
|
====================
|
||||||
|
|
||||||
|
.. automodule:: modelscope.msdatasets.dataset_cls
|
||||||
|
|
||||||
|
.. currentmodule:: modelscope.msdatasets.dataset_cls
|
||||||
|
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: generated
|
||||||
|
:nosignatures:
|
||||||
|
:template: classtemplate.rst
|
||||||
|
|
||||||
|
ExternalDataset
|
||||||
|
NativeIterableDataset
|
||||||
@@ -10,5 +10,4 @@ modelscope.msdatasets.ms_dataset
|
|||||||
:nosignatures:
|
:nosignatures:
|
||||||
:template: classtemplate.rst
|
:template: classtemplate.rst
|
||||||
|
|
||||||
MsMapDataset
|
|
||||||
MsDataset
|
MsDataset
|
||||||
|
|||||||
@@ -1137,7 +1137,7 @@ class LR_Schedulers(object):
|
|||||||
ExponentialWarmup = 'ExponentialWarmup'
|
ExponentialWarmup = 'ExponentialWarmup'
|
||||||
|
|
||||||
|
|
||||||
class Datasets(object):
|
class CustomDatasets(object):
|
||||||
""" Names for different datasets.
|
""" Names for different datasets.
|
||||||
"""
|
"""
|
||||||
ClsDataset = 'ClsDataset'
|
ClsDataset = 'ClsDataset'
|
||||||
|
|||||||
@@ -8,8 +8,8 @@ from modelscope.models.cv.tinynas_detection.damo.apis.detector_inference import
|
|||||||
inference
|
inference
|
||||||
from modelscope.models.cv.tinynas_detection.damo.detectors.detector import \
|
from modelscope.models.cv.tinynas_detection.damo.detectors.detector import \
|
||||||
build_local_model
|
build_local_model
|
||||||
from modelscope.msdatasets.task_datasets.damoyolo import (build_dataloader,
|
from modelscope.msdatasets.dataset_cls.custom_datasets.damoyolo import (
|
||||||
build_dataset)
|
build_dataloader, build_dataset)
|
||||||
|
|
||||||
|
|
||||||
def mkdir(path):
|
def mkdir(path):
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import os
|
|||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from modelscope.msdatasets.task_datasets.damoyolo.evaluation import evaluate
|
from modelscope.msdatasets.dataset_cls.custom_datasets.damoyolo import evaluate
|
||||||
from modelscope.utils.logger import get_logger
|
from modelscope.utils.logger import get_logger
|
||||||
from modelscope.utils.timer import Timer, get_time_str
|
from modelscope.utils.timer import Timer, get_time_str
|
||||||
from modelscope.utils.torch_utils import (all_gather, get_world_size,
|
from modelscope.utils.torch_utils import (all_gather, get_world_size,
|
||||||
|
|||||||
@@ -1,3 +1,2 @@
|
|||||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||||
from . import cv
|
|
||||||
from .ms_dataset import MsDataset
|
from .ms_dataset import MsDataset
|
||||||
|
|||||||
@@ -1,3 +0,0 @@
|
|||||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
||||||
from . import (image_classification, image_semantic_segmentation,
|
|
||||||
object_detection)
|
|
||||||
@@ -13,6 +13,7 @@ from modelscope.msdatasets.context.dataset_context_config import \
|
|||||||
DatasetContextConfig
|
DatasetContextConfig
|
||||||
from modelscope.msdatasets.data_files.data_files_manager import \
|
from modelscope.msdatasets.data_files.data_files_manager import \
|
||||||
DataFilesManager
|
DataFilesManager
|
||||||
|
from modelscope.msdatasets.dataset_cls.dataset import ExternalDataset
|
||||||
from modelscope.msdatasets.meta.data_meta_manager import DataMetaManager
|
from modelscope.msdatasets.meta.data_meta_manager import DataMetaManager
|
||||||
from modelscope.utils.constant import DatasetFormations
|
from modelscope.utils.constant import DatasetFormations
|
||||||
|
|
||||||
@@ -62,7 +63,8 @@ class OssDataLoader(BaseDataLoader):
|
|||||||
|
|
||||||
self.data_files_builder: Optional[DataFilesManager] = None
|
self.data_files_builder: Optional[DataFilesManager] = None
|
||||||
self.dataset: Optional[Union[Dataset, IterableDataset, DatasetDict,
|
self.dataset: Optional[Union[Dataset, IterableDataset, DatasetDict,
|
||||||
IterableDatasetDict]] = None
|
IterableDatasetDict,
|
||||||
|
ExternalDataset]] = None
|
||||||
self.builder: Optional[DatasetBuilder] = None
|
self.builder: Optional[DatasetBuilder] = None
|
||||||
self.data_files_manager: Optional[DataFilesManager] = None
|
self.data_files_manager: Optional[DataFilesManager] = None
|
||||||
|
|
||||||
@@ -141,7 +143,8 @@ class OssDataLoader(BaseDataLoader):
|
|||||||
self.builder)
|
self.builder)
|
||||||
|
|
||||||
def _post_process(self) -> None:
|
def _post_process(self) -> None:
|
||||||
...
|
if isinstance(self.dataset, ExternalDataset):
|
||||||
|
self.dataset.custom_map = self.dataset_context_config.data_meta_config.meta_type_map
|
||||||
|
|
||||||
|
|
||||||
class MaxComputeDataLoader(BaseDataLoader):
|
class MaxComputeDataLoader(BaseDataLoader):
|
||||||
|
|||||||
@@ -1 +1,3 @@
|
|||||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||||
|
|
||||||
|
from .dataset import ExternalDataset, NativeIterableDataset
|
||||||
|
|||||||
@@ -0,0 +1,84 @@
|
|||||||
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from modelscope.utils.import_utils import LazyImportModule
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .easycv_base import EasyCVBaseDataset
|
||||||
|
from .builder import CUSTOM_DATASETS, build_custom_dataset
|
||||||
|
from .torch_custom_dataset import TorchCustomDataset
|
||||||
|
from .movie_scene_segmentation.movie_scene_segmentation_dataset import MovieSceneSegmentationDataset
|
||||||
|
from .image_instance_segmentation_coco_dataset import ImageInstanceSegmentationCocoDataset
|
||||||
|
from .gopro_image_deblurring_dataset import GoproImageDeblurringDataset
|
||||||
|
from .language_guided_video_summarization_dataset import LanguageGuidedVideoSummarizationDataset
|
||||||
|
from .mgeo_ranking_dataset import MGeoRankingDataset
|
||||||
|
from .reds_image_deblurring_dataset import RedsImageDeblurringDataset
|
||||||
|
from .text_ranking_dataset import TextRankingDataset
|
||||||
|
from .veco_dataset import VecoDataset
|
||||||
|
from .video_summarization_dataset import VideoSummarizationDataset
|
||||||
|
from .audio import KWSDataset, KWSDataLoader, kws_nearfield_dataset
|
||||||
|
from .bad_image_detecting import BadImageDetectingDataset
|
||||||
|
from .image_inpainting import ImageInpaintingDataset
|
||||||
|
from .image_portrait_enhancement import ImagePortraitEnhancementDataset
|
||||||
|
from .image_quality_assessment_degradation import ImageQualityAssessmentDegradationDataset
|
||||||
|
from .image_quality_assmessment_mos import ImageQualityAssessmentMosDataset
|
||||||
|
from .referring_video_object_segmentation import ReferringVideoObjectSegmentationDataset
|
||||||
|
from .sidd_image_denoising import SiddImageDenoisingDataset
|
||||||
|
from .video_frame_interpolation import VideoFrameInterpolationDataset
|
||||||
|
from .video_stabilization import VideoStabilizationDataset
|
||||||
|
from .video_super_resolution import VideoSuperResolutionDataset
|
||||||
|
from .image_semantic_segmentation import SegDataset
|
||||||
|
from .face_2d_keypoins import FaceKeypointDataset
|
||||||
|
from .hand_2d_keypoints import HandCocoWholeBodyDataset
|
||||||
|
from .human_wholebody_keypoint import WholeBodyCocoTopDownDataset
|
||||||
|
from .image_classification import ClsDataset
|
||||||
|
from .object_detection import DetDataset, DetImagesMixDataset
|
||||||
|
from .ocr_detection import DataLoader, ImageDataset, QuadMeasurer
|
||||||
|
from .ocr_recognition_dataset import OCRRecognitionDataset
|
||||||
|
else:
|
||||||
|
_import_structure = {
|
||||||
|
'easycv_base': ['EasyCVBaseDataset'],
|
||||||
|
'builder': ['CUSTOM_DATASETS', 'build_custom_dataset'],
|
||||||
|
'torch_custom_dataset': ['TorchCustomDataset'],
|
||||||
|
'movie_scene_segmentation_dataset': ['MovieSceneSegmentationDataset'],
|
||||||
|
'image_instance_segmentation_coco_dataset':
|
||||||
|
['ImageInstanceSegmentationCocoDataset'],
|
||||||
|
'gopro_image_deblurring_dataset': ['GoproImageDeblurringDataset'],
|
||||||
|
'language_guided_video_summarization_dataset':
|
||||||
|
['LanguageGuidedVideoSummarizationDataset'],
|
||||||
|
'mgeo_ranking_dataset': ['MGeoRankingDataset'],
|
||||||
|
'reds_image_deblurring_dataset': ['RedsImageDeblurringDataset'],
|
||||||
|
'text_ranking_dataset': ['TextRankingDataset'],
|
||||||
|
'veco_dataset': ['VecoDataset'],
|
||||||
|
'video_summarization_dataset': ['VideoSummarizationDataset'],
|
||||||
|
'audio': ['KWSDataset', 'KWSDataLoader', 'kws_nearfield_dataset'],
|
||||||
|
'bad_image_detecting': ['BadImageDetectingDataset'],
|
||||||
|
'image_inpainting': ['ImageInpaintingDataset'],
|
||||||
|
'image_portrait_enhancement': ['ImagePortraitEnhancementDataset'],
|
||||||
|
'image_quality_assessment_degradation':
|
||||||
|
['ImageQualityAssessmentDegradationDataset'],
|
||||||
|
'image_quality_assmessment_mos': ['ImageQualityAssessmentMosDataset'],
|
||||||
|
'referring_video_object_segmentation':
|
||||||
|
['ReferringVideoObjectSegmentationDataset'],
|
||||||
|
'sidd_image_denoising': ['SiddImageDenoisingDataset'],
|
||||||
|
'video_frame_interpolation': ['VideoFrameInterpolationDataset'],
|
||||||
|
'video_stabilization': ['VideoStabilizationDataset'],
|
||||||
|
'video_super_resolution': ['VideoSuperResolutionDataset'],
|
||||||
|
'image_semantic_segmentation': ['SegDataset'],
|
||||||
|
'face_2d_keypoins': ['FaceKeypointDataset'],
|
||||||
|
'hand_2d_keypoints': ['HandCocoWholeBodyDataset'],
|
||||||
|
'human_wholebody_keypoint': ['WholeBodyCocoTopDownDataset'],
|
||||||
|
'image_classification': ['ClsDataset'],
|
||||||
|
'object_detection': ['DetDataset', 'DetImagesMixDataset'],
|
||||||
|
'ocr_detection': ['DataLoader', 'ImageDataset', 'QuadMeasurer'],
|
||||||
|
'ocr_recognition_dataset': ['OCRRecognitionDataset'],
|
||||||
|
}
|
||||||
|
|
||||||
|
import sys
|
||||||
|
sys.modules[__name__] = LazyImportModule(
|
||||||
|
__name__,
|
||||||
|
globals()['__file__'],
|
||||||
|
_import_structure,
|
||||||
|
module_spec=__spec__,
|
||||||
|
extra_objects={},
|
||||||
|
)
|
||||||
@@ -5,7 +5,6 @@ import math
|
|||||||
import os.path
|
import os.path
|
||||||
import queue
|
import queue
|
||||||
import threading
|
import threading
|
||||||
import time
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -18,7 +18,7 @@ import torch
|
|||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch.utils.data import IterableDataset
|
from torch.utils.data import IterableDataset
|
||||||
|
|
||||||
import modelscope.msdatasets.task_datasets.audio.kws_nearfield_processor as processor
|
import modelscope.msdatasets.dataset_cls.custom_datasets.audio.kws_nearfield_processor as processor
|
||||||
from modelscope.trainers.audio.kws_utils.file_utils import (make_pair,
|
from modelscope.trainers.audio.kws_utils.file_utils import (make_pair,
|
||||||
read_lists)
|
read_lists)
|
||||||
from modelscope.utils.logger import get_logger
|
from modelscope.utils.logger import get_logger
|
||||||
@@ -1,12 +1,8 @@
|
|||||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||||
|
|
||||||
import cv2
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from modelscope.metainfo import Models
|
from modelscope.metainfo import Models
|
||||||
from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS
|
from modelscope.msdatasets.dataset_cls.custom_datasets import (
|
||||||
from modelscope.msdatasets.task_datasets.torch_base_dataset import \
|
CUSTOM_DATASETS, TorchCustomDataset)
|
||||||
TorchTaskDataset
|
|
||||||
from modelscope.outputs import OutputKeys
|
from modelscope.outputs import OutputKeys
|
||||||
from modelscope.preprocessors import LoadImage
|
from modelscope.preprocessors import LoadImage
|
||||||
from modelscope.preprocessors.cv.bad_image_detecting_preprocessor import \
|
from modelscope.preprocessors.cv.bad_image_detecting_preprocessor import \
|
||||||
@@ -14,9 +10,9 @@ from modelscope.preprocessors.cv.bad_image_detecting_preprocessor import \
|
|||||||
from modelscope.utils.constant import Tasks
|
from modelscope.utils.constant import Tasks
|
||||||
|
|
||||||
|
|
||||||
@TASK_DATASETS.register_module(
|
@CUSTOM_DATASETS.register_module(
|
||||||
Tasks.bad_image_detecting, module_name=Models.bad_image_detecting)
|
Tasks.bad_image_detecting, module_name=Models.bad_image_detecting)
|
||||||
class BadImageDetectingDataset(TorchTaskDataset):
|
class BadImageDetectingDataset(TorchCustomDataset):
|
||||||
"""Paired image dataset for bad image detecting.
|
"""Paired image dataset for bad image detecting.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -3,13 +3,13 @@
|
|||||||
from modelscope.utils.config import ConfigDict
|
from modelscope.utils.config import ConfigDict
|
||||||
from modelscope.utils.registry import Registry, build_from_cfg
|
from modelscope.utils.registry import Registry, build_from_cfg
|
||||||
|
|
||||||
TASK_DATASETS = Registry('task_datasets')
|
CUSTOM_DATASETS = Registry('custom_datasets')
|
||||||
|
|
||||||
|
|
||||||
def build_task_dataset(cfg: ConfigDict,
|
def build_custom_dataset(cfg: ConfigDict,
|
||||||
task_name: str = None,
|
task_name: str,
|
||||||
default_args: dict = None):
|
default_args: dict = None):
|
||||||
""" Build task specific dataset processor given model config dict and the task name.
|
""" Build custom dataset for user-define dataset given model config and task name.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cfg (:obj:`ConfigDict`): config dict for model object.
|
cfg (:obj:`ConfigDict`): config dict for model object.
|
||||||
@@ -18,4 +18,4 @@ def build_task_dataset(cfg: ConfigDict,
|
|||||||
default_args (dict, optional): Default initialization arguments.
|
default_args (dict, optional): Default initialization arguments.
|
||||||
"""
|
"""
|
||||||
return build_from_cfg(
|
return build_from_cfg(
|
||||||
cfg, TASK_DATASETS, group_key=task_name, default_args=default_args)
|
cfg, CUSTOM_DATASETS, group_key=task_name, default_args=default_args)
|
||||||
@@ -1,2 +1,3 @@
|
|||||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||||
from .build import build_dataloader, build_dataset
|
from .build import build_dataloader, build_dataset
|
||||||
|
from .evaluation import evaluate
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
# Copyright © Alibaba, Inc. and its affiliates.
|
# Copyright © Alibaba, Inc. and its affiliates.
|
||||||
|
|
||||||
from modelscope.msdatasets.task_datasets.damoyolo import datasets
|
from .. import datasets
|
||||||
from .coco import coco_evaluation
|
from .coco import coco_evaluation
|
||||||
|
|
||||||
|
|
||||||
@@ -1,15 +1,16 @@
|
|||||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||||
from easycv.datasets.face import FaceKeypointDataset as _FaceKeypointDataset
|
from easycv.datasets.face import FaceKeypointDataset as _FaceKeypointDataset
|
||||||
|
|
||||||
from modelscope.metainfo import Datasets
|
from modelscope.metainfo import CustomDatasets
|
||||||
from modelscope.msdatasets.cv.easycv_base import EasyCVBaseDataset
|
from modelscope.msdatasets.dataset_cls.custom_datasets import CUSTOM_DATASETS
|
||||||
from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS
|
from modelscope.msdatasets.dataset_cls.custom_datasets.easycv_base import \
|
||||||
|
EasyCVBaseDataset
|
||||||
from modelscope.utils.constant import Tasks
|
from modelscope.utils.constant import Tasks
|
||||||
|
|
||||||
|
|
||||||
@TASK_DATASETS.register_module(
|
@CUSTOM_DATASETS.register_module(
|
||||||
group_key=Tasks.face_2d_keypoints,
|
group_key=Tasks.face_2d_keypoints,
|
||||||
module_name=Datasets.Face2dKeypointsDataset)
|
module_name=CustomDatasets.Face2dKeypointsDataset)
|
||||||
class FaceKeypointDataset(EasyCVBaseDataset, _FaceKeypointDataset):
|
class FaceKeypointDataset(EasyCVBaseDataset, _FaceKeypointDataset):
|
||||||
"""EasyCV dataset for face 2d keypoints.
|
"""EasyCV dataset for face 2d keypoints.
|
||||||
|
|
||||||
@@ -3,14 +3,13 @@
|
|||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from modelscope.metainfo import Datasets
|
from modelscope.metainfo import CustomDatasets
|
||||||
from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS
|
from modelscope.msdatasets.dataset_cls.custom_datasets import (
|
||||||
from modelscope.msdatasets.task_datasets.sidd_image_denoising.data_utils import (
|
CUSTOM_DATASETS, TorchCustomDataset)
|
||||||
|
from modelscope.msdatasets.dataset_cls.custom_datasets.sidd_image_denoising.data_utils import (
|
||||||
img2tensor, padding)
|
img2tensor, padding)
|
||||||
from modelscope.msdatasets.task_datasets.sidd_image_denoising.transforms import (
|
from modelscope.msdatasets.dataset_cls.custom_datasets.sidd_image_denoising.transforms import (
|
||||||
augment, paired_random_crop)
|
augment, paired_random_crop)
|
||||||
from modelscope.msdatasets.task_datasets.torch_base_dataset import \
|
|
||||||
TorchTaskDataset
|
|
||||||
from modelscope.utils.constant import Tasks
|
from modelscope.utils.constant import Tasks
|
||||||
|
|
||||||
|
|
||||||
@@ -18,9 +17,9 @@ def default_loader(path):
|
|||||||
return cv2.imread(path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.0
|
return cv2.imread(path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.0
|
||||||
|
|
||||||
|
|
||||||
@TASK_DATASETS.register_module(
|
@CUSTOM_DATASETS.register_module(
|
||||||
Tasks.image_deblurring, module_name=Datasets.PairedDataset)
|
Tasks.image_deblurring, module_name=CustomDatasets.PairedDataset)
|
||||||
class GoproImageDeblurringDataset(TorchTaskDataset):
|
class GoproImageDeblurringDataset(TorchCustomDataset):
|
||||||
"""Paired image dataset for image restoration.
|
"""Paired image dataset for image restoration.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -0,0 +1,22 @@
|
|||||||
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from modelscope.utils.import_utils import LazyImportModule
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .hand_2d_keypoints_dataset import HandCocoWholeBodyDataset
|
||||||
|
|
||||||
|
else:
|
||||||
|
_import_structure = {
|
||||||
|
'hand_2d_keypoints_dataset': ['HandCocoWholeBodyDataset']
|
||||||
|
}
|
||||||
|
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.modules[__name__] = LazyImportModule(
|
||||||
|
__name__,
|
||||||
|
globals()['__file__'],
|
||||||
|
_import_structure,
|
||||||
|
module_spec=__spec__,
|
||||||
|
extra_objects={},
|
||||||
|
)
|
||||||
@@ -2,15 +2,16 @@
|
|||||||
from easycv.datasets.pose import \
|
from easycv.datasets.pose import \
|
||||||
HandCocoWholeBodyDataset as _HandCocoWholeBodyDataset
|
HandCocoWholeBodyDataset as _HandCocoWholeBodyDataset
|
||||||
|
|
||||||
from modelscope.metainfo import Datasets
|
from modelscope.metainfo import CustomDatasets
|
||||||
from modelscope.msdatasets.cv.easycv_base import EasyCVBaseDataset
|
from modelscope.msdatasets.dataset_cls.custom_datasets import CUSTOM_DATASETS
|
||||||
from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS
|
from modelscope.msdatasets.dataset_cls.custom_datasets.easycv_base import \
|
||||||
|
EasyCVBaseDataset
|
||||||
from modelscope.utils.constant import Tasks
|
from modelscope.utils.constant import Tasks
|
||||||
|
|
||||||
|
|
||||||
@TASK_DATASETS.register_module(
|
@CUSTOM_DATASETS.register_module(
|
||||||
group_key=Tasks.hand_2d_keypoints,
|
group_key=Tasks.hand_2d_keypoints,
|
||||||
module_name=Datasets.HandCocoWholeBodyDataset)
|
module_name=CustomDatasets.HandCocoWholeBodyDataset)
|
||||||
class HandCocoWholeBodyDataset(EasyCVBaseDataset, _HandCocoWholeBodyDataset):
|
class HandCocoWholeBodyDataset(EasyCVBaseDataset, _HandCocoWholeBodyDataset):
|
||||||
"""EasyCV dataset for human hand 2d keypoints.
|
"""EasyCV dataset for human hand 2d keypoints.
|
||||||
|
|
||||||
@@ -2,15 +2,16 @@
|
|||||||
from easycv.datasets.pose import \
|
from easycv.datasets.pose import \
|
||||||
WholeBodyCocoTopDownDataset as _WholeBodyCocoTopDownDataset
|
WholeBodyCocoTopDownDataset as _WholeBodyCocoTopDownDataset
|
||||||
|
|
||||||
from modelscope.metainfo import Datasets
|
from modelscope.metainfo import CustomDatasets
|
||||||
from modelscope.msdatasets.cv.easycv_base import EasyCVBaseDataset
|
from modelscope.msdatasets.dataset_cls.custom_datasets import CUSTOM_DATASETS
|
||||||
from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS
|
from modelscope.msdatasets.dataset_cls.custom_datasets.easycv_base import \
|
||||||
|
EasyCVBaseDataset
|
||||||
from modelscope.utils.constant import Tasks
|
from modelscope.utils.constant import Tasks
|
||||||
|
|
||||||
|
|
||||||
@TASK_DATASETS.register_module(
|
@CUSTOM_DATASETS.register_module(
|
||||||
group_key=Tasks.human_wholebody_keypoint,
|
group_key=Tasks.human_wholebody_keypoint,
|
||||||
module_name=Datasets.HumanWholeBodyKeypointDataset)
|
module_name=CustomDatasets.HumanWholeBodyKeypointDataset)
|
||||||
class WholeBodyCocoTopDownDataset(EasyCVBaseDataset,
|
class WholeBodyCocoTopDownDataset(EasyCVBaseDataset,
|
||||||
_WholeBodyCocoTopDownDataset):
|
_WholeBodyCocoTopDownDataset):
|
||||||
"""EasyCV dataset for human whole body 2d keypoints.
|
"""EasyCV dataset for human whole body 2d keypoints.
|
||||||
@@ -1,14 +1,16 @@
|
|||||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||||
from easycv.datasets.classification import ClsDataset as _ClsDataset
|
from easycv.datasets.classification import ClsDataset as _ClsDataset
|
||||||
|
|
||||||
from modelscope.metainfo import Datasets
|
from modelscope.metainfo import CustomDatasets
|
||||||
from modelscope.msdatasets.cv.easycv_base import EasyCVBaseDataset
|
from modelscope.msdatasets.dataset_cls.custom_datasets import CUSTOM_DATASETS
|
||||||
from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS
|
from modelscope.msdatasets.dataset_cls.custom_datasets.easycv_base import \
|
||||||
|
EasyCVBaseDataset
|
||||||
from modelscope.utils.constant import Tasks
|
from modelscope.utils.constant import Tasks
|
||||||
|
|
||||||
|
|
||||||
@TASK_DATASETS.register_module(
|
@CUSTOM_DATASETS.register_module(
|
||||||
group_key=Tasks.image_classification, module_name=Datasets.ClsDataset)
|
group_key=Tasks.image_classification,
|
||||||
|
module_name=CustomDatasets.ClsDataset)
|
||||||
class ClsDataset(_ClsDataset):
|
class ClsDataset(_ClsDataset):
|
||||||
"""EasyCV dataset for classification.
|
"""EasyCV dataset for classification.
|
||||||
|
|
||||||
@@ -4,13 +4,11 @@ from typing import TYPE_CHECKING
|
|||||||
from modelscope.utils.import_utils import LazyImportModule
|
from modelscope.utils.import_utils import LazyImportModule
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .hand_2d_keypoints_dataset import Hand2DKeypointDataset
|
from .image_inpainting_dataset import ImageInpaintingDataset
|
||||||
|
|
||||||
else:
|
else:
|
||||||
_import_structure = {
|
_import_structure = {
|
||||||
'hand_2d_keypoints_dataset': ['Hand2DKeypointDataset']
|
'image_inpainting_dataset': ['ImageInpaintingDataset'],
|
||||||
}
|
}
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
sys.modules[__name__] = LazyImportModule(
|
sys.modules[__name__] = LazyImportModule(
|
||||||
@@ -3,20 +3,16 @@ Part of the implementation is borrowed and modified from LaMa,
|
|||||||
publicly available at https://github.com/saic-mdal/lama
|
publicly available at https://github.com/saic-mdal/lama
|
||||||
"""
|
"""
|
||||||
import glob
|
import glob
|
||||||
import os
|
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
import albumentations as A
|
import albumentations as A
|
||||||
import cv2
|
import cv2
|
||||||
import json
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
|
||||||
|
|
||||||
from modelscope.metainfo import Models
|
from modelscope.metainfo import Models
|
||||||
from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS
|
from modelscope.msdatasets.dataset_cls.custom_datasets import (
|
||||||
from modelscope.msdatasets.task_datasets.torch_base_dataset import \
|
CUSTOM_DATASETS, TorchCustomDataset)
|
||||||
TorchTaskDataset
|
|
||||||
from modelscope.utils.constant import Tasks
|
from modelscope.utils.constant import Tasks
|
||||||
from modelscope.utils.logger import get_logger
|
from modelscope.utils.logger import get_logger
|
||||||
from .aug import IAAAffine2, IAAPerspective2
|
from .aug import IAAAffine2, IAAPerspective2
|
||||||
@@ -296,9 +292,9 @@ def get_transforms(test_mode, out_size):
|
|||||||
return transform
|
return transform
|
||||||
|
|
||||||
|
|
||||||
@TASK_DATASETS.register_module(
|
@CUSTOM_DATASETS.register_module(
|
||||||
Tasks.image_inpainting, module_name=Models.image_inpainting)
|
Tasks.image_inpainting, module_name=Models.image_inpainting)
|
||||||
class ImageInpaintingDataset(TorchTaskDataset):
|
class ImageInpaintingDataset(TorchCustomDataset):
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
split_config = kwargs['split_config']
|
split_config = kwargs['split_config']
|
||||||
@@ -6,9 +6,9 @@ import numpy as np
|
|||||||
from pycocotools.coco import COCO
|
from pycocotools.coco import COCO
|
||||||
|
|
||||||
from modelscope.metainfo import Models
|
from modelscope.metainfo import Models
|
||||||
|
from modelscope.msdatasets.dataset_cls.custom_datasets import (
|
||||||
|
CUSTOM_DATASETS, TorchCustomDataset)
|
||||||
from modelscope.utils.constant import Tasks
|
from modelscope.utils.constant import Tasks
|
||||||
from .builder import TASK_DATASETS
|
|
||||||
from .torch_base_dataset import TorchTaskDataset
|
|
||||||
|
|
||||||
DATASET_STRUCTURE = {
|
DATASET_STRUCTURE = {
|
||||||
'train': {
|
'train': {
|
||||||
@@ -22,10 +22,10 @@ DATASET_STRUCTURE = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@TASK_DATASETS.register_module(
|
@CUSTOM_DATASETS.register_module(
|
||||||
module_name=Models.cascade_mask_rcnn_swin,
|
module_name=Models.cascade_mask_rcnn_swin,
|
||||||
group_key=Tasks.image_segmentation)
|
group_key=Tasks.image_segmentation)
|
||||||
class ImageInstanceSegmentationCocoDataset(TorchTaskDataset):
|
class ImageInstanceSegmentationCocoDataset(TorchCustomDataset):
|
||||||
"""Coco-style dataset for image instance segmentation.
|
"""Coco-style dataset for image instance segmentation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -3,10 +3,9 @@
|
|||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from modelscope.metainfo import Datasets, Models
|
from modelscope.metainfo import CustomDatasets
|
||||||
from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS
|
from modelscope.msdatasets.dataset_cls.custom_datasets import (
|
||||||
from modelscope.msdatasets.task_datasets.torch_base_dataset import \
|
CUSTOM_DATASETS, TorchCustomDataset)
|
||||||
TorchTaskDataset
|
|
||||||
from modelscope.utils.constant import Tasks
|
from modelscope.utils.constant import Tasks
|
||||||
from .data_utils import img2tensor
|
from .data_utils import img2tensor
|
||||||
|
|
||||||
@@ -15,9 +14,9 @@ def default_loader(path):
|
|||||||
return cv2.imread(path, cv2.IMREAD_COLOR).astype(np.float32) / 255.0
|
return cv2.imread(path, cv2.IMREAD_COLOR).astype(np.float32) / 255.0
|
||||||
|
|
||||||
|
|
||||||
@TASK_DATASETS.register_module(
|
@CUSTOM_DATASETS.register_module(
|
||||||
Tasks.image_portrait_enhancement, module_name=Datasets.PairedDataset)
|
Tasks.image_portrait_enhancement, module_name=CustomDatasets.PairedDataset)
|
||||||
class ImagePortraitEnhancementDataset(TorchTaskDataset):
|
class ImagePortraitEnhancementDataset(TorchCustomDataset):
|
||||||
"""Paired image dataset for image portrait enhancement.
|
"""Paired image dataset for image portrait enhancement.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -1,21 +1,18 @@
|
|||||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||||
|
|
||||||
import cv2
|
|
||||||
import numpy as np
|
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
|
|
||||||
from modelscope.metainfo import Models
|
from modelscope.metainfo import Models
|
||||||
from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS
|
from modelscope.msdatasets.dataset_cls.custom_datasets import (
|
||||||
from modelscope.msdatasets.task_datasets.torch_base_dataset import \
|
CUSTOM_DATASETS, TorchCustomDataset)
|
||||||
TorchTaskDataset
|
|
||||||
from modelscope.preprocessors import LoadImage
|
from modelscope.preprocessors import LoadImage
|
||||||
from modelscope.utils.constant import Tasks
|
from modelscope.utils.constant import Tasks
|
||||||
|
|
||||||
|
|
||||||
@TASK_DATASETS.register_module(
|
@CUSTOM_DATASETS.register_module(
|
||||||
Tasks.image_quality_assessment_degradation,
|
Tasks.image_quality_assessment_degradation,
|
||||||
module_name=Models.image_quality_assessment_degradation)
|
module_name=Models.image_quality_assessment_degradation)
|
||||||
class ImageQualityAssessmentDegradationDataset(TorchTaskDataset):
|
class ImageQualityAssessmentDegradationDataset(TorchCustomDataset):
|
||||||
"""Paired image dataset for image quality assessment degradation.
|
"""Paired image dataset for image quality assessment degradation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -1,20 +1,16 @@
|
|||||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||||
|
|
||||||
import cv2
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from modelscope.metainfo import Models
|
from modelscope.metainfo import Models
|
||||||
from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS
|
from modelscope.msdatasets.dataset_cls.custom_datasets import (
|
||||||
from modelscope.msdatasets.task_datasets.torch_base_dataset import \
|
CUSTOM_DATASETS, TorchCustomDataset)
|
||||||
TorchTaskDataset
|
|
||||||
from modelscope.preprocessors.cv import ImageQualityAssessmentMosPreprocessor
|
from modelscope.preprocessors.cv import ImageQualityAssessmentMosPreprocessor
|
||||||
from modelscope.utils.constant import Tasks
|
from modelscope.utils.constant import Tasks
|
||||||
|
|
||||||
|
|
||||||
@TASK_DATASETS.register_module(
|
@CUSTOM_DATASETS.register_module(
|
||||||
Tasks.image_quality_assessment_mos,
|
Tasks.image_quality_assessment_mos,
|
||||||
module_name=Models.image_quality_assessment_mos)
|
module_name=Models.image_quality_assessment_mos)
|
||||||
class ImageQualityAssessmentMosDataset(TorchTaskDataset):
|
class ImageQualityAssessmentMosDataset(TorchCustomDataset):
|
||||||
"""Paired image dataset for image quality assessment mos.
|
"""Paired image dataset for image quality assessment mos.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -1,14 +1,15 @@
|
|||||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||||
from easycv.datasets.segmentation import SegDataset as _SegDataset
|
from easycv.datasets.segmentation import SegDataset as _SegDataset
|
||||||
|
|
||||||
from modelscope.metainfo import Datasets
|
from modelscope.metainfo import CustomDatasets
|
||||||
from modelscope.msdatasets.cv.easycv_base import EasyCVBaseDataset
|
from modelscope.msdatasets.dataset_cls.custom_datasets import CUSTOM_DATASETS
|
||||||
from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS
|
from modelscope.msdatasets.dataset_cls.custom_datasets.easycv_base import \
|
||||||
|
EasyCVBaseDataset
|
||||||
from modelscope.utils.constant import Tasks
|
from modelscope.utils.constant import Tasks
|
||||||
|
|
||||||
|
|
||||||
@TASK_DATASETS.register_module(
|
@CUSTOM_DATASETS.register_module(
|
||||||
group_key=Tasks.image_segmentation, module_name=Datasets.SegDataset)
|
group_key=Tasks.image_segmentation, module_name=CustomDatasets.SegDataset)
|
||||||
class SegDataset(EasyCVBaseDataset, _SegDataset):
|
class SegDataset(EasyCVBaseDataset, _SegDataset):
|
||||||
"""EasyCV dataset for Sementic segmentation.
|
"""EasyCV dataset for Sementic segmentation.
|
||||||
For more details, please refer to :
|
For more details, please refer to :
|
||||||
@@ -25,16 +25,15 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from modelscope.metainfo import Models
|
from modelscope.metainfo import Models
|
||||||
from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS
|
from modelscope.msdatasets.dataset_cls.custom_datasets import (
|
||||||
from modelscope.msdatasets.task_datasets.torch_base_dataset import \
|
CUSTOM_DATASETS, TorchCustomDataset)
|
||||||
TorchTaskDataset
|
|
||||||
from modelscope.utils.constant import Tasks
|
from modelscope.utils.constant import Tasks
|
||||||
|
|
||||||
|
|
||||||
@TASK_DATASETS.register_module(
|
@CUSTOM_DATASETS.register_module(
|
||||||
Tasks.language_guided_video_summarization,
|
Tasks.language_guided_video_summarization,
|
||||||
module_name=Models.language_guided_video_summarization)
|
module_name=Models.language_guided_video_summarization)
|
||||||
class LanguageGuidedVideoSummarizationDataset(TorchTaskDataset):
|
class LanguageGuidedVideoSummarizationDataset(TorchCustomDataset):
|
||||||
|
|
||||||
def __init__(self, mode, opt, root_dir):
|
def __init__(self, mode, opt, root_dir):
|
||||||
self.mode = mode
|
self.mode = mode
|
||||||
@@ -1,24 +1,20 @@
|
|||||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||||
import random
|
import random
|
||||||
from dataclasses import dataclass
|
from typing import Any, List, Union
|
||||||
from typing import Any, Dict, List, Tuple, Union
|
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import torch
|
import torch
|
||||||
from datasets import Dataset, IterableDataset, concatenate_datasets
|
|
||||||
from torch.utils.data import ConcatDataset
|
from torch.utils.data import ConcatDataset
|
||||||
from transformers import DataCollatorWithPadding
|
|
||||||
|
|
||||||
from modelscope.metainfo import Models
|
from modelscope.metainfo import Models
|
||||||
|
from modelscope.msdatasets.dataset_cls.custom_datasets import (
|
||||||
|
CUSTOM_DATASETS, TorchCustomDataset)
|
||||||
from modelscope.utils.constant import ModeKeys, Tasks
|
from modelscope.utils.constant import ModeKeys, Tasks
|
||||||
from .base import TaskDataset
|
|
||||||
from .builder import TASK_DATASETS
|
|
||||||
from .torch_base_dataset import TorchTaskDataset
|
|
||||||
|
|
||||||
|
|
||||||
@TASK_DATASETS.register_module(
|
@CUSTOM_DATASETS.register_module(
|
||||||
group_key=Tasks.text_ranking, module_name=Models.mgeo)
|
group_key=Tasks.text_ranking, module_name=Models.mgeo)
|
||||||
class MGeoRankingDataset(TorchTaskDataset):
|
class MGeoRankingDataset(TorchCustomDataset):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
datasets: Union[Any, List[Any]],
|
datasets: Union[Any, List[Any]],
|
||||||
@@ -0,0 +1,20 @@
|
|||||||
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from modelscope.utils.import_utils import LazyImportModule
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .movie_scene_segmentation_dataset import MovieSceneSegmentationDataset
|
||||||
|
else:
|
||||||
|
_import_structure = {
|
||||||
|
'movie_scene_segmentation_dataset': ['MovieSceneSegmentationDataset'],
|
||||||
|
}
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.modules[__name__] = LazyImportModule(
|
||||||
|
__name__,
|
||||||
|
globals()['__file__'],
|
||||||
|
_import_structure,
|
||||||
|
module_spec=__spec__,
|
||||||
|
extra_objects={},
|
||||||
|
)
|
||||||
@@ -10,9 +10,8 @@ import torch
|
|||||||
from torchvision.datasets.folder import pil_loader
|
from torchvision.datasets.folder import pil_loader
|
||||||
|
|
||||||
from modelscope.metainfo import Models
|
from modelscope.metainfo import Models
|
||||||
from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS
|
from modelscope.msdatasets.dataset_cls.custom_datasets.builder import \
|
||||||
from modelscope.msdatasets.task_datasets.torch_base_dataset import \
|
CUSTOM_DATASETS
|
||||||
TorchTaskDataset
|
|
||||||
from modelscope.utils.constant import Tasks
|
from modelscope.utils.constant import Tasks
|
||||||
from . import sampler
|
from . import sampler
|
||||||
|
|
||||||
@@ -30,9 +29,9 @@ DATASET_STRUCTURE = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@TASK_DATASETS.register_module(
|
@CUSTOM_DATASETS.register_module(
|
||||||
Tasks.movie_scene_segmentation, module_name=Models.resnet50_bert)
|
group_key=Tasks.movie_scene_segmentation, module_name=Models.resnet50_bert)
|
||||||
class MovieSceneSegmentationDataset(TorchTaskDataset):
|
class MovieSceneSegmentationDataset(torch.utils.data.Dataset):
|
||||||
"""dataset for movie scene segmentation.
|
"""dataset for movie scene segmentation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -1,20 +1,21 @@
|
|||||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||||
import os.path as osp
|
|
||||||
|
|
||||||
from easycv.datasets.detection import DetDataset as _DetDataset
|
from easycv.datasets.detection import DetDataset as _DetDataset
|
||||||
from easycv.datasets.detection import \
|
from easycv.datasets.detection import \
|
||||||
DetImagesMixDataset as _DetImagesMixDataset
|
DetImagesMixDataset as _DetImagesMixDataset
|
||||||
|
|
||||||
from modelscope.metainfo import Datasets
|
from modelscope.metainfo import CustomDatasets
|
||||||
from modelscope.msdatasets.cv.easycv_base import EasyCVBaseDataset
|
from modelscope.msdatasets.dataset_cls.custom_datasets import CUSTOM_DATASETS
|
||||||
from modelscope.msdatasets.task_datasets import TASK_DATASETS
|
from modelscope.msdatasets.dataset_cls.custom_datasets.easycv_base import \
|
||||||
|
EasyCVBaseDataset
|
||||||
from modelscope.utils.constant import Tasks
|
from modelscope.utils.constant import Tasks
|
||||||
|
|
||||||
|
|
||||||
@TASK_DATASETS.register_module(
|
@CUSTOM_DATASETS.register_module(
|
||||||
group_key=Tasks.image_object_detection, module_name=Datasets.DetDataset)
|
group_key=Tasks.image_object_detection,
|
||||||
@TASK_DATASETS.register_module(
|
module_name=CustomDatasets.DetDataset)
|
||||||
group_key=Tasks.image_segmentation, module_name=Datasets.DetDataset)
|
@CUSTOM_DATASETS.register_module(
|
||||||
|
group_key=Tasks.image_segmentation, module_name=CustomDatasets.DetDataset)
|
||||||
class DetDataset(EasyCVBaseDataset, _DetDataset):
|
class DetDataset(EasyCVBaseDataset, _DetDataset):
|
||||||
"""EasyCV dataset for object detection.
|
"""EasyCV dataset for object detection.
|
||||||
For more details, please refer to https://github.com/alibaba/EasyCV/blob/master/easycv/datasets/detection/raw.py .
|
For more details, please refer to https://github.com/alibaba/EasyCV/blob/master/easycv/datasets/detection/raw.py .
|
||||||
@@ -47,12 +48,12 @@ class DetDataset(EasyCVBaseDataset, _DetDataset):
|
|||||||
_DetDataset.__init__(self, *args, **kwargs)
|
_DetDataset.__init__(self, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@TASK_DATASETS.register_module(
|
@CUSTOM_DATASETS.register_module(
|
||||||
group_key=Tasks.image_object_detection,
|
group_key=Tasks.image_object_detection,
|
||||||
module_name=Datasets.DetImagesMixDataset)
|
module_name=CustomDatasets.DetImagesMixDataset)
|
||||||
@TASK_DATASETS.register_module(
|
@CUSTOM_DATASETS.register_module(
|
||||||
group_key=Tasks.domain_specific_object_detection,
|
group_key=Tasks.domain_specific_object_detection,
|
||||||
module_name=Datasets.DetImagesMixDataset)
|
module_name=CustomDatasets.DetImagesMixDataset)
|
||||||
class DetImagesMixDataset(EasyCVBaseDataset, _DetImagesMixDataset):
|
class DetImagesMixDataset(EasyCVBaseDataset, _DetImagesMixDataset):
|
||||||
"""EasyCV dataset for object detection, a wrapper of multiple images mixed dataset.
|
"""EasyCV dataset for object detection, a wrapper of multiple images mixed dataset.
|
||||||
Suitable for training on multiple images mixed data augmentation like
|
Suitable for training on multiple images mixed data augmentation like
|
||||||
@@ -1,3 +1,4 @@
|
|||||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||||
from .data_loader import DataLoader
|
from .data_loader import DataLoader
|
||||||
from .image_dataset import ImageDataset
|
from .image_dataset import ImageDataset
|
||||||
|
from .measures import QuadMeasurer
|
||||||
@@ -9,9 +9,10 @@ import torch
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from modelscope.metainfo import Models
|
from modelscope.metainfo import Models
|
||||||
from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS
|
from modelscope.msdatasets.dataset_cls.custom_datasets.builder import \
|
||||||
from modelscope.msdatasets.task_datasets.torch_base_dataset import \
|
CUSTOM_DATASETS
|
||||||
TorchTaskDataset
|
from modelscope.msdatasets.dataset_cls.custom_datasets.torch_custom_dataset import \
|
||||||
|
TorchCustomDataset
|
||||||
from modelscope.utils.constant import Tasks
|
from modelscope.utils.constant import Tasks
|
||||||
from modelscope.utils.logger import get_logger
|
from modelscope.utils.logger import get_logger
|
||||||
|
|
||||||
@@ -29,9 +30,9 @@ def Q2B(uchar):
|
|||||||
return chr(inside_code)
|
return chr(inside_code)
|
||||||
|
|
||||||
|
|
||||||
@TASK_DATASETS.register_module(
|
@CUSTOM_DATASETS.register_module(
|
||||||
Tasks.ocr_recognition, module_name=Models.ocr_recognition)
|
Tasks.ocr_recognition, module_name=Models.ocr_recognition)
|
||||||
class OCRRecognitionDataset(TorchTaskDataset):
|
class OCRRecognitionDataset(TorchCustomDataset):
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
split_config = kwargs['split_config']
|
split_config = kwargs['split_config']
|
||||||
@@ -3,14 +3,13 @@
|
|||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from modelscope.metainfo import Datasets
|
from modelscope.metainfo import CustomDatasets
|
||||||
from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS
|
from modelscope.msdatasets.dataset_cls.custom_datasets import (
|
||||||
from modelscope.msdatasets.task_datasets.sidd_image_denoising.data_utils import (
|
CUSTOM_DATASETS, TorchCustomDataset)
|
||||||
|
from modelscope.msdatasets.dataset_cls.custom_datasets.sidd_image_denoising.data_utils import (
|
||||||
img2tensor, padding)
|
img2tensor, padding)
|
||||||
from modelscope.msdatasets.task_datasets.sidd_image_denoising.transforms import (
|
from modelscope.msdatasets.dataset_cls.custom_datasets.sidd_image_denoising.transforms import (
|
||||||
augment, paired_random_crop)
|
augment, paired_random_crop)
|
||||||
from modelscope.msdatasets.task_datasets.torch_base_dataset import \
|
|
||||||
TorchTaskDataset
|
|
||||||
from modelscope.utils.constant import Tasks
|
from modelscope.utils.constant import Tasks
|
||||||
|
|
||||||
|
|
||||||
@@ -18,9 +17,9 @@ def default_loader(path):
|
|||||||
return cv2.imread(path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.0
|
return cv2.imread(path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.0
|
||||||
|
|
||||||
|
|
||||||
@TASK_DATASETS.register_module(
|
@CUSTOM_DATASETS.register_module(
|
||||||
Tasks.image_deblurring, module_name=Datasets.PairedDataset)
|
Tasks.image_deblurring, module_name=CustomDatasets.PairedDataset)
|
||||||
class RedsImageDeblurringDataset(TorchTaskDataset):
|
class RedsImageDeblurringDataset(TorchCustomDataset):
|
||||||
"""Paired image dataset for image restoration.
|
"""Paired image dataset for image restoration.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -0,0 +1,21 @@
|
|||||||
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from modelscope.utils.import_utils import LazyImportModule
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .referring_video_object_segmentation_dataset import ReferringVideoObjectSegmentationDataset
|
||||||
|
else:
|
||||||
|
_import_structure = {
|
||||||
|
'referring_video_object_segmentation_dataset':
|
||||||
|
['MovieSceneSegmentationDataset'],
|
||||||
|
}
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.modules[__name__] = LazyImportModule(
|
||||||
|
__name__,
|
||||||
|
globals()['__file__'],
|
||||||
|
_import_structure,
|
||||||
|
module_spec=__spec__,
|
||||||
|
extra_objects={},
|
||||||
|
)
|
||||||
@@ -18,9 +18,8 @@ from tqdm import tqdm
|
|||||||
from modelscope.metainfo import Models
|
from modelscope.metainfo import Models
|
||||||
from modelscope.models.cv.referring_video_object_segmentation.utils import \
|
from modelscope.models.cv.referring_video_object_segmentation.utils import \
|
||||||
nested_tensor_from_videos_list
|
nested_tensor_from_videos_list
|
||||||
from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS
|
from modelscope.msdatasets.dataset_cls.custom_datasets import (
|
||||||
from modelscope.msdatasets.task_datasets.torch_base_dataset import \
|
CUSTOM_DATASETS, TorchCustomDataset)
|
||||||
TorchTaskDataset
|
|
||||||
from modelscope.utils.constant import Tasks
|
from modelscope.utils.constant import Tasks
|
||||||
from modelscope.utils.logger import get_logger
|
from modelscope.utils.logger import get_logger
|
||||||
from . import transformers as T
|
from . import transformers as T
|
||||||
@@ -33,10 +32,10 @@ def get_image_id(video_id, frame_idx, ref_instance_a2d_id):
|
|||||||
return image_id
|
return image_id
|
||||||
|
|
||||||
|
|
||||||
@TASK_DATASETS.register_module(
|
@CUSTOM_DATASETS.register_module(
|
||||||
Tasks.referring_video_object_segmentation,
|
Tasks.referring_video_object_segmentation,
|
||||||
module_name=Models.referring_video_object_segmentation)
|
module_name=Models.referring_video_object_segmentation)
|
||||||
class ReferringVideoObjectSegmentationDataset(TorchTaskDataset):
|
class ReferringVideoObjectSegmentationDataset(TorchCustomDataset):
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
split_config = kwargs['split_config']
|
split_config = kwargs['split_config']
|
||||||
@@ -4,9 +4,8 @@ import cv2
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from modelscope.metainfo import Models
|
from modelscope.metainfo import Models
|
||||||
from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS
|
from modelscope.msdatasets.dataset_cls.custom_datasets import (
|
||||||
from modelscope.msdatasets.task_datasets.torch_base_dataset import \
|
CUSTOM_DATASETS, TorchCustomDataset)
|
||||||
TorchTaskDataset
|
|
||||||
from modelscope.utils.constant import Tasks
|
from modelscope.utils.constant import Tasks
|
||||||
from .data_utils import img2tensor, padding
|
from .data_utils import img2tensor, padding
|
||||||
from .transforms import augment, paired_random_crop
|
from .transforms import augment, paired_random_crop
|
||||||
@@ -16,9 +15,9 @@ def default_loader(path):
|
|||||||
return cv2.imread(path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.0
|
return cv2.imread(path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.0
|
||||||
|
|
||||||
|
|
||||||
@TASK_DATASETS.register_module(
|
@CUSTOM_DATASETS.register_module(
|
||||||
Tasks.image_denoising, module_name=Models.nafnet)
|
Tasks.image_denoising, module_name=Models.nafnet)
|
||||||
class SiddImageDenoisingDataset(TorchTaskDataset):
|
class SiddImageDenoisingDataset(TorchCustomDataset):
|
||||||
"""Paired image dataset for image restoration.
|
"""Paired image dataset for image restoration.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -1,25 +1,21 @@
|
|||||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||||
import random
|
import random
|
||||||
from dataclasses import dataclass
|
from typing import Any, List, Union
|
||||||
from typing import Any, Dict, List, Tuple, Union
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from datasets import Dataset, IterableDataset, concatenate_datasets
|
|
||||||
from torch.utils.data import ConcatDataset
|
from torch.utils.data import ConcatDataset
|
||||||
from transformers import DataCollatorWithPadding
|
|
||||||
|
|
||||||
from modelscope.metainfo import Models
|
from modelscope.metainfo import Models
|
||||||
|
from modelscope.msdatasets.dataset_cls.custom_datasets import (
|
||||||
|
CUSTOM_DATASETS, TorchCustomDataset)
|
||||||
from modelscope.utils.constant import ModeKeys, Tasks
|
from modelscope.utils.constant import ModeKeys, Tasks
|
||||||
from .base import TaskDataset
|
|
||||||
from .builder import TASK_DATASETS
|
|
||||||
from .torch_base_dataset import TorchTaskDataset
|
|
||||||
|
|
||||||
|
|
||||||
@TASK_DATASETS.register_module(
|
@CUSTOM_DATASETS.register_module(
|
||||||
group_key=Tasks.text_ranking, module_name=Models.bert)
|
group_key=Tasks.text_ranking, module_name=Models.bert)
|
||||||
@TASK_DATASETS.register_module(
|
@CUSTOM_DATASETS.register_module(
|
||||||
group_key=Tasks.sentence_embedding, module_name=Models.bert)
|
group_key=Tasks.sentence_embedding, module_name=Models.bert)
|
||||||
class TextRankingDataset(TorchTaskDataset):
|
class TextRankingDataset(TorchCustomDataset):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
datasets: Union[Any, List[Any]],
|
datasets: Union[Any, List[Any]],
|
||||||
@@ -0,0 +1,51 @@
|
|||||||
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||||
|
from typing import Any, List, Union
|
||||||
|
|
||||||
|
import torch.utils.data
|
||||||
|
from torch.utils.data import ConcatDataset as TorchConcatDataset
|
||||||
|
|
||||||
|
from modelscope.utils.constant import ModeKeys
|
||||||
|
|
||||||
|
|
||||||
|
class TorchCustomDataset(torch.utils.data.Dataset):
|
||||||
|
"""The custom dataset base class for all the torch-based task processors.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
datasets: Union[Any, List[Any]],
|
||||||
|
mode=ModeKeys.TRAIN,
|
||||||
|
preprocessor=None,
|
||||||
|
**kwargs):
|
||||||
|
self.trainer = None
|
||||||
|
self.mode = mode
|
||||||
|
self.preprocessor = preprocessor
|
||||||
|
self._inner_dataset = self.prepare_dataset(datasets)
|
||||||
|
|
||||||
|
def __getitem__(self, index) -> Any:
|
||||||
|
return self.preprocessor(
|
||||||
|
self._inner_dataset[index]
|
||||||
|
) if self.preprocessor else self._inner_dataset[index]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self._inner_dataset)
|
||||||
|
|
||||||
|
def prepare_dataset(self, datasets: Union[Any, List[Any]]) -> Any:
|
||||||
|
"""Prepare a dataset.
|
||||||
|
|
||||||
|
User can process the input datasets in a whole dataset perspective.
|
||||||
|
This method gives a default implementation of datasets merging, user can override this
|
||||||
|
method to write custom logics.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
datasets: The original dataset(s)
|
||||||
|
|
||||||
|
Returns: A single dataset, which may be created after merging.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if isinstance(datasets, List):
|
||||||
|
if len(datasets) == 1:
|
||||||
|
return datasets[0]
|
||||||
|
elif len(datasets) > 1:
|
||||||
|
return TorchConcatDataset(datasets)
|
||||||
|
else:
|
||||||
|
return datasets
|
||||||
@@ -5,13 +5,13 @@ import numpy as np
|
|||||||
from datasets import Dataset, IterableDataset, concatenate_datasets
|
from datasets import Dataset, IterableDataset, concatenate_datasets
|
||||||
|
|
||||||
from modelscope.metainfo import Models
|
from modelscope.metainfo import Models
|
||||||
|
from modelscope.msdatasets.dataset_cls.custom_datasets import (
|
||||||
|
CUSTOM_DATASETS, TorchCustomDataset)
|
||||||
from modelscope.utils.constant import Tasks
|
from modelscope.utils.constant import Tasks
|
||||||
from .builder import TASK_DATASETS
|
|
||||||
from .torch_base_dataset import TorchTaskDataset
|
|
||||||
|
|
||||||
|
|
||||||
@TASK_DATASETS.register_module(module_name=Models.veco, group_key=Tasks.nli)
|
@CUSTOM_DATASETS.register_module(module_name=Models.veco, group_key=Tasks.nli)
|
||||||
class VecoDataset(TorchTaskDataset):
|
class VecoDataset(TorchCustomDataset):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
datasets: Union[Any, List[Any]],
|
datasets: Union[Any, List[Any]],
|
||||||
@@ -1,16 +1,13 @@
|
|||||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||||
|
|
||||||
from collections import defaultdict
|
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from modelscope.metainfo import Models
|
from modelscope.metainfo import Models
|
||||||
from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS
|
from modelscope.msdatasets.dataset_cls.custom_datasets import (
|
||||||
from modelscope.msdatasets.task_datasets.torch_base_dataset import \
|
CUSTOM_DATASETS, TorchCustomDataset)
|
||||||
TorchTaskDataset
|
from modelscope.msdatasets.dataset_cls.custom_datasets.video_frame_interpolation.data_utils import (
|
||||||
from modelscope.msdatasets.task_datasets.video_frame_interpolation.data_utils import (
|
|
||||||
img2tensor, img_padding)
|
img2tensor, img_padding)
|
||||||
from modelscope.utils.constant import Tasks
|
from modelscope.utils.constant import Tasks
|
||||||
|
|
||||||
@@ -19,10 +16,10 @@ def default_loader(path):
|
|||||||
return cv2.imread(path, cv2.IMREAD_UNCHANGED).astype(np.float32)
|
return cv2.imread(path, cv2.IMREAD_UNCHANGED).astype(np.float32)
|
||||||
|
|
||||||
|
|
||||||
@TASK_DATASETS.register_module(
|
@CUSTOM_DATASETS.register_module(
|
||||||
Tasks.video_frame_interpolation,
|
Tasks.video_frame_interpolation,
|
||||||
module_name=Models.video_frame_interpolation)
|
module_name=Models.video_frame_interpolation)
|
||||||
class VideoFrameInterpolationDataset(TorchTaskDataset):
|
class VideoFrameInterpolationDataset(TorchCustomDataset):
|
||||||
"""Dataset for video frame-interpolation.
|
"""Dataset for video frame-interpolation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -1,15 +1,14 @@
|
|||||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||||
|
|
||||||
from modelscope.metainfo import Models
|
from modelscope.metainfo import Models
|
||||||
from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS
|
from modelscope.msdatasets.dataset_cls.custom_datasets import (
|
||||||
from modelscope.msdatasets.task_datasets.torch_base_dataset import \
|
CUSTOM_DATASETS, TorchCustomDataset)
|
||||||
TorchTaskDataset
|
|
||||||
from modelscope.utils.constant import Tasks
|
from modelscope.utils.constant import Tasks
|
||||||
|
|
||||||
|
|
||||||
@TASK_DATASETS.register_module(
|
@CUSTOM_DATASETS.register_module(
|
||||||
Tasks.video_stabilization, module_name=Models.video_stabilization)
|
Tasks.video_stabilization, module_name=Models.video_stabilization)
|
||||||
class VideoStabilizationDataset(TorchTaskDataset):
|
class VideoStabilizationDataset(TorchCustomDataset):
|
||||||
"""Paired video dataset for video stabilization.
|
"""Paired video dataset for video stabilization.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -8,11 +8,11 @@ import json
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from modelscope.msdatasets.task_datasets.torch_base_dataset import \
|
from modelscope.msdatasets.dataset_cls.custom_datasets import \
|
||||||
TorchTaskDataset
|
TorchCustomDataset
|
||||||
|
|
||||||
|
|
||||||
class VideoSummarizationDataset(TorchTaskDataset):
|
class VideoSummarizationDataset(TorchCustomDataset):
|
||||||
|
|
||||||
def __init__(self, mode, opt, root_dir):
|
def __init__(self, mode, opt, root_dir):
|
||||||
self.mode = mode
|
self.mode = mode
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user