Merge remote-tracking branch 'origin/master' into nlp/space/dst

This commit is contained in:
智丞
2022-07-02 22:02:47 +08:00
13 changed files with 3101 additions and 0 deletions

View File

@@ -29,6 +29,7 @@ class Models(object):
ofa = 'ofa'
clip = 'clip-multi-modal-embedding'
mplug = 'mplug'
imagen = 'imagen-text-to-image-synthesis'
class Pipelines(object):
@@ -71,6 +72,7 @@ class Pipelines(object):
image_caption = 'image-captioning'
multi_modal_embedding = 'multi-modal-embedding'
visual_question_answering = 'visual-question-answering'
text_to_image_synthesis = 'text-to-image-synthesis'
class Trainers(object):

View File

@@ -1,4 +1,5 @@
from .clip.clip_model import CLIPForMultiModalEmbedding
from .image_captioning_model import OfaForImageCaptioning
from .imagen.imagen_model import ImagenForTextToImageSynthesis
from .mplug_for_visual_question_answering import \
MPlugForVisualQuestionAnswering

View File

@@ -0,0 +1,595 @@
import math
import torch
__all__ = ['GaussianDiffusion', 'beta_schedule']
def kl_divergence(mu1, logvar1, mu2, logvar2):
a = -1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2)
b = ((mu1 - mu2)**2) * torch.exp(-logvar2)
return 0.5 * (a + b)
def standard_normal_cdf(x):
return 0.5 * (1.0 + torch.tanh(
math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
def discretized_gaussian_log_likelihood(x0, mean, log_scale):
assert x0.shape == mean.shape == log_scale.shape
cx = x0 - mean
inv_stdv = torch.exp(-log_scale)
cdf_plus = standard_normal_cdf(inv_stdv * (cx + 1.0 / 255.0))
cdf_min = standard_normal_cdf(inv_stdv * (cx - 1.0 / 255.0))
log_cdf_plus = torch.log(cdf_plus.clamp(min=1e-12))
log_one_minus_cdf_min = torch.log((1.0 - cdf_min).clamp(min=1e-12))
cdf_delta = cdf_plus - cdf_min
log_probs = torch.where(
x0 < -0.999, log_cdf_plus,
torch.where(x0 > 0.999, log_one_minus_cdf_min,
torch.log(cdf_delta.clamp(min=1e-12))))
assert log_probs.shape == x0.shape
return log_probs
def _i(tensor, t, x):
shape = (x.size(0), ) + (1, ) * (x.ndim - 1)
return tensor[t].view(shape).to(x)
def cosine_fn(u):
return math.cos((u + 0.008) / 1.008 * math.pi / 2)**2
def beta_schedule(schedule,
num_timesteps=1000,
init_beta=None,
last_beta=None):
if schedule == 'linear':
scale = 1000.0 / num_timesteps
init_beta = init_beta or scale * 0.0001
last_beta = last_beta or scale * 0.02
return torch.linspace(
init_beta, last_beta, num_timesteps, dtype=torch.float64)
elif schedule == 'quadratic':
init_beta = init_beta or 0.0015
last_beta = last_beta or 0.0195
return torch.linspace(
init_beta**0.5, last_beta**0.5, num_timesteps,
dtype=torch.float64)**2
elif schedule == 'cosine':
betas = []
for step in range(num_timesteps):
t1 = step / num_timesteps
t2 = (step + 1) / num_timesteps
betas.append(min(1.0 - cosine_fn(t2) / cosine_fn(t1), 0.999))
return torch.tensor(betas, dtype=torch.float64)
else:
raise ValueError(f'Unsupported schedule: {schedule}')
class GaussianDiffusion(object):
def __init__(self,
betas,
mean_type='eps',
var_type='learned_range',
loss_type='mse',
rescale_timesteps=False):
# check input
if not isinstance(betas, torch.DoubleTensor):
betas = torch.tensor(betas, dtype=torch.float64)
assert min(betas) > 0 and max(betas) <= 1
assert mean_type in ['x0', 'x_{t-1}', 'eps']
assert var_type in [
'learned', 'learned_range', 'fixed_large', 'fixed_small'
]
assert loss_type in [
'mse', 'rescaled_mse', 'kl', 'rescaled_kl', 'l1', 'rescaled_l1'
]
self.betas = betas
self.num_timesteps = len(betas)
self.mean_type = mean_type
self.var_type = var_type
self.loss_type = loss_type
self.rescale_timesteps = rescale_timesteps
# alphas
alphas = 1 - self.betas
self.alphas_cumprod = torch.cumprod(alphas, dim=0)
self.alphas_cumprod_prev = torch.cat(
[alphas.new_ones([1]), self.alphas_cumprod[:-1]])
self.alphas_cumprod_next = torch.cat(
[self.alphas_cumprod[1:],
alphas.new_zeros([1])])
# q(x_t | x_{t-1})
self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0
- self.alphas_cumprod)
self.log_one_minus_alphas_cumprod = torch.log(1.0
- self.alphas_cumprod)
self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod)
self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod
- 1)
# q(x_{t-1} | x_t, x_0)
self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / (
1.0 - self.alphas_cumprod)
self.posterior_log_variance_clipped = torch.log(
self.posterior_variance.clamp(1e-20))
self.posterior_mean_coef1 = betas * torch.sqrt(
self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
self.posterior_mean_coef2 = (
1.0 - self.alphas_cumprod_prev) * torch.sqrt(alphas) / (
1.0 - self.alphas_cumprod)
def q_sample(self, x0, t, noise=None):
noise = torch.randn_like(x0) if noise is None else noise
return _i(self.sqrt_alphas_cumprod, t, x0) * x0 + _i(
self.sqrt_one_minus_alphas_cumprod, t, x0) * noise
def q_mean_variance(self, x0, t):
mu = _i(self.sqrt_alphas_cumprod, t, x0) * x0
var = _i(1.0 - self.alphas_cumprod, t, x0)
log_var = _i(self.log_one_minus_alphas_cumprod, t, x0)
return mu, var, log_var
def q_posterior_mean_variance(self, x0, xt, t):
mu = _i(self.posterior_mean_coef1, t, xt) * x0 + _i(
self.posterior_mean_coef2, t, xt) * xt
var = _i(self.posterior_variance, t, xt)
log_var = _i(self.posterior_log_variance_clipped, t, xt)
return mu, var, log_var
@torch.no_grad()
def p_sample(self,
xt,
t,
model,
model_kwargs={},
clamp=None,
percentile=None,
condition_fn=None,
guide_scale=None):
# predict distribution of p(x_{t-1} | x_t)
mu, var, log_var, x0 = self.p_mean_variance(xt, t, model, model_kwargs,
clamp, percentile,
guide_scale)
# random sample (with optional conditional function)
noise = torch.randn_like(xt)
shape = (-1, ) + ((1, ) * (xt.ndim - 1))
mask = t.ne(0).float().view(*shape) # no noise when t == 0
if condition_fn is not None:
grad = condition_fn(xt, self._scale_timesteps(t), **model_kwargs)
mu = mu.float() + var * grad.float()
xt_1 = mu + mask * torch.exp(0.5 * log_var) * noise
return xt_1, x0
@torch.no_grad()
def p_sample_loop(self,
noise,
model,
model_kwargs={},
clamp=None,
percentile=None,
condition_fn=None,
guide_scale=None):
# prepare input
b, c, h, w = noise.size()
xt = noise
# diffusion process
for step in torch.arange(self.num_timesteps).flip(0):
t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
xt, _ = self.p_sample(xt, t, model, model_kwargs, clamp,
percentile, condition_fn, guide_scale)
return xt
def p_mean_variance(self,
xt,
t,
model,
model_kwargs={},
clamp=None,
percentile=None,
guide_scale=None):
# predict distribution
if guide_scale is None:
out = model(xt, self._scale_timesteps(t), **model_kwargs)
else:
# classifier-free guidance
# (model_kwargs[0]: conditional kwargs; model_kwargs[1]: non-conditional kwargs)
assert isinstance(model_kwargs, list) and len(model_kwargs) == 2
assert self.mean_type == 'eps'
y_out = model(xt, self._scale_timesteps(t), **model_kwargs[0])
u_out = model(xt, self._scale_timesteps(t), **model_kwargs[1])
a = u_out[:, :3]
b = guide_scale * (y_out[:, :3] - u_out[:, :3])
c = y_out[:, 3:]
out = torch.cat([a + b, c], dim=1)
# compute variance
if self.var_type == 'learned':
out, log_var = out.chunk(2, dim=1)
var = torch.exp(log_var)
elif self.var_type == 'learned_range':
out, fraction = out.chunk(2, dim=1)
min_log_var = _i(self.posterior_log_variance_clipped, t, xt)
max_log_var = _i(torch.log(self.betas), t, xt)
fraction = (fraction + 1) / 2.0
log_var = fraction * max_log_var + (1 - fraction) * min_log_var
var = torch.exp(log_var)
elif self.var_type == 'fixed_large':
var = _i(
torch.cat([self.posterior_variance[1:2], self.betas[1:]]), t,
xt)
log_var = torch.log(var)
elif self.var_type == 'fixed_small':
var = _i(self.posterior_variance, t, xt)
log_var = _i(self.posterior_log_variance_clipped, t, xt)
# compute mean and x0
if self.mean_type == 'x_{t-1}':
mu = out # x_{t-1}
x0 = _i(1.0 / self.posterior_mean_coef1, t, xt) * mu - _i(
self.posterior_mean_coef2 / self.posterior_mean_coef1, t,
xt) * xt
elif self.mean_type == 'x0':
x0 = out
mu, _, _ = self.q_posterior_mean_variance(x0, xt, t)
elif self.mean_type == 'eps':
x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - _i(
self.sqrt_recipm1_alphas_cumprod, t, xt) * out
mu, _, _ = self.q_posterior_mean_variance(x0, xt, t)
# restrict the range of x0
if percentile is not None:
assert percentile > 0 and percentile <= 1 # e.g., 0.995
s = torch.quantile(
x0.flatten(1).abs(), percentile,
dim=1).clamp_(1.0).view(-1, 1, 1, 1)
x0 = torch.min(s, torch.max(-s, x0)) / s
elif clamp is not None:
x0 = x0.clamp(-clamp, clamp)
return mu, var, log_var, x0
@torch.no_grad()
def ddim_sample(self,
xt,
t,
model,
model_kwargs={},
clamp=None,
percentile=None,
condition_fn=None,
guide_scale=None,
ddim_timesteps=20,
eta=0.0):
stride = self.num_timesteps // ddim_timesteps
# predict distribution of p(x_{t-1} | x_t)
_, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp,
percentile, guide_scale)
if condition_fn is not None:
# x0 -> eps
alpha = _i(self.alphas_cumprod, t, xt)
eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / _i(
self.sqrt_recipm1_alphas_cumprod, t, xt)
eps = eps - (1 - alpha).sqrt() * condition_fn(
xt, self._scale_timesteps(t), **model_kwargs)
# eps -> x0
x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - _i(
self.sqrt_recipm1_alphas_cumprod, t, xt) * eps
# derive variables
eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / _i(
self.sqrt_recipm1_alphas_cumprod, t, xt)
alphas = _i(self.alphas_cumprod, t, xt)
alphas_prev = _i(self.alphas_cumprod, (t - stride).clamp(0), xt)
a = (1 - alphas_prev) / (1 - alphas)
b = (1 - alphas / alphas_prev)
sigmas = eta * torch.sqrt(a * b)
# random sample
noise = torch.randn_like(xt)
direction = torch.sqrt(1 - alphas_prev - sigmas**2) * eps
mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1)))
xt_1 = torch.sqrt(alphas_prev) * x0 + direction + mask * sigmas * noise
return xt_1, x0
@torch.no_grad()
def ddim_sample_loop(self,
noise,
model,
model_kwargs={},
clamp=None,
percentile=None,
condition_fn=None,
guide_scale=None,
ddim_timesteps=20,
eta=0.0):
# prepare input
b, c, h, w = noise.size()
xt = noise
# diffusion process (TODO: clamp is inaccurate! Consider replacing the stride by explicit prev/next steps)
steps = (1 + torch.arange(0, self.num_timesteps,
self.num_timesteps // ddim_timesteps)).clamp(
0, self.num_timesteps - 1).flip(0)
for step in steps:
t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
xt, _ = self.ddim_sample(xt, t, model, model_kwargs, clamp,
percentile, condition_fn, guide_scale,
ddim_timesteps, eta)
return xt
@torch.no_grad()
def ddim_reverse_sample(self,
xt,
t,
model,
model_kwargs={},
clamp=None,
percentile=None,
guide_scale=None,
ddim_timesteps=20):
stride = self.num_timesteps // ddim_timesteps
# predict distribution of p(x_{t-1} | x_t)
_, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp,
percentile, guide_scale)
# derive variables
eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / _i(
self.sqrt_recipm1_alphas_cumprod, t, xt)
alphas_next = _i(
torch.cat(
[self.alphas_cumprod,
self.alphas_cumprod.new_zeros([1])]),
(t + stride).clamp(0, self.num_timesteps), xt)
# reverse sample
mu = torch.sqrt(alphas_next) * x0 + torch.sqrt(1 - alphas_next) * eps
return mu, x0
@torch.no_grad()
def ddim_reverse_sample_loop(self,
x0,
model,
model_kwargs={},
clamp=None,
percentile=None,
guide_scale=None,
ddim_timesteps=20):
# prepare input
b, c, h, w = x0.size()
xt = x0
# reconstruction steps
steps = torch.arange(0, self.num_timesteps,
self.num_timesteps // ddim_timesteps)
for step in steps:
t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
xt, _ = self.ddim_reverse_sample(xt, t, model, model_kwargs, clamp,
percentile, guide_scale,
ddim_timesteps)
return xt
@torch.no_grad()
def plms_sample(self,
xt,
t,
model,
model_kwargs={},
clamp=None,
percentile=None,
condition_fn=None,
guide_scale=None,
plms_timesteps=20):
stride = self.num_timesteps // plms_timesteps
# function for compute eps
def compute_eps(xt, t):
# predict distribution of p(x_{t-1} | x_t)
_, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs,
clamp, percentile, guide_scale)
# condition
if condition_fn is not None:
# x0 -> eps
alpha = _i(self.alphas_cumprod, t, xt)
eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt
- x0) / _i(self.sqrt_recipm1_alphas_cumprod, t, xt)
eps = eps - (1 - alpha).sqrt() * condition_fn(
xt, self._scale_timesteps(t), **model_kwargs)
# eps -> x0
x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - _i(
self.sqrt_recipm1_alphas_cumprod, t, xt) * eps
# derive eps
eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / _i(
self.sqrt_recipm1_alphas_cumprod, t, xt)
return eps
# function for compute x_0 and x_{t-1}
def compute_x0(eps, t):
# eps -> x0
x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - _i(
self.sqrt_recipm1_alphas_cumprod, t, xt) * eps
# deterministic sample
alphas_prev = _i(self.alphas_cumprod, (t - stride).clamp(0), xt)
direction = torch.sqrt(1 - alphas_prev) * eps
xt_1 = torch.sqrt(alphas_prev) * x0 + direction
return xt_1, x0
# PLMS sample
eps = compute_eps(xt, t)
if len(eps_cache) == 0:
# 2nd order pseudo improved Euler
xt_1, x0 = compute_x0(eps, t)
eps_next = compute_eps(xt_1, (t - stride).clamp(0))
eps_prime = (eps + eps_next) / 2.0
elif len(eps_cache) == 1:
# 2nd order pseudo linear multistep (Adams-Bashforth)
eps_prime = (3 * eps - eps_cache[-1]) / 2.0
elif len(eps_cache) == 2:
# 3nd order pseudo linear multistep (Adams-Bashforth)
eps_prime = (23 * eps - 16 * eps_cache[-1]
+ 5 * eps_cache[-2]) / 12.0
elif len(eps_cache) >= 3:
# 4nd order pseudo linear multistep (Adams-Bashforth)
eps_prime = (55 * eps - 59 * eps_cache[-1] + 37 * eps_cache[-2]
- 9 * eps_cache[-3]) / 24.0
xt_1, x0 = compute_x0(eps_prime, t)
return xt_1, x0, eps
@torch.no_grad()
def plms_sample_loop(self,
noise,
model,
model_kwargs={},
clamp=None,
percentile=None,
condition_fn=None,
guide_scale=None,
plms_timesteps=20):
# prepare input
b, c, h, w = noise.size()
xt = noise
# diffusion process
steps = (1 + torch.arange(0, self.num_timesteps,
self.num_timesteps // plms_timesteps)).clamp(
0, self.num_timesteps - 1).flip(0)
eps_cache = []
for step in steps:
# PLMS sampling step
t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
xt, _, eps = self.plms_sample(xt, t, model, model_kwargs, clamp,
percentile, condition_fn,
guide_scale, plms_timesteps,
eps_cache)
# update eps cache
eps_cache.append(eps)
if len(eps_cache) >= 4:
eps_cache.pop(0)
return xt
def loss(self, x0, t, model, model_kwargs={}, noise=None):
noise = torch.randn_like(x0) if noise is None else noise
xt = self.q_sample(x0, t, noise=noise)
# compute loss
if self.loss_type in ['kl', 'rescaled_kl']:
loss, _ = self.variational_lower_bound(x0, xt, t, model,
model_kwargs)
if self.loss_type == 'rescaled_kl':
loss = loss * self.num_timesteps
elif self.loss_type in ['mse', 'rescaled_mse', 'l1', 'rescaled_l1']:
out = model(xt, self._scale_timesteps(t), **model_kwargs)
# VLB for variation
loss_vlb = 0.0
if self.var_type in ['learned', 'learned_range']:
out, var = out.chunk(2, dim=1)
frozen = torch.cat([
out.detach(), var
], dim=1) # learn var without affecting the prediction of mean
loss_vlb, _ = self.variational_lower_bound(
x0, xt, t, model=lambda *args, **kwargs: frozen)
if self.loss_type.startswith('rescaled_'):
loss_vlb = loss_vlb * self.num_timesteps / 1000.0
# MSE/L1 for x0/eps
target = {
'eps': noise,
'x0': x0,
'x_{t-1}': self.q_posterior_mean_variance(x0, xt, t)[0]
}[self.mean_type]
loss = (out - target).pow(1 if self.loss_type.endswith('l1') else 2
).abs().flatten(1).mean(dim=1)
# total loss
loss = loss + loss_vlb
return loss
def variational_lower_bound(self,
x0,
xt,
t,
model,
model_kwargs={},
clamp=None,
percentile=None):
# compute groundtruth and predicted distributions
mu1, _, log_var1 = self.q_posterior_mean_variance(x0, xt, t)
mu2, _, log_var2, x0 = self.p_mean_variance(xt, t, model, model_kwargs,
clamp, percentile)
# compute KL loss
kl = kl_divergence(mu1, log_var1, mu2, log_var2)
kl = kl.flatten(1).mean(dim=1) / math.log(2.0)
# compute discretized NLL loss (for p(x0 | x1) only)
nll = -discretized_gaussian_log_likelihood(
x0, mean=mu2, log_scale=0.5 * log_var2)
nll = nll.flatten(1).mean(dim=1) / math.log(2.0)
# NLL for p(x0 | x1) and KL otherwise
vlb = torch.where(t == 0, nll, kl)
return vlb, x0
@torch.no_grad()
def variational_lower_bound_loop(self,
x0,
model,
model_kwargs={},
clamp=None,
percentile=None):
# prepare input and output
b, c, h, w = x0.size()
metrics = {'vlb': [], 'mse': [], 'x0_mse': []}
# loop
for step in torch.arange(self.num_timesteps).flip(0):
# compute VLB
t = torch.full((b, ), step, dtype=torch.long, device=x0.device)
noise = torch.randn_like(x0)
xt = self.q_sample(x0, t, noise)
vlb, pred_x0 = self.variational_lower_bound(
x0, xt, t, model, model_kwargs, clamp, percentile)
# predict eps from x0
eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / _i(
self.sqrt_recipm1_alphas_cumprod, t, xt)
# collect metrics
metrics['vlb'].append(vlb)
metrics['x0_mse'].append(
(pred_x0 - x0).square().flatten(1).mean(dim=1))
metrics['mse'].append(
(eps - noise).square().flatten(1).mean(dim=1))
metrics = {k: torch.stack(v, dim=1) for k, v in metrics.items()}
# compute the prior KL term for VLB, measured in bits-per-dim
mu, _, log_var = self.q_mean_variance(x0, t)
kl_prior = kl_divergence(mu, log_var, torch.zeros_like(mu),
torch.zeros_like(log_var))
kl_prior = kl_prior.flatten(1).mean(dim=1) / math.log(2.0)
# update metrics
metrics['prior_bits_per_dim'] = kl_prior
metrics['total_bits_per_dim'] = metrics['vlb'].sum(dim=1) + kl_prior
return metrics
def _scale_timesteps(self, t):
if self.rescale_timesteps:
return t.float() * 1000.0 / self.num_timesteps
return t

View File

@@ -0,0 +1,255 @@
import os.path as osp
from typing import Any, Dict
import json
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from modelscope.metainfo import Models
from modelscope.models.base import Model
from modelscope.models.builder import MODELS
from modelscope.models.multi_modal.imagen.diffusion import (GaussianDiffusion,
beta_schedule)
from modelscope.models.multi_modal.imagen.structbert import (BertConfig,
BertModel)
from modelscope.models.multi_modal.imagen.tokenizer import FullTokenizer
from modelscope.models.multi_modal.imagen.unet_generator import ImagenGenerator
from modelscope.models.multi_modal.imagen.unet_imagen_upsampler_256 import \
SuperResUNet256
from modelscope.models.multi_modal.imagen.unet_upsampler_1024 import \
ImagenUpsampler1024
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger
logger = get_logger()
__all__ = ['ImagenForTextToImageSynthesis']
def make_diffusion(schedule,
num_timesteps=1000,
init_beta=None,
last_beta=None,
var_type='fixed_small'):
betas = beta_schedule(schedule, num_timesteps, init_beta, last_beta)
diffusion = GaussianDiffusion(betas, var_type=var_type)
return diffusion
class Tokenizer(object):
def __init__(self, vocab_file, seq_len=64):
self.vocab_file = vocab_file
self.seq_len = seq_len
self.tokenizer = FullTokenizer(
vocab_file=vocab_file, do_lower_case=True)
def __call__(self, text):
# tokenization
tokens = self.tokenizer.tokenize(text)
tokens = ['[CLS]'] + tokens[:self.seq_len - 2] + ['[SEP]']
input_ids = self.tokenizer.convert_tokens_to_ids(tokens)
input_mask = [1] * len(input_ids)
segment_ids = [0] * len(input_ids)
# padding
input_ids += [0] * (self.seq_len - len(input_ids))
input_mask += [0] * (self.seq_len - len(input_mask))
segment_ids += [0] * (self.seq_len - len(segment_ids))
assert len(input_ids) == len(input_mask) == len(
segment_ids) == self.seq_len
# convert to tensors
input_ids = torch.LongTensor(input_ids)
input_mask = torch.LongTensor(input_mask)
segment_ids = torch.LongTensor(segment_ids)
return input_ids, segment_ids, input_mask
class ImagenModel(nn.Module):
def __init__(self, model_dir):
super(ImagenModel, self).__init__()
# including text and generator config
model_config = json.load(
open('{}/imagen_config.json'.format(model_dir)))
# text encoder
text_config = model_config['text_config']
self.text_encoder = BertModel(BertConfig.from_dict(text_config))
# generator (64x64)
generator_config = model_config['generator_config']
self.unet_generator = ImagenGenerator(**generator_config)
# imagen upsampler (256x256)
imagen_upsampler_256_config = model_config[
'imagen_upsampler_256_config']
self.unet_imagen_upsampler_256 = SuperResUNet256(
**imagen_upsampler_256_config)
# dalle2 upsampler (1024x1024)
upsampler_1024_config = model_config['upsampler_1024_config']
self.unet_upsampler_1024 = ImagenUpsampler1024(**upsampler_1024_config)
def forward(self, noise, timesteps, input_ids, token_type_ids,
attention_mask):
context, y = self.text_encoder(
input_ids=input_ids,
token_type_ids=token_type_ids,
attention_mask=attention_mask)
context = context[-1]
x = self.unet_generator(noise, timesteps, y, context, attention_mask)
x = self.unet_imagen_upsampler_256(noise, timesteps, x,
torch.zeros_like(timesteps), y,
context, attention_mask)
x = self.unet_upsampler_1024(x, t, x)
return x
@MODELS.register_module(
Tasks.text_to_image_synthesis, module_name=Models.imagen)
class ImagenForTextToImageSynthesis(Model):
def __init__(self, model_dir, device_id=-1):
super().__init__(model_dir=model_dir, device_id=device_id)
imagen_model = ImagenModel(model_dir=model_dir)
pretrained_params = torch.load(
osp.join(model_dir, ModelFile.TORCH_MODEL_BIN_FILE), 'cpu')
imagen_model.load_state_dict(pretrained_params)
imagen_model.eval()
self.device_id = device_id
if self.device_id >= 0:
self.device = torch.device(f'cuda:{self.device_id}')
imagen_model.to('cuda:{}'.format(self.device_id))
logger.info('Use GPU: {}'.format(self.device_id))
else:
self.device = torch.device('cpu')
logger.info('Use CPU for inference')
# modules
self.text_encoder = imagen_model.text_encoder
self.unet_generator = imagen_model.unet_generator
self.unet_imagen_upsampler_256 = imagen_model.unet_imagen_upsampler_256
self.unet_upsampler_1024 = imagen_model.unet_upsampler_1024
# text tokenizer
vocab_path = '{}/vocab.txt'.format(model_dir)
self.tokenizer = Tokenizer(vocab_file=vocab_path, seq_len=64)
# diffusion process
diffusion_params = json.load(
open('{}/diffusion_config.json'.format(model_dir)))
self.diffusion_generator = make_diffusion(
**diffusion_params['generator_config'])
self.diffusion_imagen_upsampler_256 = make_diffusion(
**diffusion_params['imagen_upsampler_256_config'])
self.diffusion_upsampler_1024 = make_diffusion(
**diffusion_params['upsampler_1024_config'])
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
if not all([key in input for key in ('text', 'noise', 'timesteps')]):
raise ValueError(
f'input should contains "text", "noise", and "timesteps", but got {input.keys()}'
)
input_ids, token_type_ids, attention_mask = self.tokenizer(
input['text'])
input_ids = input_ids.to(self.device).unsqueeze(0)
token_type_ids = token_type_ids.to(self.device).unsqueeze(0)
attention_mask = attention_mask.to(self.device).unsqueeze(0)
context, y = self.text_encoder(
input_ids=input_ids,
token_type_ids=token_type_ids,
attention_mask=attention_mask)
context = context[-1]
x = self.unet_generator(noise, timesteps, y, context, attention_mask)
x = self.unet_imagen_upsampler_256(noise, timesteps, x,
torch.zeros_like(timesteps), y,
context, attention_mask)
x = self.unet_upsampler_1024(x, t, x)
img = x.clamp(-1, 1).add(1).mul(127.5)
img = img.squeeze(0).permute(1, 2, 0).cpu().numpy().astype(np.uint8)
return img
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
return inputs
@torch.no_grad()
def generate(self, input: Dict[str, Any]) -> Dict[str, Any]:
if 'text' not in input:
raise ValueError(
f'input should contain "text", but got {input.keys()}')
# encode text
input_ids, token_type_ids, attention_mask = self.tokenizer(
input['text'])
input_ids = input_ids.to(self.device).unsqueeze(0)
token_type_ids = token_type_ids.to(self.device).unsqueeze(0)
attention_mask = attention_mask.to(self.device).unsqueeze(0)
context, y = self.text_encoder(
input_ids=input_ids,
token_type_ids=token_type_ids,
attention_mask=attention_mask)
context = context[-1]
# generation
img = self.diffusion_generator.ddim_sample_loop(
noise=torch.randn(1, 3, 64, 64).to(self.device),
model=self.unet_generator,
model_kwargs=[{
'y': y,
'context': context,
'mask': attention_mask
}, {
'y': torch.zeros_like(y),
'context': torch.zeros_like(context),
'mask': attention_mask
}],
percentile=input.get('generator_percentile', 0.995),
guide_scale=input.get('generator_guide_scale', 5.0),
ddim_timesteps=input.get('generator_ddim_timesteps', 250),
eta=input.get('generator_ddim_eta', 0.0))
# upsampling (64->256)
img = F.interpolate(
img, scale_factor=4.0, mode='bilinear', align_corners=False)
img = self.diffusion_imagen_upsampler_256.ddim_sample_loop(
noise=torch.randn_like(img),
model=self.unet_imagen_upsampler_256,
model_kwargs=[{
'lx': img,
'lt': torch.zeros(1).to(self.device),
'y': y,
'context': context,
'mask': attention_mask
}, {
'lx': img,
'lt': torch.zeros(1).to(self.device),
'y': torch.zeros_like(y),
'context': torch.zeros_like(context),
'mask': torch.zeros_like(attention_mask)
}],
percentile=input.get('generator_percentile', 0.995),
guide_scale=input.get('generator_guide_scale', 5.0),
ddim_timesteps=input.get('generator_ddim_timesteps', 50),
eta=input.get('generator_ddim_eta', 0.0))
# upsampling (256->1024)
img = F.interpolate(
img, scale_factor=4.0, mode='bilinear', align_corners=False)
img = self.diffusion_upsampler_1024.ddim_sample_loop(
noise=torch.randn_like(img),
model=self.unet_upsampler_1024,
model_kwargs={'concat': img},
percentile=input.get('upsampler_1024_percentile', 0.995),
ddim_timesteps=input.get('upsampler_1024_ddim_timesteps', 20),
eta=input.get('upsampler_1024_ddim_eta', 0.0))
# output
img = img.clamp(-1, 1).add(1).mul(127.5).squeeze(0).permute(
1, 2, 0).cpu().numpy().astype(np.uint8)
return img

View File

@@ -0,0 +1,936 @@
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team and Alibaba inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch BERT model."""
from __future__ import absolute_import, division, print_function
import copy
import math
import json
import numpy as np
import six
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.nn import CrossEntropyLoss
def gelu(x):
"""Implementation of the gelu activation function.
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
"""
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
class BertConfig(object):
"""Configuration class to store the configuration of a `BertModel`.
"""
def __init__(self,
vocab_size,
hidden_size=768,
emb_size=-1,
num_hidden_layers=12,
transformer_type='original',
transition_function='linear',
weighted_transformer=0,
num_rolled_layers=3,
num_attention_heads=12,
intermediate_size=3072,
hidden_act='gelu',
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=16,
initializer_range=0.02,
attention_type='self',
rezero=False,
pre_ln=False,
squeeze_excitation=False,
transfer_matrix=False,
dim_dropout=False,
roberta_style=False,
set_mask_zero=False,
init_scale=False,
safer_fp16=False,
grad_checkpoint=False):
"""Constructs BertConfig.
Args:
vocab_size: Vocabulary size of `inputs_ids` in `BertModel`.
hidden_size: Size of the encoder layers and the pooler layer.
num_hidden_layers: Number of hidden layers in the Transformer encoder.
num_attention_heads: Number of attention heads for each attention layer in
the Transformer encoder.
intermediate_size: The size of the "intermediate" (i.e., feed-forward)
layer in the Transformer encoder.
hidden_act: The non-linear activation function (function or string) in the
encoder and pooler.
hidden_dropout_prob: The dropout probabilitiy for all fully connected
layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob: The dropout ratio for the attention
probabilities.
max_position_embeddings: The maximum sequence length that this model might
ever be used with. Typically set this to something large just in case
(e.g., 512 or 1024 or 2048).
type_vocab_size: The vocabulary size of the `token_type_ids` passed into
`BertModel`.
initializer_range: The sttdev of the truncated_normal_initializer for
initializing all weight matrices.
"""
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.emb_size = emb_size
self.num_hidden_layers = num_hidden_layers
self.transformer_type = transformer_type
self.transition_function = transition_function
self.weighted_transformer = weighted_transformer
self.num_rolled_layers = num_rolled_layers
self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range
self.attention_type = attention_type
self.rezero = rezero
self.pre_ln = pre_ln
self.squeeze_excitation = squeeze_excitation
self.transfer_matrix = transfer_matrix
self.dim_dropout = dim_dropout
self.set_mask_zero = set_mask_zero
self.roberta_style = roberta_style
self.init_scale = init_scale
self.safer_fp16 = safer_fp16
self.grad_checkpoint = grad_checkpoint
@classmethod
def from_dict(cls, json_object):
"""Constructs a `BertConfig` from a Python dictionary of parameters."""
config = BertConfig(vocab_size=None)
for (key, value) in six.iteritems(json_object):
config.__dict__[key] = value
return config
@classmethod
def from_json_file(cls, json_file):
"""Constructs a `BertConfig` from a json file of parameters."""
with open(json_file, 'r') as reader:
text = reader.read()
return cls.from_dict(json.loads(text))
def to_dict(self):
"""Serializes this instance to a Python dictionary."""
output = copy.deepcopy(self.__dict__)
return output
def to_json_string(self):
"""Serializes this instance to a JSON string."""
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + '\n'
class BERTLayerNorm(nn.Module):
def __init__(self, config, variance_epsilon=1e-12, special_size=None):
"""Construct a layernorm module in the TF style (epsilon inside the square root).
"""
super(BERTLayerNorm, self).__init__()
self.config = config
hidden_size = special_size if special_size is not None else config.hidden_size
self.gamma = nn.Parameter(torch.ones(hidden_size))
self.beta = nn.Parameter(torch.zeros(hidden_size))
self.variance_epsilon = variance_epsilon if not config.roberta_style else 1e-5
def forward(self, x):
previous_type = x.type()
if self.config.safer_fp16:
x = x.float()
u = x.mean(-1, keepdim=True)
s = (x - u).pow(2).mean(-1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
if self.config.safer_fp16:
return (self.gamma * x + self.beta).type(previous_type)
else:
return self.gamma * x + self.beta
class BERTEmbeddings(nn.Module):
def __init__(self, config):
super(BERTEmbeddings, self).__init__()
"""Construct the embedding module from word, position and token_type embeddings.
"""
hidden_size = config.hidden_size if config.emb_size < 0 else config.emb_size
self.word_embeddings = nn.Embedding(
config.vocab_size,
hidden_size,
padding_idx=1 if config.roberta_style else None)
self.position_embeddings = nn.Embedding(
config.max_position_embeddings,
hidden_size,
padding_idx=1 if config.roberta_style else None)
self.token_type_embeddings = nn.Embedding(config.type_vocab_size,
hidden_size)
self.config = config
self.proj = None if config.emb_size < 0 else nn.Linear(
config.emb_size, config.hidden_size)
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
# any TensorFlow checkpoint file
self.LayerNorm = BERTLayerNorm(config, special_size=hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, input_ids, token_type_ids=None, adv_embedding=None):
seq_length = input_ids.size(1)
if not self.config.roberta_style:
position_ids = torch.arange(
seq_length, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
else:
mask = input_ids.ne(1).int()
position_ids = (torch.cumsum(mask, dim=1).type_as(mask)
* mask).long() + 1
if token_type_ids is None:
token_type_ids = torch.zeros_like(input_ids)
words_embeddings = self.word_embeddings(
input_ids) if adv_embedding is None else adv_embedding
if self.config.set_mask_zero:
words_embeddings[input_ids == 103] = 0.
position_embeddings = self.position_embeddings(position_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
if not self.config.roberta_style:
embeddings = words_embeddings + position_embeddings + token_type_embeddings
else:
embeddings = words_embeddings + position_embeddings
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
if self.proj is not None:
embeddings = self.proj(embeddings)
embeddings = self.dropout(embeddings)
else:
return embeddings, words_embeddings
class BERTFactorizedAttention(nn.Module):
def __init__(self, config):
super(BERTFactorizedAttention, self).__init__()
if config.hidden_size % config.num_attention_heads != 0:
raise ValueError(
'The hidden size (%d) is not a multiple of the number of attention '
'heads (%d)' %
(config.hidden_size, config.num_attention_heads))
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size
/ config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(config.hidden_size, self.all_head_size)
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x, *size):
new_x_shape = x.size()[:-1] + (self.num_attention_heads,
self.attention_head_size)
x = x.view(*new_x_shape)
return x.permute(size)
def forward(self, hidden_states, attention_mask):
mixed_query_layer = self.query(hidden_states)
mixed_key_layer = self.key(hidden_states)
mixed_value_layer = self.value(hidden_states)
query_layer = self.transpose_for_scores(mixed_query_layer, 0, 2, 3, 1)
key_layer = self.transpose_for_scores(mixed_key_layer, 0, 2, 1, 3)
value_layer = self.transpose_for_scores(mixed_value_layer, 0, 2, 1, 3)
s_attention_scores = query_layer + attention_mask
s_attention_probs = nn.Softmax(dim=-1)(s_attention_scores)
s_attention_probs = self.dropout(s_attention_probs)
c_attention_probs = nn.Softmax(dim=-1)(key_layer)
s_context_layer = torch.matmul(s_attention_probs, value_layer)
context_layer = torch.matmul(c_attention_probs, s_context_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (
self.all_head_size, )
context_layer = context_layer.view(*new_context_layer_shape)
return context_layer
def dim_dropout(x, p=0, dim=-1, training=False):
if not training or p == 0:
return x
a = (1 - p)
b = (x.data.new(x.size()).zero_() + 1)
dropout_mask = torch.bernoulli(a * b)
return dropout_mask * (dropout_mask.size(dim) / torch.sum(
dropout_mask, dim=dim, keepdim=True)) * x
class BERTSelfAttention(nn.Module):
def __init__(self, config):
super(BERTSelfAttention, self).__init__()
if config.hidden_size % config.num_attention_heads != 0:
raise ValueError(
'The hidden size (%d) is not a multiple of the number of attention '
'heads (%d)' %
(config.hidden_size, config.num_attention_heads))
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size
/ config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(config.hidden_size, self.all_head_size)
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.config = config
if config.pre_ln:
self.LayerNorm = BERTLayerNorm(config)
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads,
self.attention_head_size)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(self, hidden_states, attention_mask, head_mask=None):
if self.config.pre_ln:
hidden_states = self.LayerNorm(hidden_states)
mixed_query_layer = self.query(hidden_states)
mixed_key_layer = self.key(hidden_states)
mixed_value_layer = self.value(hidden_states)
query_layer = self.transpose_for_scores(mixed_query_layer)
key_layer = self.transpose_for_scores(mixed_key_layer)
value_layer = self.transpose_for_scores(mixed_value_layer)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer,
key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(
self.attention_head_size)
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
if head_mask is not None and not self.training:
for i, mask in enumerate(head_mask):
if head_mask[i] == 1:
attention_scores[:, i, :, :] = 0.
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
attention_probs = nn.Softmax(dim=-1)(attention_scores)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
if not self.config.dim_dropout:
attention_probs = self.dropout(attention_probs)
else:
attention_probs = dim_dropout(
attention_probs,
p=self.config.attention_probs_dropout_prob,
dim=-1,
training=self.training)
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (
self.all_head_size, )
context_layer = context_layer.view(*new_context_layer_shape)
return context_layer
class BERTSelfOutput(nn.Module):
def __init__(self, config):
super(BERTSelfOutput, self).__init__()
self.config = config
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
if not config.pre_ln and not config.rezero:
self.LayerNorm = BERTLayerNorm(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
if config.rezero:
self.res_factor = nn.Parameter(
torch.Tensor(1).fill_(0.99).to(
dtype=next(self.parameters()).dtype))
self.factor = nn.Parameter(
torch.ones(1).to(dtype=next(self.parameters()).dtype))
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
if not self.config.rezero and not self.config.pre_ln:
hidden_states = self.LayerNorm(hidden_states + input_tensor)
elif self.config.rezero:
hidden_states = hidden_states + self.factor * input_tensor
else:
pass
return hidden_states
class BERTAttention(nn.Module):
def __init__(self, config):
super(BERTAttention, self).__init__()
if config.attention_type.lower() == 'self':
self.self = BERTSelfAttention(config)
elif config.attention_type.lower() == 'factorized':
self.self = BERTFactorizedAttention(config)
else:
raise ValueError(
'Attention type must in [self, factorized], but got {}'.format(
config.attention_type))
self.output = BERTSelfOutput(config)
def forward(self, input_tensor, attention_mask, head_mask=None):
self_output = self.self(input_tensor, attention_mask, head_mask)
attention_output = self.output(self_output, input_tensor)
return attention_output
class DepthwiseSeparableConv1d(nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0,
dilation=1,
bias=False):
super(DepthwiseSeparableConv1d, self).__init__()
padding = (kernel_size - 1) // 2
self.depthwise = nn.Conv1d(
in_channels,
in_channels,
kernel_size,
stride,
padding,
dilation,
groups=in_channels,
bias=bias)
self.pointwise = nn.Conv1d(
in_channels, out_channels, 1, 1, 0, 1, 1, bias=bias)
def forward(self, x):
x = self.depthwise(x)
x = self.pointwise(x)
return x
class BERTIntermediate(nn.Module):
def __init__(self, config):
super(BERTIntermediate, self).__init__()
self.config = config
if self.config.pre_ln:
self.LayerNorm = BERTLayerNorm(config)
self.intermediate_act_fn = gelu
if config.transition_function.lower() == 'linear':
self.dense = nn.Linear(config.hidden_size,
config.intermediate_size)
elif config.transition_function.lower() == 'cnn':
self.cnn = DepthwiseSeparableConv1d(
config.hidden_size, 4 * config.hidden_size, kernel_size=7)
elif config.config.hidden_size.lower() == 'rnn':
raise NotImplementedError(
'rnn transition function is not implemented yet')
else:
raise ValueError('Only support linear/cnn/rnn')
def forward(self, hidden_states):
if self.config.pre_ln:
hidden_states = self.LayerNorm(hidden_states)
if self.config.transition_function.lower() == 'linear':
hidden_states = self.dense(hidden_states)
elif self.config.transition_function.lower() == 'cnn':
hidden_states = self.cnn(hidden_states.transpose(-1,
-2)).transpose(
-1, -2)
else:
pass
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
class SqueezeExcitationBlock(nn.Module):
def __init__(self, config):
super(SqueezeExcitationBlock, self).__init__()
self.down_sampling = nn.Linear(config.hidden_size,
config.hidden_size // 4)
self.up_sampling = nn.Linear(config.hidden_size // 4,
config.hidden_size)
def forward(self, hidden_states):
squeeze = torch.mean(hidden_states, 1, keepdim=True)
excitation = torch.sigmoid(
self.up_sampling(gelu(self.down_sampling(squeeze))))
return hidden_states * excitation
class BERTOutput(nn.Module):
def __init__(self, config):
super(BERTOutput, self).__init__()
self.config = config
if config.transition_function.lower() == 'linear':
self.dense = nn.Linear(config.intermediate_size,
config.hidden_size)
elif config.transition_function.lower() == 'cnn':
self.cnn = DepthwiseSeparableConv1d(
4 * config.hidden_size, config.hidden_size, kernel_size=7)
elif config.config.hidden_size.lower() == 'rnn':
raise NotImplementedError(
'rnn transition function is not implemented yet')
else:
raise ValueError('Only support linear/cnn/rnn')
if not config.pre_ln and not config.rezero:
self.LayerNorm = BERTLayerNorm(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
if config.squeeze_excitation:
self.SEblock = SqueezeExcitationBlock(config)
if config.rezero:
self.res_factor = nn.Parameter(
torch.Tensor(1).fill_(0.99).to(
dtype=next(self.parameters()).dtype))
self.factor = nn.Parameter(
torch.ones(1).to(dtype=next(self.parameters()).dtype))
def forward(self, hidden_states, input_tensor):
if self.config.transition_function.lower() == 'linear':
hidden_states = self.dense(hidden_states)
elif self.config.transition_function.lower() == 'cnn':
hidden_states = self.cnn(hidden_states.transpose(-1,
-2)).transpose(
-1, -2)
else:
pass
hidden_states = self.dropout(hidden_states)
if self.config.squeeze_excitation:
hidden_states = self.SEblock(hidden_states)
if not self.config.rezero and not self.config.pre_ln:
hidden_states = self.LayerNorm(hidden_states + input_tensor)
elif self.config.rezero:
hidden_states = hidden_states + self.factor * input_tensor
else:
pass
return hidden_states
class BERTLayer(nn.Module):
def __init__(self, config):
super(BERTLayer, self).__init__()
self.attention = BERTAttention(config)
self.intermediate = BERTIntermediate(config)
self.output = BERTOutput(config)
def forward(self, hidden_states, attention_mask, head_mask=None):
attention_output = self.attention(hidden_states, attention_mask,
head_mask)
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
return attention_output, layer_output
class BERTWeightedLayer(nn.Module):
def __init__(self, config):
super(BERTWeightedLayer, self).__init__()
self.config = config
self.self = BERTSelfAttention(config)
self.attention_head_size = self.self.attention_head_size
self.w_o = nn.ModuleList([
nn.Linear(self.attention_head_size, config.hidden_size)
for _ in range(config.num_attention_heads)
])
self.w_kp = torch.rand(config.num_attention_heads)
self.w_kp = nn.Parameter(self.w_kp / self.w_kp.sum())
self.w_a = torch.rand(config.num_attention_heads)
self.w_a = nn.Parameter(self.w_a / self.w_a.sum())
self.intermediate = BERTIntermediate(config)
self.output = nn.Linear(config.intermediate_size, config.hidden_size)
self.LayerNorm = BERTLayerNorm(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, attention_mask):
self_output = self.self(hidden_states, attention_mask)
self_outputs = self_output.split(self.self.attention_head_size, dim=-1)
self_outputs = [
self.w_o[i](self_outputs[i]) for i in range(len(self_outputs))
]
self_outputs = [
self.dropout(self_outputs[i]) for i in range(len(self_outputs))
]
self_outputs = [
kappa * output for kappa, output in zip(self.w_kp, self_outputs)
]
self_outputs = [
self.intermediate(self_outputs[i])
for i in range(len(self_outputs))
]
self_outputs = [
self.output(self_outputs[i]) for i in range(len(self_outputs))
]
self_outputs = [
self.dropout(self_outputs[i]) for i in range(len(self_outputs))
]
self_outputs = [
alpha * output for alpha, output in zip(self.w_a, self_outputs)
]
output = sum(self_outputs)
return self.LayerNorm(hidden_states + output)
class BERTEncoder(nn.Module):
def __init__(self, config):
super(BERTEncoder, self).__init__()
self.layer = nn.ModuleList()
for _ in range(config.num_hidden_layers):
if config.weighted_transformer:
self.layer.append(BERTWeightedLayer(config))
else:
self.layer.append(BERTLayer(config))
if config.rezero:
for index, layer in enumerate(self.layer):
layer.output.res_factor = nn.Parameter(
torch.Tensor(1).fill_(1.).to(
dtype=next(self.parameters()).dtype))
layer.output.factor = nn.Parameter(
torch.Tensor(1).fill_(1).to(
dtype=next(self.parameters()).dtype))
layer.attention.output.res_factor = layer.output.res_factor
layer.attention.output.factor = layer.output.factor
self.config = config
def forward(self,
hidden_states,
attention_mask,
epoch_id=-1,
head_masks=None):
all_encoder_layers = [hidden_states]
if epoch_id != -1:
detach_index = int(len(self.layer) / 3) * (2 - epoch_id) - 1
else:
detach_index = -1
for index, layer_module in enumerate(self.layer):
if head_masks is None:
if not self.config.grad_checkpoint:
self_out, hidden_states = layer_module(
hidden_states, attention_mask, None)
else:
self_out, hidden_states = torch.utils.checkpoint.checkpoint(
layer_module, hidden_states, attention_mask, None)
else:
self_out, hidden_states = layer_module(hidden_states,
attention_mask,
head_masks[index])
if detach_index == index:
hidden_states.detach_()
all_encoder_layers.append(self_out)
all_encoder_layers.append(hidden_states)
return all_encoder_layers
class BERTEncoderRolled(nn.Module):
def __init__(self, config):
super(BERTEncoderRolled, self).__init__()
layer = BERTLayer(config)
self.config = config
self.layer = nn.ModuleList(
[copy.deepcopy(layer) for _ in range(config.num_rolled_layers)])
def forward(self,
hidden_states,
attention_mask,
epoch_id=-1,
head_masks=None):
all_encoder_layers = [hidden_states]
for i in range(self.config.num_hidden_layers):
if self.config.transformer_type.lower() == 'universal':
hidden_states = self.layer[i % self.config.num_rolled_layers](
hidden_states, attention_mask)
elif self.config.transformer_type.lower() == 'albert':
a = i // (
self.config.num_hidden_layers
// self.config.num_rolled_layers)
hidden_states = self.layer[a](hidden_states, attention_mask)
all_encoder_layers.append(hidden_states)
return all_encoder_layers
class BERTEncoderACT(nn.Module):
def __init__(self, config):
super(BERTEncoderACT, self).__init__()
self.layer = BERTLayer(config)
p = nn.Linear(config.hidden_size, 1)
self.p = nn.ModuleList(
[copy.deepcopy(p) for _ in range(config.num_hidden_layers)])
# Following act paper, set bias init ones
for module in self.p:
module.bias.data.fill_(1.)
self.config = config
self.act_max_steps = config.num_hidden_layers
self.threshold = 0.99
def should_continue(self, halting_probability, n_updates):
return (halting_probability.lt(self.threshold).__and__(
n_updates.lt(self.act_max_steps))).any()
def forward(self, hidden_states, attention_mask):
all_encoder_layers = [hidden_states]
batch_size, seq_len, hdim = hidden_states.size()
halting_probability = torch.zeros(batch_size, seq_len).cuda()
remainders = torch.zeros(batch_size, seq_len).cuda()
n_updates = torch.zeros(batch_size, seq_len).cuda()
for i in range(self.act_max_steps):
p = torch.sigmoid(self.p[i](hidden_states).squeeze(2))
still_running = halting_probability.lt(1.0).float()
new_halted = (halting_probability + p * still_running).gt(
self.threshold).float() * still_running
still_running = (halting_probability + p * still_running).le(
self.threshold).float() * still_running
halting_probability = halting_probability + p * still_running
remainders = remainders + new_halted * (1 - halting_probability)
halting_probability = halting_probability + new_halted * remainders
n_updates = n_updates + still_running + new_halted
update_weights = (p * still_running
+ new_halted * remainders).unsqueeze(2)
transformed_states = self.layer(hidden_states, attention_mask)
hidden_states = transformed_states * update_weights + hidden_states * (
1 - update_weights)
all_encoder_layers.append(hidden_states)
if not self.should_continue(halting_probability, n_updates):
break
return all_encoder_layers, torch.mean(n_updates + remainders)
class BERTPooler(nn.Module):
def __init__(self, config):
super(BERTPooler, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
def forward(self, hidden_states):
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
class BertModel(nn.Module):
"""BERT model ("Bidirectional Embedding Representations from a Transformer").
Example usage:
```python
# Already been converted into WordPiece token ids
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]])
config = modeling.BertConfig(vocab_size=32000, hidden_size=512,
num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024)
model = modeling.BertModel(config=config)
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config: BertConfig):
"""Constructor for BertModel.
Args:
config: `BertConfig` instance.
"""
super(BertModel, self).__init__()
self.config = config
self.embeddings = BERTEmbeddings(config)
if config.transformer_type.lower() == 'original':
self.encoder = BERTEncoder(config)
elif config.transformer_type.lower() == 'universal':
self.encoder = BERTEncoderRolled(config)
elif config.transformer_type.lower() == 'albert':
self.encoder = BERTEncoderRolled(config)
elif config.transformer_type.lower() == 'act':
self.encoder = BERTEncoderACT(config)
elif config.transformer_type.lower() == 'textnas':
from textnas_final import op_dict, input_dict, skip_dict
self.encoder = TextNASEncoder(config, op_dict, input_dict,
skip_dict)
else:
raise ValueError('Not support transformer type: {}'.format(
config.transformer_type.lower()))
self.pooler = BERTPooler(config)
def forward(self,
input_ids,
token_type_ids=None,
attention_mask=None,
epoch_id=-1,
head_masks=None,
adv_embedding=None):
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
if token_type_ids is None:
token_type_ids = torch.zeros_like(input_ids)
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask = extended_attention_mask.to(
dtype=next(self.parameters()).dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
embedding_output, word_embeddings = self.embeddings(
input_ids, token_type_ids, adv_embedding)
if self.config.transformer_type.lower() == 'act':
all_encoder_layers, act_loss = self.encoder(
embedding_output, extended_attention_mask)
elif self.config.transformer_type.lower() == 'reformer':
sequence_output = self.encoder(embedding_output)
all_encoder_layers = [sequence_output, sequence_output]
else:
all_encoder_layers = self.encoder(embedding_output,
extended_attention_mask,
epoch_id, head_masks)
all_encoder_layers.insert(0, word_embeddings)
sequence_output = all_encoder_layers[-1]
if not self.config.safer_fp16:
pooled_output = self.pooler(sequence_output)
else:
pooled_output = sequence_output[:, 0]
return all_encoder_layers, pooled_output
class BertForSequenceClassificationMultiTask(nn.Module):
"""BERT model for classification.
This module is composed of the BERT model with a linear layer on top of
the pooled output.
Example usage:
```python
# Already been converted into WordPiece token ids
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]])
config = BertConfig(vocab_size=32000, hidden_size=512,
num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024)
num_labels = 2
model = BertForSequenceClassification(config, num_labels)
logits = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config, label_list, core_encoder):
super(BertForSequenceClassificationMultiTask, self).__init__()
if core_encoder.lower() == 'bert':
self.bert = BertModel(config)
elif core_encoder.lower() == 'lstm':
self.bert = LSTMModel(config)
else:
raise ValueError(
'Only support lstm or bert, but got {}'.format(core_encoder))
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.ModuleList()
for label in label_list:
self.classifier.append(nn.Linear(config.hidden_size, len(label)))
self.label_list = label_list
def init_weights(module):
if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(
mean=0.0, std=config.initializer_range)
elif isinstance(module, BERTLayerNorm):
module.beta.data.normal_(
mean=0.0, std=config.initializer_range)
module.gamma.data.normal_(
mean=0.0, std=config.initializer_range)
if isinstance(module, nn.Linear):
module.bias.data.zero_()
self.apply(init_weights)
def forward(self,
input_ids,
token_type_ids,
attention_mask,
labels=None,
labels_index=None,
epoch_id=-1,
head_masks=None,
adv_embedding=None,
return_embedding=False,
loss_weight=None):
all_encoder_layers, pooled_output = self.bert(input_ids,
token_type_ids,
attention_mask, epoch_id,
head_masks,
adv_embedding)
pooled_output = self.dropout(pooled_output)
logits = [classifier(pooled_output) for classifier in self.classifier]
if labels is not None:
loss_fct = CrossEntropyLoss(reduction='none')
regression_loss_fct = nn.MSELoss(reduction='none')
labels_lst = torch.unbind(labels, 1)
loss_lst = []
for index, (label, logit) in enumerate(zip(labels_lst, logits)):
if len(self.label_list[index]) != 1:
loss = loss_fct(logit, label.long())
else:
loss = regression_loss_fct(logit.squeeze(-1), label)
labels_mask = (labels_index == index).to(
dtype=next(self.parameters()).dtype)
if loss_weight is not None:
loss = loss * loss_weight[index]
loss = torch.mean(loss * labels_mask)
loss_lst.append(loss)
if not return_embedding:
return sum(loss_lst), logits
else:
return sum(loss_lst), logits, all_encoder_layers[0]
else:
return logits

View File

@@ -0,0 +1,333 @@
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team and Alibaba inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes."""
from __future__ import absolute_import, division, print_function
import collections
import unicodedata
import six
def convert_to_unicode(text):
"""Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
if six.PY3:
if isinstance(text, str):
return text
elif isinstance(text, bytes):
return text.decode('utf-8', 'ignore')
else:
raise ValueError('Unsupported string type: %s' % (type(text)))
elif six.PY2:
if isinstance(text, str):
return text.decode('utf-8', 'ignore')
elif isinstance(text, unicode):
return text
else:
raise ValueError('Unsupported string type: %s' % (type(text)))
else:
raise ValueError('Not running on Python2 or Python 3?')
def printable_text(text):
"""Returns text encoded in a way suitable for print or `tf.logging`."""
# These functions want `str` for both Python2 and Python3, but in one case
# it's a Unicode string and in the other it's a byte string.
if six.PY3:
if isinstance(text, str):
return text
elif isinstance(text, bytes):
return text.decode('utf-8', 'ignore')
else:
raise ValueError('Unsupported string type: %s' % (type(text)))
elif six.PY2:
if isinstance(text, str):
return text
elif isinstance(text, unicode):
return text.encode('utf-8')
else:
raise ValueError('Unsupported string type: %s' % (type(text)))
else:
raise ValueError('Not running on Python2 or Python 3?')
def load_vocab(vocab_file):
"""Loads a vocabulary file into a dictionary."""
vocab = collections.OrderedDict()
index = 0
with open(vocab_file, 'r') as reader:
while True:
token = convert_to_unicode(reader.readline())
if not token:
break
token = token.strip()
vocab[token] = index
index += 1
return vocab
def convert_tokens_to_ids(vocab, tokens):
"""Converts a sequence of tokens into ids using the vocab."""
ids = []
for token in tokens:
ids.append(vocab[token])
return ids
def whitespace_tokenize(text):
"""Runs basic whitespace cleaning and splitting on a peice of text."""
text = text.strip()
if not text:
return []
tokens = text.split()
return tokens
class FullTokenizer(object):
"""Runs end-to-end tokenziation."""
def __init__(self, vocab_file, do_lower_case=True):
self.vocab = load_vocab(vocab_file)
self.inv_vocab = {v: k for k, v in self.vocab.items()}
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
def tokenize(self, text):
split_tokens = []
for token in self.basic_tokenizer.tokenize(text):
for sub_token in self.wordpiece_tokenizer.tokenize(token):
split_tokens.append(sub_token)
return split_tokens
def convert_tokens_to_ids(self, tokens):
return convert_tokens_to_ids(self.vocab, tokens)
def convert_ids_to_tokens(self, ids):
return [self.inv_vocab[i] for i in ids]
class BasicTokenizer(object):
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
def __init__(self, do_lower_case=True):
"""Constructs a BasicTokenizer.
Args:
do_lower_case: Whether to lower case the input.
"""
self.do_lower_case = do_lower_case
def tokenize(self, text):
"""Tokenizes a piece of text."""
text = convert_to_unicode(text)
text = self._clean_text(text)
# This was added on November 1st, 2018 for the multilingual and Chinese
# models. This is also applied to the English models now, but it doesn't
# matter since the English models were not trained on any Chinese data
# and generally don't have any Chinese data in them (there are Chinese
# characters in the vocabulary because Wikipedia does have some Chinese
# words in the English Wikipedia.).
text = self._tokenize_chinese_chars(text)
orig_tokens = whitespace_tokenize(text)
split_tokens = []
for token in orig_tokens:
if self.do_lower_case:
token = token.lower()
token = self._run_strip_accents(token)
split_tokens.extend(self._run_split_on_punc(token))
output_tokens = whitespace_tokenize(' '.join(split_tokens))
return output_tokens
def _run_strip_accents(self, text):
"""Strips accents from a piece of text."""
text = unicodedata.normalize('NFD', text)
output = []
for char in text:
cat = unicodedata.category(char)
if cat == 'Mn':
continue
output.append(char)
return ''.join(output)
def _run_split_on_punc(self, text):
"""Splits punctuation on a piece of text."""
chars = list(text)
i = 0
start_new_word = True
output = []
while i < len(chars):
char = chars[i]
if _is_punctuation(char):
output.append([char])
start_new_word = True
else:
if start_new_word:
output.append([])
start_new_word = False
output[-1].append(char)
i += 1
return [''.join(x) for x in output]
def _tokenize_chinese_chars(self, text):
"""Adds whitespace around any CJK character."""
output = []
for char in text:
cp = ord(char)
if self._is_chinese_char(cp):
output.append(' ')
output.append(char)
output.append(' ')
else:
output.append(char)
return ''.join(output)
def _is_chinese_char(self, cp):
"""Checks whether CP is the codepoint of a CJK character."""
# This defines a "chinese character" as anything in the CJK Unicode block:
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
#
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
# despite its name. The modern Korean Hangul alphabet is a different block,
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
# space-separated words, so they are not treated specially and handled
# like the all of the other languages.
if ((cp >= 0x4E00 and cp <= 0x9FFF) or (cp >= 0x3400 and cp <= 0x4DBF)
or (cp >= 0x20000 and cp <= 0x2A6DF)
or (cp >= 0x2A700 and cp <= 0x2B73F)
or (cp >= 0x2B740 and cp <= 0x2B81F)
or (cp >= 0x2B820 and cp <= 0x2CEAF)
or (cp >= 0xF900 and cp <= 0xFAFF)
or (cp >= 0x2F800 and cp <= 0x2FA1F)):
return True
return False
def _clean_text(self, text):
"""Performs invalid character removal and whitespace cleanup on text."""
output = []
for char in text:
cp = ord(char)
if cp == 0 or cp == 0xfffd or _is_control(char):
continue
if _is_whitespace(char):
output.append(' ')
else:
output.append(char)
return ''.join(output)
class WordpieceTokenizer(object):
"""Runs WordPiece tokenization."""
def __init__(self, vocab, unk_token='[UNK]', max_input_chars_per_word=100):
self.vocab = vocab
self.unk_token = unk_token
self.max_input_chars_per_word = max_input_chars_per_word
def tokenize(self, text):
"""Tokenizes a piece of text into its word pieces.
This uses a greedy longest-match-first algorithm to perform tokenization
using the given vocabulary.
For example:
input = "unaffable"
output = ["un", "##aff", "##able"]
Args:
text: A single token or whitespace separated tokens. This should have
already been passed through `BasicTokenizer.
Returns:
A list of wordpiece tokens.
"""
text = convert_to_unicode(text)
output_tokens = []
for token in whitespace_tokenize(text):
chars = list(token)
if len(chars) > self.max_input_chars_per_word:
output_tokens.append(self.unk_token)
continue
is_bad = False
start = 0
sub_tokens = []
while start < len(chars):
end = len(chars)
cur_substr = None
while start < end:
substr = ''.join(chars[start:end])
if start > 0:
substr = '##' + substr
if substr in self.vocab:
cur_substr = substr
break
end -= 1
if cur_substr is None:
is_bad = True
break
sub_tokens.append(cur_substr)
start = end
if is_bad:
output_tokens.append(self.unk_token)
else:
output_tokens.extend(sub_tokens)
return output_tokens
def _is_whitespace(char):
"""Checks whether `chars` is a whitespace character."""
# \t, \n, and \r are technically contorl characters but we treat them
# as whitespace since they are generally considered as such.
if char == ' ' or char == '\t' or char == '\n' or char == '\r':
return True
cat = unicodedata.category(char)
if cat == 'Zs':
return True
return False
def _is_control(char):
"""Checks whether `chars` is a control character."""
# These are technically control characters but we count them as whitespace
# characters.
if char == '\t' or char == '\n' or char == '\r':
return False
cat = unicodedata.category(char)
if cat.startswith('C'):
return True
return False
def _is_punctuation(char):
"""Checks whether `chars` is a punctuation character."""
cp = ord(char)
# We treat all non-letter/number ASCII as punctuation.
# Characters such as "^", "$", and "`" are not in the Unicode
# Punctuation class but we treat them as punctuation anyways, for
# consistency.
if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64)
or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
return True
cat = unicodedata.category(char)
if cat.startswith('P'):
return True
return False

View File

@@ -0,0 +1,319 @@
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
__all__ = ['ImagenGenerator']
def sinusoidal_embedding(timesteps, dim):
# check input
half = dim // 2
timesteps = timesteps.float()
# compute sinusoidal embedding
sinusoid = torch.outer(
timesteps, torch.pow(10000,
-torch.arange(half).to(timesteps).div(half)))
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
if dim % 2 != 0:
x = torch.cat([x, torch.zeros_like(x[:, :1])], dim=1)
return x
class Resample(nn.Module):
def __init__(self, in_dim, out_dim, scale_factor, use_conv=False):
assert scale_factor in [0.5, 1.0, 2.0]
super(Resample, self).__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.scale_factor = scale_factor
self.use_conv = use_conv
# layers
if scale_factor == 2.0:
self.resample = nn.Sequential(
nn.Upsample(scale_factor=scale_factor, mode='nearest'),
nn.Conv2d(in_dim, out_dim, 3, padding=1)
if use_conv else nn.Identity())
elif scale_factor == 0.5:
self.resample = nn.Conv2d(
in_dim, out_dim, 3, stride=2,
padding=1) if use_conv else nn.AvgPool2d(
kernel_size=2, stride=2)
else:
self.resample = nn.Identity()
def forward(self, x):
return self.resample(x)
class ResidualBlock(nn.Module):
def __init__(self,
in_dim,
embed_dim,
out_dim,
use_scale_shift_norm=True,
scale_factor=1.0,
dropout=0.0):
super(ResidualBlock, self).__init__()
self.in_dim = in_dim
self.embed_dim = embed_dim
self.out_dim = out_dim
self.use_scale_shift_norm = use_scale_shift_norm
self.scale_factor = scale_factor
# layers
self.layer1 = nn.Sequential(
nn.GroupNorm(32, in_dim), nn.SiLU(),
nn.Conv2d(in_dim, out_dim, 3, padding=1))
self.resample = Resample(in_dim, in_dim, scale_factor, use_conv=False)
self.embedding = nn.Sequential(
nn.SiLU(),
nn.Linear(embed_dim,
out_dim * 2 if use_scale_shift_norm else out_dim))
self.layer2 = nn.Sequential(
nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout),
nn.Conv2d(out_dim, out_dim, 3, padding=1))
self.shortcut = nn.Identity() if in_dim == out_dim else nn.Conv2d(
in_dim, out_dim, 1)
# zero out the last layer params
nn.init.zeros_(self.layer2[-1].weight)
def forward(self, x, e):
identity = self.resample(x)
x = self.layer1[-1](self.resample(self.layer1[:-1](x)))
e = self.embedding(e).unsqueeze(-1).unsqueeze(-1).type(x.dtype)
if self.use_scale_shift_norm:
scale, shift = e.chunk(2, dim=1)
x = self.layer2[0](x) * (1 + scale) + shift
x = self.layer2[1:](x)
else:
x = x + e
x = self.layer2(x)
x = x + self.shortcut(identity)
return x
class AttentionBlock(nn.Module):
def __init__(self, dim, context_dim=None, num_heads=None, head_dim=None):
# consider head_dim first, then num_heads
num_heads = dim // head_dim if head_dim else num_heads
head_dim = dim // num_heads
assert num_heads * head_dim == dim
super(AttentionBlock, self).__init__()
self.dim = dim
self.context_dim = context_dim
self.num_heads = num_heads
self.head_dim = head_dim
self.scale = math.pow(head_dim, -0.25)
# layers
self.norm = nn.GroupNorm(32, dim)
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
if context_dim is not None:
self.context_kv = nn.Linear(context_dim, dim * 2)
self.proj = nn.Conv2d(dim, dim, 1)
# zero out the last layer params
nn.init.zeros_(self.proj.weight)
def forward(self, x, context=None, mask=None):
identity = x
b, c, h, w, n, d = *x.size(), self.num_heads, self.head_dim
# compute query, key, value
x = self.norm(x)
q, k, v = self.to_qkv(x).view(b, n * 3, d, h * w).chunk(3, dim=1)
if context is not None:
ck, cv = self.context_kv(context).reshape(b, -1, n * 2,
d).permute(0, 2, 3,
1).chunk(
2, dim=1)
k = torch.cat([ck, k], dim=-1)
v = torch.cat([cv, v], dim=-1)
# compute attention
attn = torch.matmul(q.transpose(-1, -2) * self.scale, k * self.scale)
if mask is not None:
assert context is not None
full_mask = x.new_ones((b, 1, q.size(-1), k.size(-1)))
full_mask[:, 0, :, :-q.size(-1)] = mask.unsqueeze(1)
attn = attn.masked_fill(full_mask == 0, float('-inf'))
attn = F.softmax(attn, dim=-1)
# gather context
x = torch.matmul(v, attn.transpose(-1, -2))
x = x.reshape(b, c, h, w)
# output
x = self.proj(x)
return x + identity
class ImagenGenerator(nn.Module):
def __init__(self,
in_dim=3,
dim=512,
text_dim=1024,
context_dim=512,
out_dim=6,
dim_mult=[1, 2, 3, 4],
num_heads=None,
head_dim=64,
num_res_blocks=3,
attn_scales=[1 / 2, 1 / 4, 1 / 8],
resblock_resample=True,
use_scale_shift_norm=True,
dropout=0.0):
embed_dim = dim * 4
super(ImagenGenerator, self).__init__()
self.in_dim = in_dim
self.dim = dim
self.text_dim = text_dim
self.context_dim = context_dim
self.embed_dim = embed_dim
self.out_dim = out_dim
self.dim_mult = dim_mult
self.num_heads = num_heads
self.head_dim = head_dim
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.resblock_resample = resblock_resample
self.use_scale_shift_norm = use_scale_shift_norm
# params
enc_dims = [dim * u for u in [1] + dim_mult]
dec_dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
shortcut_dims = []
scale = 1.0
# embeddings
self.time_embedding = nn.Sequential(
nn.Linear(dim, embed_dim), nn.SiLU(),
nn.Linear(embed_dim, embed_dim))
self.pool_embedding = nn.Sequential(
nn.LayerNorm(text_dim), nn.Linear(text_dim, embed_dim))
self.text_embedding = nn.Sequential(
nn.LayerNorm(text_dim), nn.Linear(text_dim, context_dim),
nn.SiLU(), nn.Linear(context_dim, context_dim))
# encoder
self.encoder = nn.ModuleList(
[nn.Conv2d(self.in_dim, dim, 3, padding=1)])
shortcut_dims.append(dim)
for i, (in_dim,
out_dim) in enumerate(zip(enc_dims[:-1], enc_dims[1:])):
for j in range(num_res_blocks):
# residual (+attention) blocks
block = nn.ModuleList(
[ResidualBlock(in_dim, embed_dim, out_dim, dropout)])
if scale in attn_scales:
block.append(
AttentionBlock(out_dim, context_dim, num_heads,
head_dim))
in_dim = out_dim
self.encoder.append(block)
shortcut_dims.append(out_dim)
# downsample
if i != len(dim_mult) - 1 and j == num_res_blocks - 1:
if resblock_resample:
downsample = ResidualBlock(out_dim, embed_dim, out_dim,
use_scale_shift_norm, 0.5,
dropout)
else:
downsample = Resample(
out_dim, out_dim, 0.5, use_conv=True)
shortcut_dims.append(out_dim)
scale /= 2.0
self.encoder.append(downsample)
# middle
self.middle = nn.ModuleList([
ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm,
1.0, dropout),
AttentionBlock(out_dim, context_dim, num_heads, head_dim),
ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm,
1.0, dropout)
])
# decoder
self.decoder = nn.ModuleList()
for i, (in_dim,
out_dim) in enumerate(zip(dec_dims[:-1], dec_dims[1:])):
for j in range(num_res_blocks + 1):
# residual (+attention) blocks
block = nn.ModuleList([
ResidualBlock(in_dim + shortcut_dims.pop(), embed_dim,
out_dim, use_scale_shift_norm, 1.0, dropout)
])
if scale in attn_scales:
block.append(
AttentionBlock(out_dim, context_dim, num_heads,
head_dim))
in_dim = out_dim
# upsample
if i != len(dim_mult) - 1 and j == num_res_blocks:
if resblock_resample:
upsample = ResidualBlock(out_dim, embed_dim, out_dim,
use_scale_shift_norm, 2.0,
dropout)
else:
upsample = Resample(
out_dim, out_dim, 2.0, use_conv=True)
scale *= 2.0
block.append(upsample)
self.decoder.append(block)
# head
self.head = nn.Sequential(
nn.GroupNorm(32, out_dim), nn.SiLU(),
nn.Conv2d(out_dim, self.out_dim, 3, padding=1))
# zero out the last layer params
nn.init.zeros_(self.head[-1].weight)
def forward(self, x, t, y, context, mask=None):
# embeddings
e = self.time_embedding(sinusoidal_embedding(
t, self.dim)) + self.pool_embedding(y)
context = self.text_embedding(context)
# encoder
xs = []
for block in self.encoder:
x = self._forward_single(block, x, e, context, mask)
xs.append(x)
# middle
for block in self.middle:
x = self._forward_single(block, x, e, context, mask)
# decoder
for block in self.decoder:
x = torch.cat([x, xs.pop()], dim=1)
x = self._forward_single(block, x, e, context, mask)
# head
x = self.head(x)
return x
def _forward_single(self, module, x, e, context, mask):
if isinstance(module, ResidualBlock):
x = module(x, e)
elif isinstance(module, AttentionBlock):
x = module(x, context, mask)
elif isinstance(module, nn.ModuleList):
for block in module:
x = self._forward_single(block, x, e, context, mask)
else:
x = module(x)
return x

View File

@@ -0,0 +1,337 @@
import math
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
__all__ = ['SuperResUNet256']
def sinusoidal_embedding(timesteps, dim):
# check input
half = dim // 2
timesteps = timesteps.float()
# compute sinusoidal embedding
sinusoid = torch.outer(
timesteps, torch.pow(10000,
-torch.arange(half).to(timesteps).div(half)))
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
if dim % 2 != 0:
x = torch.cat([x, torch.zeros_like(x[:, :1])], dim=1)
return x
class Resample(nn.Module):
def __init__(self, in_dim, out_dim, scale_factor, use_conv=False):
assert scale_factor in [0.5, 1.0, 2.0]
super(Resample, self).__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.scale_factor = scale_factor
self.use_conv = use_conv
# layers
if scale_factor == 2.0:
self.resample = nn.Sequential(
nn.Upsample(scale_factor=scale_factor, mode='nearest'),
nn.Conv2d(in_dim, out_dim, 3, padding=1)
if use_conv else nn.Identity())
elif scale_factor == 0.5:
self.resample = nn.Conv2d(
in_dim, out_dim, 3, stride=2,
padding=1) if use_conv else nn.AvgPool2d(
kernel_size=2, stride=2)
else:
self.resample = nn.Identity()
def forward(self, x):
return self.resample(x)
class ResidualBlock(nn.Module):
def __init__(self,
in_dim,
embed_dim,
out_dim,
use_scale_shift_norm=True,
scale_factor=1.0,
dropout=0.0):
super(ResidualBlock, self).__init__()
self.in_dim = in_dim
self.embed_dim = embed_dim
self.out_dim = out_dim
self.use_scale_shift_norm = use_scale_shift_norm
self.scale_factor = scale_factor
# layers
self.layer1 = nn.Sequential(
nn.GroupNorm(32, in_dim), nn.SiLU(),
nn.Conv2d(in_dim, out_dim, 3, padding=1))
self.resample_x = Resample(in_dim, in_dim, scale_factor)
self.resample_i = Resample(in_dim, in_dim, scale_factor)
self.embedding = nn.Sequential(
nn.SiLU(),
nn.Linear(embed_dim,
out_dim * 2 if use_scale_shift_norm else out_dim))
self.layer2 = nn.Sequential(
nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout),
nn.Conv2d(out_dim, out_dim, 3, padding=1))
self.shortcut = nn.Identity() if in_dim == out_dim else nn.Conv2d(
in_dim, out_dim, 1)
# zero out the last layer params
nn.init.zeros_(self.layer2[-1].weight)
def forward(self, x, e):
identity = self.resample_i(x)
x = self.layer1[-1](self.resample_x(self.layer1[:-1](x)))
e = self.embedding(e).unsqueeze(-1).unsqueeze(-1)
if self.use_scale_shift_norm:
scale, shift = e.chunk(2, dim=1)
x = self.layer2[0](x) * (1 + scale) + shift
x = self.layer2[1:](x)
else:
x = x + e
x = self.layer2(x)
x = x + self.shortcut(identity)
return x
class AttentionBlock(nn.Module):
def __init__(self, dim, context_dim=None, num_heads=None, head_dim=None):
# consider head_dim first, then num_heads
num_heads = dim // head_dim if head_dim else num_heads
head_dim = dim // num_heads
assert num_heads * head_dim == dim
super(AttentionBlock, self).__init__()
self.dim = dim
self.context_dim = context_dim
self.num_heads = num_heads
self.head_dim = head_dim
self.scale = math.pow(head_dim, -0.25)
# layers
self.norm = nn.GroupNorm(32, dim)
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
if context_dim is not None:
self.context_kv = nn.Linear(context_dim, dim * 2)
self.proj = nn.Conv2d(dim, dim, 1)
# zero out the last layer params
nn.init.zeros_(self.proj.weight)
def forward(self, x, context=None, mask=None):
r"""x: [B, C, H, W].
context: [B, L, C] or None.
mask: [B, L] or None.
"""
identity = x
b, c, h, w, n, d = *x.size(), self.num_heads, self.head_dim
# compute query, key, value
x = self.norm(x)
q, k, v = self.to_qkv(x).view(b, n * 3, d, h * w).chunk(3, dim=1)
if context is not None:
ck, cv = self.context_kv(context).reshape(b, -1, n * 2,
d).permute(0, 2, 3,
1).chunk(
2, dim=1)
k = torch.cat([k, ck], dim=-1)
v = torch.cat([v, cv], dim=-1)
# compute attention
attn = torch.einsum('bndi,bndj->bnij', q * self.scale, k * self.scale)
if mask is not None:
pad_mask = mask.new_ones((b, 1, 1, h * w))
mask = torch.cat((pad_mask, mask.unsqueeze(1).unsqueeze(1)),
dim=-1)
attn = attn.masked_fill(mask == 0, float('-inf'))
attn = F.softmax(attn, dim=-1)
# gather context
x = torch.einsum('bnij,bndj->bndi', attn, v)
x = x.reshape(b, c, h, w)
# output
x = self.proj(x)
return x + identity
class SuperResUNet256(nn.Module):
def __init__(self,
in_dim=6,
out_dim=3,
dim=256,
text_dim=1024,
context_dim=512,
dim_mult=[1, 2, 2, 3, 4],
num_heads=None,
head_dim=64,
num_res_blocks=2,
attn_scales=[1 / 16],
resblock_resample=True,
use_conv=True,
use_scale_shift_norm=True,
dropout=0.1):
embed_dim = dim * 4
super(SuperResUNet256, self).__init__()
self.in_dim = in_dim
self.dim = dim
self.out_dim = out_dim
self.dim_mult = dim_mult
self.num_heads = num_heads
self.head_dim = head_dim
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.resblock_resample = resblock_resample
self.use_conv = use_conv
self.use_scale_shift_norm = use_scale_shift_norm
# params
enc_dims = [dim * u for u in [1] + dim_mult]
dec_dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
shortcut_dims = []
scale = 1.0
# embeddings
self.time_embedding = nn.Sequential(
nn.Linear(dim, embed_dim), nn.SiLU(),
nn.Linear(embed_dim, embed_dim))
self.noise_time_embedding = nn.Sequential(
nn.Linear(dim, embed_dim), nn.SiLU(),
nn.Linear(embed_dim, embed_dim))
self.pool_embedding = nn.Sequential(
nn.LayerNorm(text_dim), nn.Linear(text_dim, embed_dim))
self.text_embedding = nn.Sequential(
nn.LayerNorm(text_dim), nn.Linear(text_dim, context_dim),
nn.SiLU(), nn.Linear(context_dim, context_dim))
# encoder
self.encoder = nn.ModuleList(
[nn.Conv2d(self.in_dim, dim, 3, padding=1)])
shortcut_dims.append(dim)
for i, (in_dim,
out_dim) in enumerate(zip(enc_dims[:-1], enc_dims[1:])):
for j in range(num_res_blocks):
# residual (+attention) blocks
block = nn.ModuleList([
ResidualBlock(in_dim, embed_dim, out_dim,
use_scale_shift_norm, 1.0, dropout)
])
if scale in attn_scales:
block.append(
AttentionBlock(out_dim, context_dim, num_heads,
head_dim))
shortcut_dims.append(out_dim)
in_dim = out_dim
self.encoder.append(block)
# downsample
if i != len(dim_mult) - 1:
if resblock_resample:
downsample = ResidualBlock(out_dim, embed_dim, out_dim,
use_scale_shift_norm, 0.5,
dropout)
else:
downsample = Resample(out_dim, out_dim, 0.5, use_conv)
shortcut_dims.append(out_dim)
scale /= 2.0
self.encoder.append(downsample)
# middle
self.middle = nn.ModuleList([
ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm,
1.0, dropout),
AttentionBlock(out_dim, context_dim, num_heads, head_dim),
ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm,
1.0, dropout)
])
# decoder
self.decoder = nn.ModuleList()
for i, (in_dim,
out_dim) in enumerate(zip(dec_dims[:-1], dec_dims[1:])):
for j in range(num_res_blocks + 1):
# residual (+attention) blocks
block = nn.ModuleList([
ResidualBlock(in_dim + shortcut_dims.pop(), embed_dim,
out_dim, use_scale_shift_norm, 1.0, dropout)
])
if scale in attn_scales:
block.append(
AttentionBlock(out_dim, context_dim, num_heads,
head_dim))
in_dim = out_dim
# upsample
if i != len(dim_mult) - 1 and j == num_res_blocks:
if resblock_resample:
upsample = ResidualBlock(out_dim, embed_dim, out_dim,
use_scale_shift_norm, 2.0,
dropout)
else:
upsample = Resample(out_dim, out_dim, 2.0, use_conv)
scale *= 2.0
block.append(upsample)
self.decoder.append(block)
# head
self.head = nn.Sequential(
nn.GroupNorm(32, out_dim), nn.SiLU(),
nn.Conv2d(out_dim, self.out_dim, 3, padding=1))
# zero out the last layer params
nn.init.zeros_(self.head[-1].weight)
def forward(self, x, t, lx, lt, y, context, mask):
assert context.shape[:-1] == mask.shape
# embeddings
t = self.time_embedding(sinusoidal_embedding(t, self.dim)) \
+ self.noise_time_embedding(sinusoidal_embedding(lt, self.dim)) \
+ self.pool_embedding(y)
context = self.text_embedding(context)
if lx.shape[-2:] != x.shape[-2:]:
lx = F.interpolate(
lx, x.shape[-2:], mode='bilinear', align_corners=False)
x = torch.cat([x, lx], dim=1)
# encoder
xs = []
for block in self.encoder:
x = self._forward_single(block, x, t, context, mask)
xs.append(x)
# middle
for block in self.middle:
x = self._forward_single(block, x, t, context, mask)
# decoder
for block in self.decoder:
x = torch.cat([x, xs.pop()], dim=1)
x = self._forward_single(block, x, t, context, mask)
# head
x = self.head(x)
return x
def _forward_single(self, module, x, t, context, mask):
if isinstance(module, ResidualBlock):
x = module(x, t)
elif isinstance(module, AttentionBlock):
x = module(x, context, mask)
elif isinstance(module, nn.ModuleList):
for block in module:
x = self._forward_single(block, x, t, context, mask)
else:
x = module(x)
return x

View File

@@ -0,0 +1,240 @@
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
__all__ = ['ImagenUpsampler1024']
def sinusoidal_embedding(timesteps, dim):
# check input
half = dim // 2
timesteps = timesteps.float()
# compute sinusoidal embedding
sinusoid = torch.outer(
timesteps, torch.pow(10000,
-torch.arange(half).to(timesteps).div(half)))
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
if dim % 2 != 0:
x = torch.cat([x, torch.zeros_like(x[:, :1])], dim=1)
return x
class Resample(nn.Module):
def __init__(self, in_dim, out_dim, scale_factor, use_conv=False):
assert scale_factor in [0.5, 1.0, 2.0]
super(Resample, self).__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.scale_factor = scale_factor
self.use_conv = use_conv
# layers
if scale_factor == 2.0:
self.resample = nn.Sequential(
nn.Upsample(scale_factor=scale_factor, mode='nearest'),
nn.Conv2d(in_dim, out_dim, 3, padding=1)
if use_conv else nn.Identity())
elif scale_factor == 0.5:
self.resample = nn.Conv2d(
in_dim, out_dim, 3, stride=2,
padding=1) if use_conv else nn.AvgPool2d(
kernel_size=2, stride=2)
else:
self.resample = nn.Identity()
def forward(self, x):
return self.resample(x)
class ResidualBlock(nn.Module):
def __init__(self,
in_dim,
embed_dim,
out_dim,
use_scale_shift_norm=True,
scale_factor=1.0,
dropout=0.0):
super(ResidualBlock, self).__init__()
self.in_dim = in_dim
self.embed_dim = embed_dim
self.out_dim = out_dim
self.use_scale_shift_norm = use_scale_shift_norm
self.scale_factor = scale_factor
# layers
self.layer1 = nn.Sequential(
nn.GroupNorm(32, in_dim), nn.SiLU(),
nn.Conv2d(in_dim, out_dim, 3, padding=1))
self.resample = Resample(in_dim, in_dim, scale_factor, use_conv=False)
self.embedding = nn.Sequential(
nn.SiLU(),
nn.Linear(embed_dim,
out_dim * 2 if use_scale_shift_norm else out_dim))
self.layer2 = nn.Sequential(
nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout),
nn.Conv2d(out_dim, out_dim, 3, padding=1))
self.shortcut = nn.Identity() if in_dim == out_dim else nn.Conv2d(
in_dim, out_dim, 1)
# zero out the last layer params
nn.init.zeros_(self.layer2[-1].weight)
def forward(self, x, e):
identity = self.resample(x)
x = self.layer1[-1](self.resample(self.layer1[:-1](x)))
e = self.embedding(e).unsqueeze(-1).unsqueeze(-1).type(x.dtype)
if self.use_scale_shift_norm:
scale, shift = e.chunk(2, dim=1)
x = self.layer2[0](x) * (1 + scale) + shift
x = self.layer2[1:](x)
else:
x = x + e
x = self.layer2(x)
x = x + self.shortcut(identity)
return x
class ImagenUpsampler1024(nn.Module):
def __init__(self,
in_dim=6,
dim=192,
out_dim=3,
dim_mult=[1, 1, 2, 2, 4, 4],
num_res_blocks=2,
resblock_resample=True,
use_scale_shift_norm=True,
dropout=0.0):
embed_dim = dim * 4
super(ImagenUpsampler1024, self).__init__()
self.in_dim = in_dim
self.dim = dim
self.out_dim = out_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.resblock_resample = resblock_resample
self.use_scale_shift_norm = use_scale_shift_norm
# params
enc_dims = [dim * u for u in [1] + dim_mult]
dec_dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
shortcut_dims = []
scale = 1.0
# embedding
self.time_embedding = nn.Sequential(
nn.Linear(dim, embed_dim), nn.SiLU(),
nn.Linear(embed_dim, embed_dim))
# encoder
self.encoder = nn.ModuleList(
[nn.Conv2d(self.in_dim, dim, 3, padding=1)])
shortcut_dims.append(dim)
for i, (in_dim,
out_dim) in enumerate(zip(enc_dims[:-1], enc_dims[1:])):
for j in range(num_res_blocks):
# residual block
block = nn.ModuleList([
ResidualBlock(in_dim, embed_dim, out_dim,
use_scale_shift_norm, 1.0, dropout)
])
shortcut_dims.append(out_dim)
in_dim = out_dim
self.encoder.append(block)
# downsample
if i != len(dim_mult) - 1 and j == num_res_blocks - 1:
if resblock_resample:
downsample = ResidualBlock(out_dim, embed_dim, out_dim,
use_scale_shift_norm, 0.5,
dropout)
else:
downsample = Resample(
out_dim, out_dim, 0.5, use_conv=True)
shortcut_dims.append(out_dim)
scale /= 2.0
self.encoder.append(downsample)
# middle
self.middle = nn.ModuleList([
ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm,
1.0, dropout),
ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm,
1.0, dropout)
])
# decoder
self.decoder = nn.ModuleList()
for i, (in_dim,
out_dim) in enumerate(zip(dec_dims[:-1], dec_dims[1:])):
for j in range(num_res_blocks + 1):
# residual block
block = nn.ModuleList([
ResidualBlock(in_dim + shortcut_dims.pop(), embed_dim,
out_dim, use_scale_shift_norm, 1.0, dropout)
])
in_dim = out_dim
# upsample
if i != len(dim_mult) - 1 and j == num_res_blocks:
if resblock_resample:
upsample = ResidualBlock(out_dim, embed_dim, out_dim,
use_scale_shift_norm, 2.0,
dropout)
else:
upsample = Resample(
out_dim, out_dim, 2.0, use_conv=True)
scale *= 2.0
block.append(upsample)
self.decoder.append(block)
# head
self.head = nn.Sequential(
nn.GroupNorm(32, out_dim), nn.SiLU(),
nn.Conv2d(out_dim, self.out_dim, 3, padding=1))
# zero out the last layer params
nn.init.zeros_(self.head[-1].weight)
def forward(self, x, t, concat):
# embedding
if concat is not None:
if concat.shape[-2:] != x.shape[-2:]:
concat = F.interpolate(
concat, x.shape[-2:], mode='bilinear', align_corners=False)
x = torch.cat([x, concat], dim=1)
e = self.time_embedding(sinusoidal_embedding(t, self.dim))
# encoder
xs = []
for block in self.encoder:
x = self._forward_single(block, x, e)
xs.append(x)
# middle
for block in self.middle:
x = self._forward_single(block, x, e)
# decoder
for block in self.decoder:
x = torch.cat([x, xs.pop()], dim=1)
x = self._forward_single(block, x, e)
# head
x = self.head(x)
return x
def _forward_single(self, module, x, e):
if isinstance(module, ResidualBlock):
x = module(x, e)
elif isinstance(module, nn.ModuleList):
for block in module:
x = self._forward_single(block, x, e)
else:
x = module(x)
return x

View File

@@ -1,6 +1,7 @@
try:
from .image_captioning_pipeline import ImageCaptionPipeline
from .multi_modal_embedding_pipeline import MultiModalEmbeddingPipeline
from .text_to_image_synthesis_pipeline import TextToImageSynthesisPipeline
from .visual_question_answering_pipeline import VisualQuestionAnsweringPipeline
except ModuleNotFoundError as e:
if str(e) == "No module named 'torch'":

View File

@@ -0,0 +1,37 @@
from typing import Any, Dict, Union
from modelscope.metainfo import Pipelines
from modelscope.pipelines.base import Input
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger
from ..base import Model, Pipeline
from ..builder import PIPELINES
from ..outputs import OutputKeys
logger = get_logger()
@PIPELINES.register_module(
Tasks.text_to_image_synthesis,
module_name=Pipelines.text_to_image_synthesis)
class TextToImageSynthesisPipeline(Pipeline):
def __init__(self, model: str, device_id: int = -1):
if isinstance(model, str):
pipe_model = Model.from_pretrained(model)
elif isinstance(model, Model):
pipe_model = model
else:
raise NotImplementedError(
f'execpting a Model instance or str, but get {type(model)}.')
super().__init__(model=pipe_model)
def preprocess(self, input: Input) -> Dict[str, Any]:
return input
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
return self.model.generate(input)
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
return {OutputKeys.OUTPUT_IMG: inputs}

View File

@@ -0,0 +1,45 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import unittest
import numpy as np
from modelscope.models import Model
from modelscope.pipelines import pipeline
from modelscope.pipelines.outputs import OutputKeys
from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level
class TextToImageSynthesisTest(unittest.TestCase):
model_id = 'damo/cv_imagen_text-to-image-synthesis_tiny'
test_text = {'text': '宇航员'}
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_model_from_modelhub(self):
model = Model.from_pretrained(self.model_id)
pipe_line_text_to_image_synthesis = pipeline(
task=Tasks.text_to_image_synthesis, model=model)
img = pipe_line_text_to_image_synthesis(
self.test_text)[OutputKeys.OUTPUT_IMG]
print(np.sum(np.abs(img)))
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_model_name(self):
pipe_line_text_to_image_synthesis = pipeline(
task=Tasks.text_to_image_synthesis, model=self.model_id)
img = pipe_line_text_to_image_synthesis(
self.test_text)[OutputKeys.OUTPUT_IMG]
print(np.sum(np.abs(img)))
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_default_model(self):
pipe_line_text_to_image_synthesis = pipeline(
task=Tasks.text_to_image_synthesis)
img = pipe_line_text_to_image_synthesis(
self.test_text)[OutputKeys.OUTPUT_IMG]
print(np.sum(np.abs(img)))
if __name__ == '__main__':
unittest.main()