mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-25 12:39:25 +01:00
[to #42322933]修复pipeline串联时collate_fn异常
修复pipeline串联时collate_fn异常
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10457058
This commit is contained in:
@@ -433,6 +433,8 @@ def collate_fn(data, device):
|
||||
if isinstance(data, dict) or isinstance(data, Mapping):
|
||||
return type(data)({k: collate_fn(v, device) for k, v in data.items()})
|
||||
elif isinstance(data, (tuple, list)):
|
||||
if 0 == len(data):
|
||||
return torch.Tensor([])
|
||||
if isinstance(data[0], (int, float)):
|
||||
return default_collate(data).to(device)
|
||||
else:
|
||||
|
||||
@@ -143,6 +143,13 @@ class Body3DKeypointsPipeline(Pipeline):
|
||||
max_frame = self.keypoint_model_3d.cfg.model.INPUT.MAX_FRAME # max video frame number to be predicted 3D joints
|
||||
for i, frame in enumerate(video_frames):
|
||||
kps_2d = self.human_body_2d_kps_detector(frame)
|
||||
if [] == kps_2d.get('boxes'):
|
||||
res = {
|
||||
'success': False,
|
||||
'msg': f'fail to detect person at image frame {i}'
|
||||
}
|
||||
return res
|
||||
|
||||
box = kps_2d['boxes'][
|
||||
0] # box: [[[x1, y1], [x2, y2]]], N human boxes per frame, [0] represent using first detected bbox
|
||||
pose = kps_2d['keypoints'][0] # keypoints: [15, 2]
|
||||
@@ -180,7 +187,15 @@ class Body3DKeypointsPipeline(Pipeline):
|
||||
return res
|
||||
|
||||
def postprocess(self, input: Dict[str, Any], **kwargs) -> Dict[str, Any]:
|
||||
res = {OutputKeys.KEYPOINTS: [], OutputKeys.TIMESTAMPS: []}
|
||||
output_video_path = kwargs.get('output_video', None)
|
||||
if output_video_path is None:
|
||||
output_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name
|
||||
|
||||
res = {
|
||||
OutputKeys.KEYPOINTS: [],
|
||||
OutputKeys.TIMESTAMPS: [],
|
||||
OutputKeys.OUTPUT_VIDEO: output_video_path
|
||||
}
|
||||
|
||||
if not input['success']:
|
||||
pass
|
||||
@@ -189,10 +204,6 @@ class Body3DKeypointsPipeline(Pipeline):
|
||||
pred_3d_pose = poses.data.cpu().numpy()[
|
||||
0] # [frame_num, joint_num, joint_dim]
|
||||
|
||||
output_video_path = kwargs.get('output_video', None)
|
||||
if output_video_path is None:
|
||||
output_video_path = tempfile.NamedTemporaryFile(
|
||||
suffix='.mp4').name
|
||||
if 'render' in self.keypoint_model_3d.cfg.keys():
|
||||
self.render_prediction(pred_3d_pose, output_video_path)
|
||||
res[OutputKeys.OUTPUT_VIDEO] = output_video_path
|
||||
|
||||
Reference in New Issue
Block a user