mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-24 20:19:22 +01:00
Add HDFormer for 3d pose estimation of body
新增HDFormer 3D人体姿态估计模型 Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11576328 * add hdformer * rm redundent pipeline * add license * add test case
This commit is contained in:
@@ -34,6 +34,7 @@ class Models(object):
|
||||
product_retrieval_embedding = 'product-retrieval-embedding'
|
||||
body_2d_keypoints = 'body-2d-keypoints'
|
||||
body_3d_keypoints = 'body-3d-keypoints'
|
||||
body_3d_keypoints_hdformer = 'hdformer'
|
||||
crowd_counting = 'HRNetCrowdCounting'
|
||||
face_2d_keypoints = 'face-2d-keypoints'
|
||||
panoptic_segmentation = 'swinL-panoptic-segmentation'
|
||||
|
||||
@@ -4,12 +4,12 @@ from typing import TYPE_CHECKING
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
from .body_3d_pose import BodyKeypointsDetection3D
|
||||
|
||||
from .cannonical_pose import BodyKeypointsDetection3D
|
||||
from .hdformer import HDFormerDetector
|
||||
else:
|
||||
_import_structure = {
|
||||
'body_3d_pose': ['BodyKeypointsDetection3D'],
|
||||
'cannonical_pose': ['BodyKeypointsDetection3D'],
|
||||
'hdformer': ['HDFormerDetector'],
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
from .body_3d_pose import BodyKeypointsDetection3D
|
||||
@@ -10,7 +10,7 @@ import torch
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models.base.base_torch_model import TorchModel
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.models.cv.body_3d_keypoints.canonical_pose_modules import (
|
||||
from modelscope.models.cv.body_3d_keypoints.cannonical_pose.canonical_pose_modules import (
|
||||
TemporalModel, TransCan3Dkeys)
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
@@ -0,0 +1,2 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from .hdformer_detector import HDFormerDetector
|
||||
306
modelscope/models/cv/body_3d_keypoints/hdformer/backbone.py
Normal file
306
modelscope/models/cv/body_3d_keypoints/hdformer/backbone.py
Normal file
@@ -0,0 +1,306 @@
|
||||
# --------------------------------------------------------
|
||||
# The implementation is also open-sourced by the authors as Hanyuan Chen, and is available publicly on
|
||||
# https://github.com/hyer/HDFormer
|
||||
# --------------------------------------------------------
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from modelscope.models.cv.body_3d_keypoints.hdformer.block import \
|
||||
HightOrderAttentionBlock
|
||||
from modelscope.models.cv.body_3d_keypoints.hdformer.directed_graph import (
|
||||
DiGraph, Graph)
|
||||
from modelscope.models.cv.body_3d_keypoints.hdformer.skeleton import \
|
||||
get_skeleton
|
||||
|
||||
|
||||
class HDFormerNet(nn.Module):
|
||||
|
||||
def __init__(self, cfg):
|
||||
super(HDFormerNet, self).__init__()
|
||||
in_channels = cfg.in_channels
|
||||
dropout = cfg.dropout
|
||||
self.cfg = cfg
|
||||
self.PLANES = [16, 32, 64, 128, 256]
|
||||
|
||||
# load graph
|
||||
skeleton = get_skeleton()
|
||||
self.di_graph = DiGraph(skeleton=skeleton)
|
||||
self.graph = Graph(
|
||||
skeleton=skeleton, strategy='agcn', max_hop=1, dilation=1)
|
||||
self.A = torch.tensor(
|
||||
self.graph.A,
|
||||
dtype=torch.float32,
|
||||
requires_grad=True,
|
||||
device='cuda')
|
||||
|
||||
# build networks
|
||||
spatial_kernel_size = self.A.size(0)
|
||||
temporal_kernel_size = 9
|
||||
kernel_size = (temporal_kernel_size, spatial_kernel_size)
|
||||
|
||||
if not cfg.data_bn:
|
||||
self.data_bn = None
|
||||
else:
|
||||
n_joints = self.cfg.IN_NUM_JOINTS \
|
||||
if hasattr(self.cfg, 'IN_NUM_JOINTS') \
|
||||
else self.cfg.n_joints
|
||||
self.data_bn = nn.BatchNorm1d(in_channels * n_joints) if hasattr(cfg, 'PJN') and cfg.PJN \
|
||||
else nn.BatchNorm2d(in_channels)
|
||||
|
||||
self.downsample = nn.ModuleList(
|
||||
(
|
||||
HightOrderAttentionBlock(
|
||||
in_channels,
|
||||
self.PLANES[0],
|
||||
kernel_size,
|
||||
A=self.A,
|
||||
di_graph=self.di_graph,
|
||||
residual=False,
|
||||
adj_len=self.A.size(1),
|
||||
attention=cfg.attention_down if hasattr(
|
||||
cfg, 'attention_down') else False,
|
||||
dropout=0),
|
||||
HightOrderAttentionBlock(
|
||||
self.PLANES[0],
|
||||
self.PLANES[1],
|
||||
kernel_size,
|
||||
A=self.A,
|
||||
di_graph=self.di_graph,
|
||||
stride=2,
|
||||
adj_len=self.A.size(1),
|
||||
attention=cfg.attention_down if hasattr(
|
||||
cfg, 'attention_down') else False,
|
||||
dropout=0),
|
||||
HightOrderAttentionBlock(
|
||||
self.PLANES[1],
|
||||
self.PLANES[1],
|
||||
kernel_size,
|
||||
A=self.A,
|
||||
di_graph=self.di_graph,
|
||||
adj_len=self.A.size(1),
|
||||
attention=cfg.attention_down if hasattr(
|
||||
cfg, 'attention_down') else False,
|
||||
dropout=0),
|
||||
HightOrderAttentionBlock(
|
||||
self.PLANES[1],
|
||||
self.PLANES[2],
|
||||
kernel_size,
|
||||
A=self.A,
|
||||
di_graph=self.di_graph,
|
||||
stride=2,
|
||||
adj_len=self.A.size(1),
|
||||
attention=cfg.attention_down if hasattr(
|
||||
cfg, 'attention_down') else False,
|
||||
dropout=0),
|
||||
HightOrderAttentionBlock(
|
||||
self.PLANES[2],
|
||||
self.PLANES[2],
|
||||
kernel_size,
|
||||
A=self.A,
|
||||
di_graph=self.di_graph,
|
||||
adj_len=self.A.size(1),
|
||||
attention=cfg.attention_down if hasattr(
|
||||
cfg, 'attention_down') else False,
|
||||
dropout=0),
|
||||
HightOrderAttentionBlock(
|
||||
self.PLANES[2],
|
||||
self.PLANES[3],
|
||||
kernel_size,
|
||||
A=self.A,
|
||||
di_graph=self.di_graph,
|
||||
stride=2,
|
||||
adj_len=self.A.size(1),
|
||||
attention=cfg.attention_down if hasattr(
|
||||
cfg, 'attention_down') else False,
|
||||
dropout=dropout),
|
||||
HightOrderAttentionBlock(
|
||||
self.PLANES[3],
|
||||
self.PLANES[3],
|
||||
kernel_size,
|
||||
A=self.A,
|
||||
di_graph=self.di_graph,
|
||||
adj_len=self.A.size(1),
|
||||
attention=cfg.attention_down if hasattr(
|
||||
cfg, 'attention_down') else False,
|
||||
dropout=dropout),
|
||||
HightOrderAttentionBlock(
|
||||
self.PLANES[3],
|
||||
self.PLANES[4],
|
||||
kernel_size,
|
||||
A=self.A,
|
||||
di_graph=self.di_graph,
|
||||
stride=2,
|
||||
adj_len=self.A.size(1),
|
||||
attention=cfg.attention_down if hasattr(
|
||||
cfg, 'attention_down') else False,
|
||||
dropout=dropout),
|
||||
HightOrderAttentionBlock(
|
||||
self.PLANES[4],
|
||||
self.PLANES[4],
|
||||
kernel_size,
|
||||
A=self.A,
|
||||
di_graph=self.di_graph,
|
||||
adj_len=self.A.size(1),
|
||||
attention=cfg.attention_down if hasattr(
|
||||
cfg, 'attention_down') else False,
|
||||
dropout=dropout),
|
||||
))
|
||||
|
||||
self.upsample = nn.ModuleList((
|
||||
HightOrderAttentionBlock(
|
||||
self.PLANES[4],
|
||||
self.PLANES[3],
|
||||
kernel_size,
|
||||
A=self.A,
|
||||
di_graph=self.di_graph,
|
||||
attention=cfg.attention_up
|
||||
if hasattr(cfg, 'attention_up') else False,
|
||||
adj_len=self.A.size(1),
|
||||
dropout=dropout),
|
||||
HightOrderAttentionBlock(
|
||||
self.PLANES[3],
|
||||
self.PLANES[2],
|
||||
kernel_size,
|
||||
A=self.A,
|
||||
di_graph=self.di_graph,
|
||||
attention=cfg.attention_up
|
||||
if hasattr(cfg, 'attention_up') else False,
|
||||
adj_len=self.A.size(1),
|
||||
dropout=dropout),
|
||||
HightOrderAttentionBlock(
|
||||
self.PLANES[2],
|
||||
self.PLANES[1],
|
||||
kernel_size,
|
||||
A=self.A,
|
||||
di_graph=self.di_graph,
|
||||
attention=cfg.attention_up
|
||||
if hasattr(cfg, 'attention_up') else False,
|
||||
adj_len=self.A.size(1),
|
||||
dropout=0),
|
||||
HightOrderAttentionBlock(
|
||||
self.PLANES[1],
|
||||
self.PLANES[0],
|
||||
kernel_size,
|
||||
A=self.A,
|
||||
di_graph=self.di_graph,
|
||||
attention=cfg.attention_up
|
||||
if hasattr(cfg, 'attention_up') else False,
|
||||
adj_len=self.A.size(1),
|
||||
dropout=0),
|
||||
))
|
||||
|
||||
self.merge = nn.ModuleList((
|
||||
HightOrderAttentionBlock(
|
||||
self.PLANES[4],
|
||||
self.PLANES[0],
|
||||
kernel_size,
|
||||
A=self.A,
|
||||
di_graph=self.di_graph,
|
||||
attention=cfg.attention_merge if hasattr(
|
||||
cfg, 'attention_merge') else False,
|
||||
adj_len=self.A.size(1),
|
||||
dropout=dropout,
|
||||
max_hop=self.cfg.max_hop),
|
||||
HightOrderAttentionBlock(
|
||||
self.PLANES[3],
|
||||
self.PLANES[0],
|
||||
kernel_size,
|
||||
A=self.A,
|
||||
di_graph=self.di_graph,
|
||||
attention=cfg.attention_merge if hasattr(
|
||||
cfg, 'attention_merge') else False,
|
||||
adj_len=self.A.size(1),
|
||||
dropout=dropout,
|
||||
max_hop=self.cfg.max_hop),
|
||||
HightOrderAttentionBlock(
|
||||
self.PLANES[2],
|
||||
self.PLANES[0],
|
||||
kernel_size,
|
||||
A=self.A,
|
||||
di_graph=self.di_graph,
|
||||
attention=cfg.attention_merge if hasattr(
|
||||
cfg, 'attention_merge') else False,
|
||||
adj_len=self.A.size(1),
|
||||
dropout=0,
|
||||
max_hop=self.cfg.max_hop),
|
||||
HightOrderAttentionBlock(
|
||||
self.PLANES[1],
|
||||
self.PLANES[0],
|
||||
kernel_size,
|
||||
A=self.A,
|
||||
di_graph=self.di_graph,
|
||||
attention=cfg.attention_merge if hasattr(
|
||||
cfg, 'attention_merge') else False,
|
||||
adj_len=self.A.size(1),
|
||||
dropout=0,
|
||||
max_hop=self.cfg.max_hop),
|
||||
))
|
||||
|
||||
def get_edge_fea(self, x_v):
|
||||
x_e = (x_v[..., [c for p, c in self.di_graph.directed_edges_hop1]]
|
||||
- x_v[..., [p for p, c in self.di_graph.directed_edges_hop1]]
|
||||
).contiguous()
|
||||
N, C, T, V = x_v.shape
|
||||
edeg_append = torch.zeros((N, C, T, 1), device=x_e.device)
|
||||
x_e = torch.cat((x_e, edeg_append), dim=-1)
|
||||
return x_e
|
||||
|
||||
def forward(self, x_v: torch.Tensor):
|
||||
"""
|
||||
x: shape [B,C,T,V_v]
|
||||
"""
|
||||
B, C, T, V = x_v.shape
|
||||
# data normalization
|
||||
if self.data_bn is not None:
|
||||
if hasattr(self.cfg, 'PJN') and self.cfg.PJN:
|
||||
x_v = self.data_bn(x_v.permute(0, 1, 3, 2).contiguous().view(B, -1, T)).view(B, C, V, T) \
|
||||
.contiguous().permute(0, 1, 3, 2)
|
||||
else:
|
||||
x_v = self.data_bn(x_v)
|
||||
|
||||
x_e = self.get_edge_fea(x_v)
|
||||
|
||||
# forward
|
||||
feature = []
|
||||
for idx, hoa_block in enumerate(self.downsample):
|
||||
x_v, x_e = hoa_block(x_v, x_e)
|
||||
if idx == 0 or idx == 2 or idx == 4 or idx == 6:
|
||||
feature.append((x_v, x_e))
|
||||
|
||||
feature.append((x_v, x_e))
|
||||
feature = feature[::-1]
|
||||
|
||||
x_v, x_e = feature[0]
|
||||
identity_feature = feature[1:]
|
||||
|
||||
ushape_feature = []
|
||||
ushape_feature.append((x_v, x_e))
|
||||
for idx, (hoa_block, id) in \
|
||||
enumerate(zip(self.upsample, identity_feature)):
|
||||
x_v, x_e = hoa_block(x_v, x_e)
|
||||
if hasattr(self.cfg, 'deterministic') and self.cfg.deterministic:
|
||||
x_v = F.interpolate(x_v, scale_factor=(2, 1), mode='nearest')
|
||||
else:
|
||||
x_v = F.interpolate(
|
||||
x_v,
|
||||
scale_factor=(2, 1),
|
||||
mode='bilinear',
|
||||
align_corners=False)
|
||||
x_v += id[0]
|
||||
ushape_feature.append((x_v, x_e))
|
||||
|
||||
ushape_feature = ushape_feature[:-1]
|
||||
for idx, (hoa_block, u) in \
|
||||
enumerate(zip(self.merge, ushape_feature)):
|
||||
x_v2, x_e2 = hoa_block(*u)
|
||||
if hasattr(self.cfg, 'deterministic') and self.cfg.deterministic:
|
||||
x_v += F.interpolate(
|
||||
x_v2, scale_factor=(2**(4 - idx), 1), mode='nearest')
|
||||
else:
|
||||
x_v += F.interpolate(
|
||||
x_v2,
|
||||
scale_factor=(2**(4 - idx), 1),
|
||||
mode='bilinear',
|
||||
align_corners=False)
|
||||
return x_v, x_e
|
||||
380
modelscope/models/cv/body_3d_keypoints/hdformer/block.py
Normal file
380
modelscope/models/cv/body_3d_keypoints/hdformer/block.py
Normal file
@@ -0,0 +1,380 @@
|
||||
# Part of the implementation is borrowed and modified from 2s-AGCN, publicly available at
|
||||
# https://github.com/lshiwjx/2s-AGCN
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
def import_class(name):
|
||||
components = name.split('.')
|
||||
mod = __import__(components[0])
|
||||
for comp in components[1:]:
|
||||
mod = getattr(mod, comp)
|
||||
return mod
|
||||
|
||||
|
||||
def conv_branch_init(conv, branches):
|
||||
weight = conv.weight
|
||||
n = weight.size(0)
|
||||
k1 = weight.size(1)
|
||||
k2 = weight.size(2)
|
||||
nn.init.normal_(weight, 0, math.sqrt(2. / (n * k1 * k2 * branches)))
|
||||
if conv.bias is not None:
|
||||
nn.init.constant_(conv.bias, 0)
|
||||
|
||||
|
||||
def conv_init(conv):
|
||||
if conv.weight is not None:
|
||||
nn.init.kaiming_normal_(conv.weight, mode='fan_out')
|
||||
if conv.bias is not None:
|
||||
nn.init.constant_(conv.bias, 0)
|
||||
|
||||
|
||||
def bn_init(bn, scale):
|
||||
nn.init.constant_(bn.weight, scale)
|
||||
nn.init.constant_(bn.bias, 0)
|
||||
|
||||
|
||||
def zero(x):
|
||||
"""return zero."""
|
||||
return 0
|
||||
|
||||
|
||||
def iden(x):
|
||||
"""return input itself."""
|
||||
return x
|
||||
|
||||
|
||||
class Mlp(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
in_features,
|
||||
hidden_features=None,
|
||||
out_features=None,
|
||||
act_layer=nn.GELU,
|
||||
drop=0.,
|
||||
changedim=False,
|
||||
currentdim=0,
|
||||
depth=0):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = nn.Linear(in_features, hidden_features)
|
||||
self.act = act_layer()
|
||||
self.fc2 = nn.Linear(hidden_features, out_features)
|
||||
self.drop = nn.Dropout(drop)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.drop(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim,
|
||||
num_heads=8,
|
||||
qkv_bias=False,
|
||||
qk_scale=None,
|
||||
attn_drop=0.,
|
||||
proj_drop=0.,
|
||||
comb=False,
|
||||
vis=False):
|
||||
"""Attention is all you need
|
||||
|
||||
Args:
|
||||
dim (_type_): _description_
|
||||
num_heads (int, optional): _description_. Defaults to 8.
|
||||
qkv_bias (bool, optional): _description_. Defaults to False.
|
||||
qk_scale (_type_, optional): _description_. Defaults to None.
|
||||
attn_drop (_type_, optional): _description_. Defaults to 0..
|
||||
proj_drop (_type_, optional): _description_. Defaults to 0..
|
||||
comb (bool, optional): Defaults to False.
|
||||
True: q transpose * k.
|
||||
False: q * k transpose.
|
||||
vis (bool, optional): _description_. Defaults to False.
|
||||
"""
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
|
||||
self.scale = qk_scale or head_dim**-0.5
|
||||
|
||||
self.to_q = nn.Linear(dim, dim, bias=qkv_bias)
|
||||
self.to_k = nn.Linear(dim, dim, bias=qkv_bias)
|
||||
self.to_v = nn.Linear(dim, dim, bias=qkv_bias)
|
||||
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
self.comb = comb
|
||||
self.vis = vis
|
||||
|
||||
def forward(self, fv, fe):
|
||||
B, N, C = fv.shape
|
||||
B, E, C = fe.shape
|
||||
q = self.to_q(fv).reshape(B, N, self.num_heads,
|
||||
C // self.num_heads).permute(0, 2, 1, 3)
|
||||
k = self.to_k(fe).reshape(B, E, self.num_heads,
|
||||
C // self.num_heads).permute(0, 2, 1, 3)
|
||||
v = self.to_v(fe).reshape(B, E, self.num_heads,
|
||||
C // self.num_heads).permute(0, 2, 1, 3)
|
||||
# Now fv shape (B, H, N, C//heads)
|
||||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
if self.comb:
|
||||
fv = (attn @ v.transpose(-2, -1)).transpose(-2, -1)
|
||||
fv = rearrange(fv, 'B H N C -> B N (H C)')
|
||||
elif self.comb is False:
|
||||
fv = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
||||
fv = self.proj(fv)
|
||||
fv = self.proj_drop(fv)
|
||||
return fv
|
||||
|
||||
|
||||
class FirstOrderAttention(nn.Module):
|
||||
"""First Order Attention block for spatial relationship.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
A,
|
||||
t_kernel_size=1,
|
||||
t_stride=1,
|
||||
t_padding=0,
|
||||
t_dilation=1,
|
||||
adj_len=17,
|
||||
bias=True):
|
||||
super().__init__()
|
||||
|
||||
self.kernel_size = kernel_size
|
||||
self.A = A
|
||||
self.PA = nn.Parameter(torch.FloatTensor(3, adj_len, adj_len))
|
||||
torch.nn.init.constant_(self.PA, 1e-6)
|
||||
|
||||
self.num_subset = 3
|
||||
inter_channels = out_channels // 4
|
||||
self.inter_c = inter_channels
|
||||
self.conv_a = nn.ModuleList()
|
||||
self.conv_b = nn.ModuleList()
|
||||
self.conv_d = nn.ModuleList()
|
||||
self.linears = nn.ModuleList()
|
||||
for i in range(self.num_subset):
|
||||
self.conv_a.append(nn.Conv2d(in_channels, inter_channels, 1))
|
||||
self.conv_b.append(nn.Conv2d(in_channels, inter_channels, 1))
|
||||
self.conv_d.append(nn.Conv2d(in_channels, out_channels, 1))
|
||||
self.linears.append(nn.Linear(in_channels, in_channels))
|
||||
|
||||
if in_channels != out_channels:
|
||||
self.down = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, 1),
|
||||
nn.BatchNorm2d(out_channels))
|
||||
else:
|
||||
self.down = lambda x: x
|
||||
|
||||
self.bn = nn.BatchNorm2d(out_channels)
|
||||
self.soft = nn.Softmax(-2)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
conv_init(m)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
bn_init(m, 1)
|
||||
bn_init(self.bn, 1e-6)
|
||||
for i in range(self.num_subset):
|
||||
conv_branch_init(self.conv_d[i], self.num_subset)
|
||||
|
||||
def forward(self, x):
|
||||
assert self.A.shape[0] == self.kernel_size[1]
|
||||
|
||||
N, C, T, V = x.size()
|
||||
A = self.A + self.PA
|
||||
|
||||
y = None
|
||||
for i in range(self.num_subset):
|
||||
x_in = rearrange(x, 'N C T V -> N T V C')
|
||||
x_in = self.linears[i](x_in)
|
||||
A0 = rearrange(x_in, 'N T V C -> N (C T) V')
|
||||
|
||||
A1 = self.conv_a[i](x).permute(0, 3, 1, 2).contiguous().view(
|
||||
N, V, self.inter_c * T)
|
||||
A2 = self.conv_b[i](x).view(N, self.inter_c * T, V)
|
||||
A1 = self.soft(torch.matmul(A1, A2) / A1.size(-1))
|
||||
A1 = A1 + A[i]
|
||||
z = self.conv_d[i](torch.matmul(A0, A1).view(N, C, T, V))
|
||||
y = z + y if y is not None else z
|
||||
y = self.bn(y)
|
||||
y += self.down(x)
|
||||
|
||||
return self.relu(y)
|
||||
|
||||
|
||||
class HightOrderAttentionBlock(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
A,
|
||||
di_graph,
|
||||
attention=False,
|
||||
stride=1,
|
||||
adj_len=17,
|
||||
dropout=0,
|
||||
residual=True,
|
||||
norm_layer=nn.BatchNorm2d,
|
||||
edge_importance=False,
|
||||
graph=None,
|
||||
conditional=False,
|
||||
experts=4,
|
||||
bias=True,
|
||||
share_tcn=False,
|
||||
max_hop=2):
|
||||
super().__init__()
|
||||
|
||||
t_kernel_size = kernel_size[0]
|
||||
assert t_kernel_size % 2 == 1
|
||||
padding = ((t_kernel_size - 1) // 2, 0)
|
||||
self.max_hop = max_hop
|
||||
self.attention = attention
|
||||
self.di_graph = di_graph
|
||||
|
||||
self.foa_block = FirstOrderAttention(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
A,
|
||||
bias=bias,
|
||||
adj_len=adj_len)
|
||||
|
||||
self.tcn_v = nn.Sequential(
|
||||
norm_layer(out_channels),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(
|
||||
out_channels,
|
||||
out_channels, (t_kernel_size, 1), (stride, 1),
|
||||
padding,
|
||||
bias=bias),
|
||||
norm_layer(out_channels),
|
||||
nn.Dropout(dropout, inplace=True),
|
||||
)
|
||||
|
||||
if not residual:
|
||||
self.residual_v = zero
|
||||
elif (in_channels == out_channels) and (stride == 1):
|
||||
self.residual_v = iden
|
||||
else:
|
||||
self.residual_v = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
stride=(stride, 1),
|
||||
bias=bias),
|
||||
norm_layer(out_channels),
|
||||
)
|
||||
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
if self.attention:
|
||||
self.cross_attn = Attention(
|
||||
dim=out_channels,
|
||||
num_heads=8,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
attn_drop=dropout,
|
||||
proj_drop=dropout)
|
||||
self.norm_v = nn.LayerNorm(out_channels)
|
||||
self.mlp = Mlp(
|
||||
in_features=out_channels,
|
||||
out_features=out_channels,
|
||||
hidden_features=out_channels * 2,
|
||||
act_layer=nn.GELU,
|
||||
drop=dropout)
|
||||
self.norm_mlp = nn.LayerNorm(out_channels)
|
||||
|
||||
# linear to change fep channels
|
||||
self.linears = nn.ModuleList()
|
||||
for hop_i in range(self.max_hop - 1):
|
||||
hop_linear = nn.ModuleList()
|
||||
for i in range(
|
||||
len(
|
||||
eval(f'self.di_graph.directed_edges_hop{hop_i+2}'))
|
||||
):
|
||||
hop_linear.append(nn.Linear(hop_i + 2, 1))
|
||||
self.linears.append(hop_linear)
|
||||
|
||||
def forward(self, fv, fe):
|
||||
# `fv` (node features) has shape (B, C, T, V_node)
|
||||
# `fe` (edge features) has shape (B, C, T, V_edge)
|
||||
N, C, T, V = fv.size()
|
||||
|
||||
res_v = self.residual_v(fv)
|
||||
|
||||
fvp = self.foa_block(fv)
|
||||
fep_out = (
|
||||
fvp[..., [c for p, c in self.di_graph.directed_edges_hop1]]
|
||||
- fvp[..., [p for p, c in self.di_graph.directed_edges_hop1]]
|
||||
).contiguous()
|
||||
|
||||
if self.attention:
|
||||
fep_concat = None
|
||||
for hop_i in range(self.max_hop):
|
||||
if 0 == hop_i:
|
||||
fep_hop_i = (fvp[..., [
|
||||
c for p, c in eval(
|
||||
f'self.di_graph.directed_edges_hop{hop_i+1}')
|
||||
]] - fvp[..., [
|
||||
p for p, c in eval(
|
||||
f'self.di_graph.directed_edges_hop{hop_i+1}')
|
||||
]]).contiguous()
|
||||
fep_hop_i = rearrange(fep_hop_i, 'N C T E -> (N T) E C')
|
||||
else:
|
||||
joints_parts = eval(
|
||||
f'self.di_graph.directed_edges_hop{hop_i+1}')
|
||||
fep_hop_i = None
|
||||
for part_idx, part in enumerate(joints_parts):
|
||||
fep_part = None
|
||||
for j in range(len(part) - 1):
|
||||
fep = (fvp[..., part[j + 1]]
|
||||
- fvp[..., part[j]]).contiguous().unsqueeze(
|
||||
dim=-1)
|
||||
if fep_part is None:
|
||||
fep_part = fep
|
||||
else:
|
||||
fep_part = torch.cat((fep_part, fep), dim=-1)
|
||||
fep_part = self.linears[hop_i - 1][part_idx](fep_part)
|
||||
if fep_hop_i is None:
|
||||
fep_hop_i = fep_part
|
||||
else:
|
||||
fep_hop_i = torch.cat((fep_hop_i, fep_part),
|
||||
dim=-1)
|
||||
|
||||
fep_hop_i = rearrange(fep_hop_i, 'N C T E -> (N T) E C')
|
||||
|
||||
if fep_concat is None:
|
||||
fep_concat = fep_hop_i
|
||||
else:
|
||||
fep_concat = torch.cat((fep_concat, fep_hop_i),
|
||||
dim=-2) # dim=-2 represent edge dim
|
||||
fvp = rearrange(fvp, 'N C T V -> (N T) V C')
|
||||
fvp = self.norm_v(self.cross_attn(fvp, fep_concat)) + iden(fvp)
|
||||
fvp = self.mlp(self.norm_mlp(fvp)) + iden(
|
||||
fvp) # make output joint number = adj_len
|
||||
fvp = rearrange(fvp, '(N T) V C -> N C T V', N=N)
|
||||
|
||||
fvp = self.tcn_v(fvp) + res_v
|
||||
|
||||
return self.relu(fvp), fep_out
|
||||
@@ -0,0 +1,209 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import sys
|
||||
from typing import List, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
sys.path.insert(0, './')
|
||||
|
||||
|
||||
def edge2mat(link, num_node):
|
||||
"""According to the directed edge link, the adjacency matrix is constructed.
|
||||
link: [V, 2], each row is a tuple(start node, end node).
|
||||
"""
|
||||
A = np.zeros((num_node, num_node))
|
||||
for i, j in link:
|
||||
A[j, i] = 1
|
||||
return A
|
||||
|
||||
|
||||
def normalize_incidence_matrix(im: np.ndarray) -> np.ndarray:
|
||||
Dl = im.sum(-1)
|
||||
num_node = im.shape[0]
|
||||
Dn = np.zeros((num_node, num_node))
|
||||
for i in range(num_node):
|
||||
if Dl[i] > 0:
|
||||
Dn[i, i] = Dl[i]**(-1)
|
||||
res = Dn @ im
|
||||
return res
|
||||
|
||||
|
||||
def build_digraph_incidence_matrix(num_nodes: int,
|
||||
edges: List[Tuple]) -> np.ndarray:
|
||||
source_graph = np.zeros((num_nodes, len(edges)), dtype='float32')
|
||||
target_graph = np.zeros((num_nodes, len(edges)), dtype='float32')
|
||||
for edge_id, (source_node, target_node) in enumerate(edges):
|
||||
source_graph[source_node, edge_id] = 1.
|
||||
target_graph[target_node, edge_id] = 1.
|
||||
source_graph = normalize_incidence_matrix(source_graph)
|
||||
target_graph = normalize_incidence_matrix(target_graph)
|
||||
return source_graph, target_graph
|
||||
|
||||
|
||||
class DiGraph():
|
||||
|
||||
def __init__(self, skeleton):
|
||||
super().__init__()
|
||||
self.num_nodes = len(skeleton.parents())
|
||||
self.directed_edges_hop1 = [
|
||||
(parrent, child)
|
||||
for child, parrent in enumerate(skeleton.parents()) if parrent >= 0
|
||||
]
|
||||
self.directed_edges_hop2 = [(0, 1, 2), (0, 4, 5), (0, 7, 8), (1, 2, 3),
|
||||
(4, 5, 6), (7, 8, 9),
|
||||
(7, 8, 11), (7, 8, 14), (8, 9, 10),
|
||||
(8, 11, 12), (8, 14, 15), (11, 12, 13),
|
||||
(14, 15, 16)] # (parrent, child)
|
||||
self.directed_edges_hop3 = [(0, 1, 2, 3), (0, 4, 5, 6), (0, 7, 8, 9),
|
||||
(7, 8, 9, 10), (7, 8, 11, 12),
|
||||
(7, 8, 14, 15), (8, 11, 12, 13),
|
||||
(8, 14, 15, 16)]
|
||||
self.directed_edges_hop4 = [(0, 7, 8, 9, 10), (0, 7, 8, 11, 12),
|
||||
(0, 7, 8, 14, 15), (7, 8, 11, 12, 13),
|
||||
(7, 8, 14, 15, 16)]
|
||||
|
||||
self.num_edges = len(self.directed_edges_hop1)
|
||||
self.edge_left = [0, 1, 2, 10, 11, 12]
|
||||
self.edge_right = [3, 4, 5, 13, 14, 15]
|
||||
self.edge_middle = [6, 7, 8, 9]
|
||||
self.center = 0 # for h36m data skeleton
|
||||
# Incidence matrices
|
||||
self.source_M, self.target_M = \
|
||||
build_digraph_incidence_matrix(self.num_nodes, self.directed_edges_hop1)
|
||||
|
||||
|
||||
class Graph():
|
||||
""" The Graph to model the skeletons extracted by the openpose
|
||||
Args:
|
||||
strategy (string): must be one of the follow candidates
|
||||
- uniform: Uniform Labeling
|
||||
- distance: Distance Partitioning
|
||||
- spatial: Spatial Configuration
|
||||
- agcn: AGCN Configuration
|
||||
For more information, please refer to the section 'Partition Strategies'
|
||||
in our paper (https://arxiv.org/abs/1801.07455).
|
||||
layout (string): must be one of the follow candidates
|
||||
- openpose: Is consists of 18 joints. For more information, please
|
||||
refer to https://github.com/CMU-Perceptual-Computing-Lab/openpose#output
|
||||
- ntu-rgb+d: Is consists of 25 joints. For more information, please
|
||||
refer to https://github.com/shahroudy/NTURGB-D
|
||||
max_hop (int): the maximal distance between two connected nodes
|
||||
dilation (int): controls the spacing between the kernel points
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
skeleton=None,
|
||||
strategy='uniform',
|
||||
max_hop=1,
|
||||
dilation=1):
|
||||
self.max_hop = max_hop
|
||||
self.dilation = dilation
|
||||
|
||||
assert strategy in ['uniform', 'distance', 'spatial', 'agcn']
|
||||
self.get_edge(skeleton)
|
||||
self.hop_dis = get_hop_distance(
|
||||
self.num_node, self.edge, max_hop=max_hop)
|
||||
self.get_adjacency(strategy)
|
||||
|
||||
def __str__(self):
|
||||
return self.A
|
||||
|
||||
def get_edge(self, skeleton):
|
||||
# edge is a list of [child, parent] paris
|
||||
self.num_node = len(skeleton.parents())
|
||||
self_link = [(i, i) for i in range(self.num_node)]
|
||||
neighbor_link = [(child, parrent)
|
||||
for child, parrent in enumerate(skeleton.parents())]
|
||||
self.self_link = self_link
|
||||
self.neighbor_link = neighbor_link
|
||||
self.edge = self_link + neighbor_link
|
||||
self.center = 0 # for h36m data skeleton, root node idx
|
||||
|
||||
def get_adjacency(self, strategy):
|
||||
valid_hop = range(0, self.max_hop + 1, self.dilation)
|
||||
adjacency = np.zeros((self.num_node, self.num_node))
|
||||
for hop in valid_hop:
|
||||
adjacency[self.hop_dis == hop] = 1
|
||||
normalize_adjacency = normalize_digraph(adjacency)
|
||||
|
||||
if strategy == 'uniform':
|
||||
A = np.zeros((1, self.num_node, self.num_node))
|
||||
A[0] = normalize_adjacency
|
||||
self.A = A
|
||||
elif strategy == 'distance':
|
||||
A = np.zeros((len(valid_hop), self.num_node, self.num_node))
|
||||
for i, hop in enumerate(valid_hop):
|
||||
A[i][self.hop_dis == hop] = \
|
||||
normalize_adjacency[self.hop_dis == hop]
|
||||
self.A = A
|
||||
elif strategy == 'spatial':
|
||||
A = []
|
||||
for hop in valid_hop:
|
||||
a_root = np.zeros((self.num_node, self.num_node))
|
||||
a_close = np.zeros((self.num_node, self.num_node))
|
||||
a_further = np.zeros((self.num_node, self.num_node))
|
||||
for i in range(self.num_node):
|
||||
for j in range(self.num_node):
|
||||
if self.hop_dis[j, i] == hop:
|
||||
if self.hop_dis[j, self.center] == self.hop_dis[
|
||||
i, self.center]:
|
||||
a_root[j, i] = normalize_adjacency[j, i]
|
||||
elif self.hop_dis[j, self.center] > self.hop_dis[
|
||||
i, self.center]:
|
||||
a_close[j, i] = normalize_adjacency[j, i]
|
||||
else:
|
||||
a_further[j, i] = normalize_adjacency[j, i]
|
||||
if hop == 0:
|
||||
A.append(a_root)
|
||||
else:
|
||||
A.append(a_root + a_close)
|
||||
A.append(a_further)
|
||||
A = np.stack(A)
|
||||
self.A = A
|
||||
elif strategy == 'agcn':
|
||||
A = []
|
||||
link_mat = edge2mat(self.self_link, self.num_node)
|
||||
In = normalize_digraph(edge2mat(self.neighbor_link, self.num_node))
|
||||
outward = [(j, i) for (i, j) in self.neighbor_link]
|
||||
Out = normalize_digraph(edge2mat(outward, self.num_node))
|
||||
A = np.stack((link_mat, In, Out))
|
||||
self.A = A
|
||||
else:
|
||||
raise ValueError('Do Not Exist This Strategy')
|
||||
|
||||
|
||||
def get_hop_distance(num_node, edge, max_hop=1):
|
||||
A = np.zeros((num_node, num_node))
|
||||
for i, j in edge:
|
||||
A[j, i] = 1
|
||||
A[i, j] = 1
|
||||
|
||||
# compute hop steps
|
||||
hop_dis = np.zeros((num_node, num_node)) + np.inf
|
||||
transfer_mat = [np.linalg.matrix_power(A, d) for d in range(max_hop + 1)]
|
||||
arrive_mat = (np.stack(transfer_mat) > 0)
|
||||
for d in range(max_hop, -1, -1):
|
||||
hop_dis[arrive_mat[d]] = d
|
||||
return hop_dis
|
||||
|
||||
|
||||
def normalize_digraph(A):
|
||||
Dl = np.sum(A, 0)
|
||||
num_node = A.shape[0]
|
||||
Dn = np.zeros((num_node, num_node))
|
||||
for i in range(num_node):
|
||||
if Dl[i] > 0:
|
||||
Dn[i, i] = Dl[i]**(-1)
|
||||
AD = np.dot(A, Dn)
|
||||
return AD
|
||||
|
||||
|
||||
def normalize_undigraph(A):
|
||||
Dl = np.sum(A, 0)
|
||||
num_node = A.shape[0]
|
||||
Dn = np.zeros((num_node, num_node))
|
||||
for i in range(num_node):
|
||||
if Dl[i] > 0:
|
||||
Dn[i, i] = Dl[i]**(-0.5)
|
||||
DAD = np.dot(np.dot(Dn, A), Dn)
|
||||
return DAD
|
||||
64
modelscope/models/cv/body_3d_keypoints/hdformer/hdformer.py
Normal file
64
modelscope/models/cv/body_3d_keypoints/hdformer/hdformer.py
Normal file
@@ -0,0 +1,64 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from modelscope.models.cv.body_3d_keypoints.hdformer.backbone import \
|
||||
HDFormerNet
|
||||
|
||||
|
||||
class HDFormer(nn.Module):
|
||||
|
||||
def __init__(self, cfg):
|
||||
super(HDFormer, self).__init__()
|
||||
self.regress_with_edge = hasattr(
|
||||
cfg, 'regress_with_edge') and cfg.regress_with_edge
|
||||
self.backbone = HDFormerNet(cfg)
|
||||
num_v, num_e = self.backbone.di_graph.source_M.shape
|
||||
self.regressor_type = cfg.regressor_type if hasattr(
|
||||
cfg, 'regressor_type') else 'conv'
|
||||
if self.regressor_type == 'conv':
|
||||
self.joint_regressor = nn.Conv2d(
|
||||
self.backbone.PLANES[0],
|
||||
3 * (num_v - 1),
|
||||
kernel_size=(3, num_v + num_e) if self.regress_with_edge else
|
||||
(3, num_v),
|
||||
padding=(1, 0),
|
||||
bias=True)
|
||||
elif self.regressor_type == 'fc':
|
||||
self.joint_regressor = nn.Conv1d(
|
||||
self.backbone.PLANES[0] * (num_v + num_e)
|
||||
if self.regress_with_edge else self.backbone.PLANES[0] * num_v,
|
||||
3 * (num_v - 1),
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
bias=True)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, x_v: torch.Tensor, mean_3d: torch.Tensor,
|
||||
std_3d: torch.Tensor):
|
||||
"""
|
||||
x: shape [B,C,T,V_v]
|
||||
"""
|
||||
fv, fe = self.backbone(x_v)
|
||||
B, C, T, V = fv.shape
|
||||
|
||||
if self.regressor_type == 'conv':
|
||||
pre_joints = self.joint_regressor(torch.cat([
|
||||
fv, fe
|
||||
], dim=-1)) if self.regress_with_edge else self.joint_regressor(fv)
|
||||
elif self.regressor_type == 'fc':
|
||||
x = (torch.cat([fv, fe], dim=-1) if self.regress_with_edge else fv) \
|
||||
.permute(0, 1, 3, 2).contiguous().view(B, -1, T)
|
||||
pre_joints = self.joint_regressor(x)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
pre_joints = pre_joints.view(B, 3, V - 1,
|
||||
T).permute(0, 1, 3,
|
||||
2).contiguous() # [B,3,T,V-1]
|
||||
root_node = torch.zeros((B, 3, T, 1),
|
||||
dtype=pre_joints.dtype,
|
||||
device=pre_joints.device)
|
||||
pre_joints = torch.cat((root_node, pre_joints), dim=-1)
|
||||
pre_joints = pre_joints * std_3d + mean_3d
|
||||
return pre_joints
|
||||
@@ -0,0 +1,196 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os.path as osp
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.backends.cudnn as cudnn
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models.base.base_torch_model import TorchModel
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.models.cv.body_3d_keypoints.hdformer.hdformer import HDFormer
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
|
||||
class KeypointsTypes(object):
|
||||
POSES_CAMERA = 'poses_camera'
|
||||
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
@MODELS.register_module(
|
||||
Tasks.body_3d_keypoints, module_name=Models.body_3d_keypoints_hdformer)
|
||||
class HDFormerDetector(TorchModel):
|
||||
|
||||
def __init__(self, model_dir: str, *args, **kwargs):
|
||||
super().__init__(model_dir, *args, **kwargs)
|
||||
self.model_dir = model_dir
|
||||
|
||||
cudnn.benchmark = True
|
||||
self.model_path = osp.join(self.model_dir, ModelFile.TORCH_MODEL_FILE)
|
||||
self.mean_std_2d = np.load(
|
||||
osp.join(self.model_dir, 'mean_std_2d.npy'), allow_pickle=True)
|
||||
self.mean_std_3d = np.load(
|
||||
osp.join(self.model_dir, 'mean_std_3d.npy'), allow_pickle=True)
|
||||
self.left_right_symmetry_2d = np.array(
|
||||
[0, 4, 5, 6, 1, 2, 3, 7, 8, 9, 10, 14, 15, 16, 11, 12, 13])
|
||||
cfg_path = osp.join(self.model_dir, ModelFile.CONFIGURATION)
|
||||
self.cfg = Config.from_file(cfg_path)
|
||||
if torch.cuda.is_available():
|
||||
self.device = torch.device('cuda')
|
||||
else:
|
||||
self.device = torch.device('cpu')
|
||||
self.net = HDFormer(self.cfg.model.MODEL)
|
||||
|
||||
self.load_model()
|
||||
self.net = self.net.to(self.device)
|
||||
|
||||
def load_model(self, load_to_cpu=False):
|
||||
pretrained_dict = torch.load(
|
||||
self.model_path,
|
||||
map_location=torch.device('cuda')
|
||||
if torch.cuda.is_available() else torch.device('cpu'))
|
||||
self.net.load_state_dict(pretrained_dict['state_dict'], strict=False)
|
||||
self.net.eval()
|
||||
|
||||
def preprocess(self, input: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Proprocess of 2D input joints.
|
||||
|
||||
Args:
|
||||
input (Dict[str, Any]): [NUM_FRAME, NUM_JOINTS, 2], input 2d human body keypoints.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: canonical 2d points and root relative joints.
|
||||
"""
|
||||
if 'cuda' == input.device.type:
|
||||
input = input.data.cpu().numpy()
|
||||
elif 'cpu' == input.device.type:
|
||||
input = input.data.numpy()
|
||||
pose2d = input
|
||||
num_frames, num_joints, in_channels = pose2d.shape
|
||||
logger.info(f'2d pose frame number: {num_frames}')
|
||||
|
||||
# [NUM_FRAME, NUM_JOINTS, 2]
|
||||
c = np.array(self.cfg.model.INPUT.center)
|
||||
f = np.array(self.cfg.model.INPUT.focal_length)
|
||||
self.window_size = self.cfg.model.INPUT.window_size
|
||||
receptive_field = self.cfg.model.INPUT.n_frames
|
||||
|
||||
# split the 2D pose sequences into fixed length frames
|
||||
inputs_2d = []
|
||||
inputs_2d_flip = []
|
||||
n = 0
|
||||
indices = []
|
||||
while n + receptive_field <= num_frames:
|
||||
indices.append((n, n + receptive_field))
|
||||
n += self.window_size
|
||||
self.valid_length = n - self.window_size + receptive_field
|
||||
|
||||
if 0 == len(indices):
|
||||
logger.warn(
|
||||
f'Fail to construct test sequences, total_frames = {num_frames}, \
|
||||
while receptive_filed ={receptive_field}')
|
||||
|
||||
self.mean_2d = self.mean_std_2d[0]
|
||||
self.std_2d = self.mean_std_2d[1]
|
||||
for (start, end) in indices:
|
||||
data_2d = pose2d[start:end]
|
||||
data_2d = (data_2d - 0.5 - c) / f
|
||||
data_2d_flip = data_2d.copy()
|
||||
data_2d_flip[:, :, 0] *= -1
|
||||
data_2d_flip = data_2d_flip[:, self.left_right_symmetry_2d, :]
|
||||
data_2d_flip = (data_2d_flip - self.mean_2d) / self.std_2d
|
||||
|
||||
data_2d = (data_2d - self.mean_2d) / self.std_2d
|
||||
data_2d = torch.from_numpy(data_2d.transpose(
|
||||
(2, 0, 1))).float() # [C,T,V]
|
||||
|
||||
data_2d_flip = torch.from_numpy(data_2d_flip.transpose(
|
||||
(2, 0, 1))).float() # [C,T,V]
|
||||
|
||||
inputs_2d.append(data_2d)
|
||||
inputs_2d_flip.append(data_2d_flip)
|
||||
|
||||
self.mean_3d = self.mean_std_3d[0]
|
||||
self.std_3d = self.mean_std_3d[1]
|
||||
mean_3d = torch.from_numpy(self.mean_3d).float().unsqueeze(-1)
|
||||
mean_3d = mean_3d.permute(1, 2, 0) # [3, 1, 17]
|
||||
std_3d = torch.from_numpy(self.std_3d).float().unsqueeze(-1)
|
||||
std_3d = std_3d.permute(1, 2, 0)
|
||||
|
||||
return {
|
||||
'inputs_2d': inputs_2d,
|
||||
'inputs_2d_flip': inputs_2d_flip,
|
||||
'mean_3d': mean_3d,
|
||||
'std_3d': std_3d
|
||||
}
|
||||
|
||||
def avg_flip(self, pre, pre_flip):
|
||||
left_right_symmetry = [
|
||||
0, 4, 5, 6, 1, 2, 3, 7, 8, 9, 10, 14, 15, 16, 11, 12, 13
|
||||
]
|
||||
pre_flip[:, 0, :, :] *= -1
|
||||
pre_flip = pre_flip[:, :, :, left_right_symmetry]
|
||||
pred_avg = (pre + pre_flip) / 2.
|
||||
return pred_avg
|
||||
|
||||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""3D human pose estimation.
|
||||
|
||||
Args:
|
||||
input (Dict):
|
||||
inputs_2d: [1, NUM_FRAME, NUM_JOINTS, 2]
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]:
|
||||
"camera_pose": Tensor, [1, NUM_FRAME, OUT_NUM_JOINTS, OUT_3D_FEATURE_DIM],
|
||||
3D human pose keypoints in camera frame.
|
||||
"success": 3D pose estimation success or failed.
|
||||
"""
|
||||
inputs_2d = input['inputs_2d']
|
||||
inputs_2d_flip = input['inputs_2d_flip']
|
||||
mean_3d = input['mean_3d']
|
||||
std_3d = input['std_3d']
|
||||
preds_3d = None
|
||||
vertex_pre = None
|
||||
|
||||
if [] == inputs_2d:
|
||||
predict_dict = {'success': False, KeypointsTypes.POSES_CAMERA: []}
|
||||
return predict_dict
|
||||
|
||||
with torch.no_grad():
|
||||
for i, pose_2d in enumerate(inputs_2d):
|
||||
pose_2d = pose_2d.unsqueeze(0).cuda(non_blocking=True) \
|
||||
if torch.cuda.is_available() else pose_2d.unsqueeze(0)
|
||||
pose_2d_flip = inputs_2d_flip[i]
|
||||
pose_2d_flip = pose_2d_flip.unsqueeze(0).cuda(non_blocking=True) \
|
||||
if torch.cuda.is_available() else pose_2d_flip.unsqueeze(0)
|
||||
mean_3d = mean_3d.unsqueeze(0).cuda(non_blocking=True) \
|
||||
if torch.cuda.is_available() else mean_3d.unsqueeze(0)
|
||||
std_3d = std_3d.unsqueeze(0).cuda(non_blocking=True) \
|
||||
if torch.cuda.is_available() else std_3d.unsqueeze(0)
|
||||
|
||||
vertex_pre = self.net(pose_2d, mean_3d, std_3d)
|
||||
vertex_pre_flip = self.net(pose_2d_flip, mean_3d, std_3d)
|
||||
vertex_pre = self.avg_flip(vertex_pre, vertex_pre_flip)
|
||||
|
||||
# concat the prediction results for each window_size
|
||||
predict_3d = vertex_pre.permute(
|
||||
0, 2, 3, 1).contiguous()[0][:self.window_size]
|
||||
if preds_3d is None:
|
||||
preds_3d = predict_3d
|
||||
else:
|
||||
preds_3d = torch.concat((preds_3d, predict_3d), dim=0)
|
||||
remain_pose_results = vertex_pre.permute(
|
||||
0, 2, 3, 1).contiguous()[0][self.window_size:]
|
||||
preds_3d = torch.concat((preds_3d, remain_pose_results), dim=0)
|
||||
|
||||
preds_3d = preds_3d.unsqueeze(0) # add batch dim
|
||||
preds_3d = preds_3d / self.cfg.model.INPUT.res_w # Normalize to [-1, 1]
|
||||
predict_dict = {'success': True, KeypointsTypes.POSES_CAMERA: preds_3d}
|
||||
|
||||
return predict_dict
|
||||
103
modelscope/models/cv/body_3d_keypoints/hdformer/skeleton.py
Normal file
103
modelscope/models/cv/body_3d_keypoints/hdformer/skeleton.py
Normal file
@@ -0,0 +1,103 @@
|
||||
# Copyright (c) 2018-present, Facebook, Inc. All rights reserved.
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Skeleton:
|
||||
|
||||
def __init__(self, parents, joints_left, joints_right):
|
||||
assert len(joints_left) == len(joints_right)
|
||||
|
||||
self._parents = np.array(parents)
|
||||
self._joints_left = joints_left
|
||||
self._joints_right = joints_right
|
||||
self._compute_metadata()
|
||||
|
||||
def num_joints(self):
|
||||
return len(self._parents)
|
||||
|
||||
def parents(self):
|
||||
return self._parents
|
||||
|
||||
def has_children(self):
|
||||
return self._has_children
|
||||
|
||||
def children(self):
|
||||
return self._children
|
||||
|
||||
def remove_joints(self, joints_to_remove):
|
||||
"""
|
||||
Remove the joints specified in 'joints_to_remove'.
|
||||
"""
|
||||
valid_joints = []
|
||||
for joint in range(len(self._parents)):
|
||||
if joint not in joints_to_remove:
|
||||
valid_joints.append(joint)
|
||||
|
||||
for i in range(len(self._parents)):
|
||||
while self._parents[i] in joints_to_remove:
|
||||
self._parents[i] = self._parents[self._parents[i]]
|
||||
|
||||
index_offsets = np.zeros(len(self._parents), dtype=int)
|
||||
new_parents = []
|
||||
for i, parent in enumerate(self._parents):
|
||||
if i not in joints_to_remove:
|
||||
new_parents.append(parent - index_offsets[parent])
|
||||
else:
|
||||
index_offsets[i:] += 1
|
||||
self._parents = np.array(new_parents)
|
||||
|
||||
if self._joints_left is not None:
|
||||
new_joints_left = []
|
||||
for joint in self._joints_left:
|
||||
if joint in valid_joints:
|
||||
new_joints_left.append(joint - index_offsets[joint])
|
||||
self._joints_left = new_joints_left
|
||||
if self._joints_right is not None:
|
||||
new_joints_right = []
|
||||
for joint in self._joints_right:
|
||||
if joint in valid_joints:
|
||||
new_joints_right.append(joint - index_offsets[joint])
|
||||
self._joints_right = new_joints_right
|
||||
|
||||
self._compute_metadata()
|
||||
|
||||
return valid_joints
|
||||
|
||||
def joints_left(self):
|
||||
return self._joints_left
|
||||
|
||||
def joints_right(self):
|
||||
return self._joints_right
|
||||
|
||||
def _compute_metadata(self):
|
||||
self._has_children = np.zeros(len(self._parents)).astype(bool)
|
||||
for i, parent in enumerate(self._parents):
|
||||
if parent != -1:
|
||||
self._has_children[parent] = True
|
||||
|
||||
self._children = []
|
||||
for i, parent in enumerate(self._parents):
|
||||
self._children.append([])
|
||||
for i, parent in enumerate(self._parents):
|
||||
if parent != -1:
|
||||
self._children[parent].append(i)
|
||||
|
||||
|
||||
def get_skeleton():
|
||||
skeleton = Skeleton(
|
||||
parents=[
|
||||
-1, 0, 1, 2, 3, 4, 0, 6, 7, 8, 9, 0, 11, 12, 13, 14, 12, 16, 17,
|
||||
18, 19, 20, 19, 22, 12, 24, 25, 26, 27, 28, 27, 30
|
||||
],
|
||||
joints_left=[6, 7, 8, 9, 10, 16, 17, 18, 19, 20, 21, 22, 23],
|
||||
joints_right=[1, 2, 3, 4, 5, 24, 25, 26, 27, 28, 29, 30, 31])
|
||||
# Bring the skeleton to 17 joints instead of the original 32
|
||||
skeleton.remove_joints(
|
||||
[4, 5, 9, 10, 11, 16, 20, 21, 22, 23, 24, 28, 29, 30, 31])
|
||||
# Rewire shoulders to the correct parents
|
||||
skeleton._parents[11] = 8
|
||||
skeleton._parents[14] = 8
|
||||
# Fix children error
|
||||
skeleton._children[7] = [8]
|
||||
skeleton._children[8] = [9, 11, 14]
|
||||
return skeleton
|
||||
@@ -16,8 +16,8 @@ from matplotlib.animation import writers
|
||||
from matplotlib.ticker import MultipleLocator
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.models.cv.body_3d_keypoints.body_3d_pose import (
|
||||
BodyKeypointsDetection3D, KeypointsTypes)
|
||||
from modelscope.models.cv.body_3d_keypoints.cannonical_pose.body_3d_pose import \
|
||||
KeypointsTypes
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.pipelines.base import Input, Model, Pipeline, Tensor
|
||||
@@ -112,11 +112,19 @@ def convert_2_h36m_data(lst_kps, lst_bboxes, joints_nbr=15):
|
||||
Tasks.body_3d_keypoints, module_name=Pipelines.body_3d_keypoints)
|
||||
class Body3DKeypointsPipeline(Pipeline):
|
||||
|
||||
def __init__(self, model: Union[str, BodyKeypointsDetection3D], **kwargs):
|
||||
def __init__(self, model: str, **kwargs):
|
||||
"""Human body 3D pose estimation.
|
||||
|
||||
Args:
|
||||
model (Union[str, BodyKeypointsDetection3D]): model id on modelscope hub.
|
||||
model (str): model id on modelscope hub.
|
||||
kwargs (dict, `optional`): Extra kwargs passed into the preprocessor's constructor.
|
||||
Example:
|
||||
>>> from modelscope.pipelines import pipeline
|
||||
>>> body_3d_keypoints = pipeline(Tasks.body_3d_keypoints,
|
||||
model='damo/cv_hdformer_body-3d-keypoints_video')
|
||||
>>> test_video_url = 'https://modelscope.oss-cn-beijing.aliyuncs.com/test/videos/Walking.54138969.mp4'
|
||||
>>> output = body_3d_keypoints(test_video_url)
|
||||
>>> print(output)
|
||||
"""
|
||||
super().__init__(model=model, **kwargs)
|
||||
|
||||
@@ -130,6 +138,10 @@ class Body3DKeypointsPipeline(Pipeline):
|
||||
model=self.human_body_2d_kps_det_pipeline,
|
||||
device='gpu' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
self.max_frame = self.keypoint_model_3d.cfg.model.INPUT.MAX_FRAME \
|
||||
if hasattr(self.keypoint_model_3d.cfg.model.INPUT, 'MAX_FRAME') \
|
||||
else self.keypoint_model_3d.cfg.model.INPUT.max_frame # max video frame number to be predicted 3D joints
|
||||
|
||||
def preprocess(self, input: Input) -> Dict[str, Any]:
|
||||
self.video_url = input
|
||||
video_frames = self.read_video_frames(self.video_url)
|
||||
@@ -139,7 +151,6 @@ class Body3DKeypointsPipeline(Pipeline):
|
||||
|
||||
all_2d_poses = []
|
||||
all_boxes_with_socre = []
|
||||
max_frame = self.keypoint_model_3d.cfg.model.INPUT.MAX_FRAME # max video frame number to be predicted 3D joints
|
||||
for i, frame in enumerate(video_frames):
|
||||
kps_2d = self.human_body_2d_kps_detector(frame)
|
||||
if [] == kps_2d.get('boxes'):
|
||||
@@ -157,7 +168,7 @@ class Body3DKeypointsPipeline(Pipeline):
|
||||
all_boxes_with_socre.append(
|
||||
list(np.array(box).reshape(
|
||||
(-1))) + [score]) # construct to list with shape [5]
|
||||
if (i + 1) >= max_frame:
|
||||
if (i + 1) >= self.max_frame:
|
||||
break
|
||||
|
||||
all_2d_poses_np = np.array(all_2d_poses).reshape(
|
||||
@@ -166,10 +177,11 @@ class Body3DKeypointsPipeline(Pipeline):
|
||||
all_boxes_np = np.array(all_boxes_with_socre).reshape(
|
||||
(len(all_boxes_with_socre), 5)) # [x1, y1, x2, y2, score]
|
||||
|
||||
joint_num = self.keypoint_model_3d.cfg.model.MODEL.IN_NUM_JOINTS \
|
||||
if hasattr(self.keypoint_model_3d.cfg.model.MODEL, 'IN_NUM_JOINTS') \
|
||||
else self.keypoint_model_3d.cfg.model.MODEL.n_joints
|
||||
kps_2d_h36m_17 = convert_2_h36m_data(
|
||||
all_2d_poses_np,
|
||||
all_boxes_np,
|
||||
joints_nbr=self.keypoint_model_3d.cfg.model.MODEL.IN_NUM_JOINTS)
|
||||
all_2d_poses_np, all_boxes_np, joints_nbr=joint_num)
|
||||
kps_2d_h36m_17 = np.array(kps_2d_h36m_17)
|
||||
res = {'success': True, 'input_2d_pts': kps_2d_h36m_17}
|
||||
return res
|
||||
@@ -246,7 +258,6 @@ class Body3DKeypointsPipeline(Pipeline):
|
||||
raise Exception('modelscope error: %s cannot get video fps info.' %
|
||||
(video_url))
|
||||
|
||||
max_frame_num = self.keypoint_model_3d.cfg.model.INPUT.MAX_FRAME
|
||||
frame_idx = 0
|
||||
while True:
|
||||
ret, frame = cap.read()
|
||||
@@ -256,7 +267,7 @@ class Body3DKeypointsPipeline(Pipeline):
|
||||
timestamp_format(seconds=frame_idx / self.fps))
|
||||
frame_idx += 1
|
||||
frames.append(frame)
|
||||
if frame_idx >= max_frame_num:
|
||||
if frame_idx >= self.max_frame:
|
||||
break
|
||||
cap.release()
|
||||
return frames
|
||||
@@ -278,7 +289,8 @@ class Body3DKeypointsPipeline(Pipeline):
|
||||
[12, 13], [9, 10]] # connection between joints
|
||||
|
||||
fig = plt.figure()
|
||||
ax = p3.Axes3D(fig)
|
||||
ax = p3.Axes3D(fig, auto_add_to_figure=False)
|
||||
fig.add_axes(ax)
|
||||
x_major_locator = MultipleLocator(0.5)
|
||||
|
||||
ax.xaxis.set_major_locator(x_major_locator)
|
||||
|
||||
50
tests/pipelines/test_body_3d_keypoints_hdformer.py
Normal file
50
tests/pipelines/test_body_3d_keypoints_hdformer.py
Normal file
@@ -0,0 +1,50 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import unittest
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.pipelines.base import Pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.demo_utils import DemoCompatibilityCheck
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
|
||||
class Body3DKeypointsHDFormerTest(unittest.TestCase, DemoCompatibilityCheck):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.model_id = 'damo/cv_hdformer_body-3d-keypoints_video'
|
||||
self.test_video = 'data/test/videos/Walking.54138969.mp4'
|
||||
self.task = Tasks.body_3d_keypoints
|
||||
|
||||
def pipeline_inference(self, pipeline: Pipeline, pipeline_input):
|
||||
output = pipeline(pipeline_input, output_video='./result.mp4')
|
||||
poses = np.array(output[OutputKeys.KEYPOINTS])
|
||||
print(f'result 3d points shape {poses.shape}')
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_modelhub_with_video_file(self):
|
||||
body_3d_keypoints = pipeline(
|
||||
Tasks.body_3d_keypoints, model=self.model_id)
|
||||
pipeline_input = self.test_video
|
||||
self.pipeline_inference(
|
||||
body_3d_keypoints, pipeline_input=pipeline_input)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_modelhub_with_video_stream(self):
|
||||
body_3d_keypoints = pipeline(Tasks.body_3d_keypoints)
|
||||
cap = cv2.VideoCapture(self.test_video)
|
||||
if not cap.isOpened():
|
||||
raise Exception('modelscope error: %s cannot be decoded by OpenCV.'
|
||||
% (self.test_video))
|
||||
self.pipeline_inference(body_3d_keypoints, pipeline_input=cap)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_demo_compatibility(self):
|
||||
self.compatibility_check()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user