mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-25 04:29:22 +01:00
Merge remote-tracking branch 'origin/master' into nlp/space/dst
This commit is contained in:
@@ -29,6 +29,7 @@ class Models(object):
|
||||
ofa = 'ofa'
|
||||
clip = 'clip-multi-modal-embedding'
|
||||
mplug = 'mplug'
|
||||
imagen = 'imagen-text-to-image-synthesis'
|
||||
|
||||
|
||||
class Pipelines(object):
|
||||
@@ -71,6 +72,7 @@ class Pipelines(object):
|
||||
image_caption = 'image-captioning'
|
||||
multi_modal_embedding = 'multi-modal-embedding'
|
||||
visual_question_answering = 'visual-question-answering'
|
||||
text_to_image_synthesis = 'text-to-image-synthesis'
|
||||
|
||||
|
||||
class Trainers(object):
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from .clip.clip_model import CLIPForMultiModalEmbedding
|
||||
from .image_captioning_model import OfaForImageCaptioning
|
||||
from .imagen.imagen_model import ImagenForTextToImageSynthesis
|
||||
from .mplug_for_visual_question_answering import \
|
||||
MPlugForVisualQuestionAnswering
|
||||
|
||||
0
modelscope/models/multi_modal/imagen/__init__.py
Normal file
0
modelscope/models/multi_modal/imagen/__init__.py
Normal file
595
modelscope/models/multi_modal/imagen/diffusion.py
Normal file
595
modelscope/models/multi_modal/imagen/diffusion.py
Normal file
@@ -0,0 +1,595 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
|
||||
__all__ = ['GaussianDiffusion', 'beta_schedule']
|
||||
|
||||
|
||||
def kl_divergence(mu1, logvar1, mu2, logvar2):
|
||||
a = -1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2)
|
||||
b = ((mu1 - mu2)**2) * torch.exp(-logvar2)
|
||||
return 0.5 * (a + b)
|
||||
|
||||
|
||||
def standard_normal_cdf(x):
|
||||
return 0.5 * (1.0 + torch.tanh(
|
||||
math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
||||
|
||||
|
||||
def discretized_gaussian_log_likelihood(x0, mean, log_scale):
|
||||
assert x0.shape == mean.shape == log_scale.shape
|
||||
cx = x0 - mean
|
||||
inv_stdv = torch.exp(-log_scale)
|
||||
cdf_plus = standard_normal_cdf(inv_stdv * (cx + 1.0 / 255.0))
|
||||
cdf_min = standard_normal_cdf(inv_stdv * (cx - 1.0 / 255.0))
|
||||
log_cdf_plus = torch.log(cdf_plus.clamp(min=1e-12))
|
||||
log_one_minus_cdf_min = torch.log((1.0 - cdf_min).clamp(min=1e-12))
|
||||
cdf_delta = cdf_plus - cdf_min
|
||||
log_probs = torch.where(
|
||||
x0 < -0.999, log_cdf_plus,
|
||||
torch.where(x0 > 0.999, log_one_minus_cdf_min,
|
||||
torch.log(cdf_delta.clamp(min=1e-12))))
|
||||
assert log_probs.shape == x0.shape
|
||||
return log_probs
|
||||
|
||||
|
||||
def _i(tensor, t, x):
|
||||
shape = (x.size(0), ) + (1, ) * (x.ndim - 1)
|
||||
return tensor[t].view(shape).to(x)
|
||||
|
||||
|
||||
def cosine_fn(u):
|
||||
return math.cos((u + 0.008) / 1.008 * math.pi / 2)**2
|
||||
|
||||
|
||||
def beta_schedule(schedule,
|
||||
num_timesteps=1000,
|
||||
init_beta=None,
|
||||
last_beta=None):
|
||||
if schedule == 'linear':
|
||||
scale = 1000.0 / num_timesteps
|
||||
init_beta = init_beta or scale * 0.0001
|
||||
last_beta = last_beta or scale * 0.02
|
||||
return torch.linspace(
|
||||
init_beta, last_beta, num_timesteps, dtype=torch.float64)
|
||||
elif schedule == 'quadratic':
|
||||
init_beta = init_beta or 0.0015
|
||||
last_beta = last_beta or 0.0195
|
||||
return torch.linspace(
|
||||
init_beta**0.5, last_beta**0.5, num_timesteps,
|
||||
dtype=torch.float64)**2
|
||||
elif schedule == 'cosine':
|
||||
betas = []
|
||||
for step in range(num_timesteps):
|
||||
t1 = step / num_timesteps
|
||||
t2 = (step + 1) / num_timesteps
|
||||
betas.append(min(1.0 - cosine_fn(t2) / cosine_fn(t1), 0.999))
|
||||
return torch.tensor(betas, dtype=torch.float64)
|
||||
else:
|
||||
raise ValueError(f'Unsupported schedule: {schedule}')
|
||||
|
||||
|
||||
class GaussianDiffusion(object):
|
||||
|
||||
def __init__(self,
|
||||
betas,
|
||||
mean_type='eps',
|
||||
var_type='learned_range',
|
||||
loss_type='mse',
|
||||
rescale_timesteps=False):
|
||||
# check input
|
||||
if not isinstance(betas, torch.DoubleTensor):
|
||||
betas = torch.tensor(betas, dtype=torch.float64)
|
||||
assert min(betas) > 0 and max(betas) <= 1
|
||||
assert mean_type in ['x0', 'x_{t-1}', 'eps']
|
||||
assert var_type in [
|
||||
'learned', 'learned_range', 'fixed_large', 'fixed_small'
|
||||
]
|
||||
assert loss_type in [
|
||||
'mse', 'rescaled_mse', 'kl', 'rescaled_kl', 'l1', 'rescaled_l1'
|
||||
]
|
||||
self.betas = betas
|
||||
self.num_timesteps = len(betas)
|
||||
self.mean_type = mean_type
|
||||
self.var_type = var_type
|
||||
self.loss_type = loss_type
|
||||
self.rescale_timesteps = rescale_timesteps
|
||||
|
||||
# alphas
|
||||
alphas = 1 - self.betas
|
||||
self.alphas_cumprod = torch.cumprod(alphas, dim=0)
|
||||
self.alphas_cumprod_prev = torch.cat(
|
||||
[alphas.new_ones([1]), self.alphas_cumprod[:-1]])
|
||||
self.alphas_cumprod_next = torch.cat(
|
||||
[self.alphas_cumprod[1:],
|
||||
alphas.new_zeros([1])])
|
||||
|
||||
# q(x_t | x_{t-1})
|
||||
self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
|
||||
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0
|
||||
- self.alphas_cumprod)
|
||||
self.log_one_minus_alphas_cumprod = torch.log(1.0
|
||||
- self.alphas_cumprod)
|
||||
self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod)
|
||||
self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod
|
||||
- 1)
|
||||
|
||||
# q(x_{t-1} | x_t, x_0)
|
||||
self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / (
|
||||
1.0 - self.alphas_cumprod)
|
||||
self.posterior_log_variance_clipped = torch.log(
|
||||
self.posterior_variance.clamp(1e-20))
|
||||
self.posterior_mean_coef1 = betas * torch.sqrt(
|
||||
self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
|
||||
self.posterior_mean_coef2 = (
|
||||
1.0 - self.alphas_cumprod_prev) * torch.sqrt(alphas) / (
|
||||
1.0 - self.alphas_cumprod)
|
||||
|
||||
def q_sample(self, x0, t, noise=None):
|
||||
noise = torch.randn_like(x0) if noise is None else noise
|
||||
return _i(self.sqrt_alphas_cumprod, t, x0) * x0 + _i(
|
||||
self.sqrt_one_minus_alphas_cumprod, t, x0) * noise
|
||||
|
||||
def q_mean_variance(self, x0, t):
|
||||
mu = _i(self.sqrt_alphas_cumprod, t, x0) * x0
|
||||
var = _i(1.0 - self.alphas_cumprod, t, x0)
|
||||
log_var = _i(self.log_one_minus_alphas_cumprod, t, x0)
|
||||
return mu, var, log_var
|
||||
|
||||
def q_posterior_mean_variance(self, x0, xt, t):
|
||||
mu = _i(self.posterior_mean_coef1, t, xt) * x0 + _i(
|
||||
self.posterior_mean_coef2, t, xt) * xt
|
||||
var = _i(self.posterior_variance, t, xt)
|
||||
log_var = _i(self.posterior_log_variance_clipped, t, xt)
|
||||
return mu, var, log_var
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample(self,
|
||||
xt,
|
||||
t,
|
||||
model,
|
||||
model_kwargs={},
|
||||
clamp=None,
|
||||
percentile=None,
|
||||
condition_fn=None,
|
||||
guide_scale=None):
|
||||
# predict distribution of p(x_{t-1} | x_t)
|
||||
mu, var, log_var, x0 = self.p_mean_variance(xt, t, model, model_kwargs,
|
||||
clamp, percentile,
|
||||
guide_scale)
|
||||
|
||||
# random sample (with optional conditional function)
|
||||
noise = torch.randn_like(xt)
|
||||
shape = (-1, ) + ((1, ) * (xt.ndim - 1))
|
||||
mask = t.ne(0).float().view(*shape) # no noise when t == 0
|
||||
if condition_fn is not None:
|
||||
grad = condition_fn(xt, self._scale_timesteps(t), **model_kwargs)
|
||||
mu = mu.float() + var * grad.float()
|
||||
xt_1 = mu + mask * torch.exp(0.5 * log_var) * noise
|
||||
return xt_1, x0
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_loop(self,
|
||||
noise,
|
||||
model,
|
||||
model_kwargs={},
|
||||
clamp=None,
|
||||
percentile=None,
|
||||
condition_fn=None,
|
||||
guide_scale=None):
|
||||
# prepare input
|
||||
b, c, h, w = noise.size()
|
||||
xt = noise
|
||||
|
||||
# diffusion process
|
||||
for step in torch.arange(self.num_timesteps).flip(0):
|
||||
t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
|
||||
xt, _ = self.p_sample(xt, t, model, model_kwargs, clamp,
|
||||
percentile, condition_fn, guide_scale)
|
||||
return xt
|
||||
|
||||
def p_mean_variance(self,
|
||||
xt,
|
||||
t,
|
||||
model,
|
||||
model_kwargs={},
|
||||
clamp=None,
|
||||
percentile=None,
|
||||
guide_scale=None):
|
||||
# predict distribution
|
||||
if guide_scale is None:
|
||||
out = model(xt, self._scale_timesteps(t), **model_kwargs)
|
||||
else:
|
||||
# classifier-free guidance
|
||||
# (model_kwargs[0]: conditional kwargs; model_kwargs[1]: non-conditional kwargs)
|
||||
assert isinstance(model_kwargs, list) and len(model_kwargs) == 2
|
||||
assert self.mean_type == 'eps'
|
||||
y_out = model(xt, self._scale_timesteps(t), **model_kwargs[0])
|
||||
u_out = model(xt, self._scale_timesteps(t), **model_kwargs[1])
|
||||
a = u_out[:, :3]
|
||||
b = guide_scale * (y_out[:, :3] - u_out[:, :3])
|
||||
c = y_out[:, 3:]
|
||||
out = torch.cat([a + b, c], dim=1)
|
||||
|
||||
# compute variance
|
||||
if self.var_type == 'learned':
|
||||
out, log_var = out.chunk(2, dim=1)
|
||||
var = torch.exp(log_var)
|
||||
elif self.var_type == 'learned_range':
|
||||
out, fraction = out.chunk(2, dim=1)
|
||||
min_log_var = _i(self.posterior_log_variance_clipped, t, xt)
|
||||
max_log_var = _i(torch.log(self.betas), t, xt)
|
||||
fraction = (fraction + 1) / 2.0
|
||||
log_var = fraction * max_log_var + (1 - fraction) * min_log_var
|
||||
var = torch.exp(log_var)
|
||||
elif self.var_type == 'fixed_large':
|
||||
var = _i(
|
||||
torch.cat([self.posterior_variance[1:2], self.betas[1:]]), t,
|
||||
xt)
|
||||
log_var = torch.log(var)
|
||||
elif self.var_type == 'fixed_small':
|
||||
var = _i(self.posterior_variance, t, xt)
|
||||
log_var = _i(self.posterior_log_variance_clipped, t, xt)
|
||||
|
||||
# compute mean and x0
|
||||
if self.mean_type == 'x_{t-1}':
|
||||
mu = out # x_{t-1}
|
||||
x0 = _i(1.0 / self.posterior_mean_coef1, t, xt) * mu - _i(
|
||||
self.posterior_mean_coef2 / self.posterior_mean_coef1, t,
|
||||
xt) * xt
|
||||
elif self.mean_type == 'x0':
|
||||
x0 = out
|
||||
mu, _, _ = self.q_posterior_mean_variance(x0, xt, t)
|
||||
elif self.mean_type == 'eps':
|
||||
x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - _i(
|
||||
self.sqrt_recipm1_alphas_cumprod, t, xt) * out
|
||||
mu, _, _ = self.q_posterior_mean_variance(x0, xt, t)
|
||||
|
||||
# restrict the range of x0
|
||||
if percentile is not None:
|
||||
assert percentile > 0 and percentile <= 1 # e.g., 0.995
|
||||
s = torch.quantile(
|
||||
x0.flatten(1).abs(), percentile,
|
||||
dim=1).clamp_(1.0).view(-1, 1, 1, 1)
|
||||
x0 = torch.min(s, torch.max(-s, x0)) / s
|
||||
elif clamp is not None:
|
||||
x0 = x0.clamp(-clamp, clamp)
|
||||
return mu, var, log_var, x0
|
||||
|
||||
@torch.no_grad()
|
||||
def ddim_sample(self,
|
||||
xt,
|
||||
t,
|
||||
model,
|
||||
model_kwargs={},
|
||||
clamp=None,
|
||||
percentile=None,
|
||||
condition_fn=None,
|
||||
guide_scale=None,
|
||||
ddim_timesteps=20,
|
||||
eta=0.0):
|
||||
stride = self.num_timesteps // ddim_timesteps
|
||||
|
||||
# predict distribution of p(x_{t-1} | x_t)
|
||||
_, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp,
|
||||
percentile, guide_scale)
|
||||
if condition_fn is not None:
|
||||
# x0 -> eps
|
||||
alpha = _i(self.alphas_cumprod, t, xt)
|
||||
eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / _i(
|
||||
self.sqrt_recipm1_alphas_cumprod, t, xt)
|
||||
eps = eps - (1 - alpha).sqrt() * condition_fn(
|
||||
xt, self._scale_timesteps(t), **model_kwargs)
|
||||
|
||||
# eps -> x0
|
||||
x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - _i(
|
||||
self.sqrt_recipm1_alphas_cumprod, t, xt) * eps
|
||||
|
||||
# derive variables
|
||||
eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / _i(
|
||||
self.sqrt_recipm1_alphas_cumprod, t, xt)
|
||||
alphas = _i(self.alphas_cumprod, t, xt)
|
||||
alphas_prev = _i(self.alphas_cumprod, (t - stride).clamp(0), xt)
|
||||
a = (1 - alphas_prev) / (1 - alphas)
|
||||
b = (1 - alphas / alphas_prev)
|
||||
sigmas = eta * torch.sqrt(a * b)
|
||||
|
||||
# random sample
|
||||
noise = torch.randn_like(xt)
|
||||
direction = torch.sqrt(1 - alphas_prev - sigmas**2) * eps
|
||||
mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1)))
|
||||
xt_1 = torch.sqrt(alphas_prev) * x0 + direction + mask * sigmas * noise
|
||||
return xt_1, x0
|
||||
|
||||
@torch.no_grad()
|
||||
def ddim_sample_loop(self,
|
||||
noise,
|
||||
model,
|
||||
model_kwargs={},
|
||||
clamp=None,
|
||||
percentile=None,
|
||||
condition_fn=None,
|
||||
guide_scale=None,
|
||||
ddim_timesteps=20,
|
||||
eta=0.0):
|
||||
# prepare input
|
||||
b, c, h, w = noise.size()
|
||||
xt = noise
|
||||
|
||||
# diffusion process (TODO: clamp is inaccurate! Consider replacing the stride by explicit prev/next steps)
|
||||
steps = (1 + torch.arange(0, self.num_timesteps,
|
||||
self.num_timesteps // ddim_timesteps)).clamp(
|
||||
0, self.num_timesteps - 1).flip(0)
|
||||
for step in steps:
|
||||
t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
|
||||
xt, _ = self.ddim_sample(xt, t, model, model_kwargs, clamp,
|
||||
percentile, condition_fn, guide_scale,
|
||||
ddim_timesteps, eta)
|
||||
return xt
|
||||
|
||||
@torch.no_grad()
|
||||
def ddim_reverse_sample(self,
|
||||
xt,
|
||||
t,
|
||||
model,
|
||||
model_kwargs={},
|
||||
clamp=None,
|
||||
percentile=None,
|
||||
guide_scale=None,
|
||||
ddim_timesteps=20):
|
||||
stride = self.num_timesteps // ddim_timesteps
|
||||
|
||||
# predict distribution of p(x_{t-1} | x_t)
|
||||
_, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp,
|
||||
percentile, guide_scale)
|
||||
|
||||
# derive variables
|
||||
eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / _i(
|
||||
self.sqrt_recipm1_alphas_cumprod, t, xt)
|
||||
alphas_next = _i(
|
||||
torch.cat(
|
||||
[self.alphas_cumprod,
|
||||
self.alphas_cumprod.new_zeros([1])]),
|
||||
(t + stride).clamp(0, self.num_timesteps), xt)
|
||||
|
||||
# reverse sample
|
||||
mu = torch.sqrt(alphas_next) * x0 + torch.sqrt(1 - alphas_next) * eps
|
||||
return mu, x0
|
||||
|
||||
@torch.no_grad()
|
||||
def ddim_reverse_sample_loop(self,
|
||||
x0,
|
||||
model,
|
||||
model_kwargs={},
|
||||
clamp=None,
|
||||
percentile=None,
|
||||
guide_scale=None,
|
||||
ddim_timesteps=20):
|
||||
# prepare input
|
||||
b, c, h, w = x0.size()
|
||||
xt = x0
|
||||
|
||||
# reconstruction steps
|
||||
steps = torch.arange(0, self.num_timesteps,
|
||||
self.num_timesteps // ddim_timesteps)
|
||||
for step in steps:
|
||||
t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
|
||||
xt, _ = self.ddim_reverse_sample(xt, t, model, model_kwargs, clamp,
|
||||
percentile, guide_scale,
|
||||
ddim_timesteps)
|
||||
return xt
|
||||
|
||||
@torch.no_grad()
|
||||
def plms_sample(self,
|
||||
xt,
|
||||
t,
|
||||
model,
|
||||
model_kwargs={},
|
||||
clamp=None,
|
||||
percentile=None,
|
||||
condition_fn=None,
|
||||
guide_scale=None,
|
||||
plms_timesteps=20):
|
||||
stride = self.num_timesteps // plms_timesteps
|
||||
|
||||
# function for compute eps
|
||||
def compute_eps(xt, t):
|
||||
# predict distribution of p(x_{t-1} | x_t)
|
||||
_, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs,
|
||||
clamp, percentile, guide_scale)
|
||||
|
||||
# condition
|
||||
if condition_fn is not None:
|
||||
# x0 -> eps
|
||||
alpha = _i(self.alphas_cumprod, t, xt)
|
||||
eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt
|
||||
- x0) / _i(self.sqrt_recipm1_alphas_cumprod, t, xt)
|
||||
eps = eps - (1 - alpha).sqrt() * condition_fn(
|
||||
xt, self._scale_timesteps(t), **model_kwargs)
|
||||
|
||||
# eps -> x0
|
||||
x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - _i(
|
||||
self.sqrt_recipm1_alphas_cumprod, t, xt) * eps
|
||||
|
||||
# derive eps
|
||||
eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / _i(
|
||||
self.sqrt_recipm1_alphas_cumprod, t, xt)
|
||||
return eps
|
||||
|
||||
# function for compute x_0 and x_{t-1}
|
||||
def compute_x0(eps, t):
|
||||
# eps -> x0
|
||||
x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - _i(
|
||||
self.sqrt_recipm1_alphas_cumprod, t, xt) * eps
|
||||
|
||||
# deterministic sample
|
||||
alphas_prev = _i(self.alphas_cumprod, (t - stride).clamp(0), xt)
|
||||
direction = torch.sqrt(1 - alphas_prev) * eps
|
||||
xt_1 = torch.sqrt(alphas_prev) * x0 + direction
|
||||
return xt_1, x0
|
||||
|
||||
# PLMS sample
|
||||
eps = compute_eps(xt, t)
|
||||
if len(eps_cache) == 0:
|
||||
# 2nd order pseudo improved Euler
|
||||
xt_1, x0 = compute_x0(eps, t)
|
||||
eps_next = compute_eps(xt_1, (t - stride).clamp(0))
|
||||
eps_prime = (eps + eps_next) / 2.0
|
||||
elif len(eps_cache) == 1:
|
||||
# 2nd order pseudo linear multistep (Adams-Bashforth)
|
||||
eps_prime = (3 * eps - eps_cache[-1]) / 2.0
|
||||
elif len(eps_cache) == 2:
|
||||
# 3nd order pseudo linear multistep (Adams-Bashforth)
|
||||
eps_prime = (23 * eps - 16 * eps_cache[-1]
|
||||
+ 5 * eps_cache[-2]) / 12.0
|
||||
elif len(eps_cache) >= 3:
|
||||
# 4nd order pseudo linear multistep (Adams-Bashforth)
|
||||
eps_prime = (55 * eps - 59 * eps_cache[-1] + 37 * eps_cache[-2]
|
||||
- 9 * eps_cache[-3]) / 24.0
|
||||
xt_1, x0 = compute_x0(eps_prime, t)
|
||||
return xt_1, x0, eps
|
||||
|
||||
@torch.no_grad()
|
||||
def plms_sample_loop(self,
|
||||
noise,
|
||||
model,
|
||||
model_kwargs={},
|
||||
clamp=None,
|
||||
percentile=None,
|
||||
condition_fn=None,
|
||||
guide_scale=None,
|
||||
plms_timesteps=20):
|
||||
# prepare input
|
||||
b, c, h, w = noise.size()
|
||||
xt = noise
|
||||
|
||||
# diffusion process
|
||||
steps = (1 + torch.arange(0, self.num_timesteps,
|
||||
self.num_timesteps // plms_timesteps)).clamp(
|
||||
0, self.num_timesteps - 1).flip(0)
|
||||
eps_cache = []
|
||||
for step in steps:
|
||||
# PLMS sampling step
|
||||
t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
|
||||
xt, _, eps = self.plms_sample(xt, t, model, model_kwargs, clamp,
|
||||
percentile, condition_fn,
|
||||
guide_scale, plms_timesteps,
|
||||
eps_cache)
|
||||
|
||||
# update eps cache
|
||||
eps_cache.append(eps)
|
||||
if len(eps_cache) >= 4:
|
||||
eps_cache.pop(0)
|
||||
return xt
|
||||
|
||||
def loss(self, x0, t, model, model_kwargs={}, noise=None):
|
||||
noise = torch.randn_like(x0) if noise is None else noise
|
||||
xt = self.q_sample(x0, t, noise=noise)
|
||||
|
||||
# compute loss
|
||||
if self.loss_type in ['kl', 'rescaled_kl']:
|
||||
loss, _ = self.variational_lower_bound(x0, xt, t, model,
|
||||
model_kwargs)
|
||||
if self.loss_type == 'rescaled_kl':
|
||||
loss = loss * self.num_timesteps
|
||||
elif self.loss_type in ['mse', 'rescaled_mse', 'l1', 'rescaled_l1']:
|
||||
out = model(xt, self._scale_timesteps(t), **model_kwargs)
|
||||
|
||||
# VLB for variation
|
||||
loss_vlb = 0.0
|
||||
if self.var_type in ['learned', 'learned_range']:
|
||||
out, var = out.chunk(2, dim=1)
|
||||
frozen = torch.cat([
|
||||
out.detach(), var
|
||||
], dim=1) # learn var without affecting the prediction of mean
|
||||
loss_vlb, _ = self.variational_lower_bound(
|
||||
x0, xt, t, model=lambda *args, **kwargs: frozen)
|
||||
if self.loss_type.startswith('rescaled_'):
|
||||
loss_vlb = loss_vlb * self.num_timesteps / 1000.0
|
||||
|
||||
# MSE/L1 for x0/eps
|
||||
target = {
|
||||
'eps': noise,
|
||||
'x0': x0,
|
||||
'x_{t-1}': self.q_posterior_mean_variance(x0, xt, t)[0]
|
||||
}[self.mean_type]
|
||||
loss = (out - target).pow(1 if self.loss_type.endswith('l1') else 2
|
||||
).abs().flatten(1).mean(dim=1)
|
||||
|
||||
# total loss
|
||||
loss = loss + loss_vlb
|
||||
return loss
|
||||
|
||||
def variational_lower_bound(self,
|
||||
x0,
|
||||
xt,
|
||||
t,
|
||||
model,
|
||||
model_kwargs={},
|
||||
clamp=None,
|
||||
percentile=None):
|
||||
# compute groundtruth and predicted distributions
|
||||
mu1, _, log_var1 = self.q_posterior_mean_variance(x0, xt, t)
|
||||
mu2, _, log_var2, x0 = self.p_mean_variance(xt, t, model, model_kwargs,
|
||||
clamp, percentile)
|
||||
|
||||
# compute KL loss
|
||||
kl = kl_divergence(mu1, log_var1, mu2, log_var2)
|
||||
kl = kl.flatten(1).mean(dim=1) / math.log(2.0)
|
||||
|
||||
# compute discretized NLL loss (for p(x0 | x1) only)
|
||||
nll = -discretized_gaussian_log_likelihood(
|
||||
x0, mean=mu2, log_scale=0.5 * log_var2)
|
||||
nll = nll.flatten(1).mean(dim=1) / math.log(2.0)
|
||||
|
||||
# NLL for p(x0 | x1) and KL otherwise
|
||||
vlb = torch.where(t == 0, nll, kl)
|
||||
return vlb, x0
|
||||
|
||||
@torch.no_grad()
|
||||
def variational_lower_bound_loop(self,
|
||||
x0,
|
||||
model,
|
||||
model_kwargs={},
|
||||
clamp=None,
|
||||
percentile=None):
|
||||
# prepare input and output
|
||||
b, c, h, w = x0.size()
|
||||
metrics = {'vlb': [], 'mse': [], 'x0_mse': []}
|
||||
|
||||
# loop
|
||||
for step in torch.arange(self.num_timesteps).flip(0):
|
||||
# compute VLB
|
||||
t = torch.full((b, ), step, dtype=torch.long, device=x0.device)
|
||||
noise = torch.randn_like(x0)
|
||||
xt = self.q_sample(x0, t, noise)
|
||||
vlb, pred_x0 = self.variational_lower_bound(
|
||||
x0, xt, t, model, model_kwargs, clamp, percentile)
|
||||
|
||||
# predict eps from x0
|
||||
eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / _i(
|
||||
self.sqrt_recipm1_alphas_cumprod, t, xt)
|
||||
|
||||
# collect metrics
|
||||
metrics['vlb'].append(vlb)
|
||||
metrics['x0_mse'].append(
|
||||
(pred_x0 - x0).square().flatten(1).mean(dim=1))
|
||||
metrics['mse'].append(
|
||||
(eps - noise).square().flatten(1).mean(dim=1))
|
||||
metrics = {k: torch.stack(v, dim=1) for k, v in metrics.items()}
|
||||
|
||||
# compute the prior KL term for VLB, measured in bits-per-dim
|
||||
mu, _, log_var = self.q_mean_variance(x0, t)
|
||||
kl_prior = kl_divergence(mu, log_var, torch.zeros_like(mu),
|
||||
torch.zeros_like(log_var))
|
||||
kl_prior = kl_prior.flatten(1).mean(dim=1) / math.log(2.0)
|
||||
|
||||
# update metrics
|
||||
metrics['prior_bits_per_dim'] = kl_prior
|
||||
metrics['total_bits_per_dim'] = metrics['vlb'].sum(dim=1) + kl_prior
|
||||
return metrics
|
||||
|
||||
def _scale_timesteps(self, t):
|
||||
if self.rescale_timesteps:
|
||||
return t.float() * 1000.0 / self.num_timesteps
|
||||
return t
|
||||
255
modelscope/models/multi_modal/imagen/imagen_model.py
Normal file
255
modelscope/models/multi_modal/imagen/imagen_model.py
Normal file
@@ -0,0 +1,255 @@
|
||||
import os.path as osp
|
||||
from typing import Any, Dict
|
||||
|
||||
import json
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from PIL import Image
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models.base import Model
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.models.multi_modal.imagen.diffusion import (GaussianDiffusion,
|
||||
beta_schedule)
|
||||
from modelscope.models.multi_modal.imagen.structbert import (BertConfig,
|
||||
BertModel)
|
||||
from modelscope.models.multi_modal.imagen.tokenizer import FullTokenizer
|
||||
from modelscope.models.multi_modal.imagen.unet_generator import ImagenGenerator
|
||||
from modelscope.models.multi_modal.imagen.unet_imagen_upsampler_256 import \
|
||||
SuperResUNet256
|
||||
from modelscope.models.multi_modal.imagen.unet_upsampler_1024 import \
|
||||
ImagenUpsampler1024
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
__all__ = ['ImagenForTextToImageSynthesis']
|
||||
|
||||
|
||||
def make_diffusion(schedule,
|
||||
num_timesteps=1000,
|
||||
init_beta=None,
|
||||
last_beta=None,
|
||||
var_type='fixed_small'):
|
||||
betas = beta_schedule(schedule, num_timesteps, init_beta, last_beta)
|
||||
diffusion = GaussianDiffusion(betas, var_type=var_type)
|
||||
return diffusion
|
||||
|
||||
|
||||
class Tokenizer(object):
|
||||
|
||||
def __init__(self, vocab_file, seq_len=64):
|
||||
self.vocab_file = vocab_file
|
||||
self.seq_len = seq_len
|
||||
self.tokenizer = FullTokenizer(
|
||||
vocab_file=vocab_file, do_lower_case=True)
|
||||
|
||||
def __call__(self, text):
|
||||
# tokenization
|
||||
tokens = self.tokenizer.tokenize(text)
|
||||
tokens = ['[CLS]'] + tokens[:self.seq_len - 2] + ['[SEP]']
|
||||
input_ids = self.tokenizer.convert_tokens_to_ids(tokens)
|
||||
input_mask = [1] * len(input_ids)
|
||||
segment_ids = [0] * len(input_ids)
|
||||
|
||||
# padding
|
||||
input_ids += [0] * (self.seq_len - len(input_ids))
|
||||
input_mask += [0] * (self.seq_len - len(input_mask))
|
||||
segment_ids += [0] * (self.seq_len - len(segment_ids))
|
||||
assert len(input_ids) == len(input_mask) == len(
|
||||
segment_ids) == self.seq_len
|
||||
|
||||
# convert to tensors
|
||||
input_ids = torch.LongTensor(input_ids)
|
||||
input_mask = torch.LongTensor(input_mask)
|
||||
segment_ids = torch.LongTensor(segment_ids)
|
||||
return input_ids, segment_ids, input_mask
|
||||
|
||||
|
||||
class ImagenModel(nn.Module):
|
||||
|
||||
def __init__(self, model_dir):
|
||||
super(ImagenModel, self).__init__()
|
||||
# including text and generator config
|
||||
model_config = json.load(
|
||||
open('{}/imagen_config.json'.format(model_dir)))
|
||||
|
||||
# text encoder
|
||||
text_config = model_config['text_config']
|
||||
self.text_encoder = BertModel(BertConfig.from_dict(text_config))
|
||||
|
||||
# generator (64x64)
|
||||
generator_config = model_config['generator_config']
|
||||
self.unet_generator = ImagenGenerator(**generator_config)
|
||||
|
||||
# imagen upsampler (256x256)
|
||||
imagen_upsampler_256_config = model_config[
|
||||
'imagen_upsampler_256_config']
|
||||
self.unet_imagen_upsampler_256 = SuperResUNet256(
|
||||
**imagen_upsampler_256_config)
|
||||
|
||||
# dalle2 upsampler (1024x1024)
|
||||
upsampler_1024_config = model_config['upsampler_1024_config']
|
||||
self.unet_upsampler_1024 = ImagenUpsampler1024(**upsampler_1024_config)
|
||||
|
||||
def forward(self, noise, timesteps, input_ids, token_type_ids,
|
||||
attention_mask):
|
||||
context, y = self.text_encoder(
|
||||
input_ids=input_ids,
|
||||
token_type_ids=token_type_ids,
|
||||
attention_mask=attention_mask)
|
||||
context = context[-1]
|
||||
x = self.unet_generator(noise, timesteps, y, context, attention_mask)
|
||||
x = self.unet_imagen_upsampler_256(noise, timesteps, x,
|
||||
torch.zeros_like(timesteps), y,
|
||||
context, attention_mask)
|
||||
x = self.unet_upsampler_1024(x, t, x)
|
||||
return x
|
||||
|
||||
|
||||
@MODELS.register_module(
|
||||
Tasks.text_to_image_synthesis, module_name=Models.imagen)
|
||||
class ImagenForTextToImageSynthesis(Model):
|
||||
|
||||
def __init__(self, model_dir, device_id=-1):
|
||||
super().__init__(model_dir=model_dir, device_id=device_id)
|
||||
imagen_model = ImagenModel(model_dir=model_dir)
|
||||
pretrained_params = torch.load(
|
||||
osp.join(model_dir, ModelFile.TORCH_MODEL_BIN_FILE), 'cpu')
|
||||
imagen_model.load_state_dict(pretrained_params)
|
||||
imagen_model.eval()
|
||||
|
||||
self.device_id = device_id
|
||||
if self.device_id >= 0:
|
||||
self.device = torch.device(f'cuda:{self.device_id}')
|
||||
imagen_model.to('cuda:{}'.format(self.device_id))
|
||||
logger.info('Use GPU: {}'.format(self.device_id))
|
||||
else:
|
||||
self.device = torch.device('cpu')
|
||||
logger.info('Use CPU for inference')
|
||||
|
||||
# modules
|
||||
self.text_encoder = imagen_model.text_encoder
|
||||
self.unet_generator = imagen_model.unet_generator
|
||||
self.unet_imagen_upsampler_256 = imagen_model.unet_imagen_upsampler_256
|
||||
self.unet_upsampler_1024 = imagen_model.unet_upsampler_1024
|
||||
|
||||
# text tokenizer
|
||||
vocab_path = '{}/vocab.txt'.format(model_dir)
|
||||
self.tokenizer = Tokenizer(vocab_file=vocab_path, seq_len=64)
|
||||
|
||||
# diffusion process
|
||||
diffusion_params = json.load(
|
||||
open('{}/diffusion_config.json'.format(model_dir)))
|
||||
self.diffusion_generator = make_diffusion(
|
||||
**diffusion_params['generator_config'])
|
||||
self.diffusion_imagen_upsampler_256 = make_diffusion(
|
||||
**diffusion_params['imagen_upsampler_256_config'])
|
||||
self.diffusion_upsampler_1024 = make_diffusion(
|
||||
**diffusion_params['upsampler_1024_config'])
|
||||
|
||||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
|
||||
if not all([key in input for key in ('text', 'noise', 'timesteps')]):
|
||||
raise ValueError(
|
||||
f'input should contains "text", "noise", and "timesteps", but got {input.keys()}'
|
||||
)
|
||||
input_ids, token_type_ids, attention_mask = self.tokenizer(
|
||||
input['text'])
|
||||
input_ids = input_ids.to(self.device).unsqueeze(0)
|
||||
token_type_ids = token_type_ids.to(self.device).unsqueeze(0)
|
||||
attention_mask = attention_mask.to(self.device).unsqueeze(0)
|
||||
context, y = self.text_encoder(
|
||||
input_ids=input_ids,
|
||||
token_type_ids=token_type_ids,
|
||||
attention_mask=attention_mask)
|
||||
context = context[-1]
|
||||
x = self.unet_generator(noise, timesteps, y, context, attention_mask)
|
||||
x = self.unet_imagen_upsampler_256(noise, timesteps, x,
|
||||
torch.zeros_like(timesteps), y,
|
||||
context, attention_mask)
|
||||
x = self.unet_upsampler_1024(x, t, x)
|
||||
img = x.clamp(-1, 1).add(1).mul(127.5)
|
||||
img = img.squeeze(0).permute(1, 2, 0).cpu().numpy().astype(np.uint8)
|
||||
return img
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return inputs
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(self, input: Dict[str, Any]) -> Dict[str, Any]:
|
||||
if 'text' not in input:
|
||||
raise ValueError(
|
||||
f'input should contain "text", but got {input.keys()}')
|
||||
|
||||
# encode text
|
||||
input_ids, token_type_ids, attention_mask = self.tokenizer(
|
||||
input['text'])
|
||||
input_ids = input_ids.to(self.device).unsqueeze(0)
|
||||
token_type_ids = token_type_ids.to(self.device).unsqueeze(0)
|
||||
attention_mask = attention_mask.to(self.device).unsqueeze(0)
|
||||
context, y = self.text_encoder(
|
||||
input_ids=input_ids,
|
||||
token_type_ids=token_type_ids,
|
||||
attention_mask=attention_mask)
|
||||
context = context[-1]
|
||||
|
||||
# generation
|
||||
img = self.diffusion_generator.ddim_sample_loop(
|
||||
noise=torch.randn(1, 3, 64, 64).to(self.device),
|
||||
model=self.unet_generator,
|
||||
model_kwargs=[{
|
||||
'y': y,
|
||||
'context': context,
|
||||
'mask': attention_mask
|
||||
}, {
|
||||
'y': torch.zeros_like(y),
|
||||
'context': torch.zeros_like(context),
|
||||
'mask': attention_mask
|
||||
}],
|
||||
percentile=input.get('generator_percentile', 0.995),
|
||||
guide_scale=input.get('generator_guide_scale', 5.0),
|
||||
ddim_timesteps=input.get('generator_ddim_timesteps', 250),
|
||||
eta=input.get('generator_ddim_eta', 0.0))
|
||||
|
||||
# upsampling (64->256)
|
||||
img = F.interpolate(
|
||||
img, scale_factor=4.0, mode='bilinear', align_corners=False)
|
||||
img = self.diffusion_imagen_upsampler_256.ddim_sample_loop(
|
||||
noise=torch.randn_like(img),
|
||||
model=self.unet_imagen_upsampler_256,
|
||||
model_kwargs=[{
|
||||
'lx': img,
|
||||
'lt': torch.zeros(1).to(self.device),
|
||||
'y': y,
|
||||
'context': context,
|
||||
'mask': attention_mask
|
||||
}, {
|
||||
'lx': img,
|
||||
'lt': torch.zeros(1).to(self.device),
|
||||
'y': torch.zeros_like(y),
|
||||
'context': torch.zeros_like(context),
|
||||
'mask': torch.zeros_like(attention_mask)
|
||||
}],
|
||||
percentile=input.get('generator_percentile', 0.995),
|
||||
guide_scale=input.get('generator_guide_scale', 5.0),
|
||||
ddim_timesteps=input.get('generator_ddim_timesteps', 50),
|
||||
eta=input.get('generator_ddim_eta', 0.0))
|
||||
|
||||
# upsampling (256->1024)
|
||||
img = F.interpolate(
|
||||
img, scale_factor=4.0, mode='bilinear', align_corners=False)
|
||||
img = self.diffusion_upsampler_1024.ddim_sample_loop(
|
||||
noise=torch.randn_like(img),
|
||||
model=self.unet_upsampler_1024,
|
||||
model_kwargs={'concat': img},
|
||||
percentile=input.get('upsampler_1024_percentile', 0.995),
|
||||
ddim_timesteps=input.get('upsampler_1024_ddim_timesteps', 20),
|
||||
eta=input.get('upsampler_1024_ddim_eta', 0.0))
|
||||
|
||||
# output
|
||||
img = img.clamp(-1, 1).add(1).mul(127.5).squeeze(0).permute(
|
||||
1, 2, 0).cpu().numpy().astype(np.uint8)
|
||||
return img
|
||||
936
modelscope/models/multi_modal/imagen/structbert.py
Normal file
936
modelscope/models/multi_modal/imagen/structbert.py
Normal file
@@ -0,0 +1,936 @@
|
||||
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team and Alibaba inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PyTorch BERT model."""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
import copy
|
||||
import math
|
||||
|
||||
import json
|
||||
import numpy as np
|
||||
import six
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
|
||||
def gelu(x):
|
||||
"""Implementation of the gelu activation function.
|
||||
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
|
||||
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
||||
"""
|
||||
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
|
||||
|
||||
|
||||
class BertConfig(object):
|
||||
"""Configuration class to store the configuration of a `BertModel`.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
vocab_size,
|
||||
hidden_size=768,
|
||||
emb_size=-1,
|
||||
num_hidden_layers=12,
|
||||
transformer_type='original',
|
||||
transition_function='linear',
|
||||
weighted_transformer=0,
|
||||
num_rolled_layers=3,
|
||||
num_attention_heads=12,
|
||||
intermediate_size=3072,
|
||||
hidden_act='gelu',
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=16,
|
||||
initializer_range=0.02,
|
||||
attention_type='self',
|
||||
rezero=False,
|
||||
pre_ln=False,
|
||||
squeeze_excitation=False,
|
||||
transfer_matrix=False,
|
||||
dim_dropout=False,
|
||||
roberta_style=False,
|
||||
set_mask_zero=False,
|
||||
init_scale=False,
|
||||
safer_fp16=False,
|
||||
grad_checkpoint=False):
|
||||
"""Constructs BertConfig.
|
||||
|
||||
Args:
|
||||
vocab_size: Vocabulary size of `inputs_ids` in `BertModel`.
|
||||
hidden_size: Size of the encoder layers and the pooler layer.
|
||||
num_hidden_layers: Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads: Number of attention heads for each attention layer in
|
||||
the Transformer encoder.
|
||||
intermediate_size: The size of the "intermediate" (i.e., feed-forward)
|
||||
layer in the Transformer encoder.
|
||||
hidden_act: The non-linear activation function (function or string) in the
|
||||
encoder and pooler.
|
||||
hidden_dropout_prob: The dropout probabilitiy for all fully connected
|
||||
layers in the embeddings, encoder, and pooler.
|
||||
attention_probs_dropout_prob: The dropout ratio for the attention
|
||||
probabilities.
|
||||
max_position_embeddings: The maximum sequence length that this model might
|
||||
ever be used with. Typically set this to something large just in case
|
||||
(e.g., 512 or 1024 or 2048).
|
||||
type_vocab_size: The vocabulary size of the `token_type_ids` passed into
|
||||
`BertModel`.
|
||||
initializer_range: The sttdev of the truncated_normal_initializer for
|
||||
initializing all weight matrices.
|
||||
"""
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.emb_size = emb_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.transformer_type = transformer_type
|
||||
self.transition_function = transition_function
|
||||
self.weighted_transformer = weighted_transformer
|
||||
self.num_rolled_layers = num_rolled_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.initializer_range = initializer_range
|
||||
self.attention_type = attention_type
|
||||
self.rezero = rezero
|
||||
self.pre_ln = pre_ln
|
||||
self.squeeze_excitation = squeeze_excitation
|
||||
self.transfer_matrix = transfer_matrix
|
||||
self.dim_dropout = dim_dropout
|
||||
self.set_mask_zero = set_mask_zero
|
||||
self.roberta_style = roberta_style
|
||||
self.init_scale = init_scale
|
||||
self.safer_fp16 = safer_fp16
|
||||
self.grad_checkpoint = grad_checkpoint
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, json_object):
|
||||
"""Constructs a `BertConfig` from a Python dictionary of parameters."""
|
||||
config = BertConfig(vocab_size=None)
|
||||
for (key, value) in six.iteritems(json_object):
|
||||
config.__dict__[key] = value
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def from_json_file(cls, json_file):
|
||||
"""Constructs a `BertConfig` from a json file of parameters."""
|
||||
with open(json_file, 'r') as reader:
|
||||
text = reader.read()
|
||||
return cls.from_dict(json.loads(text))
|
||||
|
||||
def to_dict(self):
|
||||
"""Serializes this instance to a Python dictionary."""
|
||||
output = copy.deepcopy(self.__dict__)
|
||||
return output
|
||||
|
||||
def to_json_string(self):
|
||||
"""Serializes this instance to a JSON string."""
|
||||
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + '\n'
|
||||
|
||||
|
||||
class BERTLayerNorm(nn.Module):
|
||||
|
||||
def __init__(self, config, variance_epsilon=1e-12, special_size=None):
|
||||
"""Construct a layernorm module in the TF style (epsilon inside the square root).
|
||||
"""
|
||||
super(BERTLayerNorm, self).__init__()
|
||||
self.config = config
|
||||
hidden_size = special_size if special_size is not None else config.hidden_size
|
||||
self.gamma = nn.Parameter(torch.ones(hidden_size))
|
||||
self.beta = nn.Parameter(torch.zeros(hidden_size))
|
||||
self.variance_epsilon = variance_epsilon if not config.roberta_style else 1e-5
|
||||
|
||||
def forward(self, x):
|
||||
previous_type = x.type()
|
||||
if self.config.safer_fp16:
|
||||
x = x.float()
|
||||
u = x.mean(-1, keepdim=True)
|
||||
s = (x - u).pow(2).mean(-1, keepdim=True)
|
||||
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
|
||||
if self.config.safer_fp16:
|
||||
return (self.gamma * x + self.beta).type(previous_type)
|
||||
else:
|
||||
return self.gamma * x + self.beta
|
||||
|
||||
|
||||
class BERTEmbeddings(nn.Module):
|
||||
|
||||
def __init__(self, config):
|
||||
super(BERTEmbeddings, self).__init__()
|
||||
"""Construct the embedding module from word, position and token_type embeddings.
|
||||
"""
|
||||
hidden_size = config.hidden_size if config.emb_size < 0 else config.emb_size
|
||||
self.word_embeddings = nn.Embedding(
|
||||
config.vocab_size,
|
||||
hidden_size,
|
||||
padding_idx=1 if config.roberta_style else None)
|
||||
self.position_embeddings = nn.Embedding(
|
||||
config.max_position_embeddings,
|
||||
hidden_size,
|
||||
padding_idx=1 if config.roberta_style else None)
|
||||
self.token_type_embeddings = nn.Embedding(config.type_vocab_size,
|
||||
hidden_size)
|
||||
self.config = config
|
||||
self.proj = None if config.emb_size < 0 else nn.Linear(
|
||||
config.emb_size, config.hidden_size)
|
||||
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
||||
# any TensorFlow checkpoint file
|
||||
self.LayerNorm = BERTLayerNorm(config, special_size=hidden_size)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(self, input_ids, token_type_ids=None, adv_embedding=None):
|
||||
seq_length = input_ids.size(1)
|
||||
if not self.config.roberta_style:
|
||||
position_ids = torch.arange(
|
||||
seq_length, dtype=torch.long, device=input_ids.device)
|
||||
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
|
||||
else:
|
||||
mask = input_ids.ne(1).int()
|
||||
position_ids = (torch.cumsum(mask, dim=1).type_as(mask)
|
||||
* mask).long() + 1
|
||||
if token_type_ids is None:
|
||||
token_type_ids = torch.zeros_like(input_ids)
|
||||
|
||||
words_embeddings = self.word_embeddings(
|
||||
input_ids) if adv_embedding is None else adv_embedding
|
||||
if self.config.set_mask_zero:
|
||||
words_embeddings[input_ids == 103] = 0.
|
||||
position_embeddings = self.position_embeddings(position_ids)
|
||||
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
||||
|
||||
if not self.config.roberta_style:
|
||||
embeddings = words_embeddings + position_embeddings + token_type_embeddings
|
||||
else:
|
||||
embeddings = words_embeddings + position_embeddings
|
||||
embeddings = self.LayerNorm(embeddings)
|
||||
embeddings = self.dropout(embeddings)
|
||||
if self.proj is not None:
|
||||
embeddings = self.proj(embeddings)
|
||||
embeddings = self.dropout(embeddings)
|
||||
else:
|
||||
return embeddings, words_embeddings
|
||||
|
||||
|
||||
class BERTFactorizedAttention(nn.Module):
|
||||
|
||||
def __init__(self, config):
|
||||
super(BERTFactorizedAttention, self).__init__()
|
||||
if config.hidden_size % config.num_attention_heads != 0:
|
||||
raise ValueError(
|
||||
'The hidden size (%d) is not a multiple of the number of attention '
|
||||
'heads (%d)' %
|
||||
(config.hidden_size, config.num_attention_heads))
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.attention_head_size = int(config.hidden_size
|
||||
/ config.num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
|
||||
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
|
||||
def transpose_for_scores(self, x, *size):
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads,
|
||||
self.attention_head_size)
|
||||
x = x.view(*new_x_shape)
|
||||
return x.permute(size)
|
||||
|
||||
def forward(self, hidden_states, attention_mask):
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
mixed_key_layer = self.key(hidden_states)
|
||||
mixed_value_layer = self.value(hidden_states)
|
||||
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer, 0, 2, 3, 1)
|
||||
key_layer = self.transpose_for_scores(mixed_key_layer, 0, 2, 1, 3)
|
||||
value_layer = self.transpose_for_scores(mixed_value_layer, 0, 2, 1, 3)
|
||||
|
||||
s_attention_scores = query_layer + attention_mask
|
||||
s_attention_probs = nn.Softmax(dim=-1)(s_attention_scores)
|
||||
s_attention_probs = self.dropout(s_attention_probs)
|
||||
|
||||
c_attention_probs = nn.Softmax(dim=-1)(key_layer)
|
||||
s_context_layer = torch.matmul(s_attention_probs, value_layer)
|
||||
context_layer = torch.matmul(c_attention_probs, s_context_layer)
|
||||
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (
|
||||
self.all_head_size, )
|
||||
context_layer = context_layer.view(*new_context_layer_shape)
|
||||
return context_layer
|
||||
|
||||
|
||||
def dim_dropout(x, p=0, dim=-1, training=False):
|
||||
if not training or p == 0:
|
||||
return x
|
||||
a = (1 - p)
|
||||
b = (x.data.new(x.size()).zero_() + 1)
|
||||
dropout_mask = torch.bernoulli(a * b)
|
||||
return dropout_mask * (dropout_mask.size(dim) / torch.sum(
|
||||
dropout_mask, dim=dim, keepdim=True)) * x
|
||||
|
||||
|
||||
class BERTSelfAttention(nn.Module):
|
||||
|
||||
def __init__(self, config):
|
||||
super(BERTSelfAttention, self).__init__()
|
||||
if config.hidden_size % config.num_attention_heads != 0:
|
||||
raise ValueError(
|
||||
'The hidden size (%d) is not a multiple of the number of attention '
|
||||
'heads (%d)' %
|
||||
(config.hidden_size, config.num_attention_heads))
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.attention_head_size = int(config.hidden_size
|
||||
/ config.num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
|
||||
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
self.config = config
|
||||
if config.pre_ln:
|
||||
self.LayerNorm = BERTLayerNorm(config)
|
||||
|
||||
def transpose_for_scores(self, x):
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads,
|
||||
self.attention_head_size)
|
||||
x = x.view(*new_x_shape)
|
||||
return x.permute(0, 2, 1, 3)
|
||||
|
||||
def forward(self, hidden_states, attention_mask, head_mask=None):
|
||||
if self.config.pre_ln:
|
||||
hidden_states = self.LayerNorm(hidden_states)
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
mixed_key_layer = self.key(hidden_states)
|
||||
mixed_value_layer = self.value(hidden_states)
|
||||
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
key_layer = self.transpose_for_scores(mixed_key_layer)
|
||||
value_layer = self.transpose_for_scores(mixed_value_layer)
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attention_scores = torch.matmul(query_layer,
|
||||
key_layer.transpose(-1, -2))
|
||||
attention_scores = attention_scores / math.sqrt(
|
||||
self.attention_head_size)
|
||||
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
||||
if head_mask is not None and not self.training:
|
||||
for i, mask in enumerate(head_mask):
|
||||
if head_mask[i] == 1:
|
||||
attention_scores[:, i, :, :] = 0.
|
||||
attention_scores = attention_scores + attention_mask
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
if not self.config.dim_dropout:
|
||||
attention_probs = self.dropout(attention_probs)
|
||||
else:
|
||||
attention_probs = dim_dropout(
|
||||
attention_probs,
|
||||
p=self.config.attention_probs_dropout_prob,
|
||||
dim=-1,
|
||||
training=self.training)
|
||||
|
||||
context_layer = torch.matmul(attention_probs, value_layer)
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (
|
||||
self.all_head_size, )
|
||||
context_layer = context_layer.view(*new_context_layer_shape)
|
||||
return context_layer
|
||||
|
||||
|
||||
class BERTSelfOutput(nn.Module):
|
||||
|
||||
def __init__(self, config):
|
||||
super(BERTSelfOutput, self).__init__()
|
||||
self.config = config
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
if not config.pre_ln and not config.rezero:
|
||||
self.LayerNorm = BERTLayerNorm(config)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
if config.rezero:
|
||||
self.res_factor = nn.Parameter(
|
||||
torch.Tensor(1).fill_(0.99).to(
|
||||
dtype=next(self.parameters()).dtype))
|
||||
self.factor = nn.Parameter(
|
||||
torch.ones(1).to(dtype=next(self.parameters()).dtype))
|
||||
|
||||
def forward(self, hidden_states, input_tensor):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
if not self.config.rezero and not self.config.pre_ln:
|
||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||
elif self.config.rezero:
|
||||
hidden_states = hidden_states + self.factor * input_tensor
|
||||
else:
|
||||
pass
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BERTAttention(nn.Module):
|
||||
|
||||
def __init__(self, config):
|
||||
super(BERTAttention, self).__init__()
|
||||
if config.attention_type.lower() == 'self':
|
||||
self.self = BERTSelfAttention(config)
|
||||
elif config.attention_type.lower() == 'factorized':
|
||||
self.self = BERTFactorizedAttention(config)
|
||||
else:
|
||||
raise ValueError(
|
||||
'Attention type must in [self, factorized], but got {}'.format(
|
||||
config.attention_type))
|
||||
self.output = BERTSelfOutput(config)
|
||||
|
||||
def forward(self, input_tensor, attention_mask, head_mask=None):
|
||||
self_output = self.self(input_tensor, attention_mask, head_mask)
|
||||
attention_output = self.output(self_output, input_tensor)
|
||||
return attention_output
|
||||
|
||||
|
||||
class DepthwiseSeparableConv1d(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
bias=False):
|
||||
super(DepthwiseSeparableConv1d, self).__init__()
|
||||
padding = (kernel_size - 1) // 2
|
||||
self.depthwise = nn.Conv1d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
groups=in_channels,
|
||||
bias=bias)
|
||||
self.pointwise = nn.Conv1d(
|
||||
in_channels, out_channels, 1, 1, 0, 1, 1, bias=bias)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.depthwise(x)
|
||||
x = self.pointwise(x)
|
||||
return x
|
||||
|
||||
|
||||
class BERTIntermediate(nn.Module):
|
||||
|
||||
def __init__(self, config):
|
||||
super(BERTIntermediate, self).__init__()
|
||||
self.config = config
|
||||
if self.config.pre_ln:
|
||||
self.LayerNorm = BERTLayerNorm(config)
|
||||
self.intermediate_act_fn = gelu
|
||||
if config.transition_function.lower() == 'linear':
|
||||
self.dense = nn.Linear(config.hidden_size,
|
||||
config.intermediate_size)
|
||||
elif config.transition_function.lower() == 'cnn':
|
||||
self.cnn = DepthwiseSeparableConv1d(
|
||||
config.hidden_size, 4 * config.hidden_size, kernel_size=7)
|
||||
elif config.config.hidden_size.lower() == 'rnn':
|
||||
raise NotImplementedError(
|
||||
'rnn transition function is not implemented yet')
|
||||
else:
|
||||
raise ValueError('Only support linear/cnn/rnn')
|
||||
|
||||
def forward(self, hidden_states):
|
||||
if self.config.pre_ln:
|
||||
hidden_states = self.LayerNorm(hidden_states)
|
||||
if self.config.transition_function.lower() == 'linear':
|
||||
hidden_states = self.dense(hidden_states)
|
||||
elif self.config.transition_function.lower() == 'cnn':
|
||||
hidden_states = self.cnn(hidden_states.transpose(-1,
|
||||
-2)).transpose(
|
||||
-1, -2)
|
||||
else:
|
||||
pass
|
||||
hidden_states = self.intermediate_act_fn(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SqueezeExcitationBlock(nn.Module):
|
||||
|
||||
def __init__(self, config):
|
||||
super(SqueezeExcitationBlock, self).__init__()
|
||||
self.down_sampling = nn.Linear(config.hidden_size,
|
||||
config.hidden_size // 4)
|
||||
self.up_sampling = nn.Linear(config.hidden_size // 4,
|
||||
config.hidden_size)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
squeeze = torch.mean(hidden_states, 1, keepdim=True)
|
||||
excitation = torch.sigmoid(
|
||||
self.up_sampling(gelu(self.down_sampling(squeeze))))
|
||||
return hidden_states * excitation
|
||||
|
||||
|
||||
class BERTOutput(nn.Module):
|
||||
|
||||
def __init__(self, config):
|
||||
super(BERTOutput, self).__init__()
|
||||
self.config = config
|
||||
if config.transition_function.lower() == 'linear':
|
||||
self.dense = nn.Linear(config.intermediate_size,
|
||||
config.hidden_size)
|
||||
elif config.transition_function.lower() == 'cnn':
|
||||
self.cnn = DepthwiseSeparableConv1d(
|
||||
4 * config.hidden_size, config.hidden_size, kernel_size=7)
|
||||
elif config.config.hidden_size.lower() == 'rnn':
|
||||
raise NotImplementedError(
|
||||
'rnn transition function is not implemented yet')
|
||||
else:
|
||||
raise ValueError('Only support linear/cnn/rnn')
|
||||
if not config.pre_ln and not config.rezero:
|
||||
self.LayerNorm = BERTLayerNorm(config)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
if config.squeeze_excitation:
|
||||
self.SEblock = SqueezeExcitationBlock(config)
|
||||
if config.rezero:
|
||||
self.res_factor = nn.Parameter(
|
||||
torch.Tensor(1).fill_(0.99).to(
|
||||
dtype=next(self.parameters()).dtype))
|
||||
self.factor = nn.Parameter(
|
||||
torch.ones(1).to(dtype=next(self.parameters()).dtype))
|
||||
|
||||
def forward(self, hidden_states, input_tensor):
|
||||
if self.config.transition_function.lower() == 'linear':
|
||||
hidden_states = self.dense(hidden_states)
|
||||
elif self.config.transition_function.lower() == 'cnn':
|
||||
hidden_states = self.cnn(hidden_states.transpose(-1,
|
||||
-2)).transpose(
|
||||
-1, -2)
|
||||
else:
|
||||
pass
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
if self.config.squeeze_excitation:
|
||||
hidden_states = self.SEblock(hidden_states)
|
||||
if not self.config.rezero and not self.config.pre_ln:
|
||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||
elif self.config.rezero:
|
||||
hidden_states = hidden_states + self.factor * input_tensor
|
||||
else:
|
||||
pass
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BERTLayer(nn.Module):
|
||||
|
||||
def __init__(self, config):
|
||||
super(BERTLayer, self).__init__()
|
||||
self.attention = BERTAttention(config)
|
||||
self.intermediate = BERTIntermediate(config)
|
||||
self.output = BERTOutput(config)
|
||||
|
||||
def forward(self, hidden_states, attention_mask, head_mask=None):
|
||||
attention_output = self.attention(hidden_states, attention_mask,
|
||||
head_mask)
|
||||
intermediate_output = self.intermediate(attention_output)
|
||||
layer_output = self.output(intermediate_output, attention_output)
|
||||
return attention_output, layer_output
|
||||
|
||||
|
||||
class BERTWeightedLayer(nn.Module):
|
||||
|
||||
def __init__(self, config):
|
||||
super(BERTWeightedLayer, self).__init__()
|
||||
self.config = config
|
||||
self.self = BERTSelfAttention(config)
|
||||
self.attention_head_size = self.self.attention_head_size
|
||||
|
||||
self.w_o = nn.ModuleList([
|
||||
nn.Linear(self.attention_head_size, config.hidden_size)
|
||||
for _ in range(config.num_attention_heads)
|
||||
])
|
||||
self.w_kp = torch.rand(config.num_attention_heads)
|
||||
self.w_kp = nn.Parameter(self.w_kp / self.w_kp.sum())
|
||||
self.w_a = torch.rand(config.num_attention_heads)
|
||||
self.w_a = nn.Parameter(self.w_a / self.w_a.sum())
|
||||
|
||||
self.intermediate = BERTIntermediate(config)
|
||||
self.output = nn.Linear(config.intermediate_size, config.hidden_size)
|
||||
self.LayerNorm = BERTLayerNorm(config)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(self, hidden_states, attention_mask):
|
||||
self_output = self.self(hidden_states, attention_mask)
|
||||
self_outputs = self_output.split(self.self.attention_head_size, dim=-1)
|
||||
self_outputs = [
|
||||
self.w_o[i](self_outputs[i]) for i in range(len(self_outputs))
|
||||
]
|
||||
self_outputs = [
|
||||
self.dropout(self_outputs[i]) for i in range(len(self_outputs))
|
||||
]
|
||||
self_outputs = [
|
||||
kappa * output for kappa, output in zip(self.w_kp, self_outputs)
|
||||
]
|
||||
self_outputs = [
|
||||
self.intermediate(self_outputs[i])
|
||||
for i in range(len(self_outputs))
|
||||
]
|
||||
self_outputs = [
|
||||
self.output(self_outputs[i]) for i in range(len(self_outputs))
|
||||
]
|
||||
self_outputs = [
|
||||
self.dropout(self_outputs[i]) for i in range(len(self_outputs))
|
||||
]
|
||||
self_outputs = [
|
||||
alpha * output for alpha, output in zip(self.w_a, self_outputs)
|
||||
]
|
||||
output = sum(self_outputs)
|
||||
return self.LayerNorm(hidden_states + output)
|
||||
|
||||
|
||||
class BERTEncoder(nn.Module):
|
||||
|
||||
def __init__(self, config):
|
||||
super(BERTEncoder, self).__init__()
|
||||
self.layer = nn.ModuleList()
|
||||
for _ in range(config.num_hidden_layers):
|
||||
if config.weighted_transformer:
|
||||
self.layer.append(BERTWeightedLayer(config))
|
||||
else:
|
||||
self.layer.append(BERTLayer(config))
|
||||
if config.rezero:
|
||||
for index, layer in enumerate(self.layer):
|
||||
layer.output.res_factor = nn.Parameter(
|
||||
torch.Tensor(1).fill_(1.).to(
|
||||
dtype=next(self.parameters()).dtype))
|
||||
layer.output.factor = nn.Parameter(
|
||||
torch.Tensor(1).fill_(1).to(
|
||||
dtype=next(self.parameters()).dtype))
|
||||
layer.attention.output.res_factor = layer.output.res_factor
|
||||
layer.attention.output.factor = layer.output.factor
|
||||
self.config = config
|
||||
|
||||
def forward(self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
epoch_id=-1,
|
||||
head_masks=None):
|
||||
all_encoder_layers = [hidden_states]
|
||||
if epoch_id != -1:
|
||||
detach_index = int(len(self.layer) / 3) * (2 - epoch_id) - 1
|
||||
else:
|
||||
detach_index = -1
|
||||
for index, layer_module in enumerate(self.layer):
|
||||
if head_masks is None:
|
||||
if not self.config.grad_checkpoint:
|
||||
self_out, hidden_states = layer_module(
|
||||
hidden_states, attention_mask, None)
|
||||
else:
|
||||
self_out, hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
layer_module, hidden_states, attention_mask, None)
|
||||
else:
|
||||
self_out, hidden_states = layer_module(hidden_states,
|
||||
attention_mask,
|
||||
head_masks[index])
|
||||
if detach_index == index:
|
||||
hidden_states.detach_()
|
||||
all_encoder_layers.append(self_out)
|
||||
all_encoder_layers.append(hidden_states)
|
||||
return all_encoder_layers
|
||||
|
||||
|
||||
class BERTEncoderRolled(nn.Module):
|
||||
|
||||
def __init__(self, config):
|
||||
super(BERTEncoderRolled, self).__init__()
|
||||
layer = BERTLayer(config)
|
||||
self.config = config
|
||||
self.layer = nn.ModuleList(
|
||||
[copy.deepcopy(layer) for _ in range(config.num_rolled_layers)])
|
||||
|
||||
def forward(self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
epoch_id=-1,
|
||||
head_masks=None):
|
||||
all_encoder_layers = [hidden_states]
|
||||
for i in range(self.config.num_hidden_layers):
|
||||
if self.config.transformer_type.lower() == 'universal':
|
||||
hidden_states = self.layer[i % self.config.num_rolled_layers](
|
||||
hidden_states, attention_mask)
|
||||
elif self.config.transformer_type.lower() == 'albert':
|
||||
a = i // (
|
||||
self.config.num_hidden_layers
|
||||
// self.config.num_rolled_layers)
|
||||
hidden_states = self.layer[a](hidden_states, attention_mask)
|
||||
all_encoder_layers.append(hidden_states)
|
||||
return all_encoder_layers
|
||||
|
||||
|
||||
class BERTEncoderACT(nn.Module):
|
||||
|
||||
def __init__(self, config):
|
||||
super(BERTEncoderACT, self).__init__()
|
||||
self.layer = BERTLayer(config)
|
||||
p = nn.Linear(config.hidden_size, 1)
|
||||
self.p = nn.ModuleList(
|
||||
[copy.deepcopy(p) for _ in range(config.num_hidden_layers)])
|
||||
# Following act paper, set bias init ones
|
||||
for module in self.p:
|
||||
module.bias.data.fill_(1.)
|
||||
self.config = config
|
||||
self.act_max_steps = config.num_hidden_layers
|
||||
self.threshold = 0.99
|
||||
|
||||
def should_continue(self, halting_probability, n_updates):
|
||||
return (halting_probability.lt(self.threshold).__and__(
|
||||
n_updates.lt(self.act_max_steps))).any()
|
||||
|
||||
def forward(self, hidden_states, attention_mask):
|
||||
all_encoder_layers = [hidden_states]
|
||||
batch_size, seq_len, hdim = hidden_states.size()
|
||||
halting_probability = torch.zeros(batch_size, seq_len).cuda()
|
||||
remainders = torch.zeros(batch_size, seq_len).cuda()
|
||||
n_updates = torch.zeros(batch_size, seq_len).cuda()
|
||||
for i in range(self.act_max_steps):
|
||||
p = torch.sigmoid(self.p[i](hidden_states).squeeze(2))
|
||||
still_running = halting_probability.lt(1.0).float()
|
||||
new_halted = (halting_probability + p * still_running).gt(
|
||||
self.threshold).float() * still_running
|
||||
still_running = (halting_probability + p * still_running).le(
|
||||
self.threshold).float() * still_running
|
||||
halting_probability = halting_probability + p * still_running
|
||||
remainders = remainders + new_halted * (1 - halting_probability)
|
||||
halting_probability = halting_probability + new_halted * remainders
|
||||
n_updates = n_updates + still_running + new_halted
|
||||
update_weights = (p * still_running
|
||||
+ new_halted * remainders).unsqueeze(2)
|
||||
transformed_states = self.layer(hidden_states, attention_mask)
|
||||
hidden_states = transformed_states * update_weights + hidden_states * (
|
||||
1 - update_weights)
|
||||
all_encoder_layers.append(hidden_states)
|
||||
if not self.should_continue(halting_probability, n_updates):
|
||||
break
|
||||
return all_encoder_layers, torch.mean(n_updates + remainders)
|
||||
|
||||
|
||||
class BERTPooler(nn.Module):
|
||||
|
||||
def __init__(self, config):
|
||||
super(BERTPooler, self).__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.activation = nn.Tanh()
|
||||
|
||||
def forward(self, hidden_states):
|
||||
# We "pool" the model by simply taking the hidden state corresponding
|
||||
# to the first token.
|
||||
first_token_tensor = hidden_states[:, 0]
|
||||
pooled_output = self.dense(first_token_tensor)
|
||||
pooled_output = self.activation(pooled_output)
|
||||
return pooled_output
|
||||
|
||||
|
||||
class BertModel(nn.Module):
|
||||
"""BERT model ("Bidirectional Embedding Representations from a Transformer").
|
||||
|
||||
Example usage:
|
||||
```python
|
||||
# Already been converted into WordPiece token ids
|
||||
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
|
||||
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
|
||||
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]])
|
||||
|
||||
config = modeling.BertConfig(vocab_size=32000, hidden_size=512,
|
||||
num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024)
|
||||
|
||||
model = modeling.BertModel(config=config)
|
||||
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, config: BertConfig):
|
||||
"""Constructor for BertModel.
|
||||
|
||||
Args:
|
||||
config: `BertConfig` instance.
|
||||
"""
|
||||
super(BertModel, self).__init__()
|
||||
self.config = config
|
||||
self.embeddings = BERTEmbeddings(config)
|
||||
if config.transformer_type.lower() == 'original':
|
||||
self.encoder = BERTEncoder(config)
|
||||
elif config.transformer_type.lower() == 'universal':
|
||||
self.encoder = BERTEncoderRolled(config)
|
||||
elif config.transformer_type.lower() == 'albert':
|
||||
self.encoder = BERTEncoderRolled(config)
|
||||
elif config.transformer_type.lower() == 'act':
|
||||
self.encoder = BERTEncoderACT(config)
|
||||
elif config.transformer_type.lower() == 'textnas':
|
||||
from textnas_final import op_dict, input_dict, skip_dict
|
||||
self.encoder = TextNASEncoder(config, op_dict, input_dict,
|
||||
skip_dict)
|
||||
else:
|
||||
raise ValueError('Not support transformer type: {}'.format(
|
||||
config.transformer_type.lower()))
|
||||
self.pooler = BERTPooler(config)
|
||||
|
||||
def forward(self,
|
||||
input_ids,
|
||||
token_type_ids=None,
|
||||
attention_mask=None,
|
||||
epoch_id=-1,
|
||||
head_masks=None,
|
||||
adv_embedding=None):
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
if token_type_ids is None:
|
||||
token_type_ids = torch.zeros_like(input_ids)
|
||||
|
||||
# We create a 3D attention mask from a 2D tensor mask.
|
||||
# Sizes are [batch_size, 1, 1, to_seq_length]
|
||||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||
# this attention mask is more simple than the triangular masking of causal attention
|
||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
||||
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
||||
|
||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||
# masked positions, this operation will create a tensor which is 0.0 for
|
||||
# positions we want to attend and -10000.0 for masked positions.
|
||||
# Since we are adding it to the raw scores before the softmax, this is
|
||||
# effectively the same as removing these entirely.
|
||||
extended_attention_mask = extended_attention_mask.to(
|
||||
dtype=next(self.parameters()).dtype) # fp16 compatibility
|
||||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
||||
|
||||
embedding_output, word_embeddings = self.embeddings(
|
||||
input_ids, token_type_ids, adv_embedding)
|
||||
if self.config.transformer_type.lower() == 'act':
|
||||
all_encoder_layers, act_loss = self.encoder(
|
||||
embedding_output, extended_attention_mask)
|
||||
elif self.config.transformer_type.lower() == 'reformer':
|
||||
sequence_output = self.encoder(embedding_output)
|
||||
all_encoder_layers = [sequence_output, sequence_output]
|
||||
else:
|
||||
all_encoder_layers = self.encoder(embedding_output,
|
||||
extended_attention_mask,
|
||||
epoch_id, head_masks)
|
||||
all_encoder_layers.insert(0, word_embeddings)
|
||||
sequence_output = all_encoder_layers[-1]
|
||||
if not self.config.safer_fp16:
|
||||
pooled_output = self.pooler(sequence_output)
|
||||
else:
|
||||
pooled_output = sequence_output[:, 0]
|
||||
return all_encoder_layers, pooled_output
|
||||
|
||||
|
||||
class BertForSequenceClassificationMultiTask(nn.Module):
|
||||
"""BERT model for classification.
|
||||
This module is composed of the BERT model with a linear layer on top of
|
||||
the pooled output.
|
||||
|
||||
Example usage:
|
||||
```python
|
||||
# Already been converted into WordPiece token ids
|
||||
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
|
||||
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
|
||||
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]])
|
||||
|
||||
config = BertConfig(vocab_size=32000, hidden_size=512,
|
||||
num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024)
|
||||
|
||||
num_labels = 2
|
||||
|
||||
model = BertForSequenceClassification(config, num_labels)
|
||||
logits = model(input_ids, token_type_ids, input_mask)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, config, label_list, core_encoder):
|
||||
super(BertForSequenceClassificationMultiTask, self).__init__()
|
||||
if core_encoder.lower() == 'bert':
|
||||
self.bert = BertModel(config)
|
||||
elif core_encoder.lower() == 'lstm':
|
||||
self.bert = LSTMModel(config)
|
||||
else:
|
||||
raise ValueError(
|
||||
'Only support lstm or bert, but got {}'.format(core_encoder))
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.classifier = nn.ModuleList()
|
||||
for label in label_list:
|
||||
self.classifier.append(nn.Linear(config.hidden_size, len(label)))
|
||||
self.label_list = label_list
|
||||
|
||||
def init_weights(module):
|
||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
module.weight.data.normal_(
|
||||
mean=0.0, std=config.initializer_range)
|
||||
elif isinstance(module, BERTLayerNorm):
|
||||
module.beta.data.normal_(
|
||||
mean=0.0, std=config.initializer_range)
|
||||
module.gamma.data.normal_(
|
||||
mean=0.0, std=config.initializer_range)
|
||||
if isinstance(module, nn.Linear):
|
||||
module.bias.data.zero_()
|
||||
|
||||
self.apply(init_weights)
|
||||
|
||||
def forward(self,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
attention_mask,
|
||||
labels=None,
|
||||
labels_index=None,
|
||||
epoch_id=-1,
|
||||
head_masks=None,
|
||||
adv_embedding=None,
|
||||
return_embedding=False,
|
||||
loss_weight=None):
|
||||
all_encoder_layers, pooled_output = self.bert(input_ids,
|
||||
token_type_ids,
|
||||
attention_mask, epoch_id,
|
||||
head_masks,
|
||||
adv_embedding)
|
||||
pooled_output = self.dropout(pooled_output)
|
||||
logits = [classifier(pooled_output) for classifier in self.classifier]
|
||||
if labels is not None:
|
||||
loss_fct = CrossEntropyLoss(reduction='none')
|
||||
regression_loss_fct = nn.MSELoss(reduction='none')
|
||||
labels_lst = torch.unbind(labels, 1)
|
||||
loss_lst = []
|
||||
for index, (label, logit) in enumerate(zip(labels_lst, logits)):
|
||||
if len(self.label_list[index]) != 1:
|
||||
loss = loss_fct(logit, label.long())
|
||||
else:
|
||||
loss = regression_loss_fct(logit.squeeze(-1), label)
|
||||
labels_mask = (labels_index == index).to(
|
||||
dtype=next(self.parameters()).dtype)
|
||||
if loss_weight is not None:
|
||||
loss = loss * loss_weight[index]
|
||||
loss = torch.mean(loss * labels_mask)
|
||||
loss_lst.append(loss)
|
||||
if not return_embedding:
|
||||
return sum(loss_lst), logits
|
||||
else:
|
||||
return sum(loss_lst), logits, all_encoder_layers[0]
|
||||
else:
|
||||
return logits
|
||||
333
modelscope/models/multi_modal/imagen/tokenizer.py
Normal file
333
modelscope/models/multi_modal/imagen/tokenizer.py
Normal file
@@ -0,0 +1,333 @@
|
||||
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team and Alibaba inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tokenization classes."""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
import collections
|
||||
import unicodedata
|
||||
|
||||
import six
|
||||
|
||||
|
||||
def convert_to_unicode(text):
|
||||
"""Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
|
||||
if six.PY3:
|
||||
if isinstance(text, str):
|
||||
return text
|
||||
elif isinstance(text, bytes):
|
||||
return text.decode('utf-8', 'ignore')
|
||||
else:
|
||||
raise ValueError('Unsupported string type: %s' % (type(text)))
|
||||
elif six.PY2:
|
||||
if isinstance(text, str):
|
||||
return text.decode('utf-8', 'ignore')
|
||||
elif isinstance(text, unicode):
|
||||
return text
|
||||
else:
|
||||
raise ValueError('Unsupported string type: %s' % (type(text)))
|
||||
else:
|
||||
raise ValueError('Not running on Python2 or Python 3?')
|
||||
|
||||
|
||||
def printable_text(text):
|
||||
"""Returns text encoded in a way suitable for print or `tf.logging`."""
|
||||
|
||||
# These functions want `str` for both Python2 and Python3, but in one case
|
||||
# it's a Unicode string and in the other it's a byte string.
|
||||
if six.PY3:
|
||||
if isinstance(text, str):
|
||||
return text
|
||||
elif isinstance(text, bytes):
|
||||
return text.decode('utf-8', 'ignore')
|
||||
else:
|
||||
raise ValueError('Unsupported string type: %s' % (type(text)))
|
||||
elif six.PY2:
|
||||
if isinstance(text, str):
|
||||
return text
|
||||
elif isinstance(text, unicode):
|
||||
return text.encode('utf-8')
|
||||
else:
|
||||
raise ValueError('Unsupported string type: %s' % (type(text)))
|
||||
else:
|
||||
raise ValueError('Not running on Python2 or Python 3?')
|
||||
|
||||
|
||||
def load_vocab(vocab_file):
|
||||
"""Loads a vocabulary file into a dictionary."""
|
||||
vocab = collections.OrderedDict()
|
||||
index = 0
|
||||
with open(vocab_file, 'r') as reader:
|
||||
while True:
|
||||
token = convert_to_unicode(reader.readline())
|
||||
if not token:
|
||||
break
|
||||
token = token.strip()
|
||||
vocab[token] = index
|
||||
index += 1
|
||||
return vocab
|
||||
|
||||
|
||||
def convert_tokens_to_ids(vocab, tokens):
|
||||
"""Converts a sequence of tokens into ids using the vocab."""
|
||||
ids = []
|
||||
for token in tokens:
|
||||
ids.append(vocab[token])
|
||||
return ids
|
||||
|
||||
|
||||
def whitespace_tokenize(text):
|
||||
"""Runs basic whitespace cleaning and splitting on a peice of text."""
|
||||
text = text.strip()
|
||||
if not text:
|
||||
return []
|
||||
tokens = text.split()
|
||||
return tokens
|
||||
|
||||
|
||||
class FullTokenizer(object):
|
||||
"""Runs end-to-end tokenziation."""
|
||||
|
||||
def __init__(self, vocab_file, do_lower_case=True):
|
||||
self.vocab = load_vocab(vocab_file)
|
||||
self.inv_vocab = {v: k for k, v in self.vocab.items()}
|
||||
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
|
||||
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
|
||||
|
||||
def tokenize(self, text):
|
||||
split_tokens = []
|
||||
for token in self.basic_tokenizer.tokenize(text):
|
||||
for sub_token in self.wordpiece_tokenizer.tokenize(token):
|
||||
split_tokens.append(sub_token)
|
||||
|
||||
return split_tokens
|
||||
|
||||
def convert_tokens_to_ids(self, tokens):
|
||||
return convert_tokens_to_ids(self.vocab, tokens)
|
||||
|
||||
def convert_ids_to_tokens(self, ids):
|
||||
return [self.inv_vocab[i] for i in ids]
|
||||
|
||||
|
||||
class BasicTokenizer(object):
|
||||
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
|
||||
|
||||
def __init__(self, do_lower_case=True):
|
||||
"""Constructs a BasicTokenizer.
|
||||
|
||||
Args:
|
||||
do_lower_case: Whether to lower case the input.
|
||||
"""
|
||||
self.do_lower_case = do_lower_case
|
||||
|
||||
def tokenize(self, text):
|
||||
"""Tokenizes a piece of text."""
|
||||
text = convert_to_unicode(text)
|
||||
text = self._clean_text(text)
|
||||
# This was added on November 1st, 2018 for the multilingual and Chinese
|
||||
# models. This is also applied to the English models now, but it doesn't
|
||||
# matter since the English models were not trained on any Chinese data
|
||||
# and generally don't have any Chinese data in them (there are Chinese
|
||||
# characters in the vocabulary because Wikipedia does have some Chinese
|
||||
# words in the English Wikipedia.).
|
||||
text = self._tokenize_chinese_chars(text)
|
||||
orig_tokens = whitespace_tokenize(text)
|
||||
split_tokens = []
|
||||
for token in orig_tokens:
|
||||
if self.do_lower_case:
|
||||
token = token.lower()
|
||||
token = self._run_strip_accents(token)
|
||||
split_tokens.extend(self._run_split_on_punc(token))
|
||||
|
||||
output_tokens = whitespace_tokenize(' '.join(split_tokens))
|
||||
return output_tokens
|
||||
|
||||
def _run_strip_accents(self, text):
|
||||
"""Strips accents from a piece of text."""
|
||||
text = unicodedata.normalize('NFD', text)
|
||||
output = []
|
||||
for char in text:
|
||||
cat = unicodedata.category(char)
|
||||
if cat == 'Mn':
|
||||
continue
|
||||
output.append(char)
|
||||
return ''.join(output)
|
||||
|
||||
def _run_split_on_punc(self, text):
|
||||
"""Splits punctuation on a piece of text."""
|
||||
chars = list(text)
|
||||
i = 0
|
||||
start_new_word = True
|
||||
output = []
|
||||
while i < len(chars):
|
||||
char = chars[i]
|
||||
if _is_punctuation(char):
|
||||
output.append([char])
|
||||
start_new_word = True
|
||||
else:
|
||||
if start_new_word:
|
||||
output.append([])
|
||||
start_new_word = False
|
||||
output[-1].append(char)
|
||||
i += 1
|
||||
|
||||
return [''.join(x) for x in output]
|
||||
|
||||
def _tokenize_chinese_chars(self, text):
|
||||
"""Adds whitespace around any CJK character."""
|
||||
output = []
|
||||
for char in text:
|
||||
cp = ord(char)
|
||||
if self._is_chinese_char(cp):
|
||||
output.append(' ')
|
||||
output.append(char)
|
||||
output.append(' ')
|
||||
else:
|
||||
output.append(char)
|
||||
return ''.join(output)
|
||||
|
||||
def _is_chinese_char(self, cp):
|
||||
"""Checks whether CP is the codepoint of a CJK character."""
|
||||
# This defines a "chinese character" as anything in the CJK Unicode block:
|
||||
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
|
||||
#
|
||||
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
|
||||
# despite its name. The modern Korean Hangul alphabet is a different block,
|
||||
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
|
||||
# space-separated words, so they are not treated specially and handled
|
||||
# like the all of the other languages.
|
||||
if ((cp >= 0x4E00 and cp <= 0x9FFF) or (cp >= 0x3400 and cp <= 0x4DBF)
|
||||
or (cp >= 0x20000 and cp <= 0x2A6DF)
|
||||
or (cp >= 0x2A700 and cp <= 0x2B73F)
|
||||
or (cp >= 0x2B740 and cp <= 0x2B81F)
|
||||
or (cp >= 0x2B820 and cp <= 0x2CEAF)
|
||||
or (cp >= 0xF900 and cp <= 0xFAFF)
|
||||
or (cp >= 0x2F800 and cp <= 0x2FA1F)):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _clean_text(self, text):
|
||||
"""Performs invalid character removal and whitespace cleanup on text."""
|
||||
output = []
|
||||
for char in text:
|
||||
cp = ord(char)
|
||||
if cp == 0 or cp == 0xfffd or _is_control(char):
|
||||
continue
|
||||
if _is_whitespace(char):
|
||||
output.append(' ')
|
||||
else:
|
||||
output.append(char)
|
||||
return ''.join(output)
|
||||
|
||||
|
||||
class WordpieceTokenizer(object):
|
||||
"""Runs WordPiece tokenization."""
|
||||
|
||||
def __init__(self, vocab, unk_token='[UNK]', max_input_chars_per_word=100):
|
||||
self.vocab = vocab
|
||||
self.unk_token = unk_token
|
||||
self.max_input_chars_per_word = max_input_chars_per_word
|
||||
|
||||
def tokenize(self, text):
|
||||
"""Tokenizes a piece of text into its word pieces.
|
||||
|
||||
This uses a greedy longest-match-first algorithm to perform tokenization
|
||||
using the given vocabulary.
|
||||
|
||||
For example:
|
||||
input = "unaffable"
|
||||
output = ["un", "##aff", "##able"]
|
||||
|
||||
Args:
|
||||
text: A single token or whitespace separated tokens. This should have
|
||||
already been passed through `BasicTokenizer.
|
||||
|
||||
Returns:
|
||||
A list of wordpiece tokens.
|
||||
"""
|
||||
|
||||
text = convert_to_unicode(text)
|
||||
|
||||
output_tokens = []
|
||||
for token in whitespace_tokenize(text):
|
||||
chars = list(token)
|
||||
if len(chars) > self.max_input_chars_per_word:
|
||||
output_tokens.append(self.unk_token)
|
||||
continue
|
||||
|
||||
is_bad = False
|
||||
start = 0
|
||||
sub_tokens = []
|
||||
while start < len(chars):
|
||||
end = len(chars)
|
||||
cur_substr = None
|
||||
while start < end:
|
||||
substr = ''.join(chars[start:end])
|
||||
if start > 0:
|
||||
substr = '##' + substr
|
||||
if substr in self.vocab:
|
||||
cur_substr = substr
|
||||
break
|
||||
end -= 1
|
||||
if cur_substr is None:
|
||||
is_bad = True
|
||||
break
|
||||
sub_tokens.append(cur_substr)
|
||||
start = end
|
||||
|
||||
if is_bad:
|
||||
output_tokens.append(self.unk_token)
|
||||
else:
|
||||
output_tokens.extend(sub_tokens)
|
||||
return output_tokens
|
||||
|
||||
|
||||
def _is_whitespace(char):
|
||||
"""Checks whether `chars` is a whitespace character."""
|
||||
# \t, \n, and \r are technically contorl characters but we treat them
|
||||
# as whitespace since they are generally considered as such.
|
||||
if char == ' ' or char == '\t' or char == '\n' or char == '\r':
|
||||
return True
|
||||
cat = unicodedata.category(char)
|
||||
if cat == 'Zs':
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _is_control(char):
|
||||
"""Checks whether `chars` is a control character."""
|
||||
# These are technically control characters but we count them as whitespace
|
||||
# characters.
|
||||
if char == '\t' or char == '\n' or char == '\r':
|
||||
return False
|
||||
cat = unicodedata.category(char)
|
||||
if cat.startswith('C'):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _is_punctuation(char):
|
||||
"""Checks whether `chars` is a punctuation character."""
|
||||
cp = ord(char)
|
||||
# We treat all non-letter/number ASCII as punctuation.
|
||||
# Characters such as "^", "$", and "`" are not in the Unicode
|
||||
# Punctuation class but we treat them as punctuation anyways, for
|
||||
# consistency.
|
||||
if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64)
|
||||
or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
|
||||
return True
|
||||
cat = unicodedata.category(char)
|
||||
if cat.startswith('P'):
|
||||
return True
|
||||
return False
|
||||
319
modelscope/models/multi_modal/imagen/unet_generator.py
Normal file
319
modelscope/models/multi_modal/imagen/unet_generator.py
Normal file
@@ -0,0 +1,319 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
__all__ = ['ImagenGenerator']
|
||||
|
||||
|
||||
def sinusoidal_embedding(timesteps, dim):
|
||||
# check input
|
||||
half = dim // 2
|
||||
timesteps = timesteps.float()
|
||||
|
||||
# compute sinusoidal embedding
|
||||
sinusoid = torch.outer(
|
||||
timesteps, torch.pow(10000,
|
||||
-torch.arange(half).to(timesteps).div(half)))
|
||||
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
|
||||
if dim % 2 != 0:
|
||||
x = torch.cat([x, torch.zeros_like(x[:, :1])], dim=1)
|
||||
return x
|
||||
|
||||
|
||||
class Resample(nn.Module):
|
||||
|
||||
def __init__(self, in_dim, out_dim, scale_factor, use_conv=False):
|
||||
assert scale_factor in [0.5, 1.0, 2.0]
|
||||
super(Resample, self).__init__()
|
||||
self.in_dim = in_dim
|
||||
self.out_dim = out_dim
|
||||
self.scale_factor = scale_factor
|
||||
self.use_conv = use_conv
|
||||
|
||||
# layers
|
||||
if scale_factor == 2.0:
|
||||
self.resample = nn.Sequential(
|
||||
nn.Upsample(scale_factor=scale_factor, mode='nearest'),
|
||||
nn.Conv2d(in_dim, out_dim, 3, padding=1)
|
||||
if use_conv else nn.Identity())
|
||||
elif scale_factor == 0.5:
|
||||
self.resample = nn.Conv2d(
|
||||
in_dim, out_dim, 3, stride=2,
|
||||
padding=1) if use_conv else nn.AvgPool2d(
|
||||
kernel_size=2, stride=2)
|
||||
else:
|
||||
self.resample = nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
return self.resample(x)
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
in_dim,
|
||||
embed_dim,
|
||||
out_dim,
|
||||
use_scale_shift_norm=True,
|
||||
scale_factor=1.0,
|
||||
dropout=0.0):
|
||||
super(ResidualBlock, self).__init__()
|
||||
self.in_dim = in_dim
|
||||
self.embed_dim = embed_dim
|
||||
self.out_dim = out_dim
|
||||
self.use_scale_shift_norm = use_scale_shift_norm
|
||||
self.scale_factor = scale_factor
|
||||
|
||||
# layers
|
||||
self.layer1 = nn.Sequential(
|
||||
nn.GroupNorm(32, in_dim), nn.SiLU(),
|
||||
nn.Conv2d(in_dim, out_dim, 3, padding=1))
|
||||
self.resample = Resample(in_dim, in_dim, scale_factor, use_conv=False)
|
||||
self.embedding = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(embed_dim,
|
||||
out_dim * 2 if use_scale_shift_norm else out_dim))
|
||||
self.layer2 = nn.Sequential(
|
||||
nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout),
|
||||
nn.Conv2d(out_dim, out_dim, 3, padding=1))
|
||||
self.shortcut = nn.Identity() if in_dim == out_dim else nn.Conv2d(
|
||||
in_dim, out_dim, 1)
|
||||
|
||||
# zero out the last layer params
|
||||
nn.init.zeros_(self.layer2[-1].weight)
|
||||
|
||||
def forward(self, x, e):
|
||||
identity = self.resample(x)
|
||||
x = self.layer1[-1](self.resample(self.layer1[:-1](x)))
|
||||
e = self.embedding(e).unsqueeze(-1).unsqueeze(-1).type(x.dtype)
|
||||
if self.use_scale_shift_norm:
|
||||
scale, shift = e.chunk(2, dim=1)
|
||||
x = self.layer2[0](x) * (1 + scale) + shift
|
||||
x = self.layer2[1:](x)
|
||||
else:
|
||||
x = x + e
|
||||
x = self.layer2(x)
|
||||
x = x + self.shortcut(identity)
|
||||
return x
|
||||
|
||||
|
||||
class AttentionBlock(nn.Module):
|
||||
|
||||
def __init__(self, dim, context_dim=None, num_heads=None, head_dim=None):
|
||||
# consider head_dim first, then num_heads
|
||||
num_heads = dim // head_dim if head_dim else num_heads
|
||||
head_dim = dim // num_heads
|
||||
assert num_heads * head_dim == dim
|
||||
super(AttentionBlock, self).__init__()
|
||||
self.dim = dim
|
||||
self.context_dim = context_dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
self.scale = math.pow(head_dim, -0.25)
|
||||
|
||||
# layers
|
||||
self.norm = nn.GroupNorm(32, dim)
|
||||
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
|
||||
if context_dim is not None:
|
||||
self.context_kv = nn.Linear(context_dim, dim * 2)
|
||||
self.proj = nn.Conv2d(dim, dim, 1)
|
||||
|
||||
# zero out the last layer params
|
||||
nn.init.zeros_(self.proj.weight)
|
||||
|
||||
def forward(self, x, context=None, mask=None):
|
||||
identity = x
|
||||
b, c, h, w, n, d = *x.size(), self.num_heads, self.head_dim
|
||||
|
||||
# compute query, key, value
|
||||
x = self.norm(x)
|
||||
q, k, v = self.to_qkv(x).view(b, n * 3, d, h * w).chunk(3, dim=1)
|
||||
if context is not None:
|
||||
ck, cv = self.context_kv(context).reshape(b, -1, n * 2,
|
||||
d).permute(0, 2, 3,
|
||||
1).chunk(
|
||||
2, dim=1)
|
||||
k = torch.cat([ck, k], dim=-1)
|
||||
v = torch.cat([cv, v], dim=-1)
|
||||
|
||||
# compute attention
|
||||
attn = torch.matmul(q.transpose(-1, -2) * self.scale, k * self.scale)
|
||||
if mask is not None:
|
||||
assert context is not None
|
||||
full_mask = x.new_ones((b, 1, q.size(-1), k.size(-1)))
|
||||
full_mask[:, 0, :, :-q.size(-1)] = mask.unsqueeze(1)
|
||||
attn = attn.masked_fill(full_mask == 0, float('-inf'))
|
||||
attn = F.softmax(attn, dim=-1)
|
||||
|
||||
# gather context
|
||||
x = torch.matmul(v, attn.transpose(-1, -2))
|
||||
x = x.reshape(b, c, h, w)
|
||||
|
||||
# output
|
||||
x = self.proj(x)
|
||||
return x + identity
|
||||
|
||||
|
||||
class ImagenGenerator(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
in_dim=3,
|
||||
dim=512,
|
||||
text_dim=1024,
|
||||
context_dim=512,
|
||||
out_dim=6,
|
||||
dim_mult=[1, 2, 3, 4],
|
||||
num_heads=None,
|
||||
head_dim=64,
|
||||
num_res_blocks=3,
|
||||
attn_scales=[1 / 2, 1 / 4, 1 / 8],
|
||||
resblock_resample=True,
|
||||
use_scale_shift_norm=True,
|
||||
dropout=0.0):
|
||||
embed_dim = dim * 4
|
||||
super(ImagenGenerator, self).__init__()
|
||||
self.in_dim = in_dim
|
||||
self.dim = dim
|
||||
self.text_dim = text_dim
|
||||
self.context_dim = context_dim
|
||||
self.embed_dim = embed_dim
|
||||
self.out_dim = out_dim
|
||||
self.dim_mult = dim_mult
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.attn_scales = attn_scales
|
||||
self.resblock_resample = resblock_resample
|
||||
self.use_scale_shift_norm = use_scale_shift_norm
|
||||
|
||||
# params
|
||||
enc_dims = [dim * u for u in [1] + dim_mult]
|
||||
dec_dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
||||
shortcut_dims = []
|
||||
scale = 1.0
|
||||
|
||||
# embeddings
|
||||
self.time_embedding = nn.Sequential(
|
||||
nn.Linear(dim, embed_dim), nn.SiLU(),
|
||||
nn.Linear(embed_dim, embed_dim))
|
||||
self.pool_embedding = nn.Sequential(
|
||||
nn.LayerNorm(text_dim), nn.Linear(text_dim, embed_dim))
|
||||
self.text_embedding = nn.Sequential(
|
||||
nn.LayerNorm(text_dim), nn.Linear(text_dim, context_dim),
|
||||
nn.SiLU(), nn.Linear(context_dim, context_dim))
|
||||
|
||||
# encoder
|
||||
self.encoder = nn.ModuleList(
|
||||
[nn.Conv2d(self.in_dim, dim, 3, padding=1)])
|
||||
shortcut_dims.append(dim)
|
||||
for i, (in_dim,
|
||||
out_dim) in enumerate(zip(enc_dims[:-1], enc_dims[1:])):
|
||||
for j in range(num_res_blocks):
|
||||
# residual (+attention) blocks
|
||||
block = nn.ModuleList(
|
||||
[ResidualBlock(in_dim, embed_dim, out_dim, dropout)])
|
||||
if scale in attn_scales:
|
||||
block.append(
|
||||
AttentionBlock(out_dim, context_dim, num_heads,
|
||||
head_dim))
|
||||
in_dim = out_dim
|
||||
self.encoder.append(block)
|
||||
shortcut_dims.append(out_dim)
|
||||
|
||||
# downsample
|
||||
if i != len(dim_mult) - 1 and j == num_res_blocks - 1:
|
||||
if resblock_resample:
|
||||
downsample = ResidualBlock(out_dim, embed_dim, out_dim,
|
||||
use_scale_shift_norm, 0.5,
|
||||
dropout)
|
||||
else:
|
||||
downsample = Resample(
|
||||
out_dim, out_dim, 0.5, use_conv=True)
|
||||
shortcut_dims.append(out_dim)
|
||||
scale /= 2.0
|
||||
self.encoder.append(downsample)
|
||||
|
||||
# middle
|
||||
self.middle = nn.ModuleList([
|
||||
ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm,
|
||||
1.0, dropout),
|
||||
AttentionBlock(out_dim, context_dim, num_heads, head_dim),
|
||||
ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm,
|
||||
1.0, dropout)
|
||||
])
|
||||
|
||||
# decoder
|
||||
self.decoder = nn.ModuleList()
|
||||
for i, (in_dim,
|
||||
out_dim) in enumerate(zip(dec_dims[:-1], dec_dims[1:])):
|
||||
for j in range(num_res_blocks + 1):
|
||||
# residual (+attention) blocks
|
||||
block = nn.ModuleList([
|
||||
ResidualBlock(in_dim + shortcut_dims.pop(), embed_dim,
|
||||
out_dim, use_scale_shift_norm, 1.0, dropout)
|
||||
])
|
||||
if scale in attn_scales:
|
||||
block.append(
|
||||
AttentionBlock(out_dim, context_dim, num_heads,
|
||||
head_dim))
|
||||
in_dim = out_dim
|
||||
|
||||
# upsample
|
||||
if i != len(dim_mult) - 1 and j == num_res_blocks:
|
||||
if resblock_resample:
|
||||
upsample = ResidualBlock(out_dim, embed_dim, out_dim,
|
||||
use_scale_shift_norm, 2.0,
|
||||
dropout)
|
||||
else:
|
||||
upsample = Resample(
|
||||
out_dim, out_dim, 2.0, use_conv=True)
|
||||
scale *= 2.0
|
||||
block.append(upsample)
|
||||
self.decoder.append(block)
|
||||
|
||||
# head
|
||||
self.head = nn.Sequential(
|
||||
nn.GroupNorm(32, out_dim), nn.SiLU(),
|
||||
nn.Conv2d(out_dim, self.out_dim, 3, padding=1))
|
||||
|
||||
# zero out the last layer params
|
||||
nn.init.zeros_(self.head[-1].weight)
|
||||
|
||||
def forward(self, x, t, y, context, mask=None):
|
||||
# embeddings
|
||||
e = self.time_embedding(sinusoidal_embedding(
|
||||
t, self.dim)) + self.pool_embedding(y)
|
||||
context = self.text_embedding(context)
|
||||
|
||||
# encoder
|
||||
xs = []
|
||||
for block in self.encoder:
|
||||
x = self._forward_single(block, x, e, context, mask)
|
||||
xs.append(x)
|
||||
|
||||
# middle
|
||||
for block in self.middle:
|
||||
x = self._forward_single(block, x, e, context, mask)
|
||||
|
||||
# decoder
|
||||
for block in self.decoder:
|
||||
x = torch.cat([x, xs.pop()], dim=1)
|
||||
x = self._forward_single(block, x, e, context, mask)
|
||||
|
||||
# head
|
||||
x = self.head(x)
|
||||
return x
|
||||
|
||||
def _forward_single(self, module, x, e, context, mask):
|
||||
if isinstance(module, ResidualBlock):
|
||||
x = module(x, e)
|
||||
elif isinstance(module, AttentionBlock):
|
||||
x = module(x, context, mask)
|
||||
elif isinstance(module, nn.ModuleList):
|
||||
for block in module:
|
||||
x = self._forward_single(block, x, e, context, mask)
|
||||
else:
|
||||
x = module(x)
|
||||
return x
|
||||
@@ -0,0 +1,337 @@
|
||||
import math
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
__all__ = ['SuperResUNet256']
|
||||
|
||||
|
||||
def sinusoidal_embedding(timesteps, dim):
|
||||
# check input
|
||||
half = dim // 2
|
||||
timesteps = timesteps.float()
|
||||
|
||||
# compute sinusoidal embedding
|
||||
sinusoid = torch.outer(
|
||||
timesteps, torch.pow(10000,
|
||||
-torch.arange(half).to(timesteps).div(half)))
|
||||
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
|
||||
if dim % 2 != 0:
|
||||
x = torch.cat([x, torch.zeros_like(x[:, :1])], dim=1)
|
||||
return x
|
||||
|
||||
|
||||
class Resample(nn.Module):
|
||||
|
||||
def __init__(self, in_dim, out_dim, scale_factor, use_conv=False):
|
||||
assert scale_factor in [0.5, 1.0, 2.0]
|
||||
super(Resample, self).__init__()
|
||||
self.in_dim = in_dim
|
||||
self.out_dim = out_dim
|
||||
self.scale_factor = scale_factor
|
||||
self.use_conv = use_conv
|
||||
|
||||
# layers
|
||||
if scale_factor == 2.0:
|
||||
self.resample = nn.Sequential(
|
||||
nn.Upsample(scale_factor=scale_factor, mode='nearest'),
|
||||
nn.Conv2d(in_dim, out_dim, 3, padding=1)
|
||||
if use_conv else nn.Identity())
|
||||
elif scale_factor == 0.5:
|
||||
self.resample = nn.Conv2d(
|
||||
in_dim, out_dim, 3, stride=2,
|
||||
padding=1) if use_conv else nn.AvgPool2d(
|
||||
kernel_size=2, stride=2)
|
||||
else:
|
||||
self.resample = nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
return self.resample(x)
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
in_dim,
|
||||
embed_dim,
|
||||
out_dim,
|
||||
use_scale_shift_norm=True,
|
||||
scale_factor=1.0,
|
||||
dropout=0.0):
|
||||
super(ResidualBlock, self).__init__()
|
||||
self.in_dim = in_dim
|
||||
self.embed_dim = embed_dim
|
||||
self.out_dim = out_dim
|
||||
self.use_scale_shift_norm = use_scale_shift_norm
|
||||
self.scale_factor = scale_factor
|
||||
|
||||
# layers
|
||||
self.layer1 = nn.Sequential(
|
||||
nn.GroupNorm(32, in_dim), nn.SiLU(),
|
||||
nn.Conv2d(in_dim, out_dim, 3, padding=1))
|
||||
self.resample_x = Resample(in_dim, in_dim, scale_factor)
|
||||
self.resample_i = Resample(in_dim, in_dim, scale_factor)
|
||||
self.embedding = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(embed_dim,
|
||||
out_dim * 2 if use_scale_shift_norm else out_dim))
|
||||
self.layer2 = nn.Sequential(
|
||||
nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout),
|
||||
nn.Conv2d(out_dim, out_dim, 3, padding=1))
|
||||
self.shortcut = nn.Identity() if in_dim == out_dim else nn.Conv2d(
|
||||
in_dim, out_dim, 1)
|
||||
|
||||
# zero out the last layer params
|
||||
nn.init.zeros_(self.layer2[-1].weight)
|
||||
|
||||
def forward(self, x, e):
|
||||
identity = self.resample_i(x)
|
||||
x = self.layer1[-1](self.resample_x(self.layer1[:-1](x)))
|
||||
e = self.embedding(e).unsqueeze(-1).unsqueeze(-1)
|
||||
if self.use_scale_shift_norm:
|
||||
scale, shift = e.chunk(2, dim=1)
|
||||
x = self.layer2[0](x) * (1 + scale) + shift
|
||||
x = self.layer2[1:](x)
|
||||
else:
|
||||
x = x + e
|
||||
x = self.layer2(x)
|
||||
x = x + self.shortcut(identity)
|
||||
return x
|
||||
|
||||
|
||||
class AttentionBlock(nn.Module):
|
||||
|
||||
def __init__(self, dim, context_dim=None, num_heads=None, head_dim=None):
|
||||
# consider head_dim first, then num_heads
|
||||
num_heads = dim // head_dim if head_dim else num_heads
|
||||
head_dim = dim // num_heads
|
||||
assert num_heads * head_dim == dim
|
||||
super(AttentionBlock, self).__init__()
|
||||
self.dim = dim
|
||||
self.context_dim = context_dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
self.scale = math.pow(head_dim, -0.25)
|
||||
|
||||
# layers
|
||||
self.norm = nn.GroupNorm(32, dim)
|
||||
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
|
||||
if context_dim is not None:
|
||||
self.context_kv = nn.Linear(context_dim, dim * 2)
|
||||
self.proj = nn.Conv2d(dim, dim, 1)
|
||||
|
||||
# zero out the last layer params
|
||||
nn.init.zeros_(self.proj.weight)
|
||||
|
||||
def forward(self, x, context=None, mask=None):
|
||||
r"""x: [B, C, H, W].
|
||||
context: [B, L, C] or None.
|
||||
mask: [B, L] or None.
|
||||
"""
|
||||
identity = x
|
||||
b, c, h, w, n, d = *x.size(), self.num_heads, self.head_dim
|
||||
|
||||
# compute query, key, value
|
||||
x = self.norm(x)
|
||||
q, k, v = self.to_qkv(x).view(b, n * 3, d, h * w).chunk(3, dim=1)
|
||||
if context is not None:
|
||||
ck, cv = self.context_kv(context).reshape(b, -1, n * 2,
|
||||
d).permute(0, 2, 3,
|
||||
1).chunk(
|
||||
2, dim=1)
|
||||
k = torch.cat([k, ck], dim=-1)
|
||||
v = torch.cat([v, cv], dim=-1)
|
||||
|
||||
# compute attention
|
||||
attn = torch.einsum('bndi,bndj->bnij', q * self.scale, k * self.scale)
|
||||
if mask is not None:
|
||||
pad_mask = mask.new_ones((b, 1, 1, h * w))
|
||||
mask = torch.cat((pad_mask, mask.unsqueeze(1).unsqueeze(1)),
|
||||
dim=-1)
|
||||
attn = attn.masked_fill(mask == 0, float('-inf'))
|
||||
|
||||
attn = F.softmax(attn, dim=-1)
|
||||
|
||||
# gather context
|
||||
x = torch.einsum('bnij,bndj->bndi', attn, v)
|
||||
x = x.reshape(b, c, h, w)
|
||||
|
||||
# output
|
||||
x = self.proj(x)
|
||||
return x + identity
|
||||
|
||||
|
||||
class SuperResUNet256(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
in_dim=6,
|
||||
out_dim=3,
|
||||
dim=256,
|
||||
text_dim=1024,
|
||||
context_dim=512,
|
||||
dim_mult=[1, 2, 2, 3, 4],
|
||||
num_heads=None,
|
||||
head_dim=64,
|
||||
num_res_blocks=2,
|
||||
attn_scales=[1 / 16],
|
||||
resblock_resample=True,
|
||||
use_conv=True,
|
||||
use_scale_shift_norm=True,
|
||||
dropout=0.1):
|
||||
embed_dim = dim * 4
|
||||
super(SuperResUNet256, self).__init__()
|
||||
self.in_dim = in_dim
|
||||
self.dim = dim
|
||||
self.out_dim = out_dim
|
||||
self.dim_mult = dim_mult
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.attn_scales = attn_scales
|
||||
self.resblock_resample = resblock_resample
|
||||
self.use_conv = use_conv
|
||||
self.use_scale_shift_norm = use_scale_shift_norm
|
||||
|
||||
# params
|
||||
enc_dims = [dim * u for u in [1] + dim_mult]
|
||||
dec_dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
||||
shortcut_dims = []
|
||||
scale = 1.0
|
||||
|
||||
# embeddings
|
||||
self.time_embedding = nn.Sequential(
|
||||
nn.Linear(dim, embed_dim), nn.SiLU(),
|
||||
nn.Linear(embed_dim, embed_dim))
|
||||
self.noise_time_embedding = nn.Sequential(
|
||||
nn.Linear(dim, embed_dim), nn.SiLU(),
|
||||
nn.Linear(embed_dim, embed_dim))
|
||||
self.pool_embedding = nn.Sequential(
|
||||
nn.LayerNorm(text_dim), nn.Linear(text_dim, embed_dim))
|
||||
self.text_embedding = nn.Sequential(
|
||||
nn.LayerNorm(text_dim), nn.Linear(text_dim, context_dim),
|
||||
nn.SiLU(), nn.Linear(context_dim, context_dim))
|
||||
|
||||
# encoder
|
||||
self.encoder = nn.ModuleList(
|
||||
[nn.Conv2d(self.in_dim, dim, 3, padding=1)])
|
||||
shortcut_dims.append(dim)
|
||||
for i, (in_dim,
|
||||
out_dim) in enumerate(zip(enc_dims[:-1], enc_dims[1:])):
|
||||
for j in range(num_res_blocks):
|
||||
# residual (+attention) blocks
|
||||
block = nn.ModuleList([
|
||||
ResidualBlock(in_dim, embed_dim, out_dim,
|
||||
use_scale_shift_norm, 1.0, dropout)
|
||||
])
|
||||
if scale in attn_scales:
|
||||
block.append(
|
||||
AttentionBlock(out_dim, context_dim, num_heads,
|
||||
head_dim))
|
||||
shortcut_dims.append(out_dim)
|
||||
in_dim = out_dim
|
||||
self.encoder.append(block)
|
||||
|
||||
# downsample
|
||||
if i != len(dim_mult) - 1:
|
||||
if resblock_resample:
|
||||
downsample = ResidualBlock(out_dim, embed_dim, out_dim,
|
||||
use_scale_shift_norm, 0.5,
|
||||
dropout)
|
||||
else:
|
||||
downsample = Resample(out_dim, out_dim, 0.5, use_conv)
|
||||
shortcut_dims.append(out_dim)
|
||||
scale /= 2.0
|
||||
self.encoder.append(downsample)
|
||||
|
||||
# middle
|
||||
self.middle = nn.ModuleList([
|
||||
ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm,
|
||||
1.0, dropout),
|
||||
AttentionBlock(out_dim, context_dim, num_heads, head_dim),
|
||||
ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm,
|
||||
1.0, dropout)
|
||||
])
|
||||
|
||||
# decoder
|
||||
self.decoder = nn.ModuleList()
|
||||
for i, (in_dim,
|
||||
out_dim) in enumerate(zip(dec_dims[:-1], dec_dims[1:])):
|
||||
for j in range(num_res_blocks + 1):
|
||||
# residual (+attention) blocks
|
||||
block = nn.ModuleList([
|
||||
ResidualBlock(in_dim + shortcut_dims.pop(), embed_dim,
|
||||
out_dim, use_scale_shift_norm, 1.0, dropout)
|
||||
])
|
||||
if scale in attn_scales:
|
||||
block.append(
|
||||
AttentionBlock(out_dim, context_dim, num_heads,
|
||||
head_dim))
|
||||
in_dim = out_dim
|
||||
|
||||
# upsample
|
||||
if i != len(dim_mult) - 1 and j == num_res_blocks:
|
||||
if resblock_resample:
|
||||
upsample = ResidualBlock(out_dim, embed_dim, out_dim,
|
||||
use_scale_shift_norm, 2.0,
|
||||
dropout)
|
||||
else:
|
||||
upsample = Resample(out_dim, out_dim, 2.0, use_conv)
|
||||
scale *= 2.0
|
||||
block.append(upsample)
|
||||
self.decoder.append(block)
|
||||
|
||||
# head
|
||||
self.head = nn.Sequential(
|
||||
nn.GroupNorm(32, out_dim), nn.SiLU(),
|
||||
nn.Conv2d(out_dim, self.out_dim, 3, padding=1))
|
||||
|
||||
# zero out the last layer params
|
||||
nn.init.zeros_(self.head[-1].weight)
|
||||
|
||||
def forward(self, x, t, lx, lt, y, context, mask):
|
||||
assert context.shape[:-1] == mask.shape
|
||||
|
||||
# embeddings
|
||||
t = self.time_embedding(sinusoidal_embedding(t, self.dim)) \
|
||||
+ self.noise_time_embedding(sinusoidal_embedding(lt, self.dim)) \
|
||||
+ self.pool_embedding(y)
|
||||
|
||||
context = self.text_embedding(context)
|
||||
|
||||
if lx.shape[-2:] != x.shape[-2:]:
|
||||
lx = F.interpolate(
|
||||
lx, x.shape[-2:], mode='bilinear', align_corners=False)
|
||||
x = torch.cat([x, lx], dim=1)
|
||||
|
||||
# encoder
|
||||
xs = []
|
||||
for block in self.encoder:
|
||||
x = self._forward_single(block, x, t, context, mask)
|
||||
xs.append(x)
|
||||
|
||||
# middle
|
||||
for block in self.middle:
|
||||
x = self._forward_single(block, x, t, context, mask)
|
||||
|
||||
# decoder
|
||||
for block in self.decoder:
|
||||
x = torch.cat([x, xs.pop()], dim=1)
|
||||
x = self._forward_single(block, x, t, context, mask)
|
||||
|
||||
# head
|
||||
x = self.head(x)
|
||||
return x
|
||||
|
||||
def _forward_single(self, module, x, t, context, mask):
|
||||
if isinstance(module, ResidualBlock):
|
||||
x = module(x, t)
|
||||
elif isinstance(module, AttentionBlock):
|
||||
x = module(x, context, mask)
|
||||
elif isinstance(module, nn.ModuleList):
|
||||
for block in module:
|
||||
x = self._forward_single(block, x, t, context, mask)
|
||||
else:
|
||||
x = module(x)
|
||||
return x
|
||||
240
modelscope/models/multi_modal/imagen/unet_upsampler_1024.py
Normal file
240
modelscope/models/multi_modal/imagen/unet_upsampler_1024.py
Normal file
@@ -0,0 +1,240 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
__all__ = ['ImagenUpsampler1024']
|
||||
|
||||
|
||||
def sinusoidal_embedding(timesteps, dim):
|
||||
# check input
|
||||
half = dim // 2
|
||||
timesteps = timesteps.float()
|
||||
|
||||
# compute sinusoidal embedding
|
||||
sinusoid = torch.outer(
|
||||
timesteps, torch.pow(10000,
|
||||
-torch.arange(half).to(timesteps).div(half)))
|
||||
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
|
||||
if dim % 2 != 0:
|
||||
x = torch.cat([x, torch.zeros_like(x[:, :1])], dim=1)
|
||||
return x
|
||||
|
||||
|
||||
class Resample(nn.Module):
|
||||
|
||||
def __init__(self, in_dim, out_dim, scale_factor, use_conv=False):
|
||||
assert scale_factor in [0.5, 1.0, 2.0]
|
||||
super(Resample, self).__init__()
|
||||
self.in_dim = in_dim
|
||||
self.out_dim = out_dim
|
||||
self.scale_factor = scale_factor
|
||||
self.use_conv = use_conv
|
||||
|
||||
# layers
|
||||
if scale_factor == 2.0:
|
||||
self.resample = nn.Sequential(
|
||||
nn.Upsample(scale_factor=scale_factor, mode='nearest'),
|
||||
nn.Conv2d(in_dim, out_dim, 3, padding=1)
|
||||
if use_conv else nn.Identity())
|
||||
elif scale_factor == 0.5:
|
||||
self.resample = nn.Conv2d(
|
||||
in_dim, out_dim, 3, stride=2,
|
||||
padding=1) if use_conv else nn.AvgPool2d(
|
||||
kernel_size=2, stride=2)
|
||||
else:
|
||||
self.resample = nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
return self.resample(x)
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
in_dim,
|
||||
embed_dim,
|
||||
out_dim,
|
||||
use_scale_shift_norm=True,
|
||||
scale_factor=1.0,
|
||||
dropout=0.0):
|
||||
super(ResidualBlock, self).__init__()
|
||||
self.in_dim = in_dim
|
||||
self.embed_dim = embed_dim
|
||||
self.out_dim = out_dim
|
||||
self.use_scale_shift_norm = use_scale_shift_norm
|
||||
self.scale_factor = scale_factor
|
||||
|
||||
# layers
|
||||
self.layer1 = nn.Sequential(
|
||||
nn.GroupNorm(32, in_dim), nn.SiLU(),
|
||||
nn.Conv2d(in_dim, out_dim, 3, padding=1))
|
||||
self.resample = Resample(in_dim, in_dim, scale_factor, use_conv=False)
|
||||
self.embedding = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(embed_dim,
|
||||
out_dim * 2 if use_scale_shift_norm else out_dim))
|
||||
self.layer2 = nn.Sequential(
|
||||
nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout),
|
||||
nn.Conv2d(out_dim, out_dim, 3, padding=1))
|
||||
self.shortcut = nn.Identity() if in_dim == out_dim else nn.Conv2d(
|
||||
in_dim, out_dim, 1)
|
||||
|
||||
# zero out the last layer params
|
||||
nn.init.zeros_(self.layer2[-1].weight)
|
||||
|
||||
def forward(self, x, e):
|
||||
identity = self.resample(x)
|
||||
x = self.layer1[-1](self.resample(self.layer1[:-1](x)))
|
||||
e = self.embedding(e).unsqueeze(-1).unsqueeze(-1).type(x.dtype)
|
||||
if self.use_scale_shift_norm:
|
||||
scale, shift = e.chunk(2, dim=1)
|
||||
x = self.layer2[0](x) * (1 + scale) + shift
|
||||
x = self.layer2[1:](x)
|
||||
else:
|
||||
x = x + e
|
||||
x = self.layer2(x)
|
||||
x = x + self.shortcut(identity)
|
||||
return x
|
||||
|
||||
|
||||
class ImagenUpsampler1024(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
in_dim=6,
|
||||
dim=192,
|
||||
out_dim=3,
|
||||
dim_mult=[1, 1, 2, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
resblock_resample=True,
|
||||
use_scale_shift_norm=True,
|
||||
dropout=0.0):
|
||||
embed_dim = dim * 4
|
||||
super(ImagenUpsampler1024, self).__init__()
|
||||
self.in_dim = in_dim
|
||||
self.dim = dim
|
||||
self.out_dim = out_dim
|
||||
self.dim_mult = dim_mult
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resblock_resample = resblock_resample
|
||||
self.use_scale_shift_norm = use_scale_shift_norm
|
||||
|
||||
# params
|
||||
enc_dims = [dim * u for u in [1] + dim_mult]
|
||||
dec_dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
||||
shortcut_dims = []
|
||||
scale = 1.0
|
||||
|
||||
# embedding
|
||||
self.time_embedding = nn.Sequential(
|
||||
nn.Linear(dim, embed_dim), nn.SiLU(),
|
||||
nn.Linear(embed_dim, embed_dim))
|
||||
|
||||
# encoder
|
||||
self.encoder = nn.ModuleList(
|
||||
[nn.Conv2d(self.in_dim, dim, 3, padding=1)])
|
||||
shortcut_dims.append(dim)
|
||||
for i, (in_dim,
|
||||
out_dim) in enumerate(zip(enc_dims[:-1], enc_dims[1:])):
|
||||
for j in range(num_res_blocks):
|
||||
# residual block
|
||||
block = nn.ModuleList([
|
||||
ResidualBlock(in_dim, embed_dim, out_dim,
|
||||
use_scale_shift_norm, 1.0, dropout)
|
||||
])
|
||||
shortcut_dims.append(out_dim)
|
||||
in_dim = out_dim
|
||||
self.encoder.append(block)
|
||||
|
||||
# downsample
|
||||
if i != len(dim_mult) - 1 and j == num_res_blocks - 1:
|
||||
if resblock_resample:
|
||||
downsample = ResidualBlock(out_dim, embed_dim, out_dim,
|
||||
use_scale_shift_norm, 0.5,
|
||||
dropout)
|
||||
else:
|
||||
downsample = Resample(
|
||||
out_dim, out_dim, 0.5, use_conv=True)
|
||||
shortcut_dims.append(out_dim)
|
||||
scale /= 2.0
|
||||
self.encoder.append(downsample)
|
||||
|
||||
# middle
|
||||
self.middle = nn.ModuleList([
|
||||
ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm,
|
||||
1.0, dropout),
|
||||
ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm,
|
||||
1.0, dropout)
|
||||
])
|
||||
|
||||
# decoder
|
||||
self.decoder = nn.ModuleList()
|
||||
for i, (in_dim,
|
||||
out_dim) in enumerate(zip(dec_dims[:-1], dec_dims[1:])):
|
||||
for j in range(num_res_blocks + 1):
|
||||
# residual block
|
||||
block = nn.ModuleList([
|
||||
ResidualBlock(in_dim + shortcut_dims.pop(), embed_dim,
|
||||
out_dim, use_scale_shift_norm, 1.0, dropout)
|
||||
])
|
||||
in_dim = out_dim
|
||||
|
||||
# upsample
|
||||
if i != len(dim_mult) - 1 and j == num_res_blocks:
|
||||
if resblock_resample:
|
||||
upsample = ResidualBlock(out_dim, embed_dim, out_dim,
|
||||
use_scale_shift_norm, 2.0,
|
||||
dropout)
|
||||
else:
|
||||
upsample = Resample(
|
||||
out_dim, out_dim, 2.0, use_conv=True)
|
||||
scale *= 2.0
|
||||
block.append(upsample)
|
||||
self.decoder.append(block)
|
||||
|
||||
# head
|
||||
self.head = nn.Sequential(
|
||||
nn.GroupNorm(32, out_dim), nn.SiLU(),
|
||||
nn.Conv2d(out_dim, self.out_dim, 3, padding=1))
|
||||
|
||||
# zero out the last layer params
|
||||
nn.init.zeros_(self.head[-1].weight)
|
||||
|
||||
def forward(self, x, t, concat):
|
||||
# embedding
|
||||
if concat is not None:
|
||||
if concat.shape[-2:] != x.shape[-2:]:
|
||||
concat = F.interpolate(
|
||||
concat, x.shape[-2:], mode='bilinear', align_corners=False)
|
||||
x = torch.cat([x, concat], dim=1)
|
||||
e = self.time_embedding(sinusoidal_embedding(t, self.dim))
|
||||
|
||||
# encoder
|
||||
xs = []
|
||||
for block in self.encoder:
|
||||
x = self._forward_single(block, x, e)
|
||||
xs.append(x)
|
||||
|
||||
# middle
|
||||
for block in self.middle:
|
||||
x = self._forward_single(block, x, e)
|
||||
|
||||
# decoder
|
||||
for block in self.decoder:
|
||||
x = torch.cat([x, xs.pop()], dim=1)
|
||||
x = self._forward_single(block, x, e)
|
||||
|
||||
# head
|
||||
x = self.head(x)
|
||||
return x
|
||||
|
||||
def _forward_single(self, module, x, e):
|
||||
if isinstance(module, ResidualBlock):
|
||||
x = module(x, e)
|
||||
elif isinstance(module, nn.ModuleList):
|
||||
for block in module:
|
||||
x = self._forward_single(block, x, e)
|
||||
else:
|
||||
x = module(x)
|
||||
return x
|
||||
@@ -1,6 +1,7 @@
|
||||
try:
|
||||
from .image_captioning_pipeline import ImageCaptionPipeline
|
||||
from .multi_modal_embedding_pipeline import MultiModalEmbeddingPipeline
|
||||
from .text_to_image_synthesis_pipeline import TextToImageSynthesisPipeline
|
||||
from .visual_question_answering_pipeline import VisualQuestionAnsweringPipeline
|
||||
except ModuleNotFoundError as e:
|
||||
if str(e) == "No module named 'torch'":
|
||||
|
||||
@@ -0,0 +1,37 @@
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.pipelines.base import Input
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
from ..base import Model, Pipeline
|
||||
from ..builder import PIPELINES
|
||||
from ..outputs import OutputKeys
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.text_to_image_synthesis,
|
||||
module_name=Pipelines.text_to_image_synthesis)
|
||||
class TextToImageSynthesisPipeline(Pipeline):
|
||||
|
||||
def __init__(self, model: str, device_id: int = -1):
|
||||
if isinstance(model, str):
|
||||
pipe_model = Model.from_pretrained(model)
|
||||
elif isinstance(model, Model):
|
||||
pipe_model = model
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f'execpting a Model instance or str, but get {type(model)}.')
|
||||
|
||||
super().__init__(model=pipe_model)
|
||||
|
||||
def preprocess(self, input: Input) -> Dict[str, Any]:
|
||||
return input
|
||||
|
||||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return self.model.generate(input)
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return {OutputKeys.OUTPUT_IMG: inputs}
|
||||
45
tests/pipelines/test_text_to_image_synthesis.py
Normal file
45
tests/pipelines/test_text_to_image_synthesis.py
Normal file
@@ -0,0 +1,45 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from modelscope.models import Model
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.pipelines.outputs import OutputKeys
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
|
||||
class TextToImageSynthesisTest(unittest.TestCase):
|
||||
model_id = 'damo/cv_imagen_text-to-image-synthesis_tiny'
|
||||
test_text = {'text': '宇航员'}
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_with_model_from_modelhub(self):
|
||||
model = Model.from_pretrained(self.model_id)
|
||||
pipe_line_text_to_image_synthesis = pipeline(
|
||||
task=Tasks.text_to_image_synthesis, model=model)
|
||||
img = pipe_line_text_to_image_synthesis(
|
||||
self.test_text)[OutputKeys.OUTPUT_IMG]
|
||||
print(np.sum(np.abs(img)))
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_run_with_model_name(self):
|
||||
pipe_line_text_to_image_synthesis = pipeline(
|
||||
task=Tasks.text_to_image_synthesis, model=self.model_id)
|
||||
img = pipe_line_text_to_image_synthesis(
|
||||
self.test_text)[OutputKeys.OUTPUT_IMG]
|
||||
print(np.sum(np.abs(img)))
|
||||
|
||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||
def test_run_with_default_model(self):
|
||||
pipe_line_text_to_image_synthesis = pipeline(
|
||||
task=Tasks.text_to_image_synthesis)
|
||||
img = pipe_line_text_to_image_synthesis(
|
||||
self.test_text)[OutputKeys.OUTPUT_IMG]
|
||||
print(np.sum(np.abs(img)))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user