mirror of
https://github.com/modelscope/modelscope.git
synced 2026-02-24 12:10:09 +01:00
update EasyCV MsDataset
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10103248 * update EasyCV MSDataset
This commit is contained in:
committed by
wenmeng.zwm
parent
0b27c77a54
commit
07728b164e
@@ -1,26 +1,9 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os.path as osp
|
||||
|
||||
|
||||
class EasyCVBaseDataset(object):
|
||||
"""Adapt to MSDataset.
|
||||
Subclasses need to implement ``DATA_STRUCTURE``, the format is as follows, e.g.:
|
||||
|
||||
{
|
||||
'${data source name}': {
|
||||
'train':{
|
||||
'${image root arg}': 'images', # directory name of images relative to the root path
|
||||
'${label root arg}': 'labels', # directory name of lables relative to the root path
|
||||
...
|
||||
},
|
||||
'validation': {
|
||||
'${image root arg}': 'images',
|
||||
'${label root arg}': 'labels',
|
||||
...
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Args:
|
||||
split_config (dict): Dataset root path from MSDataset, e.g.
|
||||
@@ -29,7 +12,7 @@ class EasyCVBaseDataset(object):
|
||||
the model if supplied. Not support yet.
|
||||
mode: Training or Evaluation.
|
||||
"""
|
||||
DATA_STRUCTURE = None
|
||||
DATA_ROOT_PATTERN = '${data_root}'
|
||||
|
||||
def __init__(self,
|
||||
split_config=None,
|
||||
@@ -45,15 +28,9 @@ class EasyCVBaseDataset(object):
|
||||
|
||||
def _update_data_source(self, data_source):
|
||||
data_root = next(iter(self.split_config.values()))
|
||||
split = next(iter(self.split_config.keys()))
|
||||
data_root = data_root.rstrip(osp.sep)
|
||||
|
||||
# TODO: msdataset should support these keys to be configured in the dataset's json file and passed in
|
||||
if data_source['type'] not in list(self.DATA_STRUCTURE.keys()):
|
||||
raise ValueError(
|
||||
'Only support %s now, but get %s.' %
|
||||
(list(self.DATA_STRUCTURE.keys()), data_source['type']))
|
||||
|
||||
# join data root path of msdataset and default relative name
|
||||
update_args = self.DATA_STRUCTURE[data_source['type']][split]
|
||||
for k, v in update_args.items():
|
||||
data_source.update({k: osp.join(data_root, v)})
|
||||
for k, v in data_source.items():
|
||||
if isinstance(v, str) and self.DATA_ROOT_PATTERN in v:
|
||||
data_source.update(
|
||||
{k: v.replace(self.DATA_ROOT_PATTERN, data_root)})
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
from easycv.datasets.face import FaceKeypointDataset as _FaceKeypointDataset
|
||||
|
||||
from modelscope.metainfo import Datasets
|
||||
from modelscope.msdatasets.cv.easycv_base import EasyCVBaseDataset
|
||||
from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS
|
||||
from modelscope.utils.constant import Tasks
|
||||
|
||||
@@ -9,5 +10,28 @@ from modelscope.utils.constant import Tasks
|
||||
@TASK_DATASETS.register_module(
|
||||
group_key=Tasks.face_2d_keypoints,
|
||||
module_name=Datasets.Face2dKeypointsDataset)
|
||||
class FaceKeypointDataset(_FaceKeypointDataset):
|
||||
"""EasyCV dataset for face 2d keypoints."""
|
||||
class FaceKeypointDataset(EasyCVBaseDataset, _FaceKeypointDataset):
|
||||
"""EasyCV dataset for face 2d keypoints.
|
||||
|
||||
Args:
|
||||
split_config (dict): Dataset root path from MSDataset, e.g.
|
||||
{"train":"local cache path"} or {"evaluation":"local cache path"}.
|
||||
preprocessor (Preprocessor): An optional preprocessor instance, please make sure the preprocessor fits for
|
||||
the model if supplied. Not support yet.
|
||||
mode: Training or Evaluation.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
split_config=None,
|
||||
preprocessor=None,
|
||||
mode=None,
|
||||
*args,
|
||||
**kwargs) -> None:
|
||||
EasyCVBaseDataset.__init__(
|
||||
self,
|
||||
split_config=split_config,
|
||||
preprocessor=preprocessor,
|
||||
mode=mode,
|
||||
args=args,
|
||||
kwargs=kwargs)
|
||||
_FaceKeypointDataset.__init__(self, *args, **kwargs)
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
from easycv.datasets.classification import ClsDataset as _ClsDataset
|
||||
|
||||
from modelscope.metainfo import Datasets
|
||||
from modelscope.msdatasets.cv.easycv_base import EasyCVBaseDataset
|
||||
from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS
|
||||
from modelscope.utils.constant import Tasks
|
||||
|
||||
@@ -10,10 +11,26 @@ from modelscope.utils.constant import Tasks
|
||||
group_key=Tasks.image_classification, module_name=Datasets.ClsDataset)
|
||||
class ClsDataset(_ClsDataset):
|
||||
"""EasyCV dataset for classification.
|
||||
For more details, please refer to :
|
||||
https://github.com/alibaba/EasyCV/blob/master/easycv/datasets/classification/raw.py .
|
||||
|
||||
Args:
|
||||
data_source: Data source config to parse input data.
|
||||
pipeline: Sequence of transform object or config dict to be composed.
|
||||
split_config (dict): Dataset root path from MSDataset, e.g.
|
||||
{"train":"local cache path"} or {"evaluation":"local cache path"}.
|
||||
preprocessor (Preprocessor): An optional preprocessor instance, please make sure the preprocessor fits for
|
||||
the model if supplied. Not support yet.
|
||||
mode: Training or Evaluation.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
split_config=None,
|
||||
preprocessor=None,
|
||||
mode=None,
|
||||
*args,
|
||||
**kwargs) -> None:
|
||||
EasyCVBaseDataset.__init__(
|
||||
self,
|
||||
split_config=split_config,
|
||||
preprocessor=preprocessor,
|
||||
mode=mode,
|
||||
args=args,
|
||||
kwargs=kwargs)
|
||||
_ClsDataset.__init__(self, *args, **kwargs)
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os.path as osp
|
||||
|
||||
from easycv.datasets.segmentation import SegDataset as _SegDataset
|
||||
|
||||
from modelscope.metainfo import Datasets
|
||||
@@ -9,30 +7,9 @@ from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS
|
||||
from modelscope.utils.constant import Tasks
|
||||
|
||||
|
||||
class EasyCVSegBaseDataset(EasyCVBaseDataset):
|
||||
DATA_STRUCTURE = {
|
||||
# data source name
|
||||
'SegSourceRaw': {
|
||||
'train': {
|
||||
'img_root':
|
||||
'images', # directory name of images relative to the root path
|
||||
'label_root':
|
||||
'annotations', # directory name of annotation relative to the root path
|
||||
'split':
|
||||
'train.txt' # split file name relative to the root path
|
||||
},
|
||||
'validation': {
|
||||
'img_root': 'images',
|
||||
'label_root': 'annotations',
|
||||
'split': 'val.txt'
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@TASK_DATASETS.register_module(
|
||||
group_key=Tasks.image_segmentation, module_name=Datasets.SegDataset)
|
||||
class SegDataset(EasyCVSegBaseDataset, _SegDataset):
|
||||
class SegDataset(EasyCVBaseDataset, _SegDataset):
|
||||
"""EasyCV dataset for Sementic segmentation.
|
||||
For more details, please refer to :
|
||||
https://github.com/alibaba/EasyCV/blob/master/easycv/datasets/segmentation/raw.py .
|
||||
@@ -55,7 +32,7 @@ class SegDataset(EasyCVSegBaseDataset, _SegDataset):
|
||||
mode=None,
|
||||
*args,
|
||||
**kwargs) -> None:
|
||||
EasyCVSegBaseDataset.__init__(
|
||||
EasyCVBaseDataset.__init__(
|
||||
self,
|
||||
split_config=split_config,
|
||||
preprocessor=preprocessor,
|
||||
|
||||
@@ -11,26 +11,9 @@ from modelscope.msdatasets.task_datasets import TASK_DATASETS
|
||||
from modelscope.utils.constant import Tasks
|
||||
|
||||
|
||||
class EasyCVDetBaseDataset(EasyCVBaseDataset):
|
||||
DATA_STRUCTURE = {
|
||||
'DetSourceCoco': {
|
||||
'train': {
|
||||
'ann_file':
|
||||
'train.json', # file name of annotation relative to the root path
|
||||
'img_prefix':
|
||||
'images', # directory name of images relative to the root path
|
||||
},
|
||||
'validation': {
|
||||
'ann_file': 'val.json',
|
||||
'img_prefix': 'images',
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@TASK_DATASETS.register_module(
|
||||
group_key=Tasks.image_object_detection, module_name=Datasets.DetDataset)
|
||||
class DetDataset(EasyCVDetBaseDataset, _DetDataset):
|
||||
class DetDataset(EasyCVBaseDataset, _DetDataset):
|
||||
"""EasyCV dataset for object detection.
|
||||
For more details, please refer to https://github.com/alibaba/EasyCV/blob/master/easycv/datasets/detection/raw.py .
|
||||
|
||||
@@ -52,7 +35,7 @@ class DetDataset(EasyCVDetBaseDataset, _DetDataset):
|
||||
mode=None,
|
||||
*args,
|
||||
**kwargs) -> None:
|
||||
EasyCVDetBaseDataset.__init__(
|
||||
EasyCVBaseDataset.__init__(
|
||||
self,
|
||||
split_config=split_config,
|
||||
preprocessor=preprocessor,
|
||||
@@ -65,7 +48,7 @@ class DetDataset(EasyCVDetBaseDataset, _DetDataset):
|
||||
@TASK_DATASETS.register_module(
|
||||
group_key=Tasks.image_object_detection,
|
||||
module_name=Datasets.DetImagesMixDataset)
|
||||
class DetImagesMixDataset(EasyCVDetBaseDataset, _DetImagesMixDataset):
|
||||
class DetImagesMixDataset(EasyCVBaseDataset, _DetImagesMixDataset):
|
||||
"""EasyCV dataset for object detection, a wrapper of multiple images mixed dataset.
|
||||
Suitable for training on multiple images mixed data augmentation like
|
||||
mosaic and mixup. For the augmentation pipeline of mixed image data,
|
||||
@@ -99,7 +82,7 @@ class DetImagesMixDataset(EasyCVDetBaseDataset, _DetImagesMixDataset):
|
||||
mode=None,
|
||||
*args,
|
||||
**kwargs) -> None:
|
||||
EasyCVDetBaseDataset.__init__(
|
||||
EasyCVBaseDataset.__init__(
|
||||
self,
|
||||
split_config=split_config,
|
||||
preprocessor=preprocessor,
|
||||
|
||||
@@ -403,8 +403,8 @@ class MsDataset:
|
||||
)
|
||||
if isinstance(self._hf_ds, ExternalDataset):
|
||||
task_data_config.update({'preprocessor': preprocessors})
|
||||
return build_task_dataset(task_data_config, task_name,
|
||||
self._hf_ds.config_kwargs)
|
||||
task_data_config.update(self._hf_ds.config_kwargs)
|
||||
return build_task_dataset(task_data_config, task_name)
|
||||
if preprocessors is not None:
|
||||
return self.to_torch_dataset_with_processors(
|
||||
preprocessors, columns=columns)
|
||||
|
||||
@@ -19,7 +19,7 @@ quiet-level = 3
|
||||
ignore-words-list = patten,nd,ty,mot,hist,formating,winn,gool,datas,wan,confids
|
||||
|
||||
[flake8]
|
||||
select = B,C,E,F,P,T4,W,B9
|
||||
max-line-length = 120
|
||||
ignore = F401,F405,F821,W503
|
||||
select = B,C,E,F,P,T4,W,B9
|
||||
ignore = F401,F405,F821,W503,E251
|
||||
exclude = docs/src,*.pyi,.git
|
||||
|
||||
@@ -47,7 +47,8 @@ class EasyCVTrainerTestSegformer(unittest.TestCase):
|
||||
namespace='EasyCV',
|
||||
split='validation')
|
||||
kwargs = dict(
|
||||
model='EasyCV/EasyCV-Segformer-b0',
|
||||
model=
|
||||
'damo/cv_segformer-b0_image_semantic-segmentation_coco-stuff164k',
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
work_dir=self.tmp_dir,
|
||||
|
||||
Reference in New Issue
Block a user