mirror of
https://github.com/modelscope/modelscope.git
synced 2026-02-24 04:01:10 +01:00
[to #42322933]style(license): add license + render result poses with video
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10263904
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import logging
|
||||
import os.path as osp
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# The implementation is based on OSTrack, available at https://github.com/facebookresearch/VideoPose3D
|
||||
# The implementation is based on VideoPose3D, available at https://github.com/facebookresearch/VideoPose3D
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ class OutputKeys(object):
|
||||
POLYGONS = 'polygons'
|
||||
OUTPUT = 'output'
|
||||
OUTPUT_IMG = 'output_img'
|
||||
OUTPUT_VIDEO = 'output_video'
|
||||
OUTPUT_PCM = 'output_pcm'
|
||||
IMG_EMBEDDING = 'img_embedding'
|
||||
SPO_LIST = 'spo_list'
|
||||
@@ -218,13 +219,21 @@ TASK_OUTPUTS = {
|
||||
|
||||
# 3D human body keypoints detection result for single sample
|
||||
# {
|
||||
# "poses": [
|
||||
# [[x, y, z]*17],
|
||||
# [[x, y, z]*17],
|
||||
# [[x, y, z]*17]
|
||||
# ]
|
||||
# "poses": [ # 3d pose coordinate in camera coordinate
|
||||
# [[x, y, z]*17], # joints of per image
|
||||
# [[x, y, z]*17],
|
||||
# ...
|
||||
# ],
|
||||
# "timestamps": [ # timestamps of all frames
|
||||
# "00:00:0.230",
|
||||
# "00:00:0.560",
|
||||
# "00:00:0.690",
|
||||
# ],
|
||||
# "output_video": "path_to_rendered_video" , this is optional
|
||||
# and is only avaialbe when the "render" option is enabled.
|
||||
# }
|
||||
Tasks.body_3d_keypoints: [OutputKeys.POSES],
|
||||
Tasks.body_3d_keypoints:
|
||||
[OutputKeys.POSES, OutputKeys.TIMESTAMPS, OutputKeys.OUTPUT_VIDEO],
|
||||
|
||||
# 2D hand keypoints result for single sample
|
||||
# {
|
||||
|
||||
@@ -1,10 +1,19 @@
|
||||
import os
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import datetime
|
||||
import os.path as osp
|
||||
import tempfile
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
import cv2
|
||||
import matplotlib
|
||||
import matplotlib.pyplot as plt
|
||||
import mpl_toolkits.mplot3d.axes3d as p3
|
||||
import numpy as np
|
||||
import torch
|
||||
from matplotlib import animation
|
||||
from matplotlib.animation import writers
|
||||
from matplotlib.ticker import MultipleLocator
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.models.cv.body_3d_keypoints.body_3d_pose import (
|
||||
@@ -16,6 +25,8 @@ from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
matplotlib.use('Agg')
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
@@ -121,7 +132,13 @@ class Body3DKeypointsPipeline(Pipeline):
|
||||
device='gpu' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
def preprocess(self, input: Input) -> Dict[str, Any]:
|
||||
video_frames = self.read_video_frames(input)
|
||||
video_url = input.get('input_video')
|
||||
self.output_video_path = input.get('output_video_path')
|
||||
if self.output_video_path is None:
|
||||
self.output_video_path = tempfile.NamedTemporaryFile(
|
||||
suffix='.mp4').name
|
||||
|
||||
video_frames = self.read_video_frames(video_url)
|
||||
if 0 == len(video_frames):
|
||||
res = {'success': False, 'msg': 'get video frame failed.'}
|
||||
return res
|
||||
@@ -168,13 +185,21 @@ class Body3DKeypointsPipeline(Pipeline):
|
||||
return res
|
||||
|
||||
def postprocess(self, input: Dict[str, Any], **kwargs) -> Dict[str, Any]:
|
||||
res = {OutputKeys.POSES: []}
|
||||
res = {OutputKeys.POSES: [], OutputKeys.TIMESTAMPS: []}
|
||||
|
||||
if not input['success']:
|
||||
pass
|
||||
else:
|
||||
poses = input[KeypointsTypes.POSES_CAMERA]
|
||||
res = {OutputKeys.POSES: poses.data.cpu().numpy()}
|
||||
pred_3d_pose = poses.data.cpu().numpy()[
|
||||
0] # [frame_num, joint_num, joint_dim]
|
||||
|
||||
if 'render' in self.keypoint_model_3d.cfg.keys():
|
||||
self.render_prediction(pred_3d_pose)
|
||||
res[OutputKeys.OUTPUT_VIDEO] = self.output_video_path
|
||||
|
||||
res[OutputKeys.POSES] = pred_3d_pose
|
||||
res[OutputKeys.TIMESTAMPS] = self.timestamps
|
||||
return res
|
||||
|
||||
def read_video_frames(self, video_url: Union[str, cv2.VideoCapture]):
|
||||
@@ -189,7 +214,15 @@ class Body3DKeypointsPipeline(Pipeline):
|
||||
Returns:
|
||||
[nd.array]: List of video frames.
|
||||
"""
|
||||
|
||||
def timestamp_format(seconds):
|
||||
m, s = divmod(seconds, 60)
|
||||
h, m = divmod(m, 60)
|
||||
time = '%02d:%02d:%06.3f' % (h, m, s)
|
||||
return time
|
||||
|
||||
frames = []
|
||||
self.timestamps = [] # for video render
|
||||
if isinstance(video_url, str):
|
||||
cap = cv2.VideoCapture(video_url)
|
||||
if not cap.isOpened():
|
||||
@@ -199,15 +232,131 @@ class Body3DKeypointsPipeline(Pipeline):
|
||||
else:
|
||||
cap = video_url
|
||||
|
||||
self.fps = cap.get(cv2.CAP_PROP_FPS)
|
||||
if self.fps is None or self.fps <= 0:
|
||||
raise Exception('modelscope error: %s cannot get video fps info.' %
|
||||
(video_url))
|
||||
|
||||
max_frame_num = self.keypoint_model_3d.cfg.model.INPUT.MAX_FRAME
|
||||
frame_idx = 0
|
||||
while True:
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
break
|
||||
self.timestamps.append(
|
||||
timestamp_format(seconds=frame_idx / self.fps))
|
||||
frame_idx += 1
|
||||
frames.append(frame)
|
||||
if frame_idx >= max_frame_num:
|
||||
break
|
||||
cap.release()
|
||||
return frames
|
||||
|
||||
def render_prediction(self, pose3d_cam_rr):
|
||||
"""render predict result 3d poses.
|
||||
|
||||
Args:
|
||||
pose3d_cam_rr (nd.array): [frame_num, joint_num, joint_dim], 3d pose joints
|
||||
|
||||
Returns:
|
||||
"""
|
||||
frame_num = pose3d_cam_rr.shape[0]
|
||||
|
||||
left_points = [11, 12, 13, 4, 5, 6] # joints of left body
|
||||
edges = [[0, 1], [0, 4], [0, 7], [1, 2], [4, 5], [5, 6], [2,
|
||||
3], [7, 8],
|
||||
[8, 9], [8, 11], [8, 14], [14, 15], [15, 16], [11, 12],
|
||||
[12, 13], [9, 10]] # connection between joints
|
||||
|
||||
fig = plt.figure()
|
||||
ax = p3.Axes3D(fig)
|
||||
x_major_locator = MultipleLocator(0.5)
|
||||
|
||||
ax.xaxis.set_major_locator(x_major_locator)
|
||||
ax.yaxis.set_major_locator(x_major_locator)
|
||||
ax.zaxis.set_major_locator(x_major_locator)
|
||||
ax.set_xlabel('X')
|
||||
ax.set_ylabel('Y')
|
||||
ax.set_zlabel('Z')
|
||||
ax.set_xlim(-1, 1)
|
||||
ax.set_ylim(-1, 1)
|
||||
ax.set_zlim(-1, 1)
|
||||
# view direction
|
||||
azim = self.keypoint_model_3d.cfg.render.azim
|
||||
elev = self.keypoint_model_3d.cfg.render.elev
|
||||
ax.view_init(elev, azim)
|
||||
|
||||
# init plot, essentially
|
||||
x = pose3d_cam_rr[0, :, 0]
|
||||
y = pose3d_cam_rr[0, :, 1]
|
||||
z = pose3d_cam_rr[0, :, 2]
|
||||
points, = ax.plot(x, y, z, 'r.')
|
||||
|
||||
def renderBones(xs, ys, zs):
|
||||
"""render bones in skeleton
|
||||
|
||||
Args:
|
||||
xs (nd.array): [joint_num, joint_channel]
|
||||
ys (nd.array): [joint_num, joint_channel]
|
||||
zs (nd.array): [joint_num, joint_channel]
|
||||
"""
|
||||
bones = {}
|
||||
for idx, edge in enumerate(edges):
|
||||
index1, index2 = edge[0], edge[1]
|
||||
if index1 in left_points:
|
||||
edge_color = 'red'
|
||||
else:
|
||||
edge_color = 'blue'
|
||||
connect = ax.plot([xs[index1], xs[index2]],
|
||||
[ys[index1], ys[index2]],
|
||||
[zs[index1], zs[index2]],
|
||||
linewidth=2,
|
||||
color=edge_color) # plot edge
|
||||
bones[idx] = connect[0]
|
||||
return bones
|
||||
|
||||
bones = renderBones(x, y, z)
|
||||
|
||||
def update(frame_idx, points, bones):
|
||||
"""update animation
|
||||
|
||||
Args:
|
||||
frame_idx (int): frame index
|
||||
points (mpl_toolkits.mplot3d.art3d.Line3D): skeleton points ploter
|
||||
bones (dict[int, mpl_toolkits.mplot3d.art3d.Line3D]): connection ploter
|
||||
|
||||
Returns:
|
||||
tuple: points and bones ploter
|
||||
"""
|
||||
xs = pose3d_cam_rr[frame_idx, :, 0]
|
||||
ys = pose3d_cam_rr[frame_idx, :, 1]
|
||||
zs = pose3d_cam_rr[frame_idx, :, 2]
|
||||
|
||||
# update bones
|
||||
for idx, edge in enumerate(edges):
|
||||
index1, index2 = edge[0], edge[1]
|
||||
x1x2 = (xs[index1], xs[index2])
|
||||
y1y2 = (ys[index1], ys[index2])
|
||||
z1z2 = (zs[index1], zs[index2])
|
||||
bones[idx].set_xdata(x1x2)
|
||||
bones[idx].set_ydata(y1y2)
|
||||
bones[idx].set_3d_properties(z1z2, 'z')
|
||||
|
||||
# update joints
|
||||
points.set_data(xs, ys)
|
||||
points.set_3d_properties(zs, 'z')
|
||||
if 0 == frame_idx / 100:
|
||||
logger.info(f'rendering {frame_idx}/{frame_num}')
|
||||
return points, bones
|
||||
|
||||
ani = animation.FuncAnimation(
|
||||
fig=fig,
|
||||
func=update,
|
||||
frames=frame_num,
|
||||
interval=self.fps,
|
||||
fargs=(points, bones))
|
||||
|
||||
# save mp4
|
||||
Writer = writers['ffmpeg']
|
||||
writer = Writer(fps=self.fps, metadata={}, bitrate=4096)
|
||||
ani.save(self.output_video_path, writer=writer)
|
||||
|
||||
@@ -28,7 +28,12 @@ class Body3DKeypointsTest(unittest.TestCase, DemoCompatibilityCheck):
|
||||
def test_run_modelhub_with_video_file(self):
|
||||
body_3d_keypoints = pipeline(
|
||||
Tasks.body_3d_keypoints, model=self.model_id)
|
||||
self.pipeline_inference(body_3d_keypoints, self.test_video)
|
||||
pipeline_input = {
|
||||
'input_video': self.test_video,
|
||||
'output_video_path': './result.mp4'
|
||||
}
|
||||
self.pipeline_inference(
|
||||
body_3d_keypoints, pipeline_input=pipeline_input)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_modelhub_with_video_stream(self):
|
||||
@@ -37,12 +42,12 @@ class Body3DKeypointsTest(unittest.TestCase, DemoCompatibilityCheck):
|
||||
if not cap.isOpened():
|
||||
raise Exception('modelscope error: %s cannot be decoded by OpenCV.'
|
||||
% (self.test_video))
|
||||
self.pipeline_inference(body_3d_keypoints, cap)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||
def test_run_modelhub_default_model(self):
|
||||
body_3d_keypoints = pipeline(Tasks.body_3d_keypoints)
|
||||
self.pipeline_inference(body_3d_keypoints, self.test_video)
|
||||
pipeline_input = {
|
||||
'input_video': cap,
|
||||
'output_video_path': './result.mp4'
|
||||
}
|
||||
self.pipeline_inference(
|
||||
body_3d_keypoints, pipeline_input=pipeline_input)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_demo_compatibility(self):
|
||||
|
||||
Reference in New Issue
Block a user