Merge pull request #472 from kangzhao2/baishao_test

Add image2video
This commit is contained in:
Wang Qiang
2023-08-18 20:29:36 +08:00
committed by GitHub
16 changed files with 3690 additions and 0 deletions

View File

@@ -220,6 +220,7 @@ class Models(object):
stable_diffusion = 'stable-diffusion'
videocomposer = 'videocomposer'
text_to_360panorama_image = 'text-to-360panorama-image'
image_to_video_model = 'image-to-video-model'
# science models
unifold = 'unifold'
@@ -545,6 +546,7 @@ class Pipelines(object):
efficient_diffusion_tuning = 'efficient-diffusion-tuning'
multimodal_dialogue = 'multimodal-dialogue'
llama2_text_generation_pipeline = 'llama2-text-generation-pipeline'
image_to_video_task_pipeline = 'image-to-video-task-pipeline'
# science tasks
protein_structure = 'unifold-protein-structure'

View File

@@ -0,0 +1,24 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import TYPE_CHECKING
from modelscope.utils.import_utils import LazyImportModule
if TYPE_CHECKING:
from .image_to_video_model import ImageToVideo
else:
_import_structure = {
'image_to_video_model': ['ImageToVideo'],
}
import sys
sys.modules[__name__] = LazyImportModule(
__name__,
globals()['__file__'],
_import_structure,
module_spec=__spec__,
extra_objects={},
)

View File

@@ -0,0 +1,215 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import os.path as osp
import random
from copy import copy
from typing import Any, Dict
import torch
import torch.cuda.amp as amp
import modelscope.models.multi_modal.image_to_video.utils.transforms as data
from modelscope.metainfo import Models
from modelscope.models.base import TorchModel
from modelscope.models.builder import MODELS
from modelscope.models.multi_modal.image_to_video.modules import *
from modelscope.models.multi_modal.image_to_video.modules import (
AutoencoderKL, FrozenOpenCLIPVisualEmbedder, Img2VidSDUNet)
from modelscope.models.multi_modal.image_to_video.utils.config import cfg
from modelscope.models.multi_modal.image_to_video.utils.diffusion import \
GaussianDiffusion
from modelscope.models.multi_modal.image_to_video.utils.seed import setup_seed
from modelscope.models.multi_modal.image_to_video.utils.shedule import \
beta_schedule
from modelscope.utils.config import Config
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger
__all__ = ['ImageToVideo']
logger = get_logger()
@MODELS.register_module(
Tasks.image_to_video, module_name=Models.image_to_video_model)
class ImageToVideo(TorchModel):
r"""
Image2Video aims to solve the task of generating high-definition videos based on input images.
Image2Video is a video generation basic model developed by Alibaba Cloud, with a parameter size
of approximately 2 billion. It has been pre trained on large-scale video and image data and
fine-tuned on a small amount of high-quality data. The data is widely distributed and diverse
in categories, and the model has good generalization ability for different types of data
Paper link: https://arxiv.org/abs/2306.02018
Attributes:
diffusion: diffusion model for DDIM.
autoencoder: decode the latent representation into visual space.
clip_encoder: encode the image into image embedding.
"""
def __init__(self, model_dir, *args, **kwargs):
r"""
Args:
model_dir (`str` or `os.PathLike`)
Can be either:
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co
or modelscope.cn. Valid model ids can be located at the root-level, like `bert-base-uncased`,
or namespaced under a user or organization name, like `dbmdz/bert-base-german-cased`.
- A path to a *directory* containing model weights saved using
[`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
- A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In
this case, `from_tf` should be set to `True` and a configuration object should be provided as
`config` argument. This loading path is slower than converting the TensorFlow checkpoint in a
PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
- A path or url to a model folder containing a *flax checkpoint file* in *.msgpack* format (e.g,
`./flax_model/` containing `flax_model.msgpack`). In this case, `from_flax` should be set to
`True`.
"""
super().__init__(model_dir=model_dir, *args, **kwargs)
self.config = Config.from_file(
osp.join(model_dir, ModelFile.CONFIGURATION))
# assign default value
cfg.batch_size = self.config.model.model_cfg.batch_size
cfg.target_fps = self.config.model.model_cfg.target_fps
cfg.max_frames = self.config.model.model_cfg.max_frames
cfg.latent_hei = self.config.model.model_cfg.latent_hei
cfg.latent_wid = self.config.model.model_cfg.latent_wid
cfg.model_path = osp.join(model_dir,
self.config.model.model_args.ckpt_unet)
self.device = torch.device(
'cuda') if torch.cuda.is_available() else torch.device('cpu')
if 'seed' in self.config.model.model_args.keys():
cfg.seed = self.config.model.model_args.seed
else:
cfg.seed = random.randint(0, 99999)
setup_seed(cfg.seed)
# transform
vid_trans = data.Compose([
data.CenterCropWide(size=(cfg.resolution[0], cfg.resolution[0])),
data.Resize(cfg.vit_resolution),
data.ToTensor(),
data.Normalize(mean=cfg.vit_mean, std=cfg.vit_std)
])
self.vid_trans = vid_trans
cfg.embedder.pretrained = osp.join(
model_dir, self.config.model.model_args.ckpt_clip)
clip_encoder = FrozenOpenCLIPVisualEmbedder(**cfg.embedder)
clip_encoder.model.to(self.device)
self.clip_encoder = clip_encoder
logger.info(f'Build encoder with {cfg.embedder.type}')
# [unet]
generator = Img2VidSDUNet(**cfg.UNet)
generator = generator.to(self.device)
generator.eval()
load_dict = torch.load(cfg.model_path, map_location='cpu')
ret = generator.load_state_dict(load_dict['state_dict'], strict=True)
self.generator = generator
logger.info('Load model {} path {}, with local status {}'.format(
cfg.UNet.type, cfg.model_path, ret))
# [diffusion]
betas = beta_schedule(
'linear_sd',
cfg.num_timesteps,
init_beta=0.00085,
last_beta=0.0120)
diffusion = GaussianDiffusion(
betas=betas,
mean_type=cfg.mean_type,
var_type=cfg.var_type,
loss_type=cfg.loss_type,
rescale_timesteps=False,
noise_strength=getattr(cfg, 'noise_strength', 0))
self.diffusion = diffusion
logger.info('Build diffusion with type of GaussianDiffusion')
# [auotoencoder]
cfg.auto_encoder.pretrained = osp.join(
model_dir, self.config.model.model_args.ckpt_autoencoder)
autoencoder = AutoencoderKL(**cfg.auto_encoder)
autoencoder.eval()
for param in autoencoder.parameters():
param.requires_grad = False
autoencoder.to(self.device)
self.autoencoder = autoencoder
torch.cuda.empty_cache()
zero_feature = torch.zeros(1, 1, cfg.UNet.input_dim).to(self.device)
self.zero_feature = zero_feature
self.fps_tensor = torch.tensor([cfg.target_fps],
dtype=torch.long,
device=self.device)
self.cfg = cfg
def forward(self, input: Dict[str, Any]):
r"""
The entry function of image to video task.
1. Using diffusion model to generate the video's latent representation.
2. Using autoencoder to decode the video's latent representation to visual space.
Args:
input (`Dict[Str, Any]`):
The input of the task
Returns:
A generated video (as pytorch tensor).
"""
vit_frame = input['vit_frame']
cfg = self.cfg
img_embedding = self.clip_encoder(vit_frame).unsqueeze(1)
noise = self.build_noise()
zero_feature = copy(self.zero_feature)
with torch.no_grad():
with amp.autocast(enabled=cfg.use_fp16):
model_kwargs = [{
'y': img_embedding,
'fps': self.fps_tensor
}, {
'y': zero_feature.repeat(cfg.batch_size, 1, 1),
'fps': self.fps_tensor
}]
gen_video = self.diffusion.ddim_sample_loop(
noise=noise,
model=self.generator,
model_kwargs=model_kwargs,
guide_scale=cfg.guide_scale,
ddim_timesteps=cfg.ddim_timesteps,
eta=0.0)
gen_video = 1. / cfg.scale_factor * gen_video
gen_video = rearrange(gen_video, 'b c f h w -> (b f) c h w')
chunk_size = min(cfg.decoder_bs, gen_video.shape[0])
gen_video_list = torch.chunk(
gen_video, gen_video.shape[0] // chunk_size, dim=0)
decode_generator = []
for vd_data in gen_video_list:
gen_frames = self.autoencoder.decode(vd_data)
decode_generator.append(gen_frames)
gen_video = torch.cat(decode_generator, dim=0)
gen_video = rearrange(
gen_video, '(b f) c h w -> b c f h w', b=cfg.batch_size)
return gen_video.type(torch.float32).cpu()
def build_noise(self):
cfg = self.cfg
noise = torch.randn(
[1, 4, cfg.max_frames, cfg.latent_hei,
cfg.latent_wid]).to(self.device)
if cfg.noise_strength > 0:
b, c, f, *_ = noise.shape
offset_noise = torch.randn(b, c, f, 1, 1, device=noise.device)
noise = noise + cfg.noise_strength * offset_noise
return noise.contiguous()

View File

@@ -0,0 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from .autoencoder import *
from .embedder import *
from .unet_i2v import *

View File

@@ -0,0 +1,573 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import collections
import logging
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
def nonlinearity(x):
# swish
return x * torch.sigmoid(x)
def Normalize(in_channels, num_groups=32):
return torch.nn.GroupNorm(
num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
class DiagonalGaussianDistribution(object):
def __init__(self, parameters, deterministic=False):
self.parameters = parameters
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.deterministic = deterministic
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
if self.deterministic:
self.var = self.std = torch.zeros_like(
self.mean).to(device=self.parameters.device)
def sample(self):
x = self.mean + self.std * torch.randn(
self.mean.shape).to(device=self.parameters.device)
return x
def kl(self, other=None):
if self.deterministic:
return torch.Tensor([0.])
else:
if other is None:
return 0.5 * torch.sum(
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
dim=[1, 2, 3])
else:
return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
dim=[1, 2, 3])
def nll(self, sample, dims=[1, 2, 3]):
if self.deterministic:
return torch.Tensor([0.])
logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum(
logtwopi + self.logvar
+ torch.pow(sample - self.mean, 2) / self.var,
dim=dims)
def mode(self):
return self.mean
class ResnetBlock(nn.Module):
def __init__(self,
*,
in_channels,
out_channels=None,
conv_shortcut=False,
dropout,
temb_channels=512):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.norm1 = Normalize(in_channels)
self.conv1 = torch.nn.Conv2d(
in_channels, out_channels, kernel_size=3, stride=1, padding=1)
if temb_channels > 0:
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
self.norm2 = Normalize(out_channels)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = torch.nn.Conv2d(
out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = torch.nn.Conv2d(
in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
else:
self.nin_shortcut = torch.nn.Conv2d(
in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0)
def forward(self, x, temb):
h = x
h = self.norm1(h)
h = nonlinearity(h)
h = self.conv1(h)
if temb is not None:
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
h = self.norm2(h)
h = nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)
return x + h
class AttnBlock(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.k = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
b, c, h, w = q.shape
q = q.reshape(b, c, h * w)
q = q.permute(0, 2, 1)
k = k.reshape(b, c, h * w)
w_ = torch.bmm(q, k)
w_ = w_ * (int(c)**(-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)
# attend to values
v = v.reshape(b, c, h * w)
w_ = w_.permute(0, 2, 1)
h_ = torch.bmm(v, w_)
h_ = h_.reshape(b, c, h, w)
h_ = self.proj_out(h_)
return x + h_
class Upsample(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
self.conv = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=1, padding=1)
def forward(self, x):
x = torch.nn.functional.interpolate(
x, scale_factor=2.0, mode='nearest')
if self.with_conv:
x = self.conv(x)
return x
class Downsample(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
# no asymmetric padding in torch conv, must do it ourselves
self.conv = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=2, padding=0)
def forward(self, x):
if self.with_conv:
pad = (0, 1, 0, 1)
x = torch.nn.functional.pad(x, pad, mode='constant', value=0)
x = self.conv(x)
else:
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
return x
class Encoder(nn.Module):
def __init__(self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
z_channels,
double_z=True,
use_linear_attn=False,
attn_type='vanilla',
**ignore_kwargs):
super().__init__()
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
# downsampling
self.conv_in = torch.nn.Conv2d(
in_channels, self.ch, kernel_size=3, stride=1, padding=1)
curr_res = resolution
in_ch_mult = (1, ) + tuple(ch_mult)
self.in_ch_mult = in_ch_mult
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks):
block.append(
ResnetBlock(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(AttnBlock(block_in))
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions - 1:
down.downsample = Downsample(block_in, resamp_with_conv)
curr_res = curr_res // 2
self.down.append(down)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
self.mid.attn_1 = AttnBlock(block_in)
self.mid.block_2 = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(
block_in,
2 * z_channels if double_z else z_channels,
kernel_size=3,
stride=1,
padding=1)
def forward(self, x):
# timestep embedding
temb = None
# downsampling
hs = [self.conv_in(x)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1], temb)
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
hs.append(h)
if i_level != self.num_resolutions - 1:
hs.append(self.down[i_level].downsample(hs[-1]))
# middle
h = hs[-1]
h = self.mid.block_1(h, temb)
h = self.mid.attn_1(h)
h = self.mid.block_2(h, temb)
# end
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
return h
class Decoder(nn.Module):
def __init__(self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
z_channels,
give_pre_end=False,
tanh_out=False,
use_linear_attn=False,
attn_type='vanilla',
**ignorekwargs):
super().__init__()
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.give_pre_end = give_pre_end
self.tanh_out = tanh_out
# compute block_in and curr_res at lowest res
block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2**(self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res)
logging.info('Working with z of shape {} = {} dimensions.'.format(
self.z_shape, np.prod(self.z_shape)))
# z to block_in
self.conv_in = torch.nn.Conv2d(
z_channels, block_in, kernel_size=3, stride=1, padding=1)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
self.mid.attn_1 = AttnBlock(block_in)
self.mid.block_2 = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
block.append(
ResnetBlock(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(AttnBlock(block_in))
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = Upsample(block_in, resamp_with_conv)
curr_res = curr_res * 2
self.up.insert(0, up)
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(
block_in, out_ch, kernel_size=3, stride=1, padding=1)
def forward(self, z):
self.last_z_shape = z.shape
# timestep embedding
temb = None
# z to block_in
h = self.conv_in(z)
# middle
h = self.mid.block_1(h, temb)
h = self.mid.attn_1(h)
h = self.mid.block_2(h, temb)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](h, temb)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
if i_level != 0:
h = self.up[i_level].upsample(h)
# end
if self.give_pre_end:
return h
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
if self.tanh_out:
h = torch.tanh(h)
return h
class AutoencoderKL(nn.Module):
def __init__(self,
ddconfig,
embed_dim,
pretrained=None,
ignore_keys=[],
image_key='image',
colorize_nlabels=None,
monitor=None,
ema_decay=None,
learn_logvar=False,
**kwargs):
super().__init__()
self.learn_logvar = learn_logvar
self.image_key = image_key
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
assert ddconfig['double_z']
self.quant_conv = torch.nn.Conv2d(2 * ddconfig['z_channels'],
2 * embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim,
ddconfig['z_channels'], 1)
self.embed_dim = embed_dim
if colorize_nlabels is not None:
assert type(colorize_nlabels) == int
self.register_buffer('colorize',
torch.randn(3, colorize_nlabels, 1, 1))
if monitor is not None:
self.monitor = monitor
self.use_ema = ema_decay is not None
if pretrained is not None:
self.init_from_ckpt(pretrained, ignore_keys=ignore_keys)
def init_from_ckpt(self, path, ignore_keys=list()):
sd = torch.load(path, map_location='cpu')['state_dict']
keys = list(sd.keys())
sd_new = collections.OrderedDict()
for k in keys:
if k.find('first_stage_model') >= 0:
k_new = k.split('first_stage_model.')[-1]
sd_new[k_new] = sd[k]
self.load_state_dict(sd_new, strict=True)
logging.info(f'Restored from {path}')
def on_train_batch_end(self, *args, **kwargs):
if self.use_ema:
self.model_ema(self)
def encode(self, x):
h = self.encoder(x)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
return posterior
def decode(self, z):
z = self.post_quant_conv(z)
dec = self.decoder(z)
return dec
def forward(self, input, sample_posterior=True):
posterior = self.encode(input)
if sample_posterior:
z = posterior.sample()
else:
z = posterior.mode()
dec = self.decode(z)
return dec, posterior
def get_input(self, batch, k):
x = batch[k]
if len(x.shape) == 3:
x = x[..., None]
x = x.permute(0, 3, 1,
2).to(memory_format=torch.contiguous_format).float()
return x
def get_last_layer(self):
return self.decoder.conv_out.weight
@torch.no_grad()
def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs):
log = dict()
x = self.get_input(batch, self.image_key)
x = x.to(self.device)
if not only_inputs:
xrec, posterior = self(x)
if x.shape[1] > 3:
# colorize with random projection
assert xrec.shape[1] > 3
x = self.to_rgb(x)
xrec = self.to_rgb(xrec)
log['samples'] = self.decode(torch.randn_like(posterior.sample()))
log['reconstructions'] = xrec
if log_ema or self.use_ema:
with self.ema_scope():
xrec_ema, posterior_ema = self(x)
if x.shape[1] > 3:
# colorize with random projection
assert xrec_ema.shape[1] > 3
xrec_ema = self.to_rgb(xrec_ema)
log['samples_ema'] = self.decode(
torch.randn_like(posterior_ema.sample()))
log['reconstructions_ema'] = xrec_ema
log['inputs'] = x
return log
def to_rgb(self, x):
assert self.image_key == 'segmentation'
if not hasattr(self, 'colorize'):
self.register_buffer('colorize',
torch.randn(3, x.shape[1], 1, 1).to(x))
x = F.conv2d(x, weight=self.colorize)
x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
return x
class IdentityFirstStage(torch.nn.Module):
def __init__(self, *args, vq_interface=False, **kwargs):
self.vq_interface = vq_interface
super().__init__()
def encode(self, x, *args, **kwargs):
return x
def decode(self, x, *args, **kwargs):
return x
def quantize(self, x, *args, **kwargs):
if self.vq_interface:
return x, None, [None, None, None]
return x
def forward(self, x, *args, **kwargs):
return x

View File

@@ -0,0 +1,82 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import logging
import os
import numpy as np
import open_clip
import torch
import torch.nn as nn
import torchvision.transforms as T
class FrozenOpenCLIPVisualEmbedder(nn.Module):
"""
Uses the OpenCLIP transformer encoder for text
"""
LAYERS = ['last', 'penultimate']
def __init__(self,
pretrained,
vit_resolution=(224, 224),
arch='ViT-H-14',
device='cuda',
max_length=77,
freeze=True,
layer='last',
**kwargs):
super().__init__()
assert layer in self.LAYERS
model, _, preprocess = open_clip.create_model_and_transforms(
arch, device=torch.device('cpu'), pretrained=pretrained)
del model.transformer
self.model = model
data_white = np.ones(
(vit_resolution[0], vit_resolution[1], 3), dtype=np.uint8) * 255
self.white_image = preprocess(T.ToPILImage()(data_white)).unsqueeze(0)
self.device = device
self.max_length = max_length
if freeze:
self.freeze()
self.layer = layer
if self.layer == 'last':
self.layer_idx = 0
elif self.layer == 'penultimate':
self.layer_idx = 1
else:
raise NotImplementedError()
def freeze(self):
self.model = self.model.eval()
for param in self.parameters():
param.requires_grad = False
def forward(self, image):
z = self.model.encode_image(image.to(self.device))
return z
def encode_with_transformer(self, text):
x = self.model.token_embedding(text)
x = x + self.model.positional_embedding
x = x.permute(1, 0, 2)
x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
x = x.permute(1, 0, 2)
x = self.model.ln_final(x)
return x
def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
for i, r in enumerate(self.model.transformer.resblocks):
if i == len(self.model.transformer.resblocks) - self.layer_idx:
break
if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting(
):
x = checkpoint(r, x, attn_mask)
else:
x = r(x, attn_mask=attn_mask)
return x
def encode(self, text):
return self(text)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,2 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os

View File

@@ -0,0 +1,161 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import logging
import os
import os.path as osp
from datetime import datetime
import torch
from easydict import EasyDict
cfg = EasyDict(__name__='Config: VideoLDM Decoder')
# ---------------------------work dir--------------------------
cfg.work_dir = 'workspace/'
# ---------------------------Global Variable-----------------------------------
cfg.resolution = [448, 256]
# -----------------------------------------------------------------------------
# ---------------------------Dataset Parameter---------------------------------
cfg.mean = [0.5, 0.5, 0.5]
cfg.std = [0.5, 0.5, 0.5]
cfg.max_words = 1000
# PlaceHolder
cfg.vit_out_dim = 1024
cfg.vit_resolution = [224, 224]
cfg.depth_clamp = 10.0
cfg.misc_size = 384
cfg.depth_std = 20.0
cfg.frame_lens = 32
cfg.sample_fps = 8
cfg.batch_sizes = 1
# -----------------------------------------------------------------------------
# ---------------------------Mode Parameters-----------------------------------
# Diffusion
cfg.schedule = 'cosine'
cfg.num_timesteps = 1000
cfg.mean_type = 'v'
cfg.var_type = 'fixed_small'
cfg.loss_type = 'mse'
cfg.ddim_timesteps = 50
cfg.ddim_eta = 0.0
cfg.clamp = 1.0
cfg.share_noise = False
cfg.use_div_loss = False
cfg.noise_strength = 0.1
# classifier-free guidance
cfg.p_zero = 0.1
cfg.guide_scale = 3.0
# clip vision encoder
cfg.vit_mean = [0.48145466, 0.4578275, 0.40821073]
cfg.vit_std = [0.26862954, 0.26130258, 0.27577711]
# Model
cfg.scale_factor = 0.18215
cfg.use_fp16 = True
cfg.temporal_attention = True
cfg.decoder_bs = 8
cfg.UNet = {
'type': 'Img2VidSDUNet',
'in_dim': 4,
'dim': 320,
'y_dim': cfg.vit_out_dim,
'context_dim': 1024,
'out_dim': 8 if cfg.var_type.startswith('learned') else 4,
'dim_mult': [1, 2, 4, 4],
'num_heads': 8,
'head_dim': 64,
'num_res_blocks': 2,
'attn_scales': [1 / 1, 1 / 2, 1 / 4],
'dropout': 0.1,
'temporal_attention': cfg.temporal_attention,
'temporal_attn_times': 1,
'use_checkpoint': False,
'use_fps_condition': False,
'use_sim_mask': False,
'num_tokens': 4,
'default_fps': 8,
'input_dim': 1024
}
cfg.guidances = []
# auotoencoder from stabel diffusion
cfg.auto_encoder = {
'type': 'AutoencoderKL',
'ddconfig': {
'double_z': True,
'z_channels': 4,
'resolution': 256,
'in_channels': 3,
'out_ch': 3,
'ch': 128,
'ch_mult': [1, 2, 4, 4],
'num_res_blocks': 2,
'attn_resolutions': [],
'dropout': 0.0
},
'embed_dim': 4,
'pretrained': 'v2-1_512-ema-pruned.ckpt'
}
# clip embedder
cfg.embedder = {
'type': 'FrozenOpenCLIPVisualEmbedder',
'layer': 'penultimate',
'vit_resolution': [224, 224],
'pretrained': 'open_clip_pytorch_model.bin'
}
# -----------------------------------------------------------------------------
# ---------------------------Training Settings---------------------------------
# training and optimizer
cfg.ema_decay = 0.9999
cfg.num_steps = 600000
cfg.lr = 5e-5
cfg.weight_decay = 0.0
cfg.betas = (0.9, 0.999)
cfg.eps = 1.0e-8
cfg.chunk_size = 16
cfg.alpha = 0.7
cfg.save_ckp_interval = 1000
# -----------------------------------------------------------------------------
# ----------------------------Pretrain Settings---------------------------------
# Default: load 2d pretrain
cfg.fix_weight = False
cfg.load_match = False
cfg.pretrained_checkpoint = 'v2-1_512-ema-pruned.ckpt'
cfg.pretrained_image_keys = 'stable_diffusion_image_key_temporal_attention_x1.json'
cfg.resume_checkpoint = 'img2video_ldm_0779000.pth'
# -----------------------------------------------------------------------------
# -----------------------------Visual-------------------------------------------
# Visual videos
cfg.viz_interval = 1000
cfg.visual_train = {
'type': 'VisualVideoTextDuringTrain',
}
cfg.visual_inference = {
'type': 'VisualGeneratedVideos',
}
cfg.inference_list_path = ''
# logging
cfg.log_interval = 100
# Default log_dir
cfg.log_dir = 'workspace/output_data'
# -----------------------------------------------------------------------------
# ---------------------------Others--------------------------------------------
# seed
cfg.seed = 8888
# -----------------------------------------------------------------------------

View File

@@ -0,0 +1,511 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import math
import torch
__all__ = ['GaussianDiffusion', 'beta_schedule']
def _i(tensor, t, x):
r"""Index tensor using t and format the output according to x.
"""
shape = (x.size(0), ) + (1, ) * (x.ndim - 1)
if tensor.device != x.device:
tensor = tensor.to(x.device)
return tensor[t].view(shape).to(x)
def fn(u):
return math.cos((u + 0.008) / 1.008 * math.pi / 2)**2
def beta_schedule(schedule,
num_timesteps=1000,
init_beta=None,
last_beta=None):
if schedule == 'linear':
scale = 1000.0 / num_timesteps
init_beta = init_beta or scale * 0.0001
last_beta = last_beta or scale * 0.02
return torch.linspace(
init_beta, last_beta, num_timesteps, dtype=torch.float64)
elif schedule == 'quadratic':
init_beta = init_beta or 0.0015
last_beta = last_beta or 0.0195
return torch.linspace(
init_beta**0.5, last_beta**0.5, num_timesteps,
dtype=torch.float64)**2
elif schedule == 'cosine':
betas = []
for step in range(num_timesteps):
t1 = step / num_timesteps
t2 = (step + 1) / num_timesteps
betas.append(min(1.0 - fn(t2) / fn(t1), 0.999))
return torch.tensor(betas, dtype=torch.float64)
else:
raise ValueError(f'Unsupported schedule: {schedule}')
class GaussianDiffusion(object):
def __init__(self,
betas,
mean_type='eps',
var_type='learned_range',
loss_type='mse',
epsilon=1e-12,
rescale_timesteps=False,
noise_strength=0.0):
# check input
if not isinstance(betas, torch.DoubleTensor):
betas = torch.tensor(betas, dtype=torch.float64)
assert min(betas) > 0 and max(betas) <= 1
assert mean_type in ['x0', 'x_{t-1}', 'eps', 'v']
assert var_type in [
'learned', 'learned_range', 'fixed_large', 'fixed_small'
]
assert loss_type in [
'mse', 'rescaled_mse', 'kl', 'rescaled_kl', 'l1', 'rescaled_l1',
'charbonnier'
]
self.betas = betas
self.num_timesteps = len(betas)
self.mean_type = mean_type
self.var_type = var_type
self.loss_type = loss_type
self.epsilon = epsilon
self.rescale_timesteps = rescale_timesteps
self.noise_strength = noise_strength
# alphas
alphas = 1 - self.betas
self.alphas_cumprod = torch.cumprod(alphas, dim=0)
self.alphas_cumprod_prev = torch.cat(
[alphas.new_ones([1]), self.alphas_cumprod[:-1]])
self.alphas_cumprod_next = torch.cat(
[self.alphas_cumprod[1:],
alphas.new_zeros([1])])
# q(x_t | x_{t-1})
self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0
- self.alphas_cumprod)
self.log_one_minus_alphas_cumprod = torch.log(1.0
- self.alphas_cumprod)
self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod)
self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod
- 1)
# q(x_{t-1} | x_t, x_0)
self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / (
1.0 - self.alphas_cumprod)
self.posterior_log_variance_clipped = torch.log(
self.posterior_variance.clamp(1e-20))
self.posterior_mean_coef1 = betas * torch.sqrt(
self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
self.posterior_mean_coef2 = (
1.0 - self.alphas_cumprod_prev) * torch.sqrt(alphas) / (
1.0 - self.alphas_cumprod)
def sample_loss(self, x0, noise=None):
if noise is None:
noise = torch.randn_like(x0)
if self.noise_strength > 0:
b, c, f, _, _ = x0.shape
offset_noise = torch.randn(b, c, f, 1, 1, device=x0.device)
noise = noise + self.noise_strength * offset_noise
return noise
def q_sample(self, x0, t, noise=None):
r"""Sample from q(x_t | x_0).
"""
# noise = torch.randn_like(x0) if noise is None else noise
noise = self.sample_loss(x0, noise)
return _i(self.sqrt_alphas_cumprod, t, x0) * x0 + (
_i(self.sqrt_one_minus_alphas_cumprod, t, x0) * noise)
def q_mean_variance(self, x0, t):
r"""Distribution of q(x_t | x_0).
"""
mu = _i(self.sqrt_alphas_cumprod, t, x0) * x0
var = _i(1.0 - self.alphas_cumprod, t, x0)
log_var = _i(self.log_one_minus_alphas_cumprod, t, x0)
return mu, var, log_var
def q_posterior_mean_variance(self, x0, xt, t):
r"""Distribution of q(x_{t-1} | x_t, x_0).
"""
mu = _i(self.posterior_mean_coef1, t, xt) * x0 + _i(
self.posterior_mean_coef2, t, xt) * xt
var = _i(self.posterior_variance, t, xt)
log_var = _i(self.posterior_log_variance_clipped, t, xt)
return mu, var, log_var
@torch.no_grad()
def p_sample(self,
xt,
t,
model,
model_kwargs={},
clamp=None,
percentile=None,
condition_fn=None,
guide_scale=None):
r"""Sample from p(x_{t-1} | x_t).
- condition_fn: for classifier-based guidance (guided-diffusion).
- guide_scale: for classifier-free guidance (glide/dalle-2).
"""
# predict distribution of p(x_{t-1} | x_t)
mu, var, log_var, x0 = self.p_mean_variance(xt, t, model, model_kwargs,
clamp, percentile,
guide_scale)
# random sample (with optional conditional function)
noise = torch.randn_like(xt)
mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1)))
if condition_fn is not None:
grad = condition_fn(xt, self._scale_timesteps(t), **model_kwargs)
mu = mu.float() + var * grad.float()
xt_1 = mu + mask * torch.exp(0.5 * log_var) * noise
return xt_1, x0
@torch.no_grad()
def p_sample_loop(self,
noise,
model,
model_kwargs={},
clamp=None,
percentile=None,
condition_fn=None,
guide_scale=None):
r"""Sample from p(x_{t-1} | x_t) p(x_{t-2} | x_{t-1}) ... p(x_0 | x_1).
"""
# prepare input
b = noise.size(0)
xt = noise
# diffusion process
for step in torch.arange(self.num_timesteps).flip(0):
t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
xt, _ = self.p_sample(xt, t, model, model_kwargs, clamp,
percentile, condition_fn, guide_scale)
return xt
def p_mean_variance(self,
xt,
t,
model,
model_kwargs={},
clamp=None,
percentile=None,
guide_scale=None):
r"""Distribution of p(x_{t-1} | x_t).
"""
# predict distribution
if guide_scale is None:
out = model(xt, self._scale_timesteps(t), **model_kwargs)
else:
# classifier-free guidance
# (model_kwargs[0]: conditional kwargs; model_kwargs[1]: non-conditional kwargs)
assert isinstance(model_kwargs, list) and len(model_kwargs) == 2
y_out = model(xt, self._scale_timesteps(t), **model_kwargs[0])
u_out = model(xt, self._scale_timesteps(t), **model_kwargs[1])
dim = y_out.size(1) if self.var_type.startswith(
'fixed') else y_out.size(1) // 2
out = torch.cat(
[
u_out[:, :dim] + guide_scale * # noqa
(y_out[:, :dim] - u_out[:, :dim]),
y_out[:, dim:]
],
dim=1)
# compute variance
if self.var_type == 'learned':
out, log_var = out.chunk(2, dim=1)
var = torch.exp(log_var)
elif self.var_type == 'learned_range':
out, fraction = out.chunk(2, dim=1)
min_log_var = _i(self.posterior_log_variance_clipped, t, xt)
max_log_var = _i(torch.log(self.betas), t, xt)
fraction = (fraction + 1) / 2.0
log_var = fraction * max_log_var + (1 - fraction) * min_log_var
var = torch.exp(log_var)
elif self.var_type == 'fixed_large':
var = _i(
torch.cat([self.posterior_variance[1:2], self.betas[1:]]), t,
xt)
log_var = torch.log(var)
elif self.var_type == 'fixed_small':
var = _i(self.posterior_variance, t, xt)
log_var = _i(self.posterior_log_variance_clipped, t, xt)
# compute mean and x0
if self.mean_type == 'x_{t-1}':
mu = out
x0 = _i(1.0 / self.posterior_mean_coef1, t, xt) * mu - (
_i(self.posterior_mean_coef2 / self.posterior_mean_coef1, t,
xt) * xt)
elif self.mean_type == 'x0':
x0 = out
mu, _, _ = self.q_posterior_mean_variance(x0, xt, t)
elif self.mean_type == 'eps':
x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - (
_i(self.sqrt_recipm1_alphas_cumprod, t, xt) * out)
mu, _, _ = self.q_posterior_mean_variance(x0, xt, t)
elif self.mean_type == 'v':
x0 = _i(self.sqrt_alphas_cumprod, t, xt) * xt - (
_i(self.sqrt_one_minus_alphas_cumprod, t, xt) * out)
mu, _, _ = self.q_posterior_mean_variance(x0, xt, t)
# restrict the range of x0
if percentile is not None:
assert percentile > 0 and percentile <= 1
s = torch.quantile(
x0.flatten(1).abs(), percentile,
dim=1).clamp_(1.0).view(-1, 1, 1, 1)
x0 = torch.min(s, torch.max(-s, x0)) / s
elif clamp is not None:
x0 = x0.clamp(-clamp, clamp)
return mu, var, log_var, x0
@torch.no_grad()
def ddim_sample(self,
xt,
t,
model,
model_kwargs={},
clamp=None,
percentile=None,
condition_fn=None,
guide_scale=None,
ddim_timesteps=20,
eta=0.0):
r"""Sample from p(x_{t-1} | x_t) using DDIM.
- condition_fn: for classifier-based guidance (guided-diffusion).
- guide_scale: for classifier-free guidance (glide/dalle-2).
"""
stride = self.num_timesteps // ddim_timesteps
# predict distribution of p(x_{t-1} | x_t)
_, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp,
percentile, guide_scale)
if condition_fn is not None:
# x0 -> eps
alpha = _i(self.alphas_cumprod, t, xt)
eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / (
_i(self.sqrt_recipm1_alphas_cumprod, t, xt))
eps = eps - (1 - alpha).sqrt() * condition_fn(
xt, self._scale_timesteps(t), **model_kwargs)
# eps -> x0
x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - (
_i(self.sqrt_recipm1_alphas_cumprod, t, xt) * eps)
# derive variables
eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / (
_i(self.sqrt_recipm1_alphas_cumprod, t, xt))
alphas = _i(self.alphas_cumprod, t, xt)
alphas_prev = _i(self.alphas_cumprod, (t - stride).clamp(0), xt)
sigmas = eta * torch.sqrt((1 - alphas_prev) / (1 - alphas) * # noqa
(1 - alphas / alphas_prev))
# random sample
noise = torch.randn_like(xt)
direction = torch.sqrt(1 - alphas_prev - sigmas**2) * eps
mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1)))
xt_1 = torch.sqrt(alphas_prev) * x0 + direction + mask * sigmas * noise
return xt_1, x0
@torch.no_grad()
def ddim_sample_loop(self,
noise,
model,
model_kwargs={},
clamp=None,
percentile=None,
condition_fn=None,
guide_scale=None,
ddim_timesteps=20,
eta=0.0):
# prepare input
b = noise.size(0)
xt = noise
# diffusion process (TODO: clamp is inaccurate! Consider replacing the stride by explicit prev/next steps)
steps = (1 + torch.arange(0, self.num_timesteps,
self.num_timesteps // ddim_timesteps)).clamp(
0, self.num_timesteps - 1).flip(0)
for step in steps:
t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
xt, _ = self.ddim_sample(xt, t, model, model_kwargs, clamp,
percentile, condition_fn, guide_scale,
ddim_timesteps, eta)
return xt
@torch.no_grad()
def ddim_reverse_sample(self,
xt,
t,
model,
model_kwargs={},
clamp=None,
percentile=None,
guide_scale=None,
ddim_timesteps=20):
r"""Sample from p(x_{t+1} | x_t) using DDIM reverse ODE (deterministic).
"""
stride = self.num_timesteps // ddim_timesteps
# predict distribution of p(x_{t-1} | x_t)
_, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp,
percentile, guide_scale)
# derive variables
eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / (
_i(self.sqrt_recipm1_alphas_cumprod, t, xt))
alphas_next = _i(
torch.cat(
[self.alphas_cumprod,
self.alphas_cumprod.new_zeros([1])]),
(t + stride).clamp(0, self.num_timesteps), xt)
# reverse sample
mu = torch.sqrt(alphas_next) * x0 + torch.sqrt(1 - alphas_next) * eps
return mu, x0
@torch.no_grad()
def ddim_reverse_sample_loop(self,
x0,
model,
model_kwargs={},
clamp=None,
percentile=None,
guide_scale=None,
ddim_timesteps=20):
# prepare input
b = x0.size(0)
xt = x0
# reconstruction steps
steps = torch.arange(0, self.num_timesteps,
self.num_timesteps // ddim_timesteps)
for step in steps:
t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
xt, _ = self.ddim_reverse_sample(xt, t, model, model_kwargs, clamp,
percentile, guide_scale,
ddim_timesteps)
return xt
@torch.no_grad()
def plms_sample(self,
xt,
t,
model,
model_kwargs={},
clamp=None,
percentile=None,
condition_fn=None,
guide_scale=None,
plms_timesteps=20):
r"""Sample from p(x_{t-1} | x_t) using PLMS.
- condition_fn: for classifier-based guidance (guided-diffusion).
- guide_scale: for classifier-free guidance (glide/dalle-2).
"""
stride = self.num_timesteps // plms_timesteps
# function for compute eps
def compute_eps(xt, t):
# predict distribution of p(x_{t-1} | x_t)
_, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs,
clamp, percentile, guide_scale)
# condition
if condition_fn is not None:
# x0 -> eps
alpha = _i(self.alphas_cumprod, t, xt)
eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / (
_i(self.sqrt_recipm1_alphas_cumprod, t, xt))
eps = eps - (1 - alpha).sqrt() * condition_fn(
xt, self._scale_timesteps(t), **model_kwargs)
# eps -> x0
x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - (
_i(self.sqrt_recipm1_alphas_cumprod, t, xt) * eps)
# derive eps
eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / (
_i(self.sqrt_recipm1_alphas_cumprod, t, xt))
return eps
# function for compute x_0 and x_{t-1}
def compute_x0(eps, t):
# eps -> x0
x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - (
_i(self.sqrt_recipm1_alphas_cumprod, t, xt) * eps)
# deterministic sample
alphas_prev = _i(self.alphas_cumprod, (t - stride).clamp(0), xt)
direction = torch.sqrt(1 - alphas_prev) * eps
xt_1 = torch.sqrt(alphas_prev) * x0 + direction
return xt_1, x0
# PLMS sample
eps = compute_eps(xt, t)
if len(eps_cache) == 0:
# 2nd order pseudo improved Euler
xt_1, x0 = compute_x0(eps, t)
eps_next = compute_eps(xt_1, (t - stride).clamp(0))
eps_prime = (eps + eps_next) / 2.0
elif len(eps_cache) == 1:
# 2nd order pseudo linear multistep (Adams-Bashforth)
eps_prime = (3 * eps - eps_cache[-1]) / 2.0
elif len(eps_cache) == 2:
# 3nd order pseudo linear multistep (Adams-Bashforth)
eps_prime = (23 * eps - 16 * eps_cache[-1]
+ 5 * eps_cache[-2]) / 12.0
elif len(eps_cache) >= 3:
# 4nd order pseudo linear multistep (Adams-Bashforth)
eps_prime = (55 * eps - 59 * eps_cache[-1] + 37 * eps_cache[-2]
- 9 * eps_cache[-3]) / 24.0
xt_1, x0 = compute_x0(eps_prime, t)
return xt_1, x0, eps
@torch.no_grad()
def plms_sample_loop(self,
noise,
model,
model_kwargs={},
clamp=None,
percentile=None,
condition_fn=None,
guide_scale=None,
plms_timesteps=20):
# prepare input
b = noise.size(0)
xt = noise
# diffusion process
steps = (1 + torch.arange(0, self.num_timesteps,
self.num_timesteps // plms_timesteps)).clamp(
0, self.num_timesteps - 1).flip(0)
eps_cache = []
for step in steps:
# PLMS sampling step
t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
xt, _, eps = self.plms_sample(xt, t, model, model_kwargs, clamp,
percentile, condition_fn,
guide_scale, plms_timesteps,
eps_cache)
# update eps cache
eps_cache.append(eps)
if len(eps_cache) >= 4:
eps_cache.pop(0)
return xt
def _scale_timesteps(self, t):
if self.rescale_timesteps:
return t.float() * 1000.0 / self.num_timesteps
return t

View File

@@ -0,0 +1,14 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import random
import numpy as np
import torch
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True

View File

@@ -0,0 +1,60 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import math
import torch
def fn(u):
return math.cos((u + 0.008) / 1.008 * math.pi / 2)**2
def beta_schedule(schedule,
num_timesteps=1000,
init_beta=None,
last_beta=None):
'''
This code defines a function beta_schedule that generates a sequence of beta values based on the given input
parameters. These beta values can be used in video diffusion processes. The function has the following parameters:
schedule(str): Determines the type of beta schedule to be generated. It can be 'linear', 'linear_sd',
'quadratic', or 'cosine'.
num_timesteps(int, optional): The number of timesteps for the generated beta schedule. Default is 1000.
init_beta(float, optional): The initial beta value. If not provided, a default value is used based on the
chosen schedule.
last_beta(float, optional): The final beta value. If not provided, a default value is used based on the
chosen schedule.
The function returns a PyTorch tensor containing the generated beta values.
The beta schedule is determined by the schedule parameter:
1.Linear: Generates a linear sequence of beta values betweeninit_betaandlast_beta.
2.Linear_sd: Generates a linear sequence of beta values between the square root of init_beta and the square root
oflast_beta, and then squares the result.
3.Quadratic: Similar to the 'linear_sd' schedule, but with different default values forinit_betaandlast_beta.
4.Cosine: Generates a sequence of beta values based on a cosine function, ensuring the values are between 0
and 0.999.
If an unsupported schedule is provided, a ValueError is raised with a message indicating the issue.
'''
if schedule == 'linear':
scale = 1000.0 / num_timesteps
init_beta = init_beta or scale * 0.0001
last_beta = last_beta or scale * 0.02
return torch.linspace(
init_beta, last_beta, num_timesteps, dtype=torch.float64)
elif schedule == 'linear_sd':
return torch.linspace(
init_beta**0.5, last_beta**0.5, num_timesteps,
dtype=torch.float64)**2
elif schedule == 'quadratic':
init_beta = init_beta or 0.0015
last_beta = last_beta or 0.0195
return torch.linspace(
init_beta**0.5, last_beta**0.5, num_timesteps,
dtype=torch.float64)**2
elif schedule == 'cosine':
betas = []
for step in range(num_timesteps):
t1 = step / num_timesteps
t2 = (step + 1) / num_timesteps
betas.append(min(1.0 - fn(t2) / fn(t1), 0.999))
return torch.tensor(betas, dtype=torch.float64)
else:
raise ValueError(f'Unsupported schedule: {schedule}')

View File

@@ -0,0 +1,404 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import math
import random
import numpy as np
import torch
import torchvision.transforms.functional as F
from PIL import Image, ImageFilter
__all__ = [
'Compose', 'Resize', 'Rescale', 'CenterCrop', 'CenterCropV2',
'CenterCropWide', 'RandomCrop', 'RandomCropV2', 'RandomHFlip',
'GaussianBlur', 'ColorJitter', 'RandomGray', 'ToTensor', 'Normalize',
'ResizeRandomCrop', 'ExtractResizeRandomCrop', 'ExtractResizeAssignCrop'
]
class Compose(object):
def __init__(self, transforms):
self.transforms = transforms
def __getitem__(self, index):
if isinstance(index, slice):
return Compose(self.transforms[index])
else:
return self.transforms[index]
def __len__(self):
return len(self.transforms)
def __call__(self, rgb):
for t in self.transforms:
rgb = t(rgb)
return rgb
class Resize(object):
def __init__(self, size=256):
if isinstance(size, int):
size = (size, size)
self.size = size
def __call__(self, rgb):
if isinstance(rgb, list):
rgb = [u.resize(self.size, Image.BILINEAR) for u in rgb]
else:
rgb = rgb.resize(self.size, Image.BILINEAR)
return rgb
class Rescale(object):
def __init__(self, size=256, interpolation=Image.BILINEAR):
self.size = size
self.interpolation = interpolation
def __call__(self, rgb):
w, h = rgb[0].size
scale = self.size / min(w, h)
out_w, out_h = int(round(w * scale)), int(round(h * scale))
rgb = [u.resize((out_w, out_h), self.interpolation) for u in rgb]
return rgb
class CenterCrop(object):
def __init__(self, size=224):
self.size = size
def __call__(self, rgb):
w, h = rgb[0].size
assert min(w, h) >= self.size
x1 = (w - self.size) // 2
y1 = (h - self.size) // 2
rgb = [u.crop((x1, y1, x1 + self.size, y1 + self.size)) for u in rgb]
return rgb
class ResizeRandomCrop(object):
def __init__(self, size=256, size_short=292):
self.size = size
self.size_short = size_short
def __call__(self, rgb):
# consistent crop between rgb and m
while min(rgb[0].size) >= 2 * self.size_short:
rgb = [
u.resize((u.width // 2, u.height // 2), resample=Image.BOX)
for u in rgb
]
scale = self.size_short / min(rgb[0].size)
rgb = [
u.resize((round(scale * u.width), round(scale * u.height)),
resample=Image.BICUBIC) for u in rgb
]
out_w = self.size
out_h = self.size
w, h = rgb[0].size
x1 = random.randint(0, w - out_w)
y1 = random.randint(0, h - out_h)
rgb = [u.crop((x1, y1, x1 + out_w, y1 + out_h)) for u in rgb]
return rgb
class ExtractResizeRandomCrop(object):
def __init__(self, size=256, size_short=292):
self.size = size
self.size_short = size_short
def __call__(self, rgb):
# consistent crop between rgb and m
while min(rgb[0].size) >= 2 * self.size_short:
rgb = [
u.resize((u.width // 2, u.height // 2), resample=Image.BOX)
for u in rgb
]
scale = self.size_short / min(rgb[0].size)
rgb = [
u.resize((round(scale * u.width), round(scale * u.height)),
resample=Image.BICUBIC) for u in rgb
]
out_w = self.size
out_h = self.size
w, h = rgb[0].size
x1 = random.randint(0, w - out_w)
y1 = random.randint(0, h - out_h)
rgb = [u.crop((x1, y1, x1 + out_w, y1 + out_h)) for u in rgb]
wh = [x1, y1, x1 + out_w, y1 + out_h]
return rgb, wh
class ExtractResizeAssignCrop(object):
def __init__(self, size=256, size_short=292):
self.size = size
self.size_short = size_short
def __call__(self, rgb, wh):
# consistent crop between rgb and m
while min(rgb[0].size) >= 2 * self.size_short:
rgb = [
u.resize((u.width // 2, u.height // 2), resample=Image.BOX)
for u in rgb
]
scale = self.size_short / min(rgb[0].size)
rgb = [
u.resize((round(scale * u.width), round(scale * u.height)),
resample=Image.BICUBIC) for u in rgb
]
rgb = [u.crop(wh) for u in rgb]
rgb = [u.resize((self.size, self.size), Image.BILINEAR) for u in rgb]
return rgb
class CenterCropV2(object):
def __init__(self, size):
self.size = size
def __call__(self, img):
# fast resize
while min(img[0].size) >= 2 * self.size:
img = [
u.resize((u.width // 2, u.height // 2), resample=Image.BOX)
for u in img
]
scale = self.size / min(img[0].size)
img = [
u.resize((round(scale * u.width), round(scale * u.height)),
resample=Image.BICUBIC) for u in img
]
# center crop
x1 = (img[0].width - self.size) // 2
y1 = (img[0].height - self.size) // 2
img = [u.crop((x1, y1, x1 + self.size, y1 + self.size)) for u in img]
return img
class CenterCropWide(object):
def __init__(self, size):
self.size = size
def __call__(self, img):
if isinstance(img, list):
scale = min(img[0].size[0] / self.size[0],
img[0].size[1] / self.size[1])
img = [
u.resize((round(u.width // scale), round(u.height // scale)),
resample=Image.BOX) for u in img
]
# center crop
x1 = (img[0].width - self.size[0]) // 2
y1 = (img[0].height - self.size[1]) // 2
img = [
u.crop((x1, y1, x1 + self.size[0], y1 + self.size[1]))
for u in img
]
return img
else:
scale = min(img.size[0] / self.size[0], img.size[1] / self.size[1])
img = img.resize(
(round(img.width // scale), round(img.height // scale)),
resample=Image.BOX)
x1 = (img.width - self.size[0]) // 2
y1 = (img.height - self.size[1]) // 2
img = img.crop((x1, y1, x1 + self.size[0], y1 + self.size[1]))
return img
class RandomCrop(object):
def __init__(self, size=224, min_area=0.4):
self.size = size
self.min_area = min_area
def __call__(self, rgb):
# consistent crop between rgb and m
w, h = rgb[0].size
area = w * h
out_w, out_h = float('inf'), float('inf')
while out_w > w or out_h > h:
target_area = random.uniform(self.min_area, 1.0) * area
aspect_ratio = random.uniform(3. / 4., 4. / 3.)
out_w = int(round(math.sqrt(target_area * aspect_ratio)))
out_h = int(round(math.sqrt(target_area / aspect_ratio)))
x1 = random.randint(0, w - out_w)
y1 = random.randint(0, h - out_h)
rgb = [u.crop((x1, y1, x1 + out_w, y1 + out_h)) for u in rgb]
rgb = [u.resize((self.size, self.size), Image.BILINEAR) for u in rgb]
return rgb
class RandomCropV2(object):
def __init__(self, size=224, min_area=0.4, ratio=(3. / 4., 4. / 3.)):
if isinstance(size, (tuple, list)):
self.size = size
else:
self.size = (size, size)
self.min_area = min_area
self.ratio = ratio
def _get_params(self, img):
width, height = img.size
area = height * width
for _ in range(10):
target_area = random.uniform(self.min_area, 1.0) * area
log_ratio = (math.log(self.ratio[0]), math.log(self.ratio[1]))
aspect_ratio = math.exp(random.uniform(*log_ratio))
w = int(round(math.sqrt(target_area * aspect_ratio)))
h = int(round(math.sqrt(target_area / aspect_ratio)))
if 0 < w <= width and 0 < h <= height:
i = random.randint(0, height - h)
j = random.randint(0, width - w)
return i, j, h, w
# Fallback to central crop
in_ratio = float(width) / float(height)
if (in_ratio < min(self.ratio)):
w = width
h = int(round(w / min(self.ratio)))
elif (in_ratio > max(self.ratio)):
h = height
w = int(round(h * max(self.ratio)))
else:
w = width
h = height
i = (height - h) // 2
j = (width - w) // 2
return i, j, h, w
def __call__(self, rgb):
i, j, h, w = self._get_params(rgb[0])
rgb = [F.resized_crop(u, i, j, h, w, self.size) for u in rgb]
return rgb
class RandomHFlip(object):
def __init__(self, p=0.5):
self.p = p
def __call__(self, rgb):
if random.random() < self.p:
rgb = [u.transpose(Image.FLIP_LEFT_RIGHT) for u in rgb]
return rgb
class GaussianBlur(object):
def __init__(self, sigmas=[0.1, 2.0], p=0.5):
self.sigmas = sigmas
self.p = p
def __call__(self, rgb):
if random.random() < self.p:
sigma = random.uniform(*self.sigmas)
rgb = [
u.filter(ImageFilter.GaussianBlur(radius=sigma)) for u in rgb
]
return rgb
class ColorJitter(object):
def __init__(self,
brightness=0.4,
contrast=0.4,
saturation=0.4,
hue=0.1,
p=0.5):
self.brightness = brightness
self.contrast = contrast
self.saturation = saturation
self.hue = hue
self.p = p
def __call__(self, rgb):
if random.random() < self.p:
brightness, contrast, saturation, hue = self._random_params()
transforms = [
lambda f: F.adjust_brightness(f, brightness),
lambda f: F.adjust_contrast(f, contrast),
lambda f: F.adjust_saturation(f, saturation),
lambda f: F.adjust_hue(f, hue)
]
random.shuffle(transforms)
for t in transforms:
rgb = [t(u) for u in rgb]
return rgb
def _random_params(self):
brightness = random.uniform(
max(0, 1 - self.brightness), 1 + self.brightness)
contrast = random.uniform(max(0, 1 - self.contrast), 1 + self.contrast)
saturation = random.uniform(
max(0, 1 - self.saturation), 1 + self.saturation)
hue = random.uniform(-self.hue, self.hue)
return brightness, contrast, saturation, hue
class RandomGray(object):
def __init__(self, p=0.2):
self.p = p
def __call__(self, rgb):
if random.random() < self.p:
rgb = [u.convert('L').convert('RGB') for u in rgb]
return rgb
class ToTensor(object):
def __call__(self, rgb):
if isinstance(rgb, list):
rgb = torch.stack([F.to_tensor(u) for u in rgb], dim=0)
else:
rgb = F.to_tensor(rgb)
return rgb
class Normalize(object):
def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
self.mean = mean
self.std = std
def __call__(self, rgb):
rgb = rgb.clone()
rgb.clamp_(0, 1)
if not isinstance(self.mean, torch.Tensor):
self.mean = rgb.new_tensor(self.mean).view(-1)
if not isinstance(self.std, torch.Tensor):
self.std = rgb.new_tensor(self.std).view(-1)
if rgb.dim() == 4:
rgb.sub_(self.mean.view(1, -1, 1,
1)).div_(self.std.view(1, -1, 1, 1))
elif rgb.dim() == 3:
rgb.sub_(self.mean.view(-1, 1, 1)).div_(self.std.view(-1, 1, 1))
return rgb

View File

@@ -0,0 +1,104 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import tempfile
from typing import Any, Dict, Optional
import cv2
import torch
from einops import rearrange
from modelscope.metainfo import Pipelines
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Input, Model, Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors.image import LoadImage
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger
logger = get_logger()
@PIPELINES.register_module(
Tasks.image_to_video, module_name=Pipelines.image_to_video_task_pipeline)
class ImageToVideoPipeline(Pipeline):
r""" Image To Video Pipeline.
Examples:
>>> from modelscope.pipelines import pipeline
>>> from modelscope.outputs import OutputKeys
>>> p = pipeline('image-to-video', 'damo/Image-to-Video')
>>> input = 'path_to_image'
>>> p(input,)
>>> {OutputKeys.OUTPUT_VIDEO: path-to-the-generated-video}
>>>
"""
def __init__(self, model: str, **kwargs):
"""
use `model` to create a kws pipeline for prediction
Args:
model: model id on modelscope hub.
"""
super().__init__(model=model, **kwargs)
def preprocess(self, input: Input, **preprocess_params) -> Dict[str, Any]:
img_path = input
image = LoadImage.convert_to_img(img_path)
if image.mode != 'RGB':
image = image.convert('RGB')
vit_frame = self.model.vid_trans(image)
vit_frame = vit_frame.unsqueeze(0)
vit_frame = vit_frame.to(self.model.device)
return {'vit_frame': vit_frame}
def forward(self, input: Dict[str, Any],
**forward_params) -> Dict[str, Any]:
video = self.model(input)
return {'video': video}
def postprocess(self, inputs: Dict[str, Any],
**post_params) -> Dict[str, Any]:
video = tensor2vid(inputs['video'], self.model.cfg.mean,
self.model.cfg.std)
output_video_path = post_params.get('output_video', None)
temp_video_file = False
if output_video_path is None:
output_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name
temp_video_file = True
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
h, w, c = video[0].shape
video_writer = cv2.VideoWriter(
output_video_path, fourcc, fps=8, frameSize=(w, h))
for i in range(len(video)):
img = cv2.cvtColor(video[i], cv2.COLOR_RGB2BGR)
video_writer.write(img)
video_writer.release()
if temp_video_file:
video_file_content = b''
with open(output_video_path, 'rb') as f:
video_file_content = f.read()
os.remove(output_video_path)
return {OutputKeys.OUTPUT_VIDEO: video_file_content}
else:
return {OutputKeys.OUTPUT_VIDEO: output_video_path}
def tensor2vid(video, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]):
mean = torch.tensor(mean, device=video.device).reshape(1, -1, 1, 1, 1)
std = torch.tensor(std, device=video.device).reshape(1, -1, 1, 1, 1)
video = video.mul_(std).add_(mean)
video.clamp_(0, 1)
video = video * 255.0
images = rearrange(video, 'b c f h w -> b f h w c')[0]
images = [(img.numpy()).astype('uint8') for img in images]
return images

View File

@@ -256,6 +256,7 @@ class MultiModalTasks(object):
text_to_video_synthesis = 'text-to-video-synthesis'
efficient_diffusion_tuning = 'efficient-diffusion-tuning'
multimodal_dialogue = 'multimodal-dialogue'
image_to_video = 'image-to-video'
class ScienceTasks(object):

View File

@@ -0,0 +1,28 @@
import sys
import unittest
from modelscope.models import Model
from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline
from modelscope.utils.constant import DownloadMode, Tasks
from modelscope.utils.test_utils import test_level
class Image2VideoTest(unittest.TestCase):
def setUp(self) -> None:
self.task = Tasks.image_to_video
self.model_id = 'damo/Image-to-Video'
self.path = 'https://video-generation-wulanchabu.oss-cn-wulanchabu.aliyuncs.com/baishao/test.jpeg'
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_model_from_modelhub(self):
pipe = pipeline(task=self.task, model=self.model_id)
output_video_path = pipe(
self.path, output_video='./output.mp4')[OutputKeys.OUTPUT_VIDEO]
print(output_video_path)
if __name__ == '__main__':
unittest.main()