mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 16:27:45 +01:00
add human reconstruction task
单图人体重建任务 Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11778199 * add human reconstruction task
This commit is contained in:
3
data/test/images/human_reconstruction.jpg
Normal file
3
data/test/images/human_reconstruction.jpg
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:06ec486657dffbf244563a844c98c19d49b7a45b99da702403b52bb9e6bf3c0a
|
||||
size 226072
|
||||
@@ -78,6 +78,7 @@ class Models(object):
|
||||
image_body_reshaping = 'image-body-reshaping'
|
||||
image_skychange = 'image-skychange'
|
||||
video_human_matting = 'video-human-matting'
|
||||
human_reconstruction = 'human-reconstruction'
|
||||
video_frame_interpolation = 'video-frame-interpolation'
|
||||
video_object_segmentation = 'video-object-segmentation'
|
||||
video_deinterlace = 'video-deinterlace'
|
||||
@@ -361,6 +362,7 @@ class Pipelines(object):
|
||||
referring_video_object_segmentation = 'referring-video-object-segmentation'
|
||||
image_skychange = 'image-skychange'
|
||||
video_human_matting = 'video-human-matting'
|
||||
human_reconstruction = 'human-reconstruction'
|
||||
vision_middleware_multi_task = 'vision-middleware-multi-task'
|
||||
vidt = 'vidt'
|
||||
video_frame_interpolation = 'video-frame-interpolation'
|
||||
@@ -751,6 +753,8 @@ DEFAULT_MODEL_FOR_PIPELINE = {
|
||||
'damo/cv_video-inpainting'),
|
||||
Tasks.video_human_matting: (Pipelines.video_human_matting,
|
||||
'damo/cv_effnetv2_video-human-matting'),
|
||||
Tasks.human_reconstruction: (Pipelines.human_reconstruction,
|
||||
'damo/cv_hrnet_image-human-reconstruction'),
|
||||
Tasks.video_frame_interpolation: (
|
||||
Pipelines.video_frame_interpolation,
|
||||
'damo/cv_raft_video-frame-interpolation'),
|
||||
|
||||
@@ -5,13 +5,13 @@ from . import (action_recognition, animal_recognition, bad_image_detecting,
|
||||
body_2d_keypoints, body_3d_keypoints, cartoon,
|
||||
cmdssl_video_embedding, controllable_image_generation,
|
||||
crowd_counting, face_2d_keypoints, face_detection,
|
||||
face_generation, face_reconstruction, human_wholebody_keypoint,
|
||||
image_classification, image_color_enhance, image_colorization,
|
||||
image_defrcn_fewshot, image_denoise, image_inpainting,
|
||||
image_instance_segmentation, image_matching,
|
||||
image_mvs_depth_estimation, image_panoptic_segmentation,
|
||||
image_portrait_enhancement, image_probing_model,
|
||||
image_quality_assessment_degradation,
|
||||
face_generation, face_reconstruction, human_reconstruction,
|
||||
human_wholebody_keypoint, image_classification,
|
||||
image_color_enhance, image_colorization, image_defrcn_fewshot,
|
||||
image_denoise, image_inpainting, image_instance_segmentation,
|
||||
image_matching, image_mvs_depth_estimation,
|
||||
image_panoptic_segmentation, image_portrait_enhancement,
|
||||
image_probing_model, image_quality_assessment_degradation,
|
||||
image_quality_assessment_man, image_quality_assessment_mos,
|
||||
image_reid_person, image_restoration,
|
||||
image_semantic_segmentation, image_to_image_generation,
|
||||
|
||||
137
modelscope/models/cv/human_reconstruction/Reconstruction.py
Normal file
137
modelscope/models/cv/human_reconstruction/Reconstruction.py
Normal file
@@ -0,0 +1,137 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os.path as osp
|
||||
from typing import Optional
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import PIL.Image as Image
|
||||
import torch
|
||||
import torchvision.transforms as transforms
|
||||
from skimage.io import imread
|
||||
from skimage.transform import estimate_transform, warp
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models.base import Tensor, TorchModel
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.models.cv.human_reconstruction.models.detectors import \
|
||||
FasterRCNN
|
||||
from modelscope.models.cv.human_reconstruction.models.human_segmenter import \
|
||||
human_segmenter
|
||||
from modelscope.models.cv.human_reconstruction.models.networks import define_G
|
||||
from modelscope.models.cv.human_reconstruction.models.PixToMesh import \
|
||||
Pixto3DNet
|
||||
from modelscope.models.cv.human_reconstruction.utils import create_grid
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
@MODELS.register_module(
|
||||
Tasks.human_reconstruction, module_name=Models.human_reconstruction)
|
||||
class HumanReconstruction(TorchModel):
|
||||
|
||||
def __init__(self, model_dir, modelconfig, *args, **kwargs):
|
||||
"""The HumanReconstruction is modified based on PiFuHD and pix2pixhd, publicly available at
|
||||
https://shunsukesaito.github.io/PIFuHD/ &
|
||||
https://github.com/NVIDIA/pix2pixHD
|
||||
|
||||
Args:
|
||||
model_dir: the root directory of the model files
|
||||
modelconfig: the config param path of the model
|
||||
"""
|
||||
super().__init__(model_dir=model_dir, *args, **kwargs)
|
||||
if torch.cuda.is_available():
|
||||
self.device = torch.device('cuda')
|
||||
logger.info('Use GPU: {}'.format(self.device))
|
||||
else:
|
||||
self.device = torch.device('cpu')
|
||||
logger.info('Use CPU: {}'.format(self.device))
|
||||
|
||||
model_path = '{}/{}'.format(model_dir, ModelFile.TORCH_MODEL_FILE)
|
||||
normal_back_model = '{}/{}'.format(model_dir, 'Norm_B_GAN.pth')
|
||||
normal_front_model = '{}/{}'.format(model_dir, 'Norm_F_GAN.pth')
|
||||
human_seg_model = '{}/{}'.format(model_dir, ModelFile.TF_GRAPH_FILE)
|
||||
fastrcnn_ckpt = '{}/{}'.format(model_dir, 'fasterrcnn_resnet50.pth')
|
||||
self.meshmodel = Pixto3DNet(**modelconfig['model'])
|
||||
self.detector = FasterRCNN(ckpt=fastrcnn_ckpt, device=self.device)
|
||||
self.meshmodel.load_state_dict(
|
||||
torch.load(model_path, map_location='cpu'))
|
||||
self.netB = define_G(3, 3, 64, 'global', 4, 9, 1, 3, 'instance')
|
||||
self.netF = define_G(3, 3, 64, 'global', 4, 9, 1, 3, 'instance')
|
||||
self.netF.load_state_dict(torch.load(normal_front_model))
|
||||
self.netB.load_state_dict(torch.load(normal_back_model))
|
||||
self.netF = self.netF.to(self.device)
|
||||
self.netB = self.netB.to(self.device)
|
||||
self.netF.eval()
|
||||
self.netB.eval()
|
||||
self.meshmodel = self.meshmodel.to(self.device).eval()
|
||||
self.portrait_matting = human_segmenter(model_path=human_seg_model)
|
||||
b_min = np.array([-1, -1, -1])
|
||||
b_max = np.array([1, 1, 1])
|
||||
self.coords, self.mat = create_grid(modelconfig['resolution'], b_min,
|
||||
b_max)
|
||||
projection_matrix = np.identity(4)
|
||||
projection_matrix[1, 1] = -1
|
||||
self.calib = torch.Tensor(projection_matrix).float().to(self.device)
|
||||
self.calib = self.calib[:3, :4].unsqueeze(0)
|
||||
logger.info('model load over')
|
||||
|
||||
def get_mask(self, img):
|
||||
result = self.portrait_matting.run(img)
|
||||
result = result[..., None]
|
||||
mask = result.repeat(3, axis=2)
|
||||
return img, mask
|
||||
|
||||
@torch.no_grad()
|
||||
def crop_img(self, img_url):
|
||||
image = imread(img_url)[:, :, :3] / 255.
|
||||
h, w, _ = image.shape
|
||||
image_size = 512
|
||||
image_tensor = torch.tensor(
|
||||
image.transpose(2, 0, 1), dtype=torch.float32)[None, ...]
|
||||
bbox = self.detector.run(image_tensor)
|
||||
left = bbox[0]
|
||||
right = bbox[2]
|
||||
top = bbox[1]
|
||||
bottom = bbox[3]
|
||||
|
||||
old_size = max(right - left, bottom - top)
|
||||
center = np.array(
|
||||
[right - (right - left) / 2.0, bottom - (bottom - top) / 2.0])
|
||||
size = int(old_size * 1.1)
|
||||
src_pts = np.array([[center[0] - size / 2, center[1] - size / 2],
|
||||
[center[0] - size / 2, center[1] + size / 2],
|
||||
[center[0] + size / 2, center[1] - size / 2]])
|
||||
DST_PTS = np.array([[0, 0], [0, image_size - 1], [image_size - 1, 0]])
|
||||
tform = estimate_transform('similarity', src_pts, DST_PTS)
|
||||
dst_image = warp(
|
||||
image, tform.inverse, output_shape=(image_size, image_size))
|
||||
dst_image = (dst_image[:, :, ::-1] * 255).astype(np.uint8)
|
||||
return dst_image
|
||||
|
||||
@torch.no_grad()
|
||||
def generation_normal(self, img, mask):
|
||||
to_tensor = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
||||
])
|
||||
im_512 = cv2.resize(img, (512, 512))
|
||||
image_512 = Image.fromarray(im_512).convert('RGB')
|
||||
image_512 = to_tensor(image_512).unsqueeze(0)
|
||||
img = image_512.to(self.device)
|
||||
nml_f = self.netF.forward(img)
|
||||
nml_b = self.netB.forward(img)
|
||||
mask = cv2.resize(mask, (512, 512))
|
||||
mask = transforms.ToTensor()(mask).unsqueeze(0)
|
||||
nml_f = (nml_f.cpu() * mask).detach().cpu().numpy()[0]
|
||||
nml_f = (np.transpose(nml_f,
|
||||
(1, 2, 0)) * 0.5 + 0.5)[:, :, ::-1] * 255.0
|
||||
nml_b = (nml_b.cpu() * mask).detach().cpu().numpy()[0]
|
||||
nml_b = (np.transpose(nml_b,
|
||||
(1, 2, 0)) * 0.5 + 0.5)[:, :, ::-1] * 255.0
|
||||
nml_f = nml_f.astype(np.uint8)
|
||||
nml_b = nml_b.astype(np.uint8)
|
||||
return nml_f, nml_b
|
||||
|
||||
# def forward(self, img, mask, normal_f, normal_b):
|
||||
@@ -0,0 +1,32 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class Embedding(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, N_freqs, logscale=True):
|
||||
"""
|
||||
Defines a function that embeds x to (x, sin(2^k x), cos(2^k x), ...)
|
||||
in_channels: number of input channels (3 for both xyz and direction)
|
||||
"""
|
||||
super(Embedding, self).__init__()
|
||||
self.N_freqs = N_freqs
|
||||
self.in_channels = in_channels
|
||||
self.name = 'Embedding'
|
||||
self.funcs = [torch.sin, torch.cos]
|
||||
self.out_channels = in_channels * (len(self.funcs) * N_freqs + 1)
|
||||
self.input_para = dict(in_channels=in_channels, N_freqs=N_freqs)
|
||||
|
||||
if logscale:
|
||||
self.freq_bands = 2**torch.linspace(0, N_freqs - 1, N_freqs)
|
||||
else:
|
||||
self.freq_bands = torch.linspace(1, 2**(N_freqs - 1), N_freqs)
|
||||
|
||||
def forward(self, x):
|
||||
out = [x]
|
||||
for freq in self.freq_bands:
|
||||
for func in self.funcs:
|
||||
out += [func(freq * x)]
|
||||
|
||||
return torch.cat(out, 1)
|
||||
142
modelscope/models/cv/human_reconstruction/models/PixToMesh.py
Normal file
142
modelscope/models/cv/human_reconstruction/models/PixToMesh.py
Normal file
@@ -0,0 +1,142 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .Embedding import Embedding
|
||||
from .geometry import index, orthogonal, perspective
|
||||
from .Res_backbone import Res_hournet
|
||||
from .Surface_head import Surface_Head
|
||||
|
||||
|
||||
class Pixto3DNet(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
backbone,
|
||||
head,
|
||||
rgbhead,
|
||||
embedding,
|
||||
projection_mode: str = 'orthogonal',
|
||||
error_term: str = 'mse',
|
||||
num_views: int = 1):
|
||||
"""
|
||||
Parameters:
|
||||
backbone: parameter of networks to extract image features
|
||||
head: parameter of networks to predict value in surface
|
||||
rgbhead: parameter of networks to predict rgb of point
|
||||
embedding: parameter of networks to normalize depth of camera coordinate
|
||||
projection_mode: how to render your 3d model to images
|
||||
error_term: train loss
|
||||
num_view: how many images from which you want to reconstruct model
|
||||
"""
|
||||
super(Pixto3DNet, self).__init__()
|
||||
|
||||
self.backbone = Res_hournet(**backbone)
|
||||
self.head = Surface_Head(**head)
|
||||
self.rgbhead = Surface_Head(**rgbhead)
|
||||
self.depth = Embedding(**embedding)
|
||||
|
||||
if error_term == 'mse':
|
||||
self.error_term = nn.MSELoss(reduction='none')
|
||||
elif error_term == 'bce':
|
||||
self.error_term = nn.BCELoss(reduction='none')
|
||||
elif error_term == 'l1':
|
||||
self.error_term = nn.L1Loss(reduction='none')
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
self.index = index
|
||||
self.projection = orthogonal if projection_mode == 'orthogonal' else perspective
|
||||
|
||||
self.num_views = num_views
|
||||
self.im_feat_list = []
|
||||
self.intermediate_preds_list = []
|
||||
|
||||
def extract_features(self, images: torch.Tensor):
|
||||
self.im_feat_list = self.backbone(images)
|
||||
|
||||
def query(self, points, calibs, transforms=None, labels=None):
|
||||
if labels is not None:
|
||||
self.labels = labels
|
||||
|
||||
xyz = self.projection(points, calibs, transforms)
|
||||
|
||||
xy = xyz[:, :2, :]
|
||||
xyz_feat = self.depth(xyz)
|
||||
|
||||
self.intermediate_preds_list = []
|
||||
|
||||
im_feat_256 = self.im_feat_list[0]
|
||||
im_feat_512 = self.im_feat_list[1]
|
||||
|
||||
point_local_feat_list = [
|
||||
self.index(im_feat_256, xy),
|
||||
self.index(im_feat_512, xy), xyz_feat
|
||||
]
|
||||
point_local_feat = torch.cat(point_local_feat_list, 1)
|
||||
|
||||
pred, phi = self.head(point_local_feat)
|
||||
self.intermediate_preds_list.append(pred)
|
||||
self.phi = phi
|
||||
|
||||
self.preds = self.intermediate_preds_list[-1]
|
||||
|
||||
def get_preds(self):
|
||||
return self.preds
|
||||
|
||||
def query_rgb(self, points, calibs, transforms=None):
|
||||
xyz = self.projection(points, calibs, transforms)
|
||||
|
||||
xy = xyz[:, :2, :]
|
||||
xyz_feat = self.depth(xyz)
|
||||
|
||||
self.intermediate_preds_list = []
|
||||
|
||||
im_feat_256 = self.im_feat_list[0]
|
||||
im_feat_512 = self.im_feat_list[1]
|
||||
|
||||
point_local_feat_list = [
|
||||
self.index(im_feat_256, xy),
|
||||
self.index(im_feat_512, xy), xyz_feat
|
||||
]
|
||||
point_local_feat = torch.cat(point_local_feat_list, 1)
|
||||
|
||||
pred, phi = self.head(point_local_feat)
|
||||
rgb_point_feat = torch.cat([point_local_feat, phi], 1)
|
||||
rgb, phi = self.rgbhead(rgb_point_feat)
|
||||
return rgb
|
||||
|
||||
def get_error(self):
|
||||
error = 0
|
||||
lc = torch.tensor(self.labels.shape[0] * self.labels.shape[1]
|
||||
* self.labels.shape[2])
|
||||
inw = torch.sum(self.labels)
|
||||
weight_in = inw / lc
|
||||
weight = torch.abs(self.labels - weight_in)
|
||||
lamda = 1 / torch.mean(weight)
|
||||
for preds in self.intermediate_preds_list:
|
||||
error += lamda * torch.mean(
|
||||
self.error_term(preds, self.labels) * weight)
|
||||
error /= len(self.intermediate_preds_list)
|
||||
|
||||
return error
|
||||
|
||||
def forward(self,
|
||||
images,
|
||||
points,
|
||||
calibs,
|
||||
surpoint=None,
|
||||
transforms=None,
|
||||
labels=None):
|
||||
self.extract_features(images)
|
||||
|
||||
self.query(
|
||||
points=points, calibs=calibs, transforms=transforms, labels=labels)
|
||||
|
||||
if surpoint is not None:
|
||||
rgb = self.query_rgb(
|
||||
points=surpoint, calibs=calibs, transforms=transforms)
|
||||
else:
|
||||
rgb = None
|
||||
res = self.preds
|
||||
|
||||
return res, rgb
|
||||
330
modelscope/models/cv/human_reconstruction/models/Res_backbone.py
Normal file
330
modelscope/models/cv/human_reconstruction/models/Res_backbone.py
Normal file
@@ -0,0 +1,330 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class BlurPool(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
channels,
|
||||
pad_type='reflect',
|
||||
filt_size=4,
|
||||
stride=2,
|
||||
pad_off=0):
|
||||
super(BlurPool, self).__init__()
|
||||
self.filt_size = filt_size
|
||||
self.pad_off = pad_off
|
||||
self.pad_sizes = [
|
||||
int(1. * (filt_size - 1) / 2),
|
||||
int(np.ceil(1. * (filt_size - 1) / 2)),
|
||||
int(1. * (filt_size - 1) / 2),
|
||||
int(np.ceil(1. * (filt_size - 1) / 2))
|
||||
]
|
||||
self.pad_sizes = [pad_size + pad_off for pad_size in self.pad_sizes]
|
||||
self.stride = stride
|
||||
self.off = int((self.stride - 1) / 2.)
|
||||
self.channels = channels
|
||||
|
||||
if (self.filt_size == 1):
|
||||
a = np.array([
|
||||
1.,
|
||||
])
|
||||
elif (self.filt_size == 2):
|
||||
a = np.array([1., 1.])
|
||||
elif (self.filt_size == 3):
|
||||
a = np.array([1., 2., 1.])
|
||||
elif (self.filt_size == 4):
|
||||
a = np.array([1., 3., 3., 1.])
|
||||
elif (self.filt_size == 5):
|
||||
a = np.array([1., 4., 6., 4., 1.])
|
||||
elif (self.filt_size == 6):
|
||||
a = np.array([1., 5., 10., 10., 5., 1.])
|
||||
elif (self.filt_size == 7):
|
||||
a = np.array([1., 6., 15., 20., 15., 6., 1.])
|
||||
|
||||
filt = torch.Tensor(a[:, None] * a[None, :])
|
||||
filt = filt / torch.sum(filt)
|
||||
self.register_buffer(
|
||||
'filt', filt[None, None, :, :].repeat((self.channels, 1, 1, 1)))
|
||||
|
||||
self.pad = get_pad_layer(pad_type)(self.pad_sizes)
|
||||
|
||||
def forward(self, inp):
|
||||
if (self.filt_size == 1):
|
||||
if (self.pad_off == 0):
|
||||
return inp[:, :, ::self.stride, ::self.stride]
|
||||
else:
|
||||
return self.pad(inp)[:, :, ::self.stride, ::self.stride]
|
||||
else:
|
||||
return F.conv2d(
|
||||
self.pad(inp),
|
||||
self.filt,
|
||||
stride=self.stride,
|
||||
groups=inp.shape[1])
|
||||
|
||||
|
||||
def get_pad_layer(pad_type):
|
||||
if (pad_type in ['refl', 'reflect']):
|
||||
PadLayer = nn.ReflectionPad2d
|
||||
elif (pad_type in ['repl', 'replicate']):
|
||||
PadLayer = nn.ReplicationPad2d
|
||||
elif (pad_type == 'zero'):
|
||||
PadLayer = nn.ZeroPad2d
|
||||
else:
|
||||
print('Pad type [%s] not recognized' % pad_type)
|
||||
return PadLayer
|
||||
|
||||
|
||||
class ConvBlockv1(nn.Module):
|
||||
|
||||
def __init__(self, in_planes, out_planes, norm='batch'):
|
||||
super(ConvBlockv1, self).__init__()
|
||||
self.conv1 = nn.Conv2d(
|
||||
in_planes,
|
||||
int(out_planes / 2),
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias=False)
|
||||
self.conv2 = nn.Conv2d(
|
||||
int(out_planes / 2),
|
||||
int(out_planes / 4),
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias=False)
|
||||
self.conv3 = nn.Conv2d(
|
||||
int(out_planes / 4),
|
||||
int(out_planes / 4),
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias=False)
|
||||
|
||||
if norm == 'batch':
|
||||
self.bn2 = nn.BatchNorm2d(int(out_planes / 2))
|
||||
self.bn3 = nn.BatchNorm2d(int(out_planes / 4))
|
||||
self.bn4 = nn.BatchNorm2d(out_planes)
|
||||
elif norm == 'group':
|
||||
self.bn2 = nn.GroupNorm(32, int(out_planes / 2))
|
||||
self.bn3 = nn.GroupNorm(32, int(out_planes / 4))
|
||||
self.bn4 = nn.GroupNorm(32, out_planes)
|
||||
|
||||
if in_planes != out_planes:
|
||||
self.downsample = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_planes, out_planes, kernel_size=1, stride=1,
|
||||
bias=False), )
|
||||
else:
|
||||
self.downsample = None
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
out1 = self.conv1(x)
|
||||
out2 = self.bn2(out1)
|
||||
out2 = F.relu(out2, True)
|
||||
out2 = self.conv2(out2)
|
||||
|
||||
out3 = self.bn3(out2)
|
||||
out3 = F.relu(out3, True)
|
||||
out3 = self.conv3(out3)
|
||||
out3 = torch.cat((out1, out2, out3), 1)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(residual)
|
||||
out3 += residual
|
||||
out4 = self.bn4(out3)
|
||||
out4 = F.relu(out4, True)
|
||||
return out4
|
||||
|
||||
|
||||
class Conv2(nn.Module):
|
||||
|
||||
def __init__(self, in_planes, out_planes, norm='batch'):
|
||||
super(Conv2, self).__init__()
|
||||
self.conv1 = nn.Conv2d(
|
||||
in_planes,
|
||||
int(out_planes / 4),
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias=False)
|
||||
self.conv2 = nn.Conv2d(
|
||||
in_planes,
|
||||
int(out_planes / 4),
|
||||
kernel_size=5,
|
||||
stride=1,
|
||||
padding=2,
|
||||
bias=False)
|
||||
self.conv3 = nn.Conv2d(
|
||||
in_planes,
|
||||
int(out_planes / 2),
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=False)
|
||||
self.conv4 = nn.Conv2d(
|
||||
out_planes,
|
||||
out_planes,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias=False)
|
||||
|
||||
if norm == 'batch':
|
||||
self.bn1 = nn.BatchNorm2d(int(out_planes / 4))
|
||||
self.bn2 = nn.BatchNorm2d(int(out_planes / 4))
|
||||
self.bn3 = nn.BatchNorm2d(int(out_planes / 2))
|
||||
self.bn4 = nn.BatchNorm2d(out_planes)
|
||||
elif norm == 'group':
|
||||
self.bn1 = nn.GroupNorm(32, int(out_planes / 4))
|
||||
self.bn2 = nn.GroupNorm(32, int(out_planes / 4))
|
||||
self.bn3 = nn.GroupNorm(32, int(out_planes / 2))
|
||||
self.bn4 = nn.GroupNorm(32, out_planes)
|
||||
|
||||
if in_planes != out_planes:
|
||||
self.downsample = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_planes, out_planes, kernel_size=1, stride=1,
|
||||
bias=False), )
|
||||
else:
|
||||
self.downsample = None
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
out1 = self.conv1(x)
|
||||
out1 = self.bn1(out1)
|
||||
out1 = F.relu(out1, True)
|
||||
|
||||
out2 = self.conv2(x)
|
||||
out2 = self.bn2(out2)
|
||||
out2 = F.relu(out2, True)
|
||||
|
||||
out3 = self.conv3(x)
|
||||
out3 = self.bn3(out3)
|
||||
out3 = F.relu(out3, True)
|
||||
out3 = torch.cat((out1, out2, out3), 1)
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(residual)
|
||||
out = out3 + residual
|
||||
out = self.conv4(out)
|
||||
out = self.bn4(out)
|
||||
out = F.relu(out, True)
|
||||
return out
|
||||
|
||||
|
||||
class Res_hournet(nn.Module):
|
||||
|
||||
def __init__(self, norm: str = 'group', use_front=False, use_back=False):
|
||||
"""
|
||||
Defines a backbone of human reconstruction
|
||||
use_front & use_back is the normal map of input image
|
||||
"""
|
||||
super(Res_hournet, self).__init__()
|
||||
self.name = 'Res Backbone'
|
||||
self.norm = norm
|
||||
inc = 3
|
||||
self.use_front = use_front
|
||||
self.use_back = use_back
|
||||
if self.use_front:
|
||||
inc += 3
|
||||
if self.use_back:
|
||||
inc += 3
|
||||
self.conv1 = nn.Conv2d(inc, 64, kernel_size=7, stride=1, padding=3)
|
||||
if self.norm == 'batch':
|
||||
self.bn1 = nn.BatchNorm2d(64)
|
||||
elif self.norm == 'group':
|
||||
self.bn1 = nn.GroupNorm(32, 64)
|
||||
self.down_conv1 = BlurPool(
|
||||
64, pad_type='reflect', filt_size=7, stride=2)
|
||||
self.conv2 = ConvBlockv1(64, 128, self.norm)
|
||||
self.down_conv2 = BlurPool(
|
||||
128, pad_type='reflect', filt_size=7, stride=2)
|
||||
self.conv3 = ConvBlockv1(128, 128, self.norm)
|
||||
self.conv5 = ConvBlockv1(128, 256, self.norm)
|
||||
self.conv6 = ConvBlockv1(256, 256, self.norm)
|
||||
self.down_conv3 = BlurPool(
|
||||
256, pad_type='reflect', filt_size=5, stride=2)
|
||||
self.conv7 = ConvBlockv1(256, 256, self.norm)
|
||||
self.conv8 = ConvBlockv1(256, 256, self.norm)
|
||||
self.conv9 = ConvBlockv1(256, 256, self.norm)
|
||||
self.conv10 = ConvBlockv1(256, 256, self.norm)
|
||||
self.conv10_1 = ConvBlockv1(256, 512, self.norm)
|
||||
self.conv10_2 = Conv2(512, 512, self.norm)
|
||||
self.down_conv4 = BlurPool(
|
||||
512, pad_type='reflect', filt_size=5, stride=2)
|
||||
self.conv11 = Conv2(512, 512, self.norm)
|
||||
self.conv12 = ConvBlockv1(512, 512, self.norm)
|
||||
self.conv13 = Conv2(512, 512, self.norm)
|
||||
self.conv14 = ConvBlockv1(512, 512, self.norm)
|
||||
self.conv15 = Conv2(512, 512, self.norm)
|
||||
self.conv16 = ConvBlockv1(512, 512, self.norm)
|
||||
self.conv17 = Conv2(512, 512, self.norm)
|
||||
self.conv18 = ConvBlockv1(512, 512, self.norm)
|
||||
self.conv19 = Conv2(512, 512, self.norm)
|
||||
self.conv20 = ConvBlockv1(512, 512, self.norm)
|
||||
self.conv21 = Conv2(512, 512, self.norm)
|
||||
self.conv22 = ConvBlockv1(512, 512, self.norm)
|
||||
|
||||
self.up_down1 = nn.Conv2d(1024, 512, 3, 1, 1, bias=False)
|
||||
self.upconv1 = ConvBlockv1(512, 512, self.norm)
|
||||
self.upconv1_1 = ConvBlockv1(512, 512, self.norm)
|
||||
self.up_down2 = nn.Conv2d(768, 512, 3, 1, 1, bias=False)
|
||||
self.upconv2 = ConvBlockv1(512, 256, self.norm)
|
||||
self.upconv2_1 = ConvBlockv1(256, 256, self.norm)
|
||||
self.up_down3 = nn.Conv2d(384, 256, 3, 1, 1, bias=False)
|
||||
self.upconv3 = ConvBlockv1(256, 256, self.norm)
|
||||
self.upconv3_4 = nn.Conv2d(256, 128, 3, 1, 1, bias=False)
|
||||
self.up_down4 = nn.Conv2d(192, 64, 3, 1, 1, bias=False)
|
||||
self.upconv4 = ConvBlockv1(64, 64, 'batch')
|
||||
|
||||
def forward(self, x):
|
||||
out0 = self.bn1(self.conv1(x))
|
||||
out1 = self.down_conv1(out0)
|
||||
out1 = self.conv2(out1)
|
||||
out2 = self.down_conv2(out1)
|
||||
out2 = self.conv3(out2)
|
||||
out2 = self.conv5(out2)
|
||||
out2 = self.conv6(out2)
|
||||
out3 = self.down_conv3(out2)
|
||||
out3 = self.conv7(out3)
|
||||
out3 = self.conv9(self.conv8(out3))
|
||||
out3 = self.conv10(out3)
|
||||
out3 = self.conv10_2(self.conv10_1(out3))
|
||||
out4 = self.down_conv4(out3)
|
||||
out4 = self.conv12(self.conv11(out4))
|
||||
out4 = self.conv14(self.conv13(out4))
|
||||
out4 = self.conv16(self.conv15(out4))
|
||||
out4 = self.conv18(self.conv17(out4))
|
||||
out4 = self.conv20(self.conv19(out4))
|
||||
out4 = self.conv22(self.conv21(out4))
|
||||
|
||||
up1 = F.interpolate(
|
||||
out4, scale_factor=2, mode='bicubic', align_corners=True)
|
||||
up1 = torch.cat((up1, out3), 1)
|
||||
up1 = self.up_down1(up1)
|
||||
up1 = self.upconv1(up1)
|
||||
up1 = self.upconv1_1(up1)
|
||||
|
||||
up2 = F.interpolate(
|
||||
up1, scale_factor=2, mode='bicubic', align_corners=True)
|
||||
up2 = torch.cat((up2, out2), 1)
|
||||
up2 = self.up_down2(up2)
|
||||
up2 = self.upconv2(up2)
|
||||
up2 = self.upconv2_1(up2)
|
||||
|
||||
up3 = F.interpolate(
|
||||
up2, scale_factor=2, mode='bicubic', align_corners=True)
|
||||
up3 = torch.cat((up3, out1), 1)
|
||||
up3 = self.up_down3(up3)
|
||||
up3 = self.upconv3(up3)
|
||||
|
||||
up34 = self.upconv3_4(up3)
|
||||
up4 = F.interpolate(
|
||||
up34, scale_factor=2, mode='bicubic', align_corners=True)
|
||||
up4 = torch.cat((up4, out0), 1)
|
||||
up4 = self.up_down4(up4)
|
||||
up4 = self.upconv4(up4)
|
||||
return up3, up4
|
||||
@@ -0,0 +1,73 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class Surface_Head(nn.Module):
|
||||
"""
|
||||
MLP: aims at learn iso-surface function Implicit function
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
filter_channels,
|
||||
merge_layer=0,
|
||||
res_layers=[],
|
||||
norm='group',
|
||||
last_op=None):
|
||||
super(Surface_Head, self).__init__()
|
||||
if last_op == 'sigmoid':
|
||||
self.last_op = nn.Sigmoid()
|
||||
elif last_op == 'tanh':
|
||||
self.last_op = nn.Tanh()
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
'only sigmoid/tanh function could be used')
|
||||
|
||||
self.filters = nn.ModuleList()
|
||||
self.norms = nn.ModuleList()
|
||||
self.merge_layer = merge_layer if merge_layer > 0 else len(
|
||||
filter_channels) // 2
|
||||
|
||||
self.res_layers = res_layers
|
||||
self.norm = norm
|
||||
|
||||
for i in range(0, len(filter_channels) - 1):
|
||||
if i in self.res_layers:
|
||||
self.filters.append(
|
||||
nn.Conv1d(filter_channels[i] + filter_channels[0],
|
||||
filter_channels[i + 1], 1))
|
||||
else:
|
||||
self.filters.append(
|
||||
nn.Conv1d(filter_channels[i], filter_channels[i + 1], 1))
|
||||
if i != len(filter_channels) - 2:
|
||||
if norm == 'group':
|
||||
self.norms.append(nn.GroupNorm(32, filter_channels[i + 1]))
|
||||
elif norm == 'batch':
|
||||
self.norms.append(nn.BatchNorm1d(filter_channels[i + 1]))
|
||||
|
||||
def forward(self, feature):
|
||||
"""feature may include multiple view inputs
|
||||
Parameters:
|
||||
feature: [B, C_in, N]
|
||||
return:
|
||||
prediction: [B, C_out, N] and merge layer features
|
||||
"""
|
||||
|
||||
y = feature
|
||||
tmpy = feature
|
||||
phi = None
|
||||
|
||||
for i, f in enumerate(self.filters):
|
||||
y = f(y if i not in self.res_layers else torch.cat([y, tmpy], 1))
|
||||
if i != len(self.filters) - 1:
|
||||
if self.norm not in ['batch', 'group']:
|
||||
y = F.leaky_relu(y)
|
||||
else:
|
||||
y = F.leaky_relu(self.norms[i](y))
|
||||
if i == self.merge_layer:
|
||||
phi = y.clone()
|
||||
|
||||
if self.last_op is not None:
|
||||
y = self.last_op(y)
|
||||
return y, phi
|
||||
@@ -0,0 +1,66 @@
|
||||
# The implementation here is modified based on Pytorch, originally BSD License and publicly avaialbe at
|
||||
# https://github.com/pytorch/pytorch
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
class FasterRCNN(object):
|
||||
''' detect body
|
||||
COCO_INSTANCE_CATEGORY_NAMES = [
|
||||
'__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
|
||||
'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
|
||||
'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
|
||||
'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
|
||||
'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
|
||||
'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
|
||||
'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
|
||||
'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
|
||||
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
|
||||
'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
|
||||
'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
|
||||
'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
|
||||
]
|
||||
'''
|
||||
|
||||
def __init__(self, ckpt=None, device='cuda:0'):
|
||||
"""
|
||||
https://pytorch.org/docs/stable/torchvision/models.html#faster-r-cnn
|
||||
"""
|
||||
import torchvision
|
||||
if ckpt is None:
|
||||
self.model = torchvision.models.detection.fasterrcnn_resnet50_fpn(
|
||||
pretrained=True)
|
||||
else:
|
||||
self.model = torchvision.models.detection.fasterrcnn_resnet50_fpn(
|
||||
pretrained=False)
|
||||
state_dict = torch.load(ckpt, map_location='cpu')
|
||||
self.model.load_state_dict(state_dict)
|
||||
self.model.to(device)
|
||||
self.model.eval()
|
||||
self.device = device
|
||||
|
||||
@torch.no_grad()
|
||||
def run(self, input):
|
||||
"""
|
||||
return: detected box, [x1, y1, x2, y2]
|
||||
"""
|
||||
prediction = self.model(input.to(self.device))[0]
|
||||
inds = (prediction['labels'] == 1) * (prediction['scores'] > 0.5)
|
||||
if len(inds) < 1:
|
||||
return None
|
||||
else:
|
||||
bbox = prediction['boxes'][inds][0].cpu().numpy()
|
||||
return bbox
|
||||
|
||||
@torch.no_grad()
|
||||
def run_multi(self, input):
|
||||
"""
|
||||
return: detected box, [x1, y1, x2, y2]
|
||||
"""
|
||||
prediction = self.model(input.to(self.device))[0]
|
||||
inds = (prediction['labels'] == 1) * (prediction['scores'] > 0.9)
|
||||
if len(inds) < 1:
|
||||
return None
|
||||
else:
|
||||
bbox = prediction['boxes'][inds].cpu().numpy()
|
||||
return bbox
|
||||
61
modelscope/models/cv/human_reconstruction/models/geometry.py
Normal file
61
modelscope/models/cv/human_reconstruction/models/geometry.py
Normal file
@@ -0,0 +1,61 @@
|
||||
# The implementation here is modified based on PIFU, originally MIT License and publicly avaialbe at
|
||||
# https://github.com/shunsukesaito/PIFu/blob/master/lib/geometry.py
|
||||
import torch
|
||||
|
||||
|
||||
def index(feat, uv):
|
||||
"""
|
||||
extract image features at floating coordinates with bilinear interpolation
|
||||
args:
|
||||
feat: [B, C, H, W] image features
|
||||
uv: [B, 2, N] normalized image coordinates ranged in [-1, 1]
|
||||
return:
|
||||
[B, C, N] sampled pixel values
|
||||
"""
|
||||
uv = uv.transpose(1, 2)
|
||||
uv = uv.unsqueeze(2)
|
||||
samples = torch.nn.functional.grid_sample(feat, uv, align_corners=True)
|
||||
return samples[:, :, :, 0]
|
||||
|
||||
|
||||
def orthogonal(points, calib, transform=None):
|
||||
"""
|
||||
project points onto screen space using orthogonal projection
|
||||
args:
|
||||
points: [B, 3, N] 3d points in world coordinates
|
||||
calib: [B, 3, 4] projection matrix
|
||||
transform: [B, 2, 3] screen space transformation
|
||||
return:
|
||||
[B, 3, N] 3d coordinates in screen space
|
||||
"""
|
||||
rot = calib[:, :3, :3]
|
||||
trans = calib[:, :3, 3:4]
|
||||
pts = torch.baddbmm(trans, rot, points)
|
||||
if transform is not None:
|
||||
scale = transform[:2, :2]
|
||||
shift = transform[:2, 2:3]
|
||||
pts[:, :2, :] = torch.baddbmm(shift, scale, pts[:, :2, :])
|
||||
return pts
|
||||
|
||||
|
||||
def perspective(points, calib, transform=None):
|
||||
"""
|
||||
project points onto screen space using perspective projection
|
||||
args:
|
||||
points: [B, 3, N] 3d points in world coordinates
|
||||
calib: [B, 3, 4] projection matrix
|
||||
transform: [B, 2, 3] screen space trasnformation
|
||||
return:
|
||||
[B, 3, N] 3d coordinates in screen space
|
||||
"""
|
||||
rot = calib[:, :3, :3]
|
||||
trans = calib[:, :3, 3:4]
|
||||
homo = torch.baddbmm(trans, rot, points)
|
||||
xy = homo[:, :2, :] / homo[:, 2:3, :]
|
||||
if transform is not None:
|
||||
scale = transform[:2, :2]
|
||||
shift = transform[:2, 2:3]
|
||||
xy = torch.baddbmm(shift, scale, xy)
|
||||
|
||||
xyz = torch.cat([xy, homo[:, 2:3, :]], 1)
|
||||
return xyz
|
||||
@@ -0,0 +1,60 @@
|
||||
# The implementation is also open-sourced by the authors, and available at
|
||||
# https://www.modelscope.cn/models/damo/cv_unet_image-matting/summary
|
||||
import cv2
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
if tf.__version__ >= '2.0':
|
||||
tf = tf.compat.v1
|
||||
|
||||
|
||||
class human_segmenter(object):
|
||||
|
||||
def __init__(self, model_path):
|
||||
super(human_segmenter, self).__init__()
|
||||
f = tf.gfile.FastGFile(model_path, 'rb')
|
||||
graph_def = tf.GraphDef()
|
||||
graph_def.ParseFromString(f.read())
|
||||
persisted_graph = tf.import_graph_def(graph_def, name='')
|
||||
|
||||
config = tf.ConfigProto()
|
||||
config.gpu_options.per_process_gpu_memory_fraction = 0.3 # 占用GPU 30%的显存
|
||||
self.sess = tf.InteractiveSession(graph=persisted_graph, config=config)
|
||||
|
||||
self.image_node = self.sess.graph.get_tensor_by_name('input_image:0')
|
||||
self.output_node = self.sess.graph.get_tensor_by_name('output_png:0')
|
||||
self.logits_node = self.sess.graph.get_tensor_by_name('if_person:0')
|
||||
print('human_segmenter init done')
|
||||
|
||||
def image_preprocess(self, img):
|
||||
if len(img.shape) == 2:
|
||||
img = np.dstack((img, img, img))
|
||||
elif img.shape[2] == 4:
|
||||
img = img[:, :, :3]
|
||||
img = img.astype(np.float)
|
||||
return img
|
||||
|
||||
def run(self, img):
|
||||
image_feed = self.image_preprocess(img)
|
||||
output_img_value, logits_value = self.sess.run(
|
||||
[self.output_node, self.logits_node],
|
||||
feed_dict={self.image_node: image_feed})
|
||||
mask = output_img_value[:, :, -1]
|
||||
return mask
|
||||
|
||||
def get_human_bbox(self, mask):
|
||||
print('dtype:{}, max:{},shape:{}'.format(mask.dtype, np.max(mask),
|
||||
mask.shape))
|
||||
ret, thresh = cv2.threshold(mask, 127, 255, 0)
|
||||
contours, hierarchy = cv2.findContours(thresh, cv2.RETR_TREE,
|
||||
cv2.CHAIN_APPROX_SIMPLE)
|
||||
if len(contours) == 0:
|
||||
return None
|
||||
|
||||
contoursArea = [cv2.contourArea(c) for c in contours]
|
||||
max_area_index = contoursArea.index(max(contoursArea))
|
||||
bbox = cv2.boundingRect(contours[max_area_index])
|
||||
return bbox
|
||||
|
||||
def release(self):
|
||||
self.sess.close()
|
||||
366
modelscope/models/cv/human_reconstruction/models/networks.py
Normal file
366
modelscope/models/cv/human_reconstruction/models/networks.py
Normal file
@@ -0,0 +1,366 @@
|
||||
# The implementation here is modified based on Pix2PixHD, originally BSD License and publicly avaialbe at
|
||||
# https://github.com/NVIDIA/pix2pixHD
|
||||
import functools
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def weights_init(m):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find('Conv') != -1:
|
||||
m.weight.data.normal_(0.0, 0.02)
|
||||
elif classname.find('BatchNorm2d') != -1:
|
||||
m.weight.data.normal_(1.0, 0.02)
|
||||
m.bias.data.fill_(0)
|
||||
|
||||
|
||||
def get_norm_layer(norm_type='instance'):
|
||||
if norm_type == 'batch':
|
||||
norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
|
||||
elif norm_type == 'instance':
|
||||
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False)
|
||||
else:
|
||||
raise NotImplementedError('normalization layer [%s] is not found'
|
||||
% norm_type)
|
||||
return norm_layer
|
||||
|
||||
|
||||
def define_G(input_nc,
|
||||
output_nc,
|
||||
ngf,
|
||||
netG,
|
||||
n_downsample_global=3,
|
||||
n_blocks_global=9,
|
||||
n_local_enhancers=1,
|
||||
n_blocks_local=3,
|
||||
norm='instance',
|
||||
gpu_ids=[],
|
||||
last_op=nn.Tanh()):
|
||||
norm_layer = get_norm_layer(norm_type=norm)
|
||||
if netG == 'global':
|
||||
netG = GlobalGenerator(
|
||||
input_nc,
|
||||
output_nc,
|
||||
ngf,
|
||||
n_downsample_global,
|
||||
n_blocks_global,
|
||||
norm_layer,
|
||||
last_op=last_op)
|
||||
elif netG == 'local':
|
||||
netG = LocalEnhancer(input_nc, output_nc, ngf, n_downsample_global,
|
||||
n_blocks_global, n_local_enhancers,
|
||||
n_blocks_local, norm_layer)
|
||||
elif netG == 'encoder':
|
||||
netG = Encoder(input_nc, output_nc, ngf, n_downsample_global,
|
||||
norm_layer)
|
||||
else:
|
||||
raise ('generator not implemented!')
|
||||
if len(gpu_ids) > 0:
|
||||
assert (torch.cuda.is_available())
|
||||
netG.cuda(gpu_ids[0])
|
||||
netG.apply(weights_init)
|
||||
return netG
|
||||
|
||||
|
||||
def print_network(net):
|
||||
if isinstance(net, list):
|
||||
net = net[0]
|
||||
num_params = 0
|
||||
for param in net.parameters():
|
||||
num_params += param.numel()
|
||||
print(net)
|
||||
print('Total number of parameters: %d' % num_params)
|
||||
|
||||
|
||||
"""
|
||||
Generator code
|
||||
"""
|
||||
|
||||
|
||||
class LocalEnhancer(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
input_nc,
|
||||
output_nc,
|
||||
ngf=32,
|
||||
n_downsample_global=3,
|
||||
n_blocks_global=9,
|
||||
n_local_enhancers=1,
|
||||
n_blocks_local=3,
|
||||
norm_layer=nn.BatchNorm2d,
|
||||
padding_type='reflect'):
|
||||
super(LocalEnhancer, self).__init__()
|
||||
self.n_local_enhancers = n_local_enhancers
|
||||
|
||||
ngf_global = ngf * (2**n_local_enhancers)
|
||||
model_global = GlobalGenerator(input_nc, output_nc, ngf_global,
|
||||
n_downsample_global, n_blocks_global,
|
||||
norm_layer).model
|
||||
model_global = [model_global[i] for i in range(len(model_global) - 3)
|
||||
] # get rid of final convolution layers
|
||||
self.model = nn.Sequential(*model_global)
|
||||
|
||||
for n in range(1, n_local_enhancers + 1):
|
||||
ngf_global = ngf * (2**(n_local_enhancers - n))
|
||||
model_downsample = [
|
||||
nn.ReflectionPad2d(3),
|
||||
nn.Conv2d(input_nc, ngf_global, kernel_size=7, padding=0),
|
||||
norm_layer(ngf_global),
|
||||
nn.ReLU(True),
|
||||
nn.Conv2d(
|
||||
ngf_global,
|
||||
ngf_global * 2,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1),
|
||||
norm_layer(ngf_global * 2),
|
||||
nn.ReLU(True)
|
||||
]
|
||||
model_upsample = []
|
||||
for i in range(n_blocks_local):
|
||||
model_upsample += [
|
||||
ResnetBlock(
|
||||
ngf_global * 2,
|
||||
padding_type=padding_type,
|
||||
norm_layer=norm_layer)
|
||||
]
|
||||
|
||||
model_upsample += [
|
||||
nn.ConvTranspose2d(
|
||||
ngf_global * 2,
|
||||
ngf_global,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
output_padding=1),
|
||||
norm_layer(ngf_global),
|
||||
nn.ReLU(True)
|
||||
]
|
||||
|
||||
if n == n_local_enhancers:
|
||||
model_upsample += [
|
||||
nn.ReflectionPad2d(3),
|
||||
nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0),
|
||||
nn.Tanh()
|
||||
]
|
||||
|
||||
setattr(self, 'model' + str(n) + '_1',
|
||||
nn.Sequential(*model_downsample))
|
||||
setattr(self, 'model' + str(n) + '_2',
|
||||
nn.Sequential(*model_upsample))
|
||||
|
||||
self.downsample = nn.AvgPool2d(
|
||||
3, stride=2, padding=[1, 1], count_include_pad=False)
|
||||
|
||||
def forward(self, input):
|
||||
input_downsampled = [input]
|
||||
for i in range(self.n_local_enhancers):
|
||||
input_downsampled.append(self.downsample(input_downsampled[-1]))
|
||||
|
||||
output_prev = self.model(input_downsampled[-1])
|
||||
for n_local_enhancers in range(1, self.n_local_enhancers + 1):
|
||||
model_downsample = getattr(self,
|
||||
'model' + str(n_local_enhancers) + '_1')
|
||||
model_upsample = getattr(self,
|
||||
'model' + str(n_local_enhancers) + '_2')
|
||||
input_i = input_downsampled[self.n_local_enhancers
|
||||
- n_local_enhancers]
|
||||
output_prev = model_upsample(
|
||||
model_downsample(input_i) + output_prev)
|
||||
return output_prev
|
||||
|
||||
|
||||
class GlobalGenerator(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
input_nc,
|
||||
output_nc,
|
||||
ngf=64,
|
||||
n_downsampling=3,
|
||||
n_blocks=9,
|
||||
norm_layer=nn.BatchNorm2d,
|
||||
padding_type='reflect',
|
||||
last_op=nn.Tanh()):
|
||||
assert (n_blocks >= 0)
|
||||
super(GlobalGenerator, self).__init__()
|
||||
activation = nn.ReLU(True)
|
||||
|
||||
model = [
|
||||
nn.ReflectionPad2d(3),
|
||||
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0),
|
||||
norm_layer(ngf), activation
|
||||
]
|
||||
for i in range(n_downsampling):
|
||||
mult = 2**i
|
||||
model += [
|
||||
nn.Conv2d(
|
||||
ngf * mult,
|
||||
ngf * mult * 2,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1),
|
||||
norm_layer(ngf * mult * 2), activation
|
||||
]
|
||||
|
||||
mult = 2**n_downsampling
|
||||
for i in range(n_blocks):
|
||||
model += [
|
||||
ResnetBlock(
|
||||
ngf * mult,
|
||||
padding_type=padding_type,
|
||||
activation=activation,
|
||||
norm_layer=norm_layer)
|
||||
]
|
||||
|
||||
for i in range(n_downsampling):
|
||||
mult = 2**(n_downsampling - i)
|
||||
model += [
|
||||
nn.ConvTranspose2d(
|
||||
ngf * mult,
|
||||
int(ngf * mult / 2),
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
output_padding=1),
|
||||
norm_layer(int(ngf * mult / 2)), activation
|
||||
]
|
||||
model += [
|
||||
nn.ReflectionPad2d(3),
|
||||
nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)
|
||||
]
|
||||
if last_op is not None:
|
||||
model += [last_op]
|
||||
self.model = nn.Sequential(*model)
|
||||
|
||||
def forward(self, input):
|
||||
return self.model(input)
|
||||
|
||||
|
||||
"""
|
||||
Define a resnet block
|
||||
"""
|
||||
|
||||
|
||||
class ResnetBlock(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim,
|
||||
padding_type,
|
||||
norm_layer,
|
||||
activation=nn.ReLU(True),
|
||||
use_dropout=False):
|
||||
super(ResnetBlock, self).__init__()
|
||||
self.conv_block = self.build_conv_block(dim, padding_type, norm_layer,
|
||||
activation, use_dropout)
|
||||
|
||||
def build_conv_block(self, dim, padding_type, norm_layer, activation,
|
||||
use_dropout):
|
||||
conv_block = []
|
||||
p = 0
|
||||
if padding_type == 'reflect':
|
||||
conv_block += [nn.ReflectionPad2d(1)]
|
||||
elif padding_type == 'replicate':
|
||||
conv_block += [nn.ReplicationPad2d(1)]
|
||||
elif padding_type == 'zero':
|
||||
p = 1
|
||||
else:
|
||||
raise NotImplementedError('padding [%s] is not implemented'
|
||||
% padding_type)
|
||||
|
||||
conv_block += [
|
||||
nn.Conv2d(dim, dim, kernel_size=3, padding=p),
|
||||
norm_layer(dim), activation
|
||||
]
|
||||
if use_dropout:
|
||||
conv_block += [nn.Dropout(0.5)]
|
||||
|
||||
p = 0
|
||||
if padding_type == 'reflect':
|
||||
conv_block += [nn.ReflectionPad2d(1)]
|
||||
elif padding_type == 'replicate':
|
||||
conv_block += [nn.ReplicationPad2d(1)]
|
||||
elif padding_type == 'zero':
|
||||
p = 1
|
||||
else:
|
||||
raise NotImplementedError('padding [%s] is not implemented'
|
||||
% padding_type)
|
||||
conv_block += [
|
||||
nn.Conv2d(dim, dim, kernel_size=3, padding=p),
|
||||
norm_layer(dim)
|
||||
]
|
||||
|
||||
return nn.Sequential(*conv_block)
|
||||
|
||||
def forward(self, x):
|
||||
out = x + self.conv_block(x)
|
||||
return out
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
input_nc,
|
||||
output_nc,
|
||||
ngf=32,
|
||||
n_downsampling=4,
|
||||
norm_layer=nn.BatchNorm2d):
|
||||
super(Encoder, self).__init__()
|
||||
self.output_nc = output_nc
|
||||
|
||||
model = [
|
||||
nn.ReflectionPad2d(3),
|
||||
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0),
|
||||
norm_layer(ngf),
|
||||
nn.ReLU(True)
|
||||
]
|
||||
for i in range(n_downsampling):
|
||||
mult = 2**i
|
||||
model += [
|
||||
nn.Conv2d(
|
||||
ngf * mult,
|
||||
ngf * mult * 2,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1),
|
||||
norm_layer(ngf * mult * 2),
|
||||
nn.ReLU(True)
|
||||
]
|
||||
|
||||
for i in range(n_downsampling):
|
||||
mult = 2**(n_downsampling - i)
|
||||
model += [
|
||||
nn.ConvTranspose2d(
|
||||
ngf * mult,
|
||||
int(ngf * mult / 2),
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
output_padding=1),
|
||||
norm_layer(int(ngf * mult / 2)),
|
||||
nn.ReLU(True)
|
||||
]
|
||||
|
||||
model += [
|
||||
nn.ReflectionPad2d(3),
|
||||
nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0),
|
||||
nn.Tanh()
|
||||
]
|
||||
self.model = nn.Sequential(*model)
|
||||
|
||||
def forward(self, input, inst):
|
||||
outputs = self.model(input)
|
||||
|
||||
outputs_mean = outputs.clone()
|
||||
inst_list = np.unique(inst.cpu().numpy().astype(int))
|
||||
for i in inst_list:
|
||||
for b in range(input.size()[0]):
|
||||
indices = (inst[b:b + 1] == int(i)).nonzero()
|
||||
for j in range(self.output_nc):
|
||||
output_ins = outputs[indices[:, 0] + b, indices[:, 1] + j,
|
||||
indices[:, 2], indices[:, 3]]
|
||||
mean_feat = torch.mean(output_ins).expand_as(output_ins)
|
||||
outputs_mean[indices[:, 0] + b, indices[:, 1] + j,
|
||||
indices[:, 2], indices[:, 3]] = mean_feat
|
||||
return outputs_mean
|
||||
178
modelscope/models/cv/human_reconstruction/utils.py
Normal file
178
modelscope/models/cv/human_reconstruction/utils.py
Normal file
@@ -0,0 +1,178 @@
|
||||
import os
|
||||
|
||||
import mcubes
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def save_obj_mesh_with_color(mesh_path, verts, faces, colors):
|
||||
file = open(mesh_path, 'w')
|
||||
for idx, v in enumerate(verts):
|
||||
c = colors[idx]
|
||||
file.write('v %.4f %.4f %.4f %.4f %.4f %.4f\n' %
|
||||
(v[0], v[1], v[2], c[0], c[1], c[2]))
|
||||
for f in faces:
|
||||
f_plus = f + 1
|
||||
file.write('f %d %d %d\n' % (f_plus[0], f_plus[2], f_plus[1]))
|
||||
file.close()
|
||||
|
||||
|
||||
def save_obj_mesh(mesh_path, verts, faces):
|
||||
file = open(mesh_path, 'w')
|
||||
for idx, v in enumerate(verts):
|
||||
file.write('v %.4f %.4f %.4f\n' % (v[0], v[1], v[2]))
|
||||
for f in faces:
|
||||
f_plus = f + 1
|
||||
file.write('f %d %d %d\n' % (f_plus[0], f_plus[2], f_plus[1]))
|
||||
file.close()
|
||||
|
||||
|
||||
def to_tensor(img):
|
||||
if len(img.shape) == 2:
|
||||
img = img[:, :, np.newaxis]
|
||||
img = torch.from_numpy(img.transpose(2, 0, 1)).float()
|
||||
img = img / 255.
|
||||
return img
|
||||
|
||||
|
||||
def reconstruction(net, calib_tensor, coords, mat, num_samples=50000):
|
||||
|
||||
def eval_func(points):
|
||||
points = np.expand_dims(points, axis=0)
|
||||
points = np.repeat(points, 1, axis=0)
|
||||
samples = torch.from_numpy(points).cuda().float()
|
||||
net.query(samples, calib_tensor)
|
||||
pred = net.get_preds()
|
||||
pred = pred[0]
|
||||
return pred.detach().cpu().numpy()
|
||||
|
||||
sdf = eval_grid(coords, eval_func, num_samples=num_samples)
|
||||
vertices, faces = mcubes.marching_cubes(sdf, 0.5)
|
||||
verts = np.matmul(mat[:3, :3], vertices.T) + mat[:3, 3:4]
|
||||
verts = verts.T
|
||||
return verts, faces
|
||||
|
||||
|
||||
def keep_largest(mesh_big):
|
||||
mesh_lst = mesh_big.split(only_watertight=False)
|
||||
keep_mesh = mesh_lst[0]
|
||||
for mesh in mesh_lst:
|
||||
if mesh.vertices.shape[0] > keep_mesh.vertices.shape[0]:
|
||||
keep_mesh = mesh
|
||||
return keep_mesh
|
||||
|
||||
|
||||
def eval_grid(coords,
|
||||
eval_func,
|
||||
init_resolution=64,
|
||||
threshold=0.01,
|
||||
num_samples=512 * 512 * 512):
|
||||
resolution = coords.shape[1:4]
|
||||
sdf = np.zeros(resolution)
|
||||
dirty = np.ones(resolution, dtype=np.bool)
|
||||
grid_mask = np.zeros(resolution, dtype=np.bool)
|
||||
reso = resolution[0] // init_resolution
|
||||
|
||||
while reso > 0:
|
||||
grid_mask[0:resolution[0]:reso, 0:resolution[1]:reso,
|
||||
0:resolution[2]:reso] = True
|
||||
test_mask = np.logical_and(grid_mask, dirty)
|
||||
points = coords[:, test_mask]
|
||||
|
||||
sdf[test_mask] = batch_eval(points, eval_func, num_samples=num_samples)
|
||||
dirty[test_mask] = False
|
||||
|
||||
if reso <= 1:
|
||||
break
|
||||
for x in range(0, resolution[0] - reso, reso):
|
||||
for y in range(0, resolution[1] - reso, reso):
|
||||
for z in range(0, resolution[2] - reso, reso):
|
||||
if not dirty[x + reso // 2, y + reso // 2, z + reso // 2]:
|
||||
continue
|
||||
v0 = sdf[x, y, z]
|
||||
v1 = sdf[x, y, z + reso]
|
||||
v2 = sdf[x, y + reso, z]
|
||||
v3 = sdf[x, y + reso, z + reso]
|
||||
v4 = sdf[x + reso, y, z]
|
||||
v5 = sdf[x + reso, y, z + reso]
|
||||
v6 = sdf[x + reso, y + reso, z]
|
||||
v7 = sdf[x + reso, y + reso, z + reso]
|
||||
v = np.array([v0, v1, v2, v3, v4, v5, v6, v7])
|
||||
v_min = v.min()
|
||||
v_max = v.max()
|
||||
if (v_max - v_min) < threshold:
|
||||
sdf[x:x + reso, y:y + reso,
|
||||
z:z + reso] = (v_max + v_min) / 2
|
||||
dirty[x:x + reso, y:y + reso, z:z + reso] = False
|
||||
reso //= 2
|
||||
|
||||
return sdf.reshape(resolution)
|
||||
|
||||
|
||||
def batch_eval(points, eval_func, num_samples=512 * 512 * 512):
|
||||
num_pts = points.shape[1]
|
||||
sdf = np.zeros(num_pts)
|
||||
|
||||
num_batches = num_pts // num_samples
|
||||
for i in range(num_batches):
|
||||
sdf[i * num_samples:i * num_samples + num_samples] = eval_func(
|
||||
points[:, i * num_samples:i * num_samples + num_samples])
|
||||
if num_pts % num_samples:
|
||||
sdf[num_batches * num_samples:] = eval_func(points[:, num_batches
|
||||
* num_samples:])
|
||||
return sdf
|
||||
|
||||
|
||||
def create_grid(res,
|
||||
b_min=np.array([0, 0, 0]),
|
||||
b_max=np.array([1, 1, 1]),
|
||||
transform=None):
|
||||
coords = np.mgrid[:res, :res, :res]
|
||||
|
||||
coords = coords.reshape(3, -1)
|
||||
coords_matrix = np.eye(4)
|
||||
length = b_max - b_min
|
||||
|
||||
coords_matrix[0, 0] = length[0] / res
|
||||
coords_matrix[1, 1] = length[1] / res
|
||||
coords_matrix[2, 2] = length[2] / res
|
||||
coords_matrix[0:3, 3] = b_min
|
||||
|
||||
coords = np.matmul(coords_matrix[:3, :3], coords) + coords_matrix[:3, 3:4]
|
||||
if transform is not None:
|
||||
coords = np.matmul(transform[:3, :3], coords) + transform[:3, 3:4]
|
||||
coords_matrix = np.matmul(transform, coords_matrix)
|
||||
coords = coords.reshape(3, res, res, res)
|
||||
return coords, coords_matrix
|
||||
|
||||
|
||||
def get_submesh(verts,
|
||||
faces,
|
||||
color,
|
||||
verts_retained=None,
|
||||
faces_retained=None,
|
||||
min_vert_in_face=2):
|
||||
verts = verts
|
||||
faces = faces
|
||||
colors = color
|
||||
if verts_retained is not None:
|
||||
if verts_retained.dtype != 'bool':
|
||||
vert_mask = np.zeros(len(verts), dtype=bool)
|
||||
vert_mask[verts_retained] = True
|
||||
else:
|
||||
vert_mask = verts_retained
|
||||
bool_faces = np.sum(
|
||||
vert_mask[faces.ravel()].reshape(-1, 3), axis=1) > min_vert_in_face
|
||||
elif faces_retained is not None:
|
||||
if faces_retained.dtype != 'bool':
|
||||
bool_faces = np.zeros(len(faces_retained), dtype=bool)
|
||||
else:
|
||||
bool_faces = faces_retained
|
||||
new_faces = faces[bool_faces]
|
||||
vertex_ids = list(set(new_faces.ravel()))
|
||||
oldtonew = -1 * np.ones([len(verts)])
|
||||
oldtonew[vertex_ids] = range(0, len(vertex_ids))
|
||||
new_verts = verts[vertex_ids]
|
||||
new_colors = colors[vertex_ids]
|
||||
new_faces = oldtonew[new_faces].astype('int32')
|
||||
return (new_verts, new_faces, new_colors, bool_faces, vertex_ids)
|
||||
@@ -457,6 +457,16 @@ TASK_OUTPUTS = {
|
||||
# }
|
||||
Tasks.face_reconstruction: [OutputKeys.OUTPUT],
|
||||
|
||||
# 3D human reconstruction result for single sample
|
||||
# {
|
||||
# "output": {
|
||||
# "vertices": np.array with shape(n, 3),
|
||||
# "faces": np.array with shape(n, 3),
|
||||
# "colors": np.array with shape(n, 3),
|
||||
# }
|
||||
# }
|
||||
Tasks.human_reconstruction: [OutputKeys.OUTPUT],
|
||||
|
||||
# 2D hand keypoints result for single sample
|
||||
# {
|
||||
# "keypoints": [
|
||||
|
||||
@@ -316,6 +316,8 @@ TASK_INPUTS = {
|
||||
},
|
||||
Tasks.action_detection:
|
||||
InputType.VIDEO,
|
||||
Tasks.human_reconstruction:
|
||||
InputType.IMAGE,
|
||||
Tasks.image_reid_person:
|
||||
InputType.IMAGE,
|
||||
Tasks.video_inpainting: {
|
||||
|
||||
109
modelscope/pipelines/cv/human_reconstruction_pipeline.py
Normal file
109
modelscope/pipelines/cv/human_reconstruction_pipeline.py
Normal file
@@ -0,0 +1,109 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
import shutil
|
||||
from typing import Any, Dict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import trimesh
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.models.cv.human_reconstruction.utils import (
|
||||
keep_largest, reconstruction, save_obj_mesh, save_obj_mesh_with_color,
|
||||
to_tensor)
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.pipelines.base import Input, Model, Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.human_reconstruction, module_name=Pipelines.human_reconstruction)
|
||||
class HumanReconstructionPipeline(Pipeline):
|
||||
|
||||
def __init__(self, model: str, **kwargs):
|
||||
"""The inference pipeline for human reconstruction task.
|
||||
Human Reconstruction Pipeline. Given one image generate a human mesh.
|
||||
|
||||
Args:
|
||||
model (`str` or `Model` or module instance): A model instance or a model local dir
|
||||
or a model id in the model hub.
|
||||
|
||||
Example:
|
||||
>>> from modelscope.pipelines import pipeline
|
||||
>>> test_input = 'human_reconstruction.jpg' # input image path
|
||||
>>> pipeline_humanRecon = pipeline('human-reconstruction',
|
||||
model='damo/cv_hrnet_image-human-reconstruction')
|
||||
>>> result = pipeline_humanRecon(test_input)
|
||||
>>> output = result[OutputKeys.OUTPUT]
|
||||
"""
|
||||
super().__init__(model=model, **kwargs)
|
||||
if not isinstance(self.model, Model):
|
||||
logger.error('model object is not initialized.')
|
||||
raise Exception('model object is not initialized.')
|
||||
|
||||
def preprocess(self, input: Input) -> Dict[str, Any]:
|
||||
img_crop = self.model.crop_img(input)
|
||||
img, mask = self.model.get_mask(img_crop)
|
||||
normal_f, normal_b = self.model.generation_normal(img, mask)
|
||||
image = to_tensor(img_crop) * 2 - 1
|
||||
normal_b = to_tensor(normal_b) * 2 - 1
|
||||
normal_f = to_tensor(normal_f) * 2 - 1
|
||||
mask = to_tensor(mask)
|
||||
result = {
|
||||
'img': image,
|
||||
'mask': mask,
|
||||
'normal_F': normal_f,
|
||||
'normal_B': normal_b
|
||||
}
|
||||
return result
|
||||
|
||||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
|
||||
image = input['img']
|
||||
mask = input['mask']
|
||||
normF = input['normal_F']
|
||||
normB = input['normal_B']
|
||||
normF[1, ...] = -normF[1, ...]
|
||||
normB[0, ...] = -normB[0, ...]
|
||||
img = image * mask
|
||||
normal_b = normB * mask
|
||||
normal_f = normF * mask
|
||||
img = torch.cat([img, normal_f, normal_b], dim=0).float()
|
||||
image_tensor = img.unsqueeze(0).to(self.model.device)
|
||||
calib_tensor = self.model.calib
|
||||
net = self.model.meshmodel
|
||||
net.extract_features(image_tensor)
|
||||
verts, faces = reconstruction(net, calib_tensor, self.model.coords,
|
||||
self.model.mat)
|
||||
pre_mesh = trimesh.Trimesh(
|
||||
verts, faces, process=False, maintain_order=True)
|
||||
final_mesh = keep_largest(pre_mesh)
|
||||
verts = final_mesh.vertices
|
||||
faces = final_mesh.faces
|
||||
verts_tensor = torch.from_numpy(verts.T).unsqueeze(0).to(
|
||||
self.model.device).float()
|
||||
color = torch.zeros(verts.shape)
|
||||
interval = 20000
|
||||
for i in range(len(color) // interval):
|
||||
left = i * interval
|
||||
right = i * interval + interval
|
||||
if i == len(color) // interval - 1:
|
||||
right = -1
|
||||
pred_color = net.query_rgb(verts_tensor[:, :, left:right],
|
||||
calib_tensor)
|
||||
rgb = pred_color[0].detach().cpu() * 0.5 + 0.5
|
||||
color[left:right] = rgb.T
|
||||
vert_min = np.min(verts[:, 1])
|
||||
verts[:, 1] = verts[:, 1] - vert_min
|
||||
save_obj_mesh('human_reconstruction.obj', verts, faces)
|
||||
save_obj_mesh_with_color('human_color.obj', verts, faces,
|
||||
color.numpy())
|
||||
results = {'vertices': verts, 'faces': faces, 'colors': color.numpy()}
|
||||
return {OutputKeys.OUTPUT: results}
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return inputs
|
||||
@@ -141,6 +141,9 @@ class CVTasks(object):
|
||||
# 3d face reconstruction
|
||||
face_reconstruction = 'face-reconstruction'
|
||||
|
||||
# 3d human reconstruction
|
||||
human_reconstruction = 'human-reconstruction'
|
||||
|
||||
# image quality assessment mos
|
||||
image_quality_assessment_mos = 'image-quality-assessment-mos'
|
||||
# motion generation
|
||||
|
||||
@@ -60,6 +60,7 @@ torchmetrics>=0.6.2
|
||||
torchsummary>=1.5.1
|
||||
torchvision
|
||||
transformers>=4.26.0
|
||||
trimesh
|
||||
ujson
|
||||
utils
|
||||
videofeatures_clipit>=1.0
|
||||
|
||||
46
tests/pipelines/test_human_reconstruction.py
Normal file
46
tests/pipelines/test_human_reconstruction.py
Normal file
@@ -0,0 +1,46 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os.path as osp
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.pipelines.base import Pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
sys.path.append('.')
|
||||
|
||||
|
||||
class HumanReconstructionTest(unittest.TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.task = Tasks.human_reconstruction
|
||||
self.model_id = 'damo/cv_hrnet_image-human-reconstruction'
|
||||
self.test_image = 'data/test/images/human_reconstruction.jpg'
|
||||
|
||||
def pipeline_inference(self, pipeline: Pipeline, input_location: str):
|
||||
result = pipeline(input_location)
|
||||
mesh = result[OutputKeys.OUTPUT]
|
||||
print(
|
||||
f'Output to {osp.abspath("human_reconstruction.obj")}, vertices num: {mesh["vertices"].shape}'
|
||||
)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||
def test_run_by_direct_model_download(self):
|
||||
model_dir = snapshot_download(self.model_id)
|
||||
human_reconstruction = pipeline(
|
||||
Tasks.human_reconstruction, model=model_dir)
|
||||
print('running')
|
||||
self.pipeline_inference(human_reconstruction, self.test_image)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_modelhub(self):
|
||||
human_reconstruction = pipeline(
|
||||
Tasks.human_reconstruction, model=self.model_id)
|
||||
self.pipeline_inference(human_reconstruction, self.test_image)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user