mirror of
https://github.com/modelscope/modelscope.git
synced 2026-05-18 13:15:06 +02:00
@@ -220,6 +220,7 @@ class Models(object):
|
||||
stable_diffusion = 'stable-diffusion'
|
||||
videocomposer = 'videocomposer'
|
||||
text_to_360panorama_image = 'text-to-360panorama-image'
|
||||
image_to_video_model = 'image-to-video-model'
|
||||
|
||||
# science models
|
||||
unifold = 'unifold'
|
||||
@@ -545,6 +546,7 @@ class Pipelines(object):
|
||||
efficient_diffusion_tuning = 'efficient-diffusion-tuning'
|
||||
multimodal_dialogue = 'multimodal-dialogue'
|
||||
llama2_text_generation_pipeline = 'llama2-text-generation-pipeline'
|
||||
image_to_video_task_pipeline = 'image-to-video-task-pipeline'
|
||||
|
||||
# science tasks
|
||||
protein_structure = 'unifold-protein-structure'
|
||||
|
||||
24
modelscope/models/multi_modal/image_to_video/__init__.py
Normal file
24
modelscope/models/multi_modal/image_to_video/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
from .image_to_video_model import ImageToVideo
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
'image_to_video_model': ['ImageToVideo'],
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
215
modelscope/models/multi_modal/image_to_video/image_to_video_model.py
Executable file
215
modelscope/models/multi_modal/image_to_video/image_to_video_model.py
Executable file
@@ -0,0 +1,215 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
import os.path as osp
|
||||
import random
|
||||
from copy import copy
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
import torch.cuda.amp as amp
|
||||
|
||||
import modelscope.models.multi_modal.image_to_video.utils.transforms as data
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models.base import TorchModel
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.models.multi_modal.image_to_video.modules import *
|
||||
from modelscope.models.multi_modal.image_to_video.modules import (
|
||||
AutoencoderKL, FrozenOpenCLIPVisualEmbedder, Img2VidSDUNet)
|
||||
from modelscope.models.multi_modal.image_to_video.utils.config import cfg
|
||||
from modelscope.models.multi_modal.image_to_video.utils.diffusion import \
|
||||
GaussianDiffusion
|
||||
from modelscope.models.multi_modal.image_to_video.utils.seed import setup_seed
|
||||
from modelscope.models.multi_modal.image_to_video.utils.shedule import \
|
||||
beta_schedule
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
__all__ = ['ImageToVideo']
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
@MODELS.register_module(
|
||||
Tasks.image_to_video, module_name=Models.image_to_video_model)
|
||||
class ImageToVideo(TorchModel):
|
||||
r"""
|
||||
Image2Video aims to solve the task of generating high-definition videos based on input images.
|
||||
Image2Video is a video generation basic model developed by Alibaba Cloud, with a parameter size
|
||||
of approximately 2 billion. It has been pre trained on large-scale video and image data and
|
||||
fine-tuned on a small amount of high-quality data. The data is widely distributed and diverse
|
||||
in categories, and the model has good generalization ability for different types of data
|
||||
|
||||
Paper link: https://arxiv.org/abs/2306.02018
|
||||
|
||||
Attributes:
|
||||
diffusion: diffusion model for DDIM.
|
||||
autoencoder: decode the latent representation into visual space.
|
||||
clip_encoder: encode the image into image embedding.
|
||||
"""
|
||||
|
||||
def __init__(self, model_dir, *args, **kwargs):
|
||||
r"""
|
||||
Args:
|
||||
model_dir (`str` or `os.PathLike`)
|
||||
Can be either:
|
||||
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co
|
||||
or modelscope.cn. Valid model ids can be located at the root-level, like `bert-base-uncased`,
|
||||
or namespaced under a user or organization name, like `dbmdz/bert-base-german-cased`.
|
||||
- A path to a *directory* containing model weights saved using
|
||||
[`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
|
||||
- A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In
|
||||
this case, `from_tf` should be set to `True` and a configuration object should be provided as
|
||||
`config` argument. This loading path is slower than converting the TensorFlow checkpoint in a
|
||||
PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
|
||||
- A path or url to a model folder containing a *flax checkpoint file* in *.msgpack* format (e.g,
|
||||
`./flax_model/` containing `flax_model.msgpack`). In this case, `from_flax` should be set to
|
||||
`True`.
|
||||
"""
|
||||
super().__init__(model_dir=model_dir, *args, **kwargs)
|
||||
|
||||
self.config = Config.from_file(
|
||||
osp.join(model_dir, ModelFile.CONFIGURATION))
|
||||
|
||||
# assign default value
|
||||
cfg.batch_size = self.config.model.model_cfg.batch_size
|
||||
cfg.target_fps = self.config.model.model_cfg.target_fps
|
||||
cfg.max_frames = self.config.model.model_cfg.max_frames
|
||||
cfg.latent_hei = self.config.model.model_cfg.latent_hei
|
||||
cfg.latent_wid = self.config.model.model_cfg.latent_wid
|
||||
cfg.model_path = osp.join(model_dir,
|
||||
self.config.model.model_args.ckpt_unet)
|
||||
|
||||
self.device = torch.device(
|
||||
'cuda') if torch.cuda.is_available() else torch.device('cpu')
|
||||
|
||||
if 'seed' in self.config.model.model_args.keys():
|
||||
cfg.seed = self.config.model.model_args.seed
|
||||
else:
|
||||
cfg.seed = random.randint(0, 99999)
|
||||
setup_seed(cfg.seed)
|
||||
|
||||
# transform
|
||||
vid_trans = data.Compose([
|
||||
data.CenterCropWide(size=(cfg.resolution[0], cfg.resolution[0])),
|
||||
data.Resize(cfg.vit_resolution),
|
||||
data.ToTensor(),
|
||||
data.Normalize(mean=cfg.vit_mean, std=cfg.vit_std)
|
||||
])
|
||||
self.vid_trans = vid_trans
|
||||
|
||||
cfg.embedder.pretrained = osp.join(
|
||||
model_dir, self.config.model.model_args.ckpt_clip)
|
||||
clip_encoder = FrozenOpenCLIPVisualEmbedder(**cfg.embedder)
|
||||
clip_encoder.model.to(self.device)
|
||||
self.clip_encoder = clip_encoder
|
||||
logger.info(f'Build encoder with {cfg.embedder.type}')
|
||||
|
||||
# [unet]
|
||||
generator = Img2VidSDUNet(**cfg.UNet)
|
||||
generator = generator.to(self.device)
|
||||
generator.eval()
|
||||
load_dict = torch.load(cfg.model_path, map_location='cpu')
|
||||
ret = generator.load_state_dict(load_dict['state_dict'], strict=True)
|
||||
self.generator = generator
|
||||
logger.info('Load model {} path {}, with local status {}'.format(
|
||||
cfg.UNet.type, cfg.model_path, ret))
|
||||
|
||||
# [diffusion]
|
||||
betas = beta_schedule(
|
||||
'linear_sd',
|
||||
cfg.num_timesteps,
|
||||
init_beta=0.00085,
|
||||
last_beta=0.0120)
|
||||
diffusion = GaussianDiffusion(
|
||||
betas=betas,
|
||||
mean_type=cfg.mean_type,
|
||||
var_type=cfg.var_type,
|
||||
loss_type=cfg.loss_type,
|
||||
rescale_timesteps=False,
|
||||
noise_strength=getattr(cfg, 'noise_strength', 0))
|
||||
self.diffusion = diffusion
|
||||
logger.info('Build diffusion with type of GaussianDiffusion')
|
||||
|
||||
# [auotoencoder]
|
||||
cfg.auto_encoder.pretrained = osp.join(
|
||||
model_dir, self.config.model.model_args.ckpt_autoencoder)
|
||||
autoencoder = AutoencoderKL(**cfg.auto_encoder)
|
||||
autoencoder.eval()
|
||||
for param in autoencoder.parameters():
|
||||
param.requires_grad = False
|
||||
autoencoder.to(self.device)
|
||||
self.autoencoder = autoencoder
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
zero_feature = torch.zeros(1, 1, cfg.UNet.input_dim).to(self.device)
|
||||
self.zero_feature = zero_feature
|
||||
self.fps_tensor = torch.tensor([cfg.target_fps],
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
self.cfg = cfg
|
||||
|
||||
def forward(self, input: Dict[str, Any]):
|
||||
r"""
|
||||
The entry function of image to video task.
|
||||
1. Using diffusion model to generate the video's latent representation.
|
||||
2. Using autoencoder to decode the video's latent representation to visual space.
|
||||
|
||||
Args:
|
||||
input (`Dict[Str, Any]`):
|
||||
The input of the task
|
||||
Returns:
|
||||
A generated video (as pytorch tensor).
|
||||
"""
|
||||
|
||||
vit_frame = input['vit_frame']
|
||||
cfg = self.cfg
|
||||
|
||||
img_embedding = self.clip_encoder(vit_frame).unsqueeze(1)
|
||||
|
||||
noise = self.build_noise()
|
||||
zero_feature = copy(self.zero_feature)
|
||||
with torch.no_grad():
|
||||
with amp.autocast(enabled=cfg.use_fp16):
|
||||
model_kwargs = [{
|
||||
'y': img_embedding,
|
||||
'fps': self.fps_tensor
|
||||
}, {
|
||||
'y': zero_feature.repeat(cfg.batch_size, 1, 1),
|
||||
'fps': self.fps_tensor
|
||||
}]
|
||||
gen_video = self.diffusion.ddim_sample_loop(
|
||||
noise=noise,
|
||||
model=self.generator,
|
||||
model_kwargs=model_kwargs,
|
||||
guide_scale=cfg.guide_scale,
|
||||
ddim_timesteps=cfg.ddim_timesteps,
|
||||
eta=0.0)
|
||||
|
||||
gen_video = 1. / cfg.scale_factor * gen_video
|
||||
gen_video = rearrange(gen_video, 'b c f h w -> (b f) c h w')
|
||||
chunk_size = min(cfg.decoder_bs, gen_video.shape[0])
|
||||
gen_video_list = torch.chunk(
|
||||
gen_video, gen_video.shape[0] // chunk_size, dim=0)
|
||||
decode_generator = []
|
||||
for vd_data in gen_video_list:
|
||||
gen_frames = self.autoencoder.decode(vd_data)
|
||||
decode_generator.append(gen_frames)
|
||||
|
||||
gen_video = torch.cat(decode_generator, dim=0)
|
||||
gen_video = rearrange(
|
||||
gen_video, '(b f) c h w -> b c f h w', b=cfg.batch_size)
|
||||
|
||||
return gen_video.type(torch.float32).cpu()
|
||||
|
||||
def build_noise(self):
|
||||
cfg = self.cfg
|
||||
noise = torch.randn(
|
||||
[1, 4, cfg.max_frames, cfg.latent_hei,
|
||||
cfg.latent_wid]).to(self.device)
|
||||
if cfg.noise_strength > 0:
|
||||
b, c, f, *_ = noise.shape
|
||||
offset_noise = torch.randn(b, c, f, 1, 1, device=noise.device)
|
||||
noise = noise + cfg.noise_strength * offset_noise
|
||||
return noise.contiguous()
|
||||
5
modelscope/models/multi_modal/image_to_video/modules/__init__.py
Executable file
5
modelscope/models/multi_modal/image_to_video/modules/__init__.py
Executable file
@@ -0,0 +1,5 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from .autoencoder import *
|
||||
from .embedder import *
|
||||
from .unet_i2v import *
|
||||
573
modelscope/models/multi_modal/image_to_video/modules/autoencoder.py
Executable file
573
modelscope/models/multi_modal/image_to_video/modules/autoencoder.py
Executable file
@@ -0,0 +1,573 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import collections
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def nonlinearity(x):
|
||||
# swish
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
def Normalize(in_channels, num_groups=32):
|
||||
return torch.nn.GroupNorm(
|
||||
num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
|
||||
|
||||
class DiagonalGaussianDistribution(object):
|
||||
|
||||
def __init__(self, parameters, deterministic=False):
|
||||
self.parameters = parameters
|
||||
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
|
||||
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
||||
self.deterministic = deterministic
|
||||
self.std = torch.exp(0.5 * self.logvar)
|
||||
self.var = torch.exp(self.logvar)
|
||||
if self.deterministic:
|
||||
self.var = self.std = torch.zeros_like(
|
||||
self.mean).to(device=self.parameters.device)
|
||||
|
||||
def sample(self):
|
||||
x = self.mean + self.std * torch.randn(
|
||||
self.mean.shape).to(device=self.parameters.device)
|
||||
return x
|
||||
|
||||
def kl(self, other=None):
|
||||
if self.deterministic:
|
||||
return torch.Tensor([0.])
|
||||
else:
|
||||
if other is None:
|
||||
return 0.5 * torch.sum(
|
||||
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
|
||||
dim=[1, 2, 3])
|
||||
else:
|
||||
return 0.5 * torch.sum(
|
||||
torch.pow(self.mean - other.mean, 2) / other.var
|
||||
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
|
||||
dim=[1, 2, 3])
|
||||
|
||||
def nll(self, sample, dims=[1, 2, 3]):
|
||||
if self.deterministic:
|
||||
return torch.Tensor([0.])
|
||||
logtwopi = np.log(2.0 * np.pi)
|
||||
return 0.5 * torch.sum(
|
||||
logtwopi + self.logvar
|
||||
+ torch.pow(sample - self.mean, 2) / self.var,
|
||||
dim=dims)
|
||||
|
||||
def mode(self):
|
||||
return self.mean
|
||||
|
||||
|
||||
class ResnetBlock(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
in_channels,
|
||||
out_channels=None,
|
||||
conv_shortcut=False,
|
||||
dropout,
|
||||
temb_channels=512):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
self.use_conv_shortcut = conv_shortcut
|
||||
|
||||
self.norm1 = Normalize(in_channels)
|
||||
self.conv1 = torch.nn.Conv2d(
|
||||
in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
if temb_channels > 0:
|
||||
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
|
||||
self.norm2 = Normalize(out_channels)
|
||||
self.dropout = torch.nn.Dropout(dropout)
|
||||
self.conv2 = torch.nn.Conv2d(
|
||||
out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
self.conv_shortcut = torch.nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
else:
|
||||
self.nin_shortcut = torch.nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
|
||||
def forward(self, x, temb):
|
||||
h = x
|
||||
h = self.norm1(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv1(h)
|
||||
|
||||
if temb is not None:
|
||||
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
|
||||
|
||||
h = self.norm2(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.dropout(h)
|
||||
h = self.conv2(h)
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
x = self.conv_shortcut(x)
|
||||
else:
|
||||
x = self.nin_shortcut(x)
|
||||
|
||||
return x + h
|
||||
|
||||
|
||||
class AttnBlock(nn.Module):
|
||||
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = Normalize(in_channels)
|
||||
self.q = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.k = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.v = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.proj_out = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
# compute attention
|
||||
b, c, h, w = q.shape
|
||||
q = q.reshape(b, c, h * w)
|
||||
q = q.permute(0, 2, 1)
|
||||
k = k.reshape(b, c, h * w)
|
||||
w_ = torch.bmm(q, k)
|
||||
w_ = w_ * (int(c)**(-0.5))
|
||||
w_ = torch.nn.functional.softmax(w_, dim=2)
|
||||
|
||||
# attend to values
|
||||
v = v.reshape(b, c, h * w)
|
||||
w_ = w_.permute(0, 2, 1)
|
||||
h_ = torch.bmm(v, w_)
|
||||
h_ = h_.reshape(b, c, h, w)
|
||||
|
||||
h_ = self.proj_out(h_)
|
||||
|
||||
return x + h_
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, with_conv):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
if self.with_conv:
|
||||
self.conv = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
x = torch.nn.functional.interpolate(
|
||||
x, scale_factor=2.0, mode='nearest')
|
||||
if self.with_conv:
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, with_conv):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
if self.with_conv:
|
||||
# no asymmetric padding in torch conv, must do it ourselves
|
||||
self.conv = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
if self.with_conv:
|
||||
pad = (0, 1, 0, 1)
|
||||
x = torch.nn.functional.pad(x, pad, mode='constant', value=0)
|
||||
x = self.conv(x)
|
||||
else:
|
||||
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
||||
return x
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
ch,
|
||||
out_ch,
|
||||
ch_mult=(1, 2, 4, 8),
|
||||
num_res_blocks,
|
||||
attn_resolutions,
|
||||
dropout=0.0,
|
||||
resamp_with_conv=True,
|
||||
in_channels,
|
||||
resolution,
|
||||
z_channels,
|
||||
double_z=True,
|
||||
use_linear_attn=False,
|
||||
attn_type='vanilla',
|
||||
**ignore_kwargs):
|
||||
super().__init__()
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
|
||||
# downsampling
|
||||
self.conv_in = torch.nn.Conv2d(
|
||||
in_channels, self.ch, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
curr_res = resolution
|
||||
in_ch_mult = (1, ) + tuple(ch_mult)
|
||||
self.in_ch_mult = in_ch_mult
|
||||
self.down = nn.ModuleList()
|
||||
for i_level in range(self.num_resolutions):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_in = ch * in_ch_mult[i_level]
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks):
|
||||
block.append(
|
||||
ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout))
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(AttnBlock(block_in))
|
||||
down = nn.Module()
|
||||
down.block = block
|
||||
down.attn = attn
|
||||
if i_level != self.num_resolutions - 1:
|
||||
down.downsample = Downsample(block_in, resamp_with_conv)
|
||||
curr_res = curr_res // 2
|
||||
self.down.append(down)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
self.mid.attn_1 = AttnBlock(block_in)
|
||||
self.mid.block_2 = ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = torch.nn.Conv2d(
|
||||
block_in,
|
||||
2 * z_channels if double_z else z_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
# timestep embedding
|
||||
temb = None
|
||||
|
||||
# downsampling
|
||||
hs = [self.conv_in(x)]
|
||||
for i_level in range(self.num_resolutions):
|
||||
for i_block in range(self.num_res_blocks):
|
||||
h = self.down[i_level].block[i_block](hs[-1], temb)
|
||||
if len(self.down[i_level].attn) > 0:
|
||||
h = self.down[i_level].attn[i_block](h)
|
||||
hs.append(h)
|
||||
if i_level != self.num_resolutions - 1:
|
||||
hs.append(self.down[i_level].downsample(hs[-1]))
|
||||
|
||||
# middle
|
||||
h = hs[-1]
|
||||
h = self.mid.block_1(h, temb)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h, temb)
|
||||
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
ch,
|
||||
out_ch,
|
||||
ch_mult=(1, 2, 4, 8),
|
||||
num_res_blocks,
|
||||
attn_resolutions,
|
||||
dropout=0.0,
|
||||
resamp_with_conv=True,
|
||||
in_channels,
|
||||
resolution,
|
||||
z_channels,
|
||||
give_pre_end=False,
|
||||
tanh_out=False,
|
||||
use_linear_attn=False,
|
||||
attn_type='vanilla',
|
||||
**ignorekwargs):
|
||||
super().__init__()
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
self.give_pre_end = give_pre_end
|
||||
self.tanh_out = tanh_out
|
||||
|
||||
# compute block_in and curr_res at lowest res
|
||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||
curr_res = resolution // 2**(self.num_resolutions - 1)
|
||||
self.z_shape = (1, z_channels, curr_res, curr_res)
|
||||
logging.info('Working with z of shape {} = {} dimensions.'.format(
|
||||
self.z_shape, np.prod(self.z_shape)))
|
||||
|
||||
# z to block_in
|
||||
self.conv_in = torch.nn.Conv2d(
|
||||
z_channels, block_in, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
self.mid.attn_1 = AttnBlock(block_in)
|
||||
self.mid.block_2 = ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
|
||||
# upsampling
|
||||
self.up = nn.ModuleList()
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
block.append(
|
||||
ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout))
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(AttnBlock(block_in))
|
||||
up = nn.Module()
|
||||
up.block = block
|
||||
up.attn = attn
|
||||
if i_level != 0:
|
||||
up.upsample = Upsample(block_in, resamp_with_conv)
|
||||
curr_res = curr_res * 2
|
||||
self.up.insert(0, up)
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = torch.nn.Conv2d(
|
||||
block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, z):
|
||||
self.last_z_shape = z.shape
|
||||
|
||||
# timestep embedding
|
||||
temb = None
|
||||
|
||||
# z to block_in
|
||||
h = self.conv_in(z)
|
||||
|
||||
# middle
|
||||
h = self.mid.block_1(h, temb)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h, temb)
|
||||
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
h = self.up[i_level].block[i_block](h, temb)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
h = self.up[i_level].attn[i_block](h)
|
||||
if i_level != 0:
|
||||
h = self.up[i_level].upsample(h)
|
||||
|
||||
# end
|
||||
if self.give_pre_end:
|
||||
return h
|
||||
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
if self.tanh_out:
|
||||
h = torch.tanh(h)
|
||||
return h
|
||||
|
||||
|
||||
class AutoencoderKL(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
ddconfig,
|
||||
embed_dim,
|
||||
pretrained=None,
|
||||
ignore_keys=[],
|
||||
image_key='image',
|
||||
colorize_nlabels=None,
|
||||
monitor=None,
|
||||
ema_decay=None,
|
||||
learn_logvar=False,
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
self.learn_logvar = learn_logvar
|
||||
self.image_key = image_key
|
||||
self.encoder = Encoder(**ddconfig)
|
||||
self.decoder = Decoder(**ddconfig)
|
||||
assert ddconfig['double_z']
|
||||
self.quant_conv = torch.nn.Conv2d(2 * ddconfig['z_channels'],
|
||||
2 * embed_dim, 1)
|
||||
self.post_quant_conv = torch.nn.Conv2d(embed_dim,
|
||||
ddconfig['z_channels'], 1)
|
||||
self.embed_dim = embed_dim
|
||||
if colorize_nlabels is not None:
|
||||
assert type(colorize_nlabels) == int
|
||||
self.register_buffer('colorize',
|
||||
torch.randn(3, colorize_nlabels, 1, 1))
|
||||
if monitor is not None:
|
||||
self.monitor = monitor
|
||||
|
||||
self.use_ema = ema_decay is not None
|
||||
|
||||
if pretrained is not None:
|
||||
self.init_from_ckpt(pretrained, ignore_keys=ignore_keys)
|
||||
|
||||
def init_from_ckpt(self, path, ignore_keys=list()):
|
||||
sd = torch.load(path, map_location='cpu')['state_dict']
|
||||
keys = list(sd.keys())
|
||||
sd_new = collections.OrderedDict()
|
||||
for k in keys:
|
||||
if k.find('first_stage_model') >= 0:
|
||||
k_new = k.split('first_stage_model.')[-1]
|
||||
sd_new[k_new] = sd[k]
|
||||
self.load_state_dict(sd_new, strict=True)
|
||||
logging.info(f'Restored from {path}')
|
||||
|
||||
def on_train_batch_end(self, *args, **kwargs):
|
||||
if self.use_ema:
|
||||
self.model_ema(self)
|
||||
|
||||
def encode(self, x):
|
||||
h = self.encoder(x)
|
||||
moments = self.quant_conv(h)
|
||||
posterior = DiagonalGaussianDistribution(moments)
|
||||
return posterior
|
||||
|
||||
def decode(self, z):
|
||||
z = self.post_quant_conv(z)
|
||||
dec = self.decoder(z)
|
||||
return dec
|
||||
|
||||
def forward(self, input, sample_posterior=True):
|
||||
posterior = self.encode(input)
|
||||
if sample_posterior:
|
||||
z = posterior.sample()
|
||||
else:
|
||||
z = posterior.mode()
|
||||
dec = self.decode(z)
|
||||
return dec, posterior
|
||||
|
||||
def get_input(self, batch, k):
|
||||
x = batch[k]
|
||||
if len(x.shape) == 3:
|
||||
x = x[..., None]
|
||||
x = x.permute(0, 3, 1,
|
||||
2).to(memory_format=torch.contiguous_format).float()
|
||||
return x
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.decoder.conv_out.weight
|
||||
|
||||
@torch.no_grad()
|
||||
def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs):
|
||||
log = dict()
|
||||
x = self.get_input(batch, self.image_key)
|
||||
x = x.to(self.device)
|
||||
if not only_inputs:
|
||||
xrec, posterior = self(x)
|
||||
if x.shape[1] > 3:
|
||||
# colorize with random projection
|
||||
assert xrec.shape[1] > 3
|
||||
x = self.to_rgb(x)
|
||||
xrec = self.to_rgb(xrec)
|
||||
log['samples'] = self.decode(torch.randn_like(posterior.sample()))
|
||||
log['reconstructions'] = xrec
|
||||
if log_ema or self.use_ema:
|
||||
with self.ema_scope():
|
||||
xrec_ema, posterior_ema = self(x)
|
||||
if x.shape[1] > 3:
|
||||
# colorize with random projection
|
||||
assert xrec_ema.shape[1] > 3
|
||||
xrec_ema = self.to_rgb(xrec_ema)
|
||||
log['samples_ema'] = self.decode(
|
||||
torch.randn_like(posterior_ema.sample()))
|
||||
log['reconstructions_ema'] = xrec_ema
|
||||
log['inputs'] = x
|
||||
return log
|
||||
|
||||
def to_rgb(self, x):
|
||||
assert self.image_key == 'segmentation'
|
||||
if not hasattr(self, 'colorize'):
|
||||
self.register_buffer('colorize',
|
||||
torch.randn(3, x.shape[1], 1, 1).to(x))
|
||||
x = F.conv2d(x, weight=self.colorize)
|
||||
x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
|
||||
return x
|
||||
|
||||
|
||||
class IdentityFirstStage(torch.nn.Module):
|
||||
|
||||
def __init__(self, *args, vq_interface=False, **kwargs):
|
||||
self.vq_interface = vq_interface
|
||||
super().__init__()
|
||||
|
||||
def encode(self, x, *args, **kwargs):
|
||||
return x
|
||||
|
||||
def decode(self, x, *args, **kwargs):
|
||||
return x
|
||||
|
||||
def quantize(self, x, *args, **kwargs):
|
||||
if self.vq_interface:
|
||||
return x, None, [None, None, None]
|
||||
return x
|
||||
|
||||
def forward(self, x, *args, **kwargs):
|
||||
return x
|
||||
82
modelscope/models/multi_modal/image_to_video/modules/embedder.py
Executable file
82
modelscope/models/multi_modal/image_to_video/modules/embedder.py
Executable file
@@ -0,0 +1,82 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import open_clip
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.transforms as T
|
||||
|
||||
|
||||
class FrozenOpenCLIPVisualEmbedder(nn.Module):
|
||||
"""
|
||||
Uses the OpenCLIP transformer encoder for text
|
||||
"""
|
||||
LAYERS = ['last', 'penultimate']
|
||||
|
||||
def __init__(self,
|
||||
pretrained,
|
||||
vit_resolution=(224, 224),
|
||||
arch='ViT-H-14',
|
||||
device='cuda',
|
||||
max_length=77,
|
||||
freeze=True,
|
||||
layer='last',
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
assert layer in self.LAYERS
|
||||
model, _, preprocess = open_clip.create_model_and_transforms(
|
||||
arch, device=torch.device('cpu'), pretrained=pretrained)
|
||||
|
||||
del model.transformer
|
||||
self.model = model
|
||||
data_white = np.ones(
|
||||
(vit_resolution[0], vit_resolution[1], 3), dtype=np.uint8) * 255
|
||||
self.white_image = preprocess(T.ToPILImage()(data_white)).unsqueeze(0)
|
||||
|
||||
self.device = device
|
||||
self.max_length = max_length
|
||||
if freeze:
|
||||
self.freeze()
|
||||
self.layer = layer
|
||||
if self.layer == 'last':
|
||||
self.layer_idx = 0
|
||||
elif self.layer == 'penultimate':
|
||||
self.layer_idx = 1
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
def freeze(self):
|
||||
self.model = self.model.eval()
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, image):
|
||||
z = self.model.encode_image(image.to(self.device))
|
||||
return z
|
||||
|
||||
def encode_with_transformer(self, text):
|
||||
x = self.model.token_embedding(text)
|
||||
x = x + self.model.positional_embedding
|
||||
x = x.permute(1, 0, 2)
|
||||
x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
|
||||
x = x.permute(1, 0, 2)
|
||||
x = self.model.ln_final(x)
|
||||
|
||||
return x
|
||||
|
||||
def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
|
||||
for i, r in enumerate(self.model.transformer.resblocks):
|
||||
if i == len(self.model.transformer.resblocks) - self.layer_idx:
|
||||
break
|
||||
if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting(
|
||||
):
|
||||
x = checkpoint(r, x, attn_mask)
|
||||
else:
|
||||
x = r(x, attn_mask=attn_mask)
|
||||
return x
|
||||
|
||||
def encode(self, text):
|
||||
return self(text)
|
||||
1504
modelscope/models/multi_modal/image_to_video/modules/unet_i2v.py
Normal file
1504
modelscope/models/multi_modal/image_to_video/modules/unet_i2v.py
Normal file
File diff suppressed because it is too large
Load Diff
2
modelscope/models/multi_modal/image_to_video/utils/__init__.py
Executable file
2
modelscope/models/multi_modal/image_to_video/utils/__init__.py
Executable file
@@ -0,0 +1,2 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
161
modelscope/models/multi_modal/image_to_video/utils/config.py
Executable file
161
modelscope/models/multi_modal/image_to_video/utils/config.py
Executable file
@@ -0,0 +1,161 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import os.path as osp
|
||||
from datetime import datetime
|
||||
|
||||
import torch
|
||||
from easydict import EasyDict
|
||||
|
||||
cfg = EasyDict(__name__='Config: VideoLDM Decoder')
|
||||
|
||||
# ---------------------------work dir--------------------------
|
||||
cfg.work_dir = 'workspace/'
|
||||
|
||||
# ---------------------------Global Variable-----------------------------------
|
||||
cfg.resolution = [448, 256]
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
# ---------------------------Dataset Parameter---------------------------------
|
||||
cfg.mean = [0.5, 0.5, 0.5]
|
||||
cfg.std = [0.5, 0.5, 0.5]
|
||||
cfg.max_words = 1000
|
||||
|
||||
# PlaceHolder
|
||||
cfg.vit_out_dim = 1024
|
||||
cfg.vit_resolution = [224, 224]
|
||||
cfg.depth_clamp = 10.0
|
||||
cfg.misc_size = 384
|
||||
cfg.depth_std = 20.0
|
||||
|
||||
cfg.frame_lens = 32
|
||||
cfg.sample_fps = 8
|
||||
|
||||
cfg.batch_sizes = 1
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
# ---------------------------Mode Parameters-----------------------------------
|
||||
# Diffusion
|
||||
cfg.schedule = 'cosine'
|
||||
cfg.num_timesteps = 1000
|
||||
cfg.mean_type = 'v'
|
||||
cfg.var_type = 'fixed_small'
|
||||
cfg.loss_type = 'mse'
|
||||
cfg.ddim_timesteps = 50
|
||||
cfg.ddim_eta = 0.0
|
||||
cfg.clamp = 1.0
|
||||
cfg.share_noise = False
|
||||
cfg.use_div_loss = False
|
||||
cfg.noise_strength = 0.1
|
||||
|
||||
# classifier-free guidance
|
||||
cfg.p_zero = 0.1
|
||||
cfg.guide_scale = 3.0
|
||||
|
||||
# clip vision encoder
|
||||
cfg.vit_mean = [0.48145466, 0.4578275, 0.40821073]
|
||||
cfg.vit_std = [0.26862954, 0.26130258, 0.27577711]
|
||||
|
||||
# Model
|
||||
cfg.scale_factor = 0.18215
|
||||
cfg.use_fp16 = True
|
||||
cfg.temporal_attention = True
|
||||
cfg.decoder_bs = 8
|
||||
|
||||
cfg.UNet = {
|
||||
'type': 'Img2VidSDUNet',
|
||||
'in_dim': 4,
|
||||
'dim': 320,
|
||||
'y_dim': cfg.vit_out_dim,
|
||||
'context_dim': 1024,
|
||||
'out_dim': 8 if cfg.var_type.startswith('learned') else 4,
|
||||
'dim_mult': [1, 2, 4, 4],
|
||||
'num_heads': 8,
|
||||
'head_dim': 64,
|
||||
'num_res_blocks': 2,
|
||||
'attn_scales': [1 / 1, 1 / 2, 1 / 4],
|
||||
'dropout': 0.1,
|
||||
'temporal_attention': cfg.temporal_attention,
|
||||
'temporal_attn_times': 1,
|
||||
'use_checkpoint': False,
|
||||
'use_fps_condition': False,
|
||||
'use_sim_mask': False,
|
||||
'num_tokens': 4,
|
||||
'default_fps': 8,
|
||||
'input_dim': 1024
|
||||
}
|
||||
|
||||
cfg.guidances = []
|
||||
|
||||
# auotoencoder from stabel diffusion
|
||||
cfg.auto_encoder = {
|
||||
'type': 'AutoencoderKL',
|
||||
'ddconfig': {
|
||||
'double_z': True,
|
||||
'z_channels': 4,
|
||||
'resolution': 256,
|
||||
'in_channels': 3,
|
||||
'out_ch': 3,
|
||||
'ch': 128,
|
||||
'ch_mult': [1, 2, 4, 4],
|
||||
'num_res_blocks': 2,
|
||||
'attn_resolutions': [],
|
||||
'dropout': 0.0
|
||||
},
|
||||
'embed_dim': 4,
|
||||
'pretrained': 'v2-1_512-ema-pruned.ckpt'
|
||||
}
|
||||
# clip embedder
|
||||
cfg.embedder = {
|
||||
'type': 'FrozenOpenCLIPVisualEmbedder',
|
||||
'layer': 'penultimate',
|
||||
'vit_resolution': [224, 224],
|
||||
'pretrained': 'open_clip_pytorch_model.bin'
|
||||
}
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
# ---------------------------Training Settings---------------------------------
|
||||
# training and optimizer
|
||||
cfg.ema_decay = 0.9999
|
||||
cfg.num_steps = 600000
|
||||
cfg.lr = 5e-5
|
||||
cfg.weight_decay = 0.0
|
||||
cfg.betas = (0.9, 0.999)
|
||||
cfg.eps = 1.0e-8
|
||||
cfg.chunk_size = 16
|
||||
cfg.alpha = 0.7
|
||||
cfg.save_ckp_interval = 1000
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
# ----------------------------Pretrain Settings---------------------------------
|
||||
# Default: load 2d pretrain
|
||||
cfg.fix_weight = False
|
||||
cfg.load_match = False
|
||||
cfg.pretrained_checkpoint = 'v2-1_512-ema-pruned.ckpt'
|
||||
cfg.pretrained_image_keys = 'stable_diffusion_image_key_temporal_attention_x1.json'
|
||||
cfg.resume_checkpoint = 'img2video_ldm_0779000.pth'
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
# -----------------------------Visual-------------------------------------------
|
||||
# Visual videos
|
||||
cfg.viz_interval = 1000
|
||||
cfg.visual_train = {
|
||||
'type': 'VisualVideoTextDuringTrain',
|
||||
}
|
||||
cfg.visual_inference = {
|
||||
'type': 'VisualGeneratedVideos',
|
||||
}
|
||||
cfg.inference_list_path = ''
|
||||
|
||||
# logging
|
||||
cfg.log_interval = 100
|
||||
|
||||
# Default log_dir
|
||||
cfg.log_dir = 'workspace/output_data'
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
# ---------------------------Others--------------------------------------------
|
||||
# seed
|
||||
cfg.seed = 8888
|
||||
# -----------------------------------------------------------------------------
|
||||
511
modelscope/models/multi_modal/image_to_video/utils/diffusion.py
Executable file
511
modelscope/models/multi_modal/image_to_video/utils/diffusion.py
Executable file
@@ -0,0 +1,511 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
|
||||
__all__ = ['GaussianDiffusion', 'beta_schedule']
|
||||
|
||||
|
||||
def _i(tensor, t, x):
|
||||
r"""Index tensor using t and format the output according to x.
|
||||
"""
|
||||
shape = (x.size(0), ) + (1, ) * (x.ndim - 1)
|
||||
if tensor.device != x.device:
|
||||
tensor = tensor.to(x.device)
|
||||
return tensor[t].view(shape).to(x)
|
||||
|
||||
|
||||
def fn(u):
|
||||
return math.cos((u + 0.008) / 1.008 * math.pi / 2)**2
|
||||
|
||||
|
||||
def beta_schedule(schedule,
|
||||
num_timesteps=1000,
|
||||
init_beta=None,
|
||||
last_beta=None):
|
||||
if schedule == 'linear':
|
||||
scale = 1000.0 / num_timesteps
|
||||
init_beta = init_beta or scale * 0.0001
|
||||
last_beta = last_beta or scale * 0.02
|
||||
return torch.linspace(
|
||||
init_beta, last_beta, num_timesteps, dtype=torch.float64)
|
||||
elif schedule == 'quadratic':
|
||||
init_beta = init_beta or 0.0015
|
||||
last_beta = last_beta or 0.0195
|
||||
return torch.linspace(
|
||||
init_beta**0.5, last_beta**0.5, num_timesteps,
|
||||
dtype=torch.float64)**2
|
||||
elif schedule == 'cosine':
|
||||
betas = []
|
||||
for step in range(num_timesteps):
|
||||
t1 = step / num_timesteps
|
||||
t2 = (step + 1) / num_timesteps
|
||||
betas.append(min(1.0 - fn(t2) / fn(t1), 0.999))
|
||||
return torch.tensor(betas, dtype=torch.float64)
|
||||
else:
|
||||
raise ValueError(f'Unsupported schedule: {schedule}')
|
||||
|
||||
|
||||
class GaussianDiffusion(object):
|
||||
|
||||
def __init__(self,
|
||||
betas,
|
||||
mean_type='eps',
|
||||
var_type='learned_range',
|
||||
loss_type='mse',
|
||||
epsilon=1e-12,
|
||||
rescale_timesteps=False,
|
||||
noise_strength=0.0):
|
||||
# check input
|
||||
if not isinstance(betas, torch.DoubleTensor):
|
||||
betas = torch.tensor(betas, dtype=torch.float64)
|
||||
assert min(betas) > 0 and max(betas) <= 1
|
||||
assert mean_type in ['x0', 'x_{t-1}', 'eps', 'v']
|
||||
assert var_type in [
|
||||
'learned', 'learned_range', 'fixed_large', 'fixed_small'
|
||||
]
|
||||
assert loss_type in [
|
||||
'mse', 'rescaled_mse', 'kl', 'rescaled_kl', 'l1', 'rescaled_l1',
|
||||
'charbonnier'
|
||||
]
|
||||
self.betas = betas
|
||||
self.num_timesteps = len(betas)
|
||||
self.mean_type = mean_type
|
||||
self.var_type = var_type
|
||||
self.loss_type = loss_type
|
||||
self.epsilon = epsilon
|
||||
self.rescale_timesteps = rescale_timesteps
|
||||
self.noise_strength = noise_strength
|
||||
|
||||
# alphas
|
||||
alphas = 1 - self.betas
|
||||
self.alphas_cumprod = torch.cumprod(alphas, dim=0)
|
||||
self.alphas_cumprod_prev = torch.cat(
|
||||
[alphas.new_ones([1]), self.alphas_cumprod[:-1]])
|
||||
self.alphas_cumprod_next = torch.cat(
|
||||
[self.alphas_cumprod[1:],
|
||||
alphas.new_zeros([1])])
|
||||
|
||||
# q(x_t | x_{t-1})
|
||||
self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
|
||||
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0
|
||||
- self.alphas_cumprod)
|
||||
self.log_one_minus_alphas_cumprod = torch.log(1.0
|
||||
- self.alphas_cumprod)
|
||||
self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod)
|
||||
self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod
|
||||
- 1)
|
||||
|
||||
# q(x_{t-1} | x_t, x_0)
|
||||
self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / (
|
||||
1.0 - self.alphas_cumprod)
|
||||
self.posterior_log_variance_clipped = torch.log(
|
||||
self.posterior_variance.clamp(1e-20))
|
||||
self.posterior_mean_coef1 = betas * torch.sqrt(
|
||||
self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
|
||||
self.posterior_mean_coef2 = (
|
||||
1.0 - self.alphas_cumprod_prev) * torch.sqrt(alphas) / (
|
||||
1.0 - self.alphas_cumprod)
|
||||
|
||||
def sample_loss(self, x0, noise=None):
|
||||
if noise is None:
|
||||
noise = torch.randn_like(x0)
|
||||
if self.noise_strength > 0:
|
||||
b, c, f, _, _ = x0.shape
|
||||
offset_noise = torch.randn(b, c, f, 1, 1, device=x0.device)
|
||||
noise = noise + self.noise_strength * offset_noise
|
||||
return noise
|
||||
|
||||
def q_sample(self, x0, t, noise=None):
|
||||
r"""Sample from q(x_t | x_0).
|
||||
"""
|
||||
# noise = torch.randn_like(x0) if noise is None else noise
|
||||
noise = self.sample_loss(x0, noise)
|
||||
return _i(self.sqrt_alphas_cumprod, t, x0) * x0 + (
|
||||
_i(self.sqrt_one_minus_alphas_cumprod, t, x0) * noise)
|
||||
|
||||
def q_mean_variance(self, x0, t):
|
||||
r"""Distribution of q(x_t | x_0).
|
||||
"""
|
||||
mu = _i(self.sqrt_alphas_cumprod, t, x0) * x0
|
||||
var = _i(1.0 - self.alphas_cumprod, t, x0)
|
||||
log_var = _i(self.log_one_minus_alphas_cumprod, t, x0)
|
||||
return mu, var, log_var
|
||||
|
||||
def q_posterior_mean_variance(self, x0, xt, t):
|
||||
r"""Distribution of q(x_{t-1} | x_t, x_0).
|
||||
"""
|
||||
mu = _i(self.posterior_mean_coef1, t, xt) * x0 + _i(
|
||||
self.posterior_mean_coef2, t, xt) * xt
|
||||
var = _i(self.posterior_variance, t, xt)
|
||||
log_var = _i(self.posterior_log_variance_clipped, t, xt)
|
||||
return mu, var, log_var
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample(self,
|
||||
xt,
|
||||
t,
|
||||
model,
|
||||
model_kwargs={},
|
||||
clamp=None,
|
||||
percentile=None,
|
||||
condition_fn=None,
|
||||
guide_scale=None):
|
||||
r"""Sample from p(x_{t-1} | x_t).
|
||||
- condition_fn: for classifier-based guidance (guided-diffusion).
|
||||
- guide_scale: for classifier-free guidance (glide/dalle-2).
|
||||
"""
|
||||
# predict distribution of p(x_{t-1} | x_t)
|
||||
mu, var, log_var, x0 = self.p_mean_variance(xt, t, model, model_kwargs,
|
||||
clamp, percentile,
|
||||
guide_scale)
|
||||
|
||||
# random sample (with optional conditional function)
|
||||
noise = torch.randn_like(xt)
|
||||
mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1)))
|
||||
if condition_fn is not None:
|
||||
grad = condition_fn(xt, self._scale_timesteps(t), **model_kwargs)
|
||||
mu = mu.float() + var * grad.float()
|
||||
xt_1 = mu + mask * torch.exp(0.5 * log_var) * noise
|
||||
return xt_1, x0
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_loop(self,
|
||||
noise,
|
||||
model,
|
||||
model_kwargs={},
|
||||
clamp=None,
|
||||
percentile=None,
|
||||
condition_fn=None,
|
||||
guide_scale=None):
|
||||
r"""Sample from p(x_{t-1} | x_t) p(x_{t-2} | x_{t-1}) ... p(x_0 | x_1).
|
||||
"""
|
||||
# prepare input
|
||||
b = noise.size(0)
|
||||
xt = noise
|
||||
|
||||
# diffusion process
|
||||
for step in torch.arange(self.num_timesteps).flip(0):
|
||||
t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
|
||||
xt, _ = self.p_sample(xt, t, model, model_kwargs, clamp,
|
||||
percentile, condition_fn, guide_scale)
|
||||
return xt
|
||||
|
||||
def p_mean_variance(self,
|
||||
xt,
|
||||
t,
|
||||
model,
|
||||
model_kwargs={},
|
||||
clamp=None,
|
||||
percentile=None,
|
||||
guide_scale=None):
|
||||
r"""Distribution of p(x_{t-1} | x_t).
|
||||
"""
|
||||
# predict distribution
|
||||
if guide_scale is None:
|
||||
out = model(xt, self._scale_timesteps(t), **model_kwargs)
|
||||
else:
|
||||
# classifier-free guidance
|
||||
# (model_kwargs[0]: conditional kwargs; model_kwargs[1]: non-conditional kwargs)
|
||||
assert isinstance(model_kwargs, list) and len(model_kwargs) == 2
|
||||
y_out = model(xt, self._scale_timesteps(t), **model_kwargs[0])
|
||||
u_out = model(xt, self._scale_timesteps(t), **model_kwargs[1])
|
||||
dim = y_out.size(1) if self.var_type.startswith(
|
||||
'fixed') else y_out.size(1) // 2
|
||||
out = torch.cat(
|
||||
[
|
||||
u_out[:, :dim] + guide_scale * # noqa
|
||||
(y_out[:, :dim] - u_out[:, :dim]),
|
||||
y_out[:, dim:]
|
||||
],
|
||||
dim=1)
|
||||
|
||||
# compute variance
|
||||
if self.var_type == 'learned':
|
||||
out, log_var = out.chunk(2, dim=1)
|
||||
var = torch.exp(log_var)
|
||||
elif self.var_type == 'learned_range':
|
||||
out, fraction = out.chunk(2, dim=1)
|
||||
min_log_var = _i(self.posterior_log_variance_clipped, t, xt)
|
||||
max_log_var = _i(torch.log(self.betas), t, xt)
|
||||
fraction = (fraction + 1) / 2.0
|
||||
log_var = fraction * max_log_var + (1 - fraction) * min_log_var
|
||||
var = torch.exp(log_var)
|
||||
elif self.var_type == 'fixed_large':
|
||||
var = _i(
|
||||
torch.cat([self.posterior_variance[1:2], self.betas[1:]]), t,
|
||||
xt)
|
||||
log_var = torch.log(var)
|
||||
elif self.var_type == 'fixed_small':
|
||||
var = _i(self.posterior_variance, t, xt)
|
||||
log_var = _i(self.posterior_log_variance_clipped, t, xt)
|
||||
|
||||
# compute mean and x0
|
||||
if self.mean_type == 'x_{t-1}':
|
||||
mu = out
|
||||
x0 = _i(1.0 / self.posterior_mean_coef1, t, xt) * mu - (
|
||||
_i(self.posterior_mean_coef2 / self.posterior_mean_coef1, t,
|
||||
xt) * xt)
|
||||
elif self.mean_type == 'x0':
|
||||
x0 = out
|
||||
mu, _, _ = self.q_posterior_mean_variance(x0, xt, t)
|
||||
elif self.mean_type == 'eps':
|
||||
x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - (
|
||||
_i(self.sqrt_recipm1_alphas_cumprod, t, xt) * out)
|
||||
mu, _, _ = self.q_posterior_mean_variance(x0, xt, t)
|
||||
elif self.mean_type == 'v':
|
||||
x0 = _i(self.sqrt_alphas_cumprod, t, xt) * xt - (
|
||||
_i(self.sqrt_one_minus_alphas_cumprod, t, xt) * out)
|
||||
mu, _, _ = self.q_posterior_mean_variance(x0, xt, t)
|
||||
|
||||
# restrict the range of x0
|
||||
if percentile is not None:
|
||||
assert percentile > 0 and percentile <= 1
|
||||
s = torch.quantile(
|
||||
x0.flatten(1).abs(), percentile,
|
||||
dim=1).clamp_(1.0).view(-1, 1, 1, 1)
|
||||
x0 = torch.min(s, torch.max(-s, x0)) / s
|
||||
elif clamp is not None:
|
||||
x0 = x0.clamp(-clamp, clamp)
|
||||
return mu, var, log_var, x0
|
||||
|
||||
@torch.no_grad()
|
||||
def ddim_sample(self,
|
||||
xt,
|
||||
t,
|
||||
model,
|
||||
model_kwargs={},
|
||||
clamp=None,
|
||||
percentile=None,
|
||||
condition_fn=None,
|
||||
guide_scale=None,
|
||||
ddim_timesteps=20,
|
||||
eta=0.0):
|
||||
r"""Sample from p(x_{t-1} | x_t) using DDIM.
|
||||
- condition_fn: for classifier-based guidance (guided-diffusion).
|
||||
- guide_scale: for classifier-free guidance (glide/dalle-2).
|
||||
"""
|
||||
stride = self.num_timesteps // ddim_timesteps
|
||||
|
||||
# predict distribution of p(x_{t-1} | x_t)
|
||||
_, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp,
|
||||
percentile, guide_scale)
|
||||
if condition_fn is not None:
|
||||
# x0 -> eps
|
||||
alpha = _i(self.alphas_cumprod, t, xt)
|
||||
eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / (
|
||||
_i(self.sqrt_recipm1_alphas_cumprod, t, xt))
|
||||
eps = eps - (1 - alpha).sqrt() * condition_fn(
|
||||
xt, self._scale_timesteps(t), **model_kwargs)
|
||||
|
||||
# eps -> x0
|
||||
x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - (
|
||||
_i(self.sqrt_recipm1_alphas_cumprod, t, xt) * eps)
|
||||
|
||||
# derive variables
|
||||
eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / (
|
||||
_i(self.sqrt_recipm1_alphas_cumprod, t, xt))
|
||||
alphas = _i(self.alphas_cumprod, t, xt)
|
||||
alphas_prev = _i(self.alphas_cumprod, (t - stride).clamp(0), xt)
|
||||
sigmas = eta * torch.sqrt((1 - alphas_prev) / (1 - alphas) * # noqa
|
||||
(1 - alphas / alphas_prev))
|
||||
|
||||
# random sample
|
||||
noise = torch.randn_like(xt)
|
||||
direction = torch.sqrt(1 - alphas_prev - sigmas**2) * eps
|
||||
mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1)))
|
||||
xt_1 = torch.sqrt(alphas_prev) * x0 + direction + mask * sigmas * noise
|
||||
return xt_1, x0
|
||||
|
||||
@torch.no_grad()
|
||||
def ddim_sample_loop(self,
|
||||
noise,
|
||||
model,
|
||||
model_kwargs={},
|
||||
clamp=None,
|
||||
percentile=None,
|
||||
condition_fn=None,
|
||||
guide_scale=None,
|
||||
ddim_timesteps=20,
|
||||
eta=0.0):
|
||||
# prepare input
|
||||
b = noise.size(0)
|
||||
xt = noise
|
||||
|
||||
# diffusion process (TODO: clamp is inaccurate! Consider replacing the stride by explicit prev/next steps)
|
||||
steps = (1 + torch.arange(0, self.num_timesteps,
|
||||
self.num_timesteps // ddim_timesteps)).clamp(
|
||||
0, self.num_timesteps - 1).flip(0)
|
||||
for step in steps:
|
||||
t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
|
||||
xt, _ = self.ddim_sample(xt, t, model, model_kwargs, clamp,
|
||||
percentile, condition_fn, guide_scale,
|
||||
ddim_timesteps, eta)
|
||||
return xt
|
||||
|
||||
@torch.no_grad()
|
||||
def ddim_reverse_sample(self,
|
||||
xt,
|
||||
t,
|
||||
model,
|
||||
model_kwargs={},
|
||||
clamp=None,
|
||||
percentile=None,
|
||||
guide_scale=None,
|
||||
ddim_timesteps=20):
|
||||
r"""Sample from p(x_{t+1} | x_t) using DDIM reverse ODE (deterministic).
|
||||
"""
|
||||
stride = self.num_timesteps // ddim_timesteps
|
||||
|
||||
# predict distribution of p(x_{t-1} | x_t)
|
||||
_, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp,
|
||||
percentile, guide_scale)
|
||||
|
||||
# derive variables
|
||||
eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / (
|
||||
_i(self.sqrt_recipm1_alphas_cumprod, t, xt))
|
||||
alphas_next = _i(
|
||||
torch.cat(
|
||||
[self.alphas_cumprod,
|
||||
self.alphas_cumprod.new_zeros([1])]),
|
||||
(t + stride).clamp(0, self.num_timesteps), xt)
|
||||
|
||||
# reverse sample
|
||||
mu = torch.sqrt(alphas_next) * x0 + torch.sqrt(1 - alphas_next) * eps
|
||||
return mu, x0
|
||||
|
||||
@torch.no_grad()
|
||||
def ddim_reverse_sample_loop(self,
|
||||
x0,
|
||||
model,
|
||||
model_kwargs={},
|
||||
clamp=None,
|
||||
percentile=None,
|
||||
guide_scale=None,
|
||||
ddim_timesteps=20):
|
||||
# prepare input
|
||||
b = x0.size(0)
|
||||
xt = x0
|
||||
|
||||
# reconstruction steps
|
||||
steps = torch.arange(0, self.num_timesteps,
|
||||
self.num_timesteps // ddim_timesteps)
|
||||
for step in steps:
|
||||
t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
|
||||
xt, _ = self.ddim_reverse_sample(xt, t, model, model_kwargs, clamp,
|
||||
percentile, guide_scale,
|
||||
ddim_timesteps)
|
||||
return xt
|
||||
|
||||
@torch.no_grad()
|
||||
def plms_sample(self,
|
||||
xt,
|
||||
t,
|
||||
model,
|
||||
model_kwargs={},
|
||||
clamp=None,
|
||||
percentile=None,
|
||||
condition_fn=None,
|
||||
guide_scale=None,
|
||||
plms_timesteps=20):
|
||||
r"""Sample from p(x_{t-1} | x_t) using PLMS.
|
||||
- condition_fn: for classifier-based guidance (guided-diffusion).
|
||||
- guide_scale: for classifier-free guidance (glide/dalle-2).
|
||||
"""
|
||||
stride = self.num_timesteps // plms_timesteps
|
||||
|
||||
# function for compute eps
|
||||
def compute_eps(xt, t):
|
||||
# predict distribution of p(x_{t-1} | x_t)
|
||||
_, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs,
|
||||
clamp, percentile, guide_scale)
|
||||
|
||||
# condition
|
||||
if condition_fn is not None:
|
||||
# x0 -> eps
|
||||
alpha = _i(self.alphas_cumprod, t, xt)
|
||||
eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / (
|
||||
_i(self.sqrt_recipm1_alphas_cumprod, t, xt))
|
||||
eps = eps - (1 - alpha).sqrt() * condition_fn(
|
||||
xt, self._scale_timesteps(t), **model_kwargs)
|
||||
|
||||
# eps -> x0
|
||||
x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - (
|
||||
_i(self.sqrt_recipm1_alphas_cumprod, t, xt) * eps)
|
||||
|
||||
# derive eps
|
||||
eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / (
|
||||
_i(self.sqrt_recipm1_alphas_cumprod, t, xt))
|
||||
return eps
|
||||
|
||||
# function for compute x_0 and x_{t-1}
|
||||
def compute_x0(eps, t):
|
||||
# eps -> x0
|
||||
x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - (
|
||||
_i(self.sqrt_recipm1_alphas_cumprod, t, xt) * eps)
|
||||
|
||||
# deterministic sample
|
||||
alphas_prev = _i(self.alphas_cumprod, (t - stride).clamp(0), xt)
|
||||
direction = torch.sqrt(1 - alphas_prev) * eps
|
||||
xt_1 = torch.sqrt(alphas_prev) * x0 + direction
|
||||
return xt_1, x0
|
||||
|
||||
# PLMS sample
|
||||
eps = compute_eps(xt, t)
|
||||
if len(eps_cache) == 0:
|
||||
# 2nd order pseudo improved Euler
|
||||
xt_1, x0 = compute_x0(eps, t)
|
||||
eps_next = compute_eps(xt_1, (t - stride).clamp(0))
|
||||
eps_prime = (eps + eps_next) / 2.0
|
||||
elif len(eps_cache) == 1:
|
||||
# 2nd order pseudo linear multistep (Adams-Bashforth)
|
||||
eps_prime = (3 * eps - eps_cache[-1]) / 2.0
|
||||
elif len(eps_cache) == 2:
|
||||
# 3nd order pseudo linear multistep (Adams-Bashforth)
|
||||
eps_prime = (23 * eps - 16 * eps_cache[-1]
|
||||
+ 5 * eps_cache[-2]) / 12.0
|
||||
elif len(eps_cache) >= 3:
|
||||
# 4nd order pseudo linear multistep (Adams-Bashforth)
|
||||
eps_prime = (55 * eps - 59 * eps_cache[-1] + 37 * eps_cache[-2]
|
||||
- 9 * eps_cache[-3]) / 24.0
|
||||
xt_1, x0 = compute_x0(eps_prime, t)
|
||||
return xt_1, x0, eps
|
||||
|
||||
@torch.no_grad()
|
||||
def plms_sample_loop(self,
|
||||
noise,
|
||||
model,
|
||||
model_kwargs={},
|
||||
clamp=None,
|
||||
percentile=None,
|
||||
condition_fn=None,
|
||||
guide_scale=None,
|
||||
plms_timesteps=20):
|
||||
# prepare input
|
||||
b = noise.size(0)
|
||||
xt = noise
|
||||
|
||||
# diffusion process
|
||||
steps = (1 + torch.arange(0, self.num_timesteps,
|
||||
self.num_timesteps // plms_timesteps)).clamp(
|
||||
0, self.num_timesteps - 1).flip(0)
|
||||
eps_cache = []
|
||||
for step in steps:
|
||||
# PLMS sampling step
|
||||
t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
|
||||
xt, _, eps = self.plms_sample(xt, t, model, model_kwargs, clamp,
|
||||
percentile, condition_fn,
|
||||
guide_scale, plms_timesteps,
|
||||
eps_cache)
|
||||
|
||||
# update eps cache
|
||||
eps_cache.append(eps)
|
||||
if len(eps_cache) >= 4:
|
||||
eps_cache.pop(0)
|
||||
return xt
|
||||
|
||||
def _scale_timesteps(self, t):
|
||||
if self.rescale_timesteps:
|
||||
return t.float() * 1000.0 / self.num_timesteps
|
||||
return t
|
||||
14
modelscope/models/multi_modal/image_to_video/utils/seed.py
Executable file
14
modelscope/models/multi_modal/image_to_video/utils/seed.py
Executable file
@@ -0,0 +1,14 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def setup_seed(seed):
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
@@ -0,0 +1,60 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def fn(u):
|
||||
return math.cos((u + 0.008) / 1.008 * math.pi / 2)**2
|
||||
|
||||
|
||||
def beta_schedule(schedule,
|
||||
num_timesteps=1000,
|
||||
init_beta=None,
|
||||
last_beta=None):
|
||||
'''
|
||||
This code defines a function beta_schedule that generates a sequence of beta values based on the given input
|
||||
parameters. These beta values can be used in video diffusion processes. The function has the following parameters:
|
||||
schedule(str): Determines the type of beta schedule to be generated. It can be 'linear', 'linear_sd',
|
||||
'quadratic', or 'cosine'.
|
||||
num_timesteps(int, optional): The number of timesteps for the generated beta schedule. Default is 1000.
|
||||
init_beta(float, optional): The initial beta value. If not provided, a default value is used based on the
|
||||
chosen schedule.
|
||||
last_beta(float, optional): The final beta value. If not provided, a default value is used based on the
|
||||
chosen schedule.
|
||||
The function returns a PyTorch tensor containing the generated beta values.
|
||||
The beta schedule is determined by the schedule parameter:
|
||||
1.Linear: Generates a linear sequence of beta values betweeninit_betaandlast_beta.
|
||||
2.Linear_sd: Generates a linear sequence of beta values between the square root of init_beta and the square root
|
||||
oflast_beta, and then squares the result.
|
||||
3.Quadratic: Similar to the 'linear_sd' schedule, but with different default values forinit_betaandlast_beta.
|
||||
4.Cosine: Generates a sequence of beta values based on a cosine function, ensuring the values are between 0
|
||||
and 0.999.
|
||||
If an unsupported schedule is provided, a ValueError is raised with a message indicating the issue.
|
||||
'''
|
||||
if schedule == 'linear':
|
||||
scale = 1000.0 / num_timesteps
|
||||
init_beta = init_beta or scale * 0.0001
|
||||
last_beta = last_beta or scale * 0.02
|
||||
return torch.linspace(
|
||||
init_beta, last_beta, num_timesteps, dtype=torch.float64)
|
||||
elif schedule == 'linear_sd':
|
||||
return torch.linspace(
|
||||
init_beta**0.5, last_beta**0.5, num_timesteps,
|
||||
dtype=torch.float64)**2
|
||||
elif schedule == 'quadratic':
|
||||
init_beta = init_beta or 0.0015
|
||||
last_beta = last_beta or 0.0195
|
||||
return torch.linspace(
|
||||
init_beta**0.5, last_beta**0.5, num_timesteps,
|
||||
dtype=torch.float64)**2
|
||||
elif schedule == 'cosine':
|
||||
betas = []
|
||||
for step in range(num_timesteps):
|
||||
t1 = step / num_timesteps
|
||||
t2 = (step + 1) / num_timesteps
|
||||
betas.append(min(1.0 - fn(t2) / fn(t1), 0.999))
|
||||
return torch.tensor(betas, dtype=torch.float64)
|
||||
else:
|
||||
raise ValueError(f'Unsupported schedule: {schedule}')
|
||||
404
modelscope/models/multi_modal/image_to_video/utils/transforms.py
Executable file
404
modelscope/models/multi_modal/image_to_video/utils/transforms.py
Executable file
@@ -0,0 +1,404 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import math
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms.functional as F
|
||||
from PIL import Image, ImageFilter
|
||||
|
||||
__all__ = [
|
||||
'Compose', 'Resize', 'Rescale', 'CenterCrop', 'CenterCropV2',
|
||||
'CenterCropWide', 'RandomCrop', 'RandomCropV2', 'RandomHFlip',
|
||||
'GaussianBlur', 'ColorJitter', 'RandomGray', 'ToTensor', 'Normalize',
|
||||
'ResizeRandomCrop', 'ExtractResizeRandomCrop', 'ExtractResizeAssignCrop'
|
||||
]
|
||||
|
||||
|
||||
class Compose(object):
|
||||
|
||||
def __init__(self, transforms):
|
||||
self.transforms = transforms
|
||||
|
||||
def __getitem__(self, index):
|
||||
if isinstance(index, slice):
|
||||
return Compose(self.transforms[index])
|
||||
else:
|
||||
return self.transforms[index]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.transforms)
|
||||
|
||||
def __call__(self, rgb):
|
||||
for t in self.transforms:
|
||||
rgb = t(rgb)
|
||||
return rgb
|
||||
|
||||
|
||||
class Resize(object):
|
||||
|
||||
def __init__(self, size=256):
|
||||
if isinstance(size, int):
|
||||
size = (size, size)
|
||||
self.size = size
|
||||
|
||||
def __call__(self, rgb):
|
||||
if isinstance(rgb, list):
|
||||
rgb = [u.resize(self.size, Image.BILINEAR) for u in rgb]
|
||||
else:
|
||||
rgb = rgb.resize(self.size, Image.BILINEAR)
|
||||
return rgb
|
||||
|
||||
|
||||
class Rescale(object):
|
||||
|
||||
def __init__(self, size=256, interpolation=Image.BILINEAR):
|
||||
self.size = size
|
||||
self.interpolation = interpolation
|
||||
|
||||
def __call__(self, rgb):
|
||||
w, h = rgb[0].size
|
||||
scale = self.size / min(w, h)
|
||||
out_w, out_h = int(round(w * scale)), int(round(h * scale))
|
||||
rgb = [u.resize((out_w, out_h), self.interpolation) for u in rgb]
|
||||
return rgb
|
||||
|
||||
|
||||
class CenterCrop(object):
|
||||
|
||||
def __init__(self, size=224):
|
||||
self.size = size
|
||||
|
||||
def __call__(self, rgb):
|
||||
w, h = rgb[0].size
|
||||
assert min(w, h) >= self.size
|
||||
x1 = (w - self.size) // 2
|
||||
y1 = (h - self.size) // 2
|
||||
rgb = [u.crop((x1, y1, x1 + self.size, y1 + self.size)) for u in rgb]
|
||||
return rgb
|
||||
|
||||
|
||||
class ResizeRandomCrop(object):
|
||||
|
||||
def __init__(self, size=256, size_short=292):
|
||||
self.size = size
|
||||
self.size_short = size_short
|
||||
|
||||
def __call__(self, rgb):
|
||||
|
||||
# consistent crop between rgb and m
|
||||
while min(rgb[0].size) >= 2 * self.size_short:
|
||||
rgb = [
|
||||
u.resize((u.width // 2, u.height // 2), resample=Image.BOX)
|
||||
for u in rgb
|
||||
]
|
||||
scale = self.size_short / min(rgb[0].size)
|
||||
rgb = [
|
||||
u.resize((round(scale * u.width), round(scale * u.height)),
|
||||
resample=Image.BICUBIC) for u in rgb
|
||||
]
|
||||
out_w = self.size
|
||||
out_h = self.size
|
||||
w, h = rgb[0].size
|
||||
x1 = random.randint(0, w - out_w)
|
||||
y1 = random.randint(0, h - out_h)
|
||||
|
||||
rgb = [u.crop((x1, y1, x1 + out_w, y1 + out_h)) for u in rgb]
|
||||
return rgb
|
||||
|
||||
|
||||
class ExtractResizeRandomCrop(object):
|
||||
|
||||
def __init__(self, size=256, size_short=292):
|
||||
self.size = size
|
||||
self.size_short = size_short
|
||||
|
||||
def __call__(self, rgb):
|
||||
|
||||
# consistent crop between rgb and m
|
||||
while min(rgb[0].size) >= 2 * self.size_short:
|
||||
rgb = [
|
||||
u.resize((u.width // 2, u.height // 2), resample=Image.BOX)
|
||||
for u in rgb
|
||||
]
|
||||
scale = self.size_short / min(rgb[0].size)
|
||||
rgb = [
|
||||
u.resize((round(scale * u.width), round(scale * u.height)),
|
||||
resample=Image.BICUBIC) for u in rgb
|
||||
]
|
||||
out_w = self.size
|
||||
out_h = self.size
|
||||
w, h = rgb[0].size
|
||||
x1 = random.randint(0, w - out_w)
|
||||
y1 = random.randint(0, h - out_h)
|
||||
|
||||
rgb = [u.crop((x1, y1, x1 + out_w, y1 + out_h)) for u in rgb]
|
||||
wh = [x1, y1, x1 + out_w, y1 + out_h]
|
||||
return rgb, wh
|
||||
|
||||
|
||||
class ExtractResizeAssignCrop(object):
|
||||
|
||||
def __init__(self, size=256, size_short=292):
|
||||
self.size = size
|
||||
self.size_short = size_short
|
||||
|
||||
def __call__(self, rgb, wh):
|
||||
|
||||
# consistent crop between rgb and m
|
||||
while min(rgb[0].size) >= 2 * self.size_short:
|
||||
rgb = [
|
||||
u.resize((u.width // 2, u.height // 2), resample=Image.BOX)
|
||||
for u in rgb
|
||||
]
|
||||
scale = self.size_short / min(rgb[0].size)
|
||||
rgb = [
|
||||
u.resize((round(scale * u.width), round(scale * u.height)),
|
||||
resample=Image.BICUBIC) for u in rgb
|
||||
]
|
||||
|
||||
rgb = [u.crop(wh) for u in rgb]
|
||||
rgb = [u.resize((self.size, self.size), Image.BILINEAR) for u in rgb]
|
||||
|
||||
return rgb
|
||||
|
||||
|
||||
class CenterCropV2(object):
|
||||
|
||||
def __init__(self, size):
|
||||
self.size = size
|
||||
|
||||
def __call__(self, img):
|
||||
# fast resize
|
||||
while min(img[0].size) >= 2 * self.size:
|
||||
img = [
|
||||
u.resize((u.width // 2, u.height // 2), resample=Image.BOX)
|
||||
for u in img
|
||||
]
|
||||
scale = self.size / min(img[0].size)
|
||||
img = [
|
||||
u.resize((round(scale * u.width), round(scale * u.height)),
|
||||
resample=Image.BICUBIC) for u in img
|
||||
]
|
||||
|
||||
# center crop
|
||||
x1 = (img[0].width - self.size) // 2
|
||||
y1 = (img[0].height - self.size) // 2
|
||||
img = [u.crop((x1, y1, x1 + self.size, y1 + self.size)) for u in img]
|
||||
return img
|
||||
|
||||
|
||||
class CenterCropWide(object):
|
||||
|
||||
def __init__(self, size):
|
||||
self.size = size
|
||||
|
||||
def __call__(self, img):
|
||||
if isinstance(img, list):
|
||||
scale = min(img[0].size[0] / self.size[0],
|
||||
img[0].size[1] / self.size[1])
|
||||
img = [
|
||||
u.resize((round(u.width // scale), round(u.height // scale)),
|
||||
resample=Image.BOX) for u in img
|
||||
]
|
||||
|
||||
# center crop
|
||||
x1 = (img[0].width - self.size[0]) // 2
|
||||
y1 = (img[0].height - self.size[1]) // 2
|
||||
img = [
|
||||
u.crop((x1, y1, x1 + self.size[0], y1 + self.size[1]))
|
||||
for u in img
|
||||
]
|
||||
return img
|
||||
else:
|
||||
scale = min(img.size[0] / self.size[0], img.size[1] / self.size[1])
|
||||
img = img.resize(
|
||||
(round(img.width // scale), round(img.height // scale)),
|
||||
resample=Image.BOX)
|
||||
x1 = (img.width - self.size[0]) // 2
|
||||
y1 = (img.height - self.size[1]) // 2
|
||||
img = img.crop((x1, y1, x1 + self.size[0], y1 + self.size[1]))
|
||||
return img
|
||||
|
||||
|
||||
class RandomCrop(object):
|
||||
|
||||
def __init__(self, size=224, min_area=0.4):
|
||||
self.size = size
|
||||
self.min_area = min_area
|
||||
|
||||
def __call__(self, rgb):
|
||||
|
||||
# consistent crop between rgb and m
|
||||
w, h = rgb[0].size
|
||||
area = w * h
|
||||
out_w, out_h = float('inf'), float('inf')
|
||||
while out_w > w or out_h > h:
|
||||
target_area = random.uniform(self.min_area, 1.0) * area
|
||||
aspect_ratio = random.uniform(3. / 4., 4. / 3.)
|
||||
out_w = int(round(math.sqrt(target_area * aspect_ratio)))
|
||||
out_h = int(round(math.sqrt(target_area / aspect_ratio)))
|
||||
x1 = random.randint(0, w - out_w)
|
||||
y1 = random.randint(0, h - out_h)
|
||||
|
||||
rgb = [u.crop((x1, y1, x1 + out_w, y1 + out_h)) for u in rgb]
|
||||
rgb = [u.resize((self.size, self.size), Image.BILINEAR) for u in rgb]
|
||||
|
||||
return rgb
|
||||
|
||||
|
||||
class RandomCropV2(object):
|
||||
|
||||
def __init__(self, size=224, min_area=0.4, ratio=(3. / 4., 4. / 3.)):
|
||||
if isinstance(size, (tuple, list)):
|
||||
self.size = size
|
||||
else:
|
||||
self.size = (size, size)
|
||||
self.min_area = min_area
|
||||
self.ratio = ratio
|
||||
|
||||
def _get_params(self, img):
|
||||
width, height = img.size
|
||||
area = height * width
|
||||
|
||||
for _ in range(10):
|
||||
target_area = random.uniform(self.min_area, 1.0) * area
|
||||
log_ratio = (math.log(self.ratio[0]), math.log(self.ratio[1]))
|
||||
aspect_ratio = math.exp(random.uniform(*log_ratio))
|
||||
|
||||
w = int(round(math.sqrt(target_area * aspect_ratio)))
|
||||
h = int(round(math.sqrt(target_area / aspect_ratio)))
|
||||
|
||||
if 0 < w <= width and 0 < h <= height:
|
||||
i = random.randint(0, height - h)
|
||||
j = random.randint(0, width - w)
|
||||
return i, j, h, w
|
||||
|
||||
# Fallback to central crop
|
||||
in_ratio = float(width) / float(height)
|
||||
if (in_ratio < min(self.ratio)):
|
||||
w = width
|
||||
h = int(round(w / min(self.ratio)))
|
||||
elif (in_ratio > max(self.ratio)):
|
||||
h = height
|
||||
w = int(round(h * max(self.ratio)))
|
||||
else:
|
||||
w = width
|
||||
h = height
|
||||
i = (height - h) // 2
|
||||
j = (width - w) // 2
|
||||
return i, j, h, w
|
||||
|
||||
def __call__(self, rgb):
|
||||
i, j, h, w = self._get_params(rgb[0])
|
||||
rgb = [F.resized_crop(u, i, j, h, w, self.size) for u in rgb]
|
||||
return rgb
|
||||
|
||||
|
||||
class RandomHFlip(object):
|
||||
|
||||
def __init__(self, p=0.5):
|
||||
self.p = p
|
||||
|
||||
def __call__(self, rgb):
|
||||
if random.random() < self.p:
|
||||
rgb = [u.transpose(Image.FLIP_LEFT_RIGHT) for u in rgb]
|
||||
return rgb
|
||||
|
||||
|
||||
class GaussianBlur(object):
|
||||
|
||||
def __init__(self, sigmas=[0.1, 2.0], p=0.5):
|
||||
self.sigmas = sigmas
|
||||
self.p = p
|
||||
|
||||
def __call__(self, rgb):
|
||||
if random.random() < self.p:
|
||||
sigma = random.uniform(*self.sigmas)
|
||||
rgb = [
|
||||
u.filter(ImageFilter.GaussianBlur(radius=sigma)) for u in rgb
|
||||
]
|
||||
return rgb
|
||||
|
||||
|
||||
class ColorJitter(object):
|
||||
|
||||
def __init__(self,
|
||||
brightness=0.4,
|
||||
contrast=0.4,
|
||||
saturation=0.4,
|
||||
hue=0.1,
|
||||
p=0.5):
|
||||
self.brightness = brightness
|
||||
self.contrast = contrast
|
||||
self.saturation = saturation
|
||||
self.hue = hue
|
||||
self.p = p
|
||||
|
||||
def __call__(self, rgb):
|
||||
if random.random() < self.p:
|
||||
brightness, contrast, saturation, hue = self._random_params()
|
||||
transforms = [
|
||||
lambda f: F.adjust_brightness(f, brightness),
|
||||
lambda f: F.adjust_contrast(f, contrast),
|
||||
lambda f: F.adjust_saturation(f, saturation),
|
||||
lambda f: F.adjust_hue(f, hue)
|
||||
]
|
||||
random.shuffle(transforms)
|
||||
for t in transforms:
|
||||
rgb = [t(u) for u in rgb]
|
||||
|
||||
return rgb
|
||||
|
||||
def _random_params(self):
|
||||
brightness = random.uniform(
|
||||
max(0, 1 - self.brightness), 1 + self.brightness)
|
||||
contrast = random.uniform(max(0, 1 - self.contrast), 1 + self.contrast)
|
||||
saturation = random.uniform(
|
||||
max(0, 1 - self.saturation), 1 + self.saturation)
|
||||
hue = random.uniform(-self.hue, self.hue)
|
||||
return brightness, contrast, saturation, hue
|
||||
|
||||
|
||||
class RandomGray(object):
|
||||
|
||||
def __init__(self, p=0.2):
|
||||
self.p = p
|
||||
|
||||
def __call__(self, rgb):
|
||||
if random.random() < self.p:
|
||||
rgb = [u.convert('L').convert('RGB') for u in rgb]
|
||||
return rgb
|
||||
|
||||
|
||||
class ToTensor(object):
|
||||
|
||||
def __call__(self, rgb):
|
||||
if isinstance(rgb, list):
|
||||
rgb = torch.stack([F.to_tensor(u) for u in rgb], dim=0)
|
||||
else:
|
||||
rgb = F.to_tensor(rgb)
|
||||
|
||||
return rgb
|
||||
|
||||
|
||||
class Normalize(object):
|
||||
|
||||
def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
|
||||
self.mean = mean
|
||||
self.std = std
|
||||
|
||||
def __call__(self, rgb):
|
||||
rgb = rgb.clone()
|
||||
rgb.clamp_(0, 1)
|
||||
if not isinstance(self.mean, torch.Tensor):
|
||||
self.mean = rgb.new_tensor(self.mean).view(-1)
|
||||
if not isinstance(self.std, torch.Tensor):
|
||||
self.std = rgb.new_tensor(self.std).view(-1)
|
||||
if rgb.dim() == 4:
|
||||
rgb.sub_(self.mean.view(1, -1, 1,
|
||||
1)).div_(self.std.view(1, -1, 1, 1))
|
||||
elif rgb.dim() == 3:
|
||||
rgb.sub_(self.mean.view(-1, 1, 1)).div_(self.std.view(-1, 1, 1))
|
||||
return rgb
|
||||
104
modelscope/pipelines/multi_modal/image_to_video_pipeline.py
Normal file
104
modelscope/pipelines/multi_modal/image_to_video_pipeline.py
Normal file
@@ -0,0 +1,104 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
from einops import rearrange
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.base import Input, Model, Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.preprocessors.image import LoadImage
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.image_to_video, module_name=Pipelines.image_to_video_task_pipeline)
|
||||
class ImageToVideoPipeline(Pipeline):
|
||||
r""" Image To Video Pipeline.
|
||||
|
||||
Examples:
|
||||
>>> from modelscope.pipelines import pipeline
|
||||
>>> from modelscope.outputs import OutputKeys
|
||||
|
||||
>>> p = pipeline('image-to-video', 'damo/Image-to-Video')
|
||||
>>> input = 'path_to_image'
|
||||
>>> p(input,)
|
||||
|
||||
>>> {OutputKeys.OUTPUT_VIDEO: path-to-the-generated-video}
|
||||
>>>
|
||||
"""
|
||||
|
||||
def __init__(self, model: str, **kwargs):
|
||||
"""
|
||||
use `model` to create a kws pipeline for prediction
|
||||
Args:
|
||||
model: model id on modelscope hub.
|
||||
"""
|
||||
super().__init__(model=model, **kwargs)
|
||||
|
||||
def preprocess(self, input: Input, **preprocess_params) -> Dict[str, Any]:
|
||||
img_path = input
|
||||
|
||||
image = LoadImage.convert_to_img(img_path)
|
||||
if image.mode != 'RGB':
|
||||
image = image.convert('RGB')
|
||||
|
||||
vit_frame = self.model.vid_trans(image)
|
||||
vit_frame = vit_frame.unsqueeze(0)
|
||||
vit_frame = vit_frame.to(self.model.device)
|
||||
|
||||
return {'vit_frame': vit_frame}
|
||||
|
||||
def forward(self, input: Dict[str, Any],
|
||||
**forward_params) -> Dict[str, Any]:
|
||||
video = self.model(input)
|
||||
return {'video': video}
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Any],
|
||||
**post_params) -> Dict[str, Any]:
|
||||
video = tensor2vid(inputs['video'], self.model.cfg.mean,
|
||||
self.model.cfg.std)
|
||||
output_video_path = post_params.get('output_video', None)
|
||||
temp_video_file = False
|
||||
if output_video_path is None:
|
||||
output_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name
|
||||
temp_video_file = True
|
||||
|
||||
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
||||
h, w, c = video[0].shape
|
||||
video_writer = cv2.VideoWriter(
|
||||
output_video_path, fourcc, fps=8, frameSize=(w, h))
|
||||
for i in range(len(video)):
|
||||
img = cv2.cvtColor(video[i], cv2.COLOR_RGB2BGR)
|
||||
video_writer.write(img)
|
||||
video_writer.release()
|
||||
if temp_video_file:
|
||||
video_file_content = b''
|
||||
with open(output_video_path, 'rb') as f:
|
||||
video_file_content = f.read()
|
||||
os.remove(output_video_path)
|
||||
return {OutputKeys.OUTPUT_VIDEO: video_file_content}
|
||||
else:
|
||||
return {OutputKeys.OUTPUT_VIDEO: output_video_path}
|
||||
|
||||
|
||||
def tensor2vid(video, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]):
|
||||
mean = torch.tensor(mean, device=video.device).reshape(1, -1, 1, 1, 1)
|
||||
std = torch.tensor(std, device=video.device).reshape(1, -1, 1, 1, 1)
|
||||
|
||||
video = video.mul_(std).add_(mean)
|
||||
video.clamp_(0, 1)
|
||||
video = video * 255.0
|
||||
|
||||
images = rearrange(video, 'b c f h w -> b f h w c')[0]
|
||||
images = [(img.numpy()).astype('uint8') for img in images]
|
||||
|
||||
return images
|
||||
@@ -256,6 +256,7 @@ class MultiModalTasks(object):
|
||||
text_to_video_synthesis = 'text-to-video-synthesis'
|
||||
efficient_diffusion_tuning = 'efficient-diffusion-tuning'
|
||||
multimodal_dialogue = 'multimodal-dialogue'
|
||||
image_to_video = 'image-to-video'
|
||||
|
||||
|
||||
class ScienceTasks(object):
|
||||
|
||||
28
tests/pipelines/test_image2video.py
Normal file
28
tests/pipelines/test_image2video.py
Normal file
@@ -0,0 +1,28 @@
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
from modelscope.models import Model
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import DownloadMode, Tasks
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
|
||||
class Image2VideoTest(unittest.TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.task = Tasks.image_to_video
|
||||
self.model_id = 'damo/Image-to-Video'
|
||||
self.path = 'https://video-generation-wulanchabu.oss-cn-wulanchabu.aliyuncs.com/baishao/test.jpeg'
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_run_with_model_from_modelhub(self):
|
||||
pipe = pipeline(task=self.task, model=self.model_id)
|
||||
|
||||
output_video_path = pipe(
|
||||
self.path, output_video='./output.mp4')[OutputKeys.OUTPUT_VIDEO]
|
||||
print(output_video_path)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user