mirror of
https://github.com/modelscope/modelscope.git
synced 2026-02-24 12:10:09 +01:00
add video2video (#486)
* add video2video * fix bugs of pre-commit * update some files * fix video write module * fix max_frames
This commit is contained in:
@@ -221,6 +221,7 @@ class Models(object):
|
||||
videocomposer = 'videocomposer'
|
||||
text_to_360panorama_image = 'text-to-360panorama-image'
|
||||
image_to_video_model = 'image-to-video-model'
|
||||
video_to_video_model = 'video-to-video-model'
|
||||
|
||||
# science models
|
||||
unifold = 'unifold'
|
||||
@@ -547,6 +548,7 @@ class Pipelines(object):
|
||||
multimodal_dialogue = 'multimodal-dialogue'
|
||||
llama2_text_generation_pipeline = 'llama2-text-generation-pipeline'
|
||||
image_to_video_task_pipeline = 'image-to-video-task-pipeline'
|
||||
video_to_video_pipeline = 'video-to-video-pipeline'
|
||||
|
||||
# science tasks
|
||||
protein_structure = 'unifold-protein-structure'
|
||||
|
||||
24
modelscope/models/multi_modal/video_to_video/__init__.py
Normal file
24
modelscope/models/multi_modal/video_to_video/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
from .video_to_video_model import VideoToVideo
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
'video_to_video_model': ['VideoToVideo'],
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
@@ -0,0 +1,5 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from .autoencoder import *
|
||||
from .embedder import *
|
||||
from .unet_v2v import *
|
||||
@@ -0,0 +1,590 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import collections
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def nonlinearity(x):
|
||||
# swish
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
def Normalize(in_channels, num_groups=32):
|
||||
return torch.nn.GroupNorm(
|
||||
num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def get_first_stage_encoding(encoder_posterior):
|
||||
scale_factor = 0.18215
|
||||
if isinstance(encoder_posterior, DiagonalGaussianDistribution):
|
||||
z = encoder_posterior.sample()
|
||||
elif isinstance(encoder_posterior, torch.Tensor):
|
||||
z = encoder_posterior
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented"
|
||||
)
|
||||
return scale_factor * z
|
||||
|
||||
|
||||
class DiagonalGaussianDistribution(object):
|
||||
|
||||
def __init__(self, parameters, deterministic=False):
|
||||
self.parameters = parameters
|
||||
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
|
||||
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
||||
self.deterministic = deterministic
|
||||
self.std = torch.exp(0.5 * self.logvar)
|
||||
self.var = torch.exp(self.logvar)
|
||||
if self.deterministic:
|
||||
self.var = self.std = torch.zeros_like(
|
||||
self.mean).to(device=self.parameters.device)
|
||||
|
||||
def sample(self):
|
||||
x = self.mean + self.std * torch.randn(
|
||||
self.mean.shape).to(device=self.parameters.device)
|
||||
return x
|
||||
|
||||
def kl(self, other=None):
|
||||
if self.deterministic:
|
||||
return torch.Tensor([0.])
|
||||
else:
|
||||
if other is None:
|
||||
return 0.5 * torch.sum(
|
||||
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
|
||||
dim=[1, 2, 3])
|
||||
else:
|
||||
return 0.5 * torch.sum(
|
||||
torch.pow(self.mean - other.mean, 2) / other.var
|
||||
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
|
||||
dim=[1, 2, 3])
|
||||
|
||||
def nll(self, sample, dims=[1, 2, 3]):
|
||||
if self.deterministic:
|
||||
return torch.Tensor([0.])
|
||||
logtwopi = np.log(2.0 * np.pi)
|
||||
return 0.5 * torch.sum(
|
||||
logtwopi + self.logvar
|
||||
+ torch.pow(sample - self.mean, 2) / self.var,
|
||||
dim=dims)
|
||||
|
||||
def mode(self):
|
||||
return self.mean
|
||||
|
||||
|
||||
class ResnetBlock(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
in_channels,
|
||||
out_channels=None,
|
||||
conv_shortcut=False,
|
||||
dropout,
|
||||
temb_channels=512):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
self.use_conv_shortcut = conv_shortcut
|
||||
|
||||
self.norm1 = Normalize(in_channels)
|
||||
self.conv1 = torch.nn.Conv2d(
|
||||
in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
if temb_channels > 0:
|
||||
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
|
||||
self.norm2 = Normalize(out_channels)
|
||||
self.dropout = torch.nn.Dropout(dropout)
|
||||
self.conv2 = torch.nn.Conv2d(
|
||||
out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
self.conv_shortcut = torch.nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
else:
|
||||
self.nin_shortcut = torch.nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
|
||||
def forward(self, x, temb):
|
||||
h = x
|
||||
h = self.norm1(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv1(h)
|
||||
|
||||
if temb is not None:
|
||||
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
|
||||
|
||||
h = self.norm2(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.dropout(h)
|
||||
h = self.conv2(h)
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
x = self.conv_shortcut(x)
|
||||
else:
|
||||
x = self.nin_shortcut(x)
|
||||
|
||||
return x + h
|
||||
|
||||
|
||||
class AttnBlock(nn.Module):
|
||||
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = Normalize(in_channels)
|
||||
self.q = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.k = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.v = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.proj_out = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
# compute attention
|
||||
b, c, h, w = q.shape
|
||||
q = q.reshape(b, c, h * w)
|
||||
q = q.permute(0, 2, 1)
|
||||
k = k.reshape(b, c, h * w)
|
||||
w_ = torch.bmm(q, k)
|
||||
w_ = w_ * (int(c)**(-0.5))
|
||||
w_ = torch.nn.functional.softmax(w_, dim=2)
|
||||
|
||||
# attend to values
|
||||
v = v.reshape(b, c, h * w)
|
||||
w_ = w_.permute(0, 2, 1)
|
||||
h_ = torch.bmm(v, w_)
|
||||
h_ = h_.reshape(b, c, h, w)
|
||||
|
||||
h_ = self.proj_out(h_)
|
||||
|
||||
return x + h_
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, with_conv):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
if self.with_conv:
|
||||
self.conv = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
x = torch.nn.functional.interpolate(
|
||||
x, scale_factor=2.0, mode='nearest')
|
||||
if self.with_conv:
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, with_conv):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
if self.with_conv:
|
||||
# no asymmetric padding in torch conv, must do it ourselves
|
||||
self.conv = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
if self.with_conv:
|
||||
pad = (0, 1, 0, 1)
|
||||
x = torch.nn.functional.pad(x, pad, mode='constant', value=0)
|
||||
x = self.conv(x)
|
||||
else:
|
||||
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
||||
return x
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
ch,
|
||||
out_ch,
|
||||
ch_mult=(1, 2, 4, 8),
|
||||
num_res_blocks,
|
||||
attn_resolutions,
|
||||
dropout=0.0,
|
||||
resamp_with_conv=True,
|
||||
in_channels,
|
||||
resolution,
|
||||
z_channels,
|
||||
double_z=True,
|
||||
use_linear_attn=False,
|
||||
attn_type='vanilla',
|
||||
**ignore_kwargs):
|
||||
super().__init__()
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
|
||||
# downsampling
|
||||
self.conv_in = torch.nn.Conv2d(
|
||||
in_channels, self.ch, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
curr_res = resolution
|
||||
in_ch_mult = (1, ) + tuple(ch_mult)
|
||||
self.in_ch_mult = in_ch_mult
|
||||
self.down = nn.ModuleList()
|
||||
for i_level in range(self.num_resolutions):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_in = ch * in_ch_mult[i_level]
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks):
|
||||
block.append(
|
||||
ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout))
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(AttnBlock(block_in))
|
||||
down = nn.Module()
|
||||
down.block = block
|
||||
down.attn = attn
|
||||
if i_level != self.num_resolutions - 1:
|
||||
down.downsample = Downsample(block_in, resamp_with_conv)
|
||||
curr_res = curr_res // 2
|
||||
self.down.append(down)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
self.mid.attn_1 = AttnBlock(block_in)
|
||||
self.mid.block_2 = ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = torch.nn.Conv2d(
|
||||
block_in,
|
||||
2 * z_channels if double_z else z_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
# timestep embedding
|
||||
temb = None
|
||||
|
||||
# downsampling
|
||||
hs = [self.conv_in(x)]
|
||||
for i_level in range(self.num_resolutions):
|
||||
for i_block in range(self.num_res_blocks):
|
||||
h = self.down[i_level].block[i_block](hs[-1], temb)
|
||||
if len(self.down[i_level].attn) > 0:
|
||||
h = self.down[i_level].attn[i_block](h)
|
||||
hs.append(h)
|
||||
if i_level != self.num_resolutions - 1:
|
||||
hs.append(self.down[i_level].downsample(hs[-1]))
|
||||
|
||||
# middle
|
||||
h = hs[-1]
|
||||
h = self.mid.block_1(h, temb)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h, temb)
|
||||
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
ch,
|
||||
out_ch,
|
||||
ch_mult=(1, 2, 4, 8),
|
||||
num_res_blocks,
|
||||
attn_resolutions,
|
||||
dropout=0.0,
|
||||
resamp_with_conv=True,
|
||||
in_channels,
|
||||
resolution,
|
||||
z_channels,
|
||||
give_pre_end=False,
|
||||
tanh_out=False,
|
||||
use_linear_attn=False,
|
||||
attn_type='vanilla',
|
||||
**ignorekwargs):
|
||||
super().__init__()
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
self.give_pre_end = give_pre_end
|
||||
self.tanh_out = tanh_out
|
||||
|
||||
# compute block_in and curr_res at lowest res
|
||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||
curr_res = resolution // 2**(self.num_resolutions - 1)
|
||||
self.z_shape = (1, z_channels, curr_res, curr_res)
|
||||
logger.info('Working with z of shape {} = {} dimensions.'.format(
|
||||
self.z_shape, np.prod(self.z_shape)))
|
||||
|
||||
# z to block_in
|
||||
self.conv_in = torch.nn.Conv2d(
|
||||
z_channels, block_in, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
self.mid.attn_1 = AttnBlock(block_in)
|
||||
self.mid.block_2 = ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
|
||||
# upsampling
|
||||
self.up = nn.ModuleList()
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
block.append(
|
||||
ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout))
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(AttnBlock(block_in))
|
||||
up = nn.Module()
|
||||
up.block = block
|
||||
up.attn = attn
|
||||
if i_level != 0:
|
||||
up.upsample = Upsample(block_in, resamp_with_conv)
|
||||
curr_res = curr_res * 2
|
||||
self.up.insert(0, up)
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = torch.nn.Conv2d(
|
||||
block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, z):
|
||||
self.last_z_shape = z.shape
|
||||
|
||||
# timestep embedding
|
||||
temb = None
|
||||
|
||||
# z to block_in
|
||||
h = self.conv_in(z)
|
||||
|
||||
# middle
|
||||
h = self.mid.block_1(h, temb)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h, temb)
|
||||
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
h = self.up[i_level].block[i_block](h, temb)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
h = self.up[i_level].attn[i_block](h)
|
||||
if i_level != 0:
|
||||
h = self.up[i_level].upsample(h)
|
||||
|
||||
# end
|
||||
if self.give_pre_end:
|
||||
return h
|
||||
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
if self.tanh_out:
|
||||
h = torch.tanh(h)
|
||||
return h
|
||||
|
||||
|
||||
class AutoencoderKL(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
ddconfig,
|
||||
embed_dim,
|
||||
pretrained=None,
|
||||
ignore_keys=[],
|
||||
image_key='image',
|
||||
colorize_nlabels=None,
|
||||
monitor=None,
|
||||
ema_decay=None,
|
||||
learn_logvar=False,
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
self.learn_logvar = learn_logvar
|
||||
self.image_key = image_key
|
||||
self.encoder = Encoder(**ddconfig)
|
||||
self.decoder = Decoder(**ddconfig)
|
||||
assert ddconfig['double_z']
|
||||
self.quant_conv = torch.nn.Conv2d(2 * ddconfig['z_channels'],
|
||||
2 * embed_dim, 1)
|
||||
self.post_quant_conv = torch.nn.Conv2d(embed_dim,
|
||||
ddconfig['z_channels'], 1)
|
||||
self.embed_dim = embed_dim
|
||||
if colorize_nlabels is not None:
|
||||
assert type(colorize_nlabels) == int
|
||||
self.register_buffer('colorize',
|
||||
torch.randn(3, colorize_nlabels, 1, 1))
|
||||
if monitor is not None:
|
||||
self.monitor = monitor
|
||||
|
||||
self.use_ema = ema_decay is not None
|
||||
|
||||
if pretrained is not None:
|
||||
self.init_from_ckpt(pretrained, ignore_keys=ignore_keys)
|
||||
|
||||
def init_from_ckpt(self, path, ignore_keys=list()):
|
||||
sd = torch.load(path, map_location='cpu')['state_dict']
|
||||
keys = list(sd.keys())
|
||||
sd_new = collections.OrderedDict()
|
||||
for k in keys:
|
||||
if k.find('first_stage_model') >= 0:
|
||||
k_new = k.split('first_stage_model.')[-1]
|
||||
sd_new[k_new] = sd[k]
|
||||
self.load_state_dict(sd_new, strict=True)
|
||||
logger.info(f'Restored from {path}')
|
||||
|
||||
def on_train_batch_end(self, *args, **kwargs):
|
||||
if self.use_ema:
|
||||
self.model_ema(self)
|
||||
|
||||
def encode(self, x):
|
||||
h = self.encoder(x)
|
||||
moments = self.quant_conv(h)
|
||||
posterior = DiagonalGaussianDistribution(moments)
|
||||
return posterior
|
||||
|
||||
def decode(self, z):
|
||||
z = self.post_quant_conv(z)
|
||||
dec = self.decoder(z)
|
||||
return dec
|
||||
|
||||
def forward(self, input, sample_posterior=True):
|
||||
posterior = self.encode(input)
|
||||
if sample_posterior:
|
||||
z = posterior.sample()
|
||||
else:
|
||||
z = posterior.mode()
|
||||
dec = self.decode(z)
|
||||
return dec, posterior
|
||||
|
||||
def get_input(self, batch, k):
|
||||
x = batch[k]
|
||||
if len(x.shape) == 3:
|
||||
x = x[..., None]
|
||||
x = x.permute(0, 3, 1,
|
||||
2).to(memory_format=torch.contiguous_format).float()
|
||||
return x
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.decoder.conv_out.weight
|
||||
|
||||
@torch.no_grad()
|
||||
def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs):
|
||||
log = dict()
|
||||
x = self.get_input(batch, self.image_key)
|
||||
x = x.to(self.device)
|
||||
if not only_inputs:
|
||||
xrec, posterior = self(x)
|
||||
if x.shape[1] > 3:
|
||||
# colorize with random projection
|
||||
assert xrec.shape[1] > 3
|
||||
x = self.to_rgb(x)
|
||||
xrec = self.to_rgb(xrec)
|
||||
log['samples'] = self.decode(torch.randn_like(posterior.sample()))
|
||||
log['reconstructions'] = xrec
|
||||
if log_ema or self.use_ema:
|
||||
with self.ema_scope():
|
||||
xrec_ema, posterior_ema = self(x)
|
||||
if x.shape[1] > 3:
|
||||
# colorize with random projection
|
||||
assert xrec_ema.shape[1] > 3
|
||||
xrec_ema = self.to_rgb(xrec_ema)
|
||||
log['samples_ema'] = self.decode(
|
||||
torch.randn_like(posterior_ema.sample()))
|
||||
log['reconstructions_ema'] = xrec_ema
|
||||
log['inputs'] = x
|
||||
return log
|
||||
|
||||
def to_rgb(self, x):
|
||||
assert self.image_key == 'segmentation'
|
||||
if not hasattr(self, 'colorize'):
|
||||
self.register_buffer('colorize',
|
||||
torch.randn(3, x.shape[1], 1, 1).to(x))
|
||||
x = F.conv2d(x, weight=self.colorize)
|
||||
x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
|
||||
return x
|
||||
|
||||
|
||||
class IdentityFirstStage(torch.nn.Module):
|
||||
|
||||
def __init__(self, *args, vq_interface=False, **kwargs):
|
||||
self.vq_interface = vq_interface
|
||||
super().__init__()
|
||||
|
||||
def encode(self, x, *args, **kwargs):
|
||||
return x
|
||||
|
||||
def decode(self, x, *args, **kwargs):
|
||||
return x
|
||||
|
||||
def quantize(self, x, *args, **kwargs):
|
||||
if self.vq_interface:
|
||||
return x, None, [None, None, None]
|
||||
return x
|
||||
|
||||
def forward(self, x, *args, **kwargs):
|
||||
return x
|
||||
@@ -0,0 +1,76 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import open_clip
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.transforms as T
|
||||
|
||||
|
||||
class FrozenOpenCLIPEmbedder(nn.Module):
|
||||
"""
|
||||
Uses the OpenCLIP transformer encoder for text
|
||||
"""
|
||||
LAYERS = ['last', 'penultimate']
|
||||
|
||||
def __init__(self,
|
||||
pretrained,
|
||||
arch='ViT-H-14',
|
||||
device='cuda',
|
||||
max_length=77,
|
||||
freeze=True,
|
||||
layer='penultimate'):
|
||||
super().__init__()
|
||||
assert layer in self.LAYERS
|
||||
model, _, preprocess = open_clip.create_model_and_transforms(
|
||||
arch, device=torch.device('cpu'), pretrained=pretrained)
|
||||
|
||||
del model.visual
|
||||
self.model = model
|
||||
self.device = device
|
||||
self.max_length = max_length
|
||||
|
||||
if freeze:
|
||||
self.freeze()
|
||||
self.layer = layer
|
||||
if self.layer == 'last':
|
||||
self.layer_idx = 0
|
||||
elif self.layer == 'penultimate':
|
||||
self.layer_idx = 1
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
def freeze(self):
|
||||
self.model = self.model.eval()
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, text):
|
||||
tokens = open_clip.tokenize(text)
|
||||
z = self.encode_with_transformer(tokens.to(self.device))
|
||||
return z
|
||||
|
||||
def encode_with_transformer(self, text):
|
||||
x = self.model.token_embedding(text)
|
||||
x = x + self.model.positional_embedding
|
||||
x = x.permute(1, 0, 2)
|
||||
x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
|
||||
x = x.permute(1, 0, 2)
|
||||
x = self.model.ln_final(x)
|
||||
return x
|
||||
|
||||
def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
|
||||
for i, r in enumerate(self.model.transformer.resblocks):
|
||||
if i == len(self.model.transformer.resblocks) - self.layer_idx:
|
||||
break
|
||||
if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting(
|
||||
):
|
||||
x = checkpoint(r, x, attn_mask)
|
||||
else:
|
||||
x = r(x, attn_mask=attn_mask)
|
||||
return x
|
||||
|
||||
def encode(self, text):
|
||||
return self(text)
|
||||
1530
modelscope/models/multi_modal/video_to_video/modules/unet_v2v.py
Normal file
1530
modelscope/models/multi_modal/video_to_video/modules/unet_v2v.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,2 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
171
modelscope/models/multi_modal/video_to_video/utils/config.py
Normal file
171
modelscope/models/multi_modal/video_to_video/utils/config.py
Normal file
@@ -0,0 +1,171 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import os.path as osp
|
||||
from datetime import datetime
|
||||
|
||||
import torch
|
||||
from easydict import EasyDict
|
||||
|
||||
cfg = EasyDict(__name__='Config: VideoLDM Decoder')
|
||||
|
||||
# ---------------------------work dir--------------------------
|
||||
cfg.work_dir = 'workspace/'
|
||||
|
||||
# ---------------------------Global Variable-----------------------------------
|
||||
cfg.resolution = [448, 256]
|
||||
cfg.max_frames = 32
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
# ---------------------------Dataset Parameter---------------------------------
|
||||
cfg.mean = [0.5, 0.5, 0.5]
|
||||
cfg.std = [0.5, 0.5, 0.5]
|
||||
cfg.max_words = 1000
|
||||
|
||||
# PlaceHolder
|
||||
cfg.vit_out_dim = 1024
|
||||
cfg.vit_resolution = [224, 224]
|
||||
cfg.depth_clamp = 10.0
|
||||
cfg.misc_size = 384
|
||||
cfg.depth_std = 20.0
|
||||
|
||||
cfg.frame_lens = 32
|
||||
cfg.sample_fps = 8
|
||||
|
||||
cfg.batch_sizes = 1
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
# ---------------------------Mode Parameters-----------------------------------
|
||||
# Diffusion
|
||||
cfg.schedule = 'cosine'
|
||||
cfg.num_timesteps = 1000
|
||||
cfg.mean_type = 'v'
|
||||
cfg.var_type = 'fixed_small'
|
||||
cfg.loss_type = 'mse'
|
||||
cfg.ddim_timesteps = 50
|
||||
cfg.ddim_eta = 0.0
|
||||
cfg.clamp = 1.0
|
||||
cfg.share_noise = False
|
||||
cfg.use_div_loss = False
|
||||
cfg.noise_strength = 0.1
|
||||
|
||||
# classifier-free guidance
|
||||
cfg.p_zero = 0.1
|
||||
cfg.guide_scale = 3.0
|
||||
|
||||
# clip vision encoder
|
||||
cfg.vit_mean = [0.48145466, 0.4578275, 0.40821073]
|
||||
cfg.vit_std = [0.26862954, 0.26130258, 0.27577711]
|
||||
|
||||
# Model
|
||||
cfg.scale_factor = 0.18215
|
||||
cfg.use_fp16 = True
|
||||
cfg.temporal_attention = True
|
||||
cfg.decoder_bs = 8
|
||||
|
||||
cfg.UNet = {
|
||||
'type': 'Vid2VidSDUNet',
|
||||
'in_dim': 4,
|
||||
'dim': 320,
|
||||
'y_dim': cfg.vit_out_dim,
|
||||
'context_dim': 1024,
|
||||
'out_dim': 8 if cfg.var_type.startswith('learned') else 4,
|
||||
'dim_mult': [1, 2, 4, 4],
|
||||
'num_heads': 8,
|
||||
'head_dim': 64,
|
||||
'num_res_blocks': 2,
|
||||
'attn_scales': [1 / 1, 1 / 2, 1 / 4],
|
||||
'dropout': 0.1,
|
||||
'temporal_attention': cfg.temporal_attention,
|
||||
'temporal_attn_times': 1,
|
||||
'use_checkpoint': False,
|
||||
'use_fps_condition': False,
|
||||
'use_sim_mask': False,
|
||||
'num_tokens': 4,
|
||||
'default_fps': 8,
|
||||
'input_dim': 1024
|
||||
}
|
||||
|
||||
cfg.guidances = []
|
||||
|
||||
# auotoencoder from stabel diffusion
|
||||
cfg.auto_encoder = {
|
||||
'type': 'AutoencoderKL',
|
||||
'ddconfig': {
|
||||
'double_z': True,
|
||||
'z_channels': 4,
|
||||
'resolution': 256,
|
||||
'in_channels': 3,
|
||||
'out_ch': 3,
|
||||
'ch': 128,
|
||||
'ch_mult': [1, 2, 4, 4],
|
||||
'num_res_blocks': 2,
|
||||
'attn_resolutions': [],
|
||||
'dropout': 0.0
|
||||
},
|
||||
'embed_dim': 4,
|
||||
'pretrained': 'models/v2-1_512-ema-pruned.ckpt'
|
||||
}
|
||||
# clip embedder
|
||||
cfg.embedder = {
|
||||
'type': 'FrozenOpenCLIPEmbedder',
|
||||
'layer': 'penultimate',
|
||||
'vit_resolution': [224, 224],
|
||||
'pretrained': 'open_clip_pytorch_model.bin'
|
||||
}
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
# ---------------------------Training Settings---------------------------------
|
||||
# training and optimizer
|
||||
cfg.ema_decay = 0.9999
|
||||
cfg.num_steps = 600000
|
||||
cfg.lr = 5e-5
|
||||
cfg.weight_decay = 0.0
|
||||
cfg.betas = (0.9, 0.999)
|
||||
cfg.eps = 1.0e-8
|
||||
cfg.chunk_size = 16
|
||||
cfg.alpha = 0.7
|
||||
cfg.save_ckp_interval = 1000
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
# ----------------------------Pretrain Settings---------------------------------
|
||||
# Default: load 2d pretrain
|
||||
cfg.fix_weight = False
|
||||
cfg.load_match = False
|
||||
cfg.pretrained_checkpoint = 'v2-1_512-ema-pruned.ckpt'
|
||||
cfg.pretrained_image_keys = 'stable_diffusion_image_key_temporal_attention_x1.json'
|
||||
cfg.resume_checkpoint = 'img2video_ldm_0779000.pth'
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
# -----------------------------Visual-------------------------------------------
|
||||
# Visual videos
|
||||
cfg.viz_interval = 1000
|
||||
cfg.visual_train = {
|
||||
'type': 'VisualVideoTextDuringTrain',
|
||||
}
|
||||
cfg.visual_inference = {
|
||||
'type': 'VisualGeneratedVideos',
|
||||
}
|
||||
cfg.inference_list_path = ''
|
||||
|
||||
# logging
|
||||
cfg.log_interval = 100
|
||||
|
||||
# Default log_dir
|
||||
cfg.log_dir = 'workspace/output_data'
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
# ---------------------------Others--------------------------------------------
|
||||
# seed
|
||||
cfg.seed = 8888
|
||||
cfg.negative_prompt = 'worst quality, normal quality, low quality, low res, blurry, text, \
|
||||
watermark, logo, banner, extra digits, cropped, jpeg artifacts, signature, username, error, \
|
||||
sketch ,duplicate, ugly, monochrome, horror, geometry, mutation, disgusting'
|
||||
|
||||
cfg.positive_prompt = ', cinematic, High Contrast, highly detailed, unreal engine, \
|
||||
taken using a Canon EOS R camera, hyper detailed photo - realistic maximum detail, \
|
||||
32k, Color Grading, ultra HD, extreme meticulous detailing, skin pore detailing, \
|
||||
hyper sharpness, perfect without deformations, Unreal Engine 5, 4k render'
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -0,0 +1,247 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import random
|
||||
|
||||
import torch
|
||||
|
||||
from .schedules_sdedit import karras_schedule
|
||||
from .solvers_sdedit import sample_dpmpp_2m_sde, sample_heun
|
||||
|
||||
__all__ = ['GaussianDiffusion_SDEdit']
|
||||
|
||||
|
||||
def _i(tensor, t, x):
|
||||
shape = (x.size(0), ) + (1, ) * (x.ndim - 1)
|
||||
return tensor[t.to(tensor.device)].view(shape).to(x.device)
|
||||
|
||||
|
||||
class GaussianDiffusion_SDEdit(object):
|
||||
|
||||
def __init__(self, sigmas, prediction_type='eps'):
|
||||
assert prediction_type in {'x0', 'eps', 'v'}
|
||||
self.sigmas = sigmas
|
||||
self.alphas = torch.sqrt(1 - sigmas**2)
|
||||
self.num_timesteps = len(sigmas)
|
||||
self.prediction_type = prediction_type
|
||||
|
||||
def diffuse(self, x0, t, noise=None):
|
||||
noise = torch.randn_like(x0) if noise is None else noise
|
||||
xt = _i(self.alphas, t, x0) * x0 + _i(self.sigmas, t, x0) * noise
|
||||
return xt
|
||||
|
||||
def denoise(self,
|
||||
xt,
|
||||
t,
|
||||
s,
|
||||
model,
|
||||
model_kwargs={},
|
||||
guide_scale=None,
|
||||
guide_rescale=None,
|
||||
clamp=None,
|
||||
percentile=None):
|
||||
s = t - 1 if s is None else s
|
||||
|
||||
# hyperparams
|
||||
sigmas = _i(self.sigmas, t, xt)
|
||||
alphas = _i(self.alphas, t, xt)
|
||||
alphas_s = _i(self.alphas, s.clamp(0), xt)
|
||||
alphas_s[s < 0] = 1.
|
||||
sigmas_s = torch.sqrt(1 - alphas_s**2)
|
||||
|
||||
# precompute variables
|
||||
betas = 1 - (alphas / alphas_s)**2
|
||||
coef1 = betas * alphas_s / sigmas**2
|
||||
coef2 = (alphas * sigmas_s**2) / (alphas_s * sigmas**2)
|
||||
var = betas * (sigmas_s / sigmas)**2
|
||||
log_var = torch.log(var).clamp_(-20, 20)
|
||||
|
||||
# prediction
|
||||
if guide_scale is None:
|
||||
assert isinstance(model_kwargs, dict)
|
||||
out = model(xt, t=t, **model_kwargs)
|
||||
else:
|
||||
# classifier-free guidance
|
||||
assert isinstance(model_kwargs, list) and len(model_kwargs) == 2
|
||||
y_out = model(xt, t=t, **model_kwargs[0])
|
||||
if guide_scale == 1.:
|
||||
out = y_out
|
||||
else:
|
||||
u_out = model(xt, t=t, **model_kwargs[1])
|
||||
out = u_out + guide_scale * (y_out - u_out)
|
||||
|
||||
if guide_rescale is not None:
|
||||
assert guide_rescale >= 0 and guide_rescale <= 1
|
||||
ratio = (
|
||||
y_out.flatten(1).std(dim=1) / # noqa
|
||||
(out.flatten(1).std(dim=1) + 1e-12)
|
||||
).view((-1, ) + (1, ) * (y_out.ndim - 1))
|
||||
out *= guide_rescale * ratio + (1 - guide_rescale) * 1.0
|
||||
|
||||
# compute x0
|
||||
if self.prediction_type == 'x0':
|
||||
x0 = out
|
||||
elif self.prediction_type == 'eps':
|
||||
x0 = (xt - sigmas * out) / alphas
|
||||
elif self.prediction_type == 'v':
|
||||
x0 = alphas * xt - sigmas * out
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f'prediction_type {self.prediction_type} not implemented')
|
||||
|
||||
# restrict the range of x0
|
||||
if percentile is not None:
|
||||
assert percentile > 0 and percentile <= 1
|
||||
s = torch.quantile(x0.flatten(1).abs(), percentile, dim=1)
|
||||
s = s.clamp_(1.0).view((-1, ) + (1, ) * (xt.ndim - 1))
|
||||
x0 = torch.min(s, torch.max(-s, x0)) / s
|
||||
elif clamp is not None:
|
||||
x0 = x0.clamp(-clamp, clamp)
|
||||
|
||||
# recompute eps using the restricted x0
|
||||
eps = (xt - alphas * x0) / sigmas
|
||||
|
||||
# compute mu (mean of posterior distribution) using the restricted x0
|
||||
mu = coef1 * x0 + coef2 * xt
|
||||
return mu, var, log_var, x0, eps
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self,
|
||||
noise,
|
||||
model,
|
||||
model_kwargs={},
|
||||
condition_fn=None,
|
||||
guide_scale=None,
|
||||
guide_rescale=None,
|
||||
clamp=None,
|
||||
percentile=None,
|
||||
solver='euler_a',
|
||||
steps=20,
|
||||
t_max=None,
|
||||
t_min=None,
|
||||
discretization=None,
|
||||
discard_penultimate_step=None,
|
||||
return_intermediate=None,
|
||||
show_progress=False,
|
||||
seed=-1,
|
||||
**kwargs):
|
||||
# sanity check
|
||||
assert isinstance(steps, (int, torch.LongTensor))
|
||||
assert t_max is None or (t_max > 0 and t_max <= self.num_timesteps - 1)
|
||||
assert t_min is None or (t_min >= 0 and t_min < self.num_timesteps - 1)
|
||||
assert discretization in (None, 'leading', 'linspace', 'trailing')
|
||||
assert discard_penultimate_step in (None, True, False)
|
||||
assert return_intermediate in (None, 'x0', 'xt')
|
||||
|
||||
# function of diffusion solver
|
||||
solver_fn = {
|
||||
'heun': sample_heun,
|
||||
'dpmpp_2m_sde': sample_dpmpp_2m_sde
|
||||
}[solver]
|
||||
|
||||
# options
|
||||
schedule = 'karras' if 'karras' in solver else None
|
||||
discretization = discretization or 'linspace'
|
||||
seed = seed if seed >= 0 else random.randint(0, 2**31)
|
||||
if isinstance(steps, torch.LongTensor):
|
||||
discard_penultimate_step = False
|
||||
if discard_penultimate_step is None:
|
||||
discard_penultimate_step = True if solver in (
|
||||
'dpm2', 'dpm2_ancestral', 'dpmpp_2m_sde', 'dpm2_karras',
|
||||
'dpm2_ancestral_karras', 'dpmpp_2m_sde_karras') else False
|
||||
|
||||
# function for denoising xt to get x0
|
||||
intermediates = []
|
||||
|
||||
def model_fn(xt, sigma):
|
||||
# denoising
|
||||
t = self._sigma_to_t(sigma).repeat(len(xt)).round().long()
|
||||
x0 = self.denoise(xt, t, None, model, model_kwargs, guide_scale,
|
||||
guide_rescale, clamp, percentile)[-2]
|
||||
|
||||
# collect intermediate outputs
|
||||
if return_intermediate == 'xt':
|
||||
intermediates.append(xt)
|
||||
elif return_intermediate == 'x0':
|
||||
intermediates.append(x0)
|
||||
return x0
|
||||
|
||||
# get timesteps
|
||||
if isinstance(steps, int):
|
||||
steps += 1 if discard_penultimate_step else 0
|
||||
t_max = self.num_timesteps - 1 if t_max is None else t_max
|
||||
t_min = 0 if t_min is None else t_min
|
||||
|
||||
# discretize timesteps
|
||||
if discretization == 'leading':
|
||||
steps = torch.arange(t_min, t_max + 1,
|
||||
(t_max - t_min + 1) / steps).flip(0)
|
||||
elif discretization == 'linspace':
|
||||
steps = torch.linspace(t_max, t_min, steps)
|
||||
elif discretization == 'trailing':
|
||||
steps = torch.arange(t_max, t_min - 1,
|
||||
-((t_max - t_min + 1) / steps))
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f'{discretization} discretization not implemented')
|
||||
steps = steps.clamp_(t_min, t_max)
|
||||
steps = torch.as_tensor(
|
||||
steps, dtype=torch.float32, device=noise.device)
|
||||
|
||||
# get sigmas
|
||||
sigmas = self._t_to_sigma(steps)
|
||||
sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
|
||||
if schedule == 'karras':
|
||||
if sigmas[0] == float('inf'):
|
||||
sigmas = karras_schedule(
|
||||
n=len(steps) - 1,
|
||||
sigma_min=sigmas[sigmas > 0].min().item(),
|
||||
sigma_max=sigmas[sigmas < float('inf')].max().item(),
|
||||
rho=7.).to(sigmas)
|
||||
sigmas = torch.cat([
|
||||
sigmas.new_tensor([float('inf')]), sigmas,
|
||||
sigmas.new_zeros([1])
|
||||
])
|
||||
else:
|
||||
sigmas = karras_schedule(
|
||||
n=len(steps),
|
||||
sigma_min=sigmas[sigmas > 0].min().item(),
|
||||
sigma_max=sigmas.max().item(),
|
||||
rho=7.).to(sigmas)
|
||||
sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
|
||||
if discard_penultimate_step:
|
||||
sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
|
||||
|
||||
# sampling
|
||||
x0 = solver_fn(
|
||||
noise, model_fn, sigmas, show_progress=show_progress, **kwargs)
|
||||
return (x0, intermediates) if return_intermediate is not None else x0
|
||||
|
||||
def _sigma_to_t(self, sigma):
|
||||
if sigma == float('inf'):
|
||||
t = torch.full_like(sigma, len(self.sigmas) - 1)
|
||||
else:
|
||||
log_sigmas = torch.sqrt(self.sigmas**2 / # noqa
|
||||
(1 - self.sigmas**2)).log().to(sigma)
|
||||
log_sigma = sigma.log()
|
||||
dists = log_sigma - log_sigmas[:, None]
|
||||
low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(
|
||||
max=log_sigmas.shape[0] - 2)
|
||||
high_idx = low_idx + 1
|
||||
low, high = log_sigmas[low_idx], log_sigmas[high_idx]
|
||||
w = (low - log_sigma) / (low - high)
|
||||
w = w.clamp(0, 1)
|
||||
t = (1 - w) * low_idx + w * high_idx
|
||||
t = t.view(sigma.shape)
|
||||
if t.ndim == 0:
|
||||
t = t.unsqueeze(0)
|
||||
return t
|
||||
|
||||
def _t_to_sigma(self, t):
|
||||
t = t.float()
|
||||
low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac()
|
||||
log_sigmas = torch.sqrt(self.sigmas**2 / # noqa
|
||||
(1 - self.sigmas**2)).log().to(t)
|
||||
log_sigma = (1 - w) * log_sigmas[low_idx] + w * log_sigmas[high_idx]
|
||||
log_sigma[torch.isnan(log_sigma)
|
||||
| torch.isinf(log_sigma)] = float('inf')
|
||||
return log_sigma.exp()
|
||||
@@ -0,0 +1,85 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def betas_to_sigmas(betas):
|
||||
return torch.sqrt(1 - torch.cumprod(1 - betas, dim=0))
|
||||
|
||||
|
||||
def sigmas_to_betas(sigmas):
|
||||
square_alphas = 1 - sigmas**2
|
||||
betas = 1 - torch.cat(
|
||||
[square_alphas[:1], square_alphas[1:] / square_alphas[:-1]])
|
||||
return betas
|
||||
|
||||
|
||||
def logsnrs_to_sigmas(logsnrs):
|
||||
return torch.sqrt(torch.sigmoid(-logsnrs))
|
||||
|
||||
|
||||
def sigmas_to_logsnrs(sigmas):
|
||||
square_sigmas = sigmas**2
|
||||
return torch.log(square_sigmas / (1 - square_sigmas))
|
||||
|
||||
|
||||
def _logsnr_cosine(n, logsnr_min=-15, logsnr_max=15):
|
||||
t_min = math.atan(math.exp(-0.5 * logsnr_min))
|
||||
t_max = math.atan(math.exp(-0.5 * logsnr_max))
|
||||
t = torch.linspace(1, 0, n)
|
||||
logsnrs = -2 * torch.log(torch.tan(t_min + t * (t_max - t_min)))
|
||||
return logsnrs
|
||||
|
||||
|
||||
def _logsnr_cosine_shifted(n, logsnr_min=-15, logsnr_max=15, scale=2):
|
||||
logsnrs = _logsnr_cosine(n, logsnr_min, logsnr_max)
|
||||
logsnrs += 2 * math.log(1 / scale)
|
||||
return logsnrs
|
||||
|
||||
|
||||
def _logsnr_cosine_interp(n,
|
||||
logsnr_min=-15,
|
||||
logsnr_max=15,
|
||||
scale_min=2,
|
||||
scale_max=4):
|
||||
t = torch.linspace(1, 0, n)
|
||||
logsnrs_min = _logsnr_cosine_shifted(n, logsnr_min, logsnr_max, scale_min)
|
||||
logsnrs_max = _logsnr_cosine_shifted(n, logsnr_min, logsnr_max, scale_max)
|
||||
logsnrs = t * logsnrs_min + (1 - t) * logsnrs_max
|
||||
return logsnrs
|
||||
|
||||
|
||||
def karras_schedule(n, sigma_min=0.002, sigma_max=80.0, rho=7.0):
|
||||
ramp = torch.linspace(1, 0, n)
|
||||
min_inv_rho = sigma_min**(1 / rho)
|
||||
max_inv_rho = sigma_max**(1 / rho)
|
||||
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho))**rho
|
||||
sigmas = torch.sqrt(sigmas**2 / (1 + sigmas**2))
|
||||
return sigmas
|
||||
|
||||
|
||||
def logsnr_cosine_interp_schedule(n,
|
||||
logsnr_min=-15,
|
||||
logsnr_max=15,
|
||||
scale_min=2,
|
||||
scale_max=4):
|
||||
return logsnrs_to_sigmas(
|
||||
_logsnr_cosine_interp(n, logsnr_min, logsnr_max, scale_min, scale_max))
|
||||
|
||||
|
||||
def noise_schedule(schedule='logsnr_cosine_interp',
|
||||
n=1000,
|
||||
zero_terminal_snr=False,
|
||||
**kwargs):
|
||||
# compute sigmas
|
||||
sigmas = {
|
||||
'logsnr_cosine_interp': logsnr_cosine_interp_schedule
|
||||
}[schedule](n, **kwargs)
|
||||
|
||||
# post-processing
|
||||
if zero_terminal_snr and sigmas.max() != 1.0:
|
||||
scale = (1.0 - sigmas.min()) / (sigmas.max() - sigmas.min())
|
||||
sigmas = sigmas.min() + scale * (sigmas - sigmas.min())
|
||||
return sigmas
|
||||
14
modelscope/models/multi_modal/video_to_video/utils/seed.py
Normal file
14
modelscope/models/multi_modal/video_to_video/utils/seed.py
Normal file
@@ -0,0 +1,14 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def setup_seed(seed):
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
@@ -0,0 +1,194 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import torch
|
||||
import torchsde
|
||||
from tqdm.auto import trange
|
||||
|
||||
|
||||
def get_ancestral_step(sigma_from, sigma_to, eta=1.):
|
||||
"""
|
||||
Calculates the noise level (sigma_down) to step down to and the amount
|
||||
of noise to add (sigma_up) when doing an ancestral sampling step.
|
||||
"""
|
||||
if not eta:
|
||||
return sigma_to, 0.
|
||||
sigma_up = min(
|
||||
sigma_to,
|
||||
eta * (
|
||||
sigma_to**2 * # noqa
|
||||
(sigma_from**2 - sigma_to**2) / sigma_from**2)**0.5)
|
||||
sigma_down = (sigma_to**2 - sigma_up**2)**0.5
|
||||
return sigma_down, sigma_up
|
||||
|
||||
|
||||
def get_scalings(sigma):
|
||||
c_out = -sigma
|
||||
c_in = 1 / (sigma**2 + 1.**2)**0.5
|
||||
return c_out, c_in
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_heun(noise,
|
||||
model,
|
||||
sigmas,
|
||||
s_churn=0.,
|
||||
s_tmin=0.,
|
||||
s_tmax=float('inf'),
|
||||
s_noise=1.,
|
||||
show_progress=True):
|
||||
"""
|
||||
Implements Algorithm 2 (Heun steps) from Karras et al. (2022).
|
||||
"""
|
||||
x = noise * sigmas[0]
|
||||
for i in trange(len(sigmas) - 1, disable=not show_progress):
|
||||
gamma = 0.
|
||||
if s_tmin <= sigmas[i] <= s_tmax and sigmas[i] < float('inf'):
|
||||
gamma = min(s_churn / (len(sigmas) - 1), 2**0.5 - 1)
|
||||
eps = torch.randn_like(x) * s_noise
|
||||
sigma_hat = sigmas[i] * (gamma + 1)
|
||||
if gamma > 0:
|
||||
x = x + eps * (sigma_hat**2 - sigmas[i]**2)**0.5
|
||||
if sigmas[i] == float('inf'):
|
||||
# Euler method
|
||||
denoised = model(noise, sigma_hat)
|
||||
x = denoised + sigmas[i + 1] * (gamma + 1) * noise
|
||||
else:
|
||||
_, c_in = get_scalings(sigma_hat)
|
||||
denoised = model(x * c_in, sigma_hat)
|
||||
d = (x - denoised) / sigma_hat
|
||||
dt = sigmas[i + 1] - sigma_hat
|
||||
if sigmas[i + 1] == 0:
|
||||
# Euler method
|
||||
x = x + d * dt
|
||||
else:
|
||||
# Heun's method
|
||||
x_2 = x + d * dt
|
||||
_, c_in = get_scalings(sigmas[i + 1])
|
||||
denoised_2 = model(x_2 * c_in, sigmas[i + 1])
|
||||
d_2 = (x_2 - denoised_2) / sigmas[i + 1]
|
||||
d_prime = (d + d_2) / 2
|
||||
x = x + d_prime * dt
|
||||
return x
|
||||
|
||||
|
||||
class BatchedBrownianTree:
|
||||
"""
|
||||
A wrapper around torchsde.BrownianTree that enables batches of entropy.
|
||||
"""
|
||||
|
||||
def __init__(self, x, t0, t1, seed=None, **kwargs):
|
||||
t0, t1, self.sign = self.sort(t0, t1)
|
||||
w0 = kwargs.get('w0', torch.zeros_like(x))
|
||||
if seed is None:
|
||||
seed = torch.randint(0, 2**63 - 1, []).item()
|
||||
self.batched = True
|
||||
try:
|
||||
assert len(seed) == x.shape[0]
|
||||
w0 = w0[0]
|
||||
except TypeError:
|
||||
seed = [seed]
|
||||
self.batched = False
|
||||
self.trees = [
|
||||
torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs)
|
||||
for s in seed
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def sort(a, b):
|
||||
return (a, b, 1) if a < b else (b, a, -1)
|
||||
|
||||
def __call__(self, t0, t1):
|
||||
t0, t1, sign = self.sort(t0, t1)
|
||||
w = torch.stack([tree(t0, t1) for tree in self.trees]) * (
|
||||
self.sign * sign)
|
||||
return w if self.batched else w[0]
|
||||
|
||||
|
||||
class BrownianTreeNoiseSampler:
|
||||
"""
|
||||
A noise sampler backed by a torchsde.BrownianTree.
|
||||
|
||||
Args:
|
||||
x (Tensor): The tensor whose shape, device and dtype to use to generate
|
||||
random samples.
|
||||
sigma_min (float): The low end of the valid interval.
|
||||
sigma_max (float): The high end of the valid interval.
|
||||
seed (int or List[int]): The random seed. If a list of seeds is
|
||||
supplied instead of a single integer, then the noise sampler will
|
||||
use one BrownianTree per batch item, each with its own seed.
|
||||
transform (callable): A function that maps sigma to the sampler's
|
||||
internal timestep.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
x,
|
||||
sigma_min,
|
||||
sigma_max,
|
||||
seed=None,
|
||||
transform=lambda x: x):
|
||||
self.transform = transform
|
||||
t0 = self.transform(torch.as_tensor(sigma_min))
|
||||
t1 = self.transform(torch.as_tensor(sigma_max))
|
||||
self.tree = BatchedBrownianTree(x, t0, t1, seed)
|
||||
|
||||
def __call__(self, sigma, sigma_next):
|
||||
t0 = self.transform(torch.as_tensor(sigma))
|
||||
t1 = self.transform(torch.as_tensor(sigma_next))
|
||||
return self.tree(t0, t1) / (t1 - t0).abs().sqrt()
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_2m_sde(noise,
|
||||
model,
|
||||
sigmas,
|
||||
eta=1.,
|
||||
s_noise=1.,
|
||||
solver_type='midpoint',
|
||||
show_progress=True):
|
||||
"""
|
||||
DPM-Solver++ (2M) SDE.
|
||||
"""
|
||||
assert solver_type in {'heun', 'midpoint'}
|
||||
|
||||
x = noise * sigmas[0]
|
||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas[
|
||||
sigmas < float('inf')].max()
|
||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max)
|
||||
old_denoised = None
|
||||
h_last = None
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=not show_progress):
|
||||
if sigmas[i] == float('inf'):
|
||||
# Euler method
|
||||
denoised = model(noise, sigmas[i])
|
||||
x = denoised + sigmas[i + 1] * noise
|
||||
else:
|
||||
_, c_in = get_scalings(sigmas[i])
|
||||
denoised = model(x * c_in, sigmas[i])
|
||||
if sigmas[i + 1] == 0:
|
||||
# Denoising step
|
||||
x = denoised
|
||||
else:
|
||||
# DPM-Solver++(2M) SDE
|
||||
t, s = -sigmas[i].log(), -sigmas[i + 1].log()
|
||||
h = s - t
|
||||
eta_h = eta * h
|
||||
|
||||
x = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * x + \
|
||||
(-h - eta_h).expm1().neg() * denoised
|
||||
|
||||
if old_denoised is not None:
|
||||
r = h_last / h
|
||||
if solver_type == 'heun':
|
||||
x = x + ((-h - eta_h).expm1().neg() / (-h - eta_h) + 1) * \
|
||||
(1 / r) * (denoised - old_denoised)
|
||||
elif solver_type == 'midpoint':
|
||||
x = x + 0.5 * (-h - eta_h).expm1().neg() * \
|
||||
(1 / r) * (denoised - old_denoised)
|
||||
|
||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[
|
||||
i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise
|
||||
|
||||
old_denoised = denoised
|
||||
h_last = h
|
||||
return x
|
||||
404
modelscope/models/multi_modal/video_to_video/utils/transforms.py
Normal file
404
modelscope/models/multi_modal/video_to_video/utils/transforms.py
Normal file
@@ -0,0 +1,404 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import math
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms.functional as F
|
||||
from PIL import Image, ImageFilter
|
||||
|
||||
__all__ = [
|
||||
'Compose', 'Resize', 'Rescale', 'CenterCrop', 'CenterCropV2',
|
||||
'CenterCropWide', 'RandomCrop', 'RandomCropV2', 'RandomHFlip',
|
||||
'GaussianBlur', 'ColorJitter', 'RandomGray', 'ToTensor', 'Normalize',
|
||||
'ResizeRandomCrop', 'ExtractResizeRandomCrop', 'ExtractResizeAssignCrop'
|
||||
]
|
||||
|
||||
|
||||
class Compose(object):
|
||||
|
||||
def __init__(self, transforms):
|
||||
self.transforms = transforms
|
||||
|
||||
def __getitem__(self, index):
|
||||
if isinstance(index, slice):
|
||||
return Compose(self.transforms[index])
|
||||
else:
|
||||
return self.transforms[index]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.transforms)
|
||||
|
||||
def __call__(self, rgb):
|
||||
for t in self.transforms:
|
||||
rgb = t(rgb)
|
||||
return rgb
|
||||
|
||||
|
||||
class Resize(object):
|
||||
|
||||
def __init__(self, size=256):
|
||||
if isinstance(size, int):
|
||||
size = (size, size)
|
||||
self.size = size
|
||||
|
||||
def __call__(self, rgb):
|
||||
if isinstance(rgb, list):
|
||||
rgb = [u.resize(self.size, Image.BILINEAR) for u in rgb]
|
||||
else:
|
||||
rgb = rgb.resize(self.size, Image.BILINEAR)
|
||||
return rgb
|
||||
|
||||
|
||||
class Rescale(object):
|
||||
|
||||
def __init__(self, size=256, interpolation=Image.BILINEAR):
|
||||
self.size = size
|
||||
self.interpolation = interpolation
|
||||
|
||||
def __call__(self, rgb):
|
||||
w, h = rgb[0].size
|
||||
scale = self.size / min(w, h)
|
||||
out_w, out_h = int(round(w * scale)), int(round(h * scale))
|
||||
rgb = [u.resize((out_w, out_h), self.interpolation) for u in rgb]
|
||||
return rgb
|
||||
|
||||
|
||||
class CenterCrop(object):
|
||||
|
||||
def __init__(self, size=224):
|
||||
self.size = size
|
||||
|
||||
def __call__(self, rgb):
|
||||
w, h = rgb[0].size
|
||||
assert min(w, h) >= self.size
|
||||
x1 = (w - self.size) // 2
|
||||
y1 = (h - self.size) // 2
|
||||
rgb = [u.crop((x1, y1, x1 + self.size, y1 + self.size)) for u in rgb]
|
||||
return rgb
|
||||
|
||||
|
||||
class ResizeRandomCrop(object):
|
||||
|
||||
def __init__(self, size=256, size_short=292):
|
||||
self.size = size
|
||||
self.size_short = size_short
|
||||
|
||||
def __call__(self, rgb):
|
||||
|
||||
# consistent crop between rgb and m
|
||||
while min(rgb[0].size) >= 2 * self.size_short:
|
||||
rgb = [
|
||||
u.resize((u.width // 2, u.height // 2), resample=Image.BOX)
|
||||
for u in rgb
|
||||
]
|
||||
scale = self.size_short / min(rgb[0].size)
|
||||
rgb = [
|
||||
u.resize((round(scale * u.width), round(scale * u.height)),
|
||||
resample=Image.BICUBIC) for u in rgb
|
||||
]
|
||||
out_w = self.size
|
||||
out_h = self.size
|
||||
w, h = rgb[0].size
|
||||
x1 = random.randint(0, w - out_w)
|
||||
y1 = random.randint(0, h - out_h)
|
||||
|
||||
rgb = [u.crop((x1, y1, x1 + out_w, y1 + out_h)) for u in rgb]
|
||||
return rgb
|
||||
|
||||
|
||||
class ExtractResizeRandomCrop(object):
|
||||
|
||||
def __init__(self, size=256, size_short=292):
|
||||
self.size = size
|
||||
self.size_short = size_short
|
||||
|
||||
def __call__(self, rgb):
|
||||
|
||||
# consistent crop between rgb and m
|
||||
while min(rgb[0].size) >= 2 * self.size_short:
|
||||
rgb = [
|
||||
u.resize((u.width // 2, u.height // 2), resample=Image.BOX)
|
||||
for u in rgb
|
||||
]
|
||||
scale = self.size_short / min(rgb[0].size)
|
||||
rgb = [
|
||||
u.resize((round(scale * u.width), round(scale * u.height)),
|
||||
resample=Image.BICUBIC) for u in rgb
|
||||
]
|
||||
out_w = self.size
|
||||
out_h = self.size
|
||||
w, h = rgb[0].size
|
||||
x1 = random.randint(0, w - out_w)
|
||||
y1 = random.randint(0, h - out_h)
|
||||
|
||||
rgb = [u.crop((x1, y1, x1 + out_w, y1 + out_h)) for u in rgb]
|
||||
wh = [x1, y1, x1 + out_w, y1 + out_h]
|
||||
return rgb, wh
|
||||
|
||||
|
||||
class ExtractResizeAssignCrop(object):
|
||||
|
||||
def __init__(self, size=256, size_short=292):
|
||||
self.size = size
|
||||
self.size_short = size_short
|
||||
|
||||
def __call__(self, rgb, wh):
|
||||
|
||||
# consistent crop between rgb and m
|
||||
while min(rgb[0].size) >= 2 * self.size_short:
|
||||
rgb = [
|
||||
u.resize((u.width // 2, u.height // 2), resample=Image.BOX)
|
||||
for u in rgb
|
||||
]
|
||||
scale = self.size_short / min(rgb[0].size)
|
||||
rgb = [
|
||||
u.resize((round(scale * u.width), round(scale * u.height)),
|
||||
resample=Image.BICUBIC) for u in rgb
|
||||
]
|
||||
|
||||
rgb = [u.crop(wh) for u in rgb]
|
||||
rgb = [u.resize((self.size, self.size), Image.BILINEAR) for u in rgb]
|
||||
|
||||
return rgb
|
||||
|
||||
|
||||
class CenterCropV2(object):
|
||||
|
||||
def __init__(self, size):
|
||||
self.size = size
|
||||
|
||||
def __call__(self, img):
|
||||
# fast resize
|
||||
while min(img[0].size) >= 2 * self.size:
|
||||
img = [
|
||||
u.resize((u.width // 2, u.height // 2), resample=Image.BOX)
|
||||
for u in img
|
||||
]
|
||||
scale = self.size / min(img[0].size)
|
||||
img = [
|
||||
u.resize((round(scale * u.width), round(scale * u.height)),
|
||||
resample=Image.BICUBIC) for u in img
|
||||
]
|
||||
|
||||
# center crop
|
||||
x1 = (img[0].width - self.size) // 2
|
||||
y1 = (img[0].height - self.size) // 2
|
||||
img = [u.crop((x1, y1, x1 + self.size, y1 + self.size)) for u in img]
|
||||
return img
|
||||
|
||||
|
||||
class CenterCropWide(object):
|
||||
|
||||
def __init__(self, size):
|
||||
self.size = size
|
||||
|
||||
def __call__(self, img):
|
||||
if isinstance(img, list):
|
||||
scale = min(img[0].size[0] / self.size[0],
|
||||
img[0].size[1] / self.size[1])
|
||||
img = [
|
||||
u.resize((round(u.width // scale), round(u.height // scale)),
|
||||
resample=Image.BOX) for u in img
|
||||
]
|
||||
|
||||
# center crop
|
||||
x1 = (img[0].width - self.size[0]) // 2
|
||||
y1 = (img[0].height - self.size[1]) // 2
|
||||
img = [
|
||||
u.crop((x1, y1, x1 + self.size[0], y1 + self.size[1]))
|
||||
for u in img
|
||||
]
|
||||
return img
|
||||
else:
|
||||
scale = min(img.size[0] / self.size[0], img.size[1] / self.size[1])
|
||||
img = img.resize(
|
||||
(round(img.width // scale), round(img.height // scale)),
|
||||
resample=Image.BOX)
|
||||
x1 = (img.width - self.size[0]) // 2
|
||||
y1 = (img.height - self.size[1]) // 2
|
||||
img = img.crop((x1, y1, x1 + self.size[0], y1 + self.size[1]))
|
||||
return img
|
||||
|
||||
|
||||
class RandomCrop(object):
|
||||
|
||||
def __init__(self, size=224, min_area=0.4):
|
||||
self.size = size
|
||||
self.min_area = min_area
|
||||
|
||||
def __call__(self, rgb):
|
||||
|
||||
# consistent crop between rgb and m
|
||||
w, h = rgb[0].size
|
||||
area = w * h
|
||||
out_w, out_h = float('inf'), float('inf')
|
||||
while out_w > w or out_h > h:
|
||||
target_area = random.uniform(self.min_area, 1.0) * area
|
||||
aspect_ratio = random.uniform(3. / 4., 4. / 3.)
|
||||
out_w = int(round(math.sqrt(target_area * aspect_ratio)))
|
||||
out_h = int(round(math.sqrt(target_area / aspect_ratio)))
|
||||
x1 = random.randint(0, w - out_w)
|
||||
y1 = random.randint(0, h - out_h)
|
||||
|
||||
rgb = [u.crop((x1, y1, x1 + out_w, y1 + out_h)) for u in rgb]
|
||||
rgb = [u.resize((self.size, self.size), Image.BILINEAR) for u in rgb]
|
||||
|
||||
return rgb
|
||||
|
||||
|
||||
class RandomCropV2(object):
|
||||
|
||||
def __init__(self, size=224, min_area=0.4, ratio=(3. / 4., 4. / 3.)):
|
||||
if isinstance(size, (tuple, list)):
|
||||
self.size = size
|
||||
else:
|
||||
self.size = (size, size)
|
||||
self.min_area = min_area
|
||||
self.ratio = ratio
|
||||
|
||||
def _get_params(self, img):
|
||||
width, height = img.size
|
||||
area = height * width
|
||||
|
||||
for _ in range(10):
|
||||
target_area = random.uniform(self.min_area, 1.0) * area
|
||||
log_ratio = (math.log(self.ratio[0]), math.log(self.ratio[1]))
|
||||
aspect_ratio = math.exp(random.uniform(*log_ratio))
|
||||
|
||||
w = int(round(math.sqrt(target_area * aspect_ratio)))
|
||||
h = int(round(math.sqrt(target_area / aspect_ratio)))
|
||||
|
||||
if 0 < w <= width and 0 < h <= height:
|
||||
i = random.randint(0, height - h)
|
||||
j = random.randint(0, width - w)
|
||||
return i, j, h, w
|
||||
|
||||
# Fallback to central crop
|
||||
in_ratio = float(width) / float(height)
|
||||
if (in_ratio < min(self.ratio)):
|
||||
w = width
|
||||
h = int(round(w / min(self.ratio)))
|
||||
elif (in_ratio > max(self.ratio)):
|
||||
h = height
|
||||
w = int(round(h * max(self.ratio)))
|
||||
else:
|
||||
w = width
|
||||
h = height
|
||||
i = (height - h) // 2
|
||||
j = (width - w) // 2
|
||||
return i, j, h, w
|
||||
|
||||
def __call__(self, rgb):
|
||||
i, j, h, w = self._get_params(rgb[0])
|
||||
rgb = [F.resized_crop(u, i, j, h, w, self.size) for u in rgb]
|
||||
return rgb
|
||||
|
||||
|
||||
class RandomHFlip(object):
|
||||
|
||||
def __init__(self, p=0.5):
|
||||
self.p = p
|
||||
|
||||
def __call__(self, rgb):
|
||||
if random.random() < self.p:
|
||||
rgb = [u.transpose(Image.FLIP_LEFT_RIGHT) for u in rgb]
|
||||
return rgb
|
||||
|
||||
|
||||
class GaussianBlur(object):
|
||||
|
||||
def __init__(self, sigmas=[0.1, 2.0], p=0.5):
|
||||
self.sigmas = sigmas
|
||||
self.p = p
|
||||
|
||||
def __call__(self, rgb):
|
||||
if random.random() < self.p:
|
||||
sigma = random.uniform(*self.sigmas)
|
||||
rgb = [
|
||||
u.filter(ImageFilter.GaussianBlur(radius=sigma)) for u in rgb
|
||||
]
|
||||
return rgb
|
||||
|
||||
|
||||
class ColorJitter(object):
|
||||
|
||||
def __init__(self,
|
||||
brightness=0.4,
|
||||
contrast=0.4,
|
||||
saturation=0.4,
|
||||
hue=0.1,
|
||||
p=0.5):
|
||||
self.brightness = brightness
|
||||
self.contrast = contrast
|
||||
self.saturation = saturation
|
||||
self.hue = hue
|
||||
self.p = p
|
||||
|
||||
def __call__(self, rgb):
|
||||
if random.random() < self.p:
|
||||
brightness, contrast, saturation, hue = self._random_params()
|
||||
transforms = [
|
||||
lambda f: F.adjust_brightness(f, brightness),
|
||||
lambda f: F.adjust_contrast(f, contrast),
|
||||
lambda f: F.adjust_saturation(f, saturation),
|
||||
lambda f: F.adjust_hue(f, hue)
|
||||
]
|
||||
random.shuffle(transforms)
|
||||
for t in transforms:
|
||||
rgb = [t(u) for u in rgb]
|
||||
|
||||
return rgb
|
||||
|
||||
def _random_params(self):
|
||||
brightness = random.uniform(
|
||||
max(0, 1 - self.brightness), 1 + self.brightness)
|
||||
contrast = random.uniform(max(0, 1 - self.contrast), 1 + self.contrast)
|
||||
saturation = random.uniform(
|
||||
max(0, 1 - self.saturation), 1 + self.saturation)
|
||||
hue = random.uniform(-self.hue, self.hue)
|
||||
return brightness, contrast, saturation, hue
|
||||
|
||||
|
||||
class RandomGray(object):
|
||||
|
||||
def __init__(self, p=0.2):
|
||||
self.p = p
|
||||
|
||||
def __call__(self, rgb):
|
||||
if random.random() < self.p:
|
||||
rgb = [u.convert('L').convert('RGB') for u in rgb]
|
||||
return rgb
|
||||
|
||||
|
||||
class ToTensor(object):
|
||||
|
||||
def __call__(self, rgb):
|
||||
if isinstance(rgb, list):
|
||||
rgb = torch.stack([F.to_tensor(u) for u in rgb], dim=0)
|
||||
else:
|
||||
rgb = F.to_tensor(rgb)
|
||||
|
||||
return rgb
|
||||
|
||||
|
||||
class Normalize(object):
|
||||
|
||||
def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
|
||||
self.mean = mean
|
||||
self.std = std
|
||||
|
||||
def __call__(self, rgb):
|
||||
rgb = rgb.clone()
|
||||
rgb.clamp_(0, 1)
|
||||
if not isinstance(self.mean, torch.Tensor):
|
||||
self.mean = rgb.new_tensor(self.mean).view(-1)
|
||||
if not isinstance(self.std, torch.Tensor):
|
||||
self.std = rgb.new_tensor(self.std).view(-1)
|
||||
if rgb.dim() == 4:
|
||||
rgb.sub_(self.mean.view(1, -1, 1,
|
||||
1)).div_(self.std.view(1, -1, 1, 1))
|
||||
elif rgb.dim() == 3:
|
||||
rgb.sub_(self.mean.view(-1, 1, 1)).div_(self.std.view(-1, 1, 1))
|
||||
return rgb
|
||||
227
modelscope/models/multi_modal/video_to_video/video_to_video_model.py
Executable file
227
modelscope/models/multi_modal/video_to_video/video_to_video_model.py
Executable file
@@ -0,0 +1,227 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
import os.path as osp
|
||||
import random
|
||||
from copy import copy
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
import torch.cuda.amp as amp
|
||||
import torch.nn.functional as F
|
||||
|
||||
import modelscope.models.multi_modal.video_to_video.utils.transforms as data
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models.base import TorchModel
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.models.multi_modal.video_to_video.modules import *
|
||||
from modelscope.models.multi_modal.video_to_video.modules import (
|
||||
AutoencoderKL, FrozenOpenCLIPEmbedder, Vid2VidSDUNet,
|
||||
get_first_stage_encoding)
|
||||
from modelscope.models.multi_modal.video_to_video.utils.config import cfg
|
||||
from modelscope.models.multi_modal.video_to_video.utils.diffusion_sdedit import \
|
||||
GaussianDiffusion_SDEdit
|
||||
from modelscope.models.multi_modal.video_to_video.utils.schedules_sdedit import \
|
||||
noise_schedule
|
||||
from modelscope.models.multi_modal.video_to_video.utils.seed import setup_seed
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
__all__ = ['VideoToVideo']
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
@MODELS.register_module(
|
||||
Tasks.video_to_video, module_name=Models.video_to_video_model)
|
||||
class VideoToVideo(TorchModel):
|
||||
r"""
|
||||
Video2Video aims to solve the task of generating super-resolution videos based on input
|
||||
video and text, which is a video generation basic model developed by Alibaba Cloud.
|
||||
|
||||
Paper link: https://arxiv.org/abs/2306.02018
|
||||
|
||||
Attributes:
|
||||
diffusion: diffusion model for DDIM.
|
||||
autoencoder: decode the latent representation of input video into visual space.
|
||||
clip_encoder: encode the text into text embedding.
|
||||
"""
|
||||
|
||||
def __init__(self, model_dir, *args, **kwargs):
|
||||
r"""
|
||||
Args:
|
||||
model_dir (`str` or `os.PathLike`)
|
||||
Can be either:
|
||||
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co
|
||||
or modelscope.cn. Valid model ids can be located at the root-level, like `bert-base-uncased`,
|
||||
or namespaced under a user or organization name, like `dbmdz/bert-base-german-cased`.
|
||||
- A path to a *directory* containing model weights saved using
|
||||
[`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
|
||||
- A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In
|
||||
this case, `from_tf` should be set to `True` and a configuration object should be provided as
|
||||
`config` argument. This loading path is slower than converting the TensorFlow checkpoint in a
|
||||
PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
|
||||
- A path or url to a model folder containing a *flax checkpoint file* in *.msgpack* format (e.g,
|
||||
`./flax_model/` containing `flax_model.msgpack`). In this case, `from_flax` should be set to
|
||||
`True`.
|
||||
"""
|
||||
super().__init__(model_dir=model_dir, *args, **kwargs)
|
||||
|
||||
self.config = Config.from_file(
|
||||
osp.join(model_dir, ModelFile.CONFIGURATION))
|
||||
|
||||
cfg.solver_mode = self.config.model.model_args.solver_mode
|
||||
|
||||
# assign default value
|
||||
cfg.batch_size = self.config.model.model_cfg.batch_size
|
||||
cfg.target_fps = self.config.model.model_cfg.target_fps
|
||||
cfg.max_frames = self.config.model.model_cfg.max_frames
|
||||
cfg.latent_hei = self.config.model.model_cfg.latent_hei
|
||||
cfg.latent_wid = self.config.model.model_cfg.latent_wid
|
||||
cfg.model_path = osp.join(model_dir,
|
||||
self.config.model.model_args.ckpt_unet)
|
||||
|
||||
self.device = torch.device(
|
||||
'cuda') if torch.cuda.is_available() else torch.device('cpu')
|
||||
|
||||
if 'seed' in self.config.model.model_args.keys():
|
||||
cfg.seed = self.config.model.model_args.seed
|
||||
else:
|
||||
cfg.seed = random.randint(0, 99999)
|
||||
setup_seed(cfg.seed)
|
||||
|
||||
# transform
|
||||
vid_trans = data.Compose(
|
||||
[data.ToTensor(),
|
||||
data.Normalize(mean=cfg.mean, std=cfg.std)])
|
||||
self.vid_trans = vid_trans
|
||||
|
||||
cfg.embedder.pretrained = osp.join(
|
||||
model_dir, self.config.model.model_args.ckpt_clip)
|
||||
clip_encoder = FrozenOpenCLIPEmbedder(
|
||||
pretrained=cfg.embedder.pretrained)
|
||||
clip_encoder.model.to(self.device)
|
||||
self.clip_encoder = clip_encoder
|
||||
logger.info(f'Build encoder with {cfg.embedder.type}')
|
||||
|
||||
# [unet]
|
||||
generator = Vid2VidSDUNet()
|
||||
generator = generator.to(self.device)
|
||||
generator.eval()
|
||||
load_dict = torch.load(cfg.model_path, map_location='cpu')
|
||||
ret = generator.load_state_dict(load_dict['state_dict'], strict=True)
|
||||
self.generator = generator
|
||||
logger.info('Load model {} path {}, with local status {}'.format(
|
||||
cfg.UNet.type, cfg.model_path, ret))
|
||||
|
||||
# [diffusion]
|
||||
sigmas = noise_schedule(
|
||||
schedule='logsnr_cosine_interp',
|
||||
n=1000,
|
||||
zero_terminal_snr=True,
|
||||
scale_min=2.0,
|
||||
scale_max=4.0)
|
||||
diffusion = GaussianDiffusion_SDEdit(
|
||||
sigmas=sigmas, prediction_type='v')
|
||||
self.diffusion = diffusion
|
||||
logger.info('Build diffusion with type of GaussianDiffusion_SDEdit')
|
||||
|
||||
# [auotoencoder]
|
||||
cfg.auto_encoder.pretrained = osp.join(
|
||||
model_dir, self.config.model.model_args.ckpt_autoencoder)
|
||||
autoencoder = AutoencoderKL(**cfg.auto_encoder)
|
||||
autoencoder.eval()
|
||||
for param in autoencoder.parameters():
|
||||
param.requires_grad = False
|
||||
autoencoder.to(self.device)
|
||||
self.autoencoder = autoencoder
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
negative_prompt = cfg.negative_prompt
|
||||
negative_y = clip_encoder(negative_prompt).detach()
|
||||
self.negative_y = negative_y
|
||||
|
||||
positive_prompt = cfg.positive_prompt
|
||||
self.positive_prompt = positive_prompt
|
||||
|
||||
self.cfg = cfg
|
||||
|
||||
def forward(self, input: Dict[str, Any]):
|
||||
r"""
|
||||
The entry function of video to video task.
|
||||
1. Using CLIP to encode text into embeddings.
|
||||
2. Using diffusion model to generate the video's latent representation.
|
||||
3. Using autoencoder to decode the video's latent representation to visual space.
|
||||
|
||||
Args:
|
||||
input (`Dict[Str, Any]`):
|
||||
The input of the task
|
||||
Returns:
|
||||
A generated video (as pytorch tensor).
|
||||
"""
|
||||
|
||||
video_data = input['video_data']
|
||||
y = input['y']
|
||||
cfg = self.cfg
|
||||
|
||||
video_data = F.interpolate(
|
||||
video_data, size=(720, 1280), mode='bilinear')
|
||||
video_data = video_data.unsqueeze(0)
|
||||
video_data = video_data.to(self.device)
|
||||
|
||||
batch_size, frames_num, _, _, _ = video_data.shape
|
||||
video_data = rearrange(video_data, 'b f c h w -> (b f) c h w')
|
||||
|
||||
video_data_list = torch.chunk(
|
||||
video_data, video_data.shape[0] // 2, dim=0)
|
||||
with torch.no_grad():
|
||||
decode_data = []
|
||||
for vd_data in video_data_list:
|
||||
encoder_posterior = self.autoencoder.encode(vd_data)
|
||||
tmp = get_first_stage_encoding(encoder_posterior).detach()
|
||||
decode_data.append(tmp)
|
||||
video_data_feature = torch.cat(decode_data, dim=0)
|
||||
video_data_feature = rearrange(
|
||||
video_data_feature, '(b f) c h w -> b c f h w', b=batch_size)
|
||||
|
||||
with amp.autocast(enabled=True):
|
||||
total_noise_levels = 600
|
||||
t = torch.randint(
|
||||
total_noise_levels - 1,
|
||||
total_noise_levels, (1, ),
|
||||
dtype=torch.long).to(self.device)
|
||||
|
||||
noise = torch.randn_like(video_data_feature)
|
||||
noised_lr = self.diffusion.diffuse(video_data_feature, t, noise)
|
||||
model_kwargs = [{'y': y}, {'y': self.negative_y}]
|
||||
|
||||
gen_vid = self.diffusion.sample(
|
||||
noise=noised_lr,
|
||||
model=self.generator,
|
||||
model_kwargs=model_kwargs,
|
||||
guide_scale=7.5,
|
||||
guide_rescale=0.2,
|
||||
solver='dpmpp_2m_sde' if cfg.solver_mode == 'fast' else 'heun',
|
||||
steps=30 if cfg.solver_mode == 'fast' else 50,
|
||||
t_max=total_noise_levels - 1,
|
||||
t_min=0,
|
||||
discretization='trailing')
|
||||
|
||||
scale_factor = 0.18215
|
||||
vid_tensor_feature = 1. / scale_factor * gen_vid
|
||||
|
||||
vid_tensor_feature = rearrange(vid_tensor_feature,
|
||||
'b c f h w -> (b f) c h w')
|
||||
vid_tensor_feature_list = torch.chunk(
|
||||
vid_tensor_feature, vid_tensor_feature.shape[0] // 2, dim=0)
|
||||
decode_data = []
|
||||
for vd_data in vid_tensor_feature_list:
|
||||
tmp = self.autoencoder.decode(vd_data)
|
||||
decode_data.append(tmp)
|
||||
vid_tensor_gen = torch.cat(decode_data, dim=0)
|
||||
|
||||
gen_video = rearrange(
|
||||
vid_tensor_gen, '(b f) c h w -> b c f h w', b=cfg.batch_size)
|
||||
|
||||
return gen_video.type(torch.float32).cpu()
|
||||
140
modelscope/pipelines/multi_modal/video_to_video_pipeline.py
Normal file
140
modelscope/pipelines/multi_modal/video_to_video_pipeline.py
Normal file
@@ -0,0 +1,140 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
from einops import rearrange
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.base import Input, Model, Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.preprocessors.image import LoadImage
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.video_to_video, module_name=Pipelines.video_to_video_pipeline)
|
||||
class VideoToVideoPipeline(Pipeline):
|
||||
r""" Video To Video Pipeline, generating super-resolution videos based on input
|
||||
video and text
|
||||
|
||||
Examples:
|
||||
>>> from modelscope.pipelines import pipeline
|
||||
>>> from modelscope.outputs import OutputKeys
|
||||
|
||||
>>> # YOUR_VIDEO_PATH: your video url or local position in low resolution
|
||||
>>> # INPUT_TEXT: when we do video super-resolution, we will add the text content
|
||||
>>> # into results
|
||||
>>> # output_video_path: path-to-the-generated-video
|
||||
|
||||
>>> p = pipeline('video-to-video', 'damo/Video-to-Video')
|
||||
>>> input = {"video_path":YOUR_VIDEO_PATH, "text": INPUT_TEXT}
|
||||
>>> output_video_path = p(input,output_video='./output.mp4')[OutputKeys.OUTPUT_VIDEO]
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, model: str, **kwargs):
|
||||
"""
|
||||
use `model` to create a kws pipeline for prediction
|
||||
Args:
|
||||
model: model id on modelscope hub.
|
||||
"""
|
||||
super().__init__(model=model, **kwargs)
|
||||
|
||||
def preprocess(self, input: Input, **preprocess_params) -> Dict[str, Any]:
|
||||
vid_path = input['video_path']
|
||||
if 'text' in input.keys():
|
||||
text = input['text']
|
||||
else:
|
||||
text = ''
|
||||
|
||||
caption = text + self.model.positive_prompt
|
||||
y = self.model.clip_encoder(caption).detach()
|
||||
|
||||
max_frames = self.model.cfg.max_frames
|
||||
|
||||
capture = cv2.VideoCapture(vid_path)
|
||||
_fps = capture.get(cv2.CAP_PROP_FPS)
|
||||
sample_fps = _fps
|
||||
_total_frame_num = capture.get(cv2.CAP_PROP_FRAME_COUNT)
|
||||
stride = round(_fps / sample_fps)
|
||||
start_frame = 0
|
||||
|
||||
pointer = 0
|
||||
frame_list = []
|
||||
while len(frame_list) < max_frames:
|
||||
ret, frame = capture.read()
|
||||
pointer += 1
|
||||
if (not ret) or (frame is None):
|
||||
break
|
||||
if pointer < start_frame:
|
||||
continue
|
||||
if pointer >= _total_frame_num + 1:
|
||||
break
|
||||
if (pointer - start_frame) % stride == 0:
|
||||
frame = LoadImage.convert_to_img(frame)
|
||||
frame_list.append(frame)
|
||||
capture.release()
|
||||
|
||||
video_data = self.model.vid_trans(frame_list)
|
||||
|
||||
return {'video_data': video_data, 'y': y}
|
||||
|
||||
def forward(self, input: Dict[str, Any],
|
||||
**forward_params) -> Dict[str, Any]:
|
||||
video = self.model(input)
|
||||
return {'video': video}
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Any],
|
||||
**post_params) -> Dict[str, Any]:
|
||||
video = tensor2vid(inputs['video'], self.model.cfg.mean,
|
||||
self.model.cfg.std)
|
||||
output_video_path = post_params.get('output_video', None)
|
||||
temp_video_file = False
|
||||
if output_video_path is None:
|
||||
output_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name
|
||||
temp_video_file = True
|
||||
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
for fid, frame in enumerate(video):
|
||||
tpth = os.path.join(temp_dir, '%06d.png' % (fid + 1))
|
||||
cv2.imwrite(tpth, frame[:, :, ::-1],
|
||||
[int(cv2.IMWRITE_JPEG_QUALITY), 100])
|
||||
|
||||
cmd = f'ffmpeg -y -f image2 -loglevel quiet -framerate 8.0 -i {temp_dir}/%06d.png \
|
||||
-vcodec libx264 -crf 17 -pix_fmt yuv420p {output_video_path}'
|
||||
|
||||
status = os.system(cmd)
|
||||
if status != 0:
|
||||
logger.info('Save Video Error with {}'.format(status))
|
||||
os.system(f'rm -rf {temp_dir}')
|
||||
|
||||
if temp_video_file:
|
||||
video_file_content = b''
|
||||
with open(output_video_path, 'rb') as f:
|
||||
video_file_content = f.read()
|
||||
os.remove(output_video_path)
|
||||
return {OutputKeys.OUTPUT_VIDEO: video_file_content}
|
||||
else:
|
||||
return {OutputKeys.OUTPUT_VIDEO: output_video_path}
|
||||
|
||||
|
||||
def tensor2vid(video, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]):
|
||||
mean = torch.tensor(mean, device=video.device).reshape(1, -1, 1, 1, 1)
|
||||
std = torch.tensor(std, device=video.device).reshape(1, -1, 1, 1, 1)
|
||||
|
||||
video = video.mul_(std).add_(mean)
|
||||
video.clamp_(0, 1)
|
||||
video = video * 255.0
|
||||
|
||||
images = rearrange(video, 'b c f h w -> b f h w c')[0]
|
||||
images = [(img.numpy()).astype('uint8') for img in images]
|
||||
|
||||
return images
|
||||
@@ -257,6 +257,7 @@ class MultiModalTasks(object):
|
||||
efficient_diffusion_tuning = 'efficient-diffusion-tuning'
|
||||
multimodal_dialogue = 'multimodal-dialogue'
|
||||
image_to_video = 'image-to-video'
|
||||
video_to_video = 'video-to-video'
|
||||
|
||||
|
||||
class ScienceTasks(object):
|
||||
|
||||
32
tests/pipelines/test_video2video.py
Normal file
32
tests/pipelines/test_video2video.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
from modelscope.models import Model
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
|
||||
class Video2VideoTest(unittest.TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.task = Tasks.video_to_video
|
||||
self.model_id = 'damo/Video-to-Video'
|
||||
self.path = 'https://video-generation-wulanchabu.oss-cn-wulanchabu.aliyuncs.com/baishao/test.mp4'
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_run_with_model_from_modelhub(self):
|
||||
pipe = pipeline(task=self.task, model=self.model_id)
|
||||
p_input = {
|
||||
'video_path': self.path,
|
||||
'text': 'A panda is surfing on the sea'
|
||||
}
|
||||
|
||||
output_video_path = pipe(
|
||||
p_input, output_video='./output.mp4')[OutputKeys.OUTPUT_VIDEO]
|
||||
print(output_video_path)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user