mirror of
https://github.com/modelscope/modelscope.git
synced 2026-02-24 12:10:09 +01:00
VideoComposer: Compositional Video Synthesis with Motion Controllability (#431)
* VideoComposer: Compositional Video Synthesis with Motion Controllability * videocomposer pipeline * pre commit * delete xformers
This commit is contained in:
@@ -218,6 +218,7 @@ class Models(object):
|
||||
mplug_owl = 'mplug-owl'
|
||||
clip_interrogator = 'clip-interrogator'
|
||||
stable_diffusion = 'stable-diffusion'
|
||||
videocomposer = 'videocomposer'
|
||||
text_to_360panorama_image = 'text-to-360panorama-image'
|
||||
|
||||
# science models
|
||||
@@ -525,6 +526,7 @@ class Pipelines(object):
|
||||
multi_modal_similarity = 'multi-modal-similarity'
|
||||
text_to_image_synthesis = 'text-to-image-synthesis'
|
||||
video_multi_modal_embedding = 'video-multi-modal-embedding'
|
||||
videocomposer = 'videocomposer'
|
||||
image_text_retrieval = 'image-text-retrieval'
|
||||
ofa_ocr_recognition = 'ofa-ocr-recognition'
|
||||
ofa_asr = 'ofa-asr'
|
||||
|
||||
@@ -22,6 +22,7 @@ if TYPE_CHECKING:
|
||||
from .efficient_diffusion_tuning import EfficientStableDiffusion
|
||||
from .mplug_owl import MplugOwlForConditionalGeneration
|
||||
from .clip_interrogator import CLIP_Interrogator
|
||||
from .videocomposer import VideoComposer
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
@@ -42,6 +43,7 @@ else:
|
||||
'efficient_diffusion_tuning': ['EfficientStableDiffusion'],
|
||||
'mplug_owl': ['MplugOwlForConditionalGeneration'],
|
||||
'clip_interrogator': ['CLIP_Interrogator'],
|
||||
'videocomposer': ['VideoComposer'],
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
@@ -19,6 +19,9 @@ from modelscope.models.multi_modal.video_synthesis.diffusion import (
|
||||
from modelscope.models.multi_modal.video_synthesis.unet_sd import UNetSD
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
__all__ = ['TextToVideoSynthesis']
|
||||
|
||||
|
||||
23
modelscope/models/multi_modal/videocomposer/__init__.py
Normal file
23
modelscope/models/multi_modal/videocomposer/__init__.py
Normal file
@@ -0,0 +1,23 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
from .videocomposer_model import VideoComposer
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
'videocomposer_model': ['VideoComposer'],
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
@@ -0,0 +1 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
@@ -0,0 +1,44 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from tools.annotator.util import HWC3
|
||||
|
||||
|
||||
class CannyDetector:
|
||||
|
||||
def __call__(self,
|
||||
img,
|
||||
low_threshold=None,
|
||||
high_threshold=None,
|
||||
random_threshold=True):
|
||||
|
||||
# Convert to numpy
|
||||
if isinstance(img, torch.Tensor): # (h, w, c)
|
||||
img = img.cpu().numpy()
|
||||
img_np = cv2.convertScaleAbs((img * 255.))
|
||||
elif isinstance(img, np.ndarray): # (h, w, c)
|
||||
img_np = img # we assume values are in the range from 0 to 255.
|
||||
else:
|
||||
assert False
|
||||
|
||||
# Select the threshold
|
||||
if (low_threshold is None) and (high_threshold is None):
|
||||
median_intensity = np.median(img_np)
|
||||
if random_threshold is False:
|
||||
low_threshold = int(max(0, (1 - 0.33) * median_intensity))
|
||||
high_threshold = int(min(255, (1 + 0.33) * median_intensity))
|
||||
else:
|
||||
random_canny = np.random.uniform(0.1, 0.4)
|
||||
# Might try other values
|
||||
low_threshold = int(
|
||||
max(0, (1 - random_canny) * median_intensity))
|
||||
high_threshold = 2 * low_threshold
|
||||
|
||||
# Detect canny edge
|
||||
canny_edge = cv2.Canny(img_np, low_threshold, high_threshold)
|
||||
|
||||
canny_condition = torch.from_numpy(
|
||||
canny_edge.copy()).unsqueeze(dim=-1).float().cuda() / 255.0
|
||||
return canny_condition
|
||||
@@ -0,0 +1,3 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from .palette import *
|
||||
@@ -0,0 +1,155 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
r"""Modified from ``https://github.com/sergeyk/rayleigh''.
|
||||
"""
|
||||
import os
|
||||
import os.path as osp
|
||||
|
||||
import numpy as np
|
||||
from skimage.color import hsv2rgb, lab2rgb, rgb2lab
|
||||
from skimage.io import imsave
|
||||
from sklearn.metrics import euclidean_distances
|
||||
|
||||
__all__ = ['Palette']
|
||||
|
||||
|
||||
def rgb2hex(rgb):
|
||||
return '#%02x%02x%02x' % tuple([int(round(255.0 * u)) for u in rgb])
|
||||
|
||||
|
||||
def hex2rgb(hex):
|
||||
rgb = hex.strip('#')
|
||||
fn = lambda u: round(int(u, 16) / 255.0, 5) # noqa
|
||||
return fn(rgb[:2]), fn(rgb[2:4]), fn(rgb[4:6])
|
||||
|
||||
|
||||
class Palette(object):
|
||||
r"""Create a color palette (codebook) in the form of a 2D grid of colors.
|
||||
Further, the rightmost column has num_hues gradations from black to white.
|
||||
|
||||
Parameters:
|
||||
num_hues: number of colors with full lightness and saturation, in the middle.
|
||||
num_sat: number of rows above middle row that show the same hues with decreasing saturation.
|
||||
"""
|
||||
|
||||
def __init__(self, num_hues=11, num_sat=5, num_light=4):
|
||||
n = num_sat + 2 * num_light
|
||||
|
||||
# hues
|
||||
if num_hues == 8:
|
||||
hues = np.tile(
|
||||
np.array([0., 0.10, 0.15, 0.28, 0.51, 0.58, 0.77, 0.85]),
|
||||
(n, 1))
|
||||
elif num_hues == 9:
|
||||
hues = np.tile(
|
||||
np.array([0., 0.10, 0.15, 0.28, 0.49, 0.54, 0.60, 0.7, 0.87]),
|
||||
(n, 1))
|
||||
elif num_hues == 10:
|
||||
hues = np.tile(
|
||||
np.array(
|
||||
[0., 0.10, 0.15, 0.28, 0.49, 0.54, 0.60, 0.66, 0.76,
|
||||
0.87]), (n, 1))
|
||||
elif num_hues == 11:
|
||||
hues = np.tile(
|
||||
np.array([
|
||||
0.0, 0.0833, 0.166, 0.25, 0.333, 0.5, 0.56333, 0.666, 0.73,
|
||||
0.803, 0.916
|
||||
]), (n, 1))
|
||||
else:
|
||||
hues = np.tile(np.linspace(0, 1, num_hues + 1)[:-1], (n, 1))
|
||||
|
||||
# saturations
|
||||
sats = np.hstack((
|
||||
np.linspace(0, 1, num_sat + 2)[1:-1],
|
||||
1,
|
||||
[1] * num_light,
|
||||
[0.4] * # noqa
|
||||
(num_light - 1)))
|
||||
sats = np.tile(np.atleast_2d(sats).T, (1, num_hues))
|
||||
|
||||
# lights
|
||||
lights = np.hstack(
|
||||
([1] * num_sat, 1, np.linspace(1, 0.2, num_light + 2)[1:-1],
|
||||
np.linspace(1, 0.2, num_light + 2)[1:-2]))
|
||||
lights = np.tile(np.atleast_2d(lights).T, (1, num_hues))
|
||||
|
||||
# colors
|
||||
rgb = hsv2rgb(np.dstack([hues, sats, lights]))
|
||||
gray = np.tile(
|
||||
np.linspace(1, 0, n)[:, np.newaxis, np.newaxis], (1, 1, 3))
|
||||
self.thumbnail = np.hstack([rgb, gray])
|
||||
|
||||
# flatten
|
||||
rgb = rgb.T.reshape(3, -1).T
|
||||
gray = gray.T.reshape(3, -1).T
|
||||
self.rgb = np.vstack((rgb, gray))
|
||||
self.lab = rgb2lab(self.rgb[np.newaxis, :, :]).squeeze()
|
||||
self.hex = [rgb2hex(u) for u in self.rgb]
|
||||
self.lab_dists = euclidean_distances(self.lab, squared=True)
|
||||
|
||||
def histogram(self, rgb_img, sigma=20):
|
||||
# compute histogram
|
||||
lab = rgb2lab(rgb_img).reshape((-1, 3))
|
||||
min_ind = np.argmin(
|
||||
euclidean_distances(lab, self.lab, squared=True), axis=1)
|
||||
hist = 1.0 * np.bincount(
|
||||
min_ind, minlength=self.lab.shape[0]) / lab.shape[0]
|
||||
|
||||
# smooth histogram
|
||||
if sigma > 0:
|
||||
weight = np.exp(-self.lab_dists / (2.0 * sigma**2))
|
||||
weight = weight / weight.sum(1)[:, np.newaxis]
|
||||
hist = (weight * hist).sum(1)
|
||||
hist[hist < 1e-5] = 0
|
||||
return hist
|
||||
|
||||
def get_palette_image(self, hist, percentile=90, width=200, height=50):
|
||||
# curate histogram
|
||||
ind = np.argsort(-hist)
|
||||
ind = ind[hist[ind] > np.percentile(hist, percentile)]
|
||||
hist = hist[ind] / hist[ind].sum()
|
||||
|
||||
# draw palette
|
||||
nums = np.array(hist * width, dtype=int)
|
||||
array = np.vstack([
|
||||
np.tile(np.array(u), (v, 1)) for u, v in zip(self.rgb[ind], nums)
|
||||
])
|
||||
array = np.tile(array[np.newaxis, :, :], (height, 1, 1))
|
||||
if array.shape[1] < width:
|
||||
array = np.concatenate(
|
||||
[array, np.zeros((height, width - array.shape[1], 3))], axis=1)
|
||||
return array
|
||||
|
||||
def quantize_image(self, rgb_img):
|
||||
lab = rgb2lab(rgb_img).reshape((-1, 3))
|
||||
min_ind = np.argmin(
|
||||
euclidean_distances(lab, self.lab, squared=True), axis=1)
|
||||
quantized_lab = self.lab[min_ind]
|
||||
img = lab2rgb(quantized_lab.reshape(rgb_img.shape))
|
||||
return img
|
||||
|
||||
def export(self, dirname):
|
||||
if not osp.exists(dirname):
|
||||
os.makedirs(dirname)
|
||||
|
||||
# save thumbnail
|
||||
imsave(osp.join(dirname, 'palette.png'), self.thumbnail)
|
||||
|
||||
# save html
|
||||
with open(osp.join(dirname, 'palette.html'), 'w') as f:
|
||||
html = '''
|
||||
<style>
|
||||
span {
|
||||
width: 20px;
|
||||
height: 20px;
|
||||
margin: 2px;
|
||||
padding: 0px;
|
||||
display: inline-block;
|
||||
}
|
||||
</style>
|
||||
'''
|
||||
for row in self.thumbnail:
|
||||
for col in row:
|
||||
html += '<a id="{0}"><span style="background-color: {0}" /></a>\n'.format(
|
||||
rgb2hex(col))
|
||||
html += '<br />\n'
|
||||
f.write(html)
|
||||
@@ -0,0 +1,4 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from .pidinet import *
|
||||
from .sketch_simplification import *
|
||||
@@ -0,0 +1,940 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
r"""Modified from ``https://github.com/zhuoinoulu/pidinet''.
|
||||
Image augmentation: T.Compose([
|
||||
T.ToTensor(),
|
||||
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]).
|
||||
"""
|
||||
import math
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from modelscope.models.multi_modal.videocomposer.utils.utils import \
|
||||
DOWNLOAD_TO_CACHE
|
||||
|
||||
__all__ = [
|
||||
'PiDiNet', 'pidinet_bsd_tiny', 'pidinet_bsd_small', 'pidinet_bsd',
|
||||
'pidinet_nyud', 'pidinet_multicue'
|
||||
]
|
||||
|
||||
CONFIGS = {
|
||||
'baseline': {
|
||||
'layer0': 'cv',
|
||||
'layer1': 'cv',
|
||||
'layer2': 'cv',
|
||||
'layer3': 'cv',
|
||||
'layer4': 'cv',
|
||||
'layer5': 'cv',
|
||||
'layer6': 'cv',
|
||||
'layer7': 'cv',
|
||||
'layer8': 'cv',
|
||||
'layer9': 'cv',
|
||||
'layer10': 'cv',
|
||||
'layer11': 'cv',
|
||||
'layer12': 'cv',
|
||||
'layer13': 'cv',
|
||||
'layer14': 'cv',
|
||||
'layer15': 'cv',
|
||||
},
|
||||
'c-v15': {
|
||||
'layer0': 'cd',
|
||||
'layer1': 'cv',
|
||||
'layer2': 'cv',
|
||||
'layer3': 'cv',
|
||||
'layer4': 'cv',
|
||||
'layer5': 'cv',
|
||||
'layer6': 'cv',
|
||||
'layer7': 'cv',
|
||||
'layer8': 'cv',
|
||||
'layer9': 'cv',
|
||||
'layer10': 'cv',
|
||||
'layer11': 'cv',
|
||||
'layer12': 'cv',
|
||||
'layer13': 'cv',
|
||||
'layer14': 'cv',
|
||||
'layer15': 'cv',
|
||||
},
|
||||
'a-v15': {
|
||||
'layer0': 'ad',
|
||||
'layer1': 'cv',
|
||||
'layer2': 'cv',
|
||||
'layer3': 'cv',
|
||||
'layer4': 'cv',
|
||||
'layer5': 'cv',
|
||||
'layer6': 'cv',
|
||||
'layer7': 'cv',
|
||||
'layer8': 'cv',
|
||||
'layer9': 'cv',
|
||||
'layer10': 'cv',
|
||||
'layer11': 'cv',
|
||||
'layer12': 'cv',
|
||||
'layer13': 'cv',
|
||||
'layer14': 'cv',
|
||||
'layer15': 'cv',
|
||||
},
|
||||
'r-v15': {
|
||||
'layer0': 'rd',
|
||||
'layer1': 'cv',
|
||||
'layer2': 'cv',
|
||||
'layer3': 'cv',
|
||||
'layer4': 'cv',
|
||||
'layer5': 'cv',
|
||||
'layer6': 'cv',
|
||||
'layer7': 'cv',
|
||||
'layer8': 'cv',
|
||||
'layer9': 'cv',
|
||||
'layer10': 'cv',
|
||||
'layer11': 'cv',
|
||||
'layer12': 'cv',
|
||||
'layer13': 'cv',
|
||||
'layer14': 'cv',
|
||||
'layer15': 'cv',
|
||||
},
|
||||
'cvvv4': {
|
||||
'layer0': 'cd',
|
||||
'layer1': 'cv',
|
||||
'layer2': 'cv',
|
||||
'layer3': 'cv',
|
||||
'layer4': 'cd',
|
||||
'layer5': 'cv',
|
||||
'layer6': 'cv',
|
||||
'layer7': 'cv',
|
||||
'layer8': 'cd',
|
||||
'layer9': 'cv',
|
||||
'layer10': 'cv',
|
||||
'layer11': 'cv',
|
||||
'layer12': 'cd',
|
||||
'layer13': 'cv',
|
||||
'layer14': 'cv',
|
||||
'layer15': 'cv',
|
||||
},
|
||||
'avvv4': {
|
||||
'layer0': 'ad',
|
||||
'layer1': 'cv',
|
||||
'layer2': 'cv',
|
||||
'layer3': 'cv',
|
||||
'layer4': 'ad',
|
||||
'layer5': 'cv',
|
||||
'layer6': 'cv',
|
||||
'layer7': 'cv',
|
||||
'layer8': 'ad',
|
||||
'layer9': 'cv',
|
||||
'layer10': 'cv',
|
||||
'layer11': 'cv',
|
||||
'layer12': 'ad',
|
||||
'layer13': 'cv',
|
||||
'layer14': 'cv',
|
||||
'layer15': 'cv',
|
||||
},
|
||||
'rvvv4': {
|
||||
'layer0': 'rd',
|
||||
'layer1': 'cv',
|
||||
'layer2': 'cv',
|
||||
'layer3': 'cv',
|
||||
'layer4': 'rd',
|
||||
'layer5': 'cv',
|
||||
'layer6': 'cv',
|
||||
'layer7': 'cv',
|
||||
'layer8': 'rd',
|
||||
'layer9': 'cv',
|
||||
'layer10': 'cv',
|
||||
'layer11': 'cv',
|
||||
'layer12': 'rd',
|
||||
'layer13': 'cv',
|
||||
'layer14': 'cv',
|
||||
'layer15': 'cv',
|
||||
},
|
||||
'cccv4': {
|
||||
'layer0': 'cd',
|
||||
'layer1': 'cd',
|
||||
'layer2': 'cd',
|
||||
'layer3': 'cv',
|
||||
'layer4': 'cd',
|
||||
'layer5': 'cd',
|
||||
'layer6': 'cd',
|
||||
'layer7': 'cv',
|
||||
'layer8': 'cd',
|
||||
'layer9': 'cd',
|
||||
'layer10': 'cd',
|
||||
'layer11': 'cv',
|
||||
'layer12': 'cd',
|
||||
'layer13': 'cd',
|
||||
'layer14': 'cd',
|
||||
'layer15': 'cv',
|
||||
},
|
||||
'aaav4': {
|
||||
'layer0': 'ad',
|
||||
'layer1': 'ad',
|
||||
'layer2': 'ad',
|
||||
'layer3': 'cv',
|
||||
'layer4': 'ad',
|
||||
'layer5': 'ad',
|
||||
'layer6': 'ad',
|
||||
'layer7': 'cv',
|
||||
'layer8': 'ad',
|
||||
'layer9': 'ad',
|
||||
'layer10': 'ad',
|
||||
'layer11': 'cv',
|
||||
'layer12': 'ad',
|
||||
'layer13': 'ad',
|
||||
'layer14': 'ad',
|
||||
'layer15': 'cv',
|
||||
},
|
||||
'rrrv4': {
|
||||
'layer0': 'rd',
|
||||
'layer1': 'rd',
|
||||
'layer2': 'rd',
|
||||
'layer3': 'cv',
|
||||
'layer4': 'rd',
|
||||
'layer5': 'rd',
|
||||
'layer6': 'rd',
|
||||
'layer7': 'cv',
|
||||
'layer8': 'rd',
|
||||
'layer9': 'rd',
|
||||
'layer10': 'rd',
|
||||
'layer11': 'cv',
|
||||
'layer12': 'rd',
|
||||
'layer13': 'rd',
|
||||
'layer14': 'rd',
|
||||
'layer15': 'cv',
|
||||
},
|
||||
'c16': {
|
||||
'layer0': 'cd',
|
||||
'layer1': 'cd',
|
||||
'layer2': 'cd',
|
||||
'layer3': 'cd',
|
||||
'layer4': 'cd',
|
||||
'layer5': 'cd',
|
||||
'layer6': 'cd',
|
||||
'layer7': 'cd',
|
||||
'layer8': 'cd',
|
||||
'layer9': 'cd',
|
||||
'layer10': 'cd',
|
||||
'layer11': 'cd',
|
||||
'layer12': 'cd',
|
||||
'layer13': 'cd',
|
||||
'layer14': 'cd',
|
||||
'layer15': 'cd',
|
||||
},
|
||||
'a16': {
|
||||
'layer0': 'ad',
|
||||
'layer1': 'ad',
|
||||
'layer2': 'ad',
|
||||
'layer3': 'ad',
|
||||
'layer4': 'ad',
|
||||
'layer5': 'ad',
|
||||
'layer6': 'ad',
|
||||
'layer7': 'ad',
|
||||
'layer8': 'ad',
|
||||
'layer9': 'ad',
|
||||
'layer10': 'ad',
|
||||
'layer11': 'ad',
|
||||
'layer12': 'ad',
|
||||
'layer13': 'ad',
|
||||
'layer14': 'ad',
|
||||
'layer15': 'ad',
|
||||
},
|
||||
'r16': {
|
||||
'layer0': 'rd',
|
||||
'layer1': 'rd',
|
||||
'layer2': 'rd',
|
||||
'layer3': 'rd',
|
||||
'layer4': 'rd',
|
||||
'layer5': 'rd',
|
||||
'layer6': 'rd',
|
||||
'layer7': 'rd',
|
||||
'layer8': 'rd',
|
||||
'layer9': 'rd',
|
||||
'layer10': 'rd',
|
||||
'layer11': 'rd',
|
||||
'layer12': 'rd',
|
||||
'layer13': 'rd',
|
||||
'layer14': 'rd',
|
||||
'layer15': 'rd',
|
||||
},
|
||||
'carv4': {
|
||||
'layer0': 'cd',
|
||||
'layer1': 'ad',
|
||||
'layer2': 'rd',
|
||||
'layer3': 'cv',
|
||||
'layer4': 'cd',
|
||||
'layer5': 'ad',
|
||||
'layer6': 'rd',
|
||||
'layer7': 'cv',
|
||||
'layer8': 'cd',
|
||||
'layer9': 'ad',
|
||||
'layer10': 'rd',
|
||||
'layer11': 'cv',
|
||||
'layer12': 'cd',
|
||||
'layer13': 'ad',
|
||||
'layer14': 'rd',
|
||||
'layer15': 'cv'
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def create_conv_func(op_type):
|
||||
assert op_type in ['cv', 'cd', 'ad',
|
||||
'rd'], 'unknown op type: %s' % str(op_type)
|
||||
if op_type == 'cv':
|
||||
return F.conv2d
|
||||
if op_type == 'cd':
|
||||
|
||||
def func(x,
|
||||
weights,
|
||||
bias=None,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
groups=1):
|
||||
assert dilation in [1,
|
||||
2], 'dilation for cd_conv should be in 1 or 2'
|
||||
assert weights.size(2) == 3 and weights.size(
|
||||
3) == 3, 'kernel size for cd_conv should be 3x3'
|
||||
assert padding == dilation, 'padding for cd_conv set wrong'
|
||||
|
||||
weights_c = weights.sum(dim=[2, 3], keepdim=True)
|
||||
yc = F.conv2d(
|
||||
x, weights_c, stride=stride, padding=0, groups=groups)
|
||||
y = F.conv2d(
|
||||
x,
|
||||
weights,
|
||||
bias,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups)
|
||||
return y - yc
|
||||
|
||||
return func
|
||||
elif op_type == 'ad':
|
||||
|
||||
def func(x,
|
||||
weights,
|
||||
bias=None,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
groups=1):
|
||||
assert dilation in [1,
|
||||
2], 'dilation for ad_conv should be in 1 or 2'
|
||||
assert weights.size(2) == 3 and weights.size(
|
||||
3) == 3, 'kernel size for ad_conv should be 3x3'
|
||||
assert padding == dilation, 'padding for ad_conv set wrong'
|
||||
|
||||
shape = weights.shape
|
||||
weights = weights.view(shape[0], shape[1], -1)
|
||||
weights_conv = (weights
|
||||
- weights[:, :, [3, 0, 1, 6, 4, 2, 7, 8, 5]]).view(
|
||||
shape) # clock-wise
|
||||
y = F.conv2d(
|
||||
x,
|
||||
weights_conv,
|
||||
bias,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups)
|
||||
return y
|
||||
|
||||
return func
|
||||
elif op_type == 'rd':
|
||||
|
||||
def func(x,
|
||||
weights,
|
||||
bias=None,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
groups=1):
|
||||
assert dilation in [1,
|
||||
2], 'dilation for rd_conv should be in 1 or 2'
|
||||
assert weights.size(2) == 3 and weights.size(
|
||||
3) == 3, 'kernel size for rd_conv should be 3x3'
|
||||
padding = 2 * dilation
|
||||
|
||||
shape = weights.shape
|
||||
if weights.is_cuda:
|
||||
buffer = torch.cuda.FloatTensor(shape[0], shape[1],
|
||||
5 * 5).fill_(0)
|
||||
else:
|
||||
buffer = torch.zeros(shape[0], shape[1], 5 * 5)
|
||||
weights = weights.view(shape[0], shape[1], -1)
|
||||
buffer[:, :, [0, 2, 4, 10, 14, 20, 22, 24]] = weights[:, :, 1:]
|
||||
buffer[:, :, [6, 7, 8, 11, 13, 16, 17, 18]] = -weights[:, :, 1:]
|
||||
buffer[:, :, 12] = 0
|
||||
buffer = buffer.view(shape[0], shape[1], 5, 5)
|
||||
y = F.conv2d(
|
||||
x,
|
||||
buffer,
|
||||
bias,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups)
|
||||
return y
|
||||
|
||||
return func
|
||||
else:
|
||||
print('impossible to be here unless you force that', flush=True)
|
||||
return None
|
||||
|
||||
|
||||
def config_model(model):
|
||||
model_options = list(CONFIGS.keys())
|
||||
assert model in model_options, \
|
||||
'unrecognized model, please choose from %s' % str(model_options)
|
||||
|
||||
pdcs = []
|
||||
for i in range(16):
|
||||
layer_name = 'layer%d' % i
|
||||
op = CONFIGS[model][layer_name]
|
||||
pdcs.append(create_conv_func(op))
|
||||
return pdcs
|
||||
|
||||
|
||||
def config_model_converted(model):
|
||||
model_options = list(CONFIGS.keys())
|
||||
assert model in model_options, \
|
||||
'unrecognized model, please choose from %s' % str(model_options)
|
||||
|
||||
pdcs = []
|
||||
for i in range(16):
|
||||
layer_name = 'layer%d' % i
|
||||
op = CONFIGS[model][layer_name]
|
||||
pdcs.append(op)
|
||||
return pdcs
|
||||
|
||||
|
||||
def convert_pdc(op, weight):
|
||||
if op == 'cv':
|
||||
return weight
|
||||
elif op == 'cd':
|
||||
shape = weight.shape
|
||||
weight_c = weight.sum(dim=[2, 3])
|
||||
weight = weight.view(shape[0], shape[1], -1)
|
||||
weight[:, :, 4] = weight[:, :, 4] - weight_c
|
||||
weight = weight.view(shape)
|
||||
return weight
|
||||
elif op == 'ad':
|
||||
shape = weight.shape
|
||||
weight = weight.view(shape[0], shape[1], -1)
|
||||
weight_conv = (weight
|
||||
- weight[:, :, [3, 0, 1, 6, 4, 2, 7, 8, 5]]).view(shape)
|
||||
return weight_conv
|
||||
elif op == 'rd':
|
||||
shape = weight.shape
|
||||
buffer = torch.zeros(shape[0], shape[1], 5 * 5, device=weight.device)
|
||||
weight = weight.view(shape[0], shape[1], -1)
|
||||
buffer[:, :, [0, 2, 4, 10, 14, 20, 22, 24]] = weight[:, :, 1:]
|
||||
buffer[:, :, [6, 7, 8, 11, 13, 16, 17, 18]] = -weight[:, :, 1:]
|
||||
buffer = buffer.view(shape[0], shape[1], 5, 5)
|
||||
return buffer
|
||||
raise ValueError('wrong op {}'.format(str(op)))
|
||||
|
||||
|
||||
def convert_pidinet(state_dict, config):
|
||||
pdcs = config_model_converted(config)
|
||||
new_dict = {}
|
||||
for pname, p in state_dict.items():
|
||||
if 'init_block.weight' in pname:
|
||||
new_dict[pname] = convert_pdc(pdcs[0], p)
|
||||
elif 'block1_1.conv1.weight' in pname:
|
||||
new_dict[pname] = convert_pdc(pdcs[1], p)
|
||||
elif 'block1_2.conv1.weight' in pname:
|
||||
new_dict[pname] = convert_pdc(pdcs[2], p)
|
||||
elif 'block1_3.conv1.weight' in pname:
|
||||
new_dict[pname] = convert_pdc(pdcs[3], p)
|
||||
elif 'block2_1.conv1.weight' in pname:
|
||||
new_dict[pname] = convert_pdc(pdcs[4], p)
|
||||
elif 'block2_2.conv1.weight' in pname:
|
||||
new_dict[pname] = convert_pdc(pdcs[5], p)
|
||||
elif 'block2_3.conv1.weight' in pname:
|
||||
new_dict[pname] = convert_pdc(pdcs[6], p)
|
||||
elif 'block2_4.conv1.weight' in pname:
|
||||
new_dict[pname] = convert_pdc(pdcs[7], p)
|
||||
elif 'block3_1.conv1.weight' in pname:
|
||||
new_dict[pname] = convert_pdc(pdcs[8], p)
|
||||
elif 'block3_2.conv1.weight' in pname:
|
||||
new_dict[pname] = convert_pdc(pdcs[9], p)
|
||||
elif 'block3_3.conv1.weight' in pname:
|
||||
new_dict[pname] = convert_pdc(pdcs[10], p)
|
||||
elif 'block3_4.conv1.weight' in pname:
|
||||
new_dict[pname] = convert_pdc(pdcs[11], p)
|
||||
elif 'block4_1.conv1.weight' in pname:
|
||||
new_dict[pname] = convert_pdc(pdcs[12], p)
|
||||
elif 'block4_2.conv1.weight' in pname:
|
||||
new_dict[pname] = convert_pdc(pdcs[13], p)
|
||||
elif 'block4_3.conv1.weight' in pname:
|
||||
new_dict[pname] = convert_pdc(pdcs[14], p)
|
||||
elif 'block4_4.conv1.weight' in pname:
|
||||
new_dict[pname] = convert_pdc(pdcs[15], p)
|
||||
else:
|
||||
new_dict[pname] = p
|
||||
return new_dict
|
||||
|
||||
|
||||
class Conv2d(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
pdc,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=False):
|
||||
super(Conv2d, self).__init__()
|
||||
if in_channels % groups != 0:
|
||||
raise ValueError('in_channels must be divisible by groups')
|
||||
if out_channels % groups != 0:
|
||||
raise ValueError('out_channels must be divisible by groups')
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.stride = stride
|
||||
self.padding = padding
|
||||
self.dilation = dilation
|
||||
self.groups = groups
|
||||
self.weight = nn.Parameter(
|
||||
torch.Tensor(out_channels, in_channels // groups, kernel_size,
|
||||
kernel_size))
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.Tensor(out_channels))
|
||||
else:
|
||||
self.register_parameter('bias', None)
|
||||
self.reset_parameters()
|
||||
self.pdc = pdc
|
||||
|
||||
def reset_parameters(self):
|
||||
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
||||
if self.bias is not None:
|
||||
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
|
||||
bound = 1 / math.sqrt(fan_in)
|
||||
nn.init.uniform_(self.bias, -bound, bound)
|
||||
|
||||
def forward(self, input):
|
||||
return self.pdc(input, self.weight, self.bias, self.stride,
|
||||
self.padding, self.dilation, self.groups)
|
||||
|
||||
|
||||
class CSAM(nn.Module):
|
||||
r"""
|
||||
Compact Spatial Attention Module
|
||||
"""
|
||||
|
||||
def __init__(self, channels):
|
||||
super(CSAM, self).__init__()
|
||||
|
||||
mid_channels = 4
|
||||
self.relu1 = nn.ReLU()
|
||||
self.conv1 = nn.Conv2d(
|
||||
channels, mid_channels, kernel_size=1, padding=0)
|
||||
self.conv2 = nn.Conv2d(
|
||||
mid_channels, 1, kernel_size=3, padding=1, bias=False)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
nn.init.constant_(self.conv1.bias, 0)
|
||||
|
||||
def forward(self, x):
|
||||
y = self.relu1(x)
|
||||
y = self.conv1(y)
|
||||
y = self.conv2(y)
|
||||
y = self.sigmoid(y)
|
||||
|
||||
return x * y
|
||||
|
||||
|
||||
class CDCM(nn.Module):
|
||||
r"""
|
||||
Compact Dilation Convolution based Module
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels):
|
||||
super(CDCM, self).__init__()
|
||||
|
||||
self.relu1 = nn.ReLU()
|
||||
self.conv1 = nn.Conv2d(
|
||||
in_channels, out_channels, kernel_size=1, padding=0)
|
||||
self.conv2_1 = nn.Conv2d(
|
||||
out_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
dilation=5,
|
||||
padding=5,
|
||||
bias=False)
|
||||
self.conv2_2 = nn.Conv2d(
|
||||
out_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
dilation=7,
|
||||
padding=7,
|
||||
bias=False)
|
||||
self.conv2_3 = nn.Conv2d(
|
||||
out_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
dilation=9,
|
||||
padding=9,
|
||||
bias=False)
|
||||
self.conv2_4 = nn.Conv2d(
|
||||
out_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
dilation=11,
|
||||
padding=11,
|
||||
bias=False)
|
||||
nn.init.constant_(self.conv1.bias, 0)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.relu1(x)
|
||||
x = self.conv1(x)
|
||||
x1 = self.conv2_1(x)
|
||||
x2 = self.conv2_2(x)
|
||||
x3 = self.conv2_3(x)
|
||||
x4 = self.conv2_4(x)
|
||||
return x1 + x2 + x3 + x4
|
||||
|
||||
|
||||
class MapReduce(nn.Module):
|
||||
r"""
|
||||
Reduce feature maps into a single edge map
|
||||
"""
|
||||
|
||||
def __init__(self, channels):
|
||||
super(MapReduce, self).__init__()
|
||||
self.conv = nn.Conv2d(channels, 1, kernel_size=1, padding=0)
|
||||
nn.init.constant_(self.conv.bias, 0)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class PDCBlock(nn.Module):
|
||||
|
||||
def __init__(self, pdc, inplane, ouplane, stride=1):
|
||||
super(PDCBlock, self).__init__()
|
||||
self.stride = stride
|
||||
|
||||
self.stride = stride
|
||||
if self.stride > 1:
|
||||
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
self.shortcut = nn.Conv2d(
|
||||
inplane, ouplane, kernel_size=1, padding=0)
|
||||
self.conv1 = Conv2d(
|
||||
pdc,
|
||||
inplane,
|
||||
inplane,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
groups=inplane,
|
||||
bias=False)
|
||||
self.relu2 = nn.ReLU()
|
||||
self.conv2 = nn.Conv2d(
|
||||
inplane, ouplane, kernel_size=1, padding=0, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
if self.stride > 1:
|
||||
x = self.pool(x)
|
||||
y = self.conv1(x)
|
||||
y = self.relu2(y)
|
||||
y = self.conv2(y)
|
||||
if self.stride > 1:
|
||||
x = self.shortcut(x)
|
||||
y = y + x
|
||||
return y
|
||||
|
||||
|
||||
class PDCBlock_converted(nn.Module):
|
||||
r"""
|
||||
CPDC, APDC can be converted to vanilla 3x3 convolution
|
||||
RPDC can be converted to vanilla 5x5 convolution
|
||||
"""
|
||||
|
||||
def __init__(self, pdc, inplane, ouplane, stride=1):
|
||||
super(PDCBlock_converted, self).__init__()
|
||||
self.stride = stride
|
||||
|
||||
if self.stride > 1:
|
||||
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
self.shortcut = nn.Conv2d(
|
||||
inplane, ouplane, kernel_size=1, padding=0)
|
||||
if pdc == 'rd':
|
||||
self.conv1 = nn.Conv2d(
|
||||
inplane,
|
||||
inplane,
|
||||
kernel_size=5,
|
||||
padding=2,
|
||||
groups=inplane,
|
||||
bias=False)
|
||||
else:
|
||||
self.conv1 = nn.Conv2d(
|
||||
inplane,
|
||||
inplane,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
groups=inplane,
|
||||
bias=False)
|
||||
self.relu2 = nn.ReLU()
|
||||
self.conv2 = nn.Conv2d(
|
||||
inplane, ouplane, kernel_size=1, padding=0, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
if self.stride > 1:
|
||||
x = self.pool(x)
|
||||
y = self.conv1(x)
|
||||
y = self.relu2(y)
|
||||
y = self.conv2(y)
|
||||
if self.stride > 1:
|
||||
x = self.shortcut(x)
|
||||
y = y + x
|
||||
return y
|
||||
|
||||
|
||||
class PiDiNet(nn.Module):
|
||||
|
||||
def __init__(self, inplane, pdcs, dil=None, sa=False, convert=False):
|
||||
super(PiDiNet, self).__init__()
|
||||
self.sa = sa
|
||||
if dil is not None:
|
||||
assert isinstance(dil, int), 'dil should be an int'
|
||||
self.dil = dil
|
||||
|
||||
self.fuseplanes = []
|
||||
|
||||
self.inplane = inplane
|
||||
if convert:
|
||||
if pdcs[0] == 'rd':
|
||||
init_kernel_size = 5
|
||||
init_padding = 2
|
||||
else:
|
||||
init_kernel_size = 3
|
||||
init_padding = 1
|
||||
self.init_block = nn.Conv2d(
|
||||
3,
|
||||
self.inplane,
|
||||
kernel_size=init_kernel_size,
|
||||
padding=init_padding,
|
||||
bias=False)
|
||||
block_class = PDCBlock_converted
|
||||
else:
|
||||
self.init_block = Conv2d(
|
||||
pdcs[0], 3, self.inplane, kernel_size=3, padding=1)
|
||||
block_class = PDCBlock
|
||||
|
||||
self.block1_1 = block_class(pdcs[1], self.inplane, self.inplane)
|
||||
self.block1_2 = block_class(pdcs[2], self.inplane, self.inplane)
|
||||
self.block1_3 = block_class(pdcs[3], self.inplane, self.inplane)
|
||||
self.fuseplanes.append(self.inplane) # C
|
||||
|
||||
inplane = self.inplane
|
||||
self.inplane = self.inplane * 2
|
||||
self.block2_1 = block_class(pdcs[4], inplane, self.inplane, stride=2)
|
||||
self.block2_2 = block_class(pdcs[5], self.inplane, self.inplane)
|
||||
self.block2_3 = block_class(pdcs[6], self.inplane, self.inplane)
|
||||
self.block2_4 = block_class(pdcs[7], self.inplane, self.inplane)
|
||||
self.fuseplanes.append(self.inplane) # 2C
|
||||
|
||||
inplane = self.inplane
|
||||
self.inplane = self.inplane * 2
|
||||
self.block3_1 = block_class(pdcs[8], inplane, self.inplane, stride=2)
|
||||
self.block3_2 = block_class(pdcs[9], self.inplane, self.inplane)
|
||||
self.block3_3 = block_class(pdcs[10], self.inplane, self.inplane)
|
||||
self.block3_4 = block_class(pdcs[11], self.inplane, self.inplane)
|
||||
self.fuseplanes.append(self.inplane) # 4C
|
||||
|
||||
self.block4_1 = block_class(
|
||||
pdcs[12], self.inplane, self.inplane, stride=2)
|
||||
self.block4_2 = block_class(pdcs[13], self.inplane, self.inplane)
|
||||
self.block4_3 = block_class(pdcs[14], self.inplane, self.inplane)
|
||||
self.block4_4 = block_class(pdcs[15], self.inplane, self.inplane)
|
||||
self.fuseplanes.append(self.inplane) # 4C
|
||||
|
||||
self.conv_reduces = nn.ModuleList()
|
||||
if self.sa and self.dil is not None:
|
||||
self.attentions = nn.ModuleList()
|
||||
self.dilations = nn.ModuleList()
|
||||
for i in range(4):
|
||||
self.dilations.append(CDCM(self.fuseplanes[i], self.dil))
|
||||
self.attentions.append(CSAM(self.dil))
|
||||
self.conv_reduces.append(MapReduce(self.dil))
|
||||
elif self.sa:
|
||||
self.attentions = nn.ModuleList()
|
||||
for i in range(4):
|
||||
self.attentions.append(CSAM(self.fuseplanes[i]))
|
||||
self.conv_reduces.append(MapReduce(self.fuseplanes[i]))
|
||||
elif self.dil is not None:
|
||||
self.dilations = nn.ModuleList()
|
||||
for i in range(4):
|
||||
self.dilations.append(CDCM(self.fuseplanes[i], self.dil))
|
||||
self.conv_reduces.append(MapReduce(self.dil))
|
||||
else:
|
||||
for i in range(4):
|
||||
self.conv_reduces.append(MapReduce(self.fuseplanes[i]))
|
||||
|
||||
self.classifier = nn.Conv2d(4, 1, kernel_size=1)
|
||||
nn.init.constant_(self.classifier.weight, 0.25)
|
||||
nn.init.constant_(self.classifier.bias, 0)
|
||||
|
||||
def get_weights(self):
|
||||
conv_weights = []
|
||||
bn_weights = []
|
||||
relu_weights = []
|
||||
for pname, p in self.named_parameters():
|
||||
if 'bn' in pname:
|
||||
bn_weights.append(p)
|
||||
elif 'relu' in pname:
|
||||
relu_weights.append(p)
|
||||
else:
|
||||
conv_weights.append(p)
|
||||
|
||||
return conv_weights, bn_weights, relu_weights
|
||||
|
||||
def forward(self, x):
|
||||
H, W = x.size()[2:]
|
||||
|
||||
x = self.init_block(x)
|
||||
|
||||
x1 = self.block1_1(x)
|
||||
x1 = self.block1_2(x1)
|
||||
x1 = self.block1_3(x1)
|
||||
|
||||
x2 = self.block2_1(x1)
|
||||
x2 = self.block2_2(x2)
|
||||
x2 = self.block2_3(x2)
|
||||
x2 = self.block2_4(x2)
|
||||
|
||||
x3 = self.block3_1(x2)
|
||||
x3 = self.block3_2(x3)
|
||||
x3 = self.block3_3(x3)
|
||||
x3 = self.block3_4(x3)
|
||||
|
||||
x4 = self.block4_1(x3)
|
||||
x4 = self.block4_2(x4)
|
||||
x4 = self.block4_3(x4)
|
||||
x4 = self.block4_4(x4)
|
||||
|
||||
x_fuses = []
|
||||
if self.sa and self.dil is not None:
|
||||
for i, xi in enumerate([x1, x2, x3, x4]):
|
||||
x_fuses.append(self.attentions[i](self.dilations[i](xi)))
|
||||
elif self.sa:
|
||||
for i, xi in enumerate([x1, x2, x3, x4]):
|
||||
x_fuses.append(self.attentions[i](xi))
|
||||
elif self.dil is not None:
|
||||
for i, xi in enumerate([x1, x2, x3, x4]):
|
||||
x_fuses.append(self.dilations[i](xi))
|
||||
else:
|
||||
x_fuses = [x1, x2, x3, x4]
|
||||
|
||||
e1 = self.conv_reduces[0](x_fuses[0])
|
||||
e1 = F.interpolate(e1, (H, W), mode='bilinear', align_corners=False)
|
||||
|
||||
e2 = self.conv_reduces[1](x_fuses[1])
|
||||
e2 = F.interpolate(e2, (H, W), mode='bilinear', align_corners=False)
|
||||
|
||||
e3 = self.conv_reduces[2](x_fuses[2])
|
||||
e3 = F.interpolate(e3, (H, W), mode='bilinear', align_corners=False)
|
||||
|
||||
e4 = self.conv_reduces[3](x_fuses[3])
|
||||
e4 = F.interpolate(e4, (H, W), mode='bilinear', align_corners=False)
|
||||
|
||||
outputs = [e1, e2, e3, e4]
|
||||
output = self.classifier(torch.cat(outputs, dim=1))
|
||||
|
||||
outputs.append(output)
|
||||
outputs = [torch.sigmoid(r) for r in outputs]
|
||||
return outputs[-1]
|
||||
|
||||
|
||||
def pidinet_bsd_tiny(pretrained=False, vanilla_cnn=True):
|
||||
pdcs = config_model_converted('carv4') if vanilla_cnn else config_model(
|
||||
'carv4')
|
||||
model = PiDiNet(20, pdcs, dil=8, sa=True, convert=vanilla_cnn)
|
||||
if pretrained:
|
||||
state = torch.load(
|
||||
DOWNLOAD_TO_CACHE('models/pidinet/table5_pidinet-tiny.pth'),
|
||||
map_location='cpu')['state_dict']
|
||||
if vanilla_cnn:
|
||||
state = convert_pidinet(state, 'carv4')
|
||||
state = {
|
||||
k[len('module.'):] if k.startswith('module.') else k: v
|
||||
for k, v in state.items()
|
||||
}
|
||||
model.load_state_dict(state)
|
||||
return model
|
||||
|
||||
|
||||
def pidinet_bsd_small(pretrained=False, vanilla_cnn=True):
|
||||
pdcs = config_model_converted('carv4') if vanilla_cnn else config_model(
|
||||
'carv4')
|
||||
model = PiDiNet(30, pdcs, dil=12, sa=True, convert=vanilla_cnn)
|
||||
if pretrained:
|
||||
state = torch.load(
|
||||
DOWNLOAD_TO_CACHE('models/pidinet/table5_pidinet-small.pth'),
|
||||
map_location='cpu')['state_dict']
|
||||
if vanilla_cnn:
|
||||
state = convert_pidinet(state, 'carv4')
|
||||
state = {
|
||||
k[len('module.'):] if k.startswith('module.') else k: v
|
||||
for k, v in state.items()
|
||||
}
|
||||
model.load_state_dict(state)
|
||||
return model
|
||||
|
||||
|
||||
def pidinet_bsd(model_dir, pretrained=False, vanilla_cnn=True):
|
||||
pdcs = config_model_converted('carv4') if vanilla_cnn else config_model(
|
||||
'carv4')
|
||||
model = PiDiNet(60, pdcs, dil=24, sa=True, convert=vanilla_cnn)
|
||||
if pretrained:
|
||||
state = torch.load(
|
||||
os.path.join(model_dir, 'table5_pidinet.pth'),
|
||||
map_location='cpu')['state_dict']
|
||||
if vanilla_cnn:
|
||||
state = convert_pidinet(state, 'carv4')
|
||||
state = {
|
||||
k[len('module.'):] if k.startswith('module.') else k: v
|
||||
for k, v in state.items()
|
||||
}
|
||||
model.load_state_dict(state)
|
||||
return model
|
||||
|
||||
|
||||
def pidinet_nyud(pretrained=False, vanilla_cnn=True):
|
||||
pdcs = config_model_converted('carv4') if vanilla_cnn else config_model(
|
||||
'carv4')
|
||||
model = PiDiNet(60, pdcs, dil=24, sa=True, convert=vanilla_cnn)
|
||||
if pretrained:
|
||||
state = torch.load(
|
||||
DOWNLOAD_TO_CACHE('models/pidinet/table6_pidinet.pth'),
|
||||
map_location='cpu')['state_dict']
|
||||
if vanilla_cnn:
|
||||
state = convert_pidinet(state, 'carv4')
|
||||
state = {
|
||||
k[len('module.'):] if k.startswith('module.') else k: v
|
||||
for k, v in state.items()
|
||||
}
|
||||
model.load_state_dict(state)
|
||||
return model
|
||||
|
||||
|
||||
def pidinet_multicue(pretrained=False, vanilla_cnn=True):
|
||||
pdcs = config_model_converted('carv4') if vanilla_cnn else config_model(
|
||||
'carv4')
|
||||
model = PiDiNet(60, pdcs, dil=24, sa=True, convert=vanilla_cnn)
|
||||
if pretrained:
|
||||
state = torch.load(
|
||||
DOWNLOAD_TO_CACHE('models/pidinet/table7_pidinet.pth'),
|
||||
map_location='cpu')['state_dict']
|
||||
if vanilla_cnn:
|
||||
state = convert_pidinet(state, 'carv4')
|
||||
state = {
|
||||
k[len('module.'):] if k.startswith('module.') else k: v
|
||||
for k, v in state.items()
|
||||
}
|
||||
model.load_state_dict(state)
|
||||
return model
|
||||
@@ -0,0 +1,110 @@
|
||||
r"""PyTorch re-implementation adapted from the Lua code in ``https://github.com/bobbens/sketch_simplification''.
|
||||
"""
|
||||
import math
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from modelscope.models.multi_modal.videocomposer.utils.utils import \
|
||||
DOWNLOAD_TO_CACHE
|
||||
|
||||
__all__ = [
|
||||
'SketchSimplification', 'sketch_simplification_gan',
|
||||
'sketch_simplification_mse', 'sketch_to_pencil_v1', 'sketch_to_pencil_v2'
|
||||
]
|
||||
|
||||
|
||||
class SketchSimplification(nn.Module):
|
||||
r"""NOTE:
|
||||
1. Input image should has only one gray channel.
|
||||
2. Input image size should be divisible by 8.
|
||||
3. Sketch in the input/output image is in dark color while background in light color.
|
||||
"""
|
||||
|
||||
def __init__(self, mean, std):
|
||||
assert isinstance(mean, float) and isinstance(std, float)
|
||||
super(SketchSimplification, self).__init__()
|
||||
self.mean = mean
|
||||
self.std = std
|
||||
|
||||
# layers
|
||||
self.layers = nn.Sequential(
|
||||
nn.Conv2d(1, 48, 5, 2, 2), nn.ReLU(inplace=True),
|
||||
nn.Conv2d(48, 128, 3, 1, 1), nn.ReLU(inplace=True),
|
||||
nn.Conv2d(128, 128, 3, 1, 1), nn.ReLU(inplace=True),
|
||||
nn.Conv2d(128, 128, 3, 2, 1), nn.ReLU(inplace=True),
|
||||
nn.Conv2d(128, 256, 3, 1, 1), nn.ReLU(inplace=True),
|
||||
nn.Conv2d(256, 256, 3, 1, 1), nn.ReLU(inplace=True),
|
||||
nn.Conv2d(256, 256, 3, 2, 1), nn.ReLU(inplace=True),
|
||||
nn.Conv2d(256, 512, 3, 1, 1), nn.ReLU(inplace=True),
|
||||
nn.Conv2d(512, 1024, 3, 1, 1), nn.ReLU(inplace=True),
|
||||
nn.Conv2d(1024, 1024, 3, 1, 1), nn.ReLU(inplace=True),
|
||||
nn.Conv2d(1024, 1024, 3, 1, 1), nn.ReLU(inplace=True),
|
||||
nn.Conv2d(1024, 1024, 3, 1, 1), nn.ReLU(inplace=True),
|
||||
nn.Conv2d(1024, 512, 3, 1, 1), nn.ReLU(inplace=True),
|
||||
nn.Conv2d(512, 256, 3, 1, 1), nn.ReLU(inplace=True),
|
||||
nn.ConvTranspose2d(256, 256, 4, 2, 1), nn.ReLU(inplace=True),
|
||||
nn.Conv2d(256, 256, 3, 1, 1), nn.ReLU(inplace=True),
|
||||
nn.Conv2d(256, 128, 3, 1, 1), nn.ReLU(inplace=True),
|
||||
nn.ConvTranspose2d(128, 128, 4, 2, 1), nn.ReLU(inplace=True),
|
||||
nn.Conv2d(128, 128, 3, 1, 1), nn.ReLU(inplace=True),
|
||||
nn.Conv2d(128, 48, 3, 1, 1), nn.ReLU(inplace=True),
|
||||
nn.ConvTranspose2d(48, 48, 4, 2, 1), nn.ReLU(inplace=True),
|
||||
nn.Conv2d(48, 24, 3, 1, 1), nn.ReLU(inplace=True),
|
||||
nn.Conv2d(24, 1, 3, 1, 1), nn.Sigmoid())
|
||||
|
||||
def forward(self, x):
|
||||
r"""x: [B, 1, H, W] within range [0, 1]. Sketch pixels in dark color.
|
||||
"""
|
||||
x = (x - self.mean) / self.std
|
||||
return self.layers(x)
|
||||
|
||||
|
||||
def sketch_simplification_gan(model_dir, pretrained=False):
|
||||
model = SketchSimplification(
|
||||
mean=0.9664114577640158, std=0.0858381272736797)
|
||||
if pretrained:
|
||||
model.load_state_dict(
|
||||
torch.load(
|
||||
os.path.join(model_dir, 'sketch_simplification_gan.pth'),
|
||||
map_location='cpu'))
|
||||
return model
|
||||
|
||||
|
||||
def sketch_simplification_mse(pretrained=False):
|
||||
model = SketchSimplification(
|
||||
mean=0.9664423107454593, std=0.08583666033640507)
|
||||
if pretrained:
|
||||
model.load_state_dict(
|
||||
torch.load(
|
||||
DOWNLOAD_TO_CACHE(
|
||||
'models/sketch_simplification/sketch_simplification_mse.pth'
|
||||
),
|
||||
map_location='cpu'))
|
||||
return model
|
||||
|
||||
|
||||
def sketch_to_pencil_v1(pretrained=False):
|
||||
model = SketchSimplification(
|
||||
mean=0.9817833515894078, std=0.0925009022585048)
|
||||
if pretrained:
|
||||
model.load_state_dict(
|
||||
torch.load(
|
||||
DOWNLOAD_TO_CACHE(
|
||||
'models/sketch_simplification/sketch_to_pencil_v1.pth'),
|
||||
map_location='cpu'))
|
||||
return model
|
||||
|
||||
|
||||
def sketch_to_pencil_v2(pretrained=False):
|
||||
model = SketchSimplification(
|
||||
mean=0.9851298627337799, std=0.07418377454883571)
|
||||
if pretrained:
|
||||
model.load_state_dict(
|
||||
torch.load(
|
||||
DOWNLOAD_TO_CACHE(
|
||||
'models/sketch_simplification/sketch_to_pencil_v2.pth'),
|
||||
map_location='cpu'))
|
||||
return model
|
||||
@@ -0,0 +1,42 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
annotator_ckpts_path = os.path.join(os.path.dirname(__file__), 'ckpts')
|
||||
|
||||
|
||||
def HWC3(x):
|
||||
assert x.dtype == np.uint8
|
||||
if x.ndim == 2:
|
||||
x = x[:, :, None]
|
||||
assert x.ndim == 3
|
||||
H, W, C = x.shape
|
||||
assert C == 1 or C == 3 or C == 4
|
||||
if C == 3:
|
||||
return x
|
||||
if C == 1:
|
||||
return np.concatenate([x, x, x], axis=2)
|
||||
if C == 4:
|
||||
color = x[:, :, 0:3].astype(np.float32)
|
||||
alpha = x[:, :, 3:4].astype(np.float32) / 255.0
|
||||
y = color * alpha + 255.0 * (1.0 - alpha)
|
||||
y = y.clip(0, 255).astype(np.uint8)
|
||||
return y
|
||||
|
||||
|
||||
def resize_image(input_image, resolution):
|
||||
H, W, C = input_image.shape
|
||||
H = float(H)
|
||||
W = float(W)
|
||||
k = float(resolution) / min(H, W)
|
||||
H *= k
|
||||
W *= k
|
||||
H = int(np.round(H / 64.0)) * 64
|
||||
W = int(np.round(W / 64.0)) * 64
|
||||
img = cv2.resize(
|
||||
input_image, (W, H),
|
||||
interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
|
||||
return img
|
||||
650
modelscope/models/multi_modal/videocomposer/autoencoder.py
Normal file
650
modelscope/models/multi_modal/videocomposer/autoencoder.py
Normal file
@@ -0,0 +1,650 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
__all__ = ['AutoencoderKL']
|
||||
|
||||
|
||||
def nonlinearity(x):
|
||||
# swish
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
def Normalize(in_channels, num_groups=32):
|
||||
return torch.nn.GroupNorm(
|
||||
num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
|
||||
|
||||
class DiagonalGaussianDistribution(object):
|
||||
|
||||
def __init__(self, parameters, deterministic=False):
|
||||
self.parameters = parameters
|
||||
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
|
||||
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
||||
self.deterministic = deterministic
|
||||
self.std = torch.exp(0.5 * self.logvar)
|
||||
self.var = torch.exp(self.logvar)
|
||||
if self.deterministic:
|
||||
self.var = self.std = torch.zeros_like(
|
||||
self.mean).to(device=self.parameters.device)
|
||||
|
||||
def sample(self):
|
||||
x = self.mean + self.std * torch.randn(
|
||||
self.mean.shape).to(device=self.parameters.device)
|
||||
return x
|
||||
|
||||
def kl(self, other=None):
|
||||
if self.deterministic:
|
||||
return torch.Tensor([0.])
|
||||
else:
|
||||
if other is None:
|
||||
return 0.5 * torch.sum(
|
||||
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
|
||||
dim=[1, 2, 3])
|
||||
else:
|
||||
return 0.5 * torch.sum(
|
||||
torch.pow(self.mean - other.mean, 2) / other.var
|
||||
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
|
||||
dim=[1, 2, 3])
|
||||
|
||||
def nll(self, sample, dims=[1, 2, 3]):
|
||||
if self.deterministic:
|
||||
return torch.Tensor([0.])
|
||||
logtwopi = np.log(2.0 * np.pi)
|
||||
return 0.5 * torch.sum(
|
||||
logtwopi + self.logvar
|
||||
+ torch.pow(sample - self.mean, 2) / self.var,
|
||||
dim=dims)
|
||||
|
||||
def mode(self):
|
||||
return self.mean
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, with_conv):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
if self.with_conv:
|
||||
# no asymmetric padding in torch conv, must do it ourselves
|
||||
self.conv = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
if self.with_conv:
|
||||
pad = (0, 1, 0, 1)
|
||||
x = torch.nn.functional.pad(x, pad, mode='constant', value=0)
|
||||
x = self.conv(x)
|
||||
else:
|
||||
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
||||
return x
|
||||
|
||||
|
||||
class ResnetBlock(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
in_channels,
|
||||
out_channels=None,
|
||||
conv_shortcut=False,
|
||||
dropout,
|
||||
temb_channels=512):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
self.use_conv_shortcut = conv_shortcut
|
||||
|
||||
self.norm1 = Normalize(in_channels)
|
||||
self.conv1 = torch.nn.Conv2d(
|
||||
in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
if temb_channels > 0:
|
||||
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
|
||||
self.norm2 = Normalize(out_channels)
|
||||
self.dropout = torch.nn.Dropout(dropout)
|
||||
self.conv2 = torch.nn.Conv2d(
|
||||
out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
self.conv_shortcut = torch.nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
else:
|
||||
self.nin_shortcut = torch.nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
|
||||
def forward(self, x, temb):
|
||||
h = x
|
||||
h = self.norm1(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv1(h)
|
||||
|
||||
if temb is not None:
|
||||
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
|
||||
|
||||
h = self.norm2(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.dropout(h)
|
||||
h = self.conv2(h)
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
x = self.conv_shortcut(x)
|
||||
else:
|
||||
x = self.nin_shortcut(x)
|
||||
|
||||
return x + h
|
||||
|
||||
|
||||
class AttnBlock(nn.Module):
|
||||
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = Normalize(in_channels)
|
||||
self.q = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.k = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.v = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.proj_out = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
# compute attention
|
||||
b, c, h, w = q.shape
|
||||
q = q.reshape(b, c, h * w)
|
||||
q = q.permute(0, 2, 1) # b,hw,c
|
||||
k = k.reshape(b, c, h * w) # b,c,hw
|
||||
w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
||||
w_ = w_ * (int(c)**(-0.5))
|
||||
w_ = torch.nn.functional.softmax(w_, dim=2)
|
||||
|
||||
# attend to values
|
||||
v = v.reshape(b, c, h * w)
|
||||
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
|
||||
h_ = torch.bmm(
|
||||
v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
||||
h_ = h_.reshape(b, c, h, w)
|
||||
|
||||
h_ = self.proj_out(h_)
|
||||
|
||||
return x + h_
|
||||
|
||||
|
||||
class AttnBlock(nn.Module): # noqa
|
||||
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = Normalize(in_channels)
|
||||
self.q = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.k = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.v = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.proj_out = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
# compute attention
|
||||
b, c, h, w = q.shape
|
||||
q = q.reshape(b, c, h * w)
|
||||
q = q.permute(0, 2, 1) # b,hw,c
|
||||
k = k.reshape(b, c, h * w) # b,c,hw
|
||||
w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
||||
w_ = w_ * (int(c)**(-0.5))
|
||||
w_ = torch.nn.functional.softmax(w_, dim=2)
|
||||
|
||||
# attend to values
|
||||
v = v.reshape(b, c, h * w)
|
||||
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
|
||||
h_ = torch.bmm(
|
||||
v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
||||
h_ = h_.reshape(b, c, h, w)
|
||||
|
||||
h_ = self.proj_out(h_)
|
||||
|
||||
return x + h_
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, with_conv):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
if self.with_conv:
|
||||
self.conv = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
x = torch.nn.functional.interpolate(
|
||||
x, scale_factor=2.0, mode='nearest')
|
||||
if self.with_conv:
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class Downsample(nn.Module): # noqa
|
||||
|
||||
def __init__(self, in_channels, with_conv):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
if self.with_conv:
|
||||
# no asymmetric padding in torch conv, must do it ourselves
|
||||
self.conv = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
if self.with_conv:
|
||||
pad = (0, 1, 0, 1)
|
||||
x = torch.nn.functional.pad(x, pad, mode='constant', value=0)
|
||||
x = self.conv(x)
|
||||
else:
|
||||
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
||||
return x
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
ch,
|
||||
out_ch,
|
||||
ch_mult=(1, 2, 4, 8),
|
||||
num_res_blocks,
|
||||
attn_resolutions,
|
||||
dropout=0.0,
|
||||
resamp_with_conv=True,
|
||||
in_channels,
|
||||
resolution,
|
||||
z_channels,
|
||||
double_z=True,
|
||||
use_linear_attn=False,
|
||||
attn_type='vanilla',
|
||||
**ignore_kwargs):
|
||||
super().__init__()
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
|
||||
# downsampling
|
||||
self.conv_in = torch.nn.Conv2d(
|
||||
in_channels, self.ch, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
curr_res = resolution
|
||||
in_ch_mult = (1, ) + tuple(ch_mult)
|
||||
self.in_ch_mult = in_ch_mult
|
||||
self.down = nn.ModuleList()
|
||||
for i_level in range(self.num_resolutions):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_in = ch * in_ch_mult[i_level]
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks):
|
||||
block.append(
|
||||
ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout))
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(AttnBlock(block_in))
|
||||
down = nn.Module()
|
||||
down.block = block
|
||||
down.attn = attn
|
||||
if i_level != self.num_resolutions - 1:
|
||||
down.downsample = Downsample(block_in, resamp_with_conv)
|
||||
curr_res = curr_res // 2
|
||||
self.down.append(down)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
self.mid.attn_1 = AttnBlock(block_in)
|
||||
self.mid.block_2 = ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = torch.nn.Conv2d(
|
||||
block_in,
|
||||
2 * z_channels if double_z else z_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
# timestep embedding
|
||||
temb = None
|
||||
|
||||
# downsampling
|
||||
hs = [self.conv_in(x)]
|
||||
for i_level in range(self.num_resolutions):
|
||||
for i_block in range(self.num_res_blocks):
|
||||
h = self.down[i_level].block[i_block](hs[-1], temb)
|
||||
if len(self.down[i_level].attn) > 0:
|
||||
h = self.down[i_level].attn[i_block](h)
|
||||
hs.append(h)
|
||||
if i_level != self.num_resolutions - 1:
|
||||
hs.append(self.down[i_level].downsample(hs[-1]))
|
||||
|
||||
# middle
|
||||
h = hs[-1]
|
||||
h = self.mid.block_1(h, temb)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h, temb)
|
||||
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
ch,
|
||||
out_ch,
|
||||
ch_mult=(1, 2, 4, 8),
|
||||
num_res_blocks,
|
||||
attn_resolutions,
|
||||
dropout=0.0,
|
||||
resamp_with_conv=True,
|
||||
in_channels,
|
||||
resolution,
|
||||
z_channels,
|
||||
give_pre_end=False,
|
||||
tanh_out=False,
|
||||
use_linear_attn=False,
|
||||
attn_type='vanilla',
|
||||
**ignorekwargs):
|
||||
super().__init__()
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
self.give_pre_end = give_pre_end
|
||||
self.tanh_out = tanh_out
|
||||
|
||||
# compute in_ch_mult, block_in and curr_res at lowest res
|
||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||
curr_res = resolution // 2**(self.num_resolutions - 1)
|
||||
self.z_shape = (1, z_channels, curr_res, curr_res)
|
||||
print('Working with z of shape {} = {} dimensions.'.format(
|
||||
self.z_shape, np.prod(self.z_shape)))
|
||||
|
||||
# z to block_in
|
||||
self.conv_in = torch.nn.Conv2d(
|
||||
z_channels, block_in, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
self.mid.attn_1 = AttnBlock(block_in)
|
||||
self.mid.block_2 = ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
|
||||
# upsampling
|
||||
self.up = nn.ModuleList()
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
block.append(
|
||||
ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout))
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(AttnBlock(block_in))
|
||||
up = nn.Module()
|
||||
up.block = block
|
||||
up.attn = attn
|
||||
if i_level != 0:
|
||||
up.upsample = Upsample(block_in, resamp_with_conv)
|
||||
curr_res = curr_res * 2
|
||||
self.up.insert(0, up) # prepend to get consistent order
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = torch.nn.Conv2d(
|
||||
block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, z):
|
||||
self.last_z_shape = z.shape
|
||||
|
||||
# timestep embedding
|
||||
temb = None
|
||||
|
||||
# z to block_in
|
||||
h = self.conv_in(z)
|
||||
|
||||
# middle
|
||||
h = self.mid.block_1(h, temb)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h, temb)
|
||||
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
h = self.up[i_level].block[i_block](h, temb)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
h = self.up[i_level].attn[i_block](h)
|
||||
if i_level != 0:
|
||||
h = self.up[i_level].upsample(h)
|
||||
|
||||
# end
|
||||
if self.give_pre_end:
|
||||
return h
|
||||
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
if self.tanh_out:
|
||||
h = torch.tanh(h)
|
||||
return h
|
||||
|
||||
|
||||
class AutoencoderKL(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
ddconfig,
|
||||
embed_dim,
|
||||
ckpt_path=None,
|
||||
ignore_keys=[],
|
||||
image_key='image',
|
||||
colorize_nlabels=None,
|
||||
monitor=None,
|
||||
ema_decay=None,
|
||||
learn_logvar=False):
|
||||
super().__init__()
|
||||
self.learn_logvar = learn_logvar
|
||||
self.image_key = image_key
|
||||
self.encoder = Encoder(**ddconfig)
|
||||
self.decoder = Decoder(**ddconfig)
|
||||
assert ddconfig['double_z']
|
||||
self.quant_conv = torch.nn.Conv2d(2 * ddconfig['z_channels'],
|
||||
2 * embed_dim, 1)
|
||||
self.post_quant_conv = torch.nn.Conv2d(embed_dim,
|
||||
ddconfig['z_channels'], 1)
|
||||
self.embed_dim = embed_dim
|
||||
if colorize_nlabels is not None:
|
||||
assert type(colorize_nlabels) == int
|
||||
self.register_buffer('colorize',
|
||||
torch.randn(3, colorize_nlabels, 1, 1))
|
||||
if monitor is not None:
|
||||
self.monitor = monitor
|
||||
self.use_ema = ema_decay is not None
|
||||
if ckpt_path is not None:
|
||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
||||
|
||||
def init_from_ckpt(self, path, ignore_keys=list()):
|
||||
sd = torch.load(path, map_location='cpu')['state_dict']
|
||||
keys = list(sd.keys())
|
||||
for key in keys:
|
||||
print(key, sd[key].shape)
|
||||
import collections
|
||||
sd_new = collections.OrderedDict()
|
||||
for k in keys:
|
||||
if k.find('first_stage_model') >= 0:
|
||||
k_new = k.split('first_stage_model.')[-1]
|
||||
sd_new[k_new] = sd[k]
|
||||
self.load_state_dict(sd_new, strict=True)
|
||||
print(f'Restored from {path}')
|
||||
|
||||
def init_from_ckpt2(self, path, ignore_keys=list()):
|
||||
sd = torch.load(path, map_location='cpu')['state_dict']
|
||||
keys = list(sd.keys())
|
||||
|
||||
first_stage_model
|
||||
for k in keys:
|
||||
for ik in ignore_keys:
|
||||
if k.startswith(ik):
|
||||
print('Deleting key {} from state_dict.'.format(k))
|
||||
del sd[k]
|
||||
self.load_state_dict(sd, strict=False)
|
||||
print(f'Restored from {path}')
|
||||
|
||||
def on_train_batch_end(self, *args, **kwargs):
|
||||
if self.use_ema:
|
||||
self.model_ema(self)
|
||||
|
||||
def encode(self, x):
|
||||
h = self.encoder(x)
|
||||
moments = self.quant_conv(h)
|
||||
posterior = DiagonalGaussianDistribution(moments)
|
||||
return posterior
|
||||
|
||||
def decode(self, z):
|
||||
z = self.post_quant_conv(z)
|
||||
dec = self.decoder(z)
|
||||
return dec
|
||||
|
||||
def forward(self, input, sample_posterior=True):
|
||||
posterior = self.encode(input)
|
||||
if sample_posterior:
|
||||
z = posterior.sample()
|
||||
else:
|
||||
z = posterior.mode()
|
||||
dec = self.decode(z)
|
||||
return dec, posterior
|
||||
|
||||
def get_input(self, batch, k):
|
||||
x = batch[k]
|
||||
if len(x.shape) == 3:
|
||||
x = x[..., None]
|
||||
x = x.permute(0, 3, 1,
|
||||
2).to(memory_format=torch.contiguous_format).float()
|
||||
return x
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.decoder.conv_out.weight
|
||||
|
||||
@torch.no_grad()
|
||||
def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs):
|
||||
log = dict()
|
||||
x = self.get_input(batch, self.image_key)
|
||||
x = x.to(self.device)
|
||||
if not only_inputs:
|
||||
xrec, posterior = self(x)
|
||||
if x.shape[1] > 3:
|
||||
# colorize with random projection
|
||||
assert xrec.shape[1] > 3
|
||||
x = self.to_rgb(x)
|
||||
xrec = self.to_rgb(xrec)
|
||||
log['samples'] = self.decode(torch.randn_like(posterior.sample()))
|
||||
log['reconstructions'] = xrec
|
||||
if log_ema or self.use_ema:
|
||||
with self.ema_scope():
|
||||
xrec_ema, posterior_ema = self(x)
|
||||
if x.shape[1] > 3:
|
||||
# colorize with random projection
|
||||
assert xrec_ema.shape[1] > 3
|
||||
xrec_ema = self.to_rgb(xrec_ema)
|
||||
log['samples_ema'] = self.decode(
|
||||
torch.randn_like(posterior_ema.sample()))
|
||||
log['reconstructions_ema'] = xrec_ema
|
||||
log['inputs'] = x
|
||||
return log
|
||||
|
||||
def to_rgb(self, x):
|
||||
assert self.image_key == 'segmentation'
|
||||
if not hasattr(self, 'colorize'):
|
||||
self.register_buffer('colorize',
|
||||
torch.randn(3, x.shape[1], 1, 1).to(x))
|
||||
x = F.conv2d(x, weight=self.colorize)
|
||||
x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
|
||||
return x
|
||||
|
||||
|
||||
class IdentityFirstStage(torch.nn.Module):
|
||||
|
||||
def __init__(self, *args, vq_interface=False, **kwargs):
|
||||
self.vq_interface = vq_interface
|
||||
super().__init__()
|
||||
|
||||
def encode(self, x, *args, **kwargs):
|
||||
return x
|
||||
|
||||
def decode(self, x, *args, **kwargs):
|
||||
return x
|
||||
|
||||
def quantize(self, x, *args, **kwargs):
|
||||
if self.vq_interface:
|
||||
return x, None, [None, None, None]
|
||||
return x
|
||||
|
||||
def forward(self, x, *args, **kwargs):
|
||||
return x
|
||||
143
modelscope/models/multi_modal/videocomposer/clip.py
Normal file
143
modelscope/models/multi_modal/videocomposer/clip.py
Normal file
@@ -0,0 +1,143 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import numpy as np
|
||||
import open_clip
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.transforms as T
|
||||
|
||||
|
||||
class FrozenOpenCLIPEmbedder(nn.Module):
|
||||
"""
|
||||
Uses the OpenCLIP transformer encoder for text
|
||||
"""
|
||||
LAYERS = ['last', 'penultimate']
|
||||
|
||||
def __init__(self,
|
||||
arch='ViT-H-14',
|
||||
pretrained='laion2b_s32b_b79k',
|
||||
device='cuda',
|
||||
max_length=77,
|
||||
freeze=True,
|
||||
layer='last'):
|
||||
super().__init__()
|
||||
assert layer in self.LAYERS
|
||||
model, _, _ = open_clip.create_model_and_transforms(
|
||||
arch, device=torch.device('cpu'), pretrained=pretrained)
|
||||
del model.visual
|
||||
self.model = model
|
||||
|
||||
self.device = device
|
||||
self.max_length = max_length
|
||||
if freeze:
|
||||
self.freeze()
|
||||
self.layer = layer
|
||||
if self.layer == 'last':
|
||||
self.layer_idx = 0
|
||||
elif self.layer == 'penultimate':
|
||||
self.layer_idx = 1
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
def freeze(self):
|
||||
self.model = self.model.eval()
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, text):
|
||||
tokens = open_clip.tokenize(text)
|
||||
z = self.encode_with_transformer(tokens.to(self.device))
|
||||
return z
|
||||
|
||||
def encode_with_transformer(self, text):
|
||||
x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
|
||||
x = x + self.model.positional_embedding
|
||||
x = x.permute(1, 0, 2) # NLD -> LND
|
||||
x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
|
||||
x = x.permute(1, 0, 2) # LND -> NLD
|
||||
x = self.model.ln_final(x)
|
||||
return x
|
||||
|
||||
def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
|
||||
for i, r in enumerate(self.model.transformer.resblocks):
|
||||
if i == len(self.model.transformer.resblocks) - self.layer_idx:
|
||||
break
|
||||
if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting(
|
||||
):
|
||||
x = checkpoint(r, x, attn_mask)
|
||||
else:
|
||||
x = r(x, attn_mask=attn_mask)
|
||||
return x
|
||||
|
||||
def encode(self, text):
|
||||
return self(text)
|
||||
|
||||
|
||||
class FrozenOpenCLIPVisualEmbedder(nn.Module):
|
||||
"""
|
||||
Uses the OpenCLIP transformer encoder for text
|
||||
"""
|
||||
LAYERS = ['last', 'penultimate']
|
||||
|
||||
def __init__(self,
|
||||
arch='ViT-H-14',
|
||||
pretrained='laion2b_s32b_b79k',
|
||||
device='cuda',
|
||||
max_length=77,
|
||||
freeze=True,
|
||||
layer='last',
|
||||
input_shape=(224, 224, 3)):
|
||||
super().__init__()
|
||||
assert layer in self.LAYERS
|
||||
model, _, preprocess = open_clip.create_model_and_transforms(
|
||||
arch, device=torch.device('cpu'), pretrained=pretrained)
|
||||
del model.transformer
|
||||
self.model = model
|
||||
data_white = np.ones(input_shape, dtype=np.uint8) * 255
|
||||
self.black_image = preprocess(T.ToPILImage()(data_white)).unsqueeze(0)
|
||||
self.preprocess = preprocess
|
||||
|
||||
self.device = device
|
||||
self.max_length = max_length # 77
|
||||
if freeze:
|
||||
self.freeze()
|
||||
self.layer = layer # 'penultimate'
|
||||
if self.layer == 'last':
|
||||
self.layer_idx = 0
|
||||
elif self.layer == 'penultimate':
|
||||
self.layer_idx = 1
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
def freeze(self):
|
||||
self.model = self.model.eval()
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, image):
|
||||
# tokens = open_clip.tokenize(text)
|
||||
z = self.model.encode_image(image.to(self.device))
|
||||
return z
|
||||
|
||||
def encode_with_transformer(self, text):
|
||||
x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
|
||||
x = x + self.model.positional_embedding
|
||||
x = x.permute(1, 0, 2) # NLD -> LND
|
||||
x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
|
||||
x = x.permute(1, 0, 2) # LND -> NLD
|
||||
x = self.model.ln_final(x)
|
||||
return x
|
||||
|
||||
def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
|
||||
for i, r in enumerate(self.model.transformer.resblocks):
|
||||
if i == len(self.model.transformer.resblocks) - self.layer_idx:
|
||||
break
|
||||
if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting(
|
||||
):
|
||||
x = checkpoint(r, x, attn_mask)
|
||||
else:
|
||||
x = r(x, attn_mask=attn_mask)
|
||||
return x
|
||||
|
||||
def encode(self, text):
|
||||
return self(text)
|
||||
156
modelscope/models/multi_modal/videocomposer/config.py
Normal file
156
modelscope/models/multi_modal/videocomposer/config.py
Normal file
@@ -0,0 +1,156 @@
|
||||
import logging
|
||||
import os
|
||||
import os.path as osp
|
||||
from datetime import datetime
|
||||
|
||||
import torch
|
||||
from easydict import EasyDict
|
||||
|
||||
cfg = EasyDict(__name__='Config: VideoComposer')
|
||||
|
||||
pmi_world_size = int(os.getenv('WORLD_SIZE', 1))
|
||||
gpus_per_machine = torch.cuda.device_count()
|
||||
world_size = pmi_world_size * gpus_per_machine
|
||||
|
||||
cfg.video_compositions = [
|
||||
'text', 'mask', 'depthmap', 'sketch', 'motion', 'image', 'local_image',
|
||||
'single_sketch'
|
||||
]
|
||||
|
||||
# dataset
|
||||
cfg.root_dir = 'webvid10m/'
|
||||
|
||||
cfg.alpha = 0.7
|
||||
|
||||
cfg.misc_size = 384
|
||||
cfg.depth_std = 20.0
|
||||
cfg.depth_clamp = 10.0
|
||||
cfg.hist_sigma = 10.0
|
||||
|
||||
cfg.use_image_dataset = False
|
||||
cfg.alpha_img = 0.7
|
||||
|
||||
cfg.resolution = 256
|
||||
cfg.mean = [0.5, 0.5, 0.5]
|
||||
cfg.std = [0.5, 0.5, 0.5]
|
||||
|
||||
# sketch
|
||||
cfg.sketch_mean = [0.485, 0.456, 0.406]
|
||||
cfg.sketch_std = [0.229, 0.224, 0.225]
|
||||
|
||||
# dataloader
|
||||
cfg.max_words = 1000
|
||||
|
||||
cfg.frame_lens = [
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
]
|
||||
cfg.feature_framerates = [
|
||||
4,
|
||||
]
|
||||
cfg.feature_framerate = 4
|
||||
cfg.batch_sizes = {
|
||||
str(1): 1,
|
||||
str(4): 1,
|
||||
str(8): 1,
|
||||
str(16): 1,
|
||||
}
|
||||
|
||||
cfg.chunk_size = 64
|
||||
cfg.num_workers = 8
|
||||
cfg.prefetch_factor = 2
|
||||
cfg.seed = 8888
|
||||
|
||||
# diffusion
|
||||
cfg.num_timesteps = 1000
|
||||
cfg.mean_type = 'eps'
|
||||
cfg.var_type = 'fixed_small'
|
||||
cfg.loss_type = 'mse'
|
||||
cfg.ddim_timesteps = 50
|
||||
cfg.ddim_eta = 0.0
|
||||
cfg.clamp = 1.0
|
||||
cfg.share_noise = False
|
||||
cfg.use_div_loss = False
|
||||
|
||||
# classifier-free guidance
|
||||
cfg.p_zero = 0.9
|
||||
cfg.guide_scale = 6.0
|
||||
|
||||
# stabel diffusion
|
||||
cfg.sd_checkpoint = 'v2-1_512-ema-pruned.ckpt'
|
||||
|
||||
# clip vision encoder
|
||||
cfg.vit_image_size = 336
|
||||
cfg.vit_patch_size = 14
|
||||
cfg.vit_dim = 1024
|
||||
cfg.vit_out_dim = 768
|
||||
cfg.vit_heads = 16
|
||||
cfg.vit_layers = 24
|
||||
cfg.vit_mean = [0.48145466, 0.4578275, 0.40821073]
|
||||
cfg.vit_std = [0.26862954, 0.26130258, 0.27577711]
|
||||
cfg.clip_checkpoint = 'open_clip_pytorch_model.bin'
|
||||
cfg.mvs_visual = False
|
||||
|
||||
# unet
|
||||
cfg.unet_in_dim = 4
|
||||
cfg.unet_concat_dim = 8
|
||||
cfg.unet_y_dim = cfg.vit_out_dim
|
||||
cfg.unet_context_dim = 1024
|
||||
cfg.unet_out_dim = 8 if cfg.var_type.startswith('learned') else 4
|
||||
cfg.unet_dim = 320
|
||||
cfg.unet_dim_mult = [1, 2, 4, 4]
|
||||
cfg.unet_res_blocks = 2
|
||||
cfg.unet_num_heads = 8
|
||||
cfg.unet_head_dim = 64
|
||||
cfg.unet_attn_scales = [1 / 1, 1 / 2, 1 / 4]
|
||||
cfg.unet_dropout = 0.1
|
||||
cfg.misc_dropout = 0.5
|
||||
cfg.p_all_zero = 0.1
|
||||
cfg.p_all_keep = 0.1
|
||||
cfg.temporal_conv = False
|
||||
cfg.temporal_attn_times = 1
|
||||
cfg.temporal_attention = True
|
||||
|
||||
cfg.use_fps_condition = False
|
||||
cfg.use_sim_mask = False
|
||||
|
||||
# Default: load 2d pretrain
|
||||
cfg.pretrained = False
|
||||
cfg.fix_weight = False
|
||||
|
||||
# Default resume
|
||||
cfg.resume = True
|
||||
cfg.resume_step = 148000
|
||||
cfg.resume_check_dir = '.'
|
||||
cfg.resume_checkpoint = os.path.join(
|
||||
cfg.resume_check_dir,
|
||||
f'step_{cfg.resume_step}/non_ema_{cfg.resume_step}.pth')
|
||||
cfg.resume_optimizer = False
|
||||
if cfg.resume_optimizer:
|
||||
cfg.resume_optimizer = os.path.join(
|
||||
cfg.resume_check_dir, f'optimizer_step_{cfg.resume_step}.pt')
|
||||
|
||||
# acceleration
|
||||
cfg.use_ema = True
|
||||
# for debug, no ema
|
||||
if world_size < 2:
|
||||
cfg.use_ema = False
|
||||
cfg.load_from = None
|
||||
|
||||
cfg.use_checkpoint = True
|
||||
cfg.use_sharded_ddp = False
|
||||
cfg.use_fsdp = False
|
||||
cfg.use_fp16 = True
|
||||
|
||||
# training
|
||||
cfg.ema_decay = 0.9999
|
||||
cfg.viz_interval = 1000
|
||||
cfg.save_ckp_interval = 1000
|
||||
|
||||
# logging
|
||||
cfg.log_interval = 100
|
||||
composition_strings = '_'.join(cfg.video_compositions)
|
||||
# Default log_dir
|
||||
cfg.log_dir = 'outputs/'
|
||||
@@ -0,0 +1,2 @@
|
||||
ENABLE: true
|
||||
DATASET: webvid10m
|
||||
@@ -0,0 +1,20 @@
|
||||
TASK_TYPE: MULTI_TASK
|
||||
ENABLE: true
|
||||
DATASET: webvid10m
|
||||
video_compositions: ['text', 'mask', 'depthmap', 'sketch', 'motion', 'image', 'local_image', 'single_sketch']
|
||||
batch_sizes: {
|
||||
"1": 1,
|
||||
"4": 1,
|
||||
"8": 1,
|
||||
"16": 1,
|
||||
}
|
||||
vit_image_size: 224
|
||||
network_name: UNetSD_temporal
|
||||
resume: true
|
||||
resume_step: 228000
|
||||
num_workers: 1
|
||||
mvs_visual: False
|
||||
chunk_size: 1
|
||||
resume_checkpoint: "model_weights/non_ema_228000.pth"
|
||||
log_dir: 'outputs'
|
||||
num_steps: 1
|
||||
@@ -0,0 +1,23 @@
|
||||
TASK_TYPE: SINGLE_TASK
|
||||
read_image: True # You NEED Open It
|
||||
ENABLE: true
|
||||
DATASET: webvid10m
|
||||
video_compositions: ['text', 'mask', 'depthmap', 'sketch', 'motion', 'image', 'local_image', 'single_sketch']
|
||||
guidances: ['y', 'local_image', 'motion'] # You NEED Open It
|
||||
batch_sizes: {
|
||||
"1": 1,
|
||||
"4": 1,
|
||||
"8": 1,
|
||||
"16": 1,
|
||||
}
|
||||
vit_image_size: 224
|
||||
network_name: UNetSD_temporal
|
||||
resume: true
|
||||
resume_step: 228000
|
||||
seed: 182
|
||||
num_workers: 0
|
||||
mvs_visual: False
|
||||
chunk_size: 1
|
||||
resume_checkpoint: "model_weights/non_ema_228000.pth"
|
||||
log_dir: 'outputs'
|
||||
num_steps: 1
|
||||
@@ -0,0 +1,24 @@
|
||||
TASK_TYPE: SINGLE_TASK
|
||||
read_image: True # You NEED Open It
|
||||
read_style: True
|
||||
ENABLE: true
|
||||
DATASET: webvid10m
|
||||
video_compositions: ['text', 'mask', 'depthmap', 'sketch', 'motion', 'image', 'local_image', 'single_sketch']
|
||||
guidances: ['y', 'local_image', 'image', 'motion'] # You NEED Open It
|
||||
batch_sizes: {
|
||||
"1": 1,
|
||||
"4": 1,
|
||||
"8": 1,
|
||||
"16": 1,
|
||||
}
|
||||
vit_image_size: 224
|
||||
network_name: UNetSD_temporal
|
||||
resume: true
|
||||
resume_step: 228000
|
||||
seed: 182
|
||||
num_workers: 0
|
||||
mvs_visual: False
|
||||
chunk_size: 1
|
||||
resume_checkpoint: "model_weights/non_ema_228000.pth"
|
||||
log_dir: 'outputs'
|
||||
num_steps: 1
|
||||
@@ -0,0 +1,26 @@
|
||||
TASK_TYPE: SINGLE_TASK
|
||||
read_image: False # You NEED Open It
|
||||
read_style: True
|
||||
read_sketch: True
|
||||
save_origin_video: False
|
||||
ENABLE: true
|
||||
DATASET: webvid10m
|
||||
video_compositions: ['text', 'mask', 'depthmap', 'sketch', 'motion', 'image', 'local_image', 'single_sketch']
|
||||
guidances: ['y', 'image', 'single_sketch'] # You NEED Open It
|
||||
batch_sizes: {
|
||||
"1": 1,
|
||||
"4": 1,
|
||||
"8": 1,
|
||||
"16": 1,
|
||||
}
|
||||
vit_image_size: 224
|
||||
network_name: UNetSD_temporal
|
||||
resume: true
|
||||
resume_step: 228000
|
||||
seed: 182
|
||||
num_workers: 0
|
||||
mvs_visual: False
|
||||
chunk_size: 1
|
||||
resume_checkpoint: "model_weights/non_ema_228000.pth"
|
||||
log_dir: 'outputs'
|
||||
num_steps: 1
|
||||
@@ -0,0 +1,26 @@
|
||||
TASK_TYPE: SINGLE_TASK
|
||||
read_image: False # You NEED Open It
|
||||
read_style: False
|
||||
read_sketch: True
|
||||
save_origin_video: False
|
||||
ENABLE: true
|
||||
DATASET: webvid10m
|
||||
video_compositions: ['text', 'mask', 'depthmap', 'sketch', 'motion', 'image', 'local_image', 'single_sketch']
|
||||
guidances: ['y', 'single_sketch'] # You NEED Open It
|
||||
batch_sizes: {
|
||||
"1": 1,
|
||||
"4": 1,
|
||||
"8": 1,
|
||||
"16": 1,
|
||||
}
|
||||
vit_image_size: 224
|
||||
network_name: UNetSD_temporal
|
||||
resume: true
|
||||
resume_step: 228000
|
||||
seed: 182
|
||||
num_workers: 0
|
||||
mvs_visual: False
|
||||
chunk_size: 1
|
||||
resume_checkpoint: "model_weights/non_ema_228000.pth"
|
||||
log_dir: 'outputs'
|
||||
num_steps: 1
|
||||
@@ -0,0 +1,26 @@
|
||||
TASK_TYPE: SINGLE_TASK
|
||||
read_image: False # You NEED Open It
|
||||
read_style: False
|
||||
read_sketch: False
|
||||
save_origin_video: True
|
||||
ENABLE: true
|
||||
DATASET: webvid10m
|
||||
video_compositions: ['text', 'mask', 'depthmap', 'sketch', 'motion', 'image', 'local_image', 'single_sketch']
|
||||
guidances: ['y', 'depth'] # You NEED Open It
|
||||
batch_sizes: {
|
||||
"1": 1,
|
||||
"4": 1,
|
||||
"8": 1,
|
||||
"16": 1,
|
||||
}
|
||||
vit_image_size: 224
|
||||
network_name: UNetSD_temporal
|
||||
resume: true
|
||||
resume_step: 228000
|
||||
seed: 182
|
||||
num_workers: 0
|
||||
mvs_visual: False
|
||||
chunk_size: 1
|
||||
resume_checkpoint: "model_weights/non_ema_228000.pth"
|
||||
log_dir: 'outputs'
|
||||
num_steps: 1
|
||||
@@ -0,0 +1,26 @@
|
||||
TASK_TYPE: SINGLE_TASK
|
||||
read_image: False
|
||||
read_style: True
|
||||
read_sketch: False
|
||||
save_origin_video: True
|
||||
ENABLE: true
|
||||
DATASET: webvid10m
|
||||
video_compositions: ['text', 'mask', 'depthmap', 'sketch', 'motion', 'image', 'local_image', 'single_sketch']
|
||||
guidances: ['y', 'image', 'depth'] # You NEED Open It
|
||||
batch_sizes: {
|
||||
"1": 1,
|
||||
"4": 1,
|
||||
"8": 1,
|
||||
"16": 1,
|
||||
}
|
||||
vit_image_size: 224
|
||||
network_name: UNetSD_temporal
|
||||
resume: true
|
||||
resume_step: 228000
|
||||
seed: 182
|
||||
num_workers: 0
|
||||
mvs_visual: False
|
||||
chunk_size: 1
|
||||
resume_checkpoint: "model_weights/non_ema_141000_no_watermark.pth"
|
||||
log_dir: 'outputs'
|
||||
num_steps: 1
|
||||
@@ -0,0 +1,21 @@
|
||||
TASK_TYPE: VideoComposer_Inference
|
||||
ENABLE: true
|
||||
DATASET: webvid10m
|
||||
video_compositions: ['text', 'mask', 'depthmap', 'sketch', 'motion', 'image', 'local_image', 'single_sketch']
|
||||
batch_sizes: {
|
||||
"1": 1,
|
||||
"4": 1,
|
||||
"8": 1,
|
||||
"16": 1,
|
||||
}
|
||||
vit_image_size: 224
|
||||
network_name: UNetSD_temporal
|
||||
resume: true
|
||||
resume_step: 141000
|
||||
seed: 14
|
||||
num_workers: 1
|
||||
mvs_visual: True
|
||||
chunk_size: 1
|
||||
resume_checkpoint: "model_weights/non_ema_141000_no_watermark.pth"
|
||||
log_dir: 'outputs'
|
||||
num_steps: 1
|
||||
@@ -0,0 +1,5 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from .samplers import *
|
||||
from .tokenizers import *
|
||||
from .transforms import *
|
||||
158
modelscope/models/multi_modal/videocomposer/data/samplers.py
Normal file
158
modelscope/models/multi_modal/videocomposer/data/samplers.py
Normal file
@@ -0,0 +1,158 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os.path as osp
|
||||
|
||||
import json
|
||||
import numpy as np
|
||||
from torch.utils.data.sampler import Sampler
|
||||
|
||||
from modelscope.models.multi_modal.videocomposer.ops.distributed import (
|
||||
get_rank, get_world_size, shared_random_seed)
|
||||
from modelscope.models.multi_modal.videocomposer.ops.utils import (ceil_divide,
|
||||
read)
|
||||
|
||||
__all__ = ['BatchSampler', 'GroupSampler', 'ImgGroupSampler']
|
||||
|
||||
|
||||
class BatchSampler(Sampler):
|
||||
r"""An infinite batch sampler.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dataset_size,
|
||||
batch_size,
|
||||
num_replicas=None,
|
||||
rank=None,
|
||||
shuffle=False,
|
||||
seed=None):
|
||||
self.dataset_size = dataset_size
|
||||
self.batch_size = batch_size
|
||||
self.num_replicas = num_replicas or get_world_size()
|
||||
self.rank = rank or get_rank()
|
||||
self.shuffle = shuffle
|
||||
self.seed = seed or shared_random_seed()
|
||||
self.rng = np.random.default_rng(self.seed + self.rank)
|
||||
self.batches_per_rank = ceil_divide(
|
||||
dataset_size, self.num_replicas * self.batch_size)
|
||||
self.samples_per_rank = self.batches_per_rank * self.batch_size
|
||||
|
||||
# rank indices
|
||||
indices = self.rng.permutation(
|
||||
self.samples_per_rank) if shuffle else np.arange(
|
||||
self.samples_per_rank)
|
||||
indices = indices * self.num_replicas + self.rank
|
||||
indices = indices[indices < dataset_size]
|
||||
self.indices = indices
|
||||
|
||||
def __iter__(self):
|
||||
start = 0
|
||||
while True:
|
||||
batch = [
|
||||
self.indices[i % len(self.indices)]
|
||||
for i in range(start, start + self.batch_size)
|
||||
]
|
||||
if self.shuffle and (start + self.batch_size) > len(self.indices):
|
||||
self.rng.shuffle(self.indices)
|
||||
start = (start + self.batch_size) % len(self.indices)
|
||||
yield batch
|
||||
|
||||
|
||||
class GroupSampler(Sampler):
|
||||
|
||||
def __init__(self,
|
||||
group_file,
|
||||
batch_size,
|
||||
alpha=0.7,
|
||||
update_interval=5000,
|
||||
seed=8888):
|
||||
self.group_file = group_file
|
||||
self.group_folder = osp.join(osp.dirname(group_file), 'groups')
|
||||
self.batch_size = batch_size
|
||||
self.alpha = alpha
|
||||
self.update_interval = update_interval
|
||||
self.seed = seed
|
||||
self.rng = np.random.default_rng(seed)
|
||||
|
||||
def __iter__(self):
|
||||
while True:
|
||||
# keep groups up-to-date
|
||||
self.update_groups()
|
||||
|
||||
# collect items
|
||||
items = self.sample()
|
||||
while len(items) < self.batch_size:
|
||||
items += self.sample()
|
||||
|
||||
# sample a batch
|
||||
batch = self.rng.choice(
|
||||
items,
|
||||
self.batch_size,
|
||||
replace=False if len(items) >= self.batch_size else True)
|
||||
yield [u.strip().split(',') for u in batch]
|
||||
|
||||
def update_groups(self):
|
||||
if not hasattr(self, '_step'):
|
||||
self._step = 0
|
||||
if self._step % self.update_interval == 0:
|
||||
self.groups = json.loads(read(self.group_file))
|
||||
self._step += 1
|
||||
|
||||
def sample(self):
|
||||
scales = np.array(
|
||||
[float(next(iter(u)).split(':')[-1]) for u in self.groups])
|
||||
p = scales**self.alpha / (scales**self.alpha).sum()
|
||||
group = self.rng.choice(self.groups, p=p)
|
||||
list_file = osp.join(self.group_folder,
|
||||
self.rng.choice(next(iter(group.values()))))
|
||||
return read(list_file).strip().split('\n')
|
||||
|
||||
|
||||
class ImgGroupSampler(Sampler):
|
||||
|
||||
def __init__(self,
|
||||
group_file,
|
||||
batch_size,
|
||||
alpha=0.7,
|
||||
update_interval=5000,
|
||||
seed=8888):
|
||||
self.group_file = group_file
|
||||
self.group_folder = osp.join(osp.dirname(group_file), 'groups')
|
||||
self.batch_size = batch_size
|
||||
self.alpha = alpha
|
||||
self.update_interval = update_interval
|
||||
self.seed = seed
|
||||
self.rng = np.random.default_rng(seed)
|
||||
|
||||
def __iter__(self):
|
||||
while True:
|
||||
# keep groups up-to-date
|
||||
self.update_groups()
|
||||
|
||||
# collect items
|
||||
items = self.sample()
|
||||
while len(items) < self.batch_size:
|
||||
items += self.sample()
|
||||
|
||||
# sample a batch
|
||||
batch = self.rng.choice(
|
||||
items,
|
||||
self.batch_size,
|
||||
replace=False if len(items) >= self.batch_size else True)
|
||||
yield [u.strip().split(',', 1) for u in batch]
|
||||
|
||||
def update_groups(self):
|
||||
if not hasattr(self, '_step'):
|
||||
self._step = 0
|
||||
if self._step % self.update_interval == 0:
|
||||
self.groups = json.loads(read(self.group_file))
|
||||
|
||||
self._step += 1
|
||||
|
||||
def sample(self):
|
||||
scales = np.array(
|
||||
[float(next(iter(u)).split(':')[-1]) for u in self.groups])
|
||||
p = scales**self.alpha / (scales**self.alpha).sum()
|
||||
group = self.rng.choice(self.groups, p=p)
|
||||
list_file = osp.join(self.group_folder,
|
||||
self.rng.choice(next(iter(group.values()))))
|
||||
return read(list_file).strip().split('\n')
|
||||
184
modelscope/models/multi_modal/videocomposer/data/tokenizers.py
Normal file
184
modelscope/models/multi_modal/videocomposer/data/tokenizers.py
Normal file
@@ -0,0 +1,184 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import gzip
|
||||
import html
|
||||
import os
|
||||
from functools import lru_cache
|
||||
|
||||
import ftfy
|
||||
import regex as re
|
||||
import torch
|
||||
from tokenizers import BertWordPieceTokenizer, CharBPETokenizer
|
||||
|
||||
__all__ = ['CLIPTokenizer']
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def default_bpe():
|
||||
root = os.path.realpath(__file__)
|
||||
root = '/'.join(root.split('/')[:-1])
|
||||
return os.path.join(root, 'bpe_simple_vocab_16e6.txt.gz')
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def bytes_to_unicode():
|
||||
"""
|
||||
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
||||
The reversible bpe codes work on unicode strings.
|
||||
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
||||
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
||||
This is a signficant percentage of your normal, say, 32K bpe vocab.
|
||||
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
||||
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
||||
"""
|
||||
bs = list(range(ord('!'),
|
||||
ord('~') + 1)) + list(range(
|
||||
ord('¡'),
|
||||
ord('¬') + 1)) + list(range(ord('®'),
|
||||
ord('ÿ') + 1))
|
||||
cs = bs[:]
|
||||
n = 0
|
||||
for b in range(2**8):
|
||||
if b not in bs:
|
||||
bs.append(b)
|
||||
cs.append(2**8 + n)
|
||||
n += 1
|
||||
cs = [chr(n) for n in cs]
|
||||
return dict(zip(bs, cs))
|
||||
|
||||
|
||||
def get_pairs(word):
|
||||
"""Return set of symbol pairs in a word.
|
||||
Word is represented as tuple of symbols (symbols being variable-length strings).
|
||||
"""
|
||||
pairs = set()
|
||||
prev_char = word[0]
|
||||
for char in word[1:]:
|
||||
pairs.add((prev_char, char))
|
||||
prev_char = char
|
||||
return pairs
|
||||
|
||||
|
||||
def basic_clean(text):
|
||||
text = ftfy.fix_text(text)
|
||||
text = html.unescape(html.unescape(text))
|
||||
return text.strip()
|
||||
|
||||
|
||||
def whitespace_clean(text):
|
||||
text = re.sub(r'\s+', ' ', text)
|
||||
text = text.strip()
|
||||
return text
|
||||
|
||||
|
||||
class SimpleTokenizer(object):
|
||||
|
||||
def __init__(self, bpe_path: str = default_bpe()):
|
||||
self.byte_encoder = bytes_to_unicode()
|
||||
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
||||
merges = gzip.open(bpe_path).read().decode('utf-8').split('\n')
|
||||
merges = merges[1:49152 - 256 - 2 + 1]
|
||||
merges = [tuple(merge.split()) for merge in merges]
|
||||
vocab = list(bytes_to_unicode().values())
|
||||
vocab = vocab + [v + '</w>' for v in vocab]
|
||||
for merge in merges:
|
||||
vocab.append(''.join(merge))
|
||||
vocab.extend(['<|startoftext|>', '<|endoftext|>'])
|
||||
self.encoder = dict(zip(vocab, range(len(vocab))))
|
||||
self.decoder = {v: k for k, v in self.encoder.items()}
|
||||
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
||||
self.cache = {
|
||||
'<|startoftext|>': '<|startoftext|>',
|
||||
'<|endoftext|>': '<|endoftext|>'
|
||||
}
|
||||
self.pat = re.compile(
|
||||
r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
|
||||
re.IGNORECASE)
|
||||
|
||||
def bpe(self, token):
|
||||
if token in self.cache:
|
||||
return self.cache[token]
|
||||
word = tuple(token[:-1]) + (token[-1] + '</w>', )
|
||||
pairs = get_pairs(word)
|
||||
|
||||
if not pairs:
|
||||
return token + '</w>'
|
||||
|
||||
while True:
|
||||
bigram = min(
|
||||
pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
||||
if bigram not in self.bpe_ranks:
|
||||
break
|
||||
first, second = bigram
|
||||
new_word = []
|
||||
i = 0
|
||||
while i < len(word):
|
||||
try:
|
||||
j = word.index(first, i)
|
||||
new_word.extend(word[i:j])
|
||||
i = j
|
||||
except Exception as e:
|
||||
new_word.extend(word[i:])
|
||||
print(e)
|
||||
break
|
||||
|
||||
if word[i] == first and i < len(word) - 1 and word[
|
||||
i + 1] == second:
|
||||
new_word.append(first + second)
|
||||
i += 2
|
||||
else:
|
||||
new_word.append(word[i])
|
||||
i += 1
|
||||
new_word = tuple(new_word)
|
||||
word = new_word
|
||||
if len(word) == 1:
|
||||
break
|
||||
else:
|
||||
pairs = get_pairs(word)
|
||||
word = ' '.join(word)
|
||||
self.cache[token] = word
|
||||
return word
|
||||
|
||||
def encode(self, text):
|
||||
bpe_tokens = []
|
||||
text = whitespace_clean(basic_clean(text)).lower()
|
||||
for token in re.findall(self.pat, text):
|
||||
token = ''.join(self.byte_encoder[b]
|
||||
for b in token.encode('utf-8'))
|
||||
bpe_tokens.extend(self.encoder[bpe_token]
|
||||
for bpe_token in self.bpe(token).split(' '))
|
||||
return bpe_tokens
|
||||
|
||||
def decode(self, tokens):
|
||||
text = ''.join([self.decoder[token] for token in tokens])
|
||||
text = bytearray([self.byte_decoder[c] for c in text]).decode(
|
||||
'utf-8', errors='replace').replace('</w>', ' ')
|
||||
return text
|
||||
|
||||
|
||||
class CLIPTokenizer(object):
|
||||
|
||||
def __init__(self, length=77):
|
||||
self.length = length
|
||||
|
||||
# init tokenizer
|
||||
self.tokenizer = SimpleTokenizer(bpe_path=default_bpe())
|
||||
self.sos_token = self.tokenizer.encoder['<|startoftext|>']
|
||||
self.eos_token = self.tokenizer.encoder['<|endoftext|>']
|
||||
self.vocab_size = len(self.tokenizer.encoder)
|
||||
|
||||
def __call__(self, sequence):
|
||||
if isinstance(sequence, str):
|
||||
return torch.LongTensor(self._tokenizer(sequence))
|
||||
elif isinstance(sequence, list):
|
||||
return torch.LongTensor([self._tokenizer(u) for u in sequence])
|
||||
else:
|
||||
raise TypeError(
|
||||
f'Expected the "sequence" to be a string or a list, but got {type(sequence)}'
|
||||
)
|
||||
|
||||
def _tokenizer(self, text):
|
||||
tokens = self.tokenizer.encode(text)[:self.length - 2]
|
||||
tokens = [self.sos_token] + tokens + [self.eos_token]
|
||||
tokens = tokens + [0] * (self.length - len(tokens))
|
||||
return tokens
|
||||
400
modelscope/models/multi_modal/videocomposer/data/transforms.py
Normal file
400
modelscope/models/multi_modal/videocomposer/data/transforms.py
Normal file
@@ -0,0 +1,400 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import math
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms.functional as F
|
||||
import torchvision.transforms.functional as TF
|
||||
from PIL import Image, ImageFilter
|
||||
from torchvision.transforms.functional import InterpolationMode
|
||||
|
||||
__all__ = [
|
||||
'Compose', 'Resize', 'Rescale', 'CenterCrop', 'CenterCropV2', 'RandomCrop',
|
||||
'RandomCropV2', 'RandomHFlip', 'GaussianBlur', 'ColorJitter', 'RandomGray',
|
||||
'ToTensor', 'Normalize', 'ResizeRandomCrop', 'ExtractResizeRandomCrop',
|
||||
'ExtractResizeAssignCrop'
|
||||
]
|
||||
|
||||
|
||||
def random_resize(img, size):
|
||||
img = [
|
||||
TF.resize(
|
||||
u,
|
||||
size,
|
||||
interpolation=random.choice([
|
||||
InterpolationMode.BILINEAR, InterpolationMode.BICUBIC,
|
||||
InterpolationMode.LANCZOS
|
||||
])) for u in img
|
||||
]
|
||||
return img
|
||||
|
||||
|
||||
class CenterCropV3(object):
|
||||
|
||||
def __init__(self, size):
|
||||
self.size = size
|
||||
|
||||
def __call__(self, img):
|
||||
# fast resize
|
||||
while min(img.size) >= 2 * self.size:
|
||||
img = img.resize((img.width // 2, img.height // 2),
|
||||
resample=Image.BOX)
|
||||
scale = self.size / min(img.size)
|
||||
img = img.resize((round(scale * img.width), round(scale * img.height)),
|
||||
resample=Image.BICUBIC)
|
||||
|
||||
# center crop
|
||||
x1 = (img.width - self.size) // 2
|
||||
y1 = (img.height - self.size) // 2
|
||||
img = img.crop((x1, y1, x1 + self.size, y1 + self.size))
|
||||
return img
|
||||
|
||||
|
||||
class Compose(object):
|
||||
|
||||
def __init__(self, transforms):
|
||||
self.transforms = transforms
|
||||
|
||||
def __getitem__(self, index):
|
||||
if isinstance(index, slice):
|
||||
return Compose(self.transforms[index])
|
||||
else:
|
||||
return self.transforms[index]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.transforms)
|
||||
|
||||
def __call__(self, rgb):
|
||||
for t in self.transforms:
|
||||
rgb = t(rgb)
|
||||
return rgb
|
||||
|
||||
|
||||
class Resize(object):
|
||||
|
||||
def __init__(self, size=256):
|
||||
if isinstance(size, int):
|
||||
size = (size, size)
|
||||
self.size = size
|
||||
|
||||
def __call__(self, rgb):
|
||||
|
||||
rgb = [u.resize(self.size, Image.BILINEAR) for u in rgb]
|
||||
return rgb
|
||||
|
||||
|
||||
class Rescale(object):
|
||||
|
||||
def __init__(self, size=256, interpolation=Image.BILINEAR):
|
||||
self.size = size
|
||||
self.interpolation = interpolation
|
||||
|
||||
def __call__(self, rgb):
|
||||
w, h = rgb[0].size
|
||||
scale = self.size / min(w, h)
|
||||
out_w, out_h = int(round(w * scale)), int(round(h * scale))
|
||||
rgb = [u.resize((out_w, out_h), self.interpolation) for u in rgb]
|
||||
return rgb
|
||||
|
||||
|
||||
class CenterCrop(object):
|
||||
|
||||
def __init__(self, size=224):
|
||||
self.size = size
|
||||
|
||||
def __call__(self, rgb):
|
||||
w, h = rgb[0].size
|
||||
assert min(w, h) >= self.size
|
||||
x1 = (w - self.size) // 2
|
||||
y1 = (h - self.size) // 2
|
||||
rgb = [u.crop((x1, y1, x1 + self.size, y1 + self.size)) for u in rgb]
|
||||
return rgb
|
||||
|
||||
|
||||
class ResizeRandomCrop(object):
|
||||
|
||||
def __init__(self, size=256, size_short=292):
|
||||
self.size = size
|
||||
# self.min_area = min_area
|
||||
self.size_short = size_short
|
||||
|
||||
def __call__(self, rgb):
|
||||
|
||||
# consistent crop between rgb and m
|
||||
while min(rgb[0].size) >= 2 * self.size_short:
|
||||
rgb = [
|
||||
u.resize((u.width // 2, u.height // 2), resample=Image.BOX)
|
||||
for u in rgb
|
||||
]
|
||||
scale = self.size_short / min(rgb[0].size)
|
||||
rgb = [
|
||||
u.resize((round(scale * u.width), round(scale * u.height)),
|
||||
resample=Image.BICUBIC) for u in rgb
|
||||
]
|
||||
out_w = self.size
|
||||
out_h = self.size
|
||||
w, h = rgb[0].size # (518, 292)
|
||||
x1 = random.randint(0, w - out_w)
|
||||
y1 = random.randint(0, h - out_h)
|
||||
|
||||
rgb = [u.crop((x1, y1, x1 + out_w, y1 + out_h)) for u in rgb]
|
||||
|
||||
return rgb
|
||||
|
||||
|
||||
class ExtractResizeRandomCrop(object):
|
||||
|
||||
def __init__(self, size=256, size_short=292):
|
||||
self.size = size
|
||||
self.size_short = size_short
|
||||
|
||||
def __call__(self, rgb):
|
||||
|
||||
# consistent crop between rgb and m
|
||||
while min(rgb[0].size) >= 2 * self.size_short:
|
||||
rgb = [
|
||||
u.resize((u.width // 2, u.height // 2), resample=Image.BOX)
|
||||
for u in rgb
|
||||
]
|
||||
scale = self.size_short / min(rgb[0].size)
|
||||
rgb = [
|
||||
u.resize((round(scale * u.width), round(scale * u.height)),
|
||||
resample=Image.BICUBIC) for u in rgb
|
||||
]
|
||||
out_w = self.size
|
||||
out_h = self.size
|
||||
w, h = rgb[0].size # (518, 292)
|
||||
x1 = random.randint(0, w - out_w)
|
||||
y1 = random.randint(0, h - out_h)
|
||||
|
||||
rgb = [u.crop((x1, y1, x1 + out_w, y1 + out_h)) for u in rgb]
|
||||
wh = [x1, y1, x1 + out_w, y1 + out_h]
|
||||
|
||||
return rgb, wh
|
||||
|
||||
|
||||
class ExtractResizeAssignCrop(object):
|
||||
|
||||
def __init__(self, size=256, size_short=292):
|
||||
self.size = size
|
||||
self.size_short = size_short
|
||||
|
||||
def __call__(self, rgb, wh):
|
||||
|
||||
# consistent crop between rgb and m
|
||||
while min(rgb[0].size) >= 2 * self.size_short:
|
||||
rgb = [
|
||||
u.resize((u.width // 2, u.height // 2), resample=Image.BOX)
|
||||
for u in rgb
|
||||
]
|
||||
scale = self.size_short / min(rgb[0].size)
|
||||
rgb = [
|
||||
u.resize((round(scale * u.width), round(scale * u.height)),
|
||||
resample=Image.BICUBIC) for u in rgb
|
||||
]
|
||||
|
||||
rgb = [u.crop(wh) for u in rgb]
|
||||
rgb = [u.resize((self.size, self.size), Image.BILINEAR) for u in rgb]
|
||||
|
||||
return rgb
|
||||
|
||||
|
||||
class CenterCropV2(object):
|
||||
|
||||
def __init__(self, size):
|
||||
self.size = size
|
||||
|
||||
def __call__(self, img):
|
||||
# fast resize
|
||||
while min(img[0].size) >= 2 * self.size:
|
||||
img = [
|
||||
u.resize((u.width // 2, u.height // 2), resample=Image.BOX)
|
||||
for u in img
|
||||
]
|
||||
scale = self.size / min(img[0].size)
|
||||
img = [
|
||||
u.resize((round(scale * u.width), round(scale * u.height)),
|
||||
resample=Image.BICUBIC) for u in img
|
||||
]
|
||||
|
||||
# center crop
|
||||
x1 = (img[0].width - self.size) // 2
|
||||
y1 = (img[0].height - self.size) // 2
|
||||
img = [u.crop((x1, y1, x1 + self.size, y1 + self.size)) for u in img]
|
||||
return img
|
||||
|
||||
|
||||
class RandomCrop(object):
|
||||
|
||||
def __init__(self, size=224, min_area=0.4):
|
||||
self.size = size
|
||||
self.min_area = min_area
|
||||
|
||||
def __call__(self, rgb):
|
||||
|
||||
# consistent crop between rgb and m
|
||||
w, h = rgb[0].size
|
||||
area = w * h
|
||||
out_w, out_h = float('inf'), float('inf')
|
||||
while out_w > w or out_h > h:
|
||||
target_area = random.uniform(self.min_area, 1.0) * area
|
||||
aspect_ratio = random.uniform(3. / 4., 4. / 3.)
|
||||
out_w = int(round(math.sqrt(target_area * aspect_ratio)))
|
||||
out_h = int(round(math.sqrt(target_area / aspect_ratio)))
|
||||
x1 = random.randint(0, w - out_w)
|
||||
y1 = random.randint(0, h - out_h)
|
||||
|
||||
rgb = [u.crop((x1, y1, x1 + out_w, y1 + out_h)) for u in rgb]
|
||||
rgb = [u.resize((self.size, self.size), Image.BILINEAR) for u in rgb]
|
||||
|
||||
return rgb
|
||||
|
||||
|
||||
class RandomCropV2(object):
|
||||
|
||||
def __init__(self, size=224, min_area=0.4, ratio=(3. / 4., 4. / 3.)):
|
||||
if isinstance(size, (tuple, list)):
|
||||
self.size = size
|
||||
else:
|
||||
self.size = (size, size)
|
||||
self.min_area = min_area
|
||||
self.ratio = ratio
|
||||
|
||||
def _get_params(self, img):
|
||||
width, height = img.size
|
||||
area = height * width
|
||||
|
||||
for _ in range(10):
|
||||
target_area = random.uniform(self.min_area, 1.0) * area
|
||||
log_ratio = (math.log(self.ratio[0]), math.log(self.ratio[1]))
|
||||
aspect_ratio = math.exp(random.uniform(*log_ratio))
|
||||
|
||||
w = int(round(math.sqrt(target_area * aspect_ratio)))
|
||||
h = int(round(math.sqrt(target_area / aspect_ratio)))
|
||||
|
||||
if 0 < w <= width and 0 < h <= height:
|
||||
i = random.randint(0, height - h)
|
||||
j = random.randint(0, width - w)
|
||||
return i, j, h, w
|
||||
|
||||
# Fallback to central crop
|
||||
in_ratio = float(width) / float(height)
|
||||
if (in_ratio < min(self.ratio)):
|
||||
w = width
|
||||
h = int(round(w / min(self.ratio)))
|
||||
elif (in_ratio > max(self.ratio)):
|
||||
h = height
|
||||
w = int(round(h * max(self.ratio)))
|
||||
else: # whole image
|
||||
w = width
|
||||
h = height
|
||||
i = (height - h) // 2
|
||||
j = (width - w) // 2
|
||||
return i, j, h, w
|
||||
|
||||
def __call__(self, rgb):
|
||||
i, j, h, w = self._get_params(rgb[0])
|
||||
rgb = [F.resized_crop(u, i, j, h, w, self.size) for u in rgb]
|
||||
return rgb
|
||||
|
||||
|
||||
class RandomHFlip(object):
|
||||
|
||||
def __init__(self, p=0.5):
|
||||
self.p = p
|
||||
|
||||
def __call__(self, rgb):
|
||||
if random.random() < self.p:
|
||||
rgb = [u.transpose(Image.FLIP_LEFT_RIGHT) for u in rgb]
|
||||
return rgb
|
||||
|
||||
|
||||
class GaussianBlur(object):
|
||||
|
||||
def __init__(self, sigmas=[0.1, 2.0], p=0.5):
|
||||
self.sigmas = sigmas
|
||||
self.p = p
|
||||
|
||||
def __call__(self, rgb):
|
||||
if random.random() < self.p:
|
||||
sigma = random.uniform(*self.sigmas)
|
||||
rgb = [
|
||||
u.filter(ImageFilter.GaussianBlur(radius=sigma)) for u in rgb
|
||||
]
|
||||
return rgb
|
||||
|
||||
|
||||
class ColorJitter(object):
|
||||
|
||||
def __init__(self,
|
||||
brightness=0.4,
|
||||
contrast=0.4,
|
||||
saturation=0.4,
|
||||
hue=0.1,
|
||||
p=0.5):
|
||||
self.brightness = brightness
|
||||
self.contrast = contrast
|
||||
self.saturation = saturation
|
||||
self.hue = hue
|
||||
self.p = p
|
||||
|
||||
def __call__(self, rgb):
|
||||
if random.random() < self.p:
|
||||
brightness, contrast, saturation, hue = self._random_params()
|
||||
transforms = [
|
||||
lambda f: F.adjust_brightness(f, brightness),
|
||||
lambda f: F.adjust_contrast(f, contrast),
|
||||
lambda f: F.adjust_saturation(f, saturation),
|
||||
lambda f: F.adjust_hue(f, hue)
|
||||
]
|
||||
random.shuffle(transforms)
|
||||
for t in transforms:
|
||||
rgb = [t(u) for u in rgb]
|
||||
|
||||
return rgb
|
||||
|
||||
def _random_params(self):
|
||||
brightness = random.uniform(
|
||||
max(0, 1 - self.brightness), 1 + self.brightness)
|
||||
contrast = random.uniform(max(0, 1 - self.contrast), 1 + self.contrast)
|
||||
saturation = random.uniform(
|
||||
max(0, 1 - self.saturation), 1 + self.saturation)
|
||||
hue = random.uniform(-self.hue, self.hue)
|
||||
return brightness, contrast, saturation, hue
|
||||
|
||||
|
||||
class RandomGray(object):
|
||||
|
||||
def __init__(self, p=0.2):
|
||||
self.p = p
|
||||
|
||||
def __call__(self, rgb):
|
||||
if random.random() < self.p:
|
||||
rgb = [u.convert('L').convert('RGB') for u in rgb]
|
||||
return rgb
|
||||
|
||||
|
||||
class ToTensor(object):
|
||||
|
||||
def __call__(self, rgb):
|
||||
rgb = torch.stack([F.to_tensor(u) for u in rgb], dim=0)
|
||||
return rgb
|
||||
|
||||
|
||||
class Normalize(object):
|
||||
|
||||
def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
|
||||
self.mean = mean
|
||||
self.std = std
|
||||
|
||||
def __call__(self, rgb):
|
||||
rgb = rgb.clone()
|
||||
rgb.clamp_(0, 1)
|
||||
if not isinstance(self.mean, torch.Tensor):
|
||||
self.mean = rgb.new_tensor(self.mean).view(-1)
|
||||
if not isinstance(self.std, torch.Tensor):
|
||||
self.std = rgb.new_tensor(self.std).view(-1)
|
||||
rgb.sub_(self.mean.view(1, -1, 1, 1)).div_(self.std.view(1, -1, 1, 1))
|
||||
return rgb
|
||||
1514
modelscope/models/multi_modal/videocomposer/diffusion.py
Normal file
1514
modelscope/models/multi_modal/videocomposer/diffusion.py
Normal file
File diff suppressed because it is too large
Load Diff
1697
modelscope/models/multi_modal/videocomposer/dpm_solver.py
Normal file
1697
modelscope/models/multi_modal/videocomposer/dpm_solver.py
Normal file
File diff suppressed because it is too large
Load Diff
120
modelscope/models/multi_modal/videocomposer/mha_flash.py
Normal file
120
modelscope/models/multi_modal/videocomposer/mha_flash.py
Normal file
@@ -0,0 +1,120 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.cuda.amp as amp
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from flash_attn.flash_attention import FlashAttention
|
||||
|
||||
|
||||
class FlashAttentionBlock(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim,
|
||||
context_dim=None,
|
||||
num_heads=None,
|
||||
head_dim=None,
|
||||
batch_size=4):
|
||||
# consider head_dim first, then num_heads
|
||||
num_heads = dim // head_dim if head_dim else num_heads
|
||||
head_dim = dim // num_heads
|
||||
assert num_heads * head_dim == dim
|
||||
super(FlashAttentionBlock, self).__init__()
|
||||
self.dim = dim
|
||||
self.context_dim = context_dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
self.scale = math.pow(head_dim, -0.25)
|
||||
|
||||
# layers
|
||||
self.norm = nn.GroupNorm(32, dim)
|
||||
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
|
||||
if context_dim is not None:
|
||||
self.context_kv = nn.Linear(context_dim, dim * 2)
|
||||
self.proj = nn.Conv2d(dim, dim, 1)
|
||||
|
||||
if self.head_dim <= 128 and (self.head_dim % 8) == 0:
|
||||
self.flash_attn = FlashAttention(
|
||||
softmax_scale=None, attention_dropout=0.0)
|
||||
|
||||
# zero out the last layer params
|
||||
nn.init.zeros_(self.proj.weight)
|
||||
|
||||
def _init_weight(self, module):
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=0.15)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Conv2d):
|
||||
module.weight.data.normal_(mean=0.0, std=0.15)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
def forward(self, x, context=None):
|
||||
r"""x: [B, C, H, W].
|
||||
context: [B, L, C] or None.
|
||||
"""
|
||||
identity = x
|
||||
b, c, h, w, n, d = *x.size(), self.num_heads, self.head_dim
|
||||
|
||||
# compute query, key, value
|
||||
x = self.norm(x)
|
||||
q, k, v = self.to_qkv(x).view(b, n * 3, d, h * w).chunk(3, dim=1)
|
||||
if context is not None:
|
||||
ck, cv = self.context_kv(context).reshape(b, -1, n * 2,
|
||||
d).permute(0, 2, 3,
|
||||
1).chunk(
|
||||
2, dim=1)
|
||||
k = torch.cat([ck, k], dim=-1)
|
||||
v = torch.cat([cv, v], dim=-1)
|
||||
cq = torch.zeros([b, n, d, 4], dtype=q.dtype, device=q.device)
|
||||
q = torch.cat([q, cq], dim=-1)
|
||||
|
||||
qkv = torch.cat([q, k, v], dim=1)
|
||||
origin_dtype = qkv.dtype
|
||||
qkv = qkv.permute(0, 3, 1, 2).reshape(b, -1, 3, n,
|
||||
d).half().contiguous()
|
||||
out, _ = self.flash_attn(qkv)
|
||||
out.to(origin_dtype)
|
||||
|
||||
if context is not None:
|
||||
out = out[:, :-4, :, :]
|
||||
out = out.permute(0, 2, 3, 1).reshape(b, c, h, w)
|
||||
|
||||
# output
|
||||
x = self.proj(out)
|
||||
return x + identity
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
batch_size = 8
|
||||
flash_net = FlashAttentionBlock(
|
||||
dim=1280,
|
||||
context_dim=512,
|
||||
num_heads=None,
|
||||
head_dim=64,
|
||||
batch_size=batch_size).cuda()
|
||||
|
||||
x = torch.randn([batch_size, 1280, 32, 32], dtype=torch.float32).cuda()
|
||||
context = torch.randn([batch_size, 4, 512], dtype=torch.float32).cuda()
|
||||
# context = None
|
||||
flash_net.eval()
|
||||
|
||||
with amp.autocast(enabled=True):
|
||||
# warm up
|
||||
for i in range(5):
|
||||
y = flash_net(x, context)
|
||||
torch.cuda.synchronize()
|
||||
s1 = time.time()
|
||||
for i in range(10):
|
||||
y = flash_net(x, context)
|
||||
torch.cuda.synchronize()
|
||||
s2 = time.time()
|
||||
|
||||
print(f'Average cost time {(s2-s1)*1000/10} ms')
|
||||
@@ -0,0 +1,2 @@
|
||||
from .clip import *
|
||||
from .midas import *
|
||||
460
modelscope/models/multi_modal/videocomposer/models/clip.py
Normal file
460
modelscope/models/multi_modal/videocomposer/models/clip.py
Normal file
@@ -0,0 +1,460 @@
|
||||
import math
|
||||
import os.path as osp
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import modelscope.models.multi_modal.videocomposer.ops as ops
|
||||
|
||||
__all__ = [
|
||||
'CLIP', 'clip_vit_b_32', 'clip_vit_b_16', 'clip_vit_l_14',
|
||||
'clip_vit_l_14_336px', 'clip_vit_h_16'
|
||||
]
|
||||
|
||||
|
||||
def DOWNLOAD_TO_CACHE(oss_key,
|
||||
file_or_dirname=None,
|
||||
cache_dir=osp.join(
|
||||
'/'.join(osp.abspath(__file__).split('/')[:-2]),
|
||||
'model_weights')):
|
||||
r"""Download OSS [file or folder] to the cache folder.
|
||||
Only the 0th process on each node will run the downloading.
|
||||
Barrier all processes until the downloading is completed.
|
||||
"""
|
||||
# source and target paths
|
||||
base_path = osp.join(cache_dir, file_or_dirname or osp.basename(oss_key))
|
||||
|
||||
return base_path
|
||||
|
||||
|
||||
def to_fp16(m):
|
||||
if isinstance(m, (nn.Linear, nn.Conv2d)):
|
||||
m.weight.data = m.weight.data.half()
|
||||
if m.bias is not None:
|
||||
m.bias.data = m.bias.data.half()
|
||||
elif hasattr(m, 'head'):
|
||||
p = getattr(m, 'head')
|
||||
p.data = p.data.half()
|
||||
|
||||
|
||||
class QuickGELU(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
return x * torch.sigmoid(1.702 * x)
|
||||
|
||||
|
||||
class LayerNorm(nn.LayerNorm):
|
||||
r"""Subclass of nn.LayerNorm to handle fp16.
|
||||
"""
|
||||
|
||||
def forward(self, x):
|
||||
return super(LayerNorm, self).forward(x.float()).type_as(x)
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
|
||||
def __init__(self, dim, num_heads, attn_dropout=0.0, proj_dropout=0.0):
|
||||
assert dim % num_heads == 0
|
||||
super(SelfAttention, self).__init__()
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.scale = 1.0 / math.sqrt(self.head_dim)
|
||||
|
||||
# layers
|
||||
self.to_qkv = nn.Linear(dim, dim * 3)
|
||||
self.attn_dropout = nn.Dropout(attn_dropout)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_dropout = nn.Dropout(proj_dropout)
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
r"""x: [B, L, C].
|
||||
mask: [*, L, L].
|
||||
"""
|
||||
b, l, _, n = *x.size(), self.num_heads
|
||||
|
||||
# compute query, key, and value
|
||||
q, k, v = self.to_qkv(x.transpose(0, 1)).chunk(3, dim=-1)
|
||||
q = q.reshape(l, b * n, -1).transpose(0, 1)
|
||||
k = k.reshape(l, b * n, -1).transpose(0, 1)
|
||||
v = v.reshape(l, b * n, -1).transpose(0, 1)
|
||||
|
||||
# compute attention
|
||||
attn = self.scale * torch.bmm(q, k.transpose(1, 2))
|
||||
if mask is not None:
|
||||
attn = attn.masked_fill(mask[:, :l, :l] == 0, float('-inf'))
|
||||
attn = F.softmax(attn.float(), dim=-1).type_as(attn)
|
||||
attn = self.attn_dropout(attn)
|
||||
|
||||
# gather context
|
||||
x = torch.bmm(attn, v)
|
||||
x = x.view(b, n, l, -1).transpose(1, 2).reshape(b, l, -1)
|
||||
|
||||
# output
|
||||
x = self.proj(x)
|
||||
x = self.proj_dropout(x)
|
||||
return x
|
||||
|
||||
|
||||
class AttentionBlock(nn.Module):
|
||||
|
||||
def __init__(self, dim, num_heads, attn_dropout=0.0, proj_dropout=0.0):
|
||||
super(AttentionBlock, self).__init__()
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
|
||||
# layers
|
||||
self.norm1 = LayerNorm(dim)
|
||||
self.attn = SelfAttention(dim, num_heads, attn_dropout, proj_dropout)
|
||||
self.norm2 = LayerNorm(dim)
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(dim, dim * 4), QuickGELU(), nn.Linear(dim * 4, dim),
|
||||
nn.Dropout(proj_dropout))
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
x = x + self.attn(self.norm1(x), mask)
|
||||
x = x + self.mlp(self.norm2(x))
|
||||
return x
|
||||
|
||||
|
||||
class VisionTransformer(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
image_size=224,
|
||||
patch_size=16,
|
||||
dim=768,
|
||||
out_dim=512,
|
||||
num_heads=12,
|
||||
num_layers=12,
|
||||
attn_dropout=0.0,
|
||||
proj_dropout=0.0,
|
||||
embedding_dropout=0.0):
|
||||
assert image_size % patch_size == 0
|
||||
super(VisionTransformer, self).__init__()
|
||||
self.image_size = image_size
|
||||
self.patch_size = patch_size
|
||||
self.dim = dim
|
||||
self.out_dim = out_dim
|
||||
self.num_heads = num_heads
|
||||
self.num_layers = num_layers
|
||||
self.num_patches = (image_size // patch_size)**2
|
||||
|
||||
# embeddings
|
||||
gain = 1.0 / math.sqrt(dim)
|
||||
self.patch_embedding = nn.Conv2d(
|
||||
3, dim, kernel_size=patch_size, stride=patch_size, bias=False)
|
||||
self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
|
||||
self.pos_embedding = nn.Parameter(
|
||||
gain * torch.randn(1, self.num_patches + 1, dim))
|
||||
self.dropout = nn.Dropout(embedding_dropout)
|
||||
|
||||
# transformer
|
||||
self.pre_norm = LayerNorm(dim)
|
||||
self.transformer = nn.Sequential(*[
|
||||
AttentionBlock(dim, num_heads, attn_dropout, proj_dropout)
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
self.post_norm = LayerNorm(dim)
|
||||
|
||||
# head
|
||||
self.head = nn.Parameter(gain * torch.randn(dim, out_dim))
|
||||
|
||||
def forward(self, x):
|
||||
b, dtype = x.size(0), self.head.dtype
|
||||
x = x.type(dtype)
|
||||
|
||||
# patch-embedding
|
||||
x = self.patch_embedding(x).flatten(2).permute(0, 2, 1) # [b, n, c]
|
||||
x = torch.cat([self.cls_embedding.repeat(b, 1, 1).type(dtype), x],
|
||||
dim=1)
|
||||
x = self.dropout(x + self.pos_embedding.type(dtype))
|
||||
x = self.pre_norm(x)
|
||||
|
||||
# transformer
|
||||
x = self.transformer(x)
|
||||
|
||||
# head
|
||||
x = self.post_norm(x)
|
||||
x = torch.mm(x[:, 0, :], self.head)
|
||||
return x
|
||||
|
||||
def fp16(self):
|
||||
return self.apply(to_fp16)
|
||||
|
||||
|
||||
class TextTransformer(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
vocab_size,
|
||||
text_len,
|
||||
dim=512,
|
||||
out_dim=512,
|
||||
num_heads=8,
|
||||
num_layers=12,
|
||||
attn_dropout=0.0,
|
||||
proj_dropout=0.0,
|
||||
embedding_dropout=0.0):
|
||||
super(TextTransformer, self).__init__()
|
||||
self.vocab_size = vocab_size
|
||||
self.text_len = text_len
|
||||
self.dim = dim
|
||||
self.out_dim = out_dim
|
||||
self.num_heads = num_heads
|
||||
self.num_layers = num_layers
|
||||
|
||||
# embeddings
|
||||
self.token_embedding = nn.Embedding(vocab_size, dim)
|
||||
self.pos_embedding = nn.Parameter(0.01 * torch.randn(1, text_len, dim))
|
||||
self.dropout = nn.Dropout(embedding_dropout)
|
||||
|
||||
# transformer
|
||||
self.transformer = nn.ModuleList([
|
||||
AttentionBlock(dim, num_heads, attn_dropout, proj_dropout)
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
self.norm = LayerNorm(dim)
|
||||
|
||||
# head
|
||||
gain = 1.0 / math.sqrt(dim)
|
||||
self.head = nn.Parameter(gain * torch.randn(dim, out_dim))
|
||||
|
||||
# causal attention mask
|
||||
self.register_buffer('attn_mask',
|
||||
torch.tril(torch.ones(1, text_len, text_len)))
|
||||
|
||||
def forward(self, x):
|
||||
eot, dtype = x.argmax(dim=-1), self.head.dtype
|
||||
|
||||
# embeddings
|
||||
x = self.dropout(
|
||||
self.token_embedding(x).type(dtype)
|
||||
+ self.pos_embedding.type(dtype))
|
||||
|
||||
# transformer
|
||||
for block in self.transformer:
|
||||
x = block(x, self.attn_mask)
|
||||
|
||||
# head
|
||||
x = self.norm(x)
|
||||
x = torch.mm(x[torch.arange(x.size(0)), eot], self.head)
|
||||
return x
|
||||
|
||||
def fp16(self):
|
||||
return self.apply(to_fp16)
|
||||
|
||||
|
||||
class CLIP(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
embed_dim=512,
|
||||
image_size=224,
|
||||
patch_size=16,
|
||||
vision_dim=768,
|
||||
vision_heads=12,
|
||||
vision_layers=12,
|
||||
vocab_size=49408,
|
||||
text_len=77,
|
||||
text_dim=512,
|
||||
text_heads=8,
|
||||
text_layers=12,
|
||||
attn_dropout=0.0,
|
||||
proj_dropout=0.0,
|
||||
embedding_dropout=0.0):
|
||||
super(CLIP, self).__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.image_size = image_size
|
||||
self.patch_size = patch_size
|
||||
self.vision_dim = vision_dim
|
||||
self.vision_heads = vision_heads
|
||||
self.vision_layers = vision_layers
|
||||
self.vocab_size = vocab_size
|
||||
self.text_len = text_len
|
||||
self.text_dim = text_dim
|
||||
self.text_heads = text_heads
|
||||
self.text_layers = text_layers
|
||||
|
||||
# models
|
||||
self.visual = VisionTransformer(
|
||||
image_size=image_size,
|
||||
patch_size=patch_size,
|
||||
dim=vision_dim,
|
||||
out_dim=embed_dim,
|
||||
num_heads=vision_heads,
|
||||
num_layers=vision_layers,
|
||||
attn_dropout=attn_dropout,
|
||||
proj_dropout=proj_dropout,
|
||||
embedding_dropout=embedding_dropout)
|
||||
self.textual = TextTransformer(
|
||||
vocab_size=vocab_size,
|
||||
text_len=text_len,
|
||||
dim=text_dim,
|
||||
out_dim=embed_dim,
|
||||
num_heads=text_heads,
|
||||
num_layers=text_layers,
|
||||
attn_dropout=attn_dropout,
|
||||
proj_dropout=proj_dropout,
|
||||
embedding_dropout=embedding_dropout)
|
||||
self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([]))
|
||||
|
||||
def forward(self, imgs, txt_tokens):
|
||||
r"""imgs: [B, C, H, W] of torch.float32.
|
||||
txt_tokens: [B, T] of torch.long.
|
||||
"""
|
||||
xi = self.visual(imgs)
|
||||
xt = self.textual(txt_tokens)
|
||||
|
||||
# normalize features
|
||||
xi = F.normalize(xi, p=2, dim=1)
|
||||
xt = F.normalize(xt, p=2, dim=1)
|
||||
|
||||
# gather features from all ranks
|
||||
full_xi = ops.diff_all_gather(xi)
|
||||
full_xt = ops.diff_all_gather(xt)
|
||||
|
||||
# logits
|
||||
scale = self.log_scale.exp()
|
||||
logits_i2t = scale * torch.mm(xi, full_xt.t())
|
||||
logits_t2i = scale * torch.mm(xt, full_xi.t())
|
||||
|
||||
# labels
|
||||
labels = torch.arange(
|
||||
len(xi) * ops.get_rank(),
|
||||
len(xi) * (ops.get_rank() + 1),
|
||||
dtype=torch.long,
|
||||
device=xi.device)
|
||||
return logits_i2t, logits_t2i, labels
|
||||
|
||||
def init_weights(self):
|
||||
# embeddings
|
||||
nn.init.normal_(self.textual.token_embedding.weight, std=0.02)
|
||||
nn.init.normal_(self.visual.patch_embedding.weight, tsd=0.1)
|
||||
|
||||
# attentions
|
||||
for modality in ['visual', 'textual']:
|
||||
dim = self.vision_dim if modality == 'visual' else 'textual'
|
||||
transformer = getattr(self, modality).transformer
|
||||
proj_gain = (1.0 / math.sqrt(dim)) * (
|
||||
1.0 / math.sqrt(2 * transformer.num_layers))
|
||||
attn_gain = 1.0 / math.sqrt(dim)
|
||||
mlp_gain = 1.0 / math.sqrt(2.0 * dim)
|
||||
for block in transformer.layers:
|
||||
nn.init.normal_(block.attn.to_qkv.weight, std=attn_gain)
|
||||
nn.init.normal_(block.attn.proj.weight, std=proj_gain)
|
||||
nn.init.normal_(block.mlp[0].weight, std=mlp_gain)
|
||||
nn.init.normal_(block.mlp[2].weight, std=proj_gain)
|
||||
|
||||
def param_groups(self):
|
||||
groups = [{
|
||||
'params': [
|
||||
p for n, p in self.named_parameters()
|
||||
if 'norm' in n or n.endswith('bias')
|
||||
],
|
||||
'weight_decay':
|
||||
0.0
|
||||
}, {
|
||||
'params': [
|
||||
p for n, p in self.named_parameters()
|
||||
if not ('norm' in n or n.endswith('bias'))
|
||||
]
|
||||
}]
|
||||
return groups
|
||||
|
||||
def fp16(self):
|
||||
return self.apply(to_fp16)
|
||||
|
||||
|
||||
def _clip(name, pretrained=False, **kwargs):
|
||||
model = CLIP(**kwargs)
|
||||
if pretrained:
|
||||
model.load_state_dict(
|
||||
torch.load(
|
||||
DOWNLOAD_TO_CACHE(f'models/clip/{name}.pth'),
|
||||
map_location='cpu'))
|
||||
return model
|
||||
|
||||
|
||||
def clip_vit_b_32(pretrained=False, **kwargs):
|
||||
cfg = dict(
|
||||
embed_dim=512,
|
||||
image_size=224,
|
||||
patch_size=32,
|
||||
vision_dim=768,
|
||||
vision_heads=12,
|
||||
vision_layers=12,
|
||||
vocab_size=49408,
|
||||
text_len=77,
|
||||
text_dim=512,
|
||||
text_heads=8,
|
||||
text_layers=12)
|
||||
cfg.update(**kwargs)
|
||||
return _clip('openai-clip-vit-base-32', pretrained, **cfg)
|
||||
|
||||
|
||||
def clip_vit_b_16(pretrained=False, **kwargs):
|
||||
cfg = dict(
|
||||
embed_dim=512,
|
||||
image_size=224,
|
||||
patch_size=32,
|
||||
vision_dim=768,
|
||||
vision_heads=12,
|
||||
vision_layers=12,
|
||||
vocab_size=49408,
|
||||
text_len=77,
|
||||
text_dim=512,
|
||||
text_heads=8,
|
||||
text_layers=12)
|
||||
cfg.update(**kwargs)
|
||||
return _clip('openai-clip-vit-base-16', pretrained, **cfg)
|
||||
|
||||
|
||||
def clip_vit_l_14(pretrained=False, **kwargs):
|
||||
cfg = dict(
|
||||
embed_dim=768,
|
||||
image_size=224,
|
||||
patch_size=14,
|
||||
vision_dim=1024,
|
||||
vision_heads=16,
|
||||
vision_layers=24,
|
||||
vocab_size=49408,
|
||||
text_len=77,
|
||||
text_dim=768,
|
||||
text_heads=12,
|
||||
text_layers=12)
|
||||
cfg.update(**kwargs)
|
||||
return _clip('openai-clip-vit-large-14', pretrained, **cfg)
|
||||
|
||||
|
||||
def clip_vit_l_14_336px(pretrained=False, **kwargs):
|
||||
cfg = dict(
|
||||
embed_dim=768,
|
||||
image_size=336,
|
||||
patch_size=14,
|
||||
vision_dim=1024,
|
||||
vision_heads=16,
|
||||
vision_layers=24,
|
||||
vocab_size=49408,
|
||||
text_len=77,
|
||||
text_dim=768,
|
||||
text_heads=12,
|
||||
text_layers=12)
|
||||
cfg.update(**kwargs)
|
||||
return _clip('openai-clip-vit-large-14-336px', pretrained, **cfg)
|
||||
|
||||
|
||||
def clip_vit_h_16(pretrained=False, **kwargs):
|
||||
assert not pretrained, 'pretrained model for openai-clip-vit-huge-16 is not available!'
|
||||
cfg = dict(
|
||||
embed_dim=1024,
|
||||
image_size=256,
|
||||
patch_size=16,
|
||||
vision_dim=1280,
|
||||
vision_heads=16,
|
||||
vision_layers=32,
|
||||
vocab_size=49408,
|
||||
text_len=77,
|
||||
text_dim=1024,
|
||||
text_heads=16,
|
||||
text_layers=24)
|
||||
cfg.update(**kwargs)
|
||||
return _clip('openai-clip-vit-huge-16', pretrained, **cfg)
|
||||
320
modelscope/models/multi_modal/videocomposer/models/midas.py
Normal file
320
modelscope/models/multi_modal/videocomposer/models/midas.py
Normal file
@@ -0,0 +1,320 @@
|
||||
r"""A much cleaner re-implementation of ``https://github.com/isl-org/MiDaS''.
|
||||
Image augmentation: T.Compose([
|
||||
Resize(
|
||||
keep_aspect_ratio=True,
|
||||
ensure_multiple_of=32,
|
||||
interpolation=cv2.INTER_CUBIC),
|
||||
T.ToTensor(),
|
||||
T.Normalize(
|
||||
mean=[0.5, 0.5, 0.5],
|
||||
std=[0.5, 0.5, 0.5])]).
|
||||
Fast inference:
|
||||
model = model.to(memory_format=torch.channels_last).half()
|
||||
input = input.to(memory_format=torch.channels_last).half()
|
||||
output = model(input)
|
||||
"""
|
||||
import math
|
||||
import os
|
||||
import os.path as osp
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
__all__ = ['MiDaS', 'midas_v3']
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
|
||||
def __init__(self, dim, num_heads):
|
||||
assert dim % num_heads == 0
|
||||
super(SelfAttention, self).__init__()
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.scale = 1.0 / math.sqrt(self.head_dim)
|
||||
|
||||
# layers
|
||||
self.to_qkv = nn.Linear(dim, dim * 3)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
|
||||
def forward(self, x):
|
||||
b, l, c, n, d = *x.size(), self.num_heads, self.head_dim
|
||||
|
||||
# compute query, key, value
|
||||
q, k, v = self.to_qkv(x).view(b, l, n * 3, d).chunk(3, dim=2)
|
||||
|
||||
# compute attention
|
||||
attn = self.scale * torch.einsum('binc,bjnc->bnij', q, k)
|
||||
attn = F.softmax(attn.float(), dim=-1).type_as(attn)
|
||||
|
||||
# gather context
|
||||
x = torch.einsum('bnij,bjnc->binc', attn, v)
|
||||
x = x.reshape(b, l, c)
|
||||
|
||||
# output
|
||||
x = self.proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class AttentionBlock(nn.Module):
|
||||
|
||||
def __init__(self, dim, num_heads):
|
||||
super(AttentionBlock, self).__init__()
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
|
||||
# layers
|
||||
self.norm1 = nn.LayerNorm(dim)
|
||||
self.attn = SelfAttention(dim, num_heads)
|
||||
self.norm2 = nn.LayerNorm(dim)
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim))
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.attn(self.norm1(x))
|
||||
x = x + self.mlp(self.norm2(x))
|
||||
return x
|
||||
|
||||
|
||||
class VisionTransformer(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
image_size=384,
|
||||
patch_size=16,
|
||||
dim=1024,
|
||||
out_dim=1000,
|
||||
num_heads=16,
|
||||
num_layers=24):
|
||||
assert image_size % patch_size == 0
|
||||
super(VisionTransformer, self).__init__()
|
||||
self.image_size = image_size
|
||||
self.patch_size = patch_size
|
||||
self.dim = dim
|
||||
self.out_dim = out_dim
|
||||
self.num_heads = num_heads
|
||||
self.num_layers = num_layers
|
||||
self.num_patches = (image_size // patch_size)**2
|
||||
|
||||
# embeddings
|
||||
self.patch_embedding = nn.Conv2d(
|
||||
3, dim, kernel_size=patch_size, stride=patch_size)
|
||||
self.cls_embedding = nn.Parameter(torch.zeros(1, 1, dim))
|
||||
self.pos_embedding = nn.Parameter(
|
||||
torch.empty(1, self.num_patches + 1, dim).normal_(std=0.02))
|
||||
|
||||
# blocks
|
||||
self.blocks = nn.Sequential(
|
||||
*[AttentionBlock(dim, num_heads) for _ in range(num_layers)])
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
|
||||
# head
|
||||
self.head = nn.Linear(dim, out_dim)
|
||||
|
||||
def forward(self, x):
|
||||
b = x.size(0)
|
||||
|
||||
# embeddings
|
||||
x = self.patch_embedding(x).flatten(2).permute(0, 2, 1)
|
||||
x = torch.cat([self.cls_embedding.repeat(b, 1, 1), x], dim=1)
|
||||
x = x + self.pos_embedding
|
||||
|
||||
# blocks
|
||||
x = self.blocks(x)
|
||||
x = self.norm(x)
|
||||
|
||||
# head
|
||||
x = self.head(x)
|
||||
return x
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
|
||||
def __init__(self, dim):
|
||||
super(ResidualBlock, self).__init__()
|
||||
self.dim = dim
|
||||
|
||||
# layers
|
||||
self.residual = nn.Sequential(
|
||||
nn.ReLU(inplace=False), # NOTE: avoid modifying the input
|
||||
nn.Conv2d(dim, dim, 3, padding=1),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(dim, dim, 3, padding=1))
|
||||
|
||||
def forward(self, x):
|
||||
return x + self.residual(x)
|
||||
|
||||
|
||||
class FusionBlock(nn.Module):
|
||||
|
||||
def __init__(self, dim):
|
||||
super(FusionBlock, self).__init__()
|
||||
self.dim = dim
|
||||
|
||||
# layers
|
||||
self.layer1 = ResidualBlock(dim)
|
||||
self.layer2 = ResidualBlock(dim)
|
||||
self.conv_out = nn.Conv2d(dim, dim, 1)
|
||||
|
||||
def forward(self, *xs):
|
||||
assert len(xs) in (1, 2), 'invalid number of inputs'
|
||||
if len(xs) == 1:
|
||||
x = self.layer2(xs[0])
|
||||
else:
|
||||
x = self.layer2(xs[0] + self.layer1(xs[1]))
|
||||
x = F.interpolate(
|
||||
x, scale_factor=2, mode='bilinear', align_corners=True)
|
||||
x = self.conv_out(x)
|
||||
return x
|
||||
|
||||
|
||||
class MiDaS(nn.Module):
|
||||
r"""MiDaS v3.0 DPT-Large from ``https://github.com/isl-org/MiDaS''.
|
||||
Monocular depth estimation using dense prediction transformers.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
image_size=384,
|
||||
patch_size=16,
|
||||
dim=1024,
|
||||
neck_dims=[256, 512, 1024, 1024],
|
||||
fusion_dim=256,
|
||||
num_heads=16,
|
||||
num_layers=24):
|
||||
assert image_size % patch_size == 0
|
||||
assert num_layers % 4 == 0
|
||||
super(MiDaS, self).__init__()
|
||||
self.image_size = image_size
|
||||
self.patch_size = patch_size
|
||||
self.dim = dim
|
||||
self.neck_dims = neck_dims
|
||||
self.fusion_dim = fusion_dim
|
||||
self.num_heads = num_heads
|
||||
self.num_layers = num_layers
|
||||
self.num_patches = (image_size // patch_size)**2
|
||||
|
||||
# embeddings
|
||||
self.patch_embedding = nn.Conv2d(
|
||||
3, dim, kernel_size=patch_size, stride=patch_size)
|
||||
self.cls_embedding = nn.Parameter(torch.zeros(1, 1, dim))
|
||||
self.pos_embedding = nn.Parameter(
|
||||
torch.empty(1, self.num_patches + 1, dim).normal_(std=0.02))
|
||||
|
||||
# blocks
|
||||
stride = num_layers // 4
|
||||
self.blocks = nn.Sequential(
|
||||
*[AttentionBlock(dim, num_heads) for _ in range(num_layers)])
|
||||
self.slices = [slice(i * stride, (i + 1) * stride) for i in range(4)]
|
||||
|
||||
# stage1 (4x)
|
||||
self.fc1 = nn.Sequential(nn.Linear(dim * 2, dim), nn.GELU())
|
||||
self.conv1 = nn.Sequential(
|
||||
nn.Conv2d(dim, neck_dims[0], 1),
|
||||
nn.ConvTranspose2d(neck_dims[0], neck_dims[0], 4, stride=4),
|
||||
nn.Conv2d(neck_dims[0], fusion_dim, 3, padding=1, bias=False))
|
||||
self.fusion1 = FusionBlock(fusion_dim)
|
||||
|
||||
# stage2 (8x)
|
||||
self.fc2 = nn.Sequential(nn.Linear(dim * 2, dim), nn.GELU())
|
||||
self.conv2 = nn.Sequential(
|
||||
nn.Conv2d(dim, neck_dims[1], 1),
|
||||
nn.ConvTranspose2d(neck_dims[1], neck_dims[1], 2, stride=2),
|
||||
nn.Conv2d(neck_dims[1], fusion_dim, 3, padding=1, bias=False))
|
||||
self.fusion2 = FusionBlock(fusion_dim)
|
||||
|
||||
# stage3 (16x)
|
||||
self.fc3 = nn.Sequential(nn.Linear(dim * 2, dim), nn.GELU())
|
||||
self.conv3 = nn.Sequential(
|
||||
nn.Conv2d(dim, neck_dims[2], 1),
|
||||
nn.Conv2d(neck_dims[2], fusion_dim, 3, padding=1, bias=False))
|
||||
self.fusion3 = FusionBlock(fusion_dim)
|
||||
|
||||
# stage4 (32x)
|
||||
self.fc4 = nn.Sequential(nn.Linear(dim * 2, dim), nn.GELU())
|
||||
self.conv4 = nn.Sequential(
|
||||
nn.Conv2d(dim, neck_dims[3], 1),
|
||||
nn.Conv2d(neck_dims[3], neck_dims[3], 3, stride=2, padding=1),
|
||||
nn.Conv2d(neck_dims[3], fusion_dim, 3, padding=1, bias=False))
|
||||
self.fusion4 = FusionBlock(fusion_dim)
|
||||
|
||||
# head
|
||||
self.head = nn.Sequential(
|
||||
nn.Conv2d(fusion_dim, fusion_dim // 2, 3, padding=1),
|
||||
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
|
||||
nn.Conv2d(fusion_dim // 2, 32, 3, padding=1),
|
||||
nn.ReLU(inplace=True), nn.ConvTranspose2d(32, 1, 1),
|
||||
nn.ReLU(inplace=True))
|
||||
|
||||
def forward(self, x):
|
||||
b, _, h, w, p = *x.size(), self.patch_size
|
||||
assert h % p == 0 and w % p == 0, f'Image size ({w}, {h}) is not divisible by patch size ({p}, {p})'
|
||||
hp, wp, grid = h // p, w // p, self.image_size // p
|
||||
|
||||
# embeddings
|
||||
pos_embedding = torch.cat([
|
||||
self.pos_embedding[:, :1],
|
||||
F.interpolate(
|
||||
self.pos_embedding[:, 1:].reshape(1, grid, grid, -1).permute(
|
||||
0, 3, 1, 2),
|
||||
size=(hp, wp),
|
||||
mode='bilinear',
|
||||
align_corners=False).permute(0, 2, 3, 1).reshape(
|
||||
1, hp * wp, -1)
|
||||
],
|
||||
dim=1) # noqa
|
||||
x = self.patch_embedding(x).flatten(2).permute(0, 2, 1)
|
||||
x = torch.cat([self.cls_embedding.repeat(b, 1, 1), x], dim=1)
|
||||
x = x + pos_embedding
|
||||
|
||||
# stage1
|
||||
x = self.blocks[self.slices[0]](x)
|
||||
x1 = torch.cat([x[:, 1:], x[:, :1].expand_as(x[:, 1:])], dim=-1)
|
||||
x1 = self.fc1(x1).permute(0, 2, 1).unflatten(2, (hp, wp))
|
||||
x1 = self.conv1(x1)
|
||||
|
||||
# stage2
|
||||
x = self.blocks[self.slices[1]](x)
|
||||
x2 = torch.cat([x[:, 1:], x[:, :1].expand_as(x[:, 1:])], dim=-1)
|
||||
x2 = self.fc2(x2).permute(0, 2, 1).unflatten(2, (hp, wp))
|
||||
x2 = self.conv2(x2)
|
||||
|
||||
# stage3
|
||||
x = self.blocks[self.slices[2]](x)
|
||||
x3 = torch.cat([x[:, 1:], x[:, :1].expand_as(x[:, 1:])], dim=-1)
|
||||
x3 = self.fc3(x3).permute(0, 2, 1).unflatten(2, (hp, wp))
|
||||
x3 = self.conv3(x3)
|
||||
|
||||
# stage4
|
||||
x = self.blocks[self.slices[3]](x)
|
||||
x4 = torch.cat([x[:, 1:], x[:, :1].expand_as(x[:, 1:])], dim=-1)
|
||||
x4 = self.fc4(x4).permute(0, 2, 1).unflatten(2, (hp, wp))
|
||||
x4 = self.conv4(x4)
|
||||
|
||||
# fusion
|
||||
x4 = self.fusion4(x4)
|
||||
x3 = self.fusion3(x4, x3)
|
||||
x2 = self.fusion2(x3, x2)
|
||||
x1 = self.fusion1(x2, x1)
|
||||
|
||||
# head
|
||||
x = self.head(x1)
|
||||
return x
|
||||
|
||||
|
||||
def midas_v3(model_dir, pretrained=False, **kwargs):
|
||||
cfg = dict(
|
||||
image_size=384,
|
||||
patch_size=16,
|
||||
dim=1024,
|
||||
neck_dims=[256, 512, 1024, 1024],
|
||||
fusion_dim=256,
|
||||
num_heads=16,
|
||||
num_layers=24)
|
||||
cfg.update(**kwargs)
|
||||
model = MiDaS(**cfg)
|
||||
if pretrained:
|
||||
model.load_state_dict(
|
||||
torch.load(
|
||||
os.path.join(model_dir, 'midas_v3_dpt_large.pth'),
|
||||
map_location='cpu'))
|
||||
return model
|
||||
@@ -0,0 +1,7 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from .degration import *
|
||||
from .distributed import *
|
||||
from .losses import *
|
||||
from .random_mask import *
|
||||
from .utils import *
|
||||
998
modelscope/models/multi_modal/videocomposer/ops/degration.py
Normal file
998
modelscope/models/multi_modal/videocomposer/ops/degration.py
Normal file
@@ -0,0 +1,998 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
from datetime import datetime
|
||||
|
||||
import numpy as np
|
||||
import scipy
|
||||
import scipy.stats as stats
|
||||
import torch
|
||||
from scipy import ndimage
|
||||
from scipy.interpolate import interp2d
|
||||
from scipy.linalg import orth
|
||||
from torchvision.utils import make_grid
|
||||
|
||||
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
|
||||
|
||||
__all__ = ['degradation_bsrgan_light', 'degradation_bsrgan']
|
||||
|
||||
|
||||
# get uint8 image of size HxWxn_channles (RGB)
|
||||
def imread_uint(path, n_channels=3):
|
||||
# input: path
|
||||
# output: HxWx3(RGB or GGG), or HxWx1 (G)
|
||||
if n_channels == 1:
|
||||
img = cv2.imread(path, 0)
|
||||
img = np.expand_dims(img, axis=2)
|
||||
elif n_channels == 3:
|
||||
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
|
||||
if img.ndim == 2:
|
||||
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
|
||||
else:
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
return img
|
||||
|
||||
|
||||
def uint2single(img):
|
||||
return np.float32(img / 255.)
|
||||
|
||||
|
||||
def single2uint(img):
|
||||
return np.uint8((img.clip(0, 1) * 255.).round())
|
||||
|
||||
|
||||
def uint162single(img):
|
||||
return np.float32(img / 65535.)
|
||||
|
||||
|
||||
def single2uint16(img):
|
||||
return np.uint16((img.clip(0, 1) * 65535.).round())
|
||||
|
||||
|
||||
def rgb2ycbcr(img, only_y=True):
|
||||
'''same as matlab rgb2ycbcr
|
||||
only_y: only return Y channel
|
||||
Input:
|
||||
uint8, [0, 255]
|
||||
float, [0, 1]
|
||||
'''
|
||||
in_img_type = img.dtype
|
||||
img.astype(np.float32)
|
||||
if in_img_type != np.uint8:
|
||||
img *= 255.
|
||||
# convert
|
||||
if only_y:
|
||||
rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0
|
||||
else:
|
||||
rlt = np.matmul(img,
|
||||
[[65.481, -37.797, 112.0], [128.553, -74.203, -93.786],
|
||||
[24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128]
|
||||
if in_img_type == np.uint8:
|
||||
rlt = rlt.round()
|
||||
else:
|
||||
rlt /= 255.
|
||||
return rlt.astype(in_img_type)
|
||||
|
||||
|
||||
def ycbcr2rgb(img):
|
||||
'''same as matlab ycbcr2rgb
|
||||
Input:
|
||||
uint8, [0, 255]
|
||||
float, [0, 1]
|
||||
'''
|
||||
in_img_type = img.dtype
|
||||
img.astype(np.float32)
|
||||
if in_img_type != np.uint8:
|
||||
img *= 255.
|
||||
# convert
|
||||
rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621],
|
||||
[0, -0.00153632, 0.00791071],
|
||||
[0.00625893, -0.00318811, 0]]) * 255.0 + [
|
||||
-222.921, 135.576, -276.836
|
||||
] # noqa
|
||||
if in_img_type == np.uint8:
|
||||
rlt = rlt.round()
|
||||
else:
|
||||
rlt /= 255.
|
||||
return rlt.astype(in_img_type)
|
||||
|
||||
|
||||
def bgr2ycbcr(img, only_y=True):
|
||||
'''bgr version of rgb2ycbcr
|
||||
only_y: only return Y channel
|
||||
Input:
|
||||
uint8, [0, 255]
|
||||
float, [0, 1]
|
||||
'''
|
||||
in_img_type = img.dtype
|
||||
img.astype(np.float32)
|
||||
if in_img_type != np.uint8:
|
||||
img *= 255.
|
||||
# convert
|
||||
if only_y:
|
||||
rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0
|
||||
else:
|
||||
rlt = np.matmul(img,
|
||||
[[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
|
||||
[65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128]
|
||||
if in_img_type == np.uint8:
|
||||
rlt = rlt.round()
|
||||
else:
|
||||
rlt /= 255.
|
||||
return rlt.astype(in_img_type)
|
||||
|
||||
|
||||
def channel_convert(in_c, tar_type, img_list):
|
||||
# conversion among BGR, gray and y
|
||||
if in_c == 3 and tar_type == 'gray':
|
||||
gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list]
|
||||
return [np.expand_dims(img, axis=2) for img in gray_list]
|
||||
elif in_c == 3 and tar_type == 'y':
|
||||
y_list = [bgr2ycbcr(img, only_y=True) for img in img_list]
|
||||
return [np.expand_dims(img, axis=2) for img in y_list]
|
||||
elif in_c == 1 and tar_type == 'RGB':
|
||||
return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list]
|
||||
else:
|
||||
return img_list
|
||||
|
||||
|
||||
# PSNR
|
||||
def calculate_psnr(img1, img2, border=0):
|
||||
# img1 and img2 have range [0, 255]
|
||||
if not img1.shape == img2.shape:
|
||||
raise ValueError('Input images must have the same dimensions.')
|
||||
h, w = img1.shape[:2]
|
||||
img1 = img1[border:h - border, border:w - border]
|
||||
img2 = img2[border:h - border, border:w - border]
|
||||
|
||||
img1 = img1.astype(np.float64)
|
||||
img2 = img2.astype(np.float64)
|
||||
mse = np.mean((img1 - img2)**2)
|
||||
if mse == 0:
|
||||
return float('inf')
|
||||
return 20 * math.log10(255.0 / math.sqrt(mse))
|
||||
|
||||
|
||||
# SSIM
|
||||
def calculate_ssim(img1, img2, border=0):
|
||||
'''calculate SSIM
|
||||
the same outputs as MATLAB's
|
||||
img1, img2: [0, 255]
|
||||
'''
|
||||
if not img1.shape == img2.shape:
|
||||
raise ValueError('Input images must have the same dimensions.')
|
||||
h, w = img1.shape[:2]
|
||||
img1 = img1[border:h - border, border:w - border]
|
||||
img2 = img2[border:h - border, border:w - border]
|
||||
|
||||
if img1.ndim == 2:
|
||||
return ssim(img1, img2)
|
||||
elif img1.ndim == 3:
|
||||
if img1.shape[2] == 3:
|
||||
ssims = []
|
||||
for i in range(3):
|
||||
ssims.append(ssim(img1[:, :, i], img2[:, :, i]))
|
||||
return np.array(ssims).mean()
|
||||
elif img1.shape[2] == 1:
|
||||
return ssim(np.squeeze(img1), np.squeeze(img2))
|
||||
else:
|
||||
raise ValueError('Wrong input image dimensions.')
|
||||
|
||||
|
||||
def ssim(img1, img2):
|
||||
C1 = (0.01 * 255)**2
|
||||
C2 = (0.03 * 255)**2
|
||||
|
||||
img1 = img1.astype(np.float64)
|
||||
img2 = img2.astype(np.float64)
|
||||
kernel = cv2.getGaussianKernel(11, 1.5)
|
||||
window = np.outer(kernel, kernel.transpose())
|
||||
|
||||
mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5]
|
||||
mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
|
||||
mu1_sq = mu1**2
|
||||
mu2_sq = mu2**2
|
||||
mu1_mu2 = mu1 * mu2
|
||||
sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
|
||||
sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
|
||||
sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
|
||||
|
||||
ssim_map = ((2 * mu1_mu2 + C1) # noqa
|
||||
* (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) # noqa
|
||||
* # noqa
|
||||
(sigma1_sq + sigma2_sq + C2)) # noqa
|
||||
return ssim_map.mean()
|
||||
|
||||
|
||||
# matlab 'imresize' function, now only support 'bicubic'
|
||||
def cubic(x):
|
||||
absx = torch.abs(x)
|
||||
absx2 = absx**2
|
||||
absx3 = absx**3
|
||||
return (1.5 * absx3 - 2.5 * absx2 + 1) * ((absx <= 1).type_as(absx)) + \
|
||||
(-0.5 * absx3 + 2.5 * absx2 - 4*absx + 2) * (((absx > 1) * (absx <= 2)).type_as(absx)) # noqa
|
||||
|
||||
|
||||
def calculate_weights_indices(in_length, out_length, scale, kernel,
|
||||
kernel_width, antialiasing):
|
||||
if (scale < 1) and (antialiasing):
|
||||
# Use a modified kernel to simultaneously interpolate and antialias- larger kernel width
|
||||
kernel_width = kernel_width / scale
|
||||
|
||||
# Output-space coordinates
|
||||
x = torch.linspace(1, out_length, out_length)
|
||||
|
||||
# Input-space coordinates. Calculate the inverse mapping such that 0.5
|
||||
# in output space maps to 0.5 in input space, and 0.5+scale in output
|
||||
# space maps to 1.5 in input space.
|
||||
u = x / scale + 0.5 * (1 - 1 / scale)
|
||||
|
||||
# What is the left-most pixel that can be involved in the computation?
|
||||
left = torch.floor(u - kernel_width / 2)
|
||||
|
||||
# What is the maximum number of pixels that can be involved in the
|
||||
# computation? Note: it's OK to use an extra pixel here; if the
|
||||
# corresponding weights are all zero, it will be eliminated at the end
|
||||
# of this function.
|
||||
P = math.ceil(kernel_width) + 2
|
||||
|
||||
# The indices of the input pixels involved in computing the k-th output
|
||||
# pixel are in row k of the indices matrix.
|
||||
indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(
|
||||
0, P - 1, P).view(1, P).expand(out_length, P)
|
||||
|
||||
# The weights used to compute the k-th output pixel are in row k of the
|
||||
# weights matrix.
|
||||
distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices
|
||||
# apply cubic kernel
|
||||
if (scale < 1) and (antialiasing):
|
||||
weights = scale * cubic(distance_to_center * scale)
|
||||
else:
|
||||
weights = cubic(distance_to_center)
|
||||
# Normalize the weights matrix so that each row sums to 1.
|
||||
weights_sum = torch.sum(weights, 1).view(out_length, 1)
|
||||
weights = weights / weights_sum.expand(out_length, P)
|
||||
|
||||
# If a column in weights is all zero, get rid of it. only consider the first and last column.
|
||||
weights_zero_tmp = torch.sum((weights == 0), 0)
|
||||
if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
|
||||
indices = indices.narrow(1, 1, P - 2)
|
||||
weights = weights.narrow(1, 1, P - 2)
|
||||
if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
|
||||
indices = indices.narrow(1, 0, P - 2)
|
||||
weights = weights.narrow(1, 0, P - 2)
|
||||
weights = weights.contiguous()
|
||||
indices = indices.contiguous()
|
||||
sym_len_s = -indices.min() + 1
|
||||
sym_len_e = indices.max() - in_length
|
||||
indices = indices + sym_len_s - 1
|
||||
return weights, indices, int(sym_len_s), int(sym_len_e)
|
||||
|
||||
|
||||
# imresize for tensor image [0, 1]
|
||||
def imresize(img, scale, antialiasing=True):
|
||||
# Now the scale should be the same for H and W
|
||||
# input: img: pytorch tensor, CHW or HW [0,1]
|
||||
# output: CHW or HW [0,1] w/o round
|
||||
need_squeeze = True if img.dim() == 2 else False
|
||||
if need_squeeze:
|
||||
img.unsqueeze_(0)
|
||||
in_C, in_H, in_W = img.size()
|
||||
out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W
|
||||
* scale)
|
||||
kernel_width = 4
|
||||
kernel = 'cubic'
|
||||
|
||||
# Return the desired dimension order for performing the resize. The
|
||||
# strategy is to perform the resize first along the dimension with the
|
||||
# smallest scale factor.
|
||||
# Now we do not support this.
|
||||
|
||||
# get weights and indices
|
||||
weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
|
||||
in_H, out_H, scale, kernel, kernel_width, antialiasing)
|
||||
weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
|
||||
in_W, out_W, scale, kernel, kernel_width, antialiasing)
|
||||
# process H dimension
|
||||
# symmetric copying
|
||||
img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W)
|
||||
img_aug.narrow(1, sym_len_Hs, in_H).copy_(img)
|
||||
|
||||
sym_patch = img[:, :sym_len_Hs, :]
|
||||
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
|
||||
sym_patch_inv = sym_patch.index_select(1, inv_idx)
|
||||
img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv)
|
||||
|
||||
sym_patch = img[:, -sym_len_He:, :]
|
||||
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
|
||||
sym_patch_inv = sym_patch.index_select(1, inv_idx)
|
||||
img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
|
||||
|
||||
out_1 = torch.FloatTensor(in_C, out_H, in_W)
|
||||
kernel_width = weights_H.size(1)
|
||||
for i in range(out_H):
|
||||
idx = int(indices_H[i][0])
|
||||
for j in range(out_C):
|
||||
out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(
|
||||
0, 1).mv(weights_H[i])
|
||||
|
||||
# process W dimension
|
||||
# symmetric copying
|
||||
out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We)
|
||||
out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1)
|
||||
|
||||
sym_patch = out_1[:, :, :sym_len_Ws]
|
||||
inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
|
||||
sym_patch_inv = sym_patch.index_select(2, inv_idx)
|
||||
out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv)
|
||||
|
||||
sym_patch = out_1[:, :, -sym_len_We:]
|
||||
inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
|
||||
sym_patch_inv = sym_patch.index_select(2, inv_idx)
|
||||
out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
|
||||
|
||||
out_2 = torch.FloatTensor(in_C, out_H, out_W)
|
||||
kernel_width = weights_W.size(1)
|
||||
for i in range(out_W):
|
||||
idx = int(indices_W[i][0])
|
||||
for j in range(out_C):
|
||||
out_2[j, :, i] = out_1_aug[j, :,
|
||||
idx:idx + kernel_width].mv(weights_W[i])
|
||||
if need_squeeze:
|
||||
out_2.squeeze_()
|
||||
return out_2
|
||||
|
||||
|
||||
# imresize for numpy image [0, 1]
|
||||
def imresize_np(img, scale, antialiasing=True):
|
||||
# Now the scale should be the same for H and W
|
||||
# input: img: Numpy, HWC or HW [0,1]
|
||||
# output: HWC or HW [0,1] w/o round
|
||||
img = torch.from_numpy(img)
|
||||
need_squeeze = True if img.dim() == 2 else False
|
||||
if need_squeeze:
|
||||
img.unsqueeze_(2)
|
||||
|
||||
in_H, in_W, in_C = img.size()
|
||||
out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W
|
||||
* scale)
|
||||
kernel_width = 4
|
||||
kernel = 'cubic'
|
||||
|
||||
# Return the desired dimension order for performing the resize. The
|
||||
# strategy is to perform the resize first along the dimension with the
|
||||
# smallest scale factor.
|
||||
# Now we do not support this.
|
||||
|
||||
# get weights and indices
|
||||
weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
|
||||
in_H, out_H, scale, kernel, kernel_width, antialiasing)
|
||||
weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
|
||||
in_W, out_W, scale, kernel, kernel_width, antialiasing)
|
||||
# process H dimension
|
||||
# symmetric copying
|
||||
img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C)
|
||||
img_aug.narrow(0, sym_len_Hs, in_H).copy_(img)
|
||||
|
||||
sym_patch = img[:sym_len_Hs, :, :]
|
||||
inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
|
||||
sym_patch_inv = sym_patch.index_select(0, inv_idx)
|
||||
img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv)
|
||||
|
||||
sym_patch = img[-sym_len_He:, :, :]
|
||||
inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
|
||||
sym_patch_inv = sym_patch.index_select(0, inv_idx)
|
||||
img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
|
||||
|
||||
out_1 = torch.FloatTensor(out_H, in_W, in_C)
|
||||
kernel_width = weights_H.size(1)
|
||||
for i in range(out_H):
|
||||
idx = int(indices_H[i][0])
|
||||
for j in range(out_C):
|
||||
out_1[i, :, j] = img_aug[idx:idx + kernel_width, :,
|
||||
j].transpose(0, 1).mv(weights_H[i])
|
||||
|
||||
# process W dimension
|
||||
# symmetric copying
|
||||
out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C)
|
||||
out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1)
|
||||
|
||||
sym_patch = out_1[:, :sym_len_Ws, :]
|
||||
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
|
||||
sym_patch_inv = sym_patch.index_select(1, inv_idx)
|
||||
out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv)
|
||||
|
||||
sym_patch = out_1[:, -sym_len_We:, :]
|
||||
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
|
||||
sym_patch_inv = sym_patch.index_select(1, inv_idx)
|
||||
out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
|
||||
|
||||
out_2 = torch.FloatTensor(out_H, out_W, in_C)
|
||||
kernel_width = weights_W.size(1)
|
||||
for i in range(out_W):
|
||||
idx = int(indices_W[i][0])
|
||||
for j in range(out_C):
|
||||
out_2[:, i, j] = out_1_aug[:, idx:idx + kernel_width,
|
||||
j].mv(weights_W[i])
|
||||
if need_squeeze:
|
||||
out_2.squeeze_()
|
||||
|
||||
return out_2.numpy()
|
||||
|
||||
|
||||
def modcrop_np(img, sf):
|
||||
'''
|
||||
Args:
|
||||
img: numpy image, WxH or WxHxC
|
||||
sf: scale factor
|
||||
Return:
|
||||
cropped image
|
||||
'''
|
||||
w, h = img.shape[:2]
|
||||
im = np.copy(img)
|
||||
return im[:w - w % sf, :h - h % sf, ...]
|
||||
|
||||
|
||||
def analytic_kernel(k):
|
||||
"""Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
|
||||
k_size = k.shape[0]
|
||||
# Calculate the big kernels size
|
||||
big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2))
|
||||
# Loop over the small kernel to fill the big one
|
||||
for r in range(k_size):
|
||||
for c in range(k_size):
|
||||
big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k
|
||||
# Crop the edges of the big kernel to ignore very small values and increase run time of SR
|
||||
crop = k_size // 2
|
||||
cropped_big_k = big_k[crop:-crop, crop:-crop]
|
||||
# Normalize to 1
|
||||
return cropped_big_k / cropped_big_k.sum()
|
||||
|
||||
|
||||
def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
|
||||
""" generate an anisotropic Gaussian kernel
|
||||
Args:
|
||||
ksize : e.g., 15, kernel size
|
||||
theta : [0, pi], rotation angle range
|
||||
l1 : [0.1,50], scaling of eigenvalues
|
||||
l2 : [0.1,l1], scaling of eigenvalues
|
||||
If l1 = l2, will get an isotropic Gaussian kernel.
|
||||
Returns:
|
||||
k : kernel
|
||||
"""
|
||||
|
||||
v = np.dot(
|
||||
np.array([[np.cos(theta), -np.sin(theta)],
|
||||
[np.sin(theta), np.cos(theta)]]), np.array([1., 0.]))
|
||||
V = np.array([[v[0], v[1]], [v[1], -v[0]]])
|
||||
D = np.array([[l1, 0], [0, l2]])
|
||||
Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
|
||||
k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)
|
||||
return k
|
||||
|
||||
|
||||
def gm_blur_kernel(mean, cov, size=15):
|
||||
center = size / 2.0 + 0.5
|
||||
k = np.zeros([size, size])
|
||||
for y in range(size):
|
||||
for x in range(size):
|
||||
cy = y - center + 1
|
||||
cx = x - center + 1
|
||||
k[y, x] = stats.multivariate_normal.pdf([cx, cy],
|
||||
mean=mean,
|
||||
cov=cov)
|
||||
|
||||
k = k / np.sum(k)
|
||||
return k
|
||||
|
||||
|
||||
def shift_pixel(x, sf, upper_left=True):
|
||||
"""shift pixel for super-resolution with different scale factors
|
||||
Args:
|
||||
x: WxHxC or WxH
|
||||
sf: scale factor
|
||||
upper_left: shift direction
|
||||
"""
|
||||
h, w = x.shape[:2]
|
||||
shift = (sf - 1) * 0.5
|
||||
xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
|
||||
if upper_left:
|
||||
x1 = xv + shift
|
||||
y1 = yv + shift
|
||||
else:
|
||||
x1 = xv - shift
|
||||
y1 = yv - shift
|
||||
|
||||
x1 = np.clip(x1, 0, w - 1)
|
||||
y1 = np.clip(y1, 0, h - 1)
|
||||
|
||||
if x.ndim == 2:
|
||||
x = interp2d(xv, yv, x)(x1, y1)
|
||||
if x.ndim == 3:
|
||||
for i in range(x.shape[-1]):
|
||||
x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def blur(x, k):
|
||||
'''
|
||||
x: image, NxcxHxW
|
||||
k: kernel, Nx1xhxw
|
||||
'''
|
||||
n, c = x.shape[:2]
|
||||
p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2
|
||||
x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate')
|
||||
k = k.repeat(1, c, 1, 1)
|
||||
k = k.view(-1, 1, k.shape[2], k.shape[3])
|
||||
x = x.view(1, -1, x.shape[2], x.shape[3])
|
||||
x = torch.nn.functional.conv2d(
|
||||
x, k, bias=None, stride=1, padding=0, groups=n * c)
|
||||
x = x.view(n, c, x.shape[2], x.shape[3])
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def gen_kernel(
|
||||
k_size=np.array([15, 15]),
|
||||
scale_factor=np.array([4, 4]),
|
||||
min_var=0.6,
|
||||
max_var=10.,
|
||||
noise_level=0):
|
||||
""""
|
||||
# modified version of https://github.com/assafshocher/BlindSR_dataset_generator
|
||||
# Kai Zhang
|
||||
# min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
|
||||
# max_var = 2.5 * sf
|
||||
"""
|
||||
# Set random eigen-vals (lambdas) and angle (theta) for COV matrix
|
||||
lambda_1 = min_var + np.random.rand() * (max_var - min_var)
|
||||
lambda_2 = min_var + np.random.rand() * (max_var - min_var)
|
||||
theta = np.random.rand() * np.pi # random theta
|
||||
noise = -noise_level + np.random.rand(*k_size) * noise_level * 2
|
||||
|
||||
# Set COV matrix using Lambdas and Theta
|
||||
LAMBDA = np.diag([lambda_1, lambda_2])
|
||||
Q = np.array([[np.cos(theta), -np.sin(theta)],
|
||||
[np.sin(theta), np.cos(theta)]])
|
||||
SIGMA = Q @ LAMBDA @ Q.T
|
||||
INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
|
||||
|
||||
# Set expectation position (shifting kernel for aligned image)
|
||||
MU = k_size // 2 - 0.5 * (scale_factor - 1)
|
||||
MU = MU[None, None, :, None]
|
||||
|
||||
# Create meshgrid for Gaussian
|
||||
[X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
|
||||
Z = np.stack([X, Y], 2)[:, :, :, None]
|
||||
|
||||
# Calcualte Gaussian for every pixel of the kernel
|
||||
ZZ = Z - MU
|
||||
ZZ_t = ZZ.transpose(0, 1, 3, 2)
|
||||
raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
|
||||
|
||||
# shift the kernel so it will be centered
|
||||
# raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
|
||||
|
||||
# Normalize the kernel and return
|
||||
# kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
|
||||
kernel = raw_kernel / np.sum(raw_kernel)
|
||||
return kernel
|
||||
|
||||
|
||||
def fspecial_gaussian(hsize, sigma):
|
||||
hsize = [hsize, hsize]
|
||||
siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0]
|
||||
std = sigma
|
||||
[x, y] = np.meshgrid(
|
||||
np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1))
|
||||
arg = -(x * x + y * y) / (2 * std * std)
|
||||
h = np.exp(arg)
|
||||
h[h < scipy.finfo(float).eps * h.max()] = 0
|
||||
sumh = h.sum()
|
||||
if sumh != 0:
|
||||
h = h / sumh
|
||||
return h
|
||||
|
||||
|
||||
def fspecial_laplacian(alpha):
|
||||
alpha = max([0, min([alpha, 1])])
|
||||
h1 = alpha / (alpha + 1)
|
||||
h2 = (1 - alpha) / (alpha + 1)
|
||||
h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]]
|
||||
h = np.array(h)
|
||||
return h
|
||||
|
||||
|
||||
def fspecial(filter_type, *args, **kwargs):
|
||||
if filter_type == 'gaussian':
|
||||
return fspecial_gaussian(*args, **kwargs)
|
||||
if filter_type == 'laplacian':
|
||||
return fspecial_laplacian(*args, **kwargs)
|
||||
|
||||
|
||||
def bicubic_degradation(x, sf=3):
|
||||
'''
|
||||
Args:
|
||||
x: HxWxC image, [0, 1]
|
||||
sf: down-scale factor
|
||||
Return:
|
||||
bicubicly downsampled LR image
|
||||
'''
|
||||
x = imresize_np(x, scale=1 / sf)
|
||||
return x
|
||||
|
||||
|
||||
def srmd_degradation(x, k, sf=3):
|
||||
''' blur + bicubic downsampling
|
||||
Args:
|
||||
x: HxWxC image, [0, 1]
|
||||
k: hxw, double
|
||||
sf: down-scale factor
|
||||
Return:
|
||||
downsampled LR image
|
||||
Reference:
|
||||
@inproceedings{zhang2018learning,
|
||||
title={Learning a single convolutional super-resolution network for multiple degradations},
|
||||
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
|
||||
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
|
||||
pages={3262--3271},
|
||||
year={2018}
|
||||
}
|
||||
'''
|
||||
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
|
||||
x = bicubic_degradation(x, sf=sf)
|
||||
return x
|
||||
|
||||
|
||||
def dpsr_degradation(x, k, sf=3):
|
||||
''' bicubic downsampling + blur
|
||||
Args:
|
||||
x: HxWxC image, [0, 1]
|
||||
k: hxw, double
|
||||
sf: down-scale factor
|
||||
Return:
|
||||
downsampled LR image
|
||||
Reference:
|
||||
@inproceedings{zhang2019deep,
|
||||
title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
|
||||
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
|
||||
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
|
||||
pages={1671--1681},
|
||||
year={2019}
|
||||
}
|
||||
'''
|
||||
x = bicubic_degradation(x, sf=sf)
|
||||
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
|
||||
return x
|
||||
|
||||
|
||||
def classical_degradation(x, k, sf=3):
|
||||
''' blur + downsampling
|
||||
Args:
|
||||
x: HxWxC image, [0, 1]/[0, 255]
|
||||
k: hxw, double
|
||||
sf: down-scale factor
|
||||
Return:
|
||||
downsampled LR image
|
||||
'''
|
||||
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
|
||||
st = 0
|
||||
return x[st::sf, st::sf, ...]
|
||||
|
||||
|
||||
def add_sharpening(img, weight=0.5, radius=50, threshold=10):
|
||||
"""USM sharpening. borrowed from real-ESRGAN
|
||||
Input image: I; Blurry image: B.
|
||||
1. K = I + weight * (I - B)
|
||||
2. Mask = 1 if abs(I - B) > threshold, else: 0
|
||||
3. Blur mask:
|
||||
4. Out = Mask * K + (1 - Mask) * I
|
||||
Args:
|
||||
img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
|
||||
weight (float): Sharp weight. Default: 1.
|
||||
radius (float): Kernel size of Gaussian blur. Default: 50.
|
||||
threshold (int):
|
||||
"""
|
||||
if radius % 2 == 0:
|
||||
radius += 1
|
||||
blur = cv2.GaussianBlur(img, (radius, radius), 0)
|
||||
residual = img - blur
|
||||
mask = np.abs(residual) * 255 > threshold
|
||||
mask = mask.astype('float32')
|
||||
soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
|
||||
|
||||
K = img + weight * residual
|
||||
K = np.clip(K, 0, 1)
|
||||
return soft_mask * K + (1 - soft_mask) * img
|
||||
|
||||
|
||||
def add_blur_1(img, sf=4):
|
||||
wd2 = 4.0 + sf
|
||||
wd = 2.0 + 0.2 * sf
|
||||
|
||||
wd2 = wd2 / 4
|
||||
wd = wd / 4
|
||||
|
||||
if random.random() < 0.5:
|
||||
l1 = wd2 * random.random()
|
||||
l2 = wd2 * random.random()
|
||||
k = anisotropic_Gaussian(
|
||||
ksize=random.randint(2, 11) + 3,
|
||||
theta=random.random() * np.pi,
|
||||
l1=l1,
|
||||
l2=l2)
|
||||
else:
|
||||
k = fspecial('gaussian',
|
||||
random.randint(2, 4) + 3, wd * random.random())
|
||||
img = ndimage.filters.convolve(
|
||||
img, np.expand_dims(k, axis=2), mode='mirror')
|
||||
|
||||
return img
|
||||
|
||||
|
||||
def add_resize(img, sf=4):
|
||||
rnum = np.random.rand()
|
||||
if rnum > 0.8:
|
||||
sf1 = random.uniform(1, 2)
|
||||
elif rnum < 0.7:
|
||||
sf1 = random.uniform(0.5 / sf, 1)
|
||||
else:
|
||||
sf1 = 1.0
|
||||
img = cv2.resize(
|
||||
img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])),
|
||||
interpolation=random.choice([1, 2, 3]))
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
|
||||
return img
|
||||
|
||||
|
||||
def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
|
||||
noise_level = random.randint(noise_level1, noise_level2)
|
||||
rnum = np.random.rand()
|
||||
if rnum > 0.6:
|
||||
img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(
|
||||
np.float32)
|
||||
elif rnum < 0.4:
|
||||
img = img + np.random.normal(0, noise_level / 255.0,
|
||||
(*img.shape[:2], 1)).astype(np.float32)
|
||||
else:
|
||||
L = noise_level2 / 255.
|
||||
D = np.diag(np.random.rand(3))
|
||||
U = orth(np.random.rand(3, 3))
|
||||
conv = np.dot(np.dot(np.transpose(U), D), U)
|
||||
img = img + np.random.multivariate_normal([0, 0, 0], np.abs(
|
||||
L**2 * conv), img.shape[:2]).astype(np.float32)
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
return img
|
||||
|
||||
|
||||
def add_speckle_noise(img, noise_level1=2, noise_level2=25):
|
||||
noise_level = random.randint(noise_level1, noise_level2)
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
rnum = random.random()
|
||||
if rnum > 0.6:
|
||||
img += img * np.random.normal(0, noise_level / 255.0,
|
||||
img.shape).astype(np.float32)
|
||||
elif rnum < 0.4:
|
||||
img += img * np.random.normal(0, noise_level / 255.0,
|
||||
(*img.shape[:2], 1)).astype(np.float32)
|
||||
else:
|
||||
L = noise_level2 / 255.
|
||||
D = np.diag(np.random.rand(3))
|
||||
U = orth(np.random.rand(3, 3))
|
||||
conv = np.dot(np.dot(np.transpose(U), D), U)
|
||||
img += img * np.random.multivariate_normal(
|
||||
[0, 0, 0], np.abs(L**2 * conv), img.shape[:2]).astype(np.float32)
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
return img
|
||||
|
||||
|
||||
def add_Poisson_noise(img):
|
||||
img = np.clip((img * 255.0).round(), 0, 255) / 255.
|
||||
vals = 10**(2 * random.random() + 2.0)
|
||||
if random.random() < 0.5:
|
||||
img = np.random.poisson(img * vals).astype(np.float32) / vals
|
||||
else:
|
||||
img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])
|
||||
img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.
|
||||
noise_gray = np.random.poisson(img_gray * vals).astype(
|
||||
np.float32) / vals - img_gray
|
||||
img += noise_gray[:, :, np.newaxis]
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
return img
|
||||
|
||||
|
||||
def add_JPEG_noise(img):
|
||||
quality_factor = random.randint(80, 95)
|
||||
img = cv2.cvtColor(single2uint(img), cv2.COLOR_RGB2BGR)
|
||||
result, encimg = cv2.imencode(
|
||||
'.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
|
||||
img = cv2.imdecode(encimg, 1)
|
||||
img = cv2.cvtColor(uint2single(img), cv2.COLOR_BGR2RGB)
|
||||
return img
|
||||
|
||||
|
||||
def random_crop(lq, hq, sf=4, lq_patchsize=64):
|
||||
h, w = lq.shape[:2]
|
||||
rnd_h = random.randint(0, h - lq_patchsize)
|
||||
rnd_w = random.randint(0, w - lq_patchsize)
|
||||
lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :]
|
||||
|
||||
rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)
|
||||
hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf,
|
||||
rnd_w_H:rnd_w_H + lq_patchsize * sf, :]
|
||||
return lq, hq
|
||||
|
||||
|
||||
def degradation_bsrgan_light(image, sf=4, isp_model=None):
|
||||
"""
|
||||
This is the variant of the degradation model of BSRGAN from the paper
|
||||
"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
|
||||
sf: scale factor
|
||||
isp_model: camera ISP model
|
||||
Returns
|
||||
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
|
||||
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
|
||||
"""
|
||||
image = uint2single(image)
|
||||
_, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
|
||||
|
||||
h1, w1 = image.shape[:2]
|
||||
image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...]
|
||||
h, w = image.shape[:2]
|
||||
|
||||
if sf == 4 and random.random() < scale2_prob:
|
||||
if np.random.rand() < 0.5:
|
||||
image = cv2.resize(
|
||||
image,
|
||||
(int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),
|
||||
interpolation=random.choice([1, 2, 3]))
|
||||
else:
|
||||
image = imresize_np(image, 1 / 2, True)
|
||||
image = np.clip(image, 0.0, 1.0)
|
||||
sf = 2
|
||||
|
||||
shuffle_order = random.sample(range(7), 7)
|
||||
idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
|
||||
if idx1 > idx2:
|
||||
shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[
|
||||
idx2], shuffle_order[idx1]
|
||||
|
||||
for i in shuffle_order:
|
||||
|
||||
if i == 0:
|
||||
image = add_blur_1(image, sf=sf)
|
||||
elif i == 2:
|
||||
a, b = image.shape[1], image.shape[0]
|
||||
# downsample2
|
||||
if random.random() < 0.8:
|
||||
sf1 = random.uniform(1, 2 * sf)
|
||||
image = cv2.resize(
|
||||
image, (int(1 / sf1 * image.shape[1]),
|
||||
int(1 / sf1 * image.shape[0])),
|
||||
interpolation=random.choice([1, 2, 3]))
|
||||
else:
|
||||
k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
|
||||
k_shifted = shift_pixel(k, sf)
|
||||
k_shifted = k_shifted / k_shifted.sum()
|
||||
image = ndimage.filters.convolve(
|
||||
image, np.expand_dims(k_shifted, axis=2), mode='mirror')
|
||||
image = image[0::sf, 0::sf, ...]
|
||||
image = np.clip(image, 0.0, 1.0)
|
||||
elif i == 3:
|
||||
# downsample3
|
||||
image = cv2.resize(
|
||||
image, (int(1 / sf * a), int(1 / sf * b)),
|
||||
interpolation=random.choice([1, 2, 3]))
|
||||
image = np.clip(image, 0.0, 1.0)
|
||||
elif i == 4:
|
||||
# add Gaussian noise
|
||||
image = add_Gaussian_noise(image, noise_level1=1, noise_level2=2)
|
||||
elif i == 5:
|
||||
# add JPEG noise
|
||||
if random.random() < jpeg_prob:
|
||||
image = add_JPEG_noise(image)
|
||||
|
||||
# add final JPEG compression noise
|
||||
image = add_JPEG_noise(image)
|
||||
image = single2uint(image)
|
||||
return image
|
||||
|
||||
|
||||
def add_blur_2(img, sf=4):
|
||||
wd2 = 4.0 + sf
|
||||
wd = 2.0 + 0.2 * sf
|
||||
if random.random() < 0.5:
|
||||
l1 = wd2 * random.random()
|
||||
l2 = wd2 * random.random()
|
||||
k = anisotropic_Gaussian(
|
||||
ksize=2 * random.randint(2, 11) + 3,
|
||||
theta=random.random() * np.pi,
|
||||
l1=l1,
|
||||
l2=l2)
|
||||
else:
|
||||
k = fspecial('gaussian', 2 * random.randint(2, 11) + 3,
|
||||
wd * random.random())
|
||||
img = ndimage.filters.convolve(
|
||||
img, np.expand_dims(k, axis=2), mode='mirror')
|
||||
return img
|
||||
|
||||
|
||||
def degradation_bsrgan(image, sf=4, isp_model=None):
|
||||
"""
|
||||
This is the variant of the degradation model of BSRGAN from the paper
|
||||
"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
|
||||
sf: scale factor
|
||||
isp_model: camera ISP model
|
||||
Returns
|
||||
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
|
||||
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
|
||||
"""
|
||||
image = uint2single(image)
|
||||
_, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
|
||||
|
||||
h1, w1 = image.shape[:2]
|
||||
image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...]
|
||||
h, w = image.shape[:2]
|
||||
|
||||
if sf == 4 and random.random() < scale2_prob:
|
||||
if np.random.rand() < 0.5:
|
||||
image = cv2.resize(
|
||||
image,
|
||||
(int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),
|
||||
interpolation=random.choice([1, 2, 3]))
|
||||
else:
|
||||
image = imresize_np(image, 1 / 2, True)
|
||||
image = np.clip(image, 0.0, 1.0)
|
||||
sf = 2
|
||||
|
||||
shuffle_order = random.sample(range(7), 7)
|
||||
idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
|
||||
if idx1 > idx2:
|
||||
shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[
|
||||
idx2], shuffle_order[idx1]
|
||||
|
||||
for i in shuffle_order:
|
||||
|
||||
if i == 0:
|
||||
image = add_blur_2(image, sf=sf)
|
||||
elif i == 1:
|
||||
image = add_blur_2(image, sf=sf)
|
||||
elif i == 2:
|
||||
a, b = image.shape[1], image.shape[0]
|
||||
# downsample2
|
||||
if random.random() < 0.75:
|
||||
sf1 = random.uniform(1, 2 * sf)
|
||||
image = cv2.resize(
|
||||
image, (int(1 / sf1 * image.shape[1]),
|
||||
int(1 / sf1 * image.shape[0])),
|
||||
interpolation=random.choice([1, 2, 3]))
|
||||
else:
|
||||
k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
|
||||
k_shifted = shift_pixel(k, sf)
|
||||
k_shifted = k_shifted / k_shifted.sum()
|
||||
image = ndimage.filters.convolve(
|
||||
image, np.expand_dims(k_shifted, axis=2), mode='mirror')
|
||||
image = image[0::sf, 0::sf, ...]
|
||||
image = np.clip(image, 0.0, 1.0)
|
||||
elif i == 3:
|
||||
# downsample3
|
||||
image = cv2.resize(
|
||||
image, (int(1 / sf * a), int(1 / sf * b)),
|
||||
interpolation=random.choice([1, 2, 3]))
|
||||
image = np.clip(image, 0.0, 1.0)
|
||||
elif i == 4:
|
||||
# add Gaussian noise
|
||||
image = add_Gaussian_noise(image, noise_level1=2, noise_level2=25)
|
||||
elif i == 5:
|
||||
# add JPEG noise
|
||||
if random.random() < jpeg_prob:
|
||||
image = add_JPEG_noise(image)
|
||||
|
||||
# add final JPEG compression noise
|
||||
image = add_JPEG_noise(image)
|
||||
image = single2uint(image)
|
||||
return image
|
||||
460
modelscope/models/multi_modal/videocomposer/ops/distributed.py
Normal file
460
modelscope/models/multi_modal/videocomposer/ops/distributed.py
Normal file
@@ -0,0 +1,460 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import functools
|
||||
import pickle
|
||||
from collections import OrderedDict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
from torch.autograd import Function
|
||||
|
||||
__all__ = [
|
||||
'is_dist_initialized', 'get_world_size', 'get_rank', 'new_group',
|
||||
'destroy_process_group', 'barrier', 'broadcast', 'all_reduce', 'reduce',
|
||||
'gather', 'all_gather', 'reduce_dict', 'get_global_gloo_group',
|
||||
'generalized_all_gather', 'generalized_gather', 'scatter',
|
||||
'reduce_scatter', 'send', 'recv', 'isend', 'irecv', 'shared_random_seed',
|
||||
'diff_all_gather', 'diff_all_reduce', 'diff_scatter', 'diff_copy',
|
||||
'spherical_kmeans', 'sinkhorn'
|
||||
]
|
||||
|
||||
|
||||
def is_dist_initialized():
|
||||
return dist.is_available() and dist.is_initialized()
|
||||
|
||||
|
||||
def get_world_size(group=None):
|
||||
return dist.get_world_size(group) if is_dist_initialized() else 1
|
||||
|
||||
|
||||
def get_rank(group=None):
|
||||
return dist.get_rank(group) if is_dist_initialized() else 0
|
||||
|
||||
|
||||
def new_group(ranks=None, **kwargs):
|
||||
if is_dist_initialized():
|
||||
return dist.new_group(ranks, **kwargs)
|
||||
return None
|
||||
|
||||
|
||||
def destroy_process_group():
|
||||
if is_dist_initialized():
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
def barrier(group=None, **kwargs):
|
||||
if get_world_size(group) > 1:
|
||||
dist.barrier(group, **kwargs)
|
||||
|
||||
|
||||
def broadcast(tensor, src, group=None, **kwargs):
|
||||
if get_world_size(group) > 1:
|
||||
return dist.broadcast(tensor, src, group, **kwargs)
|
||||
|
||||
|
||||
def all_reduce(tensor, op=dist.ReduceOp.SUM, group=None, **kwargs):
|
||||
if get_world_size(group) > 1:
|
||||
return dist.all_reduce(tensor, op, group, **kwargs)
|
||||
|
||||
|
||||
def reduce(tensor, dst, op=dist.ReduceOp.SUM, group=None, **kwargs):
|
||||
if get_world_size(group) > 1:
|
||||
return dist.reduce(tensor, dst, op, group, **kwargs)
|
||||
|
||||
|
||||
def gather(tensor, dst=0, group=None, **kwargs):
|
||||
rank = get_rank()
|
||||
world_size = get_world_size(group)
|
||||
if world_size == 1:
|
||||
return [tensor]
|
||||
tensor_list = [torch.empty_like(tensor)
|
||||
for _ in range(world_size)] if rank == dst else None
|
||||
dist.gather(tensor, tensor_list, dst, group, **kwargs)
|
||||
return tensor_list
|
||||
|
||||
|
||||
def all_gather(tensor, uniform_size=True, group=None, **kwargs):
|
||||
world_size = get_world_size(group)
|
||||
if world_size == 1:
|
||||
return [tensor]
|
||||
assert tensor.is_contiguous(
|
||||
), 'ops.all_gather requires the tensor to be contiguous()'
|
||||
|
||||
if uniform_size:
|
||||
tensor_list = [torch.empty_like(tensor) for _ in range(world_size)]
|
||||
dist.all_gather(tensor_list, tensor, group, **kwargs)
|
||||
return tensor_list
|
||||
else:
|
||||
# collect tensor shapes across GPUs
|
||||
shape = tuple(tensor.shape)
|
||||
shape_list = generalized_all_gather(shape, group)
|
||||
|
||||
# flatten the tensor
|
||||
tensor = tensor.reshape(-1)
|
||||
size = int(np.prod(shape))
|
||||
size_list = [int(np.prod(u)) for u in shape_list]
|
||||
max_size = max(size_list)
|
||||
|
||||
# pad to maximum size
|
||||
if size != max_size:
|
||||
padding = tensor.new_zeros(max_size - size)
|
||||
tensor = torch.cat([tensor, padding], dim=0)
|
||||
|
||||
# all_gather
|
||||
tensor_list = [torch.empty_like(tensor) for _ in range(world_size)]
|
||||
dist.all_gather(tensor_list, tensor, group, **kwargs)
|
||||
|
||||
# reshape tensors
|
||||
tensor_list = [
|
||||
t[:n].view(s)
|
||||
for t, n, s in zip(tensor_list, size_list, shape_list)
|
||||
]
|
||||
return tensor_list
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def reduce_dict(input_dict, group=None, reduction='mean', **kwargs):
|
||||
assert reduction in ['mean', 'sum']
|
||||
world_size = get_world_size(group)
|
||||
if world_size == 1:
|
||||
return input_dict
|
||||
|
||||
# ensure that the orders of keys are consistent across processes
|
||||
if isinstance(input_dict, OrderedDict):
|
||||
keys = list(input_dict.keys)
|
||||
else:
|
||||
keys = sorted(input_dict.keys())
|
||||
vals = [input_dict[key] for key in keys]
|
||||
vals = torch.stack(vals, dim=0)
|
||||
dist.reduce(vals, dst=0, group=group, **kwargs)
|
||||
if dist.get_rank(group) == 0 and reduction == 'mean':
|
||||
vals /= world_size
|
||||
dist.broadcast(vals, src=0, group=group, **kwargs)
|
||||
reduced_dict = type(input_dict)([(key, val)
|
||||
for key, val in zip(keys, vals)])
|
||||
return reduced_dict
|
||||
|
||||
|
||||
@functools.lru_cache()
|
||||
def get_global_gloo_group():
|
||||
backend = dist.get_backend()
|
||||
assert backend in ['gloo', 'nccl']
|
||||
if backend == 'nccl':
|
||||
return dist.new_group(backend='gloo')
|
||||
else:
|
||||
return dist.group.WORLD
|
||||
|
||||
|
||||
def _serialize_to_tensor(data, group):
|
||||
backend = dist.get_backend(group)
|
||||
assert backend in ['gloo', 'nccl']
|
||||
device = torch.device('cpu' if backend == 'gloo' else 'cuda')
|
||||
|
||||
buffer = pickle.dumps(data)
|
||||
if len(buffer) > 1024**3:
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.warning(
|
||||
'Rank {} trying to all-gather {:.2f} GB of data on device'
|
||||
'{}'.format(get_rank(),
|
||||
len(buffer) / (1024**3), device))
|
||||
storage = torch.ByteStorage.from_buffer(buffer)
|
||||
tensor = torch.ByteTensor(storage).to(device=device)
|
||||
return tensor
|
||||
|
||||
|
||||
def _pad_to_largest_tensor(tensor, group):
|
||||
world_size = dist.get_world_size(group=group)
|
||||
assert world_size >= 1, \
|
||||
'gather/all_gather must be called from ranks within' \
|
||||
'the give group!'
|
||||
local_size = torch.tensor([tensor.numel()],
|
||||
dtype=torch.int64,
|
||||
device=tensor.device)
|
||||
size_list = [
|
||||
torch.zeros([1], dtype=torch.int64, device=tensor.device)
|
||||
for _ in range(world_size)
|
||||
]
|
||||
|
||||
# gather tensors and compute the maximum size
|
||||
dist.all_gather(size_list, local_size, group=group)
|
||||
size_list = [int(size.item()) for size in size_list]
|
||||
max_size = max(size_list)
|
||||
|
||||
# pad tensors to the same size
|
||||
if local_size != max_size:
|
||||
padding = torch.zeros((max_size - local_size, ),
|
||||
dtype=torch.uint8,
|
||||
device=tensor.device)
|
||||
tensor = torch.cat((tensor, padding), dim=0)
|
||||
return size_list, tensor
|
||||
|
||||
|
||||
def generalized_all_gather(data, group=None):
|
||||
if get_world_size(group) == 1:
|
||||
return [data]
|
||||
if group is None:
|
||||
group = get_global_gloo_group()
|
||||
|
||||
tensor = _serialize_to_tensor(data, group)
|
||||
size_list, tensor = _pad_to_largest_tensor(tensor, group)
|
||||
max_size = max(size_list)
|
||||
|
||||
# receiving tensors from all ranks
|
||||
tensor_list = [
|
||||
torch.empty((max_size, ), dtype=torch.uint8, device=tensor.device)
|
||||
for _ in size_list
|
||||
]
|
||||
dist.all_gather(tensor_list, tensor, group=group)
|
||||
|
||||
data_list = []
|
||||
for size, tensor in zip(size_list, tensor_list):
|
||||
buffer = tensor.cpu().numpy().tobytes()[:size]
|
||||
data_list.append(pickle.loads(buffer))
|
||||
return data_list
|
||||
|
||||
|
||||
def generalized_gather(data, dst=0, group=None):
|
||||
world_size = get_world_size(group)
|
||||
if world_size == 1:
|
||||
return [data]
|
||||
if group is None:
|
||||
group = get_global_gloo_group()
|
||||
rank = dist.get_rank()
|
||||
|
||||
tensor = _serialize_to_tensor(data, group)
|
||||
size_list, tensor = _pad_to_largest_tensor(tensor, group)
|
||||
|
||||
# receiving tensors from all ranks to dst
|
||||
if rank == dst:
|
||||
max_size = max(size_list)
|
||||
tensor_list = [
|
||||
torch.empty((max_size, ), dtype=torch.uint8, device=tensor.device)
|
||||
for _ in size_list
|
||||
]
|
||||
dist.gather(tensor, tensor_list, dst=dst, group=group)
|
||||
|
||||
data_list = []
|
||||
for size, tensor in zip(size_list, tensor_list):
|
||||
buffer = tensor.cpu().numpy().tobytes()[:size]
|
||||
data_list.append(pickle.loads(buffer))
|
||||
return data_list
|
||||
else:
|
||||
dist.gather(tensor, [], dst=dst, group=group)
|
||||
return []
|
||||
|
||||
|
||||
def scatter(data, scatter_list=None, src=0, group=None, **kwargs):
|
||||
r"""NOTE: only supports CPU tensor communication.
|
||||
"""
|
||||
if get_world_size(group) > 1:
|
||||
return dist.scatter(data, scatter_list, src, group, **kwargs)
|
||||
|
||||
|
||||
def reduce_scatter(output,
|
||||
input_list,
|
||||
op=dist.ReduceOp.SUM,
|
||||
group=None,
|
||||
**kwargs):
|
||||
if get_world_size(group) > 1:
|
||||
return dist.reduce_scatter(output, input_list, op, group, **kwargs)
|
||||
|
||||
|
||||
def send(tensor, dst, group=None, **kwargs):
|
||||
if get_world_size(group) > 1:
|
||||
assert tensor.is_contiguous(
|
||||
), 'ops.send requires the tensor to be contiguous()'
|
||||
return dist.send(tensor, dst, group, **kwargs)
|
||||
|
||||
|
||||
def recv(tensor, src=None, group=None, **kwargs):
|
||||
if get_world_size(group) > 1:
|
||||
assert tensor.is_contiguous(
|
||||
), 'ops.recv requires the tensor to be contiguous()'
|
||||
return dist.recv(tensor, src, group, **kwargs)
|
||||
|
||||
|
||||
def isend(tensor, dst, group=None, **kwargs):
|
||||
if get_world_size(group) > 1:
|
||||
assert tensor.is_contiguous(
|
||||
), 'ops.isend requires the tensor to be contiguous()'
|
||||
return dist.isend(tensor, dst, group, **kwargs)
|
||||
|
||||
|
||||
def irecv(tensor, src=None, group=None, **kwargs):
|
||||
if get_world_size(group) > 1:
|
||||
assert tensor.is_contiguous(
|
||||
), 'ops.irecv requires the tensor to be contiguous()'
|
||||
return dist.irecv(tensor, src, group, **kwargs)
|
||||
|
||||
|
||||
def shared_random_seed(group=None):
|
||||
seed = np.random.randint(2**31)
|
||||
all_seeds = generalized_all_gather(seed, group)
|
||||
return all_seeds[0]
|
||||
|
||||
|
||||
def _all_gather(x):
|
||||
if not (dist.is_available()
|
||||
and dist.is_initialized()) or dist.get_world_size() == 1:
|
||||
return x
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
tensors = [torch.empty_like(x) for _ in range(world_size)]
|
||||
tensors[rank] = x
|
||||
dist.all_gather(tensors, x)
|
||||
return torch.cat(tensors, dim=0).contiguous()
|
||||
|
||||
|
||||
def _all_reduce(x):
|
||||
if not (dist.is_available()
|
||||
and dist.is_initialized()) or dist.get_world_size() == 1:
|
||||
return x
|
||||
dist.all_reduce(x)
|
||||
return x
|
||||
|
||||
|
||||
def _split(x):
|
||||
if not (dist.is_available()
|
||||
and dist.is_initialized()) or dist.get_world_size() == 1:
|
||||
return x
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
return x.chunk(world_size, dim=0)[rank].contiguous()
|
||||
|
||||
|
||||
class DiffAllGather(Function):
|
||||
r"""Differentiable all-gather.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def symbolic(graph, input):
|
||||
return _all_gather(input)
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
return _all_gather(input)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return _split(grad_output)
|
||||
|
||||
|
||||
class DiffAllReduce(Function):
|
||||
r"""Differentiable all-reducd.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def symbolic(graph, input):
|
||||
return _all_reduce(input)
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
return _all_reduce(input)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return grad_output
|
||||
|
||||
|
||||
class DiffScatter(Function):
|
||||
r"""Differentiable scatter.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def symbolic(ctx, input):
|
||||
return _split(input)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return _all_gather(grad_output)
|
||||
|
||||
|
||||
class DiffCopy(Function):
|
||||
r"""Differentiable copy that reduces all gradients during backward.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def symbolic(graph, input):
|
||||
return input
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
return input
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return _all_reduce(grad_output)
|
||||
|
||||
|
||||
diff_all_gather = DiffAllGather.apply
|
||||
diff_all_reduce = DiffAllReduce.apply
|
||||
diff_scatter = DiffScatter.apply
|
||||
diff_copy = DiffCopy.apply
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def spherical_kmeans(feats, num_clusters, num_iters=10):
|
||||
k, n, c = num_clusters, *feats.size()
|
||||
ones = feats.new_ones(n, dtype=torch.long)
|
||||
|
||||
# distributed settings
|
||||
world_size = get_world_size()
|
||||
|
||||
# init clusters
|
||||
rand_inds = torch.randperm(n)[:int(np.ceil(k / world_size))]
|
||||
clusters = torch.cat(all_gather(feats[rand_inds]), dim=0)[:k]
|
||||
|
||||
# variables
|
||||
new_clusters = feats.new_zeros(k, c)
|
||||
counts = feats.new_zeros(k, dtype=torch.long)
|
||||
|
||||
# iterative Expectation-Maximization
|
||||
for step in range(num_iters + 1):
|
||||
# Expectation step
|
||||
simmat = torch.mm(feats, clusters.t())
|
||||
scores, assigns = simmat.max(dim=1)
|
||||
if step == num_iters:
|
||||
break
|
||||
|
||||
# Maximization step
|
||||
new_clusters.zero_().scatter_add_(0,
|
||||
assigns.unsqueeze(1).repeat(1, c),
|
||||
feats)
|
||||
all_reduce(new_clusters)
|
||||
|
||||
counts.zero_()
|
||||
counts.index_add_(0, assigns, ones)
|
||||
all_reduce(counts)
|
||||
|
||||
mask = (counts > 0)
|
||||
clusters[mask] = new_clusters[mask] / counts[mask].view(-1, 1)
|
||||
clusters = F.normalize(clusters, p=2, dim=1)
|
||||
return clusters, assigns, scores
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sinkhorn(Q, eps=0.5, num_iters=3):
|
||||
# normalize Q
|
||||
Q = torch.exp(Q / eps).t()
|
||||
sum_Q = Q.sum()
|
||||
all_reduce(sum_Q)
|
||||
Q /= sum_Q
|
||||
|
||||
# variables
|
||||
n, m = Q.size()
|
||||
u = Q.new_zeros(n)
|
||||
r = Q.new_ones(n) / n
|
||||
c = Q.new_ones(m) / (m * get_world_size())
|
||||
|
||||
# iterative update
|
||||
cur_sum = Q.sum(dim=1)
|
||||
all_reduce(cur_sum)
|
||||
for i in range(num_iters):
|
||||
u = cur_sum
|
||||
Q *= (r / u).unsqueeze(1)
|
||||
Q *= (c / Q.sum(dim=0)).unsqueeze(0)
|
||||
cur_sum = Q.sum(dim=1)
|
||||
all_reduce(cur_sum)
|
||||
return (Q / Q.sum(dim=0, keepdim=True)).t().float()
|
||||
37
modelscope/models/multi_modal/videocomposer/ops/losses.py
Normal file
37
modelscope/models/multi_modal/videocomposer/ops/losses.py
Normal file
@@ -0,0 +1,37 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
|
||||
__all__ = ['kl_divergence', 'discretized_gaussian_log_likelihood']
|
||||
|
||||
|
||||
def kl_divergence(mu1, logvar1, mu2, logvar2):
|
||||
return 0.5 * (
|
||||
-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + # noqa
|
||||
((mu1 - mu2)**2) * torch.exp(-logvar2))
|
||||
|
||||
|
||||
def standard_normal_cdf(x):
|
||||
r"""A fast approximation of the cumulative distribution function of the standard normal.
|
||||
"""
|
||||
return 0.5 * (1.0 + torch.tanh(
|
||||
math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
||||
|
||||
|
||||
def discretized_gaussian_log_likelihood(x0, mean, log_scale):
|
||||
assert x0.shape == mean.shape == log_scale.shape
|
||||
cx = x0 - mean
|
||||
inv_stdv = torch.exp(-log_scale)
|
||||
cdf_plus = standard_normal_cdf(inv_stdv * (cx + 1.0 / 255.0))
|
||||
cdf_min = standard_normal_cdf(inv_stdv * (cx - 1.0 / 255.0))
|
||||
log_cdf_plus = torch.log(cdf_plus.clamp(min=1e-12))
|
||||
log_one_minus_cdf_min = torch.log((1.0 - cdf_min).clamp(min=1e-12))
|
||||
cdf_delta = cdf_plus - cdf_min
|
||||
log_probs = torch.where(
|
||||
x0 < -0.999, log_cdf_plus,
|
||||
torch.where(x0 > 0.999, log_one_minus_cdf_min,
|
||||
torch.log(cdf_delta.clamp(min=1e-12))))
|
||||
assert log_probs.shape == x0.shape
|
||||
return log_probs
|
||||
@@ -0,0 +1,81 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
__all__ = ['make_irregular_mask', 'make_rectangle_mask', 'make_uncrop']
|
||||
|
||||
|
||||
def make_irregular_mask(w,
|
||||
h,
|
||||
max_angle=4,
|
||||
max_length=200,
|
||||
max_width=100,
|
||||
min_strokes=1,
|
||||
max_strokes=5,
|
||||
mode='line'):
|
||||
# initialize mask
|
||||
assert mode in ['line', 'circle', 'square']
|
||||
mask = np.zeros((h, w), np.float32)
|
||||
|
||||
# draw strokes
|
||||
num_strokes = np.random.randint(min_strokes, max_strokes + 1)
|
||||
for i in range(num_strokes):
|
||||
x1 = np.random.randint(w)
|
||||
y1 = np.random.randint(h)
|
||||
for j in range(1 + np.random.randint(5)):
|
||||
angle = 0.01 + np.random.randint(max_angle)
|
||||
if i % 2 == 0:
|
||||
angle = 2 * 3.1415926 - angle
|
||||
length = 10 + np.random.randint(max_length)
|
||||
radius = 5 + np.random.randint(max_width)
|
||||
x2 = np.clip((x1 + length * np.sin(angle)).astype(np.int32), 0, w)
|
||||
y2 = np.clip((y1 + length * np.cos(angle)).astype(np.int32), 0, h)
|
||||
if mode == 'line':
|
||||
cv2.line(mask, (x1, y1), (x2, y2), 1.0, radius)
|
||||
elif mode == 'circle':
|
||||
cv2.circle(
|
||||
mask, (x1, y1), radius=radius, color=1.0, thickness=-1)
|
||||
elif mode == 'square':
|
||||
radius = radius // 2
|
||||
mask[y1 - radius:y1 + radius, x1 - radius:x1 + radius] = 1
|
||||
x1, y1 = x2, y2
|
||||
return mask
|
||||
|
||||
|
||||
def make_rectangle_mask(w,
|
||||
h,
|
||||
margin=10,
|
||||
min_size=30,
|
||||
max_size=150,
|
||||
min_strokes=1,
|
||||
max_strokes=4):
|
||||
# initialize mask
|
||||
mask = np.zeros((h, w), np.float32)
|
||||
|
||||
# draw rectangles
|
||||
num_strokes = np.random.randint(min_strokes, max_strokes + 1)
|
||||
for i in range(num_strokes):
|
||||
box_w = np.random.randint(min_size, max_size)
|
||||
box_h = np.random.randint(min_size, max_size)
|
||||
x1 = np.random.randint(margin, w - margin - box_w + 1)
|
||||
y1 = np.random.randint(margin, h - margin - box_h + 1)
|
||||
mask[y1:y1 + box_h, x1:x1 + box_w] = 1
|
||||
return mask
|
||||
|
||||
|
||||
def make_uncrop(w, h):
|
||||
# initialize mask
|
||||
mask = np.zeros((h, w), np.float32)
|
||||
|
||||
# randomly halve the image
|
||||
side = np.random.choice([0, 1, 2, 3])
|
||||
if side == 0:
|
||||
mask[:h // 2, :] = 1
|
||||
elif side == 1:
|
||||
mask[h // 2:, :] = 1
|
||||
elif side == 2:
|
||||
mask[:, :w // 2] = 1
|
||||
elif side == 3:
|
||||
mask[:, w // 2:] = 1
|
||||
return mask
|
||||
1037
modelscope/models/multi_modal/videocomposer/ops/utils.py
Normal file
1037
modelscope/models/multi_modal/videocomposer/ops/utils.py
Normal file
File diff suppressed because it is too large
Load Diff
2102
modelscope/models/multi_modal/videocomposer/unet_sd.py
Normal file
2102
modelscope/models/multi_modal/videocomposer/unet_sd.py
Normal file
File diff suppressed because it is too large
Load Diff
273
modelscope/models/multi_modal/videocomposer/utils/config.py
Normal file
273
modelscope/models/multi_modal/videocomposer/utils/config.py
Normal file
@@ -0,0 +1,273 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import os
|
||||
|
||||
import json
|
||||
import yaml
|
||||
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def setup_seed(seed):
|
||||
print('Seed: ', seed)
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
|
||||
class Config(object):
|
||||
|
||||
def __init__(self, load=True, cfg_dict=None, cfg_level=None):
|
||||
self._level = 'cfg' + ('.'
|
||||
+ cfg_level if cfg_level is not None else '')
|
||||
if load:
|
||||
self.args = self._parse_args()
|
||||
logger.info('Loading config from {}.'.format(self.args.cfg_file))
|
||||
self.need_initialization = True
|
||||
cfg_base = self._initialize_cfg()
|
||||
cfg_dict = self._load_yaml(self.args)
|
||||
cfg_dict = self._merge_cfg_from_base(cfg_base, cfg_dict)
|
||||
cfg_dict = self._update_from_args(cfg_dict)
|
||||
self.cfg_dict = cfg_dict
|
||||
self._update_dict(cfg_dict)
|
||||
|
||||
def _parse_args(self):
|
||||
parser = argparse.ArgumentParser(
|
||||
description=
|
||||
'Argparser for configuring [code base name to think of] codebase')
|
||||
parser.add_argument(
|
||||
'--cfg',
|
||||
dest='cfg_file',
|
||||
help='Path to the configuration file',
|
||||
default=
|
||||
'./modelscope/models/multi_modal/videocomposer/configs/exp06_text_depths_vs_style.yaml'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--init_method',
|
||||
help='Initialization method, includes TCP or shared file-system',
|
||||
default='tcp://localhost:9999',
|
||||
type=str,
|
||||
)
|
||||
parser.add_argument(
|
||||
'--seed',
|
||||
type=int,
|
||||
default=8888,
|
||||
help='Need to explore for different videos')
|
||||
parser.add_argument(
|
||||
'--debug',
|
||||
action='store_true',
|
||||
default=False,
|
||||
help='Into debug information')
|
||||
parser.add_argument(
|
||||
'--input_video',
|
||||
default='demo_video/video_8800.mp4',
|
||||
help='input video for full task, or motion vector of input videos',
|
||||
type=str,
|
||||
),
|
||||
parser.add_argument(
|
||||
'--image_path', default='', help='Single Image Input', type=str)
|
||||
parser.add_argument(
|
||||
'--sketch_path', default='', help='Single Sketch Input', type=str)
|
||||
parser.add_argument(
|
||||
'--style_image', help='Single Sketch Input', type=str)
|
||||
parser.add_argument(
|
||||
'--input_text_desc',
|
||||
default=
|
||||
'A colorful and beautiful fish swimming in a small glass bowl with \
|
||||
multicolored piece of stone, Macro Video',
|
||||
type=str,
|
||||
),
|
||||
parser.add_argument(
|
||||
'opts',
|
||||
help='other configurations',
|
||||
default=None,
|
||||
nargs=argparse.REMAINDER)
|
||||
return parser.parse_args()
|
||||
|
||||
def _path_join(self, path_list):
|
||||
path = ''
|
||||
for p in path_list:
|
||||
path += p + '/'
|
||||
return path[:-1]
|
||||
|
||||
def _update_from_args(self, cfg_dict):
|
||||
args = self.args
|
||||
for var in vars(args):
|
||||
cfg_dict[var] = getattr(args, var)
|
||||
return cfg_dict
|
||||
|
||||
def _initialize_cfg(self):
|
||||
if self.need_initialization:
|
||||
self.need_initialization = False
|
||||
if os.path.exists(
|
||||
'./modelscope/models/multi_modal/videocomposer/configs/base.yaml'
|
||||
):
|
||||
with open(
|
||||
'./modelscope/models/multi_modal/videocomposer/configs/base.yaml',
|
||||
'r') as f:
|
||||
cfg = yaml.load(f.read(), Loader=yaml.SafeLoader)
|
||||
else:
|
||||
with open(
|
||||
'./modelscope/models/multi_modal/videocomposer/configs/base.yaml',
|
||||
'r') as f:
|
||||
cfg = yaml.load(f.read(), Loader=yaml.SafeLoader)
|
||||
return cfg
|
||||
|
||||
def _load_yaml(self, args, file_name=''):
|
||||
assert args.cfg_file is not None
|
||||
if not file_name == '': # reading from base file
|
||||
with open(file_name, 'r') as f:
|
||||
cfg = yaml.load(f.read(), Loader=yaml.SafeLoader)
|
||||
else:
|
||||
if os.getcwd().split('/')[-1] == args.cfg_file.split('/')[0]:
|
||||
args.cfg_file = args.cfg_file.replace(
|
||||
os.getcwd().split('/')[-1], './')
|
||||
try:
|
||||
with open(args.cfg_file, 'r') as f:
|
||||
cfg = yaml.load(f.read(), Loader=yaml.SafeLoader)
|
||||
file_name = args.cfg_file
|
||||
except Exception as e:
|
||||
args.cfg_file = os.path.realpath(__file__).split(
|
||||
'/')[-3] + '/' + args.cfg_file
|
||||
with open(args.cfg_file, 'r') as f:
|
||||
cfg = yaml.load(f.read(), Loader=yaml.SafeLoader)
|
||||
file_name = args.cfg_file
|
||||
print(e)
|
||||
|
||||
if '_BASE_RUN' not in cfg.keys() and '_BASE_MODEL' not in cfg.keys(
|
||||
) and '_BASE' not in cfg.keys():
|
||||
return cfg
|
||||
|
||||
if '_BASE' in cfg.keys():
|
||||
if cfg['_BASE'][1] == '.':
|
||||
prev_count = cfg['_BASE'].count('..')
|
||||
cfg_base_file = self._path_join(
|
||||
file_name.split('/')[:(-1 - cfg['_BASE'].count('..'))]
|
||||
+ cfg['_BASE'].split('/')[prev_count:])
|
||||
else:
|
||||
cfg_base_file = cfg['_BASE'].replace(
|
||||
'./',
|
||||
args.cfg_file.replace(args.cfg_file.split('/')[-1], ''))
|
||||
cfg_base = self._load_yaml(args, cfg_base_file)
|
||||
cfg = self._merge_cfg_from_base(cfg_base, cfg)
|
||||
else:
|
||||
if '_BASE_RUN' in cfg.keys():
|
||||
if cfg['_BASE_RUN'][1] == '.':
|
||||
prev_count = cfg['_BASE_RUN'].count('..')
|
||||
cfg_base_file = self._path_join(
|
||||
file_name.split('/')[:(-1 - prev_count)]
|
||||
+ cfg['_BASE_RUN'].split('/')[prev_count:])
|
||||
else:
|
||||
cfg_base_file = cfg['_BASE_RUN'].replace(
|
||||
'./',
|
||||
args.cfg_file.replace(
|
||||
args.cfg_file.split('/')[-1], ''))
|
||||
cfg_base = self._load_yaml(args, cfg_base_file)
|
||||
cfg = self._merge_cfg_from_base(
|
||||
cfg_base, cfg, preserve_base=True)
|
||||
if '_BASE_MODEL' in cfg.keys():
|
||||
if cfg['_BASE_MODEL'][1] == '.':
|
||||
prev_count = cfg['_BASE_MODEL'].count('..')
|
||||
cfg_base_file = self._path_join(
|
||||
file_name.split('/')[:(
|
||||
-1 - cfg['_BASE_MODEL'].count('..'))]
|
||||
+ cfg['_BASE_MODEL'].split('/')[prev_count:])
|
||||
else:
|
||||
cfg_base_file = cfg['_BASE_MODEL'].replace(
|
||||
'./',
|
||||
args.cfg_file.replace(
|
||||
args.cfg_file.split('/')[-1], ''))
|
||||
cfg_base = self._load_yaml(args, cfg_base_file)
|
||||
cfg = self._merge_cfg_from_base(cfg_base, cfg)
|
||||
cfg = self._merge_cfg_from_command(args, cfg)
|
||||
return cfg
|
||||
|
||||
def _merge_cfg_from_base(self, cfg_base, cfg_new, preserve_base=False):
|
||||
for k, v in cfg_new.items():
|
||||
if k in cfg_base.keys():
|
||||
if isinstance(v, dict):
|
||||
self._merge_cfg_from_base(cfg_base[k], v)
|
||||
else:
|
||||
cfg_base[k] = v
|
||||
else:
|
||||
if 'BASE' not in k or preserve_base:
|
||||
cfg_base[k] = v
|
||||
return cfg_base
|
||||
|
||||
def _merge_cfg_from_command(self, args, cfg):
|
||||
assert len(
|
||||
args.opts) % 2 == 0, 'Override list {} has odd length: {}.'.format(
|
||||
args.opts, len(args.opts))
|
||||
keys = args.opts[0::2]
|
||||
vals = args.opts[1::2]
|
||||
|
||||
# maximum supported depth 3
|
||||
for idx, key in enumerate(keys):
|
||||
key_split = key.split('.')
|
||||
assert len(
|
||||
key_split
|
||||
) <= 4, 'Key depth error. \nMaximum depth: 3\n Get depth: {}'.format(
|
||||
len(key_split))
|
||||
assert key_split[0] in cfg.keys(), 'Non-existant key: {}.'.format(
|
||||
key_split[0])
|
||||
if len(key_split) == 2:
|
||||
assert key_split[1] in cfg[
|
||||
key_split[0]].keys(), 'Non-existant key: {}.'.format(key)
|
||||
elif len(key_split) == 3:
|
||||
assert key_split[1] in cfg[
|
||||
key_split[0]].keys(), 'Non-existant key: {}.'.format(key)
|
||||
assert key_split[2] in cfg[key_split[0]][
|
||||
key_split[1]].keys(), 'Non-existant key: {}.'.format(key)
|
||||
elif len(key_split) == 4:
|
||||
assert key_split[1] in cfg[
|
||||
key_split[0]].keys(), 'Non-existant key: {}.'.format(key)
|
||||
assert key_split[2] in cfg[key_split[0]][
|
||||
key_split[1]].keys(), 'Non-existant key: {}.'.format(key)
|
||||
assert key_split[3] in cfg[key_split[0]][key_split[1]][
|
||||
key_split[2]].keys(), 'Non-existant key: {}.'.format(key)
|
||||
if len(key_split) == 1:
|
||||
cfg[key_split[0]] = vals[idx]
|
||||
elif len(key_split) == 2:
|
||||
cfg[key_split[0]][key_split[1]] = vals[idx]
|
||||
elif len(key_split) == 3:
|
||||
cfg[key_split[0]][key_split[1]][key_split[2]] = vals[idx]
|
||||
elif len(key_split) == 4:
|
||||
cfg[key_split[0]][key_split[1]][key_split[2]][
|
||||
key_split[3]] = vals[idx]
|
||||
return cfg
|
||||
|
||||
def _update_dict(self, cfg_dict):
|
||||
|
||||
def recur(key, elem):
|
||||
if type(elem) is dict:
|
||||
return key, Config(load=False, cfg_dict=elem, cfg_level=key)
|
||||
else:
|
||||
if type(elem) is str and elem[1:3] == 'e-':
|
||||
elem = float(elem)
|
||||
return key, elem
|
||||
|
||||
dic = dict(recur(k, v) for k, v in cfg_dict.items())
|
||||
self.__dict__.update(dic)
|
||||
|
||||
def get_args(self):
|
||||
return self.args
|
||||
|
||||
def __repr__(self):
|
||||
return '{}\n'.format(self.dump())
|
||||
|
||||
def dump(self):
|
||||
return json.dumps(self.cfg_dict, indent=2)
|
||||
|
||||
def deep_copy(self):
|
||||
return copy.deepcopy(self)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# debug
|
||||
cfg = Config(load=True)
|
||||
print(cfg.DATA)
|
||||
297
modelscope/models/multi_modal/videocomposer/utils/distributed.py
Normal file
297
modelscope/models/multi_modal/videocomposer/utils/distributed.py
Normal file
@@ -0,0 +1,297 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
||||
"""Distributed helpers."""
|
||||
|
||||
import functools
|
||||
import logging
|
||||
import pickle
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
_LOCAL_PROCESS_GROUP = None
|
||||
|
||||
|
||||
def all_gather(tensors):
|
||||
"""
|
||||
All gathers the provided tensors from all processes across machines.
|
||||
Args:
|
||||
tensors (list): tensors to perform all gather across all processes in
|
||||
all machines.
|
||||
"""
|
||||
|
||||
gather_list = []
|
||||
output_tensor = []
|
||||
world_size = dist.get_world_size()
|
||||
for tensor in tensors:
|
||||
tensor_placeholder = [
|
||||
torch.ones_like(tensor) for _ in range(world_size)
|
||||
]
|
||||
dist.all_gather(tensor_placeholder, tensor, async_op=False)
|
||||
gather_list.append(tensor_placeholder)
|
||||
for gathered_tensor in gather_list:
|
||||
output_tensor.append(torch.cat(gathered_tensor, dim=0))
|
||||
return output_tensor
|
||||
|
||||
|
||||
def all_reduce(tensors, average=True):
|
||||
"""
|
||||
All reduce the provided tensors from all processes across machines.
|
||||
Args:
|
||||
tensors (list): tensors to perform all reduce across all processes in
|
||||
all machines.
|
||||
average (bool): scales the reduced tensor by the number of overall
|
||||
processes across all machines.
|
||||
"""
|
||||
|
||||
for tensor in tensors:
|
||||
dist.all_reduce(tensor, async_op=False)
|
||||
if average:
|
||||
world_size = dist.get_world_size()
|
||||
for tensor in tensors:
|
||||
tensor.mul_(1.0 / world_size)
|
||||
return tensors
|
||||
|
||||
|
||||
def init_process_group(
|
||||
local_rank,
|
||||
local_world_size,
|
||||
shard_id,
|
||||
num_shards,
|
||||
init_method,
|
||||
dist_backend='nccl',
|
||||
):
|
||||
"""
|
||||
Initializes the default process group.
|
||||
Args:
|
||||
local_rank (int): the rank on the current local machine.
|
||||
local_world_size (int): the world size (number of processes running) on
|
||||
the current local machine.
|
||||
shard_id (int): the shard index (machine rank) of the current machine.
|
||||
num_shards (int): number of shards for distributed training.
|
||||
init_method (string): supporting three different methods for
|
||||
initializing process groups:
|
||||
"file": use shared file system to initialize the groups across
|
||||
different processes.
|
||||
"tcp": use tcp address to initialize the groups across different
|
||||
dist_backend (string): backend to use for distributed training. Options
|
||||
includes gloo, mpi and nccl, the details can be found here:
|
||||
https://pytorch.org/docs/stable/distributed.html
|
||||
"""
|
||||
# Sets the GPU to use.
|
||||
torch.cuda.set_device(local_rank)
|
||||
# Initialize the process group.
|
||||
proc_rank = local_rank + shard_id * local_world_size
|
||||
world_size = local_world_size * num_shards
|
||||
dist.init_process_group(
|
||||
backend=dist_backend,
|
||||
init_method=init_method,
|
||||
world_size=world_size,
|
||||
rank=proc_rank,
|
||||
)
|
||||
|
||||
|
||||
def is_master_proc(num_gpus=8):
|
||||
"""
|
||||
Determines if the current process is the master process.
|
||||
"""
|
||||
if torch.distributed.is_initialized():
|
||||
return dist.get_rank() % num_gpus == 0
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
def get_world_size():
|
||||
"""
|
||||
Get the size of the world.
|
||||
"""
|
||||
if not dist.is_available():
|
||||
return 1
|
||||
if not dist.is_initialized():
|
||||
return 1
|
||||
return dist.get_world_size()
|
||||
|
||||
|
||||
def get_rank():
|
||||
"""
|
||||
Get the rank of the current process.
|
||||
"""
|
||||
if not dist.is_available():
|
||||
return 0
|
||||
if not dist.is_initialized():
|
||||
return 0
|
||||
return dist.get_rank()
|
||||
|
||||
|
||||
def synchronize():
|
||||
"""
|
||||
Helper function to synchronize (barrier) among all processes when
|
||||
using distributed training
|
||||
"""
|
||||
if not dist.is_available():
|
||||
return
|
||||
if not dist.is_initialized():
|
||||
return
|
||||
world_size = dist.get_world_size()
|
||||
if world_size == 1:
|
||||
return
|
||||
dist.barrier()
|
||||
|
||||
|
||||
@functools.lru_cache()
|
||||
def _get_global_gloo_group():
|
||||
"""
|
||||
Return a process group based on gloo backend, containing all the ranks
|
||||
The result is cached.
|
||||
Returns:
|
||||
(group): pytorch dist group.
|
||||
"""
|
||||
if dist.get_backend() == 'nccl':
|
||||
return dist.new_group(backend='gloo')
|
||||
else:
|
||||
return dist.group.WORLD
|
||||
|
||||
|
||||
def _serialize_to_tensor(data, group):
|
||||
"""
|
||||
Seriialize the tensor to ByteTensor. Note that only `gloo` and `nccl`
|
||||
backend is supported.
|
||||
Args:
|
||||
data (data): data to be serialized.
|
||||
group (group): pytorch dist group.
|
||||
Returns:
|
||||
tensor (ByteTensor): tensor that serialized.
|
||||
"""
|
||||
|
||||
backend = dist.get_backend(group)
|
||||
assert backend in ['gloo', 'nccl']
|
||||
device = torch.device('cpu' if backend == 'gloo' else 'cuda')
|
||||
|
||||
buffer = pickle.dumps(data)
|
||||
if len(buffer) > 1024**3:
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.warning(
|
||||
'Rank {} trying to all-gather {:.2f} GB of data on device {}'.
|
||||
format(get_rank(),
|
||||
len(buffer) / (1024**3), device))
|
||||
storage = torch.ByteStorage.from_buffer(buffer)
|
||||
tensor = torch.ByteTensor(storage).to(device=device)
|
||||
return tensor
|
||||
|
||||
|
||||
def _pad_to_largest_tensor(tensor, group):
|
||||
"""
|
||||
Padding all the tensors from different GPUs to the largest ones.
|
||||
Args:
|
||||
tensor (tensor): tensor to pad.
|
||||
group (group): pytorch dist group.
|
||||
Returns:
|
||||
list[int]: size of the tensor, on each rank
|
||||
Tensor: padded tensor that has the max size
|
||||
"""
|
||||
world_size = dist.get_world_size(group=group)
|
||||
assert (
|
||||
world_size >= 1
|
||||
), 'comm.gather/all_gather must be called from ranks within the given group!'
|
||||
local_size = torch.tensor([tensor.numel()],
|
||||
dtype=torch.int64,
|
||||
device=tensor.device)
|
||||
size_list = [
|
||||
torch.zeros([1], dtype=torch.int64, device=tensor.device)
|
||||
for _ in range(world_size)
|
||||
]
|
||||
dist.all_gather(size_list, local_size, group=group)
|
||||
size_list = [int(size.item()) for size in size_list]
|
||||
|
||||
max_size = max(size_list)
|
||||
|
||||
# we pad the tensor because torch all_gather does not support
|
||||
# gathering tensors of different shapes
|
||||
if local_size != max_size:
|
||||
padding = torch.zeros((max_size - local_size, ),
|
||||
dtype=torch.uint8,
|
||||
device=tensor.device)
|
||||
tensor = torch.cat((tensor, padding), dim=0)
|
||||
return size_list, tensor
|
||||
|
||||
|
||||
def all_gather_unaligned(data, group=None):
|
||||
"""
|
||||
Run all_gather on arbitrary picklable data (not necessarily tensors).
|
||||
|
||||
Args:
|
||||
data: any picklable object
|
||||
group: a torch process group. By default, will use a group which
|
||||
contains all ranks on gloo backend.
|
||||
|
||||
Returns:
|
||||
list[data]: list of data gathered from each rank
|
||||
"""
|
||||
if get_world_size() == 1:
|
||||
return [data]
|
||||
if group is None:
|
||||
group = _get_global_gloo_group()
|
||||
if dist.get_world_size(group) == 1:
|
||||
return [data]
|
||||
|
||||
tensor = _serialize_to_tensor(data, group)
|
||||
|
||||
size_list, tensor = _pad_to_largest_tensor(tensor, group)
|
||||
max_size = max(size_list)
|
||||
|
||||
# receiving Tensor from all ranks
|
||||
tensor_list = [
|
||||
torch.empty((max_size, ), dtype=torch.uint8, device=tensor.device)
|
||||
for _ in size_list
|
||||
]
|
||||
dist.all_gather(tensor_list, tensor, group=group)
|
||||
|
||||
data_list = []
|
||||
for size, tensor in zip(size_list, tensor_list):
|
||||
buffer = tensor.cpu().numpy().tobytes()[:size]
|
||||
data_list.append(pickle.loads(buffer))
|
||||
|
||||
return data_list
|
||||
|
||||
|
||||
def init_distributed_training(cfg):
|
||||
"""
|
||||
Initialize variables needed for distributed training.
|
||||
"""
|
||||
if cfg.NUM_GPUS <= 1:
|
||||
return
|
||||
num_gpus_per_machine = cfg.NUM_GPUS
|
||||
num_machines = dist.get_world_size() // num_gpus_per_machine
|
||||
for i in range(num_machines):
|
||||
ranks_on_i = list(
|
||||
range(i * num_gpus_per_machine, (i + 1) * num_gpus_per_machine))
|
||||
pg = dist.new_group(ranks_on_i)
|
||||
if i == cfg.SHARD_ID:
|
||||
global _LOCAL_PROCESS_GROUP
|
||||
_LOCAL_PROCESS_GROUP = pg
|
||||
|
||||
|
||||
def get_local_size() -> int:
|
||||
"""
|
||||
Returns:
|
||||
The size of the per-machine process group,
|
||||
i.e. the number of processes per machine.
|
||||
"""
|
||||
if not dist.is_available():
|
||||
return 1
|
||||
if not dist.is_initialized():
|
||||
return 1
|
||||
return dist.get_world_size(group=_LOCAL_PROCESS_GROUP)
|
||||
|
||||
|
||||
def get_local_rank() -> int:
|
||||
"""
|
||||
Returns:
|
||||
The rank of the current process within the local (per-machine) process group.
|
||||
"""
|
||||
if not dist.is_available():
|
||||
return 0
|
||||
if not dist.is_initialized():
|
||||
return 0
|
||||
assert _LOCAL_PROCESS_GROUP is not None
|
||||
return dist.get_rank(group=_LOCAL_PROCESS_GROUP)
|
||||
955
modelscope/models/multi_modal/videocomposer/utils/utils.py
Normal file
955
modelscope/models/multi_modal/videocomposer/utils/utils.py
Normal file
@@ -0,0 +1,955 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import base64
|
||||
import binascii
|
||||
import copy
|
||||
import glob
|
||||
import gzip
|
||||
import hashlib
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import os.path as osp
|
||||
import pickle
|
||||
import random
|
||||
import sys
|
||||
import time
|
||||
import urllib.request
|
||||
import zipfile
|
||||
from io import BytesIO
|
||||
from multiprocessing.pool import ThreadPool as Pool
|
||||
|
||||
import imageio
|
||||
import json
|
||||
import numpy as np
|
||||
import oss2 as oss
|
||||
import requests
|
||||
import skvideo.io
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchvision.utils as tvutils
|
||||
from einops import rearrange
|
||||
from PIL import Image
|
||||
|
||||
__all__ = [
|
||||
'parse_oss_url', 'parse_bucket', 'read', 'read_image', 'read_gzip',
|
||||
'ceil_divide', 'to_device', 'put_object', 'put_torch_object',
|
||||
'put_object_from_file', 'get_object', 'get_object_to_file', 'rand_name',
|
||||
'save_image', 'save_video', 'save_video_vs_conditions',
|
||||
'save_video_multiple_conditions_with_data',
|
||||
'save_video_multiple_conditions', 'download_video_to_file',
|
||||
'save_video_grid_mp4', 'save_caps', 'ema', 'parallel', 'exists',
|
||||
'download', 'unzip', 'load_state_dict', 'inverse_indices',
|
||||
'detect_duplicates', 'read_tfs', 'md5', 'rope', 'format_state',
|
||||
'breakup_grid', 'huggingface_tokenizer', 'huggingface_model'
|
||||
]
|
||||
|
||||
TFS_CLIENT = None
|
||||
|
||||
|
||||
def DOWNLOAD_TO_CACHE(oss_key,
|
||||
file_or_dirname=None,
|
||||
cache_dir=osp.join(
|
||||
'/'.join(osp.abspath(__file__).split('/')[:-2]),
|
||||
'model_weights')):
|
||||
r"""Download OSS [file or folder] to the cache folder.
|
||||
Only the 0th process on each node will run the downloading.
|
||||
Barrier all processes until the downloading is completed.
|
||||
"""
|
||||
# source and target paths
|
||||
base_path = osp.join(cache_dir, file_or_dirname or osp.basename(oss_key))
|
||||
|
||||
return base_path
|
||||
|
||||
|
||||
def find_free_port():
|
||||
"""https://stackoverflow.com/questions/1365265/on-localhost-how-do-i-pick-a-free-port-number"""
|
||||
import socket
|
||||
from contextlib import closing
|
||||
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
|
||||
s.bind(('', 0))
|
||||
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
return str(s.getsockname()[1])
|
||||
|
||||
|
||||
def setup_seed(seed):
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
|
||||
|
||||
def parse_oss_url(path):
|
||||
if path.startswith('oss://'):
|
||||
path = path[len('oss://'):]
|
||||
|
||||
# configs
|
||||
configs = {
|
||||
'endpoint': os.getenv('OSS_ENDPOINT', None),
|
||||
'accessKeyID': os.getenv('OSS_ACCESS_KEY_ID', None),
|
||||
'accessKeySecret': os.getenv('OSS_ACCESS_KEY_SECRET', None),
|
||||
'securityToken': os.getenv('OSS_SECURITY_TOKEN', None)
|
||||
}
|
||||
bucket, path = path.split('/', maxsplit=1)
|
||||
if '?' in bucket:
|
||||
bucket, config = bucket.split('?', maxsplit=1)
|
||||
for pair in config.split('&'):
|
||||
k, v = pair.split('=', maxsplit=1)
|
||||
configs[k] = v
|
||||
|
||||
# session
|
||||
session = parse_oss_url._sessions.setdefault(f'{bucket}@{os.getpid()}',
|
||||
oss.Session())
|
||||
|
||||
# bucket
|
||||
bucket = oss.Bucket(
|
||||
auth=oss.Auth(configs['accessKeyID'], configs['accessKeySecret']),
|
||||
endpoint=configs['endpoint'],
|
||||
bucket_name=bucket,
|
||||
session=session)
|
||||
return bucket, path
|
||||
|
||||
|
||||
parse_oss_url._sessions = {}
|
||||
|
||||
|
||||
def parse_bucket(url):
|
||||
return parse_oss_url(osp.join(url, '_placeholder'))[0]
|
||||
|
||||
|
||||
def read(filename, mode='r', retry=5):
|
||||
assert mode in ['r', 'rb']
|
||||
exception = None
|
||||
for _ in range(retry):
|
||||
try:
|
||||
if filename.startswith('oss://'):
|
||||
bucket, path = parse_oss_url(filename)
|
||||
content = bucket.get_object(path).read()
|
||||
if mode == 'r':
|
||||
content = content.decode('utf-8')
|
||||
elif filename.startswith('http'):
|
||||
content = requests.get(filename).content
|
||||
if mode == 'r':
|
||||
content = content.decode('utf-8')
|
||||
else:
|
||||
with open(filename, mode=mode) as f:
|
||||
content = f.read()
|
||||
return content
|
||||
except Exception as e:
|
||||
exception = e
|
||||
continue
|
||||
else:
|
||||
raise exception
|
||||
|
||||
|
||||
def read_image(filename, retry=5):
|
||||
exception = None
|
||||
for _ in range(retry):
|
||||
try:
|
||||
return Image.open(BytesIO(read(filename, mode='rb', retry=retry)))
|
||||
except Exception as e:
|
||||
exception = e
|
||||
continue
|
||||
else:
|
||||
raise exception
|
||||
|
||||
|
||||
def download_video_to_file(filename, local_file, retry=5):
|
||||
exception = None
|
||||
for _ in range(retry):
|
||||
try:
|
||||
bucket, path = parse_oss_url(filename)
|
||||
bucket.get_object_to_file(path, local_file)
|
||||
break
|
||||
except Exception as e:
|
||||
exception = e
|
||||
continue
|
||||
else:
|
||||
raise exception
|
||||
|
||||
|
||||
def read_gzip(filename, retry=5):
|
||||
exception = None
|
||||
for _ in range(retry):
|
||||
try:
|
||||
remove = False
|
||||
if filename.startswith('oss://'):
|
||||
bucket, path = parse_oss_url(filename)
|
||||
filename = rand_name(suffix=osp.splitext(filename)[1])
|
||||
bucket.get_object_to_file(path, filename)
|
||||
remove = True
|
||||
with gzip.open(filename) as f:
|
||||
content = f.read()
|
||||
if remove:
|
||||
os.remove(filename)
|
||||
return content
|
||||
except Exception as e:
|
||||
exception = e
|
||||
continue
|
||||
else:
|
||||
raise exception
|
||||
|
||||
|
||||
def ceil_divide(a, b):
|
||||
return int(math.ceil(a / b))
|
||||
|
||||
|
||||
def to_device(batch, device, non_blocking=False):
|
||||
if isinstance(batch, (list, tuple)):
|
||||
return type(batch)([to_device(u, device, non_blocking) for u in batch])
|
||||
elif isinstance(batch, dict):
|
||||
return type(batch)([(k, to_device(v, device, non_blocking))
|
||||
for k, v in batch.items()])
|
||||
elif isinstance(batch, torch.Tensor) and batch.device != device:
|
||||
batch = batch.to(device, non_blocking=non_blocking)
|
||||
return batch
|
||||
|
||||
|
||||
def put_object(bucket, oss_key, data, retry=5):
|
||||
exception = None
|
||||
for _ in range(retry):
|
||||
try:
|
||||
return bucket.put_object(oss_key, data)
|
||||
except Exception as e:
|
||||
exception = e
|
||||
continue
|
||||
else:
|
||||
print(
|
||||
f'put_object to {oss_key} failed with error: {exception}',
|
||||
flush=True)
|
||||
|
||||
|
||||
def put_torch_object(bucket, oss_key, data, retry=5):
|
||||
exception = None
|
||||
for _ in range(retry):
|
||||
try:
|
||||
buffer = BytesIO()
|
||||
torch.save(data, buffer)
|
||||
return bucket.put_object(oss_key, buffer.getvalue())
|
||||
except Exception as e:
|
||||
exception = e
|
||||
continue
|
||||
else:
|
||||
print(
|
||||
f'put_torch_object to {oss_key} failed with error: {exception}',
|
||||
flush=True)
|
||||
|
||||
|
||||
def put_object_from_file(bucket, oss_key, filename, retry=5):
|
||||
exception = None
|
||||
for _ in range(retry):
|
||||
try:
|
||||
return bucket.put_object_from_file(oss_key, filename)
|
||||
except Exception as e:
|
||||
exception = e
|
||||
continue
|
||||
else:
|
||||
print(
|
||||
f'put_object_from_file to {oss_key} failed with error: {exception}',
|
||||
flush=True)
|
||||
|
||||
|
||||
def get_object(bucket, oss_key, retry=5):
|
||||
exception = None
|
||||
for _ in range(retry):
|
||||
try:
|
||||
return bucket.get_object(oss_key).read()
|
||||
except Exception as e:
|
||||
exception = e
|
||||
continue
|
||||
else:
|
||||
print(
|
||||
f'get_object from {oss_key} failed with error: {exception}',
|
||||
flush=True)
|
||||
|
||||
|
||||
def get_object_to_file(bucket, oss_key, filename, retry=5):
|
||||
exception = None
|
||||
for _ in range(retry):
|
||||
try:
|
||||
return bucket.get_object_to_file(oss_key, filename)
|
||||
except Exception as e:
|
||||
exception = e
|
||||
continue
|
||||
else:
|
||||
print(
|
||||
f'get_object_to_file from {oss_key} failed with error: {exception}',
|
||||
flush=True)
|
||||
|
||||
|
||||
def rand_name(length=8, suffix=''):
|
||||
name = binascii.b2a_hex(os.urandom(length)).decode('utf-8')
|
||||
if suffix:
|
||||
if not suffix.startswith('.'):
|
||||
suffix = '.' + suffix
|
||||
name += suffix
|
||||
return name
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def save_image(bucket,
|
||||
oss_key,
|
||||
tensor,
|
||||
nrow=8,
|
||||
normalize=True,
|
||||
range=(-1, 1),
|
||||
retry=5):
|
||||
filename = rand_name(suffix='.jpg')
|
||||
for _ in [None] * retry:
|
||||
try:
|
||||
tvutils.save_image(
|
||||
tensor, filename, nrow=nrow, normalize=normalize, range=range)
|
||||
bucket.put_object_from_file(oss_key, filename)
|
||||
exception = None
|
||||
break
|
||||
except Exception as e:
|
||||
exception = e
|
||||
continue
|
||||
|
||||
# remove temporary file
|
||||
if osp.exists(filename):
|
||||
os.remove(filename)
|
||||
if exception is not None:
|
||||
print(
|
||||
'save image to {} failed, error: {}'.format(oss_key, exception),
|
||||
flush=True)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def video_tensor_to_gif(tensor, path, duration=120, loop=0, optimize=True):
|
||||
tensor = tensor.permute(1, 2, 3, 0)
|
||||
images = tensor.unbind(dim=0)
|
||||
images = [(image.numpy() * 255).astype('uint8') for image in images]
|
||||
imageio.mimwrite(path, images, fps=8)
|
||||
return images
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def save_video(bucket,
|
||||
oss_key,
|
||||
tensor,
|
||||
mean=[0.5, 0.5, 0.5],
|
||||
std=[0.5, 0.5, 0.5],
|
||||
nrow=8,
|
||||
retry=5):
|
||||
mean = torch.tensor(mean, device=tensor.device).view(1, -1, 1, 1, 1)
|
||||
std = torch.tensor(std, device=tensor.device).view(1, -1, 1, 1, 1)
|
||||
tensor = tensor.mul_(std).add_(mean)
|
||||
tensor.clamp_(0, 1)
|
||||
|
||||
filename = rand_name(suffix='.gif')
|
||||
for _ in [None] * retry:
|
||||
try:
|
||||
one_gif = rearrange(
|
||||
tensor, '(i j) c f h w -> c f (i h) (j w)', i=nrow)
|
||||
video_tensor_to_gif(one_gif, filename)
|
||||
bucket.put_object_from_file(oss_key, filename)
|
||||
exception = None
|
||||
break
|
||||
except Exception as e:
|
||||
exception = e
|
||||
continue
|
||||
|
||||
# remove temporary file
|
||||
if osp.exists(filename):
|
||||
os.remove(filename)
|
||||
if exception is not None:
|
||||
print(
|
||||
'save video to {} failed, error: {}'.format(oss_key, exception),
|
||||
flush=True)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def save_video_multiple_conditions(oss_key,
|
||||
video_tensor,
|
||||
model_kwargs,
|
||||
source_imgs,
|
||||
palette,
|
||||
mean=[0.5, 0.5, 0.5],
|
||||
std=[0.5, 0.5, 0.5],
|
||||
nrow=8,
|
||||
retry=5,
|
||||
save_origin_video=True,
|
||||
bucket=None):
|
||||
mean = torch.tensor(mean, device=video_tensor.device).view(1, -1, 1, 1, 1)
|
||||
std = torch.tensor(std, device=video_tensor.device).view(1, -1, 1, 1, 1)
|
||||
video_tensor = video_tensor.mul_(std).add_(mean)
|
||||
try:
|
||||
video_tensor.clamp_(0, 1)
|
||||
except Exception as e:
|
||||
video_tensor = video_tensor.float().clamp_(0, 1)
|
||||
print(e)
|
||||
video_tensor = video_tensor.cpu()
|
||||
|
||||
b, c, n, h, w = video_tensor.shape
|
||||
source_imgs = F.adaptive_avg_pool3d(source_imgs, (n, h, w))
|
||||
source_imgs = source_imgs.cpu()
|
||||
|
||||
model_kwargs_channel3 = {}
|
||||
for key, conditions in model_kwargs[0].items():
|
||||
if conditions.shape[-1] == 1024:
|
||||
# Skip for style embeding
|
||||
continue
|
||||
if len(conditions.shape) == 3:
|
||||
conditions_np = conditions.cpu().numpy()
|
||||
conditions = []
|
||||
for i in conditions_np:
|
||||
vis_i = []
|
||||
for j in i:
|
||||
vis_i.append(
|
||||
palette.get_palette_image(
|
||||
j, percentile=90, width=256, height=256))
|
||||
conditions.append(np.stack(vis_i))
|
||||
conditions = torch.from_numpy(np.stack(conditions))
|
||||
conditions = rearrange(conditions, 'b n h w c -> b c n h w')
|
||||
else:
|
||||
if conditions.size(1) == 1:
|
||||
conditions = torch.cat([conditions, conditions, conditions],
|
||||
dim=1)
|
||||
conditions = F.adaptive_avg_pool3d(conditions, (n, h, w))
|
||||
if conditions.size(1) == 2:
|
||||
conditions = torch.cat([conditions, conditions[:, :1, ]],
|
||||
dim=1)
|
||||
conditions = F.adaptive_avg_pool3d(conditions, (n, h, w))
|
||||
elif conditions.size(1) == 3:
|
||||
conditions = F.adaptive_avg_pool3d(conditions, (n, h, w))
|
||||
elif conditions.size(1) == 4:
|
||||
color = ((conditions[:, 0:3] + 1.) / 2.)
|
||||
alpha = conditions[:, 3:4]
|
||||
conditions = color * alpha + 1.0 * (1.0 - alpha)
|
||||
conditions = F.adaptive_avg_pool3d(conditions, (n, h, w))
|
||||
model_kwargs_channel3[key] = conditions.cpu(
|
||||
) if conditions.is_cuda else conditions
|
||||
|
||||
filename = oss_key
|
||||
for _ in [None] * retry:
|
||||
try:
|
||||
vid_gif = rearrange(
|
||||
video_tensor, '(i j) c f h w -> c f (i h) (j w)', i=nrow)
|
||||
cons_list = [
|
||||
rearrange(con, '(i j) c f h w -> c f (i h) (j w)', i=nrow)
|
||||
for _, con in model_kwargs_channel3.items()
|
||||
]
|
||||
source_imgs = rearrange(
|
||||
source_imgs, '(i j) c f h w -> c f (i h) (j w)', i=nrow)
|
||||
|
||||
if save_origin_video:
|
||||
vid_gif = torch.cat(
|
||||
[
|
||||
source_imgs,
|
||||
] + cons_list + [
|
||||
vid_gif,
|
||||
], dim=3)
|
||||
else:
|
||||
vid_gif = torch.cat(
|
||||
cons_list + [
|
||||
vid_gif,
|
||||
], dim=3)
|
||||
|
||||
video_tensor_to_gif(vid_gif, filename)
|
||||
exception = None
|
||||
break
|
||||
except Exception as e:
|
||||
exception = e
|
||||
continue
|
||||
if exception is not None:
|
||||
logging.info('save video to {} failed, error: {}'.format(
|
||||
oss_key, exception))
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def save_video_multiple_conditions_with_data(bucket,
|
||||
video_save_key,
|
||||
gt_video_save_key,
|
||||
vis_oss_key,
|
||||
video_tensor,
|
||||
model_kwargs,
|
||||
source_imgs,
|
||||
palette,
|
||||
mean=[0.5, 0.5, 0.5],
|
||||
std=[0.5, 0.5, 0.5],
|
||||
nrow=8,
|
||||
retry=5):
|
||||
mean = torch.tensor(mean, device=video_tensor.device).view(1, -1, 1, 1, 1)
|
||||
std = torch.tensor(std, device=video_tensor.device).view(1, -1, 1, 1, 1)
|
||||
video_tensor = video_tensor.mul_(std).add_(mean)
|
||||
video_tensor.clamp_(0, 1)
|
||||
|
||||
b, c, n, h, w = video_tensor.shape
|
||||
source_imgs = F.adaptive_avg_pool3d(source_imgs, (n, h, w))
|
||||
source_imgs = source_imgs.cpu()
|
||||
|
||||
model_kwargs_channel3 = {}
|
||||
for key, conditions in model_kwargs[0].items():
|
||||
if len(conditions.shape) == 3:
|
||||
conditions_np = conditions.cpu().numpy()
|
||||
conditions = []
|
||||
for i in conditions_np:
|
||||
vis_i = []
|
||||
for j in i:
|
||||
vis_i.append(
|
||||
palette.get_palette_image(
|
||||
j, percentile=90, width=256, height=256))
|
||||
conditions.append(np.stack(vis_i))
|
||||
conditions = torch.from_numpy(np.stack(conditions))
|
||||
conditions = rearrange(conditions, 'b n h w c -> b c n h w')
|
||||
else:
|
||||
if conditions.size(1) == 1:
|
||||
conditions = torch.cat([conditions, conditions, conditions],
|
||||
dim=1)
|
||||
conditions = F.adaptive_avg_pool3d(conditions, (n, h, w))
|
||||
if conditions.size(1) == 2:
|
||||
conditions = torch.cat([conditions, conditions[:, :1, ]],
|
||||
dim=1)
|
||||
conditions = F.adaptive_avg_pool3d(conditions, (n, h, w))
|
||||
elif conditions.size(1) == 3:
|
||||
conditions = F.adaptive_avg_pool3d(conditions, (n, h, w))
|
||||
elif conditions.size(1) == 4:
|
||||
color = ((conditions[:, 0:3] + 1.) / 2.)
|
||||
alpha = conditions[:, 3:4]
|
||||
conditions = color * alpha + 1.0 * (1.0 - alpha)
|
||||
conditions = F.adaptive_avg_pool3d(conditions, (n, h, w))
|
||||
model_kwargs_channel3[key] = conditions.cpu(
|
||||
) if conditions.is_cuda else conditions
|
||||
|
||||
copy_video_tensor = video_tensor.clone()
|
||||
copy_source_imgs = source_imgs.clone()
|
||||
|
||||
filename = rand_name(suffix='.gif')
|
||||
for _ in [None] * retry:
|
||||
try:
|
||||
vid_gif = rearrange(
|
||||
video_tensor, '(i j) c f h w -> c f (i h) (j w)', i=nrow)
|
||||
cons_list = [
|
||||
rearrange(con, '(i j) c f h w -> c f (i h) (j w)', j=nrow)
|
||||
for _, con in model_kwargs_channel3.items()
|
||||
]
|
||||
source_imgs = rearrange(
|
||||
source_imgs, '(i j) c f h w -> c f (i h) (j w)', i=nrow)
|
||||
vid_gif = torch.cat(
|
||||
[
|
||||
source_imgs,
|
||||
] + cons_list + [
|
||||
vid_gif,
|
||||
], dim=3)
|
||||
|
||||
video_tensor_to_gif(vid_gif, filename)
|
||||
bucket.put_object_from_file(vis_oss_key, filename)
|
||||
exception = None
|
||||
break
|
||||
except Exception as e:
|
||||
exception = e
|
||||
continue
|
||||
|
||||
# remove temporary file
|
||||
if osp.exists(filename):
|
||||
os.remove(filename)
|
||||
|
||||
filename_pred = rand_name(suffix='.pkl')
|
||||
for _ in [None] * retry:
|
||||
try:
|
||||
copy_video_np = (copy_video_tensor.numpy() * 255).astype('uint8')
|
||||
pickle.dump(copy_video_np, open(filename_pred, 'wb'))
|
||||
bucket.put_object_from_file(video_save_key, filename_pred)
|
||||
break
|
||||
except Exception as e:
|
||||
print('error! ', video_save_key, e)
|
||||
continue
|
||||
|
||||
# remove temporary file
|
||||
if osp.exists(filename_pred):
|
||||
os.remove(filename_pred)
|
||||
|
||||
filename_gt = rand_name(suffix='.pkl')
|
||||
for _ in [None] * retry:
|
||||
try:
|
||||
copy_source_np = (copy_source_imgs.numpy() * 255).astype('uint8')
|
||||
pickle.dump(copy_source_np, open(filename_gt, 'wb'))
|
||||
bucket.put_object_from_file(gt_video_save_key, filename_gt)
|
||||
break
|
||||
except Exception as e:
|
||||
print('error! ', gt_video_save_key, e)
|
||||
continue
|
||||
|
||||
# remove temporary file
|
||||
if osp.exists(filename_gt):
|
||||
os.remove(filename_gt)
|
||||
|
||||
if exception is not None:
|
||||
print(
|
||||
'save video to {} failed, error: {}'.format(
|
||||
vis_oss_key, exception),
|
||||
flush=True)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def save_video_vs_conditions(bucket,
|
||||
oss_key,
|
||||
video_tensor,
|
||||
conditions,
|
||||
source_imgs,
|
||||
mean=[0.5, 0.5, 0.5],
|
||||
std=[0.5, 0.5, 0.5],
|
||||
nrow=8,
|
||||
retry=5):
|
||||
mean = torch.tensor(mean, device=video_tensor.device).view(1, -1, 1, 1, 1)
|
||||
std = torch.tensor(std, device=video_tensor.device).view(1, -1, 1, 1, 1)
|
||||
video_tensor = video_tensor.mul_(std).add_(mean)
|
||||
video_tensor.clamp_(0, 1)
|
||||
|
||||
b, c, n, h, w = video_tensor.shape
|
||||
source_imgs = F.adaptive_avg_pool3d(source_imgs, (n, h, w))
|
||||
source_imgs = source_imgs.cpu()
|
||||
|
||||
if conditions.size(1) == 1:
|
||||
conditions = torch.cat([conditions, conditions, conditions], dim=1)
|
||||
conditions = F.adaptive_avg_pool3d(conditions, (n, h, w))
|
||||
|
||||
filename = rand_name(suffix='.gif')
|
||||
for _ in [None] * retry:
|
||||
try:
|
||||
vid_gif = rearrange(
|
||||
video_tensor, '(i j) c f h w -> c f (i h) (j w)', i=nrow)
|
||||
con_gif = rearrange(
|
||||
conditions, '(i j) c f h w -> c f (i h) (j w)', i=nrow)
|
||||
source_imgs = rearrange(
|
||||
source_imgs, '(i j) c f h w -> c f (i h) (j w)', i=nrow)
|
||||
vid_gif = torch.cat([vid_gif, con_gif, source_imgs], dim=2)
|
||||
|
||||
video_tensor_to_gif(vid_gif, filename)
|
||||
bucket.put_object_from_file(oss_key, filename)
|
||||
exception = None
|
||||
break
|
||||
except Exception as e:
|
||||
exception = e
|
||||
continue
|
||||
|
||||
# remove temporary file
|
||||
if osp.exists(filename):
|
||||
os.remove(filename)
|
||||
if exception is not None:
|
||||
print(
|
||||
'save video to {} failed, error: {}'.format(oss_key, exception),
|
||||
flush=True)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def save_video_grid_mp4(bucket,
|
||||
oss_key,
|
||||
tensor,
|
||||
mean=[0.5, 0.5, 0.5],
|
||||
std=[0.5, 0.5, 0.5],
|
||||
nrow=None,
|
||||
fps=5,
|
||||
retry=5):
|
||||
mean = torch.tensor(mean, device=tensor.device).view(1, -1, 1, 1, 1)
|
||||
std = torch.tensor(std, device=tensor.device).view(1, -1, 1, 1, 1)
|
||||
tensor = tensor.mul_(std).add_(mean)
|
||||
tensor.clamp_(0, 1)
|
||||
b, c, t, h, w = tensor.shape
|
||||
tensor = tensor.permute(0, 2, 3, 4, 1)
|
||||
tensor = (tensor.cpu().numpy() * 255).astype('uint8')
|
||||
|
||||
filename = rand_name(suffix='.mp4')
|
||||
for _ in [None] * retry:
|
||||
try:
|
||||
if nrow is None:
|
||||
nrow = math.ceil(math.sqrt(b))
|
||||
ncol = math.ceil(b / nrow)
|
||||
padding = 1
|
||||
video_grid = np.zeros((t, (padding + h) * nrow + padding,
|
||||
(padding + w) * ncol + padding, c),
|
||||
dtype='uint8')
|
||||
for i in range(b):
|
||||
r = i // ncol
|
||||
c_ = i % ncol
|
||||
|
||||
start_r = (padding + h) * r
|
||||
start_c = (padding + w) * c_
|
||||
video_grid[:, start_r:start_r + h,
|
||||
start_c:start_c + w] = tensor[i]
|
||||
skvideo.io.vwrite(filename, video_grid, inputdict={'-r': str(fps)})
|
||||
|
||||
bucket.put_object_from_file(oss_key, filename)
|
||||
exception = None
|
||||
break
|
||||
except Exception as e:
|
||||
exception = e
|
||||
continue
|
||||
|
||||
# remove temporary file
|
||||
if osp.exists(filename):
|
||||
os.remove(filename)
|
||||
if exception is not None:
|
||||
print(
|
||||
'save video to {} failed, error: {}'.format(oss_key, exception),
|
||||
flush=True)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def save_text(bucket, oss_key, tensor, nrow=8, retry=5):
|
||||
len = tensor.shape[0]
|
||||
num_per_row = int(len / nrow)
|
||||
assert (len == nrow * num_per_row)
|
||||
texts = ''
|
||||
for i in range(nrow):
|
||||
for j in range(num_per_row):
|
||||
text = dec_bytes2obj(tensor[i * num_per_row + j])
|
||||
texts += text + '\n'
|
||||
texts += '\n'
|
||||
|
||||
for _ in [None] * retry:
|
||||
try:
|
||||
bucket.put_object(oss_key, texts)
|
||||
exception = None
|
||||
break
|
||||
except Exception as e:
|
||||
exception = e
|
||||
continue
|
||||
if exception is not None:
|
||||
print(
|
||||
'save video to {} failed, error: {}'.format(oss_key, exception),
|
||||
flush=True)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def save_caps(bucket, oss_key, caps, retry=5):
|
||||
texts = ''
|
||||
for cap in caps:
|
||||
texts += cap
|
||||
texts += '\n'
|
||||
|
||||
for _ in [None] * retry:
|
||||
try:
|
||||
bucket.put_object(oss_key, texts)
|
||||
exception = None
|
||||
break
|
||||
except Exception as e:
|
||||
exception = e
|
||||
continue
|
||||
if exception is not None:
|
||||
print(
|
||||
'save video to {} failed, error: {}'.format(oss_key, exception),
|
||||
flush=True)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def ema(net_ema, net, beta, copy_buffer=False):
|
||||
assert 0.0 <= beta <= 1.0
|
||||
for p_ema, p in zip(net_ema.parameters(), net.parameters()):
|
||||
p_ema.copy_(p.lerp(p_ema, beta))
|
||||
if copy_buffer:
|
||||
for b_ema, b in zip(net_ema.buffers(), net.buffers()):
|
||||
b_ema.copy_(b)
|
||||
|
||||
|
||||
def parallel(func, args_list, num_workers=32, timeout=None):
|
||||
assert isinstance(args_list, list)
|
||||
if not isinstance(args_list[0], tuple):
|
||||
args_list = [(args, ) for args in args_list]
|
||||
if num_workers == 0:
|
||||
return [func(*args) for args in args_list]
|
||||
with Pool(processes=num_workers) as pool:
|
||||
results = [pool.apply_async(func, args) for args in args_list]
|
||||
results = [res.get(timeout=timeout) for res in results]
|
||||
return results
|
||||
|
||||
|
||||
def exists(filename):
|
||||
if filename.startswith('oss://'):
|
||||
bucket, path = parse_oss_url(filename)
|
||||
return bucket.object_exists(path)
|
||||
else:
|
||||
return osp.exists(filename)
|
||||
|
||||
|
||||
def download(url, filename=None, replace=False, quiet=False):
|
||||
if filename is None:
|
||||
filename = osp.basename(url)
|
||||
if not osp.exists(filename) or replace:
|
||||
try:
|
||||
if url.startswith('oss://'):
|
||||
bucket, oss_key = parse_oss_url(url)
|
||||
bucket.get_object_to_file(oss_key, filename)
|
||||
else:
|
||||
urllib.request.urlretrieve(url, filename)
|
||||
if not quiet:
|
||||
print(f'Downloaded {url} to {filename}', flush=True)
|
||||
except Exception as e:
|
||||
raise ValueError(f'Downloading {filename} failed with error {e}')
|
||||
return osp.abspath(filename)
|
||||
|
||||
|
||||
def unzip(filename, dst_dir=None):
|
||||
if dst_dir is None:
|
||||
dst_dir = osp.dirname(filename)
|
||||
with zipfile.ZipFile(filename, 'r') as zip_ref:
|
||||
zip_ref.extractall(dst_dir)
|
||||
|
||||
|
||||
def load_state_dict(module, state_dict, drop_prefix=''):
|
||||
# find incompatible key-vals
|
||||
src, dst = state_dict, module.state_dict()
|
||||
if drop_prefix:
|
||||
src = type(src)([
|
||||
(k[len(drop_prefix):] if k.startswith(drop_prefix) else k, v)
|
||||
for k, v in src.items()
|
||||
])
|
||||
missing = [k for k in dst if k not in src]
|
||||
unexpected = [k for k in src if k not in dst]
|
||||
unmatched = [
|
||||
k for k in src.keys() & dst.keys() if src[k].shape != dst[k].shape
|
||||
]
|
||||
|
||||
# keep only compatible key-vals
|
||||
incompatible = set(unexpected + unmatched)
|
||||
src = type(src)([(k, v) for k, v in src.items() if k not in incompatible])
|
||||
module.load_state_dict(src, strict=False)
|
||||
|
||||
# report incompatible key-vals
|
||||
if len(missing) != 0:
|
||||
print(' Missing: ' + ', '.join(missing), flush=True)
|
||||
if len(unexpected) != 0:
|
||||
print(' Unexpected: ' + ', '.join(unexpected), flush=True)
|
||||
if len(unmatched) != 0:
|
||||
print(' Shape unmatched: ' + ', '.join(unmatched), flush=True)
|
||||
|
||||
|
||||
def inverse_indices(indices):
|
||||
r"""Inverse map of indices.
|
||||
E.g., if A[indices] == B, then B[inv_indices] == A.
|
||||
"""
|
||||
inv_indices = torch.empty_like(indices)
|
||||
inv_indices[indices] = torch.arange(len(indices)).to(indices)
|
||||
return inv_indices
|
||||
|
||||
|
||||
def detect_duplicates(feats, thr=0.9):
|
||||
assert feats.ndim == 2
|
||||
|
||||
# compute simmat
|
||||
feats = F.normalize(feats, p=2, dim=1)
|
||||
simmat = torch.mm(feats, feats.T)
|
||||
simmat.triu_(1)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# detect duplicates
|
||||
mask = ~simmat.gt(thr).any(dim=0)
|
||||
return torch.where(mask)[0]
|
||||
|
||||
|
||||
class TFSClient(object):
|
||||
|
||||
def __init__(self,
|
||||
host='restful-store.vip.tbsite.net:3800',
|
||||
app_key='5354c9fae75f5'):
|
||||
self.host = host
|
||||
self.app_key = app_key
|
||||
|
||||
# candidate servers
|
||||
self.servers = [
|
||||
u for u in read(f'http://{host}/url.list').strip().split('\n')[1:]
|
||||
if ':' in u
|
||||
]
|
||||
assert len(self.servers) >= 1
|
||||
self.__server_id = -1
|
||||
|
||||
@property
|
||||
def server(self):
|
||||
self.__server_id = (self.__server_id + 1) % len(self.servers)
|
||||
return self.servers[self.__server_id]
|
||||
|
||||
def read(self, tfs):
|
||||
tfs = osp.basename(tfs)
|
||||
meta = json.loads(
|
||||
read(
|
||||
f'http://{self.server}/v1/{self.app_key}/metadata/{tfs}?force=0'
|
||||
))
|
||||
img = Image.open(
|
||||
BytesIO(
|
||||
read(
|
||||
f'http://{self.server}/v1/{self.app_key}/{tfs}?offset=0&size={meta["SIZE"]}',
|
||||
'rb')))
|
||||
return img
|
||||
|
||||
|
||||
def read_tfs(tfs, retry=5):
|
||||
exception = None
|
||||
for _ in range(retry):
|
||||
try:
|
||||
global TFS_CLIENT
|
||||
if TFS_CLIENT is None:
|
||||
TFS_CLIENT = TFSClient()
|
||||
return TFS_CLIENT.read(tfs)
|
||||
except Exception as e:
|
||||
exception = e
|
||||
continue
|
||||
else:
|
||||
raise exception
|
||||
|
||||
|
||||
def md5(filename):
|
||||
with open(filename, 'rb') as f:
|
||||
return hashlib.md5(f.read()).hexdigest()
|
||||
|
||||
|
||||
def rope(x):
|
||||
r"""Apply rotary position embedding on x of shape [B, *(spatial dimensions), C].
|
||||
"""
|
||||
# reshape
|
||||
shape = x.shape
|
||||
x = x.view(x.size(0), -1, x.size(-1))
|
||||
l, c = x.shape[-2:]
|
||||
assert c % 2 == 0
|
||||
half = c // 2
|
||||
|
||||
# apply rotary position embedding on x
|
||||
sinusoid = torch.outer(
|
||||
torch.arange(l).to(x),
|
||||
torch.pow(10000, -torch.arange(half).to(x).div(half)))
|
||||
sin, cos = torch.sin(sinusoid), torch.cos(sinusoid)
|
||||
x1, x2 = x.chunk(2, dim=-1)
|
||||
x = torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1)
|
||||
|
||||
# reshape back
|
||||
return x.view(shape)
|
||||
|
||||
|
||||
def format_state(state, filename=None):
|
||||
r"""For comparing/aligning state_dict.
|
||||
"""
|
||||
content = '\n'.join([f'{k}\t{tuple(v.shape)}' for k, v in state.items()])
|
||||
if filename:
|
||||
with open(filename, 'w') as f:
|
||||
f.write(content)
|
||||
|
||||
|
||||
def breakup_grid(img, grid_size):
|
||||
r"""The inverse operator of ``torchvision.utils.make_grid``.
|
||||
"""
|
||||
# params
|
||||
nrow = img.height // grid_size
|
||||
ncol = img.width // grid_size
|
||||
wrow = wcol = 2
|
||||
|
||||
# collect grids
|
||||
grids = []
|
||||
for i in range(nrow):
|
||||
for j in range(ncol):
|
||||
x1 = j * grid_size + (j + 1) * wcol
|
||||
y1 = i * grid_size + (i + 1) * wrow
|
||||
grids.append(img.crop((x1, y1, x1 + grid_size, y1 + grid_size)))
|
||||
return grids
|
||||
|
||||
|
||||
def huggingface_tokenizer(name='google/mt5-xxl', **kwargs):
|
||||
from transformers import AutoTokenizer
|
||||
return AutoTokenizer.from_pretrained(
|
||||
DOWNLOAD_TO_CACHE(f'huggingface/tokenizers/{name}', name), **kwargs)
|
||||
|
||||
|
||||
def huggingface_model(name='google/mt5-xxl', model_type='AutoModel', **kwargs):
|
||||
import transformers
|
||||
return getattr(transformers, model_type).from_pretrained(
|
||||
DOWNLOAD_TO_CACHE(f'huggingface/models/{name}', name), **kwargs)
|
||||
@@ -0,0 +1,480 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
from copy import copy, deepcopy
|
||||
from os import path as osp
|
||||
from typing import Any, Dict
|
||||
|
||||
import open_clip
|
||||
import pynvml
|
||||
import torch
|
||||
import torch.cuda.amp as amp
|
||||
import torch.nn as nn
|
||||
from einops import rearrange
|
||||
|
||||
import modelscope.models.multi_modal.videocomposer.models as models
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models import TorchModel
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.models.multi_modal.videocomposer.annotator.sketch import (
|
||||
pidinet_bsd, sketch_simplification_gan)
|
||||
from modelscope.models.multi_modal.videocomposer.autoencoder import \
|
||||
AutoencoderKL
|
||||
from modelscope.models.multi_modal.videocomposer.clip import (
|
||||
FrozenOpenCLIPEmbedder, FrozenOpenCLIPVisualEmbedder)
|
||||
from modelscope.models.multi_modal.videocomposer.diffusion import (
|
||||
GaussianDiffusion, beta_schedule)
|
||||
from modelscope.models.multi_modal.videocomposer.ops.utils import (
|
||||
get_first_stage_encoding, make_masked_images, prepare_model_kwargs,
|
||||
save_with_model_kwargs)
|
||||
from modelscope.models.multi_modal.videocomposer.unet_sd import UNetSD_temporal
|
||||
from modelscope.models.multi_modal.videocomposer.utils.config import Config
|
||||
from modelscope.models.multi_modal.videocomposer.utils.utils import (
|
||||
find_free_port, setup_seed, to_device)
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.preprocessors.image import load_image
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from .config import cfg
|
||||
|
||||
__all__ = ['VideoComposer']
|
||||
|
||||
|
||||
@MODELS.register_module(
|
||||
Tasks.text_to_video_synthesis, module_name=Models.videocomposer)
|
||||
class VideoComposer(TorchModel):
|
||||
r"""
|
||||
task for video composer.
|
||||
|
||||
Attributes:
|
||||
sd_model: denosing model using in this task.
|
||||
diffusion: diffusion model for DDIM.
|
||||
autoencoder: decode the latent representation into visual space with VQGAN.
|
||||
clip_encoder: encode the text into text embedding.
|
||||
"""
|
||||
|
||||
def __init__(self, model_dir, *args, **kwargs):
|
||||
r"""
|
||||
Args:
|
||||
model_dir (`str` or `os.PathLike`)
|
||||
Can be either:
|
||||
- A string, the *model id* of a pretrained model hosted inside a model repo on modelscope
|
||||
or modelscope.cn. Valid model ids can be located at the root-level, like `bert-base-uncased`,
|
||||
or namespaced under a user or organization name, like `dbmdz/bert-base-german-cased`.
|
||||
- A path to a *directory* containing model weights saved using
|
||||
[`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
|
||||
- A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In
|
||||
this case, `from_tf` should be set to `True` and a configuration object should be provided as
|
||||
`config` argument. This loading path is slower than converting the TensorFlow checkpoint in a
|
||||
PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
|
||||
- A path or url to a model folder containing a *flax checkpoint file* in *.msgpack* format (e.g,
|
||||
`./flax_model/` containing `flax_model.msgpack`). In this case, `from_flax` should be set to
|
||||
`True`.
|
||||
"""
|
||||
super().__init__(model_dir=model_dir, *args, **kwargs)
|
||||
self.device = torch.device('cuda') if torch.cuda.is_available() \
|
||||
else torch.device('cpu')
|
||||
clip_checkpoint = kwargs.pop('clip_checkpoint',
|
||||
'open_clip_pytorch_model.bin')
|
||||
sd_checkpoint = kwargs.pop('sd_checkpoint', 'v2-1_512-ema-pruned.ckpt')
|
||||
_cfg = Config(load=True)
|
||||
cfg.update(_cfg.cfg_dict)
|
||||
# rank-wise params
|
||||
l1 = len(cfg.frame_lens)
|
||||
l2 = len(cfg.feature_framerates)
|
||||
cfg.max_frames = cfg.frame_lens[0 % (l1 * l2) // l2]
|
||||
cfg.batch_size = cfg.batch_sizes[str(cfg.max_frames)]
|
||||
# Copy update input parameter to current task
|
||||
self.cfg = cfg
|
||||
if 'MASTER_ADDR' not in os.environ:
|
||||
os.environ['MASTER_ADDR'] = 'localhost'
|
||||
os.environ['MASTER_PORT'] = find_free_port()
|
||||
self.cfg.pmi_rank = int(os.getenv('RANK', 0))
|
||||
self.cfg.pmi_world_size = int(os.getenv('WORLD_SIZE', 1))
|
||||
setup_seed(self.cfg.seed)
|
||||
self.read_image = kwargs.pop('read_image', False)
|
||||
self.read_style = kwargs.pop('read_style', True)
|
||||
self.read_sketch = kwargs.pop('read_sketch', False)
|
||||
self.save_origin_video = kwargs.pop('save_origin_video', True)
|
||||
self.video_compositions = kwargs.pop('video_compositions', [
|
||||
'text', 'mask', 'depthmap', 'sketch', 'motion', 'image',
|
||||
'local_image', 'single_sketch'
|
||||
])
|
||||
self.viz_num = self.cfg.batch_size
|
||||
self.clip_encoder = FrozenOpenCLIPEmbedder(
|
||||
layer='penultimate',
|
||||
pretrained=os.path.join(model_dir, clip_checkpoint))
|
||||
self.clip_encoder = self.clip_encoder.to(self.device)
|
||||
self.clip_encoder_visual = FrozenOpenCLIPVisualEmbedder(
|
||||
layer='penultimate',
|
||||
pretrained=os.path.join(model_dir, clip_checkpoint))
|
||||
self.clip_encoder_visual.model.to(self.device)
|
||||
ddconfig = {
|
||||
'double_z': True,
|
||||
'z_channels': 4,
|
||||
'resolution': 256,
|
||||
'in_channels': 3,
|
||||
'out_ch': 3,
|
||||
'ch': 128,
|
||||
'ch_mult': [1, 2, 4, 4],
|
||||
'num_res_blocks': 2,
|
||||
'attn_resolutions': [],
|
||||
'dropout': 0.0
|
||||
}
|
||||
self.autoencoder = AutoencoderKL(
|
||||
ddconfig, 4, ckpt_path=os.path.join(model_dir, sd_checkpoint))
|
||||
self.zero_y = self.clip_encoder('').detach()
|
||||
black_image_feature = self.clip_encoder_visual(
|
||||
self.clip_encoder_visual.black_image).unsqueeze(1)
|
||||
black_image_feature = torch.zeros_like(black_image_feature)
|
||||
self.autoencoder.eval()
|
||||
for param in self.autoencoder.parameters():
|
||||
param.requires_grad = False
|
||||
self.autoencoder.cuda()
|
||||
self.model = UNetSD_temporal(
|
||||
cfg=self.cfg,
|
||||
in_dim=self.cfg.unet_in_dim,
|
||||
concat_dim=self.cfg.unet_concat_dim,
|
||||
dim=self.cfg.unet_dim,
|
||||
y_dim=self.cfg.unet_y_dim,
|
||||
context_dim=self.cfg.unet_context_dim,
|
||||
out_dim=self.cfg.unet_out_dim,
|
||||
dim_mult=self.cfg.unet_dim_mult,
|
||||
num_heads=self.cfg.unet_num_heads,
|
||||
head_dim=self.cfg.unet_head_dim,
|
||||
num_res_blocks=self.cfg.unet_res_blocks,
|
||||
attn_scales=self.cfg.unet_attn_scales,
|
||||
dropout=self.cfg.unet_dropout,
|
||||
temporal_attention=self.cfg.temporal_attention,
|
||||
temporal_attn_times=self.cfg.temporal_attn_times,
|
||||
use_checkpoint=self.cfg.use_checkpoint,
|
||||
use_fps_condition=self.cfg.use_fps_condition,
|
||||
use_sim_mask=self.cfg.use_sim_mask,
|
||||
video_compositions=self.cfg.video_compositions,
|
||||
misc_dropout=self.cfg.misc_dropout,
|
||||
p_all_zero=self.cfg.p_all_zero,
|
||||
p_all_keep=self.cfg.p_all_zero,
|
||||
zero_y=self.zero_y,
|
||||
black_image_feature=black_image_feature,
|
||||
).to(self.device)
|
||||
|
||||
# Load checkpoint
|
||||
if self.cfg.resume and self.cfg.resume_checkpoint:
|
||||
if hasattr(self.cfg, 'text_to_video_pretrain'
|
||||
) and self.cfg.text_to_video_pretrain:
|
||||
checkpoint_name = cfg.resume_checkpoint.split('/')[-1]
|
||||
ss = torch.load(
|
||||
os.path.join(self.model_dir, cfg.resume_checkpoint))
|
||||
ss = {
|
||||
key: p
|
||||
for key, p in ss.items() if 'input_blocks.0.0' not in key
|
||||
}
|
||||
self.model.load_state_dict(ss, strict=False)
|
||||
else:
|
||||
checkpoint_name = cfg.resume_checkpoint.split('/')[-1]
|
||||
self.model.load_state_dict(
|
||||
torch.load(
|
||||
os.path.join(self.model_dir, checkpoint_name),
|
||||
map_location='cpu'),
|
||||
strict=False)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
else:
|
||||
raise ValueError(
|
||||
f'The checkpoint file {self.cfg.resume_checkpoint} is wrong ')
|
||||
|
||||
# diffusion
|
||||
betas = beta_schedule(
|
||||
'linear_sd',
|
||||
self.cfg.num_timesteps,
|
||||
init_beta=0.00085,
|
||||
last_beta=0.0120)
|
||||
self.diffusion = GaussianDiffusion(
|
||||
betas=betas,
|
||||
mean_type=self.cfg.mean_type,
|
||||
var_type=self.cfg.var_type,
|
||||
loss_type=self.cfg.loss_type,
|
||||
rescale_timesteps=False)
|
||||
|
||||
def forward(self, input: Dict[str, Any]):
|
||||
frame_in = None
|
||||
if self.read_image:
|
||||
image_key = input['style_image']
|
||||
frame = load_image(image_key)
|
||||
frame_in = misc_transforms([frame])
|
||||
|
||||
frame_sketch = None
|
||||
if self.read_sketch:
|
||||
sketch_key = self.cfg.sketch_path
|
||||
frame_sketch = load_image(sketch_key)
|
||||
frame_sketch = misc_transforms([frame_sketch])
|
||||
|
||||
frame_style = None
|
||||
if self.read_style:
|
||||
frame_style = load_image(input['style_image'])
|
||||
|
||||
# Generators for various conditions
|
||||
if 'depthmap' in self.video_compositions:
|
||||
midas = models.midas_v3(
|
||||
pretrained=True,
|
||||
model_dir=self.model_dir).eval().requires_grad_(False).to(
|
||||
memory_format=torch.channels_last).half().to(self.device)
|
||||
if 'canny' in self.video_compositions:
|
||||
canny_detector = CannyDetector()
|
||||
if 'sketch' in self.video_compositions:
|
||||
pidinet = pidinet_bsd(
|
||||
self.model_dir, pretrained=True,
|
||||
vanilla_cnn=True).eval().requires_grad_(False).to(self.device)
|
||||
cleaner = sketch_simplification_gan(
|
||||
self.model_dir,
|
||||
pretrained=True).eval().requires_grad_(False).to(self.device)
|
||||
pidi_mean = torch.tensor(self.cfg.sketch_mean).view(
|
||||
1, -1, 1, 1).to(self.device)
|
||||
pidi_std = torch.tensor(self.cfg.sketch_std).view(1, -1, 1, 1).to(
|
||||
self.device)
|
||||
# Placeholder for color inference
|
||||
palette = None
|
||||
|
||||
self.model.eval()
|
||||
caps = input['cap_txt']
|
||||
if self.cfg.max_frames == 1 and self.cfg.use_image_dataset:
|
||||
ref_imgs = input['ref_frame']
|
||||
video_data = input['video_data']
|
||||
misc_data = input['misc_data']
|
||||
mask = input['mask']
|
||||
mv_data = input['mv_data']
|
||||
fps = torch.tensor(
|
||||
[self.cfg.feature_framerate] * self.cfg.batch_size,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
else:
|
||||
ref_imgs = input['ref_frame']
|
||||
video_data = input['video_data']
|
||||
misc_data = input['misc_data']
|
||||
mask = input['mask']
|
||||
mv_data = input['mv_data']
|
||||
# add fps test
|
||||
fps = torch.tensor(
|
||||
[self.cfg.feature_framerate] * self.cfg.batch_size,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
|
||||
# save for visualization
|
||||
misc_backups = copy(misc_data)
|
||||
misc_backups = rearrange(misc_backups, 'b f c h w -> b c f h w')
|
||||
mv_data_video = []
|
||||
if 'motion' in self.cfg.video_compositions:
|
||||
mv_data_video = rearrange(mv_data, 'b f c h w -> b c f h w')
|
||||
|
||||
# mask images
|
||||
masked_video = []
|
||||
if 'mask' in self.cfg.video_compositions:
|
||||
masked_video = make_masked_images(
|
||||
misc_data.sub(0.5).div_(0.5), mask)
|
||||
masked_video = rearrange(masked_video, 'b f c h w -> b c f h w')
|
||||
|
||||
# Single Image
|
||||
image_local = []
|
||||
if 'local_image' in self.cfg.video_compositions:
|
||||
frames_num = misc_data.shape[1]
|
||||
bs_vd_local = misc_data.shape[0]
|
||||
if self.cfg.read_image:
|
||||
image_local = frame_in.unsqueeze(0).repeat(
|
||||
bs_vd_local, frames_num, 1, 1, 1).cuda()
|
||||
else:
|
||||
image_local = misc_data[:, :1].clone().repeat(
|
||||
1, frames_num, 1, 1, 1)
|
||||
image_local = rearrange(
|
||||
image_local, 'b f c h w -> b c f h w', b=bs_vd_local)
|
||||
|
||||
# encode the video_data
|
||||
bs_vd = video_data.shape[0]
|
||||
video_data = rearrange(video_data, 'b f c h w -> (b f) c h w')
|
||||
misc_data = rearrange(misc_data, 'b f c h w -> (b f) c h w')
|
||||
|
||||
video_data_list = torch.chunk(
|
||||
video_data, video_data.shape[0] // self.cfg.chunk_size, dim=0)
|
||||
misc_data_list = torch.chunk(
|
||||
misc_data, misc_data.shape[0] // self.cfg.chunk_size, dim=0)
|
||||
|
||||
with torch.no_grad():
|
||||
decode_data = []
|
||||
for vd_data in video_data_list:
|
||||
encoder_posterior = self.autoencoder.encode(vd_data)
|
||||
tmp = get_first_stage_encoding(encoder_posterior).detach()
|
||||
decode_data.append(tmp)
|
||||
video_data = torch.cat(decode_data, dim=0)
|
||||
video_data = rearrange(
|
||||
video_data, '(b f) c h w -> b c f h w', b=bs_vd)
|
||||
|
||||
depth_data = []
|
||||
if 'depthmap' in self.cfg.video_compositions:
|
||||
for misc_imgs in misc_data_list:
|
||||
depth = midas(
|
||||
misc_imgs.sub(0.5).div_(0.5).to(
|
||||
memory_format=torch.channels_last).half())
|
||||
depth = (depth / self.cfg.depth_std).clamp_(
|
||||
0, self.cfg.depth_clamp)
|
||||
depth_data.append(depth)
|
||||
depth_data = torch.cat(depth_data, dim=0)
|
||||
depth_data = rearrange(
|
||||
depth_data, '(b f) c h w -> b c f h w', b=bs_vd)
|
||||
|
||||
canny_data = []
|
||||
if 'canny' in self.cfg.video_compositions:
|
||||
for misc_imgs in misc_data_list:
|
||||
misc_imgs = rearrange(misc_imgs.clone(),
|
||||
'k c h w -> k h w c')
|
||||
canny_condition = torch.stack(
|
||||
[canny_detector(misc_img) for misc_img in misc_imgs])
|
||||
canny_condition = rearrange(canny_condition,
|
||||
'k h w c-> k c h w')
|
||||
canny_data.append(canny_condition)
|
||||
canny_data = torch.cat(canny_data, dim=0)
|
||||
canny_data = rearrange(
|
||||
canny_data, '(b f) c h w -> b c f h w', b=bs_vd)
|
||||
|
||||
sketch_data = []
|
||||
if 'sketch' in self.cfg.video_compositions:
|
||||
sketch_list = misc_data_list
|
||||
if self.cfg.read_sketch:
|
||||
sketch_repeat = frame_sketch.repeat(frames_num, 1, 1,
|
||||
1).cuda()
|
||||
sketch_list = [sketch_repeat]
|
||||
|
||||
for misc_imgs in sketch_list:
|
||||
sketch = pidinet(misc_imgs.sub(pidi_mean).div_(pidi_std))
|
||||
sketch = 1.0 - cleaner(1.0 - sketch)
|
||||
sketch_data.append(sketch)
|
||||
sketch_data = torch.cat(sketch_data, dim=0)
|
||||
sketch_data = rearrange(
|
||||
sketch_data, '(b f) c h w -> b c f h w', b=bs_vd)
|
||||
|
||||
single_sketch_data = []
|
||||
if 'single_sketch' in self.cfg.video_compositions:
|
||||
single_sketch_data = sketch_data.clone()[:, :, :1].repeat(
|
||||
1, 1, frames_num, 1, 1)
|
||||
|
||||
# preprocess for input text descripts
|
||||
y = self.clip_encoder(caps).detach()
|
||||
y0 = y.clone()
|
||||
|
||||
y_visual = []
|
||||
if 'image' in self.cfg.video_compositions:
|
||||
with torch.no_grad():
|
||||
if self.cfg.read_style:
|
||||
y_visual = self.clip_encoder_visual(
|
||||
self.clip_encoder_visual.preprocess(
|
||||
frame_style).unsqueeze(0).cuda()).unsqueeze(0)
|
||||
y_visual0 = y_visual.clone()
|
||||
else:
|
||||
ref_imgs = ref_imgs.squeeze(1)
|
||||
y_visual = self.clip_encoder_visual(ref_imgs).unsqueeze(1)
|
||||
y_visual0 = y_visual.clone()
|
||||
|
||||
with torch.no_grad():
|
||||
# Log memory
|
||||
pynvml.nvmlInit()
|
||||
# Sample images (DDIM)
|
||||
with amp.autocast(enabled=self.cfg.use_fp16):
|
||||
if self.cfg.share_noise:
|
||||
b, c, f, h, w = video_data.shape
|
||||
noise = torch.randn((self.viz_num, c, h, w),
|
||||
device=self.device)
|
||||
noise = noise.repeat_interleave(repeats=f, dim=0)
|
||||
noise = rearrange(
|
||||
noise, '(b f) c h w->b c f h w', b=self.viz_num)
|
||||
noise = noise.contiguous()
|
||||
else:
|
||||
noise = torch.randn_like(video_data[:self.viz_num])
|
||||
|
||||
full_model_kwargs = [{
|
||||
'y':
|
||||
y0[:self.viz_num],
|
||||
'local_image':
|
||||
None
|
||||
if len(image_local) == 0 else image_local[:self.viz_num],
|
||||
'image':
|
||||
None if len(y_visual) == 0 else y_visual0[:self.viz_num],
|
||||
'depth':
|
||||
None
|
||||
if len(depth_data) == 0 else depth_data[:self.viz_num],
|
||||
'canny':
|
||||
None
|
||||
if len(canny_data) == 0 else canny_data[:self.viz_num],
|
||||
'sketch':
|
||||
None
|
||||
if len(sketch_data) == 0 else sketch_data[:self.viz_num],
|
||||
'masked':
|
||||
None
|
||||
if len(masked_video) == 0 else masked_video[:self.viz_num],
|
||||
'motion':
|
||||
None if len(mv_data_video) == 0 else
|
||||
mv_data_video[:self.viz_num],
|
||||
'single_sketch':
|
||||
None if len(single_sketch_data) == 0 else
|
||||
single_sketch_data[:self.viz_num],
|
||||
'fps':
|
||||
fps[:self.viz_num]
|
||||
}, {
|
||||
'y':
|
||||
self.zero_y.repeat(self.viz_num, 1, 1)
|
||||
if not self.cfg.use_fps_condition else
|
||||
torch.zeros_like(y0)[:self.viz_num],
|
||||
'local_image':
|
||||
None
|
||||
if len(image_local) == 0 else image_local[:self.viz_num],
|
||||
'image':
|
||||
None if len(y_visual) == 0 else torch.zeros_like(
|
||||
y_visual0[:self.viz_num]),
|
||||
'depth':
|
||||
None
|
||||
if len(depth_data) == 0 else depth_data[:self.viz_num],
|
||||
'canny':
|
||||
None
|
||||
if len(canny_data) == 0 else canny_data[:self.viz_num],
|
||||
'sketch':
|
||||
None
|
||||
if len(sketch_data) == 0 else sketch_data[:self.viz_num],
|
||||
'masked':
|
||||
None
|
||||
if len(masked_video) == 0 else masked_video[:self.viz_num],
|
||||
'motion':
|
||||
None if len(mv_data_video) == 0 else
|
||||
mv_data_video[:self.viz_num],
|
||||
'single_sketch':
|
||||
None if len(single_sketch_data) == 0 else
|
||||
single_sketch_data[:self.viz_num],
|
||||
'fps':
|
||||
fps[:self.viz_num]
|
||||
}]
|
||||
|
||||
# Save generated videos
|
||||
partial_keys = self.cfg.guidances
|
||||
noise_motion = noise.clone()
|
||||
model_kwargs = prepare_model_kwargs(
|
||||
partial_keys=partial_keys,
|
||||
full_model_kwargs=full_model_kwargs,
|
||||
use_fps_condition=self.cfg.use_fps_condition)
|
||||
video_output = self.diffusion.ddim_sample_loop(
|
||||
noise=noise_motion,
|
||||
model=self.model.eval(),
|
||||
model_kwargs=model_kwargs,
|
||||
guide_scale=9.0,
|
||||
ddim_timesteps=self.cfg.ddim_timesteps,
|
||||
eta=0.0)
|
||||
|
||||
save_with_model_kwargs(
|
||||
model_kwargs=model_kwargs,
|
||||
video_data=video_output,
|
||||
autoencoder=self.autoencoder,
|
||||
ori_video=misc_backups,
|
||||
viz_num=self.viz_num,
|
||||
step=0,
|
||||
caps=caps,
|
||||
palette=palette,
|
||||
cfg=self.cfg)
|
||||
|
||||
return {
|
||||
'video': video_output.type(torch.float32).cpu(),
|
||||
'video_path': self.cfg
|
||||
}
|
||||
@@ -22,6 +22,7 @@ if TYPE_CHECKING:
|
||||
from .soonet_video_temporal_grounding_pipeline import SOONetVideoTemporalGroundingPipeline
|
||||
from .text_to_video_synthesis_pipeline import TextToVideoSynthesisPipeline
|
||||
from .multimodal_dialogue_pipeline import MultimodalDialoguePipeline
|
||||
from .videocomposer_pipeline import VideoComposerPipeline
|
||||
else:
|
||||
_import_structure = {
|
||||
'image_captioning_pipeline': ['ImageCaptioningPipeline'],
|
||||
@@ -46,7 +47,8 @@ else:
|
||||
'soonet_video_temporal_grounding_pipeline':
|
||||
['SOONetVideoTemporalGroundingPipeline'],
|
||||
'text_to_video_synthesis_pipeline': ['TextToVideoSynthesisPipeline'],
|
||||
'multimodal_dialogue_pipeline': ['MultimodalDialoguePipeline']
|
||||
'multimodal_dialogue_pipeline': ['MultimodalDialoguePipeline'],
|
||||
'videocomposer_pipeline': ['VideoComposerPipeline']
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
382
modelscope/pipelines/multi_modal/videocomposer_pipeline.py
Normal file
382
modelscope/pipelines/multi_modal/videocomposer_pipeline.py
Normal file
@@ -0,0 +1,382 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
import random
|
||||
import subprocess
|
||||
import tempfile
|
||||
import time
|
||||
from functools import partial
|
||||
from typing import Any, Dict
|
||||
|
||||
import cv2
|
||||
import imageio
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
from mvextractor.videocap import VideoCap
|
||||
from PIL import Image
|
||||
|
||||
import modelscope.models.multi_modal.videocomposer.data as data
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.models.multi_modal.videocomposer.data.transforms import (
|
||||
CenterCropV3, random_resize)
|
||||
from modelscope.models.multi_modal.videocomposer.ops.random_mask import (
|
||||
make_irregular_mask, make_rectangle_mask, make_uncrop)
|
||||
from modelscope.models.multi_modal.videocomposer.utils.utils import rand_name
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.base import Input, Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.device import device_placement
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.text_to_video_synthesis, module_name=Pipelines.videocomposer)
|
||||
class VideoComposerPipeline(Pipeline):
|
||||
r""" Video Composer Pipeline.
|
||||
|
||||
Examples:
|
||||
|
||||
>>> from modelscope.pipelines import pipeline
|
||||
>>> from modelscope.utils.constant import Tasks
|
||||
>>> pipe = pipeline(
|
||||
task=Tasks.text_to_video_synthesis,
|
||||
model='buptwq/videocomposer',
|
||||
model_revision='v1.0.1')
|
||||
>>> inputs = {'Video:FILE': 'path/input_video.mp4',
|
||||
'Image:FILE': 'path/input_image.png',
|
||||
'text': 'the text description'}
|
||||
>>> output = pipe(inputs)
|
||||
"""
|
||||
|
||||
def __init__(self, model: str, **kwargs):
|
||||
"""
|
||||
use `model` to create a videocomposer pipeline for prediction
|
||||
Args:
|
||||
model: model id on modelscope hub.
|
||||
"""
|
||||
super().__init__(model=model)
|
||||
self.log_dir = kwargs.pop('log_dir', './video_outputs')
|
||||
if not os.path.exists(self.log_dir):
|
||||
os.makedirs(self.log_dir)
|
||||
self.feature_framerate = kwargs.pop('feature_framerate', 4)
|
||||
self.frame_lens = kwargs.pop('frame_lens', [
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
])
|
||||
self.feature_framerates = kwargs.pop('feature_framerates', [
|
||||
4,
|
||||
])
|
||||
self.batch_sizes = kwargs.pop('batch_sizes', {
|
||||
'1': 1,
|
||||
'4': 1,
|
||||
'8': 1,
|
||||
'16': 1,
|
||||
})
|
||||
l1 = len(self.frame_lens)
|
||||
l2 = len(self.feature_framerates)
|
||||
self.max_frames = self.frame_lens[0 % (l1 * l2) // l2]
|
||||
self.batch_size = self.batch_sizes[str(self.max_frames)]
|
||||
self.resolution = kwargs.pop('resolution', 256)
|
||||
self.image_resolution = kwargs.pop('image_resolution', 256)
|
||||
self.mean = kwargs.pop('mean', [0.5, 0.5, 0.5])
|
||||
self.std = kwargs.pop('std', [0.5, 0.5, 0.5])
|
||||
self.vit_image_size = kwargs.pop('vit_image_size', 224)
|
||||
self.vit_mean = kwargs.pop('vit_mean',
|
||||
[0.48145466, 0.4578275, 0.40821073])
|
||||
self.vit_std = kwargs.pop('vit_std',
|
||||
[0.26862954, 0.26130258, 0.27577711])
|
||||
self.misc_size = kwargs.pop('kwargs.pop', 384)
|
||||
self.visual_mv = kwargs.pop('visual_mv', False)
|
||||
self.max_words = kwargs.pop('max_words', 1000)
|
||||
self.mvs_visual = kwargs.pop('mvs_visual', False)
|
||||
|
||||
self.infer_trans = data.Compose([
|
||||
data.CenterCropV2(size=self.resolution),
|
||||
data.ToTensor(),
|
||||
data.Normalize(mean=self.mean, std=self.std)
|
||||
])
|
||||
|
||||
self.misc_transforms = data.Compose([
|
||||
T.Lambda(partial(random_resize, size=self.misc_size)),
|
||||
data.CenterCropV2(self.misc_size),
|
||||
data.ToTensor()
|
||||
])
|
||||
|
||||
self.mv_transforms = data.Compose(
|
||||
[T.Resize(size=self.resolution),
|
||||
T.CenterCrop(self.resolution)])
|
||||
|
||||
self.vit_transforms = T.Compose([
|
||||
CenterCropV3(self.vit_image_size),
|
||||
T.ToTensor(),
|
||||
T.Normalize(mean=self.vit_mean, std=self.vit_std)
|
||||
])
|
||||
|
||||
def preprocess(self, input: Input) -> Dict[str, Any]:
|
||||
video_key = input['Video:FILE']
|
||||
cap_txt = input['text']
|
||||
style_image = input['Image:FILE']
|
||||
|
||||
total_frames = None
|
||||
|
||||
feature_framerate = self.feature_framerate
|
||||
if os.path.exists(video_key):
|
||||
try:
|
||||
ref_frame, vit_image, video_data, misc_data, mv_data = self.video_data_preprocess(
|
||||
video_key, self.feature_framerate, total_frames,
|
||||
self.mvs_visual)
|
||||
except Exception as e:
|
||||
logger.info(
|
||||
'{} get frames failed... with error: {}'.format(
|
||||
video_key, e),
|
||||
flush=True)
|
||||
|
||||
ref_frame = torch.zeros(3, self.vit_image_size,
|
||||
self.vit_image_size)
|
||||
video_data = torch.zeros(self.max_frames, 3,
|
||||
self.image_resolution,
|
||||
self.image_resolution)
|
||||
misc_data = torch.zeros(self.max_frames, 3, self.misc_size,
|
||||
self.misc_size)
|
||||
|
||||
mv_data = torch.zeros(self.max_frames, 2,
|
||||
self.image_resolution,
|
||||
self.image_resolution)
|
||||
else:
|
||||
logger.info(
|
||||
'The video path does not exist or no video dir provided!')
|
||||
ref_frame = torch.zeros(3, self.vit_image_size,
|
||||
self.vit_image_size)
|
||||
_ = torch.zeros(3, self.vit_image_size, self.vit_image_size)
|
||||
video_data = torch.zeros(self.max_frames, 3, self.image_resolution,
|
||||
self.image_resolution)
|
||||
misc_data = torch.zeros(self.max_frames, 3, self.misc_size,
|
||||
self.misc_size)
|
||||
mv_data = torch.zeros(self.max_frames, 2, self.image_resolution,
|
||||
self.image_resolution)
|
||||
|
||||
# inpainting mask
|
||||
p = random.random()
|
||||
if p < 0.7:
|
||||
mask = make_irregular_mask(512, 512)
|
||||
elif p < 0.9:
|
||||
mask = make_rectangle_mask(512, 512)
|
||||
else:
|
||||
mask = make_uncrop(512, 512)
|
||||
mask = torch.from_numpy(
|
||||
cv2.resize(
|
||||
mask, (self.misc_size, self.misc_size),
|
||||
interpolation=cv2.INTER_NEAREST)).unsqueeze(0).float()
|
||||
|
||||
mask = mask.unsqueeze(0).repeat_interleave(
|
||||
repeats=self.max_frames, dim=0)
|
||||
video_input = {
|
||||
'ref_frame': ref_frame.unsqueeze(0),
|
||||
'cap_txt': cap_txt,
|
||||
'video_data': video_data.unsqueeze(0),
|
||||
'misc_data': misc_data.unsqueeze(0),
|
||||
'feature_framerate': feature_framerate,
|
||||
'mask': mask.unsqueeze(0),
|
||||
'mv_data': mv_data.unsqueeze(0),
|
||||
'style_image': style_image
|
||||
}
|
||||
return video_input
|
||||
|
||||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return self.model(input)
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Any],
|
||||
**post_params) -> Dict[str, Any]:
|
||||
output_video_path = post_params.get('output_video', None)
|
||||
temp_video_file = False
|
||||
if output_video_path is not None:
|
||||
output_video_path = tempfile.NamedTemporaryFile(suffix='.gif').name
|
||||
temp_video_file = True
|
||||
|
||||
if temp_video_file:
|
||||
return {OutputKeys.OUTPUT_VIDEO: inputs['video_path']}
|
||||
else:
|
||||
return {OutputKeys.OUTPUT_VIDEO: inputs['video']}
|
||||
|
||||
def video_data_preprocess(self, video_key, feature_framerate, total_frames,
|
||||
visual_mv):
|
||||
|
||||
filename = video_key
|
||||
for _ in range(5):
|
||||
try:
|
||||
frame_types, frames, mvs, mvs_visual = self.extract_motion_vectors(
|
||||
input_video=filename,
|
||||
fps=feature_framerate,
|
||||
visual_mv=visual_mv)
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'{} read video frames and motion vectors failed with error: {}'
|
||||
.format(video_key, e),
|
||||
flush=True)
|
||||
|
||||
total_frames = len(frame_types)
|
||||
start_indexs = np.where((np.array(frame_types) == 'I') & (
|
||||
total_frames - np.arange(total_frames) >= self.max_frames))[0]
|
||||
start_index = np.random.choice(start_indexs)
|
||||
indices = np.arange(start_index, start_index + self.max_frames)
|
||||
|
||||
# note frames are in BGR mode, need to trans to RGB mode
|
||||
frames = [Image.fromarray(frames[i][:, :, ::-1]) for i in indices]
|
||||
mvs = [torch.from_numpy(mvs[i].transpose((2, 0, 1))) for i in indices]
|
||||
mvs = torch.stack(mvs)
|
||||
|
||||
if visual_mv:
|
||||
images = [(mvs_visual[i][:, :, ::-1]).astype('uint8')
|
||||
for i in indices]
|
||||
path = self.log_dir + '/visual_mv/' + video_key.split(
|
||||
'/')[-1] + '.gif'
|
||||
if not os.path.exists(self.log_dir + '/visual_mv/'):
|
||||
os.makedirs(self.log_dir + '/visual_mv/', exist_ok=True)
|
||||
logger.info('save motion vectors visualization to :', path)
|
||||
imageio.mimwrite(path, images, fps=8)
|
||||
|
||||
have_frames = len(frames) > 0
|
||||
middle_indix = int(len(frames) / 2)
|
||||
if have_frames:
|
||||
ref_frame = frames[middle_indix]
|
||||
vit_image = self.vit_transforms(ref_frame)
|
||||
misc_imgs_np = self.misc_transforms[:2](frames)
|
||||
misc_imgs = self.misc_transforms[2:](misc_imgs_np)
|
||||
frames = self.infer_trans(frames)
|
||||
mvs = self.mv_transforms(mvs)
|
||||
else:
|
||||
vit_image = torch.zeros(3, self.vit_image_size,
|
||||
self.vit_image_size)
|
||||
|
||||
video_data = torch.zeros(self.max_frames, 3, self.image_resolution,
|
||||
self.image_resolution)
|
||||
mv_data = torch.zeros(self.max_frames, 2, self.image_resolution,
|
||||
self.image_resolution)
|
||||
misc_data = torch.zeros(self.max_frames, 3, self.misc_size,
|
||||
self.misc_size)
|
||||
if have_frames:
|
||||
video_data[:len(frames), ...] = frames
|
||||
misc_data[:len(frames), ...] = misc_imgs
|
||||
mv_data[:len(frames), ...] = mvs
|
||||
|
||||
ref_frame = vit_image
|
||||
|
||||
del frames
|
||||
del misc_imgs
|
||||
del mvs
|
||||
|
||||
return ref_frame, vit_image, video_data, misc_data, mv_data
|
||||
|
||||
def extract_motion_vectors(self,
|
||||
input_video,
|
||||
fps=4,
|
||||
dump=False,
|
||||
verbose=False,
|
||||
visual_mv=False):
|
||||
|
||||
if dump:
|
||||
now = datetime.now().strftime('%Y-%m-%dT%H:%M:%S')
|
||||
for child in ['frames', 'motion_vectors']:
|
||||
os.makedirs(os.path.join(f'out-{now}', child), exist_ok=True)
|
||||
temp = rand_name()
|
||||
tmp_video = os.path.join(
|
||||
input_video.split('/')[0], f'{temp}' + input_video.split('/')[-1])
|
||||
videocapture = cv2.VideoCapture(input_video)
|
||||
frames_num = videocapture.get(cv2.CAP_PROP_FRAME_COUNT)
|
||||
fps_video = videocapture.get(cv2.CAP_PROP_FPS)
|
||||
# check if enough frames
|
||||
if frames_num / fps_video * fps > 16:
|
||||
fps = max(fps, 1)
|
||||
else:
|
||||
fps = int(16 / (frames_num / fps_video)) + 1
|
||||
ffmpeg_cmd = f'ffmpeg -threads 8 -loglevel error -i {input_video} -filter:v \
|
||||
fps={fps} -c:v mpeg4 -f rawvideo {tmp_video}'
|
||||
|
||||
if os.path.exists(tmp_video):
|
||||
os.remove(tmp_video)
|
||||
|
||||
subprocess.run(args=ffmpeg_cmd, shell=True, timeout=120)
|
||||
|
||||
cap = VideoCap()
|
||||
# open the video file
|
||||
ret = cap.open(tmp_video)
|
||||
if not ret:
|
||||
raise RuntimeError(f'Could not open {tmp_video}')
|
||||
|
||||
step = 0
|
||||
times = []
|
||||
|
||||
frame_types = []
|
||||
frames = []
|
||||
mvs = []
|
||||
mvs_visual = []
|
||||
# continuously read and display video frames and motion vectors
|
||||
while True:
|
||||
if verbose:
|
||||
logger.info('Frame: ', step, end=' ')
|
||||
|
||||
tstart = time.perf_counter()
|
||||
|
||||
# read next video frame and corresponding motion vectors
|
||||
ret, frame, motion_vectors, frame_type, timestamp = cap.read()
|
||||
|
||||
tend = time.perf_counter()
|
||||
telapsed = tend - tstart
|
||||
times.append(telapsed)
|
||||
|
||||
# if there is an error reading the frame
|
||||
if not ret:
|
||||
if verbose:
|
||||
logger.warning('No frame read. Stopping.')
|
||||
break
|
||||
|
||||
frame_save = np.zeros(frame.copy().shape, dtype=np.uint8)
|
||||
if visual_mv:
|
||||
frame_save = draw_motion_vectors(frame_save, motion_vectors)
|
||||
|
||||
# store motion vectors, frames, etc. in output directory
|
||||
dump = False
|
||||
if frame.shape[1] >= frame.shape[0]:
|
||||
w_half = (frame.shape[1] - frame.shape[0]) // 2
|
||||
if dump:
|
||||
cv2.imwrite(
|
||||
os.path.join('./mv_visual/', f'frame-{step}.jpg'),
|
||||
frame_save[:, w_half:-w_half])
|
||||
mvs_visual.append(frame_save[:, w_half:-w_half])
|
||||
else:
|
||||
h_half = (frame.shape[0] - frame.shape[1]) // 2
|
||||
if dump:
|
||||
cv2.imwrite(
|
||||
os.path.join('./mv_visual/', f'frame-{step}.jpg'),
|
||||
frame_save[h_half:-h_half, :])
|
||||
mvs_visual.append(frame_save[h_half:-h_half, :])
|
||||
|
||||
h, w = frame.shape[:2]
|
||||
mv = np.zeros((h, w, 2))
|
||||
position = motion_vectors[:, 5:7].clip((0, 0), (w - 1, h - 1))
|
||||
mv[position[:, 1],
|
||||
position[:,
|
||||
0]] = motion_vectors[:, 0:
|
||||
1] * motion_vectors[:, 7:
|
||||
9] / motion_vectors[:,
|
||||
9:]
|
||||
|
||||
step += 1
|
||||
frame_types.append(frame_type)
|
||||
frames.append(frame)
|
||||
mvs.append(mv)
|
||||
if verbose:
|
||||
logger.info('average dt: ', np.mean(times))
|
||||
cap.release()
|
||||
|
||||
if os.path.exists(tmp_video):
|
||||
os.remove(tmp_video)
|
||||
|
||||
return frame_types, frames, mvs, mvs_visual
|
||||
@@ -1,4 +1,5 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os.path as osp
|
||||
import re
|
||||
from io import BytesIO
|
||||
|
||||
38
tests/pipelines/test_videocomposer.py
Normal file
38
tests/pipelines/test_videocomposer.py
Normal file
@@ -0,0 +1,38 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
from modelscope.models import Model
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import DownloadMode, Tasks
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
|
||||
class VideoDeinterlaceTest(unittest.TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.task = Tasks.text_to_video_synthesis
|
||||
self.model_id = 'buptwq/videocomposer'
|
||||
self.model_revision = 'v1.0.1'
|
||||
self.dataset_id = 'buptwq/videocomposer-depths-style'
|
||||
self.text = 'A glittering and translucent fish swimming in a \
|
||||
small glass bowl with multicolored piece of stone, like a glass fish'
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_run_pipeline(self):
|
||||
pipe = pipeline(
|
||||
task=Tasks.text_to_video_synthesis,
|
||||
model=self.model_id,
|
||||
model_revision=self.model_revision)
|
||||
ds = MsDataset.load(
|
||||
self.dataset_id,
|
||||
split='train',
|
||||
download_mode=DownloadMode.FORCE_REDOWNLOAD)
|
||||
inputs = next(iter(ds))
|
||||
inputs.update({'text': self.text})
|
||||
_ = pipe(inputs)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user