From 97e35a040ce831cd6980fe4088c02d6ee6565623 Mon Sep 17 00:00:00 2001 From: "maojialiang.mjl" Date: Fri, 24 Feb 2023 09:49:02 +0800 Subject: [PATCH] [to #42322933] fix: remove restriction on image resolution in the preprocessing stage MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 去掉预处理阶段resize 到720x1280的限制,支持输出与原始输入图像分辨率对应的结果 Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11762135 --- .../image_driving_percetion_model.py | 1 + .../cv/image_driving_perception/preprocessor.py | 3 ++- .../models/cv/image_driving_perception/utils.py | 15 +++++---------- .../cv/image_driving_perception_pipeline.py | 12 ++++++++---- modelscope/utils/cv/image_utils.py | 2 -- 5 files changed, 16 insertions(+), 17 deletions(-) diff --git a/modelscope/models/cv/image_driving_perception/image_driving_percetion_model.py b/modelscope/models/cv/image_driving_perception/image_driving_percetion_model.py index b7de37e7..e29ad2b9 100644 --- a/modelscope/models/cv/image_driving_perception/image_driving_percetion_model.py +++ b/modelscope/models/cv/image_driving_perception/image_driving_percetion_model.py @@ -37,6 +37,7 @@ class YOLOPv2(TorchModel): [pred, anchor_grid], seg, ll = self.model(img) return { 'img_hw': data['img'].shape[2:], + 'ori_img_shape': data['ori_img_shape'], 'pred': pred, 'anchor_grid': anchor_grid, 'driving_area_mask': seg, diff --git a/modelscope/models/cv/image_driving_perception/preprocessor.py b/modelscope/models/cv/image_driving_perception/preprocessor.py index dbb4f761..3e0e476f 100644 --- a/modelscope/models/cv/image_driving_perception/preprocessor.py +++ b/modelscope/models/cv/image_driving_perception/preprocessor.py @@ -102,7 +102,7 @@ class ImageDrivingPerceptionPreprocessor(Preprocessor): img = self._check_image(img) else: raise Exception('img is None') - img = cv2.resize(img, output_shape, interpolation=cv2.INTER_LINEAR) + ori_h, ori_w, _ = img.shape img = self._letterbox(img, new_shape)[0] img = img.transpose(2, 0, 1) # to 3x640x640 @@ -117,4 +117,5 @@ class ImageDrivingPerceptionPreprocessor(Preprocessor): return { 'img': img, + 'ori_img_shape': (ori_h, ori_w), } diff --git a/modelscope/models/cv/image_driving_perception/utils.py b/modelscope/models/cv/image_driving_perception/utils.py index 82f16ed6..15b228a2 100644 --- a/modelscope/models/cv/image_driving_perception/utils.py +++ b/modelscope/models/cv/image_driving_perception/utils.py @@ -29,10 +29,7 @@ def split_for_trace_model(pred=None, anchor_grid=None): return pred -def scale_coords(img1_shape, - coords, - img0_shape=(720, 1280, 3), - ratio_pad=None): +def scale_coords(img1_shape, coords, img0_shape=(720, 1280), ratio_pad=None): # Rescale coords (xyxy) from img1_shape to img0_shape if ratio_pad is None: # calculate from img0_shape gain = min(img1_shape[0] / img0_shape[0], @@ -190,19 +187,17 @@ def box_iou(box1, box2): ) # iou = inter / (area1 + area2 - inter) -def driving_area_mask(seg=None): - da_predict = seg[:, :, 12:372, :] +def driving_area_mask(da_predict=None, out_shape=(720, 1280)): da_seg_mask = torch.nn.functional.interpolate( - da_predict, scale_factor=2, mode='bilinear') + da_predict, size=out_shape, mode='bilinear') _, da_seg_mask = torch.max(da_seg_mask, 1) da_seg_mask = da_seg_mask.int().squeeze().cpu().numpy() return da_seg_mask -def lane_line_mask(ll=None): - ll_predict = ll[:, :, 12:372, :] +def lane_line_mask(ll_predict=None, out_shape=(720, 1280)): ll_seg_mask = torch.nn.functional.interpolate( - ll_predict, scale_factor=2, mode='bilinear') + ll_predict, size=out_shape, mode='bilinear') ll_seg_mask = torch.round(ll_seg_mask).squeeze(1) ll_seg_mask = ll_seg_mask.int().squeeze().cpu().numpy() return ll_seg_mask diff --git a/modelscope/pipelines/cv/image_driving_perception_pipeline.py b/modelscope/pipelines/cv/image_driving_perception_pipeline.py index af4320de..ffc6ac6d 100644 --- a/modelscope/pipelines/cv/image_driving_perception_pipeline.py +++ b/modelscope/pipelines/cv/image_driving_perception_pipeline.py @@ -84,13 +84,17 @@ class ImageDrivingPerceptionPipeline(Pipeline): # Apply NMS pred = non_max_suppression(pred) - da_seg_mask = driving_area_mask(inputs['driving_area_mask']) - ll_seg_mask = lane_line_mask(inputs['lane_line_mask']) + h, w = inputs['ori_img_shape'] + da_seg_mask = driving_area_mask( + inputs['driving_area_mask'], out_shape=(h, w)) + ll_seg_mask = lane_line_mask( + inputs['lane_line_mask'], out_shape=(h, w)) for det in pred: # detections per image if len(det): - # Rescale boxes from img_size to (720, 1280) - det[:, :4] = scale_coords(inputs['img_hw'], det[:, :4]).round() + # Rescale boxes from img_size to (h, w) + det[:, :4] = scale_coords(inputs['img_hw'], det[:, :4], + (h, w)).round() results_dict[OutputKeys.BOXES] = det[:, :4].cpu().numpy() results_dict[OutputKeys.MASKS].append(da_seg_mask) diff --git a/modelscope/utils/cv/image_utils.py b/modelscope/utils/cv/image_utils.py index bbc35929..6378b82a 100644 --- a/modelscope/utils/cv/image_utils.py +++ b/modelscope/utils/cv/image_utils.py @@ -498,8 +498,6 @@ def show_image_driving_perception_result(img, results, out_file='result.jpg', if_draw=[1, 1, 1]): - assert img.shape == (720, 1280, - 3), 'input image shape need fix to (720, 1280, 3)' bboxes = results.get(OutputKeys.BOXES) if if_draw[0]: for x in bboxes: