add human reconstruction task

单图人体重建任务 

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11778199

 * add human reconstruction task
This commit is contained in:
jinmao.yk
2023-03-09 21:58:48 +08:00
committed by wenmeng.zwm
parent f493e33720
commit 4078abf488
21 changed files with 1630 additions and 7 deletions

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:06ec486657dffbf244563a844c98c19d49b7a45b99da702403b52bb9e6bf3c0a
size 226072

View File

@@ -78,6 +78,7 @@ class Models(object):
image_body_reshaping = 'image-body-reshaping'
image_skychange = 'image-skychange'
video_human_matting = 'video-human-matting'
human_reconstruction = 'human-reconstruction'
video_frame_interpolation = 'video-frame-interpolation'
video_object_segmentation = 'video-object-segmentation'
video_deinterlace = 'video-deinterlace'
@@ -361,6 +362,7 @@ class Pipelines(object):
referring_video_object_segmentation = 'referring-video-object-segmentation'
image_skychange = 'image-skychange'
video_human_matting = 'video-human-matting'
human_reconstruction = 'human-reconstruction'
vision_middleware_multi_task = 'vision-middleware-multi-task'
vidt = 'vidt'
video_frame_interpolation = 'video-frame-interpolation'
@@ -751,6 +753,8 @@ DEFAULT_MODEL_FOR_PIPELINE = {
'damo/cv_video-inpainting'),
Tasks.video_human_matting: (Pipelines.video_human_matting,
'damo/cv_effnetv2_video-human-matting'),
Tasks.human_reconstruction: (Pipelines.human_reconstruction,
'damo/cv_hrnet_image-human-reconstruction'),
Tasks.video_frame_interpolation: (
Pipelines.video_frame_interpolation,
'damo/cv_raft_video-frame-interpolation'),

View File

@@ -5,13 +5,13 @@ from . import (action_recognition, animal_recognition, bad_image_detecting,
body_2d_keypoints, body_3d_keypoints, cartoon,
cmdssl_video_embedding, controllable_image_generation,
crowd_counting, face_2d_keypoints, face_detection,
face_generation, face_reconstruction, human_wholebody_keypoint,
image_classification, image_color_enhance, image_colorization,
image_defrcn_fewshot, image_denoise, image_inpainting,
image_instance_segmentation, image_matching,
image_mvs_depth_estimation, image_panoptic_segmentation,
image_portrait_enhancement, image_probing_model,
image_quality_assessment_degradation,
face_generation, face_reconstruction, human_reconstruction,
human_wholebody_keypoint, image_classification,
image_color_enhance, image_colorization, image_defrcn_fewshot,
image_denoise, image_inpainting, image_instance_segmentation,
image_matching, image_mvs_depth_estimation,
image_panoptic_segmentation, image_portrait_enhancement,
image_probing_model, image_quality_assessment_degradation,
image_quality_assessment_man, image_quality_assessment_mos,
image_reid_person, image_restoration,
image_semantic_segmentation, image_to_image_generation,

View File

@@ -0,0 +1,137 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os.path as osp
from typing import Optional
import cv2
import numpy as np
import PIL.Image as Image
import torch
import torchvision.transforms as transforms
from skimage.io import imread
from skimage.transform import estimate_transform, warp
from modelscope.metainfo import Models
from modelscope.models.base import Tensor, TorchModel
from modelscope.models.builder import MODELS
from modelscope.models.cv.human_reconstruction.models.detectors import \
FasterRCNN
from modelscope.models.cv.human_reconstruction.models.human_segmenter import \
human_segmenter
from modelscope.models.cv.human_reconstruction.models.networks import define_G
from modelscope.models.cv.human_reconstruction.models.PixToMesh import \
Pixto3DNet
from modelscope.models.cv.human_reconstruction.utils import create_grid
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger
logger = get_logger()
@MODELS.register_module(
Tasks.human_reconstruction, module_name=Models.human_reconstruction)
class HumanReconstruction(TorchModel):
def __init__(self, model_dir, modelconfig, *args, **kwargs):
"""The HumanReconstruction is modified based on PiFuHD and pix2pixhd, publicly available at
https://shunsukesaito.github.io/PIFuHD/ &
https://github.com/NVIDIA/pix2pixHD
Args:
model_dir: the root directory of the model files
modelconfig: the config param path of the model
"""
super().__init__(model_dir=model_dir, *args, **kwargs)
if torch.cuda.is_available():
self.device = torch.device('cuda')
logger.info('Use GPU: {}'.format(self.device))
else:
self.device = torch.device('cpu')
logger.info('Use CPU: {}'.format(self.device))
model_path = '{}/{}'.format(model_dir, ModelFile.TORCH_MODEL_FILE)
normal_back_model = '{}/{}'.format(model_dir, 'Norm_B_GAN.pth')
normal_front_model = '{}/{}'.format(model_dir, 'Norm_F_GAN.pth')
human_seg_model = '{}/{}'.format(model_dir, ModelFile.TF_GRAPH_FILE)
fastrcnn_ckpt = '{}/{}'.format(model_dir, 'fasterrcnn_resnet50.pth')
self.meshmodel = Pixto3DNet(**modelconfig['model'])
self.detector = FasterRCNN(ckpt=fastrcnn_ckpt, device=self.device)
self.meshmodel.load_state_dict(
torch.load(model_path, map_location='cpu'))
self.netB = define_G(3, 3, 64, 'global', 4, 9, 1, 3, 'instance')
self.netF = define_G(3, 3, 64, 'global', 4, 9, 1, 3, 'instance')
self.netF.load_state_dict(torch.load(normal_front_model))
self.netB.load_state_dict(torch.load(normal_back_model))
self.netF = self.netF.to(self.device)
self.netB = self.netB.to(self.device)
self.netF.eval()
self.netB.eval()
self.meshmodel = self.meshmodel.to(self.device).eval()
self.portrait_matting = human_segmenter(model_path=human_seg_model)
b_min = np.array([-1, -1, -1])
b_max = np.array([1, 1, 1])
self.coords, self.mat = create_grid(modelconfig['resolution'], b_min,
b_max)
projection_matrix = np.identity(4)
projection_matrix[1, 1] = -1
self.calib = torch.Tensor(projection_matrix).float().to(self.device)
self.calib = self.calib[:3, :4].unsqueeze(0)
logger.info('model load over')
def get_mask(self, img):
result = self.portrait_matting.run(img)
result = result[..., None]
mask = result.repeat(3, axis=2)
return img, mask
@torch.no_grad()
def crop_img(self, img_url):
image = imread(img_url)[:, :, :3] / 255.
h, w, _ = image.shape
image_size = 512
image_tensor = torch.tensor(
image.transpose(2, 0, 1), dtype=torch.float32)[None, ...]
bbox = self.detector.run(image_tensor)
left = bbox[0]
right = bbox[2]
top = bbox[1]
bottom = bbox[3]
old_size = max(right - left, bottom - top)
center = np.array(
[right - (right - left) / 2.0, bottom - (bottom - top) / 2.0])
size = int(old_size * 1.1)
src_pts = np.array([[center[0] - size / 2, center[1] - size / 2],
[center[0] - size / 2, center[1] + size / 2],
[center[0] + size / 2, center[1] - size / 2]])
DST_PTS = np.array([[0, 0], [0, image_size - 1], [image_size - 1, 0]])
tform = estimate_transform('similarity', src_pts, DST_PTS)
dst_image = warp(
image, tform.inverse, output_shape=(image_size, image_size))
dst_image = (dst_image[:, :, ::-1] * 255).astype(np.uint8)
return dst_image
@torch.no_grad()
def generation_normal(self, img, mask):
to_tensor = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
im_512 = cv2.resize(img, (512, 512))
image_512 = Image.fromarray(im_512).convert('RGB')
image_512 = to_tensor(image_512).unsqueeze(0)
img = image_512.to(self.device)
nml_f = self.netF.forward(img)
nml_b = self.netB.forward(img)
mask = cv2.resize(mask, (512, 512))
mask = transforms.ToTensor()(mask).unsqueeze(0)
nml_f = (nml_f.cpu() * mask).detach().cpu().numpy()[0]
nml_f = (np.transpose(nml_f,
(1, 2, 0)) * 0.5 + 0.5)[:, :, ::-1] * 255.0
nml_b = (nml_b.cpu() * mask).detach().cpu().numpy()[0]
nml_b = (np.transpose(nml_b,
(1, 2, 0)) * 0.5 + 0.5)[:, :, ::-1] * 255.0
nml_f = nml_f.astype(np.uint8)
nml_b = nml_b.astype(np.uint8)
return nml_f, nml_b
# def forward(self, img, mask, normal_f, normal_b):

View File

@@ -0,0 +1,32 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import torch
from torch import nn
class Embedding(nn.Module):
def __init__(self, in_channels, N_freqs, logscale=True):
"""
Defines a function that embeds x to (x, sin(2^k x), cos(2^k x), ...)
in_channels: number of input channels (3 for both xyz and direction)
"""
super(Embedding, self).__init__()
self.N_freqs = N_freqs
self.in_channels = in_channels
self.name = 'Embedding'
self.funcs = [torch.sin, torch.cos]
self.out_channels = in_channels * (len(self.funcs) * N_freqs + 1)
self.input_para = dict(in_channels=in_channels, N_freqs=N_freqs)
if logscale:
self.freq_bands = 2**torch.linspace(0, N_freqs - 1, N_freqs)
else:
self.freq_bands = torch.linspace(1, 2**(N_freqs - 1), N_freqs)
def forward(self, x):
out = [x]
for freq in self.freq_bands:
for func in self.funcs:
out += [func(freq * x)]
return torch.cat(out, 1)

View File

@@ -0,0 +1,142 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import torch
import torch.nn as nn
from .Embedding import Embedding
from .geometry import index, orthogonal, perspective
from .Res_backbone import Res_hournet
from .Surface_head import Surface_Head
class Pixto3DNet(nn.Module):
def __init__(self,
backbone,
head,
rgbhead,
embedding,
projection_mode: str = 'orthogonal',
error_term: str = 'mse',
num_views: int = 1):
"""
Parameters:
backbone: parameter of networks to extract image features
head: parameter of networks to predict value in surface
rgbhead: parameter of networks to predict rgb of point
embedding: parameter of networks to normalize depth of camera coordinate
projection_mode: how to render your 3d model to images
error_term: train loss
num_view: how many images from which you want to reconstruct model
"""
super(Pixto3DNet, self).__init__()
self.backbone = Res_hournet(**backbone)
self.head = Surface_Head(**head)
self.rgbhead = Surface_Head(**rgbhead)
self.depth = Embedding(**embedding)
if error_term == 'mse':
self.error_term = nn.MSELoss(reduction='none')
elif error_term == 'bce':
self.error_term = nn.BCELoss(reduction='none')
elif error_term == 'l1':
self.error_term = nn.L1Loss(reduction='none')
else:
raise NotImplementedError
self.index = index
self.projection = orthogonal if projection_mode == 'orthogonal' else perspective
self.num_views = num_views
self.im_feat_list = []
self.intermediate_preds_list = []
def extract_features(self, images: torch.Tensor):
self.im_feat_list = self.backbone(images)
def query(self, points, calibs, transforms=None, labels=None):
if labels is not None:
self.labels = labels
xyz = self.projection(points, calibs, transforms)
xy = xyz[:, :2, :]
xyz_feat = self.depth(xyz)
self.intermediate_preds_list = []
im_feat_256 = self.im_feat_list[0]
im_feat_512 = self.im_feat_list[1]
point_local_feat_list = [
self.index(im_feat_256, xy),
self.index(im_feat_512, xy), xyz_feat
]
point_local_feat = torch.cat(point_local_feat_list, 1)
pred, phi = self.head(point_local_feat)
self.intermediate_preds_list.append(pred)
self.phi = phi
self.preds = self.intermediate_preds_list[-1]
def get_preds(self):
return self.preds
def query_rgb(self, points, calibs, transforms=None):
xyz = self.projection(points, calibs, transforms)
xy = xyz[:, :2, :]
xyz_feat = self.depth(xyz)
self.intermediate_preds_list = []
im_feat_256 = self.im_feat_list[0]
im_feat_512 = self.im_feat_list[1]
point_local_feat_list = [
self.index(im_feat_256, xy),
self.index(im_feat_512, xy), xyz_feat
]
point_local_feat = torch.cat(point_local_feat_list, 1)
pred, phi = self.head(point_local_feat)
rgb_point_feat = torch.cat([point_local_feat, phi], 1)
rgb, phi = self.rgbhead(rgb_point_feat)
return rgb
def get_error(self):
error = 0
lc = torch.tensor(self.labels.shape[0] * self.labels.shape[1]
* self.labels.shape[2])
inw = torch.sum(self.labels)
weight_in = inw / lc
weight = torch.abs(self.labels - weight_in)
lamda = 1 / torch.mean(weight)
for preds in self.intermediate_preds_list:
error += lamda * torch.mean(
self.error_term(preds, self.labels) * weight)
error /= len(self.intermediate_preds_list)
return error
def forward(self,
images,
points,
calibs,
surpoint=None,
transforms=None,
labels=None):
self.extract_features(images)
self.query(
points=points, calibs=calibs, transforms=transforms, labels=labels)
if surpoint is not None:
rgb = self.query_rgb(
points=surpoint, calibs=calibs, transforms=transforms)
else:
rgb = None
res = self.preds
return res, rgb

View File

@@ -0,0 +1,330 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
class BlurPool(nn.Module):
def __init__(self,
channels,
pad_type='reflect',
filt_size=4,
stride=2,
pad_off=0):
super(BlurPool, self).__init__()
self.filt_size = filt_size
self.pad_off = pad_off
self.pad_sizes = [
int(1. * (filt_size - 1) / 2),
int(np.ceil(1. * (filt_size - 1) / 2)),
int(1. * (filt_size - 1) / 2),
int(np.ceil(1. * (filt_size - 1) / 2))
]
self.pad_sizes = [pad_size + pad_off for pad_size in self.pad_sizes]
self.stride = stride
self.off = int((self.stride - 1) / 2.)
self.channels = channels
if (self.filt_size == 1):
a = np.array([
1.,
])
elif (self.filt_size == 2):
a = np.array([1., 1.])
elif (self.filt_size == 3):
a = np.array([1., 2., 1.])
elif (self.filt_size == 4):
a = np.array([1., 3., 3., 1.])
elif (self.filt_size == 5):
a = np.array([1., 4., 6., 4., 1.])
elif (self.filt_size == 6):
a = np.array([1., 5., 10., 10., 5., 1.])
elif (self.filt_size == 7):
a = np.array([1., 6., 15., 20., 15., 6., 1.])
filt = torch.Tensor(a[:, None] * a[None, :])
filt = filt / torch.sum(filt)
self.register_buffer(
'filt', filt[None, None, :, :].repeat((self.channels, 1, 1, 1)))
self.pad = get_pad_layer(pad_type)(self.pad_sizes)
def forward(self, inp):
if (self.filt_size == 1):
if (self.pad_off == 0):
return inp[:, :, ::self.stride, ::self.stride]
else:
return self.pad(inp)[:, :, ::self.stride, ::self.stride]
else:
return F.conv2d(
self.pad(inp),
self.filt,
stride=self.stride,
groups=inp.shape[1])
def get_pad_layer(pad_type):
if (pad_type in ['refl', 'reflect']):
PadLayer = nn.ReflectionPad2d
elif (pad_type in ['repl', 'replicate']):
PadLayer = nn.ReplicationPad2d
elif (pad_type == 'zero'):
PadLayer = nn.ZeroPad2d
else:
print('Pad type [%s] not recognized' % pad_type)
return PadLayer
class ConvBlockv1(nn.Module):
def __init__(self, in_planes, out_planes, norm='batch'):
super(ConvBlockv1, self).__init__()
self.conv1 = nn.Conv2d(
in_planes,
int(out_planes / 2),
kernel_size=3,
stride=1,
padding=1,
bias=False)
self.conv2 = nn.Conv2d(
int(out_planes / 2),
int(out_planes / 4),
kernel_size=3,
stride=1,
padding=1,
bias=False)
self.conv3 = nn.Conv2d(
int(out_planes / 4),
int(out_planes / 4),
kernel_size=3,
stride=1,
padding=1,
bias=False)
if norm == 'batch':
self.bn2 = nn.BatchNorm2d(int(out_planes / 2))
self.bn3 = nn.BatchNorm2d(int(out_planes / 4))
self.bn4 = nn.BatchNorm2d(out_planes)
elif norm == 'group':
self.bn2 = nn.GroupNorm(32, int(out_planes / 2))
self.bn3 = nn.GroupNorm(32, int(out_planes / 4))
self.bn4 = nn.GroupNorm(32, out_planes)
if in_planes != out_planes:
self.downsample = nn.Sequential(
nn.Conv2d(
in_planes, out_planes, kernel_size=1, stride=1,
bias=False), )
else:
self.downsample = None
def forward(self, x):
residual = x
out1 = self.conv1(x)
out2 = self.bn2(out1)
out2 = F.relu(out2, True)
out2 = self.conv2(out2)
out3 = self.bn3(out2)
out3 = F.relu(out3, True)
out3 = self.conv3(out3)
out3 = torch.cat((out1, out2, out3), 1)
if self.downsample is not None:
residual = self.downsample(residual)
out3 += residual
out4 = self.bn4(out3)
out4 = F.relu(out4, True)
return out4
class Conv2(nn.Module):
def __init__(self, in_planes, out_planes, norm='batch'):
super(Conv2, self).__init__()
self.conv1 = nn.Conv2d(
in_planes,
int(out_planes / 4),
kernel_size=3,
stride=1,
padding=1,
bias=False)
self.conv2 = nn.Conv2d(
in_planes,
int(out_planes / 4),
kernel_size=5,
stride=1,
padding=2,
bias=False)
self.conv3 = nn.Conv2d(
in_planes,
int(out_planes / 2),
kernel_size=1,
stride=1,
padding=0,
bias=False)
self.conv4 = nn.Conv2d(
out_planes,
out_planes,
kernel_size=3,
stride=1,
padding=1,
bias=False)
if norm == 'batch':
self.bn1 = nn.BatchNorm2d(int(out_planes / 4))
self.bn2 = nn.BatchNorm2d(int(out_planes / 4))
self.bn3 = nn.BatchNorm2d(int(out_planes / 2))
self.bn4 = nn.BatchNorm2d(out_planes)
elif norm == 'group':
self.bn1 = nn.GroupNorm(32, int(out_planes / 4))
self.bn2 = nn.GroupNorm(32, int(out_planes / 4))
self.bn3 = nn.GroupNorm(32, int(out_planes / 2))
self.bn4 = nn.GroupNorm(32, out_planes)
if in_planes != out_planes:
self.downsample = nn.Sequential(
nn.Conv2d(
in_planes, out_planes, kernel_size=1, stride=1,
bias=False), )
else:
self.downsample = None
def forward(self, x):
residual = x
out1 = self.conv1(x)
out1 = self.bn1(out1)
out1 = F.relu(out1, True)
out2 = self.conv2(x)
out2 = self.bn2(out2)
out2 = F.relu(out2, True)
out3 = self.conv3(x)
out3 = self.bn3(out3)
out3 = F.relu(out3, True)
out3 = torch.cat((out1, out2, out3), 1)
if self.downsample is not None:
residual = self.downsample(residual)
out = out3 + residual
out = self.conv4(out)
out = self.bn4(out)
out = F.relu(out, True)
return out
class Res_hournet(nn.Module):
def __init__(self, norm: str = 'group', use_front=False, use_back=False):
"""
Defines a backbone of human reconstruction
use_front & use_back is the normal map of input image
"""
super(Res_hournet, self).__init__()
self.name = 'Res Backbone'
self.norm = norm
inc = 3
self.use_front = use_front
self.use_back = use_back
if self.use_front:
inc += 3
if self.use_back:
inc += 3
self.conv1 = nn.Conv2d(inc, 64, kernel_size=7, stride=1, padding=3)
if self.norm == 'batch':
self.bn1 = nn.BatchNorm2d(64)
elif self.norm == 'group':
self.bn1 = nn.GroupNorm(32, 64)
self.down_conv1 = BlurPool(
64, pad_type='reflect', filt_size=7, stride=2)
self.conv2 = ConvBlockv1(64, 128, self.norm)
self.down_conv2 = BlurPool(
128, pad_type='reflect', filt_size=7, stride=2)
self.conv3 = ConvBlockv1(128, 128, self.norm)
self.conv5 = ConvBlockv1(128, 256, self.norm)
self.conv6 = ConvBlockv1(256, 256, self.norm)
self.down_conv3 = BlurPool(
256, pad_type='reflect', filt_size=5, stride=2)
self.conv7 = ConvBlockv1(256, 256, self.norm)
self.conv8 = ConvBlockv1(256, 256, self.norm)
self.conv9 = ConvBlockv1(256, 256, self.norm)
self.conv10 = ConvBlockv1(256, 256, self.norm)
self.conv10_1 = ConvBlockv1(256, 512, self.norm)
self.conv10_2 = Conv2(512, 512, self.norm)
self.down_conv4 = BlurPool(
512, pad_type='reflect', filt_size=5, stride=2)
self.conv11 = Conv2(512, 512, self.norm)
self.conv12 = ConvBlockv1(512, 512, self.norm)
self.conv13 = Conv2(512, 512, self.norm)
self.conv14 = ConvBlockv1(512, 512, self.norm)
self.conv15 = Conv2(512, 512, self.norm)
self.conv16 = ConvBlockv1(512, 512, self.norm)
self.conv17 = Conv2(512, 512, self.norm)
self.conv18 = ConvBlockv1(512, 512, self.norm)
self.conv19 = Conv2(512, 512, self.norm)
self.conv20 = ConvBlockv1(512, 512, self.norm)
self.conv21 = Conv2(512, 512, self.norm)
self.conv22 = ConvBlockv1(512, 512, self.norm)
self.up_down1 = nn.Conv2d(1024, 512, 3, 1, 1, bias=False)
self.upconv1 = ConvBlockv1(512, 512, self.norm)
self.upconv1_1 = ConvBlockv1(512, 512, self.norm)
self.up_down2 = nn.Conv2d(768, 512, 3, 1, 1, bias=False)
self.upconv2 = ConvBlockv1(512, 256, self.norm)
self.upconv2_1 = ConvBlockv1(256, 256, self.norm)
self.up_down3 = nn.Conv2d(384, 256, 3, 1, 1, bias=False)
self.upconv3 = ConvBlockv1(256, 256, self.norm)
self.upconv3_4 = nn.Conv2d(256, 128, 3, 1, 1, bias=False)
self.up_down4 = nn.Conv2d(192, 64, 3, 1, 1, bias=False)
self.upconv4 = ConvBlockv1(64, 64, 'batch')
def forward(self, x):
out0 = self.bn1(self.conv1(x))
out1 = self.down_conv1(out0)
out1 = self.conv2(out1)
out2 = self.down_conv2(out1)
out2 = self.conv3(out2)
out2 = self.conv5(out2)
out2 = self.conv6(out2)
out3 = self.down_conv3(out2)
out3 = self.conv7(out3)
out3 = self.conv9(self.conv8(out3))
out3 = self.conv10(out3)
out3 = self.conv10_2(self.conv10_1(out3))
out4 = self.down_conv4(out3)
out4 = self.conv12(self.conv11(out4))
out4 = self.conv14(self.conv13(out4))
out4 = self.conv16(self.conv15(out4))
out4 = self.conv18(self.conv17(out4))
out4 = self.conv20(self.conv19(out4))
out4 = self.conv22(self.conv21(out4))
up1 = F.interpolate(
out4, scale_factor=2, mode='bicubic', align_corners=True)
up1 = torch.cat((up1, out3), 1)
up1 = self.up_down1(up1)
up1 = self.upconv1(up1)
up1 = self.upconv1_1(up1)
up2 = F.interpolate(
up1, scale_factor=2, mode='bicubic', align_corners=True)
up2 = torch.cat((up2, out2), 1)
up2 = self.up_down2(up2)
up2 = self.upconv2(up2)
up2 = self.upconv2_1(up2)
up3 = F.interpolate(
up2, scale_factor=2, mode='bicubic', align_corners=True)
up3 = torch.cat((up3, out1), 1)
up3 = self.up_down3(up3)
up3 = self.upconv3(up3)
up34 = self.upconv3_4(up3)
up4 = F.interpolate(
up34, scale_factor=2, mode='bicubic', align_corners=True)
up4 = torch.cat((up4, out0), 1)
up4 = self.up_down4(up4)
up4 = self.upconv4(up4)
return up3, up4

View File

@@ -0,0 +1,73 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import torch
import torch.nn as nn
import torch.nn.functional as F
class Surface_Head(nn.Module):
"""
MLP: aims at learn iso-surface function Implicit function
"""
def __init__(self,
filter_channels,
merge_layer=0,
res_layers=[],
norm='group',
last_op=None):
super(Surface_Head, self).__init__()
if last_op == 'sigmoid':
self.last_op = nn.Sigmoid()
elif last_op == 'tanh':
self.last_op = nn.Tanh()
else:
raise NotImplementedError(
'only sigmoid/tanh function could be used')
self.filters = nn.ModuleList()
self.norms = nn.ModuleList()
self.merge_layer = merge_layer if merge_layer > 0 else len(
filter_channels) // 2
self.res_layers = res_layers
self.norm = norm
for i in range(0, len(filter_channels) - 1):
if i in self.res_layers:
self.filters.append(
nn.Conv1d(filter_channels[i] + filter_channels[0],
filter_channels[i + 1], 1))
else:
self.filters.append(
nn.Conv1d(filter_channels[i], filter_channels[i + 1], 1))
if i != len(filter_channels) - 2:
if norm == 'group':
self.norms.append(nn.GroupNorm(32, filter_channels[i + 1]))
elif norm == 'batch':
self.norms.append(nn.BatchNorm1d(filter_channels[i + 1]))
def forward(self, feature):
"""feature may include multiple view inputs
Parameters:
feature: [B, C_in, N]
return:
prediction: [B, C_out, N] and merge layer features
"""
y = feature
tmpy = feature
phi = None
for i, f in enumerate(self.filters):
y = f(y if i not in self.res_layers else torch.cat([y, tmpy], 1))
if i != len(self.filters) - 1:
if self.norm not in ['batch', 'group']:
y = F.leaky_relu(y)
else:
y = F.leaky_relu(self.norms[i](y))
if i == self.merge_layer:
phi = y.clone()
if self.last_op is not None:
y = self.last_op(y)
return y, phi

View File

@@ -0,0 +1,66 @@
# The implementation here is modified based on Pytorch, originally BSD License and publicly avaialbe at
# https://github.com/pytorch/pytorch
import numpy as np
import torch
class FasterRCNN(object):
''' detect body
COCO_INSTANCE_CATEGORY_NAMES = [
'__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
]
'''
def __init__(self, ckpt=None, device='cuda:0'):
"""
https://pytorch.org/docs/stable/torchvision/models.html#faster-r-cnn
"""
import torchvision
if ckpt is None:
self.model = torchvision.models.detection.fasterrcnn_resnet50_fpn(
pretrained=True)
else:
self.model = torchvision.models.detection.fasterrcnn_resnet50_fpn(
pretrained=False)
state_dict = torch.load(ckpt, map_location='cpu')
self.model.load_state_dict(state_dict)
self.model.to(device)
self.model.eval()
self.device = device
@torch.no_grad()
def run(self, input):
"""
return: detected box, [x1, y1, x2, y2]
"""
prediction = self.model(input.to(self.device))[0]
inds = (prediction['labels'] == 1) * (prediction['scores'] > 0.5)
if len(inds) < 1:
return None
else:
bbox = prediction['boxes'][inds][0].cpu().numpy()
return bbox
@torch.no_grad()
def run_multi(self, input):
"""
return: detected box, [x1, y1, x2, y2]
"""
prediction = self.model(input.to(self.device))[0]
inds = (prediction['labels'] == 1) * (prediction['scores'] > 0.9)
if len(inds) < 1:
return None
else:
bbox = prediction['boxes'][inds].cpu().numpy()
return bbox

View File

@@ -0,0 +1,61 @@
# The implementation here is modified based on PIFU, originally MIT License and publicly avaialbe at
# https://github.com/shunsukesaito/PIFu/blob/master/lib/geometry.py
import torch
def index(feat, uv):
"""
extract image features at floating coordinates with bilinear interpolation
args:
feat: [B, C, H, W] image features
uv: [B, 2, N] normalized image coordinates ranged in [-1, 1]
return:
[B, C, N] sampled pixel values
"""
uv = uv.transpose(1, 2)
uv = uv.unsqueeze(2)
samples = torch.nn.functional.grid_sample(feat, uv, align_corners=True)
return samples[:, :, :, 0]
def orthogonal(points, calib, transform=None):
"""
project points onto screen space using orthogonal projection
args:
points: [B, 3, N] 3d points in world coordinates
calib: [B, 3, 4] projection matrix
transform: [B, 2, 3] screen space transformation
return:
[B, 3, N] 3d coordinates in screen space
"""
rot = calib[:, :3, :3]
trans = calib[:, :3, 3:4]
pts = torch.baddbmm(trans, rot, points)
if transform is not None:
scale = transform[:2, :2]
shift = transform[:2, 2:3]
pts[:, :2, :] = torch.baddbmm(shift, scale, pts[:, :2, :])
return pts
def perspective(points, calib, transform=None):
"""
project points onto screen space using perspective projection
args:
points: [B, 3, N] 3d points in world coordinates
calib: [B, 3, 4] projection matrix
transform: [B, 2, 3] screen space trasnformation
return:
[B, 3, N] 3d coordinates in screen space
"""
rot = calib[:, :3, :3]
trans = calib[:, :3, 3:4]
homo = torch.baddbmm(trans, rot, points)
xy = homo[:, :2, :] / homo[:, 2:3, :]
if transform is not None:
scale = transform[:2, :2]
shift = transform[:2, 2:3]
xy = torch.baddbmm(shift, scale, xy)
xyz = torch.cat([xy, homo[:, 2:3, :]], 1)
return xyz

View File

@@ -0,0 +1,60 @@
# The implementation is also open-sourced by the authors, and available at
# https://www.modelscope.cn/models/damo/cv_unet_image-matting/summary
import cv2
import numpy as np
import tensorflow as tf
if tf.__version__ >= '2.0':
tf = tf.compat.v1
class human_segmenter(object):
def __init__(self, model_path):
super(human_segmenter, self).__init__()
f = tf.gfile.FastGFile(model_path, 'rb')
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
persisted_graph = tf.import_graph_def(graph_def, name='')
config = tf.ConfigProto()
config.gpu_options.per_process_gpu_memory_fraction = 0.3 # 占用GPU 30%的显存
self.sess = tf.InteractiveSession(graph=persisted_graph, config=config)
self.image_node = self.sess.graph.get_tensor_by_name('input_image:0')
self.output_node = self.sess.graph.get_tensor_by_name('output_png:0')
self.logits_node = self.sess.graph.get_tensor_by_name('if_person:0')
print('human_segmenter init done')
def image_preprocess(self, img):
if len(img.shape) == 2:
img = np.dstack((img, img, img))
elif img.shape[2] == 4:
img = img[:, :, :3]
img = img.astype(np.float)
return img
def run(self, img):
image_feed = self.image_preprocess(img)
output_img_value, logits_value = self.sess.run(
[self.output_node, self.logits_node],
feed_dict={self.image_node: image_feed})
mask = output_img_value[:, :, -1]
return mask
def get_human_bbox(self, mask):
print('dtype:{}, max:{},shape:{}'.format(mask.dtype, np.max(mask),
mask.shape))
ret, thresh = cv2.threshold(mask, 127, 255, 0)
contours, hierarchy = cv2.findContours(thresh, cv2.RETR_TREE,
cv2.CHAIN_APPROX_SIMPLE)
if len(contours) == 0:
return None
contoursArea = [cv2.contourArea(c) for c in contours]
max_area_index = contoursArea.index(max(contoursArea))
bbox = cv2.boundingRect(contours[max_area_index])
return bbox
def release(self):
self.sess.close()

View File

@@ -0,0 +1,366 @@
# The implementation here is modified based on Pix2PixHD, originally BSD License and publicly avaialbe at
# https://github.com/NVIDIA/pix2pixHD
import functools
import numpy as np
import torch
import torch.nn as nn
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find('BatchNorm2d') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
def get_norm_layer(norm_type='instance'):
if norm_type == 'batch':
norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
elif norm_type == 'instance':
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False)
else:
raise NotImplementedError('normalization layer [%s] is not found'
% norm_type)
return norm_layer
def define_G(input_nc,
output_nc,
ngf,
netG,
n_downsample_global=3,
n_blocks_global=9,
n_local_enhancers=1,
n_blocks_local=3,
norm='instance',
gpu_ids=[],
last_op=nn.Tanh()):
norm_layer = get_norm_layer(norm_type=norm)
if netG == 'global':
netG = GlobalGenerator(
input_nc,
output_nc,
ngf,
n_downsample_global,
n_blocks_global,
norm_layer,
last_op=last_op)
elif netG == 'local':
netG = LocalEnhancer(input_nc, output_nc, ngf, n_downsample_global,
n_blocks_global, n_local_enhancers,
n_blocks_local, norm_layer)
elif netG == 'encoder':
netG = Encoder(input_nc, output_nc, ngf, n_downsample_global,
norm_layer)
else:
raise ('generator not implemented!')
if len(gpu_ids) > 0:
assert (torch.cuda.is_available())
netG.cuda(gpu_ids[0])
netG.apply(weights_init)
return netG
def print_network(net):
if isinstance(net, list):
net = net[0]
num_params = 0
for param in net.parameters():
num_params += param.numel()
print(net)
print('Total number of parameters: %d' % num_params)
"""
Generator code
"""
class LocalEnhancer(nn.Module):
def __init__(self,
input_nc,
output_nc,
ngf=32,
n_downsample_global=3,
n_blocks_global=9,
n_local_enhancers=1,
n_blocks_local=3,
norm_layer=nn.BatchNorm2d,
padding_type='reflect'):
super(LocalEnhancer, self).__init__()
self.n_local_enhancers = n_local_enhancers
ngf_global = ngf * (2**n_local_enhancers)
model_global = GlobalGenerator(input_nc, output_nc, ngf_global,
n_downsample_global, n_blocks_global,
norm_layer).model
model_global = [model_global[i] for i in range(len(model_global) - 3)
] # get rid of final convolution layers
self.model = nn.Sequential(*model_global)
for n in range(1, n_local_enhancers + 1):
ngf_global = ngf * (2**(n_local_enhancers - n))
model_downsample = [
nn.ReflectionPad2d(3),
nn.Conv2d(input_nc, ngf_global, kernel_size=7, padding=0),
norm_layer(ngf_global),
nn.ReLU(True),
nn.Conv2d(
ngf_global,
ngf_global * 2,
kernel_size=3,
stride=2,
padding=1),
norm_layer(ngf_global * 2),
nn.ReLU(True)
]
model_upsample = []
for i in range(n_blocks_local):
model_upsample += [
ResnetBlock(
ngf_global * 2,
padding_type=padding_type,
norm_layer=norm_layer)
]
model_upsample += [
nn.ConvTranspose2d(
ngf_global * 2,
ngf_global,
kernel_size=3,
stride=2,
padding=1,
output_padding=1),
norm_layer(ngf_global),
nn.ReLU(True)
]
if n == n_local_enhancers:
model_upsample += [
nn.ReflectionPad2d(3),
nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0),
nn.Tanh()
]
setattr(self, 'model' + str(n) + '_1',
nn.Sequential(*model_downsample))
setattr(self, 'model' + str(n) + '_2',
nn.Sequential(*model_upsample))
self.downsample = nn.AvgPool2d(
3, stride=2, padding=[1, 1], count_include_pad=False)
def forward(self, input):
input_downsampled = [input]
for i in range(self.n_local_enhancers):
input_downsampled.append(self.downsample(input_downsampled[-1]))
output_prev = self.model(input_downsampled[-1])
for n_local_enhancers in range(1, self.n_local_enhancers + 1):
model_downsample = getattr(self,
'model' + str(n_local_enhancers) + '_1')
model_upsample = getattr(self,
'model' + str(n_local_enhancers) + '_2')
input_i = input_downsampled[self.n_local_enhancers
- n_local_enhancers]
output_prev = model_upsample(
model_downsample(input_i) + output_prev)
return output_prev
class GlobalGenerator(nn.Module):
def __init__(self,
input_nc,
output_nc,
ngf=64,
n_downsampling=3,
n_blocks=9,
norm_layer=nn.BatchNorm2d,
padding_type='reflect',
last_op=nn.Tanh()):
assert (n_blocks >= 0)
super(GlobalGenerator, self).__init__()
activation = nn.ReLU(True)
model = [
nn.ReflectionPad2d(3),
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0),
norm_layer(ngf), activation
]
for i in range(n_downsampling):
mult = 2**i
model += [
nn.Conv2d(
ngf * mult,
ngf * mult * 2,
kernel_size=3,
stride=2,
padding=1),
norm_layer(ngf * mult * 2), activation
]
mult = 2**n_downsampling
for i in range(n_blocks):
model += [
ResnetBlock(
ngf * mult,
padding_type=padding_type,
activation=activation,
norm_layer=norm_layer)
]
for i in range(n_downsampling):
mult = 2**(n_downsampling - i)
model += [
nn.ConvTranspose2d(
ngf * mult,
int(ngf * mult / 2),
kernel_size=3,
stride=2,
padding=1,
output_padding=1),
norm_layer(int(ngf * mult / 2)), activation
]
model += [
nn.ReflectionPad2d(3),
nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)
]
if last_op is not None:
model += [last_op]
self.model = nn.Sequential(*model)
def forward(self, input):
return self.model(input)
"""
Define a resnet block
"""
class ResnetBlock(nn.Module):
def __init__(self,
dim,
padding_type,
norm_layer,
activation=nn.ReLU(True),
use_dropout=False):
super(ResnetBlock, self).__init__()
self.conv_block = self.build_conv_block(dim, padding_type, norm_layer,
activation, use_dropout)
def build_conv_block(self, dim, padding_type, norm_layer, activation,
use_dropout):
conv_block = []
p = 0
if padding_type == 'reflect':
conv_block += [nn.ReflectionPad2d(1)]
elif padding_type == 'replicate':
conv_block += [nn.ReplicationPad2d(1)]
elif padding_type == 'zero':
p = 1
else:
raise NotImplementedError('padding [%s] is not implemented'
% padding_type)
conv_block += [
nn.Conv2d(dim, dim, kernel_size=3, padding=p),
norm_layer(dim), activation
]
if use_dropout:
conv_block += [nn.Dropout(0.5)]
p = 0
if padding_type == 'reflect':
conv_block += [nn.ReflectionPad2d(1)]
elif padding_type == 'replicate':
conv_block += [nn.ReplicationPad2d(1)]
elif padding_type == 'zero':
p = 1
else:
raise NotImplementedError('padding [%s] is not implemented'
% padding_type)
conv_block += [
nn.Conv2d(dim, dim, kernel_size=3, padding=p),
norm_layer(dim)
]
return nn.Sequential(*conv_block)
def forward(self, x):
out = x + self.conv_block(x)
return out
class Encoder(nn.Module):
def __init__(self,
input_nc,
output_nc,
ngf=32,
n_downsampling=4,
norm_layer=nn.BatchNorm2d):
super(Encoder, self).__init__()
self.output_nc = output_nc
model = [
nn.ReflectionPad2d(3),
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0),
norm_layer(ngf),
nn.ReLU(True)
]
for i in range(n_downsampling):
mult = 2**i
model += [
nn.Conv2d(
ngf * mult,
ngf * mult * 2,
kernel_size=3,
stride=2,
padding=1),
norm_layer(ngf * mult * 2),
nn.ReLU(True)
]
for i in range(n_downsampling):
mult = 2**(n_downsampling - i)
model += [
nn.ConvTranspose2d(
ngf * mult,
int(ngf * mult / 2),
kernel_size=3,
stride=2,
padding=1,
output_padding=1),
norm_layer(int(ngf * mult / 2)),
nn.ReLU(True)
]
model += [
nn.ReflectionPad2d(3),
nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0),
nn.Tanh()
]
self.model = nn.Sequential(*model)
def forward(self, input, inst):
outputs = self.model(input)
outputs_mean = outputs.clone()
inst_list = np.unique(inst.cpu().numpy().astype(int))
for i in inst_list:
for b in range(input.size()[0]):
indices = (inst[b:b + 1] == int(i)).nonzero()
for j in range(self.output_nc):
output_ins = outputs[indices[:, 0] + b, indices[:, 1] + j,
indices[:, 2], indices[:, 3]]
mean_feat = torch.mean(output_ins).expand_as(output_ins)
outputs_mean[indices[:, 0] + b, indices[:, 1] + j,
indices[:, 2], indices[:, 3]] = mean_feat
return outputs_mean

View File

@@ -0,0 +1,178 @@
import os
import mcubes
import numpy as np
import torch
def save_obj_mesh_with_color(mesh_path, verts, faces, colors):
file = open(mesh_path, 'w')
for idx, v in enumerate(verts):
c = colors[idx]
file.write('v %.4f %.4f %.4f %.4f %.4f %.4f\n' %
(v[0], v[1], v[2], c[0], c[1], c[2]))
for f in faces:
f_plus = f + 1
file.write('f %d %d %d\n' % (f_plus[0], f_plus[2], f_plus[1]))
file.close()
def save_obj_mesh(mesh_path, verts, faces):
file = open(mesh_path, 'w')
for idx, v in enumerate(verts):
file.write('v %.4f %.4f %.4f\n' % (v[0], v[1], v[2]))
for f in faces:
f_plus = f + 1
file.write('f %d %d %d\n' % (f_plus[0], f_plus[2], f_plus[1]))
file.close()
def to_tensor(img):
if len(img.shape) == 2:
img = img[:, :, np.newaxis]
img = torch.from_numpy(img.transpose(2, 0, 1)).float()
img = img / 255.
return img
def reconstruction(net, calib_tensor, coords, mat, num_samples=50000):
def eval_func(points):
points = np.expand_dims(points, axis=0)
points = np.repeat(points, 1, axis=0)
samples = torch.from_numpy(points).cuda().float()
net.query(samples, calib_tensor)
pred = net.get_preds()
pred = pred[0]
return pred.detach().cpu().numpy()
sdf = eval_grid(coords, eval_func, num_samples=num_samples)
vertices, faces = mcubes.marching_cubes(sdf, 0.5)
verts = np.matmul(mat[:3, :3], vertices.T) + mat[:3, 3:4]
verts = verts.T
return verts, faces
def keep_largest(mesh_big):
mesh_lst = mesh_big.split(only_watertight=False)
keep_mesh = mesh_lst[0]
for mesh in mesh_lst:
if mesh.vertices.shape[0] > keep_mesh.vertices.shape[0]:
keep_mesh = mesh
return keep_mesh
def eval_grid(coords,
eval_func,
init_resolution=64,
threshold=0.01,
num_samples=512 * 512 * 512):
resolution = coords.shape[1:4]
sdf = np.zeros(resolution)
dirty = np.ones(resolution, dtype=np.bool)
grid_mask = np.zeros(resolution, dtype=np.bool)
reso = resolution[0] // init_resolution
while reso > 0:
grid_mask[0:resolution[0]:reso, 0:resolution[1]:reso,
0:resolution[2]:reso] = True
test_mask = np.logical_and(grid_mask, dirty)
points = coords[:, test_mask]
sdf[test_mask] = batch_eval(points, eval_func, num_samples=num_samples)
dirty[test_mask] = False
if reso <= 1:
break
for x in range(0, resolution[0] - reso, reso):
for y in range(0, resolution[1] - reso, reso):
for z in range(0, resolution[2] - reso, reso):
if not dirty[x + reso // 2, y + reso // 2, z + reso // 2]:
continue
v0 = sdf[x, y, z]
v1 = sdf[x, y, z + reso]
v2 = sdf[x, y + reso, z]
v3 = sdf[x, y + reso, z + reso]
v4 = sdf[x + reso, y, z]
v5 = sdf[x + reso, y, z + reso]
v6 = sdf[x + reso, y + reso, z]
v7 = sdf[x + reso, y + reso, z + reso]
v = np.array([v0, v1, v2, v3, v4, v5, v6, v7])
v_min = v.min()
v_max = v.max()
if (v_max - v_min) < threshold:
sdf[x:x + reso, y:y + reso,
z:z + reso] = (v_max + v_min) / 2
dirty[x:x + reso, y:y + reso, z:z + reso] = False
reso //= 2
return sdf.reshape(resolution)
def batch_eval(points, eval_func, num_samples=512 * 512 * 512):
num_pts = points.shape[1]
sdf = np.zeros(num_pts)
num_batches = num_pts // num_samples
for i in range(num_batches):
sdf[i * num_samples:i * num_samples + num_samples] = eval_func(
points[:, i * num_samples:i * num_samples + num_samples])
if num_pts % num_samples:
sdf[num_batches * num_samples:] = eval_func(points[:, num_batches
* num_samples:])
return sdf
def create_grid(res,
b_min=np.array([0, 0, 0]),
b_max=np.array([1, 1, 1]),
transform=None):
coords = np.mgrid[:res, :res, :res]
coords = coords.reshape(3, -1)
coords_matrix = np.eye(4)
length = b_max - b_min
coords_matrix[0, 0] = length[0] / res
coords_matrix[1, 1] = length[1] / res
coords_matrix[2, 2] = length[2] / res
coords_matrix[0:3, 3] = b_min
coords = np.matmul(coords_matrix[:3, :3], coords) + coords_matrix[:3, 3:4]
if transform is not None:
coords = np.matmul(transform[:3, :3], coords) + transform[:3, 3:4]
coords_matrix = np.matmul(transform, coords_matrix)
coords = coords.reshape(3, res, res, res)
return coords, coords_matrix
def get_submesh(verts,
faces,
color,
verts_retained=None,
faces_retained=None,
min_vert_in_face=2):
verts = verts
faces = faces
colors = color
if verts_retained is not None:
if verts_retained.dtype != 'bool':
vert_mask = np.zeros(len(verts), dtype=bool)
vert_mask[verts_retained] = True
else:
vert_mask = verts_retained
bool_faces = np.sum(
vert_mask[faces.ravel()].reshape(-1, 3), axis=1) > min_vert_in_face
elif faces_retained is not None:
if faces_retained.dtype != 'bool':
bool_faces = np.zeros(len(faces_retained), dtype=bool)
else:
bool_faces = faces_retained
new_faces = faces[bool_faces]
vertex_ids = list(set(new_faces.ravel()))
oldtonew = -1 * np.ones([len(verts)])
oldtonew[vertex_ids] = range(0, len(vertex_ids))
new_verts = verts[vertex_ids]
new_colors = colors[vertex_ids]
new_faces = oldtonew[new_faces].astype('int32')
return (new_verts, new_faces, new_colors, bool_faces, vertex_ids)

View File

@@ -457,6 +457,16 @@ TASK_OUTPUTS = {
# }
Tasks.face_reconstruction: [OutputKeys.OUTPUT],
# 3D human reconstruction result for single sample
# {
# "output": {
# "vertices": np.array with shape(n, 3),
# "faces": np.array with shape(n, 3),
# "colors": np.array with shape(n, 3),
# }
# }
Tasks.human_reconstruction: [OutputKeys.OUTPUT],
# 2D hand keypoints result for single sample
# {
# "keypoints": [

View File

@@ -316,6 +316,8 @@ TASK_INPUTS = {
},
Tasks.action_detection:
InputType.VIDEO,
Tasks.human_reconstruction:
InputType.IMAGE,
Tasks.image_reid_person:
InputType.IMAGE,
Tasks.video_inpainting: {

View File

@@ -0,0 +1,109 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import shutil
from typing import Any, Dict
import numpy as np
import torch
import trimesh
from modelscope.metainfo import Pipelines
from modelscope.models.cv.human_reconstruction.utils import (
keep_largest, reconstruction, save_obj_mesh, save_obj_mesh_with_color,
to_tensor)
from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline
from modelscope.pipelines.base import Input, Model, Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger
logger = get_logger()
@PIPELINES.register_module(
Tasks.human_reconstruction, module_name=Pipelines.human_reconstruction)
class HumanReconstructionPipeline(Pipeline):
def __init__(self, model: str, **kwargs):
"""The inference pipeline for human reconstruction task.
Human Reconstruction Pipeline. Given one image generate a human mesh.
Args:
model (`str` or `Model` or module instance): A model instance or a model local dir
or a model id in the model hub.
Example:
>>> from modelscope.pipelines import pipeline
>>> test_input = 'human_reconstruction.jpg' # input image path
>>> pipeline_humanRecon = pipeline('human-reconstruction',
model='damo/cv_hrnet_image-human-reconstruction')
>>> result = pipeline_humanRecon(test_input)
>>> output = result[OutputKeys.OUTPUT]
"""
super().__init__(model=model, **kwargs)
if not isinstance(self.model, Model):
logger.error('model object is not initialized.')
raise Exception('model object is not initialized.')
def preprocess(self, input: Input) -> Dict[str, Any]:
img_crop = self.model.crop_img(input)
img, mask = self.model.get_mask(img_crop)
normal_f, normal_b = self.model.generation_normal(img, mask)
image = to_tensor(img_crop) * 2 - 1
normal_b = to_tensor(normal_b) * 2 - 1
normal_f = to_tensor(normal_f) * 2 - 1
mask = to_tensor(mask)
result = {
'img': image,
'mask': mask,
'normal_F': normal_f,
'normal_B': normal_b
}
return result
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
image = input['img']
mask = input['mask']
normF = input['normal_F']
normB = input['normal_B']
normF[1, ...] = -normF[1, ...]
normB[0, ...] = -normB[0, ...]
img = image * mask
normal_b = normB * mask
normal_f = normF * mask
img = torch.cat([img, normal_f, normal_b], dim=0).float()
image_tensor = img.unsqueeze(0).to(self.model.device)
calib_tensor = self.model.calib
net = self.model.meshmodel
net.extract_features(image_tensor)
verts, faces = reconstruction(net, calib_tensor, self.model.coords,
self.model.mat)
pre_mesh = trimesh.Trimesh(
verts, faces, process=False, maintain_order=True)
final_mesh = keep_largest(pre_mesh)
verts = final_mesh.vertices
faces = final_mesh.faces
verts_tensor = torch.from_numpy(verts.T).unsqueeze(0).to(
self.model.device).float()
color = torch.zeros(verts.shape)
interval = 20000
for i in range(len(color) // interval):
left = i * interval
right = i * interval + interval
if i == len(color) // interval - 1:
right = -1
pred_color = net.query_rgb(verts_tensor[:, :, left:right],
calib_tensor)
rgb = pred_color[0].detach().cpu() * 0.5 + 0.5
color[left:right] = rgb.T
vert_min = np.min(verts[:, 1])
verts[:, 1] = verts[:, 1] - vert_min
save_obj_mesh('human_reconstruction.obj', verts, faces)
save_obj_mesh_with_color('human_color.obj', verts, faces,
color.numpy())
results = {'vertices': verts, 'faces': faces, 'colors': color.numpy()}
return {OutputKeys.OUTPUT: results}
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
return inputs

View File

@@ -141,6 +141,9 @@ class CVTasks(object):
# 3d face reconstruction
face_reconstruction = 'face-reconstruction'
# 3d human reconstruction
human_reconstruction = 'human-reconstruction'
# image quality assessment mos
image_quality_assessment_mos = 'image-quality-assessment-mos'
# motion generation

View File

@@ -60,6 +60,7 @@ torchmetrics>=0.6.2
torchsummary>=1.5.1
torchvision
transformers>=4.26.0
trimesh
ujson
utils
videofeatures_clipit>=1.0

View File

@@ -0,0 +1,46 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os.path as osp
import sys
import unittest
from modelscope.hub.snapshot_download import snapshot_download
from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline
from modelscope.pipelines.base import Pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level
sys.path.append('.')
class HumanReconstructionTest(unittest.TestCase):
def setUp(self) -> None:
self.task = Tasks.human_reconstruction
self.model_id = 'damo/cv_hrnet_image-human-reconstruction'
self.test_image = 'data/test/images/human_reconstruction.jpg'
def pipeline_inference(self, pipeline: Pipeline, input_location: str):
result = pipeline(input_location)
mesh = result[OutputKeys.OUTPUT]
print(
f'Output to {osp.abspath("human_reconstruction.obj")}, vertices num: {mesh["vertices"].shape}'
)
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_by_direct_model_download(self):
model_dir = snapshot_download(self.model_id)
human_reconstruction = pipeline(
Tasks.human_reconstruction, model=model_dir)
print('running')
self.pipeline_inference(human_reconstruction, self.test_image)
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_modelhub(self):
human_reconstruction = pipeline(
Tasks.human_reconstruction, model=self.model_id)
self.pipeline_inference(human_reconstruction, self.test_image)
if __name__ == '__main__':
unittest.main()