From 543a2b32d79dc0c87fbfe67e0fb87b2a6fcf88c8 Mon Sep 17 00:00:00 2001 From: Ranqing Date: Wed, 31 Jul 2024 10:19:41 +0800 Subject: [PATCH] upload high quality human normal estimation model (#903) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * upload human normal estimation model * upload human normal estimation unittest * update human normal estimation * update test data --------- Co-authored-by: 葭润 Co-authored-by: ranqing Co-authored-by: suluyana --- data/test | 2 +- modelscope/metainfo.py | 11 +- .../cv/human_normal_estimation/__init__.py | 22 ++ .../cv/human_normal_estimation/human_nnet.py | 80 +++++++ .../networks/__init__.py | 0 .../networks/config.py | 40 ++++ .../human_normal_estimation/networks/nnet.py | 125 ++++++++++ .../networks/submodules.py | 214 ++++++++++++++++++ modelscope/pipelines/cv/__init__.py | 2 + .../cv/human_normal_estimation_pipeline.py | 95 ++++++++ modelscope/utils/constant.py | 2 + modelscope/utils/pipeline_schema.json | 7 + .../pipelines/test_human_normal_estimation.py | 37 +++ 13 files changed, 633 insertions(+), 4 deletions(-) create mode 100644 modelscope/models/cv/human_normal_estimation/__init__.py create mode 100644 modelscope/models/cv/human_normal_estimation/human_nnet.py create mode 100644 modelscope/models/cv/human_normal_estimation/networks/__init__.py create mode 100644 modelscope/models/cv/human_normal_estimation/networks/config.py create mode 100644 modelscope/models/cv/human_normal_estimation/networks/nnet.py create mode 100644 modelscope/models/cv/human_normal_estimation/networks/submodules.py create mode 100644 modelscope/pipelines/cv/human_normal_estimation_pipeline.py create mode 100644 tests/pipelines/test_human_normal_estimation.py diff --git a/data/test b/data/test index 7a7f6b8d..dedb3ce4 160000 --- a/data/test +++ b/data/test @@ -1 +1 @@ -Subproject commit 7a7f6b8d05ba8af4ea42096391fa727d358e585e +Subproject commit dedb3ce44796328b58a2aa47d3434037a9d63c7f diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 16bf679a..28aea889 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -58,6 +58,7 @@ class Models(object): s2net_depth_estimation = 's2net-depth-estimation' dro_resnet18_depth_estimation = 'dro-resnet18-depth-estimation' raft_dense_optical_flow_estimation = 'raft-dense-optical-flow-estimation' + human_normal_estimation = 'human-normal-estimation' resnet50_bert = 'resnet50-bert' referring_video_object_segmentation = 'swinT-referring-video-object-segmentation' fer = 'fer' @@ -480,6 +481,7 @@ class Pipelines(object): anydoor = 'anydoor' image_to_3d = 'image-to-3d' self_supervised_depth_completion = 'self-supervised-depth-completion' + human_normal_estimation = 'human-normal-estimation' # nlp tasks automatic_post_editing = 'automatic-post-editing' @@ -814,6 +816,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { Tasks.image_normal_estimation: (Pipelines.image_normal_estimation, 'Damo_XR_Lab/cv_omnidata_image-normal-estimation_normal'), + Tasks.human_normal_estimation: + (Pipelines.human_normal_estimation, + 'Damo_XR_Lab/cv_human_monocular-normal-estimation'), Tasks.indoor_layout_estimation: (Pipelines.indoor_layout_estimation, 'damo/cv_panovit_indoor-layout-estimation'), @@ -846,9 +851,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { Tasks.image_to_image_generation: (Pipelines.image_to_image_generation, 'damo/cv_latent_diffusion_image2image_generate'), - Tasks.image_classification: - (Pipelines.daily_image_classification, - 'damo/cv_vit-base_image-classification_Dailylife-labels'), + Tasks.image_classification: ( + Pipelines.daily_image_classification, + 'damo/cv_vit-base_image-classification_Dailylife-labels'), Tasks.image_object_detection: ( Pipelines.image_object_detection_auto, 'damo/cv_yolox_image-object-detection-auto'), diff --git a/modelscope/models/cv/human_normal_estimation/__init__.py b/modelscope/models/cv/human_normal_estimation/__init__.py new file mode 100644 index 00000000..f176c6bf --- /dev/null +++ b/modelscope/models/cv/human_normal_estimation/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .human_nnet import HumanNormalEstimation + +else: + _import_structure = { + 'human_nnet': ['HumanNormalEstimation'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/human_normal_estimation/human_nnet.py b/modelscope/models/cv/human_normal_estimation/human_nnet.py new file mode 100644 index 00000000..6621c8d3 --- /dev/null +++ b/modelscope/models/cv/human_normal_estimation/human_nnet.py @@ -0,0 +1,80 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os + +import numpy as np +import torch +import torchvision.transforms as T + +from modelscope.metainfo import Models +from modelscope.models.base.base_torch_model import TorchModel +from modelscope.models.builder import MODELS +from modelscope.models.cv.human_normal_estimation.networks import config, nnet +from modelscope.outputs import OutputKeys +from modelscope.utils.constant import Tasks + + +@MODELS.register_module( + Tasks.human_normal_estimation, module_name=Models.human_normal_estimation) +class HumanNormalEstimation(TorchModel): + + def __init__(self, model_dir: str, **kwargs): + super().__init__(model_dir, **kwargs) + config_file = os.path.join(model_dir, 'config.txt') + args = config.get_args(txt_file=config_file) + args.encoder_path = os.path.join(model_dir, args.encoder_path) + + self.device = torch.device( + 'cuda:0') if torch.cuda.is_available() else torch.device('cpu') + self.nnet = nnet.NormalNet(args=args).to(self.device) + self.nnet_path = os.path.join(model_dir, 'ckpt/best_nnet.pt') + if os.path.exists(self.nnet_path): + ckpt = torch.load( + self.nnet_path, map_location=self.device)['model'] + load_dict = {} + for k, v in ckpt.items(): + if k.startswith('module.'): + k_ = k.replace('module.', '') + load_dict[k_] = v + else: + load_dict[k] = v + self.nnet.load_state_dict(load_dict) + self.nnet.eval() + + self.normalize = T.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + + def forward(self, inputs): + img = inputs['img'].astype(np.float32) / 255.0 + msk = inputs['msk'].astype(np.float32) / 255.0 + bbox = inputs['bbox'] + + img_h, img_w = img.shape[0:2] + img = torch.from_numpy(img).permute(2, 0, + 1).unsqueeze(0).to(self.device) + img = self.normalize(img) + + fx = fy = (max(img_h, img_h) / 2.0) / np.tan(np.deg2rad(60.0 / 2.0)) + cx = (img_h / 2.0) - 0.5 + cy = (img_w / 2.0) - 0.5 + + intrins = torch.tensor( + [[fx, 0, cx + 0.5], [0, fy, cy + 0.5], [0, 0, 1]], + dtype=torch.float32, + device=self.device).unsqueeze(0) + + pred_norm = self.nnet(img, intrins=intrins)[-1] + pred_norm = pred_norm.detach().cpu().permute(0, 2, 3, 1).numpy() + pred_norm = pred_norm[0, ...] + pred_norm = pred_norm * msk[..., None] + pred_norm = pred_norm[bbox[1]:bbox[3], bbox[0]:bbox[2]] + results = pred_norm + return results + + def postprocess(self, inputs): + normal_result = inputs + results = {OutputKeys.NORMALS: normal_result} + return results + + def inference(self, data): + results = self.forward(data) + return results diff --git a/modelscope/models/cv/human_normal_estimation/networks/__init__.py b/modelscope/models/cv/human_normal_estimation/networks/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/cv/human_normal_estimation/networks/config.py b/modelscope/models/cv/human_normal_estimation/networks/config.py new file mode 100644 index 00000000..1a488309 --- /dev/null +++ b/modelscope/models/cv/human_normal_estimation/networks/config.py @@ -0,0 +1,40 @@ +import argparse + + +def convert_arg_line_to_args(arg_line): + for arg in arg_line.split(): + if not arg.strip(): + continue + yield str(arg) + + +def get_args(txt_file=None): + parser = argparse.ArgumentParser( + fromfile_prefix_chars='@', conflict_handler='resolve') + parser.convert_arg_line_to_args = convert_arg_line_to_args + + # checkpoint (only needed when testing the model) + parser.add_argument('--ckpt_path', type=str, default=None) + parser.add_argument('--encoder_path', type=str, default=None) + + # ↓↓↓↓ + # NOTE: project-specific args + parser.add_argument('--output_dim', type=int, default=3, help='{3, 4}') + parser.add_argument('--output_type', type=str, default='R', help='{R, G}') + parser.add_argument('--feature_dim', type=int, default=64) + parser.add_argument('--hidden_dim', type=int, default=64) + + parser.add_argument('--encoder_B', type=int, default=5) + + parser.add_argument('--decoder_NF', type=int, default=2048) + parser.add_argument('--decoder_BN', default=False, action='store_true') + parser.add_argument('--decoder_down', type=int, default=2) + parser.add_argument( + '--learned_upsampling', default=False, action='store_true') + + # read arguments from txt file + if txt_file: + config_filename = '@' + txt_file + + args = parser.parse_args([config_filename]) + return args diff --git a/modelscope/models/cv/human_normal_estimation/networks/nnet.py b/modelscope/models/cv/human_normal_estimation/networks/nnet.py new file mode 100644 index 00000000..e10e97c9 --- /dev/null +++ b/modelscope/models/cv/human_normal_estimation/networks/nnet.py @@ -0,0 +1,125 @@ +import os +import sys + +import torch +import torch.nn as nn + +from .submodules import (Encoder, UpSampleBN, UpSampleGN, get_pixel_coords, + get_prediction_head, normal_activation, + upsample_via_bilinear, upsample_via_mask) + +PROJECT_DIR = os.path.split(os.path.dirname(os.path.realpath(__file__)))[0] +sys.path.append(PROJECT_DIR) + + +class NormalNet(nn.Module): + + def __init__(self, args): + super(NormalNet, self).__init__() + B = args.encoder_B + NF = args.decoder_NF + BN = args.decoder_BN + learned_upsampling = args.learned_upsampling + + self.encoder = Encoder(B=B, pretrained=False, ckpt=args.encoder_path) + self.decoder = Decoder( + num_classes=args.output_dim, + B=B, + NF=NF, + BN=BN, + learned_upsampling=learned_upsampling) + + def forward(self, x, **kwargs): + return self.decoder(self.encoder(x), **kwargs) + + +class Decoder(nn.Module): + + def __init__(self, + num_classes=3, + B=5, + NF=2048, + BN=False, + learned_upsampling=True): + super(Decoder, self).__init__() + input_channels = [2048, 176, 64, 40, 24] + + UpSample = UpSampleBN if BN else UpSampleGN + features = NF + + self.conv2 = nn.Conv2d( + input_channels[0] + 2, + features, + kernel_size=1, + stride=1, + padding=0) + self.up1 = UpSample( + skip_input=features // 1 + input_channels[1] + 2, + output_features=features // 2, + align_corners=False) + self.up2 = UpSample( + skip_input=features // 2 + input_channels[2] + 2, + output_features=features // 4, + align_corners=False) + self.up3 = UpSample( + skip_input=features // 4 + input_channels[3] + 2, + output_features=features // 8, + align_corners=False) + self.up4 = UpSample( + skip_input=features // 8 + input_channels[4] + 2, + output_features=features // 16, + align_corners=False) + i_dim = features // 16 + + self.downsample_ratio = 2 + self.output_dim = num_classes + + self.pred_head = get_prediction_head(i_dim + 2, 128, num_classes) + if learned_upsampling: + self.mask_head = get_prediction_head( + i_dim + 2, 128, + 9 * self.downsample_ratio * self.downsample_ratio) + self.upsample_fn = upsample_via_mask + else: + self.mask_head = lambda a: None + self.upsample_fn = upsample_via_bilinear + + self.pixel_coords = get_pixel_coords(h=1024, w=1024).to(0) + + def ray_embedding(self, x, intrins, orig_H, orig_W): + B, _, H, W = x.shape + fu = intrins[:, 0, 0].unsqueeze(-1).unsqueeze(-1) * (W / orig_W) + cu = intrins[:, 0, 2].unsqueeze(-1).unsqueeze(-1) * (W / orig_W) + fv = intrins[:, 1, 1].unsqueeze(-1).unsqueeze(-1) * (H / orig_H) + cv = intrins[:, 1, 2].unsqueeze(-1).unsqueeze(-1) * (H / orig_H) + + uv = self.pixel_coords[:, :2, :H, :W].repeat(B, 1, 1, 1) + uv[:, 0, :, :] = (uv[:, 0, :, :] - cu) / fu + uv[:, 1, :, :] = (uv[:, 1, :, :] - cv) / fv + return torch.cat([x, uv], dim=1) + + def forward(self, features, intrins): + x_block0, x_block1, x_block2, x_block3, x_block4 = features[4], features[5], features[6], \ + features[8], features[11] + _, _, orig_H, orig_W = features[0].shape + + x_d0 = self.conv2( + self.ray_embedding(x_block4, intrins, orig_H, orig_W)) + x_d1 = self.up1(x_d0, + self.ray_embedding(x_block3, intrins, orig_H, orig_W)) + x_d2 = self.up2(x_d1, + self.ray_embedding(x_block2, intrins, orig_H, orig_W)) + x_d3 = self.up3(x_d2, + self.ray_embedding(x_block1, intrins, orig_H, orig_W)) + x_feat = self.up4( + x_d3, self.ray_embedding(x_block0, intrins, orig_H, orig_W)) + + out = self.pred_head( + self.ray_embedding(x_feat, intrins, orig_H, orig_W)) + out = normal_activation(out, elu_kappa=True) + mask = self.mask_head( + self.ray_embedding(x_feat, intrins, orig_H, orig_W)) + up_out = self.upsample_fn( + out, up_mask=mask, downsample_ratio=self.downsample_ratio) + up_out = normal_activation(up_out, elu_kappa=False) + return [up_out] diff --git a/modelscope/models/cv/human_normal_estimation/networks/submodules.py b/modelscope/models/cv/human_normal_estimation/networks/submodules.py new file mode 100644 index 00000000..32fbd011 --- /dev/null +++ b/modelscope/models/cv/human_normal_estimation/networks/submodules.py @@ -0,0 +1,214 @@ +import geffnet +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +INPUT_CHANNELS_DICT = { + 0: [1280, 112, 40, 24, 16], + 1: [1280, 112, 40, 24, 16], + 2: [1408, 120, 48, 24, 16], + 3: [1536, 136, 48, 32, 24], + 4: [1792, 160, 56, 32, 24], + 5: [2048, 176, 64, 40, 24], + 6: [2304, 200, 72, 40, 32], + 7: [2560, 224, 80, 48, 32] +} + + +class Encoder(nn.Module): + + def __init__(self, B=5, pretrained=True, ckpt=None): + super(Encoder, self).__init__() + if ckpt: + basemodel = geffnet.create_model( + 'tf_efficientnet_b%s_ap' % B, + pretrained=pretrained, + checkpoint_path=ckpt) + else: + basemodel = geffnet.create_model( + 'tf_efficientnet_b%s_ap' % B, pretrained=pretrained) + + basemodel.global_pool = nn.Identity() + basemodel.classifier = nn.Identity() + self.original_model = basemodel + + def forward(self, x): + features = [x] + for k, v in self.original_model._modules.items(): + if k == 'blocks': + for ki, vi in v._modules.items(): + features.append(vi(features[-1])) + else: + features.append(v(features[-1])) + return features + + +class ConvGRU(nn.Module): + + def __init__(self, hidden_dim, input_dim, ks=3): + super().__init__() + p = (ks - 1) // 2 + self.convz = nn.Conv2d( + hidden_dim + input_dim, hidden_dim, ks, padding=p) + self.convr = nn.Conv2d( + hidden_dim + input_dim, hidden_dim, ks, padding=p) + self.convq = nn.Conv2d( + hidden_dim + input_dim, hidden_dim, ks, padding=p) + + def forward(self, h, x): + hx = torch.cat([h, x], dim=1) + z = torch.sigmoid(self.convz(hx)) + r = torch.sigmoid(self.convr(hx)) + q = torch.tanh(self.convq(torch.cat([r * h, x], dim=1))) + h = (1 - z) * h + z * q + return h + + +class UpSampleBN(nn.Module): + + def __init__(self, skip_input, output_features, align_corners=True): + super(UpSampleBN, self).__init__() + self._net = nn.Sequential( + nn.Conv2d( + skip_input, + output_features, + kernel_size=3, + stride=1, + padding=1), nn.BatchNorm2d(output_features), nn.LeakyReLU(), + nn.Conv2d( + output_features, + output_features, + kernel_size=3, + stride=1, + padding=1), nn.BatchNorm2d(output_features), nn.LeakyReLU()) + self.align_corners = align_corners + + def forward(self, x, concat_with): + up_x = F.interpolate( + x, + size=[concat_with.size(2), + concat_with.size(3)], + mode='bilinear', + align_corners=self.align_corners) + f = torch.cat([up_x, concat_with], dim=1) + return self._net(f) + + +class Conv2d_WS(nn.Conv2d): + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True): + super(Conv2d_WS, + self).__init__(in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, bias) + + def forward(self, x): + weight = self.weight + weight_mean = weight.mean( + dim=1, keepdim=True).mean( + dim=2, keepdim=True).mean( + dim=3, keepdim=True) + weight = weight - weight_mean + std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, + 1) + 1e-5 + weight = weight / std.expand_as(weight) + return F.conv2d(x, weight, self.bias, self.stride, self.padding, + self.dilation, self.groups) + + +class UpSampleGN(nn.Module): + + def __init__(self, skip_input, output_features, align_corners=True): + super(UpSampleGN, self).__init__() + self._net = nn.Sequential( + Conv2d_WS( + skip_input, + output_features, + kernel_size=3, + stride=1, + padding=1), nn.GroupNorm(8, output_features), nn.LeakyReLU(), + Conv2d_WS( + output_features, + output_features, + kernel_size=3, + stride=1, + padding=1), nn.GroupNorm(8, output_features), nn.LeakyReLU()) + self.align_corners = align_corners + + def forward(self, x, concat_with): + up_x = F.interpolate( + x, + size=[concat_with.size(2), + concat_with.size(3)], + mode='bilinear', + align_corners=self.align_corners) + f = torch.cat([up_x, concat_with], dim=1) + return self._net(f) + + +def upsample_via_bilinear(out, up_mask=None, downsample_ratio=None): + return F.interpolate( + out, + scale_factor=downsample_ratio, + mode='bilinear', + align_corners=False) + + +def upsample_via_mask(out, up_mask, downsample_ratio, padding='zero'): + """ + convex upsampling + """ + # out: low-resolution output (B, o_dim, H, W) + # up_mask: (B, 9*k*k, H, W) + k = downsample_ratio + + B, C, H, W = out.shape + up_mask = up_mask.view(B, 1, 9, k, k, H, W) + up_mask = torch.softmax(up_mask, dim=2) # (B, 1, 9, k, k, H, W) + + if padding == 'zero': + up_out = F.unfold(out, [3, 3], padding=1) + elif padding == 'replicate': + out = F.pad(out, pad=(1, 1, 1, 1), mode='replicate') + up_out = F.unfold(out, [3, 3], padding=0) + else: + raise Exception('invalid padding for convex upsampling') + + up_out = up_out.view(B, C, 9, 1, 1, H, W) + + up_out = torch.sum(up_mask * up_out, dim=2) + up_out = up_out.permute(0, 1, 4, 2, 5, 3) + return up_out.reshape(B, C, k * H, k * W) + + +def get_prediction_head(input_dim, hidden_dim, output_dim): + return nn.Sequential( + nn.Conv2d(input_dim, hidden_dim, 3, padding=1), nn.ReLU(inplace=True), + nn.Conv2d(hidden_dim, hidden_dim, 1), nn.ReLU(inplace=True), + nn.Conv2d(hidden_dim, output_dim, 1)) + + +# submodules copy from DSINE +def get_pixel_coords(h, w): + pixel_coords = np.ones((3, h, w)).astype(np.float32) + x_range = np.concatenate([np.arange(w).reshape(1, w)] * h, axis=0) + y_range = np.concatenate([np.arange(h).reshape(h, 1)] * w, axis=1) + pixel_coords[0, :, :] = x_range + 0.5 + pixel_coords[1, :, :] = y_range + 0.5 + return torch.from_numpy(pixel_coords).unsqueeze(0) + + +def normal_activation(out, elu_kappa=True): + normal, kappa = out[:, :3, :, :], out[:, 3:, :, :] + normal = F.normalize(normal, p=2, dim=1) + if elu_kappa: + kappa = F.elu(kappa) + 1.0 + return torch.cat([normal, kappa], dim=1) diff --git a/modelscope/pipelines/cv/__init__.py b/modelscope/pipelines/cv/__init__.py index d987e989..530c86a9 100644 --- a/modelscope/pipelines/cv/__init__.py +++ b/modelscope/pipelines/cv/__init__.py @@ -123,6 +123,7 @@ if TYPE_CHECKING: from .anydoor_pipeline import AnydoorPipeline from .image_depth_estimation_marigold_pipeline import ImageDepthEstimationMarigoldPipeline from .self_supervised_depth_completion_pipeline import SelfSupervisedDepthCompletionPipeline + from .human_normal_estimation_pipeline import HumanNormalEstimationPipeline else: _import_structure = { @@ -312,6 +313,7 @@ else: 'self_supervised_depth_completion_pipeline': [ 'SelfSupervisedDepthCompletionPipeline' ], + 'human_normal_estimation_pipeline': ['HumanNormalEstimationPipeline'], } import sys diff --git a/modelscope/pipelines/cv/human_normal_estimation_pipeline.py b/modelscope/pipelines/cv/human_normal_estimation_pipeline.py new file mode 100644 index 00000000..bd19b18d --- /dev/null +++ b/modelscope/pipelines/cv/human_normal_estimation_pipeline.py @@ -0,0 +1,95 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict + +import numpy as np +from PIL import Image + +from modelscope.metainfo import Pipelines +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.human_normal_estimation, + module_name=Pipelines.human_normal_estimation) +class HumanNormalEstimationPipeline(Pipeline): + r""" Human Normal Estimation Pipeline. + + Examples: + + >>> from modelscope.pipelines import pipeline + + >>> estimator = pipeline( + >>> Tasks.human_normal_estimation, model='Damo_XR_Lab/cv_human_monocular-normal-estimation') + >>> estimator(f"{model_dir}/tests/image_normal_estimation.jpg") + """ + + def __init__(self, model: str, **kwargs): + """ + use `model` to create a image normal estimation pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model, **kwargs) + logger.info('normal estimation model, pipeline init') + + def preprocess(self, input: Input) -> Dict[str, Any]: + """ + + Args: + input: string or ndarray or Image.Image + + Returns: + data: dict including inference inputs + """ + if isinstance(input, str): + img = np.array(Image.open(input)) + if isinstance(input, Image.Image): + img = np.array(input) + + img_h, img_w, img_ch = img.shape[0:3] + + if img_ch == 3: + msk = np.full((img_h, img_w, 1), 255, dtype=np.uint8) + img = np.concatenate((img, msk), axis=-1) + + H, W = 1024, 1024 + scale_factor = min(W / img_w, H / img_h) + img = Image.fromarray(img) + img = img.resize( + (int(img_w * scale_factor), int(img_h * scale_factor)), + Image.LANCZOS) + + new_img = Image.new('RGBA', (W, H), color=(0, 0, 0, 0)) + paste_pos_w = (W - img.width) // 2 + paste_pos_h = (H - img.height) // 2 + new_img.paste(img, (paste_pos_w, paste_pos_h)) + + bbox = (paste_pos_w, paste_pos_h, paste_pos_w + img.width, + paste_pos_h + img.height) + img = np.array(new_img) + + data = {'img': img[:, :, 0:3], 'msk': img[:, :, -1], 'bbox': bbox} + + return data + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + results = self.model.inference(input) + return results + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + results = self.model.postprocess(inputs) + normals = results[OutputKeys.NORMALS] + + normals_vis = (((normals + 1) * 0.5) * 255).astype(np.uint8) + normals_vis = normals_vis[..., [2, 1, 0]] + outputs = { + OutputKeys.NORMALS: normals, + OutputKeys.NORMALS_COLOR: normals_vis + } + return outputs diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index 3570c0cb..9fcaf71c 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -78,6 +78,8 @@ class CVTasks(object): image_local_feature_matching = 'image-local-feature-matching' image_quality_assessment_degradation = 'image-quality-assessment-degradation' + human_normal_estimation = 'human-normal-estimation' + crowd_counting = 'crowd-counting' # image editing diff --git a/modelscope/utils/pipeline_schema.json b/modelscope/utils/pipeline_schema.json index 1f002567..92f9dd59 100644 --- a/modelscope/utils/pipeline_schema.json +++ b/modelscope/utils/pipeline_schema.json @@ -1179,6 +1179,13 @@ "type": "object" } }, + "human-normal-estimation": { + "input": {}, + "parameters": {}, + "output": { + "type": "object" + } + }, "image-driving-perception": { "input": { "type": "object", diff --git a/tests/pipelines/test_human_normal_estimation.py b/tests/pipelines/test_human_normal_estimation.py new file mode 100644 index 00000000..a2699c28 --- /dev/null +++ b/tests/pipelines/test_human_normal_estimation.py @@ -0,0 +1,37 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import os.path +import unittest + +import cv2 + +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + + +class HumanNormalEstimationTest(unittest.TestCase): + + def setUp(self) -> None: + self.task = 'human-normal-estimation' + self.model_id = 'Damo_XR_Lab/cv_human_monocular-normal-estimation' + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_image_normal_estimation(self): + cur_dir = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + input_location = f'{cur_dir}/data/test/images/human_normal_estimation.png' + estimator = pipeline( + Tasks.human_normal_estimation, model=self.model_id) + result = estimator(input_location) + normals_vis = result[OutputKeys.NORMALS_COLOR] + + input_img = cv2.imread(input_location) + normals_vis = cv2.resize( + normals_vis, dsize=(input_img.shape[1], input_img.shape[0])) + cv2.imwrite('result.jpg', normals_vis) + + +if __name__ == '__main__': + unittest.main()