mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 08:17:45 +01:00
add text-to-video-synthesis
文本生成视频(text-to-video-synthesis)代码 Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11767775
This commit is contained in:
committed by
wenmeng.zwm
parent
a17598b13d
commit
0ca0a8c134
1
.gitignore
vendored
1
.gitignore
vendored
@@ -123,6 +123,7 @@ tensorboard.sh
|
||||
replace.sh
|
||||
result.png
|
||||
result.jpg
|
||||
result.mp4
|
||||
|
||||
# Pytorch
|
||||
*.pth
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:78094cc48fbcfd9b6d321fe13619ecc72b65e006fc1b4c4458409ade9979486d
|
||||
size 129862
|
||||
oid sha256:d53a77b0be82993ed44bbb9244cda42bf460f8dcdf87ff3cfdbfdc7191ff418d
|
||||
size 121984
|
||||
|
||||
@@ -182,6 +182,7 @@ class Models(object):
|
||||
mplug = 'mplug'
|
||||
diffusion = 'diffusion-text-to-image-synthesis'
|
||||
multi_stage_diffusion = 'multi-stage-diffusion-text-to-image-synthesis'
|
||||
video_synthesis = 'latent-text-to-video-synthesis'
|
||||
team = 'team-multi-modal-similarity'
|
||||
video_clip = 'video-clip-multi-modal-embedding'
|
||||
mgeo = 'mgeo'
|
||||
@@ -478,6 +479,7 @@ class Pipelines(object):
|
||||
diffusers_stable_diffusion = 'diffusers-stable-diffusion'
|
||||
document_vl_embedding = 'document-vl-embedding'
|
||||
chinese_stable_diffusion = 'chinese-stable-diffusion'
|
||||
text_to_video_synthesis = 'latent-text-to-video-synthesis' # latent-text-to-video-synthesis
|
||||
gridvlp_multi_modal_classification = 'gridvlp-multi-modal-classification'
|
||||
gridvlp_multi_modal_embedding = 'gridvlp-multi-modal-embedding'
|
||||
|
||||
@@ -615,6 +617,8 @@ DEFAULT_MODEL_FOR_PIPELINE = {
|
||||
Tasks.text_to_image_synthesis:
|
||||
(Pipelines.text_to_image_synthesis,
|
||||
'damo/cv_diffusion_text-to-image-synthesis_tiny'),
|
||||
Tasks.text_to_video_synthesis: (Pipelines.text_to_video_synthesis,
|
||||
'damo/text-to-video-synthesis'),
|
||||
Tasks.body_2d_keypoints: (Pipelines.body_2d_keypoints,
|
||||
'damo/cv_hrnetv2w32_body-2d-keypoints_image'),
|
||||
Tasks.body_3d_keypoints: (Pipelines.body_3d_keypoints,
|
||||
|
||||
@@ -17,6 +17,7 @@ if TYPE_CHECKING:
|
||||
from .multi_stage_diffusion import \
|
||||
MultiStageDiffusionForTextToImageSynthesis
|
||||
from .vldoc import VLDocForDocVLEmbedding
|
||||
from .video_synthesis import TextToVideoSynthesis
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
@@ -32,6 +33,7 @@ else:
|
||||
'multi_stage_diffusion':
|
||||
['MultiStageDiffusionForTextToImageSynthesis'],
|
||||
'vldoc': ['VLDocForDocVLEmbedding'],
|
||||
'video_synthesis': ['TextToVideoSynthesis'],
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
23
modelscope/models/multi_modal/video_synthesis/__init__.py
Normal file
23
modelscope/models/multi_modal/video_synthesis/__init__.py
Normal file
@@ -0,0 +1,23 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
from .text_to_video_synthesis_model import TextToVideoSynthesis
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
'text_to_video_synthesis_model': ['TextToVideoSynthesis'],
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
569
modelscope/models/multi_modal/video_synthesis/autoencoder.py
Normal file
569
modelscope/models/multi_modal/video_synthesis/autoencoder.py
Normal file
@@ -0,0 +1,569 @@
|
||||
# Part of the implementation is borrowed and modified from latent-diffusion,
|
||||
# publicly avaialbe at https://github.com/CompVis/latent-diffusion.
|
||||
# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
__all__ = ['AutoencoderKL']
|
||||
|
||||
|
||||
def nonlinearity(x):
|
||||
# swish
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
def Normalize(in_channels, num_groups=32):
|
||||
return torch.nn.GroupNorm(
|
||||
num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
|
||||
|
||||
class DiagonalGaussianDistribution(object):
|
||||
|
||||
def __init__(self, parameters, deterministic=False):
|
||||
self.parameters = parameters
|
||||
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
|
||||
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
||||
self.deterministic = deterministic
|
||||
self.std = torch.exp(0.5 * self.logvar)
|
||||
self.var = torch.exp(self.logvar)
|
||||
if self.deterministic:
|
||||
self.var = self.std = torch.zeros_like(
|
||||
self.mean).to(device=self.parameters.device)
|
||||
|
||||
def sample(self):
|
||||
x = self.mean + self.std * torch.randn(
|
||||
self.mean.shape).to(device=self.parameters.device)
|
||||
return x
|
||||
|
||||
def kl(self, other=None):
|
||||
if self.deterministic:
|
||||
return torch.Tensor([0.])
|
||||
else:
|
||||
if other is None:
|
||||
return 0.5 * torch.sum(
|
||||
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
|
||||
dim=[1, 2, 3])
|
||||
else:
|
||||
return 0.5 * torch.sum(
|
||||
torch.pow(self.mean - other.mean, 2) / other.var
|
||||
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
|
||||
dim=[1, 2, 3])
|
||||
|
||||
def nll(self, sample, dims=[1, 2, 3]):
|
||||
if self.deterministic:
|
||||
return torch.Tensor([0.])
|
||||
logtwopi = np.log(2.0 * np.pi)
|
||||
return 0.5 * torch.sum(
|
||||
logtwopi + self.logvar
|
||||
+ torch.pow(sample - self.mean, 2) / self.var,
|
||||
dim=dims)
|
||||
|
||||
def mode(self):
|
||||
return self.mean
|
||||
|
||||
|
||||
class ResnetBlock(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
in_channels,
|
||||
out_channels=None,
|
||||
conv_shortcut=False,
|
||||
dropout,
|
||||
temb_channels=512):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
self.use_conv_shortcut = conv_shortcut
|
||||
|
||||
self.norm1 = Normalize(in_channels)
|
||||
self.conv1 = torch.nn.Conv2d(
|
||||
in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
if temb_channels > 0:
|
||||
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
|
||||
self.norm2 = Normalize(out_channels)
|
||||
self.dropout = torch.nn.Dropout(dropout)
|
||||
self.conv2 = torch.nn.Conv2d(
|
||||
out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
self.conv_shortcut = torch.nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
else:
|
||||
self.nin_shortcut = torch.nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
|
||||
def forward(self, x, temb):
|
||||
h = x
|
||||
h = self.norm1(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv1(h)
|
||||
|
||||
if temb is not None:
|
||||
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
|
||||
|
||||
h = self.norm2(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.dropout(h)
|
||||
h = self.conv2(h)
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
x = self.conv_shortcut(x)
|
||||
else:
|
||||
x = self.nin_shortcut(x)
|
||||
|
||||
return x + h
|
||||
|
||||
|
||||
class AttnBlock(nn.Module):
|
||||
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = Normalize(in_channels)
|
||||
self.q = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.k = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.v = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.proj_out = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
# compute attention
|
||||
b, c, h, w = q.shape
|
||||
q = q.reshape(b, c, h * w)
|
||||
q = q.permute(0, 2, 1) # b,hw,c
|
||||
k = k.reshape(b, c, h * w) # b,c,hw
|
||||
w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
||||
w_ = w_ * (int(c)**(-0.5))
|
||||
w_ = torch.nn.functional.softmax(w_, dim=2)
|
||||
|
||||
# attend to values
|
||||
v = v.reshape(b, c, h * w)
|
||||
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
|
||||
h_ = torch.bmm(
|
||||
v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
||||
h_ = h_.reshape(b, c, h, w)
|
||||
|
||||
h_ = self.proj_out(h_)
|
||||
|
||||
return x + h_
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, with_conv):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
if self.with_conv:
|
||||
self.conv = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
x = torch.nn.functional.interpolate(
|
||||
x, scale_factor=2.0, mode='nearest')
|
||||
if self.with_conv:
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, with_conv):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
if self.with_conv:
|
||||
# no asymmetric padding in torch conv, must do it ourselves
|
||||
self.conv = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
if self.with_conv:
|
||||
pad = (0, 1, 0, 1)
|
||||
x = torch.nn.functional.pad(x, pad, mode='constant', value=0)
|
||||
x = self.conv(x)
|
||||
else:
|
||||
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
||||
return x
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
ch,
|
||||
ch_mult=(1, 2, 4, 8),
|
||||
num_res_blocks,
|
||||
attn_resolutions,
|
||||
dropout=0.0,
|
||||
resamp_with_conv=True,
|
||||
in_channels,
|
||||
resolution,
|
||||
z_channels,
|
||||
double_z=True,
|
||||
**ignore_kwargs):
|
||||
super().__init__()
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
|
||||
# downsampling
|
||||
self.conv_in = torch.nn.Conv2d(
|
||||
in_channels, self.ch, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
curr_res = resolution
|
||||
in_ch_mult = (1, ) + tuple(ch_mult)
|
||||
self.in_ch_mult = in_ch_mult
|
||||
self.down = nn.ModuleList()
|
||||
for i_level in range(self.num_resolutions):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_in = ch * in_ch_mult[i_level]
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks):
|
||||
block.append(
|
||||
ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout))
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(AttnBlock(block_in))
|
||||
down = nn.Module()
|
||||
down.block = block
|
||||
down.attn = attn
|
||||
if i_level != self.num_resolutions - 1:
|
||||
down.downsample = Downsample(block_in, resamp_with_conv)
|
||||
curr_res = curr_res // 2
|
||||
self.down.append(down)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
self.mid.attn_1 = AttnBlock(block_in)
|
||||
self.mid.block_2 = ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = torch.nn.Conv2d(
|
||||
block_in,
|
||||
2 * z_channels if double_z else z_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
# timestep embedding
|
||||
temb = None
|
||||
|
||||
# downsampling
|
||||
hs = [self.conv_in(x)]
|
||||
for i_level in range(self.num_resolutions):
|
||||
for i_block in range(self.num_res_blocks):
|
||||
h = self.down[i_level].block[i_block](hs[-1], temb)
|
||||
if len(self.down[i_level].attn) > 0:
|
||||
h = self.down[i_level].attn[i_block](h)
|
||||
hs.append(h)
|
||||
if i_level != self.num_resolutions - 1:
|
||||
hs.append(self.down[i_level].downsample(hs[-1]))
|
||||
|
||||
# middle
|
||||
h = hs[-1]
|
||||
h = self.mid.block_1(h, temb)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h, temb)
|
||||
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
ch,
|
||||
out_ch,
|
||||
ch_mult=(1, 2, 4, 8),
|
||||
num_res_blocks,
|
||||
attn_resolutions,
|
||||
dropout=0.0,
|
||||
resamp_with_conv=True,
|
||||
in_channels,
|
||||
resolution,
|
||||
z_channels,
|
||||
give_pre_end=False,
|
||||
tanh_out=False,
|
||||
**ignorekwargs):
|
||||
super().__init__()
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
self.give_pre_end = give_pre_end
|
||||
self.tanh_out = tanh_out
|
||||
|
||||
# compute block_in and curr_res at lowest res
|
||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||
curr_res = resolution // 2**(self.num_resolutions - 1)
|
||||
self.z_shape = (1, z_channels, curr_res, curr_res)
|
||||
|
||||
# z to block_in
|
||||
self.conv_in = torch.nn.Conv2d(
|
||||
z_channels, block_in, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
self.mid.attn_1 = AttnBlock(block_in)
|
||||
self.mid.block_2 = ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
|
||||
# upsampling
|
||||
self.up = nn.ModuleList()
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
block.append(
|
||||
ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout))
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(AttnBlock(block_in))
|
||||
up = nn.Module()
|
||||
up.block = block
|
||||
up.attn = attn
|
||||
if i_level != 0:
|
||||
up.upsample = Upsample(block_in, resamp_with_conv)
|
||||
curr_res = curr_res * 2
|
||||
self.up.insert(0, up) # prepend to get consistent order
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = torch.nn.Conv2d(
|
||||
block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, z):
|
||||
self.last_z_shape = z.shape
|
||||
|
||||
# timestep embedding
|
||||
temb = None
|
||||
|
||||
# z to block_in
|
||||
h = self.conv_in(z)
|
||||
|
||||
# middle
|
||||
h = self.mid.block_1(h, temb)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h, temb)
|
||||
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
h = self.up[i_level].block[i_block](h, temb)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
h = self.up[i_level].attn[i_block](h)
|
||||
if i_level != 0:
|
||||
h = self.up[i_level].upsample(h)
|
||||
|
||||
# end
|
||||
if self.give_pre_end:
|
||||
return h
|
||||
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
if self.tanh_out:
|
||||
h = torch.tanh(h)
|
||||
return h
|
||||
|
||||
|
||||
class AutoencoderKL(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
ddconfig,
|
||||
embed_dim,
|
||||
ckpt_path=None,
|
||||
image_key='image',
|
||||
colorize_nlabels=None,
|
||||
monitor=None,
|
||||
ema_decay=None,
|
||||
learn_logvar=False):
|
||||
super().__init__()
|
||||
self.learn_logvar = learn_logvar
|
||||
self.image_key = image_key
|
||||
self.encoder = Encoder(**ddconfig)
|
||||
self.decoder = Decoder(**ddconfig)
|
||||
assert ddconfig['double_z']
|
||||
self.quant_conv = torch.nn.Conv2d(2 * ddconfig['z_channels'],
|
||||
2 * embed_dim, 1)
|
||||
self.post_quant_conv = torch.nn.Conv2d(embed_dim,
|
||||
ddconfig['z_channels'], 1)
|
||||
self.embed_dim = embed_dim
|
||||
if colorize_nlabels is not None:
|
||||
assert type(colorize_nlabels) == int
|
||||
self.register_buffer('colorize',
|
||||
torch.randn(3, colorize_nlabels, 1, 1))
|
||||
if monitor is not None:
|
||||
self.monitor = monitor
|
||||
|
||||
self.use_ema = ema_decay is not None
|
||||
|
||||
if ckpt_path is not None:
|
||||
self.init_from_ckpt(ckpt_path)
|
||||
|
||||
def init_from_ckpt(self, path):
|
||||
sd = torch.load(path, map_location='cpu')['state_dict']
|
||||
keys = list(sd.keys())
|
||||
|
||||
import collections
|
||||
sd_new = collections.OrderedDict()
|
||||
|
||||
for k in keys:
|
||||
if k.find('first_stage_model') >= 0:
|
||||
k_new = k.split('first_stage_model.')[-1]
|
||||
sd_new[k_new] = sd[k]
|
||||
|
||||
self.load_state_dict(sd_new, strict=True)
|
||||
|
||||
def on_train_batch_end(self, *args, **kwargs):
|
||||
if self.use_ema:
|
||||
self.model_ema(self)
|
||||
|
||||
def encode(self, x):
|
||||
h = self.encoder(x)
|
||||
moments = self.quant_conv(h)
|
||||
posterior = DiagonalGaussianDistribution(moments)
|
||||
return posterior
|
||||
|
||||
def decode(self, z):
|
||||
z = self.post_quant_conv(z)
|
||||
dec = self.decoder(z)
|
||||
return dec
|
||||
|
||||
def forward(self, input, sample_posterior=True):
|
||||
posterior = self.encode(input)
|
||||
if sample_posterior:
|
||||
z = posterior.sample()
|
||||
else:
|
||||
z = posterior.mode()
|
||||
dec = self.decode(z)
|
||||
return dec, posterior
|
||||
|
||||
def get_input(self, batch, k):
|
||||
x = batch[k]
|
||||
if len(x.shape) == 3:
|
||||
x = x[..., None]
|
||||
x = x.permute(0, 3, 1,
|
||||
2).to(memory_format=torch.contiguous_format).float()
|
||||
return x
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.decoder.conv_out.weight
|
||||
|
||||
@torch.no_grad()
|
||||
def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs):
|
||||
log = dict()
|
||||
x = self.get_input(batch, self.image_key)
|
||||
x = x.to(self.device)
|
||||
if not only_inputs:
|
||||
xrec, posterior = self(x)
|
||||
if x.shape[1] > 3:
|
||||
# colorize with random projection
|
||||
assert xrec.shape[1] > 3
|
||||
x = self.to_rgb(x)
|
||||
xrec = self.to_rgb(xrec)
|
||||
log['samples'] = self.decode(torch.randn_like(posterior.sample()))
|
||||
log['reconstructions'] = xrec
|
||||
if log_ema or self.use_ema:
|
||||
with self.ema_scope():
|
||||
xrec_ema, posterior_ema = self(x)
|
||||
if x.shape[1] > 3:
|
||||
# colorize with random projection
|
||||
assert xrec_ema.shape[1] > 3
|
||||
xrec_ema = self.to_rgb(xrec_ema)
|
||||
log['samples_ema'] = self.decode(
|
||||
torch.randn_like(posterior_ema.sample()))
|
||||
log['reconstructions_ema'] = xrec_ema
|
||||
log['inputs'] = x
|
||||
return log
|
||||
|
||||
def to_rgb(self, x):
|
||||
assert self.image_key == 'segmentation'
|
||||
if not hasattr(self, 'colorize'):
|
||||
self.register_buffer('colorize',
|
||||
torch.randn(3, x.shape[1], 1, 1).to(x))
|
||||
x = F.conv2d(x, weight=self.colorize)
|
||||
x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
|
||||
return x
|
||||
|
||||
|
||||
class IdentityFirstStage(torch.nn.Module):
|
||||
|
||||
def __init__(self, *args, vq_interface=False, **kwargs):
|
||||
self.vq_interface = vq_interface
|
||||
super().__init__()
|
||||
|
||||
def encode(self, x, *args, **kwargs):
|
||||
return x
|
||||
|
||||
def decode(self, x, *args, **kwargs):
|
||||
return x
|
||||
|
||||
def quantize(self, x, *args, **kwargs):
|
||||
if self.vq_interface:
|
||||
return x, None, [None, None, None]
|
||||
return x
|
||||
|
||||
def forward(self, x, *args, **kwargs):
|
||||
return x
|
||||
227
modelscope/models/multi_modal/video_synthesis/diffusion.py
Normal file
227
modelscope/models/multi_modal/video_synthesis/diffusion.py
Normal file
@@ -0,0 +1,227 @@
|
||||
# Part of the implementation is borrowed and modified from latent-diffusion,
|
||||
# publicly avaialbe at https://github.com/CompVis/latent-diffusion.
|
||||
# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.
|
||||
|
||||
import torch
|
||||
|
||||
__all__ = ['GaussianDiffusion', 'beta_schedule']
|
||||
|
||||
|
||||
def _i(tensor, t, x):
|
||||
r"""Index tensor using t and format the output according to x.
|
||||
"""
|
||||
shape = (x.size(0), ) + (1, ) * (x.ndim - 1)
|
||||
return tensor[t].view(shape).to(x)
|
||||
|
||||
|
||||
def beta_schedule(schedule,
|
||||
num_timesteps=1000,
|
||||
init_beta=None,
|
||||
last_beta=None):
|
||||
if schedule == 'linear_sd':
|
||||
return torch.linspace(
|
||||
init_beta**0.5, last_beta**0.5, num_timesteps,
|
||||
dtype=torch.float64)**2
|
||||
else:
|
||||
raise ValueError(f'Unsupported schedule: {schedule}')
|
||||
|
||||
|
||||
class GaussianDiffusion(object):
|
||||
r""" Diffusion Model for DDIM.
|
||||
"Denoising diffusion implicit models." by Song, Jiaming, Chenlin Meng, and Stefano Ermon.
|
||||
See https://arxiv.org/abs/2010.02502
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
betas,
|
||||
mean_type='eps',
|
||||
var_type='learned_range',
|
||||
loss_type='mse',
|
||||
epsilon=1e-12,
|
||||
rescale_timesteps=False):
|
||||
# check input
|
||||
if not isinstance(betas, torch.DoubleTensor):
|
||||
betas = torch.tensor(betas, dtype=torch.float64)
|
||||
assert min(betas) > 0 and max(betas) <= 1
|
||||
assert mean_type in ['x0', 'x_{t-1}', 'eps']
|
||||
assert var_type in [
|
||||
'learned', 'learned_range', 'fixed_large', 'fixed_small'
|
||||
]
|
||||
assert loss_type in [
|
||||
'mse', 'rescaled_mse', 'kl', 'rescaled_kl', 'l1', 'rescaled_l1',
|
||||
'charbonnier'
|
||||
]
|
||||
self.betas = betas
|
||||
self.num_timesteps = len(betas)
|
||||
self.mean_type = mean_type
|
||||
self.var_type = var_type
|
||||
self.loss_type = loss_type
|
||||
self.epsilon = epsilon
|
||||
self.rescale_timesteps = rescale_timesteps
|
||||
|
||||
# alphas
|
||||
alphas = 1 - self.betas
|
||||
self.alphas_cumprod = torch.cumprod(alphas, dim=0)
|
||||
self.alphas_cumprod_prev = torch.cat(
|
||||
[alphas.new_ones([1]), self.alphas_cumprod[:-1]])
|
||||
self.alphas_cumprod_next = torch.cat(
|
||||
[self.alphas_cumprod[1:],
|
||||
alphas.new_zeros([1])])
|
||||
|
||||
# q(x_t | x_{t-1})
|
||||
self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
|
||||
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0
|
||||
- self.alphas_cumprod)
|
||||
self.log_one_minus_alphas_cumprod = torch.log(1.0
|
||||
- self.alphas_cumprod)
|
||||
self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod)
|
||||
self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod
|
||||
- 1)
|
||||
|
||||
# q(x_{t-1} | x_t, x_0)
|
||||
self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / (
|
||||
1.0 - self.alphas_cumprod)
|
||||
self.posterior_log_variance_clipped = torch.log(
|
||||
self.posterior_variance.clamp(1e-20))
|
||||
self.posterior_mean_coef1 = betas * torch.sqrt(
|
||||
self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
|
||||
self.posterior_mean_coef2 = (
|
||||
1.0 - self.alphas_cumprod_prev) * torch.sqrt(alphas) / (
|
||||
1.0 - self.alphas_cumprod)
|
||||
|
||||
def p_mean_variance(self,
|
||||
xt,
|
||||
t,
|
||||
model,
|
||||
model_kwargs={},
|
||||
clamp=None,
|
||||
percentile=None,
|
||||
guide_scale=None):
|
||||
r"""Distribution of p(x_{t-1} | x_t).
|
||||
"""
|
||||
# predict distribution
|
||||
if guide_scale is None:
|
||||
out = model(xt, self._scale_timesteps(t), **model_kwargs)
|
||||
else:
|
||||
# classifier-free guidance
|
||||
# (model_kwargs[0]: conditional kwargs; model_kwargs[1]: non-conditional kwargs)
|
||||
assert isinstance(model_kwargs, list) and len(model_kwargs) == 2
|
||||
y_out = model(xt, self._scale_timesteps(t), **model_kwargs[0])
|
||||
u_out = model(xt, self._scale_timesteps(t), **model_kwargs[1])
|
||||
dim = y_out.size(1) if self.var_type.startswith(
|
||||
'fixed') else y_out.size(1) // 2
|
||||
a = u_out[:, :dim]
|
||||
b = guide_scale * (y_out[:, :dim] - u_out[:, :dim])
|
||||
c = y_out[:, dim:]
|
||||
out = torch.cat([a + b, c], dim=1)
|
||||
|
||||
# compute variance
|
||||
if self.var_type == 'fixed_small':
|
||||
var = _i(self.posterior_variance, t, xt)
|
||||
log_var = _i(self.posterior_log_variance_clipped, t, xt)
|
||||
|
||||
# compute mean and x0
|
||||
if self.mean_type == 'eps':
|
||||
x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - _i(
|
||||
self.sqrt_recipm1_alphas_cumprod, t, xt) * out
|
||||
mu, _, _ = self.q_posterior_mean_variance(x0, xt, t)
|
||||
|
||||
# restrict the range of x0
|
||||
if percentile is not None:
|
||||
assert percentile > 0 and percentile <= 1 # e.g., 0.995
|
||||
s = torch.quantile(
|
||||
x0.flatten(1).abs(), percentile,
|
||||
dim=1).clamp_(1.0).view(-1, 1, 1, 1)
|
||||
x0 = torch.min(s, torch.max(-s, x0)) / s
|
||||
elif clamp is not None:
|
||||
x0 = x0.clamp(-clamp, clamp)
|
||||
return mu, var, log_var, x0
|
||||
|
||||
def q_posterior_mean_variance(self, x0, xt, t):
|
||||
r"""Distribution of q(x_{t-1} | x_t, x_0).
|
||||
"""
|
||||
mu = _i(self.posterior_mean_coef1, t, xt) * x0 + _i(
|
||||
self.posterior_mean_coef2, t, xt) * xt
|
||||
var = _i(self.posterior_variance, t, xt)
|
||||
log_var = _i(self.posterior_log_variance_clipped, t, xt)
|
||||
return mu, var, log_var
|
||||
|
||||
@torch.no_grad()
|
||||
def ddim_sample(self,
|
||||
xt,
|
||||
t,
|
||||
model,
|
||||
model_kwargs={},
|
||||
clamp=None,
|
||||
percentile=None,
|
||||
condition_fn=None,
|
||||
guide_scale=None,
|
||||
ddim_timesteps=20,
|
||||
eta=0.0):
|
||||
r"""Sample from p(x_{t-1} | x_t) using DDIM.
|
||||
- condition_fn: for classifier-based guidance (guided-diffusion).
|
||||
- guide_scale: for classifier-free guidance (glide/dalle-2).
|
||||
"""
|
||||
stride = self.num_timesteps // ddim_timesteps
|
||||
|
||||
# predict distribution of p(x_{t-1} | x_t)
|
||||
_, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp,
|
||||
percentile, guide_scale)
|
||||
if condition_fn is not None:
|
||||
# x0 -> eps
|
||||
alpha = _i(self.alphas_cumprod, t, xt)
|
||||
eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / _i(
|
||||
self.sqrt_recipm1_alphas_cumprod, t, xt)
|
||||
eps = eps - (1 - alpha).sqrt() * condition_fn(
|
||||
xt, self._scale_timesteps(t), **model_kwargs)
|
||||
|
||||
# eps -> x0
|
||||
x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - _i(
|
||||
self.sqrt_recipm1_alphas_cumprod, t, xt) * eps
|
||||
|
||||
# derive variables
|
||||
eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / _i(
|
||||
self.sqrt_recipm1_alphas_cumprod, t, xt)
|
||||
alphas = _i(self.alphas_cumprod, t, xt)
|
||||
alphas_prev = _i(self.alphas_cumprod, (t - stride).clamp(0), xt)
|
||||
a = (1 - alphas_prev) / (1 - alphas)
|
||||
b = (1 - alphas / alphas_prev)
|
||||
sigmas = eta * torch.sqrt(a * b)
|
||||
|
||||
# random sample
|
||||
noise = torch.randn_like(xt)
|
||||
direction = torch.sqrt(1 - alphas_prev - sigmas**2) * eps
|
||||
mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1)))
|
||||
xt_1 = torch.sqrt(alphas_prev) * x0 + direction + mask * sigmas * noise
|
||||
return xt_1, x0
|
||||
|
||||
@torch.no_grad()
|
||||
def ddim_sample_loop(self,
|
||||
noise,
|
||||
model,
|
||||
model_kwargs={},
|
||||
clamp=None,
|
||||
percentile=None,
|
||||
condition_fn=None,
|
||||
guide_scale=None,
|
||||
ddim_timesteps=20,
|
||||
eta=0.0):
|
||||
# prepare input
|
||||
b = noise.size(0)
|
||||
xt = noise
|
||||
|
||||
# diffusion process (TODO: clamp is inaccurate! Consider replacing the stride by explicit prev/next steps)
|
||||
steps = (1 + torch.arange(0, self.num_timesteps,
|
||||
self.num_timesteps // ddim_timesteps)).clamp(
|
||||
0, self.num_timesteps - 1).flip(0)
|
||||
for step in steps:
|
||||
t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
|
||||
xt, _ = self.ddim_sample(xt, t, model, model_kwargs, clamp,
|
||||
percentile, condition_fn, guide_scale,
|
||||
ddim_timesteps, eta)
|
||||
return xt
|
||||
|
||||
def _scale_timesteps(self, t):
|
||||
if self.rescale_timesteps:
|
||||
return t.float() * 1000.0 / self.num_timesteps
|
||||
return t
|
||||
@@ -0,0 +1,241 @@
|
||||
# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.
|
||||
|
||||
import os
|
||||
from os import path as osp
|
||||
from typing import Any, Dict
|
||||
|
||||
import open_clip
|
||||
import torch
|
||||
import torch.cuda.amp as amp
|
||||
from einops import rearrange
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models.base import Model
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.models.multi_modal.video_synthesis.autoencoder import \
|
||||
AutoencoderKL
|
||||
from modelscope.models.multi_modal.video_synthesis.diffusion import (
|
||||
GaussianDiffusion, beta_schedule)
|
||||
from modelscope.models.multi_modal.video_synthesis.unet_sd import UNetSD
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
|
||||
__all__ = ['TextToVideoSynthesis']
|
||||
|
||||
|
||||
@MODELS.register_module(
|
||||
Tasks.text_to_video_synthesis, module_name=Models.video_synthesis)
|
||||
class TextToVideoSynthesis(Model):
|
||||
r"""
|
||||
task for text to video synthesis.
|
||||
|
||||
Attributes:
|
||||
sd_model: denosing model using in this task.
|
||||
diffusion: diffusion model for DDIM.
|
||||
autoencoder: decode the latent representation into visual space with VQGAN.
|
||||
clip_encoder: encode the text into text embedding.
|
||||
"""
|
||||
|
||||
def __init__(self, model_dir, *args, **kwargs):
|
||||
r"""
|
||||
Args:
|
||||
model_dir (`str` or `os.PathLike`)
|
||||
Can be either:
|
||||
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co
|
||||
or modelscope.cn. Valid model ids can be located at the root-level, like `bert-base-uncased`,
|
||||
or namespaced under a user or organization name, like `dbmdz/bert-base-german-cased`.
|
||||
- A path to a *directory* containing model weights saved using
|
||||
[`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
|
||||
- A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In
|
||||
this case, `from_tf` should be set to `True` and a configuration object should be provided as
|
||||
`config` argument. This loading path is slower than converting the TensorFlow checkpoint in a
|
||||
PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
|
||||
- A path or url to a model folder containing a *flax checkpoint file* in *.msgpack* format (e.g,
|
||||
`./flax_model/` containing `flax_model.msgpack`). In this case, `from_flax` should be set to
|
||||
`True`.
|
||||
"""
|
||||
super().__init__(model_dir=model_dir, *args, **kwargs)
|
||||
self.device = torch.device('cuda') if torch.cuda.is_available() \
|
||||
else torch.device('cpu')
|
||||
self.config = Config.from_file(
|
||||
osp.join(model_dir, ModelFile.CONFIGURATION))
|
||||
cfg = self.config.model.model_cfg
|
||||
cfg['temporal_attention'] = True if cfg[
|
||||
'temporal_attention'] == 'True' else False
|
||||
|
||||
# Initialize unet
|
||||
self.sd_model = UNetSD(
|
||||
in_dim=cfg['unet_in_dim'],
|
||||
dim=cfg['unet_dim'],
|
||||
y_dim=cfg['unet_y_dim'],
|
||||
context_dim=cfg['unet_context_dim'],
|
||||
out_dim=cfg['unet_out_dim'],
|
||||
dim_mult=cfg['unet_dim_mult'],
|
||||
num_heads=cfg['unet_num_heads'],
|
||||
head_dim=cfg['unet_head_dim'],
|
||||
num_res_blocks=cfg['unet_res_blocks'],
|
||||
attn_scales=cfg['unet_attn_scales'],
|
||||
dropout=cfg['unet_dropout'],
|
||||
temporal_attention=cfg['temporal_attention'])
|
||||
self.sd_model.load_state_dict(
|
||||
torch.load(
|
||||
osp.join(model_dir, self.config.model.model_args.ckpt_unet)),
|
||||
strict=True)
|
||||
self.sd_model.eval()
|
||||
self.sd_model.to(self.device)
|
||||
|
||||
# Initialize diffusion
|
||||
betas = beta_schedule(
|
||||
'linear_sd',
|
||||
cfg['num_timesteps'],
|
||||
init_beta=0.00085,
|
||||
last_beta=0.0120)
|
||||
self.diffusion = GaussianDiffusion(
|
||||
betas=betas,
|
||||
mean_type=cfg['mean_type'],
|
||||
var_type=cfg['var_type'],
|
||||
loss_type=cfg['loss_type'],
|
||||
rescale_timesteps=False)
|
||||
|
||||
# Initialize autoencoder
|
||||
ddconfig = {
|
||||
'double_z': True,
|
||||
'z_channels': 4,
|
||||
'resolution': 256,
|
||||
'in_channels': 3,
|
||||
'out_ch': 3,
|
||||
'ch': 128,
|
||||
'ch_mult': [1, 2, 4, 4],
|
||||
'num_res_blocks': 2,
|
||||
'attn_resolutions': [],
|
||||
'dropout': 0.0
|
||||
}
|
||||
self.autoencoder = AutoencoderKL(
|
||||
ddconfig, 4,
|
||||
osp.join(model_dir, self.config.model.model_args.ckpt_autoencoder))
|
||||
if self.config.model.model_args.tiny_gpu == 1:
|
||||
self.autoencoder.to('cpu')
|
||||
else:
|
||||
self.autoencoder.to(self.device)
|
||||
self.autoencoder.eval()
|
||||
|
||||
# Initialize Open clip
|
||||
self.clip_encoder = FrozenOpenCLIPEmbedder(
|
||||
version=osp.join(model_dir,
|
||||
self.config.model.model_args.ckpt_clip),
|
||||
layer='penultimate')
|
||||
if self.config.model.model_args.tiny_gpu == 1:
|
||||
self.clip_encoder.to('cpu')
|
||||
else:
|
||||
self.clip_encoder.to(self.device)
|
||||
|
||||
def forward(self, input: Dict[str, Any]):
|
||||
r"""
|
||||
The entry function of text to image synthesis task.
|
||||
1. Using diffusion model to generate the video's latent representation.
|
||||
2. Using vqgan model (autoencoder) to decode the video's latent representation to visual space.
|
||||
|
||||
Args:
|
||||
input (`Dict[Str, Any]`):
|
||||
The input of the task
|
||||
Returns:
|
||||
A generated video (as pytorch tensor).
|
||||
"""
|
||||
y = input['text_emb']
|
||||
zero_y = input['text_emb_zero']
|
||||
context = torch.cat([zero_y, y], dim=0).to(self.device)
|
||||
# synthesis
|
||||
with torch.no_grad():
|
||||
num_sample = 1 # here let b = 1
|
||||
max_frames = self.config.model.model_args.max_frames
|
||||
latent_h, latent_w = 32, 32
|
||||
with amp.autocast(enabled=True):
|
||||
x0 = self.diffusion.ddim_sample_loop(
|
||||
noise=torch.randn(num_sample, 4, max_frames, latent_h,
|
||||
latent_w).to(
|
||||
self.device), # shape: b c f h w
|
||||
model=self.sd_model,
|
||||
model_kwargs=[{
|
||||
'y':
|
||||
context[1].unsqueeze(0).repeat(num_sample, 1, 1)
|
||||
}, {
|
||||
'y':
|
||||
context[0].unsqueeze(0).repeat(num_sample, 1, 1)
|
||||
}],
|
||||
guide_scale=9.0,
|
||||
ddim_timesteps=50,
|
||||
eta=0.0)
|
||||
|
||||
scale_factor = 0.18215
|
||||
video_data = 1. / scale_factor * x0
|
||||
bs_vd = video_data.shape[0]
|
||||
video_data = rearrange(video_data, 'b c f h w -> (b f) c h w')
|
||||
self.autoencoder.to(self.device)
|
||||
video_data = self.autoencoder.decode(video_data)
|
||||
if self.config.model.model_args.tiny_gpu == 1:
|
||||
self.autoencoder.to('cpu')
|
||||
video_data = rearrange(
|
||||
video_data, '(b f) c h w -> b c f h w', b=bs_vd)
|
||||
return video_data.type(torch.float32).cpu()
|
||||
|
||||
|
||||
class FrozenOpenCLIPEmbedder(torch.nn.Module):
|
||||
"""
|
||||
Uses the OpenCLIP transformer encoder for text
|
||||
"""
|
||||
LAYERS = ['last', 'penultimate']
|
||||
|
||||
def __init__(self,
|
||||
arch='ViT-H-14',
|
||||
version='open_clip_pytorch_model.bin',
|
||||
device='cuda',
|
||||
max_length=77,
|
||||
freeze=True,
|
||||
layer='last'):
|
||||
super().__init__()
|
||||
assert layer in self.LAYERS
|
||||
model, _, _ = open_clip.create_model_and_transforms(
|
||||
arch, device=torch.device('cpu'), pretrained=version)
|
||||
del model.visual
|
||||
self.model = model
|
||||
|
||||
self.device = device
|
||||
self.max_length = max_length
|
||||
if freeze:
|
||||
self.freeze()
|
||||
self.layer = layer
|
||||
if self.layer == 'last':
|
||||
self.layer_idx = 0
|
||||
elif self.layer == 'penultimate':
|
||||
self.layer_idx = 1
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
def freeze(self):
|
||||
self.model = self.model.eval()
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, text):
|
||||
tokens = open_clip.tokenize(text)
|
||||
z = self.encode_with_transformer(tokens.to(self.device))
|
||||
return z
|
||||
|
||||
def encode_with_transformer(self, text):
|
||||
x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
|
||||
x = x + self.model.positional_embedding
|
||||
x = x.permute(1, 0, 2) # NLD -> LND
|
||||
x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
|
||||
x = x.permute(1, 0, 2) # LND -> NLD
|
||||
x = self.model.ln_final(x)
|
||||
return x
|
||||
|
||||
def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
|
||||
for i, r in enumerate(self.model.transformer.resblocks):
|
||||
if i == len(self.model.transformer.resblocks) - self.layer_idx:
|
||||
break
|
||||
x = r(x, attn_mask=attn_mask)
|
||||
return x
|
||||
|
||||
def encode(self, text):
|
||||
return self(text)
|
||||
1098
modelscope/models/multi_modal/video_synthesis/unet_sd.py
Normal file
1098
modelscope/models/multi_modal/video_synthesis/unet_sd.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -19,6 +19,7 @@ if TYPE_CHECKING:
|
||||
from .video_captioning_pipeline import VideoCaptioningPipeline
|
||||
from .video_question_answering_pipeline import VideoQuestionAnsweringPipeline
|
||||
from .diffusers_wrapped import StableDiffusionWrapperPipeline, ChineseStableDiffusionPipeline
|
||||
from .text_to_video_synthesis_pipeline import TextToVideoSynthesisPipeline
|
||||
else:
|
||||
_import_structure = {
|
||||
'image_captioning_pipeline': ['ImageCaptioningPipeline'],
|
||||
@@ -39,7 +40,8 @@ else:
|
||||
'video_question_answering_pipeline':
|
||||
['VideoQuestionAnsweringPipeline'],
|
||||
'diffusers_wrapped':
|
||||
['StableDiffusionWrapperPipeline', 'ChineseStableDiffusionPipeline']
|
||||
['StableDiffusionWrapperPipeline', 'ChineseStableDiffusionPipeline'],
|
||||
'text_to_video_synthesis_pipeline': ['TextToVideoSynthesisPipeline'],
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
@@ -0,0 +1,89 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import tempfile
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
from einops import rearrange
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.base import Input, Model, Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.text_to_video_synthesis,
|
||||
module_name=Pipelines.text_to_video_synthesis)
|
||||
class TextToVideoSynthesisPipeline(Pipeline):
|
||||
r""" Text To Video Synthesis Pipeline.
|
||||
|
||||
Examples:
|
||||
>>> from modelscope.pipelines import pipeline
|
||||
>>> from modelscope.outputs import OutputKeys
|
||||
|
||||
>>> p = pipeline('text-to-video-synthesis', 'damo/text-to-video-synthesis')
|
||||
>>> test_text = {
|
||||
>>> 'text': 'A panda eating bamboo on a rock.',
|
||||
>>> }
|
||||
>>> p(test_text,)
|
||||
|
||||
>>> {OutputKeys.OUTPUT_VIDEO: path-to-the-generated-video}
|
||||
>>>
|
||||
"""
|
||||
|
||||
def __init__(self, model: str, **kwargs):
|
||||
"""
|
||||
use `model` to create a kws pipeline for prediction
|
||||
Args:
|
||||
model: model id on modelscope hub.
|
||||
"""
|
||||
super().__init__(model=model, **kwargs)
|
||||
|
||||
def preprocess(self, input: Input, **preprocess_params) -> Dict[str, Any]:
|
||||
self.model.clip_encoder.to(self.model.device)
|
||||
text_emb = self.model.clip_encoder(input['text'])
|
||||
text_emb_zero = self.model.clip_encoder('')
|
||||
if self.model.config.model.model_args.tiny_gpu == 1:
|
||||
self.model.clip_encoder.to('cpu')
|
||||
return {'text_emb': text_emb, 'text_emb_zero': text_emb_zero}
|
||||
|
||||
def forward(self, input: Dict[str, Any],
|
||||
**forward_params) -> Dict[str, Any]:
|
||||
video = self.model(input)
|
||||
return {'video': video}
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Any],
|
||||
**post_params) -> Dict[str, Any]:
|
||||
video = tensor2vid(inputs['video'])
|
||||
output_video_path = post_params.get('output_video', None)
|
||||
if output_video_path is None:
|
||||
output_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name
|
||||
|
||||
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
||||
h, w, c = video[0].shape
|
||||
video_writer = cv2.VideoWriter(
|
||||
output_video_path, fourcc, fps=8, frameSize=(w, h))
|
||||
for i in range(len(video)):
|
||||
img = cv2.cvtColor(video[i], cv2.COLOR_RGB2BGR)
|
||||
video_writer.write(img)
|
||||
return {OutputKeys.OUTPUT_VIDEO: output_video_path}
|
||||
|
||||
|
||||
def tensor2vid(video, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]):
|
||||
mean = torch.tensor(
|
||||
mean, device=video.device).reshape(1, -1, 1, 1, 1) # ncfhw
|
||||
std = torch.tensor(
|
||||
std, device=video.device).reshape(1, -1, 1, 1, 1) # ncfhw
|
||||
video = video.mul_(std).add_(mean) # unnormalize back to [0,1]
|
||||
video.clamp_(0, 1)
|
||||
images = rearrange(video, 'i c f h w -> f h (i w) c')
|
||||
images = images.unbind(dim=0)
|
||||
images = [(image.numpy() * 255).astype('uint8')
|
||||
for image in images] # f h w c
|
||||
return images
|
||||
@@ -234,6 +234,7 @@ class MultiModalTasks(object):
|
||||
document_vl_embedding = 'document-vl-embedding'
|
||||
video_captioning = 'video-captioning'
|
||||
video_question_answering = 'video-question-answering'
|
||||
text_to_video_synthesis = 'text-to-video-synthesis'
|
||||
|
||||
|
||||
class ScienceTasks(object):
|
||||
|
||||
@@ -153,3 +153,12 @@ MPI4PY_IMPORT_ERROR = """
|
||||
`pip install mpi4py' and with following the instruction to install openmpi,
|
||||
https://docs.open-mpi.org/en/v5.0.x/installing-open-mpi/quickstart.html`
|
||||
"""
|
||||
|
||||
# docstyle-ignore
|
||||
OPENCLIP_IMPORT_ERROR = """
|
||||
{0} requires the fasttext library but it was not found in your environment.
|
||||
You can install it with pip on linux or mac:
|
||||
`pip install open_clip_torch`
|
||||
Or you can checkout the instructions on the
|
||||
installation page: https://github.com/mlfoundations/open_clip and follow the ones that match your environment.
|
||||
"""
|
||||
|
||||
@@ -304,6 +304,7 @@ REQUIREMENTS_MAAPING = OrderedDict([
|
||||
('text2sql_lgesql', (is_package_available('text2sql_lgesql'),
|
||||
TEXT2SQL_LGESQL_IMPORT_ERROR)),
|
||||
('mpi4py', (is_package_available('mpi4py'), MPI4PY_IMPORT_ERROR)),
|
||||
('open_clip', (is_package_available('open_clip'), OPENCLIP_IMPORT_ERROR)),
|
||||
])
|
||||
|
||||
SYSTEM_PACKAGE = set(['os', 'sys', 'typing'])
|
||||
|
||||
@@ -12,11 +12,13 @@ rapidfuzz
|
||||
# which introduced compatability issues that are being investigated
|
||||
rouge_score<=0.0.4
|
||||
sacrebleu
|
||||
# scikit-video
|
||||
soundfile
|
||||
taming-transformers-rom1504
|
||||
timm
|
||||
tokenizers
|
||||
torchvision
|
||||
transformers>=4.12.0
|
||||
# triton==2.0.0.dev20221120
|
||||
unicodedata2
|
||||
zhconv
|
||||
|
||||
36
tests/pipelines/test_text_to_video_synthesis.py
Normal file
36
tests/pipelines/test_text_to_video_synthesis.py
Normal file
@@ -0,0 +1,36 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import unittest
|
||||
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.demo_utils import DemoCompatibilityCheck
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
|
||||
class TextToVideoSynthesisTest(unittest.TestCase, DemoCompatibilityCheck):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.task = Tasks.text_to_video_synthesis
|
||||
self.model_id = 'damo/text-to-video-synthesis'
|
||||
|
||||
test_text = {
|
||||
'text': 'A panda eating bamboo on a rock.',
|
||||
}
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_with_model_from_modelhub(self):
|
||||
pipe_line_text_to_video_synthesis = pipeline(
|
||||
task=self.task, model=self.model_id)
|
||||
output_video_path = pipe_line_text_to_video_synthesis(
|
||||
self.test_text)[OutputKeys.OUTPUT_VIDEO]
|
||||
print(output_video_path)
|
||||
|
||||
@unittest.skip('demo compatibility test is only enabled on a needed-basis')
|
||||
def test_demo_compatibility(self):
|
||||
self.compatibility_check()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user