From 0041ab0ab8928a362b3bcc293fd6289dd618d29a Mon Sep 17 00:00:00 2001 From: myf272609 Date: Mon, 19 Sep 2022 11:28:01 +0800 Subject: [PATCH] [to #42322933] add multi-style cartoon models to ut MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. 卡通化接入多风格模型(原始日漫风、3D、手绘风、素描风、艺术特效风格),添加ut接入测试 2. 修改pipeline中模型文件名称至通用名 Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10153717 --- .../pipelines/cv/image_cartoon_pipeline.py | 6 ++-- tests/pipelines/test_person_image_cartoon.py | 28 +++++++++++++++++++ 2 files changed, 30 insertions(+), 4 deletions(-) diff --git a/modelscope/pipelines/cv/image_cartoon_pipeline.py b/modelscope/pipelines/cv/image_cartoon_pipeline.py index eb669354..f34be618 100644 --- a/modelscope/pipelines/cv/image_cartoon_pipeline.py +++ b/modelscope/pipelines/cv/image_cartoon_pipeline.py @@ -40,11 +40,9 @@ class ImageCartoonPipeline(Pipeline): with device_placement(self.framework, self.device_name): self.facer = FaceAna(self.model) self.sess_anime_head = self.load_sess( - os.path.join(self.model, 'cartoon_anime_h.pb'), - 'model_anime_head') + os.path.join(self.model, 'cartoon_h.pb'), 'model_anime_head') self.sess_anime_bg = self.load_sess( - os.path.join(self.model, 'cartoon_anime_bg.pb'), - 'model_anime_bg') + os.path.join(self.model, 'cartoon_bg.pb'), 'model_anime_bg') self.box_width = 288 global_mask = cv2.imread(os.path.join(self.model, 'alpha.jpg')) diff --git a/tests/pipelines/test_person_image_cartoon.py b/tests/pipelines/test_person_image_cartoon.py index 5c81cd28..b8549f4f 100644 --- a/tests/pipelines/test_person_image_cartoon.py +++ b/tests/pipelines/test_person_image_cartoon.py @@ -16,6 +16,10 @@ class ImageCartoonTest(unittest.TestCase, DemoCompatibilityCheck): def setUp(self) -> None: self.model_id = 'damo/cv_unet_person-image-cartoon_compound-models' + self.model_id_3d = 'damo/cv_unet_person-image-cartoon-3d_compound-models' + self.model_id_handdrawn = 'damo/cv_unet_person-image-cartoon-handdrawn_compound-models' + self.model_id_sketch = 'damo/cv_unet_person-image-cartoon-sketch_compound-models' + self.model_id_artstyle = 'damo/cv_unet_person-image-cartoon-artstyle_compound-models' self.task = Tasks.image_portrait_stylization self.test_image = 'https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/image_cartoon.png' @@ -31,6 +35,30 @@ class ImageCartoonTest(unittest.TestCase, DemoCompatibilityCheck): Tasks.image_portrait_stylization, model=self.model_id) self.pipeline_inference(img_cartoon, self.test_image) + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_modelhub_3d(self): + img_cartoon = pipeline( + Tasks.image_portrait_stylization, model=self.model_id_3d) + self.pipeline_inference(img_cartoon, self.test_image) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_modelhub_handdrawn(self): + img_cartoon = pipeline( + Tasks.image_portrait_stylization, model=self.model_id_handdrawn) + self.pipeline_inference(img_cartoon, self.test_image) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_modelhub_sketch(self): + img_cartoon = pipeline( + Tasks.image_portrait_stylization, model=self.model_id_sketch) + self.pipeline_inference(img_cartoon, self.test_image) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_modelhub_artstyle(self): + img_cartoon = pipeline( + Tasks.image_portrait_stylization, model=self.model_id_artstyle) + self.pipeline_inference(img_cartoon, self.test_image) + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_run_modelhub_default_model(self): img_cartoon = pipeline(Tasks.image_portrait_stylization)