mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-14 15:27:42 +01:00
support EasyCV framework and add Segformer model
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9781849 * support EasyCV
This commit is contained in:
@@ -2,7 +2,6 @@
|
||||
"framework": "pytorch",
|
||||
|
||||
"task": "image_classification",
|
||||
"work_dir": "./work_dir",
|
||||
|
||||
"model": {
|
||||
"type": "classification",
|
||||
@@ -119,6 +118,7 @@
|
||||
},
|
||||
|
||||
"train": {
|
||||
"work_dir": "./work_dir",
|
||||
"dataloader": {
|
||||
"batch_size_per_gpu": 2,
|
||||
"workers_per_gpu": 1
|
||||
|
||||
3
data/test/images/image_segmentation.jpg
Normal file
3
data/test/images/image_segmentation.jpg
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:af6fa61274e497ecc170de5adc4b8e7ac89eba2bc22a6aa119b08ec7adbe9459
|
||||
size 146140
|
||||
@@ -1,5 +1,5 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import json
|
||||
import jsonplus
|
||||
import numpy as np
|
||||
|
||||
from .base import FormatHandler
|
||||
@@ -22,14 +22,14 @@ def set_default(obj):
|
||||
|
||||
|
||||
class JsonHandler(FormatHandler):
|
||||
"""Use jsonplus, serialization of Python types to JSON that "just works"."""
|
||||
|
||||
def load(self, file):
|
||||
return json.load(file)
|
||||
return jsonplus.loads(file.read())
|
||||
|
||||
def dump(self, obj, file, **kwargs):
|
||||
kwargs.setdefault('default', set_default)
|
||||
json.dump(obj, file, **kwargs)
|
||||
file.write(self.dumps(obj, **kwargs))
|
||||
|
||||
def dumps(self, obj, **kwargs):
|
||||
kwargs.setdefault('default', set_default)
|
||||
return json.dumps(obj, **kwargs)
|
||||
return jsonplus.dumps(obj, **kwargs)
|
||||
|
||||
@@ -26,6 +26,10 @@ class Models(object):
|
||||
swinL_semantic_segmentation = 'swinL-semantic-segmentation'
|
||||
vitadapter_semantic_segmentation = 'vitadapter-semantic-segmentation'
|
||||
|
||||
# EasyCV models
|
||||
yolox = 'YOLOX'
|
||||
segformer = 'Segformer'
|
||||
|
||||
# nlp models
|
||||
bert = 'bert'
|
||||
palm = 'palm-v2'
|
||||
@@ -92,6 +96,8 @@ class Pipelines(object):
|
||||
body_2d_keypoints = 'hrnetv2w32_body-2d-keypoints_image'
|
||||
human_detection = 'resnet18-human-detection'
|
||||
object_detection = 'vit-object-detection'
|
||||
easycv_detection = 'easycv-detection'
|
||||
easycv_segmentation = 'easycv-segmentation'
|
||||
salient_detection = 'u2net-salient-detection'
|
||||
image_classification = 'image-classification'
|
||||
face_detection = 'resnet-face-detection-scrfd10gkps'
|
||||
@@ -171,6 +177,7 @@ class Trainers(object):
|
||||
"""
|
||||
|
||||
default = 'trainer'
|
||||
easycv = 'easycv'
|
||||
|
||||
# multi-modal trainers
|
||||
clip_multi_modal_embedding = 'clip-multi-modal-embedding'
|
||||
@@ -307,3 +314,12 @@ class LR_Schedulers(object):
|
||||
LinearWarmup = 'LinearWarmup'
|
||||
ConstantWarmup = 'ConstantWarmup'
|
||||
ExponentialWarmup = 'ExponentialWarmup'
|
||||
|
||||
|
||||
class Datasets(object):
|
||||
""" Names for different datasets.
|
||||
"""
|
||||
ClsDataset = 'ClsDataset'
|
||||
SegDataset = 'SegDataset'
|
||||
DetDataset = 'DetDataset'
|
||||
DetImagesMixDataset = 'DetImagesMixDataset'
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import Dict, Mapping, Union
|
||||
|
||||
from modelscope.metainfo import Metrics
|
||||
from modelscope.utils.config import ConfigDict
|
||||
@@ -35,16 +36,19 @@ task_default_metrics = {
|
||||
}
|
||||
|
||||
|
||||
def build_metric(metric_name: str,
|
||||
def build_metric(metric_cfg: Union[str, Dict],
|
||||
field: str = default_group,
|
||||
default_args: dict = None):
|
||||
""" Build metric given metric_name and field.
|
||||
|
||||
Args:
|
||||
metric_name (:obj:`str`): The metric name.
|
||||
metric_name (str | dict): The metric name or metric config dict.
|
||||
field (str, optional): The field of this metric, default value: 'default' for all fields.
|
||||
default_args (dict, optional): Default initialization arguments.
|
||||
"""
|
||||
cfg = ConfigDict({'type': metric_name})
|
||||
if isinstance(metric_cfg, Mapping):
|
||||
assert 'type' in metric_cfg
|
||||
else:
|
||||
metric_cfg = ConfigDict({'type': metric_cfg})
|
||||
return build_from_cfg(
|
||||
cfg, METRICS, group_key=field, default_args=default_args)
|
||||
metric_cfg, METRICS, group_key=field, default_args=default_args)
|
||||
|
||||
25
modelscope/models/cv/easycv_base.py
Normal file
25
modelscope/models/cv/easycv_base.py
Normal file
@@ -0,0 +1,25 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from easycv.models.base import BaseModel
|
||||
from easycv.utils.ms_utils import EasyCVMeta
|
||||
|
||||
from modelscope.models.base import TorchModel
|
||||
|
||||
|
||||
class EasyCVBaseModel(BaseModel, TorchModel):
|
||||
"""Base model for EasyCV."""
|
||||
|
||||
def __init__(self, model_dir=None, args=(), kwargs={}):
|
||||
kwargs.pop(EasyCVMeta.ARCH, None) # pop useless keys
|
||||
BaseModel.__init__(self)
|
||||
TorchModel.__init__(self, model_dir=model_dir)
|
||||
|
||||
def forward(self, img, mode='train', **kwargs):
|
||||
if self.training:
|
||||
losses = self.forward_train(img, **kwargs)
|
||||
loss, log_vars = self._parse_losses(losses)
|
||||
return dict(loss=loss, log_vars=log_vars)
|
||||
else:
|
||||
return self.forward_test(img, **kwargs)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.forward(*args, **kwargs)
|
||||
@@ -5,10 +5,12 @@ from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .semantic_seg_model import SemanticSegmentation
|
||||
from .segformer import Segformer
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
'semantic_seg_model': ['SemanticSegmentation'],
|
||||
'segformer': ['Segformer']
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
@@ -0,0 +1,16 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from easycv.models.segmentation import EncoderDecoder
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.models.cv.easycv_base import EasyCVBaseModel
|
||||
from modelscope.utils.constant import Tasks
|
||||
|
||||
|
||||
@MODELS.register_module(
|
||||
group_key=Tasks.image_segmentation, module_name=Models.segformer)
|
||||
class Segformer(EasyCVBaseModel, EncoderDecoder):
|
||||
|
||||
def __init__(self, model_dir=None, *args, **kwargs):
|
||||
EasyCVBaseModel.__init__(self, model_dir, args, kwargs)
|
||||
EncoderDecoder.__init__(self, *args, **kwargs)
|
||||
@@ -5,10 +5,12 @@ from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .mmdet_model import DetectionModel
|
||||
from .yolox_pai import YOLOX
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
'mmdet_model': ['DetectionModel'],
|
||||
'yolox_pai': ['YOLOX']
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
16
modelscope/models/cv/object_detection/yolox_pai.py
Normal file
16
modelscope/models/cv/object_detection/yolox_pai.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from easycv.models.detection.detectors import YOLOX as _YOLOX
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.models.cv.easycv_base import EasyCVBaseModel
|
||||
from modelscope.utils.constant import Tasks
|
||||
|
||||
|
||||
@MODELS.register_module(
|
||||
group_key=Tasks.image_object_detection, module_name=Models.yolox)
|
||||
class YOLOX(EasyCVBaseModel, _YOLOX):
|
||||
|
||||
def __init__(self, model_dir=None, *args, **kwargs):
|
||||
EasyCVBaseModel.__init__(self, model_dir, args, kwargs)
|
||||
_YOLOX.__init__(self, *args, **kwargs)
|
||||
@@ -1 +1,3 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from . import cv
|
||||
from .ms_dataset import MsDataset
|
||||
|
||||
3
modelscope/msdatasets/cv/__init__.py
Normal file
3
modelscope/msdatasets/cv/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from . import (image_classification, image_semantic_segmentation,
|
||||
object_detection)
|
||||
20
modelscope/msdatasets/cv/image_classification/__init__.py
Normal file
20
modelscope/msdatasets/cv/image_classification/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .classification_dataset import ClsDataset
|
||||
|
||||
else:
|
||||
_import_structure = {'classification_dataset': ['ClsDataset']}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
@@ -0,0 +1,19 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from easycv.datasets.classification import ClsDataset as _ClsDataset
|
||||
|
||||
from modelscope.metainfo import Datasets
|
||||
from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS
|
||||
from modelscope.utils.constant import Tasks
|
||||
|
||||
|
||||
@TASK_DATASETS.register_module(
|
||||
group_key=Tasks.image_classification, module_name=Datasets.ClsDataset)
|
||||
class ClsDataset(_ClsDataset):
|
||||
"""EasyCV dataset for classification.
|
||||
For more details, please refer to :
|
||||
https://github.com/alibaba/EasyCV/blob/master/easycv/datasets/classification/raw.py .
|
||||
|
||||
Args:
|
||||
data_source: Data source config to parse input data.
|
||||
pipeline: Sequence of transform object or config dict to be composed.
|
||||
"""
|
||||
@@ -0,0 +1,20 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .segmentation_dataset import SegDataset
|
||||
|
||||
else:
|
||||
_import_structure = {'easycv_segmentation': ['SegDataset']}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
@@ -0,0 +1,21 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from easycv.datasets.segmentation import SegDataset as _SegDataset
|
||||
|
||||
from modelscope.metainfo import Datasets
|
||||
from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS
|
||||
from modelscope.utils.constant import Tasks
|
||||
|
||||
|
||||
@TASK_DATASETS.register_module(
|
||||
group_key=Tasks.image_segmentation, module_name=Datasets.SegDataset)
|
||||
class SegDataset(_SegDataset):
|
||||
"""EasyCV dataset for Sementic segmentation.
|
||||
For more details, please refer to :
|
||||
https://github.com/alibaba/EasyCV/blob/master/easycv/datasets/segmentation/raw.py .
|
||||
|
||||
Args:
|
||||
data_source: Data source config to parse input data.
|
||||
pipeline: Sequence of transform object or config dict to be composed.
|
||||
ignore_index (int): Label index to be ignored.
|
||||
profiling: If set True, will print transform time.
|
||||
"""
|
||||
22
modelscope/msdatasets/cv/object_detection/__init__.py
Normal file
22
modelscope/msdatasets/cv/object_detection/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .easycv_detection import DetDataset, DetImagesMixDataset
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
'easycv_detection': ['DetDataset', 'DetImagesMixDataset']
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
@@ -0,0 +1,49 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from easycv.datasets.detection import DetDataset as _DetDataset
|
||||
from easycv.datasets.detection import \
|
||||
DetImagesMixDataset as _DetImagesMixDataset
|
||||
|
||||
from modelscope.metainfo import Datasets
|
||||
from modelscope.msdatasets.task_datasets import TASK_DATASETS
|
||||
from modelscope.utils.constant import Tasks
|
||||
|
||||
|
||||
@TASK_DATASETS.register_module(
|
||||
group_key=Tasks.image_object_detection, module_name=Datasets.DetDataset)
|
||||
class DetDataset(_DetDataset):
|
||||
"""EasyCV dataset for object detection.
|
||||
For more details, please refer to https://github.com/alibaba/EasyCV/blob/master/easycv/datasets/detection/raw.py .
|
||||
|
||||
Args:
|
||||
data_source: Data source config to parse input data.
|
||||
pipeline: Transform config list
|
||||
profiling: If set True, will print pipeline time
|
||||
classes: A list of class names, used in evaluation for result and groundtruth visualization
|
||||
"""
|
||||
|
||||
|
||||
@TASK_DATASETS.register_module(
|
||||
group_key=Tasks.image_object_detection,
|
||||
module_name=Datasets.DetImagesMixDataset)
|
||||
class DetImagesMixDataset(_DetImagesMixDataset):
|
||||
"""EasyCV dataset for object detection, a wrapper of multiple images mixed dataset.
|
||||
Suitable for training on multiple images mixed data augmentation like
|
||||
mosaic and mixup. For the augmentation pipeline of mixed image data,
|
||||
the `get_indexes` method needs to be provided to obtain the image
|
||||
indexes, and you can set `skip_flags` to change the pipeline running
|
||||
process. At the same time, we provide the `dynamic_scale` parameter
|
||||
to dynamically change the output image size.
|
||||
output boxes format: cx, cy, w, h
|
||||
|
||||
For more details, please refer to https://github.com/alibaba/EasyCV/blob/master/easycv/datasets/detection/mix.py .
|
||||
|
||||
Args:
|
||||
data_source (:obj:`DetSourceCoco`): Data source config to parse input data.
|
||||
pipeline (Sequence[dict]): Sequence of transform object or
|
||||
config dict to be composed.
|
||||
dynamic_scale (tuple[int], optional): The image scale can be changed
|
||||
dynamically. Default to None.
|
||||
skip_type_keys (list[str], optional): Sequence of type string to
|
||||
be skip pipeline. Default to None.
|
||||
label_padding: out labeling padding [N, 120, 5]
|
||||
"""
|
||||
@@ -240,9 +240,9 @@ class Pipeline(ABC):
|
||||
raise ValueError(f'Unsupported data type {type(data)}')
|
||||
|
||||
def _process_single(self, input: Input, *args, **kwargs) -> Dict[str, Any]:
|
||||
preprocess_params = kwargs.get('preprocess_params')
|
||||
forward_params = kwargs.get('forward_params')
|
||||
postprocess_params = kwargs.get('postprocess_params')
|
||||
preprocess_params = kwargs.get('preprocess_params', {})
|
||||
forward_params = kwargs.get('forward_params', {})
|
||||
postprocess_params = kwargs.get('postprocess_params', {})
|
||||
|
||||
out = self.preprocess(input, **preprocess_params)
|
||||
with device_placement(self.framework, self.device_name):
|
||||
|
||||
@@ -39,7 +39,7 @@ if TYPE_CHECKING:
|
||||
from .tinynas_classification_pipeline import TinynasClassificationPipeline
|
||||
from .video_category_pipeline import VideoCategoryPipeline
|
||||
from .virtual_try_on_pipeline import VirtualTryonPipeline
|
||||
|
||||
from .easycv_pipelines import EasyCVDetectionPipeline, EasyCVSegmentationPipeline
|
||||
else:
|
||||
_import_structure = {
|
||||
'action_recognition_pipeline': ['ActionRecognitionPipeline'],
|
||||
@@ -84,6 +84,8 @@ else:
|
||||
'tinynas_classification_pipeline': ['TinynasClassificationPipeline'],
|
||||
'video_category_pipeline': ['VideoCategoryPipeline'],
|
||||
'virtual_try_on_pipeline': ['VirtualTryonPipeline'],
|
||||
'easycv_pipeline':
|
||||
['EasyCVDetectionPipeline', 'EasyCVSegmentationPipeline']
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
23
modelscope/pipelines/cv/easycv_pipelines/__init__.py
Normal file
23
modelscope/pipelines/cv/easycv_pipelines/__init__.py
Normal file
@@ -0,0 +1,23 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .detection_pipeline import EasyCVDetectionPipeline
|
||||
from .segmentation_pipeline import EasyCVSegmentationPipeline
|
||||
else:
|
||||
_import_structure = {
|
||||
'detection_pipeline': ['EasyCVDetectionPipeline'],
|
||||
'segmentation_pipeline': ['EasyCVSegmentationPipeline']
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
95
modelscope/pipelines/cv/easycv_pipelines/base.py
Normal file
95
modelscope/pipelines/cv/easycv_pipelines/base.py
Normal file
@@ -0,0 +1,95 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import glob
|
||||
import os
|
||||
import os.path as osp
|
||||
from typing import Any
|
||||
|
||||
from easycv.utils.ms_utils import EasyCVMeta
|
||||
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.pipelines.util import is_official_hub_path
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile
|
||||
|
||||
|
||||
class EasyCVPipeline(object):
|
||||
"""Base pipeline for EasyCV.
|
||||
Loading configuration file of modelscope style by default,
|
||||
but it is actually use the predictor api of easycv to predict.
|
||||
So here we do some adaptation work for configuration and predict api.
|
||||
"""
|
||||
|
||||
def __init__(self, model: str, model_file_pattern='*.pt', *args, **kwargs):
|
||||
"""
|
||||
model (str): model id on modelscope hub or local model path.
|
||||
model_file_pattern (str): model file pattern.
|
||||
|
||||
"""
|
||||
self.model_file_pattern = model_file_pattern
|
||||
|
||||
assert isinstance(model, str)
|
||||
if osp.exists(model):
|
||||
model_dir = model
|
||||
else:
|
||||
assert is_official_hub_path(
|
||||
model), 'Only support local model path and official hub path!'
|
||||
model_dir = snapshot_download(
|
||||
model_id=model, revision=DEFAULT_MODEL_REVISION)
|
||||
|
||||
assert osp.isdir(model_dir)
|
||||
model_files = glob.glob(
|
||||
os.path.join(model_dir, self.model_file_pattern))
|
||||
assert len(
|
||||
model_files
|
||||
) == 1, f'Need one model file, but find {len(model_files)}: {model_files}'
|
||||
|
||||
model_path = model_files[0]
|
||||
self.model_path = model_path
|
||||
|
||||
# get configuration file from source model dir
|
||||
self.config_file = os.path.join(model_dir, ModelFile.CONFIGURATION)
|
||||
assert os.path.exists(
|
||||
self.config_file
|
||||
), f'Not find "{ModelFile.CONFIGURATION}" in model directory!'
|
||||
|
||||
self.cfg = Config.from_file(self.config_file)
|
||||
self.predict_op = self._build_predict_op()
|
||||
|
||||
def _build_predict_op(self):
|
||||
"""Build EasyCV predictor."""
|
||||
from easycv.predictors.builder import build_predictor
|
||||
|
||||
easycv_config = self._to_easycv_config()
|
||||
pipeline_op = build_predictor(self.cfg.pipeline.predictor_config, {
|
||||
'model_path': self.model_path,
|
||||
'config_file': easycv_config
|
||||
})
|
||||
return pipeline_op
|
||||
|
||||
def _to_easycv_config(self):
|
||||
"""Adapt to EasyCV predictor."""
|
||||
# TODO: refine config compatibility problems
|
||||
|
||||
easycv_arch = self.cfg.model.pop(EasyCVMeta.ARCH, None)
|
||||
model_cfg = self.cfg.model
|
||||
# Revert to the configuration of easycv
|
||||
if easycv_arch is not None:
|
||||
model_cfg.update(easycv_arch)
|
||||
|
||||
easycv_config = Config(dict(model=model_cfg))
|
||||
|
||||
reserved_keys = []
|
||||
if hasattr(self.cfg, EasyCVMeta.META):
|
||||
easycv_meta_cfg = getattr(self.cfg, EasyCVMeta.META)
|
||||
reserved_keys = easycv_meta_cfg.get(EasyCVMeta.RESERVED_KEYS, [])
|
||||
for key in reserved_keys:
|
||||
easycv_config.merge_from_dict({key: getattr(self.cfg, key)})
|
||||
if 'test_pipeline' not in reserved_keys:
|
||||
easycv_config.merge_from_dict(
|
||||
{'test_pipeline': self.cfg.dataset.val.get('pipeline', [])})
|
||||
|
||||
return easycv_config
|
||||
|
||||
def __call__(self, inputs) -> Any:
|
||||
# TODO: support image url
|
||||
return self.predict_op(inputs)
|
||||
@@ -0,0 +1,23 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.utils.constant import Tasks
|
||||
from .base import EasyCVPipeline
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.image_object_detection, module_name=Pipelines.easycv_detection)
|
||||
class EasyCVDetectionPipeline(EasyCVPipeline):
|
||||
"""Pipeline for easycv detection task."""
|
||||
|
||||
def __init__(self, model: str, model_file_pattern='*.pt', *args, **kwargs):
|
||||
"""
|
||||
model (str): model id on modelscope hub or local model path.
|
||||
model_file_pattern (str): model file pattern.
|
||||
"""
|
||||
|
||||
super(EasyCVDetectionPipeline, self).__init__(
|
||||
model=model,
|
||||
model_file_pattern=model_file_pattern,
|
||||
*args,
|
||||
**kwargs)
|
||||
@@ -0,0 +1,23 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.utils.constant import Tasks
|
||||
from .base import EasyCVPipeline
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.image_segmentation, module_name=Pipelines.easycv_segmentation)
|
||||
class EasyCVSegmentationPipeline(EasyCVPipeline):
|
||||
"""Pipeline for easycv segmentation task."""
|
||||
|
||||
def __init__(self, model: str, model_file_pattern='*.pt', *args, **kwargs):
|
||||
"""
|
||||
model (str): model id on modelscope hub or local model path.
|
||||
model_file_pattern (str): model file pattern.
|
||||
"""
|
||||
|
||||
super(EasyCVSegmentationPipeline, self).__init__(
|
||||
model=model,
|
||||
model_file_pattern=model_file_pattern,
|
||||
*args,
|
||||
**kwargs)
|
||||
0
modelscope/trainers/easycv/__init__.py
Normal file
0
modelscope/trainers/easycv/__init__.py
Normal file
175
modelscope/trainers/easycv/trainer.py
Normal file
175
modelscope/trainers/easycv/trainer.py
Normal file
@@ -0,0 +1,175 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from functools import partial
|
||||
from typing import Callable, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from modelscope.metainfo import Trainers
|
||||
from modelscope.models.base import TorchModel
|
||||
from modelscope.msdatasets import MsDataset
|
||||
from modelscope.preprocessors import Preprocessor
|
||||
from modelscope.trainers import EpochBasedTrainer
|
||||
from modelscope.trainers.base import TRAINERS
|
||||
from modelscope.trainers.easycv.utils import register_util
|
||||
from modelscope.trainers.hooks import HOOKS
|
||||
from modelscope.trainers.parallel.builder import build_parallel
|
||||
from modelscope.trainers.parallel.utils import is_parallel
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import DEFAULT_MODEL_REVISION
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
from modelscope.utils.registry import default_group
|
||||
|
||||
|
||||
@TRAINERS.register_module(module_name=Trainers.easycv)
|
||||
class EasyCVEpochBasedTrainer(EpochBasedTrainer):
|
||||
"""Epoch based Trainer for EasyCV.
|
||||
|
||||
Args:
|
||||
task: Task name.
|
||||
cfg_file(str): The config file of EasyCV.
|
||||
model (:obj:`torch.nn.Module` or :obj:`TorchModel` or `str`): The model to be run, or a valid model dir
|
||||
or a model id. If model is None, build_model method will be called.
|
||||
train_dataset (`MsDataset` or `torch.utils.data.Dataset`, *optional*):
|
||||
The dataset to use for training.
|
||||
Note that if it's a `torch.utils.data.IterableDataset` with some randomization and you are training in a
|
||||
distributed fashion, your iterable dataset should either use a internal attribute `generator` that is a
|
||||
`torch.Generator` for the randomization that must be identical on all processes (and the Trainer will
|
||||
manually set the seed of this `generator` at each epoch) or have a `set_epoch()` method that internally
|
||||
sets the seed of the RNGs used.
|
||||
eval_dataset (`MsDataset` or `torch.utils.data.Dataset`, *optional*): The dataset to use for evaluation.
|
||||
preprocessor (:obj:`Preprocessor`, *optional*): The optional preprocessor.
|
||||
NOTE: If the preprocessor has been called before the dataset fed into this trainer by user's custom code,
|
||||
this parameter should be None, meanwhile remove the 'preprocessor' key from the cfg_file.
|
||||
Else the preprocessor will be instantiated from the cfg_file or assigned from this parameter and
|
||||
this preprocessing action will be executed every time the dataset's __getitem__ is called.
|
||||
optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler._LRScheduler]`, *optional*): A tuple
|
||||
containing the optimizer and the scheduler to use.
|
||||
max_epochs: (int, optional): Total training epochs.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
task: str,
|
||||
cfg_file: Optional[str] = None,
|
||||
model: Optional[Union[TorchModel, nn.Module, str]] = None,
|
||||
arg_parse_fn: Optional[Callable] = None,
|
||||
train_dataset: Optional[Union[MsDataset, Dataset]] = None,
|
||||
eval_dataset: Optional[Union[MsDataset, Dataset]] = None,
|
||||
preprocessor: Optional[Preprocessor] = None,
|
||||
optimizers: Tuple[torch.optim.Optimizer,
|
||||
torch.optim.lr_scheduler._LRScheduler] = (None,
|
||||
None),
|
||||
model_revision: Optional[str] = DEFAULT_MODEL_REVISION,
|
||||
**kwargs):
|
||||
|
||||
self.task = task
|
||||
register_util.register_parallel()
|
||||
register_util.register_part_mmcv_hooks_to_ms()
|
||||
|
||||
super(EasyCVEpochBasedTrainer, self).__init__(
|
||||
model=model,
|
||||
cfg_file=cfg_file,
|
||||
arg_parse_fn=arg_parse_fn,
|
||||
preprocessor=preprocessor,
|
||||
optimizers=optimizers,
|
||||
model_revision=model_revision,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
**kwargs)
|
||||
|
||||
# reset data_collator
|
||||
from mmcv.parallel import collate
|
||||
|
||||
self.train_data_collator = partial(
|
||||
collate,
|
||||
samples_per_gpu=self.cfg.train.dataloader.batch_size_per_gpu)
|
||||
self.eval_data_collator = partial(
|
||||
collate,
|
||||
samples_per_gpu=self.cfg.evaluation.dataloader.batch_size_per_gpu)
|
||||
|
||||
# Register easycv hooks dynamicly. If the hook already exists in modelscope,
|
||||
# the hook in modelscope will be used, otherwise register easycv hook into ms.
|
||||
# We must manually trigger lazy import to detect whether the hook is in modelscope.
|
||||
# TODO: use ast index to detect whether the hook is in modelscope
|
||||
for h_i in self.cfg.train.get('hooks', []):
|
||||
sig = ('HOOKS', default_group, h_i['type'])
|
||||
LazyImportModule.import_module(sig)
|
||||
if h_i['type'] not in HOOKS._modules[default_group]:
|
||||
if h_i['type'] in [
|
||||
'TensorboardLoggerHookV2', 'WandbLoggerHookV2'
|
||||
]:
|
||||
raise ValueError(
|
||||
'Not support hook %s now, we will support it in the future!'
|
||||
% h_i['type'])
|
||||
register_util.register_hook_to_ms(h_i['type'], self.logger)
|
||||
|
||||
# reset parallel
|
||||
if not self._dist:
|
||||
assert not is_parallel(
|
||||
self.model
|
||||
), 'Not support model wrapped by custom parallel if not in distributed mode!'
|
||||
dp_cfg = dict(
|
||||
type='MMDataParallel',
|
||||
module=self.model,
|
||||
device_ids=[torch.cuda.current_device()])
|
||||
self.model = build_parallel(dp_cfg)
|
||||
|
||||
def create_optimizer_and_scheduler(self):
|
||||
""" Create optimizer and lr scheduler
|
||||
"""
|
||||
optimizer, lr_scheduler = self.optimizers
|
||||
if optimizer is None:
|
||||
optimizer_cfg = self.cfg.train.get('optimizer', None)
|
||||
else:
|
||||
optimizer_cfg = None
|
||||
|
||||
optim_options = {}
|
||||
if optimizer_cfg is not None:
|
||||
optim_options = optimizer_cfg.pop('options', {})
|
||||
from easycv.apis.train import build_optimizer
|
||||
optimizer = build_optimizer(self.model, optimizer_cfg)
|
||||
|
||||
if lr_scheduler is None:
|
||||
lr_scheduler_cfg = self.cfg.train.get('lr_scheduler', None)
|
||||
else:
|
||||
lr_scheduler_cfg = None
|
||||
|
||||
lr_options = {}
|
||||
# Adapt to mmcv lr scheduler hook.
|
||||
# Please refer to: https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/hooks/lr_updater.py
|
||||
if lr_scheduler_cfg is not None:
|
||||
assert optimizer is not None
|
||||
lr_options = lr_scheduler_cfg.pop('options', {})
|
||||
assert 'policy' in lr_scheduler_cfg
|
||||
policy_type = lr_scheduler_cfg.pop('policy')
|
||||
if policy_type == policy_type.lower():
|
||||
policy_type = policy_type.title()
|
||||
hook_type = policy_type + 'LrUpdaterHook'
|
||||
lr_scheduler_cfg['type'] = hook_type
|
||||
|
||||
self.cfg.train.lr_scheduler_hook = lr_scheduler_cfg
|
||||
|
||||
self.optimizer = optimizer
|
||||
self.lr_scheduler = lr_scheduler
|
||||
|
||||
return self.optimizer, self.lr_scheduler, optim_options, lr_options
|
||||
|
||||
def to_parallel(self, model) -> Union[nn.Module, TorchModel]:
|
||||
if self.cfg.get('parallel', None) is not None:
|
||||
self.cfg.parallel.update(
|
||||
dict(module=model, device_ids=[torch.cuda.current_device()]))
|
||||
return build_parallel(self.cfg.parallel)
|
||||
|
||||
dp_cfg = dict(
|
||||
type='MMDistributedDataParallel',
|
||||
module=model,
|
||||
device_ids=[torch.cuda.current_device()])
|
||||
|
||||
return build_parallel(dp_cfg)
|
||||
|
||||
def rebuild_config(self, cfg: Config):
|
||||
cfg.task = self.task
|
||||
|
||||
return cfg
|
||||
21
modelscope/trainers/easycv/utils/__init__.py
Normal file
21
modelscope/trainers/easycv/utils/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .hooks import AddLrLogHook
|
||||
from .metric import EasyCVMetric
|
||||
|
||||
else:
|
||||
_import_structure = {'hooks': ['AddLrLogHook'], 'metric': ['EasyCVMetric']}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
29
modelscope/trainers/easycv/utils/hooks.py
Normal file
29
modelscope/trainers/easycv/utils/hooks.py
Normal file
@@ -0,0 +1,29 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from modelscope.trainers.hooks import HOOKS, Priority
|
||||
from modelscope.trainers.hooks.lr_scheduler_hook import LrSchedulerHook
|
||||
from modelscope.utils.constant import LogKeys
|
||||
|
||||
|
||||
@HOOKS.register_module(module_name='AddLrLogHook')
|
||||
class AddLrLogHook(LrSchedulerHook):
|
||||
"""For EasyCV to adapt to ModelScope, the lr log of EasyCV is added in the trainer,
|
||||
but the trainer of ModelScope does not and it is added in the lr scheduler hook.
|
||||
But The lr scheduler hook used by EasyCV is the hook of mmcv, and there is no lr log.
|
||||
It will be deleted in the future.
|
||||
"""
|
||||
PRIORITY = Priority.NORMAL
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def before_run(self, trainer):
|
||||
pass
|
||||
|
||||
def before_train_iter(self, trainer):
|
||||
trainer.log_buffer.output[LogKeys.LR] = self._get_log_lr(trainer)
|
||||
|
||||
def before_train_epoch(self, trainer):
|
||||
trainer.log_buffer.output[LogKeys.LR] = self._get_log_lr(trainer)
|
||||
|
||||
def after_train_epoch(self, trainer):
|
||||
pass
|
||||
52
modelscope/trainers/easycv/utils/metric.py
Normal file
52
modelscope/trainers/easycv/utils/metric.py
Normal file
@@ -0,0 +1,52 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import itertools
|
||||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from modelscope.metrics.base import Metric
|
||||
from modelscope.metrics.builder import METRICS
|
||||
|
||||
|
||||
@METRICS.register_module(module_name='EasyCVMetric')
|
||||
class EasyCVMetric(Metric):
|
||||
"""Adapt to ModelScope Metric for EasyCV evaluator.
|
||||
"""
|
||||
|
||||
def __init__(self, trainer=None, evaluators=None, *args, **kwargs):
|
||||
from easycv.core.evaluation.builder import build_evaluator
|
||||
|
||||
self.trainer = trainer
|
||||
self.evaluators = build_evaluator(evaluators)
|
||||
self.preds = []
|
||||
self.grountruths = []
|
||||
|
||||
def add(self, outputs: Dict, inputs: Dict):
|
||||
self.preds.append(outputs)
|
||||
del inputs
|
||||
|
||||
def evaluate(self):
|
||||
results = {}
|
||||
for _, batch in enumerate(self.preds):
|
||||
for k, v in batch.items():
|
||||
if k not in results:
|
||||
results[k] = []
|
||||
results[k].append(v)
|
||||
|
||||
for k, v in results.items():
|
||||
if len(v) == 0:
|
||||
raise ValueError(f'empty result for {k}')
|
||||
|
||||
if isinstance(v[0], torch.Tensor):
|
||||
results[k] = torch.cat(v, 0)
|
||||
elif isinstance(v[0], (list, np.ndarray)):
|
||||
results[k] = list(itertools.chain.from_iterable(v))
|
||||
else:
|
||||
raise ValueError(
|
||||
f'value of batch prediction dict should only be tensor or list, {k} type is {v[0]}'
|
||||
)
|
||||
|
||||
metric_values = self.trainer.eval_dataset.evaluate(
|
||||
results, self.evaluators)
|
||||
return metric_values
|
||||
59
modelscope/trainers/easycv/utils/register_util.py
Normal file
59
modelscope/trainers/easycv/utils/register_util.py
Normal file
@@ -0,0 +1,59 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import inspect
|
||||
import logging
|
||||
|
||||
from modelscope.trainers.hooks import HOOKS
|
||||
from modelscope.trainers.parallel.builder import PARALLEL
|
||||
|
||||
|
||||
def register_parallel():
|
||||
from mmcv.parallel import MMDistributedDataParallel, MMDataParallel
|
||||
|
||||
PARALLEL.register_module(
|
||||
module_name='MMDistributedDataParallel',
|
||||
module_cls=MMDistributedDataParallel)
|
||||
PARALLEL.register_module(
|
||||
module_name='MMDataParallel', module_cls=MMDataParallel)
|
||||
|
||||
|
||||
def register_hook_to_ms(hook_name, logger=None):
|
||||
"""Register EasyCV hook to ModelScope."""
|
||||
from easycv.hooks import HOOKS as _EV_HOOKS
|
||||
|
||||
if hook_name not in _EV_HOOKS._module_dict:
|
||||
raise ValueError(
|
||||
f'Not found hook "{hook_name}" in EasyCV hook registries!')
|
||||
|
||||
obj = _EV_HOOKS._module_dict[hook_name]
|
||||
HOOKS.register_module(module_name=hook_name, module_cls=obj)
|
||||
|
||||
log_str = f'Register hook "{hook_name}" to modelscope hooks.'
|
||||
logger.info(log_str) if logger is not None else logging.info(log_str)
|
||||
|
||||
|
||||
def register_part_mmcv_hooks_to_ms():
|
||||
"""Register required mmcv hooks to ModelScope.
|
||||
Currently we only registered all lr scheduler hooks in EasyCV and mmcv.
|
||||
Please refer to:
|
||||
EasyCV: https://github.com/alibaba/EasyCV/blob/master/easycv/hooks/lr_update_hook.py
|
||||
mmcv: https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/hooks/lr_updater.py
|
||||
"""
|
||||
from mmcv.runner.hooks import lr_updater
|
||||
from mmcv.runner.hooks import HOOKS as _MMCV_HOOKS
|
||||
from easycv.hooks import StepFixCosineAnnealingLrUpdaterHook, YOLOXLrUpdaterHook
|
||||
from easycv.hooks.logger import PreLoggerHook
|
||||
|
||||
mmcv_hooks_in_easycv = [('StepFixCosineAnnealingLrUpdaterHook',
|
||||
StepFixCosineAnnealingLrUpdaterHook),
|
||||
('YOLOXLrUpdaterHook', YOLOXLrUpdaterHook),
|
||||
('PreLoggerHook', PreLoggerHook)]
|
||||
|
||||
members = inspect.getmembers(lr_updater)
|
||||
members.extend(mmcv_hooks_in_easycv)
|
||||
|
||||
for name, obj in members:
|
||||
if name in _MMCV_HOOKS._module_dict:
|
||||
HOOKS.register_module(
|
||||
module_name=name,
|
||||
module_cls=obj,
|
||||
)
|
||||
@@ -81,12 +81,19 @@ class CheckpointHook(Hook):
|
||||
if self.is_last_epoch(trainer) and self.by_epoch:
|
||||
output_dir = os.path.join(self.save_dir,
|
||||
ModelFile.TRAIN_OUTPUT_DIR)
|
||||
from modelscope.trainers.parallel.utils import is_parallel
|
||||
|
||||
trainer.model.save_pretrained(
|
||||
output_dir,
|
||||
ModelFile.TORCH_MODEL_BIN_FILE,
|
||||
save_function=save_checkpoint,
|
||||
config=trainer.cfg.to_dict())
|
||||
if is_parallel(trainer.model):
|
||||
model = trainer.model.module
|
||||
else:
|
||||
model = trainer.model
|
||||
|
||||
if hasattr(model, 'save_pretrained'):
|
||||
model.save_pretrained(
|
||||
output_dir,
|
||||
ModelFile.TORCH_MODEL_BIN_FILE,
|
||||
save_function=save_checkpoint,
|
||||
config=trainer.cfg.to_dict())
|
||||
|
||||
def after_train_iter(self, trainer):
|
||||
if self.by_epoch:
|
||||
|
||||
@@ -60,6 +60,18 @@ class LoggerHook(Hook):
|
||||
else:
|
||||
return False
|
||||
|
||||
def fetch_tensor(self, trainer, n=0):
|
||||
"""Fetch latest n values or all values, process tensor type, convert to numpy for dump logs."""
|
||||
assert n >= 0
|
||||
for key in trainer.log_buffer.val_history:
|
||||
values = trainer.log_buffer.val_history[key][-n:]
|
||||
|
||||
for i, v in enumerate(values):
|
||||
if isinstance(v, torch.Tensor):
|
||||
values[i] = v.clone().detach().cpu().numpy()
|
||||
|
||||
trainer.log_buffer.val_history[key][-n:] = values
|
||||
|
||||
def get_epoch(self, trainer):
|
||||
if trainer.mode in [ModeKeys.TRAIN, ModeKeys.EVAL]:
|
||||
epoch = trainer.epoch + 1
|
||||
@@ -88,11 +100,14 @@ class LoggerHook(Hook):
|
||||
|
||||
def after_train_iter(self, trainer):
|
||||
if self.by_epoch and self.every_n_inner_iters(trainer, self.interval):
|
||||
self.fetch_tensor(trainer, self.interval)
|
||||
trainer.log_buffer.average(self.interval)
|
||||
elif not self.by_epoch and self.every_n_iters(trainer, self.interval):
|
||||
self.fetch_tensor(trainer, self.interval)
|
||||
trainer.log_buffer.average(self.interval)
|
||||
elif self.end_of_epoch(trainer) and not self.ignore_last:
|
||||
# not precise but more stable
|
||||
self.fetch_tensor(trainer, self.interval)
|
||||
trainer.log_buffer.average(self.interval)
|
||||
|
||||
if trainer.log_buffer.ready:
|
||||
@@ -107,6 +122,7 @@ class LoggerHook(Hook):
|
||||
trainer.log_buffer.clear_output()
|
||||
|
||||
def after_val_epoch(self, trainer):
|
||||
self.fetch_tensor(trainer)
|
||||
trainer.log_buffer.average()
|
||||
self.log(trainer)
|
||||
if self.reset_flag:
|
||||
|
||||
@@ -26,7 +26,6 @@ from modelscope.msdatasets.task_datasets.torch_base_dataset import \
|
||||
TorchTaskDataset
|
||||
from modelscope.preprocessors.base import Preprocessor
|
||||
from modelscope.preprocessors.builder import build_preprocessor
|
||||
from modelscope.preprocessors.common import Compose
|
||||
from modelscope.trainers.hooks.builder import HOOKS
|
||||
from modelscope.trainers.hooks.priority import Priority, get_priority
|
||||
from modelscope.trainers.lrscheduler.builder import build_lr_scheduler
|
||||
@@ -83,7 +82,8 @@ class EpochBasedTrainer(BaseTrainer):
|
||||
model: Optional[Union[TorchModel, nn.Module, str]] = None,
|
||||
cfg_file: Optional[str] = None,
|
||||
arg_parse_fn: Optional[Callable] = None,
|
||||
data_collator: Optional[Callable] = None,
|
||||
data_collator: Optional[Union[Callable, Dict[str,
|
||||
Callable]]] = None,
|
||||
train_dataset: Optional[Union[MsDataset, Dataset]] = None,
|
||||
eval_dataset: Optional[Union[MsDataset, Dataset]] = None,
|
||||
preprocessor: Optional[Union[Preprocessor,
|
||||
@@ -104,21 +104,24 @@ class EpochBasedTrainer(BaseTrainer):
|
||||
if cfg_file is None:
|
||||
cfg_file = os.path.join(self.model_dir,
|
||||
ModelFile.CONFIGURATION)
|
||||
self.model = self.build_model()
|
||||
else:
|
||||
assert cfg_file is not None, 'Config file should not be None if model is an nn.Module class'
|
||||
assert isinstance(
|
||||
model,
|
||||
(TorchModel, nn.Module
|
||||
)), 'model should be either str, TorchMode or nn.Module.'
|
||||
assert cfg_file is not None, 'Config file should not be None if model is not from pretrained!'
|
||||
self.model_dir = os.path.dirname(cfg_file)
|
||||
self.model = model
|
||||
|
||||
super().__init__(cfg_file, arg_parse_fn)
|
||||
|
||||
# add default config
|
||||
self.cfg.merge_from_dict(self._get_default_config(), force=False)
|
||||
self.cfg = self.rebuild_config(self.cfg)
|
||||
|
||||
if 'cfg_options' in kwargs:
|
||||
self.cfg.merge_from_dict(kwargs['cfg_options'])
|
||||
|
||||
if isinstance(model, (TorchModel, nn.Module)):
|
||||
self.model = model
|
||||
else:
|
||||
self.model = self.build_model()
|
||||
|
||||
if 'work_dir' in kwargs:
|
||||
self.work_dir = kwargs['work_dir']
|
||||
else:
|
||||
@@ -162,7 +165,24 @@ class EpochBasedTrainer(BaseTrainer):
|
||||
mode=ModeKeys.EVAL,
|
||||
preprocessor=self.eval_preprocessor)
|
||||
|
||||
self.data_collator = data_collator if data_collator is not None else default_collate
|
||||
self.train_data_collator, self.eval_default_collate = None, None
|
||||
if isinstance(data_collator, Mapping):
|
||||
if not (ConfigKeys.train in data_collator
|
||||
or ConfigKeys.val in data_collator):
|
||||
raise ValueError(
|
||||
f'data_collator must split with `{ConfigKeys.train}` and `{ConfigKeys.val}` keys!'
|
||||
)
|
||||
if ConfigKeys.train in data_collator:
|
||||
assert isinstance(data_collator[ConfigKeys.train], Callable)
|
||||
self.train_data_collator = data_collator[ConfigKeys.train]
|
||||
if ConfigKeys.val in data_collator:
|
||||
assert isinstance(data_collator[ConfigKeys.val], Callable)
|
||||
self.eval_data_collator = data_collator[ConfigKeys.val]
|
||||
else:
|
||||
collate_fn = default_collate if data_collator is None else data_collator
|
||||
self.train_data_collator = collate_fn
|
||||
self.eval_data_collator = collate_fn
|
||||
|
||||
self.metrics = self.get_metrics()
|
||||
self._metric_values = None
|
||||
self.optimizers = optimizers
|
||||
@@ -364,7 +384,7 @@ class EpochBasedTrainer(BaseTrainer):
|
||||
|
||||
return train_preprocessor, eval_preprocessor
|
||||
|
||||
def get_metrics(self) -> List[str]:
|
||||
def get_metrics(self) -> List[Union[str, Dict]]:
|
||||
"""Get the metric class types.
|
||||
|
||||
The first choice will be the metrics configured in the config file, if not found, the default metrics will be
|
||||
@@ -384,7 +404,7 @@ class EpochBasedTrainer(BaseTrainer):
|
||||
f'Metrics are needed in evaluation, please try to either '
|
||||
f'add metrics in configuration.json or add the default metric for {self.cfg.task}.'
|
||||
)
|
||||
if isinstance(metrics, str):
|
||||
if isinstance(metrics, (str, Mapping)):
|
||||
metrics = [metrics]
|
||||
return metrics
|
||||
|
||||
@@ -399,6 +419,7 @@ class EpochBasedTrainer(BaseTrainer):
|
||||
self.train_dataset,
|
||||
dist=self._dist,
|
||||
seed=self._seed,
|
||||
collate_fn=self.train_data_collator,
|
||||
**self.cfg.train.get('dataloader', {}))
|
||||
self.data_loader = self.train_dataloader
|
||||
|
||||
@@ -418,6 +439,7 @@ class EpochBasedTrainer(BaseTrainer):
|
||||
self.eval_dataset,
|
||||
dist=self._dist,
|
||||
seed=self._seed,
|
||||
collate_fn=self.eval_data_collator,
|
||||
**self.cfg.evaluation.get('dataloader', {}))
|
||||
self.data_loader = self.eval_dataloader
|
||||
metric_classes = [build_metric(metric) for metric in self.metrics]
|
||||
@@ -440,7 +462,7 @@ class EpochBasedTrainer(BaseTrainer):
|
||||
override this method in a subclass.
|
||||
|
||||
"""
|
||||
model = Model.from_pretrained(self.model_dir)
|
||||
model = Model.from_pretrained(self.model_dir, cfg_dict=self.cfg)
|
||||
if not isinstance(model, nn.Module) and hasattr(model, 'model'):
|
||||
return model.model
|
||||
elif isinstance(model, nn.Module):
|
||||
@@ -552,6 +574,7 @@ class EpochBasedTrainer(BaseTrainer):
|
||||
self.train_dataset,
|
||||
dist=self._dist,
|
||||
seed=self._seed,
|
||||
collate_fn=self.train_data_collator,
|
||||
**self.cfg.train.get('dataloader', {}))
|
||||
return data_loader
|
||||
|
||||
@@ -569,9 +592,9 @@ class EpochBasedTrainer(BaseTrainer):
|
||||
mode=ModeKeys.EVAL,
|
||||
preprocessor=self.eval_preprocessor)
|
||||
|
||||
batch_size = self.cfg.evaluation.batch_size
|
||||
workers = self.cfg.evaluation.workers
|
||||
shuffle = self.cfg.evaluation.get('shuffle', False)
|
||||
batch_size = self.cfg.evaluation.dataloader.batch_size_per_gpu
|
||||
workers = self.cfg.evaluation.dataloader.workers_per_gpu
|
||||
shuffle = self.cfg.evaluation.dataloader.get('shuffle', False)
|
||||
data_loader = self._build_dataloader_with_dataset(
|
||||
self.eval_dataset,
|
||||
batch_size_per_gpu=batch_size,
|
||||
@@ -580,25 +603,31 @@ class EpochBasedTrainer(BaseTrainer):
|
||||
dist=self._dist,
|
||||
seed=self._seed,
|
||||
persistent_workers=True,
|
||||
collate_fn=self.eval_data_collator,
|
||||
)
|
||||
return data_loader
|
||||
|
||||
def build_dataset(self, data_cfg, mode, preprocessor=None):
|
||||
""" Build torch dataset object using data config
|
||||
"""
|
||||
dataset = MsDataset.load(
|
||||
dataset_name=data_cfg.name,
|
||||
split=data_cfg.split,
|
||||
subset_name=data_cfg.subset_name if hasattr(
|
||||
data_cfg, 'subset_name') else None,
|
||||
hub=data_cfg.hub if hasattr(data_cfg, 'hub') else Hubs.modelscope,
|
||||
**data_cfg,
|
||||
)
|
||||
cfg = ConfigDict(type=self.cfg.model.type, mode=mode)
|
||||
torch_dataset = dataset.to_torch_dataset(
|
||||
task_data_config=cfg,
|
||||
task_name=self.cfg.task,
|
||||
preprocessors=self.preprocessor)
|
||||
# TODO: support MsDataset load for cv
|
||||
if hasattr(data_cfg, 'name'):
|
||||
dataset = MsDataset.load(
|
||||
dataset_name=data_cfg.name,
|
||||
split=data_cfg.split,
|
||||
subset_name=data_cfg.subset_name if hasattr(
|
||||
data_cfg, 'subset_name') else None,
|
||||
hub=data_cfg.hub
|
||||
if hasattr(data_cfg, 'hub') else Hubs.modelscope,
|
||||
**data_cfg,
|
||||
)
|
||||
cfg = ConfigDict(type=self.cfg.model.type, mode=mode)
|
||||
torch_dataset = dataset.to_torch_dataset(
|
||||
task_data_config=cfg,
|
||||
task_name=self.cfg.task,
|
||||
preprocessors=self.preprocessor)
|
||||
else:
|
||||
torch_dataset = build_task_dataset(data_cfg, self.cfg.task)
|
||||
dataset = self.to_task_dataset(torch_dataset, mode)
|
||||
return dataset
|
||||
|
||||
@@ -746,7 +775,6 @@ class EpochBasedTrainer(BaseTrainer):
|
||||
sampler=sampler,
|
||||
num_workers=num_workers,
|
||||
batch_sampler=batch_sampler,
|
||||
collate_fn=self.data_collator,
|
||||
pin_memory=kwargs.pop('pin_memory', False),
|
||||
worker_init_fn=init_fn,
|
||||
**kwargs)
|
||||
@@ -820,12 +848,14 @@ class EpochBasedTrainer(BaseTrainer):
|
||||
Args:
|
||||
hook (:obj:`Hook`): The hook to be registered.
|
||||
"""
|
||||
assert isinstance(hook, Hook)
|
||||
# insert the hook to a sorted list
|
||||
inserted = False
|
||||
for i in range(len(self._hooks) - 1, -1, -1):
|
||||
if get_priority(hook.PRIORITY) > get_priority(
|
||||
self._hooks[i].PRIORITY):
|
||||
p = hook.PRIORITY if hasattr(hook, 'PRIORITY') else Priority.NORMAL
|
||||
p_i = self._hooks[i].PRIORITY if hasattr(
|
||||
self._hooks[i], 'PRIORITY') else Priority.NORMAL
|
||||
|
||||
if get_priority(p) > get_priority(p_i):
|
||||
self._hooks.insert(i + 1, hook)
|
||||
inserted = True
|
||||
break
|
||||
|
||||
@@ -15,9 +15,9 @@ import json
|
||||
|
||||
from modelscope import __version__
|
||||
from modelscope.fileio.file import LocalStorage
|
||||
from modelscope.metainfo import (Heads, Hooks, LR_Schedulers, Metrics, Models,
|
||||
Optimizers, Pipelines, Preprocessors,
|
||||
TaskModels, Trainers)
|
||||
from modelscope.metainfo import (Datasets, Heads, Hooks, LR_Schedulers,
|
||||
Metrics, Models, Optimizers, Pipelines,
|
||||
Preprocessors, TaskModels, Trainers)
|
||||
from modelscope.utils.constant import Fields, Tasks
|
||||
from modelscope.utils.file_utils import get_default_cache_dir
|
||||
from modelscope.utils.logger import get_logger
|
||||
@@ -32,8 +32,7 @@ MODELSCOPE_PATH = p.resolve().parents[1]
|
||||
REGISTER_MODULE = 'register_module'
|
||||
IGNORED_PACKAGES = ['modelscope', '.']
|
||||
SCAN_SUB_FOLDERS = [
|
||||
'models', 'metrics', 'pipelines', 'preprocessors',
|
||||
'msdatasets/task_datasets', 'trainers'
|
||||
'models', 'metrics', 'pipelines', 'preprocessors', 'trainers', 'msdatasets'
|
||||
]
|
||||
INDEXER_FILE = 'ast_indexer'
|
||||
DECORATOR_KEY = 'decorators'
|
||||
|
||||
@@ -14,7 +14,7 @@ mmcls>=0.21.0
|
||||
mmdet>=2.25.0
|
||||
networkx>=2.5
|
||||
onnxruntime>=1.10
|
||||
pai-easycv>=0.5
|
||||
pai-easycv>=0.6.0
|
||||
pandas
|
||||
psutil
|
||||
regex
|
||||
|
||||
@@ -4,6 +4,7 @@ easydict
|
||||
einops
|
||||
filelock>=3.3.0
|
||||
gast>=0.2.2
|
||||
jsonplus
|
||||
numpy
|
||||
opencv-python
|
||||
oss2
|
||||
|
||||
0
tests/pipelines/easycv_pipelines/__init__.py
Normal file
0
tests/pipelines/easycv_pipelines/__init__.py
Normal file
@@ -0,0 +1,35 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
|
||||
class EasyCVSegmentationPipelineTest(unittest.TestCase):
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_segformer_b0(self):
|
||||
img_path = 'data/test/images/image_segmentation.jpg'
|
||||
model_id = 'EasyCV/EasyCV-Segformer-b0'
|
||||
img = np.asarray(Image.open(img_path))
|
||||
|
||||
object_detect = pipeline(task=Tasks.image_segmentation, model=model_id)
|
||||
outputs = object_detect(img_path)
|
||||
self.assertEqual(len(outputs), 1)
|
||||
|
||||
results = outputs[0]
|
||||
self.assertListEqual(
|
||||
list(img.shape)[:2], list(results['seg_pred'][0].shape))
|
||||
self.assertListEqual(results['seg_pred'][0][1, :10].tolist(),
|
||||
[161 for i in range(10)])
|
||||
self.assertListEqual(results['seg_pred'][0][-1, -10:].tolist(),
|
||||
[133 for i in range(10)])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
0
tests/trainers/easycv/__init__.py
Normal file
0
tests/trainers/easycv/__init__.py
Normal file
244
tests/trainers/easycv/test_easycv_trainer.py
Normal file
244
tests/trainers/easycv/test_easycv_trainer.py
Normal file
@@ -0,0 +1,244 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import glob
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import json
|
||||
import requests
|
||||
import torch
|
||||
|
||||
from modelscope.metainfo import Models, Pipelines, Trainers
|
||||
from modelscope.trainers import build_trainer
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import LogKeys, ModeKeys, Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
from modelscope.utils.test_utils import DistributedTestCase, test_level
|
||||
from modelscope.utils.torch_utils import is_master
|
||||
|
||||
|
||||
def _download_data(url, save_dir):
|
||||
r = requests.get(url, verify=True)
|
||||
if not os.path.exists(save_dir):
|
||||
os.makedirs(save_dir)
|
||||
zip_name = os.path.split(url)[-1]
|
||||
save_path = os.path.join(save_dir, zip_name)
|
||||
with open(save_path, 'wb') as f:
|
||||
f.write(r.content)
|
||||
|
||||
unpack_dir = os.path.join(save_dir, os.path.splitext(zip_name)[0])
|
||||
shutil.unpack_archive(save_path, unpack_dir)
|
||||
|
||||
|
||||
def train_func(work_dir, dist=False, log_config=3, imgs_per_gpu=4):
|
||||
import easycv
|
||||
config_path = os.path.join(
|
||||
os.path.dirname(easycv.__file__),
|
||||
'configs/detection/yolox/yolox_s_8xb16_300e_coco.py')
|
||||
|
||||
data_dir = os.path.join(work_dir, 'small_coco_test')
|
||||
url = 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/datasets/small_coco.zip'
|
||||
if is_master():
|
||||
_download_data(url, data_dir)
|
||||
|
||||
import time
|
||||
time.sleep(1)
|
||||
cfg = Config.from_file(config_path)
|
||||
|
||||
cfg.work_dir = work_dir
|
||||
cfg.total_epochs = 2
|
||||
cfg.checkpoint_config.interval = 1
|
||||
cfg.eval_config.interval = 1
|
||||
cfg.log_config = dict(
|
||||
interval=log_config,
|
||||
hooks=[
|
||||
dict(type='TextLoggerHook'),
|
||||
dict(type='TensorboardLoggerHook')
|
||||
])
|
||||
cfg.data.train.data_source.ann_file = os.path.join(
|
||||
data_dir, 'small_coco/small_coco/instances_train2017_20.json')
|
||||
cfg.data.train.data_source.img_prefix = os.path.join(
|
||||
data_dir, 'small_coco/small_coco/train2017')
|
||||
cfg.data.val.data_source.ann_file = os.path.join(
|
||||
data_dir, 'small_coco/small_coco/instances_val2017_20.json')
|
||||
cfg.data.val.data_source.img_prefix = os.path.join(
|
||||
data_dir, 'small_coco/small_coco/val2017')
|
||||
cfg.data.imgs_per_gpu = imgs_per_gpu
|
||||
cfg.data.workers_per_gpu = 2
|
||||
cfg.data.val.imgs_per_gpu = 2
|
||||
|
||||
ms_cfg_file = os.path.join(work_dir, 'ms_yolox_s_8xb16_300e_coco.json')
|
||||
from easycv.utils.ms_utils import to_ms_config
|
||||
|
||||
if is_master():
|
||||
to_ms_config(
|
||||
cfg,
|
||||
dump=True,
|
||||
task=Tasks.image_object_detection,
|
||||
ms_model_name=Models.yolox,
|
||||
pipeline_name=Pipelines.easycv_detection,
|
||||
save_path=ms_cfg_file)
|
||||
|
||||
trainer_name = Trainers.easycv
|
||||
kwargs = dict(
|
||||
task=Tasks.image_object_detection,
|
||||
cfg_file=ms_cfg_file,
|
||||
launcher='pytorch' if dist else None)
|
||||
|
||||
trainer = build_trainer(trainer_name, kwargs)
|
||||
trainer.train()
|
||||
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), 'cuda unittest')
|
||||
class EasyCVTrainerTestSingleGpu(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.logger = get_logger()
|
||||
self.logger.info(('Testing %s.%s' %
|
||||
(type(self).__name__, self._testMethodName)))
|
||||
self.tmp_dir = tempfile.TemporaryDirectory().name
|
||||
if not os.path.exists(self.tmp_dir):
|
||||
os.makedirs(self.tmp_dir)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
shutil.rmtree(self.tmp_dir, ignore_errors=True)
|
||||
|
||||
@unittest.skipIf(
|
||||
True, 'The test cases are all run in the master process, '
|
||||
'cause registry conflicts, and it should run in the subprocess.')
|
||||
def test_single_gpu(self):
|
||||
# TODO: run in subprocess
|
||||
train_func(self.tmp_dir)
|
||||
|
||||
results_files = os.listdir(self.tmp_dir)
|
||||
json_files = glob.glob(os.path.join(self.tmp_dir, '*.log.json'))
|
||||
self.assertEqual(len(json_files), 1)
|
||||
|
||||
with open(json_files[0], 'r') as f:
|
||||
lines = [i.strip() for i in f.readlines()]
|
||||
|
||||
self.assertDictContainsSubset(
|
||||
{
|
||||
LogKeys.MODE: ModeKeys.TRAIN,
|
||||
LogKeys.EPOCH: 1,
|
||||
LogKeys.ITER: 3,
|
||||
LogKeys.LR: 0.00013
|
||||
}, json.loads(lines[0]))
|
||||
self.assertDictContainsSubset(
|
||||
{
|
||||
LogKeys.MODE: ModeKeys.EVAL,
|
||||
LogKeys.EPOCH: 1,
|
||||
LogKeys.ITER: 10
|
||||
}, json.loads(lines[1]))
|
||||
self.assertDictContainsSubset(
|
||||
{
|
||||
LogKeys.MODE: ModeKeys.TRAIN,
|
||||
LogKeys.EPOCH: 2,
|
||||
LogKeys.ITER: 3,
|
||||
LogKeys.LR: 0.00157
|
||||
}, json.loads(lines[2]))
|
||||
self.assertDictContainsSubset(
|
||||
{
|
||||
LogKeys.MODE: ModeKeys.EVAL,
|
||||
LogKeys.EPOCH: 2,
|
||||
LogKeys.ITER: 10
|
||||
}, json.loads(lines[3]))
|
||||
self.assertIn(f'{LogKeys.EPOCH}_1.pth', results_files)
|
||||
self.assertIn(f'{LogKeys.EPOCH}_2.pth', results_files)
|
||||
for i in [0, 2]:
|
||||
self.assertIn(LogKeys.DATA_LOAD_TIME, lines[i])
|
||||
self.assertIn(LogKeys.ITER_TIME, lines[i])
|
||||
self.assertIn(LogKeys.MEMORY, lines[i])
|
||||
self.assertIn('total_loss', lines[i])
|
||||
for i in [1, 3]:
|
||||
self.assertIn(
|
||||
'CocoDetectionEvaluator_DetectionBoxes_Precision/mAP',
|
||||
lines[i])
|
||||
self.assertIn('DetectionBoxes_Precision/mAP', lines[i])
|
||||
self.assertIn('DetectionBoxes_Precision/mAP@.50IOU', lines[i])
|
||||
self.assertIn('DetectionBoxes_Precision/mAP@.75IOU', lines[i])
|
||||
self.assertIn('DetectionBoxes_Precision/mAP (small)', lines[i])
|
||||
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available()
|
||||
or torch.cuda.device_count() <= 1, 'distributed unittest')
|
||||
class EasyCVTrainerTestMultiGpus(DistributedTestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.logger = get_logger()
|
||||
self.logger.info(('Testing %s.%s' %
|
||||
(type(self).__name__, self._testMethodName)))
|
||||
self.tmp_dir = tempfile.TemporaryDirectory().name
|
||||
if not os.path.exists(self.tmp_dir):
|
||||
os.makedirs(self.tmp_dir)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
shutil.rmtree(self.tmp_dir, ignore_errors=True)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||
def test_multi_gpus(self):
|
||||
self.start(
|
||||
train_func,
|
||||
num_gpus=2,
|
||||
work_dir=self.tmp_dir,
|
||||
dist=True,
|
||||
log_config=2,
|
||||
imgs_per_gpu=5)
|
||||
|
||||
results_files = os.listdir(self.tmp_dir)
|
||||
json_files = glob.glob(os.path.join(self.tmp_dir, '*.log.json'))
|
||||
self.assertEqual(len(json_files), 1)
|
||||
|
||||
with open(json_files[0], 'r') as f:
|
||||
lines = [i.strip() for i in f.readlines()]
|
||||
|
||||
self.assertDictContainsSubset(
|
||||
{
|
||||
LogKeys.MODE: ModeKeys.TRAIN,
|
||||
LogKeys.EPOCH: 1,
|
||||
LogKeys.ITER: 2,
|
||||
LogKeys.LR: 0.0002
|
||||
}, json.loads(lines[0]))
|
||||
self.assertDictContainsSubset(
|
||||
{
|
||||
LogKeys.MODE: ModeKeys.EVAL,
|
||||
LogKeys.EPOCH: 1,
|
||||
LogKeys.ITER: 5
|
||||
}, json.loads(lines[1]))
|
||||
self.assertDictContainsSubset(
|
||||
{
|
||||
LogKeys.MODE: ModeKeys.TRAIN,
|
||||
LogKeys.EPOCH: 2,
|
||||
LogKeys.ITER: 2,
|
||||
LogKeys.LR: 0.0018
|
||||
}, json.loads(lines[2]))
|
||||
self.assertDictContainsSubset(
|
||||
{
|
||||
LogKeys.MODE: ModeKeys.EVAL,
|
||||
LogKeys.EPOCH: 2,
|
||||
LogKeys.ITER: 5
|
||||
}, json.loads(lines[3]))
|
||||
|
||||
self.assertIn(f'{LogKeys.EPOCH}_1.pth', results_files)
|
||||
self.assertIn(f'{LogKeys.EPOCH}_2.pth', results_files)
|
||||
|
||||
for i in [0, 2]:
|
||||
self.assertIn(LogKeys.DATA_LOAD_TIME, lines[i])
|
||||
self.assertIn(LogKeys.ITER_TIME, lines[i])
|
||||
self.assertIn(LogKeys.MEMORY, lines[i])
|
||||
self.assertIn('total_loss', lines[i])
|
||||
for i in [1, 3]:
|
||||
self.assertIn(
|
||||
'CocoDetectionEvaluator_DetectionBoxes_Precision/mAP',
|
||||
lines[i])
|
||||
self.assertIn('DetectionBoxes_Precision/mAP', lines[i])
|
||||
self.assertIn('DetectionBoxes_Precision/mAP@.50IOU', lines[i])
|
||||
self.assertIn('DetectionBoxes_Precision/mAP@.75IOU', lines[i])
|
||||
self.assertIn('DetectionBoxes_Precision/mAP (small)', lines[i])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
99
tests/trainers/easycv/test_segformer.py
Normal file
99
tests/trainers/easycv/test_segformer.py
Normal file
@@ -0,0 +1,99 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import glob
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import requests
|
||||
import torch
|
||||
|
||||
from modelscope.metainfo import Trainers
|
||||
from modelscope.trainers import build_trainer
|
||||
from modelscope.utils.constant import LogKeys, Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
from modelscope.utils.test_utils import test_level
|
||||
from modelscope.utils.torch_utils import is_master
|
||||
|
||||
|
||||
def _download_data(url, save_dir):
|
||||
r = requests.get(url, verify=True)
|
||||
if not os.path.exists(save_dir):
|
||||
os.makedirs(save_dir)
|
||||
zip_name = os.path.split(url)[-1]
|
||||
save_path = os.path.join(save_dir, zip_name)
|
||||
with open(save_path, 'wb') as f:
|
||||
f.write(r.content)
|
||||
|
||||
unpack_dir = os.path.join(save_dir, os.path.splitext(zip_name)[0])
|
||||
shutil.unpack_archive(save_path, unpack_dir)
|
||||
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), 'cuda unittest')
|
||||
class EasyCVTrainerTestSegformer(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.logger = get_logger()
|
||||
self.logger.info(('Testing %s.%s' %
|
||||
(type(self).__name__, self._testMethodName)))
|
||||
self.tmp_dir = tempfile.TemporaryDirectory().name
|
||||
if not os.path.exists(self.tmp_dir):
|
||||
os.makedirs(self.tmp_dir)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
shutil.rmtree(self.tmp_dir, ignore_errors=True)
|
||||
|
||||
def _train(self):
|
||||
from modelscope.trainers.easycv.trainer import EasyCVEpochBasedTrainer
|
||||
|
||||
url = 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/datasets/small_coco_stuff164k.zip'
|
||||
data_dir = os.path.join(self.tmp_dir, 'data')
|
||||
if is_master():
|
||||
_download_data(url, data_dir)
|
||||
|
||||
# adapt to ditributed mode
|
||||
from easycv.utils.test_util import pseudo_dist_init
|
||||
pseudo_dist_init()
|
||||
|
||||
root_path = os.path.join(data_dir, 'small_coco_stuff164k')
|
||||
cfg_options = {
|
||||
'train.max_epochs':
|
||||
2,
|
||||
'dataset.train.data_source.img_root':
|
||||
os.path.join(root_path, 'train2017'),
|
||||
'dataset.train.data_source.label_root':
|
||||
os.path.join(root_path, 'annotations/train2017'),
|
||||
'dataset.train.data_source.split':
|
||||
os.path.join(root_path, 'train.txt'),
|
||||
'dataset.val.data_source.img_root':
|
||||
os.path.join(root_path, 'val2017'),
|
||||
'dataset.val.data_source.label_root':
|
||||
os.path.join(root_path, 'annotations/val2017'),
|
||||
'dataset.val.data_source.split':
|
||||
os.path.join(root_path, 'val.txt'),
|
||||
}
|
||||
|
||||
trainer_name = Trainers.easycv
|
||||
kwargs = dict(
|
||||
task=Tasks.image_segmentation,
|
||||
model='EasyCV/EasyCV-Segformer-b0',
|
||||
work_dir=self.tmp_dir,
|
||||
cfg_options=cfg_options)
|
||||
|
||||
trainer = build_trainer(trainer_name, kwargs)
|
||||
trainer.train()
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_single_gpu_segformer(self):
|
||||
self._train()
|
||||
|
||||
results_files = os.listdir(self.tmp_dir)
|
||||
json_files = glob.glob(os.path.join(self.tmp_dir, '*.log.json'))
|
||||
self.assertEqual(len(json_files), 1)
|
||||
self.assertIn(f'{LogKeys.EPOCH}_1.pth', results_files)
|
||||
self.assertIn(f'{LogKeys.EPOCH}_2.pth', results_files)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -4,6 +4,8 @@ import copy
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import json
|
||||
|
||||
from modelscope.utils.config import Config, check_config
|
||||
|
||||
obj = {'a': 1, 'b': {'c': [1, 2, 3], 'd': 'dd'}}
|
||||
@@ -43,7 +45,8 @@ class ConfigTest(unittest.TestCase):
|
||||
self.assertEqual(pretty_text, cfg.dump())
|
||||
cfg.dump(ofile.name)
|
||||
with open(ofile.name, 'r') as infile:
|
||||
self.assertEqual(json_str, infile.read())
|
||||
self.assertDictEqual(
|
||||
json.loads(json_str), json.loads(infile.read()))
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix='.yaml') as ofile:
|
||||
cfg.dump(ofile.name)
|
||||
|
||||
Reference in New Issue
Block a user