diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 22c2d99e..d7217d57 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -72,6 +72,7 @@ class Models(object): gemm = 'gemm-generative-multi-modal' mplug = 'mplug' diffusion = 'diffusion-text-to-image-synthesis' + multi_stage_diffusion = 'multi-stage-diffusion-text-to-image-synthesis' team = 'team-multi-modal-similarity' video_clip = 'video-clip-multi-modal-embedding' diff --git a/modelscope/models/multi_modal/__init__.py b/modelscope/models/multi_modal/__init__.py index 9219a281..0053da43 100644 --- a/modelscope/models/multi_modal/__init__.py +++ b/modelscope/models/multi_modal/__init__.py @@ -14,6 +14,8 @@ if TYPE_CHECKING: from .ofa_for_all_tasks import OfaForAllTasks from .ofa_for_text_to_image_synthesis_model import \ OfaForTextToImageSynthesis + from .multi_stage_diffusion import \ + MultiStageDiffusionForTextToImageSynthesis else: _import_structure = { @@ -25,7 +27,9 @@ else: 'mplug_for_all_tasks': ['MPlugForAllTasks'], 'ofa_for_all_tasks': ['OfaForAllTasks'], 'ofa_for_text_to_image_synthesis_model': - ['OfaForTextToImageSynthesis'] + ['OfaForTextToImageSynthesis'], + 'multi_stage_diffusion': + ['MultiStageDiffusionForTextToImageSynthesis'] } import sys diff --git a/modelscope/models/multi_modal/multi_stage_diffusion/__init__.py b/modelscope/models/multi_modal/multi_stage_diffusion/__init__.py new file mode 100644 index 00000000..accbb56e --- /dev/null +++ b/modelscope/models/multi_modal/multi_stage_diffusion/__init__.py @@ -0,0 +1 @@ +from .model import MultiStageDiffusionForTextToImageSynthesis diff --git a/modelscope/models/multi_modal/multi_stage_diffusion/clip.py b/modelscope/models/multi_modal/multi_stage_diffusion/clip.py new file mode 100644 index 00000000..54e971f7 --- /dev/null +++ b/modelscope/models/multi_modal/multi_stage_diffusion/clip.py @@ -0,0 +1,318 @@ +# The implementation here is modified based on OpenAI CLIP, publicly available at https://github.com/openai/CLIP. + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +__all__ = ['CLIP'] + + +def to_fp16(m): + if isinstance(m, (nn.Linear, nn.Conv2d)): + m.weight.data = m.weight.data.half() + if m.bias is not None: + m.bias.data = m.bias.data.half() + elif hasattr(m, 'head'): + p = getattr(m, 'head') + p.data = p.data.half() + + +class QuickGELU(nn.Module): + + def forward(self, x): + return x * torch.sigmoid(1.702 * x) + + +class LayerNorm(nn.LayerNorm): + r"""Subclass of nn.LayerNorm to handle fp16. + """ + + def forward(self, x): + return super(LayerNorm, self).forward(x.float()).type_as(x) + + +class SelfAttention(nn.Module): + + def __init__(self, dim, num_heads, attn_dropout=0.0, proj_dropout=0.0): + assert dim % num_heads == 0 + super(SelfAttention, self).__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = 1.0 / math.sqrt(self.head_dim) + + # layers + self.to_qkv = nn.Linear(dim, dim * 3) + self.attn_dropout = nn.Dropout(attn_dropout) + self.proj = nn.Linear(dim, dim) + self.proj_dropout = nn.Dropout(proj_dropout) + + def forward(self, x, mask=None): + r"""x: [B, L, C]. + mask: [*, L, L]. + """ + b, l, _, n = *x.size(), self.num_heads + + # compute query, key, and value + q, k, v = self.to_qkv(x.transpose(0, 1)).chunk(3, dim=-1) + q = q.reshape(l, b * n, -1).transpose(0, 1) + k = k.reshape(l, b * n, -1).transpose(0, 1) + v = v.reshape(l, b * n, -1).transpose(0, 1) + + # compute attention + attn = self.scale * torch.bmm(q, k.transpose(1, 2)) + if mask is not None: + attn = attn.masked_fill(mask[:, :l, :l] == 0, float('-inf')) + attn = F.softmax(attn.float(), dim=-1).type_as(attn) + attn = self.attn_dropout(attn) + + # gather context + x = torch.bmm(attn, v) + x = x.view(b, n, l, -1).transpose(1, 2).reshape(b, l, -1) + + # output + x = self.proj(x) + x = self.proj_dropout(x) + return x + + +class AttentionBlock(nn.Module): + + def __init__(self, dim, num_heads, attn_dropout=0.0, proj_dropout=0.0): + super(AttentionBlock, self).__init__() + self.dim = dim + self.num_heads = num_heads + + # layers + self.norm1 = LayerNorm(dim) + self.attn = SelfAttention(dim, num_heads, attn_dropout, proj_dropout) + self.norm2 = LayerNorm(dim) + self.mlp = nn.Sequential( + nn.Linear(dim, dim * 4), QuickGELU(), nn.Linear(dim * 4, dim), + nn.Dropout(proj_dropout)) + + def forward(self, x, mask=None): + x = x + self.attn(self.norm1(x), mask) + x = x + self.mlp(self.norm2(x)) + return x + + +class VisionTransformer(nn.Module): + + def __init__(self, + image_size=224, + patch_size=16, + dim=768, + out_dim=512, + num_heads=12, + num_layers=12, + attn_dropout=0.0, + proj_dropout=0.0, + embedding_dropout=0.0): + assert image_size % patch_size == 0 + super(VisionTransformer, self).__init__() + self.image_size = image_size + self.patch_size = patch_size + self.dim = dim + self.out_dim = out_dim + self.num_heads = num_heads + self.num_layers = num_layers + self.num_patches = (image_size // patch_size)**2 + + # embeddings + gain = 1.0 / math.sqrt(dim) + self.patch_embedding = nn.Conv2d( + 3, dim, kernel_size=patch_size, stride=patch_size, bias=False) + self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim)) + self.pos_embedding = nn.Parameter( + gain * torch.randn(1, self.num_patches + 1, dim)) + self.dropout = nn.Dropout(embedding_dropout) + + # transformer + self.pre_norm = LayerNorm(dim) + self.transformer = nn.Sequential(*[ + AttentionBlock(dim, num_heads, attn_dropout, proj_dropout) + for _ in range(num_layers) + ]) + self.post_norm = LayerNorm(dim) + + # head + self.head = nn.Parameter(gain * torch.randn(dim, out_dim)) + + def forward(self, x): + b, dtype = x.size(0), self.head.dtype + x = x.type(dtype) + + # patch-embedding + x = self.patch_embedding(x).flatten(2).permute(0, 2, 1) # [b, n, c] + x = torch.cat([self.cls_embedding.repeat(b, 1, 1).type(dtype), x], + dim=1) + x = self.dropout(x + self.pos_embedding.type(dtype)) + x = self.pre_norm(x) + + # transformer + x = self.transformer(x) + + # head + x = self.post_norm(x) + x = torch.mm(x[:, 0, :], self.head) + return x + + def fp16(self): + return self.apply(to_fp16) + + +class TextTransformer(nn.Module): + + def __init__(self, + vocab_size, + text_len, + dim=512, + out_dim=512, + num_heads=8, + num_layers=12, + attn_dropout=0.0, + proj_dropout=0.0, + embedding_dropout=0.0): + super(TextTransformer, self).__init__() + self.vocab_size = vocab_size + self.text_len = text_len + self.dim = dim + self.out_dim = out_dim + self.num_heads = num_heads + self.num_layers = num_layers + + # embeddings + self.token_embedding = nn.Embedding(vocab_size, dim) + self.pos_embedding = nn.Parameter(0.01 * torch.randn(1, text_len, dim)) + self.dropout = nn.Dropout(embedding_dropout) + + # transformer + self.transformer = nn.ModuleList([ + AttentionBlock(dim, num_heads, attn_dropout, proj_dropout) + for _ in range(num_layers) + ]) + self.norm = LayerNorm(dim) + + # head + gain = 1.0 / math.sqrt(dim) + self.head = nn.Parameter(gain * torch.randn(dim, out_dim)) + + # causal attention mask + self.register_buffer('attn_mask', + torch.tril(torch.ones(1, text_len, text_len))) + + def forward(self, x): + eot, dtype = x.argmax(dim=-1), self.head.dtype + + # embeddings + x = self.dropout( + self.token_embedding(x).type(dtype) + + self.pos_embedding.type(dtype)) + + # transformer + for block in self.transformer: + x = block(x, self.attn_mask) + + # head + x = self.norm(x) + x = torch.mm(x[torch.arange(x.size(0)), eot], self.head) + return x + + def fp16(self): + return self.apply(to_fp16) + + +class CLIP(nn.Module): + + def __init__(self, + embed_dim=512, + image_size=224, + patch_size=16, + vision_dim=768, + vision_heads=12, + vision_layers=12, + vocab_size=49408, + text_len=77, + text_dim=512, + text_heads=8, + text_layers=12, + attn_dropout=0.0, + proj_dropout=0.0, + embedding_dropout=0.0): + super(CLIP, self).__init__() + self.embed_dim = embed_dim + self.image_size = image_size + self.patch_size = patch_size + self.vision_dim = vision_dim + self.vision_heads = vision_heads + self.vision_layers = vision_layers + self.vocab_size = vocab_size + self.text_len = text_len + self.text_dim = text_dim + self.text_heads = text_heads + self.text_layers = text_layers + + # models + self.visual = VisionTransformer( + image_size=image_size, + patch_size=patch_size, + dim=vision_dim, + out_dim=embed_dim, + num_heads=vision_heads, + num_layers=vision_layers, + attn_dropout=attn_dropout, + proj_dropout=proj_dropout, + embedding_dropout=embedding_dropout) + self.textual = TextTransformer( + vocab_size=vocab_size, + text_len=text_len, + dim=text_dim, + out_dim=embed_dim, + num_heads=text_heads, + num_layers=text_layers, + attn_dropout=attn_dropout, + proj_dropout=proj_dropout, + embedding_dropout=embedding_dropout) + self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([])) + + def forward(self, imgs, txt_tokens): + r"""imgs: [B, C, H, W] of torch.float32. + txt_tokens: [B, T] of torch.long. + """ + xi = self.visual(imgs) + xt = self.textual(txt_tokens) + + # normalize features + xi = F.normalize(xi, p=2, dim=1) + xt = F.normalize(xt, p=2, dim=1) + + # logits + scale = self.log_scale.exp() + logits_i2t = scale * torch.mm(xi, xt.t()) + logits_t2i = scale * torch.mm(xt, xi.t()) + return logits_i2t, logits_t2i + + def init_weights(self): + # embeddings + nn.init.normal_(self.textual.token_embedding.weight, std=0.02) + nn.init.normal_(self.visual.patch_embedding.weight, tsd=0.1) + + # attentions + for modality in ['visual', 'textual']: + dim = self.vision_dim if modality == 'visual' else 'textual' + transformer = getattr(self, modality).transformer + proj_gain = (1.0 / math.sqrt(dim)) * ( + 1.0 / math.sqrt(2 * transformer.num_layers)) + attn_gain = 1.0 / math.sqrt(dim) + mlp_gain = 1.0 / math.sqrt(2.0 * dim) + for block in transformer.layers: + nn.init.normal_(block.attn.to_qkv.weight, std=attn_gain) + nn.init.normal_(block.attn.proj.weight, std=proj_gain) + nn.init.normal_(block.mlp[0].weight, std=mlp_gain) + nn.init.normal_(block.mlp[2].weight, std=proj_gain) + + def fp16(self): + return self.apply(to_fp16) diff --git a/modelscope/models/multi_modal/multi_stage_diffusion/decoder.py b/modelscope/models/multi_modal/multi_stage_diffusion/decoder.py new file mode 100644 index 00000000..17daedaf --- /dev/null +++ b/modelscope/models/multi_modal/multi_stage_diffusion/decoder.py @@ -0,0 +1,322 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +__all__ = ['Decoder'] + + +def sinusoidal_embedding(timesteps, dim): + # check input + half = dim // 2 + timesteps = timesteps.float() + + # compute sinusoidal embedding + sinusoid = torch.outer( + timesteps, torch.pow(10000, + -torch.arange(half).to(timesteps).div(half))) + x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) + if dim % 2 != 0: + x = torch.cat([x, torch.zeros_like(x[:, :1])], dim=1) + return x + + +class Resample(nn.Module): + + def __init__(self, in_dim, out_dim, scale_factor, use_conv=False): + assert scale_factor in [0.5, 1.0, 2.0] + super(Resample, self).__init__() + self.in_dim = in_dim + self.out_dim = out_dim + self.scale_factor = scale_factor + self.use_conv = use_conv + + # layers + if scale_factor == 2.0: + self.resample = nn.Sequential( + nn.Upsample(scale_factor=scale_factor, mode='nearest'), + nn.Conv2d(in_dim, out_dim, 3, padding=1) + if use_conv else nn.Identity()) + elif scale_factor == 0.5: + self.resample = nn.Conv2d( + in_dim, out_dim, 3, stride=2, + padding=1) if use_conv else nn.AvgPool2d( + kernel_size=2, stride=2) + else: + self.resample = nn.Identity() + + def forward(self, x): + return self.resample(x) + + +class ResidualBlock(nn.Module): + + def __init__(self, + in_dim, + embed_dim, + out_dim, + use_scale_shift_norm=True, + scale_factor=1.0, + dropout=0.0): + super(ResidualBlock, self).__init__() + self.in_dim = in_dim + self.embed_dim = embed_dim + self.out_dim = out_dim + self.use_scale_shift_norm = use_scale_shift_norm + self.scale_factor = scale_factor + + # layers + self.layer1 = nn.Sequential( + nn.GroupNorm(32, in_dim), nn.SiLU(), + nn.Conv2d(in_dim, out_dim, 3, padding=1)) + self.resample = Resample(in_dim, in_dim, scale_factor, use_conv=False) + self.embedding = nn.Sequential( + nn.SiLU(), + nn.Linear(embed_dim, + out_dim * 2 if use_scale_shift_norm else out_dim)) + self.layer2 = nn.Sequential( + nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout), + nn.Conv2d(out_dim, out_dim, 3, padding=1)) + self.shortcut = nn.Identity() if in_dim == out_dim else nn.Conv2d( + in_dim, out_dim, 1) + + # zero out the last layer params + nn.init.zeros_(self.layer2[-1].weight) + + def forward(self, x, e): + identity = self.resample(x) + x = self.layer1[-1](self.resample(self.layer1[:-1](x))) + e = self.embedding(e).unsqueeze(-1).unsqueeze(-1).type(x.dtype) + if self.use_scale_shift_norm: + scale, shift = e.chunk(2, dim=1) + x = self.layer2[0](x) * (1 + scale) + shift + x = self.layer2[1:](x) + else: + x = x + e + x = self.layer2(x) + x = x + self.shortcut(identity) + return x + + +class AttentionBlock(nn.Module): + + def __init__(self, dim, context_dim=None, num_heads=None, head_dim=None): + # consider head_dim first, then num_heads + num_heads = dim // head_dim if head_dim else num_heads + head_dim = dim // num_heads + assert num_heads * head_dim == dim + super(AttentionBlock, self).__init__() + self.dim = dim + self.context_dim = context_dim + self.num_heads = num_heads + self.head_dim = head_dim + self.scale = math.pow(head_dim, -0.25) + + # layers + self.norm = nn.GroupNorm(32, dim) + self.to_qkv = nn.Conv2d(dim, dim * 3, 1) + if context_dim is not None: + self.context_kv = nn.Linear(context_dim, dim * 2) + self.proj = nn.Conv2d(dim, dim, 1) + + # zero out the last layer params + nn.init.zeros_(self.proj.weight) + + def forward(self, x, context=None): + r"""x: [B, C, H, W]. + context: [B, L, C] or None. + """ + identity = x + b, c, h, w, n, d = *x.size(), self.num_heads, self.head_dim + + # compute query, key, value + x = self.norm(x) + q, k, v = self.to_qkv(x).view(b, n * 3, d, h * w).chunk(3, dim=1) + if context is not None: + ck, cv = self.context_kv(context).reshape(b, -1, n * 2, + d).permute(0, 2, 3, + 1).chunk( + 2, dim=1) + k = torch.cat([ck, k], dim=-1) + v = torch.cat([cv, v], dim=-1) + + # compute attention + attn = torch.matmul(q.transpose(-1, -2) * self.scale, k * self.scale) + attn = F.softmax(attn, dim=-1) + + # gather context + x = torch.matmul(v, attn.transpose(-1, -2)) + x = x.reshape(b, c, h, w) + + # output + x = self.proj(x) + return x + identity + + +class Decoder(nn.Module): + + def __init__(self, + in_dim=3, + dim=512, + y_dim=512, + context_dim=512, + out_dim=6, + dim_mult=[1, 2, 3, 4], + num_heads=None, + head_dim=64, + num_res_blocks=3, + attn_scales=[1 / 2, 1 / 4, 1 / 8], + resblock_resample=True, + use_scale_shift_norm=True, + dropout=0.1): + embed_dim = dim * 4 + super(Decoder, self).__init__() + self.in_dim = in_dim + self.dim = dim + self.y_dim = y_dim + self.context_dim = context_dim + self.embed_dim = embed_dim + self.out_dim = out_dim + self.dim_mult = dim_mult + self.num_heads = num_heads + self.head_dim = head_dim + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.resblock_resample = resblock_resample + self.use_scale_shift_norm = use_scale_shift_norm + + # params + enc_dims = [dim * u for u in [1] + dim_mult] + dec_dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + shortcut_dims = [] + scale = 1.0 + + # embeddings + self.time_embedding = nn.Sequential( + nn.Linear(dim, embed_dim), nn.SiLU(), + nn.Linear(embed_dim, embed_dim)) + self.y_embedding = nn.Sequential( + nn.Linear(y_dim, embed_dim), nn.SiLU(), + nn.Linear(embed_dim, embed_dim)) + self.context_embedding = nn.Sequential( + nn.Linear(y_dim, embed_dim), nn.SiLU(), + nn.Linear(embed_dim, context_dim * 4)) + + # encoder + self.encoder = nn.ModuleList( + [nn.Conv2d(self.in_dim, dim, 3, padding=1)]) + shortcut_dims.append(dim) + for i, (in_dim, + out_dim) in enumerate(zip(enc_dims[:-1], enc_dims[1:])): + for j in range(num_res_blocks): + # residual (+attention) blocks + block = nn.ModuleList([ + ResidualBlock(in_dim, embed_dim, out_dim, + use_scale_shift_norm, 1.0, dropout) + ]) + if scale in attn_scales: + block.append( + AttentionBlock(out_dim, context_dim, num_heads, + head_dim)) + in_dim = out_dim + self.encoder.append(block) + shortcut_dims.append(out_dim) + + # downsample + if i != len(dim_mult) - 1 and j == num_res_blocks - 1: + if resblock_resample: + downsample = ResidualBlock(out_dim, embed_dim, out_dim, + use_scale_shift_norm, 0.5, + dropout) + else: + downsample = Resample( + out_dim, out_dim, 0.5, use_conv=True) + shortcut_dims.append(out_dim) + scale /= 2.0 + self.encoder.append(downsample) + + # middle + self.middle = nn.ModuleList([ + ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm, + 1.0, dropout), + AttentionBlock(out_dim, context_dim, num_heads, head_dim), + ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm, + 1.0, dropout) + ]) + + # decoder + self.decoder = nn.ModuleList() + for i, (in_dim, + out_dim) in enumerate(zip(dec_dims[:-1], dec_dims[1:])): + for j in range(num_res_blocks + 1): + # residual (+attention) blocks + block = nn.ModuleList([ + ResidualBlock(in_dim + shortcut_dims.pop(), embed_dim, + out_dim, use_scale_shift_norm, 1.0, dropout) + ]) + if scale in attn_scales: + block.append( + AttentionBlock(out_dim, context_dim, num_heads, + head_dim)) + in_dim = out_dim + + # upsample + if i != len(dim_mult) - 1 and j == num_res_blocks: + if resblock_resample: + upsample = ResidualBlock(out_dim, embed_dim, out_dim, + use_scale_shift_norm, 2.0, + dropout) + else: + upsample = Resample( + out_dim, out_dim, 2.0, use_conv=True) + scale *= 2.0 + block.append(upsample) + self.decoder.append(block) + + # head + self.head = nn.Sequential( + nn.GroupNorm(32, out_dim), nn.SiLU(), + nn.Conv2d(out_dim, self.out_dim, 3, padding=1)) + + # zero out the last layer params + nn.init.zeros_(self.head[-1].weight) + + def forward(self, x, t, y): + # embeddings + e = self.time_embedding(sinusoidal_embedding( + t, self.dim)) + self.y_embedding(y) + context = self.context_embedding(y).view(-1, 4, self.context_dim) + + # encoder + xs = [] + for block in self.encoder: + x = self._forward_single(block, x, e, context) + xs.append(x) + + # middle + for block in self.middle: + x = self._forward_single(block, x, e, context) + + # decoder + for block in self.decoder: + x = torch.cat([x, xs.pop()], dim=1) + x = self._forward_single(block, x, e, context) + + # head + x = self.head(x) + return x + + def _forward_single(self, module, x, e, context): + if isinstance(module, ResidualBlock): + x = module(x, e) + elif isinstance(module, AttentionBlock): + x = module(x, context) + elif isinstance(module, nn.ModuleList): + for block in module: + x = self._forward_single(block, x, e, context) + else: + x = module(x) + return x diff --git a/modelscope/models/multi_modal/multi_stage_diffusion/gaussian_diffusion.py b/modelscope/models/multi_modal/multi_stage_diffusion/gaussian_diffusion.py new file mode 100644 index 00000000..a4fc52e0 --- /dev/null +++ b/modelscope/models/multi_modal/multi_stage_diffusion/gaussian_diffusion.py @@ -0,0 +1,641 @@ +# The implementation here is modified based on latent diffusion, publicly available +# at https://github.com/CompVis/latent-diffusion. + +import math + +import torch + +__all__ = ['GaussianDiffusion', 'beta_schedule'] + + +def kl_divergence(mu1, logvar1, mu2, logvar2): + u1 = -1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + u2 = ((mu1 - mu2)**2) * torch.exp(-logvar2) + return 0.5 * (u1 + u2) + + +def standard_normal_cdf(x): + r"""A fast approximation of the cumulative distribution function of the standard normal. + """ + return 0.5 * (1.0 + torch.tanh( + math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) + + +def discretized_gaussian_log_likelihood(x0, mean, log_scale): + assert x0.shape == mean.shape == log_scale.shape + cx = x0 - mean + inv_stdv = torch.exp(-log_scale) + cdf_plus = standard_normal_cdf(inv_stdv * (cx + 1.0 / 255.0)) + cdf_min = standard_normal_cdf(inv_stdv * (cx - 1.0 / 255.0)) + log_cdf_plus = torch.log(cdf_plus.clamp(min=1e-12)) + log_one_minus_cdf_min = torch.log((1.0 - cdf_min).clamp(min=1e-12)) + cdf_delta = cdf_plus - cdf_min + log_probs = torch.where( + x0 < -0.999, log_cdf_plus, + torch.where(x0 > 0.999, log_one_minus_cdf_min, + torch.log(cdf_delta.clamp(min=1e-12)))) + assert log_probs.shape == x0.shape + return log_probs + + +def _i(tensor, t, x): + r"""Index tensor using t and format the output according to x. + """ + shape = (x.size(0), ) + (1, ) * (x.ndim - 1) + return tensor[t].view(shape).to(x) + + +def beta_schedule(schedule, + num_timesteps=1000, + init_beta=None, + last_beta=None): + if schedule == 'linear': + scale = 1000.0 / num_timesteps + init_beta = init_beta or scale * 0.0001 + last_beta = last_beta or scale * 0.02 + return torch.linspace( + init_beta, last_beta, num_timesteps, dtype=torch.float64) + elif schedule == 'quadratic': + init_beta = init_beta or 0.0015 + last_beta = last_beta or 0.0195 + return torch.linspace( + init_beta**0.5, last_beta**0.5, num_timesteps, + dtype=torch.float64)**2 + elif schedule == 'cosine': + betas = [] + for step in range(num_timesteps): + t1 = step / num_timesteps + t2 = (step + 1) / num_timesteps + fn_t1 = math.cos((t1 + 0.008) / 1.008 * math.pi / 2)**2 + fn_t2 = math.cos((t2 + 0.008) / 1.008 * math.pi / 2)**2 + betas.append(min(1.0 - fn_t2 / fn_t1, 0.999)) + return torch.tensor(betas, dtype=torch.float64) + else: + raise ValueError(f'Unsupported schedule: {schedule}') + + +class GaussianDiffusion(object): + + def __init__(self, + betas, + mean_type='eps', + var_type='learned_range', + loss_type='mse', + rescale_timesteps=False): + # check input + if not isinstance(betas, torch.DoubleTensor): + betas = torch.tensor(betas, dtype=torch.float64) + assert min(betas) > 0 and max(betas) <= 1 + assert mean_type in ['x0', 'x_{t-1}', 'eps'] + assert var_type in [ + 'learned', 'learned_range', 'fixed_large', 'fixed_small' + ] + assert loss_type in [ + 'mse', 'rescaled_mse', 'kl', 'rescaled_kl', 'l1', 'rescaled_l1' + ] + self.betas = betas + self.num_timesteps = len(betas) + self.mean_type = mean_type + self.var_type = var_type + self.loss_type = loss_type + self.rescale_timesteps = rescale_timesteps + + # alphas + alphas = 1 - self.betas + self.alphas_cumprod = torch.cumprod(alphas, dim=0) + self.alphas_cumprod_prev = torch.cat( + [alphas.new_ones([1]), self.alphas_cumprod[:-1]]) + self.alphas_cumprod_next = torch.cat( + [self.alphas_cumprod[1:], + alphas.new_zeros([1])]) + + # q(x_t | x_{t-1}) + self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod) + self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 + - self.alphas_cumprod) + self.log_one_minus_alphas_cumprod = torch.log(1.0 + - self.alphas_cumprod) + self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod) + self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod + - 1) + + # q(x_{t-1} | x_t, x_0) + self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / ( + 1.0 - self.alphas_cumprod) + self.posterior_log_variance_clipped = torch.log( + self.posterior_variance.clamp(1e-20)) + self.posterior_mean_coef1 = betas * torch.sqrt( + self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + self.posterior_mean_coef2 = ( + 1.0 - self.alphas_cumprod_prev) * torch.sqrt(alphas) / ( + 1.0 - self.alphas_cumprod) + + def q_sample(self, x0, t, noise=None): + r"""Sample from q(x_t | x_0). + """ + noise = torch.randn_like(x0) if noise is None else noise + u1 = _i(self.sqrt_alphas_cumprod, t, x0) * x0 + u2 = _i(self.sqrt_one_minus_alphas_cumprod, t, x0) * noise + return u1 + u2 + + def q_mean_variance(self, x0, t): + r"""Distribution of q(x_t | x_0). + """ + mu = _i(self.sqrt_alphas_cumprod, t, x0) * x0 + var = _i(1.0 - self.alphas_cumprod, t, x0) + log_var = _i(self.log_one_minus_alphas_cumprod, t, x0) + return mu, var, log_var + + def q_posterior_mean_variance(self, x0, xt, t): + r"""Distribution of q(x_{t-1} | x_t, x_0). + """ + mu = _i(self.posterior_mean_coef1, t, xt) * x0 + _i( + self.posterior_mean_coef2, t, xt) * xt + var = _i(self.posterior_variance, t, xt) + log_var = _i(self.posterior_log_variance_clipped, t, xt) + return mu, var, log_var + + @torch.no_grad() + def p_sample(self, + xt, + t, + model, + model_kwargs={}, + clamp=None, + percentile=None, + condition_fn=None, + guide_scale=None): + r"""Sample from p(x_{t-1} | x_t). + - condition_fn: for classifier-based guidance (guided-diffusion). + - guide_scale: for classifier-free guidance (glide/dalle-2). + """ + # predict distribution of p(x_{t-1} | x_t) + mu, var, log_var, x0 = self.p_mean_variance(xt, t, model, model_kwargs, + clamp, percentile, + guide_scale) + + # random sample (with optional conditional function) + noise = torch.randn_like(xt) + shape = (-1, *((1, ) * (xt.ndim - 1))) + mask = t.ne(0).float().view(shape) # no noise when t == 0 + if condition_fn is not None: + grad = condition_fn(xt, self._scale_timesteps(t), **model_kwargs) + mu = mu.float() + var * grad.float() + xt_1 = mu + mask * torch.exp(0.5 * log_var) * noise + return xt_1, x0 + + @torch.no_grad() + def p_sample_loop(self, + noise, + model, + model_kwargs={}, + clamp=None, + percentile=None, + condition_fn=None, + guide_scale=None): + r"""Sample from p(x_{t-1} | x_t) p(x_{t-2} | x_{t-1}) ... p(x_0 | x_1). + """ + # prepare input + b = noise.size(0) + xt = noise + + # diffusion process + for step in torch.arange(self.num_timesteps).flip(0): + t = torch.full((b, ), step, dtype=torch.long, device=xt.device) + xt, _ = self.p_sample(xt, t, model, model_kwargs, clamp, + percentile, condition_fn, guide_scale) + return xt + + def p_mean_variance(self, + xt, + t, + model, + model_kwargs={}, + clamp=None, + percentile=None, + guide_scale=None): + r"""Distribution of p(x_{t-1} | x_t). + """ + # predict distribution + if guide_scale is None: + out = model(xt, self._scale_timesteps(t), **model_kwargs) + else: + # classifier-free guidance + # (model_kwargs[0]: conditional kwargs; model_kwargs[1]: non-conditional kwargs) + assert isinstance(model_kwargs, list) and len(model_kwargs) == 2 + y_out = model(xt, self._scale_timesteps(t), **model_kwargs[0]) + u_out = model(xt, self._scale_timesteps(t), **model_kwargs[1]) + cond = self.var_type.startswith('fixed') + dim = y_out.size(1) if cond else y_out.size(1) // 2 + u1 = u_out[:, :dim] + u2 = guide_scale * (y_out[:, :dim] - u_out[:, :dim]) + out = torch.cat([u1 + u2, y_out[:, dim:]], dim=1) + + # compute variance + if self.var_type == 'learned': + out, log_var = out.chunk(2, dim=1) + var = torch.exp(log_var) + elif self.var_type == 'learned_range': + out, fraction = out.chunk(2, dim=1) + min_log_var = _i(self.posterior_log_variance_clipped, t, xt) + max_log_var = _i(torch.log(self.betas), t, xt) + fraction = (fraction + 1) / 2.0 + log_var = fraction * max_log_var + (1 - fraction) * min_log_var + var = torch.exp(log_var) + elif self.var_type == 'fixed_large': + var = _i( + torch.cat([self.posterior_variance[1:2], self.betas[1:]]), t, + xt) + log_var = torch.log(var) + elif self.var_type == 'fixed_small': + var = _i(self.posterior_variance, t, xt) + log_var = _i(self.posterior_log_variance_clipped, t, xt) + + # compute mean and x0 + if self.mean_type == 'x_{t-1}': + mu = out # x_{t-1} + u1 = _i(1.0 / self.posterior_mean_coef1, t, xt) * mu + u2 = _i(self.posterior_mean_coef2 / self.posterior_mean_coef1, t, + xt) * xt + x0 = u1 - u2 + elif self.mean_type == 'x0': + x0 = out + mu, _, _ = self.q_posterior_mean_variance(x0, xt, t) + elif self.mean_type == 'eps': + u1 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt + u2 = _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * out + x0 = u1 - u2 + mu, _, _ = self.q_posterior_mean_variance(x0, xt, t) + + # restrict the range of x0 + if percentile is not None: + assert percentile > 0 and percentile <= 1 # e.g., 0.995 + s = torch.quantile( + x0.flatten(1).abs(), percentile, + dim=1).clamp_(1.0).view(-1, 1, 1, 1) + x0 = torch.min(s, torch.max(-s, x0)) / s + elif clamp is not None: + x0 = x0.clamp(-clamp, clamp) + return mu, var, log_var, x0 + + @torch.no_grad() + def ddim_sample(self, + xt, + t, + model, + model_kwargs={}, + clamp=None, + percentile=None, + condition_fn=None, + guide_scale=None, + ddim_timesteps=20, + eta=0.0): + r"""Sample from p(x_{t-1} | x_t) using DDIM. + - condition_fn: for classifier-based guidance (guided-diffusion). + - guide_scale: for classifier-free guidance (glide/dalle-2). + """ + stride = self.num_timesteps // ddim_timesteps + + # predict distribution of p(x_{t-1} | x_t) + _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp, + percentile, guide_scale) + if condition_fn is not None: + # x0 -> eps + alpha = _i(self.alphas_cumprod, t, xt) + u1 = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) + u2 = _i(self.sqrt_recipm1_alphas_cumprod, t, xt) + eps = u1 / u2 + eps = eps - (1 - alpha).sqrt() * condition_fn( + xt, self._scale_timesteps(t), **model_kwargs) + + # eps -> x0 + u1 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt + u2 = _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * eps + x0 = u1 - u2 + + # derive variables + u1 = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) + u2 = _i(self.sqrt_recipm1_alphas_cumprod, t, xt) + eps = u1 / u2 + alphas = _i(self.alphas_cumprod, t, xt) + alphas_prev = _i(self.alphas_cumprod, (t - stride).clamp(0), xt) + u1 = (1 - alphas_prev) / (1 - alphas) + u2 = (1 - alphas / alphas_prev) + sigmas = eta * torch.sqrt(u1 * u2) + + # random sample + noise = torch.randn_like(xt) + direction = torch.sqrt(1 - alphas_prev - sigmas**2) * eps + mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1))) + xt_1 = torch.sqrt(alphas_prev) * x0 + direction + mask * sigmas * noise + return xt_1, x0 + + @torch.no_grad() + def ddim_sample_loop(self, + noise, + model, + model_kwargs={}, + clamp=None, + percentile=None, + condition_fn=None, + guide_scale=None, + ddim_timesteps=20, + eta=0.0): + # prepare input + b = noise.size(0) + xt = noise + + # diffusion process (TODO: clamp is inaccurate! Consider replacing the stride by explicit prev/next steps) + steps = (1 + torch.arange(0, self.num_timesteps, + self.num_timesteps // ddim_timesteps)).clamp( + 0, self.num_timesteps - 1).flip(0) + for step in steps: + t = torch.full((b, ), step, dtype=torch.long, device=xt.device) + xt, _ = self.ddim_sample(xt, t, model, model_kwargs, clamp, + percentile, condition_fn, guide_scale, + ddim_timesteps, eta) + return xt + + @torch.no_grad() + def ddim_reverse_sample(self, + xt, + t, + model, + model_kwargs={}, + clamp=None, + percentile=None, + guide_scale=None, + ddim_timesteps=20): + r"""Sample from p(x_{t+1} | x_t) using DDIM reverse ODE (deterministic). + """ + stride = self.num_timesteps // ddim_timesteps + + # predict distribution of p(x_{t-1} | x_t) + _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp, + percentile, guide_scale) + + # derive variables + u1 = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) + u2 = _i(self.sqrt_recipm1_alphas_cumprod, t, xt) + eps = u1 / u2 + + alphas_next = _i( + torch.cat( + [self.alphas_cumprod, + self.alphas_cumprod.new_zeros([1])]), + (t + stride).clamp(0, self.num_timesteps), xt) + + # reverse sample + mu = torch.sqrt(alphas_next) * x0 + torch.sqrt(1 - alphas_next) * eps + return mu, x0 + + @torch.no_grad() + def ddim_reverse_sample_loop(self, + x0, + model, + model_kwargs={}, + clamp=None, + percentile=None, + guide_scale=None, + ddim_timesteps=20): + # prepare input + b = x0.size(0) + xt = x0 + + # reconstruction steps + steps = torch.arange(0, self.num_timesteps, + self.num_timesteps // ddim_timesteps) + for step in steps: + t = torch.full((b, ), step, dtype=torch.long, device=xt.device) + xt, _ = self.ddim_reverse_sample(xt, t, model, model_kwargs, clamp, + percentile, guide_scale, + ddim_timesteps) + return xt + + @torch.no_grad() + def plms_sample(self, + xt, + t, + model, + model_kwargs={}, + clamp=None, + percentile=None, + condition_fn=None, + guide_scale=None, + plms_timesteps=20): + r"""Sample from p(x_{t-1} | x_t) using PLMS. + - condition_fn: for classifier-based guidance (guided-diffusion). + - guide_scale: for classifier-free guidance (glide/dalle-2). + """ + stride = self.num_timesteps // plms_timesteps + + # function for compute eps + def compute_eps(xt, t): + # predict distribution of p(x_{t-1} | x_t) + _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, + clamp, percentile, guide_scale) + + # condition + if condition_fn is not None: + # x0 -> eps + alpha = _i(self.alphas_cumprod, t, xt) + u1 = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) + u2 = _i(self.sqrt_recipm1_alphas_cumprod, t, xt) + eps = u1 / u2 + eps = eps - (1 - alpha).sqrt() * condition_fn( + xt, self._scale_timesteps(t), **model_kwargs) + + # eps -> x0 + u1 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt + u2 = _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * eps + x0 = u1 - u2 + + # derive eps + u1 = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) + u2 = _i(self.sqrt_recipm1_alphas_cumprod, t, xt) + eps = u1 / u2 + return eps + + # function for compute x_0 and x_{t-1} + def compute_x0(eps, t): + # eps -> x0 + u1 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt + u2 = _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * eps + x0 = u1 - u2 + + # deterministic sample + alphas_prev = _i(self.alphas_cumprod, (t - stride).clamp(0), xt) + direction = torch.sqrt(1 - alphas_prev) * eps + # mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1))) + xt_1 = torch.sqrt(alphas_prev) * x0 + direction + return xt_1, x0 + + # PLMS sample + eps = compute_eps(xt, t) + if len(eps_cache) == 0: + # 2nd order pseudo improved Euler + xt_1, x0 = compute_x0(eps, t) + eps_next = compute_eps(xt_1, (t - stride).clamp(0)) + eps_prime = (eps + eps_next) / 2.0 + elif len(eps_cache) == 1: + # 2nd order pseudo linear multistep (Adams-Bashforth) + eps_prime = (3 * eps - eps_cache[-1]) / 2.0 + elif len(eps_cache) == 2: + # 3nd order pseudo linear multistep (Adams-Bashforth) + eps_prime = (23 * eps - 16 * eps_cache[-1] + + 5 * eps_cache[-2]) / 12.0 + elif len(eps_cache) >= 3: + # 4nd order pseudo linear multistep (Adams-Bashforth) + eps_prime = (55 * eps - 59 * eps_cache[-1] + 37 * eps_cache[-2] + - 9 * eps_cache[-3]) / 24.0 + xt_1, x0 = compute_x0(eps_prime, t) + return xt_1, x0, eps + + @torch.no_grad() + def plms_sample_loop(self, + noise, + model, + model_kwargs={}, + clamp=None, + percentile=None, + condition_fn=None, + guide_scale=None, + plms_timesteps=20): + # prepare input + b = noise.size(0) + xt = noise + + # diffusion process + steps = (1 + torch.arange(0, self.num_timesteps, + self.num_timesteps // plms_timesteps)).clamp( + 0, self.num_timesteps - 1).flip(0) + eps_cache = [] + for step in steps: + # PLMS sampling step + t = torch.full((b, ), step, dtype=torch.long, device=xt.device) + xt, _, eps = self.plms_sample(xt, t, model, model_kwargs, clamp, + percentile, condition_fn, + guide_scale, plms_timesteps, + eps_cache) + + # update eps cache + eps_cache.append(eps) + if len(eps_cache) >= 4: + eps_cache.pop(0) + return xt + + def loss(self, x0, t, model, model_kwargs={}, noise=None, input_x0=None): + noise = torch.randn_like(x0) if noise is None else noise + input_x0 = x0 if input_x0 is None else input_x0 + xt = self.q_sample(input_x0, t, noise=noise) + + # compute loss + if self.loss_type in ['kl', 'rescaled_kl']: + loss, _ = self.variational_lower_bound(x0, xt, t, model, + model_kwargs) + if self.loss_type == 'rescaled_kl': + loss = loss * self.num_timesteps + elif self.loss_type in ['mse', 'rescaled_mse', 'l1', 'rescaled_l1']: + out = model(xt, self._scale_timesteps(t), **model_kwargs) + + # VLB for variation + loss_vlb = 0.0 + if self.var_type in ['learned', 'learned_range']: + out, var = out.chunk(2, dim=1) + frozen = torch.cat([ + out.detach(), var + ], dim=1) # learn var without affecting the prediction of mean + loss_vlb, _ = self.variational_lower_bound( + x0, xt, t, model=lambda *args, **kwargs: frozen) + if self.loss_type.startswith('rescaled_'): + loss_vlb = loss_vlb * self.num_timesteps / 1000.0 + + # MSE/L1 for x0/eps + target = { + 'eps': noise, + 'x0': x0, + 'x_{t-1}': self.q_posterior_mean_variance(x0, xt, t)[0] + }[self.mean_type] + loss = (out - target).pow(1 if self.loss_type.endswith('l1') else 2 + ).abs().flatten(1).mean(dim=1) + + # total loss + loss = loss + loss_vlb + return loss + + def variational_lower_bound(self, + x0, + xt, + t, + model, + model_kwargs={}, + clamp=None, + percentile=None): + # compute groundtruth and predicted distributions + mu1, _, log_var1 = self.q_posterior_mean_variance(x0, xt, t) + mu2, _, log_var2, x0 = self.p_mean_variance(xt, t, model, model_kwargs, + clamp, percentile) + + # compute KL loss + kl = kl_divergence(mu1, log_var1, mu2, log_var2) + kl = kl.flatten(1).mean(dim=1) / math.log(2.0) + + # compute discretized NLL loss (for p(x0 | x1) only) + nll = -discretized_gaussian_log_likelihood( + x0, mean=mu2, log_scale=0.5 * log_var2) + nll = nll.flatten(1).mean(dim=1) / math.log(2.0) + + # NLL for p(x0 | x1) and KL otherwise + vlb = torch.where(t == 0, nll, kl) + return vlb, x0 + + @torch.no_grad() + def variational_lower_bound_loop(self, + x0, + model, + model_kwargs={}, + clamp=None, + percentile=None): + r"""Compute the entire variational lower bound, measured in bits-per-dim. + """ + # prepare input and output + b = x0.size(0) + metrics = {'vlb': [], 'mse': [], 'x0_mse': []} + + # loop + for step in torch.arange(self.num_timesteps).flip(0): + # compute VLB + t = torch.full((b, ), step, dtype=torch.long, device=x0.device) + noise = torch.randn_like(x0) + xt = self.q_sample(x0, t, noise) + vlb, pred_x0 = self.variational_lower_bound( + x0, xt, t, model, model_kwargs, clamp, percentile) + + # predict eps from x0 + u1 = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) + u2 = _i(self.sqrt_recipm1_alphas_cumprod, t, xt) + eps = u1 / u2 + + # collect metrics + metrics['vlb'].append(vlb) + metrics['x0_mse'].append( + (pred_x0 - x0).square().flatten(1).mean(dim=1)) + metrics['mse'].append( + (eps - noise).square().flatten(1).mean(dim=1)) + metrics = {k: torch.stack(v, dim=1) for k, v in metrics.items()} + + # compute the prior KL term for VLB, measured in bits-per-dim + mu, _, log_var = self.q_mean_variance(x0, t) + kl_prior = kl_divergence(mu, log_var, torch.zeros_like(mu), + torch.zeros_like(log_var)) + kl_prior = kl_prior.flatten(1).mean(dim=1) / math.log(2.0) + + # update metrics + metrics['prior_bits_per_dim'] = kl_prior + metrics['total_bits_per_dim'] = metrics['vlb'].sum(dim=1) + kl_prior + return metrics + + def _scale_timesteps(self, t): + if self.rescale_timesteps: + return t.float() * 1000.0 / self.num_timesteps + return t diff --git a/modelscope/models/multi_modal/multi_stage_diffusion/model.py b/modelscope/models/multi_modal/multi_stage_diffusion/model.py new file mode 100644 index 00000000..c2d83b34 --- /dev/null +++ b/modelscope/models/multi_modal/multi_stage_diffusion/model.py @@ -0,0 +1,265 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import math +import os.path as osp +from typing import Any, Dict + +import json +import numpy as np +import torch +import torch.cuda.amp as amp +import torch.nn as nn +import torch.nn.functional as F +from PIL import Image + +from modelscope.metainfo import Models +from modelscope.models import TorchModel +from modelscope.models.builder import MODELS +from modelscope.models.multi_modal.multi_stage_diffusion.clip import CLIP +from modelscope.models.multi_modal.multi_stage_diffusion.decoder import Decoder +from modelscope.models.multi_modal.multi_stage_diffusion.gaussian_diffusion import ( + GaussianDiffusion, beta_schedule) +from modelscope.models.multi_modal.multi_stage_diffusion.prior import Prior +from modelscope.models.multi_modal.multi_stage_diffusion.tokenizer import ( + CLIPTokenizer, XGLMTokenizer) +from modelscope.models.multi_modal.multi_stage_diffusion.upsampler import ( + Upsampler256, Upsampler1024) +from modelscope.models.multi_modal.multi_stage_diffusion.xglm import XGLM +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + +__all__ = ['MultiStageDiffusionForTextToImageSynthesis'] + + +def make_diffusion(schedule, + num_timesteps=1000, + init_beta=None, + last_beta=None, + mean_type='eps', + var_type='fixed_small'): + betas = beta_schedule(schedule, num_timesteps, init_beta, last_beta) + diffusion = GaussianDiffusion( + betas, mean_type=mean_type, var_type=var_type) + return diffusion + + +class UnCLIP(nn.Module): + + def __init__(self, model_dir): + super(UnCLIP, self).__init__() + self.model_dir = model_dir + self.config = json.load(open(f'{model_dir}/{ModelFile.CONFIGURATION}')) + + # modules + self.clip = CLIP(**self.config['clip']).fp16() + self.xglm = XGLM(**self.config['xglm']) + self.prior = Prior(**self.config['prior']) + self.decoder = Decoder(**self.config['decoder']) + self.upsampler256 = Upsampler256(**self.config['upsampler256']) + self.upsampler1024 = Upsampler1024(**self.config['upsampler1024']) + + # diffusions + self.prior_diffusion = make_diffusion(**self.config['prior_diffusion']) + self.decoder_diffusion = make_diffusion( + **self.config['decoder_diffusion']) + self.upsampler256_diffusion = make_diffusion( + **self.config['upsampler256_diffusion']) + self.upsampler1024_diffusion = make_diffusion( + **self.config['upsampler1024_diffusion']) + + # tokenizers + self.clip_tokenizer = CLIPTokenizer( + bpe_path=f'{model_dir}/bpe_simple_vocab_16e6.txt.gz') + self.xglm_tokenizer = XGLMTokenizer(model_dir=model_dir) + + def forward(self, *args, **kwargs): + raise NotImplementedError( + '"forward" is not implemented. Use "synthesis" instead.') + + @torch.no_grad() + def synthesis(self, + text='A photo of a confused grizzly bear in calculus class.', + tokenizer='clip', + batch_size=4, + timesteps_prior=100, + timesteps_64=50, + timesteps_256=20, + timesteps_1024=20, + guide_prior=3.0, + guide_64=7.0, + guide_256=3.0, + guide_1024=3.0, + eta_prior=0.0, + eta_64=0.0, + eta_256=0.0, + eta_1024=0.0): + device = next(self.parameters()).device + + # check params + assert all([ + t > 0 and t <= 1000 for t in + [timesteps_prior, timesteps_64, timesteps_256, timesteps_1024] + ]) + assert all([ + g > 1 and g < 15 + for g in [guide_prior, guide_64, guide_256, guide_1024] + ]) + assert all([ + e >= 0 and e <= 1.0 + for e in [eta_prior, eta_64, eta_256, eta_1024] + ]) + assert batch_size >= 1 and batch_size <= 16 + + # tokenize the text + if tokenizer == 'clip': + y = F.normalize( + self.clip.textual(self.clip_tokenizer([text]).to(device)), + p=2, + dim=1) + zero_y = F.normalize( + self.clip.textual(self.clip_tokenizer(['']).to(device)), + p=2, + dim=1) + elif tokenizer == 'xglm': + y = F.normalize( + self.xglm(*to_device(self.xglm_tokenizer([text]), device)), + p=2, + dim=1) + zero_y = F.normalize( + self.xglm(*to_device(self.xglm_tokenizer(['']), device)), + p=2, + dim=1) + else: + raise ValueError( + f'Expected tokenizer to be one of "clip" or "xglm", but got {tokenizer}' + ) + y = math.sqrt(y.size(1)) * y.repeat(batch_size, 1) + zero_y = math.sqrt(zero_y.size(1)) * zero_y.repeat(batch_size, 1) + + # synthesis + with amp.autocast(enabled=True): + # prior + x0 = self.prior_diffusion.ddim_sample_loop( + noise=torch.randn_like(y), + model=self.prior, + model_kwargs=[{ + 'y': y + }, { + 'y': zero_y + }], + guide_scale=guide_prior, + ddim_timesteps=timesteps_prior, + eta=eta_prior) + + # decoder + imgs64 = self.decoder_diffusion.ddim_sample_loop( + noise=torch.randn(batch_size, 3, 64, 64).to(device), + model=self.decoder, + model_kwargs=[{ + 'y': x0 + }, { + 'y': torch.zeros_like(x0) + }], + guide_scale=guide_64, + percentile=0.995, + ddim_timesteps=timesteps_64, + eta=eta_64).clamp_(-1, 1) + + # upsampler256 + imgs256 = F.interpolate( + imgs64, scale_factor=4.0, mode='bilinear', align_corners=False) + imgs256 = self.upsampler256_diffusion.ddim_sample_loop( + noise=torch.randn_like(imgs256), + model=self.upsampler256, + model_kwargs=[{ + 'y': y, + 'concat': imgs256 + }, { + 'y': zero_y, + 'concat': imgs256 + }], + guide_scale=guide_256, + percentile=0.995, + ddim_timesteps=timesteps_256, + eta=eta_256).clamp_(-1, 1) + + # upsampler1024 + imgs1024 = F.interpolate( + imgs256, + scale_factor=4.0, + mode='bilinear', + align_corners=False) + imgs1024 = self.upsampler1024_diffusion.ddim_sample_loop( + noise=torch.randn_like(imgs1024), + model=self.upsampler1024, + model_kwargs=[{ + 'y': y, + 'concat': imgs1024 + }, { + 'y': zero_y, + 'concat': imgs1024 + }], + guide_scale=guide_1024, + percentile=0.995, + ddim_timesteps=timesteps_1024, + eta=eta_1024).clamp_(-1, 1) + + # output ([B, C, H, W] within range [0, 1]) + imgs1024 = imgs1024.add_(1).mul_(255 / 2.0).permute(0, 2, 3, 1).cpu() + imgs1024 = [ + Image.fromarray(np.array(u, dtype=np.uint8)) for u in imgs1024 + ] + return imgs1024 + + +@MODELS.register_module( + Tasks.text_to_image_synthesis, module_name=Models.multi_stage_diffusion) +class MultiStageDiffusionForTextToImageSynthesis(TorchModel): + + def __init__(self, model_dir, device_id=-1): + super().__init__(model_dir=model_dir, device_id=device_id) + model = UnCLIP(model_dir=model_dir) + pretrained_params = torch.load( + osp.join(model_dir, ModelFile.TORCH_MODEL_BIN_FILE), 'cpu') + model.load_state_dict(pretrained_params) + model.eval() + + self.device_id = device_id + if self.device_id >= 0: + self.device = torch.device(f'cuda:{self.device_id}') + model.to('cuda:{}'.format(self.device_id)) + logger.info('Use GPU: {}'.format(self.device_id)) + else: + self.device = torch.device('cpu') + logger.info('Use CPU for inference') + self.model = model + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + if not isinstance(input, dict): + raise ValueError( + f'Expected the input to be a dictionary, but got {type(input)}' + ) + if 'text' not in input: + raise ValueError('input should contain "text", but not found') + + # ddim sampling + imgs = self.model.synthesis( + text=input.get('text'), + tokenizer=input.get('tokenizer', 'clip'), + batch_size=input.get('batch_size', 4), + timesteps_prior=input.get('timesteps_prior', 100), + timesteps_64=input.get('timesteps_64', 50), + timesteps_256=input.get('timesteps_256', 20), + timesteps_1024=input.get('timesteps_1024', 20), + guide_prior=input.get('guide_prior', 3.0), + guide_64=input.get('guide_64', 7.0), + guide_256=input.get('guide_256', 3.0), + guide_1024=input.get('guide_1024', 3.0), + eta_prior=input.get('eta_prior', 0.0), + eta_64=input.get('eta_64', 0.0), + eta_256=input.get('eta_256', 0.0), + eta_1024=input.get('eta_1024', 0.0)) + imgs = [np.array(u)[..., ::-1] for u in imgs] + return imgs diff --git a/modelscope/models/multi_modal/multi_stage_diffusion/prior.py b/modelscope/models/multi_modal/multi_stage_diffusion/prior.py new file mode 100644 index 00000000..380fa467 --- /dev/null +++ b/modelscope/models/multi_modal/multi_stage_diffusion/prior.py @@ -0,0 +1,170 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +__all__ = ['Prior'] + + +def sinusoidal_embedding(timesteps, dim): + # check input + half = dim // 2 + timesteps = timesteps.float() + + # compute sinusoidal embedding + sinusoid = torch.outer( + timesteps, torch.pow(10000, + -torch.arange(half).to(timesteps).div(half))) + x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) + if dim % 2 != 0: + x = torch.cat([x, torch.zeros_like(x[:, :1])], dim=1) + return x + + +class SelfAttention(nn.Module): + + def __init__(self, dim, num_heads): + assert dim % num_heads == 0 + super(SelfAttention, self).__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = math.pow(self.head_dim, -0.25) + + # layers + self.to_qkv = nn.Linear(dim, dim * 3) + self.proj = nn.Linear(dim, dim) + + def forward(self, x, mask): + b, l, n, c = *x.shape[:2], self.num_heads, self.head_dim + + # compute query, key, value + q, k, v = self.to_qkv(x).view(b, l, n * 3, c).chunk(3, dim=2) + + # compute attention + attn = torch.einsum('binc,bjnc->bnij', q * self.scale, k * self.scale) + if mask is not None: + attn = attn.masked_fill(mask[:, :, :l, :l] == 0, float('-inf')) + attn = F.softmax(attn.float(), dim=-1).type(attn.dtype) + + # gather context + x = torch.einsum('bnij,bjnc->binc', attn, v) + x = x.reshape(b, l, -1) + + # output + x = self.proj(x) + return x + + +class AttentionBlock(nn.Module): + + def __init__(self, dim, num_heads): + super(AttentionBlock, self).__init__() + self.dim = dim + self.num_heads = num_heads + + # layers + self.norm1 = nn.LayerNorm(dim) + self.attn = SelfAttention(dim, num_heads) + self.norm2 = nn.LayerNorm(dim) + self.ffn = nn.Sequential( + nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim)) + + def forward(self, x, mask=None): + x = x + self.attn(self.norm1(x), mask) + x = x + self.ffn(self.norm2(x)) + return x + + +class Prior(nn.Module): + + def __init__(self, dim=2048, clip_dim=768, num_heads=32, num_layers=24): + super(Prior, self).__init__() + self.dim = dim + self.clip_dim = clip_dim + self.num_heads = num_heads + self.num_layers = num_layers + + # embeddings + self.text_embedding = nn.Sequential( + nn.Linear(clip_dim, dim), nn.SiLU(), nn.Linear(dim, dim)) + self.time_embedding = nn.Sequential( + nn.Linear(dim, dim), nn.SiLU(), nn.Linear(dim, dim)) + self.vision_embedding = nn.Sequential( + nn.Linear(clip_dim, dim), nn.SiLU(), nn.Linear(dim, dim)) + self.eos_embedding = nn.Parameter(torch.zeros(1, 1, dim)) + self.pos_embedding = nn.Parameter(torch.zeros(1, 4, dim)) + + # transformer + self.blocks = nn.ModuleList( + [AttentionBlock(dim, num_heads) for _ in range(num_layers)]) + self.norm = nn.LayerNorm(dim) + + # head + self.head = nn.Linear(dim, clip_dim) + + # causal attention mask + self.register_buffer('attn_mask', torch.tril(torch.ones(1, 1, 4, 4))) + + # initialize weights + self.init_weights() + + def forward(self, x, t, y): + r"""x: [B, C]. + t: [B]. + y: [B, C]. + """ + b = x.size(0) + + # embeddings of shape [B, L + 4, C] + u1 = sinusoidal_embedding(t, self.dim) + u2 = [ + self.text_embedding(y).unsqueeze(1), + self.time_embedding(u1).unsqueeze(1), + self.vision_embedding(x).unsqueeze(1), + self.eos_embedding.repeat(b, 1, 1) + ] + x = self.pos_embedding + torch.cat(u2, dim=1) + + # transformer + for block in self.blocks: + x = block(x, self.attn_mask) + x = self.norm(x) + + # head + x = self.head(x[:, -1]) + return x + + def init_weights(self): + std = 0.02 / math.sqrt(2.0 * self.num_layers) + for name, m in self.named_modules(): + if name.endswith('attn.proj') or name.endswith('ffn.2'): + # smaller std for output layers + nn.init.normal_(m.weight, std=std) + nn.init.zeros_(m.bias) + elif isinstance(m, (nn.Linear, nn.Embedding)): + nn.init.normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.LayerNorm): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + + def param_groups(self): + groups = [{ + 'params': [ + p for n, p in self.named_parameters() + if 'norm' in n or n.endswith('bias') + ], + 'weight_decay': + 0.0 + }, { + 'params': [ + p for n, p in self.named_parameters() + if not ('norm' in n or n.endswith('bias')) + ] + }] + return groups diff --git a/modelscope/models/multi_modal/multi_stage_diffusion/tokenizer.py b/modelscope/models/multi_modal/multi_stage_diffusion/tokenizer.py new file mode 100644 index 00000000..6fd9bebe --- /dev/null +++ b/modelscope/models/multi_modal/multi_stage_diffusion/tokenizer.py @@ -0,0 +1,199 @@ +# The implementation here is modified based on OpenAI CLIP, publicly available at https://github.com/openai/CLIP. + +import gzip +import html +from functools import lru_cache + +import ftfy +import regex as re +import torch +from transformers import AutoTokenizer + +__all__ = ['CLIPTokenizer', 'XGLMTokenizer'] + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a signficant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = list(range(ord('!'), + ord('~') + 1)) + list(range( + ord('¡'), + ord('¬') + 1)) + list(range(ord('®'), + ord('ÿ') + 1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r'\s+', ' ', text) + text = text.strip() + return text + + +class SimpleTokenizer(object): + + def __init__(self, bpe_path): + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + merges = gzip.open(bpe_path).read().decode('utf-8').split('\n') + merges = merges[1:49152 - 256 - 2 + 1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v + '' for v in vocab] + for merge in merges: + vocab.append(''.join(merge)) + vocab.extend(['<|startoftext|>', '<|endoftext|>']) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = { + '<|startoftext|>': '<|startoftext|>', + '<|endoftext|>': '<|endoftext|>' + } + self.pat = re.compile( + r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", + re.IGNORECASE) + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + (token[-1] + '', ) + pairs = get_pairs(word) + + if not pairs: + return token + '' + + while True: + bigram = min( + pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except Exception: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word) - 1 and word[ + i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = ' '.join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = ''.join(self.byte_encoder[b] + for b in token.encode('utf-8')) + bpe_tokens.extend(self.encoder[bpe_token] + for bpe_token in self.bpe(token).split(' ')) + return bpe_tokens + + def decode(self, tokens): + text = ''.join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode( + 'utf-8', errors='replace').replace('', ' ') + return text + + +class CLIPTokenizer(object): + r"""CLIP tokenizer, adapted from https://github.com/openai/CLIP. + """ + + def __init__(self, bpe_path, length=77): + self.bpe_path = bpe_path + self.length = length + + # init tokenizer + self.tokenizer = SimpleTokenizer(bpe_path=bpe_path) + self.sos_token = self.tokenizer.encoder['<|startoftext|>'] + self.eos_token = self.tokenizer.encoder['<|endoftext|>'] + self.vocab_size = len(self.tokenizer.encoder) + + def __call__(self, sequence): + if isinstance(sequence, str): + return torch.LongTensor(self._tokenizer(sequence)) + elif isinstance(sequence, list): + return torch.LongTensor([self._tokenizer(u) for u in sequence]) + else: + raise TypeError( + f'Expected the "sequence" to be a string or a list, but got {type(sequence)}' + ) + + def _tokenizer(self, text): + tokens = self.tokenizer.encode(text)[:self.length - 2] + tokens = [self.sos_token] + tokens + [self.eos_token] + tokens = tokens + [0] * (self.length - len(tokens)) + return tokens + + +class XGLMTokenizer(object): + r"""A wrapper of HuggingFace's XGLM tokenizer. + """ + + def __init__(self, model_dir, length=77, **kwargs): + self.length = length + self.tokenizer = AutoTokenizer.from_pretrained(model_dir, **kwargs) + self.vocab_size = self.tokenizer.vocab_size + + def __call__(self, sequence, **kwargs): + _kwargs = { + 'return_tensors': 'pt', + 'padding': 'max_length', + 'truncation': True, + 'max_length': self.length + } + _kwargs.update(**kwargs) + tokens = self.tokenizer(sequence, **_kwargs) + return tokens.input_ids, tokens.attention_mask diff --git a/modelscope/models/multi_modal/multi_stage_diffusion/upsampler.py b/modelscope/models/multi_modal/multi_stage_diffusion/upsampler.py new file mode 100644 index 00000000..4e99a514 --- /dev/null +++ b/modelscope/models/multi_modal/multi_stage_diffusion/upsampler.py @@ -0,0 +1,466 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +__all__ = ['Upsampler256', 'Upsampler1024'] + + +def sinusoidal_embedding(timesteps, dim): + # check input + half = dim // 2 + timesteps = timesteps.float() + + # compute sinusoidal embedding + sinusoid = torch.outer( + timesteps, torch.pow(10000, + -torch.arange(half).to(timesteps).div(half))) + x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) + if dim % 2 != 0: + x = torch.cat([x, torch.zeros_like(x[:, :1])], dim=1) + return x + + +class Resample(nn.Module): + + def __init__(self, in_dim, out_dim, scale_factor, use_conv=False): + assert scale_factor in [0.5, 1.0, 2.0] + super(Resample, self).__init__() + self.in_dim = in_dim + self.out_dim = out_dim + self.scale_factor = scale_factor + self.use_conv = use_conv + + # layers + if scale_factor == 2.0: + self.resample = nn.Sequential( + nn.Upsample(scale_factor=scale_factor, mode='nearest'), + nn.Conv2d(in_dim, out_dim, 3, padding=1) + if use_conv else nn.Identity()) + elif scale_factor == 0.5: + self.resample = nn.Conv2d( + in_dim, out_dim, 3, stride=2, + padding=1) if use_conv else nn.AvgPool2d( + kernel_size=2, stride=2) + else: + self.resample = nn.Identity() + + def forward(self, x): + return self.resample(x) + + +class ResidualBlock(nn.Module): + + def __init__(self, + in_dim, + embed_dim, + out_dim, + use_scale_shift_norm=True, + scale_factor=1.0, + dropout=0.0): + super(ResidualBlock, self).__init__() + self.in_dim = in_dim + self.embed_dim = embed_dim + self.out_dim = out_dim + self.use_scale_shift_norm = use_scale_shift_norm + self.scale_factor = scale_factor + + # layers + self.layer1 = nn.Sequential( + nn.GroupNorm(32, in_dim), nn.SiLU(), + nn.Conv2d(in_dim, out_dim, 3, padding=1)) + self.resample = Resample(in_dim, in_dim, scale_factor, use_conv=False) + self.embedding = nn.Sequential( + nn.SiLU(), + nn.Linear(embed_dim, + out_dim * 2 if use_scale_shift_norm else out_dim)) + self.layer2 = nn.Sequential( + nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout), + nn.Conv2d(out_dim, out_dim, 3, padding=1)) + self.shortcut = nn.Identity() if in_dim == out_dim else nn.Conv2d( + in_dim, out_dim, 1) + + # zero out the last layer params + nn.init.zeros_(self.layer2[-1].weight) + + def forward(self, x, e): + identity = self.resample(x) + x = self.layer1[-1](self.resample(self.layer1[:-1](x))) + e = self.embedding(e).unsqueeze(-1).unsqueeze(-1).type(x.dtype) + if self.use_scale_shift_norm: + scale, shift = e.chunk(2, dim=1) + x = self.layer2[0](x) * (1 + scale) + shift + x = self.layer2[1:](x) + else: + x = x + e + x = self.layer2(x) + x = x + self.shortcut(identity) + return x + + +class AttentionBlock(nn.Module): + + def __init__(self, dim, context_dim=None, num_heads=None, head_dim=None): + # consider head_dim first, then num_heads + num_heads = dim // head_dim if head_dim else num_heads + head_dim = dim // num_heads + assert num_heads * head_dim == dim + super(AttentionBlock, self).__init__() + self.dim = dim + self.context_dim = context_dim + self.num_heads = num_heads + self.head_dim = head_dim + self.scale = math.pow(head_dim, -0.25) + + # layers + self.norm = nn.GroupNorm(32, dim) + self.to_qkv = nn.Conv2d(dim, dim * 3, 1) + if context_dim is not None: + self.context_kv = nn.Linear(context_dim, dim * 2) + self.proj = nn.Conv2d(dim, dim, 1) + + # zero out the last layer params + nn.init.zeros_(self.proj.weight) + + def forward(self, x, context=None): + r"""x: [B, C, H, W]. + context: [B, L, C] or None. + """ + identity = x + b, c, h, w, n, d = *x.size(), self.num_heads, self.head_dim + + # compute query, key, value + x = self.norm(x) + q, k, v = self.to_qkv(x).view(b, n * 3, d, h * w).chunk(3, dim=1) + if context is not None: + ck, cv = self.context_kv(context).reshape(b, -1, n * 2, + d).permute(0, 2, 3, + 1).chunk( + 2, dim=1) + k = torch.cat([ck, k], dim=-1) + v = torch.cat([cv, v], dim=-1) + + # compute attention + attn = torch.matmul(q.transpose(-1, -2) * self.scale, k * self.scale) + attn = F.softmax(attn, dim=-1) + + # gather context + x = torch.matmul(v, attn.transpose(-1, -2)) + x = x.reshape(b, c, h, w) + + # output + x = self.proj(x) + return x + identity + + +class Upsampler256(nn.Module): + + def __init__(self, + in_dim=6, + dim=320, + y_dim=768, + context_dim=512, + out_dim=3, + dim_mult=[1, 2, 3, 4], + num_heads=None, + head_dim=64, + num_res_blocks=3, + attn_scales=[1 / 8], + resblock_resample=True, + use_scale_shift_norm=True, + dropout=0.1): + embed_dim = dim * 4 + super(Upsampler256, self).__init__() + self.in_dim = in_dim + self.dim = dim + self.y_dim = y_dim + self.context_dim = context_dim + self.embed_dim = embed_dim + self.out_dim = out_dim + self.dim_mult = dim_mult + self.num_heads = num_heads + self.head_dim = head_dim + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.resblock_resample = resblock_resample + self.use_scale_shift_norm = use_scale_shift_norm + + # params + enc_dims = [dim * u for u in [1] + dim_mult] + dec_dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + shortcut_dims = [] + scale = 1.0 + + # embeddings + self.time_embedding = nn.Sequential( + nn.Linear(dim, embed_dim), nn.SiLU(), + nn.Linear(embed_dim, embed_dim)) + self.y_embedding = nn.Sequential( + nn.Linear(y_dim, embed_dim), nn.SiLU(), + nn.Linear(embed_dim, embed_dim)) + self.context_embedding = nn.Sequential( + nn.Linear(y_dim, embed_dim), nn.SiLU(), + nn.Linear(embed_dim, context_dim * 4)) + + # encoder + self.encoder = nn.ModuleList( + [nn.Conv2d(self.in_dim, dim, 3, padding=1)]) + shortcut_dims.append(dim) + for i, (in_dim, + out_dim) in enumerate(zip(enc_dims[:-1], enc_dims[1:])): + for j in range(num_res_blocks): + # residual (+attention) blocks + block = nn.ModuleList([ + ResidualBlock(in_dim, embed_dim, out_dim, + use_scale_shift_norm, 1.0, dropout) + ]) + if scale in attn_scales: + block.append( + AttentionBlock(out_dim, context_dim, num_heads, + head_dim)) + in_dim = out_dim + self.encoder.append(block) + shortcut_dims.append(out_dim) + + # downsample + if i != len(dim_mult) - 1 and j == num_res_blocks - 1: + if resblock_resample: + downsample = ResidualBlock(out_dim, embed_dim, out_dim, + use_scale_shift_norm, 0.5, + dropout) + else: + downsample = Resample( + out_dim, out_dim, 0.5, use_conv=True) + shortcut_dims.append(out_dim) + scale /= 2.0 + self.encoder.append(downsample) + + # middle + self.middle = nn.ModuleList([ + ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm, + 1.0, dropout), + AttentionBlock(out_dim, context_dim, num_heads, head_dim), + ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm, + 1.0, dropout) + ]) + + # decoder + self.decoder = nn.ModuleList() + for i, (in_dim, + out_dim) in enumerate(zip(dec_dims[:-1], dec_dims[1:])): + for j in range(num_res_blocks + 1): + # residual (+attention) blocks + block = nn.ModuleList([ + ResidualBlock(in_dim + shortcut_dims.pop(), embed_dim, + out_dim, use_scale_shift_norm, 1.0, dropout) + ]) + if scale in attn_scales: + block.append( + AttentionBlock(out_dim, context_dim, num_heads, + head_dim)) + in_dim = out_dim + + # upsample + if i != len(dim_mult) - 1 and j == num_res_blocks: + if resblock_resample: + upsample = ResidualBlock(out_dim, embed_dim, out_dim, + use_scale_shift_norm, 2.0, + dropout) + else: + upsample = Resample( + out_dim, out_dim, 2.0, use_conv=True) + scale *= 2.0 + block.append(upsample) + self.decoder.append(block) + + # head + self.head = nn.Sequential( + nn.GroupNorm(32, out_dim), nn.SiLU(), + nn.Conv2d(out_dim, self.out_dim, 3, padding=1)) + + # zero out the last layer params + nn.init.zeros_(self.head[-1].weight) + + def forward(self, x, t, y, concat): + # embeddings + x = torch.cat([x, concat], dim=1) + e = self.time_embedding(sinusoidal_embedding( + t, self.dim)) + self.y_embedding(y) + context = self.context_embedding(y).view(-1, 4, self.context_dim) + + # encoder + xs = [] + for block in self.encoder: + x = self._forward_single(block, x, e, context) + xs.append(x) + + # middle + for block in self.middle: + x = self._forward_single(block, x, e, context) + + # decoder + for block in self.decoder: + x = torch.cat([x, xs.pop()], dim=1) + x = self._forward_single(block, x, e, context) + + # head + x = self.head(x) + return x + + def _forward_single(self, module, x, e, context): + if isinstance(module, ResidualBlock): + x = module(x, e) + elif isinstance(module, AttentionBlock): + x = module(x, context) + elif isinstance(module, nn.ModuleList): + for block in module: + x = self._forward_single(block, x, e, context) + else: + x = module(x) + return x + + +class Upsampler1024(nn.Module): + + def __init__(self, + in_dim=6, + dim=192, + y_dim=768, + out_dim=3, + dim_mult=[1, 1, 2, 2, 4, 4], + num_res_blocks=2, + resblock_resample=True, + use_scale_shift_norm=True, + dropout=0.0): + embed_dim = dim * 4 + super(Upsampler1024, self).__init__() + self.in_dim = in_dim + self.dim = dim + self.y_dim = y_dim + self.out_dim = out_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.resblock_resample = resblock_resample + self.use_scale_shift_norm = use_scale_shift_norm + + # params + enc_dims = [dim * u for u in [1] + dim_mult] + dec_dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + shortcut_dims = [] + scale = 1.0 + + # embedding + self.time_embedding = nn.Sequential( + nn.Linear(dim, embed_dim), nn.SiLU(), + nn.Linear(embed_dim, embed_dim)) + self.y_embedding = nn.Sequential( + nn.Linear(y_dim, embed_dim), nn.SiLU(), + nn.Linear(embed_dim, embed_dim)) + + # encoder + self.encoder = nn.ModuleList( + [nn.Conv2d(self.in_dim, dim, 3, padding=1)]) + shortcut_dims.append(dim) + for i, (in_dim, + out_dim) in enumerate(zip(enc_dims[:-1], enc_dims[1:])): + for j in range(num_res_blocks): + # residual block + block = nn.ModuleList([ + ResidualBlock(in_dim, embed_dim, out_dim, + use_scale_shift_norm, 1.0, dropout) + ]) + shortcut_dims.append(out_dim) + in_dim = out_dim + self.encoder.append(block) + + # downsample + if i != len(dim_mult) - 1 and j == num_res_blocks - 1: + if resblock_resample: + downsample = ResidualBlock(out_dim, embed_dim, out_dim, + use_scale_shift_norm, 0.5, + dropout) + else: + downsample = Resample( + out_dim, out_dim, 0.5, use_conv=True) + shortcut_dims.append(out_dim) + scale /= 2.0 + self.encoder.append(downsample) + + # middle + self.middle = nn.ModuleList([ + ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm, + 1.0, dropout), + ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm, + 1.0, dropout) + ]) + + # decoder + self.decoder = nn.ModuleList() + for i, (in_dim, + out_dim) in enumerate(zip(dec_dims[:-1], dec_dims[1:])): + for j in range(num_res_blocks + 1): + # residual block + block = nn.ModuleList([ + ResidualBlock(in_dim + shortcut_dims.pop(), embed_dim, + out_dim, use_scale_shift_norm, 1.0, dropout) + ]) + in_dim = out_dim + + # upsample + if i != len(dim_mult) - 1 and j == num_res_blocks: + if resblock_resample: + upsample = ResidualBlock(out_dim, embed_dim, out_dim, + use_scale_shift_norm, 2.0, + dropout) + else: + upsample = Resample( + out_dim, out_dim, 2.0, use_conv=True) + scale *= 2.0 + block.append(upsample) + self.decoder.append(block) + + # head + self.head = nn.Sequential( + nn.GroupNorm(32, out_dim), nn.SiLU(), + nn.Conv2d(out_dim, self.out_dim, 3, padding=1)) + + # zero out the last layer params + nn.init.zeros_(self.head[-1].weight) + + def forward(self, x, t, y, concat): + # embedding + x = torch.cat([x, concat], dim=1) + e = self.time_embedding(sinusoidal_embedding( + t, self.dim)) + self.y_embedding(y) + + # encoder + xs = [] + for block in self.encoder: + x = self._forward_single(block, x, e) + xs.append(x) + + # middle + for block in self.middle: + x = self._forward_single(block, x, e) + + # decoder + for block in self.decoder: + x = torch.cat([x, xs.pop()], dim=1) + x = self._forward_single(block, x, e) + + # head + x = self.head(x) + return x + + def _forward_single(self, module, x, e): + if isinstance(module, ResidualBlock): + x = module(x, e) + elif isinstance(module, nn.ModuleList): + for block in module: + x = self._forward_single(block, x, e) + else: + x = module(x) + return x diff --git a/modelscope/models/multi_modal/multi_stage_diffusion/xglm.py b/modelscope/models/multi_modal/multi_stage_diffusion/xglm.py new file mode 100644 index 00000000..8a0b3ff1 --- /dev/null +++ b/modelscope/models/multi_modal/multi_stage_diffusion/xglm.py @@ -0,0 +1,205 @@ +# The implementation here is modified based on HuggingFace XGLM, publicly available +# at https://github.com/huggingface/transformers. + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +__all__ = ['XGLM'] + + +def sinusoidal_embedding(seq_len, dim, pad_token=None): + half = dim // 2 + sinusoid = torch.outer( + torch.arange(seq_len, dtype=torch.float32), + torch.pow(10000, + -torch.arange(half, dtype=torch.float32).div(half - 1))) + x = torch.cat([torch.sin(sinusoid), torch.cos(sinusoid)], dim=1) + if dim % 2 == 1: + x = torch.cat([x, torch.zeros_like(x[:, :1])], dim=1) + if pad_token is not None: + x[pad_token, :] = 0 + return x + + +class SinusoidalEmbedding(nn.Module): + + def __init__(self, seq_len, dim, pad_token): + super(SinusoidalEmbedding, self).__init__() + self.seq_len = seq_len + self.dim = dim + self.pad_token = pad_token + self.register_buffer('weight', + sinusoidal_embedding(seq_len + 2, dim, pad_token)) + + def forward(self, tokens): + mask = tokens.ne(self.pad_token).long() + indices = torch.cumsum(mask, dim=1) * mask + self.pad_token + pos_embeds = self.weight.index_select(0, indices.view(-1)).view( + *tokens.shape, -1) + return pos_embeds + + +class GELU(nn.Module): + + def forward(self, x): + return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) + + +class SelfAttention(nn.Module): + + def __init__(self, dim, num_heads, dropout=0.1): + assert dim % num_heads == 0 + super(SelfAttention, self).__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = 1.0 / math.sqrt(self.head_dim) + + # layers + self.q = nn.Linear(dim, dim) + self.k = nn.Linear(dim, dim) + self.v = nn.Linear(dim, dim) + self.o = nn.Linear(dim, dim) + self.dropout = nn.Dropout(dropout) + + def forward(self, x, mask=None): + r"""x: [B, L, C]. + mask: [B, *, L, L] or None. + """ + b, l, n, c = *x.shape[:2], self.num_heads, self.head_dim + + # compute query, key, value + q = self.q(x).view(b, l, n, c) + k = self.k(x).view(b, l, n, c) + v = self.v(x).view(b, l, n, c) + + # compute attention + attn = self.scale * torch.einsum('binc,bjnc->bnij', q, k) + if mask is not None: + attn = attn.masked_fill(mask == 0, float('-inf')) + attn = F.softmax(attn, dim=-1) + attn = self.dropout(attn) + + # gather context + x = torch.einsum('bnij,bjnc->binc', attn, v) + x = x.reshape(b, l, -1) + + # output + x = self.o(x) + x = self.dropout(x) + return x + + +class AttentionBlock(nn.Module): + + def __init__(self, dim, ffn_dim, ffn_act, num_heads, dropout=0.1): + assert ffn_act in ['gelu', 'relu'] + super(AttentionBlock, self).__init__() + self.dim = dim + self.ffn_dim = ffn_dim + self.ffn_act = ffn_act + self.num_heads = num_heads + + # layers + self.norm1 = nn.LayerNorm(dim) + self.attn = SelfAttention(dim, num_heads, dropout) + self.norm2 = nn.LayerNorm(dim) + self.ffn = nn.Sequential( + nn.Linear(dim, ffn_dim), + GELU() if ffn_act == 'gelu' else nn.ReLU(inplace=True), + nn.Linear(ffn_dim, dim), nn.Dropout(dropout)) + + def forward(self, x, mask=None): + x = x + self.attn(self.norm1(x), mask) + x = x + self.ffn(self.norm2(x)) + return x + + +class XGLM(nn.Module): + r"""A multilingual GPT model with an embedding head. + """ + + def __init__(self, + vocab_size=256008, + max_seq_len=2048, + dim=1024, + ffn_dim=4096, + ffn_act='gelu', + embed_dim=768, + num_heads=16, + num_layers=24, + pad_token=1, + dropout=0.1): + super(XGLM, self).__init__() + self.vocab_size = vocab_size + self.max_seq_len = max_seq_len + self.dim = dim + self.ffn_dim = ffn_dim + self.ffn_act = ffn_act + self.embed_dim = embed_dim + self.num_heads = num_heads + self.num_layers = num_layers + self.pad_token = pad_token + self.scale = math.sqrt(dim) # rescale token embedings + + # layers + self.token_embedding = nn.Embedding(vocab_size, dim, pad_token) + self.pos_embedding = SinusoidalEmbedding(max_seq_len, dim, pad_token) + self.eos_embedding = nn.Parameter(torch.randn(1, 1, dim)) + self.dropout = nn.Dropout(dropout) + self.blocks = nn.ModuleList([ + AttentionBlock(dim, ffn_dim, ffn_act, num_heads, dropout) + for _ in range(num_layers) + ]) + self.norm = nn.LayerNorm(dim) + self.head = nn.Linear(dim, embed_dim, bias=False) + + # causal attention mask + self.register_buffer( + 'attn_mask', + torch.tril(torch.ones(1, 1, 1 + max_seq_len, 1 + max_seq_len))) + + # init weights + self.apply(self.init_weights) + + def forward(self, tokens, mask=None): + r"""tokens: [B, L]. + mask: [B, L]. + """ + b, seq_len = tokens.size(0), 1 + tokens.size(1) + + # embeddings + x = self.scale * self.token_embedding(tokens) + x = torch.cat([x, self.eos_embedding.repeat(b, 1, 1)], dim=1) + # x = x + self.pos_embedding(tokens) + x = self.dropout(x) + + # attention mask + if mask is None: + mask = self.attn_mask[:, :, :seq_len, :seq_len].repeat(b, 1, 1, 1) + else: + mask = self.attn_mask[:, :, :seq_len, :seq_len] * torch.cat( + [mask, torch.zeros_like(mask[:, :1])], dim=1).view( + b, 1, 1, seq_len) + + # transformer + for block in self.blocks: + x = block(x, mask) + x = self.norm(x) + + # head + logits = self.head(x[:, -1]) + return logits + + def init_weights(self, m): + if isinstance(m, nn.Linear): + nn.init.normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Embedding): + nn.init.normal_(m.weight, std=0.02) + if m.padding_idx is not None: + nn.init.zeros_(m.weight[m.padding_idx]) diff --git a/modelscope/pipelines/multi_modal/text_to_image_synthesis_pipeline.py b/modelscope/pipelines/multi_modal/text_to_image_synthesis_pipeline.py index 406538cf..f402cc29 100644 --- a/modelscope/pipelines/multi_modal/text_to_image_synthesis_pipeline.py +++ b/modelscope/pipelines/multi_modal/text_to_image_synthesis_pipeline.py @@ -3,7 +3,8 @@ from typing import Any, Dict, Optional import torch from modelscope.metainfo import Pipelines -from modelscope.models.multi_modal import OfaForTextToImageSynthesis +from modelscope.models.multi_modal import ( + MultiStageDiffusionForTextToImageSynthesis, OfaForTextToImageSynthesis) from modelscope.outputs import OutputKeys from modelscope.pipelines.base import Input, Model, Pipeline from modelscope.pipelines.builder import PIPELINES @@ -48,7 +49,9 @@ class TextToImageSynthesisPipeline(Pipeline): return input def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: - if isinstance(self.model, OfaForTextToImageSynthesis): + if isinstance(self.model, + (OfaForTextToImageSynthesis, + MultiStageDiffusionForTextToImageSynthesis)): return self.model(input) return self.model.generate(input) diff --git a/tests/pipelines/test_multi_stage_diffusion.py b/tests/pipelines/test_multi_stage_diffusion.py new file mode 100644 index 00000000..f4e63ce0 --- /dev/null +++ b/tests/pipelines/test_multi_stage_diffusion.py @@ -0,0 +1,40 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import unittest + +import numpy as np +import torch + +from modelscope.models import Model +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + + +class MultiStageDiffusionTest(unittest.TestCase): + model_id = 'damo/cv_diffusion_text-to-image-synthesis' + test_text = {'text': 'Photograph of a baby chicken wearing sunglasses'} + + @unittest.skip( + 'skip test since the pretrained model is not publicly available') + def test_run_with_model_from_modelhub(self): + model = Model.from_pretrained(self.model_id) + pipe_line_text_to_image_synthesis = pipeline( + task=Tasks.text_to_image_synthesis, model=model) + img = pipe_line_text_to_image_synthesis( + self.test_text)[OutputKeys.OUTPUT_IMG] + print(np.sum(np.abs(img))) + + @unittest.skip( + 'skip test since the pretrained model is not publicly available') + def test_run_with_model_name(self): + pipe_line_text_to_image_synthesis = pipeline( + task=Tasks.text_to_image_synthesis, model=self.model_id) + img = pipe_line_text_to_image_synthesis( + self.test_text)[OutputKeys.OUTPUT_IMG] + print(np.sum(np.abs(img))) + + +if __name__ == '__main__': + unittest.main()