[to #42322933]修复pipeline串联时collate_fn异常

修复pipeline串联时collate_fn异常
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10457058
This commit is contained in:
hanyuan.chy
2022-10-20 19:33:06 +08:00
committed by yingda.chen
parent b2c5876ead
commit de6d84cb97
2 changed files with 18 additions and 5 deletions

View File

@@ -433,6 +433,8 @@ def collate_fn(data, device):
if isinstance(data, dict) or isinstance(data, Mapping):
return type(data)({k: collate_fn(v, device) for k, v in data.items()})
elif isinstance(data, (tuple, list)):
if 0 == len(data):
return torch.Tensor([])
if isinstance(data[0], (int, float)):
return default_collate(data).to(device)
else:

View File

@@ -143,6 +143,13 @@ class Body3DKeypointsPipeline(Pipeline):
max_frame = self.keypoint_model_3d.cfg.model.INPUT.MAX_FRAME # max video frame number to be predicted 3D joints
for i, frame in enumerate(video_frames):
kps_2d = self.human_body_2d_kps_detector(frame)
if [] == kps_2d.get('boxes'):
res = {
'success': False,
'msg': f'fail to detect person at image frame {i}'
}
return res
box = kps_2d['boxes'][
0] # box: [[[x1, y1], [x2, y2]]], N human boxes per frame, [0] represent using first detected bbox
pose = kps_2d['keypoints'][0] # keypoints: [15, 2]
@@ -180,7 +187,15 @@ class Body3DKeypointsPipeline(Pipeline):
return res
def postprocess(self, input: Dict[str, Any], **kwargs) -> Dict[str, Any]:
res = {OutputKeys.KEYPOINTS: [], OutputKeys.TIMESTAMPS: []}
output_video_path = kwargs.get('output_video', None)
if output_video_path is None:
output_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name
res = {
OutputKeys.KEYPOINTS: [],
OutputKeys.TIMESTAMPS: [],
OutputKeys.OUTPUT_VIDEO: output_video_path
}
if not input['success']:
pass
@@ -189,10 +204,6 @@ class Body3DKeypointsPipeline(Pipeline):
pred_3d_pose = poses.data.cpu().numpy()[
0] # [frame_num, joint_num, joint_dim]
output_video_path = kwargs.get('output_video', None)
if output_video_path is None:
output_video_path = tempfile.NamedTemporaryFile(
suffix='.mp4').name
if 'render' in self.keypoint_model_3d.cfg.keys():
self.render_prediction(pred_3d_pose, output_video_path)
res[OutputKeys.OUTPUT_VIDEO] = output_video_path