mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-23 03:29:27 +01:00
DALL-E 2: 修复dev/dalle2_1分支问题,增加测试代码,本地测试通过
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10037492
This commit is contained in:
@@ -72,6 +72,7 @@ class Models(object):
|
|||||||
gemm = 'gemm-generative-multi-modal'
|
gemm = 'gemm-generative-multi-modal'
|
||||||
mplug = 'mplug'
|
mplug = 'mplug'
|
||||||
diffusion = 'diffusion-text-to-image-synthesis'
|
diffusion = 'diffusion-text-to-image-synthesis'
|
||||||
|
multi_stage_diffusion = 'multi-stage-diffusion-text-to-image-synthesis'
|
||||||
team = 'team-multi-modal-similarity'
|
team = 'team-multi-modal-similarity'
|
||||||
video_clip = 'video-clip-multi-modal-embedding'
|
video_clip = 'video-clip-multi-modal-embedding'
|
||||||
|
|
||||||
|
|||||||
@@ -14,6 +14,8 @@ if TYPE_CHECKING:
|
|||||||
from .ofa_for_all_tasks import OfaForAllTasks
|
from .ofa_for_all_tasks import OfaForAllTasks
|
||||||
from .ofa_for_text_to_image_synthesis_model import \
|
from .ofa_for_text_to_image_synthesis_model import \
|
||||||
OfaForTextToImageSynthesis
|
OfaForTextToImageSynthesis
|
||||||
|
from .multi_stage_diffusion import \
|
||||||
|
MultiStageDiffusionForTextToImageSynthesis
|
||||||
|
|
||||||
else:
|
else:
|
||||||
_import_structure = {
|
_import_structure = {
|
||||||
@@ -25,7 +27,9 @@ else:
|
|||||||
'mplug_for_all_tasks': ['MPlugForAllTasks'],
|
'mplug_for_all_tasks': ['MPlugForAllTasks'],
|
||||||
'ofa_for_all_tasks': ['OfaForAllTasks'],
|
'ofa_for_all_tasks': ['OfaForAllTasks'],
|
||||||
'ofa_for_text_to_image_synthesis_model':
|
'ofa_for_text_to_image_synthesis_model':
|
||||||
['OfaForTextToImageSynthesis']
|
['OfaForTextToImageSynthesis'],
|
||||||
|
'multi_stage_diffusion':
|
||||||
|
['MultiStageDiffusionForTextToImageSynthesis']
|
||||||
}
|
}
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
|
|||||||
@@ -0,0 +1 @@
|
|||||||
|
from .model import MultiStageDiffusionForTextToImageSynthesis
|
||||||
318
modelscope/models/multi_modal/multi_stage_diffusion/clip.py
Normal file
318
modelscope/models/multi_modal/multi_stage_diffusion/clip.py
Normal file
@@ -0,0 +1,318 @@
|
|||||||
|
# The implementation here is modified based on OpenAI CLIP, publicly available at https://github.com/openai/CLIP.
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
__all__ = ['CLIP']
|
||||||
|
|
||||||
|
|
||||||
|
def to_fp16(m):
|
||||||
|
if isinstance(m, (nn.Linear, nn.Conv2d)):
|
||||||
|
m.weight.data = m.weight.data.half()
|
||||||
|
if m.bias is not None:
|
||||||
|
m.bias.data = m.bias.data.half()
|
||||||
|
elif hasattr(m, 'head'):
|
||||||
|
p = getattr(m, 'head')
|
||||||
|
p.data = p.data.half()
|
||||||
|
|
||||||
|
|
||||||
|
class QuickGELU(nn.Module):
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return x * torch.sigmoid(1.702 * x)
|
||||||
|
|
||||||
|
|
||||||
|
class LayerNorm(nn.LayerNorm):
|
||||||
|
r"""Subclass of nn.LayerNorm to handle fp16.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return super(LayerNorm, self).forward(x.float()).type_as(x)
|
||||||
|
|
||||||
|
|
||||||
|
class SelfAttention(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, dim, num_heads, attn_dropout=0.0, proj_dropout=0.0):
|
||||||
|
assert dim % num_heads == 0
|
||||||
|
super(SelfAttention, self).__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = dim // num_heads
|
||||||
|
self.scale = 1.0 / math.sqrt(self.head_dim)
|
||||||
|
|
||||||
|
# layers
|
||||||
|
self.to_qkv = nn.Linear(dim, dim * 3)
|
||||||
|
self.attn_dropout = nn.Dropout(attn_dropout)
|
||||||
|
self.proj = nn.Linear(dim, dim)
|
||||||
|
self.proj_dropout = nn.Dropout(proj_dropout)
|
||||||
|
|
||||||
|
def forward(self, x, mask=None):
|
||||||
|
r"""x: [B, L, C].
|
||||||
|
mask: [*, L, L].
|
||||||
|
"""
|
||||||
|
b, l, _, n = *x.size(), self.num_heads
|
||||||
|
|
||||||
|
# compute query, key, and value
|
||||||
|
q, k, v = self.to_qkv(x.transpose(0, 1)).chunk(3, dim=-1)
|
||||||
|
q = q.reshape(l, b * n, -1).transpose(0, 1)
|
||||||
|
k = k.reshape(l, b * n, -1).transpose(0, 1)
|
||||||
|
v = v.reshape(l, b * n, -1).transpose(0, 1)
|
||||||
|
|
||||||
|
# compute attention
|
||||||
|
attn = self.scale * torch.bmm(q, k.transpose(1, 2))
|
||||||
|
if mask is not None:
|
||||||
|
attn = attn.masked_fill(mask[:, :l, :l] == 0, float('-inf'))
|
||||||
|
attn = F.softmax(attn.float(), dim=-1).type_as(attn)
|
||||||
|
attn = self.attn_dropout(attn)
|
||||||
|
|
||||||
|
# gather context
|
||||||
|
x = torch.bmm(attn, v)
|
||||||
|
x = x.view(b, n, l, -1).transpose(1, 2).reshape(b, l, -1)
|
||||||
|
|
||||||
|
# output
|
||||||
|
x = self.proj(x)
|
||||||
|
x = self.proj_dropout(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionBlock(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, dim, num_heads, attn_dropout=0.0, proj_dropout=0.0):
|
||||||
|
super(AttentionBlock, self).__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
|
||||||
|
# layers
|
||||||
|
self.norm1 = LayerNorm(dim)
|
||||||
|
self.attn = SelfAttention(dim, num_heads, attn_dropout, proj_dropout)
|
||||||
|
self.norm2 = LayerNorm(dim)
|
||||||
|
self.mlp = nn.Sequential(
|
||||||
|
nn.Linear(dim, dim * 4), QuickGELU(), nn.Linear(dim * 4, dim),
|
||||||
|
nn.Dropout(proj_dropout))
|
||||||
|
|
||||||
|
def forward(self, x, mask=None):
|
||||||
|
x = x + self.attn(self.norm1(x), mask)
|
||||||
|
x = x + self.mlp(self.norm2(x))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class VisionTransformer(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
image_size=224,
|
||||||
|
patch_size=16,
|
||||||
|
dim=768,
|
||||||
|
out_dim=512,
|
||||||
|
num_heads=12,
|
||||||
|
num_layers=12,
|
||||||
|
attn_dropout=0.0,
|
||||||
|
proj_dropout=0.0,
|
||||||
|
embedding_dropout=0.0):
|
||||||
|
assert image_size % patch_size == 0
|
||||||
|
super(VisionTransformer, self).__init__()
|
||||||
|
self.image_size = image_size
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.dim = dim
|
||||||
|
self.out_dim = out_dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.num_layers = num_layers
|
||||||
|
self.num_patches = (image_size // patch_size)**2
|
||||||
|
|
||||||
|
# embeddings
|
||||||
|
gain = 1.0 / math.sqrt(dim)
|
||||||
|
self.patch_embedding = nn.Conv2d(
|
||||||
|
3, dim, kernel_size=patch_size, stride=patch_size, bias=False)
|
||||||
|
self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
|
||||||
|
self.pos_embedding = nn.Parameter(
|
||||||
|
gain * torch.randn(1, self.num_patches + 1, dim))
|
||||||
|
self.dropout = nn.Dropout(embedding_dropout)
|
||||||
|
|
||||||
|
# transformer
|
||||||
|
self.pre_norm = LayerNorm(dim)
|
||||||
|
self.transformer = nn.Sequential(*[
|
||||||
|
AttentionBlock(dim, num_heads, attn_dropout, proj_dropout)
|
||||||
|
for _ in range(num_layers)
|
||||||
|
])
|
||||||
|
self.post_norm = LayerNorm(dim)
|
||||||
|
|
||||||
|
# head
|
||||||
|
self.head = nn.Parameter(gain * torch.randn(dim, out_dim))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
b, dtype = x.size(0), self.head.dtype
|
||||||
|
x = x.type(dtype)
|
||||||
|
|
||||||
|
# patch-embedding
|
||||||
|
x = self.patch_embedding(x).flatten(2).permute(0, 2, 1) # [b, n, c]
|
||||||
|
x = torch.cat([self.cls_embedding.repeat(b, 1, 1).type(dtype), x],
|
||||||
|
dim=1)
|
||||||
|
x = self.dropout(x + self.pos_embedding.type(dtype))
|
||||||
|
x = self.pre_norm(x)
|
||||||
|
|
||||||
|
# transformer
|
||||||
|
x = self.transformer(x)
|
||||||
|
|
||||||
|
# head
|
||||||
|
x = self.post_norm(x)
|
||||||
|
x = torch.mm(x[:, 0, :], self.head)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def fp16(self):
|
||||||
|
return self.apply(to_fp16)
|
||||||
|
|
||||||
|
|
||||||
|
class TextTransformer(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
vocab_size,
|
||||||
|
text_len,
|
||||||
|
dim=512,
|
||||||
|
out_dim=512,
|
||||||
|
num_heads=8,
|
||||||
|
num_layers=12,
|
||||||
|
attn_dropout=0.0,
|
||||||
|
proj_dropout=0.0,
|
||||||
|
embedding_dropout=0.0):
|
||||||
|
super(TextTransformer, self).__init__()
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.text_len = text_len
|
||||||
|
self.dim = dim
|
||||||
|
self.out_dim = out_dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.num_layers = num_layers
|
||||||
|
|
||||||
|
# embeddings
|
||||||
|
self.token_embedding = nn.Embedding(vocab_size, dim)
|
||||||
|
self.pos_embedding = nn.Parameter(0.01 * torch.randn(1, text_len, dim))
|
||||||
|
self.dropout = nn.Dropout(embedding_dropout)
|
||||||
|
|
||||||
|
# transformer
|
||||||
|
self.transformer = nn.ModuleList([
|
||||||
|
AttentionBlock(dim, num_heads, attn_dropout, proj_dropout)
|
||||||
|
for _ in range(num_layers)
|
||||||
|
])
|
||||||
|
self.norm = LayerNorm(dim)
|
||||||
|
|
||||||
|
# head
|
||||||
|
gain = 1.0 / math.sqrt(dim)
|
||||||
|
self.head = nn.Parameter(gain * torch.randn(dim, out_dim))
|
||||||
|
|
||||||
|
# causal attention mask
|
||||||
|
self.register_buffer('attn_mask',
|
||||||
|
torch.tril(torch.ones(1, text_len, text_len)))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
eot, dtype = x.argmax(dim=-1), self.head.dtype
|
||||||
|
|
||||||
|
# embeddings
|
||||||
|
x = self.dropout(
|
||||||
|
self.token_embedding(x).type(dtype)
|
||||||
|
+ self.pos_embedding.type(dtype))
|
||||||
|
|
||||||
|
# transformer
|
||||||
|
for block in self.transformer:
|
||||||
|
x = block(x, self.attn_mask)
|
||||||
|
|
||||||
|
# head
|
||||||
|
x = self.norm(x)
|
||||||
|
x = torch.mm(x[torch.arange(x.size(0)), eot], self.head)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def fp16(self):
|
||||||
|
return self.apply(to_fp16)
|
||||||
|
|
||||||
|
|
||||||
|
class CLIP(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
embed_dim=512,
|
||||||
|
image_size=224,
|
||||||
|
patch_size=16,
|
||||||
|
vision_dim=768,
|
||||||
|
vision_heads=12,
|
||||||
|
vision_layers=12,
|
||||||
|
vocab_size=49408,
|
||||||
|
text_len=77,
|
||||||
|
text_dim=512,
|
||||||
|
text_heads=8,
|
||||||
|
text_layers=12,
|
||||||
|
attn_dropout=0.0,
|
||||||
|
proj_dropout=0.0,
|
||||||
|
embedding_dropout=0.0):
|
||||||
|
super(CLIP, self).__init__()
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
self.image_size = image_size
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.vision_dim = vision_dim
|
||||||
|
self.vision_heads = vision_heads
|
||||||
|
self.vision_layers = vision_layers
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.text_len = text_len
|
||||||
|
self.text_dim = text_dim
|
||||||
|
self.text_heads = text_heads
|
||||||
|
self.text_layers = text_layers
|
||||||
|
|
||||||
|
# models
|
||||||
|
self.visual = VisionTransformer(
|
||||||
|
image_size=image_size,
|
||||||
|
patch_size=patch_size,
|
||||||
|
dim=vision_dim,
|
||||||
|
out_dim=embed_dim,
|
||||||
|
num_heads=vision_heads,
|
||||||
|
num_layers=vision_layers,
|
||||||
|
attn_dropout=attn_dropout,
|
||||||
|
proj_dropout=proj_dropout,
|
||||||
|
embedding_dropout=embedding_dropout)
|
||||||
|
self.textual = TextTransformer(
|
||||||
|
vocab_size=vocab_size,
|
||||||
|
text_len=text_len,
|
||||||
|
dim=text_dim,
|
||||||
|
out_dim=embed_dim,
|
||||||
|
num_heads=text_heads,
|
||||||
|
num_layers=text_layers,
|
||||||
|
attn_dropout=attn_dropout,
|
||||||
|
proj_dropout=proj_dropout,
|
||||||
|
embedding_dropout=embedding_dropout)
|
||||||
|
self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([]))
|
||||||
|
|
||||||
|
def forward(self, imgs, txt_tokens):
|
||||||
|
r"""imgs: [B, C, H, W] of torch.float32.
|
||||||
|
txt_tokens: [B, T] of torch.long.
|
||||||
|
"""
|
||||||
|
xi = self.visual(imgs)
|
||||||
|
xt = self.textual(txt_tokens)
|
||||||
|
|
||||||
|
# normalize features
|
||||||
|
xi = F.normalize(xi, p=2, dim=1)
|
||||||
|
xt = F.normalize(xt, p=2, dim=1)
|
||||||
|
|
||||||
|
# logits
|
||||||
|
scale = self.log_scale.exp()
|
||||||
|
logits_i2t = scale * torch.mm(xi, xt.t())
|
||||||
|
logits_t2i = scale * torch.mm(xt, xi.t())
|
||||||
|
return logits_i2t, logits_t2i
|
||||||
|
|
||||||
|
def init_weights(self):
|
||||||
|
# embeddings
|
||||||
|
nn.init.normal_(self.textual.token_embedding.weight, std=0.02)
|
||||||
|
nn.init.normal_(self.visual.patch_embedding.weight, tsd=0.1)
|
||||||
|
|
||||||
|
# attentions
|
||||||
|
for modality in ['visual', 'textual']:
|
||||||
|
dim = self.vision_dim if modality == 'visual' else 'textual'
|
||||||
|
transformer = getattr(self, modality).transformer
|
||||||
|
proj_gain = (1.0 / math.sqrt(dim)) * (
|
||||||
|
1.0 / math.sqrt(2 * transformer.num_layers))
|
||||||
|
attn_gain = 1.0 / math.sqrt(dim)
|
||||||
|
mlp_gain = 1.0 / math.sqrt(2.0 * dim)
|
||||||
|
for block in transformer.layers:
|
||||||
|
nn.init.normal_(block.attn.to_qkv.weight, std=attn_gain)
|
||||||
|
nn.init.normal_(block.attn.proj.weight, std=proj_gain)
|
||||||
|
nn.init.normal_(block.mlp[0].weight, std=mlp_gain)
|
||||||
|
nn.init.normal_(block.mlp[2].weight, std=proj_gain)
|
||||||
|
|
||||||
|
def fp16(self):
|
||||||
|
return self.apply(to_fp16)
|
||||||
322
modelscope/models/multi_modal/multi_stage_diffusion/decoder.py
Normal file
322
modelscope/models/multi_modal/multi_stage_diffusion/decoder.py
Normal file
@@ -0,0 +1,322 @@
|
|||||||
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
__all__ = ['Decoder']
|
||||||
|
|
||||||
|
|
||||||
|
def sinusoidal_embedding(timesteps, dim):
|
||||||
|
# check input
|
||||||
|
half = dim // 2
|
||||||
|
timesteps = timesteps.float()
|
||||||
|
|
||||||
|
# compute sinusoidal embedding
|
||||||
|
sinusoid = torch.outer(
|
||||||
|
timesteps, torch.pow(10000,
|
||||||
|
-torch.arange(half).to(timesteps).div(half)))
|
||||||
|
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
|
||||||
|
if dim % 2 != 0:
|
||||||
|
x = torch.cat([x, torch.zeros_like(x[:, :1])], dim=1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Resample(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, in_dim, out_dim, scale_factor, use_conv=False):
|
||||||
|
assert scale_factor in [0.5, 1.0, 2.0]
|
||||||
|
super(Resample, self).__init__()
|
||||||
|
self.in_dim = in_dim
|
||||||
|
self.out_dim = out_dim
|
||||||
|
self.scale_factor = scale_factor
|
||||||
|
self.use_conv = use_conv
|
||||||
|
|
||||||
|
# layers
|
||||||
|
if scale_factor == 2.0:
|
||||||
|
self.resample = nn.Sequential(
|
||||||
|
nn.Upsample(scale_factor=scale_factor, mode='nearest'),
|
||||||
|
nn.Conv2d(in_dim, out_dim, 3, padding=1)
|
||||||
|
if use_conv else nn.Identity())
|
||||||
|
elif scale_factor == 0.5:
|
||||||
|
self.resample = nn.Conv2d(
|
||||||
|
in_dim, out_dim, 3, stride=2,
|
||||||
|
padding=1) if use_conv else nn.AvgPool2d(
|
||||||
|
kernel_size=2, stride=2)
|
||||||
|
else:
|
||||||
|
self.resample = nn.Identity()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.resample(x)
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualBlock(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
in_dim,
|
||||||
|
embed_dim,
|
||||||
|
out_dim,
|
||||||
|
use_scale_shift_norm=True,
|
||||||
|
scale_factor=1.0,
|
||||||
|
dropout=0.0):
|
||||||
|
super(ResidualBlock, self).__init__()
|
||||||
|
self.in_dim = in_dim
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
self.out_dim = out_dim
|
||||||
|
self.use_scale_shift_norm = use_scale_shift_norm
|
||||||
|
self.scale_factor = scale_factor
|
||||||
|
|
||||||
|
# layers
|
||||||
|
self.layer1 = nn.Sequential(
|
||||||
|
nn.GroupNorm(32, in_dim), nn.SiLU(),
|
||||||
|
nn.Conv2d(in_dim, out_dim, 3, padding=1))
|
||||||
|
self.resample = Resample(in_dim, in_dim, scale_factor, use_conv=False)
|
||||||
|
self.embedding = nn.Sequential(
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Linear(embed_dim,
|
||||||
|
out_dim * 2 if use_scale_shift_norm else out_dim))
|
||||||
|
self.layer2 = nn.Sequential(
|
||||||
|
nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout),
|
||||||
|
nn.Conv2d(out_dim, out_dim, 3, padding=1))
|
||||||
|
self.shortcut = nn.Identity() if in_dim == out_dim else nn.Conv2d(
|
||||||
|
in_dim, out_dim, 1)
|
||||||
|
|
||||||
|
# zero out the last layer params
|
||||||
|
nn.init.zeros_(self.layer2[-1].weight)
|
||||||
|
|
||||||
|
def forward(self, x, e):
|
||||||
|
identity = self.resample(x)
|
||||||
|
x = self.layer1[-1](self.resample(self.layer1[:-1](x)))
|
||||||
|
e = self.embedding(e).unsqueeze(-1).unsqueeze(-1).type(x.dtype)
|
||||||
|
if self.use_scale_shift_norm:
|
||||||
|
scale, shift = e.chunk(2, dim=1)
|
||||||
|
x = self.layer2[0](x) * (1 + scale) + shift
|
||||||
|
x = self.layer2[1:](x)
|
||||||
|
else:
|
||||||
|
x = x + e
|
||||||
|
x = self.layer2(x)
|
||||||
|
x = x + self.shortcut(identity)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionBlock(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, dim, context_dim=None, num_heads=None, head_dim=None):
|
||||||
|
# consider head_dim first, then num_heads
|
||||||
|
num_heads = dim // head_dim if head_dim else num_heads
|
||||||
|
head_dim = dim // num_heads
|
||||||
|
assert num_heads * head_dim == dim
|
||||||
|
super(AttentionBlock, self).__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.context_dim = context_dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = head_dim
|
||||||
|
self.scale = math.pow(head_dim, -0.25)
|
||||||
|
|
||||||
|
# layers
|
||||||
|
self.norm = nn.GroupNorm(32, dim)
|
||||||
|
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
|
||||||
|
if context_dim is not None:
|
||||||
|
self.context_kv = nn.Linear(context_dim, dim * 2)
|
||||||
|
self.proj = nn.Conv2d(dim, dim, 1)
|
||||||
|
|
||||||
|
# zero out the last layer params
|
||||||
|
nn.init.zeros_(self.proj.weight)
|
||||||
|
|
||||||
|
def forward(self, x, context=None):
|
||||||
|
r"""x: [B, C, H, W].
|
||||||
|
context: [B, L, C] or None.
|
||||||
|
"""
|
||||||
|
identity = x
|
||||||
|
b, c, h, w, n, d = *x.size(), self.num_heads, self.head_dim
|
||||||
|
|
||||||
|
# compute query, key, value
|
||||||
|
x = self.norm(x)
|
||||||
|
q, k, v = self.to_qkv(x).view(b, n * 3, d, h * w).chunk(3, dim=1)
|
||||||
|
if context is not None:
|
||||||
|
ck, cv = self.context_kv(context).reshape(b, -1, n * 2,
|
||||||
|
d).permute(0, 2, 3,
|
||||||
|
1).chunk(
|
||||||
|
2, dim=1)
|
||||||
|
k = torch.cat([ck, k], dim=-1)
|
||||||
|
v = torch.cat([cv, v], dim=-1)
|
||||||
|
|
||||||
|
# compute attention
|
||||||
|
attn = torch.matmul(q.transpose(-1, -2) * self.scale, k * self.scale)
|
||||||
|
attn = F.softmax(attn, dim=-1)
|
||||||
|
|
||||||
|
# gather context
|
||||||
|
x = torch.matmul(v, attn.transpose(-1, -2))
|
||||||
|
x = x.reshape(b, c, h, w)
|
||||||
|
|
||||||
|
# output
|
||||||
|
x = self.proj(x)
|
||||||
|
return x + identity
|
||||||
|
|
||||||
|
|
||||||
|
class Decoder(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
in_dim=3,
|
||||||
|
dim=512,
|
||||||
|
y_dim=512,
|
||||||
|
context_dim=512,
|
||||||
|
out_dim=6,
|
||||||
|
dim_mult=[1, 2, 3, 4],
|
||||||
|
num_heads=None,
|
||||||
|
head_dim=64,
|
||||||
|
num_res_blocks=3,
|
||||||
|
attn_scales=[1 / 2, 1 / 4, 1 / 8],
|
||||||
|
resblock_resample=True,
|
||||||
|
use_scale_shift_norm=True,
|
||||||
|
dropout=0.1):
|
||||||
|
embed_dim = dim * 4
|
||||||
|
super(Decoder, self).__init__()
|
||||||
|
self.in_dim = in_dim
|
||||||
|
self.dim = dim
|
||||||
|
self.y_dim = y_dim
|
||||||
|
self.context_dim = context_dim
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
self.out_dim = out_dim
|
||||||
|
self.dim_mult = dim_mult
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = head_dim
|
||||||
|
self.num_res_blocks = num_res_blocks
|
||||||
|
self.attn_scales = attn_scales
|
||||||
|
self.resblock_resample = resblock_resample
|
||||||
|
self.use_scale_shift_norm = use_scale_shift_norm
|
||||||
|
|
||||||
|
# params
|
||||||
|
enc_dims = [dim * u for u in [1] + dim_mult]
|
||||||
|
dec_dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
||||||
|
shortcut_dims = []
|
||||||
|
scale = 1.0
|
||||||
|
|
||||||
|
# embeddings
|
||||||
|
self.time_embedding = nn.Sequential(
|
||||||
|
nn.Linear(dim, embed_dim), nn.SiLU(),
|
||||||
|
nn.Linear(embed_dim, embed_dim))
|
||||||
|
self.y_embedding = nn.Sequential(
|
||||||
|
nn.Linear(y_dim, embed_dim), nn.SiLU(),
|
||||||
|
nn.Linear(embed_dim, embed_dim))
|
||||||
|
self.context_embedding = nn.Sequential(
|
||||||
|
nn.Linear(y_dim, embed_dim), nn.SiLU(),
|
||||||
|
nn.Linear(embed_dim, context_dim * 4))
|
||||||
|
|
||||||
|
# encoder
|
||||||
|
self.encoder = nn.ModuleList(
|
||||||
|
[nn.Conv2d(self.in_dim, dim, 3, padding=1)])
|
||||||
|
shortcut_dims.append(dim)
|
||||||
|
for i, (in_dim,
|
||||||
|
out_dim) in enumerate(zip(enc_dims[:-1], enc_dims[1:])):
|
||||||
|
for j in range(num_res_blocks):
|
||||||
|
# residual (+attention) blocks
|
||||||
|
block = nn.ModuleList([
|
||||||
|
ResidualBlock(in_dim, embed_dim, out_dim,
|
||||||
|
use_scale_shift_norm, 1.0, dropout)
|
||||||
|
])
|
||||||
|
if scale in attn_scales:
|
||||||
|
block.append(
|
||||||
|
AttentionBlock(out_dim, context_dim, num_heads,
|
||||||
|
head_dim))
|
||||||
|
in_dim = out_dim
|
||||||
|
self.encoder.append(block)
|
||||||
|
shortcut_dims.append(out_dim)
|
||||||
|
|
||||||
|
# downsample
|
||||||
|
if i != len(dim_mult) - 1 and j == num_res_blocks - 1:
|
||||||
|
if resblock_resample:
|
||||||
|
downsample = ResidualBlock(out_dim, embed_dim, out_dim,
|
||||||
|
use_scale_shift_norm, 0.5,
|
||||||
|
dropout)
|
||||||
|
else:
|
||||||
|
downsample = Resample(
|
||||||
|
out_dim, out_dim, 0.5, use_conv=True)
|
||||||
|
shortcut_dims.append(out_dim)
|
||||||
|
scale /= 2.0
|
||||||
|
self.encoder.append(downsample)
|
||||||
|
|
||||||
|
# middle
|
||||||
|
self.middle = nn.ModuleList([
|
||||||
|
ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm,
|
||||||
|
1.0, dropout),
|
||||||
|
AttentionBlock(out_dim, context_dim, num_heads, head_dim),
|
||||||
|
ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm,
|
||||||
|
1.0, dropout)
|
||||||
|
])
|
||||||
|
|
||||||
|
# decoder
|
||||||
|
self.decoder = nn.ModuleList()
|
||||||
|
for i, (in_dim,
|
||||||
|
out_dim) in enumerate(zip(dec_dims[:-1], dec_dims[1:])):
|
||||||
|
for j in range(num_res_blocks + 1):
|
||||||
|
# residual (+attention) blocks
|
||||||
|
block = nn.ModuleList([
|
||||||
|
ResidualBlock(in_dim + shortcut_dims.pop(), embed_dim,
|
||||||
|
out_dim, use_scale_shift_norm, 1.0, dropout)
|
||||||
|
])
|
||||||
|
if scale in attn_scales:
|
||||||
|
block.append(
|
||||||
|
AttentionBlock(out_dim, context_dim, num_heads,
|
||||||
|
head_dim))
|
||||||
|
in_dim = out_dim
|
||||||
|
|
||||||
|
# upsample
|
||||||
|
if i != len(dim_mult) - 1 and j == num_res_blocks:
|
||||||
|
if resblock_resample:
|
||||||
|
upsample = ResidualBlock(out_dim, embed_dim, out_dim,
|
||||||
|
use_scale_shift_norm, 2.0,
|
||||||
|
dropout)
|
||||||
|
else:
|
||||||
|
upsample = Resample(
|
||||||
|
out_dim, out_dim, 2.0, use_conv=True)
|
||||||
|
scale *= 2.0
|
||||||
|
block.append(upsample)
|
||||||
|
self.decoder.append(block)
|
||||||
|
|
||||||
|
# head
|
||||||
|
self.head = nn.Sequential(
|
||||||
|
nn.GroupNorm(32, out_dim), nn.SiLU(),
|
||||||
|
nn.Conv2d(out_dim, self.out_dim, 3, padding=1))
|
||||||
|
|
||||||
|
# zero out the last layer params
|
||||||
|
nn.init.zeros_(self.head[-1].weight)
|
||||||
|
|
||||||
|
def forward(self, x, t, y):
|
||||||
|
# embeddings
|
||||||
|
e = self.time_embedding(sinusoidal_embedding(
|
||||||
|
t, self.dim)) + self.y_embedding(y)
|
||||||
|
context = self.context_embedding(y).view(-1, 4, self.context_dim)
|
||||||
|
|
||||||
|
# encoder
|
||||||
|
xs = []
|
||||||
|
for block in self.encoder:
|
||||||
|
x = self._forward_single(block, x, e, context)
|
||||||
|
xs.append(x)
|
||||||
|
|
||||||
|
# middle
|
||||||
|
for block in self.middle:
|
||||||
|
x = self._forward_single(block, x, e, context)
|
||||||
|
|
||||||
|
# decoder
|
||||||
|
for block in self.decoder:
|
||||||
|
x = torch.cat([x, xs.pop()], dim=1)
|
||||||
|
x = self._forward_single(block, x, e, context)
|
||||||
|
|
||||||
|
# head
|
||||||
|
x = self.head(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def _forward_single(self, module, x, e, context):
|
||||||
|
if isinstance(module, ResidualBlock):
|
||||||
|
x = module(x, e)
|
||||||
|
elif isinstance(module, AttentionBlock):
|
||||||
|
x = module(x, context)
|
||||||
|
elif isinstance(module, nn.ModuleList):
|
||||||
|
for block in module:
|
||||||
|
x = self._forward_single(block, x, e, context)
|
||||||
|
else:
|
||||||
|
x = module(x)
|
||||||
|
return x
|
||||||
@@ -0,0 +1,641 @@
|
|||||||
|
# The implementation here is modified based on latent diffusion, publicly available
|
||||||
|
# at https://github.com/CompVis/latent-diffusion.
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
__all__ = ['GaussianDiffusion', 'beta_schedule']
|
||||||
|
|
||||||
|
|
||||||
|
def kl_divergence(mu1, logvar1, mu2, logvar2):
|
||||||
|
u1 = -1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2)
|
||||||
|
u2 = ((mu1 - mu2)**2) * torch.exp(-logvar2)
|
||||||
|
return 0.5 * (u1 + u2)
|
||||||
|
|
||||||
|
|
||||||
|
def standard_normal_cdf(x):
|
||||||
|
r"""A fast approximation of the cumulative distribution function of the standard normal.
|
||||||
|
"""
|
||||||
|
return 0.5 * (1.0 + torch.tanh(
|
||||||
|
math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
||||||
|
|
||||||
|
|
||||||
|
def discretized_gaussian_log_likelihood(x0, mean, log_scale):
|
||||||
|
assert x0.shape == mean.shape == log_scale.shape
|
||||||
|
cx = x0 - mean
|
||||||
|
inv_stdv = torch.exp(-log_scale)
|
||||||
|
cdf_plus = standard_normal_cdf(inv_stdv * (cx + 1.0 / 255.0))
|
||||||
|
cdf_min = standard_normal_cdf(inv_stdv * (cx - 1.0 / 255.0))
|
||||||
|
log_cdf_plus = torch.log(cdf_plus.clamp(min=1e-12))
|
||||||
|
log_one_minus_cdf_min = torch.log((1.0 - cdf_min).clamp(min=1e-12))
|
||||||
|
cdf_delta = cdf_plus - cdf_min
|
||||||
|
log_probs = torch.where(
|
||||||
|
x0 < -0.999, log_cdf_plus,
|
||||||
|
torch.where(x0 > 0.999, log_one_minus_cdf_min,
|
||||||
|
torch.log(cdf_delta.clamp(min=1e-12))))
|
||||||
|
assert log_probs.shape == x0.shape
|
||||||
|
return log_probs
|
||||||
|
|
||||||
|
|
||||||
|
def _i(tensor, t, x):
|
||||||
|
r"""Index tensor using t and format the output according to x.
|
||||||
|
"""
|
||||||
|
shape = (x.size(0), ) + (1, ) * (x.ndim - 1)
|
||||||
|
return tensor[t].view(shape).to(x)
|
||||||
|
|
||||||
|
|
||||||
|
def beta_schedule(schedule,
|
||||||
|
num_timesteps=1000,
|
||||||
|
init_beta=None,
|
||||||
|
last_beta=None):
|
||||||
|
if schedule == 'linear':
|
||||||
|
scale = 1000.0 / num_timesteps
|
||||||
|
init_beta = init_beta or scale * 0.0001
|
||||||
|
last_beta = last_beta or scale * 0.02
|
||||||
|
return torch.linspace(
|
||||||
|
init_beta, last_beta, num_timesteps, dtype=torch.float64)
|
||||||
|
elif schedule == 'quadratic':
|
||||||
|
init_beta = init_beta or 0.0015
|
||||||
|
last_beta = last_beta or 0.0195
|
||||||
|
return torch.linspace(
|
||||||
|
init_beta**0.5, last_beta**0.5, num_timesteps,
|
||||||
|
dtype=torch.float64)**2
|
||||||
|
elif schedule == 'cosine':
|
||||||
|
betas = []
|
||||||
|
for step in range(num_timesteps):
|
||||||
|
t1 = step / num_timesteps
|
||||||
|
t2 = (step + 1) / num_timesteps
|
||||||
|
fn_t1 = math.cos((t1 + 0.008) / 1.008 * math.pi / 2)**2
|
||||||
|
fn_t2 = math.cos((t2 + 0.008) / 1.008 * math.pi / 2)**2
|
||||||
|
betas.append(min(1.0 - fn_t2 / fn_t1, 0.999))
|
||||||
|
return torch.tensor(betas, dtype=torch.float64)
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Unsupported schedule: {schedule}')
|
||||||
|
|
||||||
|
|
||||||
|
class GaussianDiffusion(object):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
betas,
|
||||||
|
mean_type='eps',
|
||||||
|
var_type='learned_range',
|
||||||
|
loss_type='mse',
|
||||||
|
rescale_timesteps=False):
|
||||||
|
# check input
|
||||||
|
if not isinstance(betas, torch.DoubleTensor):
|
||||||
|
betas = torch.tensor(betas, dtype=torch.float64)
|
||||||
|
assert min(betas) > 0 and max(betas) <= 1
|
||||||
|
assert mean_type in ['x0', 'x_{t-1}', 'eps']
|
||||||
|
assert var_type in [
|
||||||
|
'learned', 'learned_range', 'fixed_large', 'fixed_small'
|
||||||
|
]
|
||||||
|
assert loss_type in [
|
||||||
|
'mse', 'rescaled_mse', 'kl', 'rescaled_kl', 'l1', 'rescaled_l1'
|
||||||
|
]
|
||||||
|
self.betas = betas
|
||||||
|
self.num_timesteps = len(betas)
|
||||||
|
self.mean_type = mean_type
|
||||||
|
self.var_type = var_type
|
||||||
|
self.loss_type = loss_type
|
||||||
|
self.rescale_timesteps = rescale_timesteps
|
||||||
|
|
||||||
|
# alphas
|
||||||
|
alphas = 1 - self.betas
|
||||||
|
self.alphas_cumprod = torch.cumprod(alphas, dim=0)
|
||||||
|
self.alphas_cumprod_prev = torch.cat(
|
||||||
|
[alphas.new_ones([1]), self.alphas_cumprod[:-1]])
|
||||||
|
self.alphas_cumprod_next = torch.cat(
|
||||||
|
[self.alphas_cumprod[1:],
|
||||||
|
alphas.new_zeros([1])])
|
||||||
|
|
||||||
|
# q(x_t | x_{t-1})
|
||||||
|
self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
|
||||||
|
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0
|
||||||
|
- self.alphas_cumprod)
|
||||||
|
self.log_one_minus_alphas_cumprod = torch.log(1.0
|
||||||
|
- self.alphas_cumprod)
|
||||||
|
self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod)
|
||||||
|
self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod
|
||||||
|
- 1)
|
||||||
|
|
||||||
|
# q(x_{t-1} | x_t, x_0)
|
||||||
|
self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / (
|
||||||
|
1.0 - self.alphas_cumprod)
|
||||||
|
self.posterior_log_variance_clipped = torch.log(
|
||||||
|
self.posterior_variance.clamp(1e-20))
|
||||||
|
self.posterior_mean_coef1 = betas * torch.sqrt(
|
||||||
|
self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
|
||||||
|
self.posterior_mean_coef2 = (
|
||||||
|
1.0 - self.alphas_cumprod_prev) * torch.sqrt(alphas) / (
|
||||||
|
1.0 - self.alphas_cumprod)
|
||||||
|
|
||||||
|
def q_sample(self, x0, t, noise=None):
|
||||||
|
r"""Sample from q(x_t | x_0).
|
||||||
|
"""
|
||||||
|
noise = torch.randn_like(x0) if noise is None else noise
|
||||||
|
u1 = _i(self.sqrt_alphas_cumprod, t, x0) * x0
|
||||||
|
u2 = _i(self.sqrt_one_minus_alphas_cumprod, t, x0) * noise
|
||||||
|
return u1 + u2
|
||||||
|
|
||||||
|
def q_mean_variance(self, x0, t):
|
||||||
|
r"""Distribution of q(x_t | x_0).
|
||||||
|
"""
|
||||||
|
mu = _i(self.sqrt_alphas_cumprod, t, x0) * x0
|
||||||
|
var = _i(1.0 - self.alphas_cumprod, t, x0)
|
||||||
|
log_var = _i(self.log_one_minus_alphas_cumprod, t, x0)
|
||||||
|
return mu, var, log_var
|
||||||
|
|
||||||
|
def q_posterior_mean_variance(self, x0, xt, t):
|
||||||
|
r"""Distribution of q(x_{t-1} | x_t, x_0).
|
||||||
|
"""
|
||||||
|
mu = _i(self.posterior_mean_coef1, t, xt) * x0 + _i(
|
||||||
|
self.posterior_mean_coef2, t, xt) * xt
|
||||||
|
var = _i(self.posterior_variance, t, xt)
|
||||||
|
log_var = _i(self.posterior_log_variance_clipped, t, xt)
|
||||||
|
return mu, var, log_var
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def p_sample(self,
|
||||||
|
xt,
|
||||||
|
t,
|
||||||
|
model,
|
||||||
|
model_kwargs={},
|
||||||
|
clamp=None,
|
||||||
|
percentile=None,
|
||||||
|
condition_fn=None,
|
||||||
|
guide_scale=None):
|
||||||
|
r"""Sample from p(x_{t-1} | x_t).
|
||||||
|
- condition_fn: for classifier-based guidance (guided-diffusion).
|
||||||
|
- guide_scale: for classifier-free guidance (glide/dalle-2).
|
||||||
|
"""
|
||||||
|
# predict distribution of p(x_{t-1} | x_t)
|
||||||
|
mu, var, log_var, x0 = self.p_mean_variance(xt, t, model, model_kwargs,
|
||||||
|
clamp, percentile,
|
||||||
|
guide_scale)
|
||||||
|
|
||||||
|
# random sample (with optional conditional function)
|
||||||
|
noise = torch.randn_like(xt)
|
||||||
|
shape = (-1, *((1, ) * (xt.ndim - 1)))
|
||||||
|
mask = t.ne(0).float().view(shape) # no noise when t == 0
|
||||||
|
if condition_fn is not None:
|
||||||
|
grad = condition_fn(xt, self._scale_timesteps(t), **model_kwargs)
|
||||||
|
mu = mu.float() + var * grad.float()
|
||||||
|
xt_1 = mu + mask * torch.exp(0.5 * log_var) * noise
|
||||||
|
return xt_1, x0
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def p_sample_loop(self,
|
||||||
|
noise,
|
||||||
|
model,
|
||||||
|
model_kwargs={},
|
||||||
|
clamp=None,
|
||||||
|
percentile=None,
|
||||||
|
condition_fn=None,
|
||||||
|
guide_scale=None):
|
||||||
|
r"""Sample from p(x_{t-1} | x_t) p(x_{t-2} | x_{t-1}) ... p(x_0 | x_1).
|
||||||
|
"""
|
||||||
|
# prepare input
|
||||||
|
b = noise.size(0)
|
||||||
|
xt = noise
|
||||||
|
|
||||||
|
# diffusion process
|
||||||
|
for step in torch.arange(self.num_timesteps).flip(0):
|
||||||
|
t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
|
||||||
|
xt, _ = self.p_sample(xt, t, model, model_kwargs, clamp,
|
||||||
|
percentile, condition_fn, guide_scale)
|
||||||
|
return xt
|
||||||
|
|
||||||
|
def p_mean_variance(self,
|
||||||
|
xt,
|
||||||
|
t,
|
||||||
|
model,
|
||||||
|
model_kwargs={},
|
||||||
|
clamp=None,
|
||||||
|
percentile=None,
|
||||||
|
guide_scale=None):
|
||||||
|
r"""Distribution of p(x_{t-1} | x_t).
|
||||||
|
"""
|
||||||
|
# predict distribution
|
||||||
|
if guide_scale is None:
|
||||||
|
out = model(xt, self._scale_timesteps(t), **model_kwargs)
|
||||||
|
else:
|
||||||
|
# classifier-free guidance
|
||||||
|
# (model_kwargs[0]: conditional kwargs; model_kwargs[1]: non-conditional kwargs)
|
||||||
|
assert isinstance(model_kwargs, list) and len(model_kwargs) == 2
|
||||||
|
y_out = model(xt, self._scale_timesteps(t), **model_kwargs[0])
|
||||||
|
u_out = model(xt, self._scale_timesteps(t), **model_kwargs[1])
|
||||||
|
cond = self.var_type.startswith('fixed')
|
||||||
|
dim = y_out.size(1) if cond else y_out.size(1) // 2
|
||||||
|
u1 = u_out[:, :dim]
|
||||||
|
u2 = guide_scale * (y_out[:, :dim] - u_out[:, :dim])
|
||||||
|
out = torch.cat([u1 + u2, y_out[:, dim:]], dim=1)
|
||||||
|
|
||||||
|
# compute variance
|
||||||
|
if self.var_type == 'learned':
|
||||||
|
out, log_var = out.chunk(2, dim=1)
|
||||||
|
var = torch.exp(log_var)
|
||||||
|
elif self.var_type == 'learned_range':
|
||||||
|
out, fraction = out.chunk(2, dim=1)
|
||||||
|
min_log_var = _i(self.posterior_log_variance_clipped, t, xt)
|
||||||
|
max_log_var = _i(torch.log(self.betas), t, xt)
|
||||||
|
fraction = (fraction + 1) / 2.0
|
||||||
|
log_var = fraction * max_log_var + (1 - fraction) * min_log_var
|
||||||
|
var = torch.exp(log_var)
|
||||||
|
elif self.var_type == 'fixed_large':
|
||||||
|
var = _i(
|
||||||
|
torch.cat([self.posterior_variance[1:2], self.betas[1:]]), t,
|
||||||
|
xt)
|
||||||
|
log_var = torch.log(var)
|
||||||
|
elif self.var_type == 'fixed_small':
|
||||||
|
var = _i(self.posterior_variance, t, xt)
|
||||||
|
log_var = _i(self.posterior_log_variance_clipped, t, xt)
|
||||||
|
|
||||||
|
# compute mean and x0
|
||||||
|
if self.mean_type == 'x_{t-1}':
|
||||||
|
mu = out # x_{t-1}
|
||||||
|
u1 = _i(1.0 / self.posterior_mean_coef1, t, xt) * mu
|
||||||
|
u2 = _i(self.posterior_mean_coef2 / self.posterior_mean_coef1, t,
|
||||||
|
xt) * xt
|
||||||
|
x0 = u1 - u2
|
||||||
|
elif self.mean_type == 'x0':
|
||||||
|
x0 = out
|
||||||
|
mu, _, _ = self.q_posterior_mean_variance(x0, xt, t)
|
||||||
|
elif self.mean_type == 'eps':
|
||||||
|
u1 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt
|
||||||
|
u2 = _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * out
|
||||||
|
x0 = u1 - u2
|
||||||
|
mu, _, _ = self.q_posterior_mean_variance(x0, xt, t)
|
||||||
|
|
||||||
|
# restrict the range of x0
|
||||||
|
if percentile is not None:
|
||||||
|
assert percentile > 0 and percentile <= 1 # e.g., 0.995
|
||||||
|
s = torch.quantile(
|
||||||
|
x0.flatten(1).abs(), percentile,
|
||||||
|
dim=1).clamp_(1.0).view(-1, 1, 1, 1)
|
||||||
|
x0 = torch.min(s, torch.max(-s, x0)) / s
|
||||||
|
elif clamp is not None:
|
||||||
|
x0 = x0.clamp(-clamp, clamp)
|
||||||
|
return mu, var, log_var, x0
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def ddim_sample(self,
|
||||||
|
xt,
|
||||||
|
t,
|
||||||
|
model,
|
||||||
|
model_kwargs={},
|
||||||
|
clamp=None,
|
||||||
|
percentile=None,
|
||||||
|
condition_fn=None,
|
||||||
|
guide_scale=None,
|
||||||
|
ddim_timesteps=20,
|
||||||
|
eta=0.0):
|
||||||
|
r"""Sample from p(x_{t-1} | x_t) using DDIM.
|
||||||
|
- condition_fn: for classifier-based guidance (guided-diffusion).
|
||||||
|
- guide_scale: for classifier-free guidance (glide/dalle-2).
|
||||||
|
"""
|
||||||
|
stride = self.num_timesteps // ddim_timesteps
|
||||||
|
|
||||||
|
# predict distribution of p(x_{t-1} | x_t)
|
||||||
|
_, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp,
|
||||||
|
percentile, guide_scale)
|
||||||
|
if condition_fn is not None:
|
||||||
|
# x0 -> eps
|
||||||
|
alpha = _i(self.alphas_cumprod, t, xt)
|
||||||
|
u1 = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0)
|
||||||
|
u2 = _i(self.sqrt_recipm1_alphas_cumprod, t, xt)
|
||||||
|
eps = u1 / u2
|
||||||
|
eps = eps - (1 - alpha).sqrt() * condition_fn(
|
||||||
|
xt, self._scale_timesteps(t), **model_kwargs)
|
||||||
|
|
||||||
|
# eps -> x0
|
||||||
|
u1 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt
|
||||||
|
u2 = _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * eps
|
||||||
|
x0 = u1 - u2
|
||||||
|
|
||||||
|
# derive variables
|
||||||
|
u1 = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0)
|
||||||
|
u2 = _i(self.sqrt_recipm1_alphas_cumprod, t, xt)
|
||||||
|
eps = u1 / u2
|
||||||
|
alphas = _i(self.alphas_cumprod, t, xt)
|
||||||
|
alphas_prev = _i(self.alphas_cumprod, (t - stride).clamp(0), xt)
|
||||||
|
u1 = (1 - alphas_prev) / (1 - alphas)
|
||||||
|
u2 = (1 - alphas / alphas_prev)
|
||||||
|
sigmas = eta * torch.sqrt(u1 * u2)
|
||||||
|
|
||||||
|
# random sample
|
||||||
|
noise = torch.randn_like(xt)
|
||||||
|
direction = torch.sqrt(1 - alphas_prev - sigmas**2) * eps
|
||||||
|
mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1)))
|
||||||
|
xt_1 = torch.sqrt(alphas_prev) * x0 + direction + mask * sigmas * noise
|
||||||
|
return xt_1, x0
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def ddim_sample_loop(self,
|
||||||
|
noise,
|
||||||
|
model,
|
||||||
|
model_kwargs={},
|
||||||
|
clamp=None,
|
||||||
|
percentile=None,
|
||||||
|
condition_fn=None,
|
||||||
|
guide_scale=None,
|
||||||
|
ddim_timesteps=20,
|
||||||
|
eta=0.0):
|
||||||
|
# prepare input
|
||||||
|
b = noise.size(0)
|
||||||
|
xt = noise
|
||||||
|
|
||||||
|
# diffusion process (TODO: clamp is inaccurate! Consider replacing the stride by explicit prev/next steps)
|
||||||
|
steps = (1 + torch.arange(0, self.num_timesteps,
|
||||||
|
self.num_timesteps // ddim_timesteps)).clamp(
|
||||||
|
0, self.num_timesteps - 1).flip(0)
|
||||||
|
for step in steps:
|
||||||
|
t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
|
||||||
|
xt, _ = self.ddim_sample(xt, t, model, model_kwargs, clamp,
|
||||||
|
percentile, condition_fn, guide_scale,
|
||||||
|
ddim_timesteps, eta)
|
||||||
|
return xt
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def ddim_reverse_sample(self,
|
||||||
|
xt,
|
||||||
|
t,
|
||||||
|
model,
|
||||||
|
model_kwargs={},
|
||||||
|
clamp=None,
|
||||||
|
percentile=None,
|
||||||
|
guide_scale=None,
|
||||||
|
ddim_timesteps=20):
|
||||||
|
r"""Sample from p(x_{t+1} | x_t) using DDIM reverse ODE (deterministic).
|
||||||
|
"""
|
||||||
|
stride = self.num_timesteps // ddim_timesteps
|
||||||
|
|
||||||
|
# predict distribution of p(x_{t-1} | x_t)
|
||||||
|
_, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp,
|
||||||
|
percentile, guide_scale)
|
||||||
|
|
||||||
|
# derive variables
|
||||||
|
u1 = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0)
|
||||||
|
u2 = _i(self.sqrt_recipm1_alphas_cumprod, t, xt)
|
||||||
|
eps = u1 / u2
|
||||||
|
|
||||||
|
alphas_next = _i(
|
||||||
|
torch.cat(
|
||||||
|
[self.alphas_cumprod,
|
||||||
|
self.alphas_cumprod.new_zeros([1])]),
|
||||||
|
(t + stride).clamp(0, self.num_timesteps), xt)
|
||||||
|
|
||||||
|
# reverse sample
|
||||||
|
mu = torch.sqrt(alphas_next) * x0 + torch.sqrt(1 - alphas_next) * eps
|
||||||
|
return mu, x0
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def ddim_reverse_sample_loop(self,
|
||||||
|
x0,
|
||||||
|
model,
|
||||||
|
model_kwargs={},
|
||||||
|
clamp=None,
|
||||||
|
percentile=None,
|
||||||
|
guide_scale=None,
|
||||||
|
ddim_timesteps=20):
|
||||||
|
# prepare input
|
||||||
|
b = x0.size(0)
|
||||||
|
xt = x0
|
||||||
|
|
||||||
|
# reconstruction steps
|
||||||
|
steps = torch.arange(0, self.num_timesteps,
|
||||||
|
self.num_timesteps // ddim_timesteps)
|
||||||
|
for step in steps:
|
||||||
|
t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
|
||||||
|
xt, _ = self.ddim_reverse_sample(xt, t, model, model_kwargs, clamp,
|
||||||
|
percentile, guide_scale,
|
||||||
|
ddim_timesteps)
|
||||||
|
return xt
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def plms_sample(self,
|
||||||
|
xt,
|
||||||
|
t,
|
||||||
|
model,
|
||||||
|
model_kwargs={},
|
||||||
|
clamp=None,
|
||||||
|
percentile=None,
|
||||||
|
condition_fn=None,
|
||||||
|
guide_scale=None,
|
||||||
|
plms_timesteps=20):
|
||||||
|
r"""Sample from p(x_{t-1} | x_t) using PLMS.
|
||||||
|
- condition_fn: for classifier-based guidance (guided-diffusion).
|
||||||
|
- guide_scale: for classifier-free guidance (glide/dalle-2).
|
||||||
|
"""
|
||||||
|
stride = self.num_timesteps // plms_timesteps
|
||||||
|
|
||||||
|
# function for compute eps
|
||||||
|
def compute_eps(xt, t):
|
||||||
|
# predict distribution of p(x_{t-1} | x_t)
|
||||||
|
_, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs,
|
||||||
|
clamp, percentile, guide_scale)
|
||||||
|
|
||||||
|
# condition
|
||||||
|
if condition_fn is not None:
|
||||||
|
# x0 -> eps
|
||||||
|
alpha = _i(self.alphas_cumprod, t, xt)
|
||||||
|
u1 = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0)
|
||||||
|
u2 = _i(self.sqrt_recipm1_alphas_cumprod, t, xt)
|
||||||
|
eps = u1 / u2
|
||||||
|
eps = eps - (1 - alpha).sqrt() * condition_fn(
|
||||||
|
xt, self._scale_timesteps(t), **model_kwargs)
|
||||||
|
|
||||||
|
# eps -> x0
|
||||||
|
u1 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt
|
||||||
|
u2 = _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * eps
|
||||||
|
x0 = u1 - u2
|
||||||
|
|
||||||
|
# derive eps
|
||||||
|
u1 = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0)
|
||||||
|
u2 = _i(self.sqrt_recipm1_alphas_cumprod, t, xt)
|
||||||
|
eps = u1 / u2
|
||||||
|
return eps
|
||||||
|
|
||||||
|
# function for compute x_0 and x_{t-1}
|
||||||
|
def compute_x0(eps, t):
|
||||||
|
# eps -> x0
|
||||||
|
u1 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt
|
||||||
|
u2 = _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * eps
|
||||||
|
x0 = u1 - u2
|
||||||
|
|
||||||
|
# deterministic sample
|
||||||
|
alphas_prev = _i(self.alphas_cumprod, (t - stride).clamp(0), xt)
|
||||||
|
direction = torch.sqrt(1 - alphas_prev) * eps
|
||||||
|
# mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1)))
|
||||||
|
xt_1 = torch.sqrt(alphas_prev) * x0 + direction
|
||||||
|
return xt_1, x0
|
||||||
|
|
||||||
|
# PLMS sample
|
||||||
|
eps = compute_eps(xt, t)
|
||||||
|
if len(eps_cache) == 0:
|
||||||
|
# 2nd order pseudo improved Euler
|
||||||
|
xt_1, x0 = compute_x0(eps, t)
|
||||||
|
eps_next = compute_eps(xt_1, (t - stride).clamp(0))
|
||||||
|
eps_prime = (eps + eps_next) / 2.0
|
||||||
|
elif len(eps_cache) == 1:
|
||||||
|
# 2nd order pseudo linear multistep (Adams-Bashforth)
|
||||||
|
eps_prime = (3 * eps - eps_cache[-1]) / 2.0
|
||||||
|
elif len(eps_cache) == 2:
|
||||||
|
# 3nd order pseudo linear multistep (Adams-Bashforth)
|
||||||
|
eps_prime = (23 * eps - 16 * eps_cache[-1]
|
||||||
|
+ 5 * eps_cache[-2]) / 12.0
|
||||||
|
elif len(eps_cache) >= 3:
|
||||||
|
# 4nd order pseudo linear multistep (Adams-Bashforth)
|
||||||
|
eps_prime = (55 * eps - 59 * eps_cache[-1] + 37 * eps_cache[-2]
|
||||||
|
- 9 * eps_cache[-3]) / 24.0
|
||||||
|
xt_1, x0 = compute_x0(eps_prime, t)
|
||||||
|
return xt_1, x0, eps
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def plms_sample_loop(self,
|
||||||
|
noise,
|
||||||
|
model,
|
||||||
|
model_kwargs={},
|
||||||
|
clamp=None,
|
||||||
|
percentile=None,
|
||||||
|
condition_fn=None,
|
||||||
|
guide_scale=None,
|
||||||
|
plms_timesteps=20):
|
||||||
|
# prepare input
|
||||||
|
b = noise.size(0)
|
||||||
|
xt = noise
|
||||||
|
|
||||||
|
# diffusion process
|
||||||
|
steps = (1 + torch.arange(0, self.num_timesteps,
|
||||||
|
self.num_timesteps // plms_timesteps)).clamp(
|
||||||
|
0, self.num_timesteps - 1).flip(0)
|
||||||
|
eps_cache = []
|
||||||
|
for step in steps:
|
||||||
|
# PLMS sampling step
|
||||||
|
t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
|
||||||
|
xt, _, eps = self.plms_sample(xt, t, model, model_kwargs, clamp,
|
||||||
|
percentile, condition_fn,
|
||||||
|
guide_scale, plms_timesteps,
|
||||||
|
eps_cache)
|
||||||
|
|
||||||
|
# update eps cache
|
||||||
|
eps_cache.append(eps)
|
||||||
|
if len(eps_cache) >= 4:
|
||||||
|
eps_cache.pop(0)
|
||||||
|
return xt
|
||||||
|
|
||||||
|
def loss(self, x0, t, model, model_kwargs={}, noise=None, input_x0=None):
|
||||||
|
noise = torch.randn_like(x0) if noise is None else noise
|
||||||
|
input_x0 = x0 if input_x0 is None else input_x0
|
||||||
|
xt = self.q_sample(input_x0, t, noise=noise)
|
||||||
|
|
||||||
|
# compute loss
|
||||||
|
if self.loss_type in ['kl', 'rescaled_kl']:
|
||||||
|
loss, _ = self.variational_lower_bound(x0, xt, t, model,
|
||||||
|
model_kwargs)
|
||||||
|
if self.loss_type == 'rescaled_kl':
|
||||||
|
loss = loss * self.num_timesteps
|
||||||
|
elif self.loss_type in ['mse', 'rescaled_mse', 'l1', 'rescaled_l1']:
|
||||||
|
out = model(xt, self._scale_timesteps(t), **model_kwargs)
|
||||||
|
|
||||||
|
# VLB for variation
|
||||||
|
loss_vlb = 0.0
|
||||||
|
if self.var_type in ['learned', 'learned_range']:
|
||||||
|
out, var = out.chunk(2, dim=1)
|
||||||
|
frozen = torch.cat([
|
||||||
|
out.detach(), var
|
||||||
|
], dim=1) # learn var without affecting the prediction of mean
|
||||||
|
loss_vlb, _ = self.variational_lower_bound(
|
||||||
|
x0, xt, t, model=lambda *args, **kwargs: frozen)
|
||||||
|
if self.loss_type.startswith('rescaled_'):
|
||||||
|
loss_vlb = loss_vlb * self.num_timesteps / 1000.0
|
||||||
|
|
||||||
|
# MSE/L1 for x0/eps
|
||||||
|
target = {
|
||||||
|
'eps': noise,
|
||||||
|
'x0': x0,
|
||||||
|
'x_{t-1}': self.q_posterior_mean_variance(x0, xt, t)[0]
|
||||||
|
}[self.mean_type]
|
||||||
|
loss = (out - target).pow(1 if self.loss_type.endswith('l1') else 2
|
||||||
|
).abs().flatten(1).mean(dim=1)
|
||||||
|
|
||||||
|
# total loss
|
||||||
|
loss = loss + loss_vlb
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def variational_lower_bound(self,
|
||||||
|
x0,
|
||||||
|
xt,
|
||||||
|
t,
|
||||||
|
model,
|
||||||
|
model_kwargs={},
|
||||||
|
clamp=None,
|
||||||
|
percentile=None):
|
||||||
|
# compute groundtruth and predicted distributions
|
||||||
|
mu1, _, log_var1 = self.q_posterior_mean_variance(x0, xt, t)
|
||||||
|
mu2, _, log_var2, x0 = self.p_mean_variance(xt, t, model, model_kwargs,
|
||||||
|
clamp, percentile)
|
||||||
|
|
||||||
|
# compute KL loss
|
||||||
|
kl = kl_divergence(mu1, log_var1, mu2, log_var2)
|
||||||
|
kl = kl.flatten(1).mean(dim=1) / math.log(2.0)
|
||||||
|
|
||||||
|
# compute discretized NLL loss (for p(x0 | x1) only)
|
||||||
|
nll = -discretized_gaussian_log_likelihood(
|
||||||
|
x0, mean=mu2, log_scale=0.5 * log_var2)
|
||||||
|
nll = nll.flatten(1).mean(dim=1) / math.log(2.0)
|
||||||
|
|
||||||
|
# NLL for p(x0 | x1) and KL otherwise
|
||||||
|
vlb = torch.where(t == 0, nll, kl)
|
||||||
|
return vlb, x0
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def variational_lower_bound_loop(self,
|
||||||
|
x0,
|
||||||
|
model,
|
||||||
|
model_kwargs={},
|
||||||
|
clamp=None,
|
||||||
|
percentile=None):
|
||||||
|
r"""Compute the entire variational lower bound, measured in bits-per-dim.
|
||||||
|
"""
|
||||||
|
# prepare input and output
|
||||||
|
b = x0.size(0)
|
||||||
|
metrics = {'vlb': [], 'mse': [], 'x0_mse': []}
|
||||||
|
|
||||||
|
# loop
|
||||||
|
for step in torch.arange(self.num_timesteps).flip(0):
|
||||||
|
# compute VLB
|
||||||
|
t = torch.full((b, ), step, dtype=torch.long, device=x0.device)
|
||||||
|
noise = torch.randn_like(x0)
|
||||||
|
xt = self.q_sample(x0, t, noise)
|
||||||
|
vlb, pred_x0 = self.variational_lower_bound(
|
||||||
|
x0, xt, t, model, model_kwargs, clamp, percentile)
|
||||||
|
|
||||||
|
# predict eps from x0
|
||||||
|
u1 = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0)
|
||||||
|
u2 = _i(self.sqrt_recipm1_alphas_cumprod, t, xt)
|
||||||
|
eps = u1 / u2
|
||||||
|
|
||||||
|
# collect metrics
|
||||||
|
metrics['vlb'].append(vlb)
|
||||||
|
metrics['x0_mse'].append(
|
||||||
|
(pred_x0 - x0).square().flatten(1).mean(dim=1))
|
||||||
|
metrics['mse'].append(
|
||||||
|
(eps - noise).square().flatten(1).mean(dim=1))
|
||||||
|
metrics = {k: torch.stack(v, dim=1) for k, v in metrics.items()}
|
||||||
|
|
||||||
|
# compute the prior KL term for VLB, measured in bits-per-dim
|
||||||
|
mu, _, log_var = self.q_mean_variance(x0, t)
|
||||||
|
kl_prior = kl_divergence(mu, log_var, torch.zeros_like(mu),
|
||||||
|
torch.zeros_like(log_var))
|
||||||
|
kl_prior = kl_prior.flatten(1).mean(dim=1) / math.log(2.0)
|
||||||
|
|
||||||
|
# update metrics
|
||||||
|
metrics['prior_bits_per_dim'] = kl_prior
|
||||||
|
metrics['total_bits_per_dim'] = metrics['vlb'].sum(dim=1) + kl_prior
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
def _scale_timesteps(self, t):
|
||||||
|
if self.rescale_timesteps:
|
||||||
|
return t.float() * 1000.0 / self.num_timesteps
|
||||||
|
return t
|
||||||
265
modelscope/models/multi_modal/multi_stage_diffusion/model.py
Normal file
265
modelscope/models/multi_modal/multi_stage_diffusion/model.py
Normal file
@@ -0,0 +1,265 @@
|
|||||||
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||||
|
|
||||||
|
import math
|
||||||
|
import os.path as osp
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
import json
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.cuda.amp as amp
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from modelscope.metainfo import Models
|
||||||
|
from modelscope.models import TorchModel
|
||||||
|
from modelscope.models.builder import MODELS
|
||||||
|
from modelscope.models.multi_modal.multi_stage_diffusion.clip import CLIP
|
||||||
|
from modelscope.models.multi_modal.multi_stage_diffusion.decoder import Decoder
|
||||||
|
from modelscope.models.multi_modal.multi_stage_diffusion.gaussian_diffusion import (
|
||||||
|
GaussianDiffusion, beta_schedule)
|
||||||
|
from modelscope.models.multi_modal.multi_stage_diffusion.prior import Prior
|
||||||
|
from modelscope.models.multi_modal.multi_stage_diffusion.tokenizer import (
|
||||||
|
CLIPTokenizer, XGLMTokenizer)
|
||||||
|
from modelscope.models.multi_modal.multi_stage_diffusion.upsampler import (
|
||||||
|
Upsampler256, Upsampler1024)
|
||||||
|
from modelscope.models.multi_modal.multi_stage_diffusion.xglm import XGLM
|
||||||
|
from modelscope.utils.constant import ModelFile, Tasks
|
||||||
|
from modelscope.utils.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
__all__ = ['MultiStageDiffusionForTextToImageSynthesis']
|
||||||
|
|
||||||
|
|
||||||
|
def make_diffusion(schedule,
|
||||||
|
num_timesteps=1000,
|
||||||
|
init_beta=None,
|
||||||
|
last_beta=None,
|
||||||
|
mean_type='eps',
|
||||||
|
var_type='fixed_small'):
|
||||||
|
betas = beta_schedule(schedule, num_timesteps, init_beta, last_beta)
|
||||||
|
diffusion = GaussianDiffusion(
|
||||||
|
betas, mean_type=mean_type, var_type=var_type)
|
||||||
|
return diffusion
|
||||||
|
|
||||||
|
|
||||||
|
class UnCLIP(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, model_dir):
|
||||||
|
super(UnCLIP, self).__init__()
|
||||||
|
self.model_dir = model_dir
|
||||||
|
self.config = json.load(open(f'{model_dir}/{ModelFile.CONFIGURATION}'))
|
||||||
|
|
||||||
|
# modules
|
||||||
|
self.clip = CLIP(**self.config['clip']).fp16()
|
||||||
|
self.xglm = XGLM(**self.config['xglm'])
|
||||||
|
self.prior = Prior(**self.config['prior'])
|
||||||
|
self.decoder = Decoder(**self.config['decoder'])
|
||||||
|
self.upsampler256 = Upsampler256(**self.config['upsampler256'])
|
||||||
|
self.upsampler1024 = Upsampler1024(**self.config['upsampler1024'])
|
||||||
|
|
||||||
|
# diffusions
|
||||||
|
self.prior_diffusion = make_diffusion(**self.config['prior_diffusion'])
|
||||||
|
self.decoder_diffusion = make_diffusion(
|
||||||
|
**self.config['decoder_diffusion'])
|
||||||
|
self.upsampler256_diffusion = make_diffusion(
|
||||||
|
**self.config['upsampler256_diffusion'])
|
||||||
|
self.upsampler1024_diffusion = make_diffusion(
|
||||||
|
**self.config['upsampler1024_diffusion'])
|
||||||
|
|
||||||
|
# tokenizers
|
||||||
|
self.clip_tokenizer = CLIPTokenizer(
|
||||||
|
bpe_path=f'{model_dir}/bpe_simple_vocab_16e6.txt.gz')
|
||||||
|
self.xglm_tokenizer = XGLMTokenizer(model_dir=model_dir)
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
raise NotImplementedError(
|
||||||
|
'"forward" is not implemented. Use "synthesis" instead.')
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def synthesis(self,
|
||||||
|
text='A photo of a confused grizzly bear in calculus class.',
|
||||||
|
tokenizer='clip',
|
||||||
|
batch_size=4,
|
||||||
|
timesteps_prior=100,
|
||||||
|
timesteps_64=50,
|
||||||
|
timesteps_256=20,
|
||||||
|
timesteps_1024=20,
|
||||||
|
guide_prior=3.0,
|
||||||
|
guide_64=7.0,
|
||||||
|
guide_256=3.0,
|
||||||
|
guide_1024=3.0,
|
||||||
|
eta_prior=0.0,
|
||||||
|
eta_64=0.0,
|
||||||
|
eta_256=0.0,
|
||||||
|
eta_1024=0.0):
|
||||||
|
device = next(self.parameters()).device
|
||||||
|
|
||||||
|
# check params
|
||||||
|
assert all([
|
||||||
|
t > 0 and t <= 1000 for t in
|
||||||
|
[timesteps_prior, timesteps_64, timesteps_256, timesteps_1024]
|
||||||
|
])
|
||||||
|
assert all([
|
||||||
|
g > 1 and g < 15
|
||||||
|
for g in [guide_prior, guide_64, guide_256, guide_1024]
|
||||||
|
])
|
||||||
|
assert all([
|
||||||
|
e >= 0 and e <= 1.0
|
||||||
|
for e in [eta_prior, eta_64, eta_256, eta_1024]
|
||||||
|
])
|
||||||
|
assert batch_size >= 1 and batch_size <= 16
|
||||||
|
|
||||||
|
# tokenize the text
|
||||||
|
if tokenizer == 'clip':
|
||||||
|
y = F.normalize(
|
||||||
|
self.clip.textual(self.clip_tokenizer([text]).to(device)),
|
||||||
|
p=2,
|
||||||
|
dim=1)
|
||||||
|
zero_y = F.normalize(
|
||||||
|
self.clip.textual(self.clip_tokenizer(['']).to(device)),
|
||||||
|
p=2,
|
||||||
|
dim=1)
|
||||||
|
elif tokenizer == 'xglm':
|
||||||
|
y = F.normalize(
|
||||||
|
self.xglm(*to_device(self.xglm_tokenizer([text]), device)),
|
||||||
|
p=2,
|
||||||
|
dim=1)
|
||||||
|
zero_y = F.normalize(
|
||||||
|
self.xglm(*to_device(self.xglm_tokenizer(['']), device)),
|
||||||
|
p=2,
|
||||||
|
dim=1)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f'Expected tokenizer to be one of "clip" or "xglm", but got {tokenizer}'
|
||||||
|
)
|
||||||
|
y = math.sqrt(y.size(1)) * y.repeat(batch_size, 1)
|
||||||
|
zero_y = math.sqrt(zero_y.size(1)) * zero_y.repeat(batch_size, 1)
|
||||||
|
|
||||||
|
# synthesis
|
||||||
|
with amp.autocast(enabled=True):
|
||||||
|
# prior
|
||||||
|
x0 = self.prior_diffusion.ddim_sample_loop(
|
||||||
|
noise=torch.randn_like(y),
|
||||||
|
model=self.prior,
|
||||||
|
model_kwargs=[{
|
||||||
|
'y': y
|
||||||
|
}, {
|
||||||
|
'y': zero_y
|
||||||
|
}],
|
||||||
|
guide_scale=guide_prior,
|
||||||
|
ddim_timesteps=timesteps_prior,
|
||||||
|
eta=eta_prior)
|
||||||
|
|
||||||
|
# decoder
|
||||||
|
imgs64 = self.decoder_diffusion.ddim_sample_loop(
|
||||||
|
noise=torch.randn(batch_size, 3, 64, 64).to(device),
|
||||||
|
model=self.decoder,
|
||||||
|
model_kwargs=[{
|
||||||
|
'y': x0
|
||||||
|
}, {
|
||||||
|
'y': torch.zeros_like(x0)
|
||||||
|
}],
|
||||||
|
guide_scale=guide_64,
|
||||||
|
percentile=0.995,
|
||||||
|
ddim_timesteps=timesteps_64,
|
||||||
|
eta=eta_64).clamp_(-1, 1)
|
||||||
|
|
||||||
|
# upsampler256
|
||||||
|
imgs256 = F.interpolate(
|
||||||
|
imgs64, scale_factor=4.0, mode='bilinear', align_corners=False)
|
||||||
|
imgs256 = self.upsampler256_diffusion.ddim_sample_loop(
|
||||||
|
noise=torch.randn_like(imgs256),
|
||||||
|
model=self.upsampler256,
|
||||||
|
model_kwargs=[{
|
||||||
|
'y': y,
|
||||||
|
'concat': imgs256
|
||||||
|
}, {
|
||||||
|
'y': zero_y,
|
||||||
|
'concat': imgs256
|
||||||
|
}],
|
||||||
|
guide_scale=guide_256,
|
||||||
|
percentile=0.995,
|
||||||
|
ddim_timesteps=timesteps_256,
|
||||||
|
eta=eta_256).clamp_(-1, 1)
|
||||||
|
|
||||||
|
# upsampler1024
|
||||||
|
imgs1024 = F.interpolate(
|
||||||
|
imgs256,
|
||||||
|
scale_factor=4.0,
|
||||||
|
mode='bilinear',
|
||||||
|
align_corners=False)
|
||||||
|
imgs1024 = self.upsampler1024_diffusion.ddim_sample_loop(
|
||||||
|
noise=torch.randn_like(imgs1024),
|
||||||
|
model=self.upsampler1024,
|
||||||
|
model_kwargs=[{
|
||||||
|
'y': y,
|
||||||
|
'concat': imgs1024
|
||||||
|
}, {
|
||||||
|
'y': zero_y,
|
||||||
|
'concat': imgs1024
|
||||||
|
}],
|
||||||
|
guide_scale=guide_1024,
|
||||||
|
percentile=0.995,
|
||||||
|
ddim_timesteps=timesteps_1024,
|
||||||
|
eta=eta_1024).clamp_(-1, 1)
|
||||||
|
|
||||||
|
# output ([B, C, H, W] within range [0, 1])
|
||||||
|
imgs1024 = imgs1024.add_(1).mul_(255 / 2.0).permute(0, 2, 3, 1).cpu()
|
||||||
|
imgs1024 = [
|
||||||
|
Image.fromarray(np.array(u, dtype=np.uint8)) for u in imgs1024
|
||||||
|
]
|
||||||
|
return imgs1024
|
||||||
|
|
||||||
|
|
||||||
|
@MODELS.register_module(
|
||||||
|
Tasks.text_to_image_synthesis, module_name=Models.multi_stage_diffusion)
|
||||||
|
class MultiStageDiffusionForTextToImageSynthesis(TorchModel):
|
||||||
|
|
||||||
|
def __init__(self, model_dir, device_id=-1):
|
||||||
|
super().__init__(model_dir=model_dir, device_id=device_id)
|
||||||
|
model = UnCLIP(model_dir=model_dir)
|
||||||
|
pretrained_params = torch.load(
|
||||||
|
osp.join(model_dir, ModelFile.TORCH_MODEL_BIN_FILE), 'cpu')
|
||||||
|
model.load_state_dict(pretrained_params)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
self.device_id = device_id
|
||||||
|
if self.device_id >= 0:
|
||||||
|
self.device = torch.device(f'cuda:{self.device_id}')
|
||||||
|
model.to('cuda:{}'.format(self.device_id))
|
||||||
|
logger.info('Use GPU: {}'.format(self.device_id))
|
||||||
|
else:
|
||||||
|
self.device = torch.device('cpu')
|
||||||
|
logger.info('Use CPU for inference')
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
if not isinstance(input, dict):
|
||||||
|
raise ValueError(
|
||||||
|
f'Expected the input to be a dictionary, but got {type(input)}'
|
||||||
|
)
|
||||||
|
if 'text' not in input:
|
||||||
|
raise ValueError('input should contain "text", but not found')
|
||||||
|
|
||||||
|
# ddim sampling
|
||||||
|
imgs = self.model.synthesis(
|
||||||
|
text=input.get('text'),
|
||||||
|
tokenizer=input.get('tokenizer', 'clip'),
|
||||||
|
batch_size=input.get('batch_size', 4),
|
||||||
|
timesteps_prior=input.get('timesteps_prior', 100),
|
||||||
|
timesteps_64=input.get('timesteps_64', 50),
|
||||||
|
timesteps_256=input.get('timesteps_256', 20),
|
||||||
|
timesteps_1024=input.get('timesteps_1024', 20),
|
||||||
|
guide_prior=input.get('guide_prior', 3.0),
|
||||||
|
guide_64=input.get('guide_64', 7.0),
|
||||||
|
guide_256=input.get('guide_256', 3.0),
|
||||||
|
guide_1024=input.get('guide_1024', 3.0),
|
||||||
|
eta_prior=input.get('eta_prior', 0.0),
|
||||||
|
eta_64=input.get('eta_64', 0.0),
|
||||||
|
eta_256=input.get('eta_256', 0.0),
|
||||||
|
eta_1024=input.get('eta_1024', 0.0))
|
||||||
|
imgs = [np.array(u)[..., ::-1] for u in imgs]
|
||||||
|
return imgs
|
||||||
170
modelscope/models/multi_modal/multi_stage_diffusion/prior.py
Normal file
170
modelscope/models/multi_modal/multi_stage_diffusion/prior.py
Normal file
@@ -0,0 +1,170 @@
|
|||||||
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
__all__ = ['Prior']
|
||||||
|
|
||||||
|
|
||||||
|
def sinusoidal_embedding(timesteps, dim):
|
||||||
|
# check input
|
||||||
|
half = dim // 2
|
||||||
|
timesteps = timesteps.float()
|
||||||
|
|
||||||
|
# compute sinusoidal embedding
|
||||||
|
sinusoid = torch.outer(
|
||||||
|
timesteps, torch.pow(10000,
|
||||||
|
-torch.arange(half).to(timesteps).div(half)))
|
||||||
|
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
|
||||||
|
if dim % 2 != 0:
|
||||||
|
x = torch.cat([x, torch.zeros_like(x[:, :1])], dim=1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SelfAttention(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, dim, num_heads):
|
||||||
|
assert dim % num_heads == 0
|
||||||
|
super(SelfAttention, self).__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = dim // num_heads
|
||||||
|
self.scale = math.pow(self.head_dim, -0.25)
|
||||||
|
|
||||||
|
# layers
|
||||||
|
self.to_qkv = nn.Linear(dim, dim * 3)
|
||||||
|
self.proj = nn.Linear(dim, dim)
|
||||||
|
|
||||||
|
def forward(self, x, mask):
|
||||||
|
b, l, n, c = *x.shape[:2], self.num_heads, self.head_dim
|
||||||
|
|
||||||
|
# compute query, key, value
|
||||||
|
q, k, v = self.to_qkv(x).view(b, l, n * 3, c).chunk(3, dim=2)
|
||||||
|
|
||||||
|
# compute attention
|
||||||
|
attn = torch.einsum('binc,bjnc->bnij', q * self.scale, k * self.scale)
|
||||||
|
if mask is not None:
|
||||||
|
attn = attn.masked_fill(mask[:, :, :l, :l] == 0, float('-inf'))
|
||||||
|
attn = F.softmax(attn.float(), dim=-1).type(attn.dtype)
|
||||||
|
|
||||||
|
# gather context
|
||||||
|
x = torch.einsum('bnij,bjnc->binc', attn, v)
|
||||||
|
x = x.reshape(b, l, -1)
|
||||||
|
|
||||||
|
# output
|
||||||
|
x = self.proj(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionBlock(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, dim, num_heads):
|
||||||
|
super(AttentionBlock, self).__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
|
||||||
|
# layers
|
||||||
|
self.norm1 = nn.LayerNorm(dim)
|
||||||
|
self.attn = SelfAttention(dim, num_heads)
|
||||||
|
self.norm2 = nn.LayerNorm(dim)
|
||||||
|
self.ffn = nn.Sequential(
|
||||||
|
nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim))
|
||||||
|
|
||||||
|
def forward(self, x, mask=None):
|
||||||
|
x = x + self.attn(self.norm1(x), mask)
|
||||||
|
x = x + self.ffn(self.norm2(x))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Prior(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, dim=2048, clip_dim=768, num_heads=32, num_layers=24):
|
||||||
|
super(Prior, self).__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.clip_dim = clip_dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.num_layers = num_layers
|
||||||
|
|
||||||
|
# embeddings
|
||||||
|
self.text_embedding = nn.Sequential(
|
||||||
|
nn.Linear(clip_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
|
||||||
|
self.time_embedding = nn.Sequential(
|
||||||
|
nn.Linear(dim, dim), nn.SiLU(), nn.Linear(dim, dim))
|
||||||
|
self.vision_embedding = nn.Sequential(
|
||||||
|
nn.Linear(clip_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
|
||||||
|
self.eos_embedding = nn.Parameter(torch.zeros(1, 1, dim))
|
||||||
|
self.pos_embedding = nn.Parameter(torch.zeros(1, 4, dim))
|
||||||
|
|
||||||
|
# transformer
|
||||||
|
self.blocks = nn.ModuleList(
|
||||||
|
[AttentionBlock(dim, num_heads) for _ in range(num_layers)])
|
||||||
|
self.norm = nn.LayerNorm(dim)
|
||||||
|
|
||||||
|
# head
|
||||||
|
self.head = nn.Linear(dim, clip_dim)
|
||||||
|
|
||||||
|
# causal attention mask
|
||||||
|
self.register_buffer('attn_mask', torch.tril(torch.ones(1, 1, 4, 4)))
|
||||||
|
|
||||||
|
# initialize weights
|
||||||
|
self.init_weights()
|
||||||
|
|
||||||
|
def forward(self, x, t, y):
|
||||||
|
r"""x: [B, C].
|
||||||
|
t: [B].
|
||||||
|
y: [B, C].
|
||||||
|
"""
|
||||||
|
b = x.size(0)
|
||||||
|
|
||||||
|
# embeddings of shape [B, L + 4, C]
|
||||||
|
u1 = sinusoidal_embedding(t, self.dim)
|
||||||
|
u2 = [
|
||||||
|
self.text_embedding(y).unsqueeze(1),
|
||||||
|
self.time_embedding(u1).unsqueeze(1),
|
||||||
|
self.vision_embedding(x).unsqueeze(1),
|
||||||
|
self.eos_embedding.repeat(b, 1, 1)
|
||||||
|
]
|
||||||
|
x = self.pos_embedding + torch.cat(u2, dim=1)
|
||||||
|
|
||||||
|
# transformer
|
||||||
|
for block in self.blocks:
|
||||||
|
x = block(x, self.attn_mask)
|
||||||
|
x = self.norm(x)
|
||||||
|
|
||||||
|
# head
|
||||||
|
x = self.head(x[:, -1])
|
||||||
|
return x
|
||||||
|
|
||||||
|
def init_weights(self):
|
||||||
|
std = 0.02 / math.sqrt(2.0 * self.num_layers)
|
||||||
|
for name, m in self.named_modules():
|
||||||
|
if name.endswith('attn.proj') or name.endswith('ffn.2'):
|
||||||
|
# smaller std for output layers
|
||||||
|
nn.init.normal_(m.weight, std=std)
|
||||||
|
nn.init.zeros_(m.bias)
|
||||||
|
elif isinstance(m, (nn.Linear, nn.Embedding)):
|
||||||
|
nn.init.normal_(m.weight, std=0.02)
|
||||||
|
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||||
|
nn.init.zeros_(m.bias)
|
||||||
|
elif isinstance(m, nn.LayerNorm):
|
||||||
|
nn.init.ones_(m.weight)
|
||||||
|
nn.init.zeros_(m.bias)
|
||||||
|
|
||||||
|
def param_groups(self):
|
||||||
|
groups = [{
|
||||||
|
'params': [
|
||||||
|
p for n, p in self.named_parameters()
|
||||||
|
if 'norm' in n or n.endswith('bias')
|
||||||
|
],
|
||||||
|
'weight_decay':
|
||||||
|
0.0
|
||||||
|
}, {
|
||||||
|
'params': [
|
||||||
|
p for n, p in self.named_parameters()
|
||||||
|
if not ('norm' in n or n.endswith('bias'))
|
||||||
|
]
|
||||||
|
}]
|
||||||
|
return groups
|
||||||
199
modelscope/models/multi_modal/multi_stage_diffusion/tokenizer.py
Normal file
199
modelscope/models/multi_modal/multi_stage_diffusion/tokenizer.py
Normal file
@@ -0,0 +1,199 @@
|
|||||||
|
# The implementation here is modified based on OpenAI CLIP, publicly available at https://github.com/openai/CLIP.
|
||||||
|
|
||||||
|
import gzip
|
||||||
|
import html
|
||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
|
import ftfy
|
||||||
|
import regex as re
|
||||||
|
import torch
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
__all__ = ['CLIPTokenizer', 'XGLMTokenizer']
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def bytes_to_unicode():
|
||||||
|
"""
|
||||||
|
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
||||||
|
The reversible bpe codes work on unicode strings.
|
||||||
|
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
||||||
|
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
||||||
|
This is a signficant percentage of your normal, say, 32K bpe vocab.
|
||||||
|
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
||||||
|
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
||||||
|
"""
|
||||||
|
bs = list(range(ord('!'),
|
||||||
|
ord('~') + 1)) + list(range(
|
||||||
|
ord('¡'),
|
||||||
|
ord('¬') + 1)) + list(range(ord('®'),
|
||||||
|
ord('ÿ') + 1))
|
||||||
|
cs = bs[:]
|
||||||
|
n = 0
|
||||||
|
for b in range(2**8):
|
||||||
|
if b not in bs:
|
||||||
|
bs.append(b)
|
||||||
|
cs.append(2**8 + n)
|
||||||
|
n += 1
|
||||||
|
cs = [chr(n) for n in cs]
|
||||||
|
return dict(zip(bs, cs))
|
||||||
|
|
||||||
|
|
||||||
|
def get_pairs(word):
|
||||||
|
"""Return set of symbol pairs in a word.
|
||||||
|
Word is represented as tuple of symbols (symbols being variable-length strings).
|
||||||
|
"""
|
||||||
|
pairs = set()
|
||||||
|
prev_char = word[0]
|
||||||
|
for char in word[1:]:
|
||||||
|
pairs.add((prev_char, char))
|
||||||
|
prev_char = char
|
||||||
|
return pairs
|
||||||
|
|
||||||
|
|
||||||
|
def basic_clean(text):
|
||||||
|
text = ftfy.fix_text(text)
|
||||||
|
text = html.unescape(html.unescape(text))
|
||||||
|
return text.strip()
|
||||||
|
|
||||||
|
|
||||||
|
def whitespace_clean(text):
|
||||||
|
text = re.sub(r'\s+', ' ', text)
|
||||||
|
text = text.strip()
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleTokenizer(object):
|
||||||
|
|
||||||
|
def __init__(self, bpe_path):
|
||||||
|
self.byte_encoder = bytes_to_unicode()
|
||||||
|
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
||||||
|
merges = gzip.open(bpe_path).read().decode('utf-8').split('\n')
|
||||||
|
merges = merges[1:49152 - 256 - 2 + 1]
|
||||||
|
merges = [tuple(merge.split()) for merge in merges]
|
||||||
|
vocab = list(bytes_to_unicode().values())
|
||||||
|
vocab = vocab + [v + '</w>' for v in vocab]
|
||||||
|
for merge in merges:
|
||||||
|
vocab.append(''.join(merge))
|
||||||
|
vocab.extend(['<|startoftext|>', '<|endoftext|>'])
|
||||||
|
self.encoder = dict(zip(vocab, range(len(vocab))))
|
||||||
|
self.decoder = {v: k for k, v in self.encoder.items()}
|
||||||
|
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
||||||
|
self.cache = {
|
||||||
|
'<|startoftext|>': '<|startoftext|>',
|
||||||
|
'<|endoftext|>': '<|endoftext|>'
|
||||||
|
}
|
||||||
|
self.pat = re.compile(
|
||||||
|
r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
|
||||||
|
re.IGNORECASE)
|
||||||
|
|
||||||
|
def bpe(self, token):
|
||||||
|
if token in self.cache:
|
||||||
|
return self.cache[token]
|
||||||
|
word = tuple(token[:-1]) + (token[-1] + '</w>', )
|
||||||
|
pairs = get_pairs(word)
|
||||||
|
|
||||||
|
if not pairs:
|
||||||
|
return token + '</w>'
|
||||||
|
|
||||||
|
while True:
|
||||||
|
bigram = min(
|
||||||
|
pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
||||||
|
if bigram not in self.bpe_ranks:
|
||||||
|
break
|
||||||
|
first, second = bigram
|
||||||
|
new_word = []
|
||||||
|
i = 0
|
||||||
|
while i < len(word):
|
||||||
|
try:
|
||||||
|
j = word.index(first, i)
|
||||||
|
new_word.extend(word[i:j])
|
||||||
|
i = j
|
||||||
|
except Exception:
|
||||||
|
new_word.extend(word[i:])
|
||||||
|
break
|
||||||
|
|
||||||
|
if word[i] == first and i < len(word) - 1 and word[
|
||||||
|
i + 1] == second:
|
||||||
|
new_word.append(first + second)
|
||||||
|
i += 2
|
||||||
|
else:
|
||||||
|
new_word.append(word[i])
|
||||||
|
i += 1
|
||||||
|
new_word = tuple(new_word)
|
||||||
|
word = new_word
|
||||||
|
if len(word) == 1:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
pairs = get_pairs(word)
|
||||||
|
word = ' '.join(word)
|
||||||
|
self.cache[token] = word
|
||||||
|
return word
|
||||||
|
|
||||||
|
def encode(self, text):
|
||||||
|
bpe_tokens = []
|
||||||
|
text = whitespace_clean(basic_clean(text)).lower()
|
||||||
|
for token in re.findall(self.pat, text):
|
||||||
|
token = ''.join(self.byte_encoder[b]
|
||||||
|
for b in token.encode('utf-8'))
|
||||||
|
bpe_tokens.extend(self.encoder[bpe_token]
|
||||||
|
for bpe_token in self.bpe(token).split(' '))
|
||||||
|
return bpe_tokens
|
||||||
|
|
||||||
|
def decode(self, tokens):
|
||||||
|
text = ''.join([self.decoder[token] for token in tokens])
|
||||||
|
text = bytearray([self.byte_decoder[c] for c in text]).decode(
|
||||||
|
'utf-8', errors='replace').replace('</w>', ' ')
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
class CLIPTokenizer(object):
|
||||||
|
r"""CLIP tokenizer, adapted from https://github.com/openai/CLIP.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, bpe_path, length=77):
|
||||||
|
self.bpe_path = bpe_path
|
||||||
|
self.length = length
|
||||||
|
|
||||||
|
# init tokenizer
|
||||||
|
self.tokenizer = SimpleTokenizer(bpe_path=bpe_path)
|
||||||
|
self.sos_token = self.tokenizer.encoder['<|startoftext|>']
|
||||||
|
self.eos_token = self.tokenizer.encoder['<|endoftext|>']
|
||||||
|
self.vocab_size = len(self.tokenizer.encoder)
|
||||||
|
|
||||||
|
def __call__(self, sequence):
|
||||||
|
if isinstance(sequence, str):
|
||||||
|
return torch.LongTensor(self._tokenizer(sequence))
|
||||||
|
elif isinstance(sequence, list):
|
||||||
|
return torch.LongTensor([self._tokenizer(u) for u in sequence])
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
f'Expected the "sequence" to be a string or a list, but got {type(sequence)}'
|
||||||
|
)
|
||||||
|
|
||||||
|
def _tokenizer(self, text):
|
||||||
|
tokens = self.tokenizer.encode(text)[:self.length - 2]
|
||||||
|
tokens = [self.sos_token] + tokens + [self.eos_token]
|
||||||
|
tokens = tokens + [0] * (self.length - len(tokens))
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
|
||||||
|
class XGLMTokenizer(object):
|
||||||
|
r"""A wrapper of HuggingFace's XGLM tokenizer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model_dir, length=77, **kwargs):
|
||||||
|
self.length = length
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(model_dir, **kwargs)
|
||||||
|
self.vocab_size = self.tokenizer.vocab_size
|
||||||
|
|
||||||
|
def __call__(self, sequence, **kwargs):
|
||||||
|
_kwargs = {
|
||||||
|
'return_tensors': 'pt',
|
||||||
|
'padding': 'max_length',
|
||||||
|
'truncation': True,
|
||||||
|
'max_length': self.length
|
||||||
|
}
|
||||||
|
_kwargs.update(**kwargs)
|
||||||
|
tokens = self.tokenizer(sequence, **_kwargs)
|
||||||
|
return tokens.input_ids, tokens.attention_mask
|
||||||
466
modelscope/models/multi_modal/multi_stage_diffusion/upsampler.py
Normal file
466
modelscope/models/multi_modal/multi_stage_diffusion/upsampler.py
Normal file
@@ -0,0 +1,466 @@
|
|||||||
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
__all__ = ['Upsampler256', 'Upsampler1024']
|
||||||
|
|
||||||
|
|
||||||
|
def sinusoidal_embedding(timesteps, dim):
|
||||||
|
# check input
|
||||||
|
half = dim // 2
|
||||||
|
timesteps = timesteps.float()
|
||||||
|
|
||||||
|
# compute sinusoidal embedding
|
||||||
|
sinusoid = torch.outer(
|
||||||
|
timesteps, torch.pow(10000,
|
||||||
|
-torch.arange(half).to(timesteps).div(half)))
|
||||||
|
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
|
||||||
|
if dim % 2 != 0:
|
||||||
|
x = torch.cat([x, torch.zeros_like(x[:, :1])], dim=1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Resample(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, in_dim, out_dim, scale_factor, use_conv=False):
|
||||||
|
assert scale_factor in [0.5, 1.0, 2.0]
|
||||||
|
super(Resample, self).__init__()
|
||||||
|
self.in_dim = in_dim
|
||||||
|
self.out_dim = out_dim
|
||||||
|
self.scale_factor = scale_factor
|
||||||
|
self.use_conv = use_conv
|
||||||
|
|
||||||
|
# layers
|
||||||
|
if scale_factor == 2.0:
|
||||||
|
self.resample = nn.Sequential(
|
||||||
|
nn.Upsample(scale_factor=scale_factor, mode='nearest'),
|
||||||
|
nn.Conv2d(in_dim, out_dim, 3, padding=1)
|
||||||
|
if use_conv else nn.Identity())
|
||||||
|
elif scale_factor == 0.5:
|
||||||
|
self.resample = nn.Conv2d(
|
||||||
|
in_dim, out_dim, 3, stride=2,
|
||||||
|
padding=1) if use_conv else nn.AvgPool2d(
|
||||||
|
kernel_size=2, stride=2)
|
||||||
|
else:
|
||||||
|
self.resample = nn.Identity()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.resample(x)
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualBlock(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
in_dim,
|
||||||
|
embed_dim,
|
||||||
|
out_dim,
|
||||||
|
use_scale_shift_norm=True,
|
||||||
|
scale_factor=1.0,
|
||||||
|
dropout=0.0):
|
||||||
|
super(ResidualBlock, self).__init__()
|
||||||
|
self.in_dim = in_dim
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
self.out_dim = out_dim
|
||||||
|
self.use_scale_shift_norm = use_scale_shift_norm
|
||||||
|
self.scale_factor = scale_factor
|
||||||
|
|
||||||
|
# layers
|
||||||
|
self.layer1 = nn.Sequential(
|
||||||
|
nn.GroupNorm(32, in_dim), nn.SiLU(),
|
||||||
|
nn.Conv2d(in_dim, out_dim, 3, padding=1))
|
||||||
|
self.resample = Resample(in_dim, in_dim, scale_factor, use_conv=False)
|
||||||
|
self.embedding = nn.Sequential(
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Linear(embed_dim,
|
||||||
|
out_dim * 2 if use_scale_shift_norm else out_dim))
|
||||||
|
self.layer2 = nn.Sequential(
|
||||||
|
nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout),
|
||||||
|
nn.Conv2d(out_dim, out_dim, 3, padding=1))
|
||||||
|
self.shortcut = nn.Identity() if in_dim == out_dim else nn.Conv2d(
|
||||||
|
in_dim, out_dim, 1)
|
||||||
|
|
||||||
|
# zero out the last layer params
|
||||||
|
nn.init.zeros_(self.layer2[-1].weight)
|
||||||
|
|
||||||
|
def forward(self, x, e):
|
||||||
|
identity = self.resample(x)
|
||||||
|
x = self.layer1[-1](self.resample(self.layer1[:-1](x)))
|
||||||
|
e = self.embedding(e).unsqueeze(-1).unsqueeze(-1).type(x.dtype)
|
||||||
|
if self.use_scale_shift_norm:
|
||||||
|
scale, shift = e.chunk(2, dim=1)
|
||||||
|
x = self.layer2[0](x) * (1 + scale) + shift
|
||||||
|
x = self.layer2[1:](x)
|
||||||
|
else:
|
||||||
|
x = x + e
|
||||||
|
x = self.layer2(x)
|
||||||
|
x = x + self.shortcut(identity)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionBlock(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, dim, context_dim=None, num_heads=None, head_dim=None):
|
||||||
|
# consider head_dim first, then num_heads
|
||||||
|
num_heads = dim // head_dim if head_dim else num_heads
|
||||||
|
head_dim = dim // num_heads
|
||||||
|
assert num_heads * head_dim == dim
|
||||||
|
super(AttentionBlock, self).__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.context_dim = context_dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = head_dim
|
||||||
|
self.scale = math.pow(head_dim, -0.25)
|
||||||
|
|
||||||
|
# layers
|
||||||
|
self.norm = nn.GroupNorm(32, dim)
|
||||||
|
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
|
||||||
|
if context_dim is not None:
|
||||||
|
self.context_kv = nn.Linear(context_dim, dim * 2)
|
||||||
|
self.proj = nn.Conv2d(dim, dim, 1)
|
||||||
|
|
||||||
|
# zero out the last layer params
|
||||||
|
nn.init.zeros_(self.proj.weight)
|
||||||
|
|
||||||
|
def forward(self, x, context=None):
|
||||||
|
r"""x: [B, C, H, W].
|
||||||
|
context: [B, L, C] or None.
|
||||||
|
"""
|
||||||
|
identity = x
|
||||||
|
b, c, h, w, n, d = *x.size(), self.num_heads, self.head_dim
|
||||||
|
|
||||||
|
# compute query, key, value
|
||||||
|
x = self.norm(x)
|
||||||
|
q, k, v = self.to_qkv(x).view(b, n * 3, d, h * w).chunk(3, dim=1)
|
||||||
|
if context is not None:
|
||||||
|
ck, cv = self.context_kv(context).reshape(b, -1, n * 2,
|
||||||
|
d).permute(0, 2, 3,
|
||||||
|
1).chunk(
|
||||||
|
2, dim=1)
|
||||||
|
k = torch.cat([ck, k], dim=-1)
|
||||||
|
v = torch.cat([cv, v], dim=-1)
|
||||||
|
|
||||||
|
# compute attention
|
||||||
|
attn = torch.matmul(q.transpose(-1, -2) * self.scale, k * self.scale)
|
||||||
|
attn = F.softmax(attn, dim=-1)
|
||||||
|
|
||||||
|
# gather context
|
||||||
|
x = torch.matmul(v, attn.transpose(-1, -2))
|
||||||
|
x = x.reshape(b, c, h, w)
|
||||||
|
|
||||||
|
# output
|
||||||
|
x = self.proj(x)
|
||||||
|
return x + identity
|
||||||
|
|
||||||
|
|
||||||
|
class Upsampler256(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
in_dim=6,
|
||||||
|
dim=320,
|
||||||
|
y_dim=768,
|
||||||
|
context_dim=512,
|
||||||
|
out_dim=3,
|
||||||
|
dim_mult=[1, 2, 3, 4],
|
||||||
|
num_heads=None,
|
||||||
|
head_dim=64,
|
||||||
|
num_res_blocks=3,
|
||||||
|
attn_scales=[1 / 8],
|
||||||
|
resblock_resample=True,
|
||||||
|
use_scale_shift_norm=True,
|
||||||
|
dropout=0.1):
|
||||||
|
embed_dim = dim * 4
|
||||||
|
super(Upsampler256, self).__init__()
|
||||||
|
self.in_dim = in_dim
|
||||||
|
self.dim = dim
|
||||||
|
self.y_dim = y_dim
|
||||||
|
self.context_dim = context_dim
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
self.out_dim = out_dim
|
||||||
|
self.dim_mult = dim_mult
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = head_dim
|
||||||
|
self.num_res_blocks = num_res_blocks
|
||||||
|
self.attn_scales = attn_scales
|
||||||
|
self.resblock_resample = resblock_resample
|
||||||
|
self.use_scale_shift_norm = use_scale_shift_norm
|
||||||
|
|
||||||
|
# params
|
||||||
|
enc_dims = [dim * u for u in [1] + dim_mult]
|
||||||
|
dec_dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
||||||
|
shortcut_dims = []
|
||||||
|
scale = 1.0
|
||||||
|
|
||||||
|
# embeddings
|
||||||
|
self.time_embedding = nn.Sequential(
|
||||||
|
nn.Linear(dim, embed_dim), nn.SiLU(),
|
||||||
|
nn.Linear(embed_dim, embed_dim))
|
||||||
|
self.y_embedding = nn.Sequential(
|
||||||
|
nn.Linear(y_dim, embed_dim), nn.SiLU(),
|
||||||
|
nn.Linear(embed_dim, embed_dim))
|
||||||
|
self.context_embedding = nn.Sequential(
|
||||||
|
nn.Linear(y_dim, embed_dim), nn.SiLU(),
|
||||||
|
nn.Linear(embed_dim, context_dim * 4))
|
||||||
|
|
||||||
|
# encoder
|
||||||
|
self.encoder = nn.ModuleList(
|
||||||
|
[nn.Conv2d(self.in_dim, dim, 3, padding=1)])
|
||||||
|
shortcut_dims.append(dim)
|
||||||
|
for i, (in_dim,
|
||||||
|
out_dim) in enumerate(zip(enc_dims[:-1], enc_dims[1:])):
|
||||||
|
for j in range(num_res_blocks):
|
||||||
|
# residual (+attention) blocks
|
||||||
|
block = nn.ModuleList([
|
||||||
|
ResidualBlock(in_dim, embed_dim, out_dim,
|
||||||
|
use_scale_shift_norm, 1.0, dropout)
|
||||||
|
])
|
||||||
|
if scale in attn_scales:
|
||||||
|
block.append(
|
||||||
|
AttentionBlock(out_dim, context_dim, num_heads,
|
||||||
|
head_dim))
|
||||||
|
in_dim = out_dim
|
||||||
|
self.encoder.append(block)
|
||||||
|
shortcut_dims.append(out_dim)
|
||||||
|
|
||||||
|
# downsample
|
||||||
|
if i != len(dim_mult) - 1 and j == num_res_blocks - 1:
|
||||||
|
if resblock_resample:
|
||||||
|
downsample = ResidualBlock(out_dim, embed_dim, out_dim,
|
||||||
|
use_scale_shift_norm, 0.5,
|
||||||
|
dropout)
|
||||||
|
else:
|
||||||
|
downsample = Resample(
|
||||||
|
out_dim, out_dim, 0.5, use_conv=True)
|
||||||
|
shortcut_dims.append(out_dim)
|
||||||
|
scale /= 2.0
|
||||||
|
self.encoder.append(downsample)
|
||||||
|
|
||||||
|
# middle
|
||||||
|
self.middle = nn.ModuleList([
|
||||||
|
ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm,
|
||||||
|
1.0, dropout),
|
||||||
|
AttentionBlock(out_dim, context_dim, num_heads, head_dim),
|
||||||
|
ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm,
|
||||||
|
1.0, dropout)
|
||||||
|
])
|
||||||
|
|
||||||
|
# decoder
|
||||||
|
self.decoder = nn.ModuleList()
|
||||||
|
for i, (in_dim,
|
||||||
|
out_dim) in enumerate(zip(dec_dims[:-1], dec_dims[1:])):
|
||||||
|
for j in range(num_res_blocks + 1):
|
||||||
|
# residual (+attention) blocks
|
||||||
|
block = nn.ModuleList([
|
||||||
|
ResidualBlock(in_dim + shortcut_dims.pop(), embed_dim,
|
||||||
|
out_dim, use_scale_shift_norm, 1.0, dropout)
|
||||||
|
])
|
||||||
|
if scale in attn_scales:
|
||||||
|
block.append(
|
||||||
|
AttentionBlock(out_dim, context_dim, num_heads,
|
||||||
|
head_dim))
|
||||||
|
in_dim = out_dim
|
||||||
|
|
||||||
|
# upsample
|
||||||
|
if i != len(dim_mult) - 1 and j == num_res_blocks:
|
||||||
|
if resblock_resample:
|
||||||
|
upsample = ResidualBlock(out_dim, embed_dim, out_dim,
|
||||||
|
use_scale_shift_norm, 2.0,
|
||||||
|
dropout)
|
||||||
|
else:
|
||||||
|
upsample = Resample(
|
||||||
|
out_dim, out_dim, 2.0, use_conv=True)
|
||||||
|
scale *= 2.0
|
||||||
|
block.append(upsample)
|
||||||
|
self.decoder.append(block)
|
||||||
|
|
||||||
|
# head
|
||||||
|
self.head = nn.Sequential(
|
||||||
|
nn.GroupNorm(32, out_dim), nn.SiLU(),
|
||||||
|
nn.Conv2d(out_dim, self.out_dim, 3, padding=1))
|
||||||
|
|
||||||
|
# zero out the last layer params
|
||||||
|
nn.init.zeros_(self.head[-1].weight)
|
||||||
|
|
||||||
|
def forward(self, x, t, y, concat):
|
||||||
|
# embeddings
|
||||||
|
x = torch.cat([x, concat], dim=1)
|
||||||
|
e = self.time_embedding(sinusoidal_embedding(
|
||||||
|
t, self.dim)) + self.y_embedding(y)
|
||||||
|
context = self.context_embedding(y).view(-1, 4, self.context_dim)
|
||||||
|
|
||||||
|
# encoder
|
||||||
|
xs = []
|
||||||
|
for block in self.encoder:
|
||||||
|
x = self._forward_single(block, x, e, context)
|
||||||
|
xs.append(x)
|
||||||
|
|
||||||
|
# middle
|
||||||
|
for block in self.middle:
|
||||||
|
x = self._forward_single(block, x, e, context)
|
||||||
|
|
||||||
|
# decoder
|
||||||
|
for block in self.decoder:
|
||||||
|
x = torch.cat([x, xs.pop()], dim=1)
|
||||||
|
x = self._forward_single(block, x, e, context)
|
||||||
|
|
||||||
|
# head
|
||||||
|
x = self.head(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def _forward_single(self, module, x, e, context):
|
||||||
|
if isinstance(module, ResidualBlock):
|
||||||
|
x = module(x, e)
|
||||||
|
elif isinstance(module, AttentionBlock):
|
||||||
|
x = module(x, context)
|
||||||
|
elif isinstance(module, nn.ModuleList):
|
||||||
|
for block in module:
|
||||||
|
x = self._forward_single(block, x, e, context)
|
||||||
|
else:
|
||||||
|
x = module(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Upsampler1024(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
in_dim=6,
|
||||||
|
dim=192,
|
||||||
|
y_dim=768,
|
||||||
|
out_dim=3,
|
||||||
|
dim_mult=[1, 1, 2, 2, 4, 4],
|
||||||
|
num_res_blocks=2,
|
||||||
|
resblock_resample=True,
|
||||||
|
use_scale_shift_norm=True,
|
||||||
|
dropout=0.0):
|
||||||
|
embed_dim = dim * 4
|
||||||
|
super(Upsampler1024, self).__init__()
|
||||||
|
self.in_dim = in_dim
|
||||||
|
self.dim = dim
|
||||||
|
self.y_dim = y_dim
|
||||||
|
self.out_dim = out_dim
|
||||||
|
self.dim_mult = dim_mult
|
||||||
|
self.num_res_blocks = num_res_blocks
|
||||||
|
self.resblock_resample = resblock_resample
|
||||||
|
self.use_scale_shift_norm = use_scale_shift_norm
|
||||||
|
|
||||||
|
# params
|
||||||
|
enc_dims = [dim * u for u in [1] + dim_mult]
|
||||||
|
dec_dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
||||||
|
shortcut_dims = []
|
||||||
|
scale = 1.0
|
||||||
|
|
||||||
|
# embedding
|
||||||
|
self.time_embedding = nn.Sequential(
|
||||||
|
nn.Linear(dim, embed_dim), nn.SiLU(),
|
||||||
|
nn.Linear(embed_dim, embed_dim))
|
||||||
|
self.y_embedding = nn.Sequential(
|
||||||
|
nn.Linear(y_dim, embed_dim), nn.SiLU(),
|
||||||
|
nn.Linear(embed_dim, embed_dim))
|
||||||
|
|
||||||
|
# encoder
|
||||||
|
self.encoder = nn.ModuleList(
|
||||||
|
[nn.Conv2d(self.in_dim, dim, 3, padding=1)])
|
||||||
|
shortcut_dims.append(dim)
|
||||||
|
for i, (in_dim,
|
||||||
|
out_dim) in enumerate(zip(enc_dims[:-1], enc_dims[1:])):
|
||||||
|
for j in range(num_res_blocks):
|
||||||
|
# residual block
|
||||||
|
block = nn.ModuleList([
|
||||||
|
ResidualBlock(in_dim, embed_dim, out_dim,
|
||||||
|
use_scale_shift_norm, 1.0, dropout)
|
||||||
|
])
|
||||||
|
shortcut_dims.append(out_dim)
|
||||||
|
in_dim = out_dim
|
||||||
|
self.encoder.append(block)
|
||||||
|
|
||||||
|
# downsample
|
||||||
|
if i != len(dim_mult) - 1 and j == num_res_blocks - 1:
|
||||||
|
if resblock_resample:
|
||||||
|
downsample = ResidualBlock(out_dim, embed_dim, out_dim,
|
||||||
|
use_scale_shift_norm, 0.5,
|
||||||
|
dropout)
|
||||||
|
else:
|
||||||
|
downsample = Resample(
|
||||||
|
out_dim, out_dim, 0.5, use_conv=True)
|
||||||
|
shortcut_dims.append(out_dim)
|
||||||
|
scale /= 2.0
|
||||||
|
self.encoder.append(downsample)
|
||||||
|
|
||||||
|
# middle
|
||||||
|
self.middle = nn.ModuleList([
|
||||||
|
ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm,
|
||||||
|
1.0, dropout),
|
||||||
|
ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm,
|
||||||
|
1.0, dropout)
|
||||||
|
])
|
||||||
|
|
||||||
|
# decoder
|
||||||
|
self.decoder = nn.ModuleList()
|
||||||
|
for i, (in_dim,
|
||||||
|
out_dim) in enumerate(zip(dec_dims[:-1], dec_dims[1:])):
|
||||||
|
for j in range(num_res_blocks + 1):
|
||||||
|
# residual block
|
||||||
|
block = nn.ModuleList([
|
||||||
|
ResidualBlock(in_dim + shortcut_dims.pop(), embed_dim,
|
||||||
|
out_dim, use_scale_shift_norm, 1.0, dropout)
|
||||||
|
])
|
||||||
|
in_dim = out_dim
|
||||||
|
|
||||||
|
# upsample
|
||||||
|
if i != len(dim_mult) - 1 and j == num_res_blocks:
|
||||||
|
if resblock_resample:
|
||||||
|
upsample = ResidualBlock(out_dim, embed_dim, out_dim,
|
||||||
|
use_scale_shift_norm, 2.0,
|
||||||
|
dropout)
|
||||||
|
else:
|
||||||
|
upsample = Resample(
|
||||||
|
out_dim, out_dim, 2.0, use_conv=True)
|
||||||
|
scale *= 2.0
|
||||||
|
block.append(upsample)
|
||||||
|
self.decoder.append(block)
|
||||||
|
|
||||||
|
# head
|
||||||
|
self.head = nn.Sequential(
|
||||||
|
nn.GroupNorm(32, out_dim), nn.SiLU(),
|
||||||
|
nn.Conv2d(out_dim, self.out_dim, 3, padding=1))
|
||||||
|
|
||||||
|
# zero out the last layer params
|
||||||
|
nn.init.zeros_(self.head[-1].weight)
|
||||||
|
|
||||||
|
def forward(self, x, t, y, concat):
|
||||||
|
# embedding
|
||||||
|
x = torch.cat([x, concat], dim=1)
|
||||||
|
e = self.time_embedding(sinusoidal_embedding(
|
||||||
|
t, self.dim)) + self.y_embedding(y)
|
||||||
|
|
||||||
|
# encoder
|
||||||
|
xs = []
|
||||||
|
for block in self.encoder:
|
||||||
|
x = self._forward_single(block, x, e)
|
||||||
|
xs.append(x)
|
||||||
|
|
||||||
|
# middle
|
||||||
|
for block in self.middle:
|
||||||
|
x = self._forward_single(block, x, e)
|
||||||
|
|
||||||
|
# decoder
|
||||||
|
for block in self.decoder:
|
||||||
|
x = torch.cat([x, xs.pop()], dim=1)
|
||||||
|
x = self._forward_single(block, x, e)
|
||||||
|
|
||||||
|
# head
|
||||||
|
x = self.head(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def _forward_single(self, module, x, e):
|
||||||
|
if isinstance(module, ResidualBlock):
|
||||||
|
x = module(x, e)
|
||||||
|
elif isinstance(module, nn.ModuleList):
|
||||||
|
for block in module:
|
||||||
|
x = self._forward_single(block, x, e)
|
||||||
|
else:
|
||||||
|
x = module(x)
|
||||||
|
return x
|
||||||
205
modelscope/models/multi_modal/multi_stage_diffusion/xglm.py
Normal file
205
modelscope/models/multi_modal/multi_stage_diffusion/xglm.py
Normal file
@@ -0,0 +1,205 @@
|
|||||||
|
# The implementation here is modified based on HuggingFace XGLM, publicly available
|
||||||
|
# at https://github.com/huggingface/transformers.
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
__all__ = ['XGLM']
|
||||||
|
|
||||||
|
|
||||||
|
def sinusoidal_embedding(seq_len, dim, pad_token=None):
|
||||||
|
half = dim // 2
|
||||||
|
sinusoid = torch.outer(
|
||||||
|
torch.arange(seq_len, dtype=torch.float32),
|
||||||
|
torch.pow(10000,
|
||||||
|
-torch.arange(half, dtype=torch.float32).div(half - 1)))
|
||||||
|
x = torch.cat([torch.sin(sinusoid), torch.cos(sinusoid)], dim=1)
|
||||||
|
if dim % 2 == 1:
|
||||||
|
x = torch.cat([x, torch.zeros_like(x[:, :1])], dim=1)
|
||||||
|
if pad_token is not None:
|
||||||
|
x[pad_token, :] = 0
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SinusoidalEmbedding(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, seq_len, dim, pad_token):
|
||||||
|
super(SinusoidalEmbedding, self).__init__()
|
||||||
|
self.seq_len = seq_len
|
||||||
|
self.dim = dim
|
||||||
|
self.pad_token = pad_token
|
||||||
|
self.register_buffer('weight',
|
||||||
|
sinusoidal_embedding(seq_len + 2, dim, pad_token))
|
||||||
|
|
||||||
|
def forward(self, tokens):
|
||||||
|
mask = tokens.ne(self.pad_token).long()
|
||||||
|
indices = torch.cumsum(mask, dim=1) * mask + self.pad_token
|
||||||
|
pos_embeds = self.weight.index_select(0, indices.view(-1)).view(
|
||||||
|
*tokens.shape, -1)
|
||||||
|
return pos_embeds
|
||||||
|
|
||||||
|
|
||||||
|
class GELU(nn.Module):
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
|
||||||
|
|
||||||
|
|
||||||
|
class SelfAttention(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, dim, num_heads, dropout=0.1):
|
||||||
|
assert dim % num_heads == 0
|
||||||
|
super(SelfAttention, self).__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = dim // num_heads
|
||||||
|
self.scale = 1.0 / math.sqrt(self.head_dim)
|
||||||
|
|
||||||
|
# layers
|
||||||
|
self.q = nn.Linear(dim, dim)
|
||||||
|
self.k = nn.Linear(dim, dim)
|
||||||
|
self.v = nn.Linear(dim, dim)
|
||||||
|
self.o = nn.Linear(dim, dim)
|
||||||
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
|
def forward(self, x, mask=None):
|
||||||
|
r"""x: [B, L, C].
|
||||||
|
mask: [B, *, L, L] or None.
|
||||||
|
"""
|
||||||
|
b, l, n, c = *x.shape[:2], self.num_heads, self.head_dim
|
||||||
|
|
||||||
|
# compute query, key, value
|
||||||
|
q = self.q(x).view(b, l, n, c)
|
||||||
|
k = self.k(x).view(b, l, n, c)
|
||||||
|
v = self.v(x).view(b, l, n, c)
|
||||||
|
|
||||||
|
# compute attention
|
||||||
|
attn = self.scale * torch.einsum('binc,bjnc->bnij', q, k)
|
||||||
|
if mask is not None:
|
||||||
|
attn = attn.masked_fill(mask == 0, float('-inf'))
|
||||||
|
attn = F.softmax(attn, dim=-1)
|
||||||
|
attn = self.dropout(attn)
|
||||||
|
|
||||||
|
# gather context
|
||||||
|
x = torch.einsum('bnij,bjnc->binc', attn, v)
|
||||||
|
x = x.reshape(b, l, -1)
|
||||||
|
|
||||||
|
# output
|
||||||
|
x = self.o(x)
|
||||||
|
x = self.dropout(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionBlock(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, dim, ffn_dim, ffn_act, num_heads, dropout=0.1):
|
||||||
|
assert ffn_act in ['gelu', 'relu']
|
||||||
|
super(AttentionBlock, self).__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.ffn_dim = ffn_dim
|
||||||
|
self.ffn_act = ffn_act
|
||||||
|
self.num_heads = num_heads
|
||||||
|
|
||||||
|
# layers
|
||||||
|
self.norm1 = nn.LayerNorm(dim)
|
||||||
|
self.attn = SelfAttention(dim, num_heads, dropout)
|
||||||
|
self.norm2 = nn.LayerNorm(dim)
|
||||||
|
self.ffn = nn.Sequential(
|
||||||
|
nn.Linear(dim, ffn_dim),
|
||||||
|
GELU() if ffn_act == 'gelu' else nn.ReLU(inplace=True),
|
||||||
|
nn.Linear(ffn_dim, dim), nn.Dropout(dropout))
|
||||||
|
|
||||||
|
def forward(self, x, mask=None):
|
||||||
|
x = x + self.attn(self.norm1(x), mask)
|
||||||
|
x = x + self.ffn(self.norm2(x))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class XGLM(nn.Module):
|
||||||
|
r"""A multilingual GPT model with an embedding head.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
vocab_size=256008,
|
||||||
|
max_seq_len=2048,
|
||||||
|
dim=1024,
|
||||||
|
ffn_dim=4096,
|
||||||
|
ffn_act='gelu',
|
||||||
|
embed_dim=768,
|
||||||
|
num_heads=16,
|
||||||
|
num_layers=24,
|
||||||
|
pad_token=1,
|
||||||
|
dropout=0.1):
|
||||||
|
super(XGLM, self).__init__()
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.max_seq_len = max_seq_len
|
||||||
|
self.dim = dim
|
||||||
|
self.ffn_dim = ffn_dim
|
||||||
|
self.ffn_act = ffn_act
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.num_layers = num_layers
|
||||||
|
self.pad_token = pad_token
|
||||||
|
self.scale = math.sqrt(dim) # rescale token embedings
|
||||||
|
|
||||||
|
# layers
|
||||||
|
self.token_embedding = nn.Embedding(vocab_size, dim, pad_token)
|
||||||
|
self.pos_embedding = SinusoidalEmbedding(max_seq_len, dim, pad_token)
|
||||||
|
self.eos_embedding = nn.Parameter(torch.randn(1, 1, dim))
|
||||||
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
self.blocks = nn.ModuleList([
|
||||||
|
AttentionBlock(dim, ffn_dim, ffn_act, num_heads, dropout)
|
||||||
|
for _ in range(num_layers)
|
||||||
|
])
|
||||||
|
self.norm = nn.LayerNorm(dim)
|
||||||
|
self.head = nn.Linear(dim, embed_dim, bias=False)
|
||||||
|
|
||||||
|
# causal attention mask
|
||||||
|
self.register_buffer(
|
||||||
|
'attn_mask',
|
||||||
|
torch.tril(torch.ones(1, 1, 1 + max_seq_len, 1 + max_seq_len)))
|
||||||
|
|
||||||
|
# init weights
|
||||||
|
self.apply(self.init_weights)
|
||||||
|
|
||||||
|
def forward(self, tokens, mask=None):
|
||||||
|
r"""tokens: [B, L].
|
||||||
|
mask: [B, L].
|
||||||
|
"""
|
||||||
|
b, seq_len = tokens.size(0), 1 + tokens.size(1)
|
||||||
|
|
||||||
|
# embeddings
|
||||||
|
x = self.scale * self.token_embedding(tokens)
|
||||||
|
x = torch.cat([x, self.eos_embedding.repeat(b, 1, 1)], dim=1)
|
||||||
|
# x = x + self.pos_embedding(tokens)
|
||||||
|
x = self.dropout(x)
|
||||||
|
|
||||||
|
# attention mask
|
||||||
|
if mask is None:
|
||||||
|
mask = self.attn_mask[:, :, :seq_len, :seq_len].repeat(b, 1, 1, 1)
|
||||||
|
else:
|
||||||
|
mask = self.attn_mask[:, :, :seq_len, :seq_len] * torch.cat(
|
||||||
|
[mask, torch.zeros_like(mask[:, :1])], dim=1).view(
|
||||||
|
b, 1, 1, seq_len)
|
||||||
|
|
||||||
|
# transformer
|
||||||
|
for block in self.blocks:
|
||||||
|
x = block(x, mask)
|
||||||
|
x = self.norm(x)
|
||||||
|
|
||||||
|
# head
|
||||||
|
logits = self.head(x[:, -1])
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def init_weights(self, m):
|
||||||
|
if isinstance(m, nn.Linear):
|
||||||
|
nn.init.normal_(m.weight, std=0.02)
|
||||||
|
if m.bias is not None:
|
||||||
|
nn.init.zeros_(m.bias)
|
||||||
|
elif isinstance(m, nn.Embedding):
|
||||||
|
nn.init.normal_(m.weight, std=0.02)
|
||||||
|
if m.padding_idx is not None:
|
||||||
|
nn.init.zeros_(m.weight[m.padding_idx])
|
||||||
@@ -3,7 +3,8 @@ from typing import Any, Dict, Optional
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from modelscope.metainfo import Pipelines
|
from modelscope.metainfo import Pipelines
|
||||||
from modelscope.models.multi_modal import OfaForTextToImageSynthesis
|
from modelscope.models.multi_modal import (
|
||||||
|
MultiStageDiffusionForTextToImageSynthesis, OfaForTextToImageSynthesis)
|
||||||
from modelscope.outputs import OutputKeys
|
from modelscope.outputs import OutputKeys
|
||||||
from modelscope.pipelines.base import Input, Model, Pipeline
|
from modelscope.pipelines.base import Input, Model, Pipeline
|
||||||
from modelscope.pipelines.builder import PIPELINES
|
from modelscope.pipelines.builder import PIPELINES
|
||||||
@@ -48,7 +49,9 @@ class TextToImageSynthesisPipeline(Pipeline):
|
|||||||
return input
|
return input
|
||||||
|
|
||||||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
|
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
if isinstance(self.model, OfaForTextToImageSynthesis):
|
if isinstance(self.model,
|
||||||
|
(OfaForTextToImageSynthesis,
|
||||||
|
MultiStageDiffusionForTextToImageSynthesis)):
|
||||||
return self.model(input)
|
return self.model(input)
|
||||||
return self.model.generate(input)
|
return self.model.generate(input)
|
||||||
|
|
||||||
|
|||||||
40
tests/pipelines/test_multi_stage_diffusion.py
Normal file
40
tests/pipelines/test_multi_stage_diffusion.py
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from modelscope.models import Model
|
||||||
|
from modelscope.outputs import OutputKeys
|
||||||
|
from modelscope.pipelines import pipeline
|
||||||
|
from modelscope.utils.constant import Tasks
|
||||||
|
from modelscope.utils.test_utils import test_level
|
||||||
|
|
||||||
|
|
||||||
|
class MultiStageDiffusionTest(unittest.TestCase):
|
||||||
|
model_id = 'damo/cv_diffusion_text-to-image-synthesis'
|
||||||
|
test_text = {'text': 'Photograph of a baby chicken wearing sunglasses'}
|
||||||
|
|
||||||
|
@unittest.skip(
|
||||||
|
'skip test since the pretrained model is not publicly available')
|
||||||
|
def test_run_with_model_from_modelhub(self):
|
||||||
|
model = Model.from_pretrained(self.model_id)
|
||||||
|
pipe_line_text_to_image_synthesis = pipeline(
|
||||||
|
task=Tasks.text_to_image_synthesis, model=model)
|
||||||
|
img = pipe_line_text_to_image_synthesis(
|
||||||
|
self.test_text)[OutputKeys.OUTPUT_IMG]
|
||||||
|
print(np.sum(np.abs(img)))
|
||||||
|
|
||||||
|
@unittest.skip(
|
||||||
|
'skip test since the pretrained model is not publicly available')
|
||||||
|
def test_run_with_model_name(self):
|
||||||
|
pipe_line_text_to_image_synthesis = pipeline(
|
||||||
|
task=Tasks.text_to_image_synthesis, model=self.model_id)
|
||||||
|
img = pipe_line_text_to_image_synthesis(
|
||||||
|
self.test_text)[OutputKeys.OUTPUT_IMG]
|
||||||
|
print(np.sum(np.abs(img)))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user