upload high quality human normal estimation model (#903)

* upload human normal estimation model

* upload human normal estimation unittest

* update human normal estimation

* update test data

---------

Co-authored-by: 葭润 <ranqing.rq@alibaba-inc.com>
Co-authored-by: ranqing <ranqing@Sundays-Mac.local>
Co-authored-by: suluyana <suluyan_sly@163.com>
This commit is contained in:
Ranqing
2024-07-31 10:19:41 +08:00
committed by GitHub
parent 836f206207
commit 543a2b32d7
13 changed files with 633 additions and 4 deletions

View File

@@ -58,6 +58,7 @@ class Models(object):
s2net_depth_estimation = 's2net-depth-estimation'
dro_resnet18_depth_estimation = 'dro-resnet18-depth-estimation'
raft_dense_optical_flow_estimation = 'raft-dense-optical-flow-estimation'
human_normal_estimation = 'human-normal-estimation'
resnet50_bert = 'resnet50-bert'
referring_video_object_segmentation = 'swinT-referring-video-object-segmentation'
fer = 'fer'
@@ -480,6 +481,7 @@ class Pipelines(object):
anydoor = 'anydoor'
image_to_3d = 'image-to-3d'
self_supervised_depth_completion = 'self-supervised-depth-completion'
human_normal_estimation = 'human-normal-estimation'
# nlp tasks
automatic_post_editing = 'automatic-post-editing'
@@ -814,6 +816,9 @@ DEFAULT_MODEL_FOR_PIPELINE = {
Tasks.image_normal_estimation:
(Pipelines.image_normal_estimation,
'Damo_XR_Lab/cv_omnidata_image-normal-estimation_normal'),
Tasks.human_normal_estimation:
(Pipelines.human_normal_estimation,
'Damo_XR_Lab/cv_human_monocular-normal-estimation'),
Tasks.indoor_layout_estimation:
(Pipelines.indoor_layout_estimation,
'damo/cv_panovit_indoor-layout-estimation'),
@@ -846,9 +851,9 @@ DEFAULT_MODEL_FOR_PIPELINE = {
Tasks.image_to_image_generation:
(Pipelines.image_to_image_generation,
'damo/cv_latent_diffusion_image2image_generate'),
Tasks.image_classification:
(Pipelines.daily_image_classification,
'damo/cv_vit-base_image-classification_Dailylife-labels'),
Tasks.image_classification: (
Pipelines.daily_image_classification,
'damo/cv_vit-base_image-classification_Dailylife-labels'),
Tasks.image_object_detection: (
Pipelines.image_object_detection_auto,
'damo/cv_yolox_image-object-detection-auto'),

View File

@@ -0,0 +1,22 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import TYPE_CHECKING
from modelscope.utils.import_utils import LazyImportModule
if TYPE_CHECKING:
from .human_nnet import HumanNormalEstimation
else:
_import_structure = {
'human_nnet': ['HumanNormalEstimation'],
}
import sys
sys.modules[__name__] = LazyImportModule(
__name__,
globals()['__file__'],
_import_structure,
module_spec=__spec__,
extra_objects={},
)

View File

@@ -0,0 +1,80 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import numpy as np
import torch
import torchvision.transforms as T
from modelscope.metainfo import Models
from modelscope.models.base.base_torch_model import TorchModel
from modelscope.models.builder import MODELS
from modelscope.models.cv.human_normal_estimation.networks import config, nnet
from modelscope.outputs import OutputKeys
from modelscope.utils.constant import Tasks
@MODELS.register_module(
Tasks.human_normal_estimation, module_name=Models.human_normal_estimation)
class HumanNormalEstimation(TorchModel):
def __init__(self, model_dir: str, **kwargs):
super().__init__(model_dir, **kwargs)
config_file = os.path.join(model_dir, 'config.txt')
args = config.get_args(txt_file=config_file)
args.encoder_path = os.path.join(model_dir, args.encoder_path)
self.device = torch.device(
'cuda:0') if torch.cuda.is_available() else torch.device('cpu')
self.nnet = nnet.NormalNet(args=args).to(self.device)
self.nnet_path = os.path.join(model_dir, 'ckpt/best_nnet.pt')
if os.path.exists(self.nnet_path):
ckpt = torch.load(
self.nnet_path, map_location=self.device)['model']
load_dict = {}
for k, v in ckpt.items():
if k.startswith('module.'):
k_ = k.replace('module.', '')
load_dict[k_] = v
else:
load_dict[k] = v
self.nnet.load_state_dict(load_dict)
self.nnet.eval()
self.normalize = T.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
def forward(self, inputs):
img = inputs['img'].astype(np.float32) / 255.0
msk = inputs['msk'].astype(np.float32) / 255.0
bbox = inputs['bbox']
img_h, img_w = img.shape[0:2]
img = torch.from_numpy(img).permute(2, 0,
1).unsqueeze(0).to(self.device)
img = self.normalize(img)
fx = fy = (max(img_h, img_h) / 2.0) / np.tan(np.deg2rad(60.0 / 2.0))
cx = (img_h / 2.0) - 0.5
cy = (img_w / 2.0) - 0.5
intrins = torch.tensor(
[[fx, 0, cx + 0.5], [0, fy, cy + 0.5], [0, 0, 1]],
dtype=torch.float32,
device=self.device).unsqueeze(0)
pred_norm = self.nnet(img, intrins=intrins)[-1]
pred_norm = pred_norm.detach().cpu().permute(0, 2, 3, 1).numpy()
pred_norm = pred_norm[0, ...]
pred_norm = pred_norm * msk[..., None]
pred_norm = pred_norm[bbox[1]:bbox[3], bbox[0]:bbox[2]]
results = pred_norm
return results
def postprocess(self, inputs):
normal_result = inputs
results = {OutputKeys.NORMALS: normal_result}
return results
def inference(self, data):
results = self.forward(data)
return results

View File

@@ -0,0 +1,40 @@
import argparse
def convert_arg_line_to_args(arg_line):
for arg in arg_line.split():
if not arg.strip():
continue
yield str(arg)
def get_args(txt_file=None):
parser = argparse.ArgumentParser(
fromfile_prefix_chars='@', conflict_handler='resolve')
parser.convert_arg_line_to_args = convert_arg_line_to_args
# checkpoint (only needed when testing the model)
parser.add_argument('--ckpt_path', type=str, default=None)
parser.add_argument('--encoder_path', type=str, default=None)
# ↓↓↓↓
# NOTE: project-specific args
parser.add_argument('--output_dim', type=int, default=3, help='{3, 4}')
parser.add_argument('--output_type', type=str, default='R', help='{R, G}')
parser.add_argument('--feature_dim', type=int, default=64)
parser.add_argument('--hidden_dim', type=int, default=64)
parser.add_argument('--encoder_B', type=int, default=5)
parser.add_argument('--decoder_NF', type=int, default=2048)
parser.add_argument('--decoder_BN', default=False, action='store_true')
parser.add_argument('--decoder_down', type=int, default=2)
parser.add_argument(
'--learned_upsampling', default=False, action='store_true')
# read arguments from txt file
if txt_file:
config_filename = '@' + txt_file
args = parser.parse_args([config_filename])
return args

View File

@@ -0,0 +1,125 @@
import os
import sys
import torch
import torch.nn as nn
from .submodules import (Encoder, UpSampleBN, UpSampleGN, get_pixel_coords,
get_prediction_head, normal_activation,
upsample_via_bilinear, upsample_via_mask)
PROJECT_DIR = os.path.split(os.path.dirname(os.path.realpath(__file__)))[0]
sys.path.append(PROJECT_DIR)
class NormalNet(nn.Module):
def __init__(self, args):
super(NormalNet, self).__init__()
B = args.encoder_B
NF = args.decoder_NF
BN = args.decoder_BN
learned_upsampling = args.learned_upsampling
self.encoder = Encoder(B=B, pretrained=False, ckpt=args.encoder_path)
self.decoder = Decoder(
num_classes=args.output_dim,
B=B,
NF=NF,
BN=BN,
learned_upsampling=learned_upsampling)
def forward(self, x, **kwargs):
return self.decoder(self.encoder(x), **kwargs)
class Decoder(nn.Module):
def __init__(self,
num_classes=3,
B=5,
NF=2048,
BN=False,
learned_upsampling=True):
super(Decoder, self).__init__()
input_channels = [2048, 176, 64, 40, 24]
UpSample = UpSampleBN if BN else UpSampleGN
features = NF
self.conv2 = nn.Conv2d(
input_channels[0] + 2,
features,
kernel_size=1,
stride=1,
padding=0)
self.up1 = UpSample(
skip_input=features // 1 + input_channels[1] + 2,
output_features=features // 2,
align_corners=False)
self.up2 = UpSample(
skip_input=features // 2 + input_channels[2] + 2,
output_features=features // 4,
align_corners=False)
self.up3 = UpSample(
skip_input=features // 4 + input_channels[3] + 2,
output_features=features // 8,
align_corners=False)
self.up4 = UpSample(
skip_input=features // 8 + input_channels[4] + 2,
output_features=features // 16,
align_corners=False)
i_dim = features // 16
self.downsample_ratio = 2
self.output_dim = num_classes
self.pred_head = get_prediction_head(i_dim + 2, 128, num_classes)
if learned_upsampling:
self.mask_head = get_prediction_head(
i_dim + 2, 128,
9 * self.downsample_ratio * self.downsample_ratio)
self.upsample_fn = upsample_via_mask
else:
self.mask_head = lambda a: None
self.upsample_fn = upsample_via_bilinear
self.pixel_coords = get_pixel_coords(h=1024, w=1024).to(0)
def ray_embedding(self, x, intrins, orig_H, orig_W):
B, _, H, W = x.shape
fu = intrins[:, 0, 0].unsqueeze(-1).unsqueeze(-1) * (W / orig_W)
cu = intrins[:, 0, 2].unsqueeze(-1).unsqueeze(-1) * (W / orig_W)
fv = intrins[:, 1, 1].unsqueeze(-1).unsqueeze(-1) * (H / orig_H)
cv = intrins[:, 1, 2].unsqueeze(-1).unsqueeze(-1) * (H / orig_H)
uv = self.pixel_coords[:, :2, :H, :W].repeat(B, 1, 1, 1)
uv[:, 0, :, :] = (uv[:, 0, :, :] - cu) / fu
uv[:, 1, :, :] = (uv[:, 1, :, :] - cv) / fv
return torch.cat([x, uv], dim=1)
def forward(self, features, intrins):
x_block0, x_block1, x_block2, x_block3, x_block4 = features[4], features[5], features[6], \
features[8], features[11]
_, _, orig_H, orig_W = features[0].shape
x_d0 = self.conv2(
self.ray_embedding(x_block4, intrins, orig_H, orig_W))
x_d1 = self.up1(x_d0,
self.ray_embedding(x_block3, intrins, orig_H, orig_W))
x_d2 = self.up2(x_d1,
self.ray_embedding(x_block2, intrins, orig_H, orig_W))
x_d3 = self.up3(x_d2,
self.ray_embedding(x_block1, intrins, orig_H, orig_W))
x_feat = self.up4(
x_d3, self.ray_embedding(x_block0, intrins, orig_H, orig_W))
out = self.pred_head(
self.ray_embedding(x_feat, intrins, orig_H, orig_W))
out = normal_activation(out, elu_kappa=True)
mask = self.mask_head(
self.ray_embedding(x_feat, intrins, orig_H, orig_W))
up_out = self.upsample_fn(
out, up_mask=mask, downsample_ratio=self.downsample_ratio)
up_out = normal_activation(up_out, elu_kappa=False)
return [up_out]

View File

@@ -0,0 +1,214 @@
import geffnet
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
INPUT_CHANNELS_DICT = {
0: [1280, 112, 40, 24, 16],
1: [1280, 112, 40, 24, 16],
2: [1408, 120, 48, 24, 16],
3: [1536, 136, 48, 32, 24],
4: [1792, 160, 56, 32, 24],
5: [2048, 176, 64, 40, 24],
6: [2304, 200, 72, 40, 32],
7: [2560, 224, 80, 48, 32]
}
class Encoder(nn.Module):
def __init__(self, B=5, pretrained=True, ckpt=None):
super(Encoder, self).__init__()
if ckpt:
basemodel = geffnet.create_model(
'tf_efficientnet_b%s_ap' % B,
pretrained=pretrained,
checkpoint_path=ckpt)
else:
basemodel = geffnet.create_model(
'tf_efficientnet_b%s_ap' % B, pretrained=pretrained)
basemodel.global_pool = nn.Identity()
basemodel.classifier = nn.Identity()
self.original_model = basemodel
def forward(self, x):
features = [x]
for k, v in self.original_model._modules.items():
if k == 'blocks':
for ki, vi in v._modules.items():
features.append(vi(features[-1]))
else:
features.append(v(features[-1]))
return features
class ConvGRU(nn.Module):
def __init__(self, hidden_dim, input_dim, ks=3):
super().__init__()
p = (ks - 1) // 2
self.convz = nn.Conv2d(
hidden_dim + input_dim, hidden_dim, ks, padding=p)
self.convr = nn.Conv2d(
hidden_dim + input_dim, hidden_dim, ks, padding=p)
self.convq = nn.Conv2d(
hidden_dim + input_dim, hidden_dim, ks, padding=p)
def forward(self, h, x):
hx = torch.cat([h, x], dim=1)
z = torch.sigmoid(self.convz(hx))
r = torch.sigmoid(self.convr(hx))
q = torch.tanh(self.convq(torch.cat([r * h, x], dim=1)))
h = (1 - z) * h + z * q
return h
class UpSampleBN(nn.Module):
def __init__(self, skip_input, output_features, align_corners=True):
super(UpSampleBN, self).__init__()
self._net = nn.Sequential(
nn.Conv2d(
skip_input,
output_features,
kernel_size=3,
stride=1,
padding=1), nn.BatchNorm2d(output_features), nn.LeakyReLU(),
nn.Conv2d(
output_features,
output_features,
kernel_size=3,
stride=1,
padding=1), nn.BatchNorm2d(output_features), nn.LeakyReLU())
self.align_corners = align_corners
def forward(self, x, concat_with):
up_x = F.interpolate(
x,
size=[concat_with.size(2),
concat_with.size(3)],
mode='bilinear',
align_corners=self.align_corners)
f = torch.cat([up_x, concat_with], dim=1)
return self._net(f)
class Conv2d_WS(nn.Conv2d):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True):
super(Conv2d_WS,
self).__init__(in_channels, out_channels, kernel_size, stride,
padding, dilation, groups, bias)
def forward(self, x):
weight = self.weight
weight_mean = weight.mean(
dim=1, keepdim=True).mean(
dim=2, keepdim=True).mean(
dim=3, keepdim=True)
weight = weight - weight_mean
std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1,
1) + 1e-5
weight = weight / std.expand_as(weight)
return F.conv2d(x, weight, self.bias, self.stride, self.padding,
self.dilation, self.groups)
class UpSampleGN(nn.Module):
def __init__(self, skip_input, output_features, align_corners=True):
super(UpSampleGN, self).__init__()
self._net = nn.Sequential(
Conv2d_WS(
skip_input,
output_features,
kernel_size=3,
stride=1,
padding=1), nn.GroupNorm(8, output_features), nn.LeakyReLU(),
Conv2d_WS(
output_features,
output_features,
kernel_size=3,
stride=1,
padding=1), nn.GroupNorm(8, output_features), nn.LeakyReLU())
self.align_corners = align_corners
def forward(self, x, concat_with):
up_x = F.interpolate(
x,
size=[concat_with.size(2),
concat_with.size(3)],
mode='bilinear',
align_corners=self.align_corners)
f = torch.cat([up_x, concat_with], dim=1)
return self._net(f)
def upsample_via_bilinear(out, up_mask=None, downsample_ratio=None):
return F.interpolate(
out,
scale_factor=downsample_ratio,
mode='bilinear',
align_corners=False)
def upsample_via_mask(out, up_mask, downsample_ratio, padding='zero'):
"""
convex upsampling
"""
# out: low-resolution output (B, o_dim, H, W)
# up_mask: (B, 9*k*k, H, W)
k = downsample_ratio
B, C, H, W = out.shape
up_mask = up_mask.view(B, 1, 9, k, k, H, W)
up_mask = torch.softmax(up_mask, dim=2) # (B, 1, 9, k, k, H, W)
if padding == 'zero':
up_out = F.unfold(out, [3, 3], padding=1)
elif padding == 'replicate':
out = F.pad(out, pad=(1, 1, 1, 1), mode='replicate')
up_out = F.unfold(out, [3, 3], padding=0)
else:
raise Exception('invalid padding for convex upsampling')
up_out = up_out.view(B, C, 9, 1, 1, H, W)
up_out = torch.sum(up_mask * up_out, dim=2)
up_out = up_out.permute(0, 1, 4, 2, 5, 3)
return up_out.reshape(B, C, k * H, k * W)
def get_prediction_head(input_dim, hidden_dim, output_dim):
return nn.Sequential(
nn.Conv2d(input_dim, hidden_dim, 3, padding=1), nn.ReLU(inplace=True),
nn.Conv2d(hidden_dim, hidden_dim, 1), nn.ReLU(inplace=True),
nn.Conv2d(hidden_dim, output_dim, 1))
# submodules copy from DSINE
def get_pixel_coords(h, w):
pixel_coords = np.ones((3, h, w)).astype(np.float32)
x_range = np.concatenate([np.arange(w).reshape(1, w)] * h, axis=0)
y_range = np.concatenate([np.arange(h).reshape(h, 1)] * w, axis=1)
pixel_coords[0, :, :] = x_range + 0.5
pixel_coords[1, :, :] = y_range + 0.5
return torch.from_numpy(pixel_coords).unsqueeze(0)
def normal_activation(out, elu_kappa=True):
normal, kappa = out[:, :3, :, :], out[:, 3:, :, :]
normal = F.normalize(normal, p=2, dim=1)
if elu_kappa:
kappa = F.elu(kappa) + 1.0
return torch.cat([normal, kappa], dim=1)

View File

@@ -123,6 +123,7 @@ if TYPE_CHECKING:
from .anydoor_pipeline import AnydoorPipeline
from .image_depth_estimation_marigold_pipeline import ImageDepthEstimationMarigoldPipeline
from .self_supervised_depth_completion_pipeline import SelfSupervisedDepthCompletionPipeline
from .human_normal_estimation_pipeline import HumanNormalEstimationPipeline
else:
_import_structure = {
@@ -312,6 +313,7 @@ else:
'self_supervised_depth_completion_pipeline': [
'SelfSupervisedDepthCompletionPipeline'
],
'human_normal_estimation_pipeline': ['HumanNormalEstimationPipeline'],
}
import sys

View File

@@ -0,0 +1,95 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Any, Dict
import numpy as np
from PIL import Image
from modelscope.metainfo import Pipelines
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Input, Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger
logger = get_logger()
@PIPELINES.register_module(
Tasks.human_normal_estimation,
module_name=Pipelines.human_normal_estimation)
class HumanNormalEstimationPipeline(Pipeline):
r""" Human Normal Estimation Pipeline.
Examples:
>>> from modelscope.pipelines import pipeline
>>> estimator = pipeline(
>>> Tasks.human_normal_estimation, model='Damo_XR_Lab/cv_human_monocular-normal-estimation')
>>> estimator(f"{model_dir}/tests/image_normal_estimation.jpg")
"""
def __init__(self, model: str, **kwargs):
"""
use `model` to create a image normal estimation pipeline for prediction
Args:
model: model id on modelscope hub.
"""
super().__init__(model=model, **kwargs)
logger.info('normal estimation model, pipeline init')
def preprocess(self, input: Input) -> Dict[str, Any]:
"""
Args:
input: string or ndarray or Image.Image
Returns:
data: dict including inference inputs
"""
if isinstance(input, str):
img = np.array(Image.open(input))
if isinstance(input, Image.Image):
img = np.array(input)
img_h, img_w, img_ch = img.shape[0:3]
if img_ch == 3:
msk = np.full((img_h, img_w, 1), 255, dtype=np.uint8)
img = np.concatenate((img, msk), axis=-1)
H, W = 1024, 1024
scale_factor = min(W / img_w, H / img_h)
img = Image.fromarray(img)
img = img.resize(
(int(img_w * scale_factor), int(img_h * scale_factor)),
Image.LANCZOS)
new_img = Image.new('RGBA', (W, H), color=(0, 0, 0, 0))
paste_pos_w = (W - img.width) // 2
paste_pos_h = (H - img.height) // 2
new_img.paste(img, (paste_pos_w, paste_pos_h))
bbox = (paste_pos_w, paste_pos_h, paste_pos_w + img.width,
paste_pos_h + img.height)
img = np.array(new_img)
data = {'img': img[:, :, 0:3], 'msk': img[:, :, -1], 'bbox': bbox}
return data
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
results = self.model.inference(input)
return results
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
results = self.model.postprocess(inputs)
normals = results[OutputKeys.NORMALS]
normals_vis = (((normals + 1) * 0.5) * 255).astype(np.uint8)
normals_vis = normals_vis[..., [2, 1, 0]]
outputs = {
OutputKeys.NORMALS: normals,
OutputKeys.NORMALS_COLOR: normals_vis
}
return outputs

View File

@@ -78,6 +78,8 @@ class CVTasks(object):
image_local_feature_matching = 'image-local-feature-matching'
image_quality_assessment_degradation = 'image-quality-assessment-degradation'
human_normal_estimation = 'human-normal-estimation'
crowd_counting = 'crowd-counting'
# image editing

View File

@@ -1179,6 +1179,13 @@
"type": "object"
}
},
"human-normal-estimation": {
"input": {},
"parameters": {},
"output": {
"type": "object"
}
},
"image-driving-perception": {
"input": {
"type": "object",

View File

@@ -0,0 +1,37 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import os.path
import unittest
import cv2
from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level
class HumanNormalEstimationTest(unittest.TestCase):
def setUp(self) -> None:
self.task = 'human-normal-estimation'
self.model_id = 'Damo_XR_Lab/cv_human_monocular-normal-estimation'
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_image_normal_estimation(self):
cur_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
input_location = f'{cur_dir}/data/test/images/human_normal_estimation.png'
estimator = pipeline(
Tasks.human_normal_estimation, model=self.model_id)
result = estimator(input_location)
normals_vis = result[OutputKeys.NORMALS_COLOR]
input_img = cv2.imread(input_location)
normals_vis = cv2.resize(
normals_vis, dsize=(input_img.shape[1], input_img.shape[0]))
cv2.imwrite('result.jpg', normals_vis)
if __name__ == '__main__':
unittest.main()