From 259bfb64c69210dc93c2939c866ce4934e035eb8 Mon Sep 17 00:00:00 2001 From: ljl191782 Date: Thu, 9 Feb 2023 10:13:32 +0000 Subject: [PATCH] [to #42322933] add universal_matting pipeline MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 上线通用抠图模型 Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11505692 --- data/test/images/universal_matting.jpg | 3 ++ modelscope/metainfo.py | 3 ++ modelscope/outputs/outputs.py | 1 + .../pipelines/cv/image_matting_pipeline.py | 9 ++-- modelscope/utils/constant.py | 1 + tests/pipelines/test_universal_matting.py | 44 +++++++++++++++++++ 6 files changed, 58 insertions(+), 3 deletions(-) create mode 100644 data/test/images/universal_matting.jpg create mode 100644 tests/pipelines/test_universal_matting.py diff --git a/data/test/images/universal_matting.jpg b/data/test/images/universal_matting.jpg new file mode 100644 index 00000000..d824eb21 --- /dev/null +++ b/data/test/images/universal_matting.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:78d7bf999d1a4186309693ff1b966edb3ccd40f7861a7589167cf9e33897a693 +size 369725 diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index a14b1375..95804ac7 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -207,6 +207,7 @@ class Pipelines(object): """ # vision tasks portrait_matting = 'unet-image-matting' + universal_matting = 'unet-universal-matting' image_denoise = 'nafnet-image-denoise' image_deblur = 'nafnet-image-deblur' person_image_cartoon = 'unet-person-image-cartoon' @@ -461,6 +462,8 @@ DEFAULT_MODEL_FOR_PIPELINE = { ), # TODO: revise back after passing the pr Tasks.portrait_matting: (Pipelines.portrait_matting, 'damo/cv_unet_image-matting'), + Tasks.universal_matting: (Pipelines.universal_matting, + 'damo/cv_unet_universal-matting'), Tasks.human_detection: (Pipelines.human_detection, 'damo/cv_resnet18_human-detection'), Tasks.image_object_detection: (Pipelines.object_detection, diff --git a/modelscope/outputs/outputs.py b/modelscope/outputs/outputs.py index 1cff5de9..322e8c62 100644 --- a/modelscope/outputs/outputs.py +++ b/modelscope/outputs/outputs.py @@ -314,6 +314,7 @@ TASK_OUTPUTS = { # , shape(h, w) for crowd counting # } Tasks.portrait_matting: [OutputKeys.OUTPUT_IMG], + Tasks.universal_matting: [OutputKeys.OUTPUT_IMG], # image_quality_assessment_mos result for a single image is a score in range [0, 1] # {0.5} diff --git a/modelscope/pipelines/cv/image_matting_pipeline.py b/modelscope/pipelines/cv/image_matting_pipeline.py index fb5d8f8b..5f5d1d56 100644 --- a/modelscope/pipelines/cv/image_matting_pipeline.py +++ b/modelscope/pipelines/cv/image_matting_pipeline.py @@ -4,6 +4,7 @@ from typing import Any, Dict import cv2 import numpy as np +import tensorflow as tf from modelscope.metainfo import Pipelines from modelscope.outputs import OutputKeys @@ -14,11 +15,16 @@ from modelscope.utils.constant import ModelFile, Tasks from modelscope.utils.device import device_placement from modelscope.utils.logger import get_logger +if tf.__version__ >= '2.0': + tf = tf.compat.v1 + logger = get_logger() @PIPELINES.register_module( Tasks.portrait_matting, module_name=Pipelines.portrait_matting) +@PIPELINES.register_module( + Tasks.universal_matting, module_name=Pipelines.universal_matting) class ImageMattingPipeline(Pipeline): def __init__(self, model: str, **kwargs): @@ -28,9 +34,6 @@ class ImageMattingPipeline(Pipeline): model: model id on modelscope hub. """ super().__init__(model=model, **kwargs) - import tensorflow as tf - if tf.__version__ >= '2.0': - tf = tf.compat.v1 model_path = osp.join(self.model, ModelFile.TF_GRAPH_FILE) with device_placement(self.framework, self.device_name): diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index 38944d81..5de328f9 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -55,6 +55,7 @@ class CVTasks(object): video_depth_estimation = 'video-depth-estimation' panorama_depth_estimation = 'panorama-depth-estimation' portrait_matting = 'portrait-matting' + universal_matting = 'universal-matting' text_driven_segmentation = 'text-driven-segmentation' shop_segmentation = 'shop-segmentation' hand_static = 'hand-static' diff --git a/tests/pipelines/test_universal_matting.py b/tests/pipelines/test_universal_matting.py new file mode 100644 index 00000000..5868cf36 --- /dev/null +++ b/tests/pipelines/test_universal_matting.py @@ -0,0 +1,44 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os.path as osp +import unittest + +import cv2 + +from modelscope.msdatasets import MsDataset +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class UniversalMattingTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.model_id = 'damo/cv_unet_universal-matting' + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_dataset(self): + input_location = ['data/test/images/universal_matting.jpg'] + + dataset = MsDataset.load(input_location, target='image') + img_matting = pipeline(Tasks.universal_matting, model=self.model_id) + result = img_matting(dataset) + cv2.imwrite('result.png', next(result)[OutputKeys.OUTPUT_IMG]) + print(f'Output written to {osp.abspath("result.png")}') + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_modelhub(self): + img_matting = pipeline(Tasks.universal_matting, model=self.model_id) + + result = img_matting('data/test/images/universal_matting.jpg') + cv2.imwrite('result.png', result[OutputKeys.OUTPUT_IMG]) + print(f'Output written to {osp.abspath("result.png")}') + + @unittest.skip('demo compatibility test is only enabled on a needed-basis') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main()