From 4cb5f8a2cd104f89b765d56527d448b2df1be151 Mon Sep 17 00:00:00 2001 From: "shouzhou.bx" Date: Wed, 12 Oct 2022 19:53:14 +0800 Subject: [PATCH] [to #42322933] add human whole body model and image object detection auto model Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10319306 --- data/test/images/auto_demo.jpg | 3 + .../body_keypoints_detection.jpg | 3 - .../keypoints_detect/img_test_wholebody.jpg | 3 + modelscope/metainfo.py | 5 ++ modelscope/models/cv/__init__.py | 20 +++--- .../cv/human_wholebody_keypoint/__init__.py | 22 +++++++ .../human_wholebody_keypoint.py | 17 +++++ .../models/cv/object_detection/__init__.py | 2 +- .../models/cv/object_detection/yolox_pai.py | 3 + .../cv/human_wholebody_keypoint/__init__.py | 22 +++++++ .../human_wholebody_keypoint_dataset.py | 39 +++++++++++ modelscope/outputs.py | 19 +++++- modelscope/pipelines/builder.py | 8 ++- modelscope/pipelines/cv/__init__.py | 11 +++- .../cv/body_2d_keypoints_pipeline.py | 4 +- .../cv/body_3d_keypoints_pipeline.py | 2 +- .../pipelines/cv/easycv_pipelines/__init__.py | 5 +- .../cv/easycv_pipelines/detection_pipeline.py | 41 +++++++++++- .../human_wholebody_keypoint_pipeline.py | 65 +++++++++++++++++++ modelscope/utils/constant.py | 1 + modelscope/utils/cv/image_utils.py | 34 +++++++++- .../test_human_wholebody_keypoint.py | 40 ++++++++++++ tests/pipelines/test_object_detection.py | 12 ++++ 23 files changed, 353 insertions(+), 28 deletions(-) create mode 100644 data/test/images/auto_demo.jpg delete mode 100644 data/test/images/keypoints_detect/body_keypoints_detection.jpg create mode 100644 data/test/images/keypoints_detect/img_test_wholebody.jpg create mode 100644 modelscope/models/cv/human_wholebody_keypoint/__init__.py create mode 100644 modelscope/models/cv/human_wholebody_keypoint/human_wholebody_keypoint.py create mode 100644 modelscope/msdatasets/cv/human_wholebody_keypoint/__init__.py create mode 100644 modelscope/msdatasets/cv/human_wholebody_keypoint/human_wholebody_keypoint_dataset.py create mode 100644 modelscope/pipelines/cv/easycv_pipelines/human_wholebody_keypoint_pipeline.py create mode 100644 tests/pipelines/test_human_wholebody_keypoint.py diff --git a/data/test/images/auto_demo.jpg b/data/test/images/auto_demo.jpg new file mode 100644 index 00000000..30393e53 --- /dev/null +++ b/data/test/images/auto_demo.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:76bf84536edbaf192a8a699efc62ba2b06056bac12c426ecfcc2e003d91fbd32 +size 53219 diff --git a/data/test/images/keypoints_detect/body_keypoints_detection.jpg b/data/test/images/keypoints_detect/body_keypoints_detection.jpg deleted file mode 100644 index 71ce7d7e..00000000 --- a/data/test/images/keypoints_detect/body_keypoints_detection.jpg +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:379e11d7fc3734d3ec95afd0d86460b4653fbf4bb1f57f993610d6a6fd30fd3d -size 1702339 diff --git a/data/test/images/keypoints_detect/img_test_wholebody.jpg b/data/test/images/keypoints_detect/img_test_wholebody.jpg new file mode 100644 index 00000000..40a9f3f8 --- /dev/null +++ b/data/test/images/keypoints_detect/img_test_wholebody.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dec0fbb931cb609bf481e56b89cd2fbbab79839f22832c3bbe69a8fae2769cdd +size 167407 diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index cae9d188..759f1688 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -40,6 +40,7 @@ class Models(object): mtcnn = 'mtcnn' ulfd = 'ulfd' video_inpainting = 'video-inpainting' + human_wholebody_keypoint = 'human-wholebody-keypoint' hand_static = 'hand-static' face_human_hand_detection = 'face-human-hand-detection' face_emotion = 'face-emotion' @@ -49,6 +50,7 @@ class Models(object): # EasyCV models yolox = 'YOLOX' segformer = 'Segformer' + image_object_detection_auto = 'image-object-detection-auto' # nlp models bert = 'bert' @@ -170,6 +172,7 @@ class Pipelines(object): ocr_recognition = 'convnextTiny-ocr-recognition' image_portrait_enhancement = 'gpen-image-portrait-enhancement' image_to_image_generation = 'image-to-image-generation' + image_object_detection_auto = 'yolox_image-object-detection-auto' skin_retouching = 'unet-skin-retouching' tinynas_classification = 'tinynas-classification' tinynas_detection = 'tinynas-detection' @@ -185,6 +188,7 @@ class Pipelines(object): movie_scene_segmentation = 'resnet50-bert-movie-scene-segmentation' shop_segmentation = 'shop-segmentation' video_inpainting = 'video-inpainting' + human_wholebody_keypoint = 'hrnetw48_human-wholebody-keypoint_image' pst_action_recognition = 'patchshift-action-recognition' hand_static = 'hand-static' face_human_hand_detection = 'face-human-hand-detection' @@ -427,6 +431,7 @@ class Datasets(object): """ ClsDataset = 'ClsDataset' Face2dKeypointsDataset = 'Face2dKeypointsDataset' + HumanWholeBodyKeypointDataset = 'HumanWholeBodyKeypointDataset' SegDataset = 'SegDataset' DetDataset = 'DetDataset' DetImagesMixDataset = 'DetImagesMixDataset' diff --git a/modelscope/models/cv/__init__.py b/modelscope/models/cv/__init__.py index ba7b03c5..fd950f4c 100644 --- a/modelscope/models/cv/__init__.py +++ b/modelscope/models/cv/__init__.py @@ -4,15 +4,15 @@ from . import (action_recognition, animal_recognition, body_2d_keypoints, body_3d_keypoints, cartoon, cmdssl_video_embedding, crowd_counting, face_2d_keypoints, face_detection, - face_generation, image_classification, image_color_enhance, - image_colorization, image_denoise, image_inpainting, - image_instance_segmentation, image_panoptic_segmentation, - image_portrait_enhancement, image_reid_person, - image_semantic_segmentation, image_to_image_generation, - image_to_image_translation, movie_scene_segmentation, - object_detection, product_retrieval_embedding, - realtime_object_detection, salient_detection, shop_segmentation, - super_resolution, video_single_object_tracking, - video_summarization, virual_tryon) + face_generation, human_wholebody_keypoint, image_classification, + image_color_enhance, image_colorization, image_denoise, + image_inpainting, image_instance_segmentation, + image_panoptic_segmentation, image_portrait_enhancement, + image_reid_person, image_semantic_segmentation, + image_to_image_generation, image_to_image_translation, + movie_scene_segmentation, object_detection, + product_retrieval_embedding, realtime_object_detection, + salient_detection, shop_segmentation, super_resolution, + video_single_object_tracking, video_summarization, virual_tryon) # yapf: enable diff --git a/modelscope/models/cv/human_wholebody_keypoint/__init__.py b/modelscope/models/cv/human_wholebody_keypoint/__init__.py new file mode 100644 index 00000000..30e23457 --- /dev/null +++ b/modelscope/models/cv/human_wholebody_keypoint/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .human_wholebody_keypoint import HumanWholeBodyKeypoint + +else: + _import_structure = { + 'human_wholebody_keypoint': ['HumanWholeBodyKeypoint'] + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/human_wholebody_keypoint/human_wholebody_keypoint.py b/modelscope/models/cv/human_wholebody_keypoint/human_wholebody_keypoint.py new file mode 100644 index 00000000..dd3c0290 --- /dev/null +++ b/modelscope/models/cv/human_wholebody_keypoint/human_wholebody_keypoint.py @@ -0,0 +1,17 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from easycv.models.pose.top_down import TopDown + +from modelscope.metainfo import Models +from modelscope.models.builder import MODELS +from modelscope.models.cv.easycv_base import EasyCVBaseModel +from modelscope.utils.constant import Tasks + + +@MODELS.register_module( + group_key=Tasks.human_wholebody_keypoint, + module_name=Models.human_wholebody_keypoint) +class HumanWholeBodyKeypoint(EasyCVBaseModel, TopDown): + + def __init__(self, model_dir=None, *args, **kwargs): + EasyCVBaseModel.__init__(self, model_dir, args, kwargs) + TopDown.__init__(self, *args, **kwargs) diff --git a/modelscope/models/cv/object_detection/__init__.py b/modelscope/models/cv/object_detection/__init__.py index 974375ce..0c782d7b 100644 --- a/modelscope/models/cv/object_detection/__init__.py +++ b/modelscope/models/cv/object_detection/__init__.py @@ -10,7 +10,7 @@ if TYPE_CHECKING: else: _import_structure = { 'mmdet_model': ['DetectionModel'], - 'yolox_pai': ['YOLOX'] + 'yolox_pai': ['YOLOX'], } import sys diff --git a/modelscope/models/cv/object_detection/yolox_pai.py b/modelscope/models/cv/object_detection/yolox_pai.py index 985cc136..46bd4e3c 100644 --- a/modelscope/models/cv/object_detection/yolox_pai.py +++ b/modelscope/models/cv/object_detection/yolox_pai.py @@ -9,6 +9,9 @@ from modelscope.utils.constant import Tasks @MODELS.register_module( group_key=Tasks.image_object_detection, module_name=Models.yolox) +@MODELS.register_module( + group_key=Tasks.image_object_detection, + module_name=Models.image_object_detection_auto) class YOLOX(EasyCVBaseModel, _YOLOX): def __init__(self, model_dir=None, *args, **kwargs): diff --git a/modelscope/msdatasets/cv/human_wholebody_keypoint/__init__.py b/modelscope/msdatasets/cv/human_wholebody_keypoint/__init__.py new file mode 100644 index 00000000..472ed2d8 --- /dev/null +++ b/modelscope/msdatasets/cv/human_wholebody_keypoint/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .human_wholebody_keypoint_dataset import WholeBodyCocoTopDownDataset + +else: + _import_structure = { + 'human_wholebody_keypoint_dataset': ['WholeBodyCocoTopDownDataset'] + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/msdatasets/cv/human_wholebody_keypoint/human_wholebody_keypoint_dataset.py b/modelscope/msdatasets/cv/human_wholebody_keypoint/human_wholebody_keypoint_dataset.py new file mode 100644 index 00000000..fc9469f2 --- /dev/null +++ b/modelscope/msdatasets/cv/human_wholebody_keypoint/human_wholebody_keypoint_dataset.py @@ -0,0 +1,39 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from easycv.datasets.pose import \ + WholeBodyCocoTopDownDataset as _WholeBodyCocoTopDownDataset + +from modelscope.metainfo import Datasets +from modelscope.msdatasets.cv.easycv_base import EasyCVBaseDataset +from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS +from modelscope.utils.constant import Tasks + + +@TASK_DATASETS.register_module( + group_key=Tasks.human_wholebody_keypoint, + module_name=Datasets.HumanWholeBodyKeypointDataset) +class WholeBodyCocoTopDownDataset(EasyCVBaseDataset, + _WholeBodyCocoTopDownDataset): + """EasyCV dataset for human whole body 2d keypoints. + + Args: + split_config (dict): Dataset root path from MSDataset, e.g. + {"train":"local cache path"} or {"evaluation":"local cache path"}. + preprocessor (Preprocessor): An optional preprocessor instance, please make sure the preprocessor fits for + the model if supplied. Not support yet. + mode: Training or Evaluation. + """ + + def __init__(self, + split_config=None, + preprocessor=None, + mode=None, + *args, + **kwargs) -> None: + EasyCVBaseDataset.__init__( + self, + split_config=split_config, + preprocessor=preprocessor, + mode=mode, + args=args, + kwargs=kwargs) + _WholeBodyCocoTopDownDataset.__init__(self, *args, **kwargs) diff --git a/modelscope/outputs.py b/modelscope/outputs.py index 0f353d3d..ab3ea54a 100644 --- a/modelscope/outputs.py +++ b/modelscope/outputs.py @@ -203,7 +203,7 @@ TASK_OUTPUTS = { # human body keypoints detection result for single sample # { - # "poses": [ + # "keypoints": [ # [[x, y]*15], # [[x, y]*15], # [[x, y]*15] @@ -220,7 +220,7 @@ TASK_OUTPUTS = { # ] # } Tasks.body_2d_keypoints: - [OutputKeys.POSES, OutputKeys.SCORES, OutputKeys.BOXES], + [OutputKeys.KEYPOINTS, OutputKeys.SCORES, OutputKeys.BOXES], # 3D human body keypoints detection result for single sample # { @@ -339,6 +339,21 @@ TASK_OUTPUTS = { OutputKeys.SCENE_META_LIST ], + # human whole body keypoints detection result for single sample + # { + # "keypoints": [ + # [[x, y]*133], + # [[x, y]*133], + # [[x, y]*133] + # ] + # "boxes": [ + # [x1, y1, x2, y2], + # [x1, y1, x2, y2], + # [x1, y1, x2, y2], + # ] + # } + Tasks.human_wholebody_keypoint: [OutputKeys.KEYPOINTS, OutputKeys.BOXES], + # video summarization result for a single video # { # "output": diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index 1f563915..bc9073bc 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -75,8 +75,6 @@ DEFAULT_MODEL_FOR_PIPELINE = { 'damo/nlp_bart_text-error-correction_chinese'), Tasks.image_captioning: (Pipelines.image_captioning, 'damo/ofa_image-caption_coco_large_en'), - Tasks.image_body_reshaping: (Pipelines.image_body_reshaping, - 'damo/cv_flow-based-body-reshaping_damo'), Tasks.image_portrait_stylization: (Pipelines.person_image_cartoon, 'damo/cv_unet_person-image-cartoon_compound-models'), @@ -159,6 +157,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { Tasks.image_classification: (Pipelines.daily_image_classification, 'damo/cv_vit-base_image-classification_Dailylife-labels'), + Tasks.image_object_detection: + (Pipelines.image_object_detection_auto, + 'damo/cv_yolox_image-object-detection-auto'), Tasks.ocr_recognition: (Pipelines.ocr_recognition, 'damo/cv_convnextTiny_ocr-recognition-general_damo'), @@ -186,6 +187,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { 'damo/cv_fft_inpainting_lama'), Tasks.video_inpainting: (Pipelines.video_inpainting, 'damo/cv_video-inpainting'), + Tasks.human_wholebody_keypoint: + (Pipelines.human_wholebody_keypoint, + 'damo/cv_hrnetw48_human-wholebody-keypoint_image'), Tasks.hand_static: (Pipelines.hand_static, 'damo/cv_mobileface_hand-static'), Tasks.face_human_hand_detection: diff --git a/modelscope/pipelines/cv/__init__.py b/modelscope/pipelines/cv/__init__.py index 118eaf17..f84f5fe5 100644 --- a/modelscope/pipelines/cv/__init__.py +++ b/modelscope/pipelines/cv/__init__.py @@ -46,7 +46,10 @@ if TYPE_CHECKING: from .video_category_pipeline import VideoCategoryPipeline from .virtual_try_on_pipeline import VirtualTryonPipeline from .shop_segmentation_pipleline import ShopSegmentationPipeline - from .easycv_pipelines import EasyCVDetectionPipeline, EasyCVSegmentationPipeline, Face2DKeypointsPipeline + from .easycv_pipelines import (EasyCVDetectionPipeline, + EasyCVSegmentationPipeline, + Face2DKeypointsPipeline, + HumanWholebodyKeypointsPipeline) from .text_driven_segmentation_pipleline import TextDrivenSegmentationPipeline from .movie_scene_segmentation_pipeline import MovieSceneSegmentationPipeline from .mog_face_detection_pipeline import MogFaceDetectionPipeline @@ -109,8 +112,10 @@ else: 'virtual_try_on_pipeline': ['VirtualTryonPipeline'], 'shop_segmentation_pipleline': ['ShopSegmentationPipeline'], 'easycv_pipeline': [ - 'EasyCVDetectionPipeline', 'EasyCVSegmentationPipeline', - 'Face2DKeypointsPipeline' + 'EasyCVDetectionPipeline', + 'EasyCVSegmentationPipeline', + 'Face2DKeypointsPipeline', + 'HumanWholebodyKeypointsPipeline', ], 'text_driven_segmentation_pipeline': ['TextDrivenSegmentationPipeline'], diff --git a/modelscope/pipelines/cv/body_2d_keypoints_pipeline.py b/modelscope/pipelines/cv/body_2d_keypoints_pipeline.py index d6afbae4..bc2e975d 100644 --- a/modelscope/pipelines/cv/body_2d_keypoints_pipeline.py +++ b/modelscope/pipelines/cv/body_2d_keypoints_pipeline.py @@ -73,7 +73,7 @@ class Body2DKeypointsPipeline(Pipeline): if input[0] is None or input[1] is None: return { OutputKeys.BOXES: [], - OutputKeys.POSES: [], + OutputKeys.KEYPOINTS: [], OutputKeys.SCORES: [] } @@ -83,7 +83,7 @@ class Body2DKeypointsPipeline(Pipeline): result_boxes.append([box[0][0], box[0][1], box[1][0], box[1][1]]) return { OutputKeys.BOXES: result_boxes, - OutputKeys.POSES: poses, + OutputKeys.KEYPOINTS: poses, OutputKeys.SCORES: scores } diff --git a/modelscope/pipelines/cv/body_3d_keypoints_pipeline.py b/modelscope/pipelines/cv/body_3d_keypoints_pipeline.py index c3f4e8c1..3502915c 100644 --- a/modelscope/pipelines/cv/body_3d_keypoints_pipeline.py +++ b/modelscope/pipelines/cv/body_3d_keypoints_pipeline.py @@ -145,7 +145,7 @@ class Body3DKeypointsPipeline(Pipeline): kps_2d = self.human_body_2d_kps_detector(frame) box = kps_2d['boxes'][ 0] # box: [[[x1, y1], [x2, y2]]], N human boxes per frame, [0] represent using first detected bbox - pose = kps_2d['poses'][0] # keypoints: [15, 2] + pose = kps_2d['keypoints'][0] # keypoints: [15, 2] score = kps_2d['scores'][0] # keypoints: [15, 2] all_2d_poses.append(pose) all_boxes_with_socre.append( diff --git a/modelscope/pipelines/cv/easycv_pipelines/__init__.py b/modelscope/pipelines/cv/easycv_pipelines/__init__.py index 4f149130..e0209b85 100644 --- a/modelscope/pipelines/cv/easycv_pipelines/__init__.py +++ b/modelscope/pipelines/cv/easycv_pipelines/__init__.py @@ -7,11 +7,14 @@ if TYPE_CHECKING: from .detection_pipeline import EasyCVDetectionPipeline from .segmentation_pipeline import EasyCVSegmentationPipeline from .face_2d_keypoints_pipeline import Face2DKeypointsPipeline + from .human_wholebody_keypoint_pipeline import HumanWholebodyKeypointsPipeline else: _import_structure = { 'detection_pipeline': ['EasyCVDetectionPipeline'], 'segmentation_pipeline': ['EasyCVSegmentationPipeline'], - 'face_2d_keypoints_pipeline': ['Face2DKeypointsPipeline'] + 'face_2d_keypoints_pipeline': ['Face2DKeypointsPipeline'], + 'human_wholebody_keypoint_pipeline': + ['HumanWholebodyKeypointsPipeline'], } import sys diff --git a/modelscope/pipelines/cv/easycv_pipelines/detection_pipeline.py b/modelscope/pipelines/cv/easycv_pipelines/detection_pipeline.py index 32365102..0c2058d5 100644 --- a/modelscope/pipelines/cv/easycv_pipelines/detection_pipeline.py +++ b/modelscope/pipelines/cv/easycv_pipelines/detection_pipeline.py @@ -1,16 +1,28 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any + from modelscope.metainfo import Pipelines +from modelscope.outputs import OutputKeys from modelscope.pipelines.builder import PIPELINES -from modelscope.utils.constant import Tasks +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.cv.image_utils import \ + show_image_object_detection_auto_result from .base import EasyCVPipeline @PIPELINES.register_module( Tasks.image_object_detection, module_name=Pipelines.easycv_detection) +@PIPELINES.register_module( + Tasks.image_object_detection, + module_name=Pipelines.image_object_detection_auto) class EasyCVDetectionPipeline(EasyCVPipeline): """Pipeline for easycv detection task.""" - def __init__(self, model: str, model_file_pattern='*.pt', *args, **kwargs): + def __init__(self, + model: str, + model_file_pattern=ModelFile.TORCH_MODEL_FILE, + *args, + **kwargs): """ model (str): model id on modelscope hub or local model path. model_file_pattern (str): model file pattern. @@ -21,3 +33,28 @@ class EasyCVDetectionPipeline(EasyCVPipeline): model_file_pattern=model_file_pattern, *args, **kwargs) + + def show_result(self, img_path, result, save_path=None): + show_image_object_detection_auto_result(img_path, result, save_path) + + def __call__(self, inputs) -> Any: + outputs = self.predict_op(inputs) + + scores = [] + labels = [] + boxes = [] + for output in outputs: + for score, label, box in zip(output['detection_scores'], + output['detection_classes'], + output['detection_boxes']): + scores.append(score) + labels.append(self.cfg.CLASSES[label]) + boxes.append([b for b in box]) + + results = [{ + OutputKeys.SCORES: scores, + OutputKeys.LABELS: labels, + OutputKeys.BOXES: boxes + } for output in outputs] + + return results diff --git a/modelscope/pipelines/cv/easycv_pipelines/human_wholebody_keypoint_pipeline.py b/modelscope/pipelines/cv/easycv_pipelines/human_wholebody_keypoint_pipeline.py new file mode 100644 index 00000000..263f8225 --- /dev/null +++ b/modelscope/pipelines/cv/easycv_pipelines/human_wholebody_keypoint_pipeline.py @@ -0,0 +1,65 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os.path +from typing import Any + +from modelscope.metainfo import Pipelines +from modelscope.outputs import OutputKeys +from modelscope.pipelines.builder import PIPELINES +from modelscope.utils.constant import ModelFile, Tasks +from .base import EasyCVPipeline + + +@PIPELINES.register_module( + Tasks.human_wholebody_keypoint, + module_name=Pipelines.human_wholebody_keypoint) +class HumanWholebodyKeypointsPipeline(EasyCVPipeline): + """Pipeline for human wholebody 2d keypoints detection.""" + + def __init__(self, + model: str, + model_file_pattern=ModelFile.TORCH_MODEL_FILE, + *args, + **kwargs): + """ + model (str): model id on modelscope hub or local model path. + model_file_pattern (str): model file pattern. + """ + self.model_dir = model + super(HumanWholebodyKeypointsPipeline, self).__init__( + model=model, + model_file_pattern=model_file_pattern, + *args, + **kwargs) + + def _build_predict_op(self, **kwargs): + """Build EasyCV predictor.""" + from easycv.predictors.builder import build_predictor + detection_predictor_type = self.cfg['DETECTION']['type'] + detection_model_path = os.path.join( + self.model_dir, self.cfg['DETECTION']['model_path']) + detection_cfg_file = os.path.join(self.model_dir, + self.cfg['DETECTION']['config_file']) + detection_score_threshold = self.cfg['DETECTION']['score_threshold'] + self.cfg.pipeline.predictor_config[ + 'detection_predictor_config'] = dict( + type=detection_predictor_type, + model_path=detection_model_path, + config_file=detection_cfg_file, + score_threshold=detection_score_threshold) + easycv_config = self._to_easycv_config() + pipeline_op = build_predictor(self.cfg.pipeline.predictor_config, { + 'model_path': self.model_path, + 'config_file': easycv_config, + **kwargs + }) + return pipeline_op + + def __call__(self, inputs) -> Any: + outputs = self.predict_op(inputs) + + results = [{ + OutputKeys.KEYPOINTS: output['keypoints'], + OutputKeys.BOXES: output['boxes'] + } for output in outputs] + + return results diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index 2a5ac694..4fa3d766 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -29,6 +29,7 @@ class CVTasks(object): body_3d_keypoints = 'body-3d-keypoints' hand_2d_keypoints = 'hand-2d-keypoints' general_recognition = 'general-recognition' + human_wholebody_keypoint = 'human-wholebody-keypoint' image_classification = 'image-classification' image_multilabel_classification = 'image-multilabel-classification' diff --git a/modelscope/utils/cv/image_utils.py b/modelscope/utils/cv/image_utils.py index eab74688..06a9bbaa 100644 --- a/modelscope/utils/cv/image_utils.py +++ b/modelscope/utils/cv/image_utils.py @@ -80,7 +80,7 @@ def realtime_object_detection_bbox_vis(image, bboxes): def draw_keypoints(output, original_image): - poses = np.array(output[OutputKeys.POSES]) + poses = np.array(output[OutputKeys.KEYPOINTS]) scores = np.array(output[OutputKeys.SCORES]) boxes = np.array(output[OutputKeys.BOXES]) assert len(poses) == len(scores) and len(poses) == len(boxes) @@ -234,3 +234,35 @@ def show_video_summarization_result(video_in_path, result, video_save_path): video_writer.write(frame) video_writer.release() cap.release() + + +def show_image_object_detection_auto_result(img_path, + detection_result, + save_path=None): + scores = detection_result[OutputKeys.SCORES] + labels = detection_result[OutputKeys.LABELS] + bboxes = detection_result[OutputKeys.BOXES] + img = cv2.imread(img_path) + assert img is not None, f"Can't read img: {img_path}" + + for (score, label, box) in zip(scores, labels, bboxes): + cv2.rectangle(img, (int(box[0]), int(box[1])), + (int(box[2]), int(box[3])), (0, 0, 255), 2) + cv2.putText( + img, + f'{score:.2f}', (int(box[0]), int(box[1])), + 1, + 1.0, (0, 255, 0), + thickness=1, + lineType=8) + cv2.putText( + img, + label, (int((box[0] + box[2]) * 0.5), int(box[1])), + 1, + 1.0, (0, 255, 0), + thickness=1, + lineType=8) + + if save_path is not None: + cv2.imwrite(save_path, img) + return img diff --git a/tests/pipelines/test_human_wholebody_keypoint.py b/tests/pipelines/test_human_wholebody_keypoint.py new file mode 100644 index 00000000..b214f4e1 --- /dev/null +++ b/tests/pipelines/test_human_wholebody_keypoint.py @@ -0,0 +1,40 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +import cv2 + +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + + +class EasyCVFace2DKeypointsPipelineTest(unittest.TestCase): + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_human_wholebody_keypoint(self): + img_path = 'data/test/images/keypoints_detect/img_test_wholebody.jpg' + model_id = 'damo/cv_hrnetw48_human-wholebody-keypoint_image' + + human_wholebody_keypoint_pipeline = pipeline( + task=Tasks.human_wholebody_keypoint, model=model_id) + output = human_wholebody_keypoint_pipeline(img_path)[0] + + output_keypoints = output[OutputKeys.KEYPOINTS] + output_pose = output[OutputKeys.BOXES] + + human_wholebody_keypoint_pipeline.predict_op.show_result( + img_path, + output_keypoints, + output_pose, + scale=1, + save_path='human_wholebody_keypoint_ret.jpg') + + for keypoint in output_keypoints: + self.assertEqual(keypoint.shape[0], 133) + for box in output_pose: + self.assertEqual(box.shape[0], 4) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_object_detection.py b/tests/pipelines/test_object_detection.py index 2a74eb41..2cb217d9 100644 --- a/tests/pipelines/test_object_detection.py +++ b/tests/pipelines/test_object_detection.py @@ -59,6 +59,18 @@ class ObjectDetectionTest(unittest.TestCase, DemoCompatibilityCheck): def test_demo_compatibility(self): self.compatibility_check() + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_image_object_detection_auto_pipeline(self): + model_id = 'damo/cv_yolox_image-object-detection-auto' + test_image = 'data/test/images/auto_demo.jpg' + + image_object_detection_auto = pipeline( + Tasks.image_object_detection, model=model_id) + + result = image_object_detection_auto(test_image)[0] + image_object_detection_auto.show_result(test_image, result, + 'auto_demo_ret.jpg') + if __name__ == '__main__': unittest.main()