add video2video (#486)

* add video2video

* fix bugs of pre-commit

* update some files

* fix video write module

* fix max_frames
This commit is contained in:
Kang
2023-08-21 18:44:14 +08:00
committed by GitHub
parent 5d0f85a9ba
commit 040698e201
17 changed files with 3744 additions and 0 deletions

View File

@@ -221,6 +221,7 @@ class Models(object):
videocomposer = 'videocomposer'
text_to_360panorama_image = 'text-to-360panorama-image'
image_to_video_model = 'image-to-video-model'
video_to_video_model = 'video-to-video-model'
# science models
unifold = 'unifold'
@@ -547,6 +548,7 @@ class Pipelines(object):
multimodal_dialogue = 'multimodal-dialogue'
llama2_text_generation_pipeline = 'llama2-text-generation-pipeline'
image_to_video_task_pipeline = 'image-to-video-task-pipeline'
video_to_video_pipeline = 'video-to-video-pipeline'
# science tasks
protein_structure = 'unifold-protein-structure'

View File

@@ -0,0 +1,24 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import TYPE_CHECKING
from modelscope.utils.import_utils import LazyImportModule
if TYPE_CHECKING:
from .video_to_video_model import VideoToVideo
else:
_import_structure = {
'video_to_video_model': ['VideoToVideo'],
}
import sys
sys.modules[__name__] = LazyImportModule(
__name__,
globals()['__file__'],
_import_structure,
module_spec=__spec__,
extra_objects={},
)

View File

@@ -0,0 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from .autoencoder import *
from .embedder import *
from .unet_v2v import *

View File

@@ -0,0 +1,590 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import collections
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from modelscope.utils.logger import get_logger
logger = get_logger()
def nonlinearity(x):
# swish
return x * torch.sigmoid(x)
def Normalize(in_channels, num_groups=32):
return torch.nn.GroupNorm(
num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
@torch.no_grad()
def get_first_stage_encoding(encoder_posterior):
scale_factor = 0.18215
if isinstance(encoder_posterior, DiagonalGaussianDistribution):
z = encoder_posterior.sample()
elif isinstance(encoder_posterior, torch.Tensor):
z = encoder_posterior
else:
raise NotImplementedError(
f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented"
)
return scale_factor * z
class DiagonalGaussianDistribution(object):
def __init__(self, parameters, deterministic=False):
self.parameters = parameters
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.deterministic = deterministic
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
if self.deterministic:
self.var = self.std = torch.zeros_like(
self.mean).to(device=self.parameters.device)
def sample(self):
x = self.mean + self.std * torch.randn(
self.mean.shape).to(device=self.parameters.device)
return x
def kl(self, other=None):
if self.deterministic:
return torch.Tensor([0.])
else:
if other is None:
return 0.5 * torch.sum(
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
dim=[1, 2, 3])
else:
return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
dim=[1, 2, 3])
def nll(self, sample, dims=[1, 2, 3]):
if self.deterministic:
return torch.Tensor([0.])
logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum(
logtwopi + self.logvar
+ torch.pow(sample - self.mean, 2) / self.var,
dim=dims)
def mode(self):
return self.mean
class ResnetBlock(nn.Module):
def __init__(self,
*,
in_channels,
out_channels=None,
conv_shortcut=False,
dropout,
temb_channels=512):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.norm1 = Normalize(in_channels)
self.conv1 = torch.nn.Conv2d(
in_channels, out_channels, kernel_size=3, stride=1, padding=1)
if temb_channels > 0:
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
self.norm2 = Normalize(out_channels)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = torch.nn.Conv2d(
out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = torch.nn.Conv2d(
in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
else:
self.nin_shortcut = torch.nn.Conv2d(
in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0)
def forward(self, x, temb):
h = x
h = self.norm1(h)
h = nonlinearity(h)
h = self.conv1(h)
if temb is not None:
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
h = self.norm2(h)
h = nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)
return x + h
class AttnBlock(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.k = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
b, c, h, w = q.shape
q = q.reshape(b, c, h * w)
q = q.permute(0, 2, 1)
k = k.reshape(b, c, h * w)
w_ = torch.bmm(q, k)
w_ = w_ * (int(c)**(-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)
# attend to values
v = v.reshape(b, c, h * w)
w_ = w_.permute(0, 2, 1)
h_ = torch.bmm(v, w_)
h_ = h_.reshape(b, c, h, w)
h_ = self.proj_out(h_)
return x + h_
class Upsample(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
self.conv = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=1, padding=1)
def forward(self, x):
x = torch.nn.functional.interpolate(
x, scale_factor=2.0, mode='nearest')
if self.with_conv:
x = self.conv(x)
return x
class Downsample(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
# no asymmetric padding in torch conv, must do it ourselves
self.conv = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=2, padding=0)
def forward(self, x):
if self.with_conv:
pad = (0, 1, 0, 1)
x = torch.nn.functional.pad(x, pad, mode='constant', value=0)
x = self.conv(x)
else:
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
return x
class Encoder(nn.Module):
def __init__(self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
z_channels,
double_z=True,
use_linear_attn=False,
attn_type='vanilla',
**ignore_kwargs):
super().__init__()
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
# downsampling
self.conv_in = torch.nn.Conv2d(
in_channels, self.ch, kernel_size=3, stride=1, padding=1)
curr_res = resolution
in_ch_mult = (1, ) + tuple(ch_mult)
self.in_ch_mult = in_ch_mult
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks):
block.append(
ResnetBlock(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(AttnBlock(block_in))
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions - 1:
down.downsample = Downsample(block_in, resamp_with_conv)
curr_res = curr_res // 2
self.down.append(down)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
self.mid.attn_1 = AttnBlock(block_in)
self.mid.block_2 = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(
block_in,
2 * z_channels if double_z else z_channels,
kernel_size=3,
stride=1,
padding=1)
def forward(self, x):
# timestep embedding
temb = None
# downsampling
hs = [self.conv_in(x)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1], temb)
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
hs.append(h)
if i_level != self.num_resolutions - 1:
hs.append(self.down[i_level].downsample(hs[-1]))
# middle
h = hs[-1]
h = self.mid.block_1(h, temb)
h = self.mid.attn_1(h)
h = self.mid.block_2(h, temb)
# end
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
return h
class Decoder(nn.Module):
def __init__(self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
z_channels,
give_pre_end=False,
tanh_out=False,
use_linear_attn=False,
attn_type='vanilla',
**ignorekwargs):
super().__init__()
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.give_pre_end = give_pre_end
self.tanh_out = tanh_out
# compute block_in and curr_res at lowest res
block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2**(self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res)
logger.info('Working with z of shape {} = {} dimensions.'.format(
self.z_shape, np.prod(self.z_shape)))
# z to block_in
self.conv_in = torch.nn.Conv2d(
z_channels, block_in, kernel_size=3, stride=1, padding=1)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
self.mid.attn_1 = AttnBlock(block_in)
self.mid.block_2 = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
block.append(
ResnetBlock(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(AttnBlock(block_in))
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = Upsample(block_in, resamp_with_conv)
curr_res = curr_res * 2
self.up.insert(0, up)
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(
block_in, out_ch, kernel_size=3, stride=1, padding=1)
def forward(self, z):
self.last_z_shape = z.shape
# timestep embedding
temb = None
# z to block_in
h = self.conv_in(z)
# middle
h = self.mid.block_1(h, temb)
h = self.mid.attn_1(h)
h = self.mid.block_2(h, temb)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](h, temb)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
if i_level != 0:
h = self.up[i_level].upsample(h)
# end
if self.give_pre_end:
return h
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
if self.tanh_out:
h = torch.tanh(h)
return h
class AutoencoderKL(nn.Module):
def __init__(self,
ddconfig,
embed_dim,
pretrained=None,
ignore_keys=[],
image_key='image',
colorize_nlabels=None,
monitor=None,
ema_decay=None,
learn_logvar=False,
**kwargs):
super().__init__()
self.learn_logvar = learn_logvar
self.image_key = image_key
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
assert ddconfig['double_z']
self.quant_conv = torch.nn.Conv2d(2 * ddconfig['z_channels'],
2 * embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim,
ddconfig['z_channels'], 1)
self.embed_dim = embed_dim
if colorize_nlabels is not None:
assert type(colorize_nlabels) == int
self.register_buffer('colorize',
torch.randn(3, colorize_nlabels, 1, 1))
if monitor is not None:
self.monitor = monitor
self.use_ema = ema_decay is not None
if pretrained is not None:
self.init_from_ckpt(pretrained, ignore_keys=ignore_keys)
def init_from_ckpt(self, path, ignore_keys=list()):
sd = torch.load(path, map_location='cpu')['state_dict']
keys = list(sd.keys())
sd_new = collections.OrderedDict()
for k in keys:
if k.find('first_stage_model') >= 0:
k_new = k.split('first_stage_model.')[-1]
sd_new[k_new] = sd[k]
self.load_state_dict(sd_new, strict=True)
logger.info(f'Restored from {path}')
def on_train_batch_end(self, *args, **kwargs):
if self.use_ema:
self.model_ema(self)
def encode(self, x):
h = self.encoder(x)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
return posterior
def decode(self, z):
z = self.post_quant_conv(z)
dec = self.decoder(z)
return dec
def forward(self, input, sample_posterior=True):
posterior = self.encode(input)
if sample_posterior:
z = posterior.sample()
else:
z = posterior.mode()
dec = self.decode(z)
return dec, posterior
def get_input(self, batch, k):
x = batch[k]
if len(x.shape) == 3:
x = x[..., None]
x = x.permute(0, 3, 1,
2).to(memory_format=torch.contiguous_format).float()
return x
def get_last_layer(self):
return self.decoder.conv_out.weight
@torch.no_grad()
def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs):
log = dict()
x = self.get_input(batch, self.image_key)
x = x.to(self.device)
if not only_inputs:
xrec, posterior = self(x)
if x.shape[1] > 3:
# colorize with random projection
assert xrec.shape[1] > 3
x = self.to_rgb(x)
xrec = self.to_rgb(xrec)
log['samples'] = self.decode(torch.randn_like(posterior.sample()))
log['reconstructions'] = xrec
if log_ema or self.use_ema:
with self.ema_scope():
xrec_ema, posterior_ema = self(x)
if x.shape[1] > 3:
# colorize with random projection
assert xrec_ema.shape[1] > 3
xrec_ema = self.to_rgb(xrec_ema)
log['samples_ema'] = self.decode(
torch.randn_like(posterior_ema.sample()))
log['reconstructions_ema'] = xrec_ema
log['inputs'] = x
return log
def to_rgb(self, x):
assert self.image_key == 'segmentation'
if not hasattr(self, 'colorize'):
self.register_buffer('colorize',
torch.randn(3, x.shape[1], 1, 1).to(x))
x = F.conv2d(x, weight=self.colorize)
x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
return x
class IdentityFirstStage(torch.nn.Module):
def __init__(self, *args, vq_interface=False, **kwargs):
self.vq_interface = vq_interface
super().__init__()
def encode(self, x, *args, **kwargs):
return x
def decode(self, x, *args, **kwargs):
return x
def quantize(self, x, *args, **kwargs):
if self.vq_interface:
return x, None, [None, None, None]
return x
def forward(self, x, *args, **kwargs):
return x

View File

@@ -0,0 +1,76 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import numpy as np
import open_clip
import torch
import torch.nn as nn
import torchvision.transforms as T
class FrozenOpenCLIPEmbedder(nn.Module):
"""
Uses the OpenCLIP transformer encoder for text
"""
LAYERS = ['last', 'penultimate']
def __init__(self,
pretrained,
arch='ViT-H-14',
device='cuda',
max_length=77,
freeze=True,
layer='penultimate'):
super().__init__()
assert layer in self.LAYERS
model, _, preprocess = open_clip.create_model_and_transforms(
arch, device=torch.device('cpu'), pretrained=pretrained)
del model.visual
self.model = model
self.device = device
self.max_length = max_length
if freeze:
self.freeze()
self.layer = layer
if self.layer == 'last':
self.layer_idx = 0
elif self.layer == 'penultimate':
self.layer_idx = 1
else:
raise NotImplementedError()
def freeze(self):
self.model = self.model.eval()
for param in self.parameters():
param.requires_grad = False
def forward(self, text):
tokens = open_clip.tokenize(text)
z = self.encode_with_transformer(tokens.to(self.device))
return z
def encode_with_transformer(self, text):
x = self.model.token_embedding(text)
x = x + self.model.positional_embedding
x = x.permute(1, 0, 2)
x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
x = x.permute(1, 0, 2)
x = self.model.ln_final(x)
return x
def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
for i, r in enumerate(self.model.transformer.resblocks):
if i == len(self.model.transformer.resblocks) - self.layer_idx:
break
if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting(
):
x = checkpoint(r, x, attn_mask)
else:
x = r(x, attn_mask=attn_mask)
return x
def encode(self, text):
return self(text)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,2 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os

View File

@@ -0,0 +1,171 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import logging
import os
import os.path as osp
from datetime import datetime
import torch
from easydict import EasyDict
cfg = EasyDict(__name__='Config: VideoLDM Decoder')
# ---------------------------work dir--------------------------
cfg.work_dir = 'workspace/'
# ---------------------------Global Variable-----------------------------------
cfg.resolution = [448, 256]
cfg.max_frames = 32
# -----------------------------------------------------------------------------
# ---------------------------Dataset Parameter---------------------------------
cfg.mean = [0.5, 0.5, 0.5]
cfg.std = [0.5, 0.5, 0.5]
cfg.max_words = 1000
# PlaceHolder
cfg.vit_out_dim = 1024
cfg.vit_resolution = [224, 224]
cfg.depth_clamp = 10.0
cfg.misc_size = 384
cfg.depth_std = 20.0
cfg.frame_lens = 32
cfg.sample_fps = 8
cfg.batch_sizes = 1
# -----------------------------------------------------------------------------
# ---------------------------Mode Parameters-----------------------------------
# Diffusion
cfg.schedule = 'cosine'
cfg.num_timesteps = 1000
cfg.mean_type = 'v'
cfg.var_type = 'fixed_small'
cfg.loss_type = 'mse'
cfg.ddim_timesteps = 50
cfg.ddim_eta = 0.0
cfg.clamp = 1.0
cfg.share_noise = False
cfg.use_div_loss = False
cfg.noise_strength = 0.1
# classifier-free guidance
cfg.p_zero = 0.1
cfg.guide_scale = 3.0
# clip vision encoder
cfg.vit_mean = [0.48145466, 0.4578275, 0.40821073]
cfg.vit_std = [0.26862954, 0.26130258, 0.27577711]
# Model
cfg.scale_factor = 0.18215
cfg.use_fp16 = True
cfg.temporal_attention = True
cfg.decoder_bs = 8
cfg.UNet = {
'type': 'Vid2VidSDUNet',
'in_dim': 4,
'dim': 320,
'y_dim': cfg.vit_out_dim,
'context_dim': 1024,
'out_dim': 8 if cfg.var_type.startswith('learned') else 4,
'dim_mult': [1, 2, 4, 4],
'num_heads': 8,
'head_dim': 64,
'num_res_blocks': 2,
'attn_scales': [1 / 1, 1 / 2, 1 / 4],
'dropout': 0.1,
'temporal_attention': cfg.temporal_attention,
'temporal_attn_times': 1,
'use_checkpoint': False,
'use_fps_condition': False,
'use_sim_mask': False,
'num_tokens': 4,
'default_fps': 8,
'input_dim': 1024
}
cfg.guidances = []
# auotoencoder from stabel diffusion
cfg.auto_encoder = {
'type': 'AutoencoderKL',
'ddconfig': {
'double_z': True,
'z_channels': 4,
'resolution': 256,
'in_channels': 3,
'out_ch': 3,
'ch': 128,
'ch_mult': [1, 2, 4, 4],
'num_res_blocks': 2,
'attn_resolutions': [],
'dropout': 0.0
},
'embed_dim': 4,
'pretrained': 'models/v2-1_512-ema-pruned.ckpt'
}
# clip embedder
cfg.embedder = {
'type': 'FrozenOpenCLIPEmbedder',
'layer': 'penultimate',
'vit_resolution': [224, 224],
'pretrained': 'open_clip_pytorch_model.bin'
}
# -----------------------------------------------------------------------------
# ---------------------------Training Settings---------------------------------
# training and optimizer
cfg.ema_decay = 0.9999
cfg.num_steps = 600000
cfg.lr = 5e-5
cfg.weight_decay = 0.0
cfg.betas = (0.9, 0.999)
cfg.eps = 1.0e-8
cfg.chunk_size = 16
cfg.alpha = 0.7
cfg.save_ckp_interval = 1000
# -----------------------------------------------------------------------------
# ----------------------------Pretrain Settings---------------------------------
# Default: load 2d pretrain
cfg.fix_weight = False
cfg.load_match = False
cfg.pretrained_checkpoint = 'v2-1_512-ema-pruned.ckpt'
cfg.pretrained_image_keys = 'stable_diffusion_image_key_temporal_attention_x1.json'
cfg.resume_checkpoint = 'img2video_ldm_0779000.pth'
# -----------------------------------------------------------------------------
# -----------------------------Visual-------------------------------------------
# Visual videos
cfg.viz_interval = 1000
cfg.visual_train = {
'type': 'VisualVideoTextDuringTrain',
}
cfg.visual_inference = {
'type': 'VisualGeneratedVideos',
}
cfg.inference_list_path = ''
# logging
cfg.log_interval = 100
# Default log_dir
cfg.log_dir = 'workspace/output_data'
# -----------------------------------------------------------------------------
# ---------------------------Others--------------------------------------------
# seed
cfg.seed = 8888
cfg.negative_prompt = 'worst quality, normal quality, low quality, low res, blurry, text, \
watermark, logo, banner, extra digits, cropped, jpeg artifacts, signature, username, error, \
sketch ,duplicate, ugly, monochrome, horror, geometry, mutation, disgusting'
cfg.positive_prompt = ', cinematic, High Contrast, highly detailed, unreal engine, \
taken using a Canon EOS R camera, hyper detailed photo - realistic maximum detail, \
32k, Color Grading, ultra HD, extreme meticulous detailing, skin pore detailing, \
hyper sharpness, perfect without deformations, Unreal Engine 5, 4k render'
# -----------------------------------------------------------------------------

View File

@@ -0,0 +1,247 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import random
import torch
from .schedules_sdedit import karras_schedule
from .solvers_sdedit import sample_dpmpp_2m_sde, sample_heun
__all__ = ['GaussianDiffusion_SDEdit']
def _i(tensor, t, x):
shape = (x.size(0), ) + (1, ) * (x.ndim - 1)
return tensor[t.to(tensor.device)].view(shape).to(x.device)
class GaussianDiffusion_SDEdit(object):
def __init__(self, sigmas, prediction_type='eps'):
assert prediction_type in {'x0', 'eps', 'v'}
self.sigmas = sigmas
self.alphas = torch.sqrt(1 - sigmas**2)
self.num_timesteps = len(sigmas)
self.prediction_type = prediction_type
def diffuse(self, x0, t, noise=None):
noise = torch.randn_like(x0) if noise is None else noise
xt = _i(self.alphas, t, x0) * x0 + _i(self.sigmas, t, x0) * noise
return xt
def denoise(self,
xt,
t,
s,
model,
model_kwargs={},
guide_scale=None,
guide_rescale=None,
clamp=None,
percentile=None):
s = t - 1 if s is None else s
# hyperparams
sigmas = _i(self.sigmas, t, xt)
alphas = _i(self.alphas, t, xt)
alphas_s = _i(self.alphas, s.clamp(0), xt)
alphas_s[s < 0] = 1.
sigmas_s = torch.sqrt(1 - alphas_s**2)
# precompute variables
betas = 1 - (alphas / alphas_s)**2
coef1 = betas * alphas_s / sigmas**2
coef2 = (alphas * sigmas_s**2) / (alphas_s * sigmas**2)
var = betas * (sigmas_s / sigmas)**2
log_var = torch.log(var).clamp_(-20, 20)
# prediction
if guide_scale is None:
assert isinstance(model_kwargs, dict)
out = model(xt, t=t, **model_kwargs)
else:
# classifier-free guidance
assert isinstance(model_kwargs, list) and len(model_kwargs) == 2
y_out = model(xt, t=t, **model_kwargs[0])
if guide_scale == 1.:
out = y_out
else:
u_out = model(xt, t=t, **model_kwargs[1])
out = u_out + guide_scale * (y_out - u_out)
if guide_rescale is not None:
assert guide_rescale >= 0 and guide_rescale <= 1
ratio = (
y_out.flatten(1).std(dim=1) / # noqa
(out.flatten(1).std(dim=1) + 1e-12)
).view((-1, ) + (1, ) * (y_out.ndim - 1))
out *= guide_rescale * ratio + (1 - guide_rescale) * 1.0
# compute x0
if self.prediction_type == 'x0':
x0 = out
elif self.prediction_type == 'eps':
x0 = (xt - sigmas * out) / alphas
elif self.prediction_type == 'v':
x0 = alphas * xt - sigmas * out
else:
raise NotImplementedError(
f'prediction_type {self.prediction_type} not implemented')
# restrict the range of x0
if percentile is not None:
assert percentile > 0 and percentile <= 1
s = torch.quantile(x0.flatten(1).abs(), percentile, dim=1)
s = s.clamp_(1.0).view((-1, ) + (1, ) * (xt.ndim - 1))
x0 = torch.min(s, torch.max(-s, x0)) / s
elif clamp is not None:
x0 = x0.clamp(-clamp, clamp)
# recompute eps using the restricted x0
eps = (xt - alphas * x0) / sigmas
# compute mu (mean of posterior distribution) using the restricted x0
mu = coef1 * x0 + coef2 * xt
return mu, var, log_var, x0, eps
@torch.no_grad()
def sample(self,
noise,
model,
model_kwargs={},
condition_fn=None,
guide_scale=None,
guide_rescale=None,
clamp=None,
percentile=None,
solver='euler_a',
steps=20,
t_max=None,
t_min=None,
discretization=None,
discard_penultimate_step=None,
return_intermediate=None,
show_progress=False,
seed=-1,
**kwargs):
# sanity check
assert isinstance(steps, (int, torch.LongTensor))
assert t_max is None or (t_max > 0 and t_max <= self.num_timesteps - 1)
assert t_min is None or (t_min >= 0 and t_min < self.num_timesteps - 1)
assert discretization in (None, 'leading', 'linspace', 'trailing')
assert discard_penultimate_step in (None, True, False)
assert return_intermediate in (None, 'x0', 'xt')
# function of diffusion solver
solver_fn = {
'heun': sample_heun,
'dpmpp_2m_sde': sample_dpmpp_2m_sde
}[solver]
# options
schedule = 'karras' if 'karras' in solver else None
discretization = discretization or 'linspace'
seed = seed if seed >= 0 else random.randint(0, 2**31)
if isinstance(steps, torch.LongTensor):
discard_penultimate_step = False
if discard_penultimate_step is None:
discard_penultimate_step = True if solver in (
'dpm2', 'dpm2_ancestral', 'dpmpp_2m_sde', 'dpm2_karras',
'dpm2_ancestral_karras', 'dpmpp_2m_sde_karras') else False
# function for denoising xt to get x0
intermediates = []
def model_fn(xt, sigma):
# denoising
t = self._sigma_to_t(sigma).repeat(len(xt)).round().long()
x0 = self.denoise(xt, t, None, model, model_kwargs, guide_scale,
guide_rescale, clamp, percentile)[-2]
# collect intermediate outputs
if return_intermediate == 'xt':
intermediates.append(xt)
elif return_intermediate == 'x0':
intermediates.append(x0)
return x0
# get timesteps
if isinstance(steps, int):
steps += 1 if discard_penultimate_step else 0
t_max = self.num_timesteps - 1 if t_max is None else t_max
t_min = 0 if t_min is None else t_min
# discretize timesteps
if discretization == 'leading':
steps = torch.arange(t_min, t_max + 1,
(t_max - t_min + 1) / steps).flip(0)
elif discretization == 'linspace':
steps = torch.linspace(t_max, t_min, steps)
elif discretization == 'trailing':
steps = torch.arange(t_max, t_min - 1,
-((t_max - t_min + 1) / steps))
else:
raise NotImplementedError(
f'{discretization} discretization not implemented')
steps = steps.clamp_(t_min, t_max)
steps = torch.as_tensor(
steps, dtype=torch.float32, device=noise.device)
# get sigmas
sigmas = self._t_to_sigma(steps)
sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
if schedule == 'karras':
if sigmas[0] == float('inf'):
sigmas = karras_schedule(
n=len(steps) - 1,
sigma_min=sigmas[sigmas > 0].min().item(),
sigma_max=sigmas[sigmas < float('inf')].max().item(),
rho=7.).to(sigmas)
sigmas = torch.cat([
sigmas.new_tensor([float('inf')]), sigmas,
sigmas.new_zeros([1])
])
else:
sigmas = karras_schedule(
n=len(steps),
sigma_min=sigmas[sigmas > 0].min().item(),
sigma_max=sigmas.max().item(),
rho=7.).to(sigmas)
sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
if discard_penultimate_step:
sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
# sampling
x0 = solver_fn(
noise, model_fn, sigmas, show_progress=show_progress, **kwargs)
return (x0, intermediates) if return_intermediate is not None else x0
def _sigma_to_t(self, sigma):
if sigma == float('inf'):
t = torch.full_like(sigma, len(self.sigmas) - 1)
else:
log_sigmas = torch.sqrt(self.sigmas**2 / # noqa
(1 - self.sigmas**2)).log().to(sigma)
log_sigma = sigma.log()
dists = log_sigma - log_sigmas[:, None]
low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(
max=log_sigmas.shape[0] - 2)
high_idx = low_idx + 1
low, high = log_sigmas[low_idx], log_sigmas[high_idx]
w = (low - log_sigma) / (low - high)
w = w.clamp(0, 1)
t = (1 - w) * low_idx + w * high_idx
t = t.view(sigma.shape)
if t.ndim == 0:
t = t.unsqueeze(0)
return t
def _t_to_sigma(self, t):
t = t.float()
low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac()
log_sigmas = torch.sqrt(self.sigmas**2 / # noqa
(1 - self.sigmas**2)).log().to(t)
log_sigma = (1 - w) * log_sigmas[low_idx] + w * log_sigmas[high_idx]
log_sigma[torch.isnan(log_sigma)
| torch.isinf(log_sigma)] = float('inf')
return log_sigma.exp()

View File

@@ -0,0 +1,85 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import math
import torch
def betas_to_sigmas(betas):
return torch.sqrt(1 - torch.cumprod(1 - betas, dim=0))
def sigmas_to_betas(sigmas):
square_alphas = 1 - sigmas**2
betas = 1 - torch.cat(
[square_alphas[:1], square_alphas[1:] / square_alphas[:-1]])
return betas
def logsnrs_to_sigmas(logsnrs):
return torch.sqrt(torch.sigmoid(-logsnrs))
def sigmas_to_logsnrs(sigmas):
square_sigmas = sigmas**2
return torch.log(square_sigmas / (1 - square_sigmas))
def _logsnr_cosine(n, logsnr_min=-15, logsnr_max=15):
t_min = math.atan(math.exp(-0.5 * logsnr_min))
t_max = math.atan(math.exp(-0.5 * logsnr_max))
t = torch.linspace(1, 0, n)
logsnrs = -2 * torch.log(torch.tan(t_min + t * (t_max - t_min)))
return logsnrs
def _logsnr_cosine_shifted(n, logsnr_min=-15, logsnr_max=15, scale=2):
logsnrs = _logsnr_cosine(n, logsnr_min, logsnr_max)
logsnrs += 2 * math.log(1 / scale)
return logsnrs
def _logsnr_cosine_interp(n,
logsnr_min=-15,
logsnr_max=15,
scale_min=2,
scale_max=4):
t = torch.linspace(1, 0, n)
logsnrs_min = _logsnr_cosine_shifted(n, logsnr_min, logsnr_max, scale_min)
logsnrs_max = _logsnr_cosine_shifted(n, logsnr_min, logsnr_max, scale_max)
logsnrs = t * logsnrs_min + (1 - t) * logsnrs_max
return logsnrs
def karras_schedule(n, sigma_min=0.002, sigma_max=80.0, rho=7.0):
ramp = torch.linspace(1, 0, n)
min_inv_rho = sigma_min**(1 / rho)
max_inv_rho = sigma_max**(1 / rho)
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho))**rho
sigmas = torch.sqrt(sigmas**2 / (1 + sigmas**2))
return sigmas
def logsnr_cosine_interp_schedule(n,
logsnr_min=-15,
logsnr_max=15,
scale_min=2,
scale_max=4):
return logsnrs_to_sigmas(
_logsnr_cosine_interp(n, logsnr_min, logsnr_max, scale_min, scale_max))
def noise_schedule(schedule='logsnr_cosine_interp',
n=1000,
zero_terminal_snr=False,
**kwargs):
# compute sigmas
sigmas = {
'logsnr_cosine_interp': logsnr_cosine_interp_schedule
}[schedule](n, **kwargs)
# post-processing
if zero_terminal_snr and sigmas.max() != 1.0:
scale = (1.0 - sigmas.min()) / (sigmas.max() - sigmas.min())
sigmas = sigmas.min() + scale * (sigmas - sigmas.min())
return sigmas

View File

@@ -0,0 +1,14 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import random
import numpy as np
import torch
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True

View File

@@ -0,0 +1,194 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import torch
import torchsde
from tqdm.auto import trange
def get_ancestral_step(sigma_from, sigma_to, eta=1.):
"""
Calculates the noise level (sigma_down) to step down to and the amount
of noise to add (sigma_up) when doing an ancestral sampling step.
"""
if not eta:
return sigma_to, 0.
sigma_up = min(
sigma_to,
eta * (
sigma_to**2 * # noqa
(sigma_from**2 - sigma_to**2) / sigma_from**2)**0.5)
sigma_down = (sigma_to**2 - sigma_up**2)**0.5
return sigma_down, sigma_up
def get_scalings(sigma):
c_out = -sigma
c_in = 1 / (sigma**2 + 1.**2)**0.5
return c_out, c_in
@torch.no_grad()
def sample_heun(noise,
model,
sigmas,
s_churn=0.,
s_tmin=0.,
s_tmax=float('inf'),
s_noise=1.,
show_progress=True):
"""
Implements Algorithm 2 (Heun steps) from Karras et al. (2022).
"""
x = noise * sigmas[0]
for i in trange(len(sigmas) - 1, disable=not show_progress):
gamma = 0.
if s_tmin <= sigmas[i] <= s_tmax and sigmas[i] < float('inf'):
gamma = min(s_churn / (len(sigmas) - 1), 2**0.5 - 1)
eps = torch.randn_like(x) * s_noise
sigma_hat = sigmas[i] * (gamma + 1)
if gamma > 0:
x = x + eps * (sigma_hat**2 - sigmas[i]**2)**0.5
if sigmas[i] == float('inf'):
# Euler method
denoised = model(noise, sigma_hat)
x = denoised + sigmas[i + 1] * (gamma + 1) * noise
else:
_, c_in = get_scalings(sigma_hat)
denoised = model(x * c_in, sigma_hat)
d = (x - denoised) / sigma_hat
dt = sigmas[i + 1] - sigma_hat
if sigmas[i + 1] == 0:
# Euler method
x = x + d * dt
else:
# Heun's method
x_2 = x + d * dt
_, c_in = get_scalings(sigmas[i + 1])
denoised_2 = model(x_2 * c_in, sigmas[i + 1])
d_2 = (x_2 - denoised_2) / sigmas[i + 1]
d_prime = (d + d_2) / 2
x = x + d_prime * dt
return x
class BatchedBrownianTree:
"""
A wrapper around torchsde.BrownianTree that enables batches of entropy.
"""
def __init__(self, x, t0, t1, seed=None, **kwargs):
t0, t1, self.sign = self.sort(t0, t1)
w0 = kwargs.get('w0', torch.zeros_like(x))
if seed is None:
seed = torch.randint(0, 2**63 - 1, []).item()
self.batched = True
try:
assert len(seed) == x.shape[0]
w0 = w0[0]
except TypeError:
seed = [seed]
self.batched = False
self.trees = [
torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs)
for s in seed
]
@staticmethod
def sort(a, b):
return (a, b, 1) if a < b else (b, a, -1)
def __call__(self, t0, t1):
t0, t1, sign = self.sort(t0, t1)
w = torch.stack([tree(t0, t1) for tree in self.trees]) * (
self.sign * sign)
return w if self.batched else w[0]
class BrownianTreeNoiseSampler:
"""
A noise sampler backed by a torchsde.BrownianTree.
Args:
x (Tensor): The tensor whose shape, device and dtype to use to generate
random samples.
sigma_min (float): The low end of the valid interval.
sigma_max (float): The high end of the valid interval.
seed (int or List[int]): The random seed. If a list of seeds is
supplied instead of a single integer, then the noise sampler will
use one BrownianTree per batch item, each with its own seed.
transform (callable): A function that maps sigma to the sampler's
internal timestep.
"""
def __init__(self,
x,
sigma_min,
sigma_max,
seed=None,
transform=lambda x: x):
self.transform = transform
t0 = self.transform(torch.as_tensor(sigma_min))
t1 = self.transform(torch.as_tensor(sigma_max))
self.tree = BatchedBrownianTree(x, t0, t1, seed)
def __call__(self, sigma, sigma_next):
t0 = self.transform(torch.as_tensor(sigma))
t1 = self.transform(torch.as_tensor(sigma_next))
return self.tree(t0, t1) / (t1 - t0).abs().sqrt()
@torch.no_grad()
def sample_dpmpp_2m_sde(noise,
model,
sigmas,
eta=1.,
s_noise=1.,
solver_type='midpoint',
show_progress=True):
"""
DPM-Solver++ (2M) SDE.
"""
assert solver_type in {'heun', 'midpoint'}
x = noise * sigmas[0]
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas[
sigmas < float('inf')].max()
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max)
old_denoised = None
h_last = None
for i in trange(len(sigmas) - 1, disable=not show_progress):
if sigmas[i] == float('inf'):
# Euler method
denoised = model(noise, sigmas[i])
x = denoised + sigmas[i + 1] * noise
else:
_, c_in = get_scalings(sigmas[i])
denoised = model(x * c_in, sigmas[i])
if sigmas[i + 1] == 0:
# Denoising step
x = denoised
else:
# DPM-Solver++(2M) SDE
t, s = -sigmas[i].log(), -sigmas[i + 1].log()
h = s - t
eta_h = eta * h
x = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * x + \
(-h - eta_h).expm1().neg() * denoised
if old_denoised is not None:
r = h_last / h
if solver_type == 'heun':
x = x + ((-h - eta_h).expm1().neg() / (-h - eta_h) + 1) * \
(1 / r) * (denoised - old_denoised)
elif solver_type == 'midpoint':
x = x + 0.5 * (-h - eta_h).expm1().neg() * \
(1 / r) * (denoised - old_denoised)
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[
i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise
old_denoised = denoised
h_last = h
return x

View File

@@ -0,0 +1,404 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import math
import random
import numpy as np
import torch
import torchvision.transforms.functional as F
from PIL import Image, ImageFilter
__all__ = [
'Compose', 'Resize', 'Rescale', 'CenterCrop', 'CenterCropV2',
'CenterCropWide', 'RandomCrop', 'RandomCropV2', 'RandomHFlip',
'GaussianBlur', 'ColorJitter', 'RandomGray', 'ToTensor', 'Normalize',
'ResizeRandomCrop', 'ExtractResizeRandomCrop', 'ExtractResizeAssignCrop'
]
class Compose(object):
def __init__(self, transforms):
self.transforms = transforms
def __getitem__(self, index):
if isinstance(index, slice):
return Compose(self.transforms[index])
else:
return self.transforms[index]
def __len__(self):
return len(self.transforms)
def __call__(self, rgb):
for t in self.transforms:
rgb = t(rgb)
return rgb
class Resize(object):
def __init__(self, size=256):
if isinstance(size, int):
size = (size, size)
self.size = size
def __call__(self, rgb):
if isinstance(rgb, list):
rgb = [u.resize(self.size, Image.BILINEAR) for u in rgb]
else:
rgb = rgb.resize(self.size, Image.BILINEAR)
return rgb
class Rescale(object):
def __init__(self, size=256, interpolation=Image.BILINEAR):
self.size = size
self.interpolation = interpolation
def __call__(self, rgb):
w, h = rgb[0].size
scale = self.size / min(w, h)
out_w, out_h = int(round(w * scale)), int(round(h * scale))
rgb = [u.resize((out_w, out_h), self.interpolation) for u in rgb]
return rgb
class CenterCrop(object):
def __init__(self, size=224):
self.size = size
def __call__(self, rgb):
w, h = rgb[0].size
assert min(w, h) >= self.size
x1 = (w - self.size) // 2
y1 = (h - self.size) // 2
rgb = [u.crop((x1, y1, x1 + self.size, y1 + self.size)) for u in rgb]
return rgb
class ResizeRandomCrop(object):
def __init__(self, size=256, size_short=292):
self.size = size
self.size_short = size_short
def __call__(self, rgb):
# consistent crop between rgb and m
while min(rgb[0].size) >= 2 * self.size_short:
rgb = [
u.resize((u.width // 2, u.height // 2), resample=Image.BOX)
for u in rgb
]
scale = self.size_short / min(rgb[0].size)
rgb = [
u.resize((round(scale * u.width), round(scale * u.height)),
resample=Image.BICUBIC) for u in rgb
]
out_w = self.size
out_h = self.size
w, h = rgb[0].size
x1 = random.randint(0, w - out_w)
y1 = random.randint(0, h - out_h)
rgb = [u.crop((x1, y1, x1 + out_w, y1 + out_h)) for u in rgb]
return rgb
class ExtractResizeRandomCrop(object):
def __init__(self, size=256, size_short=292):
self.size = size
self.size_short = size_short
def __call__(self, rgb):
# consistent crop between rgb and m
while min(rgb[0].size) >= 2 * self.size_short:
rgb = [
u.resize((u.width // 2, u.height // 2), resample=Image.BOX)
for u in rgb
]
scale = self.size_short / min(rgb[0].size)
rgb = [
u.resize((round(scale * u.width), round(scale * u.height)),
resample=Image.BICUBIC) for u in rgb
]
out_w = self.size
out_h = self.size
w, h = rgb[0].size
x1 = random.randint(0, w - out_w)
y1 = random.randint(0, h - out_h)
rgb = [u.crop((x1, y1, x1 + out_w, y1 + out_h)) for u in rgb]
wh = [x1, y1, x1 + out_w, y1 + out_h]
return rgb, wh
class ExtractResizeAssignCrop(object):
def __init__(self, size=256, size_short=292):
self.size = size
self.size_short = size_short
def __call__(self, rgb, wh):
# consistent crop between rgb and m
while min(rgb[0].size) >= 2 * self.size_short:
rgb = [
u.resize((u.width // 2, u.height // 2), resample=Image.BOX)
for u in rgb
]
scale = self.size_short / min(rgb[0].size)
rgb = [
u.resize((round(scale * u.width), round(scale * u.height)),
resample=Image.BICUBIC) for u in rgb
]
rgb = [u.crop(wh) for u in rgb]
rgb = [u.resize((self.size, self.size), Image.BILINEAR) for u in rgb]
return rgb
class CenterCropV2(object):
def __init__(self, size):
self.size = size
def __call__(self, img):
# fast resize
while min(img[0].size) >= 2 * self.size:
img = [
u.resize((u.width // 2, u.height // 2), resample=Image.BOX)
for u in img
]
scale = self.size / min(img[0].size)
img = [
u.resize((round(scale * u.width), round(scale * u.height)),
resample=Image.BICUBIC) for u in img
]
# center crop
x1 = (img[0].width - self.size) // 2
y1 = (img[0].height - self.size) // 2
img = [u.crop((x1, y1, x1 + self.size, y1 + self.size)) for u in img]
return img
class CenterCropWide(object):
def __init__(self, size):
self.size = size
def __call__(self, img):
if isinstance(img, list):
scale = min(img[0].size[0] / self.size[0],
img[0].size[1] / self.size[1])
img = [
u.resize((round(u.width // scale), round(u.height // scale)),
resample=Image.BOX) for u in img
]
# center crop
x1 = (img[0].width - self.size[0]) // 2
y1 = (img[0].height - self.size[1]) // 2
img = [
u.crop((x1, y1, x1 + self.size[0], y1 + self.size[1]))
for u in img
]
return img
else:
scale = min(img.size[0] / self.size[0], img.size[1] / self.size[1])
img = img.resize(
(round(img.width // scale), round(img.height // scale)),
resample=Image.BOX)
x1 = (img.width - self.size[0]) // 2
y1 = (img.height - self.size[1]) // 2
img = img.crop((x1, y1, x1 + self.size[0], y1 + self.size[1]))
return img
class RandomCrop(object):
def __init__(self, size=224, min_area=0.4):
self.size = size
self.min_area = min_area
def __call__(self, rgb):
# consistent crop between rgb and m
w, h = rgb[0].size
area = w * h
out_w, out_h = float('inf'), float('inf')
while out_w > w or out_h > h:
target_area = random.uniform(self.min_area, 1.0) * area
aspect_ratio = random.uniform(3. / 4., 4. / 3.)
out_w = int(round(math.sqrt(target_area * aspect_ratio)))
out_h = int(round(math.sqrt(target_area / aspect_ratio)))
x1 = random.randint(0, w - out_w)
y1 = random.randint(0, h - out_h)
rgb = [u.crop((x1, y1, x1 + out_w, y1 + out_h)) for u in rgb]
rgb = [u.resize((self.size, self.size), Image.BILINEAR) for u in rgb]
return rgb
class RandomCropV2(object):
def __init__(self, size=224, min_area=0.4, ratio=(3. / 4., 4. / 3.)):
if isinstance(size, (tuple, list)):
self.size = size
else:
self.size = (size, size)
self.min_area = min_area
self.ratio = ratio
def _get_params(self, img):
width, height = img.size
area = height * width
for _ in range(10):
target_area = random.uniform(self.min_area, 1.0) * area
log_ratio = (math.log(self.ratio[0]), math.log(self.ratio[1]))
aspect_ratio = math.exp(random.uniform(*log_ratio))
w = int(round(math.sqrt(target_area * aspect_ratio)))
h = int(round(math.sqrt(target_area / aspect_ratio)))
if 0 < w <= width and 0 < h <= height:
i = random.randint(0, height - h)
j = random.randint(0, width - w)
return i, j, h, w
# Fallback to central crop
in_ratio = float(width) / float(height)
if (in_ratio < min(self.ratio)):
w = width
h = int(round(w / min(self.ratio)))
elif (in_ratio > max(self.ratio)):
h = height
w = int(round(h * max(self.ratio)))
else:
w = width
h = height
i = (height - h) // 2
j = (width - w) // 2
return i, j, h, w
def __call__(self, rgb):
i, j, h, w = self._get_params(rgb[0])
rgb = [F.resized_crop(u, i, j, h, w, self.size) for u in rgb]
return rgb
class RandomHFlip(object):
def __init__(self, p=0.5):
self.p = p
def __call__(self, rgb):
if random.random() < self.p:
rgb = [u.transpose(Image.FLIP_LEFT_RIGHT) for u in rgb]
return rgb
class GaussianBlur(object):
def __init__(self, sigmas=[0.1, 2.0], p=0.5):
self.sigmas = sigmas
self.p = p
def __call__(self, rgb):
if random.random() < self.p:
sigma = random.uniform(*self.sigmas)
rgb = [
u.filter(ImageFilter.GaussianBlur(radius=sigma)) for u in rgb
]
return rgb
class ColorJitter(object):
def __init__(self,
brightness=0.4,
contrast=0.4,
saturation=0.4,
hue=0.1,
p=0.5):
self.brightness = brightness
self.contrast = contrast
self.saturation = saturation
self.hue = hue
self.p = p
def __call__(self, rgb):
if random.random() < self.p:
brightness, contrast, saturation, hue = self._random_params()
transforms = [
lambda f: F.adjust_brightness(f, brightness),
lambda f: F.adjust_contrast(f, contrast),
lambda f: F.adjust_saturation(f, saturation),
lambda f: F.adjust_hue(f, hue)
]
random.shuffle(transforms)
for t in transforms:
rgb = [t(u) for u in rgb]
return rgb
def _random_params(self):
brightness = random.uniform(
max(0, 1 - self.brightness), 1 + self.brightness)
contrast = random.uniform(max(0, 1 - self.contrast), 1 + self.contrast)
saturation = random.uniform(
max(0, 1 - self.saturation), 1 + self.saturation)
hue = random.uniform(-self.hue, self.hue)
return brightness, contrast, saturation, hue
class RandomGray(object):
def __init__(self, p=0.2):
self.p = p
def __call__(self, rgb):
if random.random() < self.p:
rgb = [u.convert('L').convert('RGB') for u in rgb]
return rgb
class ToTensor(object):
def __call__(self, rgb):
if isinstance(rgb, list):
rgb = torch.stack([F.to_tensor(u) for u in rgb], dim=0)
else:
rgb = F.to_tensor(rgb)
return rgb
class Normalize(object):
def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
self.mean = mean
self.std = std
def __call__(self, rgb):
rgb = rgb.clone()
rgb.clamp_(0, 1)
if not isinstance(self.mean, torch.Tensor):
self.mean = rgb.new_tensor(self.mean).view(-1)
if not isinstance(self.std, torch.Tensor):
self.std = rgb.new_tensor(self.std).view(-1)
if rgb.dim() == 4:
rgb.sub_(self.mean.view(1, -1, 1,
1)).div_(self.std.view(1, -1, 1, 1))
elif rgb.dim() == 3:
rgb.sub_(self.mean.view(-1, 1, 1)).div_(self.std.view(-1, 1, 1))
return rgb

View File

@@ -0,0 +1,227 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import os.path as osp
import random
from copy import copy
from typing import Any, Dict
import torch
import torch.cuda.amp as amp
import torch.nn.functional as F
import modelscope.models.multi_modal.video_to_video.utils.transforms as data
from modelscope.metainfo import Models
from modelscope.models.base import TorchModel
from modelscope.models.builder import MODELS
from modelscope.models.multi_modal.video_to_video.modules import *
from modelscope.models.multi_modal.video_to_video.modules import (
AutoencoderKL, FrozenOpenCLIPEmbedder, Vid2VidSDUNet,
get_first_stage_encoding)
from modelscope.models.multi_modal.video_to_video.utils.config import cfg
from modelscope.models.multi_modal.video_to_video.utils.diffusion_sdedit import \
GaussianDiffusion_SDEdit
from modelscope.models.multi_modal.video_to_video.utils.schedules_sdedit import \
noise_schedule
from modelscope.models.multi_modal.video_to_video.utils.seed import setup_seed
from modelscope.utils.config import Config
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger
__all__ = ['VideoToVideo']
logger = get_logger()
@MODELS.register_module(
Tasks.video_to_video, module_name=Models.video_to_video_model)
class VideoToVideo(TorchModel):
r"""
Video2Video aims to solve the task of generating super-resolution videos based on input
video and text, which is a video generation basic model developed by Alibaba Cloud.
Paper link: https://arxiv.org/abs/2306.02018
Attributes:
diffusion: diffusion model for DDIM.
autoencoder: decode the latent representation of input video into visual space.
clip_encoder: encode the text into text embedding.
"""
def __init__(self, model_dir, *args, **kwargs):
r"""
Args:
model_dir (`str` or `os.PathLike`)
Can be either:
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co
or modelscope.cn. Valid model ids can be located at the root-level, like `bert-base-uncased`,
or namespaced under a user or organization name, like `dbmdz/bert-base-german-cased`.
- A path to a *directory* containing model weights saved using
[`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
- A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In
this case, `from_tf` should be set to `True` and a configuration object should be provided as
`config` argument. This loading path is slower than converting the TensorFlow checkpoint in a
PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
- A path or url to a model folder containing a *flax checkpoint file* in *.msgpack* format (e.g,
`./flax_model/` containing `flax_model.msgpack`). In this case, `from_flax` should be set to
`True`.
"""
super().__init__(model_dir=model_dir, *args, **kwargs)
self.config = Config.from_file(
osp.join(model_dir, ModelFile.CONFIGURATION))
cfg.solver_mode = self.config.model.model_args.solver_mode
# assign default value
cfg.batch_size = self.config.model.model_cfg.batch_size
cfg.target_fps = self.config.model.model_cfg.target_fps
cfg.max_frames = self.config.model.model_cfg.max_frames
cfg.latent_hei = self.config.model.model_cfg.latent_hei
cfg.latent_wid = self.config.model.model_cfg.latent_wid
cfg.model_path = osp.join(model_dir,
self.config.model.model_args.ckpt_unet)
self.device = torch.device(
'cuda') if torch.cuda.is_available() else torch.device('cpu')
if 'seed' in self.config.model.model_args.keys():
cfg.seed = self.config.model.model_args.seed
else:
cfg.seed = random.randint(0, 99999)
setup_seed(cfg.seed)
# transform
vid_trans = data.Compose(
[data.ToTensor(),
data.Normalize(mean=cfg.mean, std=cfg.std)])
self.vid_trans = vid_trans
cfg.embedder.pretrained = osp.join(
model_dir, self.config.model.model_args.ckpt_clip)
clip_encoder = FrozenOpenCLIPEmbedder(
pretrained=cfg.embedder.pretrained)
clip_encoder.model.to(self.device)
self.clip_encoder = clip_encoder
logger.info(f'Build encoder with {cfg.embedder.type}')
# [unet]
generator = Vid2VidSDUNet()
generator = generator.to(self.device)
generator.eval()
load_dict = torch.load(cfg.model_path, map_location='cpu')
ret = generator.load_state_dict(load_dict['state_dict'], strict=True)
self.generator = generator
logger.info('Load model {} path {}, with local status {}'.format(
cfg.UNet.type, cfg.model_path, ret))
# [diffusion]
sigmas = noise_schedule(
schedule='logsnr_cosine_interp',
n=1000,
zero_terminal_snr=True,
scale_min=2.0,
scale_max=4.0)
diffusion = GaussianDiffusion_SDEdit(
sigmas=sigmas, prediction_type='v')
self.diffusion = diffusion
logger.info('Build diffusion with type of GaussianDiffusion_SDEdit')
# [auotoencoder]
cfg.auto_encoder.pretrained = osp.join(
model_dir, self.config.model.model_args.ckpt_autoencoder)
autoencoder = AutoencoderKL(**cfg.auto_encoder)
autoencoder.eval()
for param in autoencoder.parameters():
param.requires_grad = False
autoencoder.to(self.device)
self.autoencoder = autoencoder
torch.cuda.empty_cache()
negative_prompt = cfg.negative_prompt
negative_y = clip_encoder(negative_prompt).detach()
self.negative_y = negative_y
positive_prompt = cfg.positive_prompt
self.positive_prompt = positive_prompt
self.cfg = cfg
def forward(self, input: Dict[str, Any]):
r"""
The entry function of video to video task.
1. Using CLIP to encode text into embeddings.
2. Using diffusion model to generate the video's latent representation.
3. Using autoencoder to decode the video's latent representation to visual space.
Args:
input (`Dict[Str, Any]`):
The input of the task
Returns:
A generated video (as pytorch tensor).
"""
video_data = input['video_data']
y = input['y']
cfg = self.cfg
video_data = F.interpolate(
video_data, size=(720, 1280), mode='bilinear')
video_data = video_data.unsqueeze(0)
video_data = video_data.to(self.device)
batch_size, frames_num, _, _, _ = video_data.shape
video_data = rearrange(video_data, 'b f c h w -> (b f) c h w')
video_data_list = torch.chunk(
video_data, video_data.shape[0] // 2, dim=0)
with torch.no_grad():
decode_data = []
for vd_data in video_data_list:
encoder_posterior = self.autoencoder.encode(vd_data)
tmp = get_first_stage_encoding(encoder_posterior).detach()
decode_data.append(tmp)
video_data_feature = torch.cat(decode_data, dim=0)
video_data_feature = rearrange(
video_data_feature, '(b f) c h w -> b c f h w', b=batch_size)
with amp.autocast(enabled=True):
total_noise_levels = 600
t = torch.randint(
total_noise_levels - 1,
total_noise_levels, (1, ),
dtype=torch.long).to(self.device)
noise = torch.randn_like(video_data_feature)
noised_lr = self.diffusion.diffuse(video_data_feature, t, noise)
model_kwargs = [{'y': y}, {'y': self.negative_y}]
gen_vid = self.diffusion.sample(
noise=noised_lr,
model=self.generator,
model_kwargs=model_kwargs,
guide_scale=7.5,
guide_rescale=0.2,
solver='dpmpp_2m_sde' if cfg.solver_mode == 'fast' else 'heun',
steps=30 if cfg.solver_mode == 'fast' else 50,
t_max=total_noise_levels - 1,
t_min=0,
discretization='trailing')
scale_factor = 0.18215
vid_tensor_feature = 1. / scale_factor * gen_vid
vid_tensor_feature = rearrange(vid_tensor_feature,
'b c f h w -> (b f) c h w')
vid_tensor_feature_list = torch.chunk(
vid_tensor_feature, vid_tensor_feature.shape[0] // 2, dim=0)
decode_data = []
for vd_data in vid_tensor_feature_list:
tmp = self.autoencoder.decode(vd_data)
decode_data.append(tmp)
vid_tensor_gen = torch.cat(decode_data, dim=0)
gen_video = rearrange(
vid_tensor_gen, '(b f) c h w -> b c f h w', b=cfg.batch_size)
return gen_video.type(torch.float32).cpu()

View File

@@ -0,0 +1,140 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import tempfile
from typing import Any, Dict, Optional
import cv2
import torch
from einops import rearrange
from modelscope.metainfo import Pipelines
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Input, Model, Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors.image import LoadImage
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger
logger = get_logger()
@PIPELINES.register_module(
Tasks.video_to_video, module_name=Pipelines.video_to_video_pipeline)
class VideoToVideoPipeline(Pipeline):
r""" Video To Video Pipeline, generating super-resolution videos based on input
video and text
Examples:
>>> from modelscope.pipelines import pipeline
>>> from modelscope.outputs import OutputKeys
>>> # YOUR_VIDEO_PATH: your video url or local position in low resolution
>>> # INPUT_TEXT: when we do video super-resolution, we will add the text content
>>> # into results
>>> # output_video_path: path-to-the-generated-video
>>> p = pipeline('video-to-video', 'damo/Video-to-Video')
>>> input = {"video_path":YOUR_VIDEO_PATH, "text": INPUT_TEXT}
>>> output_video_path = p(input,output_video='./output.mp4')[OutputKeys.OUTPUT_VIDEO]
"""
def __init__(self, model: str, **kwargs):
"""
use `model` to create a kws pipeline for prediction
Args:
model: model id on modelscope hub.
"""
super().__init__(model=model, **kwargs)
def preprocess(self, input: Input, **preprocess_params) -> Dict[str, Any]:
vid_path = input['video_path']
if 'text' in input.keys():
text = input['text']
else:
text = ''
caption = text + self.model.positive_prompt
y = self.model.clip_encoder(caption).detach()
max_frames = self.model.cfg.max_frames
capture = cv2.VideoCapture(vid_path)
_fps = capture.get(cv2.CAP_PROP_FPS)
sample_fps = _fps
_total_frame_num = capture.get(cv2.CAP_PROP_FRAME_COUNT)
stride = round(_fps / sample_fps)
start_frame = 0
pointer = 0
frame_list = []
while len(frame_list) < max_frames:
ret, frame = capture.read()
pointer += 1
if (not ret) or (frame is None):
break
if pointer < start_frame:
continue
if pointer >= _total_frame_num + 1:
break
if (pointer - start_frame) % stride == 0:
frame = LoadImage.convert_to_img(frame)
frame_list.append(frame)
capture.release()
video_data = self.model.vid_trans(frame_list)
return {'video_data': video_data, 'y': y}
def forward(self, input: Dict[str, Any],
**forward_params) -> Dict[str, Any]:
video = self.model(input)
return {'video': video}
def postprocess(self, inputs: Dict[str, Any],
**post_params) -> Dict[str, Any]:
video = tensor2vid(inputs['video'], self.model.cfg.mean,
self.model.cfg.std)
output_video_path = post_params.get('output_video', None)
temp_video_file = False
if output_video_path is None:
output_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name
temp_video_file = True
temp_dir = tempfile.mkdtemp()
for fid, frame in enumerate(video):
tpth = os.path.join(temp_dir, '%06d.png' % (fid + 1))
cv2.imwrite(tpth, frame[:, :, ::-1],
[int(cv2.IMWRITE_JPEG_QUALITY), 100])
cmd = f'ffmpeg -y -f image2 -loglevel quiet -framerate 8.0 -i {temp_dir}/%06d.png \
-vcodec libx264 -crf 17 -pix_fmt yuv420p {output_video_path}'
status = os.system(cmd)
if status != 0:
logger.info('Save Video Error with {}'.format(status))
os.system(f'rm -rf {temp_dir}')
if temp_video_file:
video_file_content = b''
with open(output_video_path, 'rb') as f:
video_file_content = f.read()
os.remove(output_video_path)
return {OutputKeys.OUTPUT_VIDEO: video_file_content}
else:
return {OutputKeys.OUTPUT_VIDEO: output_video_path}
def tensor2vid(video, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]):
mean = torch.tensor(mean, device=video.device).reshape(1, -1, 1, 1, 1)
std = torch.tensor(std, device=video.device).reshape(1, -1, 1, 1, 1)
video = video.mul_(std).add_(mean)
video.clamp_(0, 1)
video = video * 255.0
images = rearrange(video, 'b c f h w -> b f h w c')[0]
images = [(img.numpy()).astype('uint8') for img in images]
return images

View File

@@ -257,6 +257,7 @@ class MultiModalTasks(object):
efficient_diffusion_tuning = 'efficient-diffusion-tuning'
multimodal_dialogue = 'multimodal-dialogue'
image_to_video = 'image-to-video'
video_to_video = 'video-to-video'
class ScienceTasks(object):

View File

@@ -0,0 +1,32 @@
import sys
import unittest
from modelscope.models import Model
from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level
class Video2VideoTest(unittest.TestCase):
def setUp(self) -> None:
self.task = Tasks.video_to_video
self.model_id = 'damo/Video-to-Video'
self.path = 'https://video-generation-wulanchabu.oss-cn-wulanchabu.aliyuncs.com/baishao/test.mp4'
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_model_from_modelhub(self):
pipe = pipeline(task=self.task, model=self.model_id)
p_input = {
'video_path': self.path,
'text': 'A panda is surfing on the sea'
}
output_video_path = pipe(
p_input, output_video='./output.mp4')[OutputKeys.OUTPUT_VIDEO]
print(output_video_path)
if __name__ == '__main__':
unittest.main()