diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index bfb44e11..00a6f24e 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -29,6 +29,7 @@ class Models(object): ofa = 'ofa' clip = 'clip-multi-modal-embedding' mplug = 'mplug' + imagen = 'imagen-text-to-image-synthesis' class Pipelines(object): @@ -71,6 +72,7 @@ class Pipelines(object): image_caption = 'image-captioning' multi_modal_embedding = 'multi-modal-embedding' visual_question_answering = 'visual-question-answering' + text_to_image_synthesis = 'text-to-image-synthesis' class Trainers(object): diff --git a/modelscope/models/multi_modal/__init__.py b/modelscope/models/multi_modal/__init__.py index 4ed9809b..a69491af 100644 --- a/modelscope/models/multi_modal/__init__.py +++ b/modelscope/models/multi_modal/__init__.py @@ -1,4 +1,5 @@ from .clip.clip_model import CLIPForMultiModalEmbedding from .image_captioning_model import OfaForImageCaptioning +from .imagen.imagen_model import ImagenForTextToImageSynthesis from .mplug_for_visual_question_answering import \ MPlugForVisualQuestionAnswering diff --git a/modelscope/models/multi_modal/imagen/__init__.py b/modelscope/models/multi_modal/imagen/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/multi_modal/imagen/diffusion.py b/modelscope/models/multi_modal/imagen/diffusion.py new file mode 100644 index 00000000..d71fe0ae --- /dev/null +++ b/modelscope/models/multi_modal/imagen/diffusion.py @@ -0,0 +1,595 @@ +import math + +import torch + +__all__ = ['GaussianDiffusion', 'beta_schedule'] + + +def kl_divergence(mu1, logvar1, mu2, logvar2): + a = -1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + b = ((mu1 - mu2)**2) * torch.exp(-logvar2) + return 0.5 * (a + b) + + +def standard_normal_cdf(x): + return 0.5 * (1.0 + torch.tanh( + math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) + + +def discretized_gaussian_log_likelihood(x0, mean, log_scale): + assert x0.shape == mean.shape == log_scale.shape + cx = x0 - mean + inv_stdv = torch.exp(-log_scale) + cdf_plus = standard_normal_cdf(inv_stdv * (cx + 1.0 / 255.0)) + cdf_min = standard_normal_cdf(inv_stdv * (cx - 1.0 / 255.0)) + log_cdf_plus = torch.log(cdf_plus.clamp(min=1e-12)) + log_one_minus_cdf_min = torch.log((1.0 - cdf_min).clamp(min=1e-12)) + cdf_delta = cdf_plus - cdf_min + log_probs = torch.where( + x0 < -0.999, log_cdf_plus, + torch.where(x0 > 0.999, log_one_minus_cdf_min, + torch.log(cdf_delta.clamp(min=1e-12)))) + assert log_probs.shape == x0.shape + return log_probs + + +def _i(tensor, t, x): + shape = (x.size(0), ) + (1, ) * (x.ndim - 1) + return tensor[t].view(shape).to(x) + + +def cosine_fn(u): + return math.cos((u + 0.008) / 1.008 * math.pi / 2)**2 + + +def beta_schedule(schedule, + num_timesteps=1000, + init_beta=None, + last_beta=None): + if schedule == 'linear': + scale = 1000.0 / num_timesteps + init_beta = init_beta or scale * 0.0001 + last_beta = last_beta or scale * 0.02 + return torch.linspace( + init_beta, last_beta, num_timesteps, dtype=torch.float64) + elif schedule == 'quadratic': + init_beta = init_beta or 0.0015 + last_beta = last_beta or 0.0195 + return torch.linspace( + init_beta**0.5, last_beta**0.5, num_timesteps, + dtype=torch.float64)**2 + elif schedule == 'cosine': + betas = [] + for step in range(num_timesteps): + t1 = step / num_timesteps + t2 = (step + 1) / num_timesteps + betas.append(min(1.0 - cosine_fn(t2) / cosine_fn(t1), 0.999)) + return torch.tensor(betas, dtype=torch.float64) + else: + raise ValueError(f'Unsupported schedule: {schedule}') + + +class GaussianDiffusion(object): + + def __init__(self, + betas, + mean_type='eps', + var_type='learned_range', + loss_type='mse', + rescale_timesteps=False): + # check input + if not isinstance(betas, torch.DoubleTensor): + betas = torch.tensor(betas, dtype=torch.float64) + assert min(betas) > 0 and max(betas) <= 1 + assert mean_type in ['x0', 'x_{t-1}', 'eps'] + assert var_type in [ + 'learned', 'learned_range', 'fixed_large', 'fixed_small' + ] + assert loss_type in [ + 'mse', 'rescaled_mse', 'kl', 'rescaled_kl', 'l1', 'rescaled_l1' + ] + self.betas = betas + self.num_timesteps = len(betas) + self.mean_type = mean_type + self.var_type = var_type + self.loss_type = loss_type + self.rescale_timesteps = rescale_timesteps + + # alphas + alphas = 1 - self.betas + self.alphas_cumprod = torch.cumprod(alphas, dim=0) + self.alphas_cumprod_prev = torch.cat( + [alphas.new_ones([1]), self.alphas_cumprod[:-1]]) + self.alphas_cumprod_next = torch.cat( + [self.alphas_cumprod[1:], + alphas.new_zeros([1])]) + + # q(x_t | x_{t-1}) + self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod) + self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 + - self.alphas_cumprod) + self.log_one_minus_alphas_cumprod = torch.log(1.0 + - self.alphas_cumprod) + self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod) + self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod + - 1) + + # q(x_{t-1} | x_t, x_0) + self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / ( + 1.0 - self.alphas_cumprod) + self.posterior_log_variance_clipped = torch.log( + self.posterior_variance.clamp(1e-20)) + self.posterior_mean_coef1 = betas * torch.sqrt( + self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + self.posterior_mean_coef2 = ( + 1.0 - self.alphas_cumprod_prev) * torch.sqrt(alphas) / ( + 1.0 - self.alphas_cumprod) + + def q_sample(self, x0, t, noise=None): + noise = torch.randn_like(x0) if noise is None else noise + return _i(self.sqrt_alphas_cumprod, t, x0) * x0 + _i( + self.sqrt_one_minus_alphas_cumprod, t, x0) * noise + + def q_mean_variance(self, x0, t): + mu = _i(self.sqrt_alphas_cumprod, t, x0) * x0 + var = _i(1.0 - self.alphas_cumprod, t, x0) + log_var = _i(self.log_one_minus_alphas_cumprod, t, x0) + return mu, var, log_var + + def q_posterior_mean_variance(self, x0, xt, t): + mu = _i(self.posterior_mean_coef1, t, xt) * x0 + _i( + self.posterior_mean_coef2, t, xt) * xt + var = _i(self.posterior_variance, t, xt) + log_var = _i(self.posterior_log_variance_clipped, t, xt) + return mu, var, log_var + + @torch.no_grad() + def p_sample(self, + xt, + t, + model, + model_kwargs={}, + clamp=None, + percentile=None, + condition_fn=None, + guide_scale=None): + # predict distribution of p(x_{t-1} | x_t) + mu, var, log_var, x0 = self.p_mean_variance(xt, t, model, model_kwargs, + clamp, percentile, + guide_scale) + + # random sample (with optional conditional function) + noise = torch.randn_like(xt) + shape = (-1, ) + ((1, ) * (xt.ndim - 1)) + mask = t.ne(0).float().view(*shape) # no noise when t == 0 + if condition_fn is not None: + grad = condition_fn(xt, self._scale_timesteps(t), **model_kwargs) + mu = mu.float() + var * grad.float() + xt_1 = mu + mask * torch.exp(0.5 * log_var) * noise + return xt_1, x0 + + @torch.no_grad() + def p_sample_loop(self, + noise, + model, + model_kwargs={}, + clamp=None, + percentile=None, + condition_fn=None, + guide_scale=None): + # prepare input + b, c, h, w = noise.size() + xt = noise + + # diffusion process + for step in torch.arange(self.num_timesteps).flip(0): + t = torch.full((b, ), step, dtype=torch.long, device=xt.device) + xt, _ = self.p_sample(xt, t, model, model_kwargs, clamp, + percentile, condition_fn, guide_scale) + return xt + + def p_mean_variance(self, + xt, + t, + model, + model_kwargs={}, + clamp=None, + percentile=None, + guide_scale=None): + # predict distribution + if guide_scale is None: + out = model(xt, self._scale_timesteps(t), **model_kwargs) + else: + # classifier-free guidance + # (model_kwargs[0]: conditional kwargs; model_kwargs[1]: non-conditional kwargs) + assert isinstance(model_kwargs, list) and len(model_kwargs) == 2 + assert self.mean_type == 'eps' + y_out = model(xt, self._scale_timesteps(t), **model_kwargs[0]) + u_out = model(xt, self._scale_timesteps(t), **model_kwargs[1]) + a = u_out[:, :3] + b = guide_scale * (y_out[:, :3] - u_out[:, :3]) + c = y_out[:, 3:] + out = torch.cat([a + b, c], dim=1) + + # compute variance + if self.var_type == 'learned': + out, log_var = out.chunk(2, dim=1) + var = torch.exp(log_var) + elif self.var_type == 'learned_range': + out, fraction = out.chunk(2, dim=1) + min_log_var = _i(self.posterior_log_variance_clipped, t, xt) + max_log_var = _i(torch.log(self.betas), t, xt) + fraction = (fraction + 1) / 2.0 + log_var = fraction * max_log_var + (1 - fraction) * min_log_var + var = torch.exp(log_var) + elif self.var_type == 'fixed_large': + var = _i( + torch.cat([self.posterior_variance[1:2], self.betas[1:]]), t, + xt) + log_var = torch.log(var) + elif self.var_type == 'fixed_small': + var = _i(self.posterior_variance, t, xt) + log_var = _i(self.posterior_log_variance_clipped, t, xt) + + # compute mean and x0 + if self.mean_type == 'x_{t-1}': + mu = out # x_{t-1} + x0 = _i(1.0 / self.posterior_mean_coef1, t, xt) * mu - _i( + self.posterior_mean_coef2 / self.posterior_mean_coef1, t, + xt) * xt + elif self.mean_type == 'x0': + x0 = out + mu, _, _ = self.q_posterior_mean_variance(x0, xt, t) + elif self.mean_type == 'eps': + x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - _i( + self.sqrt_recipm1_alphas_cumprod, t, xt) * out + mu, _, _ = self.q_posterior_mean_variance(x0, xt, t) + + # restrict the range of x0 + if percentile is not None: + assert percentile > 0 and percentile <= 1 # e.g., 0.995 + s = torch.quantile( + x0.flatten(1).abs(), percentile, + dim=1).clamp_(1.0).view(-1, 1, 1, 1) + x0 = torch.min(s, torch.max(-s, x0)) / s + elif clamp is not None: + x0 = x0.clamp(-clamp, clamp) + return mu, var, log_var, x0 + + @torch.no_grad() + def ddim_sample(self, + xt, + t, + model, + model_kwargs={}, + clamp=None, + percentile=None, + condition_fn=None, + guide_scale=None, + ddim_timesteps=20, + eta=0.0): + stride = self.num_timesteps // ddim_timesteps + + # predict distribution of p(x_{t-1} | x_t) + _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp, + percentile, guide_scale) + if condition_fn is not None: + # x0 -> eps + alpha = _i(self.alphas_cumprod, t, xt) + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / _i( + self.sqrt_recipm1_alphas_cumprod, t, xt) + eps = eps - (1 - alpha).sqrt() * condition_fn( + xt, self._scale_timesteps(t), **model_kwargs) + + # eps -> x0 + x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - _i( + self.sqrt_recipm1_alphas_cumprod, t, xt) * eps + + # derive variables + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / _i( + self.sqrt_recipm1_alphas_cumprod, t, xt) + alphas = _i(self.alphas_cumprod, t, xt) + alphas_prev = _i(self.alphas_cumprod, (t - stride).clamp(0), xt) + a = (1 - alphas_prev) / (1 - alphas) + b = (1 - alphas / alphas_prev) + sigmas = eta * torch.sqrt(a * b) + + # random sample + noise = torch.randn_like(xt) + direction = torch.sqrt(1 - alphas_prev - sigmas**2) * eps + mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1))) + xt_1 = torch.sqrt(alphas_prev) * x0 + direction + mask * sigmas * noise + return xt_1, x0 + + @torch.no_grad() + def ddim_sample_loop(self, + noise, + model, + model_kwargs={}, + clamp=None, + percentile=None, + condition_fn=None, + guide_scale=None, + ddim_timesteps=20, + eta=0.0): + # prepare input + b, c, h, w = noise.size() + xt = noise + + # diffusion process (TODO: clamp is inaccurate! Consider replacing the stride by explicit prev/next steps) + steps = (1 + torch.arange(0, self.num_timesteps, + self.num_timesteps // ddim_timesteps)).clamp( + 0, self.num_timesteps - 1).flip(0) + for step in steps: + t = torch.full((b, ), step, dtype=torch.long, device=xt.device) + xt, _ = self.ddim_sample(xt, t, model, model_kwargs, clamp, + percentile, condition_fn, guide_scale, + ddim_timesteps, eta) + return xt + + @torch.no_grad() + def ddim_reverse_sample(self, + xt, + t, + model, + model_kwargs={}, + clamp=None, + percentile=None, + guide_scale=None, + ddim_timesteps=20): + stride = self.num_timesteps // ddim_timesteps + + # predict distribution of p(x_{t-1} | x_t) + _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp, + percentile, guide_scale) + + # derive variables + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / _i( + self.sqrt_recipm1_alphas_cumprod, t, xt) + alphas_next = _i( + torch.cat( + [self.alphas_cumprod, + self.alphas_cumprod.new_zeros([1])]), + (t + stride).clamp(0, self.num_timesteps), xt) + + # reverse sample + mu = torch.sqrt(alphas_next) * x0 + torch.sqrt(1 - alphas_next) * eps + return mu, x0 + + @torch.no_grad() + def ddim_reverse_sample_loop(self, + x0, + model, + model_kwargs={}, + clamp=None, + percentile=None, + guide_scale=None, + ddim_timesteps=20): + # prepare input + b, c, h, w = x0.size() + xt = x0 + + # reconstruction steps + steps = torch.arange(0, self.num_timesteps, + self.num_timesteps // ddim_timesteps) + for step in steps: + t = torch.full((b, ), step, dtype=torch.long, device=xt.device) + xt, _ = self.ddim_reverse_sample(xt, t, model, model_kwargs, clamp, + percentile, guide_scale, + ddim_timesteps) + return xt + + @torch.no_grad() + def plms_sample(self, + xt, + t, + model, + model_kwargs={}, + clamp=None, + percentile=None, + condition_fn=None, + guide_scale=None, + plms_timesteps=20): + stride = self.num_timesteps // plms_timesteps + + # function for compute eps + def compute_eps(xt, t): + # predict distribution of p(x_{t-1} | x_t) + _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, + clamp, percentile, guide_scale) + + # condition + if condition_fn is not None: + # x0 -> eps + alpha = _i(self.alphas_cumprod, t, xt) + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt + - x0) / _i(self.sqrt_recipm1_alphas_cumprod, t, xt) + eps = eps - (1 - alpha).sqrt() * condition_fn( + xt, self._scale_timesteps(t), **model_kwargs) + + # eps -> x0 + x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - _i( + self.sqrt_recipm1_alphas_cumprod, t, xt) * eps + + # derive eps + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / _i( + self.sqrt_recipm1_alphas_cumprod, t, xt) + return eps + + # function for compute x_0 and x_{t-1} + def compute_x0(eps, t): + # eps -> x0 + x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - _i( + self.sqrt_recipm1_alphas_cumprod, t, xt) * eps + + # deterministic sample + alphas_prev = _i(self.alphas_cumprod, (t - stride).clamp(0), xt) + direction = torch.sqrt(1 - alphas_prev) * eps + xt_1 = torch.sqrt(alphas_prev) * x0 + direction + return xt_1, x0 + + # PLMS sample + eps = compute_eps(xt, t) + if len(eps_cache) == 0: + # 2nd order pseudo improved Euler + xt_1, x0 = compute_x0(eps, t) + eps_next = compute_eps(xt_1, (t - stride).clamp(0)) + eps_prime = (eps + eps_next) / 2.0 + elif len(eps_cache) == 1: + # 2nd order pseudo linear multistep (Adams-Bashforth) + eps_prime = (3 * eps - eps_cache[-1]) / 2.0 + elif len(eps_cache) == 2: + # 3nd order pseudo linear multistep (Adams-Bashforth) + eps_prime = (23 * eps - 16 * eps_cache[-1] + + 5 * eps_cache[-2]) / 12.0 + elif len(eps_cache) >= 3: + # 4nd order pseudo linear multistep (Adams-Bashforth) + eps_prime = (55 * eps - 59 * eps_cache[-1] + 37 * eps_cache[-2] + - 9 * eps_cache[-3]) / 24.0 + xt_1, x0 = compute_x0(eps_prime, t) + return xt_1, x0, eps + + @torch.no_grad() + def plms_sample_loop(self, + noise, + model, + model_kwargs={}, + clamp=None, + percentile=None, + condition_fn=None, + guide_scale=None, + plms_timesteps=20): + # prepare input + b, c, h, w = noise.size() + xt = noise + + # diffusion process + steps = (1 + torch.arange(0, self.num_timesteps, + self.num_timesteps // plms_timesteps)).clamp( + 0, self.num_timesteps - 1).flip(0) + eps_cache = [] + for step in steps: + # PLMS sampling step + t = torch.full((b, ), step, dtype=torch.long, device=xt.device) + xt, _, eps = self.plms_sample(xt, t, model, model_kwargs, clamp, + percentile, condition_fn, + guide_scale, plms_timesteps, + eps_cache) + + # update eps cache + eps_cache.append(eps) + if len(eps_cache) >= 4: + eps_cache.pop(0) + return xt + + def loss(self, x0, t, model, model_kwargs={}, noise=None): + noise = torch.randn_like(x0) if noise is None else noise + xt = self.q_sample(x0, t, noise=noise) + + # compute loss + if self.loss_type in ['kl', 'rescaled_kl']: + loss, _ = self.variational_lower_bound(x0, xt, t, model, + model_kwargs) + if self.loss_type == 'rescaled_kl': + loss = loss * self.num_timesteps + elif self.loss_type in ['mse', 'rescaled_mse', 'l1', 'rescaled_l1']: + out = model(xt, self._scale_timesteps(t), **model_kwargs) + + # VLB for variation + loss_vlb = 0.0 + if self.var_type in ['learned', 'learned_range']: + out, var = out.chunk(2, dim=1) + frozen = torch.cat([ + out.detach(), var + ], dim=1) # learn var without affecting the prediction of mean + loss_vlb, _ = self.variational_lower_bound( + x0, xt, t, model=lambda *args, **kwargs: frozen) + if self.loss_type.startswith('rescaled_'): + loss_vlb = loss_vlb * self.num_timesteps / 1000.0 + + # MSE/L1 for x0/eps + target = { + 'eps': noise, + 'x0': x0, + 'x_{t-1}': self.q_posterior_mean_variance(x0, xt, t)[0] + }[self.mean_type] + loss = (out - target).pow(1 if self.loss_type.endswith('l1') else 2 + ).abs().flatten(1).mean(dim=1) + + # total loss + loss = loss + loss_vlb + return loss + + def variational_lower_bound(self, + x0, + xt, + t, + model, + model_kwargs={}, + clamp=None, + percentile=None): + # compute groundtruth and predicted distributions + mu1, _, log_var1 = self.q_posterior_mean_variance(x0, xt, t) + mu2, _, log_var2, x0 = self.p_mean_variance(xt, t, model, model_kwargs, + clamp, percentile) + + # compute KL loss + kl = kl_divergence(mu1, log_var1, mu2, log_var2) + kl = kl.flatten(1).mean(dim=1) / math.log(2.0) + + # compute discretized NLL loss (for p(x0 | x1) only) + nll = -discretized_gaussian_log_likelihood( + x0, mean=mu2, log_scale=0.5 * log_var2) + nll = nll.flatten(1).mean(dim=1) / math.log(2.0) + + # NLL for p(x0 | x1) and KL otherwise + vlb = torch.where(t == 0, nll, kl) + return vlb, x0 + + @torch.no_grad() + def variational_lower_bound_loop(self, + x0, + model, + model_kwargs={}, + clamp=None, + percentile=None): + # prepare input and output + b, c, h, w = x0.size() + metrics = {'vlb': [], 'mse': [], 'x0_mse': []} + + # loop + for step in torch.arange(self.num_timesteps).flip(0): + # compute VLB + t = torch.full((b, ), step, dtype=torch.long, device=x0.device) + noise = torch.randn_like(x0) + xt = self.q_sample(x0, t, noise) + vlb, pred_x0 = self.variational_lower_bound( + x0, xt, t, model, model_kwargs, clamp, percentile) + + # predict eps from x0 + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / _i( + self.sqrt_recipm1_alphas_cumprod, t, xt) + + # collect metrics + metrics['vlb'].append(vlb) + metrics['x0_mse'].append( + (pred_x0 - x0).square().flatten(1).mean(dim=1)) + metrics['mse'].append( + (eps - noise).square().flatten(1).mean(dim=1)) + metrics = {k: torch.stack(v, dim=1) for k, v in metrics.items()} + + # compute the prior KL term for VLB, measured in bits-per-dim + mu, _, log_var = self.q_mean_variance(x0, t) + kl_prior = kl_divergence(mu, log_var, torch.zeros_like(mu), + torch.zeros_like(log_var)) + kl_prior = kl_prior.flatten(1).mean(dim=1) / math.log(2.0) + + # update metrics + metrics['prior_bits_per_dim'] = kl_prior + metrics['total_bits_per_dim'] = metrics['vlb'].sum(dim=1) + kl_prior + return metrics + + def _scale_timesteps(self, t): + if self.rescale_timesteps: + return t.float() * 1000.0 / self.num_timesteps + return t diff --git a/modelscope/models/multi_modal/imagen/imagen_model.py b/modelscope/models/multi_modal/imagen/imagen_model.py new file mode 100644 index 00000000..e394ccf2 --- /dev/null +++ b/modelscope/models/multi_modal/imagen/imagen_model.py @@ -0,0 +1,255 @@ +import os.path as osp +from typing import Any, Dict + +import json +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from PIL import Image + +from modelscope.metainfo import Models +from modelscope.models.base import Model +from modelscope.models.builder import MODELS +from modelscope.models.multi_modal.imagen.diffusion import (GaussianDiffusion, + beta_schedule) +from modelscope.models.multi_modal.imagen.structbert import (BertConfig, + BertModel) +from modelscope.models.multi_modal.imagen.tokenizer import FullTokenizer +from modelscope.models.multi_modal.imagen.unet_generator import ImagenGenerator +from modelscope.models.multi_modal.imagen.unet_imagen_upsampler_256 import \ + SuperResUNet256 +from modelscope.models.multi_modal.imagen.unet_upsampler_1024 import \ + ImagenUpsampler1024 +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + +__all__ = ['ImagenForTextToImageSynthesis'] + + +def make_diffusion(schedule, + num_timesteps=1000, + init_beta=None, + last_beta=None, + var_type='fixed_small'): + betas = beta_schedule(schedule, num_timesteps, init_beta, last_beta) + diffusion = GaussianDiffusion(betas, var_type=var_type) + return diffusion + + +class Tokenizer(object): + + def __init__(self, vocab_file, seq_len=64): + self.vocab_file = vocab_file + self.seq_len = seq_len + self.tokenizer = FullTokenizer( + vocab_file=vocab_file, do_lower_case=True) + + def __call__(self, text): + # tokenization + tokens = self.tokenizer.tokenize(text) + tokens = ['[CLS]'] + tokens[:self.seq_len - 2] + ['[SEP]'] + input_ids = self.tokenizer.convert_tokens_to_ids(tokens) + input_mask = [1] * len(input_ids) + segment_ids = [0] * len(input_ids) + + # padding + input_ids += [0] * (self.seq_len - len(input_ids)) + input_mask += [0] * (self.seq_len - len(input_mask)) + segment_ids += [0] * (self.seq_len - len(segment_ids)) + assert len(input_ids) == len(input_mask) == len( + segment_ids) == self.seq_len + + # convert to tensors + input_ids = torch.LongTensor(input_ids) + input_mask = torch.LongTensor(input_mask) + segment_ids = torch.LongTensor(segment_ids) + return input_ids, segment_ids, input_mask + + +class ImagenModel(nn.Module): + + def __init__(self, model_dir): + super(ImagenModel, self).__init__() + # including text and generator config + model_config = json.load( + open('{}/imagen_config.json'.format(model_dir))) + + # text encoder + text_config = model_config['text_config'] + self.text_encoder = BertModel(BertConfig.from_dict(text_config)) + + # generator (64x64) + generator_config = model_config['generator_config'] + self.unet_generator = ImagenGenerator(**generator_config) + + # imagen upsampler (256x256) + imagen_upsampler_256_config = model_config[ + 'imagen_upsampler_256_config'] + self.unet_imagen_upsampler_256 = SuperResUNet256( + **imagen_upsampler_256_config) + + # dalle2 upsampler (1024x1024) + upsampler_1024_config = model_config['upsampler_1024_config'] + self.unet_upsampler_1024 = ImagenUpsampler1024(**upsampler_1024_config) + + def forward(self, noise, timesteps, input_ids, token_type_ids, + attention_mask): + context, y = self.text_encoder( + input_ids=input_ids, + token_type_ids=token_type_ids, + attention_mask=attention_mask) + context = context[-1] + x = self.unet_generator(noise, timesteps, y, context, attention_mask) + x = self.unet_imagen_upsampler_256(noise, timesteps, x, + torch.zeros_like(timesteps), y, + context, attention_mask) + x = self.unet_upsampler_1024(x, t, x) + return x + + +@MODELS.register_module( + Tasks.text_to_image_synthesis, module_name=Models.imagen) +class ImagenForTextToImageSynthesis(Model): + + def __init__(self, model_dir, device_id=-1): + super().__init__(model_dir=model_dir, device_id=device_id) + imagen_model = ImagenModel(model_dir=model_dir) + pretrained_params = torch.load( + osp.join(model_dir, ModelFile.TORCH_MODEL_BIN_FILE), 'cpu') + imagen_model.load_state_dict(pretrained_params) + imagen_model.eval() + + self.device_id = device_id + if self.device_id >= 0: + self.device = torch.device(f'cuda:{self.device_id}') + imagen_model.to('cuda:{}'.format(self.device_id)) + logger.info('Use GPU: {}'.format(self.device_id)) + else: + self.device = torch.device('cpu') + logger.info('Use CPU for inference') + + # modules + self.text_encoder = imagen_model.text_encoder + self.unet_generator = imagen_model.unet_generator + self.unet_imagen_upsampler_256 = imagen_model.unet_imagen_upsampler_256 + self.unet_upsampler_1024 = imagen_model.unet_upsampler_1024 + + # text tokenizer + vocab_path = '{}/vocab.txt'.format(model_dir) + self.tokenizer = Tokenizer(vocab_file=vocab_path, seq_len=64) + + # diffusion process + diffusion_params = json.load( + open('{}/diffusion_config.json'.format(model_dir))) + self.diffusion_generator = make_diffusion( + **diffusion_params['generator_config']) + self.diffusion_imagen_upsampler_256 = make_diffusion( + **diffusion_params['imagen_upsampler_256_config']) + self.diffusion_upsampler_1024 = make_diffusion( + **diffusion_params['upsampler_1024_config']) + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + if not all([key in input for key in ('text', 'noise', 'timesteps')]): + raise ValueError( + f'input should contains "text", "noise", and "timesteps", but got {input.keys()}' + ) + input_ids, token_type_ids, attention_mask = self.tokenizer( + input['text']) + input_ids = input_ids.to(self.device).unsqueeze(0) + token_type_ids = token_type_ids.to(self.device).unsqueeze(0) + attention_mask = attention_mask.to(self.device).unsqueeze(0) + context, y = self.text_encoder( + input_ids=input_ids, + token_type_ids=token_type_ids, + attention_mask=attention_mask) + context = context[-1] + x = self.unet_generator(noise, timesteps, y, context, attention_mask) + x = self.unet_imagen_upsampler_256(noise, timesteps, x, + torch.zeros_like(timesteps), y, + context, attention_mask) + x = self.unet_upsampler_1024(x, t, x) + img = x.clamp(-1, 1).add(1).mul(127.5) + img = img.squeeze(0).permute(1, 2, 0).cpu().numpy().astype(np.uint8) + return img + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs + + @torch.no_grad() + def generate(self, input: Dict[str, Any]) -> Dict[str, Any]: + if 'text' not in input: + raise ValueError( + f'input should contain "text", but got {input.keys()}') + + # encode text + input_ids, token_type_ids, attention_mask = self.tokenizer( + input['text']) + input_ids = input_ids.to(self.device).unsqueeze(0) + token_type_ids = token_type_ids.to(self.device).unsqueeze(0) + attention_mask = attention_mask.to(self.device).unsqueeze(0) + context, y = self.text_encoder( + input_ids=input_ids, + token_type_ids=token_type_ids, + attention_mask=attention_mask) + context = context[-1] + + # generation + img = self.diffusion_generator.ddim_sample_loop( + noise=torch.randn(1, 3, 64, 64).to(self.device), + model=self.unet_generator, + model_kwargs=[{ + 'y': y, + 'context': context, + 'mask': attention_mask + }, { + 'y': torch.zeros_like(y), + 'context': torch.zeros_like(context), + 'mask': attention_mask + }], + percentile=input.get('generator_percentile', 0.995), + guide_scale=input.get('generator_guide_scale', 5.0), + ddim_timesteps=input.get('generator_ddim_timesteps', 250), + eta=input.get('generator_ddim_eta', 0.0)) + + # upsampling (64->256) + img = F.interpolate( + img, scale_factor=4.0, mode='bilinear', align_corners=False) + img = self.diffusion_imagen_upsampler_256.ddim_sample_loop( + noise=torch.randn_like(img), + model=self.unet_imagen_upsampler_256, + model_kwargs=[{ + 'lx': img, + 'lt': torch.zeros(1).to(self.device), + 'y': y, + 'context': context, + 'mask': attention_mask + }, { + 'lx': img, + 'lt': torch.zeros(1).to(self.device), + 'y': torch.zeros_like(y), + 'context': torch.zeros_like(context), + 'mask': torch.zeros_like(attention_mask) + }], + percentile=input.get('generator_percentile', 0.995), + guide_scale=input.get('generator_guide_scale', 5.0), + ddim_timesteps=input.get('generator_ddim_timesteps', 50), + eta=input.get('generator_ddim_eta', 0.0)) + + # upsampling (256->1024) + img = F.interpolate( + img, scale_factor=4.0, mode='bilinear', align_corners=False) + img = self.diffusion_upsampler_1024.ddim_sample_loop( + noise=torch.randn_like(img), + model=self.unet_upsampler_1024, + model_kwargs={'concat': img}, + percentile=input.get('upsampler_1024_percentile', 0.995), + ddim_timesteps=input.get('upsampler_1024_ddim_timesteps', 20), + eta=input.get('upsampler_1024_ddim_eta', 0.0)) + + # output + img = img.clamp(-1, 1).add(1).mul(127.5).squeeze(0).permute( + 1, 2, 0).cpu().numpy().astype(np.uint8) + return img diff --git a/modelscope/models/multi_modal/imagen/structbert.py b/modelscope/models/multi_modal/imagen/structbert.py new file mode 100644 index 00000000..219e642f --- /dev/null +++ b/modelscope/models/multi_modal/imagen/structbert.py @@ -0,0 +1,936 @@ +# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team and Alibaba inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch BERT model.""" + +from __future__ import absolute_import, division, print_function +import copy +import math + +import json +import numpy as np +import six +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint +from torch.nn import CrossEntropyLoss + + +def gelu(x): + """Implementation of the gelu activation function. + For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): + 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) + """ + return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) + + +class BertConfig(object): + """Configuration class to store the configuration of a `BertModel`. + """ + + def __init__(self, + vocab_size, + hidden_size=768, + emb_size=-1, + num_hidden_layers=12, + transformer_type='original', + transition_function='linear', + weighted_transformer=0, + num_rolled_layers=3, + num_attention_heads=12, + intermediate_size=3072, + hidden_act='gelu', + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=16, + initializer_range=0.02, + attention_type='self', + rezero=False, + pre_ln=False, + squeeze_excitation=False, + transfer_matrix=False, + dim_dropout=False, + roberta_style=False, + set_mask_zero=False, + init_scale=False, + safer_fp16=False, + grad_checkpoint=False): + """Constructs BertConfig. + + Args: + vocab_size: Vocabulary size of `inputs_ids` in `BertModel`. + hidden_size: Size of the encoder layers and the pooler layer. + num_hidden_layers: Number of hidden layers in the Transformer encoder. + num_attention_heads: Number of attention heads for each attention layer in + the Transformer encoder. + intermediate_size: The size of the "intermediate" (i.e., feed-forward) + layer in the Transformer encoder. + hidden_act: The non-linear activation function (function or string) in the + encoder and pooler. + hidden_dropout_prob: The dropout probabilitiy for all fully connected + layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob: The dropout ratio for the attention + probabilities. + max_position_embeddings: The maximum sequence length that this model might + ever be used with. Typically set this to something large just in case + (e.g., 512 or 1024 or 2048). + type_vocab_size: The vocabulary size of the `token_type_ids` passed into + `BertModel`. + initializer_range: The sttdev of the truncated_normal_initializer for + initializing all weight matrices. + """ + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.emb_size = emb_size + self.num_hidden_layers = num_hidden_layers + self.transformer_type = transformer_type + self.transition_function = transition_function + self.weighted_transformer = weighted_transformer + self.num_rolled_layers = num_rolled_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + self.attention_type = attention_type + self.rezero = rezero + self.pre_ln = pre_ln + self.squeeze_excitation = squeeze_excitation + self.transfer_matrix = transfer_matrix + self.dim_dropout = dim_dropout + self.set_mask_zero = set_mask_zero + self.roberta_style = roberta_style + self.init_scale = init_scale + self.safer_fp16 = safer_fp16 + self.grad_checkpoint = grad_checkpoint + + @classmethod + def from_dict(cls, json_object): + """Constructs a `BertConfig` from a Python dictionary of parameters.""" + config = BertConfig(vocab_size=None) + for (key, value) in six.iteritems(json_object): + config.__dict__[key] = value + return config + + @classmethod + def from_json_file(cls, json_file): + """Constructs a `BertConfig` from a json file of parameters.""" + with open(json_file, 'r') as reader: + text = reader.read() + return cls.from_dict(json.loads(text)) + + def to_dict(self): + """Serializes this instance to a Python dictionary.""" + output = copy.deepcopy(self.__dict__) + return output + + def to_json_string(self): + """Serializes this instance to a JSON string.""" + return json.dumps(self.to_dict(), indent=2, sort_keys=True) + '\n' + + +class BERTLayerNorm(nn.Module): + + def __init__(self, config, variance_epsilon=1e-12, special_size=None): + """Construct a layernorm module in the TF style (epsilon inside the square root). + """ + super(BERTLayerNorm, self).__init__() + self.config = config + hidden_size = special_size if special_size is not None else config.hidden_size + self.gamma = nn.Parameter(torch.ones(hidden_size)) + self.beta = nn.Parameter(torch.zeros(hidden_size)) + self.variance_epsilon = variance_epsilon if not config.roberta_style else 1e-5 + + def forward(self, x): + previous_type = x.type() + if self.config.safer_fp16: + x = x.float() + u = x.mean(-1, keepdim=True) + s = (x - u).pow(2).mean(-1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.variance_epsilon) + if self.config.safer_fp16: + return (self.gamma * x + self.beta).type(previous_type) + else: + return self.gamma * x + self.beta + + +class BERTEmbeddings(nn.Module): + + def __init__(self, config): + super(BERTEmbeddings, self).__init__() + """Construct the embedding module from word, position and token_type embeddings. + """ + hidden_size = config.hidden_size if config.emb_size < 0 else config.emb_size + self.word_embeddings = nn.Embedding( + config.vocab_size, + hidden_size, + padding_idx=1 if config.roberta_style else None) + self.position_embeddings = nn.Embedding( + config.max_position_embeddings, + hidden_size, + padding_idx=1 if config.roberta_style else None) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, + hidden_size) + self.config = config + self.proj = None if config.emb_size < 0 else nn.Linear( + config.emb_size, config.hidden_size) + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = BERTLayerNorm(config, special_size=hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, input_ids, token_type_ids=None, adv_embedding=None): + seq_length = input_ids.size(1) + if not self.config.roberta_style: + position_ids = torch.arange( + seq_length, dtype=torch.long, device=input_ids.device) + position_ids = position_ids.unsqueeze(0).expand_as(input_ids) + else: + mask = input_ids.ne(1).int() + position_ids = (torch.cumsum(mask, dim=1).type_as(mask) + * mask).long() + 1 + if token_type_ids is None: + token_type_ids = torch.zeros_like(input_ids) + + words_embeddings = self.word_embeddings( + input_ids) if adv_embedding is None else adv_embedding + if self.config.set_mask_zero: + words_embeddings[input_ids == 103] = 0. + position_embeddings = self.position_embeddings(position_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + if not self.config.roberta_style: + embeddings = words_embeddings + position_embeddings + token_type_embeddings + else: + embeddings = words_embeddings + position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + if self.proj is not None: + embeddings = self.proj(embeddings) + embeddings = self.dropout(embeddings) + else: + return embeddings, words_embeddings + + +class BERTFactorizedAttention(nn.Module): + + def __init__(self, config): + super(BERTFactorizedAttention, self).__init__() + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + 'The hidden size (%d) is not a multiple of the number of attention ' + 'heads (%d)' % + (config.hidden_size, config.num_attention_heads)) + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size + / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x, *size): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, + self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(size) + + def forward(self, hidden_states, attention_mask): + mixed_query_layer = self.query(hidden_states) + mixed_key_layer = self.key(hidden_states) + mixed_value_layer = self.value(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer, 0, 2, 3, 1) + key_layer = self.transpose_for_scores(mixed_key_layer, 0, 2, 1, 3) + value_layer = self.transpose_for_scores(mixed_value_layer, 0, 2, 1, 3) + + s_attention_scores = query_layer + attention_mask + s_attention_probs = nn.Softmax(dim=-1)(s_attention_scores) + s_attention_probs = self.dropout(s_attention_probs) + + c_attention_probs = nn.Softmax(dim=-1)(key_layer) + s_context_layer = torch.matmul(s_attention_probs, value_layer) + context_layer = torch.matmul(c_attention_probs, s_context_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + ( + self.all_head_size, ) + context_layer = context_layer.view(*new_context_layer_shape) + return context_layer + + +def dim_dropout(x, p=0, dim=-1, training=False): + if not training or p == 0: + return x + a = (1 - p) + b = (x.data.new(x.size()).zero_() + 1) + dropout_mask = torch.bernoulli(a * b) + return dropout_mask * (dropout_mask.size(dim) / torch.sum( + dropout_mask, dim=dim, keepdim=True)) * x + + +class BERTSelfAttention(nn.Module): + + def __init__(self, config): + super(BERTSelfAttention, self).__init__() + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + 'The hidden size (%d) is not a multiple of the number of attention ' + 'heads (%d)' % + (config.hidden_size, config.num_attention_heads)) + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size + / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.config = config + if config.pre_ln: + self.LayerNorm = BERTLayerNorm(config) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, + self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward(self, hidden_states, attention_mask, head_mask=None): + if self.config.pre_ln: + hidden_states = self.LayerNorm(hidden_states) + mixed_query_layer = self.query(hidden_states) + mixed_key_layer = self.key(hidden_states) + mixed_value_layer = self.value(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = self.transpose_for_scores(mixed_key_layer) + value_layer = self.transpose_for_scores(mixed_value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, + key_layer.transpose(-1, -2)) + attention_scores = attention_scores / math.sqrt( + self.attention_head_size) + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + if head_mask is not None and not self.training: + for i, mask in enumerate(head_mask): + if head_mask[i] == 1: + attention_scores[:, i, :, :] = 0. + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + if not self.config.dim_dropout: + attention_probs = self.dropout(attention_probs) + else: + attention_probs = dim_dropout( + attention_probs, + p=self.config.attention_probs_dropout_prob, + dim=-1, + training=self.training) + + context_layer = torch.matmul(attention_probs, value_layer) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + ( + self.all_head_size, ) + context_layer = context_layer.view(*new_context_layer_shape) + return context_layer + + +class BERTSelfOutput(nn.Module): + + def __init__(self, config): + super(BERTSelfOutput, self).__init__() + self.config = config + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if not config.pre_ln and not config.rezero: + self.LayerNorm = BERTLayerNorm(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + if config.rezero: + self.res_factor = nn.Parameter( + torch.Tensor(1).fill_(0.99).to( + dtype=next(self.parameters()).dtype)) + self.factor = nn.Parameter( + torch.ones(1).to(dtype=next(self.parameters()).dtype)) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + if not self.config.rezero and not self.config.pre_ln: + hidden_states = self.LayerNorm(hidden_states + input_tensor) + elif self.config.rezero: + hidden_states = hidden_states + self.factor * input_tensor + else: + pass + return hidden_states + + +class BERTAttention(nn.Module): + + def __init__(self, config): + super(BERTAttention, self).__init__() + if config.attention_type.lower() == 'self': + self.self = BERTSelfAttention(config) + elif config.attention_type.lower() == 'factorized': + self.self = BERTFactorizedAttention(config) + else: + raise ValueError( + 'Attention type must in [self, factorized], but got {}'.format( + config.attention_type)) + self.output = BERTSelfOutput(config) + + def forward(self, input_tensor, attention_mask, head_mask=None): + self_output = self.self(input_tensor, attention_mask, head_mask) + attention_output = self.output(self_output, input_tensor) + return attention_output + + +class DepthwiseSeparableConv1d(nn.Module): + + def __init__(self, + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + dilation=1, + bias=False): + super(DepthwiseSeparableConv1d, self).__init__() + padding = (kernel_size - 1) // 2 + self.depthwise = nn.Conv1d( + in_channels, + in_channels, + kernel_size, + stride, + padding, + dilation, + groups=in_channels, + bias=bias) + self.pointwise = nn.Conv1d( + in_channels, out_channels, 1, 1, 0, 1, 1, bias=bias) + + def forward(self, x): + x = self.depthwise(x) + x = self.pointwise(x) + return x + + +class BERTIntermediate(nn.Module): + + def __init__(self, config): + super(BERTIntermediate, self).__init__() + self.config = config + if self.config.pre_ln: + self.LayerNorm = BERTLayerNorm(config) + self.intermediate_act_fn = gelu + if config.transition_function.lower() == 'linear': + self.dense = nn.Linear(config.hidden_size, + config.intermediate_size) + elif config.transition_function.lower() == 'cnn': + self.cnn = DepthwiseSeparableConv1d( + config.hidden_size, 4 * config.hidden_size, kernel_size=7) + elif config.config.hidden_size.lower() == 'rnn': + raise NotImplementedError( + 'rnn transition function is not implemented yet') + else: + raise ValueError('Only support linear/cnn/rnn') + + def forward(self, hidden_states): + if self.config.pre_ln: + hidden_states = self.LayerNorm(hidden_states) + if self.config.transition_function.lower() == 'linear': + hidden_states = self.dense(hidden_states) + elif self.config.transition_function.lower() == 'cnn': + hidden_states = self.cnn(hidden_states.transpose(-1, + -2)).transpose( + -1, -2) + else: + pass + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class SqueezeExcitationBlock(nn.Module): + + def __init__(self, config): + super(SqueezeExcitationBlock, self).__init__() + self.down_sampling = nn.Linear(config.hidden_size, + config.hidden_size // 4) + self.up_sampling = nn.Linear(config.hidden_size // 4, + config.hidden_size) + + def forward(self, hidden_states): + squeeze = torch.mean(hidden_states, 1, keepdim=True) + excitation = torch.sigmoid( + self.up_sampling(gelu(self.down_sampling(squeeze)))) + return hidden_states * excitation + + +class BERTOutput(nn.Module): + + def __init__(self, config): + super(BERTOutput, self).__init__() + self.config = config + if config.transition_function.lower() == 'linear': + self.dense = nn.Linear(config.intermediate_size, + config.hidden_size) + elif config.transition_function.lower() == 'cnn': + self.cnn = DepthwiseSeparableConv1d( + 4 * config.hidden_size, config.hidden_size, kernel_size=7) + elif config.config.hidden_size.lower() == 'rnn': + raise NotImplementedError( + 'rnn transition function is not implemented yet') + else: + raise ValueError('Only support linear/cnn/rnn') + if not config.pre_ln and not config.rezero: + self.LayerNorm = BERTLayerNorm(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + if config.squeeze_excitation: + self.SEblock = SqueezeExcitationBlock(config) + if config.rezero: + self.res_factor = nn.Parameter( + torch.Tensor(1).fill_(0.99).to( + dtype=next(self.parameters()).dtype)) + self.factor = nn.Parameter( + torch.ones(1).to(dtype=next(self.parameters()).dtype)) + + def forward(self, hidden_states, input_tensor): + if self.config.transition_function.lower() == 'linear': + hidden_states = self.dense(hidden_states) + elif self.config.transition_function.lower() == 'cnn': + hidden_states = self.cnn(hidden_states.transpose(-1, + -2)).transpose( + -1, -2) + else: + pass + hidden_states = self.dropout(hidden_states) + if self.config.squeeze_excitation: + hidden_states = self.SEblock(hidden_states) + if not self.config.rezero and not self.config.pre_ln: + hidden_states = self.LayerNorm(hidden_states + input_tensor) + elif self.config.rezero: + hidden_states = hidden_states + self.factor * input_tensor + else: + pass + return hidden_states + + +class BERTLayer(nn.Module): + + def __init__(self, config): + super(BERTLayer, self).__init__() + self.attention = BERTAttention(config) + self.intermediate = BERTIntermediate(config) + self.output = BERTOutput(config) + + def forward(self, hidden_states, attention_mask, head_mask=None): + attention_output = self.attention(hidden_states, attention_mask, + head_mask) + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return attention_output, layer_output + + +class BERTWeightedLayer(nn.Module): + + def __init__(self, config): + super(BERTWeightedLayer, self).__init__() + self.config = config + self.self = BERTSelfAttention(config) + self.attention_head_size = self.self.attention_head_size + + self.w_o = nn.ModuleList([ + nn.Linear(self.attention_head_size, config.hidden_size) + for _ in range(config.num_attention_heads) + ]) + self.w_kp = torch.rand(config.num_attention_heads) + self.w_kp = nn.Parameter(self.w_kp / self.w_kp.sum()) + self.w_a = torch.rand(config.num_attention_heads) + self.w_a = nn.Parameter(self.w_a / self.w_a.sum()) + + self.intermediate = BERTIntermediate(config) + self.output = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = BERTLayerNorm(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, attention_mask): + self_output = self.self(hidden_states, attention_mask) + self_outputs = self_output.split(self.self.attention_head_size, dim=-1) + self_outputs = [ + self.w_o[i](self_outputs[i]) for i in range(len(self_outputs)) + ] + self_outputs = [ + self.dropout(self_outputs[i]) for i in range(len(self_outputs)) + ] + self_outputs = [ + kappa * output for kappa, output in zip(self.w_kp, self_outputs) + ] + self_outputs = [ + self.intermediate(self_outputs[i]) + for i in range(len(self_outputs)) + ] + self_outputs = [ + self.output(self_outputs[i]) for i in range(len(self_outputs)) + ] + self_outputs = [ + self.dropout(self_outputs[i]) for i in range(len(self_outputs)) + ] + self_outputs = [ + alpha * output for alpha, output in zip(self.w_a, self_outputs) + ] + output = sum(self_outputs) + return self.LayerNorm(hidden_states + output) + + +class BERTEncoder(nn.Module): + + def __init__(self, config): + super(BERTEncoder, self).__init__() + self.layer = nn.ModuleList() + for _ in range(config.num_hidden_layers): + if config.weighted_transformer: + self.layer.append(BERTWeightedLayer(config)) + else: + self.layer.append(BERTLayer(config)) + if config.rezero: + for index, layer in enumerate(self.layer): + layer.output.res_factor = nn.Parameter( + torch.Tensor(1).fill_(1.).to( + dtype=next(self.parameters()).dtype)) + layer.output.factor = nn.Parameter( + torch.Tensor(1).fill_(1).to( + dtype=next(self.parameters()).dtype)) + layer.attention.output.res_factor = layer.output.res_factor + layer.attention.output.factor = layer.output.factor + self.config = config + + def forward(self, + hidden_states, + attention_mask, + epoch_id=-1, + head_masks=None): + all_encoder_layers = [hidden_states] + if epoch_id != -1: + detach_index = int(len(self.layer) / 3) * (2 - epoch_id) - 1 + else: + detach_index = -1 + for index, layer_module in enumerate(self.layer): + if head_masks is None: + if not self.config.grad_checkpoint: + self_out, hidden_states = layer_module( + hidden_states, attention_mask, None) + else: + self_out, hidden_states = torch.utils.checkpoint.checkpoint( + layer_module, hidden_states, attention_mask, None) + else: + self_out, hidden_states = layer_module(hidden_states, + attention_mask, + head_masks[index]) + if detach_index == index: + hidden_states.detach_() + all_encoder_layers.append(self_out) + all_encoder_layers.append(hidden_states) + return all_encoder_layers + + +class BERTEncoderRolled(nn.Module): + + def __init__(self, config): + super(BERTEncoderRolled, self).__init__() + layer = BERTLayer(config) + self.config = config + self.layer = nn.ModuleList( + [copy.deepcopy(layer) for _ in range(config.num_rolled_layers)]) + + def forward(self, + hidden_states, + attention_mask, + epoch_id=-1, + head_masks=None): + all_encoder_layers = [hidden_states] + for i in range(self.config.num_hidden_layers): + if self.config.transformer_type.lower() == 'universal': + hidden_states = self.layer[i % self.config.num_rolled_layers]( + hidden_states, attention_mask) + elif self.config.transformer_type.lower() == 'albert': + a = i // ( + self.config.num_hidden_layers + // self.config.num_rolled_layers) + hidden_states = self.layer[a](hidden_states, attention_mask) + all_encoder_layers.append(hidden_states) + return all_encoder_layers + + +class BERTEncoderACT(nn.Module): + + def __init__(self, config): + super(BERTEncoderACT, self).__init__() + self.layer = BERTLayer(config) + p = nn.Linear(config.hidden_size, 1) + self.p = nn.ModuleList( + [copy.deepcopy(p) for _ in range(config.num_hidden_layers)]) + # Following act paper, set bias init ones + for module in self.p: + module.bias.data.fill_(1.) + self.config = config + self.act_max_steps = config.num_hidden_layers + self.threshold = 0.99 + + def should_continue(self, halting_probability, n_updates): + return (halting_probability.lt(self.threshold).__and__( + n_updates.lt(self.act_max_steps))).any() + + def forward(self, hidden_states, attention_mask): + all_encoder_layers = [hidden_states] + batch_size, seq_len, hdim = hidden_states.size() + halting_probability = torch.zeros(batch_size, seq_len).cuda() + remainders = torch.zeros(batch_size, seq_len).cuda() + n_updates = torch.zeros(batch_size, seq_len).cuda() + for i in range(self.act_max_steps): + p = torch.sigmoid(self.p[i](hidden_states).squeeze(2)) + still_running = halting_probability.lt(1.0).float() + new_halted = (halting_probability + p * still_running).gt( + self.threshold).float() * still_running + still_running = (halting_probability + p * still_running).le( + self.threshold).float() * still_running + halting_probability = halting_probability + p * still_running + remainders = remainders + new_halted * (1 - halting_probability) + halting_probability = halting_probability + new_halted * remainders + n_updates = n_updates + still_running + new_halted + update_weights = (p * still_running + + new_halted * remainders).unsqueeze(2) + transformed_states = self.layer(hidden_states, attention_mask) + hidden_states = transformed_states * update_weights + hidden_states * ( + 1 - update_weights) + all_encoder_layers.append(hidden_states) + if not self.should_continue(halting_probability, n_updates): + break + return all_encoder_layers, torch.mean(n_updates + remainders) + + +class BERTPooler(nn.Module): + + def __init__(self, config): + super(BERTPooler, self).__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertModel(nn.Module): + """BERT model ("Bidirectional Embedding Representations from a Transformer"). + + Example usage: + ```python + # Already been converted into WordPiece token ids + input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) + input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) + token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]]) + + config = modeling.BertConfig(vocab_size=32000, hidden_size=512, + num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024) + + model = modeling.BertModel(config=config) + all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) + ``` + """ + + def __init__(self, config: BertConfig): + """Constructor for BertModel. + + Args: + config: `BertConfig` instance. + """ + super(BertModel, self).__init__() + self.config = config + self.embeddings = BERTEmbeddings(config) + if config.transformer_type.lower() == 'original': + self.encoder = BERTEncoder(config) + elif config.transformer_type.lower() == 'universal': + self.encoder = BERTEncoderRolled(config) + elif config.transformer_type.lower() == 'albert': + self.encoder = BERTEncoderRolled(config) + elif config.transformer_type.lower() == 'act': + self.encoder = BERTEncoderACT(config) + elif config.transformer_type.lower() == 'textnas': + from textnas_final import op_dict, input_dict, skip_dict + self.encoder = TextNASEncoder(config, op_dict, input_dict, + skip_dict) + else: + raise ValueError('Not support transformer type: {}'.format( + config.transformer_type.lower())) + self.pooler = BERTPooler(config) + + def forward(self, + input_ids, + token_type_ids=None, + attention_mask=None, + epoch_id=-1, + head_masks=None, + adv_embedding=None): + if attention_mask is None: + attention_mask = torch.ones_like(input_ids) + if token_type_ids is None: + token_type_ids = torch.zeros_like(input_ids) + + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to( + dtype=next(self.parameters()).dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + + embedding_output, word_embeddings = self.embeddings( + input_ids, token_type_ids, adv_embedding) + if self.config.transformer_type.lower() == 'act': + all_encoder_layers, act_loss = self.encoder( + embedding_output, extended_attention_mask) + elif self.config.transformer_type.lower() == 'reformer': + sequence_output = self.encoder(embedding_output) + all_encoder_layers = [sequence_output, sequence_output] + else: + all_encoder_layers = self.encoder(embedding_output, + extended_attention_mask, + epoch_id, head_masks) + all_encoder_layers.insert(0, word_embeddings) + sequence_output = all_encoder_layers[-1] + if not self.config.safer_fp16: + pooled_output = self.pooler(sequence_output) + else: + pooled_output = sequence_output[:, 0] + return all_encoder_layers, pooled_output + + +class BertForSequenceClassificationMultiTask(nn.Module): + """BERT model for classification. + This module is composed of the BERT model with a linear layer on top of + the pooled output. + + Example usage: + ```python + # Already been converted into WordPiece token ids + input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) + input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) + token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]]) + + config = BertConfig(vocab_size=32000, hidden_size=512, + num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024) + + num_labels = 2 + + model = BertForSequenceClassification(config, num_labels) + logits = model(input_ids, token_type_ids, input_mask) + ``` + """ + + def __init__(self, config, label_list, core_encoder): + super(BertForSequenceClassificationMultiTask, self).__init__() + if core_encoder.lower() == 'bert': + self.bert = BertModel(config) + elif core_encoder.lower() == 'lstm': + self.bert = LSTMModel(config) + else: + raise ValueError( + 'Only support lstm or bert, but got {}'.format(core_encoder)) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.ModuleList() + for label in label_list: + self.classifier.append(nn.Linear(config.hidden_size, len(label))) + self.label_list = label_list + + def init_weights(module): + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_( + mean=0.0, std=config.initializer_range) + elif isinstance(module, BERTLayerNorm): + module.beta.data.normal_( + mean=0.0, std=config.initializer_range) + module.gamma.data.normal_( + mean=0.0, std=config.initializer_range) + if isinstance(module, nn.Linear): + module.bias.data.zero_() + + self.apply(init_weights) + + def forward(self, + input_ids, + token_type_ids, + attention_mask, + labels=None, + labels_index=None, + epoch_id=-1, + head_masks=None, + adv_embedding=None, + return_embedding=False, + loss_weight=None): + all_encoder_layers, pooled_output = self.bert(input_ids, + token_type_ids, + attention_mask, epoch_id, + head_masks, + adv_embedding) + pooled_output = self.dropout(pooled_output) + logits = [classifier(pooled_output) for classifier in self.classifier] + if labels is not None: + loss_fct = CrossEntropyLoss(reduction='none') + regression_loss_fct = nn.MSELoss(reduction='none') + labels_lst = torch.unbind(labels, 1) + loss_lst = [] + for index, (label, logit) in enumerate(zip(labels_lst, logits)): + if len(self.label_list[index]) != 1: + loss = loss_fct(logit, label.long()) + else: + loss = regression_loss_fct(logit.squeeze(-1), label) + labels_mask = (labels_index == index).to( + dtype=next(self.parameters()).dtype) + if loss_weight is not None: + loss = loss * loss_weight[index] + loss = torch.mean(loss * labels_mask) + loss_lst.append(loss) + if not return_embedding: + return sum(loss_lst), logits + else: + return sum(loss_lst), logits, all_encoder_layers[0] + else: + return logits diff --git a/modelscope/models/multi_modal/imagen/tokenizer.py b/modelscope/models/multi_modal/imagen/tokenizer.py new file mode 100644 index 00000000..82c09661 --- /dev/null +++ b/modelscope/models/multi_modal/imagen/tokenizer.py @@ -0,0 +1,333 @@ +# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team and Alibaba inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes.""" + +from __future__ import absolute_import, division, print_function +import collections +import unicodedata + +import six + + +def convert_to_unicode(text): + """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" + if six.PY3: + if isinstance(text, str): + return text + elif isinstance(text, bytes): + return text.decode('utf-8', 'ignore') + else: + raise ValueError('Unsupported string type: %s' % (type(text))) + elif six.PY2: + if isinstance(text, str): + return text.decode('utf-8', 'ignore') + elif isinstance(text, unicode): + return text + else: + raise ValueError('Unsupported string type: %s' % (type(text))) + else: + raise ValueError('Not running on Python2 or Python 3?') + + +def printable_text(text): + """Returns text encoded in a way suitable for print or `tf.logging`.""" + + # These functions want `str` for both Python2 and Python3, but in one case + # it's a Unicode string and in the other it's a byte string. + if six.PY3: + if isinstance(text, str): + return text + elif isinstance(text, bytes): + return text.decode('utf-8', 'ignore') + else: + raise ValueError('Unsupported string type: %s' % (type(text))) + elif six.PY2: + if isinstance(text, str): + return text + elif isinstance(text, unicode): + return text.encode('utf-8') + else: + raise ValueError('Unsupported string type: %s' % (type(text))) + else: + raise ValueError('Not running on Python2 or Python 3?') + + +def load_vocab(vocab_file): + """Loads a vocabulary file into a dictionary.""" + vocab = collections.OrderedDict() + index = 0 + with open(vocab_file, 'r') as reader: + while True: + token = convert_to_unicode(reader.readline()) + if not token: + break + token = token.strip() + vocab[token] = index + index += 1 + return vocab + + +def convert_tokens_to_ids(vocab, tokens): + """Converts a sequence of tokens into ids using the vocab.""" + ids = [] + for token in tokens: + ids.append(vocab[token]) + return ids + + +def whitespace_tokenize(text): + """Runs basic whitespace cleaning and splitting on a peice of text.""" + text = text.strip() + if not text: + return [] + tokens = text.split() + return tokens + + +class FullTokenizer(object): + """Runs end-to-end tokenziation.""" + + def __init__(self, vocab_file, do_lower_case=True): + self.vocab = load_vocab(vocab_file) + self.inv_vocab = {v: k for k, v in self.vocab.items()} + self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) + self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) + + def tokenize(self, text): + split_tokens = [] + for token in self.basic_tokenizer.tokenize(text): + for sub_token in self.wordpiece_tokenizer.tokenize(token): + split_tokens.append(sub_token) + + return split_tokens + + def convert_tokens_to_ids(self, tokens): + return convert_tokens_to_ids(self.vocab, tokens) + + def convert_ids_to_tokens(self, ids): + return [self.inv_vocab[i] for i in ids] + + +class BasicTokenizer(object): + """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" + + def __init__(self, do_lower_case=True): + """Constructs a BasicTokenizer. + + Args: + do_lower_case: Whether to lower case the input. + """ + self.do_lower_case = do_lower_case + + def tokenize(self, text): + """Tokenizes a piece of text.""" + text = convert_to_unicode(text) + text = self._clean_text(text) + # This was added on November 1st, 2018 for the multilingual and Chinese + # models. This is also applied to the English models now, but it doesn't + # matter since the English models were not trained on any Chinese data + # and generally don't have any Chinese data in them (there are Chinese + # characters in the vocabulary because Wikipedia does have some Chinese + # words in the English Wikipedia.). + text = self._tokenize_chinese_chars(text) + orig_tokens = whitespace_tokenize(text) + split_tokens = [] + for token in orig_tokens: + if self.do_lower_case: + token = token.lower() + token = self._run_strip_accents(token) + split_tokens.extend(self._run_split_on_punc(token)) + + output_tokens = whitespace_tokenize(' '.join(split_tokens)) + return output_tokens + + def _run_strip_accents(self, text): + """Strips accents from a piece of text.""" + text = unicodedata.normalize('NFD', text) + output = [] + for char in text: + cat = unicodedata.category(char) + if cat == 'Mn': + continue + output.append(char) + return ''.join(output) + + def _run_split_on_punc(self, text): + """Splits punctuation on a piece of text.""" + chars = list(text) + i = 0 + start_new_word = True + output = [] + while i < len(chars): + char = chars[i] + if _is_punctuation(char): + output.append([char]) + start_new_word = True + else: + if start_new_word: + output.append([]) + start_new_word = False + output[-1].append(char) + i += 1 + + return [''.join(x) for x in output] + + def _tokenize_chinese_chars(self, text): + """Adds whitespace around any CJK character.""" + output = [] + for char in text: + cp = ord(char) + if self._is_chinese_char(cp): + output.append(' ') + output.append(char) + output.append(' ') + else: + output.append(char) + return ''.join(output) + + def _is_chinese_char(self, cp): + """Checks whether CP is the codepoint of a CJK character.""" + # This defines a "chinese character" as anything in the CJK Unicode block: + # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) + # + # Note that the CJK Unicode block is NOT all Japanese and Korean characters, + # despite its name. The modern Korean Hangul alphabet is a different block, + # as is Japanese Hiragana and Katakana. Those alphabets are used to write + # space-separated words, so they are not treated specially and handled + # like the all of the other languages. + if ((cp >= 0x4E00 and cp <= 0x9FFF) or (cp >= 0x3400 and cp <= 0x4DBF) + or (cp >= 0x20000 and cp <= 0x2A6DF) + or (cp >= 0x2A700 and cp <= 0x2B73F) + or (cp >= 0x2B740 and cp <= 0x2B81F) + or (cp >= 0x2B820 and cp <= 0x2CEAF) + or (cp >= 0xF900 and cp <= 0xFAFF) + or (cp >= 0x2F800 and cp <= 0x2FA1F)): + return True + + return False + + def _clean_text(self, text): + """Performs invalid character removal and whitespace cleanup on text.""" + output = [] + for char in text: + cp = ord(char) + if cp == 0 or cp == 0xfffd or _is_control(char): + continue + if _is_whitespace(char): + output.append(' ') + else: + output.append(char) + return ''.join(output) + + +class WordpieceTokenizer(object): + """Runs WordPiece tokenization.""" + + def __init__(self, vocab, unk_token='[UNK]', max_input_chars_per_word=100): + self.vocab = vocab + self.unk_token = unk_token + self.max_input_chars_per_word = max_input_chars_per_word + + def tokenize(self, text): + """Tokenizes a piece of text into its word pieces. + + This uses a greedy longest-match-first algorithm to perform tokenization + using the given vocabulary. + + For example: + input = "unaffable" + output = ["un", "##aff", "##able"] + + Args: + text: A single token or whitespace separated tokens. This should have + already been passed through `BasicTokenizer. + + Returns: + A list of wordpiece tokens. + """ + + text = convert_to_unicode(text) + + output_tokens = [] + for token in whitespace_tokenize(text): + chars = list(token) + if len(chars) > self.max_input_chars_per_word: + output_tokens.append(self.unk_token) + continue + + is_bad = False + start = 0 + sub_tokens = [] + while start < len(chars): + end = len(chars) + cur_substr = None + while start < end: + substr = ''.join(chars[start:end]) + if start > 0: + substr = '##' + substr + if substr in self.vocab: + cur_substr = substr + break + end -= 1 + if cur_substr is None: + is_bad = True + break + sub_tokens.append(cur_substr) + start = end + + if is_bad: + output_tokens.append(self.unk_token) + else: + output_tokens.extend(sub_tokens) + return output_tokens + + +def _is_whitespace(char): + """Checks whether `chars` is a whitespace character.""" + # \t, \n, and \r are technically contorl characters but we treat them + # as whitespace since they are generally considered as such. + if char == ' ' or char == '\t' or char == '\n' or char == '\r': + return True + cat = unicodedata.category(char) + if cat == 'Zs': + return True + return False + + +def _is_control(char): + """Checks whether `chars` is a control character.""" + # These are technically control characters but we count them as whitespace + # characters. + if char == '\t' or char == '\n' or char == '\r': + return False + cat = unicodedata.category(char) + if cat.startswith('C'): + return True + return False + + +def _is_punctuation(char): + """Checks whether `chars` is a punctuation character.""" + cp = ord(char) + # We treat all non-letter/number ASCII as punctuation. + # Characters such as "^", "$", and "`" are not in the Unicode + # Punctuation class but we treat them as punctuation anyways, for + # consistency. + if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) + or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): + return True + cat = unicodedata.category(char) + if cat.startswith('P'): + return True + return False diff --git a/modelscope/models/multi_modal/imagen/unet_generator.py b/modelscope/models/multi_modal/imagen/unet_generator.py new file mode 100644 index 00000000..2b780a36 --- /dev/null +++ b/modelscope/models/multi_modal/imagen/unet_generator.py @@ -0,0 +1,319 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +__all__ = ['ImagenGenerator'] + + +def sinusoidal_embedding(timesteps, dim): + # check input + half = dim // 2 + timesteps = timesteps.float() + + # compute sinusoidal embedding + sinusoid = torch.outer( + timesteps, torch.pow(10000, + -torch.arange(half).to(timesteps).div(half))) + x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) + if dim % 2 != 0: + x = torch.cat([x, torch.zeros_like(x[:, :1])], dim=1) + return x + + +class Resample(nn.Module): + + def __init__(self, in_dim, out_dim, scale_factor, use_conv=False): + assert scale_factor in [0.5, 1.0, 2.0] + super(Resample, self).__init__() + self.in_dim = in_dim + self.out_dim = out_dim + self.scale_factor = scale_factor + self.use_conv = use_conv + + # layers + if scale_factor == 2.0: + self.resample = nn.Sequential( + nn.Upsample(scale_factor=scale_factor, mode='nearest'), + nn.Conv2d(in_dim, out_dim, 3, padding=1) + if use_conv else nn.Identity()) + elif scale_factor == 0.5: + self.resample = nn.Conv2d( + in_dim, out_dim, 3, stride=2, + padding=1) if use_conv else nn.AvgPool2d( + kernel_size=2, stride=2) + else: + self.resample = nn.Identity() + + def forward(self, x): + return self.resample(x) + + +class ResidualBlock(nn.Module): + + def __init__(self, + in_dim, + embed_dim, + out_dim, + use_scale_shift_norm=True, + scale_factor=1.0, + dropout=0.0): + super(ResidualBlock, self).__init__() + self.in_dim = in_dim + self.embed_dim = embed_dim + self.out_dim = out_dim + self.use_scale_shift_norm = use_scale_shift_norm + self.scale_factor = scale_factor + + # layers + self.layer1 = nn.Sequential( + nn.GroupNorm(32, in_dim), nn.SiLU(), + nn.Conv2d(in_dim, out_dim, 3, padding=1)) + self.resample = Resample(in_dim, in_dim, scale_factor, use_conv=False) + self.embedding = nn.Sequential( + nn.SiLU(), + nn.Linear(embed_dim, + out_dim * 2 if use_scale_shift_norm else out_dim)) + self.layer2 = nn.Sequential( + nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout), + nn.Conv2d(out_dim, out_dim, 3, padding=1)) + self.shortcut = nn.Identity() if in_dim == out_dim else nn.Conv2d( + in_dim, out_dim, 1) + + # zero out the last layer params + nn.init.zeros_(self.layer2[-1].weight) + + def forward(self, x, e): + identity = self.resample(x) + x = self.layer1[-1](self.resample(self.layer1[:-1](x))) + e = self.embedding(e).unsqueeze(-1).unsqueeze(-1).type(x.dtype) + if self.use_scale_shift_norm: + scale, shift = e.chunk(2, dim=1) + x = self.layer2[0](x) * (1 + scale) + shift + x = self.layer2[1:](x) + else: + x = x + e + x = self.layer2(x) + x = x + self.shortcut(identity) + return x + + +class AttentionBlock(nn.Module): + + def __init__(self, dim, context_dim=None, num_heads=None, head_dim=None): + # consider head_dim first, then num_heads + num_heads = dim // head_dim if head_dim else num_heads + head_dim = dim // num_heads + assert num_heads * head_dim == dim + super(AttentionBlock, self).__init__() + self.dim = dim + self.context_dim = context_dim + self.num_heads = num_heads + self.head_dim = head_dim + self.scale = math.pow(head_dim, -0.25) + + # layers + self.norm = nn.GroupNorm(32, dim) + self.to_qkv = nn.Conv2d(dim, dim * 3, 1) + if context_dim is not None: + self.context_kv = nn.Linear(context_dim, dim * 2) + self.proj = nn.Conv2d(dim, dim, 1) + + # zero out the last layer params + nn.init.zeros_(self.proj.weight) + + def forward(self, x, context=None, mask=None): + identity = x + b, c, h, w, n, d = *x.size(), self.num_heads, self.head_dim + + # compute query, key, value + x = self.norm(x) + q, k, v = self.to_qkv(x).view(b, n * 3, d, h * w).chunk(3, dim=1) + if context is not None: + ck, cv = self.context_kv(context).reshape(b, -1, n * 2, + d).permute(0, 2, 3, + 1).chunk( + 2, dim=1) + k = torch.cat([ck, k], dim=-1) + v = torch.cat([cv, v], dim=-1) + + # compute attention + attn = torch.matmul(q.transpose(-1, -2) * self.scale, k * self.scale) + if mask is not None: + assert context is not None + full_mask = x.new_ones((b, 1, q.size(-1), k.size(-1))) + full_mask[:, 0, :, :-q.size(-1)] = mask.unsqueeze(1) + attn = attn.masked_fill(full_mask == 0, float('-inf')) + attn = F.softmax(attn, dim=-1) + + # gather context + x = torch.matmul(v, attn.transpose(-1, -2)) + x = x.reshape(b, c, h, w) + + # output + x = self.proj(x) + return x + identity + + +class ImagenGenerator(nn.Module): + + def __init__(self, + in_dim=3, + dim=512, + text_dim=1024, + context_dim=512, + out_dim=6, + dim_mult=[1, 2, 3, 4], + num_heads=None, + head_dim=64, + num_res_blocks=3, + attn_scales=[1 / 2, 1 / 4, 1 / 8], + resblock_resample=True, + use_scale_shift_norm=True, + dropout=0.0): + embed_dim = dim * 4 + super(ImagenGenerator, self).__init__() + self.in_dim = in_dim + self.dim = dim + self.text_dim = text_dim + self.context_dim = context_dim + self.embed_dim = embed_dim + self.out_dim = out_dim + self.dim_mult = dim_mult + self.num_heads = num_heads + self.head_dim = head_dim + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.resblock_resample = resblock_resample + self.use_scale_shift_norm = use_scale_shift_norm + + # params + enc_dims = [dim * u for u in [1] + dim_mult] + dec_dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + shortcut_dims = [] + scale = 1.0 + + # embeddings + self.time_embedding = nn.Sequential( + nn.Linear(dim, embed_dim), nn.SiLU(), + nn.Linear(embed_dim, embed_dim)) + self.pool_embedding = nn.Sequential( + nn.LayerNorm(text_dim), nn.Linear(text_dim, embed_dim)) + self.text_embedding = nn.Sequential( + nn.LayerNorm(text_dim), nn.Linear(text_dim, context_dim), + nn.SiLU(), nn.Linear(context_dim, context_dim)) + + # encoder + self.encoder = nn.ModuleList( + [nn.Conv2d(self.in_dim, dim, 3, padding=1)]) + shortcut_dims.append(dim) + for i, (in_dim, + out_dim) in enumerate(zip(enc_dims[:-1], enc_dims[1:])): + for j in range(num_res_blocks): + # residual (+attention) blocks + block = nn.ModuleList( + [ResidualBlock(in_dim, embed_dim, out_dim, dropout)]) + if scale in attn_scales: + block.append( + AttentionBlock(out_dim, context_dim, num_heads, + head_dim)) + in_dim = out_dim + self.encoder.append(block) + shortcut_dims.append(out_dim) + + # downsample + if i != len(dim_mult) - 1 and j == num_res_blocks - 1: + if resblock_resample: + downsample = ResidualBlock(out_dim, embed_dim, out_dim, + use_scale_shift_norm, 0.5, + dropout) + else: + downsample = Resample( + out_dim, out_dim, 0.5, use_conv=True) + shortcut_dims.append(out_dim) + scale /= 2.0 + self.encoder.append(downsample) + + # middle + self.middle = nn.ModuleList([ + ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm, + 1.0, dropout), + AttentionBlock(out_dim, context_dim, num_heads, head_dim), + ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm, + 1.0, dropout) + ]) + + # decoder + self.decoder = nn.ModuleList() + for i, (in_dim, + out_dim) in enumerate(zip(dec_dims[:-1], dec_dims[1:])): + for j in range(num_res_blocks + 1): + # residual (+attention) blocks + block = nn.ModuleList([ + ResidualBlock(in_dim + shortcut_dims.pop(), embed_dim, + out_dim, use_scale_shift_norm, 1.0, dropout) + ]) + if scale in attn_scales: + block.append( + AttentionBlock(out_dim, context_dim, num_heads, + head_dim)) + in_dim = out_dim + + # upsample + if i != len(dim_mult) - 1 and j == num_res_blocks: + if resblock_resample: + upsample = ResidualBlock(out_dim, embed_dim, out_dim, + use_scale_shift_norm, 2.0, + dropout) + else: + upsample = Resample( + out_dim, out_dim, 2.0, use_conv=True) + scale *= 2.0 + block.append(upsample) + self.decoder.append(block) + + # head + self.head = nn.Sequential( + nn.GroupNorm(32, out_dim), nn.SiLU(), + nn.Conv2d(out_dim, self.out_dim, 3, padding=1)) + + # zero out the last layer params + nn.init.zeros_(self.head[-1].weight) + + def forward(self, x, t, y, context, mask=None): + # embeddings + e = self.time_embedding(sinusoidal_embedding( + t, self.dim)) + self.pool_embedding(y) + context = self.text_embedding(context) + + # encoder + xs = [] + for block in self.encoder: + x = self._forward_single(block, x, e, context, mask) + xs.append(x) + + # middle + for block in self.middle: + x = self._forward_single(block, x, e, context, mask) + + # decoder + for block in self.decoder: + x = torch.cat([x, xs.pop()], dim=1) + x = self._forward_single(block, x, e, context, mask) + + # head + x = self.head(x) + return x + + def _forward_single(self, module, x, e, context, mask): + if isinstance(module, ResidualBlock): + x = module(x, e) + elif isinstance(module, AttentionBlock): + x = module(x, context, mask) + elif isinstance(module, nn.ModuleList): + for block in module: + x = self._forward_single(block, x, e, context, mask) + else: + x = module(x) + return x diff --git a/modelscope/models/multi_modal/imagen/unet_imagen_upsampler_256.py b/modelscope/models/multi_modal/imagen/unet_imagen_upsampler_256.py new file mode 100644 index 00000000..0da8b805 --- /dev/null +++ b/modelscope/models/multi_modal/imagen/unet_imagen_upsampler_256.py @@ -0,0 +1,337 @@ +import math +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F + +__all__ = ['SuperResUNet256'] + + +def sinusoidal_embedding(timesteps, dim): + # check input + half = dim // 2 + timesteps = timesteps.float() + + # compute sinusoidal embedding + sinusoid = torch.outer( + timesteps, torch.pow(10000, + -torch.arange(half).to(timesteps).div(half))) + x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) + if dim % 2 != 0: + x = torch.cat([x, torch.zeros_like(x[:, :1])], dim=1) + return x + + +class Resample(nn.Module): + + def __init__(self, in_dim, out_dim, scale_factor, use_conv=False): + assert scale_factor in [0.5, 1.0, 2.0] + super(Resample, self).__init__() + self.in_dim = in_dim + self.out_dim = out_dim + self.scale_factor = scale_factor + self.use_conv = use_conv + + # layers + if scale_factor == 2.0: + self.resample = nn.Sequential( + nn.Upsample(scale_factor=scale_factor, mode='nearest'), + nn.Conv2d(in_dim, out_dim, 3, padding=1) + if use_conv else nn.Identity()) + elif scale_factor == 0.5: + self.resample = nn.Conv2d( + in_dim, out_dim, 3, stride=2, + padding=1) if use_conv else nn.AvgPool2d( + kernel_size=2, stride=2) + else: + self.resample = nn.Identity() + + def forward(self, x): + return self.resample(x) + + +class ResidualBlock(nn.Module): + + def __init__(self, + in_dim, + embed_dim, + out_dim, + use_scale_shift_norm=True, + scale_factor=1.0, + dropout=0.0): + super(ResidualBlock, self).__init__() + self.in_dim = in_dim + self.embed_dim = embed_dim + self.out_dim = out_dim + self.use_scale_shift_norm = use_scale_shift_norm + self.scale_factor = scale_factor + + # layers + self.layer1 = nn.Sequential( + nn.GroupNorm(32, in_dim), nn.SiLU(), + nn.Conv2d(in_dim, out_dim, 3, padding=1)) + self.resample_x = Resample(in_dim, in_dim, scale_factor) + self.resample_i = Resample(in_dim, in_dim, scale_factor) + self.embedding = nn.Sequential( + nn.SiLU(), + nn.Linear(embed_dim, + out_dim * 2 if use_scale_shift_norm else out_dim)) + self.layer2 = nn.Sequential( + nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout), + nn.Conv2d(out_dim, out_dim, 3, padding=1)) + self.shortcut = nn.Identity() if in_dim == out_dim else nn.Conv2d( + in_dim, out_dim, 1) + + # zero out the last layer params + nn.init.zeros_(self.layer2[-1].weight) + + def forward(self, x, e): + identity = self.resample_i(x) + x = self.layer1[-1](self.resample_x(self.layer1[:-1](x))) + e = self.embedding(e).unsqueeze(-1).unsqueeze(-1) + if self.use_scale_shift_norm: + scale, shift = e.chunk(2, dim=1) + x = self.layer2[0](x) * (1 + scale) + shift + x = self.layer2[1:](x) + else: + x = x + e + x = self.layer2(x) + x = x + self.shortcut(identity) + return x + + +class AttentionBlock(nn.Module): + + def __init__(self, dim, context_dim=None, num_heads=None, head_dim=None): + # consider head_dim first, then num_heads + num_heads = dim // head_dim if head_dim else num_heads + head_dim = dim // num_heads + assert num_heads * head_dim == dim + super(AttentionBlock, self).__init__() + self.dim = dim + self.context_dim = context_dim + self.num_heads = num_heads + self.head_dim = head_dim + self.scale = math.pow(head_dim, -0.25) + + # layers + self.norm = nn.GroupNorm(32, dim) + self.to_qkv = nn.Conv2d(dim, dim * 3, 1) + if context_dim is not None: + self.context_kv = nn.Linear(context_dim, dim * 2) + self.proj = nn.Conv2d(dim, dim, 1) + + # zero out the last layer params + nn.init.zeros_(self.proj.weight) + + def forward(self, x, context=None, mask=None): + r"""x: [B, C, H, W]. + context: [B, L, C] or None. + mask: [B, L] or None. + """ + identity = x + b, c, h, w, n, d = *x.size(), self.num_heads, self.head_dim + + # compute query, key, value + x = self.norm(x) + q, k, v = self.to_qkv(x).view(b, n * 3, d, h * w).chunk(3, dim=1) + if context is not None: + ck, cv = self.context_kv(context).reshape(b, -1, n * 2, + d).permute(0, 2, 3, + 1).chunk( + 2, dim=1) + k = torch.cat([k, ck], dim=-1) + v = torch.cat([v, cv], dim=-1) + + # compute attention + attn = torch.einsum('bndi,bndj->bnij', q * self.scale, k * self.scale) + if mask is not None: + pad_mask = mask.new_ones((b, 1, 1, h * w)) + mask = torch.cat((pad_mask, mask.unsqueeze(1).unsqueeze(1)), + dim=-1) + attn = attn.masked_fill(mask == 0, float('-inf')) + + attn = F.softmax(attn, dim=-1) + + # gather context + x = torch.einsum('bnij,bndj->bndi', attn, v) + x = x.reshape(b, c, h, w) + + # output + x = self.proj(x) + return x + identity + + +class SuperResUNet256(nn.Module): + + def __init__(self, + in_dim=6, + out_dim=3, + dim=256, + text_dim=1024, + context_dim=512, + dim_mult=[1, 2, 2, 3, 4], + num_heads=None, + head_dim=64, + num_res_blocks=2, + attn_scales=[1 / 16], + resblock_resample=True, + use_conv=True, + use_scale_shift_norm=True, + dropout=0.1): + embed_dim = dim * 4 + super(SuperResUNet256, self).__init__() + self.in_dim = in_dim + self.dim = dim + self.out_dim = out_dim + self.dim_mult = dim_mult + self.num_heads = num_heads + self.head_dim = head_dim + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.resblock_resample = resblock_resample + self.use_conv = use_conv + self.use_scale_shift_norm = use_scale_shift_norm + + # params + enc_dims = [dim * u for u in [1] + dim_mult] + dec_dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + shortcut_dims = [] + scale = 1.0 + + # embeddings + self.time_embedding = nn.Sequential( + nn.Linear(dim, embed_dim), nn.SiLU(), + nn.Linear(embed_dim, embed_dim)) + self.noise_time_embedding = nn.Sequential( + nn.Linear(dim, embed_dim), nn.SiLU(), + nn.Linear(embed_dim, embed_dim)) + self.pool_embedding = nn.Sequential( + nn.LayerNorm(text_dim), nn.Linear(text_dim, embed_dim)) + self.text_embedding = nn.Sequential( + nn.LayerNorm(text_dim), nn.Linear(text_dim, context_dim), + nn.SiLU(), nn.Linear(context_dim, context_dim)) + + # encoder + self.encoder = nn.ModuleList( + [nn.Conv2d(self.in_dim, dim, 3, padding=1)]) + shortcut_dims.append(dim) + for i, (in_dim, + out_dim) in enumerate(zip(enc_dims[:-1], enc_dims[1:])): + for j in range(num_res_blocks): + # residual (+attention) blocks + block = nn.ModuleList([ + ResidualBlock(in_dim, embed_dim, out_dim, + use_scale_shift_norm, 1.0, dropout) + ]) + if scale in attn_scales: + block.append( + AttentionBlock(out_dim, context_dim, num_heads, + head_dim)) + shortcut_dims.append(out_dim) + in_dim = out_dim + self.encoder.append(block) + + # downsample + if i != len(dim_mult) - 1: + if resblock_resample: + downsample = ResidualBlock(out_dim, embed_dim, out_dim, + use_scale_shift_norm, 0.5, + dropout) + else: + downsample = Resample(out_dim, out_dim, 0.5, use_conv) + shortcut_dims.append(out_dim) + scale /= 2.0 + self.encoder.append(downsample) + + # middle + self.middle = nn.ModuleList([ + ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm, + 1.0, dropout), + AttentionBlock(out_dim, context_dim, num_heads, head_dim), + ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm, + 1.0, dropout) + ]) + + # decoder + self.decoder = nn.ModuleList() + for i, (in_dim, + out_dim) in enumerate(zip(dec_dims[:-1], dec_dims[1:])): + for j in range(num_res_blocks + 1): + # residual (+attention) blocks + block = nn.ModuleList([ + ResidualBlock(in_dim + shortcut_dims.pop(), embed_dim, + out_dim, use_scale_shift_norm, 1.0, dropout) + ]) + if scale in attn_scales: + block.append( + AttentionBlock(out_dim, context_dim, num_heads, + head_dim)) + in_dim = out_dim + + # upsample + if i != len(dim_mult) - 1 and j == num_res_blocks: + if resblock_resample: + upsample = ResidualBlock(out_dim, embed_dim, out_dim, + use_scale_shift_norm, 2.0, + dropout) + else: + upsample = Resample(out_dim, out_dim, 2.0, use_conv) + scale *= 2.0 + block.append(upsample) + self.decoder.append(block) + + # head + self.head = nn.Sequential( + nn.GroupNorm(32, out_dim), nn.SiLU(), + nn.Conv2d(out_dim, self.out_dim, 3, padding=1)) + + # zero out the last layer params + nn.init.zeros_(self.head[-1].weight) + + def forward(self, x, t, lx, lt, y, context, mask): + assert context.shape[:-1] == mask.shape + + # embeddings + t = self.time_embedding(sinusoidal_embedding(t, self.dim)) \ + + self.noise_time_embedding(sinusoidal_embedding(lt, self.dim)) \ + + self.pool_embedding(y) + + context = self.text_embedding(context) + + if lx.shape[-2:] != x.shape[-2:]: + lx = F.interpolate( + lx, x.shape[-2:], mode='bilinear', align_corners=False) + x = torch.cat([x, lx], dim=1) + + # encoder + xs = [] + for block in self.encoder: + x = self._forward_single(block, x, t, context, mask) + xs.append(x) + + # middle + for block in self.middle: + x = self._forward_single(block, x, t, context, mask) + + # decoder + for block in self.decoder: + x = torch.cat([x, xs.pop()], dim=1) + x = self._forward_single(block, x, t, context, mask) + + # head + x = self.head(x) + return x + + def _forward_single(self, module, x, t, context, mask): + if isinstance(module, ResidualBlock): + x = module(x, t) + elif isinstance(module, AttentionBlock): + x = module(x, context, mask) + elif isinstance(module, nn.ModuleList): + for block in module: + x = self._forward_single(block, x, t, context, mask) + else: + x = module(x) + return x diff --git a/modelscope/models/multi_modal/imagen/unet_upsampler_1024.py b/modelscope/models/multi_modal/imagen/unet_upsampler_1024.py new file mode 100644 index 00000000..07d3648c --- /dev/null +++ b/modelscope/models/multi_modal/imagen/unet_upsampler_1024.py @@ -0,0 +1,240 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +__all__ = ['ImagenUpsampler1024'] + + +def sinusoidal_embedding(timesteps, dim): + # check input + half = dim // 2 + timesteps = timesteps.float() + + # compute sinusoidal embedding + sinusoid = torch.outer( + timesteps, torch.pow(10000, + -torch.arange(half).to(timesteps).div(half))) + x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) + if dim % 2 != 0: + x = torch.cat([x, torch.zeros_like(x[:, :1])], dim=1) + return x + + +class Resample(nn.Module): + + def __init__(self, in_dim, out_dim, scale_factor, use_conv=False): + assert scale_factor in [0.5, 1.0, 2.0] + super(Resample, self).__init__() + self.in_dim = in_dim + self.out_dim = out_dim + self.scale_factor = scale_factor + self.use_conv = use_conv + + # layers + if scale_factor == 2.0: + self.resample = nn.Sequential( + nn.Upsample(scale_factor=scale_factor, mode='nearest'), + nn.Conv2d(in_dim, out_dim, 3, padding=1) + if use_conv else nn.Identity()) + elif scale_factor == 0.5: + self.resample = nn.Conv2d( + in_dim, out_dim, 3, stride=2, + padding=1) if use_conv else nn.AvgPool2d( + kernel_size=2, stride=2) + else: + self.resample = nn.Identity() + + def forward(self, x): + return self.resample(x) + + +class ResidualBlock(nn.Module): + + def __init__(self, + in_dim, + embed_dim, + out_dim, + use_scale_shift_norm=True, + scale_factor=1.0, + dropout=0.0): + super(ResidualBlock, self).__init__() + self.in_dim = in_dim + self.embed_dim = embed_dim + self.out_dim = out_dim + self.use_scale_shift_norm = use_scale_shift_norm + self.scale_factor = scale_factor + + # layers + self.layer1 = nn.Sequential( + nn.GroupNorm(32, in_dim), nn.SiLU(), + nn.Conv2d(in_dim, out_dim, 3, padding=1)) + self.resample = Resample(in_dim, in_dim, scale_factor, use_conv=False) + self.embedding = nn.Sequential( + nn.SiLU(), + nn.Linear(embed_dim, + out_dim * 2 if use_scale_shift_norm else out_dim)) + self.layer2 = nn.Sequential( + nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout), + nn.Conv2d(out_dim, out_dim, 3, padding=1)) + self.shortcut = nn.Identity() if in_dim == out_dim else nn.Conv2d( + in_dim, out_dim, 1) + + # zero out the last layer params + nn.init.zeros_(self.layer2[-1].weight) + + def forward(self, x, e): + identity = self.resample(x) + x = self.layer1[-1](self.resample(self.layer1[:-1](x))) + e = self.embedding(e).unsqueeze(-1).unsqueeze(-1).type(x.dtype) + if self.use_scale_shift_norm: + scale, shift = e.chunk(2, dim=1) + x = self.layer2[0](x) * (1 + scale) + shift + x = self.layer2[1:](x) + else: + x = x + e + x = self.layer2(x) + x = x + self.shortcut(identity) + return x + + +class ImagenUpsampler1024(nn.Module): + + def __init__(self, + in_dim=6, + dim=192, + out_dim=3, + dim_mult=[1, 1, 2, 2, 4, 4], + num_res_blocks=2, + resblock_resample=True, + use_scale_shift_norm=True, + dropout=0.0): + embed_dim = dim * 4 + super(ImagenUpsampler1024, self).__init__() + self.in_dim = in_dim + self.dim = dim + self.out_dim = out_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.resblock_resample = resblock_resample + self.use_scale_shift_norm = use_scale_shift_norm + + # params + enc_dims = [dim * u for u in [1] + dim_mult] + dec_dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + shortcut_dims = [] + scale = 1.0 + + # embedding + self.time_embedding = nn.Sequential( + nn.Linear(dim, embed_dim), nn.SiLU(), + nn.Linear(embed_dim, embed_dim)) + + # encoder + self.encoder = nn.ModuleList( + [nn.Conv2d(self.in_dim, dim, 3, padding=1)]) + shortcut_dims.append(dim) + for i, (in_dim, + out_dim) in enumerate(zip(enc_dims[:-1], enc_dims[1:])): + for j in range(num_res_blocks): + # residual block + block = nn.ModuleList([ + ResidualBlock(in_dim, embed_dim, out_dim, + use_scale_shift_norm, 1.0, dropout) + ]) + shortcut_dims.append(out_dim) + in_dim = out_dim + self.encoder.append(block) + + # downsample + if i != len(dim_mult) - 1 and j == num_res_blocks - 1: + if resblock_resample: + downsample = ResidualBlock(out_dim, embed_dim, out_dim, + use_scale_shift_norm, 0.5, + dropout) + else: + downsample = Resample( + out_dim, out_dim, 0.5, use_conv=True) + shortcut_dims.append(out_dim) + scale /= 2.0 + self.encoder.append(downsample) + + # middle + self.middle = nn.ModuleList([ + ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm, + 1.0, dropout), + ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm, + 1.0, dropout) + ]) + + # decoder + self.decoder = nn.ModuleList() + for i, (in_dim, + out_dim) in enumerate(zip(dec_dims[:-1], dec_dims[1:])): + for j in range(num_res_blocks + 1): + # residual block + block = nn.ModuleList([ + ResidualBlock(in_dim + shortcut_dims.pop(), embed_dim, + out_dim, use_scale_shift_norm, 1.0, dropout) + ]) + in_dim = out_dim + + # upsample + if i != len(dim_mult) - 1 and j == num_res_blocks: + if resblock_resample: + upsample = ResidualBlock(out_dim, embed_dim, out_dim, + use_scale_shift_norm, 2.0, + dropout) + else: + upsample = Resample( + out_dim, out_dim, 2.0, use_conv=True) + scale *= 2.0 + block.append(upsample) + self.decoder.append(block) + + # head + self.head = nn.Sequential( + nn.GroupNorm(32, out_dim), nn.SiLU(), + nn.Conv2d(out_dim, self.out_dim, 3, padding=1)) + + # zero out the last layer params + nn.init.zeros_(self.head[-1].weight) + + def forward(self, x, t, concat): + # embedding + if concat is not None: + if concat.shape[-2:] != x.shape[-2:]: + concat = F.interpolate( + concat, x.shape[-2:], mode='bilinear', align_corners=False) + x = torch.cat([x, concat], dim=1) + e = self.time_embedding(sinusoidal_embedding(t, self.dim)) + + # encoder + xs = [] + for block in self.encoder: + x = self._forward_single(block, x, e) + xs.append(x) + + # middle + for block in self.middle: + x = self._forward_single(block, x, e) + + # decoder + for block in self.decoder: + x = torch.cat([x, xs.pop()], dim=1) + x = self._forward_single(block, x, e) + + # head + x = self.head(x) + return x + + def _forward_single(self, module, x, e): + if isinstance(module, ResidualBlock): + x = module(x, e) + elif isinstance(module, nn.ModuleList): + for block in module: + x = self._forward_single(block, x, e) + else: + x = module(x) + return x diff --git a/modelscope/pipelines/multi_modal/__init__.py b/modelscope/pipelines/multi_modal/__init__.py index 49b07cce..76c22238 100644 --- a/modelscope/pipelines/multi_modal/__init__.py +++ b/modelscope/pipelines/multi_modal/__init__.py @@ -1,6 +1,7 @@ try: from .image_captioning_pipeline import ImageCaptionPipeline from .multi_modal_embedding_pipeline import MultiModalEmbeddingPipeline + from .text_to_image_synthesis_pipeline import TextToImageSynthesisPipeline from .visual_question_answering_pipeline import VisualQuestionAnsweringPipeline except ModuleNotFoundError as e: if str(e) == "No module named 'torch'": diff --git a/modelscope/pipelines/multi_modal/text_to_image_synthesis_pipeline.py b/modelscope/pipelines/multi_modal/text_to_image_synthesis_pipeline.py new file mode 100644 index 00000000..edffe1f2 --- /dev/null +++ b/modelscope/pipelines/multi_modal/text_to_image_synthesis_pipeline.py @@ -0,0 +1,37 @@ +from typing import Any, Dict, Union + +from modelscope.metainfo import Pipelines +from modelscope.pipelines.base import Input +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger +from ..base import Model, Pipeline +from ..builder import PIPELINES +from ..outputs import OutputKeys + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.text_to_image_synthesis, + module_name=Pipelines.text_to_image_synthesis) +class TextToImageSynthesisPipeline(Pipeline): + + def __init__(self, model: str, device_id: int = -1): + if isinstance(model, str): + pipe_model = Model.from_pretrained(model) + elif isinstance(model, Model): + pipe_model = model + else: + raise NotImplementedError( + f'execpting a Model instance or str, but get {type(model)}.') + + super().__init__(model=pipe_model) + + def preprocess(self, input: Input) -> Dict[str, Any]: + return input + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + return self.model.generate(input) + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return {OutputKeys.OUTPUT_IMG: inputs} diff --git a/tests/pipelines/test_text_to_image_synthesis.py b/tests/pipelines/test_text_to_image_synthesis.py new file mode 100644 index 00000000..d5ce990d --- /dev/null +++ b/tests/pipelines/test_text_to_image_synthesis.py @@ -0,0 +1,45 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import unittest + +import numpy as np + +from modelscope.models import Model +from modelscope.pipelines import pipeline +from modelscope.pipelines.outputs import OutputKeys +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + + +class TextToImageSynthesisTest(unittest.TestCase): + model_id = 'damo/cv_imagen_text-to-image-synthesis_tiny' + test_text = {'text': '宇航员'} + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_from_modelhub(self): + model = Model.from_pretrained(self.model_id) + pipe_line_text_to_image_synthesis = pipeline( + task=Tasks.text_to_image_synthesis, model=model) + img = pipe_line_text_to_image_synthesis( + self.test_text)[OutputKeys.OUTPUT_IMG] + print(np.sum(np.abs(img))) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_model_name(self): + pipe_line_text_to_image_synthesis = pipeline( + task=Tasks.text_to_image_synthesis, model=self.model_id) + img = pipe_line_text_to_image_synthesis( + self.test_text)[OutputKeys.OUTPUT_IMG] + print(np.sum(np.abs(img))) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_with_default_model(self): + pipe_line_text_to_image_synthesis = pipeline( + task=Tasks.text_to_image_synthesis) + img = pipe_line_text_to_image_synthesis( + self.test_text)[OutputKeys.OUTPUT_IMG] + print(np.sum(np.abs(img))) + + +if __name__ == '__main__': + unittest.main()