mirror of
https://github.com/modelscope/modelscope.git
synced 2026-02-24 04:01:10 +01:00
人像生成代码评审
该CR提交了关于人像生成的代码,能够通过给定人像图片以及相应的target姿势数据生成相应姿势的图片。 Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/13715612
This commit is contained in:
Submodule data/test updated: b648024203...ff64feb02b
@@ -122,6 +122,7 @@ class Models(object):
|
||||
fastinst = 'fastinst'
|
||||
pedestrian_attribute_recognition = 'pedestrian-attribute-recognition'
|
||||
image_try_on = 'image-try-on'
|
||||
human_image_generation = 'human-image-generation'
|
||||
|
||||
# nlp models
|
||||
bert = 'bert'
|
||||
@@ -427,6 +428,7 @@ class Pipelines(object):
|
||||
pedestrian_attribute_recognition = 'resnet50_pedestrian-attribute-recognition_image'
|
||||
text_to_360panorama_image = 'text-to-360panorama-image'
|
||||
image_try_on = 'image-try-on'
|
||||
human_image_generation = 'human-image-generation'
|
||||
|
||||
# nlp tasks
|
||||
automatic_post_editing = 'automatic-post-editing'
|
||||
@@ -877,7 +879,9 @@ DEFAULT_MODEL_FOR_PIPELINE = {
|
||||
Pipelines.text_to_360panorama_image,
|
||||
'damo/cv_diffusion_text-to-360panorama-image_generation'),
|
||||
Tasks.image_try_on: (Pipelines.image_try_on,
|
||||
'damo/cv_SAL-VTON_virtual-try-on')
|
||||
'damo/cv_SAL-VTON_virtual-try-on'),
|
||||
Tasks.human_image_generation: (Pipelines.human_image_generation,
|
||||
'damo/cv_FreqHPT_human-image-generation')
|
||||
}
|
||||
|
||||
|
||||
|
||||
22
modelscope/models/cv/human_image_generation/__init__.py
Normal file
22
modelscope/models/cv/human_image_generation/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
||||
# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .human_image_generation_infer import FreqHPTForHumanImageGeneration
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
'human_image_generation_infer': ['FreqHPTForHumanImageGeneration']
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
@@ -0,0 +1,717 @@
|
||||
import collections
|
||||
import math
|
||||
import sys
|
||||
|
||||
import torch
|
||||
from pytorch_wavelets import DWTForward, DWTInverse
|
||||
from torch import kl_div, nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from modelscope.ops.human_image_generation.fused_act import (FusedLeakyReLU,
|
||||
fused_leaky_relu)
|
||||
from modelscope.ops.human_image_generation.upfirdn2d import upfirdn2d
|
||||
from .conv2d_gradfix import conv2d, conv_transpose2d
|
||||
from .wavelet_module import *
|
||||
|
||||
|
||||
# add flow
|
||||
class ExtractionOperation_flow(nn.Module):
|
||||
|
||||
def __init__(self, in_channel, num_label, match_kernel):
|
||||
super(ExtractionOperation_flow, self).__init__()
|
||||
self.value_conv = EqualConv2d(
|
||||
in_channel,
|
||||
in_channel,
|
||||
match_kernel,
|
||||
1,
|
||||
match_kernel // 2,
|
||||
bias=True)
|
||||
self.semantic_extraction_filter = EqualConv2d(
|
||||
in_channel,
|
||||
num_label,
|
||||
match_kernel,
|
||||
1,
|
||||
match_kernel // 2,
|
||||
bias=False)
|
||||
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
self.num_label = num_label
|
||||
|
||||
def forward(self, value, recoder):
|
||||
key = value
|
||||
b, c, h, w = value.shape
|
||||
key = self.semantic_extraction_filter(self.feature_norm(key))
|
||||
extraction_softmax = self.softmax(key.view(b, -1, h * w))
|
||||
values_flatten = self.value_conv(value).view(b, -1, h * w)
|
||||
neural_textures = torch.einsum('bkm,bvm->bvk', extraction_softmax,
|
||||
values_flatten)
|
||||
recoder['extraction_softmax'].insert(0, extraction_softmax)
|
||||
recoder['neural_textures'].insert(0, neural_textures)
|
||||
return neural_textures, extraction_softmax
|
||||
|
||||
def feature_norm(self, input_tensor):
|
||||
input_tensor = input_tensor - input_tensor.mean(dim=1, keepdim=True)
|
||||
norm = torch.norm(
|
||||
input_tensor, 2, 1, keepdim=True) + sys.float_info.epsilon
|
||||
out = torch.div(input_tensor, norm)
|
||||
return out
|
||||
|
||||
|
||||
class DistributionOperation_flow(nn.Module):
|
||||
|
||||
def __init__(self, num_label, input_dim, match_kernel=3):
|
||||
super(DistributionOperation_flow, self).__init__()
|
||||
self.semantic_distribution_filter = EqualConv2d(
|
||||
input_dim,
|
||||
num_label,
|
||||
kernel_size=match_kernel,
|
||||
stride=1,
|
||||
padding=match_kernel // 2)
|
||||
self.num_label = num_label
|
||||
|
||||
def forward(self, query, extracted_feature, recoder):
|
||||
b, c, h, w = query.shape
|
||||
|
||||
query = self.semantic_distribution_filter(query)
|
||||
query_flatten = query.view(b, self.num_label, -1)
|
||||
query_softmax = F.softmax(query_flatten, 1)
|
||||
values_q = torch.einsum('bkm,bkv->bvm', query_softmax,
|
||||
extracted_feature.permute(0, 2, 1))
|
||||
attn_out = values_q.view(b, -1, h, w)
|
||||
recoder['semantic_distribution'].append(query)
|
||||
return attn_out
|
||||
|
||||
|
||||
class EncoderLayer_flow(nn.Sequential):
|
||||
|
||||
def __init__(self,
|
||||
in_channel,
|
||||
out_channel,
|
||||
kernel_size,
|
||||
downsample=False,
|
||||
blur_kernel=[1, 3, 3, 1],
|
||||
bias=True,
|
||||
activate=True,
|
||||
use_extraction=False,
|
||||
num_label=None,
|
||||
match_kernel=None,
|
||||
num_extractions=2):
|
||||
super().__init__()
|
||||
|
||||
if downsample:
|
||||
factor = 2
|
||||
p = (len(blur_kernel) - factor) + (kernel_size - 1)
|
||||
pad0 = (p + 1) // 2
|
||||
pad1 = p // 2
|
||||
self.blur = Blur(blur_kernel, pad=(pad0, pad1))
|
||||
|
||||
stride = 2
|
||||
padding = 0
|
||||
|
||||
else:
|
||||
self.blur = None
|
||||
stride = 1
|
||||
padding = kernel_size // 2
|
||||
|
||||
self.conv = EqualConv2d(
|
||||
in_channel,
|
||||
out_channel,
|
||||
kernel_size,
|
||||
padding=padding,
|
||||
stride=stride,
|
||||
bias=bias and not activate,
|
||||
)
|
||||
|
||||
self.activate = FusedLeakyReLU(
|
||||
out_channel, bias=bias) if activate else None
|
||||
self.use_extraction = use_extraction
|
||||
if self.use_extraction:
|
||||
self.extraction_operations = nn.ModuleList()
|
||||
for _ in range(num_extractions):
|
||||
self.extraction_operations.append(
|
||||
ExtractionOperation_flow(out_channel, num_label,
|
||||
match_kernel))
|
||||
|
||||
def forward(self, input, recoder=None):
|
||||
out = self.blur(input) if self.blur is not None else input
|
||||
out = self.conv(out)
|
||||
out = self.activate(out) if self.activate is not None else out
|
||||
if self.use_extraction:
|
||||
for extraction_operation in self.extraction_operations:
|
||||
extraction_operation(out, recoder)
|
||||
return out
|
||||
|
||||
|
||||
class DecoderLayer_flow_wavelet_fuse24(nn.Module):
|
||||
|
||||
# add fft refinement and tps
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channel,
|
||||
out_channel,
|
||||
kernel_size,
|
||||
upsample=False,
|
||||
blur_kernel=[1, 3, 3, 1],
|
||||
bias=True,
|
||||
activate=True,
|
||||
use_distribution=True,
|
||||
num_label=16,
|
||||
match_kernel=3,
|
||||
wavelet_down_level=False,
|
||||
window_size=8,
|
||||
):
|
||||
super().__init__()
|
||||
if upsample:
|
||||
factor = 2
|
||||
p = (len(blur_kernel) - factor) - (kernel_size - 1)
|
||||
pad0 = (p + 1) // 2 + factor - 1
|
||||
pad1 = p // 2 + 1
|
||||
|
||||
self.blur = Blur(
|
||||
blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
|
||||
self.conv = EqualTransposeConv2d(
|
||||
in_channel,
|
||||
out_channel,
|
||||
kernel_size,
|
||||
stride=2,
|
||||
padding=0,
|
||||
bias=bias and not activate,
|
||||
)
|
||||
else:
|
||||
self.conv = EqualConv2d(
|
||||
in_channel,
|
||||
out_channel,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=kernel_size // 2,
|
||||
bias=bias and not activate,
|
||||
)
|
||||
self.blur = None
|
||||
|
||||
self.distribution_operation = DistributionOperation_flow(
|
||||
num_label, out_channel,
|
||||
match_kernel=match_kernel) if use_distribution else None
|
||||
self.activate = FusedLeakyReLU(
|
||||
out_channel, bias=bias) if activate else None
|
||||
self.use_distribution = use_distribution
|
||||
|
||||
# mask prediction network
|
||||
if use_distribution:
|
||||
self.conv_mask_lf = nn.Sequential(*[
|
||||
EqualConv2d(
|
||||
out_channel, 1, 3, stride=1, padding=3 // 2, bias=False),
|
||||
nn.Sigmoid()
|
||||
])
|
||||
self.conv_mask_dict = nn.ModuleDict()
|
||||
for level in range(wavelet_down_level):
|
||||
conv_mask = nn.Sequential(*[
|
||||
EqualConv2d(
|
||||
out_channel,
|
||||
1,
|
||||
3,
|
||||
stride=1,
|
||||
padding=3 // 2,
|
||||
bias=False),
|
||||
nn.Sigmoid()
|
||||
])
|
||||
self.conv_mask_dict[str(level)] = conv_mask
|
||||
|
||||
self.wavelet_down_level = wavelet_down_level
|
||||
if wavelet_down_level:
|
||||
self.dwt = DWTForward(
|
||||
J=self.wavelet_down_level, mode='zero', wave='haar')
|
||||
self.idwt = DWTInverse(mode='zero', wave='haar')
|
||||
|
||||
# for mask input channel squeeze and expand
|
||||
self.conv_l_squeeze = EqualConv2d(
|
||||
2 * out_channel, out_channel, 1, 1, 0, bias=False)
|
||||
self.conv_h_squeeze = EqualConv2d(
|
||||
6 * out_channel, out_channel, 1, 1, 0, bias=False)
|
||||
|
||||
self.conv_l = EqualConv2d(
|
||||
out_channel, out_channel, 3, 1, 3 // 2, bias=False)
|
||||
|
||||
self.hf_modules = nn.ModuleDict()
|
||||
for level in range(wavelet_down_level):
|
||||
hf_module = nn.Module()
|
||||
prev_channel = out_channel if level == self.wavelet_down_level - 1 else 3 * out_channel
|
||||
hf_module.conv_prev = EqualConv2d(
|
||||
prev_channel, 3 * out_channel, 3, 1, 3 // 2, bias=False)
|
||||
hf_module.conv_hf = GatedConv2dWithActivation(
|
||||
3 * out_channel, 3 * out_channel, 3, 1, 3 // 2, bias=False)
|
||||
hf_module.conv_out = GatedConv2dWithActivation(
|
||||
3 * out_channel, 3 * out_channel, 3, 1, 3 // 2, bias=False)
|
||||
self.hf_modules[str(level)] = hf_module
|
||||
|
||||
self.amp_fuse = nn.Sequential(
|
||||
EqualConv2d(2 * out_channel, out_channel, 1, 1, 0),
|
||||
FusedLeakyReLU(out_channel, bias=False),
|
||||
EqualConv2d(out_channel, out_channel, 1, 1, 0))
|
||||
self.pha_fuse = nn.Sequential(
|
||||
EqualConv2d(2 * out_channel, out_channel, 1, 1, 0),
|
||||
FusedLeakyReLU(out_channel, bias=False),
|
||||
EqualConv2d(out_channel, out_channel, 1, 1, 0))
|
||||
self.post = EqualConv2d(out_channel, out_channel, 1, 1, 0)
|
||||
self.eps = 1e-8
|
||||
|
||||
def forward(self,
|
||||
input,
|
||||
neural_texture=None,
|
||||
recoder=None,
|
||||
warped_texture=None,
|
||||
style_net=None,
|
||||
gstyle=None):
|
||||
out = self.conv(input)
|
||||
out = self.blur(out) if self.blur is not None else out
|
||||
|
||||
mask_l, mask_h = None, None
|
||||
out_attn = None
|
||||
if self.use_distribution and neural_texture is not None:
|
||||
out_ori = out
|
||||
out_attn = self.distribution_operation(out, neural_texture,
|
||||
recoder)
|
||||
# wavelet fusion
|
||||
if self.wavelet_down_level:
|
||||
assert out.shape[2] % 2 == 0, \
|
||||
f'out shape {out.shape} is not appropriate for processing'
|
||||
b, c, h, w = out.shape
|
||||
|
||||
# wavelet decomposition
|
||||
LF_attn, HF_attn = self.dwt(out_attn)
|
||||
LF_warp, HF_warp = self.dwt(warped_texture)
|
||||
LF_out, HF_out = self.dwt(out)
|
||||
|
||||
# generate mask
|
||||
hf_dict = {}
|
||||
l_mask_input = torch.cat([LF_attn, LF_warp], dim=1)
|
||||
l_mask_input = self.conv_l_squeeze(l_mask_input)
|
||||
l_mask_input = style_net(l_mask_input, gstyle)
|
||||
ml = self.conv_mask_lf(l_mask_input)
|
||||
mask_l = ml
|
||||
|
||||
for level in range(self.wavelet_down_level):
|
||||
# level up, feature size down
|
||||
scale = 2**(level + 1)
|
||||
hfa = HF_attn[level].view(b, c * 3, h // scale, w // scale)
|
||||
hfw = HF_warp[level].view(b, c * 3, h // scale, w // scale)
|
||||
hfg = HF_out[level].view(b, c * 3, h // scale, w // scale)
|
||||
|
||||
h_mask_input = torch.cat([hfa, hfw], dim=1)
|
||||
h_mask_input = self.conv_h_squeeze(h_mask_input)
|
||||
h_mask_input = style_net(h_mask_input, gstyle)
|
||||
mh = self.conv_mask_dict[str(level)](h_mask_input)
|
||||
if level == 0:
|
||||
mask_h = mh
|
||||
|
||||
# fuse high frequency
|
||||
xh = (mh * hfa + (1 - mh) * hfw + hfg) / math.sqrt(2)
|
||||
hf_dict[str(level)] = xh
|
||||
|
||||
temp_result = (1 - ml) * LF_warp + LF_out
|
||||
out_l = (ml * LF_attn + temp_result) / math.sqrt(2)
|
||||
out_h_list = []
|
||||
for level in range(self.wavelet_down_level - 1, -1, -1):
|
||||
xh = hf_dict[str(level)]
|
||||
b, c, h, w = xh.shape
|
||||
out_h_list.append(xh.view(b, c // 3, 3, h, w))
|
||||
out_h_list = (
|
||||
out_h_list)[::-1] # the h list from large to small size
|
||||
#
|
||||
out = self.idwt((out_l, out_h_list))
|
||||
else:
|
||||
out = (out + out_attn) / math.sqrt(2)
|
||||
|
||||
# fourier refinement
|
||||
_, _, H, W = out.shape
|
||||
fuseF = torch.fft.rfft2(out + self.eps, norm='backward')
|
||||
outF = torch.fft.rfft2(out_ori + self.eps, norm='backward')
|
||||
amp = self.amp_fuse(
|
||||
torch.cat([torch.abs(fuseF), torch.abs(outF)], 1))
|
||||
pha = self.pha_fuse(
|
||||
torch.cat(
|
||||
[torch.angle(fuseF), torch.angle(outF)], 1))
|
||||
out_fft = torch.fft.irfft2(
|
||||
amp * torch.exp(1j * pha) + self.eps,
|
||||
s=(H, W),
|
||||
dim=(-2, -1),
|
||||
norm='backward')
|
||||
|
||||
out = out + self.post(out_fft)
|
||||
|
||||
out = self.activate(
|
||||
out.contiguous()) if self.activate is not None else out
|
||||
return out, mask_h, mask_l
|
||||
|
||||
|
||||
# base functions
|
||||
|
||||
|
||||
class EqualConv2d(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
in_channel,
|
||||
out_channel,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=True,
|
||||
dilation=1):
|
||||
super().__init__()
|
||||
|
||||
self.weight = nn.Parameter(
|
||||
torch.randn(out_channel, in_channel, kernel_size, kernel_size))
|
||||
self.scale = 1 / math.sqrt(in_channel * kernel_size**2)
|
||||
|
||||
self.stride = stride
|
||||
self.padding = padding
|
||||
self.dilation = dilation
|
||||
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.zeros(out_channel))
|
||||
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
def forward(self, input):
|
||||
out = conv2d(
|
||||
input,
|
||||
self.weight * self.scale,
|
||||
bias=self.bias,
|
||||
stride=self.stride,
|
||||
padding=self.padding,
|
||||
dilation=self.dilation)
|
||||
|
||||
return out
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},'
|
||||
f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'
|
||||
)
|
||||
|
||||
|
||||
class EqualTransposeConv2d(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
in_channel,
|
||||
out_channel,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=True):
|
||||
super().__init__()
|
||||
|
||||
self.weight = nn.Parameter(
|
||||
torch.randn(out_channel, in_channel, kernel_size, kernel_size))
|
||||
self.scale = 1 / math.sqrt(in_channel * kernel_size**2)
|
||||
|
||||
self.stride = stride
|
||||
self.padding = padding
|
||||
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.zeros(out_channel))
|
||||
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
def forward(self, input):
|
||||
weight = self.weight.transpose(0, 1)
|
||||
out = conv_transpose2d(
|
||||
input,
|
||||
weight * self.scale,
|
||||
bias=self.bias,
|
||||
stride=self.stride,
|
||||
padding=self.padding,
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},'
|
||||
f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'
|
||||
)
|
||||
|
||||
|
||||
class ToRGB(nn.Module):
|
||||
|
||||
def __init__(self, in_channel, upsample=True, blur_kernel=[1, 3, 3, 1]):
|
||||
super().__init__()
|
||||
|
||||
if upsample:
|
||||
self.upsample = Upsample(blur_kernel)
|
||||
self.conv = EqualConv2d(in_channel, 3, 3, stride=1, padding=1)
|
||||
|
||||
def forward(self, input, skip=None):
|
||||
out = self.conv(input)
|
||||
if skip is not None:
|
||||
skip = self.upsample(skip)
|
||||
out = out + skip
|
||||
return out
|
||||
|
||||
|
||||
class EqualLinear(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
in_dim,
|
||||
out_dim,
|
||||
bias=True,
|
||||
bias_init=0,
|
||||
lr_mul=1,
|
||||
activation=None):
|
||||
super().__init__()
|
||||
|
||||
self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
|
||||
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
|
||||
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
self.activation = activation
|
||||
|
||||
self.scale = (1 / math.sqrt(in_dim)) * lr_mul
|
||||
self.lr_mul = lr_mul
|
||||
|
||||
def forward(self, input):
|
||||
if self.activation:
|
||||
out = F.linear(input, self.weight * self.scale)
|
||||
out = fused_leaky_relu(out, self.bias * self.lr_mul)
|
||||
|
||||
else:
|
||||
out = F.linear(
|
||||
input, self.weight * self.scale, bias=self.bias * self.lr_mul)
|
||||
|
||||
return out
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})'
|
||||
)
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
|
||||
def __init__(self, kernel, factor=2):
|
||||
super().__init__()
|
||||
|
||||
self.factor = factor
|
||||
kernel = make_kernel(kernel) * (factor**2)
|
||||
self.register_buffer('kernel', kernel)
|
||||
|
||||
p = kernel.shape[0] - factor
|
||||
|
||||
pad0 = (p + 1) // 2 + factor - 1
|
||||
pad1 = p // 2
|
||||
|
||||
self.pad = (pad0, pad1)
|
||||
|
||||
def forward(self, input):
|
||||
out = upfirdn2d(
|
||||
input, self.kernel, up=self.factor, down=1, pad=self.pad)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
in_channel,
|
||||
out_channel,
|
||||
blur_kernel=[1, 3, 3, 1],
|
||||
downsample=True):
|
||||
super().__init__()
|
||||
|
||||
self.conv1 = ConvLayer(in_channel, in_channel, 3)
|
||||
self.conv2 = ConvLayer(
|
||||
in_channel, out_channel, 3, downsample=downsample)
|
||||
|
||||
self.skip = ConvLayer(
|
||||
in_channel,
|
||||
out_channel,
|
||||
1,
|
||||
downsample=downsample,
|
||||
activate=False,
|
||||
bias=False)
|
||||
|
||||
def forward(self, input):
|
||||
out = self.conv1(input)
|
||||
out = self.conv2(out)
|
||||
|
||||
skip = self.skip(input)
|
||||
out = (out + skip) / math.sqrt(2)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ConvLayer(nn.Sequential):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channel,
|
||||
out_channel,
|
||||
kernel_size,
|
||||
downsample=False,
|
||||
blur_kernel=[1, 3, 3, 1],
|
||||
bias=True,
|
||||
activate=True,
|
||||
):
|
||||
layers = []
|
||||
|
||||
if downsample:
|
||||
factor = 2
|
||||
p = (len(blur_kernel) - factor) + (kernel_size - 1)
|
||||
pad0 = (p + 1) // 2
|
||||
pad1 = p // 2
|
||||
|
||||
layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
|
||||
|
||||
stride = 2
|
||||
self.padding = 0
|
||||
|
||||
else:
|
||||
stride = 1
|
||||
self.padding = kernel_size // 2
|
||||
|
||||
layers.append(
|
||||
EqualConv2d(
|
||||
in_channel,
|
||||
out_channel,
|
||||
kernel_size,
|
||||
padding=self.padding,
|
||||
stride=stride,
|
||||
bias=bias and not activate,
|
||||
))
|
||||
|
||||
if activate:
|
||||
layers.append(FusedLeakyReLU(out_channel, bias=bias))
|
||||
|
||||
super().__init__(*layers)
|
||||
|
||||
|
||||
class Blur(nn.Module):
|
||||
|
||||
def __init__(self, kernel, pad, upsample_factor=1):
|
||||
super().__init__()
|
||||
|
||||
kernel = make_kernel(kernel)
|
||||
|
||||
if upsample_factor > 1:
|
||||
kernel = kernel * (upsample_factor**2)
|
||||
|
||||
self.register_buffer('kernel', kernel)
|
||||
|
||||
self.pad = pad
|
||||
|
||||
def forward(self, input):
|
||||
out = upfirdn2d(input, self.kernel, pad=self.pad)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class GatedConv2dWithActivation(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
bias=True,
|
||||
activation=None):
|
||||
super(GatedConv2dWithActivation, self).__init__()
|
||||
self.activation = FusedLeakyReLU(out_channels, bias=False)
|
||||
self.conv2d = EqualConv2d(in_channels, out_channels, kernel_size,
|
||||
stride, padding, bias, dilation)
|
||||
self.mask_conv2d = EqualConv2d(in_channels, out_channels, kernel_size,
|
||||
stride, padding, bias, dilation)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
def gated(self, mask):
|
||||
return self.sigmoid(mask)
|
||||
|
||||
def forward(self, input):
|
||||
x = self.conv2d(input)
|
||||
mask = self.mask_conv2d(input)
|
||||
if self.activation is not None:
|
||||
x = self.activation(x) * self.gated(mask)
|
||||
else:
|
||||
x = x * self.gated(mask)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def make_kernel(k):
|
||||
k = torch.tensor(k, dtype=torch.float32)
|
||||
|
||||
if k.ndim == 1:
|
||||
k = k[None, :] * k[:, None]
|
||||
|
||||
k /= k.sum()
|
||||
|
||||
return k
|
||||
|
||||
|
||||
class SPDNorm(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
norm_channel,
|
||||
label_nc,
|
||||
norm_type='position',
|
||||
use_equal=False):
|
||||
super().__init__()
|
||||
param_free_norm_type = norm_type
|
||||
ks = 3
|
||||
if param_free_norm_type == 'instance':
|
||||
self.param_free_norm = nn.InstanceNorm2d(
|
||||
norm_channel, affine=False)
|
||||
elif param_free_norm_type == 'batch':
|
||||
self.param_free_norm = nn.BatchNorm2d(norm_channel, affine=False)
|
||||
elif param_free_norm_type == 'position':
|
||||
self.param_free_norm = PositionalNorm2d
|
||||
else:
|
||||
raise ValueError(
|
||||
'%s is not a recognized param-free norm type in SPADE'
|
||||
% param_free_norm_type)
|
||||
|
||||
# The dimension of the intermediate embedding space. Yes, hardcoded.
|
||||
pw = ks // 2
|
||||
nhidden = 128
|
||||
if not use_equal:
|
||||
self.mlp_activate = nn.Sequential(
|
||||
nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw),
|
||||
nn.ReLU())
|
||||
self.mlp_gamma = nn.Conv2d(
|
||||
nhidden, norm_channel, kernel_size=ks, padding=pw)
|
||||
self.mlp_beta = nn.Conv2d(
|
||||
nhidden, norm_channel, kernel_size=ks, padding=pw)
|
||||
else:
|
||||
self.mlp_activate = nn.Sequential(*[
|
||||
EqualConv2d(label_nc, nhidden, kernel_size=ks, padding=pw),
|
||||
FusedLeakyReLU(nhidden, bias=False)
|
||||
])
|
||||
self.mlp_gamma = EqualConv2d(
|
||||
nhidden, norm_channel, kernel_size=ks, padding=pw)
|
||||
self.mlp_beta = EqualConv2d(
|
||||
nhidden, norm_channel, kernel_size=ks, padding=pw)
|
||||
|
||||
def forward(self, x, prior_f, weight=1.0):
|
||||
normalized = self.param_free_norm(x)
|
||||
# Part 2. produce scaling and bias conditioned on condition feature
|
||||
actv = self.mlp_activate(prior_f)
|
||||
gamma = self.mlp_gamma(actv) * weight
|
||||
beta = self.mlp_beta(actv) * weight
|
||||
# apply scale and bias
|
||||
out = normalized * (1 + gamma) + beta
|
||||
return out
|
||||
|
||||
|
||||
def PositionalNorm2d(x, epsilon=1e-5):
|
||||
# x: B*C*W*H normalize in C dim
|
||||
mean = x.mean(dim=1, keepdim=True)
|
||||
std = x.var(dim=1, keepdim=True).add(epsilon).sqrt()
|
||||
output = (x - mean) / std
|
||||
return output
|
||||
@@ -0,0 +1,358 @@
|
||||
import collections
|
||||
import functools
|
||||
import math
|
||||
from tkinter.ttk import Style
|
||||
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .base_function import *
|
||||
from .flow_module import MaskStyle, StyleFlow
|
||||
from .tps import TPS
|
||||
|
||||
|
||||
# adding flow version
|
||||
class Encoder_wiflow(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size,
|
||||
input_dim,
|
||||
channels,
|
||||
num_labels=None,
|
||||
match_kernels=None,
|
||||
blur_kernel=[1, 3, 3, 1],
|
||||
):
|
||||
super().__init__()
|
||||
self.first = EncoderLayer_flow(input_dim, channels[size], 1)
|
||||
self.convs = nn.ModuleList()
|
||||
self.num_labels = num_labels
|
||||
self.match_kernels = match_kernels
|
||||
|
||||
log_size = int(math.log(size, 2))
|
||||
self.log_size = log_size
|
||||
|
||||
in_channel = channels[size]
|
||||
for i in range(log_size - 1, 3, -1):
|
||||
out_channel = channels[2**i]
|
||||
num_label = num_labels[2**i] if num_labels is not None else None
|
||||
match_kernel = match_kernels[
|
||||
2**i] if match_kernels is not None else None
|
||||
use_extraction = num_label and match_kernel
|
||||
conv = EncoderLayer_flow(
|
||||
in_channel,
|
||||
out_channel,
|
||||
kernel_size=3,
|
||||
downsample=True,
|
||||
blur_kernel=blur_kernel,
|
||||
use_extraction=use_extraction,
|
||||
num_label=num_label,
|
||||
match_kernel=match_kernel)
|
||||
|
||||
self.convs.append(conv)
|
||||
in_channel = out_channel
|
||||
|
||||
def forward(self, input, recoder=None, out_list=None):
|
||||
out = self.first(input)
|
||||
for layer in self.convs:
|
||||
out = layer(out, recoder)
|
||||
if out_list is not None:
|
||||
out_list.append(out)
|
||||
return out
|
||||
|
||||
|
||||
class Decoder_wiflow_wavelet_fuse25(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size,
|
||||
channels,
|
||||
num_labels,
|
||||
match_kernels,
|
||||
blur_kernel=[1, 3, 3, 1],
|
||||
wavelet_down_levels={'16': 3},
|
||||
window_size=8,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.convs = nn.ModuleList()
|
||||
# input at resolution 16*16
|
||||
in_channel = channels[16]
|
||||
self.log_size = int(math.log(size, 2))
|
||||
self.conv_mask_dict = nn.ModuleDict()
|
||||
self.conv_mask_fuse_dict = nn.ModuleDict()
|
||||
|
||||
flow_fusion = False
|
||||
|
||||
for i in range(4, self.log_size + 1):
|
||||
out_channel = channels[2**i]
|
||||
num_label, match_kernel = num_labels[2**i], match_kernels[2**i]
|
||||
use_distribution = num_label and match_kernel
|
||||
upsample = (i != 4)
|
||||
wavelet_down_level = wavelet_down_levels[(2**i)]
|
||||
base_layer = functools.partial(
|
||||
DecoderLayer_flow_wavelet_fuse24,
|
||||
out_channel=out_channel,
|
||||
kernel_size=3,
|
||||
blur_kernel=blur_kernel,
|
||||
use_distribution=use_distribution,
|
||||
num_label=num_label,
|
||||
match_kernel=match_kernel,
|
||||
wavelet_down_level=wavelet_down_level,
|
||||
window_size=window_size)
|
||||
# mask head for fusion
|
||||
if use_distribution:
|
||||
conv_mask = [
|
||||
EqualConv2d(
|
||||
2 * out_channel,
|
||||
3,
|
||||
3,
|
||||
stride=1,
|
||||
padding=3 // 2,
|
||||
bias=False),
|
||||
nn.Sigmoid()
|
||||
]
|
||||
conv_mask = nn.Sequential(*conv_mask)
|
||||
self.conv_mask_dict[str(2**i)] = conv_mask
|
||||
|
||||
if not i == 4:
|
||||
conv_mask_fuse = nn.Sequential(*[
|
||||
EqualConv2d(
|
||||
2, 1, 3, stride=1, padding=3 // 2, bias=False),
|
||||
nn.Sigmoid()
|
||||
])
|
||||
self.conv_mask_fuse_dict[str(2**i)] = conv_mask_fuse
|
||||
|
||||
if not flow_fusion:
|
||||
self.conv_flow_fusion = nn.Sequential(
|
||||
EqualConv2d(
|
||||
2 * out_channel,
|
||||
1,
|
||||
kernel_size=7,
|
||||
stride=1,
|
||||
padding=3,
|
||||
bias=False), nn.Sigmoid())
|
||||
flow_fusion = True
|
||||
|
||||
up = nn.Module()
|
||||
up.conv0 = base_layer(in_channel=in_channel, upsample=upsample)
|
||||
up.conv1 = base_layer(in_channel=out_channel, upsample=False)
|
||||
up.to_rgb = ToRGB(out_channel, upsample=upsample)
|
||||
self.convs.append(up)
|
||||
in_channel = out_channel
|
||||
|
||||
style_in_channels = channels[16]
|
||||
self.style_out_channel = 128
|
||||
self.cond_style = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
style_in_channels,
|
||||
self.style_out_channel,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1), nn.LeakyReLU(inplace=False, negative_slope=0.1),
|
||||
nn.AdaptiveAvgPool2d(1))
|
||||
self.image_style = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
style_in_channels,
|
||||
self.style_out_channel,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1), nn.LeakyReLU(inplace=False, negative_slope=0.1),
|
||||
nn.AdaptiveAvgPool2d(1))
|
||||
self.flow_model = StyleFlow(
|
||||
channels, self.log_size, style_in=2 * self.style_out_channel)
|
||||
|
||||
self.num_labels, self.match_kernels = num_labels, match_kernels
|
||||
|
||||
# for mask prediction
|
||||
self.mask_style = MaskStyle(
|
||||
channels,
|
||||
self.log_size,
|
||||
style_in=2 * self.style_out_channel,
|
||||
channels_multiplier=1)
|
||||
|
||||
# tps transformation
|
||||
self.tps = TPS()
|
||||
|
||||
def forward(self,
|
||||
input,
|
||||
neural_textures,
|
||||
skeleton_features,
|
||||
source_features,
|
||||
kp_skeleton,
|
||||
recoder,
|
||||
add_nted=True):
|
||||
source_features = source_features[::-1]
|
||||
skeleton_features = skeleton_features[::-1]
|
||||
|
||||
counter = 0
|
||||
out, skip = input, None
|
||||
|
||||
last_flow = None
|
||||
mask_all_h, mask_all_l = [], []
|
||||
delta_list = []
|
||||
delta_x_all = []
|
||||
delta_y_all = []
|
||||
last_flow_all = []
|
||||
filter_x = [[0, 0, 0], [1, -2, 1], [0, 0, 0]]
|
||||
filter_y = [[0, 1, 0], [0, -2, 0], [0, 1, 0]]
|
||||
filter_diag1 = [[1, 0, 0], [0, -2, 0], [0, 0, 1]]
|
||||
filter_diag2 = [[0, 0, 1], [0, -2, 0], [1, 0, 0]]
|
||||
weight_array = np.ones([3, 3, 1, 4])
|
||||
weight_array[:, :, 0, 0] = filter_x
|
||||
weight_array[:, :, 0, 1] = filter_y
|
||||
weight_array[:, :, 0, 2] = filter_diag1
|
||||
weight_array[:, :, 0, 3] = filter_diag2
|
||||
weight_array = torch.FloatTensor(weight_array).permute(3, 2, 0, 1).to(
|
||||
input.device)
|
||||
self.weight = nn.Parameter(data=weight_array, requires_grad=False)
|
||||
|
||||
B = source_features[0].shape[0]
|
||||
source_style = self.cond_style(source_features[0]).view(B, -1)
|
||||
target_style = self.image_style(skeleton_features[0]).view(B, -1)
|
||||
style = torch.cat([source_style, target_style], 1)
|
||||
|
||||
for i, up in enumerate(self.convs):
|
||||
use_distribution = (
|
||||
self.num_labels[2**(i + 4)] and self.match_kernels[2**(i + 4)])
|
||||
if use_distribution:
|
||||
# warp features with styleflow
|
||||
source_feature = source_features[i]
|
||||
skeleton_feature = skeleton_features[i]
|
||||
if last_flow is not None:
|
||||
last_flow = F.interpolate(
|
||||
last_flow, scale_factor=2, mode='bilinear')
|
||||
s_warp_after = F.grid_sample(
|
||||
source_feature,
|
||||
last_flow.detach().permute(0, 2, 3, 1),
|
||||
mode='bilinear',
|
||||
padding_mode='border')
|
||||
else:
|
||||
s_warp_after = source_feature
|
||||
scale = str(2**(i + 4))
|
||||
|
||||
# use tps transformation to estimate flow at the very beginning
|
||||
if last_flow is not None:
|
||||
style_map = self.flow_model.netStyle[scale](s_warp_after,
|
||||
style)
|
||||
flow = self.flow_model.netF[scale](style_map, style)
|
||||
flow = apply_offset(flow)
|
||||
|
||||
else:
|
||||
style_map = self.flow_model.netStyle[scale](s_warp_after,
|
||||
style)
|
||||
flow = self.flow_model.netF[scale](style_map, style)
|
||||
flow_dense = apply_offset(flow)
|
||||
flow_tps = self.tps(source_feature, kp_skeleton)
|
||||
warped_dense = F.grid_sample(
|
||||
source_feature,
|
||||
flow_dense,
|
||||
mode='bilinear',
|
||||
padding_mode='border')
|
||||
warped_tps = F.grid_sample(
|
||||
source_feature,
|
||||
flow_tps,
|
||||
mode='bilinear',
|
||||
padding_mode='border')
|
||||
contribution_map = self.conv_flow_fusion(
|
||||
torch.cat([warped_dense, warped_tps], 1))
|
||||
flow = contribution_map * flow_tps.permute(0, 3, 1, 2) + (
|
||||
1 - contribution_map) * flow_dense.permute(0, 3, 1, 2)
|
||||
flow = flow.permute(0, 2, 3, 1).contiguous()
|
||||
|
||||
if last_flow is not None:
|
||||
# update flow according to the last scale flow
|
||||
flow = F.grid_sample(
|
||||
last_flow,
|
||||
flow,
|
||||
mode='bilinear',
|
||||
padding_mode='border')
|
||||
else:
|
||||
flow = flow.permute(0, 3, 1, 2)
|
||||
|
||||
last_flow = flow
|
||||
s_warp = F.grid_sample(
|
||||
source_feature,
|
||||
flow.permute(0, 2, 3, 1),
|
||||
mode='bilinear',
|
||||
padding_mode='border')
|
||||
|
||||
# refine flow according to the original flow
|
||||
flow = self.flow_model.netRefine[scale](
|
||||
torch.cat([s_warp, skeleton_feature], 1))
|
||||
|
||||
delta_list.append(flow)
|
||||
flow = apply_offset(flow)
|
||||
flow = F.grid_sample(
|
||||
last_flow, flow, mode='bilinear', padding_mode='border')
|
||||
last_flow_all.append(flow)
|
||||
|
||||
last_flow = flow
|
||||
flow_x, flow_y = torch.split(last_flow, 1, dim=1)
|
||||
delta_x = F.conv2d(flow_x, self.weight)
|
||||
delta_y = F.conv2d(flow_y, self.weight)
|
||||
delta_x_all.append(delta_x)
|
||||
delta_y_all.append(delta_y)
|
||||
|
||||
s_warp = F.grid_sample(
|
||||
source_feature,
|
||||
last_flow.permute(0, 2, 3, 1),
|
||||
mode='bilinear',
|
||||
padding_mode='border')
|
||||
|
||||
# nted attention
|
||||
neural_texture_conv0 = neural_textures[counter]
|
||||
neural_texture_conv1 = neural_textures[counter + 1]
|
||||
counter += 2
|
||||
|
||||
if not add_nted: # turn off the nted attention
|
||||
neural_texture_conv0, neural_texture_conv1 = None, None
|
||||
else:
|
||||
neural_texture_conv0, neural_texture_conv1 = None, None
|
||||
s_warp = None
|
||||
|
||||
mask_style_net = self.mask_style.netM[
|
||||
scale] if use_distribution else None
|
||||
out, mask_h, mask_l = up.conv0(
|
||||
out,
|
||||
neural_texture=neural_texture_conv0,
|
||||
recoder=recoder,
|
||||
warped_texture=s_warp,
|
||||
style_net=mask_style_net,
|
||||
gstyle=style)
|
||||
out, mask_h, mask_l = up.conv1(
|
||||
out,
|
||||
neural_texture=neural_texture_conv1,
|
||||
recoder=recoder,
|
||||
warped_texture=s_warp,
|
||||
style_net=mask_style_net,
|
||||
gstyle=style)
|
||||
if use_distribution:
|
||||
if mask_h is not None:
|
||||
mask_all_h.append(mask_h)
|
||||
if mask_l is not None:
|
||||
mask_all_l.append(mask_l)
|
||||
skip = up.to_rgb(out, skip)
|
||||
|
||||
image = skip
|
||||
return image, delta_x_all, delta_y_all, delta_list, last_flow_all, mask_all_h, mask_all_l
|
||||
|
||||
|
||||
def apply_offset(offset):
|
||||
sizes = list(offset.size()[2:])
|
||||
grid_list = torch.meshgrid(
|
||||
[torch.arange(size, device=offset.device) for size in sizes])
|
||||
grid_list = reversed(grid_list)
|
||||
# apply offset
|
||||
grid_list = [
|
||||
grid.float().unsqueeze(0) + offset[:, dim, ...]
|
||||
for dim, grid in enumerate(grid_list)
|
||||
]
|
||||
# normalize
|
||||
grid_list = [
|
||||
grid / ((size - 1.0) / 2.0) - 1.0
|
||||
for grid, size in zip(grid_list, reversed(sizes))
|
||||
]
|
||||
|
||||
return torch.stack(grid_list, dim=-1)
|
||||
@@ -0,0 +1,227 @@
|
||||
import contextlib
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
from torch import autograd
|
||||
from torch.nn import functional as F
|
||||
|
||||
enabled = True
|
||||
weight_gradients_disabled = False
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def no_weight_gradients():
|
||||
global weight_gradients_disabled
|
||||
|
||||
old = weight_gradients_disabled
|
||||
weight_gradients_disabled = True
|
||||
yield
|
||||
weight_gradients_disabled = old
|
||||
|
||||
|
||||
def conv2d(input,
|
||||
weight,
|
||||
bias=None,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
groups=1):
|
||||
if could_use_op(input):
|
||||
return conv2d_gradfix(
|
||||
transpose=False,
|
||||
weight_shape=weight.shape,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
output_padding=0,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
).apply(input, weight, bias)
|
||||
|
||||
return F.conv2d(
|
||||
input=input,
|
||||
weight=weight,
|
||||
bias=bias,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
)
|
||||
|
||||
|
||||
def conv_transpose2d(
|
||||
input,
|
||||
weight,
|
||||
bias=None,
|
||||
stride=1,
|
||||
padding=0,
|
||||
output_padding=0,
|
||||
groups=1,
|
||||
dilation=1,
|
||||
):
|
||||
if could_use_op(input):
|
||||
return conv2d_gradfix(
|
||||
transpose=True,
|
||||
weight_shape=weight.shape,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
output_padding=output_padding,
|
||||
groups=groups,
|
||||
dilation=dilation,
|
||||
).apply(input, weight, bias)
|
||||
|
||||
return F.conv_transpose2d(
|
||||
input=input,
|
||||
weight=weight,
|
||||
bias=bias,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
output_padding=output_padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
)
|
||||
|
||||
|
||||
def could_use_op(input):
|
||||
if (not enabled) or (not torch.backends.cudnn.enabled):
|
||||
return False
|
||||
|
||||
if input.device.type != 'cuda':
|
||||
return False
|
||||
|
||||
if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.']):
|
||||
return True
|
||||
|
||||
warnings.warn(
|
||||
f'conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d().'
|
||||
)
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def ensure_tuple(xs, ndim):
|
||||
xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs, ) * ndim
|
||||
|
||||
return xs
|
||||
|
||||
|
||||
conv2d_gradfix_cache = dict()
|
||||
|
||||
|
||||
def conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding,
|
||||
dilation, groups):
|
||||
ndim = 2
|
||||
weight_shape = tuple(weight_shape)
|
||||
stride = ensure_tuple(stride, ndim)
|
||||
padding = ensure_tuple(padding, ndim)
|
||||
output_padding = ensure_tuple(output_padding, ndim)
|
||||
dilation = ensure_tuple(dilation, ndim)
|
||||
|
||||
key = (transpose, weight_shape, stride, padding, output_padding, dilation,
|
||||
groups)
|
||||
if key in conv2d_gradfix_cache:
|
||||
return conv2d_gradfix_cache[key]
|
||||
|
||||
common_kwargs = dict(
|
||||
stride=stride, padding=padding, dilation=dilation, groups=groups)
|
||||
|
||||
def calc_output_padding(input_shape, output_shape):
|
||||
if transpose:
|
||||
return [0, 0]
|
||||
|
||||
shape1 = (output_shape[i + 2] - 1) * stride[i]
|
||||
shape2 = (1 - 2 * padding[i]) - dilation[i] * (weight_shape[i + 2] - 1)
|
||||
return [input_shape[i + 2] - shape1 - shape2 for i in range(ndim)]
|
||||
|
||||
class Conv2d(autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input, weight, bias):
|
||||
if not transpose:
|
||||
out = F.conv2d(
|
||||
input=input, weight=weight, bias=bias, **common_kwargs)
|
||||
|
||||
else:
|
||||
out = F.conv_transpose2d(
|
||||
input=input,
|
||||
weight=weight,
|
||||
bias=bias,
|
||||
output_padding=output_padding,
|
||||
**common_kwargs,
|
||||
)
|
||||
|
||||
ctx.save_for_backward(input, weight)
|
||||
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, weight = ctx.saved_tensors
|
||||
grad_input, grad_weight, grad_bias = None, None, None
|
||||
|
||||
if ctx.needs_input_grad[0]:
|
||||
p = calc_output_padding(
|
||||
input_shape=input.shape, output_shape=grad_output.shape)
|
||||
grad_input = conv2d_gradfix(
|
||||
transpose=(not transpose),
|
||||
weight_shape=weight_shape,
|
||||
output_padding=p,
|
||||
**common_kwargs,
|
||||
).apply(grad_output, weight, None)
|
||||
|
||||
if ctx.needs_input_grad[1] and not weight_gradients_disabled:
|
||||
grad_weight = Conv2dGradWeight.apply(grad_output, input)
|
||||
|
||||
if ctx.needs_input_grad[2]:
|
||||
grad_bias = grad_output.sum((0, 2, 3))
|
||||
|
||||
return grad_input, grad_weight, grad_bias
|
||||
|
||||
class Conv2dGradWeight(autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, grad_output, input):
|
||||
op = torch._C._jit_get_operation(
|
||||
'aten::cudnn_convolution_backward_weight' if not transpose else
|
||||
'aten::cudnn_convolution_transpose_backward_weight')
|
||||
flags = [
|
||||
torch.backends.cudnn.benchmark,
|
||||
torch.backends.cudnn.deterministic,
|
||||
torch.backends.cudnn.allow_tf32,
|
||||
]
|
||||
grad_weight = op(
|
||||
weight_shape,
|
||||
grad_output,
|
||||
input,
|
||||
padding,
|
||||
stride,
|
||||
dilation,
|
||||
groups,
|
||||
*flags,
|
||||
)
|
||||
ctx.save_for_backward(grad_output, input)
|
||||
|
||||
return grad_weight
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_grad_weight):
|
||||
grad_output, input = ctx.saved_tensors
|
||||
grad_grad_output, grad_grad_input = None, None
|
||||
|
||||
if ctx.needs_input_grad[0]:
|
||||
grad_grad_output = Conv2d.apply(input, grad_grad_weight, None)
|
||||
|
||||
if ctx.needs_input_grad[1]:
|
||||
p = calc_output_padding(
|
||||
input_shape=input.shape, output_shape=grad_output.shape)
|
||||
grad_grad_input = conv2d_gradfix(
|
||||
transpose=(not transpose),
|
||||
weight_shape=weight_shape,
|
||||
output_padding=p,
|
||||
**common_kwargs,
|
||||
).apply(grad_output, grad_grad_weight, None)
|
||||
|
||||
return grad_grad_output, grad_grad_input
|
||||
|
||||
conv2d_gradfix_cache[key] = Conv2d
|
||||
|
||||
return Conv2d
|
||||
@@ -0,0 +1,64 @@
|
||||
import collections
|
||||
import os
|
||||
import sys
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from .base_module import *
|
||||
|
||||
sys.path.insert(0, os.path.abspath(os.path.dirname(os.path.dirname(__file__))))
|
||||
|
||||
|
||||
class Generator(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size,
|
||||
semantic_dim,
|
||||
channels,
|
||||
num_labels,
|
||||
match_kernels,
|
||||
blur_kernel=[1, 3, 3, 1],
|
||||
wavelet_down_levels={'16': 3},
|
||||
window_size=8,
|
||||
):
|
||||
super().__init__()
|
||||
self.size = size
|
||||
self.reference_encoder = Encoder_wiflow(size, 3, channels, num_labels,
|
||||
match_kernels, blur_kernel)
|
||||
|
||||
self.skeleton_encoder = Encoder_wiflow(
|
||||
size,
|
||||
semantic_dim,
|
||||
channels,
|
||||
)
|
||||
|
||||
self.target_image_renderer = Decoder_wiflow_wavelet_fuse25(
|
||||
size, channels, num_labels, match_kernels, blur_kernel,
|
||||
wavelet_down_levels, window_size)
|
||||
|
||||
def _cal_temp(self, module):
|
||||
return sum(p.numel() for p in module.parameters() if p.requires_grad)
|
||||
|
||||
def forward(self, source_image, skeleton, kp_skeleton):
|
||||
output_dict = {}
|
||||
recoder = collections.defaultdict(list)
|
||||
skeleton_feature_list, source_feature_list = [], []
|
||||
skeleton_feature = self.skeleton_encoder(
|
||||
skeleton, out_list=skeleton_feature_list)
|
||||
_ = self.reference_encoder(
|
||||
source_image, recoder, out_list=source_feature_list)
|
||||
neural_textures = recoder['neural_textures']
|
||||
|
||||
output_dict['fake_image'], delta_x_all, delta_y_all, delta_list, last_flow_all, mask_all_h, mask_all_l = \
|
||||
self.target_image_renderer(skeleton_feature, neural_textures, skeleton_feature_list,
|
||||
source_feature_list, kp_skeleton, recoder)
|
||||
output_dict['info'] = recoder
|
||||
output_dict['delta_x'] = delta_x_all
|
||||
output_dict['delta_y'] = delta_y_all
|
||||
output_dict['delta_list'] = delta_list
|
||||
output_dict['last_flow_all'] = last_flow_all
|
||||
output_dict['mask_all_h'] = mask_all_h
|
||||
output_dict['mask_all_l'] = mask_all_l
|
||||
return output_dict
|
||||
@@ -0,0 +1,346 @@
|
||||
from math import sqrt
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .base_function import EqualConv2d, EqualLinear
|
||||
|
||||
|
||||
def TVLoss(x):
|
||||
tv_h = x[:, :, 1:, :] - x[:, :, :-1, :]
|
||||
tv_w = x[:, :, :, 1:] - x[:, :, :, :-1]
|
||||
|
||||
return torch.mean(torch.abs(tv_h)) + torch.mean(torch.abs(tv_w))
|
||||
|
||||
|
||||
class MaskStyle(nn.Module):
|
||||
|
||||
def __init__(self, channels, log_size, style_in, channels_multiplier=2):
|
||||
super().__init__()
|
||||
self.log_size = log_size
|
||||
padding_type = 'zero'
|
||||
actvn = 'lrelu'
|
||||
normalize_mlp = False
|
||||
modulated_conv = True
|
||||
|
||||
self.netM = nn.ModuleDict()
|
||||
|
||||
for i in range(4, self.log_size + 1):
|
||||
out_channel = channels[2**i]
|
||||
|
||||
style_mask = StyledConvBlock(
|
||||
channels_multiplier * out_channel,
|
||||
channels_multiplier * out_channel,
|
||||
latent_dim=style_in,
|
||||
padding=padding_type,
|
||||
actvn=actvn,
|
||||
normalize_affine_output=normalize_mlp,
|
||||
modulated_conv=modulated_conv)
|
||||
|
||||
scale = str(2**i)
|
||||
self.netM[scale] = style_mask
|
||||
|
||||
|
||||
class StyleFlow(nn.Module):
|
||||
|
||||
def __init__(self, channels, log_size, style_in):
|
||||
super().__init__()
|
||||
self.log_size = log_size
|
||||
padding_type = 'zero'
|
||||
actvn = 'lrelu'
|
||||
normalize_mlp = False
|
||||
modulated_conv = True
|
||||
|
||||
self.netRefine = nn.ModuleDict()
|
||||
self.netStyle = nn.ModuleDict()
|
||||
self.netF = nn.ModuleDict()
|
||||
|
||||
for i in range(4, self.log_size + 1):
|
||||
out_channel = channels[2**i]
|
||||
|
||||
netRefine_layer = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(
|
||||
2 * out_channel,
|
||||
out_channels=128,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1),
|
||||
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
|
||||
torch.nn.Conv2d(
|
||||
in_channels=128,
|
||||
out_channels=64,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1),
|
||||
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
|
||||
torch.nn.Conv2d(
|
||||
in_channels=64,
|
||||
out_channels=32,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1),
|
||||
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
|
||||
torch.nn.Conv2d(
|
||||
in_channels=32,
|
||||
out_channels=2,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1))
|
||||
|
||||
style_block = StyledConvBlock(
|
||||
out_channel,
|
||||
49,
|
||||
latent_dim=style_in,
|
||||
padding=padding_type,
|
||||
actvn=actvn,
|
||||
normalize_affine_output=normalize_mlp,
|
||||
modulated_conv=modulated_conv)
|
||||
|
||||
style_F_block = Styled_F_ConvBlock(
|
||||
49,
|
||||
2,
|
||||
latent_dim=style_in,
|
||||
padding=padding_type,
|
||||
actvn=actvn,
|
||||
normalize_affine_output=normalize_mlp,
|
||||
modulated_conv=modulated_conv)
|
||||
|
||||
scale = str(2**i)
|
||||
self.netRefine[scale] = (netRefine_layer)
|
||||
self.netStyle[scale] = (style_block)
|
||||
self.netF[scale] = (style_F_block)
|
||||
|
||||
|
||||
class StyledConvBlock(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
fin,
|
||||
fout,
|
||||
latent_dim=256,
|
||||
padding='zero',
|
||||
actvn='lrelu',
|
||||
normalize_affine_output=False,
|
||||
modulated_conv=False):
|
||||
super(StyledConvBlock, self).__init__()
|
||||
if not modulated_conv:
|
||||
if padding == 'reflect':
|
||||
padding_layer = nn.ReflectionPad2d
|
||||
else:
|
||||
padding_layer = nn.ZeroPad2d
|
||||
|
||||
if modulated_conv:
|
||||
conv2d = ModulatedConv2d
|
||||
else:
|
||||
conv2d = EqualConv2d
|
||||
|
||||
if modulated_conv:
|
||||
self.actvn_gain = sqrt(2)
|
||||
else:
|
||||
self.actvn_gain = 1.0
|
||||
|
||||
self.modulated_conv = modulated_conv
|
||||
|
||||
if actvn == 'relu':
|
||||
activation = nn.ReLU(True)
|
||||
else:
|
||||
activation = nn.LeakyReLU(0.2, True)
|
||||
|
||||
if self.modulated_conv:
|
||||
self.conv0 = conv2d(
|
||||
fin,
|
||||
fout,
|
||||
kernel_size=3,
|
||||
padding_type=padding,
|
||||
upsample=False,
|
||||
latent_dim=latent_dim,
|
||||
normalize_mlp=normalize_affine_output)
|
||||
else:
|
||||
conv0 = conv2d(fin, fout, kernel_size=3)
|
||||
|
||||
seq0 = [padding_layer(1), conv0]
|
||||
self.conv0 = nn.Sequential(*seq0)
|
||||
|
||||
self.actvn0 = activation
|
||||
|
||||
if self.modulated_conv:
|
||||
self.conv1 = conv2d(
|
||||
fout,
|
||||
fout,
|
||||
kernel_size=3,
|
||||
padding_type=padding,
|
||||
downsample=False,
|
||||
latent_dim=latent_dim,
|
||||
normalize_mlp=normalize_affine_output)
|
||||
else:
|
||||
conv1 = conv2d(fout, fout, kernel_size=3)
|
||||
seq1 = [padding_layer(1), conv1]
|
||||
self.conv1 = nn.Sequential(*seq1)
|
||||
|
||||
self.actvn1 = activation
|
||||
|
||||
def forward(self, input, latent=None):
|
||||
if self.modulated_conv:
|
||||
out = self.conv0(input, latent)
|
||||
else:
|
||||
out = self.conv0(input)
|
||||
|
||||
out = self.actvn0(out) * self.actvn_gain
|
||||
|
||||
if self.modulated_conv:
|
||||
out = self.conv1(out, latent)
|
||||
else:
|
||||
out = self.conv1(out)
|
||||
|
||||
out = self.actvn1(out) * self.actvn_gain
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Styled_F_ConvBlock(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
fin,
|
||||
fout,
|
||||
latent_dim=256,
|
||||
padding='zero',
|
||||
actvn='lrelu',
|
||||
normalize_affine_output=False,
|
||||
modulated_conv=False):
|
||||
super(Styled_F_ConvBlock, self).__init__()
|
||||
if not modulated_conv:
|
||||
if padding == 'reflect':
|
||||
padding_layer = nn.ReflectionPad2d
|
||||
else:
|
||||
padding_layer = nn.ZeroPad2d
|
||||
|
||||
if modulated_conv:
|
||||
conv2d = ModulatedConv2d
|
||||
else:
|
||||
conv2d = EqualConv2d
|
||||
|
||||
if modulated_conv:
|
||||
self.actvn_gain = sqrt(2)
|
||||
else:
|
||||
self.actvn_gain = 1.0
|
||||
|
||||
self.modulated_conv = modulated_conv
|
||||
|
||||
if actvn == 'relu':
|
||||
activation = nn.ReLU(True)
|
||||
else:
|
||||
activation = nn.LeakyReLU(0.2, True)
|
||||
|
||||
if self.modulated_conv:
|
||||
self.conv0 = conv2d(
|
||||
fin,
|
||||
128,
|
||||
kernel_size=3,
|
||||
padding_type=padding,
|
||||
upsample=False,
|
||||
latent_dim=latent_dim,
|
||||
normalize_mlp=normalize_affine_output)
|
||||
else:
|
||||
conv0 = conv2d(fin, 128, kernel_size=3)
|
||||
|
||||
seq0 = [padding_layer(1), conv0]
|
||||
self.conv0 = nn.Sequential(*seq0)
|
||||
|
||||
self.actvn0 = activation
|
||||
|
||||
if self.modulated_conv:
|
||||
self.conv1 = conv2d(
|
||||
128,
|
||||
fout,
|
||||
kernel_size=3,
|
||||
padding_type=padding,
|
||||
downsample=False,
|
||||
latent_dim=latent_dim,
|
||||
normalize_mlp=normalize_affine_output)
|
||||
else:
|
||||
conv1 = conv2d(128, fout, kernel_size=3)
|
||||
seq1 = [padding_layer(1), conv1]
|
||||
self.conv1 = nn.Sequential(*seq1)
|
||||
|
||||
def forward(self, input, latent=None):
|
||||
if self.modulated_conv:
|
||||
out = self.conv0(input, latent)
|
||||
else:
|
||||
out = self.conv0(input)
|
||||
|
||||
out = self.actvn0(out) * self.actvn_gain
|
||||
|
||||
if self.modulated_conv:
|
||||
out = self.conv1(out, latent)
|
||||
else:
|
||||
out = self.conv1(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ModulatedConv2d(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
fin,
|
||||
fout,
|
||||
kernel_size,
|
||||
padding_type='zero',
|
||||
upsample=False,
|
||||
downsample=False,
|
||||
latent_dim=512,
|
||||
normalize_mlp=False):
|
||||
super(ModulatedConv2d, self).__init__()
|
||||
self.in_channels = fin
|
||||
self.out_channels = fout
|
||||
self.kernel_size = kernel_size
|
||||
padding_size = kernel_size // 2
|
||||
|
||||
if kernel_size == 1:
|
||||
self.demudulate = False
|
||||
else:
|
||||
self.demudulate = True
|
||||
|
||||
self.weight = nn.Parameter(
|
||||
torch.Tensor(fout, fin, kernel_size, kernel_size))
|
||||
self.bias = nn.Parameter(torch.Tensor(1, fout, 1, 1))
|
||||
|
||||
if normalize_mlp:
|
||||
self.mlp_class_std = nn.Sequential(
|
||||
EqualLinear(latent_dim, fin), PixelNorm())
|
||||
else:
|
||||
self.mlp_class_std = EqualLinear(latent_dim, fin)
|
||||
|
||||
if padding_type == 'reflect':
|
||||
self.padding = nn.ReflectionPad2d(padding_size)
|
||||
else:
|
||||
self.padding = nn.ZeroPad2d(padding_size)
|
||||
|
||||
self.weight.data.normal_()
|
||||
self.bias.data.zero_()
|
||||
|
||||
def forward(self, input, latent):
|
||||
fan_in = self.weight.data.size(1) * self.weight.data[0][0].numel()
|
||||
weight = self.weight * sqrt(2 / fan_in)
|
||||
weight = weight.view(1, self.out_channels, self.in_channels,
|
||||
self.kernel_size, self.kernel_size)
|
||||
|
||||
s = self.mlp_class_std(latent).view(-1, 1, self.in_channels, 1, 1)
|
||||
weight = s * weight
|
||||
if self.demudulate:
|
||||
d = torch.rsqrt((weight**2).sum(4).sum(3).sum(2) + 1e-5).view(
|
||||
-1, self.out_channels, 1, 1, 1)
|
||||
weight = (d * weight).view(-1, self.in_channels, self.kernel_size,
|
||||
self.kernel_size)
|
||||
else:
|
||||
weight = weight.view(-1, self.in_channels, self.kernel_size,
|
||||
self.kernel_size)
|
||||
|
||||
batch, _, height, width = input.shape
|
||||
|
||||
input = input.reshape(1, -1, height, width)
|
||||
input = self.padding(input)
|
||||
out = F.conv2d(
|
||||
input, weight, groups=batch).view(batch, self.out_channels, height,
|
||||
width) + self.bias
|
||||
|
||||
return out
|
||||
121
modelscope/models/cv/human_image_generation/generators/tps.py
Normal file
121
modelscope/models/cv/human_image_generation/generators/tps.py
Normal file
@@ -0,0 +1,121 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
|
||||
class TPS(nn.Module):
|
||||
|
||||
def __init__(self, mode='kp'):
|
||||
super().__init__()
|
||||
self.mode = mode
|
||||
|
||||
def trans(self, kp_1):
|
||||
if self.mode == 'kp':
|
||||
device = kp_1.device
|
||||
kp_type = kp_1.type()
|
||||
self.gs = kp_1.shape[1]
|
||||
n = kp_1.shape[2]
|
||||
K = torch.norm(
|
||||
kp_1[:, :, :, None] - kp_1[:, :, None, :], dim=4, p=2)
|
||||
K = K**2
|
||||
K = K * torch.log(K + 1e-9)
|
||||
|
||||
one1 = torch.ones(self.bs, kp_1.shape[1], kp_1.shape[2],
|
||||
1).to(device).type(kp_type)
|
||||
kp_1p = torch.cat([kp_1, one1], 3)
|
||||
|
||||
zero = torch.zeros(self.bs, kp_1.shape[1], 3,
|
||||
3).to(device).type(kp_type)
|
||||
P = torch.cat([kp_1p, zero], 2)
|
||||
L = torch.cat([K, kp_1p.permute(0, 1, 3, 2)], 2)
|
||||
L = torch.cat([L, P], 3)
|
||||
|
||||
zero = torch.zeros(self.bs, kp_1.shape[1], 3,
|
||||
2).to(device).type(kp_type)
|
||||
kp_substitute = torch.zeros(kp_1.shape).to(device).type(kp_type)
|
||||
Y = torch.cat([kp_substitute, zero], 2)
|
||||
one = torch.eye(L.shape[2]).expand(
|
||||
L.shape).to(device).type(kp_type) * 0.01
|
||||
L = L + one
|
||||
|
||||
param = torch.matmul(torch.inverse(L), Y)
|
||||
self.theta = param[:, :, n:, :].permute(0, 1, 3, 2)
|
||||
|
||||
self.control_points = kp_1
|
||||
self.control_params = param[:, :, :n, :]
|
||||
else:
|
||||
raise Exception('Error TPS mode')
|
||||
|
||||
def transform_frame(self, frame):
|
||||
grid = make_coordinate_grid(
|
||||
frame.shape[2:], type=frame.type()).unsqueeze(0).to(frame.device)
|
||||
grid = grid.view(1, frame.shape[2] * frame.shape[3], 2)
|
||||
shape = [self.bs, frame.shape[2], frame.shape[3], 2]
|
||||
if self.mode == 'kp':
|
||||
shape.insert(1, self.gs)
|
||||
grid = self.warp_coordinates(grid).view(*shape)
|
||||
return grid
|
||||
|
||||
def warp_coordinates(self, coordinates):
|
||||
theta = self.theta.type(coordinates.type()).to(coordinates.device)
|
||||
control_points = self.control_points.type(coordinates.type()).to(
|
||||
coordinates.device)
|
||||
control_params = self.control_params.type(coordinates.type()).to(
|
||||
coordinates.device)
|
||||
|
||||
if self.mode == 'kp':
|
||||
transformed = torch.matmul(theta[:, :, :, :2],
|
||||
coordinates.permute(
|
||||
0, 2, 1)) + theta[:, :, :, 2:]
|
||||
|
||||
distances = coordinates.view(
|
||||
coordinates.shape[0], 1, 1, -1, 2) - control_points.view(
|
||||
self.bs, control_points.shape[1], -1, 1, 2)
|
||||
distances = distances**2
|
||||
result = distances.sum(-1)
|
||||
result = result * torch.log(result + 1e-9)
|
||||
result = torch.matmul(result.permute(0, 1, 3, 2), control_params)
|
||||
transformed = transformed.permute(0, 1, 3, 2) + result
|
||||
|
||||
else:
|
||||
raise Exception('Error TPS mode')
|
||||
|
||||
return transformed
|
||||
|
||||
def preprocess_kp(self, kp_1):
|
||||
'''
|
||||
kp_1: (b, ntps*nkp, 2)
|
||||
'''
|
||||
kp_mask = (kp_1 == -1)
|
||||
num_keypoints = kp_1.shape[1]
|
||||
kp_1 = kp_1.masked_fill(kp_mask, -1.)
|
||||
return kp_1, num_keypoints
|
||||
|
||||
def forward(self, source_image, kp_driving):
|
||||
bs, _, h, w = source_image.shape
|
||||
self.bs = bs
|
||||
kp_driving, num_keypoints = self.preprocess_kp(kp_driving)
|
||||
kp_1 = kp_driving.view(bs, -1, num_keypoints, 2)
|
||||
self.trans(kp_1)
|
||||
grid = self.transform_frame(source_image)
|
||||
grid = grid.view(bs, h, w, 2)
|
||||
return grid
|
||||
|
||||
|
||||
def make_coordinate_grid(spatial_size, type):
|
||||
"""
|
||||
Create a meshgrid [-1,1] x [-1,1] of given spatial_size.
|
||||
"""
|
||||
h, w = spatial_size
|
||||
x = torch.arange(w).type(type)
|
||||
y = torch.arange(h).type(type)
|
||||
|
||||
x = (2 * (x / (w - 1)) - 1)
|
||||
y = (2 * (y / (h - 1)) - 1)
|
||||
|
||||
yy = y.view(-1, 1).repeat(1, w)
|
||||
xx = x.view(1, -1).repeat(h, 1)
|
||||
|
||||
meshed = torch.cat([xx.unsqueeze_(2), yy.unsqueeze_(2)], 2)
|
||||
|
||||
return meshed
|
||||
@@ -0,0 +1,182 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def get_wav(in_channels, pool=True):
|
||||
harr_wav_L = 1 / np.sqrt(2) * np.ones((1, 2))
|
||||
harr_wav_H = 1 / np.sqrt(2) * np.ones((1, 2))
|
||||
harr_wav_H[0, 0] = -1 * harr_wav_H[0, 0]
|
||||
harr_wav_LL = np.transpose(harr_wav_L) * harr_wav_L
|
||||
harr_wav_LH = np.transpose(harr_wav_L) * harr_wav_H
|
||||
harr_wav_HL = np.transpose(harr_wav_H) * harr_wav_L
|
||||
harr_wav_HH = np.transpose(harr_wav_H) * harr_wav_H
|
||||
filter_LL = torch.from_numpy(harr_wav_LL).unsqueeze(0)
|
||||
filter_LH = torch.from_numpy(harr_wav_LH).unsqueeze(0)
|
||||
filter_HL = torch.from_numpy(harr_wav_HL).unsqueeze(0)
|
||||
filter_HH = torch.from_numpy(harr_wav_HH).unsqueeze(0)
|
||||
if pool:
|
||||
net = nn.Conv2d
|
||||
else:
|
||||
net = nn.ConvTranspose2d
|
||||
LL = net(
|
||||
in_channels,
|
||||
in_channels * 2,
|
||||
kernel_size=2,
|
||||
stride=2,
|
||||
padding=0,
|
||||
bias=False,
|
||||
groups=in_channels)
|
||||
LH = net(
|
||||
in_channels,
|
||||
in_channels * 2,
|
||||
kernel_size=2,
|
||||
stride=2,
|
||||
padding=0,
|
||||
bias=False,
|
||||
groups=in_channels)
|
||||
HL = net(
|
||||
in_channels,
|
||||
in_channels * 2,
|
||||
kernel_size=2,
|
||||
stride=2,
|
||||
padding=0,
|
||||
bias=False,
|
||||
groups=in_channels)
|
||||
HH = net(
|
||||
in_channels,
|
||||
in_channels * 2,
|
||||
kernel_size=2,
|
||||
stride=2,
|
||||
padding=0,
|
||||
bias=False,
|
||||
groups=in_channels)
|
||||
LL.weight.requires_grad = False
|
||||
LH.weight.requires_grad = False
|
||||
HL.weight.requires_grad = False
|
||||
HH.weight.requires_grad = False
|
||||
LL.weight.data = filter_LL.float().unsqueeze(0).expand(
|
||||
in_channels * 2, -1, -1, -1)
|
||||
LH.weight.data = filter_LH.float().unsqueeze(0).expand(
|
||||
in_channels * 2, -1, -1, -1)
|
||||
HL.weight.data = filter_HL.float().unsqueeze(0).expand(
|
||||
in_channels * 2, -1, -1, -1)
|
||||
HH.weight.data = filter_HH.float().unsqueeze(0).expand(
|
||||
in_channels * 2, -1, -1, -1)
|
||||
return LL, LH, HL, HH
|
||||
|
||||
|
||||
class WavePool(nn.Module):
|
||||
|
||||
def __init__(self, in_channels):
|
||||
super(WavePool, self).__init__()
|
||||
self.LL, self.LH, self.HL, self.HH = get_wav(in_channels)
|
||||
|
||||
def forward(self, x):
|
||||
return self.LL(x), self.LH(x), self.HL(x), self.HH(x)
|
||||
|
||||
|
||||
def get_wav_two(in_channels, out_channels=None, pool=True):
|
||||
"""wavelet decomposition using conv2d"""
|
||||
harr_wav_L = 1 / np.sqrt(2) * np.ones((1, 2))
|
||||
harr_wav_H = 1 / np.sqrt(2) * np.ones((1, 2))
|
||||
harr_wav_H[0, 0] = -1 * harr_wav_H[0, 0]
|
||||
|
||||
harr_wav_LL = np.transpose(harr_wav_L) * harr_wav_L
|
||||
harr_wav_LH = np.transpose(harr_wav_L) * harr_wav_H
|
||||
harr_wav_HL = np.transpose(harr_wav_H) * harr_wav_L
|
||||
harr_wav_HH = np.transpose(harr_wav_H) * harr_wav_H
|
||||
|
||||
filter_LL = torch.from_numpy(harr_wav_LL).unsqueeze(0)
|
||||
filter_LH = torch.from_numpy(harr_wav_LH).unsqueeze(0)
|
||||
filter_HL = torch.from_numpy(harr_wav_HL).unsqueeze(0)
|
||||
filter_HH = torch.from_numpy(harr_wav_HH).unsqueeze(0)
|
||||
|
||||
if pool:
|
||||
net = nn.Conv2d
|
||||
else:
|
||||
net = nn.ConvTranspose2d
|
||||
if out_channels is None:
|
||||
out_channels = in_channels
|
||||
LL = net(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=2,
|
||||
stride=2,
|
||||
padding=0,
|
||||
bias=False,
|
||||
groups=in_channels)
|
||||
LH = net(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=2,
|
||||
stride=2,
|
||||
padding=0,
|
||||
bias=False,
|
||||
groups=in_channels)
|
||||
HL = net(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=2,
|
||||
stride=2,
|
||||
padding=0,
|
||||
bias=False,
|
||||
groups=in_channels)
|
||||
HH = net(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=2,
|
||||
stride=2,
|
||||
padding=0,
|
||||
bias=False,
|
||||
groups=in_channels)
|
||||
|
||||
LL.weight.requires_grad = False
|
||||
LH.weight.requires_grad = False
|
||||
HL.weight.requires_grad = False
|
||||
HH.weight.requires_grad = False
|
||||
|
||||
LL.weight.data = filter_LL.float().unsqueeze(0).expand(
|
||||
in_channels, -1, -1, -1)
|
||||
LH.weight.data = filter_LH.float().unsqueeze(0).expand(
|
||||
in_channels, -1, -1, -1)
|
||||
HL.weight.data = filter_HL.float().unsqueeze(0).expand(
|
||||
in_channels, -1, -1, -1)
|
||||
HH.weight.data = filter_HH.float().unsqueeze(0).expand(
|
||||
in_channels, -1, -1, -1)
|
||||
|
||||
return LL, LH, HL, HH
|
||||
|
||||
|
||||
class WavePool2(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, out_channels=None):
|
||||
super(WavePool2, self).__init__()
|
||||
self.LL, self.LH, self.HL, self.HH = get_wav_two(
|
||||
in_channels, out_channels)
|
||||
|
||||
def forward(self, x):
|
||||
return self.LL(x), self.LH(x), self.HL(x), self.HH(x)
|
||||
|
||||
|
||||
class WaveUnpool(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, out_channels=None, option_unpool='cat5'):
|
||||
super(WaveUnpool, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
self.option_unpool = option_unpool
|
||||
self.LL, self.LH, self.HL, self.HH = get_wav_two(
|
||||
self.in_channels, out_channels, pool=False)
|
||||
|
||||
def forward(self, LL, LH, HL, HH, original=None):
|
||||
if self.option_unpool == 'sum':
|
||||
return self.LL(LL) + self.LH(LH) + self.HL(HL) + self.HH(HH)
|
||||
elif self.option_unpool == 'cat5' and original is not None:
|
||||
return torch.cat(
|
||||
[self.LL(LL),
|
||||
self.LH(LH),
|
||||
self.HL(HL),
|
||||
self.HH(HH), original],
|
||||
dim=1)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
@@ -0,0 +1,268 @@
|
||||
# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.
|
||||
|
||||
import math
|
||||
import random
|
||||
from ast import Global
|
||||
from pickle import GLOBAL
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision
|
||||
import torchvision.transforms as transforms
|
||||
import torchvision.transforms.functional as F
|
||||
from PIL import Image
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models.base import TorchModel
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
from .generators.extraction_distribution_model_flow25 import \
|
||||
Generator as Generator
|
||||
|
||||
tv_version = int(torchvision.__version__.split('.')[1])
|
||||
if tv_version > 8:
|
||||
from torchvision.transforms.functional import InterpolationMode
|
||||
resize_method = InterpolationMode.BICUBIC
|
||||
resize_nearest = InterpolationMode.NEAREST
|
||||
else:
|
||||
resize_method = Image.BICUBIC
|
||||
resize_nearest = Image.NEAREST
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def get_random_params(size, scale_param, use_flip=False):
|
||||
w, h = size
|
||||
scale = random.random() * scale_param
|
||||
|
||||
if use_flip:
|
||||
use_flip = random.random() > 0.9
|
||||
|
||||
new_w = int(w * (1.0 + scale))
|
||||
new_h = int(h * (1.0 + scale))
|
||||
x = random.randint(0, np.maximum(0, new_w - w))
|
||||
y = random.randint(0, np.maximum(0, new_h - h))
|
||||
return {
|
||||
'crop_param': (x, y, w, h),
|
||||
'scale_size': (new_h, new_w),
|
||||
'use_flip': use_flip
|
||||
}
|
||||
|
||||
|
||||
def get_transform(param, method=resize_method, normalize=True, toTensor=True):
|
||||
transform_list = []
|
||||
if 'scale_size' in param and param['scale_size'] is not None:
|
||||
osize = param['scale_size']
|
||||
transform_list.append(transforms.Resize(osize, interpolation=method))
|
||||
|
||||
if 'crop_param' in param and param['crop_param'] is not None:
|
||||
transform_list.append(
|
||||
transforms.Lambda(lambda img: __crop(img, param['crop_param'])))
|
||||
|
||||
if param['use_flip']:
|
||||
transform_list.append(transforms.Lambda(lambda img: __flip(img)))
|
||||
|
||||
if toTensor:
|
||||
transform_list += [transforms.ToTensor()]
|
||||
|
||||
if normalize:
|
||||
transform_list += [
|
||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
||||
]
|
||||
return transforms.Compose(transform_list)
|
||||
|
||||
|
||||
def __crop(img, pos):
|
||||
x1, y1, tw, th = pos
|
||||
return img.crop((x1, y1, x1 + tw, y1 + th))
|
||||
|
||||
|
||||
def __flip(img):
|
||||
return F.hflip(img)
|
||||
|
||||
|
||||
def normalize():
|
||||
return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
||||
|
||||
|
||||
def load_checkpoint(model, checkpoint_path, device):
|
||||
params = torch.load(checkpoint_path, map_location=device)
|
||||
if 'target_image_renderer.weight' in params['net_G_ema'].keys():
|
||||
params['net_G_ema'].pop('target_image_renderer.weight')
|
||||
model.load_state_dict(params['net_G_ema'])
|
||||
model.to(device)
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
@MODELS.register_module(
|
||||
Tasks.human_image_generation, module_name=Models.human_image_generation)
|
||||
class FreqHPTForHumanImageGeneration(TorchModel):
|
||||
"""initialize the human image generation model from the `model_dir` path.
|
||||
|
||||
Args:
|
||||
model_dir (str): the model path.
|
||||
"""
|
||||
|
||||
def __init__(self, model_dir, device_id=0, *args, **kwargs):
|
||||
|
||||
super().__init__(
|
||||
model_dir=model_dir, device_id=device_id, *args, **kwargs)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
self.device = 'cuda'
|
||||
logger.info('Use GPU')
|
||||
else:
|
||||
self.device = 'cpu'
|
||||
logger.info('Use CPU')
|
||||
|
||||
size = 512
|
||||
semantic_dim = 20
|
||||
channels = {
|
||||
16: 256,
|
||||
32: 256,
|
||||
64: 256,
|
||||
128: 128,
|
||||
256: 128,
|
||||
512: 64,
|
||||
1024: 32
|
||||
}
|
||||
num_labels = {16: 16, 32: 32, 64: 64, 128: 64, 256: 64, 512: False}
|
||||
match_kernels = {16: False, 32: 3, 64: 3, 128: 3, 256: 3, 512: False}
|
||||
wavelet_down_levels = {16: False, 32: 1, 64: 2, 128: 3, 256: 3, 512: 3}
|
||||
self.model = Generator(
|
||||
size,
|
||||
semantic_dim,
|
||||
channels,
|
||||
num_labels,
|
||||
match_kernels,
|
||||
wavelet_down_levels=wavelet_down_levels)
|
||||
self.model = load_checkpoint(
|
||||
self.model, model_dir + '/' + ModelFile.TORCH_MODEL_BIN_FILE,
|
||||
self.device)
|
||||
|
||||
def forward(self, x, y, z):
|
||||
pred_result = self.model(x, y, z)
|
||||
return pred_result
|
||||
|
||||
|
||||
def trans_keypoins(keypoints, param, img_size, offset=None):
|
||||
missing_keypoint_index = keypoints == -1
|
||||
|
||||
# crop the white line in the original dataset
|
||||
if not offset == 40:
|
||||
keypoints[:, 0] = (keypoints[:, 0] - 40)
|
||||
|
||||
# resize the dataset
|
||||
img_h, img_w = img_size
|
||||
scale_w = 1.0 / 176.0 * img_w
|
||||
scale_h = 1.0 / 256.0 * img_h
|
||||
|
||||
if 'scale_size' in param and param['scale_size'] is not None:
|
||||
new_h, new_w = param['scale_size']
|
||||
scale_w = scale_w / img_w * new_w
|
||||
scale_h = scale_h / img_h * new_h
|
||||
|
||||
if 'crop_param' in param and param['crop_param'] is not None:
|
||||
w, h, _, _ = param['crop_param']
|
||||
else:
|
||||
w, h = 0, 0
|
||||
|
||||
keypoints[:, 0] = keypoints[:, 0] * scale_w - w
|
||||
keypoints[:, 1] = keypoints[:, 1] * scale_h - h
|
||||
|
||||
normalized_kp = keypoints.copy()
|
||||
normalized_kp[:, 0] = (normalized_kp[:, 0]) / img_w * 2 - 1
|
||||
normalized_kp[:, 1] = (normalized_kp[:, 1]) / img_h * 2 - 1
|
||||
normalized_kp[missing_keypoint_index] = -1
|
||||
|
||||
keypoints[missing_keypoint_index] = -1
|
||||
return keypoints, normalized_kp
|
||||
|
||||
|
||||
def get_label_tensor(path, img, param):
|
||||
limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10],
|
||||
[10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15],
|
||||
[15, 17], [1, 16], [16, 18], [3, 17], [6, 18]]
|
||||
|
||||
colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0],
|
||||
[170, 255, 0], [85, 255, 0], [0, 255, 0], [0, 255, 85],
|
||||
[0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255],
|
||||
[0, 0, 255], [85, 0, 255], [170, 0, 255], [255, 0, 255],
|
||||
[255, 0, 170], [255, 0, 85]]
|
||||
canvas = np.zeros((img.shape[1], img.shape[2], 3)).astype(np.uint8)
|
||||
keypoint = np.loadtxt(path)
|
||||
keypoint, normalized_kp = trans_keypoins(keypoint, param, img.shape[1:])
|
||||
stickwidth = 4
|
||||
for i in range(18):
|
||||
x, y = keypoint[i, 0:2]
|
||||
if x == -1 or y == -1:
|
||||
continue
|
||||
cv2.circle(canvas, (int(x), int(y)), 4, colors[i], thickness=-1)
|
||||
joints = []
|
||||
for i in range(17):
|
||||
Y = keypoint[np.array(limbSeq[i]) - 1, 0]
|
||||
X = keypoint[np.array(limbSeq[i]) - 1, 1]
|
||||
cur_canvas = canvas.copy()
|
||||
if -1 in Y or -1 in X:
|
||||
joints.append(np.zeros_like(cur_canvas[:, :, 0]))
|
||||
continue
|
||||
mX = np.mean(X)
|
||||
mY = np.mean(Y)
|
||||
length = ((X[0] - X[1])**2 + (Y[0] - Y[1])**2)**0.5
|
||||
angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
|
||||
polygon = cv2.ellipse2Poly(
|
||||
(int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0,
|
||||
360, 1)
|
||||
cv2.fillConvexPoly(cur_canvas, polygon, colors[i])
|
||||
canvas = cv2.addWeighted(canvas, 0.4, cur_canvas, 0.6, 0)
|
||||
|
||||
joint = np.zeros_like(cur_canvas[:, :, 0])
|
||||
cv2.fillConvexPoly(joint, polygon, 255)
|
||||
joint = cv2.addWeighted(joint, 0.4, joint, 0.6, 0)
|
||||
joints.append(joint)
|
||||
pose = F.to_tensor(
|
||||
Image.fromarray(cv2.cvtColor(canvas, cv2.COLOR_BGR2RGB)))
|
||||
|
||||
tensors_dist = 0
|
||||
e = 1
|
||||
for i in range(len(joints)):
|
||||
im_dist = cv2.distanceTransform(255 - joints[i], cv2.DIST_L1, 3)
|
||||
im_dist = np.clip((im_dist / 3), 0, 255).astype(np.uint8)
|
||||
tensor_dist = F.to_tensor(Image.fromarray(im_dist))
|
||||
tensors_dist = tensor_dist if e == 1 else torch.cat(
|
||||
[tensors_dist, tensor_dist])
|
||||
e += 1
|
||||
|
||||
label_tensor = torch.cat((pose, tensors_dist), dim=0)
|
||||
return label_tensor, normalized_kp
|
||||
|
||||
|
||||
def get_image_tensor(path):
|
||||
img = Image.open(path)
|
||||
param = get_random_params(img.size, 0)
|
||||
trans = get_transform(param, normalize=True, toTensor=True)
|
||||
img = trans(img)
|
||||
return img, param
|
||||
|
||||
|
||||
def infer(genmodel, image_path, target_label_path, device):
|
||||
ref_tensor, param = get_image_tensor(image_path)
|
||||
target_label_tensor, target_kp = get_label_tensor(target_label_path,
|
||||
ref_tensor, param)
|
||||
|
||||
ref_tensor = ref_tensor.unsqueeze(0).to(device)
|
||||
target_label_tensor = target_label_tensor.unsqueeze(0).to(device)
|
||||
target_kp = torch.from_numpy(target_kp).unsqueeze(0).to(device)
|
||||
output_dict = genmodel(ref_tensor, target_label_tensor, target_kp)
|
||||
output_image = output_dict['fake_image'][0]
|
||||
|
||||
output_image = output_image.clamp_(-1, 1)
|
||||
image = (output_image + 1) * 0.5
|
||||
image = image.detach().cpu().squeeze().numpy()
|
||||
image = np.transpose(image, (1, 2, 0)) * 255
|
||||
image = np.uint8(image)
|
||||
bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
||||
return bgr
|
||||
0
modelscope/ops/human_image_generation/__init__.py
Normal file
0
modelscope/ops/human_image_generation/__init__.py
Normal file
118
modelscope/ops/human_image_generation/fused_act.py
Normal file
118
modelscope/ops/human_image_generation/fused_act.py
Normal file
@@ -0,0 +1,118 @@
|
||||
import os
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.autograd import Function
|
||||
from torch.nn import functional as F
|
||||
from torch.utils.cpp_extension import load
|
||||
|
||||
module_path = os.path.dirname(__file__)
|
||||
fused = load(
|
||||
'fused',
|
||||
sources=[
|
||||
os.path.join(module_path, 'fused_bias_act.cpp'),
|
||||
os.path.join(module_path, 'fused_bias_act_kernel.cu'),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class FusedLeakyReLUFunctionBackward(Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, grad_output, out, bias, negative_slope, scale):
|
||||
ctx.save_for_backward(out)
|
||||
ctx.negative_slope = negative_slope
|
||||
ctx.scale = scale
|
||||
|
||||
empty = grad_output.new_empty(0)
|
||||
|
||||
grad_input = fused.fused_bias_act(grad_output, empty, out, 3, 1,
|
||||
negative_slope, scale)
|
||||
|
||||
dim = [0]
|
||||
|
||||
if grad_input.ndim > 2:
|
||||
dim += list(range(2, grad_input.ndim))
|
||||
|
||||
if bias:
|
||||
grad_bias = grad_input.sum(dim).detach()
|
||||
|
||||
else:
|
||||
grad_bias = empty
|
||||
|
||||
return grad_input, grad_bias
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, gradgrad_input, gradgrad_bias):
|
||||
out, = ctx.saved_tensors
|
||||
gradgrad_out = fused.fused_bias_act(gradgrad_input, gradgrad_bias, out,
|
||||
3, 1, ctx.negative_slope,
|
||||
ctx.scale)
|
||||
|
||||
return gradgrad_out, None, None, None, None
|
||||
|
||||
|
||||
class FusedLeakyReLUFunction(Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input, bias, negative_slope, scale):
|
||||
empty = input.new_empty(0)
|
||||
|
||||
ctx.bias = bias is not None
|
||||
|
||||
if bias is None:
|
||||
bias = empty
|
||||
|
||||
out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope,
|
||||
scale)
|
||||
ctx.save_for_backward(out)
|
||||
ctx.negative_slope = negative_slope
|
||||
ctx.scale = scale
|
||||
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
out, = ctx.saved_tensors
|
||||
|
||||
grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(
|
||||
grad_output, out, ctx.bias, ctx.negative_slope, ctx.scale)
|
||||
|
||||
if not ctx.bias:
|
||||
grad_bias = None
|
||||
|
||||
return grad_input, grad_bias, None, None
|
||||
|
||||
|
||||
class FusedLeakyReLU(nn.Module):
|
||||
|
||||
def __init__(self, channel, bias=True, negative_slope=0.2, scale=2**0.5):
|
||||
super().__init__()
|
||||
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.zeros(channel))
|
||||
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
self.negative_slope = negative_slope
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, input):
|
||||
return fused_leaky_relu(input, self.bias, self.negative_slope,
|
||||
self.scale)
|
||||
|
||||
|
||||
def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2**0.5):
|
||||
if input.device.type == 'cpu':
|
||||
if bias is not None:
|
||||
rest_dim = [1] * (input.ndim - bias.ndim - 1)
|
||||
return (F.leaky_relu(
|
||||
input + bias.view(1, bias.shape[0], *rest_dim),
|
||||
negative_slope=0.2) * scale)
|
||||
|
||||
else:
|
||||
return F.leaky_relu(input, negative_slope=0.2) * scale
|
||||
|
||||
else:
|
||||
return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
|
||||
21
modelscope/ops/human_image_generation/fused_bias_act.cpp
Normal file
21
modelscope/ops/human_image_generation/fused_bias_act.cpp
Normal file
@@ -0,0 +1,21 @@
|
||||
#include <torch/extension.h>
|
||||
|
||||
|
||||
torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
|
||||
int act, int grad, float alpha, float scale);
|
||||
|
||||
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
||||
|
||||
torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
|
||||
int act, int grad, float alpha, float scale) {
|
||||
CHECK_CUDA(input);
|
||||
CHECK_CUDA(bias);
|
||||
|
||||
return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
|
||||
}
|
||||
@@ -0,0 +1,99 @@
|
||||
// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
|
||||
//
|
||||
// This work is made available under the Nvidia Source Code License-NC.
|
||||
// To view a copy of this license, visit
|
||||
// https://nvlabs.github.io/stylegan2/license.html
|
||||
|
||||
#include <torch/types.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/AccumulateType.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/cuda/CUDAApplyUtils.cuh>
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
|
||||
template <typename scalar_t>
|
||||
static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref,
|
||||
int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) {
|
||||
int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
|
||||
|
||||
scalar_t zero = 0.0;
|
||||
|
||||
for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) {
|
||||
scalar_t x = p_x[xi];
|
||||
|
||||
if (use_bias) {
|
||||
x += p_b[(xi / step_b) % size_b];
|
||||
}
|
||||
|
||||
scalar_t ref = use_ref ? p_ref[xi] : zero;
|
||||
|
||||
scalar_t y;
|
||||
|
||||
switch (act * 10 + grad) {
|
||||
default:
|
||||
case 10: y = x; break;
|
||||
case 11: y = x; break;
|
||||
case 12: y = 0.0; break;
|
||||
|
||||
case 30: y = (x > 0.0) ? x : x * alpha; break;
|
||||
case 31: y = (ref > 0.0) ? x : x * alpha; break;
|
||||
case 32: y = 0.0; break;
|
||||
}
|
||||
|
||||
out[xi] = y * scale;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
|
||||
int act, int grad, float alpha, float scale) {
|
||||
int curDevice = -1;
|
||||
cudaGetDevice(&curDevice);
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
|
||||
|
||||
auto x = input.contiguous();
|
||||
auto b = bias.contiguous();
|
||||
auto ref = refer.contiguous();
|
||||
|
||||
int use_bias = b.numel() ? 1 : 0;
|
||||
int use_ref = ref.numel() ? 1 : 0;
|
||||
|
||||
int size_x = x.numel();
|
||||
int size_b = b.numel();
|
||||
int step_b = 1;
|
||||
|
||||
for (int i = 1 + 1; i < x.dim(); i++) {
|
||||
step_b *= x.size(i);
|
||||
}
|
||||
|
||||
int loop_x = 4;
|
||||
int block_size = 4 * 32;
|
||||
int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
|
||||
|
||||
auto y = torch::empty_like(x);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] {
|
||||
fused_bias_act_kernel<scalar_t><<<grid_size, block_size, 0, stream>>>(
|
||||
y.data_ptr<scalar_t>(),
|
||||
x.data_ptr<scalar_t>(),
|
||||
b.data_ptr<scalar_t>(),
|
||||
ref.data_ptr<scalar_t>(),
|
||||
act,
|
||||
grad,
|
||||
alpha,
|
||||
scale,
|
||||
loop_x,
|
||||
size_x,
|
||||
step_b,
|
||||
size_b,
|
||||
use_bias,
|
||||
use_ref
|
||||
);
|
||||
});
|
||||
|
||||
return y;
|
||||
}
|
||||
23
modelscope/ops/human_image_generation/upfirdn2d.cpp
Normal file
23
modelscope/ops/human_image_generation/upfirdn2d.cpp
Normal file
@@ -0,0 +1,23 @@
|
||||
#include <torch/extension.h>
|
||||
|
||||
|
||||
torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
|
||||
int up_x, int up_y, int down_x, int down_y,
|
||||
int pad_x0, int pad_x1, int pad_y0, int pad_y1);
|
||||
|
||||
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
||||
|
||||
torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel,
|
||||
int up_x, int up_y, int down_x, int down_y,
|
||||
int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
|
||||
CHECK_CUDA(input);
|
||||
CHECK_CUDA(kernel);
|
||||
|
||||
return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1);
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
|
||||
}
|
||||
208
modelscope/ops/human_image_generation/upfirdn2d.py
Normal file
208
modelscope/ops/human_image_generation/upfirdn2d.py
Normal file
@@ -0,0 +1,208 @@
|
||||
import os
|
||||
from collections import abc
|
||||
|
||||
import torch
|
||||
from torch.autograd import Function
|
||||
from torch.nn import functional as F
|
||||
from torch.utils.cpp_extension import load
|
||||
|
||||
module_path = os.path.dirname(__file__)
|
||||
upfirdn2d_op = load(
|
||||
'upfirdn2d',
|
||||
sources=[
|
||||
os.path.join(module_path, 'upfirdn2d.cpp'),
|
||||
os.path.join(module_path, 'upfirdn2d_kernel.cu'),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class UpFirDn2dBackward(Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad,
|
||||
in_size, out_size):
|
||||
|
||||
up_x, up_y = up
|
||||
down_x, down_y = down
|
||||
g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
|
||||
|
||||
grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
|
||||
|
||||
grad_input = upfirdn2d_op.upfirdn2d(
|
||||
grad_output,
|
||||
grad_kernel,
|
||||
down_x,
|
||||
down_y,
|
||||
up_x,
|
||||
up_y,
|
||||
g_pad_x0,
|
||||
g_pad_x1,
|
||||
g_pad_y0,
|
||||
g_pad_y1,
|
||||
)
|
||||
grad_input = grad_input.view(in_size[0], in_size[1], in_size[2],
|
||||
in_size[3])
|
||||
|
||||
ctx.save_for_backward(kernel)
|
||||
|
||||
pad_x0, pad_x1, pad_y0, pad_y1 = pad
|
||||
|
||||
ctx.up_x = up_x
|
||||
ctx.up_y = up_y
|
||||
ctx.down_x = down_x
|
||||
ctx.down_y = down_y
|
||||
ctx.pad_x0 = pad_x0
|
||||
ctx.pad_x1 = pad_x1
|
||||
ctx.pad_y0 = pad_y0
|
||||
ctx.pad_y1 = pad_y1
|
||||
ctx.in_size = in_size
|
||||
ctx.out_size = out_size
|
||||
|
||||
return grad_input
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, gradgrad_input):
|
||||
kernel, = ctx.saved_tensors
|
||||
|
||||
gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2],
|
||||
ctx.in_size[3], 1)
|
||||
|
||||
gradgrad_out = upfirdn2d_op.upfirdn2d(
|
||||
gradgrad_input,
|
||||
kernel,
|
||||
ctx.up_x,
|
||||
ctx.up_y,
|
||||
ctx.down_x,
|
||||
ctx.down_y,
|
||||
ctx.pad_x0,
|
||||
ctx.pad_x1,
|
||||
ctx.pad_y0,
|
||||
ctx.pad_y1,
|
||||
)
|
||||
# gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3])
|
||||
gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.in_size[1],
|
||||
ctx.out_size[0], ctx.out_size[1])
|
||||
|
||||
return gradgrad_out, None, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
class UpFirDn2d(Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input, kernel, up, down, pad):
|
||||
up_x, up_y = up
|
||||
down_x, down_y = down
|
||||
pad_x0, pad_x1, pad_y0, pad_y1 = pad
|
||||
|
||||
kernel_h, kernel_w = kernel.shape
|
||||
batch, channel, in_h, in_w = input.shape
|
||||
ctx.in_size = input.shape
|
||||
|
||||
input = input.reshape(-1, in_h, in_w, 1)
|
||||
|
||||
ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
|
||||
|
||||
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y
|
||||
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x
|
||||
ctx.out_size = (out_h, out_w)
|
||||
|
||||
ctx.up = (up_x, up_y)
|
||||
ctx.down = (down_x, down_y)
|
||||
ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
|
||||
|
||||
g_pad_x0 = kernel_w - pad_x0 - 1
|
||||
g_pad_y0 = kernel_h - pad_y0 - 1
|
||||
g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
|
||||
g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
|
||||
|
||||
ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
|
||||
|
||||
out = upfirdn2d_op.upfirdn2d(input, kernel, up_x, up_y, down_x, down_y,
|
||||
pad_x0, pad_x1, pad_y0, pad_y1)
|
||||
# out = out.view(major, out_h, out_w, minor)
|
||||
out = out.view(-1, channel, out_h, out_w)
|
||||
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
kernel, grad_kernel = ctx.saved_tensors
|
||||
|
||||
grad_input = None
|
||||
|
||||
if ctx.needs_input_grad[0]:
|
||||
grad_input = UpFirDn2dBackward.apply(
|
||||
grad_output,
|
||||
kernel,
|
||||
grad_kernel,
|
||||
ctx.up,
|
||||
ctx.down,
|
||||
ctx.pad,
|
||||
ctx.g_pad,
|
||||
ctx.in_size,
|
||||
ctx.out_size,
|
||||
)
|
||||
|
||||
return grad_input, None, None, None, None
|
||||
|
||||
|
||||
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
|
||||
if not isinstance(up, abc.Iterable):
|
||||
up = (up, up)
|
||||
|
||||
if not isinstance(down, abc.Iterable):
|
||||
down = (down, down)
|
||||
|
||||
if len(pad) == 2:
|
||||
pad = (pad[0], pad[1], pad[0], pad[1])
|
||||
|
||||
if input.device.type == 'cpu':
|
||||
out = upfirdn2d_native(input, kernel, *up, *down, *pad)
|
||||
|
||||
else:
|
||||
out = UpFirDn2d.apply(input, kernel, up, down, pad)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1,
|
||||
pad_y0, pad_y1):
|
||||
_, channel, in_h, in_w = input.shape
|
||||
input = input.reshape(-1, in_h, in_w, 1)
|
||||
|
||||
_, in_h, in_w, minor = input.shape
|
||||
kernel_h, kernel_w = kernel.shape
|
||||
|
||||
out = input.view(-1, in_h, 1, in_w, 1, minor)
|
||||
out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
|
||||
out = out.view(-1, in_h * up_y, in_w * up_x, minor)
|
||||
|
||||
out = F.pad(
|
||||
out,
|
||||
[0, 0,
|
||||
max(pad_x0, 0),
|
||||
max(pad_x1, 0),
|
||||
max(pad_y0, 0),
|
||||
max(pad_y1, 0)])
|
||||
out = out[:,
|
||||
max(-pad_y0, 0):out.shape[1] - max(-pad_y1, 0),
|
||||
max(-pad_x0, 0):out.shape[2] - max(-pad_x1, 0), :, ]
|
||||
|
||||
out = out.permute(0, 3, 1, 2)
|
||||
out = out.reshape(
|
||||
[-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
|
||||
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
|
||||
out = F.conv2d(out, w)
|
||||
out = out.reshape(
|
||||
-1,
|
||||
minor,
|
||||
in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
|
||||
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
|
||||
)
|
||||
out = out.permute(0, 2, 3, 1)
|
||||
out = out[:, ::down_y, ::down_x, :]
|
||||
|
||||
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y
|
||||
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x
|
||||
|
||||
return out.view(-1, channel, out_h, out_w)
|
||||
369
modelscope/ops/human_image_generation/upfirdn2d_kernel.cu
Normal file
369
modelscope/ops/human_image_generation/upfirdn2d_kernel.cu
Normal file
@@ -0,0 +1,369 @@
|
||||
// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
|
||||
//
|
||||
// This work is made available under the Nvidia Source Code License-NC.
|
||||
// To view a copy of this license, visit
|
||||
// https://nvlabs.github.io/stylegan2/license.html
|
||||
|
||||
#include <torch/types.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/AccumulateType.h>
|
||||
#include <ATen/cuda/CUDAApplyUtils.cuh>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
static __host__ __device__ __forceinline__ int floor_div(int a, int b) {
|
||||
int c = a / b;
|
||||
|
||||
if (c * b > a) {
|
||||
c--;
|
||||
}
|
||||
|
||||
return c;
|
||||
}
|
||||
|
||||
struct UpFirDn2DKernelParams {
|
||||
int up_x;
|
||||
int up_y;
|
||||
int down_x;
|
||||
int down_y;
|
||||
int pad_x0;
|
||||
int pad_x1;
|
||||
int pad_y0;
|
||||
int pad_y1;
|
||||
|
||||
int major_dim;
|
||||
int in_h;
|
||||
int in_w;
|
||||
int minor_dim;
|
||||
int kernel_h;
|
||||
int kernel_w;
|
||||
int out_h;
|
||||
int out_w;
|
||||
int loop_major;
|
||||
int loop_x;
|
||||
};
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input,
|
||||
const scalar_t *kernel,
|
||||
const UpFirDn2DKernelParams p) {
|
||||
int minor_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int out_y = minor_idx / p.minor_dim;
|
||||
minor_idx -= out_y * p.minor_dim;
|
||||
int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y;
|
||||
int major_idx_base = blockIdx.z * p.loop_major;
|
||||
|
||||
if (out_x_base >= p.out_w || out_y >= p.out_h ||
|
||||
major_idx_base >= p.major_dim) {
|
||||
return;
|
||||
}
|
||||
|
||||
int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0;
|
||||
int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h);
|
||||
int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y;
|
||||
int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y;
|
||||
|
||||
for (int loop_major = 0, major_idx = major_idx_base;
|
||||
loop_major < p.loop_major && major_idx < p.major_dim;
|
||||
loop_major++, major_idx++) {
|
||||
for (int loop_x = 0, out_x = out_x_base;
|
||||
loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) {
|
||||
int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0;
|
||||
int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w);
|
||||
int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x;
|
||||
int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x;
|
||||
|
||||
const scalar_t *x_p =
|
||||
&input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim +
|
||||
minor_idx];
|
||||
const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x];
|
||||
int x_px = p.minor_dim;
|
||||
int k_px = -p.up_x;
|
||||
int x_py = p.in_w * p.minor_dim;
|
||||
int k_py = -p.up_y * p.kernel_w;
|
||||
|
||||
scalar_t v = 0.0f;
|
||||
|
||||
for (int y = 0; y < h; y++) {
|
||||
for (int x = 0; x < w; x++) {
|
||||
v += static_cast<scalar_t>(*x_p) * static_cast<scalar_t>(*k_p);
|
||||
x_p += x_px;
|
||||
k_p += k_px;
|
||||
}
|
||||
|
||||
x_p += x_py - w * x_px;
|
||||
k_p += k_py - w * k_px;
|
||||
}
|
||||
|
||||
out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
|
||||
minor_idx] = v;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, int up_x, int up_y, int down_x, int down_y,
|
||||
int kernel_h, int kernel_w, int tile_out_h, int tile_out_w>
|
||||
__global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input,
|
||||
const scalar_t *kernel,
|
||||
const UpFirDn2DKernelParams p) {
|
||||
const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;
|
||||
const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;
|
||||
|
||||
__shared__ volatile float sk[kernel_h][kernel_w];
|
||||
__shared__ volatile float sx[tile_in_h][tile_in_w];
|
||||
|
||||
int minor_idx = blockIdx.x;
|
||||
int tile_out_y = minor_idx / p.minor_dim;
|
||||
minor_idx -= tile_out_y * p.minor_dim;
|
||||
tile_out_y *= tile_out_h;
|
||||
int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;
|
||||
int major_idx_base = blockIdx.z * p.loop_major;
|
||||
|
||||
if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h |
|
||||
major_idx_base >= p.major_dim) {
|
||||
return;
|
||||
}
|
||||
|
||||
for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w;
|
||||
tap_idx += blockDim.x) {
|
||||
int ky = tap_idx / kernel_w;
|
||||
int kx = tap_idx - ky * kernel_w;
|
||||
scalar_t v = 0.0;
|
||||
|
||||
if (kx < p.kernel_w & ky < p.kernel_h) {
|
||||
v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];
|
||||
}
|
||||
|
||||
sk[ky][kx] = v;
|
||||
}
|
||||
|
||||
for (int loop_major = 0, major_idx = major_idx_base;
|
||||
loop_major < p.loop_major & major_idx < p.major_dim;
|
||||
loop_major++, major_idx++) {
|
||||
for (int loop_x = 0, tile_out_x = tile_out_x_base;
|
||||
loop_x < p.loop_x & tile_out_x < p.out_w;
|
||||
loop_x++, tile_out_x += tile_out_w) {
|
||||
int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;
|
||||
int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;
|
||||
int tile_in_x = floor_div(tile_mid_x, up_x);
|
||||
int tile_in_y = floor_div(tile_mid_y, up_y);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w;
|
||||
in_idx += blockDim.x) {
|
||||
int rel_in_y = in_idx / tile_in_w;
|
||||
int rel_in_x = in_idx - rel_in_y * tile_in_w;
|
||||
int in_x = rel_in_x + tile_in_x;
|
||||
int in_y = rel_in_y + tile_in_y;
|
||||
|
||||
scalar_t v = 0.0;
|
||||
|
||||
if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {
|
||||
v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) *
|
||||
p.minor_dim +
|
||||
minor_idx];
|
||||
}
|
||||
|
||||
sx[rel_in_y][rel_in_x] = v;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w;
|
||||
out_idx += blockDim.x) {
|
||||
int rel_out_y = out_idx / tile_out_w;
|
||||
int rel_out_x = out_idx - rel_out_y * tile_out_w;
|
||||
int out_x = rel_out_x + tile_out_x;
|
||||
int out_y = rel_out_y + tile_out_y;
|
||||
|
||||
int mid_x = tile_mid_x + rel_out_x * down_x;
|
||||
int mid_y = tile_mid_y + rel_out_y * down_y;
|
||||
int in_x = floor_div(mid_x, up_x);
|
||||
int in_y = floor_div(mid_y, up_y);
|
||||
int rel_in_x = in_x - tile_in_x;
|
||||
int rel_in_y = in_y - tile_in_y;
|
||||
int kernel_x = (in_x + 1) * up_x - mid_x - 1;
|
||||
int kernel_y = (in_y + 1) * up_y - mid_y - 1;
|
||||
|
||||
scalar_t v = 0.0;
|
||||
|
||||
#pragma unroll
|
||||
for (int y = 0; y < kernel_h / up_y; y++)
|
||||
#pragma unroll
|
||||
for (int x = 0; x < kernel_w / up_x; x++)
|
||||
v += sx[rel_in_y + y][rel_in_x + x] *
|
||||
sk[kernel_y + y * up_y][kernel_x + x * up_x];
|
||||
|
||||
if (out_x < p.out_w & out_y < p.out_h) {
|
||||
out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
|
||||
minor_idx] = v;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
torch::Tensor upfirdn2d_op(const torch::Tensor &input,
|
||||
const torch::Tensor &kernel, int up_x, int up_y,
|
||||
int down_x, int down_y, int pad_x0, int pad_x1,
|
||||
int pad_y0, int pad_y1) {
|
||||
int curDevice = -1;
|
||||
cudaGetDevice(&curDevice);
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
|
||||
|
||||
UpFirDn2DKernelParams p;
|
||||
|
||||
auto x = input.contiguous();
|
||||
auto k = kernel.contiguous();
|
||||
|
||||
p.major_dim = x.size(0);
|
||||
p.in_h = x.size(1);
|
||||
p.in_w = x.size(2);
|
||||
p.minor_dim = x.size(3);
|
||||
p.kernel_h = k.size(0);
|
||||
p.kernel_w = k.size(1);
|
||||
p.up_x = up_x;
|
||||
p.up_y = up_y;
|
||||
p.down_x = down_x;
|
||||
p.down_y = down_y;
|
||||
p.pad_x0 = pad_x0;
|
||||
p.pad_x1 = pad_x1;
|
||||
p.pad_y0 = pad_y0;
|
||||
p.pad_y1 = pad_y1;
|
||||
|
||||
p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) /
|
||||
p.down_y;
|
||||
p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) /
|
||||
p.down_x;
|
||||
|
||||
auto out =
|
||||
at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());
|
||||
|
||||
int mode = -1;
|
||||
|
||||
int tile_out_h = -1;
|
||||
int tile_out_w = -1;
|
||||
|
||||
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
|
||||
p.kernel_h <= 4 && p.kernel_w <= 4) {
|
||||
mode = 1;
|
||||
tile_out_h = 16;
|
||||
tile_out_w = 64;
|
||||
}
|
||||
|
||||
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
|
||||
p.kernel_h <= 3 && p.kernel_w <= 3) {
|
||||
mode = 2;
|
||||
tile_out_h = 16;
|
||||
tile_out_w = 64;
|
||||
}
|
||||
|
||||
if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
|
||||
p.kernel_h <= 4 && p.kernel_w <= 4) {
|
||||
mode = 3;
|
||||
tile_out_h = 16;
|
||||
tile_out_w = 64;
|
||||
}
|
||||
|
||||
if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
|
||||
p.kernel_h <= 2 && p.kernel_w <= 2) {
|
||||
mode = 4;
|
||||
tile_out_h = 16;
|
||||
tile_out_w = 64;
|
||||
}
|
||||
|
||||
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
|
||||
p.kernel_h <= 4 && p.kernel_w <= 4) {
|
||||
mode = 5;
|
||||
tile_out_h = 8;
|
||||
tile_out_w = 32;
|
||||
}
|
||||
|
||||
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
|
||||
p.kernel_h <= 2 && p.kernel_w <= 2) {
|
||||
mode = 6;
|
||||
tile_out_h = 8;
|
||||
tile_out_w = 32;
|
||||
}
|
||||
|
||||
dim3 block_size;
|
||||
dim3 grid_size;
|
||||
|
||||
if (tile_out_h > 0 && tile_out_w > 0) {
|
||||
p.loop_major = (p.major_dim - 1) / 16384 + 1;
|
||||
p.loop_x = 1;
|
||||
block_size = dim3(32 * 8, 1, 1);
|
||||
grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,
|
||||
(p.out_w - 1) / (p.loop_x * tile_out_w) + 1,
|
||||
(p.major_dim - 1) / p.loop_major + 1);
|
||||
} else {
|
||||
p.loop_major = (p.major_dim - 1) / 16384 + 1;
|
||||
p.loop_x = 4;
|
||||
block_size = dim3(4, 32, 1);
|
||||
grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1,
|
||||
(p.out_w - 1) / (p.loop_x * block_size.y) + 1,
|
||||
(p.major_dim - 1) / p.loop_major + 1);
|
||||
}
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] {
|
||||
switch (mode) {
|
||||
case 1:
|
||||
upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 4, 4, 16, 64>
|
||||
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
|
||||
x.data_ptr<scalar_t>(),
|
||||
k.data_ptr<scalar_t>(), p);
|
||||
|
||||
break;
|
||||
|
||||
case 2:
|
||||
upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 3, 3, 16, 64>
|
||||
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
|
||||
x.data_ptr<scalar_t>(),
|
||||
k.data_ptr<scalar_t>(), p);
|
||||
|
||||
break;
|
||||
|
||||
case 3:
|
||||
upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 4, 4, 16, 64>
|
||||
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
|
||||
x.data_ptr<scalar_t>(),
|
||||
k.data_ptr<scalar_t>(), p);
|
||||
|
||||
break;
|
||||
|
||||
case 4:
|
||||
upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 2, 2, 16, 64>
|
||||
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
|
||||
x.data_ptr<scalar_t>(),
|
||||
k.data_ptr<scalar_t>(), p);
|
||||
|
||||
break;
|
||||
|
||||
case 5:
|
||||
upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>
|
||||
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
|
||||
x.data_ptr<scalar_t>(),
|
||||
k.data_ptr<scalar_t>(), p);
|
||||
|
||||
break;
|
||||
|
||||
case 6:
|
||||
upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>
|
||||
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
|
||||
x.data_ptr<scalar_t>(),
|
||||
k.data_ptr<scalar_t>(), p);
|
||||
|
||||
break;
|
||||
|
||||
default:
|
||||
upfirdn2d_kernel_large<scalar_t><<<grid_size, block_size, 0, stream>>>(
|
||||
out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(),
|
||||
k.data_ptr<scalar_t>(), p);
|
||||
}
|
||||
});
|
||||
|
||||
return out;
|
||||
}
|
||||
@@ -1513,6 +1513,11 @@ TASK_OUTPUTS = {
|
||||
# "output_img": np.ndarray with shape [height, width, 3]
|
||||
# }
|
||||
Tasks.image_try_on: [OutputKeys.OUTPUT_IMG],
|
||||
# Tasks.human_image_generation result for a single sample
|
||||
# {
|
||||
# "output_img": np.ndarray with shape [height, width, 3]
|
||||
# }
|
||||
Tasks.human_image_generation: [OutputKeys.OUTPUT_IMG],
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -224,7 +224,10 @@ TASK_INPUTS = {
|
||||
InputKeys.IMAGE: InputType.IMAGE,
|
||||
InputKeys.IMAGE: InputType.IMAGE
|
||||
},
|
||||
|
||||
Tasks.human_image_generation: {
|
||||
InputKeys.IMAGE: InputType.IMAGE,
|
||||
'target_pose_path': InputType.TEXT
|
||||
},
|
||||
# ============ nlp tasks ===================
|
||||
Tasks.chat: {
|
||||
'text': InputType.TEXT,
|
||||
|
||||
60
modelscope/pipelines/cv/human_image_generation_pipeline.py
Normal file
60
modelscope/pipelines/cv/human_image_generation_pipeline.py
Normal file
@@ -0,0 +1,60 @@
|
||||
# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.models.cv.human_image_generation import \
|
||||
human_image_generation_infer
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.base import Input, Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.preprocessors import LoadImage
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.human_image_generation, module_name=Pipelines.human_image_generation)
|
||||
class FreqHPTForHumanImageGenerationPipeline(Pipeline):
|
||||
""" Human Image Generation Pipeline.
|
||||
Examples:
|
||||
>>> human_image_generation = pipeline(Tasks.human_image_generation, model='damo/cv_FreqHPT_human-image-generation')
|
||||
>>> input_images = {'source_img_path': '/your_path/source_img.jpg',
|
||||
>>> 'target_pose_path': '/your_path/target_pose.txt'}
|
||||
>>> result = human_image_generation(input_images)
|
||||
>>> result[OutputKeys.OUTPUT_IMG]
|
||||
"""
|
||||
|
||||
def __init__(self, model: str, **kwargs):
|
||||
"""
|
||||
use `model` to create human image generation pipeline for prediction
|
||||
Args:
|
||||
model: model id on modelscope hub.
|
||||
"""
|
||||
|
||||
super().__init__(model=model, **kwargs)
|
||||
self.model_path = model
|
||||
logger.info('load model done')
|
||||
if torch.cuda.is_available():
|
||||
self.device = 'cuda'
|
||||
logger.info('Use GPU')
|
||||
else:
|
||||
self.device = 'cpu'
|
||||
logger.info('Use CPU')
|
||||
|
||||
def preprocess(self, input: Input) -> Dict[str, Any]:
|
||||
return input
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return inputs
|
||||
|
||||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
|
||||
human_image_generation = human_image_generation_infer.infer(
|
||||
self.model, input['source_img_path'], input['target_pose_path'],
|
||||
self.device)
|
||||
return {OutputKeys.OUTPUT_IMG: human_image_generation}
|
||||
@@ -98,6 +98,7 @@ class CVTasks(object):
|
||||
controllable_image_generation = 'controllable-image-generation'
|
||||
text_to_360panorama_image = 'text-to-360panorama-image'
|
||||
image_try_on = 'image-try-on'
|
||||
human_image_generation = 'human-image-generation'
|
||||
|
||||
# video recognition
|
||||
live_category = 'live-category'
|
||||
|
||||
47
tests/pipelines/test_human_image_generation.py
Normal file
47
tests/pipelines/test_human_image_generation.py
Normal file
@@ -0,0 +1,47 @@
|
||||
# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.
|
||||
import unittest
|
||||
|
||||
import cv2
|
||||
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.pipelines.base import Pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class HumanImageGenerationTest(unittest.TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.model_id = 'damo/cv_FreqHPT_human-image-generation'
|
||||
self.input = {
|
||||
'source_img_path':
|
||||
'data/test/images/human_image_generation_source_img.jpg',
|
||||
'target_pose_path':
|
||||
'data/test/images/human_image_generation_target_pose.txt'
|
||||
}
|
||||
|
||||
def pipeline_inference(self, pipeline: Pipeline, input: str):
|
||||
result = pipeline(input)
|
||||
logger.info(result)
|
||||
cv2.imwrite('result.jpg', result[OutputKeys.OUTPUT_IMG])
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_modelhub(self):
|
||||
human_image_generation = pipeline(
|
||||
Tasks.human_image_generation,
|
||||
model=self.model_id,
|
||||
revision='v1.0.1')
|
||||
self.pipeline_inference(human_image_generation, self.input)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||
def test_run_modelhub_default_model(self):
|
||||
human_image_generation = pipeline(Tasks.human_image_generation)
|
||||
self.pipeline_inference(human_image_generation, self.input)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user