mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 16:27:45 +01:00
[to #42322933]image2image_translation codes
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9526987
This commit is contained in:
3
data/test/images/img2img_input.jpg
Normal file
3
data/test/images/img2img_input.jpg
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:7e4cbf844cd16a892a7d2f2764b1537c346675d3b0145016d6836441ba907366
|
||||
size 9195
|
||||
3
data/test/images/img2img_input_mask.png
Normal file
3
data/test/images/img2img_input_mask.png
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:33b3d3076e191fa92511bf69fa76e1222b3b3be0049e711c948a1218b587510c
|
||||
size 4805
|
||||
3
data/test/images/img2img_input_masked_img.png
Normal file
3
data/test/images/img2img_input_masked_img.png
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:99c2b02a927b86ff194287ea4c5a05349dd800cff2b523212d1dad378c252feb
|
||||
size 103334
|
||||
@@ -77,6 +77,7 @@ class Pipelines(object):
|
||||
face_image_generation = 'gan-face-image-generation'
|
||||
style_transfer = 'AAMS-style-transfer'
|
||||
image_instance_segmentation = 'cascade-mask-rcnn-swin-image-instance-segmentation'
|
||||
image2image_translation = 'image-to-image-translation'
|
||||
live_category = 'live-category'
|
||||
video_category = 'video-category'
|
||||
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
from .transforms import * # noqa F403
|
||||
@@ -0,0 +1,121 @@
|
||||
import math
|
||||
import random
|
||||
|
||||
import torchvision.transforms.functional as TF
|
||||
from PIL import Image, ImageFilter
|
||||
|
||||
__all__ = [
|
||||
'Identity', 'PadToSquare', 'RandomScale', 'RandomRotate',
|
||||
'RandomGaussianBlur', 'RandomCrop'
|
||||
]
|
||||
|
||||
|
||||
class Identity(object):
|
||||
|
||||
def __call__(self, *args):
|
||||
if len(args) == 0:
|
||||
return None
|
||||
elif len(args) == 1:
|
||||
return args[0]
|
||||
else:
|
||||
return args
|
||||
|
||||
|
||||
class PadToSquare(object):
|
||||
|
||||
def __init__(self, fill=(255, 255, 255)):
|
||||
self.fill = fill
|
||||
|
||||
def __call__(self, img):
|
||||
w, h = img.size
|
||||
if w != h:
|
||||
if w > h:
|
||||
t = (w - h) // 2
|
||||
b = w - h - t
|
||||
padding = (0, t, 0, b)
|
||||
else:
|
||||
left = (h - w) // 2
|
||||
right = h - w - l
|
||||
padding = (left, 0, right, 0)
|
||||
img = TF.pad(img, padding, fill=self.fill)
|
||||
return img
|
||||
|
||||
|
||||
class RandomScale(object):
|
||||
|
||||
def __init__(self,
|
||||
min_scale=0.5,
|
||||
max_scale=2.0,
|
||||
min_ratio=0.8,
|
||||
max_ratio=1.25):
|
||||
self.min_scale = min_scale
|
||||
self.max_scale = max_scale
|
||||
self.min_ratio = min_ratio
|
||||
self.max_ratio = max_ratio
|
||||
|
||||
def __call__(self, img):
|
||||
w, h = img.size
|
||||
scale = 2**random.uniform(
|
||||
math.log2(self.min_scale), math.log2(self.max_scale))
|
||||
ratio = 2**random.uniform(
|
||||
math.log2(self.min_ratio), math.log2(self.max_ratio))
|
||||
ow = int(w * scale * math.sqrt(ratio))
|
||||
oh = int(h * scale / math.sqrt(ratio))
|
||||
img = img.resize((ow, oh), Image.BILINEAR)
|
||||
return img
|
||||
|
||||
|
||||
class RandomRotate(object):
|
||||
|
||||
def __init__(self,
|
||||
min_angle=-10.0,
|
||||
max_angle=10.0,
|
||||
padding=(255, 255, 255),
|
||||
p=0.5):
|
||||
self.min_angle = min_angle
|
||||
self.max_angle = max_angle
|
||||
self.padding = padding
|
||||
self.p = p
|
||||
|
||||
def __call__(self, img):
|
||||
if random.random() < self.p:
|
||||
angle = random.uniform(self.min_angle, self.max_angle)
|
||||
img = img.rotate(angle, Image.BILINEAR, fillcolor=self.padding)
|
||||
return img
|
||||
|
||||
|
||||
class RandomGaussianBlur(object):
|
||||
|
||||
def __init__(self, radius=5, p=0.5):
|
||||
self.radius = radius
|
||||
self.p = p
|
||||
|
||||
def __call__(self, img):
|
||||
if random.random() < self.p:
|
||||
img = img.filter(ImageFilter.GaussianBlur(radius=self.radius))
|
||||
return img
|
||||
|
||||
|
||||
class RandomCrop(object):
|
||||
|
||||
def __init__(self, size, padding=(255, 255, 255)):
|
||||
self.size = size
|
||||
self.padding = padding
|
||||
|
||||
def __call__(self, img):
|
||||
# pad
|
||||
w, h = img.size
|
||||
pad_w = max(0, self.size - w)
|
||||
pad_h = max(0, self.size - h)
|
||||
if pad_w > 0 or pad_h > 0:
|
||||
half_w = pad_w // 2
|
||||
half_h = pad_h // 2
|
||||
pad = (half_w, half_h, pad_w - half_w, pad_h - half_h)
|
||||
img = TF.pad(img, pad, fill=self.padding)
|
||||
|
||||
# crop
|
||||
w, h = img.size
|
||||
x1 = random.randint(0, w - self.size)
|
||||
y1 = random.randint(0, h - self.size)
|
||||
img = img.crop((x1, y1, x1 + self.size, y1 + self.size))
|
||||
return img
|
||||
@@ -0,0 +1,323 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
__all__ = ['UNet']
|
||||
|
||||
|
||||
def sinusoidal_embedding(timesteps, dim):
|
||||
# check input
|
||||
half = dim // 2
|
||||
timesteps = timesteps.float()
|
||||
|
||||
# compute sinusoidal embedding
|
||||
sinusoid = torch.outer(
|
||||
timesteps, torch.pow(10000,
|
||||
-torch.arange(half).to(timesteps).div(half)))
|
||||
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
|
||||
if dim % 2 != 0:
|
||||
x = torch.cat([x, torch.zeros_like(x[:, :1])], dim=1)
|
||||
return x
|
||||
|
||||
|
||||
class Resample(nn.Module):
|
||||
|
||||
def __init__(self, scale_factor=1.0):
|
||||
assert scale_factor in [0.5, 1.0, 2.0]
|
||||
super(Resample, self).__init__()
|
||||
self.scale_factor = scale_factor
|
||||
|
||||
def forward(self, x):
|
||||
if self.scale_factor == 2.0:
|
||||
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
||||
elif self.scale_factor == 0.5:
|
||||
x = F.avg_pool2d(x, kernel_size=2, stride=2)
|
||||
return x
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
|
||||
def __init__(self, in_dim, embed_dim, out_dim, dropout=0.0):
|
||||
super(ResidualBlock, self).__init__()
|
||||
self.in_dim = in_dim
|
||||
self.embed_dim = embed_dim
|
||||
self.out_dim = out_dim
|
||||
|
||||
# layers
|
||||
self.layer1 = nn.Sequential(
|
||||
nn.GroupNorm(32, in_dim), nn.SiLU(),
|
||||
nn.Conv2d(in_dim, out_dim, 3, padding=1))
|
||||
self.embedding = nn.Sequential(nn.SiLU(),
|
||||
nn.Linear(embed_dim, out_dim))
|
||||
self.layer2 = nn.Sequential(
|
||||
nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout),
|
||||
nn.Conv2d(out_dim, out_dim, 3, padding=1))
|
||||
self.shortcut = nn.Identity() if in_dim == out_dim else nn.Conv2d(
|
||||
in_dim, out_dim, 1)
|
||||
|
||||
# zero out the last layer params
|
||||
nn.init.zeros_(self.layer2[-1].weight)
|
||||
|
||||
def forward(self, x, y):
|
||||
identity = x
|
||||
x = self.layer1(x)
|
||||
x = x + self.embedding(y).unsqueeze(-1).unsqueeze(-1)
|
||||
x = self.layer2(x)
|
||||
x = x + self.shortcut(identity)
|
||||
return x
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
|
||||
def __init__(self, dim, context_dim=None, num_heads=8, dropout=0.0):
|
||||
assert dim % num_heads == 0
|
||||
assert context_dim is None or context_dim % num_heads == 0
|
||||
context_dim = context_dim or dim
|
||||
super(MultiHeadAttention, self).__init__()
|
||||
self.dim = dim
|
||||
self.context_dim = context_dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.scale = math.pow(self.head_dim, -0.25)
|
||||
|
||||
# layers
|
||||
self.q = nn.Linear(dim, dim, bias=False)
|
||||
self.k = nn.Linear(context_dim, dim, bias=False)
|
||||
self.v = nn.Linear(context_dim, dim, bias=False)
|
||||
self.o = nn.Linear(dim, dim)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x, context=None):
|
||||
# check inputs
|
||||
context = x if context is None else context
|
||||
b, n, c = x.size(0), self.num_heads, self.head_dim
|
||||
|
||||
# compute query, key, value
|
||||
q = self.q(x).view(b, -1, n, c)
|
||||
k = self.k(context).view(b, -1, n, c)
|
||||
v = self.v(context).view(b, -1, n, c)
|
||||
|
||||
# compute attention
|
||||
attn = torch.einsum('binc,bjnc->bnij', q * self.scale, k * self.scale)
|
||||
attn = F.softmax(attn, dim=-1)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
# gather context
|
||||
x = torch.einsum('bnij,bjnc->binc', attn, v)
|
||||
x = x.reshape(b, -1, n * c)
|
||||
|
||||
# output
|
||||
x = self.o(x)
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
|
||||
|
||||
class GLU(nn.Module):
|
||||
|
||||
def __init__(self, in_dim, out_dim):
|
||||
super(GLU, self).__init__()
|
||||
self.in_dim = in_dim
|
||||
self.out_dim = out_dim
|
||||
self.proj = nn.Linear(in_dim, out_dim * 2)
|
||||
|
||||
def forward(self, x):
|
||||
x, gate = self.proj(x).chunk(2, dim=-1)
|
||||
return x * F.gelu(gate)
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
|
||||
def __init__(self, dim, context_dim, num_heads, dropout=0.0):
|
||||
super(TransformerBlock, self).__init__()
|
||||
self.dim = dim
|
||||
self.context_dim = context_dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
|
||||
# input
|
||||
self.norm1 = nn.GroupNorm(32, dim, eps=1e-6, affine=True)
|
||||
self.conv1 = nn.Conv2d(dim, dim, 1)
|
||||
|
||||
# self attention
|
||||
self.norm2 = nn.LayerNorm(dim)
|
||||
self.self_attn = MultiHeadAttention(dim, None, num_heads, dropout)
|
||||
|
||||
# cross attention
|
||||
self.norm3 = nn.LayerNorm(dim)
|
||||
self.cross_attn = MultiHeadAttention(dim, context_dim, num_heads,
|
||||
dropout)
|
||||
|
||||
# ffn
|
||||
self.norm4 = nn.LayerNorm(dim)
|
||||
self.ffn = nn.Sequential(
|
||||
GLU(dim, dim * 4), nn.Dropout(dropout), nn.Linear(dim * 4, dim))
|
||||
|
||||
# output
|
||||
self.conv2 = nn.Conv2d(dim, dim, 1)
|
||||
|
||||
# zero out the last layer params
|
||||
nn.init.zeros_(self.conv2.weight)
|
||||
|
||||
def forward(self, x, context):
|
||||
b, c, h, w = x.size()
|
||||
identity = x
|
||||
|
||||
# input
|
||||
x = self.norm1(x)
|
||||
x = self.conv1(x).view(b, c, -1).transpose(1, 2)
|
||||
|
||||
# attention
|
||||
x = x + self.self_attn(self.norm2(x))
|
||||
x = x + self.cross_attn(self.norm3(x), context)
|
||||
x = x + self.ffn(self.norm4(x))
|
||||
|
||||
# output
|
||||
x = x.transpose(1, 2).view(b, c, h, w)
|
||||
x = self.conv2(x)
|
||||
return x + identity
|
||||
|
||||
|
||||
class UNet(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
resolution=64,
|
||||
in_dim=3,
|
||||
dim=192,
|
||||
context_dim=512,
|
||||
out_dim=3,
|
||||
dim_mult=[1, 2, 3, 5],
|
||||
num_heads=1,
|
||||
head_dim=None,
|
||||
num_res_blocks=2,
|
||||
attn_scales=[1 / 2, 1 / 4, 1 / 8],
|
||||
num_classes=1001,
|
||||
dropout=0.0):
|
||||
embed_dim = dim * 4
|
||||
super(UNet, self).__init__()
|
||||
self.resolution = resolution
|
||||
self.in_dim = in_dim
|
||||
self.dim = dim
|
||||
self.context_dim = context_dim
|
||||
self.out_dim = out_dim
|
||||
self.dim_mult = dim_mult
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.attn_scales = attn_scales
|
||||
self.num_classes = num_classes
|
||||
|
||||
# params
|
||||
enc_dims = [dim * u for u in [1] + dim_mult]
|
||||
dec_dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
||||
shortcut_dims = []
|
||||
scale = 1.0
|
||||
|
||||
# embeddings
|
||||
self.time_embedding = nn.Sequential(
|
||||
nn.Linear(dim, embed_dim), nn.SiLU(),
|
||||
nn.Linear(embed_dim, embed_dim))
|
||||
self.label_embedding = nn.Embedding(num_classes, context_dim)
|
||||
|
||||
# encoder
|
||||
self.encoder = nn.ModuleList(
|
||||
[nn.Conv2d(self.in_dim, dim, 3, padding=1)])
|
||||
shortcut_dims.append(dim)
|
||||
for i, (in_dim,
|
||||
out_dim) in enumerate(zip(enc_dims[:-1], enc_dims[1:])):
|
||||
for j in range(num_res_blocks):
|
||||
# residual (+attention) blocks
|
||||
block = nn.ModuleList(
|
||||
[ResidualBlock(in_dim, embed_dim, out_dim, dropout)])
|
||||
if scale in attn_scales:
|
||||
block.append(
|
||||
TransformerBlock(out_dim, context_dim, num_heads))
|
||||
in_dim = out_dim
|
||||
self.encoder.append(block)
|
||||
shortcut_dims.append(out_dim)
|
||||
|
||||
# downsample
|
||||
if i != len(dim_mult) - 1 and j == num_res_blocks - 1:
|
||||
self.encoder.append(
|
||||
nn.Conv2d(out_dim, out_dim, 3, stride=2, padding=1))
|
||||
shortcut_dims.append(out_dim)
|
||||
scale /= 2.0
|
||||
|
||||
# middle
|
||||
self.middle = nn.ModuleList([
|
||||
ResidualBlock(out_dim, embed_dim, out_dim, dropout),
|
||||
TransformerBlock(out_dim, context_dim, num_heads),
|
||||
ResidualBlock(out_dim, embed_dim, out_dim, dropout)
|
||||
])
|
||||
|
||||
# decoder
|
||||
self.decoder = nn.ModuleList()
|
||||
for i, (in_dim,
|
||||
out_dim) in enumerate(zip(dec_dims[:-1], dec_dims[1:])):
|
||||
for j in range(num_res_blocks + 1):
|
||||
# residual (+attention) blocks
|
||||
block = nn.ModuleList([
|
||||
ResidualBlock(in_dim + shortcut_dims.pop(), embed_dim,
|
||||
out_dim, dropout)
|
||||
])
|
||||
if scale in attn_scales:
|
||||
block.append(
|
||||
TransformerBlock(out_dim, context_dim, num_heads,
|
||||
dropout))
|
||||
in_dim = out_dim
|
||||
|
||||
# upsample
|
||||
if i != len(dim_mult) - 1 and j == num_res_blocks:
|
||||
block.append(
|
||||
nn.Sequential(
|
||||
Resample(scale_factor=2.0),
|
||||
nn.Conv2d(out_dim, out_dim, 3, padding=1)))
|
||||
scale *= 2.0
|
||||
self.decoder.append(block)
|
||||
|
||||
# head
|
||||
self.head = nn.Sequential(
|
||||
nn.GroupNorm(32, out_dim), nn.SiLU(),
|
||||
nn.Conv2d(out_dim, self.out_dim, 3, padding=1))
|
||||
|
||||
# zero out the last layer params
|
||||
nn.init.zeros_(self.head[-1].weight)
|
||||
|
||||
def forward(self, x, t, y, concat=None):
|
||||
# embeddings
|
||||
if concat is not None:
|
||||
x = torch.cat([x, concat], dim=1)
|
||||
t = self.time_embedding(sinusoidal_embedding(t, self.dim))
|
||||
y = self.label_embedding(y)
|
||||
|
||||
# encoder
|
||||
xs = []
|
||||
for block in self.encoder:
|
||||
x = self._forward_single(block, x, t, y)
|
||||
xs.append(x)
|
||||
|
||||
# middle
|
||||
for block in self.middle:
|
||||
x = self._forward_single(block, x, t, y)
|
||||
|
||||
# decoder
|
||||
for block in self.decoder:
|
||||
x = torch.cat([x, xs.pop()], dim=1)
|
||||
x = self._forward_single(block, x, t, y)
|
||||
|
||||
# head
|
||||
x = self.head(x)
|
||||
return x
|
||||
|
||||
def _forward_single(self, module, x, t, y):
|
||||
if isinstance(module, ResidualBlock):
|
||||
x = module(x, t)
|
||||
elif isinstance(module, TransformerBlock):
|
||||
x = module(x, y)
|
||||
elif isinstance(module, nn.ModuleList):
|
||||
for block in module:
|
||||
x = self._forward_single(block, x, t, y)
|
||||
else:
|
||||
x = module(x)
|
||||
return x
|
||||
@@ -0,0 +1,2 @@
|
||||
from .autoencoder import * # noqa F403
|
||||
from .clip import * # noqa F403
|
||||
@@ -0,0 +1,412 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
__all__ = ['VQAutoencoder', 'KLAutoencoder', 'PatchDiscriminator']
|
||||
|
||||
|
||||
def group_norm(dim):
|
||||
return nn.GroupNorm(32, dim, eps=1e-6, affine=True)
|
||||
|
||||
|
||||
class Resample(nn.Module):
|
||||
|
||||
def __init__(self, dim, scale_factor):
|
||||
super(Resample, self).__init__()
|
||||
self.dim = dim
|
||||
self.scale_factor = scale_factor
|
||||
|
||||
# layers
|
||||
if scale_factor == 2.0:
|
||||
self.resample = nn.Sequential(
|
||||
nn.Upsample(scale_factor=scale_factor, mode='nearest'),
|
||||
nn.Conv2d(dim, dim, 3, padding=1))
|
||||
elif scale_factor == 0.5:
|
||||
self.resample = nn.Sequential(
|
||||
nn.ZeroPad2d((0, 1, 0, 1)),
|
||||
nn.Conv2d(dim, dim, 3, stride=2, padding=0))
|
||||
else:
|
||||
self.resample = nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
return self.resample(x)
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
|
||||
def __init__(self, in_dim, out_dim, dropout=0.0):
|
||||
super(ResidualBlock, self).__init__()
|
||||
self.in_dim = in_dim
|
||||
self.out_dim = out_dim
|
||||
|
||||
# layers
|
||||
self.residual = nn.Sequential(
|
||||
group_norm(in_dim), nn.SiLU(),
|
||||
nn.Conv2d(in_dim, out_dim, 3, padding=1), group_norm(out_dim),
|
||||
nn.SiLU(), nn.Dropout(dropout),
|
||||
nn.Conv2d(out_dim, out_dim, 3, padding=1))
|
||||
self.shortcut = nn.Conv2d(in_dim, out_dim,
|
||||
1) if in_dim != out_dim else nn.Identity()
|
||||
|
||||
# zero out the last layer params
|
||||
nn.init.zeros_(self.residual[-1].weight)
|
||||
|
||||
def forward(self, x):
|
||||
return self.residual(x) + self.shortcut(x)
|
||||
|
||||
|
||||
class AttentionBlock(nn.Module):
|
||||
|
||||
def __init__(self, dim):
|
||||
super(AttentionBlock, self).__init__()
|
||||
self.dim = dim
|
||||
self.scale = math.pow(dim, -0.25)
|
||||
|
||||
# layers
|
||||
self.norm = group_norm(dim)
|
||||
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
|
||||
self.proj = nn.Conv2d(dim, dim, 1)
|
||||
|
||||
# zero out the last layer params
|
||||
nn.init.zeros_(self.proj.weight)
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
b, c, h, w = x.size()
|
||||
|
||||
# compute query, key, value
|
||||
x = self.norm(x)
|
||||
q, k, v = self.to_qkv(x).view(b, c * 3, -1).chunk(3, dim=1)
|
||||
|
||||
# compute attention
|
||||
attn = torch.einsum('bci,bcj->bij', q * self.scale, k * self.scale)
|
||||
attn = F.softmax(attn, dim=-1)
|
||||
|
||||
# gather context
|
||||
x = torch.einsum('bij,bcj->bci', attn, v)
|
||||
x = x.reshape(b, c, h, w)
|
||||
|
||||
# output
|
||||
x = self.proj(x)
|
||||
return x + identity
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim=128,
|
||||
z_dim=3,
|
||||
dim_mult=[1, 2, 4],
|
||||
num_res_blocks=2,
|
||||
attn_scales=[],
|
||||
dropout=0.0):
|
||||
super(Encoder, self).__init__()
|
||||
self.dim = dim
|
||||
self.z_dim = z_dim
|
||||
self.dim_mult = dim_mult
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.attn_scales = attn_scales
|
||||
|
||||
# params
|
||||
dims = [dim * u for u in [1] + dim_mult]
|
||||
scale = 1.0
|
||||
|
||||
# init block
|
||||
self.conv1 = nn.Conv2d(3, dims[0], 3, padding=1)
|
||||
|
||||
# downsample blocks
|
||||
downsamples = []
|
||||
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
||||
# residual (+attention) blocks
|
||||
for _ in range(num_res_blocks):
|
||||
downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
||||
if scale in attn_scales:
|
||||
downsamples.append(AttentionBlock(out_dim))
|
||||
in_dim = out_dim
|
||||
|
||||
# downsample block
|
||||
if i != len(dim_mult) - 1:
|
||||
downsamples.append(Resample(out_dim, scale_factor=0.5))
|
||||
scale /= 2.0
|
||||
self.downsamples = nn.Sequential(*downsamples)
|
||||
|
||||
# middle blocks
|
||||
self.middle = nn.Sequential(
|
||||
ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim),
|
||||
ResidualBlock(out_dim, out_dim, dropout))
|
||||
|
||||
# output blocks
|
||||
self.head = nn.Sequential(
|
||||
group_norm(out_dim), nn.SiLU(),
|
||||
nn.Conv2d(out_dim, z_dim, 3, padding=1))
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.downsamples(x)
|
||||
x = self.middle(x)
|
||||
x = self.head(x)
|
||||
return x
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim=128,
|
||||
z_dim=3,
|
||||
dim_mult=[1, 2, 4],
|
||||
num_res_blocks=2,
|
||||
attn_scales=[],
|
||||
dropout=0.0):
|
||||
super(Decoder, self).__init__()
|
||||
self.dim = dim
|
||||
self.z_dim = z_dim
|
||||
self.dim_mult = dim_mult
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.attn_scales = attn_scales
|
||||
|
||||
# params
|
||||
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
||||
scale = 1.0 / 2**(len(dim_mult) - 2)
|
||||
|
||||
# init block
|
||||
self.conv1 = nn.Conv2d(z_dim, dims[0], 3, padding=1)
|
||||
|
||||
# middle blocks
|
||||
self.middle = nn.Sequential(
|
||||
ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]),
|
||||
ResidualBlock(dims[0], dims[0], dropout))
|
||||
|
||||
# upsample blocks
|
||||
upsamples = []
|
||||
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
||||
# residual (+attention) blocks
|
||||
for _ in range(num_res_blocks + 1):
|
||||
upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
||||
if scale in attn_scales:
|
||||
upsamples.append(AttentionBlock(out_dim))
|
||||
in_dim = out_dim
|
||||
|
||||
# upsample block
|
||||
if i != len(dim_mult) - 1:
|
||||
upsamples.append(Resample(out_dim, scale_factor=2.0))
|
||||
scale *= 2.0
|
||||
self.upsamples = nn.Sequential(*upsamples)
|
||||
|
||||
# output blocks
|
||||
self.head = nn.Sequential(
|
||||
group_norm(out_dim), nn.SiLU(),
|
||||
nn.Conv2d(out_dim, 3, 3, padding=1))
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.middle(x)
|
||||
x = self.upsamples(x)
|
||||
x = self.head(x)
|
||||
return x
|
||||
|
||||
|
||||
class VectorQuantizer(nn.Module):
|
||||
|
||||
def __init__(self, codebook_size=8192, z_dim=3, beta=0.25):
|
||||
super(VectorQuantizer, self).__init__()
|
||||
self.codebook_size = codebook_size
|
||||
self.z_dim = z_dim
|
||||
self.beta = beta
|
||||
|
||||
# init codebook
|
||||
eps = math.sqrt(1.0 / codebook_size)
|
||||
self.codebook = nn.Parameter(
|
||||
torch.empty(codebook_size, z_dim).uniform_(-eps, eps))
|
||||
|
||||
def forward(self, z):
|
||||
# preprocess
|
||||
b, c, h, w = z.size()
|
||||
flatten = z.permute(0, 2, 3, 1).reshape(-1, c)
|
||||
|
||||
# quantization
|
||||
with torch.no_grad():
|
||||
tokens = torch.cdist(flatten, self.codebook).argmin(dim=1)
|
||||
quantized = F.embedding(tokens,
|
||||
self.codebook).view(b, h, w,
|
||||
c).permute(0, 3, 1, 2)
|
||||
|
||||
# compute loss
|
||||
codebook_loss = F.mse_loss(quantized, z.detach())
|
||||
commitment_loss = F.mse_loss(quantized.detach(), z)
|
||||
loss = codebook_loss + self.beta * commitment_loss
|
||||
|
||||
# perplexity
|
||||
counts = F.one_hot(tokens, self.codebook_size).sum(dim=0).to(z.dtype)
|
||||
# dist.all_reduce(counts)
|
||||
p = counts / counts.sum()
|
||||
perplexity = torch.exp(-torch.sum(p * torch.log(p + 1e-10)))
|
||||
|
||||
# postprocess
|
||||
tokens = tokens.view(b, h, w)
|
||||
quantized = z + (quantized - z).detach()
|
||||
return quantized, tokens, loss, perplexity
|
||||
|
||||
|
||||
class VQAutoencoder(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim=128,
|
||||
z_dim=3,
|
||||
dim_mult=[1, 2, 4],
|
||||
num_res_blocks=2,
|
||||
attn_scales=[],
|
||||
dropout=0.0,
|
||||
codebook_size=8192,
|
||||
beta=0.25):
|
||||
super(VQAutoencoder, self).__init__()
|
||||
self.dim = dim
|
||||
self.z_dim = z_dim
|
||||
self.dim_mult = dim_mult
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.attn_scales = attn_scales
|
||||
self.codebook_size = codebook_size
|
||||
self.beta = beta
|
||||
|
||||
# blocks
|
||||
self.encoder = Encoder(dim, z_dim, dim_mult, num_res_blocks,
|
||||
attn_scales, dropout)
|
||||
self.conv1 = nn.Conv2d(z_dim, z_dim, 1)
|
||||
self.quantizer = VectorQuantizer(codebook_size, z_dim, beta)
|
||||
self.conv2 = nn.Conv2d(z_dim, z_dim, 1)
|
||||
self.decoder = Decoder(dim, z_dim, dim_mult, num_res_blocks,
|
||||
attn_scales, dropout)
|
||||
|
||||
def forward(self, x):
|
||||
z = self.encoder(x)
|
||||
z = self.conv1(z)
|
||||
z, tokens, loss, perplexity = self.quantizer(z)
|
||||
z = self.conv2(z)
|
||||
x = self.decoder(z)
|
||||
return x, tokens, loss, perplexity
|
||||
|
||||
def encode(self, imgs):
|
||||
z = self.encoder(imgs)
|
||||
z = self.conv1(z)
|
||||
return z
|
||||
|
||||
def decode(self, z):
|
||||
r"""Absort the quantizer in the decoder.
|
||||
"""
|
||||
z = self.quantizer(z)[0]
|
||||
z = self.conv2(z)
|
||||
imgs = self.decoder(z)
|
||||
return imgs
|
||||
|
||||
@torch.no_grad()
|
||||
def encode_to_tokens(self, imgs):
|
||||
# preprocess
|
||||
z = self.encoder(imgs)
|
||||
z = self.conv1(z)
|
||||
|
||||
# quantization
|
||||
b, c, h, w = z.size()
|
||||
flatten = z.permute(0, 2, 3, 1).reshape(-1, c)
|
||||
tokens = torch.cdist(flatten, self.quantizer.codebook).argmin(dim=1)
|
||||
return tokens.view(b, -1)
|
||||
|
||||
@torch.no_grad()
|
||||
def decode_from_tokens(self, tokens):
|
||||
# dequantization
|
||||
z = F.embedding(tokens, self.quantizer.codebook)
|
||||
|
||||
# postprocess
|
||||
b, l, c = z.size()
|
||||
h = w = int(math.sqrt(l))
|
||||
z = z.view(b, h, w, c).permute(0, 3, 1, 2)
|
||||
z = self.conv2(z)
|
||||
imgs = self.decoder(z)
|
||||
return imgs
|
||||
|
||||
|
||||
class KLAutoencoder(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim=128,
|
||||
z_dim=4,
|
||||
dim_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
attn_scales=[],
|
||||
dropout=0.0):
|
||||
super(KLAutoencoder, self).__init__()
|
||||
self.dim = dim
|
||||
self.z_dim = z_dim
|
||||
self.dim_mult = dim_mult
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.attn_scales = attn_scales
|
||||
|
||||
# blocks
|
||||
self.encoder = Encoder(dim, z_dim * 2, dim_mult, num_res_blocks,
|
||||
attn_scales, dropout)
|
||||
self.conv1 = nn.Conv2d(z_dim * 2, z_dim * 2, 1)
|
||||
self.conv2 = nn.Conv2d(z_dim, z_dim, 1)
|
||||
self.decoder = Decoder(dim, z_dim, dim_mult, num_res_blocks,
|
||||
attn_scales, dropout)
|
||||
|
||||
def forward(self, x):
|
||||
mu, log_var = self.encode(x)
|
||||
z = self.reparameterize(mu, log_var)
|
||||
x = self.decode(z)
|
||||
return x, mu, log_var
|
||||
|
||||
def encode(self, x):
|
||||
x = self.encoder(x)
|
||||
mu, log_var = self.conv1(x).chunk(2, dim=1)
|
||||
return mu, log_var
|
||||
|
||||
def decode(self, z):
|
||||
x = self.conv2(z)
|
||||
x = self.decoder(x)
|
||||
return x
|
||||
|
||||
def reparameterize(self, mu, log_var):
|
||||
std = torch.exp(0.5 * log_var)
|
||||
eps = torch.randn_like(std)
|
||||
return eps * std + mu
|
||||
|
||||
|
||||
class PatchDiscriminator(nn.Module):
|
||||
|
||||
def __init__(self, in_dim=3, dim=64, num_layers=3):
|
||||
super(PatchDiscriminator, self).__init__()
|
||||
self.in_dim = in_dim
|
||||
self.dim = dim
|
||||
self.num_layers = num_layers
|
||||
|
||||
# params
|
||||
dims = [dim * min(8, 2**u) for u in range(num_layers + 1)]
|
||||
|
||||
# layers
|
||||
layers = [
|
||||
nn.Conv2d(in_dim, dim, 4, stride=2, padding=1),
|
||||
nn.LeakyReLU(0.2)
|
||||
]
|
||||
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
||||
stride = 1 if i == num_layers - 1 else 2
|
||||
layers += [
|
||||
nn.Conv2d(
|
||||
in_dim, out_dim, 4, stride=stride, padding=1, bias=False),
|
||||
nn.BatchNorm2d(out_dim),
|
||||
nn.LeakyReLU(0.2)
|
||||
]
|
||||
layers += [nn.Conv2d(out_dim, 1, 4, stride=1, padding=1)]
|
||||
self.layers = nn.Sequential(*layers)
|
||||
|
||||
# initialize weights
|
||||
self.apply(self.init_weights)
|
||||
|
||||
def forward(self, x):
|
||||
return self.layers(x)
|
||||
|
||||
def init_weights(self, m):
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.normal_(m.weight, 0.0, 0.02)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.normal_(m.weight, 1.0, 0.02)
|
||||
nn.init.zeros_(m.bias)
|
||||
418
modelscope/models/cv/image_to_image_translation/models/clip.py
Normal file
418
modelscope/models/cv/image_to_image_translation/models/clip.py
Normal file
@@ -0,0 +1,418 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import modelscope.models.cv.image_to_image_translation.ops as ops # for using differentiable all_gather
|
||||
|
||||
__all__ = [
|
||||
'CLIP', 'clip_vit_b_32', 'clip_vit_b_16', 'clip_vit_l_14',
|
||||
'clip_vit_l_14_336px', 'clip_vit_h_16'
|
||||
]
|
||||
|
||||
|
||||
def to_fp16(m):
|
||||
if isinstance(m, (nn.Linear, nn.Conv2d)):
|
||||
m.weight.data = m.weight.data.half()
|
||||
if m.bias is not None:
|
||||
m.bias.data = m.bias.data.half()
|
||||
elif hasattr(m, 'head'):
|
||||
p = getattr(m, 'head')
|
||||
p.data = p.data.half()
|
||||
|
||||
|
||||
class QuickGELU(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
return x * torch.sigmoid(1.702 * x)
|
||||
|
||||
|
||||
class LayerNorm(nn.LayerNorm):
|
||||
r"""Subclass of nn.LayerNorm to handle fp16.
|
||||
"""
|
||||
|
||||
def forward(self, x):
|
||||
return super(LayerNorm, self).forward(x.float()).type_as(x)
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
|
||||
def __init__(self, dim, num_heads, attn_dropout=0.0, proj_dropout=0.0):
|
||||
assert dim % num_heads == 0
|
||||
super(SelfAttention, self).__init__()
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.scale = 1.0 / math.sqrt(self.head_dim)
|
||||
|
||||
# layers
|
||||
self.to_qkv = nn.Linear(dim, dim * 3)
|
||||
self.attn_dropout = nn.Dropout(attn_dropout)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_dropout = nn.Dropout(proj_dropout)
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
r"""x: [B, L, C].
|
||||
mask: [*, L, L].
|
||||
"""
|
||||
b, l, _, n = *x.size(), self.num_heads
|
||||
|
||||
# compute query, key, and value
|
||||
q, k, v = self.to_qkv(x.transpose(0, 1)).chunk(3, dim=-1)
|
||||
q = q.reshape(l, b * n, -1).transpose(0, 1)
|
||||
k = k.reshape(l, b * n, -1).transpose(0, 1)
|
||||
v = v.reshape(l, b * n, -1).transpose(0, 1)
|
||||
|
||||
# compute attention
|
||||
attn = self.scale * torch.bmm(q, k.transpose(1, 2))
|
||||
if mask is not None:
|
||||
attn = attn.masked_fill(mask[:, :l, :l] == 0, float('-inf'))
|
||||
attn = F.softmax(attn.float(), dim=-1).type_as(attn)
|
||||
attn = self.attn_dropout(attn)
|
||||
|
||||
# gather context
|
||||
x = torch.bmm(attn, v)
|
||||
x = x.view(b, n, l, -1).transpose(1, 2).reshape(b, l, -1)
|
||||
|
||||
# output
|
||||
x = self.proj(x)
|
||||
x = self.proj_dropout(x)
|
||||
return x
|
||||
|
||||
|
||||
class AttentionBlock(nn.Module):
|
||||
|
||||
def __init__(self, dim, num_heads, attn_dropout=0.0, proj_dropout=0.0):
|
||||
super(AttentionBlock, self).__init__()
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
|
||||
# layers
|
||||
self.norm1 = LayerNorm(dim)
|
||||
self.attn = SelfAttention(dim, num_heads, attn_dropout, proj_dropout)
|
||||
self.norm2 = LayerNorm(dim)
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(dim, dim * 4), QuickGELU(), nn.Linear(dim * 4, dim),
|
||||
nn.Dropout(proj_dropout))
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
x = x + self.attn(self.norm1(x), mask)
|
||||
x = x + self.mlp(self.norm2(x))
|
||||
return x
|
||||
|
||||
|
||||
class VisionTransformer(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
image_size=224,
|
||||
patch_size=16,
|
||||
dim=768,
|
||||
out_dim=512,
|
||||
num_heads=12,
|
||||
num_layers=12,
|
||||
attn_dropout=0.0,
|
||||
proj_dropout=0.0,
|
||||
embedding_dropout=0.0):
|
||||
assert image_size % patch_size == 0
|
||||
super(VisionTransformer, self).__init__()
|
||||
self.image_size = image_size
|
||||
self.patch_size = patch_size
|
||||
self.dim = dim
|
||||
self.out_dim = out_dim
|
||||
self.num_heads = num_heads
|
||||
self.num_layers = num_layers
|
||||
self.num_patches = (image_size // patch_size)**2
|
||||
|
||||
# embeddings
|
||||
gain = 1.0 / math.sqrt(dim)
|
||||
self.patch_embedding = nn.Conv2d(
|
||||
3, dim, kernel_size=patch_size, stride=patch_size, bias=False)
|
||||
self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
|
||||
self.pos_embedding = nn.Parameter(
|
||||
gain * torch.randn(1, self.num_patches + 1, dim))
|
||||
self.dropout = nn.Dropout(embedding_dropout)
|
||||
|
||||
# transformer
|
||||
self.pre_norm = LayerNorm(dim)
|
||||
self.transformer = nn.Sequential(*[
|
||||
AttentionBlock(dim, num_heads, attn_dropout, proj_dropout)
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
self.post_norm = LayerNorm(dim)
|
||||
|
||||
# head
|
||||
self.head = nn.Parameter(gain * torch.randn(dim, out_dim))
|
||||
|
||||
def forward(self, x):
|
||||
b, dtype = x.size(0), self.head.dtype
|
||||
x = x.type(dtype)
|
||||
|
||||
# patch-embedding
|
||||
x = self.patch_embedding(x).flatten(2).permute(0, 2, 1) # [b, n, c]
|
||||
x = torch.cat([self.cls_embedding.repeat(b, 1, 1).type(dtype), x],
|
||||
dim=1)
|
||||
x = self.dropout(x + self.pos_embedding.type(dtype))
|
||||
x = self.pre_norm(x)
|
||||
|
||||
# transformer
|
||||
x = self.transformer(x)
|
||||
|
||||
# head
|
||||
x = self.post_norm(x)
|
||||
x = torch.mm(x[:, 0, :], self.head)
|
||||
return x
|
||||
|
||||
def fp16(self):
|
||||
return self.apply(to_fp16)
|
||||
|
||||
|
||||
class TextTransformer(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
vocab_size,
|
||||
text_len,
|
||||
dim=512,
|
||||
out_dim=512,
|
||||
num_heads=8,
|
||||
num_layers=12,
|
||||
attn_dropout=0.0,
|
||||
proj_dropout=0.0,
|
||||
embedding_dropout=0.0):
|
||||
super(TextTransformer, self).__init__()
|
||||
self.vocab_size = vocab_size
|
||||
self.text_len = text_len
|
||||
self.dim = dim
|
||||
self.out_dim = out_dim
|
||||
self.num_heads = num_heads
|
||||
self.num_layers = num_layers
|
||||
|
||||
# embeddings
|
||||
self.token_embedding = nn.Embedding(vocab_size, dim)
|
||||
self.pos_embedding = nn.Parameter(0.01 * torch.randn(1, text_len, dim))
|
||||
self.dropout = nn.Dropout(embedding_dropout)
|
||||
|
||||
# transformer
|
||||
self.transformer = nn.ModuleList([
|
||||
AttentionBlock(dim, num_heads, attn_dropout, proj_dropout)
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
self.norm = LayerNorm(dim)
|
||||
|
||||
# head
|
||||
gain = 1.0 / math.sqrt(dim)
|
||||
self.head = nn.Parameter(gain * torch.randn(dim, out_dim))
|
||||
|
||||
# causal attention mask
|
||||
self.register_buffer('attn_mask',
|
||||
torch.tril(torch.ones(1, text_len, text_len)))
|
||||
|
||||
def forward(self, x):
|
||||
eot, dtype = x.argmax(dim=-1), self.head.dtype
|
||||
|
||||
# embeddings
|
||||
x = self.dropout(
|
||||
self.token_embedding(x).type(dtype)
|
||||
+ self.pos_embedding.type(dtype))
|
||||
|
||||
# transformer
|
||||
for block in self.transformer:
|
||||
x = block(x, self.attn_mask)
|
||||
|
||||
# head
|
||||
x = self.norm(x)
|
||||
x = torch.mm(x[torch.arange(x.size(0)), eot], self.head)
|
||||
return x
|
||||
|
||||
def fp16(self):
|
||||
return self.apply(to_fp16)
|
||||
|
||||
|
||||
class CLIP(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
embed_dim=512,
|
||||
image_size=224,
|
||||
patch_size=16,
|
||||
vision_dim=768,
|
||||
vision_heads=12,
|
||||
vision_layers=12,
|
||||
vocab_size=49408,
|
||||
text_len=77,
|
||||
text_dim=512,
|
||||
text_heads=8,
|
||||
text_layers=12,
|
||||
attn_dropout=0.0,
|
||||
proj_dropout=0.0,
|
||||
embedding_dropout=0.0):
|
||||
super(CLIP, self).__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.image_size = image_size
|
||||
self.patch_size = patch_size
|
||||
self.vision_dim = vision_dim
|
||||
self.vision_heads = vision_heads
|
||||
self.vision_layers = vision_layers
|
||||
self.vocab_size = vocab_size
|
||||
self.text_len = text_len
|
||||
self.text_dim = text_dim
|
||||
self.text_heads = text_heads
|
||||
self.text_layers = text_layers
|
||||
|
||||
# models
|
||||
self.visual = VisionTransformer(
|
||||
image_size=image_size,
|
||||
patch_size=patch_size,
|
||||
dim=vision_dim,
|
||||
out_dim=embed_dim,
|
||||
num_heads=vision_heads,
|
||||
num_layers=vision_layers,
|
||||
attn_dropout=attn_dropout,
|
||||
proj_dropout=proj_dropout,
|
||||
embedding_dropout=embedding_dropout)
|
||||
self.textual = TextTransformer(
|
||||
vocab_size=vocab_size,
|
||||
text_len=text_len,
|
||||
dim=text_dim,
|
||||
out_dim=embed_dim,
|
||||
num_heads=text_heads,
|
||||
num_layers=text_layers,
|
||||
attn_dropout=attn_dropout,
|
||||
proj_dropout=proj_dropout,
|
||||
embedding_dropout=embedding_dropout)
|
||||
self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([]))
|
||||
|
||||
def forward(self, imgs, txt_tokens):
|
||||
r"""imgs: [B, C, H, W] of torch.float32.
|
||||
txt_tokens: [B, T] of torch.long.
|
||||
"""
|
||||
xi = self.visual(imgs)
|
||||
xt = self.textual(txt_tokens)
|
||||
|
||||
# normalize features
|
||||
xi = F.normalize(xi, p=2, dim=1)
|
||||
xt = F.normalize(xt, p=2, dim=1)
|
||||
|
||||
# gather features from all ranks
|
||||
full_xi = ops.diff_all_gather(xi)
|
||||
full_xt = ops.diff_all_gather(xt)
|
||||
|
||||
# logits
|
||||
scale = self.log_scale.exp()
|
||||
logits_i2t = scale * torch.mm(xi, full_xt.t())
|
||||
logits_t2i = scale * torch.mm(xt, full_xi.t())
|
||||
|
||||
# labels
|
||||
labels = torch.arange(
|
||||
len(xi) * ops.get_rank(),
|
||||
len(xi) * (ops.get_rank() + 1),
|
||||
dtype=torch.long,
|
||||
device=xi.device)
|
||||
return logits_i2t, logits_t2i, labels
|
||||
|
||||
def init_weights(self):
|
||||
# embeddings
|
||||
nn.init.normal_(self.textual.token_embedding.weight, std=0.02)
|
||||
nn.init.normal_(self.visual.patch_embedding.weight, tsd=0.1)
|
||||
|
||||
# attentions
|
||||
for modality in ['visual', 'textual']:
|
||||
dim = self.vision_dim if modality == 'visual' else 'textual'
|
||||
transformer = getattr(self, modality).transformer
|
||||
proj_gain = (1.0 / math.sqrt(dim)) * (
|
||||
1.0 / math.sqrt(2 * transformer.num_layers))
|
||||
attn_gain = 1.0 / math.sqrt(dim)
|
||||
mlp_gain = 1.0 / math.sqrt(2.0 * dim)
|
||||
for block in transformer.layers:
|
||||
nn.init.normal_(block.attn.to_qkv.weight, std=attn_gain)
|
||||
nn.init.normal_(block.attn.proj.weight, std=proj_gain)
|
||||
nn.init.normal_(block.mlp[0].weight, std=mlp_gain)
|
||||
nn.init.normal_(block.mlp[2].weight, std=proj_gain)
|
||||
|
||||
def param_groups(self):
|
||||
groups = [{
|
||||
'params': [
|
||||
p for n, p in self.named_parameters()
|
||||
if 'norm' in n or n.endswith('bias')
|
||||
],
|
||||
'weight_decay':
|
||||
0.0
|
||||
}, {
|
||||
'params': [
|
||||
p for n, p in self.named_parameters()
|
||||
if not ('norm' in n or n.endswith('bias'))
|
||||
]
|
||||
}]
|
||||
return groups
|
||||
|
||||
def fp16(self):
|
||||
return self.apply(to_fp16)
|
||||
|
||||
|
||||
def clip_vit_b_32(**kwargs):
|
||||
return CLIP(
|
||||
embed_dim=512,
|
||||
image_size=224,
|
||||
patch_size=32,
|
||||
vision_dim=768,
|
||||
vision_heads=12,
|
||||
vision_layers=12,
|
||||
text_dim=512,
|
||||
text_heads=8,
|
||||
text_layers=12,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def clip_vit_b_16(**kwargs):
|
||||
return CLIP(
|
||||
embed_dim=512,
|
||||
image_size=224,
|
||||
patch_size=16,
|
||||
vision_dim=768,
|
||||
vision_heads=12,
|
||||
vision_layers=12,
|
||||
text_dim=512,
|
||||
text_heads=8,
|
||||
text_layers=12,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def clip_vit_l_14(**kwargs):
|
||||
return CLIP(
|
||||
embed_dim=768,
|
||||
image_size=224,
|
||||
patch_size=14,
|
||||
vision_dim=1024,
|
||||
vision_heads=16,
|
||||
vision_layers=24,
|
||||
text_dim=768,
|
||||
text_heads=12,
|
||||
text_layers=12,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def clip_vit_l_14_336px(**kwargs):
|
||||
return CLIP(
|
||||
embed_dim=768,
|
||||
image_size=336,
|
||||
patch_size=14,
|
||||
vision_dim=1024,
|
||||
vision_heads=16,
|
||||
vision_layers=24,
|
||||
text_dim=768,
|
||||
text_heads=12,
|
||||
text_layers=12,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def clip_vit_h_16(**kwargs):
|
||||
return CLIP(
|
||||
embed_dim=1024,
|
||||
image_size=256,
|
||||
patch_size=16,
|
||||
vision_dim=1280,
|
||||
vision_heads=16,
|
||||
vision_layers=32,
|
||||
text_dim=1024,
|
||||
text_heads=16,
|
||||
text_layers=24,
|
||||
**kwargs)
|
||||
@@ -0,0 +1,8 @@
|
||||
from .degradation import * # noqa F403
|
||||
from .diffusion import * # noqa F403
|
||||
from .losses import * # noqa F403
|
||||
from .metrics import * # noqa F403
|
||||
from .random_color import * # noqa F403
|
||||
from .random_mask import * # noqa F403
|
||||
from .svd import * # noqa F403
|
||||
from .utils import * # noqa F403
|
||||
663
modelscope/models/cv/image_to_image_translation/ops/apps.py
Normal file
663
modelscope/models/cv/image_to_image_translation/ops/apps.py
Normal file
@@ -0,0 +1,663 @@
|
||||
# APPs that facilitate the use of pretrained neural networks.
|
||||
|
||||
import os.path as osp
|
||||
|
||||
import artist.data as data
|
||||
import artist.models as models
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.cuda.amp as amp
|
||||
import torch.nn.functional as F
|
||||
import torchvision.transforms as T
|
||||
from artist import DOWNLOAD_TO_CACHE
|
||||
from PIL import Image
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
from .utils import parallel, read_image
|
||||
|
||||
__all__ = [
|
||||
'FeatureExtractor', 'Classifier', 'Text2Image', 'Sole2Shoe', 'ImageParser',
|
||||
'TextImageMatch', 'taobao_feature_extractor', 'singleton_classifier',
|
||||
'orientation_classifier', 'fashion_text2image', 'mindalle_text2image',
|
||||
'sole2shoe', 'sole_parser', 'sod_foreground_parser',
|
||||
'fashion_text_image_match'
|
||||
]
|
||||
|
||||
|
||||
class ImageFolder(Dataset):
|
||||
|
||||
def __init__(self, paths, transforms=None):
|
||||
self.paths = paths
|
||||
self.transforms = transforms
|
||||
|
||||
def __getitem__(self, index):
|
||||
img = read_image(self.paths[index])
|
||||
if img.mode != 'RGB':
|
||||
img = img.convert('RGB')
|
||||
if self.transforms is not None:
|
||||
img = self.transforms(img)
|
||||
return img
|
||||
|
||||
def __len__(self):
|
||||
return len(self.paths)
|
||||
|
||||
|
||||
class FeatureExtractor(object):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model='InceptionV1',
|
||||
checkpoint='models/inception-v1/1218shoes.v9_7.140.0.1520000',
|
||||
resolution=224,
|
||||
mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225],
|
||||
batch_size=64,
|
||||
device=torch.device(
|
||||
'cuda' if torch.cuda.is_available() else 'cpu')): # noqa E125
|
||||
self.resolution = resolution
|
||||
self.batch_size = batch_size
|
||||
self.device = device
|
||||
|
||||
# init model
|
||||
self.net = getattr(
|
||||
models,
|
||||
model)(num_classes=None).eval().requires_grad_(False).to(device)
|
||||
self.net.load_state_dict(
|
||||
torch.load(DOWNLOAD_TO_CACHE(checkpoint), map_location=device))
|
||||
|
||||
# data transforms
|
||||
self.transforms = T.Compose([
|
||||
data.PadToSquare(),
|
||||
T.Resize(resolution),
|
||||
T.ToTensor(),
|
||||
T.Normalize(mean, std)
|
||||
])
|
||||
|
||||
def __call__(self, imgs, num_workers=0):
|
||||
r"""imgs: Either a PIL.Image or a list of PIL.Image instances.
|
||||
"""
|
||||
# preprocess
|
||||
if isinstance(imgs, Image.Image):
|
||||
imgs = [imgs]
|
||||
assert isinstance(imgs,
|
||||
(tuple, list)) and isinstance(imgs[0], Image.Image)
|
||||
imgs = torch.stack(parallel(self.transforms, imgs, num_workers), dim=0)
|
||||
|
||||
# forward
|
||||
feats = []
|
||||
for batch in imgs.split(self.batch_size, dim=0):
|
||||
batch = batch.to(self.device, non_blocking=True)
|
||||
feats.append(self.net(batch))
|
||||
return torch.cat(feats, dim=0)
|
||||
|
||||
def batch_process(self, paths):
|
||||
# init dataloader
|
||||
dataloader = DataLoader(
|
||||
dataset=ImageFolder(paths, self.transforms),
|
||||
batch_size=self.batch_size,
|
||||
shuffle=False,
|
||||
drop_last=False,
|
||||
pin_memory=True,
|
||||
num_workers=8,
|
||||
prefetch_factor=2)
|
||||
|
||||
# forward
|
||||
feats = []
|
||||
for step, batch in enumerate(dataloader, 1):
|
||||
print(f'Step: {step}/{len(dataloader)}', flush=True)
|
||||
batch = batch.to(self.device, non_blocking=True)
|
||||
feats.append(self.net(batch))
|
||||
return torch.cat(feats)
|
||||
|
||||
|
||||
class Classifier(object):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model='InceptionV1',
|
||||
checkpoint='models/classifier/shoes+apparel+bag-sgdetect-211230.pth',
|
||||
num_classes=1,
|
||||
resolution=224,
|
||||
mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225],
|
||||
batch_size=64,
|
||||
device=torch.device(
|
||||
'cuda' if torch.cuda.is_available() else 'cpu')): # noqa E125
|
||||
self.num_classes = num_classes
|
||||
self.resolution = resolution
|
||||
self.batch_size = batch_size
|
||||
self.device = device
|
||||
|
||||
# init model
|
||||
self.net = getattr(models, model)(
|
||||
num_classes=num_classes).eval().requires_grad_(False).to(device)
|
||||
self.net.load_state_dict(
|
||||
torch.load(DOWNLOAD_TO_CACHE(checkpoint), map_location=device))
|
||||
|
||||
# data transforms
|
||||
self.transforms = T.Compose([
|
||||
data.PadToSquare(),
|
||||
T.Resize(resolution),
|
||||
T.ToTensor(),
|
||||
T.Normalize(mean, std)
|
||||
])
|
||||
|
||||
def __call__(self, imgs, num_workers=0):
|
||||
r"""imgs: Either a PIL.Image or a list of PIL.Image instances.
|
||||
"""
|
||||
# preprocess
|
||||
if isinstance(imgs, Image.Image):
|
||||
imgs = [imgs]
|
||||
assert isinstance(imgs,
|
||||
(tuple, list)) and isinstance(imgs[0], Image.Image)
|
||||
imgs = torch.stack(parallel(self.transforms, imgs, num_workers), dim=0)
|
||||
|
||||
# forward
|
||||
scores = []
|
||||
for batch in imgs.split(self.batch_size, dim=0):
|
||||
batch = batch.to(self.device, non_blocking=True)
|
||||
logits = self.net(batch)
|
||||
scores.append(logits.sigmoid() if self.num_classes == # noqa W504
|
||||
1 else logits.softmax(dim=1))
|
||||
return torch.cat(scores, dim=0)
|
||||
|
||||
|
||||
class Text2Image(object):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vqgan_dim=128,
|
||||
vqgan_z_dim=256,
|
||||
vqgan_dim_mult=[1, 1, 2, 2, 4],
|
||||
vqgan_num_res_blocks=2,
|
||||
vqgan_attn_scales=[1.0 / 16],
|
||||
vqgan_codebook_size=975,
|
||||
vqgan_beta=0.25,
|
||||
gpt_txt_vocab_size=21128,
|
||||
gpt_txt_seq_len=64,
|
||||
gpt_img_seq_len=1024,
|
||||
gpt_dim=1024,
|
||||
gpt_num_heads=16,
|
||||
gpt_num_layers=24,
|
||||
vqgan_checkpoint='models/vqgan/vqgan_shoes+apparels_step10k_vocab975.pth',
|
||||
gpt_checkpoint='models/seq2seq_gpt/text2image_shoes+apparels_step400k.pth',
|
||||
tokenizer=data.BertTokenizer(name='bert-base-chinese', length=64),
|
||||
batch_size=16,
|
||||
device=torch.device(
|
||||
'cuda' if torch.cuda.is_available() else 'cpu')): # noqa E125
|
||||
self.tokenizer = tokenizer
|
||||
self.batch_size = batch_size
|
||||
self.device = device
|
||||
|
||||
# init VQGAN model
|
||||
self.vqgan = models.VQGAN(
|
||||
dim=vqgan_dim,
|
||||
z_dim=vqgan_z_dim,
|
||||
dim_mult=vqgan_dim_mult,
|
||||
num_res_blocks=vqgan_num_res_blocks,
|
||||
attn_scales=vqgan_attn_scales,
|
||||
codebook_size=vqgan_codebook_size,
|
||||
beta=vqgan_beta).eval().requires_grad_(False).to(device)
|
||||
self.vqgan.load_state_dict(
|
||||
torch.load(
|
||||
DOWNLOAD_TO_CACHE(vqgan_checkpoint), map_location=device))
|
||||
|
||||
# init GPT model
|
||||
self.gpt = models.Seq2SeqGPT(
|
||||
src_vocab_size=gpt_txt_vocab_size,
|
||||
tar_vocab_size=vqgan_codebook_size,
|
||||
src_seq_len=gpt_txt_seq_len,
|
||||
tar_seq_len=gpt_img_seq_len,
|
||||
dim=gpt_dim,
|
||||
num_heads=gpt_num_heads,
|
||||
num_layers=gpt_num_layers).eval().requires_grad_(False).to(device)
|
||||
self.gpt.load_state_dict(
|
||||
torch.load(DOWNLOAD_TO_CACHE(gpt_checkpoint), map_location=device))
|
||||
|
||||
def __call__(self,
|
||||
txts,
|
||||
top_k=64,
|
||||
top_p=None,
|
||||
temperature=0.6,
|
||||
use_fp16=True):
|
||||
# preprocess
|
||||
if isinstance(txts, str):
|
||||
txts = [txts]
|
||||
assert isinstance(txts, (tuple, list)) and isinstance(txts[0], str)
|
||||
txt_tokens = torch.LongTensor([self.tokenizer(u) for u in txts])
|
||||
|
||||
# forward
|
||||
out_imgs = []
|
||||
for batch in txt_tokens.split(self.batch_size, dim=0):
|
||||
# sample
|
||||
batch = batch.to(self.device, non_blocking=True)
|
||||
with amp.autocast(enabled=use_fp16):
|
||||
img_tokens = self.gpt.sample(batch, top_k, top_p, temperature)
|
||||
|
||||
# decode
|
||||
imgs = self.vqgan.decode_from_tokens(img_tokens)
|
||||
imgs = self._whiten_borders(imgs)
|
||||
imgs = imgs.clamp_(-1, 1).add_(1).mul_(125.0).permute(
|
||||
0, 2, 3, 1).cpu().numpy().astype(np.uint8)
|
||||
imgs = [Image.fromarray(u) for u in imgs]
|
||||
|
||||
# append
|
||||
out_imgs += imgs
|
||||
return out_imgs
|
||||
|
||||
def _whiten_borders(self, imgs):
|
||||
r"""Remove border artifacts.
|
||||
"""
|
||||
imgs[:, :, :18, :] = 1
|
||||
imgs[:, :, :, :18] = 1
|
||||
imgs[:, :, -18:, :] = 1
|
||||
imgs[:, :, :, -18:] = 1
|
||||
return imgs
|
||||
|
||||
|
||||
class Sole2Shoe(object):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vqgan_dim=128,
|
||||
vqgan_z_dim=256,
|
||||
vqgan_dim_mult=[1, 1, 2, 2, 4],
|
||||
vqgan_num_res_blocks=2,
|
||||
vqgan_attn_scales=[1.0 / 16],
|
||||
vqgan_codebook_size=975,
|
||||
vqgan_beta=0.25,
|
||||
src_resolution=256,
|
||||
tar_resolution=512,
|
||||
gpt_dim=1024,
|
||||
gpt_num_heads=16,
|
||||
gpt_num_layers=24,
|
||||
vqgan_checkpoint='models/vqgan/vqgan_shoes+apparels_step10k_vocab975.pth',
|
||||
gpt_checkpoint='models/seq2seq_gpt/sole2shoe-step300k-220104.pth',
|
||||
batch_size=12,
|
||||
device=torch.device(
|
||||
'cuda' if torch.cuda.is_available() else 'cpu')): # noqa E125
|
||||
self.batch_size = batch_size
|
||||
self.device = device
|
||||
src_seq_len = (src_resolution // 16)**2
|
||||
tar_seq_len = (tar_resolution // 16)**2
|
||||
|
||||
# init VQGAN model
|
||||
self.vqgan = models.VQGAN(
|
||||
dim=vqgan_dim,
|
||||
z_dim=vqgan_z_dim,
|
||||
dim_mult=vqgan_dim_mult,
|
||||
num_res_blocks=vqgan_num_res_blocks,
|
||||
attn_scales=vqgan_attn_scales,
|
||||
codebook_size=vqgan_codebook_size,
|
||||
beta=vqgan_beta).eval().requires_grad_(False).to(device)
|
||||
self.vqgan.load_state_dict(
|
||||
torch.load(
|
||||
DOWNLOAD_TO_CACHE(vqgan_checkpoint), map_location=device))
|
||||
|
||||
# init GPT model
|
||||
self.gpt = models.Seq2SeqGPT(
|
||||
src_vocab_size=vqgan_codebook_size,
|
||||
tar_vocab_size=vqgan_codebook_size,
|
||||
src_seq_len=src_seq_len,
|
||||
tar_seq_len=tar_seq_len,
|
||||
dim=gpt_dim,
|
||||
num_heads=gpt_num_heads,
|
||||
num_layers=gpt_num_layers).eval().requires_grad_(False).to(device)
|
||||
self.gpt.load_state_dict(
|
||||
torch.load(DOWNLOAD_TO_CACHE(gpt_checkpoint), map_location=device))
|
||||
|
||||
# data transforms
|
||||
self.transforms = T.Compose([
|
||||
data.PadToSquare(),
|
||||
T.Resize(src_resolution),
|
||||
T.ToTensor(),
|
||||
T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
||||
])
|
||||
|
||||
def __call__(self,
|
||||
sole_imgs,
|
||||
top_k=64,
|
||||
top_p=None,
|
||||
temperature=0.6,
|
||||
use_fp16=True,
|
||||
num_workers=0):
|
||||
# preprocess
|
||||
if isinstance(sole_imgs, Image.Image):
|
||||
sole_imgs = [sole_imgs]
|
||||
assert isinstance(sole_imgs, (tuple, list)) and isinstance(
|
||||
sole_imgs[0], Image.Image)
|
||||
sole_imgs = torch.stack(
|
||||
parallel(self.transforms, sole_imgs, num_workers), dim=0)
|
||||
|
||||
# forward
|
||||
out_imgs = []
|
||||
for batch in sole_imgs.split(self.batch_size, dim=0):
|
||||
# sample
|
||||
batch = batch.to(self.device)
|
||||
with amp.autocast(enabled=use_fp16):
|
||||
sole_tokens = self.vqgan.encode_to_tokens(batch)
|
||||
shoe_tokens = self.gpt.sample(sole_tokens, top_k, top_p,
|
||||
temperature)
|
||||
|
||||
# decode
|
||||
shoe_imgs = self.vqgan.decode_from_tokens(shoe_tokens)
|
||||
shoe_imgs = self._whiten_borders(shoe_imgs)
|
||||
shoe_imgs = shoe_imgs.clamp_(-1, 1).add_(1).mul_(125.0).permute(
|
||||
0, 2, 3, 1).cpu().numpy().astype(np.uint8)
|
||||
shoe_imgs = [Image.fromarray(u) for u in shoe_imgs]
|
||||
|
||||
# append
|
||||
out_imgs += shoe_imgs
|
||||
return out_imgs
|
||||
|
||||
def _whiten_borders(self, imgs):
|
||||
r"""Remove border artifacts.
|
||||
"""
|
||||
imgs[:, :, :18, :] = 1
|
||||
imgs[:, :, :, :18] = 1
|
||||
imgs[:, :, -18:, :] = 1
|
||||
imgs[:, :, :, -18:] = 1
|
||||
return imgs
|
||||
|
||||
|
||||
class ImageParser(object):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model='SPNet',
|
||||
num_classes=2,
|
||||
resolution=800,
|
||||
mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225],
|
||||
model_with_softmax=False,
|
||||
checkpoint='models/spnet/sole_segmentation_211219.pth',
|
||||
batch_size=16,
|
||||
device=torch.device(
|
||||
'cuda' if torch.cuda.is_available() else 'cpu')): # noqa E125
|
||||
self.batch_size = batch_size
|
||||
self.device = device
|
||||
|
||||
# init model
|
||||
if checkpoint.endswith('.pt'):
|
||||
self.net = torch.jit.load(
|
||||
DOWNLOAD_TO_CACHE(checkpoint)).eval().to(device)
|
||||
[p.requires_grad_(False) for p in self.net.parameters()]
|
||||
else:
|
||||
self.net = getattr(models, model)(
|
||||
num_classes=num_classes,
|
||||
pretrained=False).eval().requires_grad_(False).to(device)
|
||||
self.net.load_state_dict(
|
||||
torch.load(DOWNLOAD_TO_CACHE(checkpoint), map_location=device))
|
||||
self.softmax = (lambda x, dim: x) if model_with_softmax else F.softmax
|
||||
|
||||
# data transforms
|
||||
self.transforms = T.Compose([
|
||||
data.PadToSquare(),
|
||||
T.Resize(resolution),
|
||||
T.ToTensor(),
|
||||
T.Normalize(mean, std)
|
||||
])
|
||||
|
||||
def __call__(self, imgs, num_workers=0):
|
||||
# preprocess
|
||||
if isinstance(imgs, Image.Image):
|
||||
imgs = [imgs]
|
||||
assert isinstance(imgs,
|
||||
(tuple, list)) and isinstance(imgs[0], Image.Image)
|
||||
sizes = [u.size for u in imgs]
|
||||
imgs = torch.stack(parallel(self.transforms, imgs, num_workers), dim=0)
|
||||
|
||||
# forward
|
||||
masks = []
|
||||
for batch in imgs.split(self.batch_size, dim=0):
|
||||
batch = batch.to(self.device, non_blocking=True)
|
||||
masks.append(self.softmax(self.net(batch), dim=1))
|
||||
|
||||
# postprocess
|
||||
masks = torch.cat(masks, dim=0).unsqueeze(1)
|
||||
masks = [
|
||||
F.interpolate(u, v, mode='bilinear', align_corners=False)
|
||||
for u, v in zip(masks, sizes)
|
||||
]
|
||||
return masks
|
||||
|
||||
|
||||
class TextImageMatch(object):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim=512,
|
||||
image_size=224,
|
||||
patch_size=32,
|
||||
vision_dim=768,
|
||||
vision_heads=12,
|
||||
vision_layers=12,
|
||||
vocab_size=21128,
|
||||
text_len=77,
|
||||
text_dim=512,
|
||||
text_heads=8,
|
||||
text_layers=12,
|
||||
mean=[0.48145466, 0.4578275, 0.40821073],
|
||||
std=[0.26862954, 0.26130258, 0.27577711],
|
||||
checkpoint='models/clip/clip_shoes+apparels_step84k_210105.pth',
|
||||
tokenizer=data.BertTokenizer(name='bert-base-chinese', length=77),
|
||||
batch_size=64,
|
||||
device=torch.device(
|
||||
'cuda' if torch.cuda.is_available() else 'cpu')): # noqa E125
|
||||
self.tokenizer = tokenizer
|
||||
self.batch_size = batch_size
|
||||
self.device = device
|
||||
|
||||
# init model
|
||||
self.clip = models.CLIP(
|
||||
embed_dim=embed_dim,
|
||||
image_size=image_size,
|
||||
patch_size=patch_size,
|
||||
vision_dim=vision_dim,
|
||||
vision_heads=vision_heads,
|
||||
vision_layers=vision_layers,
|
||||
vocab_size=vocab_size,
|
||||
text_len=text_len,
|
||||
text_dim=text_dim,
|
||||
text_heads=text_heads,
|
||||
text_layers=text_layers).eval().requires_grad_(False).to(device)
|
||||
self.clip.load_state_dict(
|
||||
torch.load(DOWNLOAD_TO_CACHE(checkpoint), map_location=device))
|
||||
|
||||
# transforms
|
||||
scale_size = int(image_size * 8 / 7)
|
||||
self.transforms = T.Compose([
|
||||
data.PadToSquare(),
|
||||
T.Resize(scale_size),
|
||||
T.CenterCrop(image_size),
|
||||
T.ToTensor(),
|
||||
T.Normalize(mean, std)
|
||||
])
|
||||
|
||||
def __call__(self, imgs, txts, num_workers=0):
|
||||
# preprocess
|
||||
assert isinstance(imgs,
|
||||
(tuple, list)) and isinstance(imgs[0], Image.Image)
|
||||
assert isinstance(txts, (tuple, list)) and isinstance(txts[0], str)
|
||||
txt_tokens = torch.LongTensor([self.tokenizer(u) for u in txts])
|
||||
imgs = torch.stack(parallel(self.transforms, imgs, num_workers), dim=0)
|
||||
|
||||
# forward
|
||||
scores = []
|
||||
for img_batch, txt_batch in zip(
|
||||
imgs.split(self.batch_size, dim=0),
|
||||
txt_tokens.split(self.batch_size, dim=0)):
|
||||
img_batch = img_batch.to(self.device)
|
||||
txt_batch = txt_batch.to(self.device)
|
||||
xi = F.normalize(self.clip.visual(img_batch), p=2, dim=1)
|
||||
xt = F.normalize(self.clip.textual(txt_batch), p=2, dim=1)
|
||||
scores.append((xi * xt).sum(dim=1))
|
||||
return torch.cat(scores, dim=0)
|
||||
|
||||
|
||||
def taobao_feature_extractor(category='shoes', **kwargs):
|
||||
r"""Pretrained taobao-search feature extractors.
|
||||
"""
|
||||
assert category in ['softall', 'shoes', 'bag']
|
||||
checkpoint = osp.join(
|
||||
'models/inception-v1', {
|
||||
'softall': '1214softall_10.10.0.5000',
|
||||
'shoes': '1218shoes.v9_7.140.0.1520000',
|
||||
'bag': '0926bag.v9_6.29.0.140000'
|
||||
}[category])
|
||||
app = FeatureExtractor(
|
||||
model='InceptionV1',
|
||||
checkpoint=checkpoint,
|
||||
resolution=224,
|
||||
mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225],
|
||||
**kwargs)
|
||||
return app
|
||||
|
||||
|
||||
def singleton_classifier(**kwargs):
|
||||
r"""Pretrained classifier that finds single-object images.
|
||||
Supports shoes, apparel, and bag images.
|
||||
"""
|
||||
app = Classifier(
|
||||
model='InceptionV1',
|
||||
checkpoint='models/classifier/shoes+apparel+bag-sgdetect-211230.pth',
|
||||
num_classes=1,
|
||||
resolution=224,
|
||||
mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225],
|
||||
**kwargs)
|
||||
return app
|
||||
|
||||
|
||||
def orientation_classifier(**kwargs):
|
||||
r"""Shoes orientation classifier.
|
||||
"""
|
||||
app = Classifier(
|
||||
model='InceptionV1',
|
||||
checkpoint='models/classifier/shoes-oriendetect-20211026.pth',
|
||||
num_classes=1,
|
||||
resolution=224,
|
||||
mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225],
|
||||
**kwargs)
|
||||
return app
|
||||
|
||||
|
||||
def fashion_text2image(**kwargs):
|
||||
r"""Fashion text-to-image generator.
|
||||
Supports shoe and apparel image generation.
|
||||
"""
|
||||
app = Text2Image(
|
||||
vqgan_dim=128,
|
||||
vqgan_z_dim=256,
|
||||
vqgan_dim_mult=[1, 1, 2, 2, 4],
|
||||
vqgan_num_res_blocks=2,
|
||||
vqgan_attn_scales=[1.0 / 16],
|
||||
vqgan_codebook_size=975,
|
||||
vqgan_beta=0.25,
|
||||
gpt_txt_vocab_size=21128,
|
||||
gpt_txt_seq_len=64,
|
||||
gpt_img_seq_len=1024,
|
||||
gpt_dim=1024,
|
||||
gpt_num_heads=16,
|
||||
gpt_num_layers=24,
|
||||
vqgan_checkpoint= # noqa E251
|
||||
'models/vqgan/vqgan_shoes+apparels_step10k_vocab975.pth',
|
||||
gpt_checkpoint= # noqa E251
|
||||
'models/seq2seq_gpt/text2image_shoes+apparels_step400k.pth',
|
||||
tokenizer=data.BertTokenizer(name='bert-base-chinese', length=64),
|
||||
**kwargs)
|
||||
return app
|
||||
|
||||
|
||||
def mindalle_text2image(**kwargs):
|
||||
r"""Pretrained text2image generator with weights copied from minDALL-E.
|
||||
"""
|
||||
app = Text2Image(
|
||||
vqgan_dim=128,
|
||||
vqgan_z_dim=256,
|
||||
vqgan_dim_mult=[1, 1, 2, 2, 4],
|
||||
vqgan_num_res_blocks=2,
|
||||
vqgan_attn_scales=[1.0 / 16],
|
||||
vqgan_codebook_size=16384,
|
||||
vqgan_beta=0.25,
|
||||
gpt_txt_vocab_size=16384,
|
||||
gpt_txt_seq_len=64,
|
||||
gpt_img_seq_len=256,
|
||||
gpt_dim=1536,
|
||||
gpt_num_heads=24,
|
||||
gpt_num_layers=42,
|
||||
vqgan_checkpoint='models/minDALLE/1.3B_vqgan.pth',
|
||||
gpt_checkpoint='models/minDALLE/1.3B_gpt.pth',
|
||||
tokenizer=data.BPETokenizer(length=64),
|
||||
**kwargs)
|
||||
return app
|
||||
|
||||
|
||||
def sole2shoe(**kwargs):
|
||||
app = Sole2Shoe(
|
||||
vqgan_dim=128,
|
||||
vqgan_z_dim=256,
|
||||
vqgan_dim_mult=[1, 1, 2, 2, 4],
|
||||
vqgan_num_res_blocks=2,
|
||||
vqgan_attn_scales=[1.0 / 16],
|
||||
vqgan_codebook_size=975,
|
||||
vqgan_beta=0.25,
|
||||
src_resolution=256,
|
||||
tar_resolution=512,
|
||||
gpt_dim=1024,
|
||||
gpt_num_heads=16,
|
||||
gpt_num_layers=24,
|
||||
vqgan_checkpoint= # noqa E251
|
||||
'models/vqgan/vqgan_shoes+apparels_step10k_vocab975.pth',
|
||||
gpt_checkpoint='models/seq2seq_gpt/sole2shoe-step300k-220104.pth',
|
||||
**kwargs)
|
||||
return app
|
||||
|
||||
|
||||
def sole_parser(**kwargs):
|
||||
app = ImageParser(
|
||||
model='SPNet',
|
||||
num_classes=2,
|
||||
resolution=800,
|
||||
mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225],
|
||||
model_with_softmax=False,
|
||||
checkpoint='models/spnet/sole_segmentation_211219.pth',
|
||||
**kwargs)
|
||||
return app
|
||||
|
||||
|
||||
def sod_foreground_parser(**kwargs):
|
||||
app = ImageParser(
|
||||
model=None,
|
||||
num_classes=None,
|
||||
resolution=448,
|
||||
mean=[0.488431, 0.466275, 0.403686],
|
||||
std=[0.222627, 0.21949, 0.22549],
|
||||
model_with_softmax=True,
|
||||
checkpoint='models/semseg/sod_model_20201228.pt',
|
||||
**kwargs)
|
||||
return app
|
||||
|
||||
|
||||
def fashion_text_image_match(**kwargs):
|
||||
app = TextImageMatch(
|
||||
embed_dim=512,
|
||||
image_size=224,
|
||||
patch_size=32,
|
||||
vision_dim=768,
|
||||
vision_heads=12,
|
||||
vision_layers=12,
|
||||
vocab_size=21128,
|
||||
text_len=77,
|
||||
text_dim=512,
|
||||
text_heads=8,
|
||||
text_layers=12,
|
||||
mean=[0.48145466, 0.4578275, 0.40821073],
|
||||
std=[0.26862954, 0.26130258, 0.27577711],
|
||||
checkpoint='models/clip/clip_shoes+apparels_step84k_210105.pth',
|
||||
tokenizer=data.BertTokenizer(name='bert-base-chinese', length=77),
|
||||
**kwargs)
|
||||
return app
|
||||
1074
modelscope/models/cv/image_to_image_translation/ops/degradation.py
Normal file
1074
modelscope/models/cv/image_to_image_translation/ops/degradation.py
Normal file
File diff suppressed because it is too large
Load Diff
598
modelscope/models/cv/image_to_image_translation/ops/diffusion.py
Normal file
598
modelscope/models/cv/image_to_image_translation/ops/diffusion.py
Normal file
@@ -0,0 +1,598 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
|
||||
from .losses import discretized_gaussian_log_likelihood, kl_divergence
|
||||
|
||||
__all__ = ['GaussianDiffusion', 'beta_schedule']
|
||||
|
||||
|
||||
def _i(tensor, t, x):
|
||||
r"""Index tensor using t and format the output according to x.
|
||||
"""
|
||||
shape = (x.size(0), ) + (1, ) * (x.ndim - 1)
|
||||
return tensor[t].view(shape).to(x)
|
||||
|
||||
|
||||
def beta_schedule(schedule,
|
||||
num_timesteps=1000,
|
||||
init_beta=None,
|
||||
last_beta=None):
|
||||
if schedule == 'linear':
|
||||
scale = 1000.0 / num_timesteps
|
||||
init_beta = init_beta or scale * 0.0001
|
||||
last_beta = last_beta or scale * 0.02
|
||||
return torch.linspace(
|
||||
init_beta, last_beta, num_timesteps, dtype=torch.float64)
|
||||
elif schedule == 'quadratic':
|
||||
init_beta = init_beta or 0.0015
|
||||
last_beta = last_beta or 0.0195
|
||||
return torch.linspace(
|
||||
init_beta**0.5, last_beta**0.5, num_timesteps,
|
||||
dtype=torch.float64)**2
|
||||
elif schedule == 'cosine':
|
||||
betas = []
|
||||
for step in range(num_timesteps):
|
||||
t1 = step / num_timesteps
|
||||
t2 = (step + 1) / num_timesteps
|
||||
|
||||
# fn = lambda u: math.cos((u + 0.008) / 1.008 * math.pi / 2)**2
|
||||
def fn(u):
|
||||
return math.cos((u + 0.008) / 1.008 * math.pi / 2)**2
|
||||
|
||||
betas.append(min(1.0 - fn(t2) / fn(t1), 0.999))
|
||||
return torch.tensor(betas, dtype=torch.float64)
|
||||
else:
|
||||
raise ValueError(f'Unsupported schedule: {schedule}')
|
||||
|
||||
|
||||
class GaussianDiffusion(object):
|
||||
|
||||
def __init__(self,
|
||||
betas,
|
||||
mean_type='eps',
|
||||
var_type='learned_range',
|
||||
loss_type='mse',
|
||||
rescale_timesteps=False):
|
||||
# check input
|
||||
if not isinstance(betas, torch.DoubleTensor):
|
||||
betas = torch.tensor(betas, dtype=torch.float64)
|
||||
assert min(betas) > 0 and max(betas) <= 1
|
||||
assert mean_type in ['x0', 'x_{t-1}', 'eps']
|
||||
assert var_type in [
|
||||
'learned', 'learned_range', 'fixed_large', 'fixed_small'
|
||||
]
|
||||
assert loss_type in [
|
||||
'mse', 'rescaled_mse', 'kl', 'rescaled_kl', 'l1', 'rescaled_l1'
|
||||
]
|
||||
self.betas = betas
|
||||
self.num_timesteps = len(betas)
|
||||
self.mean_type = mean_type
|
||||
self.var_type = var_type
|
||||
self.loss_type = loss_type
|
||||
self.rescale_timesteps = rescale_timesteps
|
||||
|
||||
# alphas
|
||||
alphas = 1 - self.betas
|
||||
self.alphas_cumprod = torch.cumprod(alphas, dim=0)
|
||||
self.alphas_cumprod_prev = torch.cat(
|
||||
[alphas.new_ones([1]), self.alphas_cumprod[:-1]])
|
||||
self.alphas_cumprod_next = torch.cat(
|
||||
[self.alphas_cumprod[1:],
|
||||
alphas.new_zeros([1])])
|
||||
|
||||
# q(x_t | x_{t-1})
|
||||
self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
|
||||
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0
|
||||
- self.alphas_cumprod)
|
||||
self.log_one_minus_alphas_cumprod = torch.log(1.0
|
||||
- self.alphas_cumprod)
|
||||
self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod)
|
||||
self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod
|
||||
- 1)
|
||||
|
||||
# q(x_{t-1} | x_t, x_0)
|
||||
self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / (
|
||||
1.0 - self.alphas_cumprod)
|
||||
self.posterior_log_variance_clipped = torch.log(
|
||||
self.posterior_variance.clamp(1e-20))
|
||||
self.posterior_mean_coef1 = betas * torch.sqrt(
|
||||
self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
|
||||
self.posterior_mean_coef2 = (
|
||||
1.0 - self.alphas_cumprod_prev) * torch.sqrt(alphas) / (
|
||||
1.0 - self.alphas_cumprod)
|
||||
|
||||
def q_sample(self, x0, t, noise=None):
|
||||
r"""Sample from q(x_t | x_0).
|
||||
"""
|
||||
noise = torch.randn_like(x0) if noise is None else noise
|
||||
return _i(self.sqrt_alphas_cumprod, t, x0) * x0 + _i(
|
||||
self.sqrt_one_minus_alphas_cumprod, t, x0) * noise
|
||||
|
||||
def q_mean_variance(self, x0, t):
|
||||
r"""Distribution of q(x_t | x_0).
|
||||
"""
|
||||
mu = _i(self.sqrt_alphas_cumprod, t, x0) * x0
|
||||
var = _i(1.0 - self.alphas_cumprod, t, x0)
|
||||
log_var = _i(self.log_one_minus_alphas_cumprod, t, x0)
|
||||
return mu, var, log_var
|
||||
|
||||
def q_posterior_mean_variance(self, x0, xt, t):
|
||||
r"""Distribution of q(x_{t-1} | x_t, x_0).
|
||||
"""
|
||||
mu = _i(self.posterior_mean_coef1, t, xt) * x0 + _i(
|
||||
self.posterior_mean_coef2, t, xt) * xt
|
||||
var = _i(self.posterior_variance, t, xt)
|
||||
log_var = _i(self.posterior_log_variance_clipped, t, xt)
|
||||
return mu, var, log_var
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample(self,
|
||||
xt,
|
||||
t,
|
||||
model,
|
||||
model_kwargs={},
|
||||
clamp=None,
|
||||
percentile=None,
|
||||
condition_fn=None,
|
||||
guide_scale=None):
|
||||
r"""Sample from p(x_{t-1} | x_t).
|
||||
- condition_fn: for classifier-based guidance (guided-diffusion).
|
||||
- guide_scale: for classifier-free guidance (glide/dalle-2).
|
||||
"""
|
||||
# predict distribution of p(x_{t-1} | x_t)
|
||||
mu, var, log_var, x0 = self.p_mean_variance(xt, t, model, model_kwargs,
|
||||
clamp, percentile,
|
||||
guide_scale)
|
||||
|
||||
# random sample (with optional conditional function)
|
||||
noise = torch.randn_like(xt)
|
||||
# no noise when t == 0
|
||||
mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1)))
|
||||
if condition_fn is not None:
|
||||
grad = condition_fn(xt, self._scale_timesteps(t), **model_kwargs)
|
||||
mu = mu.float() + var * grad.float()
|
||||
xt_1 = mu + mask * torch.exp(0.5 * log_var) * noise
|
||||
return xt_1, x0
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_loop(self,
|
||||
noise,
|
||||
model,
|
||||
model_kwargs={},
|
||||
clamp=None,
|
||||
percentile=None,
|
||||
condition_fn=None,
|
||||
guide_scale=None):
|
||||
r"""Sample from p(x_{t-1} | x_t) p(x_{t-2} | x_{t-1}) ... p(x_0 | x_1).
|
||||
"""
|
||||
# prepare input
|
||||
b, c, h, w = noise.size()
|
||||
xt = noise
|
||||
|
||||
# diffusion process
|
||||
for step in torch.arange(self.num_timesteps).flip(0):
|
||||
t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
|
||||
xt, _ = self.p_sample(xt, t, model, model_kwargs, clamp,
|
||||
percentile, condition_fn, guide_scale)
|
||||
return xt
|
||||
|
||||
def p_mean_variance(self,
|
||||
xt,
|
||||
t,
|
||||
model,
|
||||
model_kwargs={},
|
||||
clamp=None,
|
||||
percentile=None,
|
||||
guide_scale=None):
|
||||
r"""Distribution of p(x_{t-1} | x_t).
|
||||
"""
|
||||
# predict distribution
|
||||
if guide_scale is None:
|
||||
out = model(xt, self._scale_timesteps(t), **model_kwargs)
|
||||
else:
|
||||
# classifier-free guidance
|
||||
# (model_kwargs[0]: conditional kwargs; model_kwargs[1]: non-conditional kwargs)
|
||||
assert isinstance(model_kwargs, list) and len(model_kwargs) == 2
|
||||
assert self.mean_type == 'eps'
|
||||
y_out = model(xt, self._scale_timesteps(t), **model_kwargs[0])
|
||||
u_out = model(xt, self._scale_timesteps(t), **model_kwargs[1])
|
||||
out = torch.cat(
|
||||
[
|
||||
u_out[:, :3] + guide_scale * # noqa W504
|
||||
(y_out[:, :3] - u_out[:, :3]),
|
||||
y_out[:, 3:]
|
||||
],
|
||||
dim=1)
|
||||
|
||||
# compute variance
|
||||
if self.var_type == 'learned':
|
||||
out, log_var = out.chunk(2, dim=1)
|
||||
var = torch.exp(log_var)
|
||||
elif self.var_type == 'learned_range':
|
||||
out, fraction = out.chunk(2, dim=1)
|
||||
min_log_var = _i(self.posterior_log_variance_clipped, t, xt)
|
||||
max_log_var = _i(torch.log(self.betas), t, xt)
|
||||
fraction = (fraction + 1) / 2.0
|
||||
log_var = fraction * max_log_var + (1 - fraction) * min_log_var
|
||||
var = torch.exp(log_var)
|
||||
elif self.var_type == 'fixed_large':
|
||||
var = _i(
|
||||
torch.cat([self.posterior_variance[1:2], self.betas[1:]]), t,
|
||||
xt)
|
||||
log_var = torch.log(var)
|
||||
elif self.var_type == 'fixed_small':
|
||||
var = _i(self.posterior_variance, t, xt)
|
||||
log_var = _i(self.posterior_log_variance_clipped, t, xt)
|
||||
|
||||
# compute mean and x0
|
||||
if self.mean_type == 'x_{t-1}':
|
||||
mu = out # x_{t-1}
|
||||
x0 = _i(1.0 / self.posterior_mean_coef1, t, xt) * mu - _i(
|
||||
self.posterior_mean_coef2 / self.posterior_mean_coef1, t,
|
||||
xt) * xt
|
||||
elif self.mean_type == 'x0':
|
||||
x0 = out
|
||||
mu, _, _ = self.q_posterior_mean_variance(x0, xt, t)
|
||||
elif self.mean_type == 'eps':
|
||||
x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - _i(
|
||||
self.sqrt_recipm1_alphas_cumprod, t, xt) * out
|
||||
mu, _, _ = self.q_posterior_mean_variance(x0, xt, t)
|
||||
|
||||
# restrict the range of x0
|
||||
if percentile is not None:
|
||||
assert percentile > 0 and percentile <= 1 # e.g., 0.995
|
||||
s = torch.quantile(
|
||||
x0.flatten(1).abs(), percentile,
|
||||
dim=1).clamp_(1.0).view(-1, 1, 1, 1)
|
||||
x0 = torch.min(s, torch.max(-s, x0)) / s
|
||||
elif clamp is not None:
|
||||
x0 = x0.clamp(-clamp, clamp)
|
||||
return mu, var, log_var, x0
|
||||
|
||||
@torch.no_grad()
|
||||
def ddim_sample(self,
|
||||
xt,
|
||||
t,
|
||||
model,
|
||||
model_kwargs={},
|
||||
clamp=None,
|
||||
percentile=None,
|
||||
condition_fn=None,
|
||||
guide_scale=None,
|
||||
ddim_timesteps=20,
|
||||
eta=0.0):
|
||||
stride = self.num_timesteps // ddim_timesteps
|
||||
|
||||
# predict distribution of p(x_{t-1} | x_t)
|
||||
_, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp,
|
||||
percentile, guide_scale)
|
||||
if condition_fn is not None:
|
||||
# x0 -> eps
|
||||
alpha = _i(self.alphas_cumprod, t, xt)
|
||||
eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / _i(
|
||||
self.sqrt_recipm1_alphas_cumprod, t, xt)
|
||||
eps = eps - (1 - alpha).sqrt() * condition_fn(
|
||||
xt, self._scale_timesteps(t), **model_kwargs)
|
||||
|
||||
# eps -> x0
|
||||
x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - _i(
|
||||
self.sqrt_recipm1_alphas_cumprod, t, xt) * eps
|
||||
|
||||
# derive variables
|
||||
eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / _i(
|
||||
self.sqrt_recipm1_alphas_cumprod, t, xt)
|
||||
alphas = _i(self.alphas_cumprod, t, xt)
|
||||
alphas_prev = _i(self.alphas_cumprod, (t - stride).clamp(0), xt)
|
||||
sigmas = eta * torch.sqrt((1 - alphas_prev) / # noqa W504
|
||||
(1 - alphas) * # noqa W504
|
||||
(1 - alphas / alphas_prev))
|
||||
|
||||
# random sample
|
||||
noise = torch.randn_like(xt)
|
||||
direction = torch.sqrt(1 - alphas_prev - sigmas**2) * eps
|
||||
mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1)))
|
||||
xt_1 = torch.sqrt(alphas_prev) * x0 + direction + mask * sigmas * noise
|
||||
return xt_1, x0
|
||||
|
||||
@torch.no_grad()
|
||||
def ddim_sample_loop(self,
|
||||
noise,
|
||||
model,
|
||||
model_kwargs={},
|
||||
clamp=None,
|
||||
percentile=None,
|
||||
condition_fn=None,
|
||||
guide_scale=None,
|
||||
ddim_timesteps=20,
|
||||
eta=0.0):
|
||||
# prepare input
|
||||
b, c, h, w = noise.size()
|
||||
xt = noise
|
||||
|
||||
# diffusion process (TODO: clamp is inaccurate! Consider replacing the stride by explicit prev/next steps)
|
||||
steps = (1 + torch.arange(0, self.num_timesteps,
|
||||
self.num_timesteps // ddim_timesteps)).clamp(
|
||||
0, self.num_timesteps - 1).flip(0)
|
||||
for step in steps:
|
||||
t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
|
||||
xt, _ = self.ddim_sample(xt, t, model, model_kwargs, clamp,
|
||||
percentile, condition_fn, guide_scale,
|
||||
ddim_timesteps, eta)
|
||||
return xt
|
||||
|
||||
@torch.no_grad()
|
||||
def ddim_reverse_sample(self,
|
||||
xt,
|
||||
t,
|
||||
model,
|
||||
model_kwargs={},
|
||||
clamp=None,
|
||||
percentile=None,
|
||||
guide_scale=None,
|
||||
ddim_timesteps=20):
|
||||
r"""Sample from p(x_{t+1} | x_t) using DDIM reverse ODE (deterministic).
|
||||
"""
|
||||
stride = self.num_timesteps // ddim_timesteps
|
||||
|
||||
# predict distribution of p(x_{t-1} | x_t)
|
||||
_, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp,
|
||||
percentile, guide_scale)
|
||||
|
||||
# derive variables
|
||||
eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / _i(
|
||||
self.sqrt_recipm1_alphas_cumprod, t, xt)
|
||||
alphas_next = _i(
|
||||
torch.cat(
|
||||
[self.alphas_cumprod,
|
||||
self.alphas_cumprod.new_zeros([1])]),
|
||||
(t + stride).clamp(0, self.num_timesteps), xt)
|
||||
|
||||
# reverse sample
|
||||
mu = torch.sqrt(alphas_next) * x0 + torch.sqrt(1 - alphas_next) * eps
|
||||
return mu, x0
|
||||
|
||||
@torch.no_grad()
|
||||
def ddim_reverse_sample_loop(self,
|
||||
x0,
|
||||
model,
|
||||
model_kwargs={},
|
||||
clamp=None,
|
||||
percentile=None,
|
||||
guide_scale=None,
|
||||
ddim_timesteps=20):
|
||||
# prepare input
|
||||
b, c, h, w = x0.size()
|
||||
xt = x0
|
||||
|
||||
# reconstruction steps
|
||||
steps = torch.arange(0, self.num_timesteps,
|
||||
self.num_timesteps // ddim_timesteps)
|
||||
for step in steps:
|
||||
t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
|
||||
xt, _ = self.ddim_reverse_sample(xt, t, model, model_kwargs, clamp,
|
||||
percentile, guide_scale,
|
||||
ddim_timesteps)
|
||||
return xt
|
||||
|
||||
@torch.no_grad()
|
||||
def plms_sample(self,
|
||||
xt,
|
||||
t,
|
||||
model,
|
||||
model_kwargs={},
|
||||
clamp=None,
|
||||
percentile=None,
|
||||
condition_fn=None,
|
||||
guide_scale=None,
|
||||
plms_timesteps=20):
|
||||
r"""Sample from p(x_{t-1} | x_t) using PLMS.
|
||||
- condition_fn: for classifier-based guidance (guided-diffusion).
|
||||
- guide_scale: for classifier-free guidance (glide/dalle-2).
|
||||
"""
|
||||
stride = self.num_timesteps // plms_timesteps
|
||||
|
||||
# function for compute eps
|
||||
def compute_eps(xt, t):
|
||||
# predict distribution of p(x_{t-1} | x_t)
|
||||
_, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs,
|
||||
clamp, percentile, guide_scale)
|
||||
|
||||
# condition
|
||||
if condition_fn is not None:
|
||||
# x0 -> eps
|
||||
alpha = _i(self.alphas_cumprod, t, xt)
|
||||
eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt
|
||||
- x0) / _i(self.sqrt_recipm1_alphas_cumprod, t, xt)
|
||||
eps = eps - (1 - alpha).sqrt() * condition_fn(
|
||||
xt, self._scale_timesteps(t), **model_kwargs)
|
||||
|
||||
# eps -> x0
|
||||
x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - _i(
|
||||
self.sqrt_recipm1_alphas_cumprod, t, xt) * eps
|
||||
|
||||
# derive eps
|
||||
eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / _i(
|
||||
self.sqrt_recipm1_alphas_cumprod, t, xt)
|
||||
return eps
|
||||
|
||||
# function for compute x_0 and x_{t-1}
|
||||
def compute_x0(eps, t):
|
||||
# eps -> x0
|
||||
x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - _i(
|
||||
self.sqrt_recipm1_alphas_cumprod, t, xt) * eps
|
||||
|
||||
# deterministic sample
|
||||
alphas_prev = _i(self.alphas_cumprod, (t - stride).clamp(0), xt)
|
||||
direction = torch.sqrt(1 - alphas_prev) * eps
|
||||
# mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1)))
|
||||
xt_1 = torch.sqrt(alphas_prev) * x0 + direction
|
||||
return xt_1, x0
|
||||
|
||||
# PLMS sample
|
||||
eps = compute_eps(xt, t)
|
||||
if len(eps_cache) == 0:
|
||||
# 2nd order pseudo improved Euler
|
||||
xt_1, x0 = compute_x0(eps, t)
|
||||
eps_next = compute_eps(xt_1, (t - stride).clamp(0))
|
||||
eps_prime = (eps + eps_next) / 2.0
|
||||
elif len(eps_cache) == 1:
|
||||
# 2nd order pseudo linear multistep (Adams-Bashforth)
|
||||
eps_prime = (3 * eps - eps_cache[-1]) / 2.0
|
||||
elif len(eps_cache) == 2:
|
||||
# 3nd order pseudo linear multistep (Adams-Bashforth)
|
||||
eps_prime = (23 * eps - 16 * eps_cache[-1]
|
||||
+ 5 * eps_cache[-2]) / 12.0
|
||||
elif len(eps_cache) >= 3:
|
||||
# 4nd order pseudo linear multistep (Adams-Bashforth)
|
||||
eps_prime = (55 * eps - 59 * eps_cache[-1] + 37 * eps_cache[-2]
|
||||
- 9 * eps_cache[-3]) / 24.0
|
||||
xt_1, x0 = compute_x0(eps_prime, t)
|
||||
return xt_1, x0, eps
|
||||
|
||||
@torch.no_grad()
|
||||
def plms_sample_loop(self,
|
||||
noise,
|
||||
model,
|
||||
model_kwargs={},
|
||||
clamp=None,
|
||||
percentile=None,
|
||||
condition_fn=None,
|
||||
guide_scale=None,
|
||||
plms_timesteps=20):
|
||||
# prepare input
|
||||
b, c, h, w = noise.size()
|
||||
xt = noise
|
||||
|
||||
# diffusion process
|
||||
steps = (1 + torch.arange(0, self.num_timesteps,
|
||||
self.num_timesteps // plms_timesteps)).clamp(
|
||||
0, self.num_timesteps - 1).flip(0)
|
||||
eps_cache = []
|
||||
for step in steps:
|
||||
# PLMS sampling step
|
||||
t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
|
||||
xt, _, eps = self.plms_sample(xt, t, model, model_kwargs, clamp,
|
||||
percentile, condition_fn,
|
||||
guide_scale, plms_timesteps,
|
||||
eps_cache)
|
||||
|
||||
# update eps cache
|
||||
eps_cache.append(eps)
|
||||
if len(eps_cache) >= 4:
|
||||
eps_cache.pop(0)
|
||||
return xt
|
||||
|
||||
def loss(self, x0, t, model, model_kwargs={}, noise=None):
|
||||
noise = torch.randn_like(x0) if noise is None else noise
|
||||
xt = self.q_sample(x0, t, noise=noise)
|
||||
|
||||
# compute loss
|
||||
if self.loss_type in ['kl', 'rescaled_kl']:
|
||||
loss, _ = self.variational_lower_bound(x0, xt, t, model,
|
||||
model_kwargs)
|
||||
if self.loss_type == 'rescaled_kl':
|
||||
loss = loss * self.num_timesteps
|
||||
elif self.loss_type in ['mse', 'rescaled_mse', 'l1', 'rescaled_l1']:
|
||||
out = model(xt, self._scale_timesteps(t), **model_kwargs)
|
||||
|
||||
# VLB for variation
|
||||
loss_vlb = 0.0
|
||||
if self.var_type in ['learned', 'learned_range']:
|
||||
out, var = out.chunk(2, dim=1)
|
||||
frozen = torch.cat([
|
||||
out.detach(), var
|
||||
], dim=1) # learn var without affecting the prediction of mean
|
||||
loss_vlb, _ = self.variational_lower_bound(
|
||||
x0, xt, t, model=lambda *args, **kwargs: frozen)
|
||||
if self.loss_type.startswith('rescaled_'):
|
||||
loss_vlb = loss_vlb * self.num_timesteps / 1000.0
|
||||
|
||||
# MSE/L1 for x0/eps
|
||||
target = {
|
||||
'eps': noise,
|
||||
'x0': x0,
|
||||
'x_{t-1}': self.q_posterior_mean_variance(x0, xt, t)[0]
|
||||
}[self.mean_type]
|
||||
loss = (out - target).pow(1 if self.loss_type.endswith('l1') else 2
|
||||
).abs().flatten(1).mean(dim=1)
|
||||
|
||||
# total loss
|
||||
loss = loss + loss_vlb
|
||||
return loss
|
||||
|
||||
def variational_lower_bound(self,
|
||||
x0,
|
||||
xt,
|
||||
t,
|
||||
model,
|
||||
model_kwargs={},
|
||||
clamp=None,
|
||||
percentile=None):
|
||||
# compute groundtruth and predicted distributions
|
||||
mu1, _, log_var1 = self.q_posterior_mean_variance(x0, xt, t)
|
||||
mu2, _, log_var2, x0 = self.p_mean_variance(xt, t, model, model_kwargs,
|
||||
clamp, percentile)
|
||||
|
||||
# compute KL loss
|
||||
kl = kl_divergence(mu1, log_var1, mu2, log_var2)
|
||||
kl = kl.flatten(1).mean(dim=1) / math.log(2.0)
|
||||
|
||||
# compute discretized NLL loss (for p(x0 | x1) only)
|
||||
nll = -discretized_gaussian_log_likelihood(
|
||||
x0, mean=mu2, log_scale=0.5 * log_var2)
|
||||
nll = nll.flatten(1).mean(dim=1) / math.log(2.0)
|
||||
|
||||
# NLL for p(x0 | x1) and KL otherwise
|
||||
vlb = torch.where(t == 0, nll, kl)
|
||||
return vlb, x0
|
||||
|
||||
@torch.no_grad()
|
||||
def variational_lower_bound_loop(self,
|
||||
x0,
|
||||
model,
|
||||
model_kwargs={},
|
||||
clamp=None,
|
||||
percentile=None):
|
||||
r"""Compute the entire variational lower bound, measured in bits-per-dim.
|
||||
"""
|
||||
# prepare input and output
|
||||
b, c, h, w = x0.size()
|
||||
metrics = {'vlb': [], 'mse': [], 'x0_mse': []}
|
||||
|
||||
# loop
|
||||
for step in torch.arange(self.num_timesteps).flip(0):
|
||||
# compute VLB
|
||||
t = torch.full((b, ), step, dtype=torch.long, device=x0.device)
|
||||
noise = torch.randn_like(x0)
|
||||
xt = self.q_sample(x0, t, noise)
|
||||
vlb, pred_x0 = self.variational_lower_bound(
|
||||
x0, xt, t, model, model_kwargs, clamp, percentile)
|
||||
|
||||
# predict eps from x0
|
||||
eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / _i(
|
||||
self.sqrt_recipm1_alphas_cumprod, t, xt)
|
||||
|
||||
# collect metrics
|
||||
metrics['vlb'].append(vlb)
|
||||
metrics['x0_mse'].append(
|
||||
(pred_x0 - x0).square().flatten(1).mean(dim=1))
|
||||
metrics['mse'].append(
|
||||
(eps - noise).square().flatten(1).mean(dim=1))
|
||||
metrics = {k: torch.stack(v, dim=1) for k, v in metrics.items()}
|
||||
|
||||
# compute the prior KL term for VLB, measured in bits-per-dim
|
||||
mu, _, log_var = self.q_mean_variance(x0, t)
|
||||
kl_prior = kl_divergence(mu, log_var, torch.zeros_like(mu),
|
||||
torch.zeros_like(log_var))
|
||||
kl_prior = kl_prior.flatten(1).mean(dim=1) / math.log(2.0)
|
||||
|
||||
# update metrics
|
||||
metrics['prior_bits_per_dim'] = kl_prior
|
||||
metrics['total_bits_per_dim'] = metrics['vlb'].sum(dim=1) + kl_prior
|
||||
return metrics
|
||||
|
||||
def _scale_timesteps(self, t):
|
||||
if self.rescale_timesteps:
|
||||
return t.float() * 1000.0 / self.num_timesteps
|
||||
return t
|
||||
@@ -0,0 +1,35 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
|
||||
__all__ = ['kl_divergence', 'discretized_gaussian_log_likelihood']
|
||||
|
||||
|
||||
def kl_divergence(mu1, logvar1, mu2, logvar2):
|
||||
return 0.5 * (
|
||||
-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + # noqa W504
|
||||
((mu1 - mu2)**2) * torch.exp(-logvar2))
|
||||
|
||||
|
||||
def standard_normal_cdf(x):
|
||||
r"""A fast approximation of the cumulative distribution function of the standard normal.
|
||||
"""
|
||||
return 0.5 * (1.0 + torch.tanh(
|
||||
math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
||||
|
||||
|
||||
def discretized_gaussian_log_likelihood(x0, mean, log_scale):
|
||||
assert x0.shape == mean.shape == log_scale.shape
|
||||
cx = x0 - mean
|
||||
inv_stdv = torch.exp(-log_scale)
|
||||
cdf_plus = standard_normal_cdf(inv_stdv * (cx + 1.0 / 255.0))
|
||||
cdf_min = standard_normal_cdf(inv_stdv * (cx - 1.0 / 255.0))
|
||||
log_cdf_plus = torch.log(cdf_plus.clamp(min=1e-12))
|
||||
log_one_minus_cdf_min = torch.log((1.0 - cdf_min).clamp(min=1e-12))
|
||||
cdf_delta = cdf_plus - cdf_min
|
||||
log_probs = torch.where(
|
||||
x0 < -0.999, log_cdf_plus,
|
||||
torch.where(x0 > 0.999, log_one_minus_cdf_min,
|
||||
torch.log(cdf_delta.clamp(min=1e-12))))
|
||||
assert log_probs.shape == x0.shape
|
||||
return log_probs
|
||||
126
modelscope/models/cv/image_to_image_translation/ops/metrics.py
Normal file
126
modelscope/models/cv/image_to_image_translation/ops/metrics.py
Normal file
@@ -0,0 +1,126 @@
|
||||
import numpy as np
|
||||
import scipy.linalg as linalg
|
||||
import torch
|
||||
|
||||
__all__ = [
|
||||
'get_fid_net', 'get_is_net', 'compute_fid', 'compute_prdc', 'compute_is'
|
||||
]
|
||||
|
||||
|
||||
def get_fid_net(resize_input=True, normalize_input=True):
|
||||
r"""InceptionV3 network for the evaluation of Fréchet Inception Distance (FID).
|
||||
|
||||
Args:
|
||||
resize_input: whether or not to resize the input to (299, 299).
|
||||
normalize_input: whether or not to normalize the input from range (0, 1) to range(-1, 1).
|
||||
"""
|
||||
from artist.models import InceptionV3
|
||||
return InceptionV3(
|
||||
output_blocks=(3, ),
|
||||
resize_input=resize_input,
|
||||
normalize_input=normalize_input,
|
||||
requires_grad=False,
|
||||
use_fid_inception=True).eval().requires_grad_(False)
|
||||
|
||||
|
||||
def get_is_net(resize_input=True, normalize_input=True):
|
||||
r"""InceptionV3 network for the evaluation of Inception Score (IS).
|
||||
|
||||
Args:
|
||||
resize_input: whether or not to resize the input to (299, 299).
|
||||
normalize_input: whether or not to normalize the input from range (0, 1) to range(-1, 1).
|
||||
"""
|
||||
from artist.models import InceptionV3
|
||||
return InceptionV3(
|
||||
output_blocks=(4, ),
|
||||
resize_input=resize_input,
|
||||
normalize_input=normalize_input,
|
||||
requires_grad=False,
|
||||
use_fid_inception=False).eval().requires_grad_(False)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def compute_fid(real_feats, fake_feats, eps=1e-6):
|
||||
r"""Compute Fréchet Inception Distance (FID).
|
||||
|
||||
Args:
|
||||
real_feats: [N, C].
|
||||
fake_feats: [N, C].
|
||||
"""
|
||||
# check inputs
|
||||
if isinstance(real_feats, torch.Tensor):
|
||||
real_feats = real_feats.cpu().numpy().astype(np.float_)
|
||||
if isinstance(fake_feats, torch.Tensor):
|
||||
fake_feats = fake_feats.cpu().numpy().astype(np.float_)
|
||||
|
||||
# real statistics
|
||||
mu1 = np.mean(real_feats, axis=0)
|
||||
sigma1 = np.cov(real_feats, rowvar=False)
|
||||
|
||||
# fake statistics
|
||||
mu2 = np.mean(fake_feats, axis=0)
|
||||
sigma2 = np.cov(fake_feats, rowvar=False)
|
||||
|
||||
# compute covmean
|
||||
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
|
||||
if not np.isfinite(covmean).all():
|
||||
print(
|
||||
f'FID calculation produces singular product; adding {eps} to diagonal of cov',
|
||||
flush=True)
|
||||
offset = np.eye(sigma1.shape[0]) * eps
|
||||
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
|
||||
|
||||
# numerical error might give slight imaginary component
|
||||
if np.iscomplexobj(covmean):
|
||||
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
|
||||
m = np.max(np.abs(covmean.imag))
|
||||
raise ValueError('Imaginary component {}'.format(m))
|
||||
covmean = covmean.real
|
||||
|
||||
# compute Fréchet distance
|
||||
diff = mu1 - mu2
|
||||
fid = diff.dot(diff) + np.trace(sigma1) + np.trace(
|
||||
sigma2) - 2 * np.trace(covmean)
|
||||
return fid.item()
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def compute_prdc(real_feats, fake_feats, knn=5):
|
||||
r"""Compute precision, recall, density, and coverage given two manifolds.
|
||||
|
||||
Args:
|
||||
real_feats: [N, C].
|
||||
fake_feats: [N, C].
|
||||
knn: the number of nearest neighbors to consider.
|
||||
"""
|
||||
# distances
|
||||
real_kth = -(-torch.cdist(real_feats, real_feats)).topk(
|
||||
k=knn, dim=1)[0][:, -1]
|
||||
fake_kth = -(-torch.cdist(fake_feats, fake_feats)).topk(
|
||||
k=knn, dim=1)[0][:, -1]
|
||||
dists = torch.cdist(real_feats, fake_feats)
|
||||
|
||||
# metrics
|
||||
precision = (dists < real_kth.unsqueeze(1)).any(
|
||||
dim=0).float().mean().item()
|
||||
recall = (dists < fake_kth.unsqueeze(0)).any(dim=1).float().mean().item()
|
||||
density = (dists < real_kth.unsqueeze(1)).float().sum(
|
||||
dim=0).mean().item() / knn
|
||||
coverage = (dists.min(dim=1)[0] < real_kth).float().mean().item()
|
||||
return precision, recall, density, coverage
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def compute_is(logits, num_splits=10):
|
||||
preds = logits.softmax(dim=1).cpu().numpy()
|
||||
split_scores = []
|
||||
for k in range(num_splits):
|
||||
part = preds[k * (len(logits) // num_splits):(k + 1)
|
||||
* (len(logits) // num_splits), :]
|
||||
py = np.mean(part, axis=0)
|
||||
scores = []
|
||||
for i in range(part.shape[0]):
|
||||
pyx = part[i, :]
|
||||
scores.append(entropy(pyx, py))
|
||||
split_scores.append(np.exp(np.mean(scores)))
|
||||
return np.mean(split_scores), np.std(split_scores)
|
||||
@@ -0,0 +1,220 @@
|
||||
import colorsys
|
||||
import random
|
||||
|
||||
__all__ = ['RandomColor', 'rand_color']
|
||||
|
||||
COLORMAP = {
|
||||
'blue': {
|
||||
'hue_range': [179, 257],
|
||||
'lower_bounds': [[20, 100], [30, 86], [40, 80], [50, 74], [60, 60],
|
||||
[70, 52], [80, 44], [90, 39], [100, 35]]
|
||||
},
|
||||
'green': {
|
||||
'hue_range': [63, 178],
|
||||
'lower_bounds': [[30, 100], [40, 90], [50, 85], [60, 81], [70, 74],
|
||||
[80, 64], [90, 50], [100, 40]]
|
||||
},
|
||||
'monochrome': {
|
||||
'hue_range': [0, 0],
|
||||
'lower_bounds': [[0, 0], [100, 0]]
|
||||
},
|
||||
'orange': {
|
||||
'hue_range': [19, 46],
|
||||
'lower_bounds': [[20, 100], [30, 93], [40, 88], [50, 86], [60, 85],
|
||||
[70, 70], [100, 70]]
|
||||
},
|
||||
'pink': {
|
||||
'hue_range': [283, 334],
|
||||
'lower_bounds': [[20, 100], [30, 90], [40, 86], [60, 84], [80, 80],
|
||||
[90, 75], [100, 73]]
|
||||
},
|
||||
'purple': {
|
||||
'hue_range': [258, 282],
|
||||
'lower_bounds': [[20, 100], [30, 87], [40, 79], [50, 70], [60, 65],
|
||||
[70, 59], [80, 52], [90, 45], [100, 42]]
|
||||
},
|
||||
'red': {
|
||||
'hue_range': [-26, 18],
|
||||
'lower_bounds': [[20, 100], [30, 92], [40, 89], [50, 85], [60, 78],
|
||||
[70, 70], [80, 60], [90, 55], [100, 50]]
|
||||
},
|
||||
'yellow': {
|
||||
'hue_range': [47, 62],
|
||||
'lower_bounds': [[25, 100], [40, 94], [50, 89], [60, 86], [70, 84],
|
||||
[80, 82], [90, 80], [100, 75]]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class RandomColor(object):
|
||||
|
||||
def __init__(self, seed=None):
|
||||
self.colormap = COLORMAP
|
||||
self.random = random.Random(seed)
|
||||
|
||||
for color_name, color_attrs in self.colormap.items():
|
||||
lower_bounds = color_attrs['lower_bounds']
|
||||
s_min = lower_bounds[0][0]
|
||||
s_max = lower_bounds[len(lower_bounds) - 1][0]
|
||||
|
||||
b_min = lower_bounds[len(lower_bounds) - 1][1]
|
||||
b_max = lower_bounds[0][1]
|
||||
|
||||
self.colormap[color_name]['saturation_range'] = [s_min, s_max]
|
||||
self.colormap[color_name]['brightness_range'] = [b_min, b_max]
|
||||
|
||||
def generate(self, hue=None, luminosity=None, count=1, format_='hex'):
|
||||
colors = []
|
||||
for _ in range(count):
|
||||
# First we pick a hue (H)
|
||||
H = self.pick_hue(hue)
|
||||
|
||||
# Then use H to determine saturation (S)
|
||||
S = self.pick_saturation(H, hue, luminosity)
|
||||
|
||||
# Then use S and H to determine brightness (B).
|
||||
B = self.pick_brightness(H, S, luminosity)
|
||||
|
||||
# Then we return the HSB color in the desired format
|
||||
colors.append(self.set_format([H, S, B], format_))
|
||||
|
||||
return colors
|
||||
|
||||
def pick_hue(self, hue):
|
||||
hue_range = self.get_hue_range(hue)
|
||||
hue = self.random_within(hue_range)
|
||||
|
||||
# Instead of storing red as two seperate ranges,
|
||||
# we group them, using negative numbers
|
||||
if (hue < 0):
|
||||
hue += 360
|
||||
|
||||
return hue
|
||||
|
||||
def pick_saturation(self, hue, hue_name, luminosity):
|
||||
|
||||
if luminosity == 'random':
|
||||
return self.random_within([0, 100])
|
||||
|
||||
if hue_name == 'monochrome':
|
||||
return 0
|
||||
|
||||
saturation_range = self.get_saturation_range(hue)
|
||||
|
||||
s_min = saturation_range[0]
|
||||
s_max = saturation_range[1]
|
||||
|
||||
if luminosity == 'bright':
|
||||
s_min = 55
|
||||
elif luminosity == 'dark':
|
||||
s_min = s_max - 10
|
||||
elif luminosity == 'light':
|
||||
s_max = 55
|
||||
|
||||
return self.random_within([s_min, s_max])
|
||||
|
||||
def pick_brightness(self, H, S, luminosity):
|
||||
b_min = self.get_minimum_brightness(H, S)
|
||||
b_max = 100
|
||||
|
||||
if luminosity == 'dark':
|
||||
b_max = b_min + 20
|
||||
elif luminosity == 'light':
|
||||
b_min = (b_max + b_min) / 2
|
||||
elif luminosity == 'random':
|
||||
b_min = 0
|
||||
b_max = 100
|
||||
|
||||
return self.random_within([b_min, b_max])
|
||||
|
||||
def set_format(self, hsv, format_):
|
||||
if 'hsv' in format_:
|
||||
color = hsv
|
||||
elif 'rgb' in format_:
|
||||
color = self.hsv_to_rgb(hsv)
|
||||
elif 'hex' in format_:
|
||||
r, g, b = self.hsv_to_rgb(hsv)
|
||||
return '#%02x%02x%02x' % (r, g, b)
|
||||
else:
|
||||
return 'unrecognized format'
|
||||
|
||||
if 'Array' in format_ or format_ == 'hex':
|
||||
return color
|
||||
else:
|
||||
prefix = format_[:3]
|
||||
color_values = [str(x) for x in color]
|
||||
return '%s(%s)' % (prefix, ', '.join(color_values))
|
||||
|
||||
def get_minimum_brightness(self, H, S):
|
||||
lower_bounds = self.get_color_info(H)['lower_bounds']
|
||||
|
||||
for i in range(len(lower_bounds) - 1):
|
||||
s1 = lower_bounds[i][0]
|
||||
v1 = lower_bounds[i][1]
|
||||
|
||||
s2 = lower_bounds[i + 1][0]
|
||||
v2 = lower_bounds[i + 1][1]
|
||||
|
||||
if s1 <= S <= s2:
|
||||
m = (v2 - v1) / (s2 - s1)
|
||||
b = v1 - m * s1
|
||||
|
||||
return m * S + b
|
||||
|
||||
return 0
|
||||
|
||||
def get_hue_range(self, color_input):
|
||||
if color_input and color_input.isdigit():
|
||||
number = int(color_input)
|
||||
|
||||
if 0 < number < 360:
|
||||
return [number, number]
|
||||
|
||||
elif color_input and color_input in self.colormap:
|
||||
color = self.colormap[color_input]
|
||||
if 'hue_range' in color:
|
||||
return color['hue_range']
|
||||
|
||||
else:
|
||||
return [0, 360]
|
||||
|
||||
def get_saturation_range(self, hue):
|
||||
return self.get_color_info(hue)['saturation_range']
|
||||
|
||||
def get_color_info(self, hue):
|
||||
# Maps red colors to make picking hue easier
|
||||
if 334 <= hue <= 360:
|
||||
hue -= 360
|
||||
|
||||
for color_name, color in self.colormap.items():
|
||||
if color['hue_range'] and color['hue_range'][0] <= hue <= color[
|
||||
'hue_range'][1]:
|
||||
return self.colormap[color_name]
|
||||
|
||||
# this should probably raise an exception
|
||||
return 'Color not found'
|
||||
|
||||
def random_within(self, r):
|
||||
return self.random.randint(int(r[0]), int(r[1]))
|
||||
|
||||
@classmethod
|
||||
def hsv_to_rgb(cls, hsv):
|
||||
h, s, v = hsv
|
||||
h = 1 if h == 0 else h
|
||||
h = 359 if h == 360 else h
|
||||
|
||||
h = float(h) / 360
|
||||
s = float(s) / 100
|
||||
v = float(v) / 100
|
||||
|
||||
rgb = colorsys.hsv_to_rgb(h, s, v)
|
||||
return [int(c * 255) for c in rgb]
|
||||
|
||||
|
||||
def rand_color():
|
||||
generator = RandomColor()
|
||||
hue = random.choice(list(COLORMAP.keys()))
|
||||
color = generator.generate(hue=hue, count=1, format_='rgb')[0]
|
||||
color = color[color.find('(') + 1:color.find(')')]
|
||||
color = tuple([int(u) for u in color.split(',')])
|
||||
return color
|
||||
@@ -0,0 +1,79 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
__all__ = ['make_irregular_mask', 'make_rectangle_mask', 'make_uncrop']
|
||||
|
||||
|
||||
def make_irregular_mask(w,
|
||||
h,
|
||||
max_angle=4,
|
||||
max_length=200,
|
||||
max_width=100,
|
||||
min_strokes=1,
|
||||
max_strokes=5,
|
||||
mode='line'):
|
||||
# initialize mask
|
||||
assert mode in ['line', 'circle', 'square']
|
||||
mask = np.zeros((h, w), np.float32)
|
||||
|
||||
# draw strokes
|
||||
num_strokes = np.random.randint(min_strokes, max_strokes + 1)
|
||||
for i in range(num_strokes):
|
||||
x1 = np.random.randint(w)
|
||||
y1 = np.random.randint(h)
|
||||
for j in range(1 + np.random.randint(5)):
|
||||
angle = 0.01 + np.random.randint(max_angle)
|
||||
if i % 2 == 0:
|
||||
angle = 2 * 3.1415926 - angle
|
||||
length = 10 + np.random.randint(max_length)
|
||||
radius = 5 + np.random.randint(max_width)
|
||||
x2 = np.clip((x1 + length * np.sin(angle)).astype(np.int32), 0, w)
|
||||
y2 = np.clip((y1 + length * np.cos(angle)).astype(np.int32), 0, h)
|
||||
if mode == 'line':
|
||||
cv2.line(mask, (x1, y1), (x2, y2), 1.0, radius)
|
||||
elif mode == 'circle':
|
||||
cv2.circle(
|
||||
mask, (x1, y1), radius=radius, color=1.0, thickness=-1)
|
||||
elif mode == 'square':
|
||||
radius = radius // 2
|
||||
mask[y1 - radius:y1 + radius, x1 - radius:x1 + radius] = 1
|
||||
x1, y1 = x2, y2
|
||||
return mask
|
||||
|
||||
|
||||
def make_rectangle_mask(w,
|
||||
h,
|
||||
margin=10,
|
||||
min_size=30,
|
||||
max_size=150,
|
||||
min_strokes=1,
|
||||
max_strokes=4):
|
||||
# initialize mask
|
||||
mask = np.zeros((h, w), np.float32)
|
||||
|
||||
# draw rectangles
|
||||
num_strokes = np.random.randint(min_strokes, max_strokes + 1)
|
||||
for i in range(num_strokes):
|
||||
box_w = np.random.randint(min_size, max_size)
|
||||
box_h = np.random.randint(min_size, max_size)
|
||||
x1 = np.random.randint(margin, w - margin - box_w + 1)
|
||||
y1 = np.random.randint(margin, h - margin - box_h + 1)
|
||||
mask[y1:y1 + box_h, x1:x1 + box_w] = 1
|
||||
return mask
|
||||
|
||||
|
||||
def make_uncrop(w, h):
|
||||
# initialize mask
|
||||
mask = np.zeros((h, w), np.float32)
|
||||
|
||||
# randomly halve the image
|
||||
side = np.random.choice([0, 1, 2, 3])
|
||||
if side == 0:
|
||||
mask[:h // 2, :] = 1
|
||||
elif side == 1:
|
||||
mask[h // 2:, :] = 1
|
||||
elif side == 2:
|
||||
mask[:, :w // 2] = 1
|
||||
elif side == 2:
|
||||
mask[:, w // 2:] = 1
|
||||
return mask
|
||||
152
modelscope/models/cv/image_to_image_translation/ops/svd.py
Normal file
152
modelscope/models/cv/image_to_image_translation/ops/svd.py
Normal file
@@ -0,0 +1,152 @@
|
||||
r"""SVD of linear degradation matrices described in the paper
|
||||
``Denoising Diffusion Restoration Models.''
|
||||
@article{kawar2022denoising,
|
||||
title={Denoising Diffusion Restoration Models},
|
||||
author={Bahjat Kawar and Michael Elad and Stefano Ermon and Jiaming Song},
|
||||
year={2022},
|
||||
journal={arXiv preprint arXiv:2201.11793},
|
||||
}
|
||||
"""
|
||||
import torch
|
||||
|
||||
__all__ = ['SVD', 'IdentitySVD', 'DenoiseSVD', 'ColorizationSVD']
|
||||
|
||||
|
||||
class SVD(object):
|
||||
r"""SVD decomposition of a matrix, i.e., H = UDV^T.
|
||||
NOTE: assume that all inputs (i.e., h, x) are of shape [B, CHW].
|
||||
"""
|
||||
|
||||
def __init__(self, h):
|
||||
self.u, self.d, self.v = torch.svd(h, some=False)
|
||||
self.ut = self.u.t()
|
||||
self.vt = self.v.t()
|
||||
self.d[self.d < 1e-3] = 0
|
||||
|
||||
def U(self, x):
|
||||
return torch.matmul(self.u, x)
|
||||
|
||||
def Ut(self, x):
|
||||
return torch.matmul(self.ut, x)
|
||||
|
||||
def V(self, x):
|
||||
return torch.matmul(self.v, x)
|
||||
|
||||
def Vt(self, x):
|
||||
return torch.matmul(self.vt, x)
|
||||
|
||||
@property
|
||||
def D(self):
|
||||
return self.d
|
||||
|
||||
def H(self, x):
|
||||
return self.U(self.D * self.Vt(x)[:, :self.D.size(0)])
|
||||
|
||||
def Ht(self, x):
|
||||
return self.V(self._pad(self.D * self.Ut(x)[:, :self.D.size(0)]))
|
||||
|
||||
def Hinv(self, x):
|
||||
r"""Multiplies x by the pseudo inverse of H.
|
||||
"""
|
||||
x = self.Ut(x)
|
||||
x[:, :self.D.size(0)] = x[:, :self.D.size(0)] / self.D
|
||||
return self.V(self._pad(x))
|
||||
|
||||
def _pad(self, x):
|
||||
o = x.new_zeros(x.size(0), self.v.size(0))
|
||||
o[:, :self.u.size(0)] = x.view(x.size(0), -1)
|
||||
return o
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
r"""Update the data type and device of UDV matrices.
|
||||
"""
|
||||
for k, v in self.__dict__.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
setattr(self, k, v.to(*args, **kwargs))
|
||||
return self
|
||||
|
||||
|
||||
class IdentitySVD(SVD):
|
||||
|
||||
def __init__(self, c, h, w):
|
||||
self.d = torch.ones(c * h * w)
|
||||
|
||||
def U(self, x):
|
||||
return x.clone()
|
||||
|
||||
def Ut(self, x):
|
||||
return x.clone()
|
||||
|
||||
def V(self, x):
|
||||
return x.clone()
|
||||
|
||||
def Vt(self, x):
|
||||
return x.clone()
|
||||
|
||||
def H(self, x):
|
||||
return x.clone()
|
||||
|
||||
def Ht(self, x):
|
||||
return x.clone()
|
||||
|
||||
def Hinv(self, x):
|
||||
return x.clone()
|
||||
|
||||
def _pad(self, x):
|
||||
return x.clone()
|
||||
|
||||
|
||||
class DenoiseSVD(SVD):
|
||||
|
||||
def __init__(self, c, h, w):
|
||||
self.num_entries = c * h * w
|
||||
self.d = torch.ones(self.num_entries)
|
||||
|
||||
def U(self, x):
|
||||
return x.clone()
|
||||
|
||||
def Ut(self, x):
|
||||
return x.clone()
|
||||
|
||||
def V(self, x):
|
||||
return x.clone()
|
||||
|
||||
def Vt(self, x):
|
||||
return x.clone()
|
||||
|
||||
def _pad(self, x):
|
||||
return x.clone()
|
||||
|
||||
|
||||
class ColorizationSVD(SVD):
|
||||
|
||||
def __init__(self, c, h, w):
|
||||
self.color_dim = c
|
||||
self.num_pixels = h * w
|
||||
self.u, self.d, self.v = torch.svd(torch.ones(1, c) / c, some=False)
|
||||
self.vt = self.v.t()
|
||||
|
||||
def U(self, x):
|
||||
return self.u[0, 0] * x
|
||||
|
||||
def Ut(self, x):
|
||||
return self.u[0, 0] * x
|
||||
|
||||
def V(self, x):
|
||||
return torch.einsum('ij,bjn->bin', self.v,
|
||||
x.view(x.size(0), self.color_dim,
|
||||
self.num_pixels)).flatten(1)
|
||||
|
||||
def Vt(self, x):
|
||||
return torch.einsum('ij,bjn->bin', self.vt,
|
||||
x.view(x.size(0), self.color_dim,
|
||||
self.num_pixels)).flatten(1)
|
||||
|
||||
@property
|
||||
def D(self):
|
||||
return self.d.repeat(self.num_pixels)
|
||||
|
||||
def _pad(self, x):
|
||||
o = x.new_zeros(x.size(0), self.color_dim * self.num_pixels)
|
||||
o[:, :self.num_pixels] = x
|
||||
return o
|
||||
224
modelscope/models/cv/image_to_image_translation/ops/utils.py
Normal file
224
modelscope/models/cv/image_to_image_translation/ops/utils.py
Normal file
@@ -0,0 +1,224 @@
|
||||
import base64
|
||||
import binascii
|
||||
import hashlib
|
||||
import math
|
||||
import os
|
||||
import os.path as osp
|
||||
import zipfile
|
||||
from io import BytesIO
|
||||
from multiprocessing.pool import ThreadPool as Pool
|
||||
|
||||
import cv2
|
||||
import json
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from PIL import Image
|
||||
|
||||
from .random_color import rand_color
|
||||
|
||||
__all__ = [
|
||||
'ceil_divide', 'to_device', 'rand_name', 'ema', 'parallel', 'unzip',
|
||||
'load_state_dict', 'inverse_indices', 'detect_duplicates', 'md5', 'rope',
|
||||
'format_state', 'breakup_grid', 'viz_anno_geometry', 'image_to_base64'
|
||||
]
|
||||
|
||||
TFS_CLIENT = None
|
||||
|
||||
|
||||
def ceil_divide(a, b):
|
||||
return int(math.ceil(a / b))
|
||||
|
||||
|
||||
def to_device(batch, device, non_blocking=False):
|
||||
if isinstance(batch, (list, tuple)):
|
||||
return type(batch)([to_device(u, device, non_blocking) for u in batch])
|
||||
elif isinstance(batch, dict):
|
||||
return type(batch)([(k, to_device(v, device, non_blocking))
|
||||
for k, v in batch.items()])
|
||||
elif isinstance(batch, torch.Tensor):
|
||||
return batch.to(device, non_blocking=non_blocking)
|
||||
return batch
|
||||
|
||||
|
||||
def rand_name(length=8, suffix=''):
|
||||
name = binascii.b2a_hex(os.urandom(length)).decode('utf-8')
|
||||
if suffix:
|
||||
if not suffix.startswith('.'):
|
||||
suffix = '.' + suffix
|
||||
name += suffix
|
||||
return name
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def ema(net_ema, net, beta, copy_buffer=False):
|
||||
assert 0.0 <= beta <= 1.0
|
||||
for p_ema, p in zip(net_ema.parameters(), net.parameters()):
|
||||
p_ema.copy_(p.lerp(p_ema, beta))
|
||||
if copy_buffer:
|
||||
for b_ema, b in zip(net_ema.buffers(), net.buffers()):
|
||||
b_ema.copy_(b)
|
||||
|
||||
|
||||
def parallel(func, args_list, num_workers=32, timeout=None):
|
||||
assert isinstance(args_list, list)
|
||||
if not isinstance(args_list[0], tuple):
|
||||
args_list = [(args, ) for args in args_list]
|
||||
if num_workers == 0:
|
||||
return [func(*args) for args in args_list]
|
||||
with Pool(processes=num_workers) as pool:
|
||||
results = [pool.apply_async(func, args) for args in args_list]
|
||||
results = [res.get(timeout=timeout) for res in results]
|
||||
return results
|
||||
|
||||
|
||||
def unzip(filename, dst_dir=None):
|
||||
if dst_dir is None:
|
||||
dst_dir = osp.dirname(filename)
|
||||
with zipfile.ZipFile(filename, 'r') as zip_ref:
|
||||
zip_ref.extractall(dst_dir)
|
||||
|
||||
|
||||
def load_state_dict(module, state_dict, drop_prefix=''):
|
||||
# find incompatible key-vals
|
||||
src, dst = state_dict, module.state_dict()
|
||||
if drop_prefix:
|
||||
src = type(src)([
|
||||
(k[len(drop_prefix):] if k.startswith(drop_prefix) else k, v)
|
||||
for k, v in src.items()
|
||||
])
|
||||
missing = [k for k in dst if k not in src]
|
||||
unexpected = [k for k in src if k not in dst]
|
||||
unmatched = [
|
||||
k for k in src.keys() & dst.keys() if src[k].shape != dst[k].shape
|
||||
]
|
||||
|
||||
# keep only compatible key-vals
|
||||
incompatible = set(unexpected + unmatched)
|
||||
src = type(src)([(k, v) for k, v in src.items() if k not in incompatible])
|
||||
module.load_state_dict(src, strict=False)
|
||||
|
||||
# report incompatible key-vals
|
||||
if len(missing) != 0:
|
||||
print(' Missing: ' + ', '.join(missing), flush=True)
|
||||
if len(unexpected) != 0:
|
||||
print(' Unexpected: ' + ', '.join(unexpected), flush=True)
|
||||
if len(unmatched) != 0:
|
||||
print(' Shape unmatched: ' + ', '.join(unmatched), flush=True)
|
||||
|
||||
|
||||
def inverse_indices(indices):
|
||||
r"""Inverse map of indices.
|
||||
E.g., if A[indices] == B, then B[inv_indices] == A.
|
||||
"""
|
||||
inv_indices = torch.empty_like(indices)
|
||||
inv_indices[indices] = torch.arange(len(indices)).to(indices)
|
||||
return inv_indices
|
||||
|
||||
|
||||
def detect_duplicates(feats, thr=0.9):
|
||||
assert feats.ndim == 2
|
||||
|
||||
# compute simmat
|
||||
feats = F.normalize(feats, p=2, dim=1)
|
||||
simmat = torch.mm(feats, feats.T)
|
||||
simmat.triu_(1)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# detect duplicates
|
||||
mask = ~simmat.gt(thr).any(dim=0)
|
||||
return torch.where(mask)[0]
|
||||
|
||||
|
||||
def md5(filename):
|
||||
with open(filename, 'rb') as f:
|
||||
return hashlib.md5(f.read()).hexdigest()
|
||||
|
||||
|
||||
def rope(x):
|
||||
r"""Apply rotary position embedding on x of shape [B, *(spatial dimensions), C].
|
||||
"""
|
||||
# reshape
|
||||
shape = x.shape
|
||||
x = x.view(x.size(0), -1, x.size(-1))
|
||||
l, c = x.shape[-2:]
|
||||
assert c % 2 == 0
|
||||
half = c // 2
|
||||
|
||||
# apply rotary position embedding on x
|
||||
sinusoid = torch.outer(
|
||||
torch.arange(l).to(x),
|
||||
torch.pow(10000, -torch.arange(half).to(x).div(half)))
|
||||
sin, cos = torch.sin(sinusoid), torch.cos(sinusoid)
|
||||
x1, x2 = x.chunk(2, dim=-1)
|
||||
x = torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1)
|
||||
|
||||
# reshape back
|
||||
return x.view(shape)
|
||||
|
||||
|
||||
def format_state(state, filename=None):
|
||||
r"""For comparing/aligning state_dict.
|
||||
"""
|
||||
content = '\n'.join([f'{k}\t{tuple(v.shape)}' for k, v in state.items()])
|
||||
if filename:
|
||||
with open(filename, 'w') as f:
|
||||
f.write(content)
|
||||
|
||||
|
||||
def breakup_grid(img, grid_size):
|
||||
r"""The inverse operator of ``torchvision.utils.make_grid``.
|
||||
"""
|
||||
# params
|
||||
nrow = img.height // grid_size
|
||||
ncol = img.width // grid_size
|
||||
wrow = wcol = 2 # NOTE: use default values here
|
||||
|
||||
# collect grids
|
||||
grids = []
|
||||
for i in range(nrow):
|
||||
for j in range(ncol):
|
||||
x1 = j * grid_size + (j + 1) * wcol
|
||||
y1 = i * grid_size + (i + 1) * wrow
|
||||
grids.append(img.crop((x1, y1, x1 + grid_size, y1 + grid_size)))
|
||||
return grids
|
||||
|
||||
|
||||
def viz_anno_geometry(item):
|
||||
r"""Visualize an annotation item from SmartLabel.
|
||||
"""
|
||||
if isinstance(item, str):
|
||||
item = json.loads(item)
|
||||
assert isinstance(item, dict)
|
||||
|
||||
# read image
|
||||
orig_img = read_image(item['image_url'], retry=100)
|
||||
img = cv2.cvtColor(np.asarray(orig_img), cv2.COLOR_BGR2RGB)
|
||||
|
||||
# loop over geometries
|
||||
for geometry in item['sd_result']['items']:
|
||||
# params
|
||||
poly_img = img.copy()
|
||||
color = rand_color()
|
||||
points = np.array(geometry['meta']['geometry']).round().astype(int)
|
||||
line_color = tuple([int(u * 0.55) for u in color])
|
||||
|
||||
# draw polygons
|
||||
poly_img = cv2.fillPoly(poly_img, pts=[points], color=color)
|
||||
poly_img = cv2.polylines(
|
||||
poly_img,
|
||||
pts=[points],
|
||||
isClosed=True,
|
||||
color=line_color,
|
||||
thickness=2)
|
||||
|
||||
# mixing
|
||||
img = np.clip(0.25 * img + 0.75 * poly_img, 0, 255).astype(np.uint8)
|
||||
return orig_img, Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
|
||||
|
||||
|
||||
def image_to_base64(img, format='JPEG'):
|
||||
buffer = BytesIO()
|
||||
img.save(buffer, format=format)
|
||||
code = base64.b64encode(buffer.getvalue()).decode('utf-8')
|
||||
return code
|
||||
@@ -15,6 +15,7 @@ if TYPE_CHECKING:
|
||||
from .image_color_enhance_pipeline import ImageColorEnhancePipeline
|
||||
from .image_colorization_pipeline import ImageColorizationPipeline
|
||||
from .image_instance_segmentation_pipeline import ImageInstanceSegmentationPipeline
|
||||
from .image_to_image_translation_pipeline import Image2ImageTranslationPipeline
|
||||
from .video_category_pipeline import VideoCategoryPipeline
|
||||
from .image_matting_pipeline import ImageMattingPipeline
|
||||
from .image_super_resolution_pipeline import ImageSuperResolutionPipeline
|
||||
|
||||
325
modelscope/pipelines/cv/image_to_image_translation_pipeline.py
Normal file
325
modelscope/pipelines/cv/image_to_image_translation_pipeline.py
Normal file
@@ -0,0 +1,325 @@
|
||||
import io
|
||||
import os.path as osp
|
||||
import sys
|
||||
from typing import Any, Dict
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
from PIL import Image
|
||||
from torchvision.utils import save_image
|
||||
|
||||
import modelscope.models.cv.image_to_image_translation.data as data
|
||||
import modelscope.models.cv.image_to_image_translation.models as models
|
||||
import modelscope.models.cv.image_to_image_translation.ops as ops
|
||||
from modelscope.fileio import File
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.models.cv.image_to_image_translation.model_translation import \
|
||||
UNet
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.base import Input, Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.preprocessors import load_image
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def save_grid(imgs, filename, nrow=5):
|
||||
save_image(
|
||||
imgs.clamp(-1, 1), filename, range=(-1, 1), normalize=True, nrow=nrow)
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.image_generation, module_name=Pipelines.image2image_translation)
|
||||
class Image2ImageTranslationPipeline(Pipeline):
|
||||
|
||||
def __init__(self, model: str, **kwargs):
|
||||
"""
|
||||
use `model` to create a kws pipeline for prediction
|
||||
Args:
|
||||
model: model id on modelscope hub.
|
||||
"""
|
||||
super().__init__(model=model)
|
||||
config_path = osp.join(self.model, ModelFile.CONFIGURATION)
|
||||
logger.info(f'loading config from {config_path}')
|
||||
self.cfg = Config.from_file(config_path)
|
||||
if torch.cuda.is_available():
|
||||
self._device = torch.device('cuda')
|
||||
else:
|
||||
self._device = torch.device('cpu')
|
||||
self.repetition = 4
|
||||
# load autoencoder model
|
||||
ae_model_path = osp.join(self.model, self.cfg.ModelPath.ae_model_path)
|
||||
logger.info(f'loading autoencoder model from {ae_model_path}')
|
||||
self.autoencoder = models.VQAutoencoder(
|
||||
dim=self.cfg.Params.ae.ae_dim,
|
||||
z_dim=self.cfg.Params.ae.ae_z_dim,
|
||||
dim_mult=self.cfg.Params.ae.ae_dim_mult,
|
||||
attn_scales=self.cfg.Params.ae.ae_attn_scales,
|
||||
codebook_size=self.cfg.Params.ae.ae_codebook_size).eval(
|
||||
).requires_grad_(False).to(self._device) # noqa E123
|
||||
self.autoencoder.load_state_dict(
|
||||
torch.load(ae_model_path, map_location=self._device))
|
||||
logger.info('load autoencoder model done')
|
||||
|
||||
# load palette model
|
||||
palette_model_path = osp.join(self.model, ModelFile.TORCH_MODEL_FILE)
|
||||
logger.info(f'loading palette model from {palette_model_path}')
|
||||
self.palette = UNet(
|
||||
resolution=self.cfg.Params.unet.unet_resolution,
|
||||
in_dim=self.cfg.Params.unet.unet_in_dim,
|
||||
dim=self.cfg.Params.unet.unet_dim,
|
||||
context_dim=self.cfg.Params.unet.unet_context_dim,
|
||||
out_dim=self.cfg.Params.unet.unet_out_dim,
|
||||
dim_mult=self.cfg.Params.unet.unet_dim_mult,
|
||||
num_heads=self.cfg.Params.unet.unet_num_heads,
|
||||
head_dim=None,
|
||||
num_res_blocks=self.cfg.Params.unet.unet_res_blocks,
|
||||
attn_scales=self.cfg.Params.unet.unet_attn_scales,
|
||||
num_classes=self.cfg.Params.unet.unet_num_classes + 1,
|
||||
dropout=self.cfg.Params.unet.unet_dropout).eval().requires_grad_(
|
||||
False).to(self._device)
|
||||
self.palette.load_state_dict(
|
||||
torch.load(palette_model_path, map_location=self._device))
|
||||
logger.info('load palette model done')
|
||||
|
||||
# diffusion
|
||||
logger.info('Initialization diffusion ...')
|
||||
betas = ops.beta_schedule(self.cfg.Params.diffusion.schedule,
|
||||
self.cfg.Params.diffusion.num_timesteps)
|
||||
self.diffusion = ops.GaussianDiffusion(
|
||||
betas=betas,
|
||||
mean_type=self.cfg.Params.diffusion.mean_type,
|
||||
var_type=self.cfg.Params.diffusion.var_type,
|
||||
loss_type=self.cfg.Params.diffusion.loss_type,
|
||||
rescale_timesteps=False)
|
||||
|
||||
self.transforms = T.Compose([
|
||||
data.PadToSquare(),
|
||||
T.Resize(
|
||||
self.cfg.DATA.scale_size,
|
||||
interpolation=T.InterpolationMode.BICUBIC),
|
||||
T.ToTensor(),
|
||||
T.Normalize(mean=self.cfg.DATA.mean, std=self.cfg.DATA.std)
|
||||
])
|
||||
|
||||
def preprocess(self, input: Input) -> Dict[str, Any]:
|
||||
if len(input) == 3: # colorization
|
||||
_, input_type, save_path = input
|
||||
elif len(input) == 4: # uncropping or in-painting
|
||||
_, meta, input_type, save_path = input
|
||||
if input_type == 0: # uncropping
|
||||
assert meta in ['up', 'down', 'left', 'right']
|
||||
direction = meta
|
||||
|
||||
list_ = []
|
||||
for i in range(len(input) - 2):
|
||||
input_img = input[i]
|
||||
if input_img in ['up', 'down', 'left', 'right']:
|
||||
continue
|
||||
if isinstance(input_img, str):
|
||||
if input_type == 2 and i == 0:
|
||||
logger.info('Loading image by origin way ... ')
|
||||
bytes = File.read(input_img)
|
||||
img = Image.open(io.BytesIO(bytes))
|
||||
assert len(img.split()) == 4
|
||||
else:
|
||||
img = load_image(input_img)
|
||||
elif isinstance(input_img, PIL.Image.Image):
|
||||
img = input_img.convert('RGB')
|
||||
elif isinstance(input_img, np.ndarray):
|
||||
if len(input_img.shape) == 2:
|
||||
input_img = cv2.cvtColor(input_img, cv2.COLOR_GRAY2BGR)
|
||||
img = input_img[:, :, ::-1]
|
||||
img = Image.fromarray(img.astype('uint8')).convert('RGB')
|
||||
else:
|
||||
raise TypeError(f'input should be either str, PIL.Image,'
|
||||
f' np.array, but got {type(input)}')
|
||||
list_.append(img)
|
||||
img_list = []
|
||||
if input_type != 2:
|
||||
for img in list_:
|
||||
img = self.transforms(img)
|
||||
imgs = torch.unsqueeze(img, 0)
|
||||
imgs = imgs.to(self._device)
|
||||
img_list.append(imgs)
|
||||
elif input_type == 2:
|
||||
mask, masked_img = list_[0], list_[1]
|
||||
img = self.transforms(masked_img.convert('RGB'))
|
||||
mask = torch.from_numpy(
|
||||
np.array(
|
||||
mask.resize((img.shape[2], img.shape[1])),
|
||||
dtype=np.float32)[:, :, -1] / 255.0).unsqueeze(0)
|
||||
img = (1 - mask) * img + mask * torch.randn_like(img).clamp_(-1, 1)
|
||||
imgs = img.unsqueeze(0).to(self._device)
|
||||
b, c, h, w = imgs.shape
|
||||
y = torch.LongTensor([self.cfg.Classes.class_id]).to(self._device)
|
||||
|
||||
if input_type == 0:
|
||||
assert len(img_list) == 1
|
||||
result = {
|
||||
'image_data': img_list[0],
|
||||
'c': c,
|
||||
'h': h,
|
||||
'w': w,
|
||||
'direction': direction,
|
||||
'type': input_type,
|
||||
'y': y,
|
||||
'save_path': save_path
|
||||
}
|
||||
elif input_type == 1:
|
||||
assert len(img_list) == 1
|
||||
result = {
|
||||
'image_data': img_list[0],
|
||||
'c': c,
|
||||
'h': h,
|
||||
'w': w,
|
||||
'type': input_type,
|
||||
'y': y,
|
||||
'save_path': save_path
|
||||
}
|
||||
elif input_type == 2:
|
||||
result = {
|
||||
'image_data': imgs,
|
||||
# 'image_mask': mask,
|
||||
'c': c,
|
||||
'h': h,
|
||||
'w': w,
|
||||
'type': input_type,
|
||||
'y': y,
|
||||
'save_path': save_path
|
||||
}
|
||||
return result
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
|
||||
type_ = input['type']
|
||||
if type_ == 0:
|
||||
# Uncropping
|
||||
img = input['image_data']
|
||||
direction = input['direction']
|
||||
y = input['y']
|
||||
|
||||
# fix seed
|
||||
torch.manual_seed(1 * 8888)
|
||||
torch.cuda.manual_seed(1 * 8888)
|
||||
|
||||
logger.info(f'Processing {direction} uncropping')
|
||||
img = img.clone()
|
||||
i_y = y.repeat(self.repetition, 1)
|
||||
if direction == 'up':
|
||||
img[:, :, input['h'] // 2:, :] = torch.randn_like(
|
||||
img[:, :, input['h'] // 2:, :])
|
||||
elif direction == 'down':
|
||||
img[:, :, :input['h'] // 2, :] = torch.randn_like(
|
||||
img[:, :, :input['h'] // 2, :])
|
||||
elif direction == 'left':
|
||||
img[:, :, :,
|
||||
input['w'] // 2:] = torch.randn_like(img[:, :, :,
|
||||
input['w'] // 2:])
|
||||
elif direction == 'right':
|
||||
img[:, :, :, :input['w'] // 2] = torch.randn_like(
|
||||
img[:, :, :, :input['w'] // 2])
|
||||
i_concat = self.autoencoder.encode(img).repeat(
|
||||
self.repetition, 1, 1, 1)
|
||||
|
||||
# sample images
|
||||
x0 = self.diffusion.ddim_sample_loop(
|
||||
noise=torch.randn_like(i_concat),
|
||||
model=self.palette,
|
||||
model_kwargs=[{
|
||||
'y': i_y,
|
||||
'concat': i_concat
|
||||
}, {
|
||||
'y':
|
||||
torch.full_like(i_y,
|
||||
self.cfg.Params.unet.unet_num_classes),
|
||||
'concat':
|
||||
i_concat
|
||||
}],
|
||||
guide_scale=1.0,
|
||||
clamp=None,
|
||||
ddim_timesteps=50,
|
||||
eta=1.0)
|
||||
i_gen_imgs = self.autoencoder.decode(x0)
|
||||
save_grid(i_gen_imgs, input['save_path'], nrow=4)
|
||||
return {OutputKeys.OUTPUT_IMG: i_gen_imgs}
|
||||
|
||||
elif type_ == 1:
|
||||
# Colorization #
|
||||
img = input['image_data']
|
||||
y = input['y']
|
||||
# fix seed
|
||||
torch.manual_seed(1 * 8888)
|
||||
torch.cuda.manual_seed(1 * 8888)
|
||||
|
||||
logger.info('Processing Colorization')
|
||||
img = img.clone()
|
||||
img = img.mean(dim=1, keepdim=True).repeat(1, 3, 1, 1)
|
||||
i_concat = self.autoencoder.encode(img).repeat(
|
||||
self.repetition, 1, 1, 1)
|
||||
i_y = y.repeat(self.repetition, 1)
|
||||
|
||||
# sample images
|
||||
x0 = self.diffusion.ddim_sample_loop(
|
||||
noise=torch.randn_like(i_concat),
|
||||
model=self.palette,
|
||||
model_kwargs=[{
|
||||
'y': i_y,
|
||||
'concat': i_concat
|
||||
}, {
|
||||
'y':
|
||||
torch.full_like(i_y,
|
||||
self.cfg.Params.unet.unet_num_classes),
|
||||
'concat':
|
||||
i_concat
|
||||
}],
|
||||
guide_scale=1.0,
|
||||
clamp=None,
|
||||
ddim_timesteps=50,
|
||||
eta=0.0)
|
||||
i_gen_imgs = self.autoencoder.decode(x0)
|
||||
save_grid(i_gen_imgs, input['save_path'], nrow=4)
|
||||
return {OutputKeys.OUTPUT_IMG: i_gen_imgs}
|
||||
elif type_ == 2:
|
||||
# Combination #
|
||||
logger.info('Processing Combination')
|
||||
|
||||
# prepare inputs
|
||||
img = input['image_data']
|
||||
concat = self.autoencoder.encode(img).repeat(
|
||||
self.repetition, 1, 1, 1)
|
||||
y = torch.LongTensor([126]).unsqueeze(0).to(self._device).repeat(
|
||||
self.repetition, 1)
|
||||
|
||||
# sample images
|
||||
x0 = self.diffusion.ddim_sample_loop(
|
||||
noise=torch.randn_like(concat),
|
||||
model=self.palette,
|
||||
model_kwargs=[{
|
||||
'y': y,
|
||||
'concat': concat
|
||||
}, {
|
||||
'y':
|
||||
torch.full_like(y, self.cfg.Params.unet.unet_num_classes),
|
||||
'concat':
|
||||
concat
|
||||
}],
|
||||
guide_scale=1.0,
|
||||
clamp=None,
|
||||
ddim_timesteps=50,
|
||||
eta=1.0)
|
||||
i_gen_imgs = self.autoencoder.decode(x0)
|
||||
save_grid(i_gen_imgs, input['save_path'], nrow=4)
|
||||
return {OutputKeys.OUTPUT_IMG: i_gen_imgs}
|
||||
else:
|
||||
raise TypeError(
|
||||
f'input type should be 0 (Uncropping), 1 (Colorization), 2 (Combation)'
|
||||
f' but got {type_}')
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return inputs
|
||||
38
tests/pipelines/test_image2image_translation.py
Normal file
38
tests/pipelines/test_image2image_translation.py
Normal file
@@ -0,0 +1,38 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os.path as osp
|
||||
import shutil
|
||||
import unittest
|
||||
|
||||
from modelscope.fileio import File
|
||||
from modelscope.msdatasets import MsDataset
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
|
||||
class Image2ImageTranslationTest(unittest.TestCase):
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_modelhub(self):
|
||||
r"""We provide three translation modes, i.e., uncropping, colorization and combination.
|
||||
You can pass the following parameters for different mode.
|
||||
1. Uncropping Mode:
|
||||
result = img2img_gen_pipeline(('data/test/images/img2img_input.jpg', 'left', 0, 'result.jpg'))
|
||||
2. Colorization Mode:
|
||||
result = img2img_gen_pipeline(('data/test/images/img2img_input.jpg', 1, 'result.jpg'))
|
||||
3. Combination Mode:
|
||||
just like the following code.
|
||||
"""
|
||||
img2img_gen_pipeline = pipeline(
|
||||
Tasks.image_generation,
|
||||
model='damo/cv_latent_diffusion_image2image_translation')
|
||||
result = img2img_gen_pipeline(
|
||||
('data/test/images/img2img_input_mask.png',
|
||||
'data/test/images/img2img_input_masked_img.png', 2,
|
||||
'result.jpg')) # combination mode
|
||||
|
||||
print(f'output: {result}.')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user