VideoComposer: Compositional Video Synthesis with Motion Controllability (#431)

* VideoComposer: Compositional Video Synthesis with Motion Controllability

* videocomposer pipeline

* pre commit

* delete xformers
This commit is contained in:
Wang Qiang
2023-08-15 12:01:03 +08:00
committed by GitHub
parent 18d33a4825
commit ee8afd2d62
50 changed files with 14483 additions and 1 deletions

View File

@@ -218,6 +218,7 @@ class Models(object):
mplug_owl = 'mplug-owl'
clip_interrogator = 'clip-interrogator'
stable_diffusion = 'stable-diffusion'
videocomposer = 'videocomposer'
text_to_360panorama_image = 'text-to-360panorama-image'
# science models
@@ -525,6 +526,7 @@ class Pipelines(object):
multi_modal_similarity = 'multi-modal-similarity'
text_to_image_synthesis = 'text-to-image-synthesis'
video_multi_modal_embedding = 'video-multi-modal-embedding'
videocomposer = 'videocomposer'
image_text_retrieval = 'image-text-retrieval'
ofa_ocr_recognition = 'ofa-ocr-recognition'
ofa_asr = 'ofa-asr'

View File

@@ -22,6 +22,7 @@ if TYPE_CHECKING:
from .efficient_diffusion_tuning import EfficientStableDiffusion
from .mplug_owl import MplugOwlForConditionalGeneration
from .clip_interrogator import CLIP_Interrogator
from .videocomposer import VideoComposer
else:
_import_structure = {
@@ -42,6 +43,7 @@ else:
'efficient_diffusion_tuning': ['EfficientStableDiffusion'],
'mplug_owl': ['MplugOwlForConditionalGeneration'],
'clip_interrogator': ['CLIP_Interrogator'],
'videocomposer': ['VideoComposer'],
}
import sys

View File

@@ -19,6 +19,9 @@ from modelscope.models.multi_modal.video_synthesis.diffusion import (
from modelscope.models.multi_modal.video_synthesis.unet_sd import UNetSD
from modelscope.utils.config import Config
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger
logger = get_logger()
__all__ = ['TextToVideoSynthesis']

View File

@@ -0,0 +1,23 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import TYPE_CHECKING
from modelscope.utils.import_utils import LazyImportModule
if TYPE_CHECKING:
from .videocomposer_model import VideoComposer
else:
_import_structure = {
'videocomposer_model': ['VideoComposer'],
}
import sys
sys.modules[__name__] = LazyImportModule(
__name__,
globals()['__file__'],
_import_structure,
module_spec=__spec__,
extra_objects={},
)

View File

@@ -0,0 +1 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

View File

@@ -0,0 +1,44 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import cv2
import numpy as np
import torch
from tools.annotator.util import HWC3
class CannyDetector:
def __call__(self,
img,
low_threshold=None,
high_threshold=None,
random_threshold=True):
# Convert to numpy
if isinstance(img, torch.Tensor): # (h, w, c)
img = img.cpu().numpy()
img_np = cv2.convertScaleAbs((img * 255.))
elif isinstance(img, np.ndarray): # (h, w, c)
img_np = img # we assume values are in the range from 0 to 255.
else:
assert False
# Select the threshold
if (low_threshold is None) and (high_threshold is None):
median_intensity = np.median(img_np)
if random_threshold is False:
low_threshold = int(max(0, (1 - 0.33) * median_intensity))
high_threshold = int(min(255, (1 + 0.33) * median_intensity))
else:
random_canny = np.random.uniform(0.1, 0.4)
# Might try other values
low_threshold = int(
max(0, (1 - random_canny) * median_intensity))
high_threshold = 2 * low_threshold
# Detect canny edge
canny_edge = cv2.Canny(img_np, low_threshold, high_threshold)
canny_condition = torch.from_numpy(
canny_edge.copy()).unsqueeze(dim=-1).float().cuda() / 255.0
return canny_condition

View File

@@ -0,0 +1,3 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from .palette import *

View File

@@ -0,0 +1,155 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
r"""Modified from ``https://github.com/sergeyk/rayleigh''.
"""
import os
import os.path as osp
import numpy as np
from skimage.color import hsv2rgb, lab2rgb, rgb2lab
from skimage.io import imsave
from sklearn.metrics import euclidean_distances
__all__ = ['Palette']
def rgb2hex(rgb):
return '#%02x%02x%02x' % tuple([int(round(255.0 * u)) for u in rgb])
def hex2rgb(hex):
rgb = hex.strip('#')
fn = lambda u: round(int(u, 16) / 255.0, 5) # noqa
return fn(rgb[:2]), fn(rgb[2:4]), fn(rgb[4:6])
class Palette(object):
r"""Create a color palette (codebook) in the form of a 2D grid of colors.
Further, the rightmost column has num_hues gradations from black to white.
Parameters:
num_hues: number of colors with full lightness and saturation, in the middle.
num_sat: number of rows above middle row that show the same hues with decreasing saturation.
"""
def __init__(self, num_hues=11, num_sat=5, num_light=4):
n = num_sat + 2 * num_light
# hues
if num_hues == 8:
hues = np.tile(
np.array([0., 0.10, 0.15, 0.28, 0.51, 0.58, 0.77, 0.85]),
(n, 1))
elif num_hues == 9:
hues = np.tile(
np.array([0., 0.10, 0.15, 0.28, 0.49, 0.54, 0.60, 0.7, 0.87]),
(n, 1))
elif num_hues == 10:
hues = np.tile(
np.array(
[0., 0.10, 0.15, 0.28, 0.49, 0.54, 0.60, 0.66, 0.76,
0.87]), (n, 1))
elif num_hues == 11:
hues = np.tile(
np.array([
0.0, 0.0833, 0.166, 0.25, 0.333, 0.5, 0.56333, 0.666, 0.73,
0.803, 0.916
]), (n, 1))
else:
hues = np.tile(np.linspace(0, 1, num_hues + 1)[:-1], (n, 1))
# saturations
sats = np.hstack((
np.linspace(0, 1, num_sat + 2)[1:-1],
1,
[1] * num_light,
[0.4] * # noqa
(num_light - 1)))
sats = np.tile(np.atleast_2d(sats).T, (1, num_hues))
# lights
lights = np.hstack(
([1] * num_sat, 1, np.linspace(1, 0.2, num_light + 2)[1:-1],
np.linspace(1, 0.2, num_light + 2)[1:-2]))
lights = np.tile(np.atleast_2d(lights).T, (1, num_hues))
# colors
rgb = hsv2rgb(np.dstack([hues, sats, lights]))
gray = np.tile(
np.linspace(1, 0, n)[:, np.newaxis, np.newaxis], (1, 1, 3))
self.thumbnail = np.hstack([rgb, gray])
# flatten
rgb = rgb.T.reshape(3, -1).T
gray = gray.T.reshape(3, -1).T
self.rgb = np.vstack((rgb, gray))
self.lab = rgb2lab(self.rgb[np.newaxis, :, :]).squeeze()
self.hex = [rgb2hex(u) for u in self.rgb]
self.lab_dists = euclidean_distances(self.lab, squared=True)
def histogram(self, rgb_img, sigma=20):
# compute histogram
lab = rgb2lab(rgb_img).reshape((-1, 3))
min_ind = np.argmin(
euclidean_distances(lab, self.lab, squared=True), axis=1)
hist = 1.0 * np.bincount(
min_ind, minlength=self.lab.shape[0]) / lab.shape[0]
# smooth histogram
if sigma > 0:
weight = np.exp(-self.lab_dists / (2.0 * sigma**2))
weight = weight / weight.sum(1)[:, np.newaxis]
hist = (weight * hist).sum(1)
hist[hist < 1e-5] = 0
return hist
def get_palette_image(self, hist, percentile=90, width=200, height=50):
# curate histogram
ind = np.argsort(-hist)
ind = ind[hist[ind] > np.percentile(hist, percentile)]
hist = hist[ind] / hist[ind].sum()
# draw palette
nums = np.array(hist * width, dtype=int)
array = np.vstack([
np.tile(np.array(u), (v, 1)) for u, v in zip(self.rgb[ind], nums)
])
array = np.tile(array[np.newaxis, :, :], (height, 1, 1))
if array.shape[1] < width:
array = np.concatenate(
[array, np.zeros((height, width - array.shape[1], 3))], axis=1)
return array
def quantize_image(self, rgb_img):
lab = rgb2lab(rgb_img).reshape((-1, 3))
min_ind = np.argmin(
euclidean_distances(lab, self.lab, squared=True), axis=1)
quantized_lab = self.lab[min_ind]
img = lab2rgb(quantized_lab.reshape(rgb_img.shape))
return img
def export(self, dirname):
if not osp.exists(dirname):
os.makedirs(dirname)
# save thumbnail
imsave(osp.join(dirname, 'palette.png'), self.thumbnail)
# save html
with open(osp.join(dirname, 'palette.html'), 'w') as f:
html = '''
<style>
span {
width: 20px;
height: 20px;
margin: 2px;
padding: 0px;
display: inline-block;
}
</style>
'''
for row in self.thumbnail:
for col in row:
html += '<a id="{0}"><span style="background-color: {0}" /></a>\n'.format(
rgb2hex(col))
html += '<br />\n'
f.write(html)

View File

@@ -0,0 +1,4 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from .pidinet import *
from .sketch_simplification import *

View File

@@ -0,0 +1,940 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
r"""Modified from ``https://github.com/zhuoinoulu/pidinet''.
Image augmentation: T.Compose([
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]).
"""
import math
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from modelscope.models.multi_modal.videocomposer.utils.utils import \
DOWNLOAD_TO_CACHE
__all__ = [
'PiDiNet', 'pidinet_bsd_tiny', 'pidinet_bsd_small', 'pidinet_bsd',
'pidinet_nyud', 'pidinet_multicue'
]
CONFIGS = {
'baseline': {
'layer0': 'cv',
'layer1': 'cv',
'layer2': 'cv',
'layer3': 'cv',
'layer4': 'cv',
'layer5': 'cv',
'layer6': 'cv',
'layer7': 'cv',
'layer8': 'cv',
'layer9': 'cv',
'layer10': 'cv',
'layer11': 'cv',
'layer12': 'cv',
'layer13': 'cv',
'layer14': 'cv',
'layer15': 'cv',
},
'c-v15': {
'layer0': 'cd',
'layer1': 'cv',
'layer2': 'cv',
'layer3': 'cv',
'layer4': 'cv',
'layer5': 'cv',
'layer6': 'cv',
'layer7': 'cv',
'layer8': 'cv',
'layer9': 'cv',
'layer10': 'cv',
'layer11': 'cv',
'layer12': 'cv',
'layer13': 'cv',
'layer14': 'cv',
'layer15': 'cv',
},
'a-v15': {
'layer0': 'ad',
'layer1': 'cv',
'layer2': 'cv',
'layer3': 'cv',
'layer4': 'cv',
'layer5': 'cv',
'layer6': 'cv',
'layer7': 'cv',
'layer8': 'cv',
'layer9': 'cv',
'layer10': 'cv',
'layer11': 'cv',
'layer12': 'cv',
'layer13': 'cv',
'layer14': 'cv',
'layer15': 'cv',
},
'r-v15': {
'layer0': 'rd',
'layer1': 'cv',
'layer2': 'cv',
'layer3': 'cv',
'layer4': 'cv',
'layer5': 'cv',
'layer6': 'cv',
'layer7': 'cv',
'layer8': 'cv',
'layer9': 'cv',
'layer10': 'cv',
'layer11': 'cv',
'layer12': 'cv',
'layer13': 'cv',
'layer14': 'cv',
'layer15': 'cv',
},
'cvvv4': {
'layer0': 'cd',
'layer1': 'cv',
'layer2': 'cv',
'layer3': 'cv',
'layer4': 'cd',
'layer5': 'cv',
'layer6': 'cv',
'layer7': 'cv',
'layer8': 'cd',
'layer9': 'cv',
'layer10': 'cv',
'layer11': 'cv',
'layer12': 'cd',
'layer13': 'cv',
'layer14': 'cv',
'layer15': 'cv',
},
'avvv4': {
'layer0': 'ad',
'layer1': 'cv',
'layer2': 'cv',
'layer3': 'cv',
'layer4': 'ad',
'layer5': 'cv',
'layer6': 'cv',
'layer7': 'cv',
'layer8': 'ad',
'layer9': 'cv',
'layer10': 'cv',
'layer11': 'cv',
'layer12': 'ad',
'layer13': 'cv',
'layer14': 'cv',
'layer15': 'cv',
},
'rvvv4': {
'layer0': 'rd',
'layer1': 'cv',
'layer2': 'cv',
'layer3': 'cv',
'layer4': 'rd',
'layer5': 'cv',
'layer6': 'cv',
'layer7': 'cv',
'layer8': 'rd',
'layer9': 'cv',
'layer10': 'cv',
'layer11': 'cv',
'layer12': 'rd',
'layer13': 'cv',
'layer14': 'cv',
'layer15': 'cv',
},
'cccv4': {
'layer0': 'cd',
'layer1': 'cd',
'layer2': 'cd',
'layer3': 'cv',
'layer4': 'cd',
'layer5': 'cd',
'layer6': 'cd',
'layer7': 'cv',
'layer8': 'cd',
'layer9': 'cd',
'layer10': 'cd',
'layer11': 'cv',
'layer12': 'cd',
'layer13': 'cd',
'layer14': 'cd',
'layer15': 'cv',
},
'aaav4': {
'layer0': 'ad',
'layer1': 'ad',
'layer2': 'ad',
'layer3': 'cv',
'layer4': 'ad',
'layer5': 'ad',
'layer6': 'ad',
'layer7': 'cv',
'layer8': 'ad',
'layer9': 'ad',
'layer10': 'ad',
'layer11': 'cv',
'layer12': 'ad',
'layer13': 'ad',
'layer14': 'ad',
'layer15': 'cv',
},
'rrrv4': {
'layer0': 'rd',
'layer1': 'rd',
'layer2': 'rd',
'layer3': 'cv',
'layer4': 'rd',
'layer5': 'rd',
'layer6': 'rd',
'layer7': 'cv',
'layer8': 'rd',
'layer9': 'rd',
'layer10': 'rd',
'layer11': 'cv',
'layer12': 'rd',
'layer13': 'rd',
'layer14': 'rd',
'layer15': 'cv',
},
'c16': {
'layer0': 'cd',
'layer1': 'cd',
'layer2': 'cd',
'layer3': 'cd',
'layer4': 'cd',
'layer5': 'cd',
'layer6': 'cd',
'layer7': 'cd',
'layer8': 'cd',
'layer9': 'cd',
'layer10': 'cd',
'layer11': 'cd',
'layer12': 'cd',
'layer13': 'cd',
'layer14': 'cd',
'layer15': 'cd',
},
'a16': {
'layer0': 'ad',
'layer1': 'ad',
'layer2': 'ad',
'layer3': 'ad',
'layer4': 'ad',
'layer5': 'ad',
'layer6': 'ad',
'layer7': 'ad',
'layer8': 'ad',
'layer9': 'ad',
'layer10': 'ad',
'layer11': 'ad',
'layer12': 'ad',
'layer13': 'ad',
'layer14': 'ad',
'layer15': 'ad',
},
'r16': {
'layer0': 'rd',
'layer1': 'rd',
'layer2': 'rd',
'layer3': 'rd',
'layer4': 'rd',
'layer5': 'rd',
'layer6': 'rd',
'layer7': 'rd',
'layer8': 'rd',
'layer9': 'rd',
'layer10': 'rd',
'layer11': 'rd',
'layer12': 'rd',
'layer13': 'rd',
'layer14': 'rd',
'layer15': 'rd',
},
'carv4': {
'layer0': 'cd',
'layer1': 'ad',
'layer2': 'rd',
'layer3': 'cv',
'layer4': 'cd',
'layer5': 'ad',
'layer6': 'rd',
'layer7': 'cv',
'layer8': 'cd',
'layer9': 'ad',
'layer10': 'rd',
'layer11': 'cv',
'layer12': 'cd',
'layer13': 'ad',
'layer14': 'rd',
'layer15': 'cv'
}
}
def create_conv_func(op_type):
assert op_type in ['cv', 'cd', 'ad',
'rd'], 'unknown op type: %s' % str(op_type)
if op_type == 'cv':
return F.conv2d
if op_type == 'cd':
def func(x,
weights,
bias=None,
stride=1,
padding=0,
dilation=1,
groups=1):
assert dilation in [1,
2], 'dilation for cd_conv should be in 1 or 2'
assert weights.size(2) == 3 and weights.size(
3) == 3, 'kernel size for cd_conv should be 3x3'
assert padding == dilation, 'padding for cd_conv set wrong'
weights_c = weights.sum(dim=[2, 3], keepdim=True)
yc = F.conv2d(
x, weights_c, stride=stride, padding=0, groups=groups)
y = F.conv2d(
x,
weights,
bias,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups)
return y - yc
return func
elif op_type == 'ad':
def func(x,
weights,
bias=None,
stride=1,
padding=0,
dilation=1,
groups=1):
assert dilation in [1,
2], 'dilation for ad_conv should be in 1 or 2'
assert weights.size(2) == 3 and weights.size(
3) == 3, 'kernel size for ad_conv should be 3x3'
assert padding == dilation, 'padding for ad_conv set wrong'
shape = weights.shape
weights = weights.view(shape[0], shape[1], -1)
weights_conv = (weights
- weights[:, :, [3, 0, 1, 6, 4, 2, 7, 8, 5]]).view(
shape) # clock-wise
y = F.conv2d(
x,
weights_conv,
bias,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups)
return y
return func
elif op_type == 'rd':
def func(x,
weights,
bias=None,
stride=1,
padding=0,
dilation=1,
groups=1):
assert dilation in [1,
2], 'dilation for rd_conv should be in 1 or 2'
assert weights.size(2) == 3 and weights.size(
3) == 3, 'kernel size for rd_conv should be 3x3'
padding = 2 * dilation
shape = weights.shape
if weights.is_cuda:
buffer = torch.cuda.FloatTensor(shape[0], shape[1],
5 * 5).fill_(0)
else:
buffer = torch.zeros(shape[0], shape[1], 5 * 5)
weights = weights.view(shape[0], shape[1], -1)
buffer[:, :, [0, 2, 4, 10, 14, 20, 22, 24]] = weights[:, :, 1:]
buffer[:, :, [6, 7, 8, 11, 13, 16, 17, 18]] = -weights[:, :, 1:]
buffer[:, :, 12] = 0
buffer = buffer.view(shape[0], shape[1], 5, 5)
y = F.conv2d(
x,
buffer,
bias,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups)
return y
return func
else:
print('impossible to be here unless you force that', flush=True)
return None
def config_model(model):
model_options = list(CONFIGS.keys())
assert model in model_options, \
'unrecognized model, please choose from %s' % str(model_options)
pdcs = []
for i in range(16):
layer_name = 'layer%d' % i
op = CONFIGS[model][layer_name]
pdcs.append(create_conv_func(op))
return pdcs
def config_model_converted(model):
model_options = list(CONFIGS.keys())
assert model in model_options, \
'unrecognized model, please choose from %s' % str(model_options)
pdcs = []
for i in range(16):
layer_name = 'layer%d' % i
op = CONFIGS[model][layer_name]
pdcs.append(op)
return pdcs
def convert_pdc(op, weight):
if op == 'cv':
return weight
elif op == 'cd':
shape = weight.shape
weight_c = weight.sum(dim=[2, 3])
weight = weight.view(shape[0], shape[1], -1)
weight[:, :, 4] = weight[:, :, 4] - weight_c
weight = weight.view(shape)
return weight
elif op == 'ad':
shape = weight.shape
weight = weight.view(shape[0], shape[1], -1)
weight_conv = (weight
- weight[:, :, [3, 0, 1, 6, 4, 2, 7, 8, 5]]).view(shape)
return weight_conv
elif op == 'rd':
shape = weight.shape
buffer = torch.zeros(shape[0], shape[1], 5 * 5, device=weight.device)
weight = weight.view(shape[0], shape[1], -1)
buffer[:, :, [0, 2, 4, 10, 14, 20, 22, 24]] = weight[:, :, 1:]
buffer[:, :, [6, 7, 8, 11, 13, 16, 17, 18]] = -weight[:, :, 1:]
buffer = buffer.view(shape[0], shape[1], 5, 5)
return buffer
raise ValueError('wrong op {}'.format(str(op)))
def convert_pidinet(state_dict, config):
pdcs = config_model_converted(config)
new_dict = {}
for pname, p in state_dict.items():
if 'init_block.weight' in pname:
new_dict[pname] = convert_pdc(pdcs[0], p)
elif 'block1_1.conv1.weight' in pname:
new_dict[pname] = convert_pdc(pdcs[1], p)
elif 'block1_2.conv1.weight' in pname:
new_dict[pname] = convert_pdc(pdcs[2], p)
elif 'block1_3.conv1.weight' in pname:
new_dict[pname] = convert_pdc(pdcs[3], p)
elif 'block2_1.conv1.weight' in pname:
new_dict[pname] = convert_pdc(pdcs[4], p)
elif 'block2_2.conv1.weight' in pname:
new_dict[pname] = convert_pdc(pdcs[5], p)
elif 'block2_3.conv1.weight' in pname:
new_dict[pname] = convert_pdc(pdcs[6], p)
elif 'block2_4.conv1.weight' in pname:
new_dict[pname] = convert_pdc(pdcs[7], p)
elif 'block3_1.conv1.weight' in pname:
new_dict[pname] = convert_pdc(pdcs[8], p)
elif 'block3_2.conv1.weight' in pname:
new_dict[pname] = convert_pdc(pdcs[9], p)
elif 'block3_3.conv1.weight' in pname:
new_dict[pname] = convert_pdc(pdcs[10], p)
elif 'block3_4.conv1.weight' in pname:
new_dict[pname] = convert_pdc(pdcs[11], p)
elif 'block4_1.conv1.weight' in pname:
new_dict[pname] = convert_pdc(pdcs[12], p)
elif 'block4_2.conv1.weight' in pname:
new_dict[pname] = convert_pdc(pdcs[13], p)
elif 'block4_3.conv1.weight' in pname:
new_dict[pname] = convert_pdc(pdcs[14], p)
elif 'block4_4.conv1.weight' in pname:
new_dict[pname] = convert_pdc(pdcs[15], p)
else:
new_dict[pname] = p
return new_dict
class Conv2d(nn.Module):
def __init__(self,
pdc,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=False):
super(Conv2d, self).__init__()
if in_channels % groups != 0:
raise ValueError('in_channels must be divisible by groups')
if out_channels % groups != 0:
raise ValueError('out_channels must be divisible by groups')
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.dilation = dilation
self.groups = groups
self.weight = nn.Parameter(
torch.Tensor(out_channels, in_channels // groups, kernel_size,
kernel_size))
if bias:
self.bias = nn.Parameter(torch.Tensor(out_channels))
else:
self.register_parameter('bias', None)
self.reset_parameters()
self.pdc = pdc
def reset_parameters(self):
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in)
nn.init.uniform_(self.bias, -bound, bound)
def forward(self, input):
return self.pdc(input, self.weight, self.bias, self.stride,
self.padding, self.dilation, self.groups)
class CSAM(nn.Module):
r"""
Compact Spatial Attention Module
"""
def __init__(self, channels):
super(CSAM, self).__init__()
mid_channels = 4
self.relu1 = nn.ReLU()
self.conv1 = nn.Conv2d(
channels, mid_channels, kernel_size=1, padding=0)
self.conv2 = nn.Conv2d(
mid_channels, 1, kernel_size=3, padding=1, bias=False)
self.sigmoid = nn.Sigmoid()
nn.init.constant_(self.conv1.bias, 0)
def forward(self, x):
y = self.relu1(x)
y = self.conv1(y)
y = self.conv2(y)
y = self.sigmoid(y)
return x * y
class CDCM(nn.Module):
r"""
Compact Dilation Convolution based Module
"""
def __init__(self, in_channels, out_channels):
super(CDCM, self).__init__()
self.relu1 = nn.ReLU()
self.conv1 = nn.Conv2d(
in_channels, out_channels, kernel_size=1, padding=0)
self.conv2_1 = nn.Conv2d(
out_channels,
out_channels,
kernel_size=3,
dilation=5,
padding=5,
bias=False)
self.conv2_2 = nn.Conv2d(
out_channels,
out_channels,
kernel_size=3,
dilation=7,
padding=7,
bias=False)
self.conv2_3 = nn.Conv2d(
out_channels,
out_channels,
kernel_size=3,
dilation=9,
padding=9,
bias=False)
self.conv2_4 = nn.Conv2d(
out_channels,
out_channels,
kernel_size=3,
dilation=11,
padding=11,
bias=False)
nn.init.constant_(self.conv1.bias, 0)
def forward(self, x):
x = self.relu1(x)
x = self.conv1(x)
x1 = self.conv2_1(x)
x2 = self.conv2_2(x)
x3 = self.conv2_3(x)
x4 = self.conv2_4(x)
return x1 + x2 + x3 + x4
class MapReduce(nn.Module):
r"""
Reduce feature maps into a single edge map
"""
def __init__(self, channels):
super(MapReduce, self).__init__()
self.conv = nn.Conv2d(channels, 1, kernel_size=1, padding=0)
nn.init.constant_(self.conv.bias, 0)
def forward(self, x):
return self.conv(x)
class PDCBlock(nn.Module):
def __init__(self, pdc, inplane, ouplane, stride=1):
super(PDCBlock, self).__init__()
self.stride = stride
self.stride = stride
if self.stride > 1:
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.shortcut = nn.Conv2d(
inplane, ouplane, kernel_size=1, padding=0)
self.conv1 = Conv2d(
pdc,
inplane,
inplane,
kernel_size=3,
padding=1,
groups=inplane,
bias=False)
self.relu2 = nn.ReLU()
self.conv2 = nn.Conv2d(
inplane, ouplane, kernel_size=1, padding=0, bias=False)
def forward(self, x):
if self.stride > 1:
x = self.pool(x)
y = self.conv1(x)
y = self.relu2(y)
y = self.conv2(y)
if self.stride > 1:
x = self.shortcut(x)
y = y + x
return y
class PDCBlock_converted(nn.Module):
r"""
CPDC, APDC can be converted to vanilla 3x3 convolution
RPDC can be converted to vanilla 5x5 convolution
"""
def __init__(self, pdc, inplane, ouplane, stride=1):
super(PDCBlock_converted, self).__init__()
self.stride = stride
if self.stride > 1:
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.shortcut = nn.Conv2d(
inplane, ouplane, kernel_size=1, padding=0)
if pdc == 'rd':
self.conv1 = nn.Conv2d(
inplane,
inplane,
kernel_size=5,
padding=2,
groups=inplane,
bias=False)
else:
self.conv1 = nn.Conv2d(
inplane,
inplane,
kernel_size=3,
padding=1,
groups=inplane,
bias=False)
self.relu2 = nn.ReLU()
self.conv2 = nn.Conv2d(
inplane, ouplane, kernel_size=1, padding=0, bias=False)
def forward(self, x):
if self.stride > 1:
x = self.pool(x)
y = self.conv1(x)
y = self.relu2(y)
y = self.conv2(y)
if self.stride > 1:
x = self.shortcut(x)
y = y + x
return y
class PiDiNet(nn.Module):
def __init__(self, inplane, pdcs, dil=None, sa=False, convert=False):
super(PiDiNet, self).__init__()
self.sa = sa
if dil is not None:
assert isinstance(dil, int), 'dil should be an int'
self.dil = dil
self.fuseplanes = []
self.inplane = inplane
if convert:
if pdcs[0] == 'rd':
init_kernel_size = 5
init_padding = 2
else:
init_kernel_size = 3
init_padding = 1
self.init_block = nn.Conv2d(
3,
self.inplane,
kernel_size=init_kernel_size,
padding=init_padding,
bias=False)
block_class = PDCBlock_converted
else:
self.init_block = Conv2d(
pdcs[0], 3, self.inplane, kernel_size=3, padding=1)
block_class = PDCBlock
self.block1_1 = block_class(pdcs[1], self.inplane, self.inplane)
self.block1_2 = block_class(pdcs[2], self.inplane, self.inplane)
self.block1_3 = block_class(pdcs[3], self.inplane, self.inplane)
self.fuseplanes.append(self.inplane) # C
inplane = self.inplane
self.inplane = self.inplane * 2
self.block2_1 = block_class(pdcs[4], inplane, self.inplane, stride=2)
self.block2_2 = block_class(pdcs[5], self.inplane, self.inplane)
self.block2_3 = block_class(pdcs[6], self.inplane, self.inplane)
self.block2_4 = block_class(pdcs[7], self.inplane, self.inplane)
self.fuseplanes.append(self.inplane) # 2C
inplane = self.inplane
self.inplane = self.inplane * 2
self.block3_1 = block_class(pdcs[8], inplane, self.inplane, stride=2)
self.block3_2 = block_class(pdcs[9], self.inplane, self.inplane)
self.block3_3 = block_class(pdcs[10], self.inplane, self.inplane)
self.block3_4 = block_class(pdcs[11], self.inplane, self.inplane)
self.fuseplanes.append(self.inplane) # 4C
self.block4_1 = block_class(
pdcs[12], self.inplane, self.inplane, stride=2)
self.block4_2 = block_class(pdcs[13], self.inplane, self.inplane)
self.block4_3 = block_class(pdcs[14], self.inplane, self.inplane)
self.block4_4 = block_class(pdcs[15], self.inplane, self.inplane)
self.fuseplanes.append(self.inplane) # 4C
self.conv_reduces = nn.ModuleList()
if self.sa and self.dil is not None:
self.attentions = nn.ModuleList()
self.dilations = nn.ModuleList()
for i in range(4):
self.dilations.append(CDCM(self.fuseplanes[i], self.dil))
self.attentions.append(CSAM(self.dil))
self.conv_reduces.append(MapReduce(self.dil))
elif self.sa:
self.attentions = nn.ModuleList()
for i in range(4):
self.attentions.append(CSAM(self.fuseplanes[i]))
self.conv_reduces.append(MapReduce(self.fuseplanes[i]))
elif self.dil is not None:
self.dilations = nn.ModuleList()
for i in range(4):
self.dilations.append(CDCM(self.fuseplanes[i], self.dil))
self.conv_reduces.append(MapReduce(self.dil))
else:
for i in range(4):
self.conv_reduces.append(MapReduce(self.fuseplanes[i]))
self.classifier = nn.Conv2d(4, 1, kernel_size=1)
nn.init.constant_(self.classifier.weight, 0.25)
nn.init.constant_(self.classifier.bias, 0)
def get_weights(self):
conv_weights = []
bn_weights = []
relu_weights = []
for pname, p in self.named_parameters():
if 'bn' in pname:
bn_weights.append(p)
elif 'relu' in pname:
relu_weights.append(p)
else:
conv_weights.append(p)
return conv_weights, bn_weights, relu_weights
def forward(self, x):
H, W = x.size()[2:]
x = self.init_block(x)
x1 = self.block1_1(x)
x1 = self.block1_2(x1)
x1 = self.block1_3(x1)
x2 = self.block2_1(x1)
x2 = self.block2_2(x2)
x2 = self.block2_3(x2)
x2 = self.block2_4(x2)
x3 = self.block3_1(x2)
x3 = self.block3_2(x3)
x3 = self.block3_3(x3)
x3 = self.block3_4(x3)
x4 = self.block4_1(x3)
x4 = self.block4_2(x4)
x4 = self.block4_3(x4)
x4 = self.block4_4(x4)
x_fuses = []
if self.sa and self.dil is not None:
for i, xi in enumerate([x1, x2, x3, x4]):
x_fuses.append(self.attentions[i](self.dilations[i](xi)))
elif self.sa:
for i, xi in enumerate([x1, x2, x3, x4]):
x_fuses.append(self.attentions[i](xi))
elif self.dil is not None:
for i, xi in enumerate([x1, x2, x3, x4]):
x_fuses.append(self.dilations[i](xi))
else:
x_fuses = [x1, x2, x3, x4]
e1 = self.conv_reduces[0](x_fuses[0])
e1 = F.interpolate(e1, (H, W), mode='bilinear', align_corners=False)
e2 = self.conv_reduces[1](x_fuses[1])
e2 = F.interpolate(e2, (H, W), mode='bilinear', align_corners=False)
e3 = self.conv_reduces[2](x_fuses[2])
e3 = F.interpolate(e3, (H, W), mode='bilinear', align_corners=False)
e4 = self.conv_reduces[3](x_fuses[3])
e4 = F.interpolate(e4, (H, W), mode='bilinear', align_corners=False)
outputs = [e1, e2, e3, e4]
output = self.classifier(torch.cat(outputs, dim=1))
outputs.append(output)
outputs = [torch.sigmoid(r) for r in outputs]
return outputs[-1]
def pidinet_bsd_tiny(pretrained=False, vanilla_cnn=True):
pdcs = config_model_converted('carv4') if vanilla_cnn else config_model(
'carv4')
model = PiDiNet(20, pdcs, dil=8, sa=True, convert=vanilla_cnn)
if pretrained:
state = torch.load(
DOWNLOAD_TO_CACHE('models/pidinet/table5_pidinet-tiny.pth'),
map_location='cpu')['state_dict']
if vanilla_cnn:
state = convert_pidinet(state, 'carv4')
state = {
k[len('module.'):] if k.startswith('module.') else k: v
for k, v in state.items()
}
model.load_state_dict(state)
return model
def pidinet_bsd_small(pretrained=False, vanilla_cnn=True):
pdcs = config_model_converted('carv4') if vanilla_cnn else config_model(
'carv4')
model = PiDiNet(30, pdcs, dil=12, sa=True, convert=vanilla_cnn)
if pretrained:
state = torch.load(
DOWNLOAD_TO_CACHE('models/pidinet/table5_pidinet-small.pth'),
map_location='cpu')['state_dict']
if vanilla_cnn:
state = convert_pidinet(state, 'carv4')
state = {
k[len('module.'):] if k.startswith('module.') else k: v
for k, v in state.items()
}
model.load_state_dict(state)
return model
def pidinet_bsd(model_dir, pretrained=False, vanilla_cnn=True):
pdcs = config_model_converted('carv4') if vanilla_cnn else config_model(
'carv4')
model = PiDiNet(60, pdcs, dil=24, sa=True, convert=vanilla_cnn)
if pretrained:
state = torch.load(
os.path.join(model_dir, 'table5_pidinet.pth'),
map_location='cpu')['state_dict']
if vanilla_cnn:
state = convert_pidinet(state, 'carv4')
state = {
k[len('module.'):] if k.startswith('module.') else k: v
for k, v in state.items()
}
model.load_state_dict(state)
return model
def pidinet_nyud(pretrained=False, vanilla_cnn=True):
pdcs = config_model_converted('carv4') if vanilla_cnn else config_model(
'carv4')
model = PiDiNet(60, pdcs, dil=24, sa=True, convert=vanilla_cnn)
if pretrained:
state = torch.load(
DOWNLOAD_TO_CACHE('models/pidinet/table6_pidinet.pth'),
map_location='cpu')['state_dict']
if vanilla_cnn:
state = convert_pidinet(state, 'carv4')
state = {
k[len('module.'):] if k.startswith('module.') else k: v
for k, v in state.items()
}
model.load_state_dict(state)
return model
def pidinet_multicue(pretrained=False, vanilla_cnn=True):
pdcs = config_model_converted('carv4') if vanilla_cnn else config_model(
'carv4')
model = PiDiNet(60, pdcs, dil=24, sa=True, convert=vanilla_cnn)
if pretrained:
state = torch.load(
DOWNLOAD_TO_CACHE('models/pidinet/table7_pidinet.pth'),
map_location='cpu')['state_dict']
if vanilla_cnn:
state = convert_pidinet(state, 'carv4')
state = {
k[len('module.'):] if k.startswith('module.') else k: v
for k, v in state.items()
}
model.load_state_dict(state)
return model

View File

@@ -0,0 +1,110 @@
r"""PyTorch re-implementation adapted from the Lua code in ``https://github.com/bobbens/sketch_simplification''.
"""
import math
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from modelscope.models.multi_modal.videocomposer.utils.utils import \
DOWNLOAD_TO_CACHE
__all__ = [
'SketchSimplification', 'sketch_simplification_gan',
'sketch_simplification_mse', 'sketch_to_pencil_v1', 'sketch_to_pencil_v2'
]
class SketchSimplification(nn.Module):
r"""NOTE:
1. Input image should has only one gray channel.
2. Input image size should be divisible by 8.
3. Sketch in the input/output image is in dark color while background in light color.
"""
def __init__(self, mean, std):
assert isinstance(mean, float) and isinstance(std, float)
super(SketchSimplification, self).__init__()
self.mean = mean
self.std = std
# layers
self.layers = nn.Sequential(
nn.Conv2d(1, 48, 5, 2, 2), nn.ReLU(inplace=True),
nn.Conv2d(48, 128, 3, 1, 1), nn.ReLU(inplace=True),
nn.Conv2d(128, 128, 3, 1, 1), nn.ReLU(inplace=True),
nn.Conv2d(128, 128, 3, 2, 1), nn.ReLU(inplace=True),
nn.Conv2d(128, 256, 3, 1, 1), nn.ReLU(inplace=True),
nn.Conv2d(256, 256, 3, 1, 1), nn.ReLU(inplace=True),
nn.Conv2d(256, 256, 3, 2, 1), nn.ReLU(inplace=True),
nn.Conv2d(256, 512, 3, 1, 1), nn.ReLU(inplace=True),
nn.Conv2d(512, 1024, 3, 1, 1), nn.ReLU(inplace=True),
nn.Conv2d(1024, 1024, 3, 1, 1), nn.ReLU(inplace=True),
nn.Conv2d(1024, 1024, 3, 1, 1), nn.ReLU(inplace=True),
nn.Conv2d(1024, 1024, 3, 1, 1), nn.ReLU(inplace=True),
nn.Conv2d(1024, 512, 3, 1, 1), nn.ReLU(inplace=True),
nn.Conv2d(512, 256, 3, 1, 1), nn.ReLU(inplace=True),
nn.ConvTranspose2d(256, 256, 4, 2, 1), nn.ReLU(inplace=True),
nn.Conv2d(256, 256, 3, 1, 1), nn.ReLU(inplace=True),
nn.Conv2d(256, 128, 3, 1, 1), nn.ReLU(inplace=True),
nn.ConvTranspose2d(128, 128, 4, 2, 1), nn.ReLU(inplace=True),
nn.Conv2d(128, 128, 3, 1, 1), nn.ReLU(inplace=True),
nn.Conv2d(128, 48, 3, 1, 1), nn.ReLU(inplace=True),
nn.ConvTranspose2d(48, 48, 4, 2, 1), nn.ReLU(inplace=True),
nn.Conv2d(48, 24, 3, 1, 1), nn.ReLU(inplace=True),
nn.Conv2d(24, 1, 3, 1, 1), nn.Sigmoid())
def forward(self, x):
r"""x: [B, 1, H, W] within range [0, 1]. Sketch pixels in dark color.
"""
x = (x - self.mean) / self.std
return self.layers(x)
def sketch_simplification_gan(model_dir, pretrained=False):
model = SketchSimplification(
mean=0.9664114577640158, std=0.0858381272736797)
if pretrained:
model.load_state_dict(
torch.load(
os.path.join(model_dir, 'sketch_simplification_gan.pth'),
map_location='cpu'))
return model
def sketch_simplification_mse(pretrained=False):
model = SketchSimplification(
mean=0.9664423107454593, std=0.08583666033640507)
if pretrained:
model.load_state_dict(
torch.load(
DOWNLOAD_TO_CACHE(
'models/sketch_simplification/sketch_simplification_mse.pth'
),
map_location='cpu'))
return model
def sketch_to_pencil_v1(pretrained=False):
model = SketchSimplification(
mean=0.9817833515894078, std=0.0925009022585048)
if pretrained:
model.load_state_dict(
torch.load(
DOWNLOAD_TO_CACHE(
'models/sketch_simplification/sketch_to_pencil_v1.pth'),
map_location='cpu'))
return model
def sketch_to_pencil_v2(pretrained=False):
model = SketchSimplification(
mean=0.9851298627337799, std=0.07418377454883571)
if pretrained:
model.load_state_dict(
torch.load(
DOWNLOAD_TO_CACHE(
'models/sketch_simplification/sketch_to_pencil_v2.pth'),
map_location='cpu'))
return model

View File

@@ -0,0 +1,42 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import cv2
import numpy as np
annotator_ckpts_path = os.path.join(os.path.dirname(__file__), 'ckpts')
def HWC3(x):
assert x.dtype == np.uint8
if x.ndim == 2:
x = x[:, :, None]
assert x.ndim == 3
H, W, C = x.shape
assert C == 1 or C == 3 or C == 4
if C == 3:
return x
if C == 1:
return np.concatenate([x, x, x], axis=2)
if C == 4:
color = x[:, :, 0:3].astype(np.float32)
alpha = x[:, :, 3:4].astype(np.float32) / 255.0
y = color * alpha + 255.0 * (1.0 - alpha)
y = y.clip(0, 255).astype(np.uint8)
return y
def resize_image(input_image, resolution):
H, W, C = input_image.shape
H = float(H)
W = float(W)
k = float(resolution) / min(H, W)
H *= k
W *= k
H = int(np.round(H / 64.0)) * 64
W = int(np.round(W / 64.0)) * 64
img = cv2.resize(
input_image, (W, H),
interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
return img

View File

@@ -0,0 +1,650 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
__all__ = ['AutoencoderKL']
def nonlinearity(x):
# swish
return x * torch.sigmoid(x)
def Normalize(in_channels, num_groups=32):
return torch.nn.GroupNorm(
num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
class DiagonalGaussianDistribution(object):
def __init__(self, parameters, deterministic=False):
self.parameters = parameters
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.deterministic = deterministic
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
if self.deterministic:
self.var = self.std = torch.zeros_like(
self.mean).to(device=self.parameters.device)
def sample(self):
x = self.mean + self.std * torch.randn(
self.mean.shape).to(device=self.parameters.device)
return x
def kl(self, other=None):
if self.deterministic:
return torch.Tensor([0.])
else:
if other is None:
return 0.5 * torch.sum(
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
dim=[1, 2, 3])
else:
return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
dim=[1, 2, 3])
def nll(self, sample, dims=[1, 2, 3]):
if self.deterministic:
return torch.Tensor([0.])
logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum(
logtwopi + self.logvar
+ torch.pow(sample - self.mean, 2) / self.var,
dim=dims)
def mode(self):
return self.mean
class Downsample(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
# no asymmetric padding in torch conv, must do it ourselves
self.conv = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=2, padding=0)
def forward(self, x):
if self.with_conv:
pad = (0, 1, 0, 1)
x = torch.nn.functional.pad(x, pad, mode='constant', value=0)
x = self.conv(x)
else:
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
return x
class ResnetBlock(nn.Module):
def __init__(self,
*,
in_channels,
out_channels=None,
conv_shortcut=False,
dropout,
temb_channels=512):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.norm1 = Normalize(in_channels)
self.conv1 = torch.nn.Conv2d(
in_channels, out_channels, kernel_size=3, stride=1, padding=1)
if temb_channels > 0:
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
self.norm2 = Normalize(out_channels)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = torch.nn.Conv2d(
out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = torch.nn.Conv2d(
in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
else:
self.nin_shortcut = torch.nn.Conv2d(
in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0)
def forward(self, x, temb):
h = x
h = self.norm1(h)
h = nonlinearity(h)
h = self.conv1(h)
if temb is not None:
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
h = self.norm2(h)
h = nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)
return x + h
class AttnBlock(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.k = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
b, c, h, w = q.shape
q = q.reshape(b, c, h * w)
q = q.permute(0, 2, 1) # b,hw,c
k = k.reshape(b, c, h * w) # b,c,hw
w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w_ = w_ * (int(c)**(-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)
# attend to values
v = v.reshape(b, c, h * w)
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
h_ = torch.bmm(
v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
h_ = h_.reshape(b, c, h, w)
h_ = self.proj_out(h_)
return x + h_
class AttnBlock(nn.Module): # noqa
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.k = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
b, c, h, w = q.shape
q = q.reshape(b, c, h * w)
q = q.permute(0, 2, 1) # b,hw,c
k = k.reshape(b, c, h * w) # b,c,hw
w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w_ = w_ * (int(c)**(-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)
# attend to values
v = v.reshape(b, c, h * w)
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
h_ = torch.bmm(
v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
h_ = h_.reshape(b, c, h, w)
h_ = self.proj_out(h_)
return x + h_
class Upsample(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
self.conv = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=1, padding=1)
def forward(self, x):
x = torch.nn.functional.interpolate(
x, scale_factor=2.0, mode='nearest')
if self.with_conv:
x = self.conv(x)
return x
class Downsample(nn.Module): # noqa
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
# no asymmetric padding in torch conv, must do it ourselves
self.conv = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=2, padding=0)
def forward(self, x):
if self.with_conv:
pad = (0, 1, 0, 1)
x = torch.nn.functional.pad(x, pad, mode='constant', value=0)
x = self.conv(x)
else:
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
return x
class Encoder(nn.Module):
def __init__(self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
z_channels,
double_z=True,
use_linear_attn=False,
attn_type='vanilla',
**ignore_kwargs):
super().__init__()
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
# downsampling
self.conv_in = torch.nn.Conv2d(
in_channels, self.ch, kernel_size=3, stride=1, padding=1)
curr_res = resolution
in_ch_mult = (1, ) + tuple(ch_mult)
self.in_ch_mult = in_ch_mult
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks):
block.append(
ResnetBlock(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(AttnBlock(block_in))
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions - 1:
down.downsample = Downsample(block_in, resamp_with_conv)
curr_res = curr_res // 2
self.down.append(down)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
self.mid.attn_1 = AttnBlock(block_in)
self.mid.block_2 = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(
block_in,
2 * z_channels if double_z else z_channels,
kernel_size=3,
stride=1,
padding=1)
def forward(self, x):
# timestep embedding
temb = None
# downsampling
hs = [self.conv_in(x)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1], temb)
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
hs.append(h)
if i_level != self.num_resolutions - 1:
hs.append(self.down[i_level].downsample(hs[-1]))
# middle
h = hs[-1]
h = self.mid.block_1(h, temb)
h = self.mid.attn_1(h)
h = self.mid.block_2(h, temb)
# end
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
return h
class Decoder(nn.Module):
def __init__(self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
z_channels,
give_pre_end=False,
tanh_out=False,
use_linear_attn=False,
attn_type='vanilla',
**ignorekwargs):
super().__init__()
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.give_pre_end = give_pre_end
self.tanh_out = tanh_out
# compute in_ch_mult, block_in and curr_res at lowest res
block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2**(self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res)
print('Working with z of shape {} = {} dimensions.'.format(
self.z_shape, np.prod(self.z_shape)))
# z to block_in
self.conv_in = torch.nn.Conv2d(
z_channels, block_in, kernel_size=3, stride=1, padding=1)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
self.mid.attn_1 = AttnBlock(block_in)
self.mid.block_2 = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
block.append(
ResnetBlock(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(AttnBlock(block_in))
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = Upsample(block_in, resamp_with_conv)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(
block_in, out_ch, kernel_size=3, stride=1, padding=1)
def forward(self, z):
self.last_z_shape = z.shape
# timestep embedding
temb = None
# z to block_in
h = self.conv_in(z)
# middle
h = self.mid.block_1(h, temb)
h = self.mid.attn_1(h)
h = self.mid.block_2(h, temb)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](h, temb)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
if i_level != 0:
h = self.up[i_level].upsample(h)
# end
if self.give_pre_end:
return h
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
if self.tanh_out:
h = torch.tanh(h)
return h
class AutoencoderKL(nn.Module):
def __init__(self,
ddconfig,
embed_dim,
ckpt_path=None,
ignore_keys=[],
image_key='image',
colorize_nlabels=None,
monitor=None,
ema_decay=None,
learn_logvar=False):
super().__init__()
self.learn_logvar = learn_logvar
self.image_key = image_key
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
assert ddconfig['double_z']
self.quant_conv = torch.nn.Conv2d(2 * ddconfig['z_channels'],
2 * embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim,
ddconfig['z_channels'], 1)
self.embed_dim = embed_dim
if colorize_nlabels is not None:
assert type(colorize_nlabels) == int
self.register_buffer('colorize',
torch.randn(3, colorize_nlabels, 1, 1))
if monitor is not None:
self.monitor = monitor
self.use_ema = ema_decay is not None
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
def init_from_ckpt(self, path, ignore_keys=list()):
sd = torch.load(path, map_location='cpu')['state_dict']
keys = list(sd.keys())
for key in keys:
print(key, sd[key].shape)
import collections
sd_new = collections.OrderedDict()
for k in keys:
if k.find('first_stage_model') >= 0:
k_new = k.split('first_stage_model.')[-1]
sd_new[k_new] = sd[k]
self.load_state_dict(sd_new, strict=True)
print(f'Restored from {path}')
def init_from_ckpt2(self, path, ignore_keys=list()):
sd = torch.load(path, map_location='cpu')['state_dict']
keys = list(sd.keys())
first_stage_model
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print('Deleting key {} from state_dict.'.format(k))
del sd[k]
self.load_state_dict(sd, strict=False)
print(f'Restored from {path}')
def on_train_batch_end(self, *args, **kwargs):
if self.use_ema:
self.model_ema(self)
def encode(self, x):
h = self.encoder(x)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
return posterior
def decode(self, z):
z = self.post_quant_conv(z)
dec = self.decoder(z)
return dec
def forward(self, input, sample_posterior=True):
posterior = self.encode(input)
if sample_posterior:
z = posterior.sample()
else:
z = posterior.mode()
dec = self.decode(z)
return dec, posterior
def get_input(self, batch, k):
x = batch[k]
if len(x.shape) == 3:
x = x[..., None]
x = x.permute(0, 3, 1,
2).to(memory_format=torch.contiguous_format).float()
return x
def get_last_layer(self):
return self.decoder.conv_out.weight
@torch.no_grad()
def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs):
log = dict()
x = self.get_input(batch, self.image_key)
x = x.to(self.device)
if not only_inputs:
xrec, posterior = self(x)
if x.shape[1] > 3:
# colorize with random projection
assert xrec.shape[1] > 3
x = self.to_rgb(x)
xrec = self.to_rgb(xrec)
log['samples'] = self.decode(torch.randn_like(posterior.sample()))
log['reconstructions'] = xrec
if log_ema or self.use_ema:
with self.ema_scope():
xrec_ema, posterior_ema = self(x)
if x.shape[1] > 3:
# colorize with random projection
assert xrec_ema.shape[1] > 3
xrec_ema = self.to_rgb(xrec_ema)
log['samples_ema'] = self.decode(
torch.randn_like(posterior_ema.sample()))
log['reconstructions_ema'] = xrec_ema
log['inputs'] = x
return log
def to_rgb(self, x):
assert self.image_key == 'segmentation'
if not hasattr(self, 'colorize'):
self.register_buffer('colorize',
torch.randn(3, x.shape[1], 1, 1).to(x))
x = F.conv2d(x, weight=self.colorize)
x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
return x
class IdentityFirstStage(torch.nn.Module):
def __init__(self, *args, vq_interface=False, **kwargs):
self.vq_interface = vq_interface
super().__init__()
def encode(self, x, *args, **kwargs):
return x
def decode(self, x, *args, **kwargs):
return x
def quantize(self, x, *args, **kwargs):
if self.vq_interface:
return x, None, [None, None, None]
return x
def forward(self, x, *args, **kwargs):
return x

View File

@@ -0,0 +1,143 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import numpy as np
import open_clip
import torch
import torch.nn as nn
import torchvision.transforms as T
class FrozenOpenCLIPEmbedder(nn.Module):
"""
Uses the OpenCLIP transformer encoder for text
"""
LAYERS = ['last', 'penultimate']
def __init__(self,
arch='ViT-H-14',
pretrained='laion2b_s32b_b79k',
device='cuda',
max_length=77,
freeze=True,
layer='last'):
super().__init__()
assert layer in self.LAYERS
model, _, _ = open_clip.create_model_and_transforms(
arch, device=torch.device('cpu'), pretrained=pretrained)
del model.visual
self.model = model
self.device = device
self.max_length = max_length
if freeze:
self.freeze()
self.layer = layer
if self.layer == 'last':
self.layer_idx = 0
elif self.layer == 'penultimate':
self.layer_idx = 1
else:
raise NotImplementedError()
def freeze(self):
self.model = self.model.eval()
for param in self.parameters():
param.requires_grad = False
def forward(self, text):
tokens = open_clip.tokenize(text)
z = self.encode_with_transformer(tokens.to(self.device))
return z
def encode_with_transformer(self, text):
x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
x = x + self.model.positional_embedding
x = x.permute(1, 0, 2) # NLD -> LND
x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.model.ln_final(x)
return x
def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
for i, r in enumerate(self.model.transformer.resblocks):
if i == len(self.model.transformer.resblocks) - self.layer_idx:
break
if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting(
):
x = checkpoint(r, x, attn_mask)
else:
x = r(x, attn_mask=attn_mask)
return x
def encode(self, text):
return self(text)
class FrozenOpenCLIPVisualEmbedder(nn.Module):
"""
Uses the OpenCLIP transformer encoder for text
"""
LAYERS = ['last', 'penultimate']
def __init__(self,
arch='ViT-H-14',
pretrained='laion2b_s32b_b79k',
device='cuda',
max_length=77,
freeze=True,
layer='last',
input_shape=(224, 224, 3)):
super().__init__()
assert layer in self.LAYERS
model, _, preprocess = open_clip.create_model_and_transforms(
arch, device=torch.device('cpu'), pretrained=pretrained)
del model.transformer
self.model = model
data_white = np.ones(input_shape, dtype=np.uint8) * 255
self.black_image = preprocess(T.ToPILImage()(data_white)).unsqueeze(0)
self.preprocess = preprocess
self.device = device
self.max_length = max_length # 77
if freeze:
self.freeze()
self.layer = layer # 'penultimate'
if self.layer == 'last':
self.layer_idx = 0
elif self.layer == 'penultimate':
self.layer_idx = 1
else:
raise NotImplementedError()
def freeze(self):
self.model = self.model.eval()
for param in self.parameters():
param.requires_grad = False
def forward(self, image):
# tokens = open_clip.tokenize(text)
z = self.model.encode_image(image.to(self.device))
return z
def encode_with_transformer(self, text):
x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
x = x + self.model.positional_embedding
x = x.permute(1, 0, 2) # NLD -> LND
x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.model.ln_final(x)
return x
def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
for i, r in enumerate(self.model.transformer.resblocks):
if i == len(self.model.transformer.resblocks) - self.layer_idx:
break
if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting(
):
x = checkpoint(r, x, attn_mask)
else:
x = r(x, attn_mask=attn_mask)
return x
def encode(self, text):
return self(text)

View File

@@ -0,0 +1,156 @@
import logging
import os
import os.path as osp
from datetime import datetime
import torch
from easydict import EasyDict
cfg = EasyDict(__name__='Config: VideoComposer')
pmi_world_size = int(os.getenv('WORLD_SIZE', 1))
gpus_per_machine = torch.cuda.device_count()
world_size = pmi_world_size * gpus_per_machine
cfg.video_compositions = [
'text', 'mask', 'depthmap', 'sketch', 'motion', 'image', 'local_image',
'single_sketch'
]
# dataset
cfg.root_dir = 'webvid10m/'
cfg.alpha = 0.7
cfg.misc_size = 384
cfg.depth_std = 20.0
cfg.depth_clamp = 10.0
cfg.hist_sigma = 10.0
cfg.use_image_dataset = False
cfg.alpha_img = 0.7
cfg.resolution = 256
cfg.mean = [0.5, 0.5, 0.5]
cfg.std = [0.5, 0.5, 0.5]
# sketch
cfg.sketch_mean = [0.485, 0.456, 0.406]
cfg.sketch_std = [0.229, 0.224, 0.225]
# dataloader
cfg.max_words = 1000
cfg.frame_lens = [
16,
16,
16,
16,
]
cfg.feature_framerates = [
4,
]
cfg.feature_framerate = 4
cfg.batch_sizes = {
str(1): 1,
str(4): 1,
str(8): 1,
str(16): 1,
}
cfg.chunk_size = 64
cfg.num_workers = 8
cfg.prefetch_factor = 2
cfg.seed = 8888
# diffusion
cfg.num_timesteps = 1000
cfg.mean_type = 'eps'
cfg.var_type = 'fixed_small'
cfg.loss_type = 'mse'
cfg.ddim_timesteps = 50
cfg.ddim_eta = 0.0
cfg.clamp = 1.0
cfg.share_noise = False
cfg.use_div_loss = False
# classifier-free guidance
cfg.p_zero = 0.9
cfg.guide_scale = 6.0
# stabel diffusion
cfg.sd_checkpoint = 'v2-1_512-ema-pruned.ckpt'
# clip vision encoder
cfg.vit_image_size = 336
cfg.vit_patch_size = 14
cfg.vit_dim = 1024
cfg.vit_out_dim = 768
cfg.vit_heads = 16
cfg.vit_layers = 24
cfg.vit_mean = [0.48145466, 0.4578275, 0.40821073]
cfg.vit_std = [0.26862954, 0.26130258, 0.27577711]
cfg.clip_checkpoint = 'open_clip_pytorch_model.bin'
cfg.mvs_visual = False
# unet
cfg.unet_in_dim = 4
cfg.unet_concat_dim = 8
cfg.unet_y_dim = cfg.vit_out_dim
cfg.unet_context_dim = 1024
cfg.unet_out_dim = 8 if cfg.var_type.startswith('learned') else 4
cfg.unet_dim = 320
cfg.unet_dim_mult = [1, 2, 4, 4]
cfg.unet_res_blocks = 2
cfg.unet_num_heads = 8
cfg.unet_head_dim = 64
cfg.unet_attn_scales = [1 / 1, 1 / 2, 1 / 4]
cfg.unet_dropout = 0.1
cfg.misc_dropout = 0.5
cfg.p_all_zero = 0.1
cfg.p_all_keep = 0.1
cfg.temporal_conv = False
cfg.temporal_attn_times = 1
cfg.temporal_attention = True
cfg.use_fps_condition = False
cfg.use_sim_mask = False
# Default: load 2d pretrain
cfg.pretrained = False
cfg.fix_weight = False
# Default resume
cfg.resume = True
cfg.resume_step = 148000
cfg.resume_check_dir = '.'
cfg.resume_checkpoint = os.path.join(
cfg.resume_check_dir,
f'step_{cfg.resume_step}/non_ema_{cfg.resume_step}.pth')
cfg.resume_optimizer = False
if cfg.resume_optimizer:
cfg.resume_optimizer = os.path.join(
cfg.resume_check_dir, f'optimizer_step_{cfg.resume_step}.pt')
# acceleration
cfg.use_ema = True
# for debug, no ema
if world_size < 2:
cfg.use_ema = False
cfg.load_from = None
cfg.use_checkpoint = True
cfg.use_sharded_ddp = False
cfg.use_fsdp = False
cfg.use_fp16 = True
# training
cfg.ema_decay = 0.9999
cfg.viz_interval = 1000
cfg.save_ckp_interval = 1000
# logging
cfg.log_interval = 100
composition_strings = '_'.join(cfg.video_compositions)
# Default log_dir
cfg.log_dir = 'outputs/'

View File

@@ -0,0 +1,2 @@
ENABLE: true
DATASET: webvid10m

View File

@@ -0,0 +1,20 @@
TASK_TYPE: MULTI_TASK
ENABLE: true
DATASET: webvid10m
video_compositions: ['text', 'mask', 'depthmap', 'sketch', 'motion', 'image', 'local_image', 'single_sketch']
batch_sizes: {
"1": 1,
"4": 1,
"8": 1,
"16": 1,
}
vit_image_size: 224
network_name: UNetSD_temporal
resume: true
resume_step: 228000
num_workers: 1
mvs_visual: False
chunk_size: 1
resume_checkpoint: "model_weights/non_ema_228000.pth"
log_dir: 'outputs'
num_steps: 1

View File

@@ -0,0 +1,23 @@
TASK_TYPE: SINGLE_TASK
read_image: True # You NEED Open It
ENABLE: true
DATASET: webvid10m
video_compositions: ['text', 'mask', 'depthmap', 'sketch', 'motion', 'image', 'local_image', 'single_sketch']
guidances: ['y', 'local_image', 'motion'] # You NEED Open It
batch_sizes: {
"1": 1,
"4": 1,
"8": 1,
"16": 1,
}
vit_image_size: 224
network_name: UNetSD_temporal
resume: true
resume_step: 228000
seed: 182
num_workers: 0
mvs_visual: False
chunk_size: 1
resume_checkpoint: "model_weights/non_ema_228000.pth"
log_dir: 'outputs'
num_steps: 1

View File

@@ -0,0 +1,24 @@
TASK_TYPE: SINGLE_TASK
read_image: True # You NEED Open It
read_style: True
ENABLE: true
DATASET: webvid10m
video_compositions: ['text', 'mask', 'depthmap', 'sketch', 'motion', 'image', 'local_image', 'single_sketch']
guidances: ['y', 'local_image', 'image', 'motion'] # You NEED Open It
batch_sizes: {
"1": 1,
"4": 1,
"8": 1,
"16": 1,
}
vit_image_size: 224
network_name: UNetSD_temporal
resume: true
resume_step: 228000
seed: 182
num_workers: 0
mvs_visual: False
chunk_size: 1
resume_checkpoint: "model_weights/non_ema_228000.pth"
log_dir: 'outputs'
num_steps: 1

View File

@@ -0,0 +1,26 @@
TASK_TYPE: SINGLE_TASK
read_image: False # You NEED Open It
read_style: True
read_sketch: True
save_origin_video: False
ENABLE: true
DATASET: webvid10m
video_compositions: ['text', 'mask', 'depthmap', 'sketch', 'motion', 'image', 'local_image', 'single_sketch']
guidances: ['y', 'image', 'single_sketch'] # You NEED Open It
batch_sizes: {
"1": 1,
"4": 1,
"8": 1,
"16": 1,
}
vit_image_size: 224
network_name: UNetSD_temporal
resume: true
resume_step: 228000
seed: 182
num_workers: 0
mvs_visual: False
chunk_size: 1
resume_checkpoint: "model_weights/non_ema_228000.pth"
log_dir: 'outputs'
num_steps: 1

View File

@@ -0,0 +1,26 @@
TASK_TYPE: SINGLE_TASK
read_image: False # You NEED Open It
read_style: False
read_sketch: True
save_origin_video: False
ENABLE: true
DATASET: webvid10m
video_compositions: ['text', 'mask', 'depthmap', 'sketch', 'motion', 'image', 'local_image', 'single_sketch']
guidances: ['y', 'single_sketch'] # You NEED Open It
batch_sizes: {
"1": 1,
"4": 1,
"8": 1,
"16": 1,
}
vit_image_size: 224
network_name: UNetSD_temporal
resume: true
resume_step: 228000
seed: 182
num_workers: 0
mvs_visual: False
chunk_size: 1
resume_checkpoint: "model_weights/non_ema_228000.pth"
log_dir: 'outputs'
num_steps: 1

View File

@@ -0,0 +1,26 @@
TASK_TYPE: SINGLE_TASK
read_image: False # You NEED Open It
read_style: False
read_sketch: False
save_origin_video: True
ENABLE: true
DATASET: webvid10m
video_compositions: ['text', 'mask', 'depthmap', 'sketch', 'motion', 'image', 'local_image', 'single_sketch']
guidances: ['y', 'depth'] # You NEED Open It
batch_sizes: {
"1": 1,
"4": 1,
"8": 1,
"16": 1,
}
vit_image_size: 224
network_name: UNetSD_temporal
resume: true
resume_step: 228000
seed: 182
num_workers: 0
mvs_visual: False
chunk_size: 1
resume_checkpoint: "model_weights/non_ema_228000.pth"
log_dir: 'outputs'
num_steps: 1

View File

@@ -0,0 +1,26 @@
TASK_TYPE: SINGLE_TASK
read_image: False
read_style: True
read_sketch: False
save_origin_video: True
ENABLE: true
DATASET: webvid10m
video_compositions: ['text', 'mask', 'depthmap', 'sketch', 'motion', 'image', 'local_image', 'single_sketch']
guidances: ['y', 'image', 'depth'] # You NEED Open It
batch_sizes: {
"1": 1,
"4": 1,
"8": 1,
"16": 1,
}
vit_image_size: 224
network_name: UNetSD_temporal
resume: true
resume_step: 228000
seed: 182
num_workers: 0
mvs_visual: False
chunk_size: 1
resume_checkpoint: "model_weights/non_ema_141000_no_watermark.pth"
log_dir: 'outputs'
num_steps: 1

View File

@@ -0,0 +1,21 @@
TASK_TYPE: VideoComposer_Inference
ENABLE: true
DATASET: webvid10m
video_compositions: ['text', 'mask', 'depthmap', 'sketch', 'motion', 'image', 'local_image', 'single_sketch']
batch_sizes: {
"1": 1,
"4": 1,
"8": 1,
"16": 1,
}
vit_image_size: 224
network_name: UNetSD_temporal
resume: true
resume_step: 141000
seed: 14
num_workers: 1
mvs_visual: True
chunk_size: 1
resume_checkpoint: "model_weights/non_ema_141000_no_watermark.pth"
log_dir: 'outputs'
num_steps: 1

View File

@@ -0,0 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from .samplers import *
from .tokenizers import *
from .transforms import *

View File

@@ -0,0 +1,158 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os.path as osp
import json
import numpy as np
from torch.utils.data.sampler import Sampler
from modelscope.models.multi_modal.videocomposer.ops.distributed import (
get_rank, get_world_size, shared_random_seed)
from modelscope.models.multi_modal.videocomposer.ops.utils import (ceil_divide,
read)
__all__ = ['BatchSampler', 'GroupSampler', 'ImgGroupSampler']
class BatchSampler(Sampler):
r"""An infinite batch sampler.
"""
def __init__(self,
dataset_size,
batch_size,
num_replicas=None,
rank=None,
shuffle=False,
seed=None):
self.dataset_size = dataset_size
self.batch_size = batch_size
self.num_replicas = num_replicas or get_world_size()
self.rank = rank or get_rank()
self.shuffle = shuffle
self.seed = seed or shared_random_seed()
self.rng = np.random.default_rng(self.seed + self.rank)
self.batches_per_rank = ceil_divide(
dataset_size, self.num_replicas * self.batch_size)
self.samples_per_rank = self.batches_per_rank * self.batch_size
# rank indices
indices = self.rng.permutation(
self.samples_per_rank) if shuffle else np.arange(
self.samples_per_rank)
indices = indices * self.num_replicas + self.rank
indices = indices[indices < dataset_size]
self.indices = indices
def __iter__(self):
start = 0
while True:
batch = [
self.indices[i % len(self.indices)]
for i in range(start, start + self.batch_size)
]
if self.shuffle and (start + self.batch_size) > len(self.indices):
self.rng.shuffle(self.indices)
start = (start + self.batch_size) % len(self.indices)
yield batch
class GroupSampler(Sampler):
def __init__(self,
group_file,
batch_size,
alpha=0.7,
update_interval=5000,
seed=8888):
self.group_file = group_file
self.group_folder = osp.join(osp.dirname(group_file), 'groups')
self.batch_size = batch_size
self.alpha = alpha
self.update_interval = update_interval
self.seed = seed
self.rng = np.random.default_rng(seed)
def __iter__(self):
while True:
# keep groups up-to-date
self.update_groups()
# collect items
items = self.sample()
while len(items) < self.batch_size:
items += self.sample()
# sample a batch
batch = self.rng.choice(
items,
self.batch_size,
replace=False if len(items) >= self.batch_size else True)
yield [u.strip().split(',') for u in batch]
def update_groups(self):
if not hasattr(self, '_step'):
self._step = 0
if self._step % self.update_interval == 0:
self.groups = json.loads(read(self.group_file))
self._step += 1
def sample(self):
scales = np.array(
[float(next(iter(u)).split(':')[-1]) for u in self.groups])
p = scales**self.alpha / (scales**self.alpha).sum()
group = self.rng.choice(self.groups, p=p)
list_file = osp.join(self.group_folder,
self.rng.choice(next(iter(group.values()))))
return read(list_file).strip().split('\n')
class ImgGroupSampler(Sampler):
def __init__(self,
group_file,
batch_size,
alpha=0.7,
update_interval=5000,
seed=8888):
self.group_file = group_file
self.group_folder = osp.join(osp.dirname(group_file), 'groups')
self.batch_size = batch_size
self.alpha = alpha
self.update_interval = update_interval
self.seed = seed
self.rng = np.random.default_rng(seed)
def __iter__(self):
while True:
# keep groups up-to-date
self.update_groups()
# collect items
items = self.sample()
while len(items) < self.batch_size:
items += self.sample()
# sample a batch
batch = self.rng.choice(
items,
self.batch_size,
replace=False if len(items) >= self.batch_size else True)
yield [u.strip().split(',', 1) for u in batch]
def update_groups(self):
if not hasattr(self, '_step'):
self._step = 0
if self._step % self.update_interval == 0:
self.groups = json.loads(read(self.group_file))
self._step += 1
def sample(self):
scales = np.array(
[float(next(iter(u)).split(':')[-1]) for u in self.groups])
p = scales**self.alpha / (scales**self.alpha).sum()
group = self.rng.choice(self.groups, p=p)
list_file = osp.join(self.group_folder,
self.rng.choice(next(iter(group.values()))))
return read(list_file).strip().split('\n')

View File

@@ -0,0 +1,184 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import gzip
import html
import os
from functools import lru_cache
import ftfy
import regex as re
import torch
from tokenizers import BertWordPieceTokenizer, CharBPETokenizer
__all__ = ['CLIPTokenizer']
@lru_cache()
def default_bpe():
root = os.path.realpath(__file__)
root = '/'.join(root.split('/')[:-1])
return os.path.join(root, 'bpe_simple_vocab_16e6.txt.gz')
@lru_cache()
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a signficant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
bs = list(range(ord('!'),
ord('~') + 1)) + list(range(
ord('¡'),
ord('¬') + 1)) + list(range(ord('®'),
ord('ÿ') + 1))
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8 + n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
def get_pairs(word):
"""Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
def basic_clean(text):
text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text))
return text.strip()
def whitespace_clean(text):
text = re.sub(r'\s+', ' ', text)
text = text.strip()
return text
class SimpleTokenizer(object):
def __init__(self, bpe_path: str = default_bpe()):
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
merges = gzip.open(bpe_path).read().decode('utf-8').split('\n')
merges = merges[1:49152 - 256 - 2 + 1]
merges = [tuple(merge.split()) for merge in merges]
vocab = list(bytes_to_unicode().values())
vocab = vocab + [v + '</w>' for v in vocab]
for merge in merges:
vocab.append(''.join(merge))
vocab.extend(['<|startoftext|>', '<|endoftext|>'])
self.encoder = dict(zip(vocab, range(len(vocab))))
self.decoder = {v: k for k, v in self.encoder.items()}
self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {
'<|startoftext|>': '<|startoftext|>',
'<|endoftext|>': '<|endoftext|>'
}
self.pat = re.compile(
r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
re.IGNORECASE)
def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token[:-1]) + (token[-1] + '</w>', )
pairs = get_pairs(word)
if not pairs:
return token + '</w>'
while True:
bigram = min(
pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except Exception as e:
new_word.extend(word[i:])
print(e)
break
if word[i] == first and i < len(word) - 1 and word[
i + 1] == second:
new_word.append(first + second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = ' '.join(word)
self.cache[token] = word
return word
def encode(self, text):
bpe_tokens = []
text = whitespace_clean(basic_clean(text)).lower()
for token in re.findall(self.pat, text):
token = ''.join(self.byte_encoder[b]
for b in token.encode('utf-8'))
bpe_tokens.extend(self.encoder[bpe_token]
for bpe_token in self.bpe(token).split(' '))
return bpe_tokens
def decode(self, tokens):
text = ''.join([self.decoder[token] for token in tokens])
text = bytearray([self.byte_decoder[c] for c in text]).decode(
'utf-8', errors='replace').replace('</w>', ' ')
return text
class CLIPTokenizer(object):
def __init__(self, length=77):
self.length = length
# init tokenizer
self.tokenizer = SimpleTokenizer(bpe_path=default_bpe())
self.sos_token = self.tokenizer.encoder['<|startoftext|>']
self.eos_token = self.tokenizer.encoder['<|endoftext|>']
self.vocab_size = len(self.tokenizer.encoder)
def __call__(self, sequence):
if isinstance(sequence, str):
return torch.LongTensor(self._tokenizer(sequence))
elif isinstance(sequence, list):
return torch.LongTensor([self._tokenizer(u) for u in sequence])
else:
raise TypeError(
f'Expected the "sequence" to be a string or a list, but got {type(sequence)}'
)
def _tokenizer(self, text):
tokens = self.tokenizer.encode(text)[:self.length - 2]
tokens = [self.sos_token] + tokens + [self.eos_token]
tokens = tokens + [0] * (self.length - len(tokens))
return tokens

View File

@@ -0,0 +1,400 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import math
import random
import numpy as np
import torch
import torchvision.transforms.functional as F
import torchvision.transforms.functional as TF
from PIL import Image, ImageFilter
from torchvision.transforms.functional import InterpolationMode
__all__ = [
'Compose', 'Resize', 'Rescale', 'CenterCrop', 'CenterCropV2', 'RandomCrop',
'RandomCropV2', 'RandomHFlip', 'GaussianBlur', 'ColorJitter', 'RandomGray',
'ToTensor', 'Normalize', 'ResizeRandomCrop', 'ExtractResizeRandomCrop',
'ExtractResizeAssignCrop'
]
def random_resize(img, size):
img = [
TF.resize(
u,
size,
interpolation=random.choice([
InterpolationMode.BILINEAR, InterpolationMode.BICUBIC,
InterpolationMode.LANCZOS
])) for u in img
]
return img
class CenterCropV3(object):
def __init__(self, size):
self.size = size
def __call__(self, img):
# fast resize
while min(img.size) >= 2 * self.size:
img = img.resize((img.width // 2, img.height // 2),
resample=Image.BOX)
scale = self.size / min(img.size)
img = img.resize((round(scale * img.width), round(scale * img.height)),
resample=Image.BICUBIC)
# center crop
x1 = (img.width - self.size) // 2
y1 = (img.height - self.size) // 2
img = img.crop((x1, y1, x1 + self.size, y1 + self.size))
return img
class Compose(object):
def __init__(self, transforms):
self.transforms = transforms
def __getitem__(self, index):
if isinstance(index, slice):
return Compose(self.transforms[index])
else:
return self.transforms[index]
def __len__(self):
return len(self.transforms)
def __call__(self, rgb):
for t in self.transforms:
rgb = t(rgb)
return rgb
class Resize(object):
def __init__(self, size=256):
if isinstance(size, int):
size = (size, size)
self.size = size
def __call__(self, rgb):
rgb = [u.resize(self.size, Image.BILINEAR) for u in rgb]
return rgb
class Rescale(object):
def __init__(self, size=256, interpolation=Image.BILINEAR):
self.size = size
self.interpolation = interpolation
def __call__(self, rgb):
w, h = rgb[0].size
scale = self.size / min(w, h)
out_w, out_h = int(round(w * scale)), int(round(h * scale))
rgb = [u.resize((out_w, out_h), self.interpolation) for u in rgb]
return rgb
class CenterCrop(object):
def __init__(self, size=224):
self.size = size
def __call__(self, rgb):
w, h = rgb[0].size
assert min(w, h) >= self.size
x1 = (w - self.size) // 2
y1 = (h - self.size) // 2
rgb = [u.crop((x1, y1, x1 + self.size, y1 + self.size)) for u in rgb]
return rgb
class ResizeRandomCrop(object):
def __init__(self, size=256, size_short=292):
self.size = size
# self.min_area = min_area
self.size_short = size_short
def __call__(self, rgb):
# consistent crop between rgb and m
while min(rgb[0].size) >= 2 * self.size_short:
rgb = [
u.resize((u.width // 2, u.height // 2), resample=Image.BOX)
for u in rgb
]
scale = self.size_short / min(rgb[0].size)
rgb = [
u.resize((round(scale * u.width), round(scale * u.height)),
resample=Image.BICUBIC) for u in rgb
]
out_w = self.size
out_h = self.size
w, h = rgb[0].size # (518, 292)
x1 = random.randint(0, w - out_w)
y1 = random.randint(0, h - out_h)
rgb = [u.crop((x1, y1, x1 + out_w, y1 + out_h)) for u in rgb]
return rgb
class ExtractResizeRandomCrop(object):
def __init__(self, size=256, size_short=292):
self.size = size
self.size_short = size_short
def __call__(self, rgb):
# consistent crop between rgb and m
while min(rgb[0].size) >= 2 * self.size_short:
rgb = [
u.resize((u.width // 2, u.height // 2), resample=Image.BOX)
for u in rgb
]
scale = self.size_short / min(rgb[0].size)
rgb = [
u.resize((round(scale * u.width), round(scale * u.height)),
resample=Image.BICUBIC) for u in rgb
]
out_w = self.size
out_h = self.size
w, h = rgb[0].size # (518, 292)
x1 = random.randint(0, w - out_w)
y1 = random.randint(0, h - out_h)
rgb = [u.crop((x1, y1, x1 + out_w, y1 + out_h)) for u in rgb]
wh = [x1, y1, x1 + out_w, y1 + out_h]
return rgb, wh
class ExtractResizeAssignCrop(object):
def __init__(self, size=256, size_short=292):
self.size = size
self.size_short = size_short
def __call__(self, rgb, wh):
# consistent crop between rgb and m
while min(rgb[0].size) >= 2 * self.size_short:
rgb = [
u.resize((u.width // 2, u.height // 2), resample=Image.BOX)
for u in rgb
]
scale = self.size_short / min(rgb[0].size)
rgb = [
u.resize((round(scale * u.width), round(scale * u.height)),
resample=Image.BICUBIC) for u in rgb
]
rgb = [u.crop(wh) for u in rgb]
rgb = [u.resize((self.size, self.size), Image.BILINEAR) for u in rgb]
return rgb
class CenterCropV2(object):
def __init__(self, size):
self.size = size
def __call__(self, img):
# fast resize
while min(img[0].size) >= 2 * self.size:
img = [
u.resize((u.width // 2, u.height // 2), resample=Image.BOX)
for u in img
]
scale = self.size / min(img[0].size)
img = [
u.resize((round(scale * u.width), round(scale * u.height)),
resample=Image.BICUBIC) for u in img
]
# center crop
x1 = (img[0].width - self.size) // 2
y1 = (img[0].height - self.size) // 2
img = [u.crop((x1, y1, x1 + self.size, y1 + self.size)) for u in img]
return img
class RandomCrop(object):
def __init__(self, size=224, min_area=0.4):
self.size = size
self.min_area = min_area
def __call__(self, rgb):
# consistent crop between rgb and m
w, h = rgb[0].size
area = w * h
out_w, out_h = float('inf'), float('inf')
while out_w > w or out_h > h:
target_area = random.uniform(self.min_area, 1.0) * area
aspect_ratio = random.uniform(3. / 4., 4. / 3.)
out_w = int(round(math.sqrt(target_area * aspect_ratio)))
out_h = int(round(math.sqrt(target_area / aspect_ratio)))
x1 = random.randint(0, w - out_w)
y1 = random.randint(0, h - out_h)
rgb = [u.crop((x1, y1, x1 + out_w, y1 + out_h)) for u in rgb]
rgb = [u.resize((self.size, self.size), Image.BILINEAR) for u in rgb]
return rgb
class RandomCropV2(object):
def __init__(self, size=224, min_area=0.4, ratio=(3. / 4., 4. / 3.)):
if isinstance(size, (tuple, list)):
self.size = size
else:
self.size = (size, size)
self.min_area = min_area
self.ratio = ratio
def _get_params(self, img):
width, height = img.size
area = height * width
for _ in range(10):
target_area = random.uniform(self.min_area, 1.0) * area
log_ratio = (math.log(self.ratio[0]), math.log(self.ratio[1]))
aspect_ratio = math.exp(random.uniform(*log_ratio))
w = int(round(math.sqrt(target_area * aspect_ratio)))
h = int(round(math.sqrt(target_area / aspect_ratio)))
if 0 < w <= width and 0 < h <= height:
i = random.randint(0, height - h)
j = random.randint(0, width - w)
return i, j, h, w
# Fallback to central crop
in_ratio = float(width) / float(height)
if (in_ratio < min(self.ratio)):
w = width
h = int(round(w / min(self.ratio)))
elif (in_ratio > max(self.ratio)):
h = height
w = int(round(h * max(self.ratio)))
else: # whole image
w = width
h = height
i = (height - h) // 2
j = (width - w) // 2
return i, j, h, w
def __call__(self, rgb):
i, j, h, w = self._get_params(rgb[0])
rgb = [F.resized_crop(u, i, j, h, w, self.size) for u in rgb]
return rgb
class RandomHFlip(object):
def __init__(self, p=0.5):
self.p = p
def __call__(self, rgb):
if random.random() < self.p:
rgb = [u.transpose(Image.FLIP_LEFT_RIGHT) for u in rgb]
return rgb
class GaussianBlur(object):
def __init__(self, sigmas=[0.1, 2.0], p=0.5):
self.sigmas = sigmas
self.p = p
def __call__(self, rgb):
if random.random() < self.p:
sigma = random.uniform(*self.sigmas)
rgb = [
u.filter(ImageFilter.GaussianBlur(radius=sigma)) for u in rgb
]
return rgb
class ColorJitter(object):
def __init__(self,
brightness=0.4,
contrast=0.4,
saturation=0.4,
hue=0.1,
p=0.5):
self.brightness = brightness
self.contrast = contrast
self.saturation = saturation
self.hue = hue
self.p = p
def __call__(self, rgb):
if random.random() < self.p:
brightness, contrast, saturation, hue = self._random_params()
transforms = [
lambda f: F.adjust_brightness(f, brightness),
lambda f: F.adjust_contrast(f, contrast),
lambda f: F.adjust_saturation(f, saturation),
lambda f: F.adjust_hue(f, hue)
]
random.shuffle(transforms)
for t in transforms:
rgb = [t(u) for u in rgb]
return rgb
def _random_params(self):
brightness = random.uniform(
max(0, 1 - self.brightness), 1 + self.brightness)
contrast = random.uniform(max(0, 1 - self.contrast), 1 + self.contrast)
saturation = random.uniform(
max(0, 1 - self.saturation), 1 + self.saturation)
hue = random.uniform(-self.hue, self.hue)
return brightness, contrast, saturation, hue
class RandomGray(object):
def __init__(self, p=0.2):
self.p = p
def __call__(self, rgb):
if random.random() < self.p:
rgb = [u.convert('L').convert('RGB') for u in rgb]
return rgb
class ToTensor(object):
def __call__(self, rgb):
rgb = torch.stack([F.to_tensor(u) for u in rgb], dim=0)
return rgb
class Normalize(object):
def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
self.mean = mean
self.std = std
def __call__(self, rgb):
rgb = rgb.clone()
rgb.clamp_(0, 1)
if not isinstance(self.mean, torch.Tensor):
self.mean = rgb.new_tensor(self.mean).view(-1)
if not isinstance(self.std, torch.Tensor):
self.std = rgb.new_tensor(self.std).view(-1)
rgb.sub_(self.mean.view(1, -1, 1, 1)).div_(self.std.view(1, -1, 1, 1))
return rgb

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,120 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import math
import os
import random
import time
import numpy as np
import torch
import torch.cuda.amp as amp
import torch.nn as nn
import torch.nn.functional as F
from flash_attn.flash_attention import FlashAttention
class FlashAttentionBlock(nn.Module):
def __init__(self,
dim,
context_dim=None,
num_heads=None,
head_dim=None,
batch_size=4):
# consider head_dim first, then num_heads
num_heads = dim // head_dim if head_dim else num_heads
head_dim = dim // num_heads
assert num_heads * head_dim == dim
super(FlashAttentionBlock, self).__init__()
self.dim = dim
self.context_dim = context_dim
self.num_heads = num_heads
self.head_dim = head_dim
self.scale = math.pow(head_dim, -0.25)
# layers
self.norm = nn.GroupNorm(32, dim)
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
if context_dim is not None:
self.context_kv = nn.Linear(context_dim, dim * 2)
self.proj = nn.Conv2d(dim, dim, 1)
if self.head_dim <= 128 and (self.head_dim % 8) == 0:
self.flash_attn = FlashAttention(
softmax_scale=None, attention_dropout=0.0)
# zero out the last layer params
nn.init.zeros_(self.proj.weight)
def _init_weight(self, module):
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=0.15)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Conv2d):
module.weight.data.normal_(mean=0.0, std=0.15)
if module.bias is not None:
module.bias.data.zero_()
def forward(self, x, context=None):
r"""x: [B, C, H, W].
context: [B, L, C] or None.
"""
identity = x
b, c, h, w, n, d = *x.size(), self.num_heads, self.head_dim
# compute query, key, value
x = self.norm(x)
q, k, v = self.to_qkv(x).view(b, n * 3, d, h * w).chunk(3, dim=1)
if context is not None:
ck, cv = self.context_kv(context).reshape(b, -1, n * 2,
d).permute(0, 2, 3,
1).chunk(
2, dim=1)
k = torch.cat([ck, k], dim=-1)
v = torch.cat([cv, v], dim=-1)
cq = torch.zeros([b, n, d, 4], dtype=q.dtype, device=q.device)
q = torch.cat([q, cq], dim=-1)
qkv = torch.cat([q, k, v], dim=1)
origin_dtype = qkv.dtype
qkv = qkv.permute(0, 3, 1, 2).reshape(b, -1, 3, n,
d).half().contiguous()
out, _ = self.flash_attn(qkv)
out.to(origin_dtype)
if context is not None:
out = out[:, :-4, :, :]
out = out.permute(0, 2, 3, 1).reshape(b, c, h, w)
# output
x = self.proj(out)
return x + identity
if __name__ == '__main__':
batch_size = 8
flash_net = FlashAttentionBlock(
dim=1280,
context_dim=512,
num_heads=None,
head_dim=64,
batch_size=batch_size).cuda()
x = torch.randn([batch_size, 1280, 32, 32], dtype=torch.float32).cuda()
context = torch.randn([batch_size, 4, 512], dtype=torch.float32).cuda()
# context = None
flash_net.eval()
with amp.autocast(enabled=True):
# warm up
for i in range(5):
y = flash_net(x, context)
torch.cuda.synchronize()
s1 = time.time()
for i in range(10):
y = flash_net(x, context)
torch.cuda.synchronize()
s2 = time.time()
print(f'Average cost time {(s2-s1)*1000/10} ms')

View File

@@ -0,0 +1,2 @@
from .clip import *
from .midas import *

View File

@@ -0,0 +1,460 @@
import math
import os.path as osp
import torch
import torch.nn as nn
import torch.nn.functional as F
import modelscope.models.multi_modal.videocomposer.ops as ops
__all__ = [
'CLIP', 'clip_vit_b_32', 'clip_vit_b_16', 'clip_vit_l_14',
'clip_vit_l_14_336px', 'clip_vit_h_16'
]
def DOWNLOAD_TO_CACHE(oss_key,
file_or_dirname=None,
cache_dir=osp.join(
'/'.join(osp.abspath(__file__).split('/')[:-2]),
'model_weights')):
r"""Download OSS [file or folder] to the cache folder.
Only the 0th process on each node will run the downloading.
Barrier all processes until the downloading is completed.
"""
# source and target paths
base_path = osp.join(cache_dir, file_or_dirname or osp.basename(oss_key))
return base_path
def to_fp16(m):
if isinstance(m, (nn.Linear, nn.Conv2d)):
m.weight.data = m.weight.data.half()
if m.bias is not None:
m.bias.data = m.bias.data.half()
elif hasattr(m, 'head'):
p = getattr(m, 'head')
p.data = p.data.half()
class QuickGELU(nn.Module):
def forward(self, x):
return x * torch.sigmoid(1.702 * x)
class LayerNorm(nn.LayerNorm):
r"""Subclass of nn.LayerNorm to handle fp16.
"""
def forward(self, x):
return super(LayerNorm, self).forward(x.float()).type_as(x)
class SelfAttention(nn.Module):
def __init__(self, dim, num_heads, attn_dropout=0.0, proj_dropout=0.0):
assert dim % num_heads == 0
super(SelfAttention, self).__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = 1.0 / math.sqrt(self.head_dim)
# layers
self.to_qkv = nn.Linear(dim, dim * 3)
self.attn_dropout = nn.Dropout(attn_dropout)
self.proj = nn.Linear(dim, dim)
self.proj_dropout = nn.Dropout(proj_dropout)
def forward(self, x, mask=None):
r"""x: [B, L, C].
mask: [*, L, L].
"""
b, l, _, n = *x.size(), self.num_heads
# compute query, key, and value
q, k, v = self.to_qkv(x.transpose(0, 1)).chunk(3, dim=-1)
q = q.reshape(l, b * n, -1).transpose(0, 1)
k = k.reshape(l, b * n, -1).transpose(0, 1)
v = v.reshape(l, b * n, -1).transpose(0, 1)
# compute attention
attn = self.scale * torch.bmm(q, k.transpose(1, 2))
if mask is not None:
attn = attn.masked_fill(mask[:, :l, :l] == 0, float('-inf'))
attn = F.softmax(attn.float(), dim=-1).type_as(attn)
attn = self.attn_dropout(attn)
# gather context
x = torch.bmm(attn, v)
x = x.view(b, n, l, -1).transpose(1, 2).reshape(b, l, -1)
# output
x = self.proj(x)
x = self.proj_dropout(x)
return x
class AttentionBlock(nn.Module):
def __init__(self, dim, num_heads, attn_dropout=0.0, proj_dropout=0.0):
super(AttentionBlock, self).__init__()
self.dim = dim
self.num_heads = num_heads
# layers
self.norm1 = LayerNorm(dim)
self.attn = SelfAttention(dim, num_heads, attn_dropout, proj_dropout)
self.norm2 = LayerNorm(dim)
self.mlp = nn.Sequential(
nn.Linear(dim, dim * 4), QuickGELU(), nn.Linear(dim * 4, dim),
nn.Dropout(proj_dropout))
def forward(self, x, mask=None):
x = x + self.attn(self.norm1(x), mask)
x = x + self.mlp(self.norm2(x))
return x
class VisionTransformer(nn.Module):
def __init__(self,
image_size=224,
patch_size=16,
dim=768,
out_dim=512,
num_heads=12,
num_layers=12,
attn_dropout=0.0,
proj_dropout=0.0,
embedding_dropout=0.0):
assert image_size % patch_size == 0
super(VisionTransformer, self).__init__()
self.image_size = image_size
self.patch_size = patch_size
self.dim = dim
self.out_dim = out_dim
self.num_heads = num_heads
self.num_layers = num_layers
self.num_patches = (image_size // patch_size)**2
# embeddings
gain = 1.0 / math.sqrt(dim)
self.patch_embedding = nn.Conv2d(
3, dim, kernel_size=patch_size, stride=patch_size, bias=False)
self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
self.pos_embedding = nn.Parameter(
gain * torch.randn(1, self.num_patches + 1, dim))
self.dropout = nn.Dropout(embedding_dropout)
# transformer
self.pre_norm = LayerNorm(dim)
self.transformer = nn.Sequential(*[
AttentionBlock(dim, num_heads, attn_dropout, proj_dropout)
for _ in range(num_layers)
])
self.post_norm = LayerNorm(dim)
# head
self.head = nn.Parameter(gain * torch.randn(dim, out_dim))
def forward(self, x):
b, dtype = x.size(0), self.head.dtype
x = x.type(dtype)
# patch-embedding
x = self.patch_embedding(x).flatten(2).permute(0, 2, 1) # [b, n, c]
x = torch.cat([self.cls_embedding.repeat(b, 1, 1).type(dtype), x],
dim=1)
x = self.dropout(x + self.pos_embedding.type(dtype))
x = self.pre_norm(x)
# transformer
x = self.transformer(x)
# head
x = self.post_norm(x)
x = torch.mm(x[:, 0, :], self.head)
return x
def fp16(self):
return self.apply(to_fp16)
class TextTransformer(nn.Module):
def __init__(self,
vocab_size,
text_len,
dim=512,
out_dim=512,
num_heads=8,
num_layers=12,
attn_dropout=0.0,
proj_dropout=0.0,
embedding_dropout=0.0):
super(TextTransformer, self).__init__()
self.vocab_size = vocab_size
self.text_len = text_len
self.dim = dim
self.out_dim = out_dim
self.num_heads = num_heads
self.num_layers = num_layers
# embeddings
self.token_embedding = nn.Embedding(vocab_size, dim)
self.pos_embedding = nn.Parameter(0.01 * torch.randn(1, text_len, dim))
self.dropout = nn.Dropout(embedding_dropout)
# transformer
self.transformer = nn.ModuleList([
AttentionBlock(dim, num_heads, attn_dropout, proj_dropout)
for _ in range(num_layers)
])
self.norm = LayerNorm(dim)
# head
gain = 1.0 / math.sqrt(dim)
self.head = nn.Parameter(gain * torch.randn(dim, out_dim))
# causal attention mask
self.register_buffer('attn_mask',
torch.tril(torch.ones(1, text_len, text_len)))
def forward(self, x):
eot, dtype = x.argmax(dim=-1), self.head.dtype
# embeddings
x = self.dropout(
self.token_embedding(x).type(dtype)
+ self.pos_embedding.type(dtype))
# transformer
for block in self.transformer:
x = block(x, self.attn_mask)
# head
x = self.norm(x)
x = torch.mm(x[torch.arange(x.size(0)), eot], self.head)
return x
def fp16(self):
return self.apply(to_fp16)
class CLIP(nn.Module):
def __init__(self,
embed_dim=512,
image_size=224,
patch_size=16,
vision_dim=768,
vision_heads=12,
vision_layers=12,
vocab_size=49408,
text_len=77,
text_dim=512,
text_heads=8,
text_layers=12,
attn_dropout=0.0,
proj_dropout=0.0,
embedding_dropout=0.0):
super(CLIP, self).__init__()
self.embed_dim = embed_dim
self.image_size = image_size
self.patch_size = patch_size
self.vision_dim = vision_dim
self.vision_heads = vision_heads
self.vision_layers = vision_layers
self.vocab_size = vocab_size
self.text_len = text_len
self.text_dim = text_dim
self.text_heads = text_heads
self.text_layers = text_layers
# models
self.visual = VisionTransformer(
image_size=image_size,
patch_size=patch_size,
dim=vision_dim,
out_dim=embed_dim,
num_heads=vision_heads,
num_layers=vision_layers,
attn_dropout=attn_dropout,
proj_dropout=proj_dropout,
embedding_dropout=embedding_dropout)
self.textual = TextTransformer(
vocab_size=vocab_size,
text_len=text_len,
dim=text_dim,
out_dim=embed_dim,
num_heads=text_heads,
num_layers=text_layers,
attn_dropout=attn_dropout,
proj_dropout=proj_dropout,
embedding_dropout=embedding_dropout)
self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([]))
def forward(self, imgs, txt_tokens):
r"""imgs: [B, C, H, W] of torch.float32.
txt_tokens: [B, T] of torch.long.
"""
xi = self.visual(imgs)
xt = self.textual(txt_tokens)
# normalize features
xi = F.normalize(xi, p=2, dim=1)
xt = F.normalize(xt, p=2, dim=1)
# gather features from all ranks
full_xi = ops.diff_all_gather(xi)
full_xt = ops.diff_all_gather(xt)
# logits
scale = self.log_scale.exp()
logits_i2t = scale * torch.mm(xi, full_xt.t())
logits_t2i = scale * torch.mm(xt, full_xi.t())
# labels
labels = torch.arange(
len(xi) * ops.get_rank(),
len(xi) * (ops.get_rank() + 1),
dtype=torch.long,
device=xi.device)
return logits_i2t, logits_t2i, labels
def init_weights(self):
# embeddings
nn.init.normal_(self.textual.token_embedding.weight, std=0.02)
nn.init.normal_(self.visual.patch_embedding.weight, tsd=0.1)
# attentions
for modality in ['visual', 'textual']:
dim = self.vision_dim if modality == 'visual' else 'textual'
transformer = getattr(self, modality).transformer
proj_gain = (1.0 / math.sqrt(dim)) * (
1.0 / math.sqrt(2 * transformer.num_layers))
attn_gain = 1.0 / math.sqrt(dim)
mlp_gain = 1.0 / math.sqrt(2.0 * dim)
for block in transformer.layers:
nn.init.normal_(block.attn.to_qkv.weight, std=attn_gain)
nn.init.normal_(block.attn.proj.weight, std=proj_gain)
nn.init.normal_(block.mlp[0].weight, std=mlp_gain)
nn.init.normal_(block.mlp[2].weight, std=proj_gain)
def param_groups(self):
groups = [{
'params': [
p for n, p in self.named_parameters()
if 'norm' in n or n.endswith('bias')
],
'weight_decay':
0.0
}, {
'params': [
p for n, p in self.named_parameters()
if not ('norm' in n or n.endswith('bias'))
]
}]
return groups
def fp16(self):
return self.apply(to_fp16)
def _clip(name, pretrained=False, **kwargs):
model = CLIP(**kwargs)
if pretrained:
model.load_state_dict(
torch.load(
DOWNLOAD_TO_CACHE(f'models/clip/{name}.pth'),
map_location='cpu'))
return model
def clip_vit_b_32(pretrained=False, **kwargs):
cfg = dict(
embed_dim=512,
image_size=224,
patch_size=32,
vision_dim=768,
vision_heads=12,
vision_layers=12,
vocab_size=49408,
text_len=77,
text_dim=512,
text_heads=8,
text_layers=12)
cfg.update(**kwargs)
return _clip('openai-clip-vit-base-32', pretrained, **cfg)
def clip_vit_b_16(pretrained=False, **kwargs):
cfg = dict(
embed_dim=512,
image_size=224,
patch_size=32,
vision_dim=768,
vision_heads=12,
vision_layers=12,
vocab_size=49408,
text_len=77,
text_dim=512,
text_heads=8,
text_layers=12)
cfg.update(**kwargs)
return _clip('openai-clip-vit-base-16', pretrained, **cfg)
def clip_vit_l_14(pretrained=False, **kwargs):
cfg = dict(
embed_dim=768,
image_size=224,
patch_size=14,
vision_dim=1024,
vision_heads=16,
vision_layers=24,
vocab_size=49408,
text_len=77,
text_dim=768,
text_heads=12,
text_layers=12)
cfg.update(**kwargs)
return _clip('openai-clip-vit-large-14', pretrained, **cfg)
def clip_vit_l_14_336px(pretrained=False, **kwargs):
cfg = dict(
embed_dim=768,
image_size=336,
patch_size=14,
vision_dim=1024,
vision_heads=16,
vision_layers=24,
vocab_size=49408,
text_len=77,
text_dim=768,
text_heads=12,
text_layers=12)
cfg.update(**kwargs)
return _clip('openai-clip-vit-large-14-336px', pretrained, **cfg)
def clip_vit_h_16(pretrained=False, **kwargs):
assert not pretrained, 'pretrained model for openai-clip-vit-huge-16 is not available!'
cfg = dict(
embed_dim=1024,
image_size=256,
patch_size=16,
vision_dim=1280,
vision_heads=16,
vision_layers=32,
vocab_size=49408,
text_len=77,
text_dim=1024,
text_heads=16,
text_layers=24)
cfg.update(**kwargs)
return _clip('openai-clip-vit-huge-16', pretrained, **cfg)

View File

@@ -0,0 +1,320 @@
r"""A much cleaner re-implementation of ``https://github.com/isl-org/MiDaS''.
Image augmentation: T.Compose([
Resize(
keep_aspect_ratio=True,
ensure_multiple_of=32,
interpolation=cv2.INTER_CUBIC),
T.ToTensor(),
T.Normalize(
mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5])]).
Fast inference:
model = model.to(memory_format=torch.channels_last).half()
input = input.to(memory_format=torch.channels_last).half()
output = model(input)
"""
import math
import os
import os.path as osp
import torch
import torch.nn as nn
import torch.nn.functional as F
__all__ = ['MiDaS', 'midas_v3']
class SelfAttention(nn.Module):
def __init__(self, dim, num_heads):
assert dim % num_heads == 0
super(SelfAttention, self).__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = 1.0 / math.sqrt(self.head_dim)
# layers
self.to_qkv = nn.Linear(dim, dim * 3)
self.proj = nn.Linear(dim, dim)
def forward(self, x):
b, l, c, n, d = *x.size(), self.num_heads, self.head_dim
# compute query, key, value
q, k, v = self.to_qkv(x).view(b, l, n * 3, d).chunk(3, dim=2)
# compute attention
attn = self.scale * torch.einsum('binc,bjnc->bnij', q, k)
attn = F.softmax(attn.float(), dim=-1).type_as(attn)
# gather context
x = torch.einsum('bnij,bjnc->binc', attn, v)
x = x.reshape(b, l, c)
# output
x = self.proj(x)
return x
class AttentionBlock(nn.Module):
def __init__(self, dim, num_heads):
super(AttentionBlock, self).__init__()
self.dim = dim
self.num_heads = num_heads
# layers
self.norm1 = nn.LayerNorm(dim)
self.attn = SelfAttention(dim, num_heads)
self.norm2 = nn.LayerNorm(dim)
self.mlp = nn.Sequential(
nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim))
def forward(self, x):
x = x + self.attn(self.norm1(x))
x = x + self.mlp(self.norm2(x))
return x
class VisionTransformer(nn.Module):
def __init__(self,
image_size=384,
patch_size=16,
dim=1024,
out_dim=1000,
num_heads=16,
num_layers=24):
assert image_size % patch_size == 0
super(VisionTransformer, self).__init__()
self.image_size = image_size
self.patch_size = patch_size
self.dim = dim
self.out_dim = out_dim
self.num_heads = num_heads
self.num_layers = num_layers
self.num_patches = (image_size // patch_size)**2
# embeddings
self.patch_embedding = nn.Conv2d(
3, dim, kernel_size=patch_size, stride=patch_size)
self.cls_embedding = nn.Parameter(torch.zeros(1, 1, dim))
self.pos_embedding = nn.Parameter(
torch.empty(1, self.num_patches + 1, dim).normal_(std=0.02))
# blocks
self.blocks = nn.Sequential(
*[AttentionBlock(dim, num_heads) for _ in range(num_layers)])
self.norm = nn.LayerNorm(dim)
# head
self.head = nn.Linear(dim, out_dim)
def forward(self, x):
b = x.size(0)
# embeddings
x = self.patch_embedding(x).flatten(2).permute(0, 2, 1)
x = torch.cat([self.cls_embedding.repeat(b, 1, 1), x], dim=1)
x = x + self.pos_embedding
# blocks
x = self.blocks(x)
x = self.norm(x)
# head
x = self.head(x)
return x
class ResidualBlock(nn.Module):
def __init__(self, dim):
super(ResidualBlock, self).__init__()
self.dim = dim
# layers
self.residual = nn.Sequential(
nn.ReLU(inplace=False), # NOTE: avoid modifying the input
nn.Conv2d(dim, dim, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(dim, dim, 3, padding=1))
def forward(self, x):
return x + self.residual(x)
class FusionBlock(nn.Module):
def __init__(self, dim):
super(FusionBlock, self).__init__()
self.dim = dim
# layers
self.layer1 = ResidualBlock(dim)
self.layer2 = ResidualBlock(dim)
self.conv_out = nn.Conv2d(dim, dim, 1)
def forward(self, *xs):
assert len(xs) in (1, 2), 'invalid number of inputs'
if len(xs) == 1:
x = self.layer2(xs[0])
else:
x = self.layer2(xs[0] + self.layer1(xs[1]))
x = F.interpolate(
x, scale_factor=2, mode='bilinear', align_corners=True)
x = self.conv_out(x)
return x
class MiDaS(nn.Module):
r"""MiDaS v3.0 DPT-Large from ``https://github.com/isl-org/MiDaS''.
Monocular depth estimation using dense prediction transformers.
"""
def __init__(self,
image_size=384,
patch_size=16,
dim=1024,
neck_dims=[256, 512, 1024, 1024],
fusion_dim=256,
num_heads=16,
num_layers=24):
assert image_size % patch_size == 0
assert num_layers % 4 == 0
super(MiDaS, self).__init__()
self.image_size = image_size
self.patch_size = patch_size
self.dim = dim
self.neck_dims = neck_dims
self.fusion_dim = fusion_dim
self.num_heads = num_heads
self.num_layers = num_layers
self.num_patches = (image_size // patch_size)**2
# embeddings
self.patch_embedding = nn.Conv2d(
3, dim, kernel_size=patch_size, stride=patch_size)
self.cls_embedding = nn.Parameter(torch.zeros(1, 1, dim))
self.pos_embedding = nn.Parameter(
torch.empty(1, self.num_patches + 1, dim).normal_(std=0.02))
# blocks
stride = num_layers // 4
self.blocks = nn.Sequential(
*[AttentionBlock(dim, num_heads) for _ in range(num_layers)])
self.slices = [slice(i * stride, (i + 1) * stride) for i in range(4)]
# stage1 (4x)
self.fc1 = nn.Sequential(nn.Linear(dim * 2, dim), nn.GELU())
self.conv1 = nn.Sequential(
nn.Conv2d(dim, neck_dims[0], 1),
nn.ConvTranspose2d(neck_dims[0], neck_dims[0], 4, stride=4),
nn.Conv2d(neck_dims[0], fusion_dim, 3, padding=1, bias=False))
self.fusion1 = FusionBlock(fusion_dim)
# stage2 (8x)
self.fc2 = nn.Sequential(nn.Linear(dim * 2, dim), nn.GELU())
self.conv2 = nn.Sequential(
nn.Conv2d(dim, neck_dims[1], 1),
nn.ConvTranspose2d(neck_dims[1], neck_dims[1], 2, stride=2),
nn.Conv2d(neck_dims[1], fusion_dim, 3, padding=1, bias=False))
self.fusion2 = FusionBlock(fusion_dim)
# stage3 (16x)
self.fc3 = nn.Sequential(nn.Linear(dim * 2, dim), nn.GELU())
self.conv3 = nn.Sequential(
nn.Conv2d(dim, neck_dims[2], 1),
nn.Conv2d(neck_dims[2], fusion_dim, 3, padding=1, bias=False))
self.fusion3 = FusionBlock(fusion_dim)
# stage4 (32x)
self.fc4 = nn.Sequential(nn.Linear(dim * 2, dim), nn.GELU())
self.conv4 = nn.Sequential(
nn.Conv2d(dim, neck_dims[3], 1),
nn.Conv2d(neck_dims[3], neck_dims[3], 3, stride=2, padding=1),
nn.Conv2d(neck_dims[3], fusion_dim, 3, padding=1, bias=False))
self.fusion4 = FusionBlock(fusion_dim)
# head
self.head = nn.Sequential(
nn.Conv2d(fusion_dim, fusion_dim // 2, 3, padding=1),
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
nn.Conv2d(fusion_dim // 2, 32, 3, padding=1),
nn.ReLU(inplace=True), nn.ConvTranspose2d(32, 1, 1),
nn.ReLU(inplace=True))
def forward(self, x):
b, _, h, w, p = *x.size(), self.patch_size
assert h % p == 0 and w % p == 0, f'Image size ({w}, {h}) is not divisible by patch size ({p}, {p})'
hp, wp, grid = h // p, w // p, self.image_size // p
# embeddings
pos_embedding = torch.cat([
self.pos_embedding[:, :1],
F.interpolate(
self.pos_embedding[:, 1:].reshape(1, grid, grid, -1).permute(
0, 3, 1, 2),
size=(hp, wp),
mode='bilinear',
align_corners=False).permute(0, 2, 3, 1).reshape(
1, hp * wp, -1)
],
dim=1) # noqa
x = self.patch_embedding(x).flatten(2).permute(0, 2, 1)
x = torch.cat([self.cls_embedding.repeat(b, 1, 1), x], dim=1)
x = x + pos_embedding
# stage1
x = self.blocks[self.slices[0]](x)
x1 = torch.cat([x[:, 1:], x[:, :1].expand_as(x[:, 1:])], dim=-1)
x1 = self.fc1(x1).permute(0, 2, 1).unflatten(2, (hp, wp))
x1 = self.conv1(x1)
# stage2
x = self.blocks[self.slices[1]](x)
x2 = torch.cat([x[:, 1:], x[:, :1].expand_as(x[:, 1:])], dim=-1)
x2 = self.fc2(x2).permute(0, 2, 1).unflatten(2, (hp, wp))
x2 = self.conv2(x2)
# stage3
x = self.blocks[self.slices[2]](x)
x3 = torch.cat([x[:, 1:], x[:, :1].expand_as(x[:, 1:])], dim=-1)
x3 = self.fc3(x3).permute(0, 2, 1).unflatten(2, (hp, wp))
x3 = self.conv3(x3)
# stage4
x = self.blocks[self.slices[3]](x)
x4 = torch.cat([x[:, 1:], x[:, :1].expand_as(x[:, 1:])], dim=-1)
x4 = self.fc4(x4).permute(0, 2, 1).unflatten(2, (hp, wp))
x4 = self.conv4(x4)
# fusion
x4 = self.fusion4(x4)
x3 = self.fusion3(x4, x3)
x2 = self.fusion2(x3, x2)
x1 = self.fusion1(x2, x1)
# head
x = self.head(x1)
return x
def midas_v3(model_dir, pretrained=False, **kwargs):
cfg = dict(
image_size=384,
patch_size=16,
dim=1024,
neck_dims=[256, 512, 1024, 1024],
fusion_dim=256,
num_heads=16,
num_layers=24)
cfg.update(**kwargs)
model = MiDaS(**cfg)
if pretrained:
model.load_state_dict(
torch.load(
os.path.join(model_dir, 'midas_v3_dpt_large.pth'),
map_location='cpu'))
return model

View File

@@ -0,0 +1,7 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from .degration import *
from .distributed import *
from .losses import *
from .random_mask import *
from .utils import *

View File

@@ -0,0 +1,998 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import math
import os
import random
from datetime import datetime
import numpy as np
import scipy
import scipy.stats as stats
import torch
from scipy import ndimage
from scipy.interpolate import interp2d
from scipy.linalg import orth
from torchvision.utils import make_grid
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
__all__ = ['degradation_bsrgan_light', 'degradation_bsrgan']
# get uint8 image of size HxWxn_channles (RGB)
def imread_uint(path, n_channels=3):
# input: path
# output: HxWx3(RGB or GGG), or HxWx1 (G)
if n_channels == 1:
img = cv2.imread(path, 0)
img = np.expand_dims(img, axis=2)
elif n_channels == 3:
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
if img.ndim == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
else:
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return img
def uint2single(img):
return np.float32(img / 255.)
def single2uint(img):
return np.uint8((img.clip(0, 1) * 255.).round())
def uint162single(img):
return np.float32(img / 65535.)
def single2uint16(img):
return np.uint16((img.clip(0, 1) * 65535.).round())
def rgb2ycbcr(img, only_y=True):
'''same as matlab rgb2ycbcr
only_y: only return Y channel
Input:
uint8, [0, 255]
float, [0, 1]
'''
in_img_type = img.dtype
img.astype(np.float32)
if in_img_type != np.uint8:
img *= 255.
# convert
if only_y:
rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0
else:
rlt = np.matmul(img,
[[65.481, -37.797, 112.0], [128.553, -74.203, -93.786],
[24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128]
if in_img_type == np.uint8:
rlt = rlt.round()
else:
rlt /= 255.
return rlt.astype(in_img_type)
def ycbcr2rgb(img):
'''same as matlab ycbcr2rgb
Input:
uint8, [0, 255]
float, [0, 1]
'''
in_img_type = img.dtype
img.astype(np.float32)
if in_img_type != np.uint8:
img *= 255.
# convert
rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621],
[0, -0.00153632, 0.00791071],
[0.00625893, -0.00318811, 0]]) * 255.0 + [
-222.921, 135.576, -276.836
] # noqa
if in_img_type == np.uint8:
rlt = rlt.round()
else:
rlt /= 255.
return rlt.astype(in_img_type)
def bgr2ycbcr(img, only_y=True):
'''bgr version of rgb2ycbcr
only_y: only return Y channel
Input:
uint8, [0, 255]
float, [0, 1]
'''
in_img_type = img.dtype
img.astype(np.float32)
if in_img_type != np.uint8:
img *= 255.
# convert
if only_y:
rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0
else:
rlt = np.matmul(img,
[[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
[65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128]
if in_img_type == np.uint8:
rlt = rlt.round()
else:
rlt /= 255.
return rlt.astype(in_img_type)
def channel_convert(in_c, tar_type, img_list):
# conversion among BGR, gray and y
if in_c == 3 and tar_type == 'gray':
gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list]
return [np.expand_dims(img, axis=2) for img in gray_list]
elif in_c == 3 and tar_type == 'y':
y_list = [bgr2ycbcr(img, only_y=True) for img in img_list]
return [np.expand_dims(img, axis=2) for img in y_list]
elif in_c == 1 and tar_type == 'RGB':
return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list]
else:
return img_list
# PSNR
def calculate_psnr(img1, img2, border=0):
# img1 and img2 have range [0, 255]
if not img1.shape == img2.shape:
raise ValueError('Input images must have the same dimensions.')
h, w = img1.shape[:2]
img1 = img1[border:h - border, border:w - border]
img2 = img2[border:h - border, border:w - border]
img1 = img1.astype(np.float64)
img2 = img2.astype(np.float64)
mse = np.mean((img1 - img2)**2)
if mse == 0:
return float('inf')
return 20 * math.log10(255.0 / math.sqrt(mse))
# SSIM
def calculate_ssim(img1, img2, border=0):
'''calculate SSIM
the same outputs as MATLAB's
img1, img2: [0, 255]
'''
if not img1.shape == img2.shape:
raise ValueError('Input images must have the same dimensions.')
h, w = img1.shape[:2]
img1 = img1[border:h - border, border:w - border]
img2 = img2[border:h - border, border:w - border]
if img1.ndim == 2:
return ssim(img1, img2)
elif img1.ndim == 3:
if img1.shape[2] == 3:
ssims = []
for i in range(3):
ssims.append(ssim(img1[:, :, i], img2[:, :, i]))
return np.array(ssims).mean()
elif img1.shape[2] == 1:
return ssim(np.squeeze(img1), np.squeeze(img2))
else:
raise ValueError('Wrong input image dimensions.')
def ssim(img1, img2):
C1 = (0.01 * 255)**2
C2 = (0.03 * 255)**2
img1 = img1.astype(np.float64)
img2 = img2.astype(np.float64)
kernel = cv2.getGaussianKernel(11, 1.5)
window = np.outer(kernel, kernel.transpose())
mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5]
mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
mu1_sq = mu1**2
mu2_sq = mu2**2
mu1_mu2 = mu1 * mu2
sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
ssim_map = ((2 * mu1_mu2 + C1) # noqa
* (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) # noqa
* # noqa
(sigma1_sq + sigma2_sq + C2)) # noqa
return ssim_map.mean()
# matlab 'imresize' function, now only support 'bicubic'
def cubic(x):
absx = torch.abs(x)
absx2 = absx**2
absx3 = absx**3
return (1.5 * absx3 - 2.5 * absx2 + 1) * ((absx <= 1).type_as(absx)) + \
(-0.5 * absx3 + 2.5 * absx2 - 4*absx + 2) * (((absx > 1) * (absx <= 2)).type_as(absx)) # noqa
def calculate_weights_indices(in_length, out_length, scale, kernel,
kernel_width, antialiasing):
if (scale < 1) and (antialiasing):
# Use a modified kernel to simultaneously interpolate and antialias- larger kernel width
kernel_width = kernel_width / scale
# Output-space coordinates
x = torch.linspace(1, out_length, out_length)
# Input-space coordinates. Calculate the inverse mapping such that 0.5
# in output space maps to 0.5 in input space, and 0.5+scale in output
# space maps to 1.5 in input space.
u = x / scale + 0.5 * (1 - 1 / scale)
# What is the left-most pixel that can be involved in the computation?
left = torch.floor(u - kernel_width / 2)
# What is the maximum number of pixels that can be involved in the
# computation? Note: it's OK to use an extra pixel here; if the
# corresponding weights are all zero, it will be eliminated at the end
# of this function.
P = math.ceil(kernel_width) + 2
# The indices of the input pixels involved in computing the k-th output
# pixel are in row k of the indices matrix.
indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(
0, P - 1, P).view(1, P).expand(out_length, P)
# The weights used to compute the k-th output pixel are in row k of the
# weights matrix.
distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices
# apply cubic kernel
if (scale < 1) and (antialiasing):
weights = scale * cubic(distance_to_center * scale)
else:
weights = cubic(distance_to_center)
# Normalize the weights matrix so that each row sums to 1.
weights_sum = torch.sum(weights, 1).view(out_length, 1)
weights = weights / weights_sum.expand(out_length, P)
# If a column in weights is all zero, get rid of it. only consider the first and last column.
weights_zero_tmp = torch.sum((weights == 0), 0)
if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
indices = indices.narrow(1, 1, P - 2)
weights = weights.narrow(1, 1, P - 2)
if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
indices = indices.narrow(1, 0, P - 2)
weights = weights.narrow(1, 0, P - 2)
weights = weights.contiguous()
indices = indices.contiguous()
sym_len_s = -indices.min() + 1
sym_len_e = indices.max() - in_length
indices = indices + sym_len_s - 1
return weights, indices, int(sym_len_s), int(sym_len_e)
# imresize for tensor image [0, 1]
def imresize(img, scale, antialiasing=True):
# Now the scale should be the same for H and W
# input: img: pytorch tensor, CHW or HW [0,1]
# output: CHW or HW [0,1] w/o round
need_squeeze = True if img.dim() == 2 else False
if need_squeeze:
img.unsqueeze_(0)
in_C, in_H, in_W = img.size()
out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W
* scale)
kernel_width = 4
kernel = 'cubic'
# Return the desired dimension order for performing the resize. The
# strategy is to perform the resize first along the dimension with the
# smallest scale factor.
# Now we do not support this.
# get weights and indices
weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
in_H, out_H, scale, kernel, kernel_width, antialiasing)
weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
in_W, out_W, scale, kernel, kernel_width, antialiasing)
# process H dimension
# symmetric copying
img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W)
img_aug.narrow(1, sym_len_Hs, in_H).copy_(img)
sym_patch = img[:, :sym_len_Hs, :]
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
sym_patch_inv = sym_patch.index_select(1, inv_idx)
img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv)
sym_patch = img[:, -sym_len_He:, :]
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
sym_patch_inv = sym_patch.index_select(1, inv_idx)
img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
out_1 = torch.FloatTensor(in_C, out_H, in_W)
kernel_width = weights_H.size(1)
for i in range(out_H):
idx = int(indices_H[i][0])
for j in range(out_C):
out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(
0, 1).mv(weights_H[i])
# process W dimension
# symmetric copying
out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We)
out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1)
sym_patch = out_1[:, :, :sym_len_Ws]
inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
sym_patch_inv = sym_patch.index_select(2, inv_idx)
out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv)
sym_patch = out_1[:, :, -sym_len_We:]
inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
sym_patch_inv = sym_patch.index_select(2, inv_idx)
out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
out_2 = torch.FloatTensor(in_C, out_H, out_W)
kernel_width = weights_W.size(1)
for i in range(out_W):
idx = int(indices_W[i][0])
for j in range(out_C):
out_2[j, :, i] = out_1_aug[j, :,
idx:idx + kernel_width].mv(weights_W[i])
if need_squeeze:
out_2.squeeze_()
return out_2
# imresize for numpy image [0, 1]
def imresize_np(img, scale, antialiasing=True):
# Now the scale should be the same for H and W
# input: img: Numpy, HWC or HW [0,1]
# output: HWC or HW [0,1] w/o round
img = torch.from_numpy(img)
need_squeeze = True if img.dim() == 2 else False
if need_squeeze:
img.unsqueeze_(2)
in_H, in_W, in_C = img.size()
out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W
* scale)
kernel_width = 4
kernel = 'cubic'
# Return the desired dimension order for performing the resize. The
# strategy is to perform the resize first along the dimension with the
# smallest scale factor.
# Now we do not support this.
# get weights and indices
weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
in_H, out_H, scale, kernel, kernel_width, antialiasing)
weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
in_W, out_W, scale, kernel, kernel_width, antialiasing)
# process H dimension
# symmetric copying
img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C)
img_aug.narrow(0, sym_len_Hs, in_H).copy_(img)
sym_patch = img[:sym_len_Hs, :, :]
inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
sym_patch_inv = sym_patch.index_select(0, inv_idx)
img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv)
sym_patch = img[-sym_len_He:, :, :]
inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
sym_patch_inv = sym_patch.index_select(0, inv_idx)
img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
out_1 = torch.FloatTensor(out_H, in_W, in_C)
kernel_width = weights_H.size(1)
for i in range(out_H):
idx = int(indices_H[i][0])
for j in range(out_C):
out_1[i, :, j] = img_aug[idx:idx + kernel_width, :,
j].transpose(0, 1).mv(weights_H[i])
# process W dimension
# symmetric copying
out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C)
out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1)
sym_patch = out_1[:, :sym_len_Ws, :]
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
sym_patch_inv = sym_patch.index_select(1, inv_idx)
out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv)
sym_patch = out_1[:, -sym_len_We:, :]
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
sym_patch_inv = sym_patch.index_select(1, inv_idx)
out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
out_2 = torch.FloatTensor(out_H, out_W, in_C)
kernel_width = weights_W.size(1)
for i in range(out_W):
idx = int(indices_W[i][0])
for j in range(out_C):
out_2[:, i, j] = out_1_aug[:, idx:idx + kernel_width,
j].mv(weights_W[i])
if need_squeeze:
out_2.squeeze_()
return out_2.numpy()
def modcrop_np(img, sf):
'''
Args:
img: numpy image, WxH or WxHxC
sf: scale factor
Return:
cropped image
'''
w, h = img.shape[:2]
im = np.copy(img)
return im[:w - w % sf, :h - h % sf, ...]
def analytic_kernel(k):
"""Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
k_size = k.shape[0]
# Calculate the big kernels size
big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2))
# Loop over the small kernel to fill the big one
for r in range(k_size):
for c in range(k_size):
big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k
# Crop the edges of the big kernel to ignore very small values and increase run time of SR
crop = k_size // 2
cropped_big_k = big_k[crop:-crop, crop:-crop]
# Normalize to 1
return cropped_big_k / cropped_big_k.sum()
def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
""" generate an anisotropic Gaussian kernel
Args:
ksize : e.g., 15, kernel size
theta : [0, pi], rotation angle range
l1 : [0.1,50], scaling of eigenvalues
l2 : [0.1,l1], scaling of eigenvalues
If l1 = l2, will get an isotropic Gaussian kernel.
Returns:
k : kernel
"""
v = np.dot(
np.array([[np.cos(theta), -np.sin(theta)],
[np.sin(theta), np.cos(theta)]]), np.array([1., 0.]))
V = np.array([[v[0], v[1]], [v[1], -v[0]]])
D = np.array([[l1, 0], [0, l2]])
Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)
return k
def gm_blur_kernel(mean, cov, size=15):
center = size / 2.0 + 0.5
k = np.zeros([size, size])
for y in range(size):
for x in range(size):
cy = y - center + 1
cx = x - center + 1
k[y, x] = stats.multivariate_normal.pdf([cx, cy],
mean=mean,
cov=cov)
k = k / np.sum(k)
return k
def shift_pixel(x, sf, upper_left=True):
"""shift pixel for super-resolution with different scale factors
Args:
x: WxHxC or WxH
sf: scale factor
upper_left: shift direction
"""
h, w = x.shape[:2]
shift = (sf - 1) * 0.5
xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
if upper_left:
x1 = xv + shift
y1 = yv + shift
else:
x1 = xv - shift
y1 = yv - shift
x1 = np.clip(x1, 0, w - 1)
y1 = np.clip(y1, 0, h - 1)
if x.ndim == 2:
x = interp2d(xv, yv, x)(x1, y1)
if x.ndim == 3:
for i in range(x.shape[-1]):
x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
return x
def blur(x, k):
'''
x: image, NxcxHxW
k: kernel, Nx1xhxw
'''
n, c = x.shape[:2]
p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2
x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate')
k = k.repeat(1, c, 1, 1)
k = k.view(-1, 1, k.shape[2], k.shape[3])
x = x.view(1, -1, x.shape[2], x.shape[3])
x = torch.nn.functional.conv2d(
x, k, bias=None, stride=1, padding=0, groups=n * c)
x = x.view(n, c, x.shape[2], x.shape[3])
return x
def gen_kernel(
k_size=np.array([15, 15]),
scale_factor=np.array([4, 4]),
min_var=0.6,
max_var=10.,
noise_level=0):
""""
# modified version of https://github.com/assafshocher/BlindSR_dataset_generator
# Kai Zhang
# min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
# max_var = 2.5 * sf
"""
# Set random eigen-vals (lambdas) and angle (theta) for COV matrix
lambda_1 = min_var + np.random.rand() * (max_var - min_var)
lambda_2 = min_var + np.random.rand() * (max_var - min_var)
theta = np.random.rand() * np.pi # random theta
noise = -noise_level + np.random.rand(*k_size) * noise_level * 2
# Set COV matrix using Lambdas and Theta
LAMBDA = np.diag([lambda_1, lambda_2])
Q = np.array([[np.cos(theta), -np.sin(theta)],
[np.sin(theta), np.cos(theta)]])
SIGMA = Q @ LAMBDA @ Q.T
INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
# Set expectation position (shifting kernel for aligned image)
MU = k_size // 2 - 0.5 * (scale_factor - 1)
MU = MU[None, None, :, None]
# Create meshgrid for Gaussian
[X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
Z = np.stack([X, Y], 2)[:, :, :, None]
# Calcualte Gaussian for every pixel of the kernel
ZZ = Z - MU
ZZ_t = ZZ.transpose(0, 1, 3, 2)
raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
# shift the kernel so it will be centered
# raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
# Normalize the kernel and return
# kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
kernel = raw_kernel / np.sum(raw_kernel)
return kernel
def fspecial_gaussian(hsize, sigma):
hsize = [hsize, hsize]
siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0]
std = sigma
[x, y] = np.meshgrid(
np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1))
arg = -(x * x + y * y) / (2 * std * std)
h = np.exp(arg)
h[h < scipy.finfo(float).eps * h.max()] = 0
sumh = h.sum()
if sumh != 0:
h = h / sumh
return h
def fspecial_laplacian(alpha):
alpha = max([0, min([alpha, 1])])
h1 = alpha / (alpha + 1)
h2 = (1 - alpha) / (alpha + 1)
h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]]
h = np.array(h)
return h
def fspecial(filter_type, *args, **kwargs):
if filter_type == 'gaussian':
return fspecial_gaussian(*args, **kwargs)
if filter_type == 'laplacian':
return fspecial_laplacian(*args, **kwargs)
def bicubic_degradation(x, sf=3):
'''
Args:
x: HxWxC image, [0, 1]
sf: down-scale factor
Return:
bicubicly downsampled LR image
'''
x = imresize_np(x, scale=1 / sf)
return x
def srmd_degradation(x, k, sf=3):
''' blur + bicubic downsampling
Args:
x: HxWxC image, [0, 1]
k: hxw, double
sf: down-scale factor
Return:
downsampled LR image
Reference:
@inproceedings{zhang2018learning,
title={Learning a single convolutional super-resolution network for multiple degradations},
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
pages={3262--3271},
year={2018}
}
'''
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
x = bicubic_degradation(x, sf=sf)
return x
def dpsr_degradation(x, k, sf=3):
''' bicubic downsampling + blur
Args:
x: HxWxC image, [0, 1]
k: hxw, double
sf: down-scale factor
Return:
downsampled LR image
Reference:
@inproceedings{zhang2019deep,
title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
pages={1671--1681},
year={2019}
}
'''
x = bicubic_degradation(x, sf=sf)
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
return x
def classical_degradation(x, k, sf=3):
''' blur + downsampling
Args:
x: HxWxC image, [0, 1]/[0, 255]
k: hxw, double
sf: down-scale factor
Return:
downsampled LR image
'''
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
st = 0
return x[st::sf, st::sf, ...]
def add_sharpening(img, weight=0.5, radius=50, threshold=10):
"""USM sharpening. borrowed from real-ESRGAN
Input image: I; Blurry image: B.
1. K = I + weight * (I - B)
2. Mask = 1 if abs(I - B) > threshold, else: 0
3. Blur mask:
4. Out = Mask * K + (1 - Mask) * I
Args:
img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
weight (float): Sharp weight. Default: 1.
radius (float): Kernel size of Gaussian blur. Default: 50.
threshold (int):
"""
if radius % 2 == 0:
radius += 1
blur = cv2.GaussianBlur(img, (radius, radius), 0)
residual = img - blur
mask = np.abs(residual) * 255 > threshold
mask = mask.astype('float32')
soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
K = img + weight * residual
K = np.clip(K, 0, 1)
return soft_mask * K + (1 - soft_mask) * img
def add_blur_1(img, sf=4):
wd2 = 4.0 + sf
wd = 2.0 + 0.2 * sf
wd2 = wd2 / 4
wd = wd / 4
if random.random() < 0.5:
l1 = wd2 * random.random()
l2 = wd2 * random.random()
k = anisotropic_Gaussian(
ksize=random.randint(2, 11) + 3,
theta=random.random() * np.pi,
l1=l1,
l2=l2)
else:
k = fspecial('gaussian',
random.randint(2, 4) + 3, wd * random.random())
img = ndimage.filters.convolve(
img, np.expand_dims(k, axis=2), mode='mirror')
return img
def add_resize(img, sf=4):
rnum = np.random.rand()
if rnum > 0.8:
sf1 = random.uniform(1, 2)
elif rnum < 0.7:
sf1 = random.uniform(0.5 / sf, 1)
else:
sf1 = 1.0
img = cv2.resize(
img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])),
interpolation=random.choice([1, 2, 3]))
img = np.clip(img, 0.0, 1.0)
return img
def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
noise_level = random.randint(noise_level1, noise_level2)
rnum = np.random.rand()
if rnum > 0.6:
img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(
np.float32)
elif rnum < 0.4:
img = img + np.random.normal(0, noise_level / 255.0,
(*img.shape[:2], 1)).astype(np.float32)
else:
L = noise_level2 / 255.
D = np.diag(np.random.rand(3))
U = orth(np.random.rand(3, 3))
conv = np.dot(np.dot(np.transpose(U), D), U)
img = img + np.random.multivariate_normal([0, 0, 0], np.abs(
L**2 * conv), img.shape[:2]).astype(np.float32)
img = np.clip(img, 0.0, 1.0)
return img
def add_speckle_noise(img, noise_level1=2, noise_level2=25):
noise_level = random.randint(noise_level1, noise_level2)
img = np.clip(img, 0.0, 1.0)
rnum = random.random()
if rnum > 0.6:
img += img * np.random.normal(0, noise_level / 255.0,
img.shape).astype(np.float32)
elif rnum < 0.4:
img += img * np.random.normal(0, noise_level / 255.0,
(*img.shape[:2], 1)).astype(np.float32)
else:
L = noise_level2 / 255.
D = np.diag(np.random.rand(3))
U = orth(np.random.rand(3, 3))
conv = np.dot(np.dot(np.transpose(U), D), U)
img += img * np.random.multivariate_normal(
[0, 0, 0], np.abs(L**2 * conv), img.shape[:2]).astype(np.float32)
img = np.clip(img, 0.0, 1.0)
return img
def add_Poisson_noise(img):
img = np.clip((img * 255.0).round(), 0, 255) / 255.
vals = 10**(2 * random.random() + 2.0)
if random.random() < 0.5:
img = np.random.poisson(img * vals).astype(np.float32) / vals
else:
img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])
img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.
noise_gray = np.random.poisson(img_gray * vals).astype(
np.float32) / vals - img_gray
img += noise_gray[:, :, np.newaxis]
img = np.clip(img, 0.0, 1.0)
return img
def add_JPEG_noise(img):
quality_factor = random.randint(80, 95)
img = cv2.cvtColor(single2uint(img), cv2.COLOR_RGB2BGR)
result, encimg = cv2.imencode(
'.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
img = cv2.imdecode(encimg, 1)
img = cv2.cvtColor(uint2single(img), cv2.COLOR_BGR2RGB)
return img
def random_crop(lq, hq, sf=4, lq_patchsize=64):
h, w = lq.shape[:2]
rnd_h = random.randint(0, h - lq_patchsize)
rnd_w = random.randint(0, w - lq_patchsize)
lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :]
rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)
hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf,
rnd_w_H:rnd_w_H + lq_patchsize * sf, :]
return lq, hq
def degradation_bsrgan_light(image, sf=4, isp_model=None):
"""
This is the variant of the degradation model of BSRGAN from the paper
"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
sf: scale factor
isp_model: camera ISP model
Returns
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
"""
image = uint2single(image)
_, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
h1, w1 = image.shape[:2]
image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...]
h, w = image.shape[:2]
if sf == 4 and random.random() < scale2_prob:
if np.random.rand() < 0.5:
image = cv2.resize(
image,
(int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),
interpolation=random.choice([1, 2, 3]))
else:
image = imresize_np(image, 1 / 2, True)
image = np.clip(image, 0.0, 1.0)
sf = 2
shuffle_order = random.sample(range(7), 7)
idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
if idx1 > idx2:
shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[
idx2], shuffle_order[idx1]
for i in shuffle_order:
if i == 0:
image = add_blur_1(image, sf=sf)
elif i == 2:
a, b = image.shape[1], image.shape[0]
# downsample2
if random.random() < 0.8:
sf1 = random.uniform(1, 2 * sf)
image = cv2.resize(
image, (int(1 / sf1 * image.shape[1]),
int(1 / sf1 * image.shape[0])),
interpolation=random.choice([1, 2, 3]))
else:
k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
k_shifted = shift_pixel(k, sf)
k_shifted = k_shifted / k_shifted.sum()
image = ndimage.filters.convolve(
image, np.expand_dims(k_shifted, axis=2), mode='mirror')
image = image[0::sf, 0::sf, ...]
image = np.clip(image, 0.0, 1.0)
elif i == 3:
# downsample3
image = cv2.resize(
image, (int(1 / sf * a), int(1 / sf * b)),
interpolation=random.choice([1, 2, 3]))
image = np.clip(image, 0.0, 1.0)
elif i == 4:
# add Gaussian noise
image = add_Gaussian_noise(image, noise_level1=1, noise_level2=2)
elif i == 5:
# add JPEG noise
if random.random() < jpeg_prob:
image = add_JPEG_noise(image)
# add final JPEG compression noise
image = add_JPEG_noise(image)
image = single2uint(image)
return image
def add_blur_2(img, sf=4):
wd2 = 4.0 + sf
wd = 2.0 + 0.2 * sf
if random.random() < 0.5:
l1 = wd2 * random.random()
l2 = wd2 * random.random()
k = anisotropic_Gaussian(
ksize=2 * random.randint(2, 11) + 3,
theta=random.random() * np.pi,
l1=l1,
l2=l2)
else:
k = fspecial('gaussian', 2 * random.randint(2, 11) + 3,
wd * random.random())
img = ndimage.filters.convolve(
img, np.expand_dims(k, axis=2), mode='mirror')
return img
def degradation_bsrgan(image, sf=4, isp_model=None):
"""
This is the variant of the degradation model of BSRGAN from the paper
"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
sf: scale factor
isp_model: camera ISP model
Returns
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
"""
image = uint2single(image)
_, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
h1, w1 = image.shape[:2]
image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...]
h, w = image.shape[:2]
if sf == 4 and random.random() < scale2_prob:
if np.random.rand() < 0.5:
image = cv2.resize(
image,
(int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),
interpolation=random.choice([1, 2, 3]))
else:
image = imresize_np(image, 1 / 2, True)
image = np.clip(image, 0.0, 1.0)
sf = 2
shuffle_order = random.sample(range(7), 7)
idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
if idx1 > idx2:
shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[
idx2], shuffle_order[idx1]
for i in shuffle_order:
if i == 0:
image = add_blur_2(image, sf=sf)
elif i == 1:
image = add_blur_2(image, sf=sf)
elif i == 2:
a, b = image.shape[1], image.shape[0]
# downsample2
if random.random() < 0.75:
sf1 = random.uniform(1, 2 * sf)
image = cv2.resize(
image, (int(1 / sf1 * image.shape[1]),
int(1 / sf1 * image.shape[0])),
interpolation=random.choice([1, 2, 3]))
else:
k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
k_shifted = shift_pixel(k, sf)
k_shifted = k_shifted / k_shifted.sum()
image = ndimage.filters.convolve(
image, np.expand_dims(k_shifted, axis=2), mode='mirror')
image = image[0::sf, 0::sf, ...]
image = np.clip(image, 0.0, 1.0)
elif i == 3:
# downsample3
image = cv2.resize(
image, (int(1 / sf * a), int(1 / sf * b)),
interpolation=random.choice([1, 2, 3]))
image = np.clip(image, 0.0, 1.0)
elif i == 4:
# add Gaussian noise
image = add_Gaussian_noise(image, noise_level1=2, noise_level2=25)
elif i == 5:
# add JPEG noise
if random.random() < jpeg_prob:
image = add_JPEG_noise(image)
# add final JPEG compression noise
image = add_JPEG_noise(image)
image = single2uint(image)
return image

View File

@@ -0,0 +1,460 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import functools
import pickle
from collections import OrderedDict
import numpy as np
import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch.autograd import Function
__all__ = [
'is_dist_initialized', 'get_world_size', 'get_rank', 'new_group',
'destroy_process_group', 'barrier', 'broadcast', 'all_reduce', 'reduce',
'gather', 'all_gather', 'reduce_dict', 'get_global_gloo_group',
'generalized_all_gather', 'generalized_gather', 'scatter',
'reduce_scatter', 'send', 'recv', 'isend', 'irecv', 'shared_random_seed',
'diff_all_gather', 'diff_all_reduce', 'diff_scatter', 'diff_copy',
'spherical_kmeans', 'sinkhorn'
]
def is_dist_initialized():
return dist.is_available() and dist.is_initialized()
def get_world_size(group=None):
return dist.get_world_size(group) if is_dist_initialized() else 1
def get_rank(group=None):
return dist.get_rank(group) if is_dist_initialized() else 0
def new_group(ranks=None, **kwargs):
if is_dist_initialized():
return dist.new_group(ranks, **kwargs)
return None
def destroy_process_group():
if is_dist_initialized():
dist.destroy_process_group()
def barrier(group=None, **kwargs):
if get_world_size(group) > 1:
dist.barrier(group, **kwargs)
def broadcast(tensor, src, group=None, **kwargs):
if get_world_size(group) > 1:
return dist.broadcast(tensor, src, group, **kwargs)
def all_reduce(tensor, op=dist.ReduceOp.SUM, group=None, **kwargs):
if get_world_size(group) > 1:
return dist.all_reduce(tensor, op, group, **kwargs)
def reduce(tensor, dst, op=dist.ReduceOp.SUM, group=None, **kwargs):
if get_world_size(group) > 1:
return dist.reduce(tensor, dst, op, group, **kwargs)
def gather(tensor, dst=0, group=None, **kwargs):
rank = get_rank()
world_size = get_world_size(group)
if world_size == 1:
return [tensor]
tensor_list = [torch.empty_like(tensor)
for _ in range(world_size)] if rank == dst else None
dist.gather(tensor, tensor_list, dst, group, **kwargs)
return tensor_list
def all_gather(tensor, uniform_size=True, group=None, **kwargs):
world_size = get_world_size(group)
if world_size == 1:
return [tensor]
assert tensor.is_contiguous(
), 'ops.all_gather requires the tensor to be contiguous()'
if uniform_size:
tensor_list = [torch.empty_like(tensor) for _ in range(world_size)]
dist.all_gather(tensor_list, tensor, group, **kwargs)
return tensor_list
else:
# collect tensor shapes across GPUs
shape = tuple(tensor.shape)
shape_list = generalized_all_gather(shape, group)
# flatten the tensor
tensor = tensor.reshape(-1)
size = int(np.prod(shape))
size_list = [int(np.prod(u)) for u in shape_list]
max_size = max(size_list)
# pad to maximum size
if size != max_size:
padding = tensor.new_zeros(max_size - size)
tensor = torch.cat([tensor, padding], dim=0)
# all_gather
tensor_list = [torch.empty_like(tensor) for _ in range(world_size)]
dist.all_gather(tensor_list, tensor, group, **kwargs)
# reshape tensors
tensor_list = [
t[:n].view(s)
for t, n, s in zip(tensor_list, size_list, shape_list)
]
return tensor_list
@torch.no_grad()
def reduce_dict(input_dict, group=None, reduction='mean', **kwargs):
assert reduction in ['mean', 'sum']
world_size = get_world_size(group)
if world_size == 1:
return input_dict
# ensure that the orders of keys are consistent across processes
if isinstance(input_dict, OrderedDict):
keys = list(input_dict.keys)
else:
keys = sorted(input_dict.keys())
vals = [input_dict[key] for key in keys]
vals = torch.stack(vals, dim=0)
dist.reduce(vals, dst=0, group=group, **kwargs)
if dist.get_rank(group) == 0 and reduction == 'mean':
vals /= world_size
dist.broadcast(vals, src=0, group=group, **kwargs)
reduced_dict = type(input_dict)([(key, val)
for key, val in zip(keys, vals)])
return reduced_dict
@functools.lru_cache()
def get_global_gloo_group():
backend = dist.get_backend()
assert backend in ['gloo', 'nccl']
if backend == 'nccl':
return dist.new_group(backend='gloo')
else:
return dist.group.WORLD
def _serialize_to_tensor(data, group):
backend = dist.get_backend(group)
assert backend in ['gloo', 'nccl']
device = torch.device('cpu' if backend == 'gloo' else 'cuda')
buffer = pickle.dumps(data)
if len(buffer) > 1024**3:
logger = logging.getLogger(__name__)
logger.warning(
'Rank {} trying to all-gather {:.2f} GB of data on device'
'{}'.format(get_rank(),
len(buffer) / (1024**3), device))
storage = torch.ByteStorage.from_buffer(buffer)
tensor = torch.ByteTensor(storage).to(device=device)
return tensor
def _pad_to_largest_tensor(tensor, group):
world_size = dist.get_world_size(group=group)
assert world_size >= 1, \
'gather/all_gather must be called from ranks within' \
'the give group!'
local_size = torch.tensor([tensor.numel()],
dtype=torch.int64,
device=tensor.device)
size_list = [
torch.zeros([1], dtype=torch.int64, device=tensor.device)
for _ in range(world_size)
]
# gather tensors and compute the maximum size
dist.all_gather(size_list, local_size, group=group)
size_list = [int(size.item()) for size in size_list]
max_size = max(size_list)
# pad tensors to the same size
if local_size != max_size:
padding = torch.zeros((max_size - local_size, ),
dtype=torch.uint8,
device=tensor.device)
tensor = torch.cat((tensor, padding), dim=0)
return size_list, tensor
def generalized_all_gather(data, group=None):
if get_world_size(group) == 1:
return [data]
if group is None:
group = get_global_gloo_group()
tensor = _serialize_to_tensor(data, group)
size_list, tensor = _pad_to_largest_tensor(tensor, group)
max_size = max(size_list)
# receiving tensors from all ranks
tensor_list = [
torch.empty((max_size, ), dtype=torch.uint8, device=tensor.device)
for _ in size_list
]
dist.all_gather(tensor_list, tensor, group=group)
data_list = []
for size, tensor in zip(size_list, tensor_list):
buffer = tensor.cpu().numpy().tobytes()[:size]
data_list.append(pickle.loads(buffer))
return data_list
def generalized_gather(data, dst=0, group=None):
world_size = get_world_size(group)
if world_size == 1:
return [data]
if group is None:
group = get_global_gloo_group()
rank = dist.get_rank()
tensor = _serialize_to_tensor(data, group)
size_list, tensor = _pad_to_largest_tensor(tensor, group)
# receiving tensors from all ranks to dst
if rank == dst:
max_size = max(size_list)
tensor_list = [
torch.empty((max_size, ), dtype=torch.uint8, device=tensor.device)
for _ in size_list
]
dist.gather(tensor, tensor_list, dst=dst, group=group)
data_list = []
for size, tensor in zip(size_list, tensor_list):
buffer = tensor.cpu().numpy().tobytes()[:size]
data_list.append(pickle.loads(buffer))
return data_list
else:
dist.gather(tensor, [], dst=dst, group=group)
return []
def scatter(data, scatter_list=None, src=0, group=None, **kwargs):
r"""NOTE: only supports CPU tensor communication.
"""
if get_world_size(group) > 1:
return dist.scatter(data, scatter_list, src, group, **kwargs)
def reduce_scatter(output,
input_list,
op=dist.ReduceOp.SUM,
group=None,
**kwargs):
if get_world_size(group) > 1:
return dist.reduce_scatter(output, input_list, op, group, **kwargs)
def send(tensor, dst, group=None, **kwargs):
if get_world_size(group) > 1:
assert tensor.is_contiguous(
), 'ops.send requires the tensor to be contiguous()'
return dist.send(tensor, dst, group, **kwargs)
def recv(tensor, src=None, group=None, **kwargs):
if get_world_size(group) > 1:
assert tensor.is_contiguous(
), 'ops.recv requires the tensor to be contiguous()'
return dist.recv(tensor, src, group, **kwargs)
def isend(tensor, dst, group=None, **kwargs):
if get_world_size(group) > 1:
assert tensor.is_contiguous(
), 'ops.isend requires the tensor to be contiguous()'
return dist.isend(tensor, dst, group, **kwargs)
def irecv(tensor, src=None, group=None, **kwargs):
if get_world_size(group) > 1:
assert tensor.is_contiguous(
), 'ops.irecv requires the tensor to be contiguous()'
return dist.irecv(tensor, src, group, **kwargs)
def shared_random_seed(group=None):
seed = np.random.randint(2**31)
all_seeds = generalized_all_gather(seed, group)
return all_seeds[0]
def _all_gather(x):
if not (dist.is_available()
and dist.is_initialized()) or dist.get_world_size() == 1:
return x
rank = dist.get_rank()
world_size = dist.get_world_size()
tensors = [torch.empty_like(x) for _ in range(world_size)]
tensors[rank] = x
dist.all_gather(tensors, x)
return torch.cat(tensors, dim=0).contiguous()
def _all_reduce(x):
if not (dist.is_available()
and dist.is_initialized()) or dist.get_world_size() == 1:
return x
dist.all_reduce(x)
return x
def _split(x):
if not (dist.is_available()
and dist.is_initialized()) or dist.get_world_size() == 1:
return x
rank = dist.get_rank()
world_size = dist.get_world_size()
return x.chunk(world_size, dim=0)[rank].contiguous()
class DiffAllGather(Function):
r"""Differentiable all-gather.
"""
@staticmethod
def symbolic(graph, input):
return _all_gather(input)
@staticmethod
def forward(ctx, input):
return _all_gather(input)
@staticmethod
def backward(ctx, grad_output):
return _split(grad_output)
class DiffAllReduce(Function):
r"""Differentiable all-reducd.
"""
@staticmethod
def symbolic(graph, input):
return _all_reduce(input)
@staticmethod
def forward(ctx, input):
return _all_reduce(input)
@staticmethod
def backward(ctx, grad_output):
return grad_output
class DiffScatter(Function):
r"""Differentiable scatter.
"""
@staticmethod
def symbolic(ctx, input):
return _split(input)
@staticmethod
def backward(ctx, grad_output):
return _all_gather(grad_output)
class DiffCopy(Function):
r"""Differentiable copy that reduces all gradients during backward.
"""
@staticmethod
def symbolic(graph, input):
return input
@staticmethod
def forward(ctx, input):
return input
@staticmethod
def backward(ctx, grad_output):
return _all_reduce(grad_output)
diff_all_gather = DiffAllGather.apply
diff_all_reduce = DiffAllReduce.apply
diff_scatter = DiffScatter.apply
diff_copy = DiffCopy.apply
@torch.no_grad()
def spherical_kmeans(feats, num_clusters, num_iters=10):
k, n, c = num_clusters, *feats.size()
ones = feats.new_ones(n, dtype=torch.long)
# distributed settings
world_size = get_world_size()
# init clusters
rand_inds = torch.randperm(n)[:int(np.ceil(k / world_size))]
clusters = torch.cat(all_gather(feats[rand_inds]), dim=0)[:k]
# variables
new_clusters = feats.new_zeros(k, c)
counts = feats.new_zeros(k, dtype=torch.long)
# iterative Expectation-Maximization
for step in range(num_iters + 1):
# Expectation step
simmat = torch.mm(feats, clusters.t())
scores, assigns = simmat.max(dim=1)
if step == num_iters:
break
# Maximization step
new_clusters.zero_().scatter_add_(0,
assigns.unsqueeze(1).repeat(1, c),
feats)
all_reduce(new_clusters)
counts.zero_()
counts.index_add_(0, assigns, ones)
all_reduce(counts)
mask = (counts > 0)
clusters[mask] = new_clusters[mask] / counts[mask].view(-1, 1)
clusters = F.normalize(clusters, p=2, dim=1)
return clusters, assigns, scores
@torch.no_grad()
def sinkhorn(Q, eps=0.5, num_iters=3):
# normalize Q
Q = torch.exp(Q / eps).t()
sum_Q = Q.sum()
all_reduce(sum_Q)
Q /= sum_Q
# variables
n, m = Q.size()
u = Q.new_zeros(n)
r = Q.new_ones(n) / n
c = Q.new_ones(m) / (m * get_world_size())
# iterative update
cur_sum = Q.sum(dim=1)
all_reduce(cur_sum)
for i in range(num_iters):
u = cur_sum
Q *= (r / u).unsqueeze(1)
Q *= (c / Q.sum(dim=0)).unsqueeze(0)
cur_sum = Q.sum(dim=1)
all_reduce(cur_sum)
return (Q / Q.sum(dim=0, keepdim=True)).t().float()

View File

@@ -0,0 +1,37 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import math
import torch
__all__ = ['kl_divergence', 'discretized_gaussian_log_likelihood']
def kl_divergence(mu1, logvar1, mu2, logvar2):
return 0.5 * (
-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + # noqa
((mu1 - mu2)**2) * torch.exp(-logvar2))
def standard_normal_cdf(x):
r"""A fast approximation of the cumulative distribution function of the standard normal.
"""
return 0.5 * (1.0 + torch.tanh(
math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
def discretized_gaussian_log_likelihood(x0, mean, log_scale):
assert x0.shape == mean.shape == log_scale.shape
cx = x0 - mean
inv_stdv = torch.exp(-log_scale)
cdf_plus = standard_normal_cdf(inv_stdv * (cx + 1.0 / 255.0))
cdf_min = standard_normal_cdf(inv_stdv * (cx - 1.0 / 255.0))
log_cdf_plus = torch.log(cdf_plus.clamp(min=1e-12))
log_one_minus_cdf_min = torch.log((1.0 - cdf_min).clamp(min=1e-12))
cdf_delta = cdf_plus - cdf_min
log_probs = torch.where(
x0 < -0.999, log_cdf_plus,
torch.where(x0 > 0.999, log_one_minus_cdf_min,
torch.log(cdf_delta.clamp(min=1e-12))))
assert log_probs.shape == x0.shape
return log_probs

View File

@@ -0,0 +1,81 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import cv2
import numpy as np
__all__ = ['make_irregular_mask', 'make_rectangle_mask', 'make_uncrop']
def make_irregular_mask(w,
h,
max_angle=4,
max_length=200,
max_width=100,
min_strokes=1,
max_strokes=5,
mode='line'):
# initialize mask
assert mode in ['line', 'circle', 'square']
mask = np.zeros((h, w), np.float32)
# draw strokes
num_strokes = np.random.randint(min_strokes, max_strokes + 1)
for i in range(num_strokes):
x1 = np.random.randint(w)
y1 = np.random.randint(h)
for j in range(1 + np.random.randint(5)):
angle = 0.01 + np.random.randint(max_angle)
if i % 2 == 0:
angle = 2 * 3.1415926 - angle
length = 10 + np.random.randint(max_length)
radius = 5 + np.random.randint(max_width)
x2 = np.clip((x1 + length * np.sin(angle)).astype(np.int32), 0, w)
y2 = np.clip((y1 + length * np.cos(angle)).astype(np.int32), 0, h)
if mode == 'line':
cv2.line(mask, (x1, y1), (x2, y2), 1.0, radius)
elif mode == 'circle':
cv2.circle(
mask, (x1, y1), radius=radius, color=1.0, thickness=-1)
elif mode == 'square':
radius = radius // 2
mask[y1 - radius:y1 + radius, x1 - radius:x1 + radius] = 1
x1, y1 = x2, y2
return mask
def make_rectangle_mask(w,
h,
margin=10,
min_size=30,
max_size=150,
min_strokes=1,
max_strokes=4):
# initialize mask
mask = np.zeros((h, w), np.float32)
# draw rectangles
num_strokes = np.random.randint(min_strokes, max_strokes + 1)
for i in range(num_strokes):
box_w = np.random.randint(min_size, max_size)
box_h = np.random.randint(min_size, max_size)
x1 = np.random.randint(margin, w - margin - box_w + 1)
y1 = np.random.randint(margin, h - margin - box_h + 1)
mask[y1:y1 + box_h, x1:x1 + box_w] = 1
return mask
def make_uncrop(w, h):
# initialize mask
mask = np.zeros((h, w), np.float32)
# randomly halve the image
side = np.random.choice([0, 1, 2, 3])
if side == 0:
mask[:h // 2, :] = 1
elif side == 1:
mask[h // 2:, :] = 1
elif side == 2:
mask[:, :w // 2] = 1
elif side == 3:
mask[:, w // 2:] = 1
return mask

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,273 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import argparse
import copy
import os
import json
import yaml
from modelscope.utils.logger import get_logger
logger = get_logger()
def setup_seed(seed):
print('Seed: ', seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
class Config(object):
def __init__(self, load=True, cfg_dict=None, cfg_level=None):
self._level = 'cfg' + ('.'
+ cfg_level if cfg_level is not None else '')
if load:
self.args = self._parse_args()
logger.info('Loading config from {}.'.format(self.args.cfg_file))
self.need_initialization = True
cfg_base = self._initialize_cfg()
cfg_dict = self._load_yaml(self.args)
cfg_dict = self._merge_cfg_from_base(cfg_base, cfg_dict)
cfg_dict = self._update_from_args(cfg_dict)
self.cfg_dict = cfg_dict
self._update_dict(cfg_dict)
def _parse_args(self):
parser = argparse.ArgumentParser(
description=
'Argparser for configuring [code base name to think of] codebase')
parser.add_argument(
'--cfg',
dest='cfg_file',
help='Path to the configuration file',
default=
'./modelscope/models/multi_modal/videocomposer/configs/exp06_text_depths_vs_style.yaml'
)
parser.add_argument(
'--init_method',
help='Initialization method, includes TCP or shared file-system',
default='tcp://localhost:9999',
type=str,
)
parser.add_argument(
'--seed',
type=int,
default=8888,
help='Need to explore for different videos')
parser.add_argument(
'--debug',
action='store_true',
default=False,
help='Into debug information')
parser.add_argument(
'--input_video',
default='demo_video/video_8800.mp4',
help='input video for full task, or motion vector of input videos',
type=str,
),
parser.add_argument(
'--image_path', default='', help='Single Image Input', type=str)
parser.add_argument(
'--sketch_path', default='', help='Single Sketch Input', type=str)
parser.add_argument(
'--style_image', help='Single Sketch Input', type=str)
parser.add_argument(
'--input_text_desc',
default=
'A colorful and beautiful fish swimming in a small glass bowl with \
multicolored piece of stone, Macro Video',
type=str,
),
parser.add_argument(
'opts',
help='other configurations',
default=None,
nargs=argparse.REMAINDER)
return parser.parse_args()
def _path_join(self, path_list):
path = ''
for p in path_list:
path += p + '/'
return path[:-1]
def _update_from_args(self, cfg_dict):
args = self.args
for var in vars(args):
cfg_dict[var] = getattr(args, var)
return cfg_dict
def _initialize_cfg(self):
if self.need_initialization:
self.need_initialization = False
if os.path.exists(
'./modelscope/models/multi_modal/videocomposer/configs/base.yaml'
):
with open(
'./modelscope/models/multi_modal/videocomposer/configs/base.yaml',
'r') as f:
cfg = yaml.load(f.read(), Loader=yaml.SafeLoader)
else:
with open(
'./modelscope/models/multi_modal/videocomposer/configs/base.yaml',
'r') as f:
cfg = yaml.load(f.read(), Loader=yaml.SafeLoader)
return cfg
def _load_yaml(self, args, file_name=''):
assert args.cfg_file is not None
if not file_name == '': # reading from base file
with open(file_name, 'r') as f:
cfg = yaml.load(f.read(), Loader=yaml.SafeLoader)
else:
if os.getcwd().split('/')[-1] == args.cfg_file.split('/')[0]:
args.cfg_file = args.cfg_file.replace(
os.getcwd().split('/')[-1], './')
try:
with open(args.cfg_file, 'r') as f:
cfg = yaml.load(f.read(), Loader=yaml.SafeLoader)
file_name = args.cfg_file
except Exception as e:
args.cfg_file = os.path.realpath(__file__).split(
'/')[-3] + '/' + args.cfg_file
with open(args.cfg_file, 'r') as f:
cfg = yaml.load(f.read(), Loader=yaml.SafeLoader)
file_name = args.cfg_file
print(e)
if '_BASE_RUN' not in cfg.keys() and '_BASE_MODEL' not in cfg.keys(
) and '_BASE' not in cfg.keys():
return cfg
if '_BASE' in cfg.keys():
if cfg['_BASE'][1] == '.':
prev_count = cfg['_BASE'].count('..')
cfg_base_file = self._path_join(
file_name.split('/')[:(-1 - cfg['_BASE'].count('..'))]
+ cfg['_BASE'].split('/')[prev_count:])
else:
cfg_base_file = cfg['_BASE'].replace(
'./',
args.cfg_file.replace(args.cfg_file.split('/')[-1], ''))
cfg_base = self._load_yaml(args, cfg_base_file)
cfg = self._merge_cfg_from_base(cfg_base, cfg)
else:
if '_BASE_RUN' in cfg.keys():
if cfg['_BASE_RUN'][1] == '.':
prev_count = cfg['_BASE_RUN'].count('..')
cfg_base_file = self._path_join(
file_name.split('/')[:(-1 - prev_count)]
+ cfg['_BASE_RUN'].split('/')[prev_count:])
else:
cfg_base_file = cfg['_BASE_RUN'].replace(
'./',
args.cfg_file.replace(
args.cfg_file.split('/')[-1], ''))
cfg_base = self._load_yaml(args, cfg_base_file)
cfg = self._merge_cfg_from_base(
cfg_base, cfg, preserve_base=True)
if '_BASE_MODEL' in cfg.keys():
if cfg['_BASE_MODEL'][1] == '.':
prev_count = cfg['_BASE_MODEL'].count('..')
cfg_base_file = self._path_join(
file_name.split('/')[:(
-1 - cfg['_BASE_MODEL'].count('..'))]
+ cfg['_BASE_MODEL'].split('/')[prev_count:])
else:
cfg_base_file = cfg['_BASE_MODEL'].replace(
'./',
args.cfg_file.replace(
args.cfg_file.split('/')[-1], ''))
cfg_base = self._load_yaml(args, cfg_base_file)
cfg = self._merge_cfg_from_base(cfg_base, cfg)
cfg = self._merge_cfg_from_command(args, cfg)
return cfg
def _merge_cfg_from_base(self, cfg_base, cfg_new, preserve_base=False):
for k, v in cfg_new.items():
if k in cfg_base.keys():
if isinstance(v, dict):
self._merge_cfg_from_base(cfg_base[k], v)
else:
cfg_base[k] = v
else:
if 'BASE' not in k or preserve_base:
cfg_base[k] = v
return cfg_base
def _merge_cfg_from_command(self, args, cfg):
assert len(
args.opts) % 2 == 0, 'Override list {} has odd length: {}.'.format(
args.opts, len(args.opts))
keys = args.opts[0::2]
vals = args.opts[1::2]
# maximum supported depth 3
for idx, key in enumerate(keys):
key_split = key.split('.')
assert len(
key_split
) <= 4, 'Key depth error. \nMaximum depth: 3\n Get depth: {}'.format(
len(key_split))
assert key_split[0] in cfg.keys(), 'Non-existant key: {}.'.format(
key_split[0])
if len(key_split) == 2:
assert key_split[1] in cfg[
key_split[0]].keys(), 'Non-existant key: {}.'.format(key)
elif len(key_split) == 3:
assert key_split[1] in cfg[
key_split[0]].keys(), 'Non-existant key: {}.'.format(key)
assert key_split[2] in cfg[key_split[0]][
key_split[1]].keys(), 'Non-existant key: {}.'.format(key)
elif len(key_split) == 4:
assert key_split[1] in cfg[
key_split[0]].keys(), 'Non-existant key: {}.'.format(key)
assert key_split[2] in cfg[key_split[0]][
key_split[1]].keys(), 'Non-existant key: {}.'.format(key)
assert key_split[3] in cfg[key_split[0]][key_split[1]][
key_split[2]].keys(), 'Non-existant key: {}.'.format(key)
if len(key_split) == 1:
cfg[key_split[0]] = vals[idx]
elif len(key_split) == 2:
cfg[key_split[0]][key_split[1]] = vals[idx]
elif len(key_split) == 3:
cfg[key_split[0]][key_split[1]][key_split[2]] = vals[idx]
elif len(key_split) == 4:
cfg[key_split[0]][key_split[1]][key_split[2]][
key_split[3]] = vals[idx]
return cfg
def _update_dict(self, cfg_dict):
def recur(key, elem):
if type(elem) is dict:
return key, Config(load=False, cfg_dict=elem, cfg_level=key)
else:
if type(elem) is str and elem[1:3] == 'e-':
elem = float(elem)
return key, elem
dic = dict(recur(k, v) for k, v in cfg_dict.items())
self.__dict__.update(dic)
def get_args(self):
return self.args
def __repr__(self):
return '{}\n'.format(self.dump())
def dump(self):
return json.dumps(self.cfg_dict, indent=2)
def deep_copy(self):
return copy.deepcopy(self)
if __name__ == '__main__':
# debug
cfg = Config(load=True)
print(cfg.DATA)

View File

@@ -0,0 +1,297 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
"""Distributed helpers."""
import functools
import logging
import pickle
import torch
import torch.distributed as dist
_LOCAL_PROCESS_GROUP = None
def all_gather(tensors):
"""
All gathers the provided tensors from all processes across machines.
Args:
tensors (list): tensors to perform all gather across all processes in
all machines.
"""
gather_list = []
output_tensor = []
world_size = dist.get_world_size()
for tensor in tensors:
tensor_placeholder = [
torch.ones_like(tensor) for _ in range(world_size)
]
dist.all_gather(tensor_placeholder, tensor, async_op=False)
gather_list.append(tensor_placeholder)
for gathered_tensor in gather_list:
output_tensor.append(torch.cat(gathered_tensor, dim=0))
return output_tensor
def all_reduce(tensors, average=True):
"""
All reduce the provided tensors from all processes across machines.
Args:
tensors (list): tensors to perform all reduce across all processes in
all machines.
average (bool): scales the reduced tensor by the number of overall
processes across all machines.
"""
for tensor in tensors:
dist.all_reduce(tensor, async_op=False)
if average:
world_size = dist.get_world_size()
for tensor in tensors:
tensor.mul_(1.0 / world_size)
return tensors
def init_process_group(
local_rank,
local_world_size,
shard_id,
num_shards,
init_method,
dist_backend='nccl',
):
"""
Initializes the default process group.
Args:
local_rank (int): the rank on the current local machine.
local_world_size (int): the world size (number of processes running) on
the current local machine.
shard_id (int): the shard index (machine rank) of the current machine.
num_shards (int): number of shards for distributed training.
init_method (string): supporting three different methods for
initializing process groups:
"file": use shared file system to initialize the groups across
different processes.
"tcp": use tcp address to initialize the groups across different
dist_backend (string): backend to use for distributed training. Options
includes gloo, mpi and nccl, the details can be found here:
https://pytorch.org/docs/stable/distributed.html
"""
# Sets the GPU to use.
torch.cuda.set_device(local_rank)
# Initialize the process group.
proc_rank = local_rank + shard_id * local_world_size
world_size = local_world_size * num_shards
dist.init_process_group(
backend=dist_backend,
init_method=init_method,
world_size=world_size,
rank=proc_rank,
)
def is_master_proc(num_gpus=8):
"""
Determines if the current process is the master process.
"""
if torch.distributed.is_initialized():
return dist.get_rank() % num_gpus == 0
else:
return True
def get_world_size():
"""
Get the size of the world.
"""
if not dist.is_available():
return 1
if not dist.is_initialized():
return 1
return dist.get_world_size()
def get_rank():
"""
Get the rank of the current process.
"""
if not dist.is_available():
return 0
if not dist.is_initialized():
return 0
return dist.get_rank()
def synchronize():
"""
Helper function to synchronize (barrier) among all processes when
using distributed training
"""
if not dist.is_available():
return
if not dist.is_initialized():
return
world_size = dist.get_world_size()
if world_size == 1:
return
dist.barrier()
@functools.lru_cache()
def _get_global_gloo_group():
"""
Return a process group based on gloo backend, containing all the ranks
The result is cached.
Returns:
(group): pytorch dist group.
"""
if dist.get_backend() == 'nccl':
return dist.new_group(backend='gloo')
else:
return dist.group.WORLD
def _serialize_to_tensor(data, group):
"""
Seriialize the tensor to ByteTensor. Note that only `gloo` and `nccl`
backend is supported.
Args:
data (data): data to be serialized.
group (group): pytorch dist group.
Returns:
tensor (ByteTensor): tensor that serialized.
"""
backend = dist.get_backend(group)
assert backend in ['gloo', 'nccl']
device = torch.device('cpu' if backend == 'gloo' else 'cuda')
buffer = pickle.dumps(data)
if len(buffer) > 1024**3:
logger = logging.getLogger(__name__)
logger.warning(
'Rank {} trying to all-gather {:.2f} GB of data on device {}'.
format(get_rank(),
len(buffer) / (1024**3), device))
storage = torch.ByteStorage.from_buffer(buffer)
tensor = torch.ByteTensor(storage).to(device=device)
return tensor
def _pad_to_largest_tensor(tensor, group):
"""
Padding all the tensors from different GPUs to the largest ones.
Args:
tensor (tensor): tensor to pad.
group (group): pytorch dist group.
Returns:
list[int]: size of the tensor, on each rank
Tensor: padded tensor that has the max size
"""
world_size = dist.get_world_size(group=group)
assert (
world_size >= 1
), 'comm.gather/all_gather must be called from ranks within the given group!'
local_size = torch.tensor([tensor.numel()],
dtype=torch.int64,
device=tensor.device)
size_list = [
torch.zeros([1], dtype=torch.int64, device=tensor.device)
for _ in range(world_size)
]
dist.all_gather(size_list, local_size, group=group)
size_list = [int(size.item()) for size in size_list]
max_size = max(size_list)
# we pad the tensor because torch all_gather does not support
# gathering tensors of different shapes
if local_size != max_size:
padding = torch.zeros((max_size - local_size, ),
dtype=torch.uint8,
device=tensor.device)
tensor = torch.cat((tensor, padding), dim=0)
return size_list, tensor
def all_gather_unaligned(data, group=None):
"""
Run all_gather on arbitrary picklable data (not necessarily tensors).
Args:
data: any picklable object
group: a torch process group. By default, will use a group which
contains all ranks on gloo backend.
Returns:
list[data]: list of data gathered from each rank
"""
if get_world_size() == 1:
return [data]
if group is None:
group = _get_global_gloo_group()
if dist.get_world_size(group) == 1:
return [data]
tensor = _serialize_to_tensor(data, group)
size_list, tensor = _pad_to_largest_tensor(tensor, group)
max_size = max(size_list)
# receiving Tensor from all ranks
tensor_list = [
torch.empty((max_size, ), dtype=torch.uint8, device=tensor.device)
for _ in size_list
]
dist.all_gather(tensor_list, tensor, group=group)
data_list = []
for size, tensor in zip(size_list, tensor_list):
buffer = tensor.cpu().numpy().tobytes()[:size]
data_list.append(pickle.loads(buffer))
return data_list
def init_distributed_training(cfg):
"""
Initialize variables needed for distributed training.
"""
if cfg.NUM_GPUS <= 1:
return
num_gpus_per_machine = cfg.NUM_GPUS
num_machines = dist.get_world_size() // num_gpus_per_machine
for i in range(num_machines):
ranks_on_i = list(
range(i * num_gpus_per_machine, (i + 1) * num_gpus_per_machine))
pg = dist.new_group(ranks_on_i)
if i == cfg.SHARD_ID:
global _LOCAL_PROCESS_GROUP
_LOCAL_PROCESS_GROUP = pg
def get_local_size() -> int:
"""
Returns:
The size of the per-machine process group,
i.e. the number of processes per machine.
"""
if not dist.is_available():
return 1
if not dist.is_initialized():
return 1
return dist.get_world_size(group=_LOCAL_PROCESS_GROUP)
def get_local_rank() -> int:
"""
Returns:
The rank of the current process within the local (per-machine) process group.
"""
if not dist.is_available():
return 0
if not dist.is_initialized():
return 0
assert _LOCAL_PROCESS_GROUP is not None
return dist.get_rank(group=_LOCAL_PROCESS_GROUP)

View File

@@ -0,0 +1,955 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import base64
import binascii
import copy
import glob
import gzip
import hashlib
import logging
import math
import os
import os.path as osp
import pickle
import random
import sys
import time
import urllib.request
import zipfile
from io import BytesIO
from multiprocessing.pool import ThreadPool as Pool
import imageio
import json
import numpy as np
import oss2 as oss
import requests
import skvideo.io
import torch
import torch.nn.functional as F
import torchvision.utils as tvutils
from einops import rearrange
from PIL import Image
__all__ = [
'parse_oss_url', 'parse_bucket', 'read', 'read_image', 'read_gzip',
'ceil_divide', 'to_device', 'put_object', 'put_torch_object',
'put_object_from_file', 'get_object', 'get_object_to_file', 'rand_name',
'save_image', 'save_video', 'save_video_vs_conditions',
'save_video_multiple_conditions_with_data',
'save_video_multiple_conditions', 'download_video_to_file',
'save_video_grid_mp4', 'save_caps', 'ema', 'parallel', 'exists',
'download', 'unzip', 'load_state_dict', 'inverse_indices',
'detect_duplicates', 'read_tfs', 'md5', 'rope', 'format_state',
'breakup_grid', 'huggingface_tokenizer', 'huggingface_model'
]
TFS_CLIENT = None
def DOWNLOAD_TO_CACHE(oss_key,
file_or_dirname=None,
cache_dir=osp.join(
'/'.join(osp.abspath(__file__).split('/')[:-2]),
'model_weights')):
r"""Download OSS [file or folder] to the cache folder.
Only the 0th process on each node will run the downloading.
Barrier all processes until the downloading is completed.
"""
# source and target paths
base_path = osp.join(cache_dir, file_or_dirname or osp.basename(oss_key))
return base_path
def find_free_port():
"""https://stackoverflow.com/questions/1365265/on-localhost-how-do-i-pick-a-free-port-number"""
import socket
from contextlib import closing
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
s.bind(('', 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
return str(s.getsockname()[1])
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
def parse_oss_url(path):
if path.startswith('oss://'):
path = path[len('oss://'):]
# configs
configs = {
'endpoint': os.getenv('OSS_ENDPOINT', None),
'accessKeyID': os.getenv('OSS_ACCESS_KEY_ID', None),
'accessKeySecret': os.getenv('OSS_ACCESS_KEY_SECRET', None),
'securityToken': os.getenv('OSS_SECURITY_TOKEN', None)
}
bucket, path = path.split('/', maxsplit=1)
if '?' in bucket:
bucket, config = bucket.split('?', maxsplit=1)
for pair in config.split('&'):
k, v = pair.split('=', maxsplit=1)
configs[k] = v
# session
session = parse_oss_url._sessions.setdefault(f'{bucket}@{os.getpid()}',
oss.Session())
# bucket
bucket = oss.Bucket(
auth=oss.Auth(configs['accessKeyID'], configs['accessKeySecret']),
endpoint=configs['endpoint'],
bucket_name=bucket,
session=session)
return bucket, path
parse_oss_url._sessions = {}
def parse_bucket(url):
return parse_oss_url(osp.join(url, '_placeholder'))[0]
def read(filename, mode='r', retry=5):
assert mode in ['r', 'rb']
exception = None
for _ in range(retry):
try:
if filename.startswith('oss://'):
bucket, path = parse_oss_url(filename)
content = bucket.get_object(path).read()
if mode == 'r':
content = content.decode('utf-8')
elif filename.startswith('http'):
content = requests.get(filename).content
if mode == 'r':
content = content.decode('utf-8')
else:
with open(filename, mode=mode) as f:
content = f.read()
return content
except Exception as e:
exception = e
continue
else:
raise exception
def read_image(filename, retry=5):
exception = None
for _ in range(retry):
try:
return Image.open(BytesIO(read(filename, mode='rb', retry=retry)))
except Exception as e:
exception = e
continue
else:
raise exception
def download_video_to_file(filename, local_file, retry=5):
exception = None
for _ in range(retry):
try:
bucket, path = parse_oss_url(filename)
bucket.get_object_to_file(path, local_file)
break
except Exception as e:
exception = e
continue
else:
raise exception
def read_gzip(filename, retry=5):
exception = None
for _ in range(retry):
try:
remove = False
if filename.startswith('oss://'):
bucket, path = parse_oss_url(filename)
filename = rand_name(suffix=osp.splitext(filename)[1])
bucket.get_object_to_file(path, filename)
remove = True
with gzip.open(filename) as f:
content = f.read()
if remove:
os.remove(filename)
return content
except Exception as e:
exception = e
continue
else:
raise exception
def ceil_divide(a, b):
return int(math.ceil(a / b))
def to_device(batch, device, non_blocking=False):
if isinstance(batch, (list, tuple)):
return type(batch)([to_device(u, device, non_blocking) for u in batch])
elif isinstance(batch, dict):
return type(batch)([(k, to_device(v, device, non_blocking))
for k, v in batch.items()])
elif isinstance(batch, torch.Tensor) and batch.device != device:
batch = batch.to(device, non_blocking=non_blocking)
return batch
def put_object(bucket, oss_key, data, retry=5):
exception = None
for _ in range(retry):
try:
return bucket.put_object(oss_key, data)
except Exception as e:
exception = e
continue
else:
print(
f'put_object to {oss_key} failed with error: {exception}',
flush=True)
def put_torch_object(bucket, oss_key, data, retry=5):
exception = None
for _ in range(retry):
try:
buffer = BytesIO()
torch.save(data, buffer)
return bucket.put_object(oss_key, buffer.getvalue())
except Exception as e:
exception = e
continue
else:
print(
f'put_torch_object to {oss_key} failed with error: {exception}',
flush=True)
def put_object_from_file(bucket, oss_key, filename, retry=5):
exception = None
for _ in range(retry):
try:
return bucket.put_object_from_file(oss_key, filename)
except Exception as e:
exception = e
continue
else:
print(
f'put_object_from_file to {oss_key} failed with error: {exception}',
flush=True)
def get_object(bucket, oss_key, retry=5):
exception = None
for _ in range(retry):
try:
return bucket.get_object(oss_key).read()
except Exception as e:
exception = e
continue
else:
print(
f'get_object from {oss_key} failed with error: {exception}',
flush=True)
def get_object_to_file(bucket, oss_key, filename, retry=5):
exception = None
for _ in range(retry):
try:
return bucket.get_object_to_file(oss_key, filename)
except Exception as e:
exception = e
continue
else:
print(
f'get_object_to_file from {oss_key} failed with error: {exception}',
flush=True)
def rand_name(length=8, suffix=''):
name = binascii.b2a_hex(os.urandom(length)).decode('utf-8')
if suffix:
if not suffix.startswith('.'):
suffix = '.' + suffix
name += suffix
return name
@torch.no_grad()
def save_image(bucket,
oss_key,
tensor,
nrow=8,
normalize=True,
range=(-1, 1),
retry=5):
filename = rand_name(suffix='.jpg')
for _ in [None] * retry:
try:
tvutils.save_image(
tensor, filename, nrow=nrow, normalize=normalize, range=range)
bucket.put_object_from_file(oss_key, filename)
exception = None
break
except Exception as e:
exception = e
continue
# remove temporary file
if osp.exists(filename):
os.remove(filename)
if exception is not None:
print(
'save image to {} failed, error: {}'.format(oss_key, exception),
flush=True)
@torch.no_grad()
def video_tensor_to_gif(tensor, path, duration=120, loop=0, optimize=True):
tensor = tensor.permute(1, 2, 3, 0)
images = tensor.unbind(dim=0)
images = [(image.numpy() * 255).astype('uint8') for image in images]
imageio.mimwrite(path, images, fps=8)
return images
@torch.no_grad()
def save_video(bucket,
oss_key,
tensor,
mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5],
nrow=8,
retry=5):
mean = torch.tensor(mean, device=tensor.device).view(1, -1, 1, 1, 1)
std = torch.tensor(std, device=tensor.device).view(1, -1, 1, 1, 1)
tensor = tensor.mul_(std).add_(mean)
tensor.clamp_(0, 1)
filename = rand_name(suffix='.gif')
for _ in [None] * retry:
try:
one_gif = rearrange(
tensor, '(i j) c f h w -> c f (i h) (j w)', i=nrow)
video_tensor_to_gif(one_gif, filename)
bucket.put_object_from_file(oss_key, filename)
exception = None
break
except Exception as e:
exception = e
continue
# remove temporary file
if osp.exists(filename):
os.remove(filename)
if exception is not None:
print(
'save video to {} failed, error: {}'.format(oss_key, exception),
flush=True)
@torch.no_grad()
def save_video_multiple_conditions(oss_key,
video_tensor,
model_kwargs,
source_imgs,
palette,
mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5],
nrow=8,
retry=5,
save_origin_video=True,
bucket=None):
mean = torch.tensor(mean, device=video_tensor.device).view(1, -1, 1, 1, 1)
std = torch.tensor(std, device=video_tensor.device).view(1, -1, 1, 1, 1)
video_tensor = video_tensor.mul_(std).add_(mean)
try:
video_tensor.clamp_(0, 1)
except Exception as e:
video_tensor = video_tensor.float().clamp_(0, 1)
print(e)
video_tensor = video_tensor.cpu()
b, c, n, h, w = video_tensor.shape
source_imgs = F.adaptive_avg_pool3d(source_imgs, (n, h, w))
source_imgs = source_imgs.cpu()
model_kwargs_channel3 = {}
for key, conditions in model_kwargs[0].items():
if conditions.shape[-1] == 1024:
# Skip for style embeding
continue
if len(conditions.shape) == 3:
conditions_np = conditions.cpu().numpy()
conditions = []
for i in conditions_np:
vis_i = []
for j in i:
vis_i.append(
palette.get_palette_image(
j, percentile=90, width=256, height=256))
conditions.append(np.stack(vis_i))
conditions = torch.from_numpy(np.stack(conditions))
conditions = rearrange(conditions, 'b n h w c -> b c n h w')
else:
if conditions.size(1) == 1:
conditions = torch.cat([conditions, conditions, conditions],
dim=1)
conditions = F.adaptive_avg_pool3d(conditions, (n, h, w))
if conditions.size(1) == 2:
conditions = torch.cat([conditions, conditions[:, :1, ]],
dim=1)
conditions = F.adaptive_avg_pool3d(conditions, (n, h, w))
elif conditions.size(1) == 3:
conditions = F.adaptive_avg_pool3d(conditions, (n, h, w))
elif conditions.size(1) == 4:
color = ((conditions[:, 0:3] + 1.) / 2.)
alpha = conditions[:, 3:4]
conditions = color * alpha + 1.0 * (1.0 - alpha)
conditions = F.adaptive_avg_pool3d(conditions, (n, h, w))
model_kwargs_channel3[key] = conditions.cpu(
) if conditions.is_cuda else conditions
filename = oss_key
for _ in [None] * retry:
try:
vid_gif = rearrange(
video_tensor, '(i j) c f h w -> c f (i h) (j w)', i=nrow)
cons_list = [
rearrange(con, '(i j) c f h w -> c f (i h) (j w)', i=nrow)
for _, con in model_kwargs_channel3.items()
]
source_imgs = rearrange(
source_imgs, '(i j) c f h w -> c f (i h) (j w)', i=nrow)
if save_origin_video:
vid_gif = torch.cat(
[
source_imgs,
] + cons_list + [
vid_gif,
], dim=3)
else:
vid_gif = torch.cat(
cons_list + [
vid_gif,
], dim=3)
video_tensor_to_gif(vid_gif, filename)
exception = None
break
except Exception as e:
exception = e
continue
if exception is not None:
logging.info('save video to {} failed, error: {}'.format(
oss_key, exception))
@torch.no_grad()
def save_video_multiple_conditions_with_data(bucket,
video_save_key,
gt_video_save_key,
vis_oss_key,
video_tensor,
model_kwargs,
source_imgs,
palette,
mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5],
nrow=8,
retry=5):
mean = torch.tensor(mean, device=video_tensor.device).view(1, -1, 1, 1, 1)
std = torch.tensor(std, device=video_tensor.device).view(1, -1, 1, 1, 1)
video_tensor = video_tensor.mul_(std).add_(mean)
video_tensor.clamp_(0, 1)
b, c, n, h, w = video_tensor.shape
source_imgs = F.adaptive_avg_pool3d(source_imgs, (n, h, w))
source_imgs = source_imgs.cpu()
model_kwargs_channel3 = {}
for key, conditions in model_kwargs[0].items():
if len(conditions.shape) == 3:
conditions_np = conditions.cpu().numpy()
conditions = []
for i in conditions_np:
vis_i = []
for j in i:
vis_i.append(
palette.get_palette_image(
j, percentile=90, width=256, height=256))
conditions.append(np.stack(vis_i))
conditions = torch.from_numpy(np.stack(conditions))
conditions = rearrange(conditions, 'b n h w c -> b c n h w')
else:
if conditions.size(1) == 1:
conditions = torch.cat([conditions, conditions, conditions],
dim=1)
conditions = F.adaptive_avg_pool3d(conditions, (n, h, w))
if conditions.size(1) == 2:
conditions = torch.cat([conditions, conditions[:, :1, ]],
dim=1)
conditions = F.adaptive_avg_pool3d(conditions, (n, h, w))
elif conditions.size(1) == 3:
conditions = F.adaptive_avg_pool3d(conditions, (n, h, w))
elif conditions.size(1) == 4:
color = ((conditions[:, 0:3] + 1.) / 2.)
alpha = conditions[:, 3:4]
conditions = color * alpha + 1.0 * (1.0 - alpha)
conditions = F.adaptive_avg_pool3d(conditions, (n, h, w))
model_kwargs_channel3[key] = conditions.cpu(
) if conditions.is_cuda else conditions
copy_video_tensor = video_tensor.clone()
copy_source_imgs = source_imgs.clone()
filename = rand_name(suffix='.gif')
for _ in [None] * retry:
try:
vid_gif = rearrange(
video_tensor, '(i j) c f h w -> c f (i h) (j w)', i=nrow)
cons_list = [
rearrange(con, '(i j) c f h w -> c f (i h) (j w)', j=nrow)
for _, con in model_kwargs_channel3.items()
]
source_imgs = rearrange(
source_imgs, '(i j) c f h w -> c f (i h) (j w)', i=nrow)
vid_gif = torch.cat(
[
source_imgs,
] + cons_list + [
vid_gif,
], dim=3)
video_tensor_to_gif(vid_gif, filename)
bucket.put_object_from_file(vis_oss_key, filename)
exception = None
break
except Exception as e:
exception = e
continue
# remove temporary file
if osp.exists(filename):
os.remove(filename)
filename_pred = rand_name(suffix='.pkl')
for _ in [None] * retry:
try:
copy_video_np = (copy_video_tensor.numpy() * 255).astype('uint8')
pickle.dump(copy_video_np, open(filename_pred, 'wb'))
bucket.put_object_from_file(video_save_key, filename_pred)
break
except Exception as e:
print('error! ', video_save_key, e)
continue
# remove temporary file
if osp.exists(filename_pred):
os.remove(filename_pred)
filename_gt = rand_name(suffix='.pkl')
for _ in [None] * retry:
try:
copy_source_np = (copy_source_imgs.numpy() * 255).astype('uint8')
pickle.dump(copy_source_np, open(filename_gt, 'wb'))
bucket.put_object_from_file(gt_video_save_key, filename_gt)
break
except Exception as e:
print('error! ', gt_video_save_key, e)
continue
# remove temporary file
if osp.exists(filename_gt):
os.remove(filename_gt)
if exception is not None:
print(
'save video to {} failed, error: {}'.format(
vis_oss_key, exception),
flush=True)
@torch.no_grad()
def save_video_vs_conditions(bucket,
oss_key,
video_tensor,
conditions,
source_imgs,
mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5],
nrow=8,
retry=5):
mean = torch.tensor(mean, device=video_tensor.device).view(1, -1, 1, 1, 1)
std = torch.tensor(std, device=video_tensor.device).view(1, -1, 1, 1, 1)
video_tensor = video_tensor.mul_(std).add_(mean)
video_tensor.clamp_(0, 1)
b, c, n, h, w = video_tensor.shape
source_imgs = F.adaptive_avg_pool3d(source_imgs, (n, h, w))
source_imgs = source_imgs.cpu()
if conditions.size(1) == 1:
conditions = torch.cat([conditions, conditions, conditions], dim=1)
conditions = F.adaptive_avg_pool3d(conditions, (n, h, w))
filename = rand_name(suffix='.gif')
for _ in [None] * retry:
try:
vid_gif = rearrange(
video_tensor, '(i j) c f h w -> c f (i h) (j w)', i=nrow)
con_gif = rearrange(
conditions, '(i j) c f h w -> c f (i h) (j w)', i=nrow)
source_imgs = rearrange(
source_imgs, '(i j) c f h w -> c f (i h) (j w)', i=nrow)
vid_gif = torch.cat([vid_gif, con_gif, source_imgs], dim=2)
video_tensor_to_gif(vid_gif, filename)
bucket.put_object_from_file(oss_key, filename)
exception = None
break
except Exception as e:
exception = e
continue
# remove temporary file
if osp.exists(filename):
os.remove(filename)
if exception is not None:
print(
'save video to {} failed, error: {}'.format(oss_key, exception),
flush=True)
@torch.no_grad()
def save_video_grid_mp4(bucket,
oss_key,
tensor,
mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5],
nrow=None,
fps=5,
retry=5):
mean = torch.tensor(mean, device=tensor.device).view(1, -1, 1, 1, 1)
std = torch.tensor(std, device=tensor.device).view(1, -1, 1, 1, 1)
tensor = tensor.mul_(std).add_(mean)
tensor.clamp_(0, 1)
b, c, t, h, w = tensor.shape
tensor = tensor.permute(0, 2, 3, 4, 1)
tensor = (tensor.cpu().numpy() * 255).astype('uint8')
filename = rand_name(suffix='.mp4')
for _ in [None] * retry:
try:
if nrow is None:
nrow = math.ceil(math.sqrt(b))
ncol = math.ceil(b / nrow)
padding = 1
video_grid = np.zeros((t, (padding + h) * nrow + padding,
(padding + w) * ncol + padding, c),
dtype='uint8')
for i in range(b):
r = i // ncol
c_ = i % ncol
start_r = (padding + h) * r
start_c = (padding + w) * c_
video_grid[:, start_r:start_r + h,
start_c:start_c + w] = tensor[i]
skvideo.io.vwrite(filename, video_grid, inputdict={'-r': str(fps)})
bucket.put_object_from_file(oss_key, filename)
exception = None
break
except Exception as e:
exception = e
continue
# remove temporary file
if osp.exists(filename):
os.remove(filename)
if exception is not None:
print(
'save video to {} failed, error: {}'.format(oss_key, exception),
flush=True)
@torch.no_grad()
def save_text(bucket, oss_key, tensor, nrow=8, retry=5):
len = tensor.shape[0]
num_per_row = int(len / nrow)
assert (len == nrow * num_per_row)
texts = ''
for i in range(nrow):
for j in range(num_per_row):
text = dec_bytes2obj(tensor[i * num_per_row + j])
texts += text + '\n'
texts += '\n'
for _ in [None] * retry:
try:
bucket.put_object(oss_key, texts)
exception = None
break
except Exception as e:
exception = e
continue
if exception is not None:
print(
'save video to {} failed, error: {}'.format(oss_key, exception),
flush=True)
@torch.no_grad()
def save_caps(bucket, oss_key, caps, retry=5):
texts = ''
for cap in caps:
texts += cap
texts += '\n'
for _ in [None] * retry:
try:
bucket.put_object(oss_key, texts)
exception = None
break
except Exception as e:
exception = e
continue
if exception is not None:
print(
'save video to {} failed, error: {}'.format(oss_key, exception),
flush=True)
@torch.no_grad()
def ema(net_ema, net, beta, copy_buffer=False):
assert 0.0 <= beta <= 1.0
for p_ema, p in zip(net_ema.parameters(), net.parameters()):
p_ema.copy_(p.lerp(p_ema, beta))
if copy_buffer:
for b_ema, b in zip(net_ema.buffers(), net.buffers()):
b_ema.copy_(b)
def parallel(func, args_list, num_workers=32, timeout=None):
assert isinstance(args_list, list)
if not isinstance(args_list[0], tuple):
args_list = [(args, ) for args in args_list]
if num_workers == 0:
return [func(*args) for args in args_list]
with Pool(processes=num_workers) as pool:
results = [pool.apply_async(func, args) for args in args_list]
results = [res.get(timeout=timeout) for res in results]
return results
def exists(filename):
if filename.startswith('oss://'):
bucket, path = parse_oss_url(filename)
return bucket.object_exists(path)
else:
return osp.exists(filename)
def download(url, filename=None, replace=False, quiet=False):
if filename is None:
filename = osp.basename(url)
if not osp.exists(filename) or replace:
try:
if url.startswith('oss://'):
bucket, oss_key = parse_oss_url(url)
bucket.get_object_to_file(oss_key, filename)
else:
urllib.request.urlretrieve(url, filename)
if not quiet:
print(f'Downloaded {url} to {filename}', flush=True)
except Exception as e:
raise ValueError(f'Downloading {filename} failed with error {e}')
return osp.abspath(filename)
def unzip(filename, dst_dir=None):
if dst_dir is None:
dst_dir = osp.dirname(filename)
with zipfile.ZipFile(filename, 'r') as zip_ref:
zip_ref.extractall(dst_dir)
def load_state_dict(module, state_dict, drop_prefix=''):
# find incompatible key-vals
src, dst = state_dict, module.state_dict()
if drop_prefix:
src = type(src)([
(k[len(drop_prefix):] if k.startswith(drop_prefix) else k, v)
for k, v in src.items()
])
missing = [k for k in dst if k not in src]
unexpected = [k for k in src if k not in dst]
unmatched = [
k for k in src.keys() & dst.keys() if src[k].shape != dst[k].shape
]
# keep only compatible key-vals
incompatible = set(unexpected + unmatched)
src = type(src)([(k, v) for k, v in src.items() if k not in incompatible])
module.load_state_dict(src, strict=False)
# report incompatible key-vals
if len(missing) != 0:
print(' Missing: ' + ', '.join(missing), flush=True)
if len(unexpected) != 0:
print(' Unexpected: ' + ', '.join(unexpected), flush=True)
if len(unmatched) != 0:
print(' Shape unmatched: ' + ', '.join(unmatched), flush=True)
def inverse_indices(indices):
r"""Inverse map of indices.
E.g., if A[indices] == B, then B[inv_indices] == A.
"""
inv_indices = torch.empty_like(indices)
inv_indices[indices] = torch.arange(len(indices)).to(indices)
return inv_indices
def detect_duplicates(feats, thr=0.9):
assert feats.ndim == 2
# compute simmat
feats = F.normalize(feats, p=2, dim=1)
simmat = torch.mm(feats, feats.T)
simmat.triu_(1)
torch.cuda.synchronize()
# detect duplicates
mask = ~simmat.gt(thr).any(dim=0)
return torch.where(mask)[0]
class TFSClient(object):
def __init__(self,
host='restful-store.vip.tbsite.net:3800',
app_key='5354c9fae75f5'):
self.host = host
self.app_key = app_key
# candidate servers
self.servers = [
u for u in read(f'http://{host}/url.list').strip().split('\n')[1:]
if ':' in u
]
assert len(self.servers) >= 1
self.__server_id = -1
@property
def server(self):
self.__server_id = (self.__server_id + 1) % len(self.servers)
return self.servers[self.__server_id]
def read(self, tfs):
tfs = osp.basename(tfs)
meta = json.loads(
read(
f'http://{self.server}/v1/{self.app_key}/metadata/{tfs}?force=0'
))
img = Image.open(
BytesIO(
read(
f'http://{self.server}/v1/{self.app_key}/{tfs}?offset=0&size={meta["SIZE"]}',
'rb')))
return img
def read_tfs(tfs, retry=5):
exception = None
for _ in range(retry):
try:
global TFS_CLIENT
if TFS_CLIENT is None:
TFS_CLIENT = TFSClient()
return TFS_CLIENT.read(tfs)
except Exception as e:
exception = e
continue
else:
raise exception
def md5(filename):
with open(filename, 'rb') as f:
return hashlib.md5(f.read()).hexdigest()
def rope(x):
r"""Apply rotary position embedding on x of shape [B, *(spatial dimensions), C].
"""
# reshape
shape = x.shape
x = x.view(x.size(0), -1, x.size(-1))
l, c = x.shape[-2:]
assert c % 2 == 0
half = c // 2
# apply rotary position embedding on x
sinusoid = torch.outer(
torch.arange(l).to(x),
torch.pow(10000, -torch.arange(half).to(x).div(half)))
sin, cos = torch.sin(sinusoid), torch.cos(sinusoid)
x1, x2 = x.chunk(2, dim=-1)
x = torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1)
# reshape back
return x.view(shape)
def format_state(state, filename=None):
r"""For comparing/aligning state_dict.
"""
content = '\n'.join([f'{k}\t{tuple(v.shape)}' for k, v in state.items()])
if filename:
with open(filename, 'w') as f:
f.write(content)
def breakup_grid(img, grid_size):
r"""The inverse operator of ``torchvision.utils.make_grid``.
"""
# params
nrow = img.height // grid_size
ncol = img.width // grid_size
wrow = wcol = 2
# collect grids
grids = []
for i in range(nrow):
for j in range(ncol):
x1 = j * grid_size + (j + 1) * wcol
y1 = i * grid_size + (i + 1) * wrow
grids.append(img.crop((x1, y1, x1 + grid_size, y1 + grid_size)))
return grids
def huggingface_tokenizer(name='google/mt5-xxl', **kwargs):
from transformers import AutoTokenizer
return AutoTokenizer.from_pretrained(
DOWNLOAD_TO_CACHE(f'huggingface/tokenizers/{name}', name), **kwargs)
def huggingface_model(name='google/mt5-xxl', model_type='AutoModel', **kwargs):
import transformers
return getattr(transformers, model_type).from_pretrained(
DOWNLOAD_TO_CACHE(f'huggingface/models/{name}', name), **kwargs)

View File

@@ -0,0 +1,480 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
from copy import copy, deepcopy
from os import path as osp
from typing import Any, Dict
import open_clip
import pynvml
import torch
import torch.cuda.amp as amp
import torch.nn as nn
from einops import rearrange
import modelscope.models.multi_modal.videocomposer.models as models
from modelscope.metainfo import Models
from modelscope.models import TorchModel
from modelscope.models.builder import MODELS
from modelscope.models.multi_modal.videocomposer.annotator.sketch import (
pidinet_bsd, sketch_simplification_gan)
from modelscope.models.multi_modal.videocomposer.autoencoder import \
AutoencoderKL
from modelscope.models.multi_modal.videocomposer.clip import (
FrozenOpenCLIPEmbedder, FrozenOpenCLIPVisualEmbedder)
from modelscope.models.multi_modal.videocomposer.diffusion import (
GaussianDiffusion, beta_schedule)
from modelscope.models.multi_modal.videocomposer.ops.utils import (
get_first_stage_encoding, make_masked_images, prepare_model_kwargs,
save_with_model_kwargs)
from modelscope.models.multi_modal.videocomposer.unet_sd import UNetSD_temporal
from modelscope.models.multi_modal.videocomposer.utils.config import Config
from modelscope.models.multi_modal.videocomposer.utils.utils import (
find_free_port, setup_seed, to_device)
from modelscope.outputs import OutputKeys
from modelscope.preprocessors.image import load_image
from modelscope.utils.constant import ModelFile, Tasks
from .config import cfg
__all__ = ['VideoComposer']
@MODELS.register_module(
Tasks.text_to_video_synthesis, module_name=Models.videocomposer)
class VideoComposer(TorchModel):
r"""
task for video composer.
Attributes:
sd_model: denosing model using in this task.
diffusion: diffusion model for DDIM.
autoencoder: decode the latent representation into visual space with VQGAN.
clip_encoder: encode the text into text embedding.
"""
def __init__(self, model_dir, *args, **kwargs):
r"""
Args:
model_dir (`str` or `os.PathLike`)
Can be either:
- A string, the *model id* of a pretrained model hosted inside a model repo on modelscope
or modelscope.cn. Valid model ids can be located at the root-level, like `bert-base-uncased`,
or namespaced under a user or organization name, like `dbmdz/bert-base-german-cased`.
- A path to a *directory* containing model weights saved using
[`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
- A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In
this case, `from_tf` should be set to `True` and a configuration object should be provided as
`config` argument. This loading path is slower than converting the TensorFlow checkpoint in a
PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
- A path or url to a model folder containing a *flax checkpoint file* in *.msgpack* format (e.g,
`./flax_model/` containing `flax_model.msgpack`). In this case, `from_flax` should be set to
`True`.
"""
super().__init__(model_dir=model_dir, *args, **kwargs)
self.device = torch.device('cuda') if torch.cuda.is_available() \
else torch.device('cpu')
clip_checkpoint = kwargs.pop('clip_checkpoint',
'open_clip_pytorch_model.bin')
sd_checkpoint = kwargs.pop('sd_checkpoint', 'v2-1_512-ema-pruned.ckpt')
_cfg = Config(load=True)
cfg.update(_cfg.cfg_dict)
# rank-wise params
l1 = len(cfg.frame_lens)
l2 = len(cfg.feature_framerates)
cfg.max_frames = cfg.frame_lens[0 % (l1 * l2) // l2]
cfg.batch_size = cfg.batch_sizes[str(cfg.max_frames)]
# Copy update input parameter to current task
self.cfg = cfg
if 'MASTER_ADDR' not in os.environ:
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = find_free_port()
self.cfg.pmi_rank = int(os.getenv('RANK', 0))
self.cfg.pmi_world_size = int(os.getenv('WORLD_SIZE', 1))
setup_seed(self.cfg.seed)
self.read_image = kwargs.pop('read_image', False)
self.read_style = kwargs.pop('read_style', True)
self.read_sketch = kwargs.pop('read_sketch', False)
self.save_origin_video = kwargs.pop('save_origin_video', True)
self.video_compositions = kwargs.pop('video_compositions', [
'text', 'mask', 'depthmap', 'sketch', 'motion', 'image',
'local_image', 'single_sketch'
])
self.viz_num = self.cfg.batch_size
self.clip_encoder = FrozenOpenCLIPEmbedder(
layer='penultimate',
pretrained=os.path.join(model_dir, clip_checkpoint))
self.clip_encoder = self.clip_encoder.to(self.device)
self.clip_encoder_visual = FrozenOpenCLIPVisualEmbedder(
layer='penultimate',
pretrained=os.path.join(model_dir, clip_checkpoint))
self.clip_encoder_visual.model.to(self.device)
ddconfig = {
'double_z': True,
'z_channels': 4,
'resolution': 256,
'in_channels': 3,
'out_ch': 3,
'ch': 128,
'ch_mult': [1, 2, 4, 4],
'num_res_blocks': 2,
'attn_resolutions': [],
'dropout': 0.0
}
self.autoencoder = AutoencoderKL(
ddconfig, 4, ckpt_path=os.path.join(model_dir, sd_checkpoint))
self.zero_y = self.clip_encoder('').detach()
black_image_feature = self.clip_encoder_visual(
self.clip_encoder_visual.black_image).unsqueeze(1)
black_image_feature = torch.zeros_like(black_image_feature)
self.autoencoder.eval()
for param in self.autoencoder.parameters():
param.requires_grad = False
self.autoencoder.cuda()
self.model = UNetSD_temporal(
cfg=self.cfg,
in_dim=self.cfg.unet_in_dim,
concat_dim=self.cfg.unet_concat_dim,
dim=self.cfg.unet_dim,
y_dim=self.cfg.unet_y_dim,
context_dim=self.cfg.unet_context_dim,
out_dim=self.cfg.unet_out_dim,
dim_mult=self.cfg.unet_dim_mult,
num_heads=self.cfg.unet_num_heads,
head_dim=self.cfg.unet_head_dim,
num_res_blocks=self.cfg.unet_res_blocks,
attn_scales=self.cfg.unet_attn_scales,
dropout=self.cfg.unet_dropout,
temporal_attention=self.cfg.temporal_attention,
temporal_attn_times=self.cfg.temporal_attn_times,
use_checkpoint=self.cfg.use_checkpoint,
use_fps_condition=self.cfg.use_fps_condition,
use_sim_mask=self.cfg.use_sim_mask,
video_compositions=self.cfg.video_compositions,
misc_dropout=self.cfg.misc_dropout,
p_all_zero=self.cfg.p_all_zero,
p_all_keep=self.cfg.p_all_zero,
zero_y=self.zero_y,
black_image_feature=black_image_feature,
).to(self.device)
# Load checkpoint
if self.cfg.resume and self.cfg.resume_checkpoint:
if hasattr(self.cfg, 'text_to_video_pretrain'
) and self.cfg.text_to_video_pretrain:
checkpoint_name = cfg.resume_checkpoint.split('/')[-1]
ss = torch.load(
os.path.join(self.model_dir, cfg.resume_checkpoint))
ss = {
key: p
for key, p in ss.items() if 'input_blocks.0.0' not in key
}
self.model.load_state_dict(ss, strict=False)
else:
checkpoint_name = cfg.resume_checkpoint.split('/')[-1]
self.model.load_state_dict(
torch.load(
os.path.join(self.model_dir, checkpoint_name),
map_location='cpu'),
strict=False)
torch.cuda.empty_cache()
else:
raise ValueError(
f'The checkpoint file {self.cfg.resume_checkpoint} is wrong ')
# diffusion
betas = beta_schedule(
'linear_sd',
self.cfg.num_timesteps,
init_beta=0.00085,
last_beta=0.0120)
self.diffusion = GaussianDiffusion(
betas=betas,
mean_type=self.cfg.mean_type,
var_type=self.cfg.var_type,
loss_type=self.cfg.loss_type,
rescale_timesteps=False)
def forward(self, input: Dict[str, Any]):
frame_in = None
if self.read_image:
image_key = input['style_image']
frame = load_image(image_key)
frame_in = misc_transforms([frame])
frame_sketch = None
if self.read_sketch:
sketch_key = self.cfg.sketch_path
frame_sketch = load_image(sketch_key)
frame_sketch = misc_transforms([frame_sketch])
frame_style = None
if self.read_style:
frame_style = load_image(input['style_image'])
# Generators for various conditions
if 'depthmap' in self.video_compositions:
midas = models.midas_v3(
pretrained=True,
model_dir=self.model_dir).eval().requires_grad_(False).to(
memory_format=torch.channels_last).half().to(self.device)
if 'canny' in self.video_compositions:
canny_detector = CannyDetector()
if 'sketch' in self.video_compositions:
pidinet = pidinet_bsd(
self.model_dir, pretrained=True,
vanilla_cnn=True).eval().requires_grad_(False).to(self.device)
cleaner = sketch_simplification_gan(
self.model_dir,
pretrained=True).eval().requires_grad_(False).to(self.device)
pidi_mean = torch.tensor(self.cfg.sketch_mean).view(
1, -1, 1, 1).to(self.device)
pidi_std = torch.tensor(self.cfg.sketch_std).view(1, -1, 1, 1).to(
self.device)
# Placeholder for color inference
palette = None
self.model.eval()
caps = input['cap_txt']
if self.cfg.max_frames == 1 and self.cfg.use_image_dataset:
ref_imgs = input['ref_frame']
video_data = input['video_data']
misc_data = input['misc_data']
mask = input['mask']
mv_data = input['mv_data']
fps = torch.tensor(
[self.cfg.feature_framerate] * self.cfg.batch_size,
dtype=torch.long,
device=self.device)
else:
ref_imgs = input['ref_frame']
video_data = input['video_data']
misc_data = input['misc_data']
mask = input['mask']
mv_data = input['mv_data']
# add fps test
fps = torch.tensor(
[self.cfg.feature_framerate] * self.cfg.batch_size,
dtype=torch.long,
device=self.device)
# save for visualization
misc_backups = copy(misc_data)
misc_backups = rearrange(misc_backups, 'b f c h w -> b c f h w')
mv_data_video = []
if 'motion' in self.cfg.video_compositions:
mv_data_video = rearrange(mv_data, 'b f c h w -> b c f h w')
# mask images
masked_video = []
if 'mask' in self.cfg.video_compositions:
masked_video = make_masked_images(
misc_data.sub(0.5).div_(0.5), mask)
masked_video = rearrange(masked_video, 'b f c h w -> b c f h w')
# Single Image
image_local = []
if 'local_image' in self.cfg.video_compositions:
frames_num = misc_data.shape[1]
bs_vd_local = misc_data.shape[0]
if self.cfg.read_image:
image_local = frame_in.unsqueeze(0).repeat(
bs_vd_local, frames_num, 1, 1, 1).cuda()
else:
image_local = misc_data[:, :1].clone().repeat(
1, frames_num, 1, 1, 1)
image_local = rearrange(
image_local, 'b f c h w -> b c f h w', b=bs_vd_local)
# encode the video_data
bs_vd = video_data.shape[0]
video_data = rearrange(video_data, 'b f c h w -> (b f) c h w')
misc_data = rearrange(misc_data, 'b f c h w -> (b f) c h w')
video_data_list = torch.chunk(
video_data, video_data.shape[0] // self.cfg.chunk_size, dim=0)
misc_data_list = torch.chunk(
misc_data, misc_data.shape[0] // self.cfg.chunk_size, dim=0)
with torch.no_grad():
decode_data = []
for vd_data in video_data_list:
encoder_posterior = self.autoencoder.encode(vd_data)
tmp = get_first_stage_encoding(encoder_posterior).detach()
decode_data.append(tmp)
video_data = torch.cat(decode_data, dim=0)
video_data = rearrange(
video_data, '(b f) c h w -> b c f h w', b=bs_vd)
depth_data = []
if 'depthmap' in self.cfg.video_compositions:
for misc_imgs in misc_data_list:
depth = midas(
misc_imgs.sub(0.5).div_(0.5).to(
memory_format=torch.channels_last).half())
depth = (depth / self.cfg.depth_std).clamp_(
0, self.cfg.depth_clamp)
depth_data.append(depth)
depth_data = torch.cat(depth_data, dim=0)
depth_data = rearrange(
depth_data, '(b f) c h w -> b c f h w', b=bs_vd)
canny_data = []
if 'canny' in self.cfg.video_compositions:
for misc_imgs in misc_data_list:
misc_imgs = rearrange(misc_imgs.clone(),
'k c h w -> k h w c')
canny_condition = torch.stack(
[canny_detector(misc_img) for misc_img in misc_imgs])
canny_condition = rearrange(canny_condition,
'k h w c-> k c h w')
canny_data.append(canny_condition)
canny_data = torch.cat(canny_data, dim=0)
canny_data = rearrange(
canny_data, '(b f) c h w -> b c f h w', b=bs_vd)
sketch_data = []
if 'sketch' in self.cfg.video_compositions:
sketch_list = misc_data_list
if self.cfg.read_sketch:
sketch_repeat = frame_sketch.repeat(frames_num, 1, 1,
1).cuda()
sketch_list = [sketch_repeat]
for misc_imgs in sketch_list:
sketch = pidinet(misc_imgs.sub(pidi_mean).div_(pidi_std))
sketch = 1.0 - cleaner(1.0 - sketch)
sketch_data.append(sketch)
sketch_data = torch.cat(sketch_data, dim=0)
sketch_data = rearrange(
sketch_data, '(b f) c h w -> b c f h w', b=bs_vd)
single_sketch_data = []
if 'single_sketch' in self.cfg.video_compositions:
single_sketch_data = sketch_data.clone()[:, :, :1].repeat(
1, 1, frames_num, 1, 1)
# preprocess for input text descripts
y = self.clip_encoder(caps).detach()
y0 = y.clone()
y_visual = []
if 'image' in self.cfg.video_compositions:
with torch.no_grad():
if self.cfg.read_style:
y_visual = self.clip_encoder_visual(
self.clip_encoder_visual.preprocess(
frame_style).unsqueeze(0).cuda()).unsqueeze(0)
y_visual0 = y_visual.clone()
else:
ref_imgs = ref_imgs.squeeze(1)
y_visual = self.clip_encoder_visual(ref_imgs).unsqueeze(1)
y_visual0 = y_visual.clone()
with torch.no_grad():
# Log memory
pynvml.nvmlInit()
# Sample images (DDIM)
with amp.autocast(enabled=self.cfg.use_fp16):
if self.cfg.share_noise:
b, c, f, h, w = video_data.shape
noise = torch.randn((self.viz_num, c, h, w),
device=self.device)
noise = noise.repeat_interleave(repeats=f, dim=0)
noise = rearrange(
noise, '(b f) c h w->b c f h w', b=self.viz_num)
noise = noise.contiguous()
else:
noise = torch.randn_like(video_data[:self.viz_num])
full_model_kwargs = [{
'y':
y0[:self.viz_num],
'local_image':
None
if len(image_local) == 0 else image_local[:self.viz_num],
'image':
None if len(y_visual) == 0 else y_visual0[:self.viz_num],
'depth':
None
if len(depth_data) == 0 else depth_data[:self.viz_num],
'canny':
None
if len(canny_data) == 0 else canny_data[:self.viz_num],
'sketch':
None
if len(sketch_data) == 0 else sketch_data[:self.viz_num],
'masked':
None
if len(masked_video) == 0 else masked_video[:self.viz_num],
'motion':
None if len(mv_data_video) == 0 else
mv_data_video[:self.viz_num],
'single_sketch':
None if len(single_sketch_data) == 0 else
single_sketch_data[:self.viz_num],
'fps':
fps[:self.viz_num]
}, {
'y':
self.zero_y.repeat(self.viz_num, 1, 1)
if not self.cfg.use_fps_condition else
torch.zeros_like(y0)[:self.viz_num],
'local_image':
None
if len(image_local) == 0 else image_local[:self.viz_num],
'image':
None if len(y_visual) == 0 else torch.zeros_like(
y_visual0[:self.viz_num]),
'depth':
None
if len(depth_data) == 0 else depth_data[:self.viz_num],
'canny':
None
if len(canny_data) == 0 else canny_data[:self.viz_num],
'sketch':
None
if len(sketch_data) == 0 else sketch_data[:self.viz_num],
'masked':
None
if len(masked_video) == 0 else masked_video[:self.viz_num],
'motion':
None if len(mv_data_video) == 0 else
mv_data_video[:self.viz_num],
'single_sketch':
None if len(single_sketch_data) == 0 else
single_sketch_data[:self.viz_num],
'fps':
fps[:self.viz_num]
}]
# Save generated videos
partial_keys = self.cfg.guidances
noise_motion = noise.clone()
model_kwargs = prepare_model_kwargs(
partial_keys=partial_keys,
full_model_kwargs=full_model_kwargs,
use_fps_condition=self.cfg.use_fps_condition)
video_output = self.diffusion.ddim_sample_loop(
noise=noise_motion,
model=self.model.eval(),
model_kwargs=model_kwargs,
guide_scale=9.0,
ddim_timesteps=self.cfg.ddim_timesteps,
eta=0.0)
save_with_model_kwargs(
model_kwargs=model_kwargs,
video_data=video_output,
autoencoder=self.autoencoder,
ori_video=misc_backups,
viz_num=self.viz_num,
step=0,
caps=caps,
palette=palette,
cfg=self.cfg)
return {
'video': video_output.type(torch.float32).cpu(),
'video_path': self.cfg
}

View File

@@ -22,6 +22,7 @@ if TYPE_CHECKING:
from .soonet_video_temporal_grounding_pipeline import SOONetVideoTemporalGroundingPipeline
from .text_to_video_synthesis_pipeline import TextToVideoSynthesisPipeline
from .multimodal_dialogue_pipeline import MultimodalDialoguePipeline
from .videocomposer_pipeline import VideoComposerPipeline
else:
_import_structure = {
'image_captioning_pipeline': ['ImageCaptioningPipeline'],
@@ -46,7 +47,8 @@ else:
'soonet_video_temporal_grounding_pipeline':
['SOONetVideoTemporalGroundingPipeline'],
'text_to_video_synthesis_pipeline': ['TextToVideoSynthesisPipeline'],
'multimodal_dialogue_pipeline': ['MultimodalDialoguePipeline']
'multimodal_dialogue_pipeline': ['MultimodalDialoguePipeline'],
'videocomposer_pipeline': ['VideoComposerPipeline']
}
import sys

View File

@@ -0,0 +1,382 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import random
import subprocess
import tempfile
import time
from functools import partial
from typing import Any, Dict
import cv2
import imageio
import numpy as np
import torch
import torchvision.transforms as T
from mvextractor.videocap import VideoCap
from PIL import Image
import modelscope.models.multi_modal.videocomposer.data as data
from modelscope.metainfo import Pipelines
from modelscope.models.multi_modal.videocomposer.data.transforms import (
CenterCropV3, random_resize)
from modelscope.models.multi_modal.videocomposer.ops.random_mask import (
make_irregular_mask, make_rectangle_mask, make_uncrop)
from modelscope.models.multi_modal.videocomposer.utils.utils import rand_name
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Input, Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.utils.constant import Tasks
from modelscope.utils.device import device_placement
from modelscope.utils.logger import get_logger
logger = get_logger()
@PIPELINES.register_module(
Tasks.text_to_video_synthesis, module_name=Pipelines.videocomposer)
class VideoComposerPipeline(Pipeline):
r""" Video Composer Pipeline.
Examples:
>>> from modelscope.pipelines import pipeline
>>> from modelscope.utils.constant import Tasks
>>> pipe = pipeline(
task=Tasks.text_to_video_synthesis,
model='buptwq/videocomposer',
model_revision='v1.0.1')
>>> inputs = {'Video:FILE': 'path/input_video.mp4',
'Image:FILE': 'path/input_image.png',
'text': 'the text description'}
>>> output = pipe(inputs)
"""
def __init__(self, model: str, **kwargs):
"""
use `model` to create a videocomposer pipeline for prediction
Args:
model: model id on modelscope hub.
"""
super().__init__(model=model)
self.log_dir = kwargs.pop('log_dir', './video_outputs')
if not os.path.exists(self.log_dir):
os.makedirs(self.log_dir)
self.feature_framerate = kwargs.pop('feature_framerate', 4)
self.frame_lens = kwargs.pop('frame_lens', [
16,
16,
16,
16,
])
self.feature_framerates = kwargs.pop('feature_framerates', [
4,
])
self.batch_sizes = kwargs.pop('batch_sizes', {
'1': 1,
'4': 1,
'8': 1,
'16': 1,
})
l1 = len(self.frame_lens)
l2 = len(self.feature_framerates)
self.max_frames = self.frame_lens[0 % (l1 * l2) // l2]
self.batch_size = self.batch_sizes[str(self.max_frames)]
self.resolution = kwargs.pop('resolution', 256)
self.image_resolution = kwargs.pop('image_resolution', 256)
self.mean = kwargs.pop('mean', [0.5, 0.5, 0.5])
self.std = kwargs.pop('std', [0.5, 0.5, 0.5])
self.vit_image_size = kwargs.pop('vit_image_size', 224)
self.vit_mean = kwargs.pop('vit_mean',
[0.48145466, 0.4578275, 0.40821073])
self.vit_std = kwargs.pop('vit_std',
[0.26862954, 0.26130258, 0.27577711])
self.misc_size = kwargs.pop('kwargs.pop', 384)
self.visual_mv = kwargs.pop('visual_mv', False)
self.max_words = kwargs.pop('max_words', 1000)
self.mvs_visual = kwargs.pop('mvs_visual', False)
self.infer_trans = data.Compose([
data.CenterCropV2(size=self.resolution),
data.ToTensor(),
data.Normalize(mean=self.mean, std=self.std)
])
self.misc_transforms = data.Compose([
T.Lambda(partial(random_resize, size=self.misc_size)),
data.CenterCropV2(self.misc_size),
data.ToTensor()
])
self.mv_transforms = data.Compose(
[T.Resize(size=self.resolution),
T.CenterCrop(self.resolution)])
self.vit_transforms = T.Compose([
CenterCropV3(self.vit_image_size),
T.ToTensor(),
T.Normalize(mean=self.vit_mean, std=self.vit_std)
])
def preprocess(self, input: Input) -> Dict[str, Any]:
video_key = input['Video:FILE']
cap_txt = input['text']
style_image = input['Image:FILE']
total_frames = None
feature_framerate = self.feature_framerate
if os.path.exists(video_key):
try:
ref_frame, vit_image, video_data, misc_data, mv_data = self.video_data_preprocess(
video_key, self.feature_framerate, total_frames,
self.mvs_visual)
except Exception as e:
logger.info(
'{} get frames failed... with error: {}'.format(
video_key, e),
flush=True)
ref_frame = torch.zeros(3, self.vit_image_size,
self.vit_image_size)
video_data = torch.zeros(self.max_frames, 3,
self.image_resolution,
self.image_resolution)
misc_data = torch.zeros(self.max_frames, 3, self.misc_size,
self.misc_size)
mv_data = torch.zeros(self.max_frames, 2,
self.image_resolution,
self.image_resolution)
else:
logger.info(
'The video path does not exist or no video dir provided!')
ref_frame = torch.zeros(3, self.vit_image_size,
self.vit_image_size)
_ = torch.zeros(3, self.vit_image_size, self.vit_image_size)
video_data = torch.zeros(self.max_frames, 3, self.image_resolution,
self.image_resolution)
misc_data = torch.zeros(self.max_frames, 3, self.misc_size,
self.misc_size)
mv_data = torch.zeros(self.max_frames, 2, self.image_resolution,
self.image_resolution)
# inpainting mask
p = random.random()
if p < 0.7:
mask = make_irregular_mask(512, 512)
elif p < 0.9:
mask = make_rectangle_mask(512, 512)
else:
mask = make_uncrop(512, 512)
mask = torch.from_numpy(
cv2.resize(
mask, (self.misc_size, self.misc_size),
interpolation=cv2.INTER_NEAREST)).unsqueeze(0).float()
mask = mask.unsqueeze(0).repeat_interleave(
repeats=self.max_frames, dim=0)
video_input = {
'ref_frame': ref_frame.unsqueeze(0),
'cap_txt': cap_txt,
'video_data': video_data.unsqueeze(0),
'misc_data': misc_data.unsqueeze(0),
'feature_framerate': feature_framerate,
'mask': mask.unsqueeze(0),
'mv_data': mv_data.unsqueeze(0),
'style_image': style_image
}
return video_input
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
return self.model(input)
def postprocess(self, inputs: Dict[str, Any],
**post_params) -> Dict[str, Any]:
output_video_path = post_params.get('output_video', None)
temp_video_file = False
if output_video_path is not None:
output_video_path = tempfile.NamedTemporaryFile(suffix='.gif').name
temp_video_file = True
if temp_video_file:
return {OutputKeys.OUTPUT_VIDEO: inputs['video_path']}
else:
return {OutputKeys.OUTPUT_VIDEO: inputs['video']}
def video_data_preprocess(self, video_key, feature_framerate, total_frames,
visual_mv):
filename = video_key
for _ in range(5):
try:
frame_types, frames, mvs, mvs_visual = self.extract_motion_vectors(
input_video=filename,
fps=feature_framerate,
visual_mv=visual_mv)
break
except Exception as e:
logger.error(
'{} read video frames and motion vectors failed with error: {}'
.format(video_key, e),
flush=True)
total_frames = len(frame_types)
start_indexs = np.where((np.array(frame_types) == 'I') & (
total_frames - np.arange(total_frames) >= self.max_frames))[0]
start_index = np.random.choice(start_indexs)
indices = np.arange(start_index, start_index + self.max_frames)
# note frames are in BGR mode, need to trans to RGB mode
frames = [Image.fromarray(frames[i][:, :, ::-1]) for i in indices]
mvs = [torch.from_numpy(mvs[i].transpose((2, 0, 1))) for i in indices]
mvs = torch.stack(mvs)
if visual_mv:
images = [(mvs_visual[i][:, :, ::-1]).astype('uint8')
for i in indices]
path = self.log_dir + '/visual_mv/' + video_key.split(
'/')[-1] + '.gif'
if not os.path.exists(self.log_dir + '/visual_mv/'):
os.makedirs(self.log_dir + '/visual_mv/', exist_ok=True)
logger.info('save motion vectors visualization to :', path)
imageio.mimwrite(path, images, fps=8)
have_frames = len(frames) > 0
middle_indix = int(len(frames) / 2)
if have_frames:
ref_frame = frames[middle_indix]
vit_image = self.vit_transforms(ref_frame)
misc_imgs_np = self.misc_transforms[:2](frames)
misc_imgs = self.misc_transforms[2:](misc_imgs_np)
frames = self.infer_trans(frames)
mvs = self.mv_transforms(mvs)
else:
vit_image = torch.zeros(3, self.vit_image_size,
self.vit_image_size)
video_data = torch.zeros(self.max_frames, 3, self.image_resolution,
self.image_resolution)
mv_data = torch.zeros(self.max_frames, 2, self.image_resolution,
self.image_resolution)
misc_data = torch.zeros(self.max_frames, 3, self.misc_size,
self.misc_size)
if have_frames:
video_data[:len(frames), ...] = frames
misc_data[:len(frames), ...] = misc_imgs
mv_data[:len(frames), ...] = mvs
ref_frame = vit_image
del frames
del misc_imgs
del mvs
return ref_frame, vit_image, video_data, misc_data, mv_data
def extract_motion_vectors(self,
input_video,
fps=4,
dump=False,
verbose=False,
visual_mv=False):
if dump:
now = datetime.now().strftime('%Y-%m-%dT%H:%M:%S')
for child in ['frames', 'motion_vectors']:
os.makedirs(os.path.join(f'out-{now}', child), exist_ok=True)
temp = rand_name()
tmp_video = os.path.join(
input_video.split('/')[0], f'{temp}' + input_video.split('/')[-1])
videocapture = cv2.VideoCapture(input_video)
frames_num = videocapture.get(cv2.CAP_PROP_FRAME_COUNT)
fps_video = videocapture.get(cv2.CAP_PROP_FPS)
# check if enough frames
if frames_num / fps_video * fps > 16:
fps = max(fps, 1)
else:
fps = int(16 / (frames_num / fps_video)) + 1
ffmpeg_cmd = f'ffmpeg -threads 8 -loglevel error -i {input_video} -filter:v \
fps={fps} -c:v mpeg4 -f rawvideo {tmp_video}'
if os.path.exists(tmp_video):
os.remove(tmp_video)
subprocess.run(args=ffmpeg_cmd, shell=True, timeout=120)
cap = VideoCap()
# open the video file
ret = cap.open(tmp_video)
if not ret:
raise RuntimeError(f'Could not open {tmp_video}')
step = 0
times = []
frame_types = []
frames = []
mvs = []
mvs_visual = []
# continuously read and display video frames and motion vectors
while True:
if verbose:
logger.info('Frame: ', step, end=' ')
tstart = time.perf_counter()
# read next video frame and corresponding motion vectors
ret, frame, motion_vectors, frame_type, timestamp = cap.read()
tend = time.perf_counter()
telapsed = tend - tstart
times.append(telapsed)
# if there is an error reading the frame
if not ret:
if verbose:
logger.warning('No frame read. Stopping.')
break
frame_save = np.zeros(frame.copy().shape, dtype=np.uint8)
if visual_mv:
frame_save = draw_motion_vectors(frame_save, motion_vectors)
# store motion vectors, frames, etc. in output directory
dump = False
if frame.shape[1] >= frame.shape[0]:
w_half = (frame.shape[1] - frame.shape[0]) // 2
if dump:
cv2.imwrite(
os.path.join('./mv_visual/', f'frame-{step}.jpg'),
frame_save[:, w_half:-w_half])
mvs_visual.append(frame_save[:, w_half:-w_half])
else:
h_half = (frame.shape[0] - frame.shape[1]) // 2
if dump:
cv2.imwrite(
os.path.join('./mv_visual/', f'frame-{step}.jpg'),
frame_save[h_half:-h_half, :])
mvs_visual.append(frame_save[h_half:-h_half, :])
h, w = frame.shape[:2]
mv = np.zeros((h, w, 2))
position = motion_vectors[:, 5:7].clip((0, 0), (w - 1, h - 1))
mv[position[:, 1],
position[:,
0]] = motion_vectors[:, 0:
1] * motion_vectors[:, 7:
9] / motion_vectors[:,
9:]
step += 1
frame_types.append(frame_type)
frames.append(frame)
mvs.append(mv)
if verbose:
logger.info('average dt: ', np.mean(times))
cap.release()
if os.path.exists(tmp_video):
os.remove(tmp_video)
return frame_types, frames, mvs, mvs_visual

View File

@@ -1,4 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os.path as osp
import re
from io import BytesIO

View File

@@ -0,0 +1,38 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import sys
import unittest
from modelscope.models import Model
from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline
from modelscope.utils.constant import DownloadMode, Tasks
from modelscope.utils.test_utils import test_level
class VideoDeinterlaceTest(unittest.TestCase):
def setUp(self) -> None:
self.task = Tasks.text_to_video_synthesis
self.model_id = 'buptwq/videocomposer'
self.model_revision = 'v1.0.1'
self.dataset_id = 'buptwq/videocomposer-depths-style'
self.text = 'A glittering and translucent fish swimming in a \
small glass bowl with multicolored piece of stone, like a glass fish'
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_pipeline(self):
pipe = pipeline(
task=Tasks.text_to_video_synthesis,
model=self.model_id,
model_revision=self.model_revision)
ds = MsDataset.load(
self.dataset_id,
split='train',
download_mode=DownloadMode.FORCE_REDOWNLOAD)
inputs = next(iter(ds))
inputs.update({'text': self.text})
_ = pipe(inputs)
if __name__ == '__main__':
unittest.main()