diff --git a/.gitignore b/.gitignore index 573669b2..790daab3 100644 --- a/.gitignore +++ b/.gitignore @@ -123,6 +123,7 @@ tensorboard.sh replace.sh result.png result.jpg +result.mp4 # Pytorch *.pth diff --git a/data/test/images/dogs.jpg b/data/test/images/dogs.jpg index 450a969d..1003c3fd 100644 --- a/data/test/images/dogs.jpg +++ b/data/test/images/dogs.jpg @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:78094cc48fbcfd9b6d321fe13619ecc72b65e006fc1b4c4458409ade9979486d -size 129862 +oid sha256:d53a77b0be82993ed44bbb9244cda42bf460f8dcdf87ff3cfdbfdc7191ff418d +size 121984 diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 47b31428..564ad003 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -182,6 +182,7 @@ class Models(object): mplug = 'mplug' diffusion = 'diffusion-text-to-image-synthesis' multi_stage_diffusion = 'multi-stage-diffusion-text-to-image-synthesis' + video_synthesis = 'latent-text-to-video-synthesis' team = 'team-multi-modal-similarity' video_clip = 'video-clip-multi-modal-embedding' mgeo = 'mgeo' @@ -478,6 +479,7 @@ class Pipelines(object): diffusers_stable_diffusion = 'diffusers-stable-diffusion' document_vl_embedding = 'document-vl-embedding' chinese_stable_diffusion = 'chinese-stable-diffusion' + text_to_video_synthesis = 'latent-text-to-video-synthesis' # latent-text-to-video-synthesis gridvlp_multi_modal_classification = 'gridvlp-multi-modal-classification' gridvlp_multi_modal_embedding = 'gridvlp-multi-modal-embedding' @@ -615,6 +617,8 @@ DEFAULT_MODEL_FOR_PIPELINE = { Tasks.text_to_image_synthesis: (Pipelines.text_to_image_synthesis, 'damo/cv_diffusion_text-to-image-synthesis_tiny'), + Tasks.text_to_video_synthesis: (Pipelines.text_to_video_synthesis, + 'damo/text-to-video-synthesis'), Tasks.body_2d_keypoints: (Pipelines.body_2d_keypoints, 'damo/cv_hrnetv2w32_body-2d-keypoints_image'), Tasks.body_3d_keypoints: (Pipelines.body_3d_keypoints, diff --git a/modelscope/models/multi_modal/__init__.py b/modelscope/models/multi_modal/__init__.py index 4edf6212..452377b2 100644 --- a/modelscope/models/multi_modal/__init__.py +++ b/modelscope/models/multi_modal/__init__.py @@ -17,6 +17,7 @@ if TYPE_CHECKING: from .multi_stage_diffusion import \ MultiStageDiffusionForTextToImageSynthesis from .vldoc import VLDocForDocVLEmbedding + from .video_synthesis import TextToVideoSynthesis else: _import_structure = { @@ -32,6 +33,7 @@ else: 'multi_stage_diffusion': ['MultiStageDiffusionForTextToImageSynthesis'], 'vldoc': ['VLDocForDocVLEmbedding'], + 'video_synthesis': ['TextToVideoSynthesis'], } import sys diff --git a/modelscope/models/multi_modal/video_synthesis/__init__.py b/modelscope/models/multi_modal/video_synthesis/__init__.py new file mode 100644 index 00000000..7db72e7c --- /dev/null +++ b/modelscope/models/multi_modal/video_synthesis/__init__.py @@ -0,0 +1,23 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + + from .text_to_video_synthesis_model import TextToVideoSynthesis + +else: + _import_structure = { + 'text_to_video_synthesis_model': ['TextToVideoSynthesis'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/multi_modal/video_synthesis/autoencoder.py b/modelscope/models/multi_modal/video_synthesis/autoencoder.py new file mode 100644 index 00000000..7885f262 --- /dev/null +++ b/modelscope/models/multi_modal/video_synthesis/autoencoder.py @@ -0,0 +1,569 @@ +# Part of the implementation is borrowed and modified from latent-diffusion, +# publicly avaialbe at https://github.com/CompVis/latent-diffusion. +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +__all__ = ['AutoencoderKL'] + + +def nonlinearity(x): + # swish + return x * torch.sigmoid(x) + + +def Normalize(in_channels, num_groups=32): + return torch.nn.GroupNorm( + num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + + +class DiagonalGaussianDistribution(object): + + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like( + self.mean).to(device=self.parameters.device) + + def sample(self): + x = self.mean + self.std * torch.randn( + self.mean.shape).to(device=self.parameters.device) + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.]) + else: + if other is None: + return 0.5 * torch.sum( + torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, + dim=[1, 2, 3]) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var - 1.0 - self.logvar + other.logvar, + dim=[1, 2, 3]) + + def nll(self, sample, dims=[1, 2, 3]): + if self.deterministic: + return torch.Tensor([0.]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + + torch.pow(sample - self.mean, 2) / self.var, + dim=dims) + + def mode(self): + return self.mean + + +class ResnetBlock(nn.Module): + + def __init__(self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout, + temb_channels=512): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d( + out_channels, out_channels, kernel_size=3, stride=1, padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + else: + self.nin_shortcut = torch.nn.Conv2d( + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x + h + + +class AttnBlock(nn.Module): + + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = q.reshape(b, c, h * w) + q = q.permute(0, 2, 1) # b,hw,c + k = k.reshape(b, c, h * w) # b,c,hw + w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h * w) + w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm( + v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b, c, h, w) + + h_ = self.proj_out(h_) + + return x + h_ + + +class Upsample(nn.Module): + + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x): + x = torch.nn.functional.interpolate( + x, scale_factor=2.0, mode='nearest') + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=2, padding=0) + + def forward(self, x): + if self.with_conv: + pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, pad, mode='constant', value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class Encoder(nn.Module): + + def __init__(self, + *, + ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + double_z=True, + **ignore_kwargs): + super().__init__() + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # downsampling + self.conv_in = torch.nn.Conv2d( + in_channels, self.ch, kernel_size=3, stride=1, padding=1) + + curr_res = resolution + in_ch_mult = (1, ) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, + 2 * z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + # timestep embedding + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + + def __init__(self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + give_pre_end=False, + tanh_out=False, + **ignorekwargs): + super().__init__() + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + self.tanh_out = tanh_out + + # compute block_in and curr_res at lowest res + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2**(self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + + # z to block_in + self.conv_in = torch.nn.Conv2d( + z_channels, block_in, kernel_size=3, stride=1, padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, out_ch, kernel_size=3, stride=1, padding=1) + + def forward(self, z): + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h, temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + if self.tanh_out: + h = torch.tanh(h) + return h + + +class AutoencoderKL(nn.Module): + + def __init__(self, + ddconfig, + embed_dim, + ckpt_path=None, + image_key='image', + colorize_nlabels=None, + monitor=None, + ema_decay=None, + learn_logvar=False): + super().__init__() + self.learn_logvar = learn_logvar + self.image_key = image_key + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + assert ddconfig['double_z'] + self.quant_conv = torch.nn.Conv2d(2 * ddconfig['z_channels'], + 2 * embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, + ddconfig['z_channels'], 1) + self.embed_dim = embed_dim + if colorize_nlabels is not None: + assert type(colorize_nlabels) == int + self.register_buffer('colorize', + torch.randn(3, colorize_nlabels, 1, 1)) + if monitor is not None: + self.monitor = monitor + + self.use_ema = ema_decay is not None + + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path) + + def init_from_ckpt(self, path): + sd = torch.load(path, map_location='cpu')['state_dict'] + keys = list(sd.keys()) + + import collections + sd_new = collections.OrderedDict() + + for k in keys: + if k.find('first_stage_model') >= 0: + k_new = k.split('first_stage_model.')[-1] + sd_new[k_new] = sd[k] + + self.load_state_dict(sd_new, strict=True) + + def on_train_batch_end(self, *args, **kwargs): + if self.use_ema: + self.model_ema(self) + + def encode(self, x): + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def decode(self, z): + z = self.post_quant_conv(z) + dec = self.decoder(z) + return dec + + def forward(self, input, sample_posterior=True): + posterior = self.encode(input) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + dec = self.decode(z) + return dec, posterior + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = x.permute(0, 3, 1, + 2).to(memory_format=torch.contiguous_format).float() + return x + + def get_last_layer(self): + return self.decoder.conv_out.weight + + @torch.no_grad() + def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + if not only_inputs: + xrec, posterior = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec.shape[1] > 3 + x = self.to_rgb(x) + xrec = self.to_rgb(xrec) + log['samples'] = self.decode(torch.randn_like(posterior.sample())) + log['reconstructions'] = xrec + if log_ema or self.use_ema: + with self.ema_scope(): + xrec_ema, posterior_ema = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec_ema.shape[1] > 3 + xrec_ema = self.to_rgb(xrec_ema) + log['samples_ema'] = self.decode( + torch.randn_like(posterior_ema.sample())) + log['reconstructions_ema'] = xrec_ema + log['inputs'] = x + return log + + def to_rgb(self, x): + assert self.image_key == 'segmentation' + if not hasattr(self, 'colorize'): + self.register_buffer('colorize', + torch.randn(3, x.shape[1], 1, 1).to(x)) + x = F.conv2d(x, weight=self.colorize) + x = 2. * (x - x.min()) / (x.max() - x.min()) - 1. + return x + + +class IdentityFirstStage(torch.nn.Module): + + def __init__(self, *args, vq_interface=False, **kwargs): + self.vq_interface = vq_interface + super().__init__() + + def encode(self, x, *args, **kwargs): + return x + + def decode(self, x, *args, **kwargs): + return x + + def quantize(self, x, *args, **kwargs): + if self.vq_interface: + return x, None, [None, None, None] + return x + + def forward(self, x, *args, **kwargs): + return x diff --git a/modelscope/models/multi_modal/video_synthesis/diffusion.py b/modelscope/models/multi_modal/video_synthesis/diffusion.py new file mode 100644 index 00000000..4eba1f13 --- /dev/null +++ b/modelscope/models/multi_modal/video_synthesis/diffusion.py @@ -0,0 +1,227 @@ +# Part of the implementation is borrowed and modified from latent-diffusion, +# publicly avaialbe at https://github.com/CompVis/latent-diffusion. +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. + +import torch + +__all__ = ['GaussianDiffusion', 'beta_schedule'] + + +def _i(tensor, t, x): + r"""Index tensor using t and format the output according to x. + """ + shape = (x.size(0), ) + (1, ) * (x.ndim - 1) + return tensor[t].view(shape).to(x) + + +def beta_schedule(schedule, + num_timesteps=1000, + init_beta=None, + last_beta=None): + if schedule == 'linear_sd': + return torch.linspace( + init_beta**0.5, last_beta**0.5, num_timesteps, + dtype=torch.float64)**2 + else: + raise ValueError(f'Unsupported schedule: {schedule}') + + +class GaussianDiffusion(object): + r""" Diffusion Model for DDIM. + "Denoising diffusion implicit models." by Song, Jiaming, Chenlin Meng, and Stefano Ermon. + See https://arxiv.org/abs/2010.02502 + """ + + def __init__(self, + betas, + mean_type='eps', + var_type='learned_range', + loss_type='mse', + epsilon=1e-12, + rescale_timesteps=False): + # check input + if not isinstance(betas, torch.DoubleTensor): + betas = torch.tensor(betas, dtype=torch.float64) + assert min(betas) > 0 and max(betas) <= 1 + assert mean_type in ['x0', 'x_{t-1}', 'eps'] + assert var_type in [ + 'learned', 'learned_range', 'fixed_large', 'fixed_small' + ] + assert loss_type in [ + 'mse', 'rescaled_mse', 'kl', 'rescaled_kl', 'l1', 'rescaled_l1', + 'charbonnier' + ] + self.betas = betas + self.num_timesteps = len(betas) + self.mean_type = mean_type + self.var_type = var_type + self.loss_type = loss_type + self.epsilon = epsilon + self.rescale_timesteps = rescale_timesteps + + # alphas + alphas = 1 - self.betas + self.alphas_cumprod = torch.cumprod(alphas, dim=0) + self.alphas_cumprod_prev = torch.cat( + [alphas.new_ones([1]), self.alphas_cumprod[:-1]]) + self.alphas_cumprod_next = torch.cat( + [self.alphas_cumprod[1:], + alphas.new_zeros([1])]) + + # q(x_t | x_{t-1}) + self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod) + self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 + - self.alphas_cumprod) + self.log_one_minus_alphas_cumprod = torch.log(1.0 + - self.alphas_cumprod) + self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod) + self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod + - 1) + + # q(x_{t-1} | x_t, x_0) + self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / ( + 1.0 - self.alphas_cumprod) + self.posterior_log_variance_clipped = torch.log( + self.posterior_variance.clamp(1e-20)) + self.posterior_mean_coef1 = betas * torch.sqrt( + self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + self.posterior_mean_coef2 = ( + 1.0 - self.alphas_cumprod_prev) * torch.sqrt(alphas) / ( + 1.0 - self.alphas_cumprod) + + def p_mean_variance(self, + xt, + t, + model, + model_kwargs={}, + clamp=None, + percentile=None, + guide_scale=None): + r"""Distribution of p(x_{t-1} | x_t). + """ + # predict distribution + if guide_scale is None: + out = model(xt, self._scale_timesteps(t), **model_kwargs) + else: + # classifier-free guidance + # (model_kwargs[0]: conditional kwargs; model_kwargs[1]: non-conditional kwargs) + assert isinstance(model_kwargs, list) and len(model_kwargs) == 2 + y_out = model(xt, self._scale_timesteps(t), **model_kwargs[0]) + u_out = model(xt, self._scale_timesteps(t), **model_kwargs[1]) + dim = y_out.size(1) if self.var_type.startswith( + 'fixed') else y_out.size(1) // 2 + a = u_out[:, :dim] + b = guide_scale * (y_out[:, :dim] - u_out[:, :dim]) + c = y_out[:, dim:] + out = torch.cat([a + b, c], dim=1) + + # compute variance + if self.var_type == 'fixed_small': + var = _i(self.posterior_variance, t, xt) + log_var = _i(self.posterior_log_variance_clipped, t, xt) + + # compute mean and x0 + if self.mean_type == 'eps': + x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - _i( + self.sqrt_recipm1_alphas_cumprod, t, xt) * out + mu, _, _ = self.q_posterior_mean_variance(x0, xt, t) + + # restrict the range of x0 + if percentile is not None: + assert percentile > 0 and percentile <= 1 # e.g., 0.995 + s = torch.quantile( + x0.flatten(1).abs(), percentile, + dim=1).clamp_(1.0).view(-1, 1, 1, 1) + x0 = torch.min(s, torch.max(-s, x0)) / s + elif clamp is not None: + x0 = x0.clamp(-clamp, clamp) + return mu, var, log_var, x0 + + def q_posterior_mean_variance(self, x0, xt, t): + r"""Distribution of q(x_{t-1} | x_t, x_0). + """ + mu = _i(self.posterior_mean_coef1, t, xt) * x0 + _i( + self.posterior_mean_coef2, t, xt) * xt + var = _i(self.posterior_variance, t, xt) + log_var = _i(self.posterior_log_variance_clipped, t, xt) + return mu, var, log_var + + @torch.no_grad() + def ddim_sample(self, + xt, + t, + model, + model_kwargs={}, + clamp=None, + percentile=None, + condition_fn=None, + guide_scale=None, + ddim_timesteps=20, + eta=0.0): + r"""Sample from p(x_{t-1} | x_t) using DDIM. + - condition_fn: for classifier-based guidance (guided-diffusion). + - guide_scale: for classifier-free guidance (glide/dalle-2). + """ + stride = self.num_timesteps // ddim_timesteps + + # predict distribution of p(x_{t-1} | x_t) + _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp, + percentile, guide_scale) + if condition_fn is not None: + # x0 -> eps + alpha = _i(self.alphas_cumprod, t, xt) + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / _i( + self.sqrt_recipm1_alphas_cumprod, t, xt) + eps = eps - (1 - alpha).sqrt() * condition_fn( + xt, self._scale_timesteps(t), **model_kwargs) + + # eps -> x0 + x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - _i( + self.sqrt_recipm1_alphas_cumprod, t, xt) * eps + + # derive variables + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / _i( + self.sqrt_recipm1_alphas_cumprod, t, xt) + alphas = _i(self.alphas_cumprod, t, xt) + alphas_prev = _i(self.alphas_cumprod, (t - stride).clamp(0), xt) + a = (1 - alphas_prev) / (1 - alphas) + b = (1 - alphas / alphas_prev) + sigmas = eta * torch.sqrt(a * b) + + # random sample + noise = torch.randn_like(xt) + direction = torch.sqrt(1 - alphas_prev - sigmas**2) * eps + mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1))) + xt_1 = torch.sqrt(alphas_prev) * x0 + direction + mask * sigmas * noise + return xt_1, x0 + + @torch.no_grad() + def ddim_sample_loop(self, + noise, + model, + model_kwargs={}, + clamp=None, + percentile=None, + condition_fn=None, + guide_scale=None, + ddim_timesteps=20, + eta=0.0): + # prepare input + b = noise.size(0) + xt = noise + + # diffusion process (TODO: clamp is inaccurate! Consider replacing the stride by explicit prev/next steps) + steps = (1 + torch.arange(0, self.num_timesteps, + self.num_timesteps // ddim_timesteps)).clamp( + 0, self.num_timesteps - 1).flip(0) + for step in steps: + t = torch.full((b, ), step, dtype=torch.long, device=xt.device) + xt, _ = self.ddim_sample(xt, t, model, model_kwargs, clamp, + percentile, condition_fn, guide_scale, + ddim_timesteps, eta) + return xt + + def _scale_timesteps(self, t): + if self.rescale_timesteps: + return t.float() * 1000.0 / self.num_timesteps + return t diff --git a/modelscope/models/multi_modal/video_synthesis/text_to_video_synthesis_model.py b/modelscope/models/multi_modal/video_synthesis/text_to_video_synthesis_model.py new file mode 100644 index 00000000..1bcd6eda --- /dev/null +++ b/modelscope/models/multi_modal/video_synthesis/text_to_video_synthesis_model.py @@ -0,0 +1,241 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. + +import os +from os import path as osp +from typing import Any, Dict + +import open_clip +import torch +import torch.cuda.amp as amp +from einops import rearrange + +from modelscope.metainfo import Models +from modelscope.models.base import Model +from modelscope.models.builder import MODELS +from modelscope.models.multi_modal.video_synthesis.autoencoder import \ + AutoencoderKL +from modelscope.models.multi_modal.video_synthesis.diffusion import ( + GaussianDiffusion, beta_schedule) +from modelscope.models.multi_modal.video_synthesis.unet_sd import UNetSD +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile, Tasks + +__all__ = ['TextToVideoSynthesis'] + + +@MODELS.register_module( + Tasks.text_to_video_synthesis, module_name=Models.video_synthesis) +class TextToVideoSynthesis(Model): + r""" + task for text to video synthesis. + + Attributes: + sd_model: denosing model using in this task. + diffusion: diffusion model for DDIM. + autoencoder: decode the latent representation into visual space with VQGAN. + clip_encoder: encode the text into text embedding. + """ + + def __init__(self, model_dir, *args, **kwargs): + r""" + Args: + model_dir (`str` or `os.PathLike`) + Can be either: + - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co + or modelscope.cn. Valid model ids can be located at the root-level, like `bert-base-uncased`, + or namespaced under a user or organization name, like `dbmdz/bert-base-german-cased`. + - A path to a *directory* containing model weights saved using + [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. + - A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In + this case, `from_tf` should be set to `True` and a configuration object should be provided as + `config` argument. This loading path is slower than converting the TensorFlow checkpoint in a + PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards. + - A path or url to a model folder containing a *flax checkpoint file* in *.msgpack* format (e.g, + `./flax_model/` containing `flax_model.msgpack`). In this case, `from_flax` should be set to + `True`. + """ + super().__init__(model_dir=model_dir, *args, **kwargs) + self.device = torch.device('cuda') if torch.cuda.is_available() \ + else torch.device('cpu') + self.config = Config.from_file( + osp.join(model_dir, ModelFile.CONFIGURATION)) + cfg = self.config.model.model_cfg + cfg['temporal_attention'] = True if cfg[ + 'temporal_attention'] == 'True' else False + + # Initialize unet + self.sd_model = UNetSD( + in_dim=cfg['unet_in_dim'], + dim=cfg['unet_dim'], + y_dim=cfg['unet_y_dim'], + context_dim=cfg['unet_context_dim'], + out_dim=cfg['unet_out_dim'], + dim_mult=cfg['unet_dim_mult'], + num_heads=cfg['unet_num_heads'], + head_dim=cfg['unet_head_dim'], + num_res_blocks=cfg['unet_res_blocks'], + attn_scales=cfg['unet_attn_scales'], + dropout=cfg['unet_dropout'], + temporal_attention=cfg['temporal_attention']) + self.sd_model.load_state_dict( + torch.load( + osp.join(model_dir, self.config.model.model_args.ckpt_unet)), + strict=True) + self.sd_model.eval() + self.sd_model.to(self.device) + + # Initialize diffusion + betas = beta_schedule( + 'linear_sd', + cfg['num_timesteps'], + init_beta=0.00085, + last_beta=0.0120) + self.diffusion = GaussianDiffusion( + betas=betas, + mean_type=cfg['mean_type'], + var_type=cfg['var_type'], + loss_type=cfg['loss_type'], + rescale_timesteps=False) + + # Initialize autoencoder + ddconfig = { + 'double_z': True, + 'z_channels': 4, + 'resolution': 256, + 'in_channels': 3, + 'out_ch': 3, + 'ch': 128, + 'ch_mult': [1, 2, 4, 4], + 'num_res_blocks': 2, + 'attn_resolutions': [], + 'dropout': 0.0 + } + self.autoencoder = AutoencoderKL( + ddconfig, 4, + osp.join(model_dir, self.config.model.model_args.ckpt_autoencoder)) + if self.config.model.model_args.tiny_gpu == 1: + self.autoencoder.to('cpu') + else: + self.autoencoder.to(self.device) + self.autoencoder.eval() + + # Initialize Open clip + self.clip_encoder = FrozenOpenCLIPEmbedder( + version=osp.join(model_dir, + self.config.model.model_args.ckpt_clip), + layer='penultimate') + if self.config.model.model_args.tiny_gpu == 1: + self.clip_encoder.to('cpu') + else: + self.clip_encoder.to(self.device) + + def forward(self, input: Dict[str, Any]): + r""" + The entry function of text to image synthesis task. + 1. Using diffusion model to generate the video's latent representation. + 2. Using vqgan model (autoencoder) to decode the video's latent representation to visual space. + + Args: + input (`Dict[Str, Any]`): + The input of the task + Returns: + A generated video (as pytorch tensor). + """ + y = input['text_emb'] + zero_y = input['text_emb_zero'] + context = torch.cat([zero_y, y], dim=0).to(self.device) + # synthesis + with torch.no_grad(): + num_sample = 1 # here let b = 1 + max_frames = self.config.model.model_args.max_frames + latent_h, latent_w = 32, 32 + with amp.autocast(enabled=True): + x0 = self.diffusion.ddim_sample_loop( + noise=torch.randn(num_sample, 4, max_frames, latent_h, + latent_w).to( + self.device), # shape: b c f h w + model=self.sd_model, + model_kwargs=[{ + 'y': + context[1].unsqueeze(0).repeat(num_sample, 1, 1) + }, { + 'y': + context[0].unsqueeze(0).repeat(num_sample, 1, 1) + }], + guide_scale=9.0, + ddim_timesteps=50, + eta=0.0) + + scale_factor = 0.18215 + video_data = 1. / scale_factor * x0 + bs_vd = video_data.shape[0] + video_data = rearrange(video_data, 'b c f h w -> (b f) c h w') + self.autoencoder.to(self.device) + video_data = self.autoencoder.decode(video_data) + if self.config.model.model_args.tiny_gpu == 1: + self.autoencoder.to('cpu') + video_data = rearrange( + video_data, '(b f) c h w -> b c f h w', b=bs_vd) + return video_data.type(torch.float32).cpu() + + +class FrozenOpenCLIPEmbedder(torch.nn.Module): + """ + Uses the OpenCLIP transformer encoder for text + """ + LAYERS = ['last', 'penultimate'] + + def __init__(self, + arch='ViT-H-14', + version='open_clip_pytorch_model.bin', + device='cuda', + max_length=77, + freeze=True, + layer='last'): + super().__init__() + assert layer in self.LAYERS + model, _, _ = open_clip.create_model_and_transforms( + arch, device=torch.device('cpu'), pretrained=version) + del model.visual + self.model = model + + self.device = device + self.max_length = max_length + if freeze: + self.freeze() + self.layer = layer + if self.layer == 'last': + self.layer_idx = 0 + elif self.layer == 'penultimate': + self.layer_idx = 1 + else: + raise NotImplementedError() + + def freeze(self): + self.model = self.model.eval() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + tokens = open_clip.tokenize(text) + z = self.encode_with_transformer(tokens.to(self.device)) + return z + + def encode_with_transformer(self, text): + x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] + x = x + self.model.positional_embedding + x = x.permute(1, 0, 2) # NLD -> LND + x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.model.ln_final(x) + return x + + def text_transformer_forward(self, x: torch.Tensor, attn_mask=None): + for i, r in enumerate(self.model.transformer.resblocks): + if i == len(self.model.transformer.resblocks) - self.layer_idx: + break + x = r(x, attn_mask=attn_mask) + return x + + def encode(self, text): + return self(text) diff --git a/modelscope/models/multi_modal/video_synthesis/unet_sd.py b/modelscope/models/multi_modal/video_synthesis/unet_sd.py new file mode 100644 index 00000000..f3c764eb --- /dev/null +++ b/modelscope/models/multi_modal/video_synthesis/unet_sd.py @@ -0,0 +1,1098 @@ +# Part of the implementation is borrowed and modified from stable-diffusion, +# publicly avaialbe at https://github.com/Stability-AI/stablediffusion. +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, repeat + +__all__ = ['UNetSD'] + + +def exists(x): + return x is not None + + +def default(val, d): + if exists(val): + return val + return d() if callable(d) else d + + +class UNetSD(nn.Module): + + def __init__(self, + in_dim=7, + dim=512, + y_dim=512, + context_dim=512, + out_dim=6, + dim_mult=[1, 2, 3, 4], + num_heads=None, + head_dim=64, + num_res_blocks=3, + attn_scales=[1 / 2, 1 / 4, 1 / 8], + use_scale_shift_norm=True, + dropout=0.1, + temporal_attn_times=2, + temporal_attention=True, + use_checkpoint=False, + use_image_dataset=False, + use_fps_condition=False, + use_sim_mask=False): + embed_dim = dim * 4 + num_heads = num_heads if num_heads else dim // 32 + super(UNetSD, self).__init__() + self.in_dim = in_dim + self.dim = dim + self.y_dim = y_dim + self.context_dim = context_dim + self.embed_dim = embed_dim + self.out_dim = out_dim + self.dim_mult = dim_mult + self.num_heads = num_heads + # parameters for spatial/temporal attention + self.head_dim = head_dim + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.use_scale_shift_norm = use_scale_shift_norm + self.temporal_attn_times = temporal_attn_times + self.temporal_attention = temporal_attention + self.use_checkpoint = use_checkpoint + self.use_image_dataset = use_image_dataset + self.use_fps_condition = use_fps_condition + self.use_sim_mask = use_sim_mask + use_linear_in_temporal = False + transformer_depth = 1 + disabled_sa = False + # params + enc_dims = [dim * u for u in [1] + dim_mult] + dec_dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + shortcut_dims = [] + scale = 1.0 + + # embeddings + self.time_embed = nn.Sequential( + nn.Linear(dim, embed_dim), nn.SiLU(), + nn.Linear(embed_dim, embed_dim)) + + if self.use_fps_condition: + self.fps_embedding = nn.Sequential( + nn.Linear(dim, embed_dim), nn.SiLU(), + nn.Linear(embed_dim, embed_dim)) + nn.init.zeros_(self.fps_embedding[-1].weight) + nn.init.zeros_(self.fps_embedding[-1].bias) + + # encoder + self.input_blocks = nn.ModuleList() + init_block = nn.ModuleList([nn.Conv2d(self.in_dim, dim, 3, padding=1)]) + + if temporal_attention: + init_block.append( + TemporalTransformer( + dim, + num_heads, + head_dim, + depth=transformer_depth, + context_dim=context_dim, + disable_self_attn=disabled_sa, + use_linear=use_linear_in_temporal, + multiply_zero=use_image_dataset)) + + self.input_blocks.append(init_block) + shortcut_dims.append(dim) + for i, (in_dim, + out_dim) in enumerate(zip(enc_dims[:-1], enc_dims[1:])): + for j in range(num_res_blocks): + # residual (+attention) blocks + block = nn.ModuleList([ + ResBlock( + in_dim, + embed_dim, + dropout, + out_channels=out_dim, + use_scale_shift_norm=False, + use_image_dataset=use_image_dataset, + ) + ]) + if scale in attn_scales: + block.append( + SpatialTransformer( + out_dim, + out_dim // head_dim, + head_dim, + depth=1, + context_dim=self.context_dim, + disable_self_attn=False, + use_linear=True)) + if self.temporal_attention: + block.append( + TemporalTransformer( + out_dim, + out_dim // head_dim, + head_dim, + depth=transformer_depth, + context_dim=context_dim, + disable_self_attn=disabled_sa, + use_linear=use_linear_in_temporal, + multiply_zero=use_image_dataset)) + + in_dim = out_dim + self.input_blocks.append(block) + shortcut_dims.append(out_dim) + + # downsample + if i != len(dim_mult) - 1 and j == num_res_blocks - 1: + downsample = Downsample( + out_dim, True, dims=2, out_channels=out_dim) + shortcut_dims.append(out_dim) + scale /= 2.0 + self.input_blocks.append(downsample) + + # middle + self.middle_block = nn.ModuleList([ + ResBlock( + out_dim, + embed_dim, + dropout, + use_scale_shift_norm=False, + use_image_dataset=use_image_dataset, + ), + SpatialTransformer( + out_dim, + out_dim // head_dim, + head_dim, + depth=1, + context_dim=self.context_dim, + disable_self_attn=False, + use_linear=True) + ]) + + if self.temporal_attention: + self.middle_block.append( + TemporalTransformer( + out_dim, + out_dim // head_dim, + head_dim, + depth=transformer_depth, + context_dim=context_dim, + disable_self_attn=disabled_sa, + use_linear=use_linear_in_temporal, + multiply_zero=use_image_dataset, + )) + + self.middle_block.append( + ResBlock( + out_dim, + embed_dim, + dropout, + use_scale_shift_norm=False, + use_image_dataset=use_image_dataset, + )) + + # decoder + self.output_blocks = nn.ModuleList() + for i, (in_dim, + out_dim) in enumerate(zip(dec_dims[:-1], dec_dims[1:])): + for j in range(num_res_blocks + 1): + # residual (+attention) blocks + block = nn.ModuleList([ + ResBlock( + in_dim + shortcut_dims.pop(), + embed_dim, + dropout, + out_dim, + use_scale_shift_norm=False, + use_image_dataset=use_image_dataset, + ) + ]) + if scale in attn_scales: + block.append( + SpatialTransformer( + out_dim, + out_dim // head_dim, + head_dim, + depth=1, + context_dim=1024, + disable_self_attn=False, + use_linear=True)) + + if self.temporal_attention: + block.append( + TemporalTransformer( + out_dim, + out_dim // head_dim, + head_dim, + depth=transformer_depth, + context_dim=context_dim, + disable_self_attn=disabled_sa, + use_linear=use_linear_in_temporal, + multiply_zero=use_image_dataset)) + in_dim = out_dim + + # upsample + if i != len(dim_mult) - 1 and j == num_res_blocks: + upsample = Upsample( + out_dim, True, dims=2.0, out_channels=out_dim) + scale *= 2.0 + block.append(upsample) + self.output_blocks.append(block) + + # head + self.out = nn.Sequential( + nn.GroupNorm(32, out_dim), nn.SiLU(), + nn.Conv2d(out_dim, self.out_dim, 3, padding=1)) + + # zero out the last layer params + nn.init.zeros_(self.out[-1].weight) + + def forward( + self, + x, + t, + y, + fps=None, + video_mask=None, + focus_present_mask=None, + prob_focus_present=0., + mask_last_frame_num=0 # mask last frame num + ): + """ + prob_focus_present: probability at which a given batch sample will focus on the present + (0. is all off, 1. is completely arrested attention across time) + """ + batch, device = x.shape[0], x.device + self.batch = batch + + # image and video joint training, if mask_last_frame_num is set, prob_focus_present will be ignored + if mask_last_frame_num > 0: + focus_present_mask = None + video_mask[-mask_last_frame_num:] = False + else: + focus_present_mask = default( + focus_present_mask, lambda: prob_mask_like( + (batch, ), prob_focus_present, device=device)) + + time_rel_pos_bias = None + # embeddings + if self.use_fps_condition and fps is not None: + e = self.time_embed(sinusoidal_embedding( + t, self.dim)) + self.fps_embedding( + sinusoidal_embedding(fps, self.dim)) + else: + e = self.time_embed(sinusoidal_embedding(t, self.dim)) + context = y + + # repeat f times for spatial e and context + f = x.shape[2] + e = e.repeat_interleave(repeats=f, dim=0) + context = context.repeat_interleave(repeats=f, dim=0) + + # always in shape (b f) c h w, except for temporal layer + x = rearrange(x, 'b c f h w -> (b f) c h w') + # encoder + xs = [] + for block in self.input_blocks: + x = self._forward_single(block, x, e, context, time_rel_pos_bias, + focus_present_mask, video_mask) + xs.append(x) + + # middle + for block in self.middle_block: + x = self._forward_single(block, x, e, context, time_rel_pos_bias, + focus_present_mask, video_mask) + + # decoder + for block in self.output_blocks: + x = torch.cat([x, xs.pop()], dim=1) + x = self._forward_single( + block, + x, + e, + context, + time_rel_pos_bias, + focus_present_mask, + video_mask, + reference=xs[-1] if len(xs) > 0 else None) + + # head + x = self.out(x) + # reshape back to (b c f h w) + x = rearrange(x, '(b f) c h w -> b c f h w', b=batch) + return x + + def _forward_single(self, + module, + x, + e, + context, + time_rel_pos_bias, + focus_present_mask, + video_mask, + reference=None): + if isinstance(module, ResidualBlock): + x = x.contiguous() + x = module(x, e, reference) + elif isinstance(module, ResBlock): + x = x.contiguous() + x = module(x, e, self.batch) + elif isinstance(module, SpatialTransformer): + x = module(x, context) + elif isinstance(module, TemporalTransformer): + x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch) + x = module(x, context) + x = rearrange(x, 'b c f h w -> (b f) c h w') + elif isinstance(module, CrossAttention): + x = module(x, context) + elif isinstance(module, BasicTransformerBlock): + x = module(x, context) + elif isinstance(module, FeedForward): + x = module(x, context) + elif isinstance(module, Upsample): + x = module(x) + elif isinstance(module, Downsample): + x = module(x) + elif isinstance(module, Resample): + x = module(x, reference) + elif isinstance(module, nn.ModuleList): + for block in module: + x = self._forward_single(block, x, e, context, + time_rel_pos_bias, focus_present_mask, + video_mask, reference) + else: + x = module(x) + return x + + +def sinusoidal_embedding(timesteps, dim): + # check input + half = dim // 2 + timesteps = timesteps.float() + # compute sinusoidal embedding + sinusoid = torch.outer( + timesteps, torch.pow(10000, + -torch.arange(half).to(timesteps).div(half))) + x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) + if dim % 2 != 0: + x = torch.cat([x, torch.zeros_like(x[:, :1])], dim=1) + return x + + +class CrossAttention(nn.Module): + + def __init__(self, + query_dim, + context_dim=None, + heads=8, + dim_head=64, + dropout=0.): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.scale = dim_head**-0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) + + def forward(self, x, context=None, mask=None): + h = self.heads + + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), + (q, k, v)) + sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale + del q, k + + if exists(mask): + mask = rearrange(mask, 'b ... -> b (...)') + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h=h) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + sim = sim.softmax(dim=-1) + + out = torch.einsum('b i j, b j d -> b i d', sim, v) + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + return self.to_out(out) + + +class SpatialTransformer(nn.Module): + """ + Transformer block for image-like data in spatial axis. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + NEW: use_linear for more efficiency instead of the 1x1 convs + """ + + def __init__(self, + in_channels, + n_heads, + d_head, + depth=1, + dropout=0., + context_dim=None, + disable_self_attn=False, + use_linear=False, + use_checkpoint=True): + super().__init__() + if exists(context_dim) and not isinstance(context_dim, list): + context_dim = [context_dim] + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = torch.nn.GroupNorm( + num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + if not use_linear: + self.proj_in = nn.Conv2d( + in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + else: + self.proj_in = nn.Linear(in_channels, inner_dim) + + self.transformer_blocks = nn.ModuleList([ + BasicTransformerBlock( + inner_dim, + n_heads, + d_head, + dropout=dropout, + context_dim=context_dim[d], + disable_self_attn=disable_self_attn, + checkpoint=use_checkpoint) for d in range(depth) + ]) + if not use_linear: + self.proj_out = zero_module( + nn.Conv2d( + inner_dim, in_channels, kernel_size=1, stride=1, + padding=0)) + else: + self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) + self.use_linear = use_linear + + def forward(self, x, context=None): + # note: if no context is given, cross-attention defaults to self-attention + if not isinstance(context, list): + context = [context] + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + if not self.use_linear: + x = self.proj_in(x) + x = rearrange(x, 'b c h w -> b (h w) c').contiguous() + if self.use_linear: + x = self.proj_in(x) + for i, block in enumerate(self.transformer_blocks): + x = block(x, context=context[i]) + if self.use_linear: + x = self.proj_out(x) + x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() + if not self.use_linear: + x = self.proj_out(x) + return x + x_in + + +class TemporalTransformer(nn.Module): + """ + Transformer block for image-like data in temporal axis. + First, reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + """ + + def __init__(self, + in_channels, + n_heads, + d_head, + depth=1, + dropout=0., + context_dim=None, + disable_self_attn=False, + use_linear=False, + use_checkpoint=True, + only_self_att=True, + multiply_zero=False): + super().__init__() + self.multiply_zero = multiply_zero + self.only_self_att = only_self_att + if self.only_self_att: + context_dim = None + if not isinstance(context_dim, list): + context_dim = [context_dim] + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = torch.nn.GroupNorm( + num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + if not use_linear: + self.proj_in = nn.Conv1d( + in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + else: + self.proj_in = nn.Linear(in_channels, inner_dim) + + self.transformer_blocks = nn.ModuleList([ + BasicTransformerBlock( + inner_dim, + n_heads, + d_head, + dropout=dropout, + context_dim=context_dim[d], + checkpoint=use_checkpoint) for d in range(depth) + ]) + if not use_linear: + self.proj_out = zero_module( + nn.Conv1d( + inner_dim, in_channels, kernel_size=1, stride=1, + padding=0)) + else: + self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) + self.use_linear = use_linear + + def forward(self, x, context=None): + # note: if no context is given, cross-attention defaults to self-attention + if self.only_self_att: + context = None + if not isinstance(context, list): + context = [context] + b, c, f, h, w = x.shape + x_in = x + x = self.norm(x) + + if not self.use_linear: + x = rearrange(x, 'b c f h w -> (b h w) c f').contiguous() + x = self.proj_in(x) + if self.use_linear: + x = rearrange( + x, '(b f) c h w -> b (h w) f c', f=self.frames).contiguous() + x = self.proj_in(x) + + if self.only_self_att: + x = rearrange(x, 'bhw c f -> bhw f c').contiguous() + for i, block in enumerate(self.transformer_blocks): + x = block(x) + x = rearrange(x, '(b hw) f c -> b hw f c', b=b).contiguous() + else: + x = rearrange(x, '(b hw) c f -> b hw f c', b=b).contiguous() + for i, block in enumerate(self.transformer_blocks): + context[i] = rearrange( + context[i], '(b f) l con -> b f l con', + f=self.frames).contiguous() + # calculate each batch one by one (since number in shape could not greater then 65,535 for some package) + for j in range(b): + context_i_j = repeat( + context[i][j], + 'f l con -> (f r) l con', + r=(h * w) // self.frames, + f=self.frames).contiguous() + x[j] = block(x[j], context=context_i_j) + + if self.use_linear: + x = self.proj_out(x) + x = rearrange(x, 'b (h w) f c -> b f c h w', h=h, w=w).contiguous() + if not self.use_linear: + x = rearrange(x, 'b hw f c -> (b hw) c f').contiguous() + x = self.proj_out(x) + x = rearrange( + x, '(b h w) c f -> b c f h w', b=b, h=h, w=w).contiguous() + + if self.multiply_zero: + x = 0.0 * x + x_in + else: + x = x + x_in + return x + + +class BasicTransformerBlock(nn.Module): + + def __init__(self, + dim, + n_heads, + d_head, + dropout=0., + context_dim=None, + gated_ff=True, + checkpoint=True, + disable_self_attn=False): + super().__init__() + attn_cls = CrossAttention + self.disable_self_attn = disable_self_attn + self.attn1 = attn_cls( + query_dim=dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout, + context_dim=context_dim if self.disable_self_attn else + None) # is a self-attention if not self.disable_self_attn + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = attn_cls( + query_dim=dim, + context_dim=context_dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout) # is self-attn if context is none + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + + def forward(self, x, context=None): + x = self.attn1( + self.norm1(x), + context=context if self.disable_self_attn else None) + x + x = self.attn2(self.norm2(x), context=context) + x + x = self.ff(self.norm3(x)) + x + return x + + +# feedforward +class GEGLU(nn.Module): + + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +class FeedForward(nn.Module): + + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = nn.Sequential(nn.Linear( + dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential(project_in, nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out)) + + def forward(self, x): + return self.net(x) + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, + channels, + use_conv, + dims=2, + out_channels=None, + padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = nn.Conv2d( + self.channels, self.out_channels, 3, padding=padding) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate( + x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), + mode='nearest') + else: + x = F.interpolate(x, scale_factor=2, mode='nearest') + if self.use_conv: + x = self.conv(x) + return x + + +class ResBlock(nn.Module): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + :param use_temporal_conv: if True, use the temporal convolution. + :param use_image_dataset: if True, the temporal parameters will not be optimized. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + up=False, + down=False, + use_temporal_conv=True, + use_image_dataset=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_scale_shift_norm = use_scale_shift_norm + self.use_temporal_conv = use_temporal_conv + + self.in_layers = nn.Sequential( + nn.GroupNorm(32, channels), + nn.SiLU(), + nn.Conv2d(channels, self.out_channels, 3, padding=1), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + nn.SiLU(), + nn.Linear( + emb_channels, + 2 * self.out_channels + if use_scale_shift_norm else self.out_channels, + ), + ) + self.out_layers = nn.Sequential( + nn.GroupNorm(32, self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module( + nn.Conv2d(self.out_channels, self.out_channels, 3, padding=1)), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, 3, padding=1) + else: + self.skip_connection = nn.Conv2d(channels, self.out_channels, 1) + + if self.use_temporal_conv: + self.temopral_conv = TemporalConvBlock_v2( + self.out_channels, + self.out_channels, + dropout=0.1, + use_image_dataset=use_image_dataset) + + def forward(self, x, emb, batch_size): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + return self._forward(x, emb, batch_size) + + def _forward(self, x, emb, batch_size): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = torch.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + h = self.skip_connection(x) + h + + if self.use_temporal_conv: + h = rearrange(h, '(b f) c h w -> b c f h w', b=batch_size) + h = self.temopral_conv(h) + h = rearrange(h, 'b c f h w -> (b f) c h w') + return h + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, + channels, + use_conv, + dims=2, + out_channels=None, + padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if self.use_conv: + self.op = nn.Conv2d( + self.channels, + self.out_channels, + 3, + stride=stride, + padding=padding) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class Resample(nn.Module): + + def __init__(self, in_dim, out_dim, mode): + assert mode in ['none', 'upsample', 'downsample'] + super(Resample, self).__init__() + self.in_dim = in_dim + self.out_dim = out_dim + self.mode = mode + + def forward(self, x, reference=None): + if self.mode == 'upsample': + assert reference is not None + x = F.interpolate(x, size=reference.shape[-2:], mode='nearest') + elif self.mode == 'downsample': + x = F.adaptive_avg_pool2d( + x, output_size=tuple(u // 2 for u in x.shape[-2:])) + return x + + +class ResidualBlock(nn.Module): + + def __init__(self, + in_dim, + embed_dim, + out_dim, + use_scale_shift_norm=True, + mode='none', + dropout=0.0): + super(ResidualBlock, self).__init__() + self.in_dim = in_dim + self.embed_dim = embed_dim + self.out_dim = out_dim + self.use_scale_shift_norm = use_scale_shift_norm + self.mode = mode + + # layers + self.layer1 = nn.Sequential( + nn.GroupNorm(32, in_dim), nn.SiLU(), + nn.Conv2d(in_dim, out_dim, 3, padding=1)) + self.resample = Resample(in_dim, in_dim, mode) + self.embedding = nn.Sequential( + nn.SiLU(), + nn.Linear(embed_dim, + out_dim * 2 if use_scale_shift_norm else out_dim)) + self.layer2 = nn.Sequential( + nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout), + nn.Conv2d(out_dim, out_dim, 3, padding=1)) + self.shortcut = nn.Identity() if in_dim == out_dim else nn.Conv2d( + in_dim, out_dim, 1) + # zero out the last layer params + nn.init.zeros_(self.layer2[-1].weight) + + def forward(self, x, e, reference=None): + identity = self.resample(x, reference) + x = self.layer1[-1](self.resample(self.layer1[:-1](x), reference)) + e = self.embedding(e).unsqueeze(-1).unsqueeze(-1).type(x.dtype) + if self.use_scale_shift_norm: + scale, shift = e.chunk(2, dim=1) + x = self.layer2[0](x) * (1 + scale) + shift + x = self.layer2[1:](x) + else: + x = x + e + x = self.layer2(x) + x = x + self.shortcut(identity) + return x + + +class AttentionBlock(nn.Module): + + def __init__(self, dim, context_dim=None, num_heads=None, head_dim=None): + # consider head_dim first, then num_heads + num_heads = dim // head_dim if head_dim else num_heads + head_dim = dim // num_heads + assert num_heads * head_dim == dim + super(AttentionBlock, self).__init__() + self.dim = dim + self.context_dim = context_dim + self.num_heads = num_heads + self.head_dim = head_dim + self.scale = math.pow(head_dim, -0.25) + + # layers + self.norm = nn.GroupNorm(32, dim) + self.to_qkv = nn.Conv2d(dim, dim * 3, 1) + if context_dim is not None: + self.context_kv = nn.Linear(context_dim, dim * 2) + self.proj = nn.Conv2d(dim, dim, 1) + + # zero out the last layer params + nn.init.zeros_(self.proj.weight) + + def forward(self, x, context=None): + r"""x: [B, C, H, W]. + context: [B, L, C] or None. + """ + identity = x + b, c, h, w, n, d = *x.size(), self.num_heads, self.head_dim + + # compute query, key, value + x = self.norm(x) + q, k, v = self.to_qkv(x).view(b, n * 3, d, h * w).chunk(3, dim=1) + if context is not None: + ck, cv = self.context_kv(context).reshape(b, -1, n * 2, + d).permute(0, 2, 3, + 1).chunk( + 2, dim=1) + k = torch.cat([ck, k], dim=-1) + v = torch.cat([cv, v], dim=-1) + + # compute attention + attn = torch.matmul(q.transpose(-1, -2) * self.scale, k * self.scale) + attn = F.softmax(attn, dim=-1) + + # gather context + x = torch.matmul(v, attn.transpose(-1, -2)) + x = x.reshape(b, c, h, w) + # output + x = self.proj(x) + return x + identity + + +class TemporalConvBlock_v2(nn.Module): + + def __init__(self, + in_dim, + out_dim=None, + dropout=0.0, + use_image_dataset=False): + super(TemporalConvBlock_v2, self).__init__() + if out_dim is None: + out_dim = in_dim # int(1.5*in_dim) + self.in_dim = in_dim + self.out_dim = out_dim + self.use_image_dataset = use_image_dataset + + # conv layers + self.conv1 = nn.Sequential( + nn.GroupNorm(32, in_dim), nn.SiLU(), + nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0))) + self.conv2 = nn.Sequential( + nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout), + nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0))) + self.conv3 = nn.Sequential( + nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout), + nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0))) + self.conv4 = nn.Sequential( + nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout), + nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0))) + + # zero out the last layer params,so the conv block is identity + nn.init.zeros_(self.conv4[-1].weight) + nn.init.zeros_(self.conv4[-1].bias) + + def forward(self, x): + identity = x + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + x = self.conv4(x) + + if self.use_image_dataset: + x = identity + 0.0 * x + else: + x = identity + x + return x + + +def prob_mask_like(shape, prob, device): + if prob == 1: + return torch.ones(shape, device=device, dtype=torch.bool) + elif prob == 0: + return torch.zeros(shape, device=device, dtype=torch.bool) + else: + mask = torch.zeros(shape, device=device).float().uniform_(0, 1) < prob + # aviod mask all, which will cause find_unused_parameters error + if mask.all(): + mask[0] = False + return mask + + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f'unsupported dimensions: {dims}') + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f'unsupported dimensions: {dims}') diff --git a/modelscope/pipelines/multi_modal/__init__.py b/modelscope/pipelines/multi_modal/__init__.py index e8ca1a3c..d199a43a 100644 --- a/modelscope/pipelines/multi_modal/__init__.py +++ b/modelscope/pipelines/multi_modal/__init__.py @@ -19,6 +19,7 @@ if TYPE_CHECKING: from .video_captioning_pipeline import VideoCaptioningPipeline from .video_question_answering_pipeline import VideoQuestionAnsweringPipeline from .diffusers_wrapped import StableDiffusionWrapperPipeline, ChineseStableDiffusionPipeline + from .text_to_video_synthesis_pipeline import TextToVideoSynthesisPipeline else: _import_structure = { 'image_captioning_pipeline': ['ImageCaptioningPipeline'], @@ -39,7 +40,8 @@ else: 'video_question_answering_pipeline': ['VideoQuestionAnsweringPipeline'], 'diffusers_wrapped': - ['StableDiffusionWrapperPipeline', 'ChineseStableDiffusionPipeline'] + ['StableDiffusionWrapperPipeline', 'ChineseStableDiffusionPipeline'], + 'text_to_video_synthesis_pipeline': ['TextToVideoSynthesisPipeline'], } import sys diff --git a/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py b/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py new file mode 100644 index 00000000..ee6635a6 --- /dev/null +++ b/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py @@ -0,0 +1,89 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import tempfile +from typing import Any, Dict, Optional + +import cv2 +import torch +from einops import rearrange + +from modelscope.metainfo import Pipelines +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Model, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.text_to_video_synthesis, + module_name=Pipelines.text_to_video_synthesis) +class TextToVideoSynthesisPipeline(Pipeline): + r""" Text To Video Synthesis Pipeline. + + Examples: + >>> from modelscope.pipelines import pipeline + >>> from modelscope.outputs import OutputKeys + + >>> p = pipeline('text-to-video-synthesis', 'damo/text-to-video-synthesis') + >>> test_text = { + >>> 'text': 'A panda eating bamboo on a rock.', + >>> } + >>> p(test_text,) + + >>> {OutputKeys.OUTPUT_VIDEO: path-to-the-generated-video} + >>> + """ + + def __init__(self, model: str, **kwargs): + """ + use `model` to create a kws pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model, **kwargs) + + def preprocess(self, input: Input, **preprocess_params) -> Dict[str, Any]: + self.model.clip_encoder.to(self.model.device) + text_emb = self.model.clip_encoder(input['text']) + text_emb_zero = self.model.clip_encoder('') + if self.model.config.model.model_args.tiny_gpu == 1: + self.model.clip_encoder.to('cpu') + return {'text_emb': text_emb, 'text_emb_zero': text_emb_zero} + + def forward(self, input: Dict[str, Any], + **forward_params) -> Dict[str, Any]: + video = self.model(input) + return {'video': video} + + def postprocess(self, inputs: Dict[str, Any], + **post_params) -> Dict[str, Any]: + video = tensor2vid(inputs['video']) + output_video_path = post_params.get('output_video', None) + if output_video_path is None: + output_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name + + fourcc = cv2.VideoWriter_fourcc(*'mp4v') + h, w, c = video[0].shape + video_writer = cv2.VideoWriter( + output_video_path, fourcc, fps=8, frameSize=(w, h)) + for i in range(len(video)): + img = cv2.cvtColor(video[i], cv2.COLOR_RGB2BGR) + video_writer.write(img) + return {OutputKeys.OUTPUT_VIDEO: output_video_path} + + +def tensor2vid(video, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]): + mean = torch.tensor( + mean, device=video.device).reshape(1, -1, 1, 1, 1) # ncfhw + std = torch.tensor( + std, device=video.device).reshape(1, -1, 1, 1, 1) # ncfhw + video = video.mul_(std).add_(mean) # unnormalize back to [0,1] + video.clamp_(0, 1) + images = rearrange(video, 'i c f h w -> f h (i w) c') + images = images.unbind(dim=0) + images = [(image.numpy() * 255).astype('uint8') + for image in images] # f h w c + return images diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index 667575ff..382a89d6 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -234,6 +234,7 @@ class MultiModalTasks(object): document_vl_embedding = 'document-vl-embedding' video_captioning = 'video-captioning' video_question_answering = 'video-question-answering' + text_to_video_synthesis = 'text-to-video-synthesis' class ScienceTasks(object): diff --git a/modelscope/utils/error.py b/modelscope/utils/error.py index 44e6b238..d3000130 100644 --- a/modelscope/utils/error.py +++ b/modelscope/utils/error.py @@ -153,3 +153,12 @@ MPI4PY_IMPORT_ERROR = """ `pip install mpi4py' and with following the instruction to install openmpi, https://docs.open-mpi.org/en/v5.0.x/installing-open-mpi/quickstart.html` """ + +# docstyle-ignore +OPENCLIP_IMPORT_ERROR = """ +{0} requires the fasttext library but it was not found in your environment. +You can install it with pip on linux or mac: +`pip install open_clip_torch` +Or you can checkout the instructions on the +installation page: https://github.com/mlfoundations/open_clip and follow the ones that match your environment. +""" diff --git a/modelscope/utils/import_utils.py b/modelscope/utils/import_utils.py index 3517ea3d..ea123ed7 100644 --- a/modelscope/utils/import_utils.py +++ b/modelscope/utils/import_utils.py @@ -304,6 +304,7 @@ REQUIREMENTS_MAAPING = OrderedDict([ ('text2sql_lgesql', (is_package_available('text2sql_lgesql'), TEXT2SQL_LGESQL_IMPORT_ERROR)), ('mpi4py', (is_package_available('mpi4py'), MPI4PY_IMPORT_ERROR)), + ('open_clip', (is_package_available('open_clip'), OPENCLIP_IMPORT_ERROR)), ]) SYSTEM_PACKAGE = set(['os', 'sys', 'typing']) diff --git a/requirements/multi-modal.txt b/requirements/multi-modal.txt index 8a86be8e..ef8037b6 100644 --- a/requirements/multi-modal.txt +++ b/requirements/multi-modal.txt @@ -12,11 +12,13 @@ rapidfuzz # which introduced compatability issues that are being investigated rouge_score<=0.0.4 sacrebleu +# scikit-video soundfile taming-transformers-rom1504 timm tokenizers torchvision transformers>=4.12.0 +# triton==2.0.0.dev20221120 unicodedata2 zhconv diff --git a/tests/pipelines/test_text_to_video_synthesis.py b/tests/pipelines/test_text_to_video_synthesis.py new file mode 100644 index 00000000..6463c155 --- /dev/null +++ b/tests/pipelines/test_text_to_video_synthesis.py @@ -0,0 +1,36 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import unittest + +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class TextToVideoSynthesisTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.text_to_video_synthesis + self.model_id = 'damo/text-to-video-synthesis' + + test_text = { + 'text': 'A panda eating bamboo on a rock.', + } + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_from_modelhub(self): + pipe_line_text_to_video_synthesis = pipeline( + task=self.task, model=self.model_id) + output_video_path = pipe_line_text_to_video_synthesis( + self.test_text)[OutputKeys.OUTPUT_VIDEO] + print(output_video_path) + + @unittest.skip('demo compatibility test is only enabled on a needed-basis') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main()