Add HDFormer for 3d pose estimation of body

新增HDFormer 3D人体姿态估计模型 

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11576328

* add hdformer

* rm redundent pipeline

* add license

* add test case
This commit is contained in:
hanyuan.chy
2023-02-10 02:13:09 +00:00
committed by wenmeng.zwm
parent 0894b1ea71
commit 6da3be3047
14 changed files with 1341 additions and 17 deletions

View File

@@ -34,6 +34,7 @@ class Models(object):
product_retrieval_embedding = 'product-retrieval-embedding'
body_2d_keypoints = 'body-2d-keypoints'
body_3d_keypoints = 'body-3d-keypoints'
body_3d_keypoints_hdformer = 'hdformer'
crowd_counting = 'HRNetCrowdCounting'
face_2d_keypoints = 'face-2d-keypoints'
panoptic_segmentation = 'swinL-panoptic-segmentation'

View File

@@ -4,12 +4,12 @@ from typing import TYPE_CHECKING
from modelscope.utils.import_utils import LazyImportModule
if TYPE_CHECKING:
from .body_3d_pose import BodyKeypointsDetection3D
from .cannonical_pose import BodyKeypointsDetection3D
from .hdformer import HDFormerDetector
else:
_import_structure = {
'body_3d_pose': ['BodyKeypointsDetection3D'],
'cannonical_pose': ['BodyKeypointsDetection3D'],
'hdformer': ['HDFormerDetector'],
}
import sys

View File

@@ -0,0 +1 @@
from .body_3d_pose import BodyKeypointsDetection3D

View File

@@ -10,7 +10,7 @@ import torch
from modelscope.metainfo import Models
from modelscope.models.base.base_torch_model import TorchModel
from modelscope.models.builder import MODELS
from modelscope.models.cv.body_3d_keypoints.canonical_pose_modules import (
from modelscope.models.cv.body_3d_keypoints.cannonical_pose.canonical_pose_modules import (
TemporalModel, TransCan3Dkeys)
from modelscope.utils.config import Config
from modelscope.utils.constant import ModelFile, Tasks

View File

@@ -0,0 +1,2 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from .hdformer_detector import HDFormerDetector

View File

@@ -0,0 +1,306 @@
# --------------------------------------------------------
# The implementation is also open-sourced by the authors as Hanyuan Chen, and is available publicly on
# https://github.com/hyer/HDFormer
# --------------------------------------------------------
import torch
import torch.nn as nn
import torch.nn.functional as F
from modelscope.models.cv.body_3d_keypoints.hdformer.block import \
HightOrderAttentionBlock
from modelscope.models.cv.body_3d_keypoints.hdformer.directed_graph import (
DiGraph, Graph)
from modelscope.models.cv.body_3d_keypoints.hdformer.skeleton import \
get_skeleton
class HDFormerNet(nn.Module):
def __init__(self, cfg):
super(HDFormerNet, self).__init__()
in_channels = cfg.in_channels
dropout = cfg.dropout
self.cfg = cfg
self.PLANES = [16, 32, 64, 128, 256]
# load graph
skeleton = get_skeleton()
self.di_graph = DiGraph(skeleton=skeleton)
self.graph = Graph(
skeleton=skeleton, strategy='agcn', max_hop=1, dilation=1)
self.A = torch.tensor(
self.graph.A,
dtype=torch.float32,
requires_grad=True,
device='cuda')
# build networks
spatial_kernel_size = self.A.size(0)
temporal_kernel_size = 9
kernel_size = (temporal_kernel_size, spatial_kernel_size)
if not cfg.data_bn:
self.data_bn = None
else:
n_joints = self.cfg.IN_NUM_JOINTS \
if hasattr(self.cfg, 'IN_NUM_JOINTS') \
else self.cfg.n_joints
self.data_bn = nn.BatchNorm1d(in_channels * n_joints) if hasattr(cfg, 'PJN') and cfg.PJN \
else nn.BatchNorm2d(in_channels)
self.downsample = nn.ModuleList(
(
HightOrderAttentionBlock(
in_channels,
self.PLANES[0],
kernel_size,
A=self.A,
di_graph=self.di_graph,
residual=False,
adj_len=self.A.size(1),
attention=cfg.attention_down if hasattr(
cfg, 'attention_down') else False,
dropout=0),
HightOrderAttentionBlock(
self.PLANES[0],
self.PLANES[1],
kernel_size,
A=self.A,
di_graph=self.di_graph,
stride=2,
adj_len=self.A.size(1),
attention=cfg.attention_down if hasattr(
cfg, 'attention_down') else False,
dropout=0),
HightOrderAttentionBlock(
self.PLANES[1],
self.PLANES[1],
kernel_size,
A=self.A,
di_graph=self.di_graph,
adj_len=self.A.size(1),
attention=cfg.attention_down if hasattr(
cfg, 'attention_down') else False,
dropout=0),
HightOrderAttentionBlock(
self.PLANES[1],
self.PLANES[2],
kernel_size,
A=self.A,
di_graph=self.di_graph,
stride=2,
adj_len=self.A.size(1),
attention=cfg.attention_down if hasattr(
cfg, 'attention_down') else False,
dropout=0),
HightOrderAttentionBlock(
self.PLANES[2],
self.PLANES[2],
kernel_size,
A=self.A,
di_graph=self.di_graph,
adj_len=self.A.size(1),
attention=cfg.attention_down if hasattr(
cfg, 'attention_down') else False,
dropout=0),
HightOrderAttentionBlock(
self.PLANES[2],
self.PLANES[3],
kernel_size,
A=self.A,
di_graph=self.di_graph,
stride=2,
adj_len=self.A.size(1),
attention=cfg.attention_down if hasattr(
cfg, 'attention_down') else False,
dropout=dropout),
HightOrderAttentionBlock(
self.PLANES[3],
self.PLANES[3],
kernel_size,
A=self.A,
di_graph=self.di_graph,
adj_len=self.A.size(1),
attention=cfg.attention_down if hasattr(
cfg, 'attention_down') else False,
dropout=dropout),
HightOrderAttentionBlock(
self.PLANES[3],
self.PLANES[4],
kernel_size,
A=self.A,
di_graph=self.di_graph,
stride=2,
adj_len=self.A.size(1),
attention=cfg.attention_down if hasattr(
cfg, 'attention_down') else False,
dropout=dropout),
HightOrderAttentionBlock(
self.PLANES[4],
self.PLANES[4],
kernel_size,
A=self.A,
di_graph=self.di_graph,
adj_len=self.A.size(1),
attention=cfg.attention_down if hasattr(
cfg, 'attention_down') else False,
dropout=dropout),
))
self.upsample = nn.ModuleList((
HightOrderAttentionBlock(
self.PLANES[4],
self.PLANES[3],
kernel_size,
A=self.A,
di_graph=self.di_graph,
attention=cfg.attention_up
if hasattr(cfg, 'attention_up') else False,
adj_len=self.A.size(1),
dropout=dropout),
HightOrderAttentionBlock(
self.PLANES[3],
self.PLANES[2],
kernel_size,
A=self.A,
di_graph=self.di_graph,
attention=cfg.attention_up
if hasattr(cfg, 'attention_up') else False,
adj_len=self.A.size(1),
dropout=dropout),
HightOrderAttentionBlock(
self.PLANES[2],
self.PLANES[1],
kernel_size,
A=self.A,
di_graph=self.di_graph,
attention=cfg.attention_up
if hasattr(cfg, 'attention_up') else False,
adj_len=self.A.size(1),
dropout=0),
HightOrderAttentionBlock(
self.PLANES[1],
self.PLANES[0],
kernel_size,
A=self.A,
di_graph=self.di_graph,
attention=cfg.attention_up
if hasattr(cfg, 'attention_up') else False,
adj_len=self.A.size(1),
dropout=0),
))
self.merge = nn.ModuleList((
HightOrderAttentionBlock(
self.PLANES[4],
self.PLANES[0],
kernel_size,
A=self.A,
di_graph=self.di_graph,
attention=cfg.attention_merge if hasattr(
cfg, 'attention_merge') else False,
adj_len=self.A.size(1),
dropout=dropout,
max_hop=self.cfg.max_hop),
HightOrderAttentionBlock(
self.PLANES[3],
self.PLANES[0],
kernel_size,
A=self.A,
di_graph=self.di_graph,
attention=cfg.attention_merge if hasattr(
cfg, 'attention_merge') else False,
adj_len=self.A.size(1),
dropout=dropout,
max_hop=self.cfg.max_hop),
HightOrderAttentionBlock(
self.PLANES[2],
self.PLANES[0],
kernel_size,
A=self.A,
di_graph=self.di_graph,
attention=cfg.attention_merge if hasattr(
cfg, 'attention_merge') else False,
adj_len=self.A.size(1),
dropout=0,
max_hop=self.cfg.max_hop),
HightOrderAttentionBlock(
self.PLANES[1],
self.PLANES[0],
kernel_size,
A=self.A,
di_graph=self.di_graph,
attention=cfg.attention_merge if hasattr(
cfg, 'attention_merge') else False,
adj_len=self.A.size(1),
dropout=0,
max_hop=self.cfg.max_hop),
))
def get_edge_fea(self, x_v):
x_e = (x_v[..., [c for p, c in self.di_graph.directed_edges_hop1]]
- x_v[..., [p for p, c in self.di_graph.directed_edges_hop1]]
).contiguous()
N, C, T, V = x_v.shape
edeg_append = torch.zeros((N, C, T, 1), device=x_e.device)
x_e = torch.cat((x_e, edeg_append), dim=-1)
return x_e
def forward(self, x_v: torch.Tensor):
"""
x: shape [B,C,T,V_v]
"""
B, C, T, V = x_v.shape
# data normalization
if self.data_bn is not None:
if hasattr(self.cfg, 'PJN') and self.cfg.PJN:
x_v = self.data_bn(x_v.permute(0, 1, 3, 2).contiguous().view(B, -1, T)).view(B, C, V, T) \
.contiguous().permute(0, 1, 3, 2)
else:
x_v = self.data_bn(x_v)
x_e = self.get_edge_fea(x_v)
# forward
feature = []
for idx, hoa_block in enumerate(self.downsample):
x_v, x_e = hoa_block(x_v, x_e)
if idx == 0 or idx == 2 or idx == 4 or idx == 6:
feature.append((x_v, x_e))
feature.append((x_v, x_e))
feature = feature[::-1]
x_v, x_e = feature[0]
identity_feature = feature[1:]
ushape_feature = []
ushape_feature.append((x_v, x_e))
for idx, (hoa_block, id) in \
enumerate(zip(self.upsample, identity_feature)):
x_v, x_e = hoa_block(x_v, x_e)
if hasattr(self.cfg, 'deterministic') and self.cfg.deterministic:
x_v = F.interpolate(x_v, scale_factor=(2, 1), mode='nearest')
else:
x_v = F.interpolate(
x_v,
scale_factor=(2, 1),
mode='bilinear',
align_corners=False)
x_v += id[0]
ushape_feature.append((x_v, x_e))
ushape_feature = ushape_feature[:-1]
for idx, (hoa_block, u) in \
enumerate(zip(self.merge, ushape_feature)):
x_v2, x_e2 = hoa_block(*u)
if hasattr(self.cfg, 'deterministic') and self.cfg.deterministic:
x_v += F.interpolate(
x_v2, scale_factor=(2**(4 - idx), 1), mode='nearest')
else:
x_v += F.interpolate(
x_v2,
scale_factor=(2**(4 - idx), 1),
mode='bilinear',
align_corners=False)
return x_v, x_e

View File

@@ -0,0 +1,380 @@
# Part of the implementation is borrowed and modified from 2s-AGCN, publicly available at
# https://github.com/lshiwjx/2s-AGCN
import math
import torch
import torch.nn as nn
from einops import rearrange
def import_class(name):
components = name.split('.')
mod = __import__(components[0])
for comp in components[1:]:
mod = getattr(mod, comp)
return mod
def conv_branch_init(conv, branches):
weight = conv.weight
n = weight.size(0)
k1 = weight.size(1)
k2 = weight.size(2)
nn.init.normal_(weight, 0, math.sqrt(2. / (n * k1 * k2 * branches)))
if conv.bias is not None:
nn.init.constant_(conv.bias, 0)
def conv_init(conv):
if conv.weight is not None:
nn.init.kaiming_normal_(conv.weight, mode='fan_out')
if conv.bias is not None:
nn.init.constant_(conv.bias, 0)
def bn_init(bn, scale):
nn.init.constant_(bn.weight, scale)
nn.init.constant_(bn.bias, 0)
def zero(x):
"""return zero."""
return 0
def iden(x):
"""return input itself."""
return x
class Mlp(nn.Module):
def __init__(self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
drop=0.,
changedim=False,
currentdim=0,
depth=0):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class Attention(nn.Module):
def __init__(self,
dim,
num_heads=8,
qkv_bias=False,
qk_scale=None,
attn_drop=0.,
proj_drop=0.,
comb=False,
vis=False):
"""Attention is all you need
Args:
dim (_type_): _description_
num_heads (int, optional): _description_. Defaults to 8.
qkv_bias (bool, optional): _description_. Defaults to False.
qk_scale (_type_, optional): _description_. Defaults to None.
attn_drop (_type_, optional): _description_. Defaults to 0..
proj_drop (_type_, optional): _description_. Defaults to 0..
comb (bool, optional): Defaults to False.
True: q transpose * k.
False: q * k transpose.
vis (bool, optional): _description_. Defaults to False.
"""
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
self.scale = qk_scale or head_dim**-0.5
self.to_q = nn.Linear(dim, dim, bias=qkv_bias)
self.to_k = nn.Linear(dim, dim, bias=qkv_bias)
self.to_v = nn.Linear(dim, dim, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.comb = comb
self.vis = vis
def forward(self, fv, fe):
B, N, C = fv.shape
B, E, C = fe.shape
q = self.to_q(fv).reshape(B, N, self.num_heads,
C // self.num_heads).permute(0, 2, 1, 3)
k = self.to_k(fe).reshape(B, E, self.num_heads,
C // self.num_heads).permute(0, 2, 1, 3)
v = self.to_v(fe).reshape(B, E, self.num_heads,
C // self.num_heads).permute(0, 2, 1, 3)
# Now fv shape (B, H, N, C//heads)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
if self.comb:
fv = (attn @ v.transpose(-2, -1)).transpose(-2, -1)
fv = rearrange(fv, 'B H N C -> B N (H C)')
elif self.comb is False:
fv = (attn @ v).transpose(1, 2).reshape(B, N, C)
fv = self.proj(fv)
fv = self.proj_drop(fv)
return fv
class FirstOrderAttention(nn.Module):
"""First Order Attention block for spatial relationship.
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
A,
t_kernel_size=1,
t_stride=1,
t_padding=0,
t_dilation=1,
adj_len=17,
bias=True):
super().__init__()
self.kernel_size = kernel_size
self.A = A
self.PA = nn.Parameter(torch.FloatTensor(3, adj_len, adj_len))
torch.nn.init.constant_(self.PA, 1e-6)
self.num_subset = 3
inter_channels = out_channels // 4
self.inter_c = inter_channels
self.conv_a = nn.ModuleList()
self.conv_b = nn.ModuleList()
self.conv_d = nn.ModuleList()
self.linears = nn.ModuleList()
for i in range(self.num_subset):
self.conv_a.append(nn.Conv2d(in_channels, inter_channels, 1))
self.conv_b.append(nn.Conv2d(in_channels, inter_channels, 1))
self.conv_d.append(nn.Conv2d(in_channels, out_channels, 1))
self.linears.append(nn.Linear(in_channels, in_channels))
if in_channels != out_channels:
self.down = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 1),
nn.BatchNorm2d(out_channels))
else:
self.down = lambda x: x
self.bn = nn.BatchNorm2d(out_channels)
self.soft = nn.Softmax(-2)
self.relu = nn.ReLU()
for m in self.modules():
if isinstance(m, nn.Conv2d):
conv_init(m)
elif isinstance(m, nn.BatchNorm2d):
bn_init(m, 1)
bn_init(self.bn, 1e-6)
for i in range(self.num_subset):
conv_branch_init(self.conv_d[i], self.num_subset)
def forward(self, x):
assert self.A.shape[0] == self.kernel_size[1]
N, C, T, V = x.size()
A = self.A + self.PA
y = None
for i in range(self.num_subset):
x_in = rearrange(x, 'N C T V -> N T V C')
x_in = self.linears[i](x_in)
A0 = rearrange(x_in, 'N T V C -> N (C T) V')
A1 = self.conv_a[i](x).permute(0, 3, 1, 2).contiguous().view(
N, V, self.inter_c * T)
A2 = self.conv_b[i](x).view(N, self.inter_c * T, V)
A1 = self.soft(torch.matmul(A1, A2) / A1.size(-1))
A1 = A1 + A[i]
z = self.conv_d[i](torch.matmul(A0, A1).view(N, C, T, V))
y = z + y if y is not None else z
y = self.bn(y)
y += self.down(x)
return self.relu(y)
class HightOrderAttentionBlock(nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size,
A,
di_graph,
attention=False,
stride=1,
adj_len=17,
dropout=0,
residual=True,
norm_layer=nn.BatchNorm2d,
edge_importance=False,
graph=None,
conditional=False,
experts=4,
bias=True,
share_tcn=False,
max_hop=2):
super().__init__()
t_kernel_size = kernel_size[0]
assert t_kernel_size % 2 == 1
padding = ((t_kernel_size - 1) // 2, 0)
self.max_hop = max_hop
self.attention = attention
self.di_graph = di_graph
self.foa_block = FirstOrderAttention(
in_channels,
out_channels,
kernel_size,
A,
bias=bias,
adj_len=adj_len)
self.tcn_v = nn.Sequential(
norm_layer(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(
out_channels,
out_channels, (t_kernel_size, 1), (stride, 1),
padding,
bias=bias),
norm_layer(out_channels),
nn.Dropout(dropout, inplace=True),
)
if not residual:
self.residual_v = zero
elif (in_channels == out_channels) and (stride == 1):
self.residual_v = iden
else:
self.residual_v = nn.Sequential(
nn.Conv2d(
in_channels,
out_channels,
kernel_size=1,
stride=(stride, 1),
bias=bias),
norm_layer(out_channels),
)
self.relu = nn.ReLU(inplace=True)
if self.attention:
self.cross_attn = Attention(
dim=out_channels,
num_heads=8,
qkv_bias=True,
qk_scale=None,
attn_drop=dropout,
proj_drop=dropout)
self.norm_v = nn.LayerNorm(out_channels)
self.mlp = Mlp(
in_features=out_channels,
out_features=out_channels,
hidden_features=out_channels * 2,
act_layer=nn.GELU,
drop=dropout)
self.norm_mlp = nn.LayerNorm(out_channels)
# linear to change fep channels
self.linears = nn.ModuleList()
for hop_i in range(self.max_hop - 1):
hop_linear = nn.ModuleList()
for i in range(
len(
eval(f'self.di_graph.directed_edges_hop{hop_i+2}'))
):
hop_linear.append(nn.Linear(hop_i + 2, 1))
self.linears.append(hop_linear)
def forward(self, fv, fe):
# `fv` (node features) has shape (B, C, T, V_node)
# `fe` (edge features) has shape (B, C, T, V_edge)
N, C, T, V = fv.size()
res_v = self.residual_v(fv)
fvp = self.foa_block(fv)
fep_out = (
fvp[..., [c for p, c in self.di_graph.directed_edges_hop1]]
- fvp[..., [p for p, c in self.di_graph.directed_edges_hop1]]
).contiguous()
if self.attention:
fep_concat = None
for hop_i in range(self.max_hop):
if 0 == hop_i:
fep_hop_i = (fvp[..., [
c for p, c in eval(
f'self.di_graph.directed_edges_hop{hop_i+1}')
]] - fvp[..., [
p for p, c in eval(
f'self.di_graph.directed_edges_hop{hop_i+1}')
]]).contiguous()
fep_hop_i = rearrange(fep_hop_i, 'N C T E -> (N T) E C')
else:
joints_parts = eval(
f'self.di_graph.directed_edges_hop{hop_i+1}')
fep_hop_i = None
for part_idx, part in enumerate(joints_parts):
fep_part = None
for j in range(len(part) - 1):
fep = (fvp[..., part[j + 1]]
- fvp[..., part[j]]).contiguous().unsqueeze(
dim=-1)
if fep_part is None:
fep_part = fep
else:
fep_part = torch.cat((fep_part, fep), dim=-1)
fep_part = self.linears[hop_i - 1][part_idx](fep_part)
if fep_hop_i is None:
fep_hop_i = fep_part
else:
fep_hop_i = torch.cat((fep_hop_i, fep_part),
dim=-1)
fep_hop_i = rearrange(fep_hop_i, 'N C T E -> (N T) E C')
if fep_concat is None:
fep_concat = fep_hop_i
else:
fep_concat = torch.cat((fep_concat, fep_hop_i),
dim=-2) # dim=-2 represent edge dim
fvp = rearrange(fvp, 'N C T V -> (N T) V C')
fvp = self.norm_v(self.cross_attn(fvp, fep_concat)) + iden(fvp)
fvp = self.mlp(self.norm_mlp(fvp)) + iden(
fvp) # make output joint number = adj_len
fvp = rearrange(fvp, '(N T) V C -> N C T V', N=N)
fvp = self.tcn_v(fvp) + res_v
return self.relu(fvp), fep_out

View File

@@ -0,0 +1,209 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import sys
from typing import List, Tuple
import numpy as np
sys.path.insert(0, './')
def edge2mat(link, num_node):
"""According to the directed edge link, the adjacency matrix is constructed.
link: [V, 2], each row is a tuple(start node, end node).
"""
A = np.zeros((num_node, num_node))
for i, j in link:
A[j, i] = 1
return A
def normalize_incidence_matrix(im: np.ndarray) -> np.ndarray:
Dl = im.sum(-1)
num_node = im.shape[0]
Dn = np.zeros((num_node, num_node))
for i in range(num_node):
if Dl[i] > 0:
Dn[i, i] = Dl[i]**(-1)
res = Dn @ im
return res
def build_digraph_incidence_matrix(num_nodes: int,
edges: List[Tuple]) -> np.ndarray:
source_graph = np.zeros((num_nodes, len(edges)), dtype='float32')
target_graph = np.zeros((num_nodes, len(edges)), dtype='float32')
for edge_id, (source_node, target_node) in enumerate(edges):
source_graph[source_node, edge_id] = 1.
target_graph[target_node, edge_id] = 1.
source_graph = normalize_incidence_matrix(source_graph)
target_graph = normalize_incidence_matrix(target_graph)
return source_graph, target_graph
class DiGraph():
def __init__(self, skeleton):
super().__init__()
self.num_nodes = len(skeleton.parents())
self.directed_edges_hop1 = [
(parrent, child)
for child, parrent in enumerate(skeleton.parents()) if parrent >= 0
]
self.directed_edges_hop2 = [(0, 1, 2), (0, 4, 5), (0, 7, 8), (1, 2, 3),
(4, 5, 6), (7, 8, 9),
(7, 8, 11), (7, 8, 14), (8, 9, 10),
(8, 11, 12), (8, 14, 15), (11, 12, 13),
(14, 15, 16)] # (parrent, child)
self.directed_edges_hop3 = [(0, 1, 2, 3), (0, 4, 5, 6), (0, 7, 8, 9),
(7, 8, 9, 10), (7, 8, 11, 12),
(7, 8, 14, 15), (8, 11, 12, 13),
(8, 14, 15, 16)]
self.directed_edges_hop4 = [(0, 7, 8, 9, 10), (0, 7, 8, 11, 12),
(0, 7, 8, 14, 15), (7, 8, 11, 12, 13),
(7, 8, 14, 15, 16)]
self.num_edges = len(self.directed_edges_hop1)
self.edge_left = [0, 1, 2, 10, 11, 12]
self.edge_right = [3, 4, 5, 13, 14, 15]
self.edge_middle = [6, 7, 8, 9]
self.center = 0 # for h36m data skeleton
# Incidence matrices
self.source_M, self.target_M = \
build_digraph_incidence_matrix(self.num_nodes, self.directed_edges_hop1)
class Graph():
""" The Graph to model the skeletons extracted by the openpose
Args:
strategy (string): must be one of the follow candidates
- uniform: Uniform Labeling
- distance: Distance Partitioning
- spatial: Spatial Configuration
- agcn: AGCN Configuration
For more information, please refer to the section 'Partition Strategies'
in our paper (https://arxiv.org/abs/1801.07455).
layout (string): must be one of the follow candidates
- openpose: Is consists of 18 joints. For more information, please
refer to https://github.com/CMU-Perceptual-Computing-Lab/openpose#output
- ntu-rgb+d: Is consists of 25 joints. For more information, please
refer to https://github.com/shahroudy/NTURGB-D
max_hop (int): the maximal distance between two connected nodes
dilation (int): controls the spacing between the kernel points
"""
def __init__(self,
skeleton=None,
strategy='uniform',
max_hop=1,
dilation=1):
self.max_hop = max_hop
self.dilation = dilation
assert strategy in ['uniform', 'distance', 'spatial', 'agcn']
self.get_edge(skeleton)
self.hop_dis = get_hop_distance(
self.num_node, self.edge, max_hop=max_hop)
self.get_adjacency(strategy)
def __str__(self):
return self.A
def get_edge(self, skeleton):
# edge is a list of [child, parent] paris
self.num_node = len(skeleton.parents())
self_link = [(i, i) for i in range(self.num_node)]
neighbor_link = [(child, parrent)
for child, parrent in enumerate(skeleton.parents())]
self.self_link = self_link
self.neighbor_link = neighbor_link
self.edge = self_link + neighbor_link
self.center = 0 # for h36m data skeleton, root node idx
def get_adjacency(self, strategy):
valid_hop = range(0, self.max_hop + 1, self.dilation)
adjacency = np.zeros((self.num_node, self.num_node))
for hop in valid_hop:
adjacency[self.hop_dis == hop] = 1
normalize_adjacency = normalize_digraph(adjacency)
if strategy == 'uniform':
A = np.zeros((1, self.num_node, self.num_node))
A[0] = normalize_adjacency
self.A = A
elif strategy == 'distance':
A = np.zeros((len(valid_hop), self.num_node, self.num_node))
for i, hop in enumerate(valid_hop):
A[i][self.hop_dis == hop] = \
normalize_adjacency[self.hop_dis == hop]
self.A = A
elif strategy == 'spatial':
A = []
for hop in valid_hop:
a_root = np.zeros((self.num_node, self.num_node))
a_close = np.zeros((self.num_node, self.num_node))
a_further = np.zeros((self.num_node, self.num_node))
for i in range(self.num_node):
for j in range(self.num_node):
if self.hop_dis[j, i] == hop:
if self.hop_dis[j, self.center] == self.hop_dis[
i, self.center]:
a_root[j, i] = normalize_adjacency[j, i]
elif self.hop_dis[j, self.center] > self.hop_dis[
i, self.center]:
a_close[j, i] = normalize_adjacency[j, i]
else:
a_further[j, i] = normalize_adjacency[j, i]
if hop == 0:
A.append(a_root)
else:
A.append(a_root + a_close)
A.append(a_further)
A = np.stack(A)
self.A = A
elif strategy == 'agcn':
A = []
link_mat = edge2mat(self.self_link, self.num_node)
In = normalize_digraph(edge2mat(self.neighbor_link, self.num_node))
outward = [(j, i) for (i, j) in self.neighbor_link]
Out = normalize_digraph(edge2mat(outward, self.num_node))
A = np.stack((link_mat, In, Out))
self.A = A
else:
raise ValueError('Do Not Exist This Strategy')
def get_hop_distance(num_node, edge, max_hop=1):
A = np.zeros((num_node, num_node))
for i, j in edge:
A[j, i] = 1
A[i, j] = 1
# compute hop steps
hop_dis = np.zeros((num_node, num_node)) + np.inf
transfer_mat = [np.linalg.matrix_power(A, d) for d in range(max_hop + 1)]
arrive_mat = (np.stack(transfer_mat) > 0)
for d in range(max_hop, -1, -1):
hop_dis[arrive_mat[d]] = d
return hop_dis
def normalize_digraph(A):
Dl = np.sum(A, 0)
num_node = A.shape[0]
Dn = np.zeros((num_node, num_node))
for i in range(num_node):
if Dl[i] > 0:
Dn[i, i] = Dl[i]**(-1)
AD = np.dot(A, Dn)
return AD
def normalize_undigraph(A):
Dl = np.sum(A, 0)
num_node = A.shape[0]
Dn = np.zeros((num_node, num_node))
for i in range(num_node):
if Dl[i] > 0:
Dn[i, i] = Dl[i]**(-0.5)
DAD = np.dot(np.dot(Dn, A), Dn)
return DAD

View File

@@ -0,0 +1,64 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import torch
import torch.nn as nn
from modelscope.models.cv.body_3d_keypoints.hdformer.backbone import \
HDFormerNet
class HDFormer(nn.Module):
def __init__(self, cfg):
super(HDFormer, self).__init__()
self.regress_with_edge = hasattr(
cfg, 'regress_with_edge') and cfg.regress_with_edge
self.backbone = HDFormerNet(cfg)
num_v, num_e = self.backbone.di_graph.source_M.shape
self.regressor_type = cfg.regressor_type if hasattr(
cfg, 'regressor_type') else 'conv'
if self.regressor_type == 'conv':
self.joint_regressor = nn.Conv2d(
self.backbone.PLANES[0],
3 * (num_v - 1),
kernel_size=(3, num_v + num_e) if self.regress_with_edge else
(3, num_v),
padding=(1, 0),
bias=True)
elif self.regressor_type == 'fc':
self.joint_regressor = nn.Conv1d(
self.backbone.PLANES[0] * (num_v + num_e)
if self.regress_with_edge else self.backbone.PLANES[0] * num_v,
3 * (num_v - 1),
kernel_size=3,
padding=1,
bias=True)
else:
raise NotImplementedError
def forward(self, x_v: torch.Tensor, mean_3d: torch.Tensor,
std_3d: torch.Tensor):
"""
x: shape [B,C,T,V_v]
"""
fv, fe = self.backbone(x_v)
B, C, T, V = fv.shape
if self.regressor_type == 'conv':
pre_joints = self.joint_regressor(torch.cat([
fv, fe
], dim=-1)) if self.regress_with_edge else self.joint_regressor(fv)
elif self.regressor_type == 'fc':
x = (torch.cat([fv, fe], dim=-1) if self.regress_with_edge else fv) \
.permute(0, 1, 3, 2).contiguous().view(B, -1, T)
pre_joints = self.joint_regressor(x)
else:
raise NotImplementedError
pre_joints = pre_joints.view(B, 3, V - 1,
T).permute(0, 1, 3,
2).contiguous() # [B,3,T,V-1]
root_node = torch.zeros((B, 3, T, 1),
dtype=pre_joints.dtype,
device=pre_joints.device)
pre_joints = torch.cat((root_node, pre_joints), dim=-1)
pre_joints = pre_joints * std_3d + mean_3d
return pre_joints

View File

@@ -0,0 +1,196 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os.path as osp
from typing import Any, Dict, List, Union
import numpy as np
import torch
import torch.backends.cudnn as cudnn
from modelscope.metainfo import Models
from modelscope.models.base.base_torch_model import TorchModel
from modelscope.models.builder import MODELS
from modelscope.models.cv.body_3d_keypoints.hdformer.hdformer import HDFormer
from modelscope.utils.config import Config
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger
class KeypointsTypes(object):
POSES_CAMERA = 'poses_camera'
logger = get_logger()
@MODELS.register_module(
Tasks.body_3d_keypoints, module_name=Models.body_3d_keypoints_hdformer)
class HDFormerDetector(TorchModel):
def __init__(self, model_dir: str, *args, **kwargs):
super().__init__(model_dir, *args, **kwargs)
self.model_dir = model_dir
cudnn.benchmark = True
self.model_path = osp.join(self.model_dir, ModelFile.TORCH_MODEL_FILE)
self.mean_std_2d = np.load(
osp.join(self.model_dir, 'mean_std_2d.npy'), allow_pickle=True)
self.mean_std_3d = np.load(
osp.join(self.model_dir, 'mean_std_3d.npy'), allow_pickle=True)
self.left_right_symmetry_2d = np.array(
[0, 4, 5, 6, 1, 2, 3, 7, 8, 9, 10, 14, 15, 16, 11, 12, 13])
cfg_path = osp.join(self.model_dir, ModelFile.CONFIGURATION)
self.cfg = Config.from_file(cfg_path)
if torch.cuda.is_available():
self.device = torch.device('cuda')
else:
self.device = torch.device('cpu')
self.net = HDFormer(self.cfg.model.MODEL)
self.load_model()
self.net = self.net.to(self.device)
def load_model(self, load_to_cpu=False):
pretrained_dict = torch.load(
self.model_path,
map_location=torch.device('cuda')
if torch.cuda.is_available() else torch.device('cpu'))
self.net.load_state_dict(pretrained_dict['state_dict'], strict=False)
self.net.eval()
def preprocess(self, input: Dict[str, Any]) -> Dict[str, Any]:
"""Proprocess of 2D input joints.
Args:
input (Dict[str, Any]): [NUM_FRAME, NUM_JOINTS, 2], input 2d human body keypoints.
Returns:
Dict[str, Any]: canonical 2d points and root relative joints.
"""
if 'cuda' == input.device.type:
input = input.data.cpu().numpy()
elif 'cpu' == input.device.type:
input = input.data.numpy()
pose2d = input
num_frames, num_joints, in_channels = pose2d.shape
logger.info(f'2d pose frame number: {num_frames}')
# [NUM_FRAME, NUM_JOINTS, 2]
c = np.array(self.cfg.model.INPUT.center)
f = np.array(self.cfg.model.INPUT.focal_length)
self.window_size = self.cfg.model.INPUT.window_size
receptive_field = self.cfg.model.INPUT.n_frames
# split the 2D pose sequences into fixed length frames
inputs_2d = []
inputs_2d_flip = []
n = 0
indices = []
while n + receptive_field <= num_frames:
indices.append((n, n + receptive_field))
n += self.window_size
self.valid_length = n - self.window_size + receptive_field
if 0 == len(indices):
logger.warn(
f'Fail to construct test sequences, total_frames = {num_frames}, \
while receptive_filed ={receptive_field}')
self.mean_2d = self.mean_std_2d[0]
self.std_2d = self.mean_std_2d[1]
for (start, end) in indices:
data_2d = pose2d[start:end]
data_2d = (data_2d - 0.5 - c) / f
data_2d_flip = data_2d.copy()
data_2d_flip[:, :, 0] *= -1
data_2d_flip = data_2d_flip[:, self.left_right_symmetry_2d, :]
data_2d_flip = (data_2d_flip - self.mean_2d) / self.std_2d
data_2d = (data_2d - self.mean_2d) / self.std_2d
data_2d = torch.from_numpy(data_2d.transpose(
(2, 0, 1))).float() # [C,T,V]
data_2d_flip = torch.from_numpy(data_2d_flip.transpose(
(2, 0, 1))).float() # [C,T,V]
inputs_2d.append(data_2d)
inputs_2d_flip.append(data_2d_flip)
self.mean_3d = self.mean_std_3d[0]
self.std_3d = self.mean_std_3d[1]
mean_3d = torch.from_numpy(self.mean_3d).float().unsqueeze(-1)
mean_3d = mean_3d.permute(1, 2, 0) # [3, 1, 17]
std_3d = torch.from_numpy(self.std_3d).float().unsqueeze(-1)
std_3d = std_3d.permute(1, 2, 0)
return {
'inputs_2d': inputs_2d,
'inputs_2d_flip': inputs_2d_flip,
'mean_3d': mean_3d,
'std_3d': std_3d
}
def avg_flip(self, pre, pre_flip):
left_right_symmetry = [
0, 4, 5, 6, 1, 2, 3, 7, 8, 9, 10, 14, 15, 16, 11, 12, 13
]
pre_flip[:, 0, :, :] *= -1
pre_flip = pre_flip[:, :, :, left_right_symmetry]
pred_avg = (pre + pre_flip) / 2.
return pred_avg
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
"""3D human pose estimation.
Args:
input (Dict):
inputs_2d: [1, NUM_FRAME, NUM_JOINTS, 2]
Returns:
Dict[str, Any]:
"camera_pose": Tensor, [1, NUM_FRAME, OUT_NUM_JOINTS, OUT_3D_FEATURE_DIM],
3D human pose keypoints in camera frame.
"success": 3D pose estimation success or failed.
"""
inputs_2d = input['inputs_2d']
inputs_2d_flip = input['inputs_2d_flip']
mean_3d = input['mean_3d']
std_3d = input['std_3d']
preds_3d = None
vertex_pre = None
if [] == inputs_2d:
predict_dict = {'success': False, KeypointsTypes.POSES_CAMERA: []}
return predict_dict
with torch.no_grad():
for i, pose_2d in enumerate(inputs_2d):
pose_2d = pose_2d.unsqueeze(0).cuda(non_blocking=True) \
if torch.cuda.is_available() else pose_2d.unsqueeze(0)
pose_2d_flip = inputs_2d_flip[i]
pose_2d_flip = pose_2d_flip.unsqueeze(0).cuda(non_blocking=True) \
if torch.cuda.is_available() else pose_2d_flip.unsqueeze(0)
mean_3d = mean_3d.unsqueeze(0).cuda(non_blocking=True) \
if torch.cuda.is_available() else mean_3d.unsqueeze(0)
std_3d = std_3d.unsqueeze(0).cuda(non_blocking=True) \
if torch.cuda.is_available() else std_3d.unsqueeze(0)
vertex_pre = self.net(pose_2d, mean_3d, std_3d)
vertex_pre_flip = self.net(pose_2d_flip, mean_3d, std_3d)
vertex_pre = self.avg_flip(vertex_pre, vertex_pre_flip)
# concat the prediction results for each window_size
predict_3d = vertex_pre.permute(
0, 2, 3, 1).contiguous()[0][:self.window_size]
if preds_3d is None:
preds_3d = predict_3d
else:
preds_3d = torch.concat((preds_3d, predict_3d), dim=0)
remain_pose_results = vertex_pre.permute(
0, 2, 3, 1).contiguous()[0][self.window_size:]
preds_3d = torch.concat((preds_3d, remain_pose_results), dim=0)
preds_3d = preds_3d.unsqueeze(0) # add batch dim
preds_3d = preds_3d / self.cfg.model.INPUT.res_w # Normalize to [-1, 1]
predict_dict = {'success': True, KeypointsTypes.POSES_CAMERA: preds_3d}
return predict_dict

View File

@@ -0,0 +1,103 @@
# Copyright (c) 2018-present, Facebook, Inc. All rights reserved.
import numpy as np
class Skeleton:
def __init__(self, parents, joints_left, joints_right):
assert len(joints_left) == len(joints_right)
self._parents = np.array(parents)
self._joints_left = joints_left
self._joints_right = joints_right
self._compute_metadata()
def num_joints(self):
return len(self._parents)
def parents(self):
return self._parents
def has_children(self):
return self._has_children
def children(self):
return self._children
def remove_joints(self, joints_to_remove):
"""
Remove the joints specified in 'joints_to_remove'.
"""
valid_joints = []
for joint in range(len(self._parents)):
if joint not in joints_to_remove:
valid_joints.append(joint)
for i in range(len(self._parents)):
while self._parents[i] in joints_to_remove:
self._parents[i] = self._parents[self._parents[i]]
index_offsets = np.zeros(len(self._parents), dtype=int)
new_parents = []
for i, parent in enumerate(self._parents):
if i not in joints_to_remove:
new_parents.append(parent - index_offsets[parent])
else:
index_offsets[i:] += 1
self._parents = np.array(new_parents)
if self._joints_left is not None:
new_joints_left = []
for joint in self._joints_left:
if joint in valid_joints:
new_joints_left.append(joint - index_offsets[joint])
self._joints_left = new_joints_left
if self._joints_right is not None:
new_joints_right = []
for joint in self._joints_right:
if joint in valid_joints:
new_joints_right.append(joint - index_offsets[joint])
self._joints_right = new_joints_right
self._compute_metadata()
return valid_joints
def joints_left(self):
return self._joints_left
def joints_right(self):
return self._joints_right
def _compute_metadata(self):
self._has_children = np.zeros(len(self._parents)).astype(bool)
for i, parent in enumerate(self._parents):
if parent != -1:
self._has_children[parent] = True
self._children = []
for i, parent in enumerate(self._parents):
self._children.append([])
for i, parent in enumerate(self._parents):
if parent != -1:
self._children[parent].append(i)
def get_skeleton():
skeleton = Skeleton(
parents=[
-1, 0, 1, 2, 3, 4, 0, 6, 7, 8, 9, 0, 11, 12, 13, 14, 12, 16, 17,
18, 19, 20, 19, 22, 12, 24, 25, 26, 27, 28, 27, 30
],
joints_left=[6, 7, 8, 9, 10, 16, 17, 18, 19, 20, 21, 22, 23],
joints_right=[1, 2, 3, 4, 5, 24, 25, 26, 27, 28, 29, 30, 31])
# Bring the skeleton to 17 joints instead of the original 32
skeleton.remove_joints(
[4, 5, 9, 10, 11, 16, 20, 21, 22, 23, 24, 28, 29, 30, 31])
# Rewire shoulders to the correct parents
skeleton._parents[11] = 8
skeleton._parents[14] = 8
# Fix children error
skeleton._children[7] = [8]
skeleton._children[8] = [9, 11, 14]
return skeleton

View File

@@ -16,8 +16,8 @@ from matplotlib.animation import writers
from matplotlib.ticker import MultipleLocator
from modelscope.metainfo import Pipelines
from modelscope.models.cv.body_3d_keypoints.body_3d_pose import (
BodyKeypointsDetection3D, KeypointsTypes)
from modelscope.models.cv.body_3d_keypoints.cannonical_pose.body_3d_pose import \
KeypointsTypes
from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline
from modelscope.pipelines.base import Input, Model, Pipeline, Tensor
@@ -112,11 +112,19 @@ def convert_2_h36m_data(lst_kps, lst_bboxes, joints_nbr=15):
Tasks.body_3d_keypoints, module_name=Pipelines.body_3d_keypoints)
class Body3DKeypointsPipeline(Pipeline):
def __init__(self, model: Union[str, BodyKeypointsDetection3D], **kwargs):
def __init__(self, model: str, **kwargs):
"""Human body 3D pose estimation.
Args:
model (Union[str, BodyKeypointsDetection3D]): model id on modelscope hub.
model (str): model id on modelscope hub.
kwargs (dict, `optional`): Extra kwargs passed into the preprocessor's constructor.
Example:
>>> from modelscope.pipelines import pipeline
>>> body_3d_keypoints = pipeline(Tasks.body_3d_keypoints,
model='damo/cv_hdformer_body-3d-keypoints_video')
>>> test_video_url = 'https://modelscope.oss-cn-beijing.aliyuncs.com/test/videos/Walking.54138969.mp4'
>>> output = body_3d_keypoints(test_video_url)
>>> print(output)
"""
super().__init__(model=model, **kwargs)
@@ -130,6 +138,10 @@ class Body3DKeypointsPipeline(Pipeline):
model=self.human_body_2d_kps_det_pipeline,
device='gpu' if torch.cuda.is_available() else 'cpu')
self.max_frame = self.keypoint_model_3d.cfg.model.INPUT.MAX_FRAME \
if hasattr(self.keypoint_model_3d.cfg.model.INPUT, 'MAX_FRAME') \
else self.keypoint_model_3d.cfg.model.INPUT.max_frame # max video frame number to be predicted 3D joints
def preprocess(self, input: Input) -> Dict[str, Any]:
self.video_url = input
video_frames = self.read_video_frames(self.video_url)
@@ -139,7 +151,6 @@ class Body3DKeypointsPipeline(Pipeline):
all_2d_poses = []
all_boxes_with_socre = []
max_frame = self.keypoint_model_3d.cfg.model.INPUT.MAX_FRAME # max video frame number to be predicted 3D joints
for i, frame in enumerate(video_frames):
kps_2d = self.human_body_2d_kps_detector(frame)
if [] == kps_2d.get('boxes'):
@@ -157,7 +168,7 @@ class Body3DKeypointsPipeline(Pipeline):
all_boxes_with_socre.append(
list(np.array(box).reshape(
(-1))) + [score]) # construct to list with shape [5]
if (i + 1) >= max_frame:
if (i + 1) >= self.max_frame:
break
all_2d_poses_np = np.array(all_2d_poses).reshape(
@@ -166,10 +177,11 @@ class Body3DKeypointsPipeline(Pipeline):
all_boxes_np = np.array(all_boxes_with_socre).reshape(
(len(all_boxes_with_socre), 5)) # [x1, y1, x2, y2, score]
joint_num = self.keypoint_model_3d.cfg.model.MODEL.IN_NUM_JOINTS \
if hasattr(self.keypoint_model_3d.cfg.model.MODEL, 'IN_NUM_JOINTS') \
else self.keypoint_model_3d.cfg.model.MODEL.n_joints
kps_2d_h36m_17 = convert_2_h36m_data(
all_2d_poses_np,
all_boxes_np,
joints_nbr=self.keypoint_model_3d.cfg.model.MODEL.IN_NUM_JOINTS)
all_2d_poses_np, all_boxes_np, joints_nbr=joint_num)
kps_2d_h36m_17 = np.array(kps_2d_h36m_17)
res = {'success': True, 'input_2d_pts': kps_2d_h36m_17}
return res
@@ -246,7 +258,6 @@ class Body3DKeypointsPipeline(Pipeline):
raise Exception('modelscope error: %s cannot get video fps info.' %
(video_url))
max_frame_num = self.keypoint_model_3d.cfg.model.INPUT.MAX_FRAME
frame_idx = 0
while True:
ret, frame = cap.read()
@@ -256,7 +267,7 @@ class Body3DKeypointsPipeline(Pipeline):
timestamp_format(seconds=frame_idx / self.fps))
frame_idx += 1
frames.append(frame)
if frame_idx >= max_frame_num:
if frame_idx >= self.max_frame:
break
cap.release()
return frames
@@ -278,7 +289,8 @@ class Body3DKeypointsPipeline(Pipeline):
[12, 13], [9, 10]] # connection between joints
fig = plt.figure()
ax = p3.Axes3D(fig)
ax = p3.Axes3D(fig, auto_add_to_figure=False)
fig.add_axes(ax)
x_major_locator = MultipleLocator(0.5)
ax.xaxis.set_major_locator(x_major_locator)

View File

@@ -0,0 +1,50 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import unittest
import cv2
import numpy as np
from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline
from modelscope.pipelines.base import Pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.demo_utils import DemoCompatibilityCheck
from modelscope.utils.test_utils import test_level
class Body3DKeypointsHDFormerTest(unittest.TestCase, DemoCompatibilityCheck):
def setUp(self) -> None:
self.model_id = 'damo/cv_hdformer_body-3d-keypoints_video'
self.test_video = 'data/test/videos/Walking.54138969.mp4'
self.task = Tasks.body_3d_keypoints
def pipeline_inference(self, pipeline: Pipeline, pipeline_input):
output = pipeline(pipeline_input, output_video='./result.mp4')
poses = np.array(output[OutputKeys.KEYPOINTS])
print(f'result 3d points shape {poses.shape}')
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_modelhub_with_video_file(self):
body_3d_keypoints = pipeline(
Tasks.body_3d_keypoints, model=self.model_id)
pipeline_input = self.test_video
self.pipeline_inference(
body_3d_keypoints, pipeline_input=pipeline_input)
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_modelhub_with_video_stream(self):
body_3d_keypoints = pipeline(Tasks.body_3d_keypoints)
cap = cv2.VideoCapture(self.test_video)
if not cap.isOpened():
raise Exception('modelscope error: %s cannot be decoded by OpenCV.'
% (self.test_video))
self.pipeline_inference(body_3d_keypoints, pipeline_input=cap)
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_demo_compatibility(self):
self.compatibility_check()
if __name__ == '__main__':
unittest.main()