[to #42322933] fix: remove restriction on image resolution in the preprocessing stage

去掉预处理阶段resize 到720x1280的限制,支持输出与原始输入图像分辨率对应的结果

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11762135
This commit is contained in:
maojialiang.mjl
2023-02-24 09:49:02 +08:00
committed by wenmeng.zwm
parent 4ca52a0934
commit 97e35a040c
5 changed files with 16 additions and 17 deletions

View File

@@ -37,6 +37,7 @@ class YOLOPv2(TorchModel):
[pred, anchor_grid], seg, ll = self.model(img)
return {
'img_hw': data['img'].shape[2:],
'ori_img_shape': data['ori_img_shape'],
'pred': pred,
'anchor_grid': anchor_grid,
'driving_area_mask': seg,

View File

@@ -102,7 +102,7 @@ class ImageDrivingPerceptionPreprocessor(Preprocessor):
img = self._check_image(img)
else:
raise Exception('img is None')
img = cv2.resize(img, output_shape, interpolation=cv2.INTER_LINEAR)
ori_h, ori_w, _ = img.shape
img = self._letterbox(img, new_shape)[0]
img = img.transpose(2, 0, 1) # to 3x640x640
@@ -117,4 +117,5 @@ class ImageDrivingPerceptionPreprocessor(Preprocessor):
return {
'img': img,
'ori_img_shape': (ori_h, ori_w),
}

View File

@@ -29,10 +29,7 @@ def split_for_trace_model(pred=None, anchor_grid=None):
return pred
def scale_coords(img1_shape,
coords,
img0_shape=(720, 1280, 3),
ratio_pad=None):
def scale_coords(img1_shape, coords, img0_shape=(720, 1280), ratio_pad=None):
# Rescale coords (xyxy) from img1_shape to img0_shape
if ratio_pad is None: # calculate from img0_shape
gain = min(img1_shape[0] / img0_shape[0],
@@ -190,19 +187,17 @@ def box_iou(box1, box2):
) # iou = inter / (area1 + area2 - inter)
def driving_area_mask(seg=None):
da_predict = seg[:, :, 12:372, :]
def driving_area_mask(da_predict=None, out_shape=(720, 1280)):
da_seg_mask = torch.nn.functional.interpolate(
da_predict, scale_factor=2, mode='bilinear')
da_predict, size=out_shape, mode='bilinear')
_, da_seg_mask = torch.max(da_seg_mask, 1)
da_seg_mask = da_seg_mask.int().squeeze().cpu().numpy()
return da_seg_mask
def lane_line_mask(ll=None):
ll_predict = ll[:, :, 12:372, :]
def lane_line_mask(ll_predict=None, out_shape=(720, 1280)):
ll_seg_mask = torch.nn.functional.interpolate(
ll_predict, scale_factor=2, mode='bilinear')
ll_predict, size=out_shape, mode='bilinear')
ll_seg_mask = torch.round(ll_seg_mask).squeeze(1)
ll_seg_mask = ll_seg_mask.int().squeeze().cpu().numpy()
return ll_seg_mask

View File

@@ -84,13 +84,17 @@ class ImageDrivingPerceptionPipeline(Pipeline):
# Apply NMS
pred = non_max_suppression(pred)
da_seg_mask = driving_area_mask(inputs['driving_area_mask'])
ll_seg_mask = lane_line_mask(inputs['lane_line_mask'])
h, w = inputs['ori_img_shape']
da_seg_mask = driving_area_mask(
inputs['driving_area_mask'], out_shape=(h, w))
ll_seg_mask = lane_line_mask(
inputs['lane_line_mask'], out_shape=(h, w))
for det in pred: # detections per image
if len(det):
# Rescale boxes from img_size to (720, 1280)
det[:, :4] = scale_coords(inputs['img_hw'], det[:, :4]).round()
# Rescale boxes from img_size to (h, w)
det[:, :4] = scale_coords(inputs['img_hw'], det[:, :4],
(h, w)).round()
results_dict[OutputKeys.BOXES] = det[:, :4].cpu().numpy()
results_dict[OutputKeys.MASKS].append(da_seg_mask)

View File

@@ -498,8 +498,6 @@ def show_image_driving_perception_result(img,
results,
out_file='result.jpg',
if_draw=[1, 1, 1]):
assert img.shape == (720, 1280,
3), 'input image shape need fix to (720, 1280, 3)'
bboxes = results.get(OutputKeys.BOXES)
if if_draw[0]:
for x in bboxes: