From cb1aa66a498be32556c32811f77da1ce9da904e7 Mon Sep 17 00:00:00 2001 From: "shouzhou.bx" Date: Thu, 4 Aug 2022 23:34:48 +0800 Subject: [PATCH] [to #43259593]cv:add human pose eastimation to maas-lib add human pose eastimation to maas-lib v3 Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9491970 --- .../images/keypoints_detect/000000438304.jpg | 3 + .../images/keypoints_detect/000000438862.jpg | 3 + .../images/keypoints_detect/000000439522.jpg | 3 + .../images/keypoints_detect/000000440336.jpg | 3 + .../images/keypoints_detect/000000442836.jpg | 3 + .../images/keypoints_detect/000000447088.jpg | 3 + .../images/keypoints_detect/000000447917.jpg | 3 + .../images/keypoints_detect/000000448263.jpg | 3 + .../body_keypoints_detection.jpg | 3 + modelscope/metainfo.py | 2 + modelscope/models/cv/__init__.py | 8 +- .../models/cv/body_2d_keypoints/__init__.py | 23 + .../body_2d_keypoints/hrnet_basic_modules.py | 397 ++++++++++++++++++ .../models/cv/body_2d_keypoints/hrnet_v2.py | 221 ++++++++++ modelscope/models/cv/body_2d_keypoints/w48.py | 51 +++ modelscope/outputs.py | 21 + modelscope/pipelines/builder.py | 3 + modelscope/pipelines/cv/__init__.py | 2 + .../cv/body_2d_keypoints_pipeline.py | 261 ++++++++++++ modelscope/utils/constant.py | 1 + tests/pipelines/test_body_2d_keypoints.py | 100 +++++ 21 files changed, 1113 insertions(+), 4 deletions(-) create mode 100644 data/test/images/keypoints_detect/000000438304.jpg create mode 100644 data/test/images/keypoints_detect/000000438862.jpg create mode 100644 data/test/images/keypoints_detect/000000439522.jpg create mode 100644 data/test/images/keypoints_detect/000000440336.jpg create mode 100644 data/test/images/keypoints_detect/000000442836.jpg create mode 100644 data/test/images/keypoints_detect/000000447088.jpg create mode 100644 data/test/images/keypoints_detect/000000447917.jpg create mode 100644 data/test/images/keypoints_detect/000000448263.jpg create mode 100644 data/test/images/keypoints_detect/body_keypoints_detection.jpg create mode 100644 modelscope/models/cv/body_2d_keypoints/__init__.py create mode 100644 modelscope/models/cv/body_2d_keypoints/hrnet_basic_modules.py create mode 100644 modelscope/models/cv/body_2d_keypoints/hrnet_v2.py create mode 100644 modelscope/models/cv/body_2d_keypoints/w48.py create mode 100644 modelscope/pipelines/cv/body_2d_keypoints_pipeline.py create mode 100644 tests/pipelines/test_body_2d_keypoints.py diff --git a/data/test/images/keypoints_detect/000000438304.jpg b/data/test/images/keypoints_detect/000000438304.jpg new file mode 100644 index 00000000..5d03c471 --- /dev/null +++ b/data/test/images/keypoints_detect/000000438304.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:64ab6a5556b022cbd398d98cd5bb243a4ee6e4ea6e3285f433eb78b76b53fd4e +size 269177 diff --git a/data/test/images/keypoints_detect/000000438862.jpg b/data/test/images/keypoints_detect/000000438862.jpg new file mode 100644 index 00000000..47946a91 --- /dev/null +++ b/data/test/images/keypoints_detect/000000438862.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3689831ed23f734ebab9405f48ffbfbbefb778e9de3101a9d56e421ea45288cf +size 248595 diff --git a/data/test/images/keypoints_detect/000000439522.jpg b/data/test/images/keypoints_detect/000000439522.jpg new file mode 100644 index 00000000..32b59e7a --- /dev/null +++ b/data/test/images/keypoints_detect/000000439522.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:663545f71af556370c7cba7fd8010a665d00c0b477075562a3d7669c6d853ad3 +size 107685 diff --git a/data/test/images/keypoints_detect/000000440336.jpg b/data/test/images/keypoints_detect/000000440336.jpg new file mode 100644 index 00000000..b61d7c8d --- /dev/null +++ b/data/test/images/keypoints_detect/000000440336.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e5c2df473a26427ae57950acec86d1e4d3a49cdf1a18d427cd1a354465408f00 +size 102909 diff --git a/data/test/images/keypoints_detect/000000442836.jpg b/data/test/images/keypoints_detect/000000442836.jpg new file mode 100644 index 00000000..9642df68 --- /dev/null +++ b/data/test/images/keypoints_detect/000000442836.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:44b225eaff012bd016fcfe8a3dbeace93fd418164f40e4b5f5b9f0d76f39097b +size 308635 diff --git a/data/test/images/keypoints_detect/000000447088.jpg b/data/test/images/keypoints_detect/000000447088.jpg new file mode 100644 index 00000000..8d4f1752 --- /dev/null +++ b/data/test/images/keypoints_detect/000000447088.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:510da487b16303646cf4b500cae0a4168cba2feb3dd706c007a3f5c64400501c +size 148413 diff --git a/data/test/images/keypoints_detect/000000447917.jpg b/data/test/images/keypoints_detect/000000447917.jpg new file mode 100644 index 00000000..542c7b3a --- /dev/null +++ b/data/test/images/keypoints_detect/000000447917.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dbaa52b9ecc59b899500db9200ce65b17aa8b87172c8c70de585fa27c80e7ad1 +size 238442 diff --git a/data/test/images/keypoints_detect/000000448263.jpg b/data/test/images/keypoints_detect/000000448263.jpg new file mode 100644 index 00000000..474563e2 --- /dev/null +++ b/data/test/images/keypoints_detect/000000448263.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:72fcff7fd4da5ede2d3c1a31449769b0595685f7250597f05cd176c4c80ced03 +size 37753 diff --git a/data/test/images/keypoints_detect/body_keypoints_detection.jpg b/data/test/images/keypoints_detect/body_keypoints_detection.jpg new file mode 100644 index 00000000..71ce7d7e --- /dev/null +++ b/data/test/images/keypoints_detect/body_keypoints_detection.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:379e11d7fc3734d3ec95afd0d86460b4653fbf4bb1f57f993610d6a6fd30fd3d +size 1702339 diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 16aa8bb6..91d0a4b6 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -18,6 +18,7 @@ class Models(object): cascade_mask_rcnn_swin = 'cascade_mask_rcnn_swin' gpen = 'gpen' product_retrieval_embedding = 'product-retrieval-embedding' + body_2d_keypoints = 'body-2d-keypoints' # nlp models bert = 'bert' @@ -77,6 +78,7 @@ class Pipelines(object): action_recognition = 'TAdaConv_action-recognition' animal_recognation = 'resnet101-animal_recog' cmdssl_video_embedding = 'cmdssl-r2p1d_video_embedding' + body_2d_keypoints = 'hrnetv2w32_body-2d-keypoints_image' human_detection = 'resnet18-human-detection' object_detection = 'vit-object-detection' image_classification = 'image-classification' diff --git a/modelscope/models/cv/__init__.py b/modelscope/models/cv/__init__.py index beeb0994..397c2fba 100644 --- a/modelscope/models/cv/__init__.py +++ b/modelscope/models/cv/__init__.py @@ -1,8 +1,8 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -from . import (action_recognition, animal_recognition, cartoon, - cmdssl_video_embedding, face_detection, face_generation, - image_classification, image_color_enhance, image_colorization, - image_denoise, image_instance_segmentation, +from . import (action_recognition, animal_recognition, body_2d_keypoints, + cartoon, cmdssl_video_embedding, face_detection, + face_generation, image_classification, image_color_enhance, + image_colorization, image_denoise, image_instance_segmentation, image_portrait_enhancement, image_to_image_generation, image_to_image_translation, object_detection, product_retrieval_embedding, super_resolution, virual_tryon) diff --git a/modelscope/models/cv/body_2d_keypoints/__init__.py b/modelscope/models/cv/body_2d_keypoints/__init__.py new file mode 100644 index 00000000..d953b773 --- /dev/null +++ b/modelscope/models/cv/body_2d_keypoints/__init__.py @@ -0,0 +1,23 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + + from .hrnet_v2 import PoseHighResolutionNetV2 + +else: + _import_structure = { + 'keypoints_detector': ['PoseHighResolutionNetV2'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/body_2d_keypoints/hrnet_basic_modules.py b/modelscope/models/cv/body_2d_keypoints/hrnet_basic_modules.py new file mode 100644 index 00000000..3b960688 --- /dev/null +++ b/modelscope/models/cv/body_2d_keypoints/hrnet_basic_modules.py @@ -0,0 +1,397 @@ +# The implementation is based on HRNET, available at https://github.com/HRNet/HigherHRNet-Human-Pose-Estimation. + +import torch +import torch.nn as nn + +BN_MOMENTUM = 0.1 + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d( + in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=1, + bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.conv2 = nn.Conv2d( + planes, + planes, + kernel_size=3, + stride=stride, + padding=1, + bias=False) + self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.conv3 = nn.Conv2d( + planes, planes * self.expansion, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d( + planes * self.expansion, momentum=BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class HighResolutionModule(nn.Module): + + def __init__(self, + num_branches, + blocks, + num_blocks, + num_inchannels, + num_channels, + fuse_method, + multi_scale_output=True): + super(HighResolutionModule, self).__init__() + self._check_branches(num_branches, blocks, num_blocks, num_inchannels, + num_channels) + + self.num_inchannels = num_inchannels + self.fuse_method = fuse_method + self.num_branches = num_branches + + self.multi_scale_output = multi_scale_output + + self.branches = self._make_branches(num_branches, blocks, num_blocks, + num_channels) + self.fuse_layers = self._make_fuse_layers() + self.relu = nn.ReLU(True) + + def _check_branches(self, num_branches, blocks, num_blocks, num_inchannels, + num_channels): + if num_branches != len(num_blocks): + error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format( + num_branches, len(num_blocks)) + raise ValueError(error_msg) + + if num_branches != len(num_channels): + error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format( + num_branches, len(num_channels)) + raise ValueError(error_msg) + + if num_branches != len(num_inchannels): + error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format( + num_branches, len(num_inchannels)) + raise ValueError(error_msg) + + def _make_one_branch(self, + branch_index, + block, + num_blocks, + num_channels, + stride=1): + downsample = None + if stride != 1 or \ + self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion: + downsample = nn.Sequential( + nn.Conv2d( + self.num_inchannels[branch_index], + num_channels[branch_index] * block.expansion, + kernel_size=1, + stride=stride, + bias=False), + nn.BatchNorm2d( + num_channels[branch_index] * block.expansion, + momentum=BN_MOMENTUM), + ) + layers = [] + layers.append( + block(self.num_inchannels[branch_index], + num_channels[branch_index], stride, downsample)) + self.num_inchannels[branch_index] = \ + num_channels[branch_index] * block.expansion + for i in range(1, num_blocks[branch_index]): + layers.append( + block(self.num_inchannels[branch_index], + num_channels[branch_index])) + + return nn.Sequential(*layers) + + def _make_branches(self, num_branches, block, num_blocks, num_channels): + branches = [] + + for i in range(num_branches): + branches.append( + self._make_one_branch(i, block, num_blocks, num_channels)) + + return nn.ModuleList(branches) + + def _make_fuse_layers(self): + if self.num_branches == 1: + return None + + num_branches = self.num_branches + num_inchannels = self.num_inchannels + fuse_layers = [] + for i in range(num_branches if self.multi_scale_output else 1): + fuse_layer = [] + for j in range(num_branches): + if j > i: + fuse_layer.append( + nn.Sequential( + nn.Conv2d( + num_inchannels[j], + num_inchannels[i], + 1, + 1, + 0, + bias=False), nn.BatchNorm2d(num_inchannels[i]), + nn.Upsample( + scale_factor=2**(j - i), mode='nearest'))) + elif j == i: + fuse_layer.append(None) + else: + conv3x3s = [] + for k in range(i - j): + if k == i - j - 1: + num_outchannels_conv3x3 = num_inchannels[i] + conv3x3s.append( + nn.Sequential( + nn.Conv2d( + num_inchannels[j], + num_outchannels_conv3x3, + 3, + 2, + 1, + bias=False), + nn.BatchNorm2d(num_outchannels_conv3x3))) + else: + num_outchannels_conv3x3 = num_inchannels[j] + conv3x3s.append( + nn.Sequential( + nn.Conv2d( + num_inchannels[j], + num_outchannels_conv3x3, + 3, + 2, + 1, + bias=False), + nn.BatchNorm2d(num_outchannels_conv3x3), + nn.ReLU(True))) + fuse_layer.append(nn.Sequential(*conv3x3s)) + fuse_layers.append(nn.ModuleList(fuse_layer)) + + return nn.ModuleList(fuse_layers) + + def get_num_inchannels(self): + return self.num_inchannels + + def forward(self, x): + if self.num_branches == 1: + return [self.branches[0](x[0])] + + for i in range(self.num_branches): + x[i] = self.branches[i](x[i]) + + x_fuse = [] + + for i in range(len(self.fuse_layers)): + y = x[0] if i == 0 else self.fuse_layers[i][0](x[0]) + for j in range(1, self.num_branches): + if i == j: + y = y + x[j] + else: + y = y + self.fuse_layers[i][j](x[j]) + x_fuse.append(self.relu(y)) + + return x_fuse + + +def conv_bn(in_channels, out_channels, kernel_size, stride, padding, groups=1): + result = nn.Sequential() + result.add_module( + 'conv', + nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + groups=groups, + bias=False)) + result.add_module('bn', nn.BatchNorm2d(num_features=out_channels)) + return result + + +def upsample(scale, oup): + return nn.Sequential( + nn.Upsample(scale_factor=scale, mode='bilinear'), + nn.Conv2d( + in_channels=oup, + out_channels=oup, + kernel_size=3, + stride=1, + padding=1, + groups=1, + bias=False), nn.BatchNorm2d(oup), nn.PReLU()) + + +class SE_Block(nn.Module): + + def __init__(self, c, r=16): + super().__init__() + self.squeeze = nn.AdaptiveAvgPool2d(1) + self.excitation = nn.Sequential( + nn.Linear(c, c // r, bias=False), nn.ReLU(inplace=True), + nn.Linear(c // r, c, bias=False), nn.Sigmoid()) + + def forward(self, x): + bs, c, _, _ = x.shape + y = self.squeeze(x).view(bs, c) + y = self.excitation(y).view(bs, c, 1, 1) + return x * y.expand_as(x) + + +class BasicBlockSE(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, r=64): + super(BasicBlockSE, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.downsample = downsample + self.stride = stride + self.se = SE_Block(planes, r) + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.se(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class BottleneckSE(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, r=64): + super(BottleneckSE, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.conv2 = nn.Conv2d( + planes, + planes, + kernel_size=3, + stride=stride, + padding=1, + bias=False) + self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.conv3 = nn.Conv2d( + planes, planes * self.expansion, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d( + planes * self.expansion, momentum=BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + self.se = SE_Block(planes * self.expansion, r) + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + out = self.se(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +blocks_dict = { + 'BASIC': BasicBlock, + 'BOTTLENECK': Bottleneck, + 'BASICSE': BasicBlockSE, + 'BOTTLENECKSE': BottleneckSE, +} diff --git a/modelscope/models/cv/body_2d_keypoints/hrnet_v2.py b/modelscope/models/cv/body_2d_keypoints/hrnet_v2.py new file mode 100644 index 00000000..1570c8cc --- /dev/null +++ b/modelscope/models/cv/body_2d_keypoints/hrnet_v2.py @@ -0,0 +1,221 @@ +import os + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from modelscope.metainfo import Models +from modelscope.models.base.base_torch_model import TorchModel +from modelscope.models.builder import MODELS +from modelscope.models.cv.body_2d_keypoints.hrnet_basic_modules import ( + BN_MOMENTUM, BasicBlock, Bottleneck, HighResolutionModule, blocks_dict) +from modelscope.models.cv.body_2d_keypoints.w48 import cfg_128x128_15 +from modelscope.utils.constant import Tasks + + +@MODELS.register_module( + Tasks.body_2d_keypoints, module_name=Models.body_2d_keypoints) +class PoseHighResolutionNetV2(TorchModel): + + def __init__(self, cfg=None, **kwargs): + if cfg is None: + cfg = cfg_128x128_15 + self.inplanes = 64 + extra = cfg['MODEL']['EXTRA'] + super(PoseHighResolutionNetV2, self).__init__(**kwargs) + + # stem net + self.conv1 = nn.Conv2d( + 3, 64, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) + self.conv2 = nn.Conv2d( + 64, 64, kernel_size=3, stride=2, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + self.layer1 = self._make_layer(Bottleneck, 64, 4) + + self.stage2_cfg = cfg['MODEL']['EXTRA']['STAGE2'] + num_channels = self.stage2_cfg['NUM_CHANNELS'] + block = blocks_dict[self.stage2_cfg['BLOCK']] + num_channels = [ + num_channels[i] * block.expansion + for i in range(len(num_channels)) + ] + self.transition1 = self._make_transition_layer([256], num_channels) + self.stage2, pre_stage_channels = self._make_stage( + self.stage2_cfg, num_channels) + + self.stage3_cfg = cfg['MODEL']['EXTRA']['STAGE3'] + num_channels = self.stage3_cfg['NUM_CHANNELS'] + block = blocks_dict[self.stage3_cfg['BLOCK']] + num_channels = [ + num_channels[i] * block.expansion + for i in range(len(num_channels)) + ] + self.transition2 = self._make_transition_layer(pre_stage_channels, + num_channels) + self.stage3, pre_stage_channels = self._make_stage( + self.stage3_cfg, num_channels) + + self.stage4_cfg = cfg['MODEL']['EXTRA']['STAGE4'] + num_channels = self.stage4_cfg['NUM_CHANNELS'] + block = blocks_dict[self.stage4_cfg['BLOCK']] + num_channels = [ + num_channels[i] * block.expansion + for i in range(len(num_channels)) + ] + self.transition3 = self._make_transition_layer(pre_stage_channels, + num_channels) + self.stage4, pre_stage_channels = self._make_stage( + self.stage4_cfg, num_channels, multi_scale_output=True) + """final four layers""" + last_inp_channels = np.int(np.sum(pre_stage_channels)) + self.final_layer = nn.Sequential( + nn.Conv2d( + in_channels=last_inp_channels, + out_channels=last_inp_channels, + kernel_size=1, + stride=1, + padding=0), + nn.BatchNorm2d(last_inp_channels, momentum=BN_MOMENTUM), + nn.ReLU(inplace=False), + nn.Conv2d( + in_channels=last_inp_channels, + out_channels=cfg['MODEL']['NUM_JOINTS'], + kernel_size=extra['FINAL_CONV_KERNEL'], + stride=1, + padding=1 if extra['FINAL_CONV_KERNEL'] == 3 else 0)) + + self.pretrained_layers = cfg['MODEL']['EXTRA']['PRETRAINED_LAYERS'] + + def _make_transition_layer(self, num_channels_pre_layer, + num_channels_cur_layer): + num_branches_cur = len(num_channels_cur_layer) + num_branches_pre = len(num_channels_pre_layer) + + transition_layers = [] + for i in range(num_branches_cur): + if i < num_branches_pre: + if num_channels_cur_layer[i] != num_channels_pre_layer[i]: + transition_layers.append( + nn.Sequential( + nn.Conv2d( + num_channels_pre_layer[i], + num_channels_cur_layer[i], + 3, + 1, + 1, + bias=False), + nn.BatchNorm2d(num_channels_cur_layer[i]), + nn.ReLU(inplace=True))) + else: + transition_layers.append(None) + else: + conv3x3s = [] + for j in range(i + 1 - num_branches_pre): + inchannels = num_channels_pre_layer[-1] + outchannels = num_channels_cur_layer[ + i] if j == i - num_branches_pre else inchannels + conv3x3s.append( + nn.Sequential( + nn.Conv2d( + inchannels, outchannels, 3, 2, 1, bias=False), + nn.BatchNorm2d(outchannels), + nn.ReLU(inplace=True))) + transition_layers.append(nn.Sequential(*conv3x3s)) + + return nn.ModuleList(transition_layers) + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d( + self.inplanes, + planes * block.expansion, + kernel_size=1, + stride=stride, + bias=False), + nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def _make_stage(self, + layer_config, + num_inchannels, + multi_scale_output=True): + num_modules = layer_config['NUM_MODULES'] + num_branches = layer_config['NUM_BRANCHES'] + num_blocks = layer_config['NUM_BLOCKS'] + num_channels = layer_config['NUM_CHANNELS'] + block = blocks_dict[layer_config['BLOCK']] + fuse_method = layer_config['FUSE_METHOD'] + + modules = [] + for i in range(num_modules): + if not multi_scale_output and i == num_modules - 1: + reset_multi_scale_output = False + else: + reset_multi_scale_output = True + + modules.append( + HighResolutionModule(num_branches, block, num_blocks, + num_inchannels, num_channels, fuse_method, + reset_multi_scale_output)) + num_inchannels = modules[-1].get_num_inchannels() + + return nn.Sequential(*modules), num_inchannels + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.conv2(x) + x = self.bn2(x) + x = self.relu(x) + x = self.layer1(x) + + x_list = [] + for i in range(self.stage2_cfg['NUM_BRANCHES']): + if self.transition1[i] is not None: + x_list.append(self.transition1[i](x)) + else: + x_list.append(x) + + y_list = self.stage2(x_list) + + x_list = [] + for i in range(self.stage3_cfg['NUM_BRANCHES']): + if self.transition2[i] is not None: + x_list.append(self.transition2[i](y_list[-1])) + else: + x_list.append(y_list[i]) + + y_list = self.stage3(x_list) + + x_list = [] + for i in range(self.stage4_cfg['NUM_BRANCHES']): + if self.transition3[i] is not None: + x_list.append(self.transition3[i](y_list[-1])) + else: + x_list.append(y_list[i]) + + y_list = self.stage4(x_list) + + y0_h, y0_w = y_list[0].size(2), y_list[0].size(3) + y1 = F.upsample(y_list[1], size=(y0_h, y0_w), mode='bilinear') + y2 = F.upsample(y_list[2], size=(y0_h, y0_w), mode='bilinear') + y3 = F.upsample(y_list[3], size=(y0_h, y0_w), mode='bilinear') + + y = torch.cat([y_list[0], y1, y2, y3], 1) + output = self.final_layer(y) + + return output diff --git a/modelscope/models/cv/body_2d_keypoints/w48.py b/modelscope/models/cv/body_2d_keypoints/w48.py new file mode 100644 index 00000000..7140f8fe --- /dev/null +++ b/modelscope/models/cv/body_2d_keypoints/w48.py @@ -0,0 +1,51 @@ +cfg_128x128_15 = { + 'DATASET': { + 'TYPE': 'DAMO', + 'PARENT_IDS': [0, 0, 1, 2, 3, 1, 5, 6, 14, 8, 9, 14, 11, 12, 1], + 'LEFT_IDS': [2, 3, 4, 8, 9, 10], + 'RIGHT_IDS': [5, 6, 7, 11, 12, 13], + 'SPINE_IDS': [0, 1, 14] + }, + 'MODEL': { + 'INIT_WEIGHTS': True, + 'NAME': 'pose_hrnet', + 'NUM_JOINTS': 15, + 'PRETRAINED': '', + 'TARGET_TYPE': 'gaussian', + 'IMAGE_SIZE': [128, 128], + 'HEATMAP_SIZE': [32, 32], + 'SIGMA': 2.0, + 'EXTRA': { + 'PRETRAINED_LAYERS': [ + 'conv1', 'bn1', 'conv2', 'bn2', 'layer1', 'transition1', + 'stage2', 'transition2', 'stage3', 'transition3', 'stage4' + ], + 'FINAL_CONV_KERNEL': + 1, + 'STAGE2': { + 'NUM_MODULES': 1, + 'NUM_BRANCHES': 2, + 'BLOCK': 'BASIC', + 'NUM_BLOCKS': [4, 4], + 'NUM_CHANNELS': [48, 96], + 'FUSE_METHOD': 'SUM' + }, + 'STAGE3': { + 'NUM_MODULES': 4, + 'NUM_BRANCHES': 3, + 'BLOCK': 'BASIC', + 'NUM_BLOCKS': [4, 4, 4], + 'NUM_CHANNELS': [48, 96, 192], + 'FUSE_METHOD': 'SUM' + }, + 'STAGE4': { + 'NUM_MODULES': 3, + 'NUM_BRANCHES': 4, + 'BLOCK': 'BASIC', + 'NUM_BLOCKS': [4, 4, 4, 4], + 'NUM_CHANNELS': [48, 96, 192, 384], + 'FUSE_METHOD': 'SUM' + }, + } + } +} diff --git a/modelscope/outputs.py b/modelscope/outputs.py index 8c88262d..a288a4c3 100644 --- a/modelscope/outputs.py +++ b/modelscope/outputs.py @@ -158,6 +158,27 @@ TASK_OUTPUTS = { # } Tasks.action_recognition: [OutputKeys.LABELS], + # human body keypoints detection result for single sample + # { + # "poses": [ + # [x, y], + # [x, y], + # [x, y] + # ] + # "scores": [ + # [score], + # [score], + # [score], + # ] + # "boxes": [ + # [x1, y1, x2, y2], + # [x1, y1, x2, y2], + # [x1, y1, x2, y2], + # ] + # } + Tasks.body_2d_keypoints: + [OutputKeys.POSES, OutputKeys.SCORES, OutputKeys.BOXES], + # live category recognition result for single video # { # "scores": [0.885272, 0.014790631, 0.014558001], diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index 28dd190a..ea18d7b7 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -87,6 +87,8 @@ DEFAULT_MODEL_FOR_PIPELINE = { Tasks.text_to_image_synthesis: (Pipelines.text_to_image_synthesis, 'damo/cv_diffusion_text-to-image-synthesis_tiny'), + Tasks.body_2d_keypoints: (Pipelines.body_2d_keypoints, + 'damo/cv_hrnetv2w32_body-2d-keypoints_image'), Tasks.face_detection: (Pipelines.face_detection, 'damo/cv_resnet_facedetection_scrfd10gkps'), Tasks.face_recognition: (Pipelines.face_recognition, @@ -238,6 +240,7 @@ def pipeline(task: str = None, cfg = ConfigDict(type=pipeline_name, model=model) cfg.device = device + if kwargs: cfg.update(kwargs) diff --git a/modelscope/pipelines/cv/__init__.py b/modelscope/pipelines/cv/__init__.py index d8b09c63..d7a8da2c 100644 --- a/modelscope/pipelines/cv/__init__.py +++ b/modelscope/pipelines/cv/__init__.py @@ -6,6 +6,7 @@ from modelscope.utils.import_utils import LazyImportModule if TYPE_CHECKING: from .action_recognition_pipeline import ActionRecognitionPipeline from .animal_recognition_pipeline import AnimalRecognitionPipeline + from .body_2d_keypoints_pipeline import Body2DKeypointsPipeline from .cmdssl_video_embedding_pipeline import CMDSSLVideoEmbeddingPipeline from .image_detection_pipeline import ImageDetectionPipeline from .face_detection_pipeline import FaceDetectionPipeline @@ -34,6 +35,7 @@ else: _import_structure = { 'action_recognition_pipeline': ['ActionRecognitionPipeline'], 'animal_recognition_pipeline': ['AnimalRecognitionPipeline'], + 'body_2d_keypoints_pipeline': ['Body2DKeypointsPipeline'], 'cmdssl_video_embedding_pipeline': ['CMDSSLVideoEmbeddingPipeline'], 'image_detection_pipeline': ['ImageDetectionPipeline'], 'face_detection_pipeline': ['FaceDetectionPipeline'], diff --git a/modelscope/pipelines/cv/body_2d_keypoints_pipeline.py b/modelscope/pipelines/cv/body_2d_keypoints_pipeline.py new file mode 100644 index 00000000..f16c48e4 --- /dev/null +++ b/modelscope/pipelines/cv/body_2d_keypoints_pipeline.py @@ -0,0 +1,261 @@ +import os.path as osp +from typing import Any, Dict, List, Union + +import cv2 +import json +import numpy as np +import torch +from PIL import Image +from torchvision import transforms + +from modelscope.metainfo import Pipelines +from modelscope.models.cv.body_2d_keypoints.hrnet_v2 import \ + PoseHighResolutionNetV2 +from modelscope.models.cv.body_2d_keypoints.w48 import cfg_128x128_15 +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.pipelines.base import Input, Model, Pipeline, Tensor +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import load_image +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.body_2d_keypoints, module_name=Pipelines.body_2d_keypoints) +class Body2DKeypointsPipeline(Pipeline): + + def __init__(self, model: str, human_detector: Pipeline, **kwargs): + super().__init__(model=model, **kwargs) + self.keypoint_model = KeypointsDetection(model) + self.keypoint_model.eval() + self.human_detector = human_detector + + def preprocess(self, input: Input) -> Dict[Tensor, Union[str, np.ndarray]]: + output = self.human_detector(input) + + if isinstance(input, str): + image = cv2.imread(input, -1)[:, :, 0:3] + elif isinstance(input, np.ndarray): + if len(input.shape) == 2: + image = cv2.cvtColor(input, cv2.COLOR_GRAY2BGR) + image = image[:, :, 0:3] + + return {'image': image, 'output': output} + + def forward(self, input: Tensor) -> Dict[Tensor, Dict[str, np.ndarray]]: + input_image = input['image'] + output = input['output'] + + bboxes = [] + scores = np.array(output[OutputKeys.SCORES].cpu(), dtype=np.float32) + boxes = np.array(output[OutputKeys.BOXES].cpu(), dtype=np.float32) + + for id, box in enumerate(boxes): + box_tmp = [ + box[0], box[1], box[2] - box[0], box[3] - box[1], scores[id], 0 + ] + bboxes.append(box_tmp) + if len(bboxes) == 0: + logger.error('cannot detect human in the image') + return [None, None] + human_images, metas = self.keypoint_model.preprocess( + [bboxes, input_image]) + outputs = self.keypoint_model.forward(human_images) + return [outputs, metas] + + def postprocess(self, input: Dict[Tensor, Dict[str, np.ndarray]], + **kwargs) -> str: + if input[0] is None or input[1] is None: + return { + OutputKeys.BOXES: [], + OutputKeys.POSES: [], + OutputKeys.SCORES: [] + } + + poses, scores, boxes = self.keypoint_model.postprocess(input) + return { + OutputKeys.BOXES: boxes, + OutputKeys.POSES: poses, + OutputKeys.SCORES: scores + } + + +class KeypointsDetection(): + + def __init__(self, model: str, **kwargs): + self.model = model + cfg = cfg_128x128_15 + self.key_points_model = PoseHighResolutionNetV2(cfg) + pretrained_state_dict = torch.load( + osp.join(self.model, ModelFile.TORCH_MODEL_FILE)) + self.key_points_model.load_state_dict( + pretrained_state_dict, strict=False) + + self.input_size = cfg['MODEL']['IMAGE_SIZE'] + self.lst_parent_ids = cfg['DATASET']['PARENT_IDS'] + self.lst_left_ids = cfg['DATASET']['LEFT_IDS'] + self.lst_right_ids = cfg['DATASET']['RIGHT_IDS'] + self.box_enlarge_ratio = 0.05 + + def train(self): + return self.key_points_model.train() + + def eval(self): + return self.key_points_model.eval() + + def forward(self, input: Tensor) -> Tensor: + with torch.no_grad(): + return self.key_points_model.forward(input) + + def get_pts(self, heatmaps): + [pts_num, height, width] = heatmaps.shape + pts = [] + scores = [] + for i in range(pts_num): + heatmap = heatmaps[i, :, :] + pt = np.where(heatmap == np.max(heatmap)) + scores.append(np.max(heatmap)) + x = pt[1][0] + y = pt[0][0] + + [h, w] = heatmap.shape + if x >= 1 and x <= w - 2 and y >= 1 and y <= h - 2: + x_diff = heatmap[y, x + 1] - heatmap[y, x - 1] + y_diff = heatmap[y + 1, x] - heatmap[y - 1, x] + x_sign = 0 + y_sign = 0 + if x_diff < 0: + x_sign = -1 + if x_diff > 0: + x_sign = 1 + if y_diff < 0: + y_sign = -1 + if y_diff > 0: + y_sign = 1 + x = x + x_sign * 0.25 + y = y + y_sign * 0.25 + + pts.append([x, y]) + return pts, scores + + def pts_transform(self, meta, pts, lt_x, lt_y): + pts_new = [] + s = meta['s'] + o = meta['o'] + size = len(pts) + for i in range(size): + ratio = 4 + x = (int(pts[i][0] * ratio) - o[0]) / s[0] + y = (int(pts[i][1] * ratio) - o[1]) / s[1] + + pt = [x, y] + pts_new.append(pt) + + return pts_new + + def postprocess(self, inputs: Dict[Tensor, Dict[str, np.ndarray]], + **kwargs): + output_poses = [] + output_scores = [] + output_boxes = [] + for i in range(inputs[0].shape[0]): + outputs, scores = self.get_pts( + (inputs[0][i]).detach().cpu().numpy()) + outputs = self.pts_transform(inputs[1][i], outputs, 0, 0) + box = np.array(inputs[1][i]['human_box'][0:4]).reshape(2, 2) + outputs = np.array(outputs) + box[0] + output_poses.append(outputs.tolist()) + output_scores.append(scores) + output_boxes.append(box.tolist()) + return output_poses, output_scores, output_boxes + + def image_crop_resize(self, input, margin=[0, 0]): + pad_img = np.zeros((self.input_size[1], self.input_size[0], 3), + dtype=np.uint8) + + h, w, ch = input.shape + + h_new = self.input_size[1] - margin[1] * 2 + w_new = self.input_size[0] - margin[0] * 2 + s0 = float(h_new) / h + s1 = float(w_new) / w + s = min(s0, s1) + w_new = int(s * w) + h_new = int(s * h) + + img_new = cv2.resize(input, (w_new, h_new), cv2.INTER_LINEAR) + + cx = self.input_size[0] // 2 + cy = self.input_size[1] // 2 + + pad_img[cy - h_new // 2:cy - h_new // 2 + h_new, + cx - w_new // 2:cx - w_new // 2 + w_new, :] = img_new + + return pad_img, np.array([cx, cy]), np.array([s, s]), np.array( + [cx - w_new // 2, cy - h_new // 2]) + + def image_transform(self, input: Input) -> Dict[Tensor, Any]: + if isinstance(input, str): + image = cv2.imread(input, -1)[:, :, 0:3] + elif isinstance(input, np.ndarray): + if len(input.shape) == 2: + image = cv2.cvtColor(input, cv2.COLOR_GRAY2BGR) + else: + image = input + image = image[:, :, 0:3] + elif isinstance(input, torch.Tensor): + image = input.cpu().numpy()[:, :, 0:3] + + w, h, _ = image.shape + w_new = self.input_size[0] + h_new = self.input_size[1] + + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + img_resize, c, s, o = self.image_crop_resize(image) + + img_resize = np.float32(img_resize) / 255. + mean = [0.485, 0.456, 0.406] + std = [0.229, 0.224, 0.225] + img_resize = (img_resize - mean) / std + + input_data = np.zeros([1, 3, h_new, w_new], dtype=np.float32) + + img_resize = img_resize.transpose((2, 0, 1)) + input_data[0, :] = img_resize + meta = {'c': c, 's': s, 'o': o} + return [torch.from_numpy(input_data), meta] + + def crop_image(self, image, box): + height, width, _ = image.shape + w, h = box[1] - box[0] + box[0, :] -= (w * self.box_enlarge_ratio, h * self.box_enlarge_ratio) + box[1, :] += (w * self.box_enlarge_ratio, h * self.box_enlarge_ratio) + + box[0, 0] = min(max(box[0, 0], 0.0), width) + box[0, 1] = min(max(box[0, 1], 0.0), height) + box[1, 0] = min(max(box[1, 0], 0.0), width) + box[1, 1] = min(max(box[1, 1], 0.0), height) + + cropped_image = image[int(box[0][1]):int(box[1][1]), + int(box[0][0]):int(box[1][0])] + return cropped_image + + def preprocess(self, input: Dict[Tensor, Tensor]) -> Dict[Tensor, Any]: + bboxes = input[0] + image = input[1] + + lst_human_images = [] + lst_meta = [] + for i in range(len(bboxes)): + box = np.array(bboxes[i][0:4]).reshape(2, 2) + box[1] += box[0] + human_image = self.crop_image(image.clone(), box) + human_image, meta = self.image_transform(human_image) + lst_human_images.append(human_image) + meta['human_box'] = box + lst_meta.append(meta) + + return [torch.cat(lst_human_images, dim=0), lst_meta] diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index f11546b1..2e49dfc5 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -22,6 +22,7 @@ class CVTasks(object): human_detection = 'human-detection' human_object_interaction = 'human-object-interaction' face_image_generation = 'face-image-generation' + body_2d_keypoints = 'body-2d-keypoints' image_classification = 'image-classification' image_multilabel_classification = 'image-multilabel-classification' diff --git a/tests/pipelines/test_body_2d_keypoints.py b/tests/pipelines/test_body_2d_keypoints.py new file mode 100644 index 00000000..9b5bcdee --- /dev/null +++ b/tests/pipelines/test_body_2d_keypoints.py @@ -0,0 +1,100 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import os.path as osp +import pdb +import unittest + +import cv2 +import numpy as np +import torch + +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.pipelines.base import Pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + +lst_parent_ids_17 = [0, 0, 0, 1, 2, 0, 0, 5, 6, 7, 8, 5, 6, 11, 12, 13, 14] +lst_left_ids_17 = [1, 3, 5, 7, 9, 11, 13, 15] +lst_right_ids_17 = [2, 4, 6, 8, 10, 12, 14, 16] +lst_spine_ids_17 = [0] + +lst_parent_ids_15 = [0, 0, 1, 2, 3, 1, 5, 6, 14, 8, 9, 14, 11, 12, 1] +lst_left_ids_15 = [2, 3, 4, 8, 9, 10] +lst_right_ids_15 = [5, 6, 7, 11, 12, 13] +lst_spine_ids_15 = [0, 1, 14] + + +def draw_joints(image, np_kps, score, threshold=0.2): + if np_kps.shape[0] == 17: + lst_parent_ids = lst_parent_ids_17 + lst_left_ids = lst_left_ids_17 + lst_right_ids = lst_right_ids_17 + + elif np_kps.shape[0] == 15: + lst_parent_ids = lst_parent_ids_15 + lst_left_ids = lst_left_ids_15 + lst_right_ids = lst_right_ids_15 + + for i in range(len(lst_parent_ids)): + pid = lst_parent_ids[i] + if i == pid: + continue + + if (score[i] < threshold or score[1] < threshold): + continue + + if i in lst_left_ids and pid in lst_left_ids: + color = (0, 255, 0) + elif i in lst_right_ids and pid in lst_right_ids: + color = (255, 0, 0) + else: + color = (0, 255, 255) + + cv2.line(image, (int(np_kps[i, 0]), int(np_kps[i, 1])), + (int(np_kps[pid][0]), int(np_kps[pid, 1])), color, 3) + + for i in range(np_kps.shape[0]): + if score[i] < threshold: + continue + cv2.circle(image, (int(np_kps[i, 0]), int(np_kps[i, 1])), 5, + (0, 0, 255), -1) + + +def draw_box(image, box): + cv2.rectangle(image, (int(box[0][0]), int(box[0][1])), + (int(box[1][0]), int(box[1][1])), (0, 0, 255), 2) + + +class Body2DKeypointsTest(unittest.TestCase): + + def setUp(self) -> None: + self.model_id = 'damo/cv_hrnetv2w32_body-2d-keypoints_image' + self.test_image = 'data/test/images/keypoints_detect/000000438862.jpg' + self.human_detect_model_id = 'damo/cv_resnet18_human-detection' + + def pipeline_inference(self, pipeline: Pipeline): + output = pipeline(self.test_image) + poses = np.array(output[OutputKeys.POSES]) + scores = np.array(output[OutputKeys.SCORES]) + boxes = np.array(output[OutputKeys.BOXES]) + assert len(poses) == len(scores) and len(poses) == len(boxes) + image = cv2.imread(self.test_image, -1) + for i in range(len(poses)): + draw_box(image, np.array(boxes[i])) + draw_joints(image, np.array(poses[i]), np.array(scores[i])) + cv2.imwrite('pose_keypoint.jpg', image) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_modelhub(self): + human_detector = pipeline( + Tasks.human_detection, model=self.human_detect_model_id) + body_2d_keypoints = pipeline( + Tasks.body_2d_keypoints, + human_detector=human_detector, + model=self.model_id) + self.pipeline_inference(body_2d_keypoints) + + +if __name__ == '__main__': + unittest.main()