diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 49bb9ca8..630d4aa5 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -221,6 +221,7 @@ class Models(object): videocomposer = 'videocomposer' text_to_360panorama_image = 'text-to-360panorama-image' image_to_video_model = 'image-to-video-model' + video_to_video_model = 'video-to-video-model' # science models unifold = 'unifold' @@ -547,6 +548,7 @@ class Pipelines(object): multimodal_dialogue = 'multimodal-dialogue' llama2_text_generation_pipeline = 'llama2-text-generation-pipeline' image_to_video_task_pipeline = 'image-to-video-task-pipeline' + video_to_video_pipeline = 'video-to-video-pipeline' # science tasks protein_structure = 'unifold-protein-structure' diff --git a/modelscope/models/multi_modal/video_to_video/__init__.py b/modelscope/models/multi_modal/video_to_video/__init__.py new file mode 100644 index 00000000..0e10a8eb --- /dev/null +++ b/modelscope/models/multi_modal/video_to_video/__init__.py @@ -0,0 +1,24 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + + from .video_to_video_model import VideoToVideo + +else: + _import_structure = { + 'video_to_video_model': ['VideoToVideo'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/multi_modal/video_to_video/modules/__init__.py b/modelscope/models/multi_modal/video_to_video/modules/__init__.py new file mode 100644 index 00000000..6c882318 --- /dev/null +++ b/modelscope/models/multi_modal/video_to_video/modules/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from .autoencoder import * +from .embedder import * +from .unet_v2v import * diff --git a/modelscope/models/multi_modal/video_to_video/modules/autoencoder.py b/modelscope/models/multi_modal/video_to_video/modules/autoencoder.py new file mode 100644 index 00000000..714a8953 --- /dev/null +++ b/modelscope/models/multi_modal/video_to_video/modules/autoencoder.py @@ -0,0 +1,590 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import collections + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +def nonlinearity(x): + # swish + return x * torch.sigmoid(x) + + +def Normalize(in_channels, num_groups=32): + return torch.nn.GroupNorm( + num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + + +@torch.no_grad() +def get_first_stage_encoding(encoder_posterior): + scale_factor = 0.18215 + if isinstance(encoder_posterior, DiagonalGaussianDistribution): + z = encoder_posterior.sample() + elif isinstance(encoder_posterior, torch.Tensor): + z = encoder_posterior + else: + raise NotImplementedError( + f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented" + ) + return scale_factor * z + + +class DiagonalGaussianDistribution(object): + + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like( + self.mean).to(device=self.parameters.device) + + def sample(self): + x = self.mean + self.std * torch.randn( + self.mean.shape).to(device=self.parameters.device) + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.]) + else: + if other is None: + return 0.5 * torch.sum( + torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, + dim=[1, 2, 3]) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var - 1.0 - self.logvar + other.logvar, + dim=[1, 2, 3]) + + def nll(self, sample, dims=[1, 2, 3]): + if self.deterministic: + return torch.Tensor([0.]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + + torch.pow(sample - self.mean, 2) / self.var, + dim=dims) + + def mode(self): + return self.mean + + +class ResnetBlock(nn.Module): + + def __init__(self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout, + temb_channels=512): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d( + out_channels, out_channels, kernel_size=3, stride=1, padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + else: + self.nin_shortcut = torch.nn.Conv2d( + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x + h + + +class AttnBlock(nn.Module): + + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = q.reshape(b, c, h * w) + q = q.permute(0, 2, 1) + k = k.reshape(b, c, h * w) + w_ = torch.bmm(q, k) + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h * w) + w_ = w_.permute(0, 2, 1) + h_ = torch.bmm(v, w_) + h_ = h_.reshape(b, c, h, w) + + h_ = self.proj_out(h_) + + return x + h_ + + +class Upsample(nn.Module): + + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x): + x = torch.nn.functional.interpolate( + x, scale_factor=2.0, mode='nearest') + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=2, padding=0) + + def forward(self, x): + if self.with_conv: + pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, pad, mode='constant', value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class Encoder(nn.Module): + + def __init__(self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + double_z=True, + use_linear_attn=False, + attn_type='vanilla', + **ignore_kwargs): + super().__init__() + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # downsampling + self.conv_in = torch.nn.Conv2d( + in_channels, self.ch, kernel_size=3, stride=1, padding=1) + + curr_res = resolution + in_ch_mult = (1, ) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, + 2 * z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + # timestep embedding + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + + def __init__(self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + give_pre_end=False, + tanh_out=False, + use_linear_attn=False, + attn_type='vanilla', + **ignorekwargs): + super().__init__() + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + self.tanh_out = tanh_out + + # compute block_in and curr_res at lowest res + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2**(self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + logger.info('Working with z of shape {} = {} dimensions.'.format( + self.z_shape, np.prod(self.z_shape))) + + # z to block_in + self.conv_in = torch.nn.Conv2d( + z_channels, block_in, kernel_size=3, stride=1, padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, out_ch, kernel_size=3, stride=1, padding=1) + + def forward(self, z): + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h, temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + if self.tanh_out: + h = torch.tanh(h) + return h + + +class AutoencoderKL(nn.Module): + + def __init__(self, + ddconfig, + embed_dim, + pretrained=None, + ignore_keys=[], + image_key='image', + colorize_nlabels=None, + monitor=None, + ema_decay=None, + learn_logvar=False, + **kwargs): + super().__init__() + self.learn_logvar = learn_logvar + self.image_key = image_key + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + assert ddconfig['double_z'] + self.quant_conv = torch.nn.Conv2d(2 * ddconfig['z_channels'], + 2 * embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, + ddconfig['z_channels'], 1) + self.embed_dim = embed_dim + if colorize_nlabels is not None: + assert type(colorize_nlabels) == int + self.register_buffer('colorize', + torch.randn(3, colorize_nlabels, 1, 1)) + if monitor is not None: + self.monitor = monitor + + self.use_ema = ema_decay is not None + + if pretrained is not None: + self.init_from_ckpt(pretrained, ignore_keys=ignore_keys) + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location='cpu')['state_dict'] + keys = list(sd.keys()) + sd_new = collections.OrderedDict() + for k in keys: + if k.find('first_stage_model') >= 0: + k_new = k.split('first_stage_model.')[-1] + sd_new[k_new] = sd[k] + self.load_state_dict(sd_new, strict=True) + logger.info(f'Restored from {path}') + + def on_train_batch_end(self, *args, **kwargs): + if self.use_ema: + self.model_ema(self) + + def encode(self, x): + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def decode(self, z): + z = self.post_quant_conv(z) + dec = self.decoder(z) + return dec + + def forward(self, input, sample_posterior=True): + posterior = self.encode(input) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + dec = self.decode(z) + return dec, posterior + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = x.permute(0, 3, 1, + 2).to(memory_format=torch.contiguous_format).float() + return x + + def get_last_layer(self): + return self.decoder.conv_out.weight + + @torch.no_grad() + def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + if not only_inputs: + xrec, posterior = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec.shape[1] > 3 + x = self.to_rgb(x) + xrec = self.to_rgb(xrec) + log['samples'] = self.decode(torch.randn_like(posterior.sample())) + log['reconstructions'] = xrec + if log_ema or self.use_ema: + with self.ema_scope(): + xrec_ema, posterior_ema = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec_ema.shape[1] > 3 + xrec_ema = self.to_rgb(xrec_ema) + log['samples_ema'] = self.decode( + torch.randn_like(posterior_ema.sample())) + log['reconstructions_ema'] = xrec_ema + log['inputs'] = x + return log + + def to_rgb(self, x): + assert self.image_key == 'segmentation' + if not hasattr(self, 'colorize'): + self.register_buffer('colorize', + torch.randn(3, x.shape[1], 1, 1).to(x)) + x = F.conv2d(x, weight=self.colorize) + x = 2. * (x - x.min()) / (x.max() - x.min()) - 1. + return x + + +class IdentityFirstStage(torch.nn.Module): + + def __init__(self, *args, vq_interface=False, **kwargs): + self.vq_interface = vq_interface + super().__init__() + + def encode(self, x, *args, **kwargs): + return x + + def decode(self, x, *args, **kwargs): + return x + + def quantize(self, x, *args, **kwargs): + if self.vq_interface: + return x, None, [None, None, None] + return x + + def forward(self, x, *args, **kwargs): + return x diff --git a/modelscope/models/multi_modal/video_to_video/modules/embedder.py b/modelscope/models/multi_modal/video_to_video/modules/embedder.py new file mode 100644 index 00000000..ae8889a6 --- /dev/null +++ b/modelscope/models/multi_modal/video_to_video/modules/embedder.py @@ -0,0 +1,76 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os + +import numpy as np +import open_clip +import torch +import torch.nn as nn +import torchvision.transforms as T + + +class FrozenOpenCLIPEmbedder(nn.Module): + """ + Uses the OpenCLIP transformer encoder for text + """ + LAYERS = ['last', 'penultimate'] + + def __init__(self, + pretrained, + arch='ViT-H-14', + device='cuda', + max_length=77, + freeze=True, + layer='penultimate'): + super().__init__() + assert layer in self.LAYERS + model, _, preprocess = open_clip.create_model_and_transforms( + arch, device=torch.device('cpu'), pretrained=pretrained) + + del model.visual + self.model = model + self.device = device + self.max_length = max_length + + if freeze: + self.freeze() + self.layer = layer + if self.layer == 'last': + self.layer_idx = 0 + elif self.layer == 'penultimate': + self.layer_idx = 1 + else: + raise NotImplementedError() + + def freeze(self): + self.model = self.model.eval() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + tokens = open_clip.tokenize(text) + z = self.encode_with_transformer(tokens.to(self.device)) + return z + + def encode_with_transformer(self, text): + x = self.model.token_embedding(text) + x = x + self.model.positional_embedding + x = x.permute(1, 0, 2) + x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) + x = x.permute(1, 0, 2) + x = self.model.ln_final(x) + return x + + def text_transformer_forward(self, x: torch.Tensor, attn_mask=None): + for i, r in enumerate(self.model.transformer.resblocks): + if i == len(self.model.transformer.resblocks) - self.layer_idx: + break + if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting( + ): + x = checkpoint(r, x, attn_mask) + else: + x = r(x, attn_mask=attn_mask) + return x + + def encode(self, text): + return self(text) diff --git a/modelscope/models/multi_modal/video_to_video/modules/unet_v2v.py b/modelscope/models/multi_modal/video_to_video/modules/unet_v2v.py new file mode 100644 index 00000000..219ddb43 --- /dev/null +++ b/modelscope/models/multi_modal/video_to_video/modules/unet_v2v.py @@ -0,0 +1,1530 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import math +import os + +import torch +import torch.nn as nn +import torch.nn.functional as F +import xformers +import xformers.ops +from einops import rearrange +from fairscale.nn.checkpoint import checkpoint_wrapper +from rotary_embedding_torch import RotaryEmbedding + +USE_TEMPORAL_TRANSFORMER = True + + +class DropPath(nn.Module): + r"""DropPath but without rescaling and supports optional all-zero and/or all-keep. + """ + + def __init__(self, p): + super(DropPath, self).__init__() + self.p = p + + def forward(self, *args, zero=None, keep=None): + if not self.training: + return args[0] if len(args) == 1 else args + + # params + x = args[0] + b = x.size(0) + n = (torch.rand(b) < self.p).sum() + + # non-zero and non-keep mask + mask = x.new_ones(b, dtype=torch.bool) + if keep is not None: + mask[keep] = False + if zero is not None: + mask[zero] = False + + # drop-path index + index = torch.where(mask)[0] + index = index[torch.randperm(len(index))[:n]] + if zero is not None: + index = torch.cat([index, torch.where(zero)[0]], dim=0) + + # drop-path multiplier + multiplier = x.new_ones(b) + multiplier[index] = 0.0 + output = tuple(u * self.broadcast(multiplier, u) for u in args) + return output[0] if len(args) == 1 else output + + def broadcast(self, src, dst): + assert src.size(0) == dst.size(0) + shape = (dst.size(0), ) + (1, ) * (dst.ndim - 1) + return src.view(shape) + + +def sinusoidal_embedding(timesteps, dim): + # check input + half = dim // 2 + timesteps = timesteps.float() + + # compute sinusoidal embedding + sinusoid = torch.outer( + timesteps, torch.pow(10000, + -torch.arange(half).to(timesteps).div(half))) + x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) + if dim % 2 != 0: + x = torch.cat([x, torch.zeros_like(x[:, :1])], dim=1) + return x + + +def exists(x): + return x is not None + + +def default(val, d): + if exists(val): + return val + return d() if callable(d) else d + + +def prob_mask_like(shape, prob, device): + if prob == 1: + return torch.ones(shape, device=device, dtype=torch.bool) + elif prob == 0: + return torch.zeros(shape, device=device, dtype=torch.bool) + else: + mask = torch.zeros(shape, device=device).float().uniform_(0, 1) < prob + # aviod mask all, which will cause find_unused_parameters error + if mask.all(): + mask[0] = False + return mask + + +class MemoryEfficientCrossAttention(nn.Module): + + def __init__(self, + query_dim, + context_dim=None, + heads=8, + dim_head=64, + dropout=0.0): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.heads = heads + self.dim_head = dim_head + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) + self.attention_op: Optional[Any] = None + + def forward(self, x, context=None, mask=None): + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + b, _, _ = q.shape + q, k, v = map( + lambda t: t.unsqueeze(3).reshape(b, t.shape[ + 1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape( + b * self.heads, t.shape[1], self.dim_head).contiguous(), + (q, k, v), + ) + + # actually compute the attention, what we cannot get enough of + out = xformers.ops.memory_efficient_attention( + q, k, v, attn_bias=None, op=self.attention_op) + + if exists(mask): + raise NotImplementedError + out = ( + out.unsqueeze(0).reshape( + b, self.heads, out.shape[1], + self.dim_head).permute(0, 2, 1, + 3).reshape(b, out.shape[1], + self.heads * self.dim_head)) + return self.to_out(out) + + +class RelativePositionBias(nn.Module): + + def __init__(self, heads=8, num_buckets=32, max_distance=128): + super().__init__() + self.num_buckets = num_buckets + self.max_distance = max_distance + self.relative_attention_bias = nn.Embedding(num_buckets, heads) + + @staticmethod + def _relative_position_bucket(relative_position, + num_buckets=32, + max_distance=128): + ret = 0 + n = -relative_position + + num_buckets //= 2 + ret += (n < 0).long() * num_buckets + n = torch.abs(n) + + max_exact = num_buckets // 2 + is_small = n < max_exact + + val_if_large = max_exact + ( + torch.log(n.float() / max_exact) + / math.log(max_distance / max_exact) * # noqa + (num_buckets - max_exact)).long() + val_if_large = torch.min( + val_if_large, torch.full_like(val_if_large, num_buckets - 1)) + + ret += torch.where(is_small, n, val_if_large) + return ret + + def forward(self, n, device): + q_pos = torch.arange(n, dtype=torch.long, device=device) + k_pos = torch.arange(n, dtype=torch.long, device=device) + rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1') + rp_bucket = self._relative_position_bucket( + rel_pos, + num_buckets=self.num_buckets, + max_distance=self.max_distance) + values = self.relative_attention_bias(rp_bucket) + return rearrange(values, 'i j h -> h i j') + + +class SpatialTransformer(nn.Module): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + NEW: use_linear for more efficiency instead of the 1x1 convs + """ + + def __init__(self, + in_channels, + n_heads, + d_head, + depth=1, + dropout=0., + context_dim=None, + disable_self_attn=False, + use_linear=False, + use_checkpoint=True): + super().__init__() + if exists(context_dim) and not isinstance(context_dim, list): + context_dim = [context_dim] + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = torch.nn.GroupNorm( + num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + if not use_linear: + self.proj_in = nn.Conv2d( + in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + else: + self.proj_in = nn.Linear(in_channels, inner_dim) + + self.transformer_blocks = nn.ModuleList([ + BasicTransformerBlock( + inner_dim, + n_heads, + d_head, + dropout=dropout, + context_dim=context_dim[d], + disable_self_attn=disable_self_attn, + checkpoint=use_checkpoint) for d in range(depth) + ]) + if not use_linear: + self.proj_out = zero_module( + nn.Conv2d( + inner_dim, in_channels, kernel_size=1, stride=1, + padding=0)) + else: + self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) + self.use_linear = use_linear + + def forward(self, x, context=None): + # note: if no context is given, cross-attention defaults to self-attention + if not isinstance(context, list): + context = [context] + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + if not self.use_linear: + x = self.proj_in(x) + x = rearrange(x, 'b c h w -> b (h w) c').contiguous() + if self.use_linear: + x = self.proj_in(x) + for i, block in enumerate(self.transformer_blocks): + x = block(x, context=context[i]) + if self.use_linear: + x = self.proj_out(x) + x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() + if not self.use_linear: + x = self.proj_out(x) + return x + x_in + + +_ATTN_PRECISION = os.environ.get('ATTN_PRECISION', 'fp32') + + +class CrossAttention(nn.Module): + + def __init__(self, + query_dim, + context_dim=None, + heads=8, + dim_head=64, + dropout=0.): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.scale = dim_head**-0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) + + def forward(self, x, context=None, mask=None): + h = self.heads + + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), + (q, k, v)) + + # force cast to fp32 to avoid overflowing + if _ATTN_PRECISION == 'fp32': + with torch.autocast(enabled=False, device_type='cuda'): + q, k = q.float(), k.float() + sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale + else: + sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale + + del q, k + + if exists(mask): + mask = rearrange(mask, 'b ... -> b (...)') + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h=h) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + sim = sim.softmax(dim=-1) + + out = torch.einsum('b i j, b j d -> b i d', sim, v) + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + return self.to_out(out) + + +class BasicTransformerBlock(nn.Module): + + def __init__(self, + dim, + n_heads, + d_head, + dropout=0., + context_dim=None, + gated_ff=True, + checkpoint=True, + disable_self_attn=False): + super().__init__() + attn_cls = MemoryEfficientCrossAttention + self.disable_self_attn = disable_self_attn + self.attn1 = attn_cls( + query_dim=dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout, + context_dim=context_dim if self.disable_self_attn else None) + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = attn_cls( + query_dim=dim, + context_dim=context_dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout) + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + + def forward_(self, x, context=None): + return checkpoint(self._forward, (x, context), self.parameters(), + self.checkpoint) + + def forward(self, x, context=None): + x = self.attn1( + self.norm1(x), + context=context if self.disable_self_attn else None) + x + x = self.attn2(self.norm2(x), context=context) + x + x = self.ff(self.norm3(x)) + x + return x + + +# feedforward +class GEGLU(nn.Module): + + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +class FeedForward(nn.Module): + + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = nn.Sequential(nn.Linear( + dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential(project_in, nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out)) + + def forward(self, x): + return self.net(x) + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, + channels, + use_conv, + dims=2, + out_channels=None, + padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = nn.Conv2d( + self.channels, self.out_channels, 3, padding=padding) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate( + x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), + mode='nearest') + else: + x = F.interpolate(x, scale_factor=2, mode='nearest') + x = x[..., 1:-1, :] + if self.use_conv: + x = self.conv(x) + return x + + +class ResBlock(nn.Module): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param use_checkpoint: if True, use gradient checkpointing on this module. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + up=False, + down=False, + use_temporal_conv=True, + use_image_dataset=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_scale_shift_norm = use_scale_shift_norm + self.use_temporal_conv = use_temporal_conv + + self.in_layers = nn.Sequential( + nn.GroupNorm(32, channels), + nn.SiLU(), + nn.Conv2d(channels, self.out_channels, 3, padding=1), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + nn.SiLU(), + nn.Linear( + emb_channels, + 2 * self.out_channels + if use_scale_shift_norm else self.out_channels, + ), + ) + self.out_layers = nn.Sequential( + nn.GroupNorm(32, self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module( + nn.Conv2d(self.out_channels, self.out_channels, 3, padding=1)), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, 3, padding=1) + else: + self.skip_connection = nn.Conv2d(channels, self.out_channels, 1) + + if self.use_temporal_conv: + self.temopral_conv = TemporalConvBlock_v2( + self.out_channels, + self.out_channels, + dropout=0.1, + use_image_dataset=use_image_dataset) + + def forward(self, x, emb, batch_size): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + return self._forward(x, emb, batch_size) + + def _forward(self, x, emb, batch_size): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = th.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + h = self.skip_connection(x) + h + + if self.use_temporal_conv: + h = rearrange(h, '(b f) c h w -> b c f h w', b=batch_size) + h = self.temopral_conv(h) + h = rearrange(h, 'b c f h w -> (b f) c h w') + return h + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, + channels, + use_conv, + dims=2, + out_channels=None, + padding=(2, 1)): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = nn.Conv2d( + self.channels, + self.out_channels, + 3, + stride=stride, + padding=padding) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class Resample(nn.Module): + + def __init__(self, in_dim, out_dim, mode): + assert mode in ['none', 'upsample', 'downsample'] + super(Resample, self).__init__() + self.in_dim = in_dim + self.out_dim = out_dim + self.mode = mode + + def forward(self, x, reference=None): + if self.mode == 'upsample': + assert reference is not None + x = F.interpolate(x, size=reference.shape[-2:], mode='nearest') + elif self.mode == 'downsample': + x = F.adaptive_avg_pool2d( + x, output_size=tuple(u // 2 for u in x.shape[-2:])) + return x + + +class ResidualBlock(nn.Module): + + def __init__(self, + in_dim, + embed_dim, + out_dim, + use_scale_shift_norm=True, + mode='none', + dropout=0.0): + super(ResidualBlock, self).__init__() + self.in_dim = in_dim + self.embed_dim = embed_dim + self.out_dim = out_dim + self.use_scale_shift_norm = use_scale_shift_norm + self.mode = mode + + # layers + self.layer1 = nn.Sequential( + nn.GroupNorm(32, in_dim), nn.SiLU(), + nn.Conv2d(in_dim, out_dim, 3, padding=1)) + self.resample = Resample(in_dim, in_dim, mode) + self.embedding = nn.Sequential( + nn.SiLU(), + nn.Linear(embed_dim, + out_dim * 2 if use_scale_shift_norm else out_dim)) + self.layer2 = nn.Sequential( + nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout), + nn.Conv2d(out_dim, out_dim, 3, padding=1)) + self.shortcut = nn.Identity() if in_dim == out_dim else nn.Conv2d( + in_dim, out_dim, 1) + + # zero out the last layer params + nn.init.zeros_(self.layer2[-1].weight) + + def forward(self, x, e, reference=None): + identity = self.resample(x, reference) + x = self.layer1[-1](self.resample(self.layer1[:-1](x), reference)) + e = self.embedding(e).unsqueeze(-1).unsqueeze(-1).type(x.dtype) + if self.use_scale_shift_norm: + scale, shift = e.chunk(2, dim=1) + x = self.layer2[0](x) * (1 + scale) + shift + x = self.layer2[1:](x) + else: + x = x + e + x = self.layer2(x) + x = x + self.shortcut(identity) + return x + + +class AttentionBlock(nn.Module): + + def __init__(self, dim, context_dim=None, num_heads=None, head_dim=None): + # consider head_dim first, then num_heads + num_heads = dim // head_dim if head_dim else num_heads + head_dim = dim // num_heads + assert num_heads * head_dim == dim + super(AttentionBlock, self).__init__() + self.dim = dim + self.context_dim = context_dim + self.num_heads = num_heads + self.head_dim = head_dim + self.scale = math.pow(head_dim, -0.25) + + # layers + self.norm = nn.GroupNorm(32, dim) + self.to_qkv = nn.Conv2d(dim, dim * 3, 1) + if context_dim is not None: + self.context_kv = nn.Linear(context_dim, dim * 2) + self.proj = nn.Conv2d(dim, dim, 1) + + # zero out the last layer params + nn.init.zeros_(self.proj.weight) + + def forward(self, x, context=None): + r"""x: [B, C, H, W]. + context: [B, L, C] or None. + """ + identity = x + b, c, h, w, n, d = *x.size(), self.num_heads, self.head_dim + + # compute query, key, value + x = self.norm(x) + q, k, v = self.to_qkv(x).view(b, n * 3, d, h * w).chunk(3, dim=1) + if context is not None: + ck, cv = self.context_kv(context).reshape(b, -1, n * 2, + d).permute(0, 2, 3, + 1).chunk( + 2, dim=1) + k = torch.cat([ck, k], dim=-1) + v = torch.cat([cv, v], dim=-1) + + # compute attention + attn = torch.matmul(q.transpose(-1, -2) * self.scale, k * self.scale) + attn = F.softmax(attn, dim=-1) + + # gather context + x = torch.matmul(v, attn.transpose(-1, -2)) + x = x.reshape(b, c, h, w) + + # output + x = self.proj(x) + return x + identity + + +class TemporalAttentionBlock(nn.Module): + + def __init__(self, + dim, + heads=4, + dim_head=32, + rotary_emb=None, + use_image_dataset=False, + use_sim_mask=False): + super().__init__() + # consider num_heads first, as pos_bias needs fixed num_heads + dim_head = dim // heads + assert heads * dim_head == dim + self.use_image_dataset = use_image_dataset + self.use_sim_mask = use_sim_mask + + self.scale = dim_head**-0.5 + self.heads = heads + hidden_dim = dim_head * heads + + self.norm = nn.GroupNorm(32, dim) + self.rotary_emb = rotary_emb + self.to_qkv = nn.Linear(dim, hidden_dim * 3) + self.to_out = nn.Linear(hidden_dim, dim) + + def forward(self, + x, + pos_bias=None, + focus_present_mask=None, + video_mask=None): + + identity = x + n, height, device = x.shape[2], x.shape[-2], x.device + + x = self.norm(x) + x = rearrange(x, 'b c f h w -> b (h w) f c') + + qkv = self.to_qkv(x).chunk(3, dim=-1) + + if exists(focus_present_mask) and focus_present_mask.all(): + # if all batch samples are focusing on present + # it would be equivalent to passing that token's values (v=qkv[-1]) through to the output + values = qkv[-1] + out = self.to_out(values) + out = rearrange(out, 'b (h w) f c -> b c f h w', h=height) + + return out + identity + + # split out heads + q = rearrange(qkv[0], '... n (h d) -> ... h n d', h=self.heads) + k = rearrange(qkv[1], '... n (h d) -> ... h n d', h=self.heads) + v = rearrange(qkv[2], '... n (h d) -> ... h n d', h=self.heads) + + # scale + + q = q * self.scale + + # rotate positions into queries and keys for time attention + if exists(self.rotary_emb): + q = self.rotary_emb.rotate_queries_or_keys(q) + k = self.rotary_emb.rotate_queries_or_keys(k) + + # similarity + # shape [b (hw) h n n], n=f + sim = torch.einsum('... h i d, ... h j d -> ... h i j', q, k) + + # relative positional bias + + if exists(pos_bias): + sim = sim + pos_bias + + if (focus_present_mask is None and video_mask is not None): + # video_mask: [B, n] + mask = video_mask[:, None, :] * video_mask[:, :, None] + mask = mask.unsqueeze(1).unsqueeze(1) + sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max) + elif exists(focus_present_mask) and not (~focus_present_mask).all(): + attend_all_mask = torch.ones((n, n), + device=device, + dtype=torch.bool) + attend_self_mask = torch.eye(n, device=device, dtype=torch.bool) + + mask = torch.where( + rearrange(focus_present_mask, 'b -> b 1 1 1 1'), + rearrange(attend_self_mask, 'i j -> 1 1 1 i j'), + rearrange(attend_all_mask, 'i j -> 1 1 1 i j'), + ) + + sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max) + + if self.use_sim_mask: + sim_mask = torch.tril( + torch.ones((n, n), device=device, dtype=torch.bool), + diagonal=0) + sim = sim.masked_fill(~sim_mask, -torch.finfo(sim.dtype).max) + + # numerical stability + sim = sim - sim.amax(dim=-1, keepdim=True).detach() + attn = sim.softmax(dim=-1) + + # aggregate values + + out = torch.einsum('... h i j, ... h j d -> ... h i d', attn, v) + out = rearrange(out, '... h n d -> ... n (h d)') + out = self.to_out(out) + + out = rearrange(out, 'b (h w) f c -> b c f h w', h=height) + + if self.use_image_dataset: + out = identity + 0 * out + else: + out = identity + out + return out + + +class TemporalTransformer(nn.Module): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + """ + + def __init__(self, + in_channels, + n_heads, + d_head, + depth=1, + dropout=0., + context_dim=None, + disable_self_attn=False, + use_linear=False, + use_checkpoint=True, + only_self_att=True, + multiply_zero=False): + super().__init__() + self.multiply_zero = multiply_zero + self.only_self_att = only_self_att + self.use_adaptor = False + if self.only_self_att: + context_dim = None + if not isinstance(context_dim, list): + context_dim = [context_dim] + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = torch.nn.GroupNorm( + num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + if not use_linear: + self.proj_in = nn.Conv1d( + in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + else: + self.proj_in = nn.Linear(in_channels, inner_dim) + if self.use_adaptor: + self.adaptor_in = nn.Linear(frames, frames) + + self.transformer_blocks = nn.ModuleList([ + BasicTransformerBlock( + inner_dim, + n_heads, + d_head, + dropout=dropout, + context_dim=context_dim[d], + checkpoint=use_checkpoint) for d in range(depth) + ]) + if not use_linear: + self.proj_out = zero_module( + nn.Conv1d( + inner_dim, in_channels, kernel_size=1, stride=1, + padding=0)) + else: + self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) + if self.use_adaptor: + self.adaptor_out = nn.Linear(frames, frames) + self.use_linear = use_linear + + def forward(self, x, context=None): + # note: if no context is given, cross-attention defaults to self-attention + if self.only_self_att: + context = None + if not isinstance(context, list): + context = [context] + b, c, f, h, w = x.shape + x_in = x + x = self.norm(x) + + if not self.use_linear: + x = rearrange(x, 'b c f h w -> (b h w) c f').contiguous() + x = self.proj_in(x) + # [16384, 16, 320] + if self.use_linear: + x = rearrange( + x, '(b f) c h w -> b (h w) f c', f=self.frames).contiguous() + x = self.proj_in(x) + + if self.only_self_att: + x = rearrange(x, 'bhw c f -> bhw f c').contiguous() + for i, block in enumerate(self.transformer_blocks): + x = block(x) + x = rearrange(x, '(b hw) f c -> b hw f c', b=b).contiguous() + else: + x = rearrange(x, '(b hw) c f -> b hw f c', b=b).contiguous() + for i, block in enumerate(self.transformer_blocks): + context[i] = rearrange( + context[i], '(b f) l con -> b f l con', + f=self.frames).contiguous() + # calculate each batch one by one (since number in shape could not greater then 65,535 for some package) + for j in range(b): + context_i_j = repeat( + context[i][j], + 'f l con -> (f r) l con', + r=(h * w) // self.frames, + f=self.frames).contiguous() + x[j] = block(x[j], context=context_i_j) + + if self.use_linear: + x = self.proj_out(x) + x = rearrange(x, 'b (h w) f c -> b f c h w', h=h, w=w).contiguous() + if not self.use_linear: + x = rearrange(x, 'b hw f c -> (b hw) c f').contiguous() + x = self.proj_out(x) + x = rearrange( + x, '(b h w) c f -> b c f h w', b=b, h=h, w=w).contiguous() + + if self.multiply_zero: + x = 0.0 * x + x_in + else: + x = x + x_in + return x + + +class TemporalAttentionMultiBlock(nn.Module): + + def __init__( + self, + dim, + heads=4, + dim_head=32, + rotary_emb=None, + use_image_dataset=False, + use_sim_mask=False, + temporal_attn_times=1, + ): + super().__init__() + self.att_layers = nn.ModuleList([ + TemporalAttentionBlock(dim, heads, dim_head, rotary_emb, + use_image_dataset, use_sim_mask) + for _ in range(temporal_attn_times) + ]) + + def forward(self, + x, + pos_bias=None, + focus_present_mask=None, + video_mask=None): + for layer in self.att_layers: + x = layer(x, pos_bias, focus_present_mask, video_mask) + return x + + +class InitTemporalConvBlock(nn.Module): + + def __init__(self, + in_dim, + out_dim=None, + dropout=0.0, + use_image_dataset=False): + super(InitTemporalConvBlock, self).__init__() + if out_dim is None: + out_dim = in_dim + self.in_dim = in_dim + self.out_dim = out_dim + self.use_image_dataset = use_image_dataset + + # conv layers + self.conv = nn.Sequential( + nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout), + nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0))) + + # zero out the last layer params,so the conv block is identity + nn.init.zeros_(self.conv[-1].weight) + nn.init.zeros_(self.conv[-1].bias) + + def forward(self, x): + identity = x + x = self.conv(x) + if self.use_image_dataset: + x = identity + 0 * x + else: + x = identity + x + return x + + +class TemporalConvBlock(nn.Module): + + def __init__(self, + in_dim, + out_dim=None, + dropout=0.0, + use_image_dataset=False): + super(TemporalConvBlock, self).__init__() + if out_dim is None: + out_dim = in_dim + self.in_dim = in_dim + self.out_dim = out_dim + self.use_image_dataset = use_image_dataset + + # conv layers + self.conv1 = nn.Sequential( + nn.GroupNorm(32, in_dim), nn.SiLU(), + nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0))) + self.conv2 = nn.Sequential( + nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout), + nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0))) + + # zero out the last layer params,so the conv block is identity + nn.init.zeros_(self.conv2[-1].weight) + nn.init.zeros_(self.conv2[-1].bias) + + def forward(self, x): + identity = x + x = self.conv1(x) + x = self.conv2(x) + if self.use_image_dataset: + x = identity + 0 * x + else: + x = identity + x + return x + + +class TemporalConvBlock_v2(nn.Module): + + def __init__(self, + in_dim, + out_dim=None, + dropout=0.0, + use_image_dataset=False): + super(TemporalConvBlock_v2, self).__init__() + if out_dim is None: + out_dim = in_dim + self.in_dim = in_dim + self.out_dim = out_dim + self.use_image_dataset = use_image_dataset + + # conv layers + self.conv1 = nn.Sequential( + nn.GroupNorm(32, in_dim), nn.SiLU(), + nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0))) + self.conv2 = nn.Sequential( + nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout), + nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0))) + self.conv3 = nn.Sequential( + nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout), + nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0))) + self.conv4 = nn.Sequential( + nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout), + nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0))) + + # zero out the last layer params,so the conv block is identity + nn.init.zeros_(self.conv4[-1].weight) + nn.init.zeros_(self.conv4[-1].bias) + + def forward(self, x): + identity = x + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + x = self.conv4(x) + + if self.use_image_dataset: + x = identity + 0.0 * x + else: + x = identity + x + return x + + +class Vid2VidSDUNet(nn.Module): + + def __init__(self, + in_dim=4, + dim=320, + y_dim=1024, + context_dim=1024, + out_dim=4, + dim_mult=[1, 2, 4, 4], + num_heads=8, + head_dim=64, + num_res_blocks=2, + attn_scales=[1 / 1, 1 / 2, 1 / 4], + use_scale_shift_norm=True, + dropout=0.1, + temporal_attn_times=1, + temporal_attention=True, + use_checkpoint=True, + use_image_dataset=False, + use_fps_condition=False, + use_sim_mask=False, + training=False, + inpainting=True): + embed_dim = dim * 4 + num_heads = num_heads if num_heads else dim // 32 + super(Vid2VidSDUNet, self).__init__() + self.in_dim = in_dim + self.dim = dim + self.y_dim = y_dim + self.context_dim = context_dim + self.embed_dim = embed_dim + self.out_dim = out_dim + self.dim_mult = dim_mult + # for temporal attention + self.num_heads = num_heads + # for spatial attention + self.head_dim = head_dim + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.use_scale_shift_norm = use_scale_shift_norm + self.temporal_attn_times = temporal_attn_times + self.temporal_attention = temporal_attention + self.use_checkpoint = use_checkpoint + self.use_image_dataset = use_image_dataset + self.use_fps_condition = use_fps_condition + self.use_sim_mask = use_sim_mask + self.training = training + self.inpainting = inpainting + + use_linear_in_temporal = False + transformer_depth = 1 + disabled_sa = False + # params + enc_dims = [dim * u for u in [1] + dim_mult] + dec_dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + shortcut_dims = [] + scale = 1.0 + + # embeddings + self.time_embed = nn.Sequential( + nn.Linear(dim, embed_dim), nn.SiLU(), + nn.Linear(embed_dim, embed_dim)) + + if temporal_attention and not USE_TEMPORAL_TRANSFORMER: + self.rotary_emb = RotaryEmbedding(min(32, head_dim)) + self.time_rel_pos_bias = RelativePositionBias( + heads=num_heads, max_distance=32) + + if self.use_fps_condition: + self.fps_embedding = nn.Sequential( + nn.Linear(dim, embed_dim), nn.SiLU(), + nn.Linear(embed_dim, embed_dim)) + nn.init.zeros_(self.fps_embedding[-1].weight) + nn.init.zeros_(self.fps_embedding[-1].bias) + + # encoder + self.input_blocks = nn.ModuleList() + init_block = nn.ModuleList([nn.Conv2d(self.in_dim, dim, 3, padding=1)]) + # need an initial temporal attention? + if temporal_attention: + if USE_TEMPORAL_TRANSFORMER: + init_block.append( + TemporalTransformer( + dim, + num_heads, + head_dim, + depth=transformer_depth, + context_dim=context_dim, + disable_self_attn=disabled_sa, + use_linear=use_linear_in_temporal, + multiply_zero=use_image_dataset)) + else: + init_block.append( + TemporalAttentionMultiBlock( + dim, + num_heads, + head_dim, + rotary_emb=self.rotary_emb, + temporal_attn_times=temporal_attn_times, + use_image_dataset=use_image_dataset)) + self.input_blocks.append(init_block) + shortcut_dims.append(dim) + for i, (in_dim, + out_dim) in enumerate(zip(enc_dims[:-1], enc_dims[1:])): + for j in range(num_res_blocks): + block = nn.ModuleList([ + ResBlock( + in_dim, + embed_dim, + dropout, + out_channels=out_dim, + use_scale_shift_norm=False, + use_image_dataset=use_image_dataset, + ) + ]) + if scale in attn_scales: + block.append( + SpatialTransformer( + out_dim, + out_dim // head_dim, + head_dim, + depth=1, + context_dim=self.context_dim, + disable_self_attn=False, + use_linear=True)) + if self.temporal_attention: + if USE_TEMPORAL_TRANSFORMER: + block.append( + TemporalTransformer( + out_dim, + out_dim // head_dim, + head_dim, + depth=transformer_depth, + context_dim=context_dim, + disable_self_attn=disabled_sa, + use_linear=use_linear_in_temporal, + multiply_zero=use_image_dataset)) + else: + block.append( + TemporalAttentionMultiBlock( + out_dim, + num_heads, + head_dim, + rotary_emb=self.rotary_emb, + use_image_dataset=use_image_dataset, + use_sim_mask=use_sim_mask, + temporal_attn_times=temporal_attn_times)) + in_dim = out_dim + self.input_blocks.append(block) + shortcut_dims.append(out_dim) + + # downsample + if i != len(dim_mult) - 1 and j == num_res_blocks - 1: + downsample = Downsample( + out_dim, True, dims=2, out_channels=out_dim) + shortcut_dims.append(out_dim) + scale /= 2.0 + self.input_blocks.append(downsample) + + self.middle_block = nn.ModuleList([ + ResBlock( + out_dim, + embed_dim, + dropout, + use_scale_shift_norm=False, + use_image_dataset=use_image_dataset, + ), + SpatialTransformer( + out_dim, + out_dim // head_dim, + head_dim, + depth=1, + context_dim=self.context_dim, + disable_self_attn=False, + use_linear=True) + ]) + + if self.temporal_attention: + if USE_TEMPORAL_TRANSFORMER: + self.middle_block.append( + TemporalTransformer( + out_dim, + out_dim // head_dim, + head_dim, + depth=transformer_depth, + context_dim=context_dim, + disable_self_attn=disabled_sa, + use_linear=use_linear_in_temporal, + multiply_zero=use_image_dataset, + )) + else: + self.middle_block.append( + TemporalAttentionMultiBlock( + out_dim, + num_heads, + head_dim, + rotary_emb=self.rotary_emb, + use_image_dataset=use_image_dataset, + use_sim_mask=use_sim_mask, + temporal_attn_times=temporal_attn_times)) + + self.middle_block.append( + ResBlock(out_dim, embed_dim, dropout, use_scale_shift_norm=False)) + + # decoder + self.output_blocks = nn.ModuleList() + for i, (in_dim, + out_dim) in enumerate(zip(dec_dims[:-1], dec_dims[1:])): + for j in range(num_res_blocks + 1): + block = nn.ModuleList([ + ResBlock( + in_dim + shortcut_dims.pop(), + embed_dim, + dropout, + out_dim, + use_scale_shift_norm=False, + use_image_dataset=use_image_dataset, + ) + ]) + if scale in attn_scales: + block.append( + SpatialTransformer( + out_dim, + out_dim // head_dim, + head_dim, + depth=1, + context_dim=1024, + disable_self_attn=False, + use_linear=True)) + if self.temporal_attention: + if USE_TEMPORAL_TRANSFORMER: + block.append( + TemporalTransformer( + out_dim, + out_dim // head_dim, + head_dim, + depth=transformer_depth, + context_dim=context_dim, + disable_self_attn=disabled_sa, + use_linear=use_linear_in_temporal, + multiply_zero=use_image_dataset)) + else: + block.append( + TemporalAttentionMultiBlock( + out_dim, + num_heads, + head_dim, + rotary_emb=self.rotary_emb, + use_image_dataset=use_image_dataset, + use_sim_mask=use_sim_mask, + temporal_attn_times=temporal_attn_times)) + in_dim = out_dim + + # upsample + if i != len(dim_mult) - 1 and j == num_res_blocks: + upsample = Upsample( + out_dim, True, dims=2.0, out_channels=out_dim) + scale *= 2.0 + block.append(upsample) + self.output_blocks.append(block) + + # head + self.out = nn.Sequential( + nn.GroupNorm(32, out_dim), nn.SiLU(), + nn.Conv2d(out_dim, self.out_dim, 3, padding=1)) + + # zero out the last layer params + nn.init.zeros_(self.out[-1].weight) + + def forward(self, + x, + t, + y, + x_lr=None, + fps=None, + video_mask=None, + focus_present_mask=None, + prob_focus_present=0., + mask_last_frame_num=0): + + batch, x_c, x_f, x_h, x_w = x.shape + device = x.device + self.batch = batch + + # image and video joint training, if mask_last_frame_num is set, prob_focus_present will be ignored + if mask_last_frame_num > 0: + focus_present_mask = None + video_mask[-mask_last_frame_num:] = False + else: + focus_present_mask = default( + focus_present_mask, lambda: prob_mask_like( + (batch, ), prob_focus_present, device=device)) + + if self.temporal_attention and not USE_TEMPORAL_TRANSFORMER: + time_rel_pos_bias = self.time_rel_pos_bias( + x.shape[2], device=x.device) + else: + time_rel_pos_bias = None + + # embeddings + e = self.time_embed(sinusoidal_embedding(t, self.dim)) + context = y + + # repeat f times for spatial e and context + e = e.repeat_interleave(repeats=x_f, dim=0) + context = context.repeat_interleave(repeats=x_f, dim=0) + + # always in shape (b f) c h w, except for temporal layer + x = rearrange(x, 'b c f h w -> (b f) c h w') + # encoder + xs = [] + for block in self.input_blocks: + x = self._forward_single(block, x, e, context, time_rel_pos_bias, + focus_present_mask, video_mask) + xs.append(x) + + # middle + for block in self.middle_block: + x = self._forward_single(block, x, e, context, time_rel_pos_bias, + focus_present_mask, video_mask) + + # decoder + for block in self.output_blocks: + x = torch.cat([x, xs.pop()], dim=1) + x = self._forward_single( + block, + x, + e, + context, + time_rel_pos_bias, + focus_present_mask, + video_mask, + reference=xs[-1] if len(xs) > 0 else None) + + # head + x = self.out(x) + + # reshape back to (b c f h w) + x = rearrange(x, '(b f) c h w -> b c f h w', b=batch) + return x + + def _forward_single(self, + module, + x, + e, + context, + time_rel_pos_bias, + focus_present_mask, + video_mask, + reference=None): + if isinstance(module, ResidualBlock): + module = checkpoint_wrapper( + module) if self.use_checkpoint else module + x = x.contiguous() + x = module(x, e, reference) + elif isinstance(module, ResBlock): + module = checkpoint_wrapper( + module) if self.use_checkpoint else module + x = x.contiguous() + x = module(x, e, self.batch) + elif isinstance(module, SpatialTransformer): + module = checkpoint_wrapper( + module) if self.use_checkpoint else module + x = module(x, context) + elif isinstance(module, TemporalTransformer): + module = checkpoint_wrapper( + module) if self.use_checkpoint else module + x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch) + x = module(x, context) + x = rearrange(x, 'b c f h w -> (b f) c h w') + elif isinstance(module, CrossAttention): + module = checkpoint_wrapper( + module) if self.use_checkpoint else module + x = module(x, context) + elif isinstance(module, MemoryEfficientCrossAttention): + module = checkpoint_wrapper( + module) if self.use_checkpoint else module + x = module(x, context) + elif isinstance(module, BasicTransformerBlock): + module = checkpoint_wrapper( + module) if self.use_checkpoint else module + x = module(x, context) + elif isinstance(module, FeedForward): + x = module(x, context) + elif isinstance(module, Upsample): + x = module(x) + elif isinstance(module, Downsample): + x = module(x) + elif isinstance(module, Resample): + x = module(x, reference) + elif isinstance(module, TemporalAttentionBlock): + module = checkpoint_wrapper( + module) if self.use_checkpoint else module + x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch) + x = module(x, time_rel_pos_bias, focus_present_mask, video_mask) + x = rearrange(x, 'b c f h w -> (b f) c h w') + elif isinstance(module, TemporalAttentionMultiBlock): + module = checkpoint_wrapper( + module) if self.use_checkpoint else module + x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch) + x = module(x, time_rel_pos_bias, focus_present_mask, video_mask) + x = rearrange(x, 'b c f h w -> (b f) c h w') + elif isinstance(module, InitTemporalConvBlock): + module = checkpoint_wrapper( + module) if self.use_checkpoint else module + x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch) + x = module(x) + x = rearrange(x, 'b c f h w -> (b f) c h w') + elif isinstance(module, TemporalConvBlock): + module = checkpoint_wrapper( + module) if self.use_checkpoint else module + x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch) + x = module(x) + x = rearrange(x, 'b c f h w -> (b f) c h w') + elif isinstance(module, nn.ModuleList): + for block in module: + x = self._forward_single(block, x, e, context, + time_rel_pos_bias, focus_present_mask, + video_mask, reference) + else: + x = module(x) + return x diff --git a/modelscope/models/multi_modal/video_to_video/utils/__init__.py b/modelscope/models/multi_modal/video_to_video/utils/__init__.py new file mode 100644 index 00000000..92654e51 --- /dev/null +++ b/modelscope/models/multi_modal/video_to_video/utils/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os diff --git a/modelscope/models/multi_modal/video_to_video/utils/config.py b/modelscope/models/multi_modal/video_to_video/utils/config.py new file mode 100644 index 00000000..1c9586eb --- /dev/null +++ b/modelscope/models/multi_modal/video_to_video/utils/config.py @@ -0,0 +1,171 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import logging +import os +import os.path as osp +from datetime import datetime + +import torch +from easydict import EasyDict + +cfg = EasyDict(__name__='Config: VideoLDM Decoder') + +# ---------------------------work dir-------------------------- +cfg.work_dir = 'workspace/' + +# ---------------------------Global Variable----------------------------------- +cfg.resolution = [448, 256] +cfg.max_frames = 32 +# ----------------------------------------------------------------------------- + +# ---------------------------Dataset Parameter--------------------------------- +cfg.mean = [0.5, 0.5, 0.5] +cfg.std = [0.5, 0.5, 0.5] +cfg.max_words = 1000 + +# PlaceHolder +cfg.vit_out_dim = 1024 +cfg.vit_resolution = [224, 224] +cfg.depth_clamp = 10.0 +cfg.misc_size = 384 +cfg.depth_std = 20.0 + +cfg.frame_lens = 32 +cfg.sample_fps = 8 + +cfg.batch_sizes = 1 +# ----------------------------------------------------------------------------- + +# ---------------------------Mode Parameters----------------------------------- +# Diffusion +cfg.schedule = 'cosine' +cfg.num_timesteps = 1000 +cfg.mean_type = 'v' +cfg.var_type = 'fixed_small' +cfg.loss_type = 'mse' +cfg.ddim_timesteps = 50 +cfg.ddim_eta = 0.0 +cfg.clamp = 1.0 +cfg.share_noise = False +cfg.use_div_loss = False +cfg.noise_strength = 0.1 + +# classifier-free guidance +cfg.p_zero = 0.1 +cfg.guide_scale = 3.0 + +# clip vision encoder +cfg.vit_mean = [0.48145466, 0.4578275, 0.40821073] +cfg.vit_std = [0.26862954, 0.26130258, 0.27577711] + +# Model +cfg.scale_factor = 0.18215 +cfg.use_fp16 = True +cfg.temporal_attention = True +cfg.decoder_bs = 8 + +cfg.UNet = { + 'type': 'Vid2VidSDUNet', + 'in_dim': 4, + 'dim': 320, + 'y_dim': cfg.vit_out_dim, + 'context_dim': 1024, + 'out_dim': 8 if cfg.var_type.startswith('learned') else 4, + 'dim_mult': [1, 2, 4, 4], + 'num_heads': 8, + 'head_dim': 64, + 'num_res_blocks': 2, + 'attn_scales': [1 / 1, 1 / 2, 1 / 4], + 'dropout': 0.1, + 'temporal_attention': cfg.temporal_attention, + 'temporal_attn_times': 1, + 'use_checkpoint': False, + 'use_fps_condition': False, + 'use_sim_mask': False, + 'num_tokens': 4, + 'default_fps': 8, + 'input_dim': 1024 +} + +cfg.guidances = [] + +# auotoencoder from stabel diffusion +cfg.auto_encoder = { + 'type': 'AutoencoderKL', + 'ddconfig': { + 'double_z': True, + 'z_channels': 4, + 'resolution': 256, + 'in_channels': 3, + 'out_ch': 3, + 'ch': 128, + 'ch_mult': [1, 2, 4, 4], + 'num_res_blocks': 2, + 'attn_resolutions': [], + 'dropout': 0.0 + }, + 'embed_dim': 4, + 'pretrained': 'models/v2-1_512-ema-pruned.ckpt' +} +# clip embedder +cfg.embedder = { + 'type': 'FrozenOpenCLIPEmbedder', + 'layer': 'penultimate', + 'vit_resolution': [224, 224], + 'pretrained': 'open_clip_pytorch_model.bin' +} +# ----------------------------------------------------------------------------- + +# ---------------------------Training Settings--------------------------------- +# training and optimizer +cfg.ema_decay = 0.9999 +cfg.num_steps = 600000 +cfg.lr = 5e-5 +cfg.weight_decay = 0.0 +cfg.betas = (0.9, 0.999) +cfg.eps = 1.0e-8 +cfg.chunk_size = 16 +cfg.alpha = 0.7 +cfg.save_ckp_interval = 1000 +# ----------------------------------------------------------------------------- + +# ----------------------------Pretrain Settings--------------------------------- +# Default: load 2d pretrain +cfg.fix_weight = False +cfg.load_match = False +cfg.pretrained_checkpoint = 'v2-1_512-ema-pruned.ckpt' +cfg.pretrained_image_keys = 'stable_diffusion_image_key_temporal_attention_x1.json' +cfg.resume_checkpoint = 'img2video_ldm_0779000.pth' +# ----------------------------------------------------------------------------- + +# -----------------------------Visual------------------------------------------- +# Visual videos +cfg.viz_interval = 1000 +cfg.visual_train = { + 'type': 'VisualVideoTextDuringTrain', +} +cfg.visual_inference = { + 'type': 'VisualGeneratedVideos', +} +cfg.inference_list_path = '' + +# logging +cfg.log_interval = 100 + +# Default log_dir +cfg.log_dir = 'workspace/output_data' +# ----------------------------------------------------------------------------- + +# ---------------------------Others-------------------------------------------- +# seed +cfg.seed = 8888 +cfg.negative_prompt = 'worst quality, normal quality, low quality, low res, blurry, text, \ +watermark, logo, banner, extra digits, cropped, jpeg artifacts, signature, username, error, \ +sketch ,duplicate, ugly, monochrome, horror, geometry, mutation, disgusting' + +cfg.positive_prompt = ', cinematic, High Contrast, highly detailed, unreal engine, \ +taken using a Canon EOS R camera, hyper detailed photo - realistic maximum detail, \ +32k, Color Grading, ultra HD, extreme meticulous detailing, skin pore detailing, \ +hyper sharpness, perfect without deformations, Unreal Engine 5, 4k render' + +# ----------------------------------------------------------------------------- diff --git a/modelscope/models/multi_modal/video_to_video/utils/diffusion_sdedit.py b/modelscope/models/multi_modal/video_to_video/utils/diffusion_sdedit.py new file mode 100644 index 00000000..be5b5f57 --- /dev/null +++ b/modelscope/models/multi_modal/video_to_video/utils/diffusion_sdedit.py @@ -0,0 +1,247 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import random + +import torch + +from .schedules_sdedit import karras_schedule +from .solvers_sdedit import sample_dpmpp_2m_sde, sample_heun + +__all__ = ['GaussianDiffusion_SDEdit'] + + +def _i(tensor, t, x): + shape = (x.size(0), ) + (1, ) * (x.ndim - 1) + return tensor[t.to(tensor.device)].view(shape).to(x.device) + + +class GaussianDiffusion_SDEdit(object): + + def __init__(self, sigmas, prediction_type='eps'): + assert prediction_type in {'x0', 'eps', 'v'} + self.sigmas = sigmas + self.alphas = torch.sqrt(1 - sigmas**2) + self.num_timesteps = len(sigmas) + self.prediction_type = prediction_type + + def diffuse(self, x0, t, noise=None): + noise = torch.randn_like(x0) if noise is None else noise + xt = _i(self.alphas, t, x0) * x0 + _i(self.sigmas, t, x0) * noise + return xt + + def denoise(self, + xt, + t, + s, + model, + model_kwargs={}, + guide_scale=None, + guide_rescale=None, + clamp=None, + percentile=None): + s = t - 1 if s is None else s + + # hyperparams + sigmas = _i(self.sigmas, t, xt) + alphas = _i(self.alphas, t, xt) + alphas_s = _i(self.alphas, s.clamp(0), xt) + alphas_s[s < 0] = 1. + sigmas_s = torch.sqrt(1 - alphas_s**2) + + # precompute variables + betas = 1 - (alphas / alphas_s)**2 + coef1 = betas * alphas_s / sigmas**2 + coef2 = (alphas * sigmas_s**2) / (alphas_s * sigmas**2) + var = betas * (sigmas_s / sigmas)**2 + log_var = torch.log(var).clamp_(-20, 20) + + # prediction + if guide_scale is None: + assert isinstance(model_kwargs, dict) + out = model(xt, t=t, **model_kwargs) + else: + # classifier-free guidance + assert isinstance(model_kwargs, list) and len(model_kwargs) == 2 + y_out = model(xt, t=t, **model_kwargs[0]) + if guide_scale == 1.: + out = y_out + else: + u_out = model(xt, t=t, **model_kwargs[1]) + out = u_out + guide_scale * (y_out - u_out) + + if guide_rescale is not None: + assert guide_rescale >= 0 and guide_rescale <= 1 + ratio = ( + y_out.flatten(1).std(dim=1) / # noqa + (out.flatten(1).std(dim=1) + 1e-12) + ).view((-1, ) + (1, ) * (y_out.ndim - 1)) + out *= guide_rescale * ratio + (1 - guide_rescale) * 1.0 + + # compute x0 + if self.prediction_type == 'x0': + x0 = out + elif self.prediction_type == 'eps': + x0 = (xt - sigmas * out) / alphas + elif self.prediction_type == 'v': + x0 = alphas * xt - sigmas * out + else: + raise NotImplementedError( + f'prediction_type {self.prediction_type} not implemented') + + # restrict the range of x0 + if percentile is not None: + assert percentile > 0 and percentile <= 1 + s = torch.quantile(x0.flatten(1).abs(), percentile, dim=1) + s = s.clamp_(1.0).view((-1, ) + (1, ) * (xt.ndim - 1)) + x0 = torch.min(s, torch.max(-s, x0)) / s + elif clamp is not None: + x0 = x0.clamp(-clamp, clamp) + + # recompute eps using the restricted x0 + eps = (xt - alphas * x0) / sigmas + + # compute mu (mean of posterior distribution) using the restricted x0 + mu = coef1 * x0 + coef2 * xt + return mu, var, log_var, x0, eps + + @torch.no_grad() + def sample(self, + noise, + model, + model_kwargs={}, + condition_fn=None, + guide_scale=None, + guide_rescale=None, + clamp=None, + percentile=None, + solver='euler_a', + steps=20, + t_max=None, + t_min=None, + discretization=None, + discard_penultimate_step=None, + return_intermediate=None, + show_progress=False, + seed=-1, + **kwargs): + # sanity check + assert isinstance(steps, (int, torch.LongTensor)) + assert t_max is None or (t_max > 0 and t_max <= self.num_timesteps - 1) + assert t_min is None or (t_min >= 0 and t_min < self.num_timesteps - 1) + assert discretization in (None, 'leading', 'linspace', 'trailing') + assert discard_penultimate_step in (None, True, False) + assert return_intermediate in (None, 'x0', 'xt') + + # function of diffusion solver + solver_fn = { + 'heun': sample_heun, + 'dpmpp_2m_sde': sample_dpmpp_2m_sde + }[solver] + + # options + schedule = 'karras' if 'karras' in solver else None + discretization = discretization or 'linspace' + seed = seed if seed >= 0 else random.randint(0, 2**31) + if isinstance(steps, torch.LongTensor): + discard_penultimate_step = False + if discard_penultimate_step is None: + discard_penultimate_step = True if solver in ( + 'dpm2', 'dpm2_ancestral', 'dpmpp_2m_sde', 'dpm2_karras', + 'dpm2_ancestral_karras', 'dpmpp_2m_sde_karras') else False + + # function for denoising xt to get x0 + intermediates = [] + + def model_fn(xt, sigma): + # denoising + t = self._sigma_to_t(sigma).repeat(len(xt)).round().long() + x0 = self.denoise(xt, t, None, model, model_kwargs, guide_scale, + guide_rescale, clamp, percentile)[-2] + + # collect intermediate outputs + if return_intermediate == 'xt': + intermediates.append(xt) + elif return_intermediate == 'x0': + intermediates.append(x0) + return x0 + + # get timesteps + if isinstance(steps, int): + steps += 1 if discard_penultimate_step else 0 + t_max = self.num_timesteps - 1 if t_max is None else t_max + t_min = 0 if t_min is None else t_min + + # discretize timesteps + if discretization == 'leading': + steps = torch.arange(t_min, t_max + 1, + (t_max - t_min + 1) / steps).flip(0) + elif discretization == 'linspace': + steps = torch.linspace(t_max, t_min, steps) + elif discretization == 'trailing': + steps = torch.arange(t_max, t_min - 1, + -((t_max - t_min + 1) / steps)) + else: + raise NotImplementedError( + f'{discretization} discretization not implemented') + steps = steps.clamp_(t_min, t_max) + steps = torch.as_tensor( + steps, dtype=torch.float32, device=noise.device) + + # get sigmas + sigmas = self._t_to_sigma(steps) + sigmas = torch.cat([sigmas, sigmas.new_zeros([1])]) + if schedule == 'karras': + if sigmas[0] == float('inf'): + sigmas = karras_schedule( + n=len(steps) - 1, + sigma_min=sigmas[sigmas > 0].min().item(), + sigma_max=sigmas[sigmas < float('inf')].max().item(), + rho=7.).to(sigmas) + sigmas = torch.cat([ + sigmas.new_tensor([float('inf')]), sigmas, + sigmas.new_zeros([1]) + ]) + else: + sigmas = karras_schedule( + n=len(steps), + sigma_min=sigmas[sigmas > 0].min().item(), + sigma_max=sigmas.max().item(), + rho=7.).to(sigmas) + sigmas = torch.cat([sigmas, sigmas.new_zeros([1])]) + if discard_penultimate_step: + sigmas = torch.cat([sigmas[:-2], sigmas[-1:]]) + + # sampling + x0 = solver_fn( + noise, model_fn, sigmas, show_progress=show_progress, **kwargs) + return (x0, intermediates) if return_intermediate is not None else x0 + + def _sigma_to_t(self, sigma): + if sigma == float('inf'): + t = torch.full_like(sigma, len(self.sigmas) - 1) + else: + log_sigmas = torch.sqrt(self.sigmas**2 / # noqa + (1 - self.sigmas**2)).log().to(sigma) + log_sigma = sigma.log() + dists = log_sigma - log_sigmas[:, None] + low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp( + max=log_sigmas.shape[0] - 2) + high_idx = low_idx + 1 + low, high = log_sigmas[low_idx], log_sigmas[high_idx] + w = (low - log_sigma) / (low - high) + w = w.clamp(0, 1) + t = (1 - w) * low_idx + w * high_idx + t = t.view(sigma.shape) + if t.ndim == 0: + t = t.unsqueeze(0) + return t + + def _t_to_sigma(self, t): + t = t.float() + low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac() + log_sigmas = torch.sqrt(self.sigmas**2 / # noqa + (1 - self.sigmas**2)).log().to(t) + log_sigma = (1 - w) * log_sigmas[low_idx] + w * log_sigmas[high_idx] + log_sigma[torch.isnan(log_sigma) + | torch.isinf(log_sigma)] = float('inf') + return log_sigma.exp() diff --git a/modelscope/models/multi_modal/video_to_video/utils/schedules_sdedit.py b/modelscope/models/multi_modal/video_to_video/utils/schedules_sdedit.py new file mode 100644 index 00000000..06fd4e8a --- /dev/null +++ b/modelscope/models/multi_modal/video_to_video/utils/schedules_sdedit.py @@ -0,0 +1,85 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import math + +import torch + + +def betas_to_sigmas(betas): + return torch.sqrt(1 - torch.cumprod(1 - betas, dim=0)) + + +def sigmas_to_betas(sigmas): + square_alphas = 1 - sigmas**2 + betas = 1 - torch.cat( + [square_alphas[:1], square_alphas[1:] / square_alphas[:-1]]) + return betas + + +def logsnrs_to_sigmas(logsnrs): + return torch.sqrt(torch.sigmoid(-logsnrs)) + + +def sigmas_to_logsnrs(sigmas): + square_sigmas = sigmas**2 + return torch.log(square_sigmas / (1 - square_sigmas)) + + +def _logsnr_cosine(n, logsnr_min=-15, logsnr_max=15): + t_min = math.atan(math.exp(-0.5 * logsnr_min)) + t_max = math.atan(math.exp(-0.5 * logsnr_max)) + t = torch.linspace(1, 0, n) + logsnrs = -2 * torch.log(torch.tan(t_min + t * (t_max - t_min))) + return logsnrs + + +def _logsnr_cosine_shifted(n, logsnr_min=-15, logsnr_max=15, scale=2): + logsnrs = _logsnr_cosine(n, logsnr_min, logsnr_max) + logsnrs += 2 * math.log(1 / scale) + return logsnrs + + +def _logsnr_cosine_interp(n, + logsnr_min=-15, + logsnr_max=15, + scale_min=2, + scale_max=4): + t = torch.linspace(1, 0, n) + logsnrs_min = _logsnr_cosine_shifted(n, logsnr_min, logsnr_max, scale_min) + logsnrs_max = _logsnr_cosine_shifted(n, logsnr_min, logsnr_max, scale_max) + logsnrs = t * logsnrs_min + (1 - t) * logsnrs_max + return logsnrs + + +def karras_schedule(n, sigma_min=0.002, sigma_max=80.0, rho=7.0): + ramp = torch.linspace(1, 0, n) + min_inv_rho = sigma_min**(1 / rho) + max_inv_rho = sigma_max**(1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho))**rho + sigmas = torch.sqrt(sigmas**2 / (1 + sigmas**2)) + return sigmas + + +def logsnr_cosine_interp_schedule(n, + logsnr_min=-15, + logsnr_max=15, + scale_min=2, + scale_max=4): + return logsnrs_to_sigmas( + _logsnr_cosine_interp(n, logsnr_min, logsnr_max, scale_min, scale_max)) + + +def noise_schedule(schedule='logsnr_cosine_interp', + n=1000, + zero_terminal_snr=False, + **kwargs): + # compute sigmas + sigmas = { + 'logsnr_cosine_interp': logsnr_cosine_interp_schedule + }[schedule](n, **kwargs) + + # post-processing + if zero_terminal_snr and sigmas.max() != 1.0: + scale = (1.0 - sigmas.min()) / (sigmas.max() - sigmas.min()) + sigmas = sigmas.min() + scale * (sigmas - sigmas.min()) + return sigmas diff --git a/modelscope/models/multi_modal/video_to_video/utils/seed.py b/modelscope/models/multi_modal/video_to_video/utils/seed.py new file mode 100644 index 00000000..df3c9c50 --- /dev/null +++ b/modelscope/models/multi_modal/video_to_video/utils/seed.py @@ -0,0 +1,14 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import random + +import numpy as np +import torch + + +def setup_seed(seed): + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + torch.backends.cudnn.deterministic = True diff --git a/modelscope/models/multi_modal/video_to_video/utils/solvers_sdedit.py b/modelscope/models/multi_modal/video_to_video/utils/solvers_sdedit.py new file mode 100644 index 00000000..8d00a39f --- /dev/null +++ b/modelscope/models/multi_modal/video_to_video/utils/solvers_sdedit.py @@ -0,0 +1,194 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import torch +import torchsde +from tqdm.auto import trange + + +def get_ancestral_step(sigma_from, sigma_to, eta=1.): + """ + Calculates the noise level (sigma_down) to step down to and the amount + of noise to add (sigma_up) when doing an ancestral sampling step. + """ + if not eta: + return sigma_to, 0. + sigma_up = min( + sigma_to, + eta * ( + sigma_to**2 * # noqa + (sigma_from**2 - sigma_to**2) / sigma_from**2)**0.5) + sigma_down = (sigma_to**2 - sigma_up**2)**0.5 + return sigma_down, sigma_up + + +def get_scalings(sigma): + c_out = -sigma + c_in = 1 / (sigma**2 + 1.**2)**0.5 + return c_out, c_in + + +@torch.no_grad() +def sample_heun(noise, + model, + sigmas, + s_churn=0., + s_tmin=0., + s_tmax=float('inf'), + s_noise=1., + show_progress=True): + """ + Implements Algorithm 2 (Heun steps) from Karras et al. (2022). + """ + x = noise * sigmas[0] + for i in trange(len(sigmas) - 1, disable=not show_progress): + gamma = 0. + if s_tmin <= sigmas[i] <= s_tmax and sigmas[i] < float('inf'): + gamma = min(s_churn / (len(sigmas) - 1), 2**0.5 - 1) + eps = torch.randn_like(x) * s_noise + sigma_hat = sigmas[i] * (gamma + 1) + if gamma > 0: + x = x + eps * (sigma_hat**2 - sigmas[i]**2)**0.5 + if sigmas[i] == float('inf'): + # Euler method + denoised = model(noise, sigma_hat) + x = denoised + sigmas[i + 1] * (gamma + 1) * noise + else: + _, c_in = get_scalings(sigma_hat) + denoised = model(x * c_in, sigma_hat) + d = (x - denoised) / sigma_hat + dt = sigmas[i + 1] - sigma_hat + if sigmas[i + 1] == 0: + # Euler method + x = x + d * dt + else: + # Heun's method + x_2 = x + d * dt + _, c_in = get_scalings(sigmas[i + 1]) + denoised_2 = model(x_2 * c_in, sigmas[i + 1]) + d_2 = (x_2 - denoised_2) / sigmas[i + 1] + d_prime = (d + d_2) / 2 + x = x + d_prime * dt + return x + + +class BatchedBrownianTree: + """ + A wrapper around torchsde.BrownianTree that enables batches of entropy. + """ + + def __init__(self, x, t0, t1, seed=None, **kwargs): + t0, t1, self.sign = self.sort(t0, t1) + w0 = kwargs.get('w0', torch.zeros_like(x)) + if seed is None: + seed = torch.randint(0, 2**63 - 1, []).item() + self.batched = True + try: + assert len(seed) == x.shape[0] + w0 = w0[0] + except TypeError: + seed = [seed] + self.batched = False + self.trees = [ + torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) + for s in seed + ] + + @staticmethod + def sort(a, b): + return (a, b, 1) if a < b else (b, a, -1) + + def __call__(self, t0, t1): + t0, t1, sign = self.sort(t0, t1) + w = torch.stack([tree(t0, t1) for tree in self.trees]) * ( + self.sign * sign) + return w if self.batched else w[0] + + +class BrownianTreeNoiseSampler: + """ + A noise sampler backed by a torchsde.BrownianTree. + + Args: + x (Tensor): The tensor whose shape, device and dtype to use to generate + random samples. + sigma_min (float): The low end of the valid interval. + sigma_max (float): The high end of the valid interval. + seed (int or List[int]): The random seed. If a list of seeds is + supplied instead of a single integer, then the noise sampler will + use one BrownianTree per batch item, each with its own seed. + transform (callable): A function that maps sigma to the sampler's + internal timestep. + """ + + def __init__(self, + x, + sigma_min, + sigma_max, + seed=None, + transform=lambda x: x): + self.transform = transform + t0 = self.transform(torch.as_tensor(sigma_min)) + t1 = self.transform(torch.as_tensor(sigma_max)) + self.tree = BatchedBrownianTree(x, t0, t1, seed) + + def __call__(self, sigma, sigma_next): + t0 = self.transform(torch.as_tensor(sigma)) + t1 = self.transform(torch.as_tensor(sigma_next)) + return self.tree(t0, t1) / (t1 - t0).abs().sqrt() + + +@torch.no_grad() +def sample_dpmpp_2m_sde(noise, + model, + sigmas, + eta=1., + s_noise=1., + solver_type='midpoint', + show_progress=True): + """ + DPM-Solver++ (2M) SDE. + """ + assert solver_type in {'heun', 'midpoint'} + + x = noise * sigmas[0] + sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas[ + sigmas < float('inf')].max() + noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) + old_denoised = None + h_last = None + + for i in trange(len(sigmas) - 1, disable=not show_progress): + if sigmas[i] == float('inf'): + # Euler method + denoised = model(noise, sigmas[i]) + x = denoised + sigmas[i + 1] * noise + else: + _, c_in = get_scalings(sigmas[i]) + denoised = model(x * c_in, sigmas[i]) + if sigmas[i + 1] == 0: + # Denoising step + x = denoised + else: + # DPM-Solver++(2M) SDE + t, s = -sigmas[i].log(), -sigmas[i + 1].log() + h = s - t + eta_h = eta * h + + x = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * x + \ + (-h - eta_h).expm1().neg() * denoised + + if old_denoised is not None: + r = h_last / h + if solver_type == 'heun': + x = x + ((-h - eta_h).expm1().neg() / (-h - eta_h) + 1) * \ + (1 / r) * (denoised - old_denoised) + elif solver_type == 'midpoint': + x = x + 0.5 * (-h - eta_h).expm1().neg() * \ + (1 / r) * (denoised - old_denoised) + + x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[ + i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise + + old_denoised = denoised + h_last = h + return x diff --git a/modelscope/models/multi_modal/video_to_video/utils/transforms.py b/modelscope/models/multi_modal/video_to_video/utils/transforms.py new file mode 100644 index 00000000..3663620f --- /dev/null +++ b/modelscope/models/multi_modal/video_to_video/utils/transforms.py @@ -0,0 +1,404 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import math +import random + +import numpy as np +import torch +import torchvision.transforms.functional as F +from PIL import Image, ImageFilter + +__all__ = [ + 'Compose', 'Resize', 'Rescale', 'CenterCrop', 'CenterCropV2', + 'CenterCropWide', 'RandomCrop', 'RandomCropV2', 'RandomHFlip', + 'GaussianBlur', 'ColorJitter', 'RandomGray', 'ToTensor', 'Normalize', + 'ResizeRandomCrop', 'ExtractResizeRandomCrop', 'ExtractResizeAssignCrop' +] + + +class Compose(object): + + def __init__(self, transforms): + self.transforms = transforms + + def __getitem__(self, index): + if isinstance(index, slice): + return Compose(self.transforms[index]) + else: + return self.transforms[index] + + def __len__(self): + return len(self.transforms) + + def __call__(self, rgb): + for t in self.transforms: + rgb = t(rgb) + return rgb + + +class Resize(object): + + def __init__(self, size=256): + if isinstance(size, int): + size = (size, size) + self.size = size + + def __call__(self, rgb): + if isinstance(rgb, list): + rgb = [u.resize(self.size, Image.BILINEAR) for u in rgb] + else: + rgb = rgb.resize(self.size, Image.BILINEAR) + return rgb + + +class Rescale(object): + + def __init__(self, size=256, interpolation=Image.BILINEAR): + self.size = size + self.interpolation = interpolation + + def __call__(self, rgb): + w, h = rgb[0].size + scale = self.size / min(w, h) + out_w, out_h = int(round(w * scale)), int(round(h * scale)) + rgb = [u.resize((out_w, out_h), self.interpolation) for u in rgb] + return rgb + + +class CenterCrop(object): + + def __init__(self, size=224): + self.size = size + + def __call__(self, rgb): + w, h = rgb[0].size + assert min(w, h) >= self.size + x1 = (w - self.size) // 2 + y1 = (h - self.size) // 2 + rgb = [u.crop((x1, y1, x1 + self.size, y1 + self.size)) for u in rgb] + return rgb + + +class ResizeRandomCrop(object): + + def __init__(self, size=256, size_short=292): + self.size = size + self.size_short = size_short + + def __call__(self, rgb): + + # consistent crop between rgb and m + while min(rgb[0].size) >= 2 * self.size_short: + rgb = [ + u.resize((u.width // 2, u.height // 2), resample=Image.BOX) + for u in rgb + ] + scale = self.size_short / min(rgb[0].size) + rgb = [ + u.resize((round(scale * u.width), round(scale * u.height)), + resample=Image.BICUBIC) for u in rgb + ] + out_w = self.size + out_h = self.size + w, h = rgb[0].size + x1 = random.randint(0, w - out_w) + y1 = random.randint(0, h - out_h) + + rgb = [u.crop((x1, y1, x1 + out_w, y1 + out_h)) for u in rgb] + return rgb + + +class ExtractResizeRandomCrop(object): + + def __init__(self, size=256, size_short=292): + self.size = size + self.size_short = size_short + + def __call__(self, rgb): + + # consistent crop between rgb and m + while min(rgb[0].size) >= 2 * self.size_short: + rgb = [ + u.resize((u.width // 2, u.height // 2), resample=Image.BOX) + for u in rgb + ] + scale = self.size_short / min(rgb[0].size) + rgb = [ + u.resize((round(scale * u.width), round(scale * u.height)), + resample=Image.BICUBIC) for u in rgb + ] + out_w = self.size + out_h = self.size + w, h = rgb[0].size + x1 = random.randint(0, w - out_w) + y1 = random.randint(0, h - out_h) + + rgb = [u.crop((x1, y1, x1 + out_w, y1 + out_h)) for u in rgb] + wh = [x1, y1, x1 + out_w, y1 + out_h] + return rgb, wh + + +class ExtractResizeAssignCrop(object): + + def __init__(self, size=256, size_short=292): + self.size = size + self.size_short = size_short + + def __call__(self, rgb, wh): + + # consistent crop between rgb and m + while min(rgb[0].size) >= 2 * self.size_short: + rgb = [ + u.resize((u.width // 2, u.height // 2), resample=Image.BOX) + for u in rgb + ] + scale = self.size_short / min(rgb[0].size) + rgb = [ + u.resize((round(scale * u.width), round(scale * u.height)), + resample=Image.BICUBIC) for u in rgb + ] + + rgb = [u.crop(wh) for u in rgb] + rgb = [u.resize((self.size, self.size), Image.BILINEAR) for u in rgb] + + return rgb + + +class CenterCropV2(object): + + def __init__(self, size): + self.size = size + + def __call__(self, img): + # fast resize + while min(img[0].size) >= 2 * self.size: + img = [ + u.resize((u.width // 2, u.height // 2), resample=Image.BOX) + for u in img + ] + scale = self.size / min(img[0].size) + img = [ + u.resize((round(scale * u.width), round(scale * u.height)), + resample=Image.BICUBIC) for u in img + ] + + # center crop + x1 = (img[0].width - self.size) // 2 + y1 = (img[0].height - self.size) // 2 + img = [u.crop((x1, y1, x1 + self.size, y1 + self.size)) for u in img] + return img + + +class CenterCropWide(object): + + def __init__(self, size): + self.size = size + + def __call__(self, img): + if isinstance(img, list): + scale = min(img[0].size[0] / self.size[0], + img[0].size[1] / self.size[1]) + img = [ + u.resize((round(u.width // scale), round(u.height // scale)), + resample=Image.BOX) for u in img + ] + + # center crop + x1 = (img[0].width - self.size[0]) // 2 + y1 = (img[0].height - self.size[1]) // 2 + img = [ + u.crop((x1, y1, x1 + self.size[0], y1 + self.size[1])) + for u in img + ] + return img + else: + scale = min(img.size[0] / self.size[0], img.size[1] / self.size[1]) + img = img.resize( + (round(img.width // scale), round(img.height // scale)), + resample=Image.BOX) + x1 = (img.width - self.size[0]) // 2 + y1 = (img.height - self.size[1]) // 2 + img = img.crop((x1, y1, x1 + self.size[0], y1 + self.size[1])) + return img + + +class RandomCrop(object): + + def __init__(self, size=224, min_area=0.4): + self.size = size + self.min_area = min_area + + def __call__(self, rgb): + + # consistent crop between rgb and m + w, h = rgb[0].size + area = w * h + out_w, out_h = float('inf'), float('inf') + while out_w > w or out_h > h: + target_area = random.uniform(self.min_area, 1.0) * area + aspect_ratio = random.uniform(3. / 4., 4. / 3.) + out_w = int(round(math.sqrt(target_area * aspect_ratio))) + out_h = int(round(math.sqrt(target_area / aspect_ratio))) + x1 = random.randint(0, w - out_w) + y1 = random.randint(0, h - out_h) + + rgb = [u.crop((x1, y1, x1 + out_w, y1 + out_h)) for u in rgb] + rgb = [u.resize((self.size, self.size), Image.BILINEAR) for u in rgb] + + return rgb + + +class RandomCropV2(object): + + def __init__(self, size=224, min_area=0.4, ratio=(3. / 4., 4. / 3.)): + if isinstance(size, (tuple, list)): + self.size = size + else: + self.size = (size, size) + self.min_area = min_area + self.ratio = ratio + + def _get_params(self, img): + width, height = img.size + area = height * width + + for _ in range(10): + target_area = random.uniform(self.min_area, 1.0) * area + log_ratio = (math.log(self.ratio[0]), math.log(self.ratio[1])) + aspect_ratio = math.exp(random.uniform(*log_ratio)) + + w = int(round(math.sqrt(target_area * aspect_ratio))) + h = int(round(math.sqrt(target_area / aspect_ratio))) + + if 0 < w <= width and 0 < h <= height: + i = random.randint(0, height - h) + j = random.randint(0, width - w) + return i, j, h, w + + # Fallback to central crop + in_ratio = float(width) / float(height) + if (in_ratio < min(self.ratio)): + w = width + h = int(round(w / min(self.ratio))) + elif (in_ratio > max(self.ratio)): + h = height + w = int(round(h * max(self.ratio))) + else: + w = width + h = height + i = (height - h) // 2 + j = (width - w) // 2 + return i, j, h, w + + def __call__(self, rgb): + i, j, h, w = self._get_params(rgb[0]) + rgb = [F.resized_crop(u, i, j, h, w, self.size) for u in rgb] + return rgb + + +class RandomHFlip(object): + + def __init__(self, p=0.5): + self.p = p + + def __call__(self, rgb): + if random.random() < self.p: + rgb = [u.transpose(Image.FLIP_LEFT_RIGHT) for u in rgb] + return rgb + + +class GaussianBlur(object): + + def __init__(self, sigmas=[0.1, 2.0], p=0.5): + self.sigmas = sigmas + self.p = p + + def __call__(self, rgb): + if random.random() < self.p: + sigma = random.uniform(*self.sigmas) + rgb = [ + u.filter(ImageFilter.GaussianBlur(radius=sigma)) for u in rgb + ] + return rgb + + +class ColorJitter(object): + + def __init__(self, + brightness=0.4, + contrast=0.4, + saturation=0.4, + hue=0.1, + p=0.5): + self.brightness = brightness + self.contrast = contrast + self.saturation = saturation + self.hue = hue + self.p = p + + def __call__(self, rgb): + if random.random() < self.p: + brightness, contrast, saturation, hue = self._random_params() + transforms = [ + lambda f: F.adjust_brightness(f, brightness), + lambda f: F.adjust_contrast(f, contrast), + lambda f: F.adjust_saturation(f, saturation), + lambda f: F.adjust_hue(f, hue) + ] + random.shuffle(transforms) + for t in transforms: + rgb = [t(u) for u in rgb] + + return rgb + + def _random_params(self): + brightness = random.uniform( + max(0, 1 - self.brightness), 1 + self.brightness) + contrast = random.uniform(max(0, 1 - self.contrast), 1 + self.contrast) + saturation = random.uniform( + max(0, 1 - self.saturation), 1 + self.saturation) + hue = random.uniform(-self.hue, self.hue) + return brightness, contrast, saturation, hue + + +class RandomGray(object): + + def __init__(self, p=0.2): + self.p = p + + def __call__(self, rgb): + if random.random() < self.p: + rgb = [u.convert('L').convert('RGB') for u in rgb] + return rgb + + +class ToTensor(object): + + def __call__(self, rgb): + if isinstance(rgb, list): + rgb = torch.stack([F.to_tensor(u) for u in rgb], dim=0) + else: + rgb = F.to_tensor(rgb) + + return rgb + + +class Normalize(object): + + def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]): + self.mean = mean + self.std = std + + def __call__(self, rgb): + rgb = rgb.clone() + rgb.clamp_(0, 1) + if not isinstance(self.mean, torch.Tensor): + self.mean = rgb.new_tensor(self.mean).view(-1) + if not isinstance(self.std, torch.Tensor): + self.std = rgb.new_tensor(self.std).view(-1) + if rgb.dim() == 4: + rgb.sub_(self.mean.view(1, -1, 1, + 1)).div_(self.std.view(1, -1, 1, 1)) + elif rgb.dim() == 3: + rgb.sub_(self.mean.view(-1, 1, 1)).div_(self.std.view(-1, 1, 1)) + return rgb diff --git a/modelscope/models/multi_modal/video_to_video/video_to_video_model.py b/modelscope/models/multi_modal/video_to_video/video_to_video_model.py new file mode 100755 index 00000000..283de03f --- /dev/null +++ b/modelscope/models/multi_modal/video_to_video/video_to_video_model.py @@ -0,0 +1,227 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +import os.path as osp +import random +from copy import copy +from typing import Any, Dict + +import torch +import torch.cuda.amp as amp +import torch.nn.functional as F + +import modelscope.models.multi_modal.video_to_video.utils.transforms as data +from modelscope.metainfo import Models +from modelscope.models.base import TorchModel +from modelscope.models.builder import MODELS +from modelscope.models.multi_modal.video_to_video.modules import * +from modelscope.models.multi_modal.video_to_video.modules import ( + AutoencoderKL, FrozenOpenCLIPEmbedder, Vid2VidSDUNet, + get_first_stage_encoding) +from modelscope.models.multi_modal.video_to_video.utils.config import cfg +from modelscope.models.multi_modal.video_to_video.utils.diffusion_sdedit import \ + GaussianDiffusion_SDEdit +from modelscope.models.multi_modal.video_to_video.utils.schedules_sdedit import \ + noise_schedule +from modelscope.models.multi_modal.video_to_video.utils.seed import setup_seed +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +__all__ = ['VideoToVideo'] + +logger = get_logger() + + +@MODELS.register_module( + Tasks.video_to_video, module_name=Models.video_to_video_model) +class VideoToVideo(TorchModel): + r""" + Video2Video aims to solve the task of generating super-resolution videos based on input + video and text, which is a video generation basic model developed by Alibaba Cloud. + + Paper link: https://arxiv.org/abs/2306.02018 + + Attributes: + diffusion: diffusion model for DDIM. + autoencoder: decode the latent representation of input video into visual space. + clip_encoder: encode the text into text embedding. + """ + + def __init__(self, model_dir, *args, **kwargs): + r""" + Args: + model_dir (`str` or `os.PathLike`) + Can be either: + - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co + or modelscope.cn. Valid model ids can be located at the root-level, like `bert-base-uncased`, + or namespaced under a user or organization name, like `dbmdz/bert-base-german-cased`. + - A path to a *directory* containing model weights saved using + [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. + - A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In + this case, `from_tf` should be set to `True` and a configuration object should be provided as + `config` argument. This loading path is slower than converting the TensorFlow checkpoint in a + PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards. + - A path or url to a model folder containing a *flax checkpoint file* in *.msgpack* format (e.g, + `./flax_model/` containing `flax_model.msgpack`). In this case, `from_flax` should be set to + `True`. + """ + super().__init__(model_dir=model_dir, *args, **kwargs) + + self.config = Config.from_file( + osp.join(model_dir, ModelFile.CONFIGURATION)) + + cfg.solver_mode = self.config.model.model_args.solver_mode + + # assign default value + cfg.batch_size = self.config.model.model_cfg.batch_size + cfg.target_fps = self.config.model.model_cfg.target_fps + cfg.max_frames = self.config.model.model_cfg.max_frames + cfg.latent_hei = self.config.model.model_cfg.latent_hei + cfg.latent_wid = self.config.model.model_cfg.latent_wid + cfg.model_path = osp.join(model_dir, + self.config.model.model_args.ckpt_unet) + + self.device = torch.device( + 'cuda') if torch.cuda.is_available() else torch.device('cpu') + + if 'seed' in self.config.model.model_args.keys(): + cfg.seed = self.config.model.model_args.seed + else: + cfg.seed = random.randint(0, 99999) + setup_seed(cfg.seed) + + # transform + vid_trans = data.Compose( + [data.ToTensor(), + data.Normalize(mean=cfg.mean, std=cfg.std)]) + self.vid_trans = vid_trans + + cfg.embedder.pretrained = osp.join( + model_dir, self.config.model.model_args.ckpt_clip) + clip_encoder = FrozenOpenCLIPEmbedder( + pretrained=cfg.embedder.pretrained) + clip_encoder.model.to(self.device) + self.clip_encoder = clip_encoder + logger.info(f'Build encoder with {cfg.embedder.type}') + + # [unet] + generator = Vid2VidSDUNet() + generator = generator.to(self.device) + generator.eval() + load_dict = torch.load(cfg.model_path, map_location='cpu') + ret = generator.load_state_dict(load_dict['state_dict'], strict=True) + self.generator = generator + logger.info('Load model {} path {}, with local status {}'.format( + cfg.UNet.type, cfg.model_path, ret)) + + # [diffusion] + sigmas = noise_schedule( + schedule='logsnr_cosine_interp', + n=1000, + zero_terminal_snr=True, + scale_min=2.0, + scale_max=4.0) + diffusion = GaussianDiffusion_SDEdit( + sigmas=sigmas, prediction_type='v') + self.diffusion = diffusion + logger.info('Build diffusion with type of GaussianDiffusion_SDEdit') + + # [auotoencoder] + cfg.auto_encoder.pretrained = osp.join( + model_dir, self.config.model.model_args.ckpt_autoencoder) + autoencoder = AutoencoderKL(**cfg.auto_encoder) + autoencoder.eval() + for param in autoencoder.parameters(): + param.requires_grad = False + autoencoder.to(self.device) + self.autoencoder = autoencoder + torch.cuda.empty_cache() + + negative_prompt = cfg.negative_prompt + negative_y = clip_encoder(negative_prompt).detach() + self.negative_y = negative_y + + positive_prompt = cfg.positive_prompt + self.positive_prompt = positive_prompt + + self.cfg = cfg + + def forward(self, input: Dict[str, Any]): + r""" + The entry function of video to video task. + 1. Using CLIP to encode text into embeddings. + 2. Using diffusion model to generate the video's latent representation. + 3. Using autoencoder to decode the video's latent representation to visual space. + + Args: + input (`Dict[Str, Any]`): + The input of the task + Returns: + A generated video (as pytorch tensor). + """ + + video_data = input['video_data'] + y = input['y'] + cfg = self.cfg + + video_data = F.interpolate( + video_data, size=(720, 1280), mode='bilinear') + video_data = video_data.unsqueeze(0) + video_data = video_data.to(self.device) + + batch_size, frames_num, _, _, _ = video_data.shape + video_data = rearrange(video_data, 'b f c h w -> (b f) c h w') + + video_data_list = torch.chunk( + video_data, video_data.shape[0] // 2, dim=0) + with torch.no_grad(): + decode_data = [] + for vd_data in video_data_list: + encoder_posterior = self.autoencoder.encode(vd_data) + tmp = get_first_stage_encoding(encoder_posterior).detach() + decode_data.append(tmp) + video_data_feature = torch.cat(decode_data, dim=0) + video_data_feature = rearrange( + video_data_feature, '(b f) c h w -> b c f h w', b=batch_size) + + with amp.autocast(enabled=True): + total_noise_levels = 600 + t = torch.randint( + total_noise_levels - 1, + total_noise_levels, (1, ), + dtype=torch.long).to(self.device) + + noise = torch.randn_like(video_data_feature) + noised_lr = self.diffusion.diffuse(video_data_feature, t, noise) + model_kwargs = [{'y': y}, {'y': self.negative_y}] + + gen_vid = self.diffusion.sample( + noise=noised_lr, + model=self.generator, + model_kwargs=model_kwargs, + guide_scale=7.5, + guide_rescale=0.2, + solver='dpmpp_2m_sde' if cfg.solver_mode == 'fast' else 'heun', + steps=30 if cfg.solver_mode == 'fast' else 50, + t_max=total_noise_levels - 1, + t_min=0, + discretization='trailing') + + scale_factor = 0.18215 + vid_tensor_feature = 1. / scale_factor * gen_vid + + vid_tensor_feature = rearrange(vid_tensor_feature, + 'b c f h w -> (b f) c h w') + vid_tensor_feature_list = torch.chunk( + vid_tensor_feature, vid_tensor_feature.shape[0] // 2, dim=0) + decode_data = [] + for vd_data in vid_tensor_feature_list: + tmp = self.autoencoder.decode(vd_data) + decode_data.append(tmp) + vid_tensor_gen = torch.cat(decode_data, dim=0) + + gen_video = rearrange( + vid_tensor_gen, '(b f) c h w -> b c f h w', b=cfg.batch_size) + + return gen_video.type(torch.float32).cpu() diff --git a/modelscope/pipelines/multi_modal/video_to_video_pipeline.py b/modelscope/pipelines/multi_modal/video_to_video_pipeline.py new file mode 100644 index 00000000..36e6544d --- /dev/null +++ b/modelscope/pipelines/multi_modal/video_to_video_pipeline.py @@ -0,0 +1,140 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +import tempfile +from typing import Any, Dict, Optional + +import cv2 +import torch +from einops import rearrange + +from modelscope.metainfo import Pipelines +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Model, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors.image import LoadImage +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.video_to_video, module_name=Pipelines.video_to_video_pipeline) +class VideoToVideoPipeline(Pipeline): + r""" Video To Video Pipeline, generating super-resolution videos based on input + video and text + + Examples: + >>> from modelscope.pipelines import pipeline + >>> from modelscope.outputs import OutputKeys + + >>> # YOUR_VIDEO_PATH: your video url or local position in low resolution + >>> # INPUT_TEXT: when we do video super-resolution, we will add the text content + >>> # into results + >>> # output_video_path: path-to-the-generated-video + + >>> p = pipeline('video-to-video', 'damo/Video-to-Video') + >>> input = {"video_path":YOUR_VIDEO_PATH, "text": INPUT_TEXT} + >>> output_video_path = p(input,output_video='./output.mp4')[OutputKeys.OUTPUT_VIDEO] + + """ + + def __init__(self, model: str, **kwargs): + """ + use `model` to create a kws pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model, **kwargs) + + def preprocess(self, input: Input, **preprocess_params) -> Dict[str, Any]: + vid_path = input['video_path'] + if 'text' in input.keys(): + text = input['text'] + else: + text = '' + + caption = text + self.model.positive_prompt + y = self.model.clip_encoder(caption).detach() + + max_frames = self.model.cfg.max_frames + + capture = cv2.VideoCapture(vid_path) + _fps = capture.get(cv2.CAP_PROP_FPS) + sample_fps = _fps + _total_frame_num = capture.get(cv2.CAP_PROP_FRAME_COUNT) + stride = round(_fps / sample_fps) + start_frame = 0 + + pointer = 0 + frame_list = [] + while len(frame_list) < max_frames: + ret, frame = capture.read() + pointer += 1 + if (not ret) or (frame is None): + break + if pointer < start_frame: + continue + if pointer >= _total_frame_num + 1: + break + if (pointer - start_frame) % stride == 0: + frame = LoadImage.convert_to_img(frame) + frame_list.append(frame) + capture.release() + + video_data = self.model.vid_trans(frame_list) + + return {'video_data': video_data, 'y': y} + + def forward(self, input: Dict[str, Any], + **forward_params) -> Dict[str, Any]: + video = self.model(input) + return {'video': video} + + def postprocess(self, inputs: Dict[str, Any], + **post_params) -> Dict[str, Any]: + video = tensor2vid(inputs['video'], self.model.cfg.mean, + self.model.cfg.std) + output_video_path = post_params.get('output_video', None) + temp_video_file = False + if output_video_path is None: + output_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name + temp_video_file = True + + temp_dir = tempfile.mkdtemp() + for fid, frame in enumerate(video): + tpth = os.path.join(temp_dir, '%06d.png' % (fid + 1)) + cv2.imwrite(tpth, frame[:, :, ::-1], + [int(cv2.IMWRITE_JPEG_QUALITY), 100]) + + cmd = f'ffmpeg -y -f image2 -loglevel quiet -framerate 8.0 -i {temp_dir}/%06d.png \ + -vcodec libx264 -crf 17 -pix_fmt yuv420p {output_video_path}' + + status = os.system(cmd) + if status != 0: + logger.info('Save Video Error with {}'.format(status)) + os.system(f'rm -rf {temp_dir}') + + if temp_video_file: + video_file_content = b'' + with open(output_video_path, 'rb') as f: + video_file_content = f.read() + os.remove(output_video_path) + return {OutputKeys.OUTPUT_VIDEO: video_file_content} + else: + return {OutputKeys.OUTPUT_VIDEO: output_video_path} + + +def tensor2vid(video, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]): + mean = torch.tensor(mean, device=video.device).reshape(1, -1, 1, 1, 1) + std = torch.tensor(std, device=video.device).reshape(1, -1, 1, 1, 1) + + video = video.mul_(std).add_(mean) + video.clamp_(0, 1) + video = video * 255.0 + + images = rearrange(video, 'b c f h w -> b f h w c')[0] + images = [(img.numpy()).astype('uint8') for img in images] + + return images diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index 9bfe3ac2..3bcad94c 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -257,6 +257,7 @@ class MultiModalTasks(object): efficient_diffusion_tuning = 'efficient-diffusion-tuning' multimodal_dialogue = 'multimodal-dialogue' image_to_video = 'image-to-video' + video_to_video = 'video-to-video' class ScienceTasks(object): diff --git a/tests/pipelines/test_video2video.py b/tests/pipelines/test_video2video.py new file mode 100644 index 00000000..fcd9a7e5 --- /dev/null +++ b/tests/pipelines/test_video2video.py @@ -0,0 +1,32 @@ +import sys +import unittest + +from modelscope.models import Model +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + + +class Video2VideoTest(unittest.TestCase): + + def setUp(self) -> None: + self.task = Tasks.video_to_video + self.model_id = 'damo/Video-to-Video' + self.path = 'https://video-generation-wulanchabu.oss-cn-wulanchabu.aliyuncs.com/baishao/test.mp4' + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_model_from_modelhub(self): + pipe = pipeline(task=self.task, model=self.model_id) + p_input = { + 'video_path': self.path, + 'text': 'A panda is surfing on the sea' + } + + output_video_path = pipe( + p_input, output_video='./output.mp4')[OutputKeys.OUTPUT_VIDEO] + print(output_video_path) + + +if __name__ == '__main__': + unittest.main()