mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 16:27:45 +01:00
upload high quality human normal estimation model (#903)
* upload human normal estimation model * upload human normal estimation unittest * update human normal estimation * update test data --------- Co-authored-by: 葭润 <ranqing.rq@alibaba-inc.com> Co-authored-by: ranqing <ranqing@Sundays-Mac.local> Co-authored-by: suluyana <suluyan_sly@163.com>
This commit is contained in:
Submodule data/test updated: 7a7f6b8d05...dedb3ce447
@@ -58,6 +58,7 @@ class Models(object):
|
|||||||
s2net_depth_estimation = 's2net-depth-estimation'
|
s2net_depth_estimation = 's2net-depth-estimation'
|
||||||
dro_resnet18_depth_estimation = 'dro-resnet18-depth-estimation'
|
dro_resnet18_depth_estimation = 'dro-resnet18-depth-estimation'
|
||||||
raft_dense_optical_flow_estimation = 'raft-dense-optical-flow-estimation'
|
raft_dense_optical_flow_estimation = 'raft-dense-optical-flow-estimation'
|
||||||
|
human_normal_estimation = 'human-normal-estimation'
|
||||||
resnet50_bert = 'resnet50-bert'
|
resnet50_bert = 'resnet50-bert'
|
||||||
referring_video_object_segmentation = 'swinT-referring-video-object-segmentation'
|
referring_video_object_segmentation = 'swinT-referring-video-object-segmentation'
|
||||||
fer = 'fer'
|
fer = 'fer'
|
||||||
@@ -480,6 +481,7 @@ class Pipelines(object):
|
|||||||
anydoor = 'anydoor'
|
anydoor = 'anydoor'
|
||||||
image_to_3d = 'image-to-3d'
|
image_to_3d = 'image-to-3d'
|
||||||
self_supervised_depth_completion = 'self-supervised-depth-completion'
|
self_supervised_depth_completion = 'self-supervised-depth-completion'
|
||||||
|
human_normal_estimation = 'human-normal-estimation'
|
||||||
|
|
||||||
# nlp tasks
|
# nlp tasks
|
||||||
automatic_post_editing = 'automatic-post-editing'
|
automatic_post_editing = 'automatic-post-editing'
|
||||||
@@ -814,6 +816,9 @@ DEFAULT_MODEL_FOR_PIPELINE = {
|
|||||||
Tasks.image_normal_estimation:
|
Tasks.image_normal_estimation:
|
||||||
(Pipelines.image_normal_estimation,
|
(Pipelines.image_normal_estimation,
|
||||||
'Damo_XR_Lab/cv_omnidata_image-normal-estimation_normal'),
|
'Damo_XR_Lab/cv_omnidata_image-normal-estimation_normal'),
|
||||||
|
Tasks.human_normal_estimation:
|
||||||
|
(Pipelines.human_normal_estimation,
|
||||||
|
'Damo_XR_Lab/cv_human_monocular-normal-estimation'),
|
||||||
Tasks.indoor_layout_estimation:
|
Tasks.indoor_layout_estimation:
|
||||||
(Pipelines.indoor_layout_estimation,
|
(Pipelines.indoor_layout_estimation,
|
||||||
'damo/cv_panovit_indoor-layout-estimation'),
|
'damo/cv_panovit_indoor-layout-estimation'),
|
||||||
@@ -846,9 +851,9 @@ DEFAULT_MODEL_FOR_PIPELINE = {
|
|||||||
Tasks.image_to_image_generation:
|
Tasks.image_to_image_generation:
|
||||||
(Pipelines.image_to_image_generation,
|
(Pipelines.image_to_image_generation,
|
||||||
'damo/cv_latent_diffusion_image2image_generate'),
|
'damo/cv_latent_diffusion_image2image_generate'),
|
||||||
Tasks.image_classification:
|
Tasks.image_classification: (
|
||||||
(Pipelines.daily_image_classification,
|
Pipelines.daily_image_classification,
|
||||||
'damo/cv_vit-base_image-classification_Dailylife-labels'),
|
'damo/cv_vit-base_image-classification_Dailylife-labels'),
|
||||||
Tasks.image_object_detection: (
|
Tasks.image_object_detection: (
|
||||||
Pipelines.image_object_detection_auto,
|
Pipelines.image_object_detection_auto,
|
||||||
'damo/cv_yolox_image-object-detection-auto'),
|
'damo/cv_yolox_image-object-detection-auto'),
|
||||||
|
|||||||
22
modelscope/models/cv/human_normal_estimation/__init__.py
Normal file
22
modelscope/models/cv/human_normal_estimation/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from modelscope.utils.import_utils import LazyImportModule
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .human_nnet import HumanNormalEstimation
|
||||||
|
|
||||||
|
else:
|
||||||
|
_import_structure = {
|
||||||
|
'human_nnet': ['HumanNormalEstimation'],
|
||||||
|
}
|
||||||
|
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.modules[__name__] = LazyImportModule(
|
||||||
|
__name__,
|
||||||
|
globals()['__file__'],
|
||||||
|
_import_structure,
|
||||||
|
module_spec=__spec__,
|
||||||
|
extra_objects={},
|
||||||
|
)
|
||||||
80
modelscope/models/cv/human_normal_estimation/human_nnet.py
Normal file
80
modelscope/models/cv/human_normal_estimation/human_nnet.py
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||||
|
import os
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torchvision.transforms as T
|
||||||
|
|
||||||
|
from modelscope.metainfo import Models
|
||||||
|
from modelscope.models.base.base_torch_model import TorchModel
|
||||||
|
from modelscope.models.builder import MODELS
|
||||||
|
from modelscope.models.cv.human_normal_estimation.networks import config, nnet
|
||||||
|
from modelscope.outputs import OutputKeys
|
||||||
|
from modelscope.utils.constant import Tasks
|
||||||
|
|
||||||
|
|
||||||
|
@MODELS.register_module(
|
||||||
|
Tasks.human_normal_estimation, module_name=Models.human_normal_estimation)
|
||||||
|
class HumanNormalEstimation(TorchModel):
|
||||||
|
|
||||||
|
def __init__(self, model_dir: str, **kwargs):
|
||||||
|
super().__init__(model_dir, **kwargs)
|
||||||
|
config_file = os.path.join(model_dir, 'config.txt')
|
||||||
|
args = config.get_args(txt_file=config_file)
|
||||||
|
args.encoder_path = os.path.join(model_dir, args.encoder_path)
|
||||||
|
|
||||||
|
self.device = torch.device(
|
||||||
|
'cuda:0') if torch.cuda.is_available() else torch.device('cpu')
|
||||||
|
self.nnet = nnet.NormalNet(args=args).to(self.device)
|
||||||
|
self.nnet_path = os.path.join(model_dir, 'ckpt/best_nnet.pt')
|
||||||
|
if os.path.exists(self.nnet_path):
|
||||||
|
ckpt = torch.load(
|
||||||
|
self.nnet_path, map_location=self.device)['model']
|
||||||
|
load_dict = {}
|
||||||
|
for k, v in ckpt.items():
|
||||||
|
if k.startswith('module.'):
|
||||||
|
k_ = k.replace('module.', '')
|
||||||
|
load_dict[k_] = v
|
||||||
|
else:
|
||||||
|
load_dict[k] = v
|
||||||
|
self.nnet.load_state_dict(load_dict)
|
||||||
|
self.nnet.eval()
|
||||||
|
|
||||||
|
self.normalize = T.Normalize(
|
||||||
|
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||||
|
|
||||||
|
def forward(self, inputs):
|
||||||
|
img = inputs['img'].astype(np.float32) / 255.0
|
||||||
|
msk = inputs['msk'].astype(np.float32) / 255.0
|
||||||
|
bbox = inputs['bbox']
|
||||||
|
|
||||||
|
img_h, img_w = img.shape[0:2]
|
||||||
|
img = torch.from_numpy(img).permute(2, 0,
|
||||||
|
1).unsqueeze(0).to(self.device)
|
||||||
|
img = self.normalize(img)
|
||||||
|
|
||||||
|
fx = fy = (max(img_h, img_h) / 2.0) / np.tan(np.deg2rad(60.0 / 2.0))
|
||||||
|
cx = (img_h / 2.0) - 0.5
|
||||||
|
cy = (img_w / 2.0) - 0.5
|
||||||
|
|
||||||
|
intrins = torch.tensor(
|
||||||
|
[[fx, 0, cx + 0.5], [0, fy, cy + 0.5], [0, 0, 1]],
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=self.device).unsqueeze(0)
|
||||||
|
|
||||||
|
pred_norm = self.nnet(img, intrins=intrins)[-1]
|
||||||
|
pred_norm = pred_norm.detach().cpu().permute(0, 2, 3, 1).numpy()
|
||||||
|
pred_norm = pred_norm[0, ...]
|
||||||
|
pred_norm = pred_norm * msk[..., None]
|
||||||
|
pred_norm = pred_norm[bbox[1]:bbox[3], bbox[0]:bbox[2]]
|
||||||
|
results = pred_norm
|
||||||
|
return results
|
||||||
|
|
||||||
|
def postprocess(self, inputs):
|
||||||
|
normal_result = inputs
|
||||||
|
results = {OutputKeys.NORMALS: normal_result}
|
||||||
|
return results
|
||||||
|
|
||||||
|
def inference(self, data):
|
||||||
|
results = self.forward(data)
|
||||||
|
return results
|
||||||
@@ -0,0 +1,40 @@
|
|||||||
|
import argparse
|
||||||
|
|
||||||
|
|
||||||
|
def convert_arg_line_to_args(arg_line):
|
||||||
|
for arg in arg_line.split():
|
||||||
|
if not arg.strip():
|
||||||
|
continue
|
||||||
|
yield str(arg)
|
||||||
|
|
||||||
|
|
||||||
|
def get_args(txt_file=None):
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
fromfile_prefix_chars='@', conflict_handler='resolve')
|
||||||
|
parser.convert_arg_line_to_args = convert_arg_line_to_args
|
||||||
|
|
||||||
|
# checkpoint (only needed when testing the model)
|
||||||
|
parser.add_argument('--ckpt_path', type=str, default=None)
|
||||||
|
parser.add_argument('--encoder_path', type=str, default=None)
|
||||||
|
|
||||||
|
# ↓↓↓↓
|
||||||
|
# NOTE: project-specific args
|
||||||
|
parser.add_argument('--output_dim', type=int, default=3, help='{3, 4}')
|
||||||
|
parser.add_argument('--output_type', type=str, default='R', help='{R, G}')
|
||||||
|
parser.add_argument('--feature_dim', type=int, default=64)
|
||||||
|
parser.add_argument('--hidden_dim', type=int, default=64)
|
||||||
|
|
||||||
|
parser.add_argument('--encoder_B', type=int, default=5)
|
||||||
|
|
||||||
|
parser.add_argument('--decoder_NF', type=int, default=2048)
|
||||||
|
parser.add_argument('--decoder_BN', default=False, action='store_true')
|
||||||
|
parser.add_argument('--decoder_down', type=int, default=2)
|
||||||
|
parser.add_argument(
|
||||||
|
'--learned_upsampling', default=False, action='store_true')
|
||||||
|
|
||||||
|
# read arguments from txt file
|
||||||
|
if txt_file:
|
||||||
|
config_filename = '@' + txt_file
|
||||||
|
|
||||||
|
args = parser.parse_args([config_filename])
|
||||||
|
return args
|
||||||
125
modelscope/models/cv/human_normal_estimation/networks/nnet.py
Normal file
125
modelscope/models/cv/human_normal_estimation/networks/nnet.py
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from .submodules import (Encoder, UpSampleBN, UpSampleGN, get_pixel_coords,
|
||||||
|
get_prediction_head, normal_activation,
|
||||||
|
upsample_via_bilinear, upsample_via_mask)
|
||||||
|
|
||||||
|
PROJECT_DIR = os.path.split(os.path.dirname(os.path.realpath(__file__)))[0]
|
||||||
|
sys.path.append(PROJECT_DIR)
|
||||||
|
|
||||||
|
|
||||||
|
class NormalNet(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, args):
|
||||||
|
super(NormalNet, self).__init__()
|
||||||
|
B = args.encoder_B
|
||||||
|
NF = args.decoder_NF
|
||||||
|
BN = args.decoder_BN
|
||||||
|
learned_upsampling = args.learned_upsampling
|
||||||
|
|
||||||
|
self.encoder = Encoder(B=B, pretrained=False, ckpt=args.encoder_path)
|
||||||
|
self.decoder = Decoder(
|
||||||
|
num_classes=args.output_dim,
|
||||||
|
B=B,
|
||||||
|
NF=NF,
|
||||||
|
BN=BN,
|
||||||
|
learned_upsampling=learned_upsampling)
|
||||||
|
|
||||||
|
def forward(self, x, **kwargs):
|
||||||
|
return self.decoder(self.encoder(x), **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class Decoder(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
num_classes=3,
|
||||||
|
B=5,
|
||||||
|
NF=2048,
|
||||||
|
BN=False,
|
||||||
|
learned_upsampling=True):
|
||||||
|
super(Decoder, self).__init__()
|
||||||
|
input_channels = [2048, 176, 64, 40, 24]
|
||||||
|
|
||||||
|
UpSample = UpSampleBN if BN else UpSampleGN
|
||||||
|
features = NF
|
||||||
|
|
||||||
|
self.conv2 = nn.Conv2d(
|
||||||
|
input_channels[0] + 2,
|
||||||
|
features,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0)
|
||||||
|
self.up1 = UpSample(
|
||||||
|
skip_input=features // 1 + input_channels[1] + 2,
|
||||||
|
output_features=features // 2,
|
||||||
|
align_corners=False)
|
||||||
|
self.up2 = UpSample(
|
||||||
|
skip_input=features // 2 + input_channels[2] + 2,
|
||||||
|
output_features=features // 4,
|
||||||
|
align_corners=False)
|
||||||
|
self.up3 = UpSample(
|
||||||
|
skip_input=features // 4 + input_channels[3] + 2,
|
||||||
|
output_features=features // 8,
|
||||||
|
align_corners=False)
|
||||||
|
self.up4 = UpSample(
|
||||||
|
skip_input=features // 8 + input_channels[4] + 2,
|
||||||
|
output_features=features // 16,
|
||||||
|
align_corners=False)
|
||||||
|
i_dim = features // 16
|
||||||
|
|
||||||
|
self.downsample_ratio = 2
|
||||||
|
self.output_dim = num_classes
|
||||||
|
|
||||||
|
self.pred_head = get_prediction_head(i_dim + 2, 128, num_classes)
|
||||||
|
if learned_upsampling:
|
||||||
|
self.mask_head = get_prediction_head(
|
||||||
|
i_dim + 2, 128,
|
||||||
|
9 * self.downsample_ratio * self.downsample_ratio)
|
||||||
|
self.upsample_fn = upsample_via_mask
|
||||||
|
else:
|
||||||
|
self.mask_head = lambda a: None
|
||||||
|
self.upsample_fn = upsample_via_bilinear
|
||||||
|
|
||||||
|
self.pixel_coords = get_pixel_coords(h=1024, w=1024).to(0)
|
||||||
|
|
||||||
|
def ray_embedding(self, x, intrins, orig_H, orig_W):
|
||||||
|
B, _, H, W = x.shape
|
||||||
|
fu = intrins[:, 0, 0].unsqueeze(-1).unsqueeze(-1) * (W / orig_W)
|
||||||
|
cu = intrins[:, 0, 2].unsqueeze(-1).unsqueeze(-1) * (W / orig_W)
|
||||||
|
fv = intrins[:, 1, 1].unsqueeze(-1).unsqueeze(-1) * (H / orig_H)
|
||||||
|
cv = intrins[:, 1, 2].unsqueeze(-1).unsqueeze(-1) * (H / orig_H)
|
||||||
|
|
||||||
|
uv = self.pixel_coords[:, :2, :H, :W].repeat(B, 1, 1, 1)
|
||||||
|
uv[:, 0, :, :] = (uv[:, 0, :, :] - cu) / fu
|
||||||
|
uv[:, 1, :, :] = (uv[:, 1, :, :] - cv) / fv
|
||||||
|
return torch.cat([x, uv], dim=1)
|
||||||
|
|
||||||
|
def forward(self, features, intrins):
|
||||||
|
x_block0, x_block1, x_block2, x_block3, x_block4 = features[4], features[5], features[6], \
|
||||||
|
features[8], features[11]
|
||||||
|
_, _, orig_H, orig_W = features[0].shape
|
||||||
|
|
||||||
|
x_d0 = self.conv2(
|
||||||
|
self.ray_embedding(x_block4, intrins, orig_H, orig_W))
|
||||||
|
x_d1 = self.up1(x_d0,
|
||||||
|
self.ray_embedding(x_block3, intrins, orig_H, orig_W))
|
||||||
|
x_d2 = self.up2(x_d1,
|
||||||
|
self.ray_embedding(x_block2, intrins, orig_H, orig_W))
|
||||||
|
x_d3 = self.up3(x_d2,
|
||||||
|
self.ray_embedding(x_block1, intrins, orig_H, orig_W))
|
||||||
|
x_feat = self.up4(
|
||||||
|
x_d3, self.ray_embedding(x_block0, intrins, orig_H, orig_W))
|
||||||
|
|
||||||
|
out = self.pred_head(
|
||||||
|
self.ray_embedding(x_feat, intrins, orig_H, orig_W))
|
||||||
|
out = normal_activation(out, elu_kappa=True)
|
||||||
|
mask = self.mask_head(
|
||||||
|
self.ray_embedding(x_feat, intrins, orig_H, orig_W))
|
||||||
|
up_out = self.upsample_fn(
|
||||||
|
out, up_mask=mask, downsample_ratio=self.downsample_ratio)
|
||||||
|
up_out = normal_activation(up_out, elu_kappa=False)
|
||||||
|
return [up_out]
|
||||||
@@ -0,0 +1,214 @@
|
|||||||
|
import geffnet
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
INPUT_CHANNELS_DICT = {
|
||||||
|
0: [1280, 112, 40, 24, 16],
|
||||||
|
1: [1280, 112, 40, 24, 16],
|
||||||
|
2: [1408, 120, 48, 24, 16],
|
||||||
|
3: [1536, 136, 48, 32, 24],
|
||||||
|
4: [1792, 160, 56, 32, 24],
|
||||||
|
5: [2048, 176, 64, 40, 24],
|
||||||
|
6: [2304, 200, 72, 40, 32],
|
||||||
|
7: [2560, 224, 80, 48, 32]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class Encoder(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, B=5, pretrained=True, ckpt=None):
|
||||||
|
super(Encoder, self).__init__()
|
||||||
|
if ckpt:
|
||||||
|
basemodel = geffnet.create_model(
|
||||||
|
'tf_efficientnet_b%s_ap' % B,
|
||||||
|
pretrained=pretrained,
|
||||||
|
checkpoint_path=ckpt)
|
||||||
|
else:
|
||||||
|
basemodel = geffnet.create_model(
|
||||||
|
'tf_efficientnet_b%s_ap' % B, pretrained=pretrained)
|
||||||
|
|
||||||
|
basemodel.global_pool = nn.Identity()
|
||||||
|
basemodel.classifier = nn.Identity()
|
||||||
|
self.original_model = basemodel
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
features = [x]
|
||||||
|
for k, v in self.original_model._modules.items():
|
||||||
|
if k == 'blocks':
|
||||||
|
for ki, vi in v._modules.items():
|
||||||
|
features.append(vi(features[-1]))
|
||||||
|
else:
|
||||||
|
features.append(v(features[-1]))
|
||||||
|
return features
|
||||||
|
|
||||||
|
|
||||||
|
class ConvGRU(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, hidden_dim, input_dim, ks=3):
|
||||||
|
super().__init__()
|
||||||
|
p = (ks - 1) // 2
|
||||||
|
self.convz = nn.Conv2d(
|
||||||
|
hidden_dim + input_dim, hidden_dim, ks, padding=p)
|
||||||
|
self.convr = nn.Conv2d(
|
||||||
|
hidden_dim + input_dim, hidden_dim, ks, padding=p)
|
||||||
|
self.convq = nn.Conv2d(
|
||||||
|
hidden_dim + input_dim, hidden_dim, ks, padding=p)
|
||||||
|
|
||||||
|
def forward(self, h, x):
|
||||||
|
hx = torch.cat([h, x], dim=1)
|
||||||
|
z = torch.sigmoid(self.convz(hx))
|
||||||
|
r = torch.sigmoid(self.convr(hx))
|
||||||
|
q = torch.tanh(self.convq(torch.cat([r * h, x], dim=1)))
|
||||||
|
h = (1 - z) * h + z * q
|
||||||
|
return h
|
||||||
|
|
||||||
|
|
||||||
|
class UpSampleBN(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, skip_input, output_features, align_corners=True):
|
||||||
|
super(UpSampleBN, self).__init__()
|
||||||
|
self._net = nn.Sequential(
|
||||||
|
nn.Conv2d(
|
||||||
|
skip_input,
|
||||||
|
output_features,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1), nn.BatchNorm2d(output_features), nn.LeakyReLU(),
|
||||||
|
nn.Conv2d(
|
||||||
|
output_features,
|
||||||
|
output_features,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1), nn.BatchNorm2d(output_features), nn.LeakyReLU())
|
||||||
|
self.align_corners = align_corners
|
||||||
|
|
||||||
|
def forward(self, x, concat_with):
|
||||||
|
up_x = F.interpolate(
|
||||||
|
x,
|
||||||
|
size=[concat_with.size(2),
|
||||||
|
concat_with.size(3)],
|
||||||
|
mode='bilinear',
|
||||||
|
align_corners=self.align_corners)
|
||||||
|
f = torch.cat([up_x, concat_with], dim=1)
|
||||||
|
return self._net(f)
|
||||||
|
|
||||||
|
|
||||||
|
class Conv2d_WS(nn.Conv2d):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
dilation=1,
|
||||||
|
groups=1,
|
||||||
|
bias=True):
|
||||||
|
super(Conv2d_WS,
|
||||||
|
self).__init__(in_channels, out_channels, kernel_size, stride,
|
||||||
|
padding, dilation, groups, bias)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
weight = self.weight
|
||||||
|
weight_mean = weight.mean(
|
||||||
|
dim=1, keepdim=True).mean(
|
||||||
|
dim=2, keepdim=True).mean(
|
||||||
|
dim=3, keepdim=True)
|
||||||
|
weight = weight - weight_mean
|
||||||
|
std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1,
|
||||||
|
1) + 1e-5
|
||||||
|
weight = weight / std.expand_as(weight)
|
||||||
|
return F.conv2d(x, weight, self.bias, self.stride, self.padding,
|
||||||
|
self.dilation, self.groups)
|
||||||
|
|
||||||
|
|
||||||
|
class UpSampleGN(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, skip_input, output_features, align_corners=True):
|
||||||
|
super(UpSampleGN, self).__init__()
|
||||||
|
self._net = nn.Sequential(
|
||||||
|
Conv2d_WS(
|
||||||
|
skip_input,
|
||||||
|
output_features,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1), nn.GroupNorm(8, output_features), nn.LeakyReLU(),
|
||||||
|
Conv2d_WS(
|
||||||
|
output_features,
|
||||||
|
output_features,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1), nn.GroupNorm(8, output_features), nn.LeakyReLU())
|
||||||
|
self.align_corners = align_corners
|
||||||
|
|
||||||
|
def forward(self, x, concat_with):
|
||||||
|
up_x = F.interpolate(
|
||||||
|
x,
|
||||||
|
size=[concat_with.size(2),
|
||||||
|
concat_with.size(3)],
|
||||||
|
mode='bilinear',
|
||||||
|
align_corners=self.align_corners)
|
||||||
|
f = torch.cat([up_x, concat_with], dim=1)
|
||||||
|
return self._net(f)
|
||||||
|
|
||||||
|
|
||||||
|
def upsample_via_bilinear(out, up_mask=None, downsample_ratio=None):
|
||||||
|
return F.interpolate(
|
||||||
|
out,
|
||||||
|
scale_factor=downsample_ratio,
|
||||||
|
mode='bilinear',
|
||||||
|
align_corners=False)
|
||||||
|
|
||||||
|
|
||||||
|
def upsample_via_mask(out, up_mask, downsample_ratio, padding='zero'):
|
||||||
|
"""
|
||||||
|
convex upsampling
|
||||||
|
"""
|
||||||
|
# out: low-resolution output (B, o_dim, H, W)
|
||||||
|
# up_mask: (B, 9*k*k, H, W)
|
||||||
|
k = downsample_ratio
|
||||||
|
|
||||||
|
B, C, H, W = out.shape
|
||||||
|
up_mask = up_mask.view(B, 1, 9, k, k, H, W)
|
||||||
|
up_mask = torch.softmax(up_mask, dim=2) # (B, 1, 9, k, k, H, W)
|
||||||
|
|
||||||
|
if padding == 'zero':
|
||||||
|
up_out = F.unfold(out, [3, 3], padding=1)
|
||||||
|
elif padding == 'replicate':
|
||||||
|
out = F.pad(out, pad=(1, 1, 1, 1), mode='replicate')
|
||||||
|
up_out = F.unfold(out, [3, 3], padding=0)
|
||||||
|
else:
|
||||||
|
raise Exception('invalid padding for convex upsampling')
|
||||||
|
|
||||||
|
up_out = up_out.view(B, C, 9, 1, 1, H, W)
|
||||||
|
|
||||||
|
up_out = torch.sum(up_mask * up_out, dim=2)
|
||||||
|
up_out = up_out.permute(0, 1, 4, 2, 5, 3)
|
||||||
|
return up_out.reshape(B, C, k * H, k * W)
|
||||||
|
|
||||||
|
|
||||||
|
def get_prediction_head(input_dim, hidden_dim, output_dim):
|
||||||
|
return nn.Sequential(
|
||||||
|
nn.Conv2d(input_dim, hidden_dim, 3, padding=1), nn.ReLU(inplace=True),
|
||||||
|
nn.Conv2d(hidden_dim, hidden_dim, 1), nn.ReLU(inplace=True),
|
||||||
|
nn.Conv2d(hidden_dim, output_dim, 1))
|
||||||
|
|
||||||
|
|
||||||
|
# submodules copy from DSINE
|
||||||
|
def get_pixel_coords(h, w):
|
||||||
|
pixel_coords = np.ones((3, h, w)).astype(np.float32)
|
||||||
|
x_range = np.concatenate([np.arange(w).reshape(1, w)] * h, axis=0)
|
||||||
|
y_range = np.concatenate([np.arange(h).reshape(h, 1)] * w, axis=1)
|
||||||
|
pixel_coords[0, :, :] = x_range + 0.5
|
||||||
|
pixel_coords[1, :, :] = y_range + 0.5
|
||||||
|
return torch.from_numpy(pixel_coords).unsqueeze(0)
|
||||||
|
|
||||||
|
|
||||||
|
def normal_activation(out, elu_kappa=True):
|
||||||
|
normal, kappa = out[:, :3, :, :], out[:, 3:, :, :]
|
||||||
|
normal = F.normalize(normal, p=2, dim=1)
|
||||||
|
if elu_kappa:
|
||||||
|
kappa = F.elu(kappa) + 1.0
|
||||||
|
return torch.cat([normal, kappa], dim=1)
|
||||||
@@ -123,6 +123,7 @@ if TYPE_CHECKING:
|
|||||||
from .anydoor_pipeline import AnydoorPipeline
|
from .anydoor_pipeline import AnydoorPipeline
|
||||||
from .image_depth_estimation_marigold_pipeline import ImageDepthEstimationMarigoldPipeline
|
from .image_depth_estimation_marigold_pipeline import ImageDepthEstimationMarigoldPipeline
|
||||||
from .self_supervised_depth_completion_pipeline import SelfSupervisedDepthCompletionPipeline
|
from .self_supervised_depth_completion_pipeline import SelfSupervisedDepthCompletionPipeline
|
||||||
|
from .human_normal_estimation_pipeline import HumanNormalEstimationPipeline
|
||||||
|
|
||||||
else:
|
else:
|
||||||
_import_structure = {
|
_import_structure = {
|
||||||
@@ -312,6 +313,7 @@ else:
|
|||||||
'self_supervised_depth_completion_pipeline': [
|
'self_supervised_depth_completion_pipeline': [
|
||||||
'SelfSupervisedDepthCompletionPipeline'
|
'SelfSupervisedDepthCompletionPipeline'
|
||||||
],
|
],
|
||||||
|
'human_normal_estimation_pipeline': ['HumanNormalEstimationPipeline'],
|
||||||
}
|
}
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
|
|||||||
95
modelscope/pipelines/cv/human_normal_estimation_pipeline.py
Normal file
95
modelscope/pipelines/cv/human_normal_estimation_pipeline.py
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from modelscope.metainfo import Pipelines
|
||||||
|
from modelscope.outputs import OutputKeys
|
||||||
|
from modelscope.pipelines.base import Input, Pipeline
|
||||||
|
from modelscope.pipelines.builder import PIPELINES
|
||||||
|
from modelscope.utils.constant import Tasks
|
||||||
|
from modelscope.utils.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
@PIPELINES.register_module(
|
||||||
|
Tasks.human_normal_estimation,
|
||||||
|
module_name=Pipelines.human_normal_estimation)
|
||||||
|
class HumanNormalEstimationPipeline(Pipeline):
|
||||||
|
r""" Human Normal Estimation Pipeline.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
>>> from modelscope.pipelines import pipeline
|
||||||
|
|
||||||
|
>>> estimator = pipeline(
|
||||||
|
>>> Tasks.human_normal_estimation, model='Damo_XR_Lab/cv_human_monocular-normal-estimation')
|
||||||
|
>>> estimator(f"{model_dir}/tests/image_normal_estimation.jpg")
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model: str, **kwargs):
|
||||||
|
"""
|
||||||
|
use `model` to create a image normal estimation pipeline for prediction
|
||||||
|
Args:
|
||||||
|
model: model id on modelscope hub.
|
||||||
|
"""
|
||||||
|
super().__init__(model=model, **kwargs)
|
||||||
|
logger.info('normal estimation model, pipeline init')
|
||||||
|
|
||||||
|
def preprocess(self, input: Input) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input: string or ndarray or Image.Image
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
data: dict including inference inputs
|
||||||
|
"""
|
||||||
|
if isinstance(input, str):
|
||||||
|
img = np.array(Image.open(input))
|
||||||
|
if isinstance(input, Image.Image):
|
||||||
|
img = np.array(input)
|
||||||
|
|
||||||
|
img_h, img_w, img_ch = img.shape[0:3]
|
||||||
|
|
||||||
|
if img_ch == 3:
|
||||||
|
msk = np.full((img_h, img_w, 1), 255, dtype=np.uint8)
|
||||||
|
img = np.concatenate((img, msk), axis=-1)
|
||||||
|
|
||||||
|
H, W = 1024, 1024
|
||||||
|
scale_factor = min(W / img_w, H / img_h)
|
||||||
|
img = Image.fromarray(img)
|
||||||
|
img = img.resize(
|
||||||
|
(int(img_w * scale_factor), int(img_h * scale_factor)),
|
||||||
|
Image.LANCZOS)
|
||||||
|
|
||||||
|
new_img = Image.new('RGBA', (W, H), color=(0, 0, 0, 0))
|
||||||
|
paste_pos_w = (W - img.width) // 2
|
||||||
|
paste_pos_h = (H - img.height) // 2
|
||||||
|
new_img.paste(img, (paste_pos_w, paste_pos_h))
|
||||||
|
|
||||||
|
bbox = (paste_pos_w, paste_pos_h, paste_pos_w + img.width,
|
||||||
|
paste_pos_h + img.height)
|
||||||
|
img = np.array(new_img)
|
||||||
|
|
||||||
|
data = {'img': img[:, :, 0:3], 'msk': img[:, :, -1], 'bbox': bbox}
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
results = self.model.inference(input)
|
||||||
|
return results
|
||||||
|
|
||||||
|
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
results = self.model.postprocess(inputs)
|
||||||
|
normals = results[OutputKeys.NORMALS]
|
||||||
|
|
||||||
|
normals_vis = (((normals + 1) * 0.5) * 255).astype(np.uint8)
|
||||||
|
normals_vis = normals_vis[..., [2, 1, 0]]
|
||||||
|
outputs = {
|
||||||
|
OutputKeys.NORMALS: normals,
|
||||||
|
OutputKeys.NORMALS_COLOR: normals_vis
|
||||||
|
}
|
||||||
|
return outputs
|
||||||
@@ -78,6 +78,8 @@ class CVTasks(object):
|
|||||||
image_local_feature_matching = 'image-local-feature-matching'
|
image_local_feature_matching = 'image-local-feature-matching'
|
||||||
image_quality_assessment_degradation = 'image-quality-assessment-degradation'
|
image_quality_assessment_degradation = 'image-quality-assessment-degradation'
|
||||||
|
|
||||||
|
human_normal_estimation = 'human-normal-estimation'
|
||||||
|
|
||||||
crowd_counting = 'crowd-counting'
|
crowd_counting = 'crowd-counting'
|
||||||
|
|
||||||
# image editing
|
# image editing
|
||||||
|
|||||||
@@ -1179,6 +1179,13 @@
|
|||||||
"type": "object"
|
"type": "object"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"human-normal-estimation": {
|
||||||
|
"input": {},
|
||||||
|
"parameters": {},
|
||||||
|
"output": {
|
||||||
|
"type": "object"
|
||||||
|
}
|
||||||
|
},
|
||||||
"image-driving-perception": {
|
"image-driving-perception": {
|
||||||
"input": {
|
"input": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
|
|||||||
37
tests/pipelines/test_human_normal_estimation.py
Normal file
37
tests/pipelines/test_human_normal_estimation.py
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||||
|
import os
|
||||||
|
import os.path
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
from modelscope.outputs import OutputKeys
|
||||||
|
from modelscope.pipelines import pipeline
|
||||||
|
from modelscope.utils.constant import Tasks
|
||||||
|
from modelscope.utils.test_utils import test_level
|
||||||
|
|
||||||
|
|
||||||
|
class HumanNormalEstimationTest(unittest.TestCase):
|
||||||
|
|
||||||
|
def setUp(self) -> None:
|
||||||
|
self.task = 'human-normal-estimation'
|
||||||
|
self.model_id = 'Damo_XR_Lab/cv_human_monocular-normal-estimation'
|
||||||
|
|
||||||
|
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||||
|
def test_image_normal_estimation(self):
|
||||||
|
cur_dir = os.path.dirname(
|
||||||
|
os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
input_location = f'{cur_dir}/data/test/images/human_normal_estimation.png'
|
||||||
|
estimator = pipeline(
|
||||||
|
Tasks.human_normal_estimation, model=self.model_id)
|
||||||
|
result = estimator(input_location)
|
||||||
|
normals_vis = result[OutputKeys.NORMALS_COLOR]
|
||||||
|
|
||||||
|
input_img = cv2.imread(input_location)
|
||||||
|
normals_vis = cv2.resize(
|
||||||
|
normals_vis, dsize=(input_img.shape[1], input_img.shape[0]))
|
||||||
|
cv2.imwrite('result.jpg', normals_vis)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user