From b2c5876eadc3af6fbb5049fbe907dd99cca08d2a Mon Sep 17 00:00:00 2001 From: "jiangnana.jnn" Date: Thu, 20 Oct 2022 19:31:53 +0800 Subject: [PATCH 1/3] [to #42322933]fix create dir when dist Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10459756 --- modelscope/trainers/hooks/logger/text_logger_hook.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modelscope/trainers/hooks/logger/text_logger_hook.py b/modelscope/trainers/hooks/logger/text_logger_hook.py index 6629a0c9..8552ab4e 100644 --- a/modelscope/trainers/hooks/logger/text_logger_hook.py +++ b/modelscope/trainers/hooks/logger/text_logger_hook.py @@ -51,7 +51,7 @@ class TextLoggerHook(LoggerHook): if self.out_dir is None: self.out_dir = trainer.work_dir - if not osp.exists(self.out_dir): + if not osp.exists(self.out_dir) and is_master(): os.makedirs(self.out_dir) trainer.logger.info('Text logs will be saved to {}'.format( From de6d84cb9781181dc49b8162fb1f5a23fe4c8993 Mon Sep 17 00:00:00 2001 From: "hanyuan.chy" Date: Thu, 20 Oct 2022 19:33:06 +0800 Subject: [PATCH 2/3] =?UTF-8?q?[to=20#42322933]=E4=BF=AE=E5=A4=8Dpipeline?= =?UTF-8?q?=E4=B8=B2=E8=81=94=E6=97=B6collate=5Ffn=E5=BC=82=E5=B8=B8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 修复pipeline串联时collate_fn异常 Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10457058 --- modelscope/pipelines/base.py | 2 ++ .../cv/body_3d_keypoints_pipeline.py | 21 ++++++++++++++----- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/modelscope/pipelines/base.py b/modelscope/pipelines/base.py index ea329be4..644749fc 100644 --- a/modelscope/pipelines/base.py +++ b/modelscope/pipelines/base.py @@ -433,6 +433,8 @@ def collate_fn(data, device): if isinstance(data, dict) or isinstance(data, Mapping): return type(data)({k: collate_fn(v, device) for k, v in data.items()}) elif isinstance(data, (tuple, list)): + if 0 == len(data): + return torch.Tensor([]) if isinstance(data[0], (int, float)): return default_collate(data).to(device) else: diff --git a/modelscope/pipelines/cv/body_3d_keypoints_pipeline.py b/modelscope/pipelines/cv/body_3d_keypoints_pipeline.py index 3502915c..8522ceff 100644 --- a/modelscope/pipelines/cv/body_3d_keypoints_pipeline.py +++ b/modelscope/pipelines/cv/body_3d_keypoints_pipeline.py @@ -143,6 +143,13 @@ class Body3DKeypointsPipeline(Pipeline): max_frame = self.keypoint_model_3d.cfg.model.INPUT.MAX_FRAME # max video frame number to be predicted 3D joints for i, frame in enumerate(video_frames): kps_2d = self.human_body_2d_kps_detector(frame) + if [] == kps_2d.get('boxes'): + res = { + 'success': False, + 'msg': f'fail to detect person at image frame {i}' + } + return res + box = kps_2d['boxes'][ 0] # box: [[[x1, y1], [x2, y2]]], N human boxes per frame, [0] represent using first detected bbox pose = kps_2d['keypoints'][0] # keypoints: [15, 2] @@ -180,7 +187,15 @@ class Body3DKeypointsPipeline(Pipeline): return res def postprocess(self, input: Dict[str, Any], **kwargs) -> Dict[str, Any]: - res = {OutputKeys.KEYPOINTS: [], OutputKeys.TIMESTAMPS: []} + output_video_path = kwargs.get('output_video', None) + if output_video_path is None: + output_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name + + res = { + OutputKeys.KEYPOINTS: [], + OutputKeys.TIMESTAMPS: [], + OutputKeys.OUTPUT_VIDEO: output_video_path + } if not input['success']: pass @@ -189,10 +204,6 @@ class Body3DKeypointsPipeline(Pipeline): pred_3d_pose = poses.data.cpu().numpy()[ 0] # [frame_num, joint_num, joint_dim] - output_video_path = kwargs.get('output_video', None) - if output_video_path is None: - output_video_path = tempfile.NamedTemporaryFile( - suffix='.mp4').name if 'render' in self.keypoint_model_3d.cfg.keys(): self.render_prediction(pred_3d_pose, output_video_path) res[OutputKeys.OUTPUT_VIDEO] = output_video_path From 2b49b322a2b452b96413fe70c678c78be7b5b61a Mon Sep 17 00:00:00 2001 From: "hemu.zp" Date: Thu, 20 Oct 2022 19:50:40 +0800 Subject: [PATCH 3/3] [to #42322933] Add palm ut MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 为以下三个模型补充 ut damo/nlp_palm2.0_text-generation_chinese-large damo/nlp_palm2.0_text-generation_commodity_chinese-base damo/nlp_palm2.0_text-generation_weather_chinese-base Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10435599 --- tests/pipelines/test_text_generation.py | 50 +++++++++++++++++++++---- 1 file changed, 43 insertions(+), 7 deletions(-) diff --git a/tests/pipelines/test_text_generation.py b/tests/pipelines/test_text_generation.py index 5a270f83..4b0ebd47 100644 --- a/tests/pipelines/test_text_generation.py +++ b/tests/pipelines/test_text_generation.py @@ -15,12 +15,17 @@ from modelscope.utils.test_utils import test_level class TextGenerationTest(unittest.TestCase, DemoCompatibilityCheck): def setUp(self) -> None: - self.palm_model_id_zh = 'damo/nlp_palm2.0_text-generation_chinese-base' + self.palm_model_id_zh_base = 'damo/nlp_palm2.0_text-generation_chinese-base' + self.palm_model_id_zh_large = 'damo/nlp_palm2.0_text-generation_chinese-large' + self.palm_model_id_zh_commodity = 'damo/nlp_palm2.0_text-generation_commodity_chinese-base' + self.palm_model_id_zh_weather = 'damo/nlp_palm2.0_text-generation_weather_chinese-base' self.palm_model_id_en = 'damo/nlp_palm2.0_text-generation_english-base' self.palm_input_zh = """ 本文总结了十个可穿戴产品的设计原则,而这些原则,同样也是笔者认为是这个行业最吸引人的地方: 1.为人们解决重复性问题;2.从人开始,而不是从机器开始;3.要引起注意,但不要刻意;4.提升用户能力,而不是取代 """ + self.palm_input_commodity = '垃圾桶,双层,可拆卸,加高,加高双层,把手,垃圾桶,内附,万向轮' + self.palm_input_weather = "今日天气类型='浮尘'&空气质量等级='重度污染'&紫外线强度指数='中等'" self.palm_input_en = """ The Director of Public Prosecutions who let off Lord Janner over alleged child sex abuse started her career at a legal chambers when the disgraced Labour peer was a top QC there . Alison Saunders , @@ -51,8 +56,8 @@ class TextGenerationTest(unittest.TestCase, DemoCompatibilityCheck): print(pipeline_ins(input)) @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') - def test_palm_zh_with_model_name(self): - self.run_pipeline_with_model_id(self.palm_model_id_zh, + def test_palm_zh_base_with_model_name(self): + self.run_pipeline_with_model_id(self.palm_model_id_zh_base, self.palm_input_zh) @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') @@ -71,10 +76,40 @@ class TextGenerationTest(unittest.TestCase, DemoCompatibilityCheck): self.gpt3_input) @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') - def test_palm_zh_with_model_instance(self): - self.run_pipeline_with_model_instance(self.palm_model_id_zh, + def test_palm_zh_large_with_model_name(self): + self.run_pipeline_with_model_id(self.palm_model_id_zh_large, + self.palm_input_zh) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_palm_zh_commodity_with_model_name(self): + self.run_pipeline_with_model_id(self.palm_model_id_zh_commodity, + self.palm_input_commodity) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_palm_zh_weather_with_model_name(self): + self.run_pipeline_with_model_id(self.palm_model_id_zh_weather, + self.palm_input_weather) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_palm_zh_base_with_model_instance(self): + self.run_pipeline_with_model_instance(self.palm_model_id_zh_base, self.palm_input_zh) + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_palm_zh_large_with_model_instance(self): + self.run_pipeline_with_model_instance(self.palm_model_id_zh_large, + self.palm_input_zh) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_palm_zh_commodity_with_model_instance(self): + self.run_pipeline_with_model_instance(self.palm_model_id_zh_commodity, + self.palm_input_commodity) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_palm_zh_weather_with_model_instance(self): + self.run_pipeline_with_model_instance(self.palm_model_id_zh_weather, + self.palm_input_weather) + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_palm_en_with_model_instance(self): self.run_pipeline_with_model_instance(self.palm_model_id_en, @@ -92,8 +127,9 @@ class TextGenerationTest(unittest.TestCase, DemoCompatibilityCheck): @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_run_palm(self): - for model_id, input in ((self.palm_model_id_zh, self.palm_input_zh), - (self.palm_model_id_en, self.palm_input_en)): + for model_id, input in ((self.palm_model_id_zh_base, + self.palm_input_zh), (self.palm_model_id_en, + self.palm_input_en)): cache_path = snapshot_download(model_id) model = PalmForTextGeneration.from_pretrained(cache_path) preprocessor = TextGenerationPreprocessor(