From ee8afd2d62d276b306fe9ba8b5c672bbc175da06 Mon Sep 17 00:00:00 2001
From: Wang Qiang <37444407+XDUWQ@users.noreply.github.com>
Date: Tue, 15 Aug 2023 12:01:03 +0800
Subject: [PATCH] VideoComposer: Compositional Video Synthesis with Motion
Controllability (#431)
* VideoComposer: Compositional Video Synthesis with Motion Controllability
* videocomposer pipeline
* pre commit
* delete xformers
---
modelscope/metainfo.py | 2 +
modelscope/models/multi_modal/__init__.py | 2 +
.../text_to_video_synthesis_model.py | 3 +
.../multi_modal/videocomposer/__init__.py | 23 +
.../videocomposer/annotator/__init__.py | 1 +
.../videocomposer/annotator/canny/__init__.py | 44 +
.../annotator/histogram/__init__.py | 3 +
.../annotator/histogram/palette.py | 155 ++
.../annotator/sketch/__init__.py | 4 +
.../videocomposer/annotator/sketch/pidinet.py | 940 ++++++++
.../annotator/sketch/sketch_simplification.py | 110 +
.../videocomposer/annotator/util.py | 42 +
.../multi_modal/videocomposer/autoencoder.py | 650 +++++
.../models/multi_modal/videocomposer/clip.py | 143 ++
.../multi_modal/videocomposer/config.py | 156 ++
.../videocomposer/configs/base.yaml | 2 +
.../configs/exp01_vidcomposer_full.yaml | 20 +
.../configs/exp02_motion_transfer.yaml | 23 +
.../exp02_motion_transfer_vs_style.yaml | 24 +
.../configs/exp03_sketch2video_style.yaml | 26 +
.../configs/exp04_sketch2video_wo_style.yaml | 26 +
.../configs/exp05_text_depths_wo_style.yaml | 26 +
.../configs/exp06_text_depths_vs_style.yaml | 26 +
.../exp10_vidcomposer_no_watermark_full.yaml | 21 +
.../videocomposer/data/__init__.py | 5 +
.../videocomposer/data/samplers.py | 158 ++
.../videocomposer/data/tokenizers.py | 184 ++
.../videocomposer/data/transforms.py | 400 ++++
.../multi_modal/videocomposer/diffusion.py | 1514 ++++++++++++
.../multi_modal/videocomposer/dpm_solver.py | 1697 +++++++++++++
.../multi_modal/videocomposer/mha_flash.py | 120 +
.../videocomposer/models/__init__.py | 2 +
.../multi_modal/videocomposer/models/clip.py | 460 ++++
.../multi_modal/videocomposer/models/midas.py | 320 +++
.../multi_modal/videocomposer/ops/__init__.py | 7 +
.../videocomposer/ops/degration.py | 998 ++++++++
.../videocomposer/ops/distributed.py | 460 ++++
.../multi_modal/videocomposer/ops/losses.py | 37 +
.../videocomposer/ops/random_mask.py | 81 +
.../multi_modal/videocomposer/ops/utils.py | 1037 ++++++++
.../multi_modal/videocomposer/unet_sd.py | 2102 +++++++++++++++++
.../videocomposer/utils/__init__.py | 0
.../multi_modal/videocomposer/utils/config.py | 273 +++
.../videocomposer/utils/distributed.py | 297 +++
.../multi_modal/videocomposer/utils/utils.py | 955 ++++++++
.../videocomposer/videocomposer_model.py | 480 ++++
modelscope/pipelines/multi_modal/__init__.py | 4 +-
.../multi_modal/videocomposer_pipeline.py | 382 +++
modelscope/preprocessors/multi_modal.py | 1 +
tests/pipelines/test_videocomposer.py | 38 +
50 files changed, 14483 insertions(+), 1 deletion(-)
create mode 100644 modelscope/models/multi_modal/videocomposer/__init__.py
create mode 100644 modelscope/models/multi_modal/videocomposer/annotator/__init__.py
create mode 100644 modelscope/models/multi_modal/videocomposer/annotator/canny/__init__.py
create mode 100644 modelscope/models/multi_modal/videocomposer/annotator/histogram/__init__.py
create mode 100644 modelscope/models/multi_modal/videocomposer/annotator/histogram/palette.py
create mode 100644 modelscope/models/multi_modal/videocomposer/annotator/sketch/__init__.py
create mode 100644 modelscope/models/multi_modal/videocomposer/annotator/sketch/pidinet.py
create mode 100644 modelscope/models/multi_modal/videocomposer/annotator/sketch/sketch_simplification.py
create mode 100644 modelscope/models/multi_modal/videocomposer/annotator/util.py
create mode 100644 modelscope/models/multi_modal/videocomposer/autoencoder.py
create mode 100644 modelscope/models/multi_modal/videocomposer/clip.py
create mode 100644 modelscope/models/multi_modal/videocomposer/config.py
create mode 100644 modelscope/models/multi_modal/videocomposer/configs/base.yaml
create mode 100644 modelscope/models/multi_modal/videocomposer/configs/exp01_vidcomposer_full.yaml
create mode 100644 modelscope/models/multi_modal/videocomposer/configs/exp02_motion_transfer.yaml
create mode 100644 modelscope/models/multi_modal/videocomposer/configs/exp02_motion_transfer_vs_style.yaml
create mode 100644 modelscope/models/multi_modal/videocomposer/configs/exp03_sketch2video_style.yaml
create mode 100644 modelscope/models/multi_modal/videocomposer/configs/exp04_sketch2video_wo_style.yaml
create mode 100644 modelscope/models/multi_modal/videocomposer/configs/exp05_text_depths_wo_style.yaml
create mode 100644 modelscope/models/multi_modal/videocomposer/configs/exp06_text_depths_vs_style.yaml
create mode 100644 modelscope/models/multi_modal/videocomposer/configs/exp10_vidcomposer_no_watermark_full.yaml
create mode 100644 modelscope/models/multi_modal/videocomposer/data/__init__.py
create mode 100644 modelscope/models/multi_modal/videocomposer/data/samplers.py
create mode 100644 modelscope/models/multi_modal/videocomposer/data/tokenizers.py
create mode 100644 modelscope/models/multi_modal/videocomposer/data/transforms.py
create mode 100644 modelscope/models/multi_modal/videocomposer/diffusion.py
create mode 100644 modelscope/models/multi_modal/videocomposer/dpm_solver.py
create mode 100644 modelscope/models/multi_modal/videocomposer/mha_flash.py
create mode 100644 modelscope/models/multi_modal/videocomposer/models/__init__.py
create mode 100644 modelscope/models/multi_modal/videocomposer/models/clip.py
create mode 100644 modelscope/models/multi_modal/videocomposer/models/midas.py
create mode 100644 modelscope/models/multi_modal/videocomposer/ops/__init__.py
create mode 100644 modelscope/models/multi_modal/videocomposer/ops/degration.py
create mode 100644 modelscope/models/multi_modal/videocomposer/ops/distributed.py
create mode 100644 modelscope/models/multi_modal/videocomposer/ops/losses.py
create mode 100644 modelscope/models/multi_modal/videocomposer/ops/random_mask.py
create mode 100644 modelscope/models/multi_modal/videocomposer/ops/utils.py
create mode 100644 modelscope/models/multi_modal/videocomposer/unet_sd.py
create mode 100644 modelscope/models/multi_modal/videocomposer/utils/__init__.py
create mode 100644 modelscope/models/multi_modal/videocomposer/utils/config.py
create mode 100644 modelscope/models/multi_modal/videocomposer/utils/distributed.py
create mode 100644 modelscope/models/multi_modal/videocomposer/utils/utils.py
create mode 100644 modelscope/models/multi_modal/videocomposer/videocomposer_model.py
create mode 100644 modelscope/pipelines/multi_modal/videocomposer_pipeline.py
create mode 100644 tests/pipelines/test_videocomposer.py
diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py
index df5e8a75..70fd9c86 100644
--- a/modelscope/metainfo.py
+++ b/modelscope/metainfo.py
@@ -218,6 +218,7 @@ class Models(object):
mplug_owl = 'mplug-owl'
clip_interrogator = 'clip-interrogator'
stable_diffusion = 'stable-diffusion'
+ videocomposer = 'videocomposer'
text_to_360panorama_image = 'text-to-360panorama-image'
# science models
@@ -525,6 +526,7 @@ class Pipelines(object):
multi_modal_similarity = 'multi-modal-similarity'
text_to_image_synthesis = 'text-to-image-synthesis'
video_multi_modal_embedding = 'video-multi-modal-embedding'
+ videocomposer = 'videocomposer'
image_text_retrieval = 'image-text-retrieval'
ofa_ocr_recognition = 'ofa-ocr-recognition'
ofa_asr = 'ofa-asr'
diff --git a/modelscope/models/multi_modal/__init__.py b/modelscope/models/multi_modal/__init__.py
index 9fa34baf..31af94b2 100644
--- a/modelscope/models/multi_modal/__init__.py
+++ b/modelscope/models/multi_modal/__init__.py
@@ -22,6 +22,7 @@ if TYPE_CHECKING:
from .efficient_diffusion_tuning import EfficientStableDiffusion
from .mplug_owl import MplugOwlForConditionalGeneration
from .clip_interrogator import CLIP_Interrogator
+ from .videocomposer import VideoComposer
else:
_import_structure = {
@@ -42,6 +43,7 @@ else:
'efficient_diffusion_tuning': ['EfficientStableDiffusion'],
'mplug_owl': ['MplugOwlForConditionalGeneration'],
'clip_interrogator': ['CLIP_Interrogator'],
+ 'videocomposer': ['VideoComposer'],
}
import sys
diff --git a/modelscope/models/multi_modal/video_synthesis/text_to_video_synthesis_model.py b/modelscope/models/multi_modal/video_synthesis/text_to_video_synthesis_model.py
index 93a7e3ba..0ec66069 100644
--- a/modelscope/models/multi_modal/video_synthesis/text_to_video_synthesis_model.py
+++ b/modelscope/models/multi_modal/video_synthesis/text_to_video_synthesis_model.py
@@ -19,6 +19,9 @@ from modelscope.models.multi_modal.video_synthesis.diffusion import (
from modelscope.models.multi_modal.video_synthesis.unet_sd import UNetSD
from modelscope.utils.config import Config
from modelscope.utils.constant import ModelFile, Tasks
+from modelscope.utils.logger import get_logger
+
+logger = get_logger()
__all__ = ['TextToVideoSynthesis']
diff --git a/modelscope/models/multi_modal/videocomposer/__init__.py b/modelscope/models/multi_modal/videocomposer/__init__.py
new file mode 100644
index 00000000..ae9102c0
--- /dev/null
+++ b/modelscope/models/multi_modal/videocomposer/__init__.py
@@ -0,0 +1,23 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+from typing import TYPE_CHECKING
+
+from modelscope.utils.import_utils import LazyImportModule
+
+if TYPE_CHECKING:
+
+ from .videocomposer_model import VideoComposer
+
+else:
+ _import_structure = {
+ 'videocomposer_model': ['VideoComposer'],
+ }
+
+ import sys
+
+ sys.modules[__name__] = LazyImportModule(
+ __name__,
+ globals()['__file__'],
+ _import_structure,
+ module_spec=__spec__,
+ extra_objects={},
+ )
diff --git a/modelscope/models/multi_modal/videocomposer/annotator/__init__.py b/modelscope/models/multi_modal/videocomposer/annotator/__init__.py
new file mode 100644
index 00000000..b937315b
--- /dev/null
+++ b/modelscope/models/multi_modal/videocomposer/annotator/__init__.py
@@ -0,0 +1 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
diff --git a/modelscope/models/multi_modal/videocomposer/annotator/canny/__init__.py b/modelscope/models/multi_modal/videocomposer/annotator/canny/__init__.py
new file mode 100644
index 00000000..bf653d47
--- /dev/null
+++ b/modelscope/models/multi_modal/videocomposer/annotator/canny/__init__.py
@@ -0,0 +1,44 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+
+import cv2
+import numpy as np
+import torch
+from tools.annotator.util import HWC3
+
+
+class CannyDetector:
+
+ def __call__(self,
+ img,
+ low_threshold=None,
+ high_threshold=None,
+ random_threshold=True):
+
+ # Convert to numpy
+ if isinstance(img, torch.Tensor): # (h, w, c)
+ img = img.cpu().numpy()
+ img_np = cv2.convertScaleAbs((img * 255.))
+ elif isinstance(img, np.ndarray): # (h, w, c)
+ img_np = img # we assume values are in the range from 0 to 255.
+ else:
+ assert False
+
+ # Select the threshold
+ if (low_threshold is None) and (high_threshold is None):
+ median_intensity = np.median(img_np)
+ if random_threshold is False:
+ low_threshold = int(max(0, (1 - 0.33) * median_intensity))
+ high_threshold = int(min(255, (1 + 0.33) * median_intensity))
+ else:
+ random_canny = np.random.uniform(0.1, 0.4)
+ # Might try other values
+ low_threshold = int(
+ max(0, (1 - random_canny) * median_intensity))
+ high_threshold = 2 * low_threshold
+
+ # Detect canny edge
+ canny_edge = cv2.Canny(img_np, low_threshold, high_threshold)
+
+ canny_condition = torch.from_numpy(
+ canny_edge.copy()).unsqueeze(dim=-1).float().cuda() / 255.0
+ return canny_condition
diff --git a/modelscope/models/multi_modal/videocomposer/annotator/histogram/__init__.py b/modelscope/models/multi_modal/videocomposer/annotator/histogram/__init__.py
new file mode 100644
index 00000000..62f092a1
--- /dev/null
+++ b/modelscope/models/multi_modal/videocomposer/annotator/histogram/__init__.py
@@ -0,0 +1,3 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+
+from .palette import *
diff --git a/modelscope/models/multi_modal/videocomposer/annotator/histogram/palette.py b/modelscope/models/multi_modal/videocomposer/annotator/histogram/palette.py
new file mode 100644
index 00000000..d28ba85b
--- /dev/null
+++ b/modelscope/models/multi_modal/videocomposer/annotator/histogram/palette.py
@@ -0,0 +1,155 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+r"""Modified from ``https://github.com/sergeyk/rayleigh''.
+"""
+import os
+import os.path as osp
+
+import numpy as np
+from skimage.color import hsv2rgb, lab2rgb, rgb2lab
+from skimage.io import imsave
+from sklearn.metrics import euclidean_distances
+
+__all__ = ['Palette']
+
+
+def rgb2hex(rgb):
+ return '#%02x%02x%02x' % tuple([int(round(255.0 * u)) for u in rgb])
+
+
+def hex2rgb(hex):
+ rgb = hex.strip('#')
+ fn = lambda u: round(int(u, 16) / 255.0, 5) # noqa
+ return fn(rgb[:2]), fn(rgb[2:4]), fn(rgb[4:6])
+
+
+class Palette(object):
+ r"""Create a color palette (codebook) in the form of a 2D grid of colors.
+ Further, the rightmost column has num_hues gradations from black to white.
+
+ Parameters:
+ num_hues: number of colors with full lightness and saturation, in the middle.
+ num_sat: number of rows above middle row that show the same hues with decreasing saturation.
+ """
+
+ def __init__(self, num_hues=11, num_sat=5, num_light=4):
+ n = num_sat + 2 * num_light
+
+ # hues
+ if num_hues == 8:
+ hues = np.tile(
+ np.array([0., 0.10, 0.15, 0.28, 0.51, 0.58, 0.77, 0.85]),
+ (n, 1))
+ elif num_hues == 9:
+ hues = np.tile(
+ np.array([0., 0.10, 0.15, 0.28, 0.49, 0.54, 0.60, 0.7, 0.87]),
+ (n, 1))
+ elif num_hues == 10:
+ hues = np.tile(
+ np.array(
+ [0., 0.10, 0.15, 0.28, 0.49, 0.54, 0.60, 0.66, 0.76,
+ 0.87]), (n, 1))
+ elif num_hues == 11:
+ hues = np.tile(
+ np.array([
+ 0.0, 0.0833, 0.166, 0.25, 0.333, 0.5, 0.56333, 0.666, 0.73,
+ 0.803, 0.916
+ ]), (n, 1))
+ else:
+ hues = np.tile(np.linspace(0, 1, num_hues + 1)[:-1], (n, 1))
+
+ # saturations
+ sats = np.hstack((
+ np.linspace(0, 1, num_sat + 2)[1:-1],
+ 1,
+ [1] * num_light,
+ [0.4] * # noqa
+ (num_light - 1)))
+ sats = np.tile(np.atleast_2d(sats).T, (1, num_hues))
+
+ # lights
+ lights = np.hstack(
+ ([1] * num_sat, 1, np.linspace(1, 0.2, num_light + 2)[1:-1],
+ np.linspace(1, 0.2, num_light + 2)[1:-2]))
+ lights = np.tile(np.atleast_2d(lights).T, (1, num_hues))
+
+ # colors
+ rgb = hsv2rgb(np.dstack([hues, sats, lights]))
+ gray = np.tile(
+ np.linspace(1, 0, n)[:, np.newaxis, np.newaxis], (1, 1, 3))
+ self.thumbnail = np.hstack([rgb, gray])
+
+ # flatten
+ rgb = rgb.T.reshape(3, -1).T
+ gray = gray.T.reshape(3, -1).T
+ self.rgb = np.vstack((rgb, gray))
+ self.lab = rgb2lab(self.rgb[np.newaxis, :, :]).squeeze()
+ self.hex = [rgb2hex(u) for u in self.rgb]
+ self.lab_dists = euclidean_distances(self.lab, squared=True)
+
+ def histogram(self, rgb_img, sigma=20):
+ # compute histogram
+ lab = rgb2lab(rgb_img).reshape((-1, 3))
+ min_ind = np.argmin(
+ euclidean_distances(lab, self.lab, squared=True), axis=1)
+ hist = 1.0 * np.bincount(
+ min_ind, minlength=self.lab.shape[0]) / lab.shape[0]
+
+ # smooth histogram
+ if sigma > 0:
+ weight = np.exp(-self.lab_dists / (2.0 * sigma**2))
+ weight = weight / weight.sum(1)[:, np.newaxis]
+ hist = (weight * hist).sum(1)
+ hist[hist < 1e-5] = 0
+ return hist
+
+ def get_palette_image(self, hist, percentile=90, width=200, height=50):
+ # curate histogram
+ ind = np.argsort(-hist)
+ ind = ind[hist[ind] > np.percentile(hist, percentile)]
+ hist = hist[ind] / hist[ind].sum()
+
+ # draw palette
+ nums = np.array(hist * width, dtype=int)
+ array = np.vstack([
+ np.tile(np.array(u), (v, 1)) for u, v in zip(self.rgb[ind], nums)
+ ])
+ array = np.tile(array[np.newaxis, :, :], (height, 1, 1))
+ if array.shape[1] < width:
+ array = np.concatenate(
+ [array, np.zeros((height, width - array.shape[1], 3))], axis=1)
+ return array
+
+ def quantize_image(self, rgb_img):
+ lab = rgb2lab(rgb_img).reshape((-1, 3))
+ min_ind = np.argmin(
+ euclidean_distances(lab, self.lab, squared=True), axis=1)
+ quantized_lab = self.lab[min_ind]
+ img = lab2rgb(quantized_lab.reshape(rgb_img.shape))
+ return img
+
+ def export(self, dirname):
+ if not osp.exists(dirname):
+ os.makedirs(dirname)
+
+ # save thumbnail
+ imsave(osp.join(dirname, 'palette.png'), self.thumbnail)
+
+ # save html
+ with open(osp.join(dirname, 'palette.html'), 'w') as f:
+ html = '''
+
+ '''
+ for row in self.thumbnail:
+ for col in row:
+ html += '\n'.format(
+ rgb2hex(col))
+ html += '
\n'
+ f.write(html)
diff --git a/modelscope/models/multi_modal/videocomposer/annotator/sketch/__init__.py b/modelscope/models/multi_modal/videocomposer/annotator/sketch/__init__.py
new file mode 100644
index 00000000..51e28cc7
--- /dev/null
+++ b/modelscope/models/multi_modal/videocomposer/annotator/sketch/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+
+from .pidinet import *
+from .sketch_simplification import *
diff --git a/modelscope/models/multi_modal/videocomposer/annotator/sketch/pidinet.py b/modelscope/models/multi_modal/videocomposer/annotator/sketch/pidinet.py
new file mode 100644
index 00000000..86c17d5d
--- /dev/null
+++ b/modelscope/models/multi_modal/videocomposer/annotator/sketch/pidinet.py
@@ -0,0 +1,940 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+r"""Modified from ``https://github.com/zhuoinoulu/pidinet''.
+ Image augmentation: T.Compose([
+ T.ToTensor(),
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]).
+"""
+import math
+import os
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from modelscope.models.multi_modal.videocomposer.utils.utils import \
+ DOWNLOAD_TO_CACHE
+
+__all__ = [
+ 'PiDiNet', 'pidinet_bsd_tiny', 'pidinet_bsd_small', 'pidinet_bsd',
+ 'pidinet_nyud', 'pidinet_multicue'
+]
+
+CONFIGS = {
+ 'baseline': {
+ 'layer0': 'cv',
+ 'layer1': 'cv',
+ 'layer2': 'cv',
+ 'layer3': 'cv',
+ 'layer4': 'cv',
+ 'layer5': 'cv',
+ 'layer6': 'cv',
+ 'layer7': 'cv',
+ 'layer8': 'cv',
+ 'layer9': 'cv',
+ 'layer10': 'cv',
+ 'layer11': 'cv',
+ 'layer12': 'cv',
+ 'layer13': 'cv',
+ 'layer14': 'cv',
+ 'layer15': 'cv',
+ },
+ 'c-v15': {
+ 'layer0': 'cd',
+ 'layer1': 'cv',
+ 'layer2': 'cv',
+ 'layer3': 'cv',
+ 'layer4': 'cv',
+ 'layer5': 'cv',
+ 'layer6': 'cv',
+ 'layer7': 'cv',
+ 'layer8': 'cv',
+ 'layer9': 'cv',
+ 'layer10': 'cv',
+ 'layer11': 'cv',
+ 'layer12': 'cv',
+ 'layer13': 'cv',
+ 'layer14': 'cv',
+ 'layer15': 'cv',
+ },
+ 'a-v15': {
+ 'layer0': 'ad',
+ 'layer1': 'cv',
+ 'layer2': 'cv',
+ 'layer3': 'cv',
+ 'layer4': 'cv',
+ 'layer5': 'cv',
+ 'layer6': 'cv',
+ 'layer7': 'cv',
+ 'layer8': 'cv',
+ 'layer9': 'cv',
+ 'layer10': 'cv',
+ 'layer11': 'cv',
+ 'layer12': 'cv',
+ 'layer13': 'cv',
+ 'layer14': 'cv',
+ 'layer15': 'cv',
+ },
+ 'r-v15': {
+ 'layer0': 'rd',
+ 'layer1': 'cv',
+ 'layer2': 'cv',
+ 'layer3': 'cv',
+ 'layer4': 'cv',
+ 'layer5': 'cv',
+ 'layer6': 'cv',
+ 'layer7': 'cv',
+ 'layer8': 'cv',
+ 'layer9': 'cv',
+ 'layer10': 'cv',
+ 'layer11': 'cv',
+ 'layer12': 'cv',
+ 'layer13': 'cv',
+ 'layer14': 'cv',
+ 'layer15': 'cv',
+ },
+ 'cvvv4': {
+ 'layer0': 'cd',
+ 'layer1': 'cv',
+ 'layer2': 'cv',
+ 'layer3': 'cv',
+ 'layer4': 'cd',
+ 'layer5': 'cv',
+ 'layer6': 'cv',
+ 'layer7': 'cv',
+ 'layer8': 'cd',
+ 'layer9': 'cv',
+ 'layer10': 'cv',
+ 'layer11': 'cv',
+ 'layer12': 'cd',
+ 'layer13': 'cv',
+ 'layer14': 'cv',
+ 'layer15': 'cv',
+ },
+ 'avvv4': {
+ 'layer0': 'ad',
+ 'layer1': 'cv',
+ 'layer2': 'cv',
+ 'layer3': 'cv',
+ 'layer4': 'ad',
+ 'layer5': 'cv',
+ 'layer6': 'cv',
+ 'layer7': 'cv',
+ 'layer8': 'ad',
+ 'layer9': 'cv',
+ 'layer10': 'cv',
+ 'layer11': 'cv',
+ 'layer12': 'ad',
+ 'layer13': 'cv',
+ 'layer14': 'cv',
+ 'layer15': 'cv',
+ },
+ 'rvvv4': {
+ 'layer0': 'rd',
+ 'layer1': 'cv',
+ 'layer2': 'cv',
+ 'layer3': 'cv',
+ 'layer4': 'rd',
+ 'layer5': 'cv',
+ 'layer6': 'cv',
+ 'layer7': 'cv',
+ 'layer8': 'rd',
+ 'layer9': 'cv',
+ 'layer10': 'cv',
+ 'layer11': 'cv',
+ 'layer12': 'rd',
+ 'layer13': 'cv',
+ 'layer14': 'cv',
+ 'layer15': 'cv',
+ },
+ 'cccv4': {
+ 'layer0': 'cd',
+ 'layer1': 'cd',
+ 'layer2': 'cd',
+ 'layer3': 'cv',
+ 'layer4': 'cd',
+ 'layer5': 'cd',
+ 'layer6': 'cd',
+ 'layer7': 'cv',
+ 'layer8': 'cd',
+ 'layer9': 'cd',
+ 'layer10': 'cd',
+ 'layer11': 'cv',
+ 'layer12': 'cd',
+ 'layer13': 'cd',
+ 'layer14': 'cd',
+ 'layer15': 'cv',
+ },
+ 'aaav4': {
+ 'layer0': 'ad',
+ 'layer1': 'ad',
+ 'layer2': 'ad',
+ 'layer3': 'cv',
+ 'layer4': 'ad',
+ 'layer5': 'ad',
+ 'layer6': 'ad',
+ 'layer7': 'cv',
+ 'layer8': 'ad',
+ 'layer9': 'ad',
+ 'layer10': 'ad',
+ 'layer11': 'cv',
+ 'layer12': 'ad',
+ 'layer13': 'ad',
+ 'layer14': 'ad',
+ 'layer15': 'cv',
+ },
+ 'rrrv4': {
+ 'layer0': 'rd',
+ 'layer1': 'rd',
+ 'layer2': 'rd',
+ 'layer3': 'cv',
+ 'layer4': 'rd',
+ 'layer5': 'rd',
+ 'layer6': 'rd',
+ 'layer7': 'cv',
+ 'layer8': 'rd',
+ 'layer9': 'rd',
+ 'layer10': 'rd',
+ 'layer11': 'cv',
+ 'layer12': 'rd',
+ 'layer13': 'rd',
+ 'layer14': 'rd',
+ 'layer15': 'cv',
+ },
+ 'c16': {
+ 'layer0': 'cd',
+ 'layer1': 'cd',
+ 'layer2': 'cd',
+ 'layer3': 'cd',
+ 'layer4': 'cd',
+ 'layer5': 'cd',
+ 'layer6': 'cd',
+ 'layer7': 'cd',
+ 'layer8': 'cd',
+ 'layer9': 'cd',
+ 'layer10': 'cd',
+ 'layer11': 'cd',
+ 'layer12': 'cd',
+ 'layer13': 'cd',
+ 'layer14': 'cd',
+ 'layer15': 'cd',
+ },
+ 'a16': {
+ 'layer0': 'ad',
+ 'layer1': 'ad',
+ 'layer2': 'ad',
+ 'layer3': 'ad',
+ 'layer4': 'ad',
+ 'layer5': 'ad',
+ 'layer6': 'ad',
+ 'layer7': 'ad',
+ 'layer8': 'ad',
+ 'layer9': 'ad',
+ 'layer10': 'ad',
+ 'layer11': 'ad',
+ 'layer12': 'ad',
+ 'layer13': 'ad',
+ 'layer14': 'ad',
+ 'layer15': 'ad',
+ },
+ 'r16': {
+ 'layer0': 'rd',
+ 'layer1': 'rd',
+ 'layer2': 'rd',
+ 'layer3': 'rd',
+ 'layer4': 'rd',
+ 'layer5': 'rd',
+ 'layer6': 'rd',
+ 'layer7': 'rd',
+ 'layer8': 'rd',
+ 'layer9': 'rd',
+ 'layer10': 'rd',
+ 'layer11': 'rd',
+ 'layer12': 'rd',
+ 'layer13': 'rd',
+ 'layer14': 'rd',
+ 'layer15': 'rd',
+ },
+ 'carv4': {
+ 'layer0': 'cd',
+ 'layer1': 'ad',
+ 'layer2': 'rd',
+ 'layer3': 'cv',
+ 'layer4': 'cd',
+ 'layer5': 'ad',
+ 'layer6': 'rd',
+ 'layer7': 'cv',
+ 'layer8': 'cd',
+ 'layer9': 'ad',
+ 'layer10': 'rd',
+ 'layer11': 'cv',
+ 'layer12': 'cd',
+ 'layer13': 'ad',
+ 'layer14': 'rd',
+ 'layer15': 'cv'
+ }
+}
+
+
+def create_conv_func(op_type):
+ assert op_type in ['cv', 'cd', 'ad',
+ 'rd'], 'unknown op type: %s' % str(op_type)
+ if op_type == 'cv':
+ return F.conv2d
+ if op_type == 'cd':
+
+ def func(x,
+ weights,
+ bias=None,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1):
+ assert dilation in [1,
+ 2], 'dilation for cd_conv should be in 1 or 2'
+ assert weights.size(2) == 3 and weights.size(
+ 3) == 3, 'kernel size for cd_conv should be 3x3'
+ assert padding == dilation, 'padding for cd_conv set wrong'
+
+ weights_c = weights.sum(dim=[2, 3], keepdim=True)
+ yc = F.conv2d(
+ x, weights_c, stride=stride, padding=0, groups=groups)
+ y = F.conv2d(
+ x,
+ weights,
+ bias,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ groups=groups)
+ return y - yc
+
+ return func
+ elif op_type == 'ad':
+
+ def func(x,
+ weights,
+ bias=None,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1):
+ assert dilation in [1,
+ 2], 'dilation for ad_conv should be in 1 or 2'
+ assert weights.size(2) == 3 and weights.size(
+ 3) == 3, 'kernel size for ad_conv should be 3x3'
+ assert padding == dilation, 'padding for ad_conv set wrong'
+
+ shape = weights.shape
+ weights = weights.view(shape[0], shape[1], -1)
+ weights_conv = (weights
+ - weights[:, :, [3, 0, 1, 6, 4, 2, 7, 8, 5]]).view(
+ shape) # clock-wise
+ y = F.conv2d(
+ x,
+ weights_conv,
+ bias,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ groups=groups)
+ return y
+
+ return func
+ elif op_type == 'rd':
+
+ def func(x,
+ weights,
+ bias=None,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1):
+ assert dilation in [1,
+ 2], 'dilation for rd_conv should be in 1 or 2'
+ assert weights.size(2) == 3 and weights.size(
+ 3) == 3, 'kernel size for rd_conv should be 3x3'
+ padding = 2 * dilation
+
+ shape = weights.shape
+ if weights.is_cuda:
+ buffer = torch.cuda.FloatTensor(shape[0], shape[1],
+ 5 * 5).fill_(0)
+ else:
+ buffer = torch.zeros(shape[0], shape[1], 5 * 5)
+ weights = weights.view(shape[0], shape[1], -1)
+ buffer[:, :, [0, 2, 4, 10, 14, 20, 22, 24]] = weights[:, :, 1:]
+ buffer[:, :, [6, 7, 8, 11, 13, 16, 17, 18]] = -weights[:, :, 1:]
+ buffer[:, :, 12] = 0
+ buffer = buffer.view(shape[0], shape[1], 5, 5)
+ y = F.conv2d(
+ x,
+ buffer,
+ bias,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ groups=groups)
+ return y
+
+ return func
+ else:
+ print('impossible to be here unless you force that', flush=True)
+ return None
+
+
+def config_model(model):
+ model_options = list(CONFIGS.keys())
+ assert model in model_options, \
+ 'unrecognized model, please choose from %s' % str(model_options)
+
+ pdcs = []
+ for i in range(16):
+ layer_name = 'layer%d' % i
+ op = CONFIGS[model][layer_name]
+ pdcs.append(create_conv_func(op))
+ return pdcs
+
+
+def config_model_converted(model):
+ model_options = list(CONFIGS.keys())
+ assert model in model_options, \
+ 'unrecognized model, please choose from %s' % str(model_options)
+
+ pdcs = []
+ for i in range(16):
+ layer_name = 'layer%d' % i
+ op = CONFIGS[model][layer_name]
+ pdcs.append(op)
+ return pdcs
+
+
+def convert_pdc(op, weight):
+ if op == 'cv':
+ return weight
+ elif op == 'cd':
+ shape = weight.shape
+ weight_c = weight.sum(dim=[2, 3])
+ weight = weight.view(shape[0], shape[1], -1)
+ weight[:, :, 4] = weight[:, :, 4] - weight_c
+ weight = weight.view(shape)
+ return weight
+ elif op == 'ad':
+ shape = weight.shape
+ weight = weight.view(shape[0], shape[1], -1)
+ weight_conv = (weight
+ - weight[:, :, [3, 0, 1, 6, 4, 2, 7, 8, 5]]).view(shape)
+ return weight_conv
+ elif op == 'rd':
+ shape = weight.shape
+ buffer = torch.zeros(shape[0], shape[1], 5 * 5, device=weight.device)
+ weight = weight.view(shape[0], shape[1], -1)
+ buffer[:, :, [0, 2, 4, 10, 14, 20, 22, 24]] = weight[:, :, 1:]
+ buffer[:, :, [6, 7, 8, 11, 13, 16, 17, 18]] = -weight[:, :, 1:]
+ buffer = buffer.view(shape[0], shape[1], 5, 5)
+ return buffer
+ raise ValueError('wrong op {}'.format(str(op)))
+
+
+def convert_pidinet(state_dict, config):
+ pdcs = config_model_converted(config)
+ new_dict = {}
+ for pname, p in state_dict.items():
+ if 'init_block.weight' in pname:
+ new_dict[pname] = convert_pdc(pdcs[0], p)
+ elif 'block1_1.conv1.weight' in pname:
+ new_dict[pname] = convert_pdc(pdcs[1], p)
+ elif 'block1_2.conv1.weight' in pname:
+ new_dict[pname] = convert_pdc(pdcs[2], p)
+ elif 'block1_3.conv1.weight' in pname:
+ new_dict[pname] = convert_pdc(pdcs[3], p)
+ elif 'block2_1.conv1.weight' in pname:
+ new_dict[pname] = convert_pdc(pdcs[4], p)
+ elif 'block2_2.conv1.weight' in pname:
+ new_dict[pname] = convert_pdc(pdcs[5], p)
+ elif 'block2_3.conv1.weight' in pname:
+ new_dict[pname] = convert_pdc(pdcs[6], p)
+ elif 'block2_4.conv1.weight' in pname:
+ new_dict[pname] = convert_pdc(pdcs[7], p)
+ elif 'block3_1.conv1.weight' in pname:
+ new_dict[pname] = convert_pdc(pdcs[8], p)
+ elif 'block3_2.conv1.weight' in pname:
+ new_dict[pname] = convert_pdc(pdcs[9], p)
+ elif 'block3_3.conv1.weight' in pname:
+ new_dict[pname] = convert_pdc(pdcs[10], p)
+ elif 'block3_4.conv1.weight' in pname:
+ new_dict[pname] = convert_pdc(pdcs[11], p)
+ elif 'block4_1.conv1.weight' in pname:
+ new_dict[pname] = convert_pdc(pdcs[12], p)
+ elif 'block4_2.conv1.weight' in pname:
+ new_dict[pname] = convert_pdc(pdcs[13], p)
+ elif 'block4_3.conv1.weight' in pname:
+ new_dict[pname] = convert_pdc(pdcs[14], p)
+ elif 'block4_4.conv1.weight' in pname:
+ new_dict[pname] = convert_pdc(pdcs[15], p)
+ else:
+ new_dict[pname] = p
+ return new_dict
+
+
+class Conv2d(nn.Module):
+
+ def __init__(self,
+ pdc,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ bias=False):
+ super(Conv2d, self).__init__()
+ if in_channels % groups != 0:
+ raise ValueError('in_channels must be divisible by groups')
+ if out_channels % groups != 0:
+ raise ValueError('out_channels must be divisible by groups')
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.kernel_size = kernel_size
+ self.stride = stride
+ self.padding = padding
+ self.dilation = dilation
+ self.groups = groups
+ self.weight = nn.Parameter(
+ torch.Tensor(out_channels, in_channels // groups, kernel_size,
+ kernel_size))
+ if bias:
+ self.bias = nn.Parameter(torch.Tensor(out_channels))
+ else:
+ self.register_parameter('bias', None)
+ self.reset_parameters()
+ self.pdc = pdc
+
+ def reset_parameters(self):
+ nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
+ if self.bias is not None:
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
+ bound = 1 / math.sqrt(fan_in)
+ nn.init.uniform_(self.bias, -bound, bound)
+
+ def forward(self, input):
+ return self.pdc(input, self.weight, self.bias, self.stride,
+ self.padding, self.dilation, self.groups)
+
+
+class CSAM(nn.Module):
+ r"""
+ Compact Spatial Attention Module
+ """
+
+ def __init__(self, channels):
+ super(CSAM, self).__init__()
+
+ mid_channels = 4
+ self.relu1 = nn.ReLU()
+ self.conv1 = nn.Conv2d(
+ channels, mid_channels, kernel_size=1, padding=0)
+ self.conv2 = nn.Conv2d(
+ mid_channels, 1, kernel_size=3, padding=1, bias=False)
+ self.sigmoid = nn.Sigmoid()
+ nn.init.constant_(self.conv1.bias, 0)
+
+ def forward(self, x):
+ y = self.relu1(x)
+ y = self.conv1(y)
+ y = self.conv2(y)
+ y = self.sigmoid(y)
+
+ return x * y
+
+
+class CDCM(nn.Module):
+ r"""
+ Compact Dilation Convolution based Module
+ """
+
+ def __init__(self, in_channels, out_channels):
+ super(CDCM, self).__init__()
+
+ self.relu1 = nn.ReLU()
+ self.conv1 = nn.Conv2d(
+ in_channels, out_channels, kernel_size=1, padding=0)
+ self.conv2_1 = nn.Conv2d(
+ out_channels,
+ out_channels,
+ kernel_size=3,
+ dilation=5,
+ padding=5,
+ bias=False)
+ self.conv2_2 = nn.Conv2d(
+ out_channels,
+ out_channels,
+ kernel_size=3,
+ dilation=7,
+ padding=7,
+ bias=False)
+ self.conv2_3 = nn.Conv2d(
+ out_channels,
+ out_channels,
+ kernel_size=3,
+ dilation=9,
+ padding=9,
+ bias=False)
+ self.conv2_4 = nn.Conv2d(
+ out_channels,
+ out_channels,
+ kernel_size=3,
+ dilation=11,
+ padding=11,
+ bias=False)
+ nn.init.constant_(self.conv1.bias, 0)
+
+ def forward(self, x):
+ x = self.relu1(x)
+ x = self.conv1(x)
+ x1 = self.conv2_1(x)
+ x2 = self.conv2_2(x)
+ x3 = self.conv2_3(x)
+ x4 = self.conv2_4(x)
+ return x1 + x2 + x3 + x4
+
+
+class MapReduce(nn.Module):
+ r"""
+ Reduce feature maps into a single edge map
+ """
+
+ def __init__(self, channels):
+ super(MapReduce, self).__init__()
+ self.conv = nn.Conv2d(channels, 1, kernel_size=1, padding=0)
+ nn.init.constant_(self.conv.bias, 0)
+
+ def forward(self, x):
+ return self.conv(x)
+
+
+class PDCBlock(nn.Module):
+
+ def __init__(self, pdc, inplane, ouplane, stride=1):
+ super(PDCBlock, self).__init__()
+ self.stride = stride
+
+ self.stride = stride
+ if self.stride > 1:
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
+ self.shortcut = nn.Conv2d(
+ inplane, ouplane, kernel_size=1, padding=0)
+ self.conv1 = Conv2d(
+ pdc,
+ inplane,
+ inplane,
+ kernel_size=3,
+ padding=1,
+ groups=inplane,
+ bias=False)
+ self.relu2 = nn.ReLU()
+ self.conv2 = nn.Conv2d(
+ inplane, ouplane, kernel_size=1, padding=0, bias=False)
+
+ def forward(self, x):
+ if self.stride > 1:
+ x = self.pool(x)
+ y = self.conv1(x)
+ y = self.relu2(y)
+ y = self.conv2(y)
+ if self.stride > 1:
+ x = self.shortcut(x)
+ y = y + x
+ return y
+
+
+class PDCBlock_converted(nn.Module):
+ r"""
+ CPDC, APDC can be converted to vanilla 3x3 convolution
+ RPDC can be converted to vanilla 5x5 convolution
+ """
+
+ def __init__(self, pdc, inplane, ouplane, stride=1):
+ super(PDCBlock_converted, self).__init__()
+ self.stride = stride
+
+ if self.stride > 1:
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
+ self.shortcut = nn.Conv2d(
+ inplane, ouplane, kernel_size=1, padding=0)
+ if pdc == 'rd':
+ self.conv1 = nn.Conv2d(
+ inplane,
+ inplane,
+ kernel_size=5,
+ padding=2,
+ groups=inplane,
+ bias=False)
+ else:
+ self.conv1 = nn.Conv2d(
+ inplane,
+ inplane,
+ kernel_size=3,
+ padding=1,
+ groups=inplane,
+ bias=False)
+ self.relu2 = nn.ReLU()
+ self.conv2 = nn.Conv2d(
+ inplane, ouplane, kernel_size=1, padding=0, bias=False)
+
+ def forward(self, x):
+ if self.stride > 1:
+ x = self.pool(x)
+ y = self.conv1(x)
+ y = self.relu2(y)
+ y = self.conv2(y)
+ if self.stride > 1:
+ x = self.shortcut(x)
+ y = y + x
+ return y
+
+
+class PiDiNet(nn.Module):
+
+ def __init__(self, inplane, pdcs, dil=None, sa=False, convert=False):
+ super(PiDiNet, self).__init__()
+ self.sa = sa
+ if dil is not None:
+ assert isinstance(dil, int), 'dil should be an int'
+ self.dil = dil
+
+ self.fuseplanes = []
+
+ self.inplane = inplane
+ if convert:
+ if pdcs[0] == 'rd':
+ init_kernel_size = 5
+ init_padding = 2
+ else:
+ init_kernel_size = 3
+ init_padding = 1
+ self.init_block = nn.Conv2d(
+ 3,
+ self.inplane,
+ kernel_size=init_kernel_size,
+ padding=init_padding,
+ bias=False)
+ block_class = PDCBlock_converted
+ else:
+ self.init_block = Conv2d(
+ pdcs[0], 3, self.inplane, kernel_size=3, padding=1)
+ block_class = PDCBlock
+
+ self.block1_1 = block_class(pdcs[1], self.inplane, self.inplane)
+ self.block1_2 = block_class(pdcs[2], self.inplane, self.inplane)
+ self.block1_3 = block_class(pdcs[3], self.inplane, self.inplane)
+ self.fuseplanes.append(self.inplane) # C
+
+ inplane = self.inplane
+ self.inplane = self.inplane * 2
+ self.block2_1 = block_class(pdcs[4], inplane, self.inplane, stride=2)
+ self.block2_2 = block_class(pdcs[5], self.inplane, self.inplane)
+ self.block2_3 = block_class(pdcs[6], self.inplane, self.inplane)
+ self.block2_4 = block_class(pdcs[7], self.inplane, self.inplane)
+ self.fuseplanes.append(self.inplane) # 2C
+
+ inplane = self.inplane
+ self.inplane = self.inplane * 2
+ self.block3_1 = block_class(pdcs[8], inplane, self.inplane, stride=2)
+ self.block3_2 = block_class(pdcs[9], self.inplane, self.inplane)
+ self.block3_3 = block_class(pdcs[10], self.inplane, self.inplane)
+ self.block3_4 = block_class(pdcs[11], self.inplane, self.inplane)
+ self.fuseplanes.append(self.inplane) # 4C
+
+ self.block4_1 = block_class(
+ pdcs[12], self.inplane, self.inplane, stride=2)
+ self.block4_2 = block_class(pdcs[13], self.inplane, self.inplane)
+ self.block4_3 = block_class(pdcs[14], self.inplane, self.inplane)
+ self.block4_4 = block_class(pdcs[15], self.inplane, self.inplane)
+ self.fuseplanes.append(self.inplane) # 4C
+
+ self.conv_reduces = nn.ModuleList()
+ if self.sa and self.dil is not None:
+ self.attentions = nn.ModuleList()
+ self.dilations = nn.ModuleList()
+ for i in range(4):
+ self.dilations.append(CDCM(self.fuseplanes[i], self.dil))
+ self.attentions.append(CSAM(self.dil))
+ self.conv_reduces.append(MapReduce(self.dil))
+ elif self.sa:
+ self.attentions = nn.ModuleList()
+ for i in range(4):
+ self.attentions.append(CSAM(self.fuseplanes[i]))
+ self.conv_reduces.append(MapReduce(self.fuseplanes[i]))
+ elif self.dil is not None:
+ self.dilations = nn.ModuleList()
+ for i in range(4):
+ self.dilations.append(CDCM(self.fuseplanes[i], self.dil))
+ self.conv_reduces.append(MapReduce(self.dil))
+ else:
+ for i in range(4):
+ self.conv_reduces.append(MapReduce(self.fuseplanes[i]))
+
+ self.classifier = nn.Conv2d(4, 1, kernel_size=1)
+ nn.init.constant_(self.classifier.weight, 0.25)
+ nn.init.constant_(self.classifier.bias, 0)
+
+ def get_weights(self):
+ conv_weights = []
+ bn_weights = []
+ relu_weights = []
+ for pname, p in self.named_parameters():
+ if 'bn' in pname:
+ bn_weights.append(p)
+ elif 'relu' in pname:
+ relu_weights.append(p)
+ else:
+ conv_weights.append(p)
+
+ return conv_weights, bn_weights, relu_weights
+
+ def forward(self, x):
+ H, W = x.size()[2:]
+
+ x = self.init_block(x)
+
+ x1 = self.block1_1(x)
+ x1 = self.block1_2(x1)
+ x1 = self.block1_3(x1)
+
+ x2 = self.block2_1(x1)
+ x2 = self.block2_2(x2)
+ x2 = self.block2_3(x2)
+ x2 = self.block2_4(x2)
+
+ x3 = self.block3_1(x2)
+ x3 = self.block3_2(x3)
+ x3 = self.block3_3(x3)
+ x3 = self.block3_4(x3)
+
+ x4 = self.block4_1(x3)
+ x4 = self.block4_2(x4)
+ x4 = self.block4_3(x4)
+ x4 = self.block4_4(x4)
+
+ x_fuses = []
+ if self.sa and self.dil is not None:
+ for i, xi in enumerate([x1, x2, x3, x4]):
+ x_fuses.append(self.attentions[i](self.dilations[i](xi)))
+ elif self.sa:
+ for i, xi in enumerate([x1, x2, x3, x4]):
+ x_fuses.append(self.attentions[i](xi))
+ elif self.dil is not None:
+ for i, xi in enumerate([x1, x2, x3, x4]):
+ x_fuses.append(self.dilations[i](xi))
+ else:
+ x_fuses = [x1, x2, x3, x4]
+
+ e1 = self.conv_reduces[0](x_fuses[0])
+ e1 = F.interpolate(e1, (H, W), mode='bilinear', align_corners=False)
+
+ e2 = self.conv_reduces[1](x_fuses[1])
+ e2 = F.interpolate(e2, (H, W), mode='bilinear', align_corners=False)
+
+ e3 = self.conv_reduces[2](x_fuses[2])
+ e3 = F.interpolate(e3, (H, W), mode='bilinear', align_corners=False)
+
+ e4 = self.conv_reduces[3](x_fuses[3])
+ e4 = F.interpolate(e4, (H, W), mode='bilinear', align_corners=False)
+
+ outputs = [e1, e2, e3, e4]
+ output = self.classifier(torch.cat(outputs, dim=1))
+
+ outputs.append(output)
+ outputs = [torch.sigmoid(r) for r in outputs]
+ return outputs[-1]
+
+
+def pidinet_bsd_tiny(pretrained=False, vanilla_cnn=True):
+ pdcs = config_model_converted('carv4') if vanilla_cnn else config_model(
+ 'carv4')
+ model = PiDiNet(20, pdcs, dil=8, sa=True, convert=vanilla_cnn)
+ if pretrained:
+ state = torch.load(
+ DOWNLOAD_TO_CACHE('models/pidinet/table5_pidinet-tiny.pth'),
+ map_location='cpu')['state_dict']
+ if vanilla_cnn:
+ state = convert_pidinet(state, 'carv4')
+ state = {
+ k[len('module.'):] if k.startswith('module.') else k: v
+ for k, v in state.items()
+ }
+ model.load_state_dict(state)
+ return model
+
+
+def pidinet_bsd_small(pretrained=False, vanilla_cnn=True):
+ pdcs = config_model_converted('carv4') if vanilla_cnn else config_model(
+ 'carv4')
+ model = PiDiNet(30, pdcs, dil=12, sa=True, convert=vanilla_cnn)
+ if pretrained:
+ state = torch.load(
+ DOWNLOAD_TO_CACHE('models/pidinet/table5_pidinet-small.pth'),
+ map_location='cpu')['state_dict']
+ if vanilla_cnn:
+ state = convert_pidinet(state, 'carv4')
+ state = {
+ k[len('module.'):] if k.startswith('module.') else k: v
+ for k, v in state.items()
+ }
+ model.load_state_dict(state)
+ return model
+
+
+def pidinet_bsd(model_dir, pretrained=False, vanilla_cnn=True):
+ pdcs = config_model_converted('carv4') if vanilla_cnn else config_model(
+ 'carv4')
+ model = PiDiNet(60, pdcs, dil=24, sa=True, convert=vanilla_cnn)
+ if pretrained:
+ state = torch.load(
+ os.path.join(model_dir, 'table5_pidinet.pth'),
+ map_location='cpu')['state_dict']
+ if vanilla_cnn:
+ state = convert_pidinet(state, 'carv4')
+ state = {
+ k[len('module.'):] if k.startswith('module.') else k: v
+ for k, v in state.items()
+ }
+ model.load_state_dict(state)
+ return model
+
+
+def pidinet_nyud(pretrained=False, vanilla_cnn=True):
+ pdcs = config_model_converted('carv4') if vanilla_cnn else config_model(
+ 'carv4')
+ model = PiDiNet(60, pdcs, dil=24, sa=True, convert=vanilla_cnn)
+ if pretrained:
+ state = torch.load(
+ DOWNLOAD_TO_CACHE('models/pidinet/table6_pidinet.pth'),
+ map_location='cpu')['state_dict']
+ if vanilla_cnn:
+ state = convert_pidinet(state, 'carv4')
+ state = {
+ k[len('module.'):] if k.startswith('module.') else k: v
+ for k, v in state.items()
+ }
+ model.load_state_dict(state)
+ return model
+
+
+def pidinet_multicue(pretrained=False, vanilla_cnn=True):
+ pdcs = config_model_converted('carv4') if vanilla_cnn else config_model(
+ 'carv4')
+ model = PiDiNet(60, pdcs, dil=24, sa=True, convert=vanilla_cnn)
+ if pretrained:
+ state = torch.load(
+ DOWNLOAD_TO_CACHE('models/pidinet/table7_pidinet.pth'),
+ map_location='cpu')['state_dict']
+ if vanilla_cnn:
+ state = convert_pidinet(state, 'carv4')
+ state = {
+ k[len('module.'):] if k.startswith('module.') else k: v
+ for k, v in state.items()
+ }
+ model.load_state_dict(state)
+ return model
diff --git a/modelscope/models/multi_modal/videocomposer/annotator/sketch/sketch_simplification.py b/modelscope/models/multi_modal/videocomposer/annotator/sketch/sketch_simplification.py
new file mode 100644
index 00000000..2555d593
--- /dev/null
+++ b/modelscope/models/multi_modal/videocomposer/annotator/sketch/sketch_simplification.py
@@ -0,0 +1,110 @@
+r"""PyTorch re-implementation adapted from the Lua code in ``https://github.com/bobbens/sketch_simplification''.
+"""
+import math
+import os
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from modelscope.models.multi_modal.videocomposer.utils.utils import \
+ DOWNLOAD_TO_CACHE
+
+__all__ = [
+ 'SketchSimplification', 'sketch_simplification_gan',
+ 'sketch_simplification_mse', 'sketch_to_pencil_v1', 'sketch_to_pencil_v2'
+]
+
+
+class SketchSimplification(nn.Module):
+ r"""NOTE:
+ 1. Input image should has only one gray channel.
+ 2. Input image size should be divisible by 8.
+ 3. Sketch in the input/output image is in dark color while background in light color.
+ """
+
+ def __init__(self, mean, std):
+ assert isinstance(mean, float) and isinstance(std, float)
+ super(SketchSimplification, self).__init__()
+ self.mean = mean
+ self.std = std
+
+ # layers
+ self.layers = nn.Sequential(
+ nn.Conv2d(1, 48, 5, 2, 2), nn.ReLU(inplace=True),
+ nn.Conv2d(48, 128, 3, 1, 1), nn.ReLU(inplace=True),
+ nn.Conv2d(128, 128, 3, 1, 1), nn.ReLU(inplace=True),
+ nn.Conv2d(128, 128, 3, 2, 1), nn.ReLU(inplace=True),
+ nn.Conv2d(128, 256, 3, 1, 1), nn.ReLU(inplace=True),
+ nn.Conv2d(256, 256, 3, 1, 1), nn.ReLU(inplace=True),
+ nn.Conv2d(256, 256, 3, 2, 1), nn.ReLU(inplace=True),
+ nn.Conv2d(256, 512, 3, 1, 1), nn.ReLU(inplace=True),
+ nn.Conv2d(512, 1024, 3, 1, 1), nn.ReLU(inplace=True),
+ nn.Conv2d(1024, 1024, 3, 1, 1), nn.ReLU(inplace=True),
+ nn.Conv2d(1024, 1024, 3, 1, 1), nn.ReLU(inplace=True),
+ nn.Conv2d(1024, 1024, 3, 1, 1), nn.ReLU(inplace=True),
+ nn.Conv2d(1024, 512, 3, 1, 1), nn.ReLU(inplace=True),
+ nn.Conv2d(512, 256, 3, 1, 1), nn.ReLU(inplace=True),
+ nn.ConvTranspose2d(256, 256, 4, 2, 1), nn.ReLU(inplace=True),
+ nn.Conv2d(256, 256, 3, 1, 1), nn.ReLU(inplace=True),
+ nn.Conv2d(256, 128, 3, 1, 1), nn.ReLU(inplace=True),
+ nn.ConvTranspose2d(128, 128, 4, 2, 1), nn.ReLU(inplace=True),
+ nn.Conv2d(128, 128, 3, 1, 1), nn.ReLU(inplace=True),
+ nn.Conv2d(128, 48, 3, 1, 1), nn.ReLU(inplace=True),
+ nn.ConvTranspose2d(48, 48, 4, 2, 1), nn.ReLU(inplace=True),
+ nn.Conv2d(48, 24, 3, 1, 1), nn.ReLU(inplace=True),
+ nn.Conv2d(24, 1, 3, 1, 1), nn.Sigmoid())
+
+ def forward(self, x):
+ r"""x: [B, 1, H, W] within range [0, 1]. Sketch pixels in dark color.
+ """
+ x = (x - self.mean) / self.std
+ return self.layers(x)
+
+
+def sketch_simplification_gan(model_dir, pretrained=False):
+ model = SketchSimplification(
+ mean=0.9664114577640158, std=0.0858381272736797)
+ if pretrained:
+ model.load_state_dict(
+ torch.load(
+ os.path.join(model_dir, 'sketch_simplification_gan.pth'),
+ map_location='cpu'))
+ return model
+
+
+def sketch_simplification_mse(pretrained=False):
+ model = SketchSimplification(
+ mean=0.9664423107454593, std=0.08583666033640507)
+ if pretrained:
+ model.load_state_dict(
+ torch.load(
+ DOWNLOAD_TO_CACHE(
+ 'models/sketch_simplification/sketch_simplification_mse.pth'
+ ),
+ map_location='cpu'))
+ return model
+
+
+def sketch_to_pencil_v1(pretrained=False):
+ model = SketchSimplification(
+ mean=0.9817833515894078, std=0.0925009022585048)
+ if pretrained:
+ model.load_state_dict(
+ torch.load(
+ DOWNLOAD_TO_CACHE(
+ 'models/sketch_simplification/sketch_to_pencil_v1.pth'),
+ map_location='cpu'))
+ return model
+
+
+def sketch_to_pencil_v2(pretrained=False):
+ model = SketchSimplification(
+ mean=0.9851298627337799, std=0.07418377454883571)
+ if pretrained:
+ model.load_state_dict(
+ torch.load(
+ DOWNLOAD_TO_CACHE(
+ 'models/sketch_simplification/sketch_to_pencil_v2.pth'),
+ map_location='cpu'))
+ return model
diff --git a/modelscope/models/multi_modal/videocomposer/annotator/util.py b/modelscope/models/multi_modal/videocomposer/annotator/util.py
new file mode 100644
index 00000000..b02755c9
--- /dev/null
+++ b/modelscope/models/multi_modal/videocomposer/annotator/util.py
@@ -0,0 +1,42 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+
+import os
+
+import cv2
+import numpy as np
+
+annotator_ckpts_path = os.path.join(os.path.dirname(__file__), 'ckpts')
+
+
+def HWC3(x):
+ assert x.dtype == np.uint8
+ if x.ndim == 2:
+ x = x[:, :, None]
+ assert x.ndim == 3
+ H, W, C = x.shape
+ assert C == 1 or C == 3 or C == 4
+ if C == 3:
+ return x
+ if C == 1:
+ return np.concatenate([x, x, x], axis=2)
+ if C == 4:
+ color = x[:, :, 0:3].astype(np.float32)
+ alpha = x[:, :, 3:4].astype(np.float32) / 255.0
+ y = color * alpha + 255.0 * (1.0 - alpha)
+ y = y.clip(0, 255).astype(np.uint8)
+ return y
+
+
+def resize_image(input_image, resolution):
+ H, W, C = input_image.shape
+ H = float(H)
+ W = float(W)
+ k = float(resolution) / min(H, W)
+ H *= k
+ W *= k
+ H = int(np.round(H / 64.0)) * 64
+ W = int(np.round(W / 64.0)) * 64
+ img = cv2.resize(
+ input_image, (W, H),
+ interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
+ return img
diff --git a/modelscope/models/multi_modal/videocomposer/autoencoder.py b/modelscope/models/multi_modal/videocomposer/autoencoder.py
new file mode 100644
index 00000000..99b65f87
--- /dev/null
+++ b/modelscope/models/multi_modal/videocomposer/autoencoder.py
@@ -0,0 +1,650 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+__all__ = ['AutoencoderKL']
+
+
+def nonlinearity(x):
+ # swish
+ return x * torch.sigmoid(x)
+
+
+def Normalize(in_channels, num_groups=32):
+ return torch.nn.GroupNorm(
+ num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
+
+
+class DiagonalGaussianDistribution(object):
+
+ def __init__(self, parameters, deterministic=False):
+ self.parameters = parameters
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
+ self.deterministic = deterministic
+ self.std = torch.exp(0.5 * self.logvar)
+ self.var = torch.exp(self.logvar)
+ if self.deterministic:
+ self.var = self.std = torch.zeros_like(
+ self.mean).to(device=self.parameters.device)
+
+ def sample(self):
+ x = self.mean + self.std * torch.randn(
+ self.mean.shape).to(device=self.parameters.device)
+ return x
+
+ def kl(self, other=None):
+ if self.deterministic:
+ return torch.Tensor([0.])
+ else:
+ if other is None:
+ return 0.5 * torch.sum(
+ torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
+ dim=[1, 2, 3])
+ else:
+ return 0.5 * torch.sum(
+ torch.pow(self.mean - other.mean, 2) / other.var
+ + self.var / other.var - 1.0 - self.logvar + other.logvar,
+ dim=[1, 2, 3])
+
+ def nll(self, sample, dims=[1, 2, 3]):
+ if self.deterministic:
+ return torch.Tensor([0.])
+ logtwopi = np.log(2.0 * np.pi)
+ return 0.5 * torch.sum(
+ logtwopi + self.logvar
+ + torch.pow(sample - self.mean, 2) / self.var,
+ dim=dims)
+
+ def mode(self):
+ return self.mean
+
+
+class Downsample(nn.Module):
+
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ # no asymmetric padding in torch conv, must do it ourselves
+ self.conv = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=3, stride=2, padding=0)
+
+ def forward(self, x):
+ if self.with_conv:
+ pad = (0, 1, 0, 1)
+ x = torch.nn.functional.pad(x, pad, mode='constant', value=0)
+ x = self.conv(x)
+ else:
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
+ return x
+
+
+class ResnetBlock(nn.Module):
+
+ def __init__(self,
+ *,
+ in_channels,
+ out_channels=None,
+ conv_shortcut=False,
+ dropout,
+ temb_channels=512):
+ super().__init__()
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+ self.use_conv_shortcut = conv_shortcut
+
+ self.norm1 = Normalize(in_channels)
+ self.conv1 = torch.nn.Conv2d(
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1)
+ if temb_channels > 0:
+ self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
+ self.norm2 = Normalize(out_channels)
+ self.dropout = torch.nn.Dropout(dropout)
+ self.conv2 = torch.nn.Conv2d(
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1)
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ self.conv_shortcut = torch.nn.Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ else:
+ self.nin_shortcut = torch.nn.Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+ def forward(self, x, temb):
+ h = x
+ h = self.norm1(h)
+ h = nonlinearity(h)
+ h = self.conv1(h)
+
+ if temb is not None:
+ h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
+
+ h = self.norm2(h)
+ h = nonlinearity(h)
+ h = self.dropout(h)
+ h = self.conv2(h)
+
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ x = self.conv_shortcut(x)
+ else:
+ x = self.nin_shortcut(x)
+
+ return x + h
+
+
+class AttnBlock(nn.Module):
+
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+ self.k = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+ self.v = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+ self.proj_out = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b, c, h, w = q.shape
+ q = q.reshape(b, c, h * w)
+ q = q.permute(0, 2, 1) # b,hw,c
+ k = k.reshape(b, c, h * w) # b,c,hw
+ w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
+ w_ = w_ * (int(c)**(-0.5))
+ w_ = torch.nn.functional.softmax(w_, dim=2)
+
+ # attend to values
+ v = v.reshape(b, c, h * w)
+ w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
+ h_ = torch.bmm(
+ v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
+ h_ = h_.reshape(b, c, h, w)
+
+ h_ = self.proj_out(h_)
+
+ return x + h_
+
+
+class AttnBlock(nn.Module): # noqa
+
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+ self.k = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+ self.v = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+ self.proj_out = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b, c, h, w = q.shape
+ q = q.reshape(b, c, h * w)
+ q = q.permute(0, 2, 1) # b,hw,c
+ k = k.reshape(b, c, h * w) # b,c,hw
+ w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
+ w_ = w_ * (int(c)**(-0.5))
+ w_ = torch.nn.functional.softmax(w_, dim=2)
+
+ # attend to values
+ v = v.reshape(b, c, h * w)
+ w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
+ h_ = torch.bmm(
+ v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
+ h_ = h_.reshape(b, c, h, w)
+
+ h_ = self.proj_out(h_)
+
+ return x + h_
+
+
+class Upsample(nn.Module):
+
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ self.conv = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=3, stride=1, padding=1)
+
+ def forward(self, x):
+ x = torch.nn.functional.interpolate(
+ x, scale_factor=2.0, mode='nearest')
+ if self.with_conv:
+ x = self.conv(x)
+ return x
+
+
+class Downsample(nn.Module): # noqa
+
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ # no asymmetric padding in torch conv, must do it ourselves
+ self.conv = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=3, stride=2, padding=0)
+
+ def forward(self, x):
+ if self.with_conv:
+ pad = (0, 1, 0, 1)
+ x = torch.nn.functional.pad(x, pad, mode='constant', value=0)
+ x = self.conv(x)
+ else:
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
+ return x
+
+
+class Encoder(nn.Module):
+
+ def __init__(self,
+ *,
+ ch,
+ out_ch,
+ ch_mult=(1, 2, 4, 8),
+ num_res_blocks,
+ attn_resolutions,
+ dropout=0.0,
+ resamp_with_conv=True,
+ in_channels,
+ resolution,
+ z_channels,
+ double_z=True,
+ use_linear_attn=False,
+ attn_type='vanilla',
+ **ignore_kwargs):
+ super().__init__()
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+
+ # downsampling
+ self.conv_in = torch.nn.Conv2d(
+ in_channels, self.ch, kernel_size=3, stride=1, padding=1)
+
+ curr_res = resolution
+ in_ch_mult = (1, ) + tuple(ch_mult)
+ self.in_ch_mult = in_ch_mult
+ self.down = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_in = ch * in_ch_mult[i_level]
+ block_out = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks):
+ block.append(
+ ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(AttnBlock(block_in))
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ if i_level != self.num_resolutions - 1:
+ down.downsample = Downsample(block_in, resamp_with_conv)
+ curr_res = curr_res // 2
+ self.down.append(down)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ self.mid.attn_1 = AttnBlock(block_in)
+ self.mid.block_2 = ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(
+ block_in,
+ 2 * z_channels if double_z else z_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x):
+ # timestep embedding
+ temb = None
+
+ # downsampling
+ hs = [self.conv_in(x)]
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](hs[-1], temb)
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h)
+ hs.append(h)
+ if i_level != self.num_resolutions - 1:
+ hs.append(self.down[i_level].downsample(hs[-1]))
+
+ # middle
+ h = hs[-1]
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # end
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+
+class Decoder(nn.Module):
+
+ def __init__(self,
+ *,
+ ch,
+ out_ch,
+ ch_mult=(1, 2, 4, 8),
+ num_res_blocks,
+ attn_resolutions,
+ dropout=0.0,
+ resamp_with_conv=True,
+ in_channels,
+ resolution,
+ z_channels,
+ give_pre_end=False,
+ tanh_out=False,
+ use_linear_attn=False,
+ attn_type='vanilla',
+ **ignorekwargs):
+ super().__init__()
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+ self.give_pre_end = give_pre_end
+ self.tanh_out = tanh_out
+
+ # compute in_ch_mult, block_in and curr_res at lowest res
+ block_in = ch * ch_mult[self.num_resolutions - 1]
+ curr_res = resolution // 2**(self.num_resolutions - 1)
+ self.z_shape = (1, z_channels, curr_res, curr_res)
+ print('Working with z of shape {} = {} dimensions.'.format(
+ self.z_shape, np.prod(self.z_shape)))
+
+ # z to block_in
+ self.conv_in = torch.nn.Conv2d(
+ z_channels, block_in, kernel_size=3, stride=1, padding=1)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ self.mid.attn_1 = AttnBlock(block_in)
+ self.mid.block_2 = ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks + 1):
+ block.append(
+ ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(AttnBlock(block_in))
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ up.upsample = Upsample(block_in, resamp_with_conv)
+ curr_res = curr_res * 2
+ self.up.insert(0, up) # prepend to get consistent order
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(
+ block_in, out_ch, kernel_size=3, stride=1, padding=1)
+
+ def forward(self, z):
+ self.last_z_shape = z.shape
+
+ # timestep embedding
+ temb = None
+
+ # z to block_in
+ h = self.conv_in(z)
+
+ # middle
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks + 1):
+ h = self.up[i_level].block[i_block](h, temb)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h)
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+
+ # end
+ if self.give_pre_end:
+ return h
+
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ if self.tanh_out:
+ h = torch.tanh(h)
+ return h
+
+
+class AutoencoderKL(nn.Module):
+
+ def __init__(self,
+ ddconfig,
+ embed_dim,
+ ckpt_path=None,
+ ignore_keys=[],
+ image_key='image',
+ colorize_nlabels=None,
+ monitor=None,
+ ema_decay=None,
+ learn_logvar=False):
+ super().__init__()
+ self.learn_logvar = learn_logvar
+ self.image_key = image_key
+ self.encoder = Encoder(**ddconfig)
+ self.decoder = Decoder(**ddconfig)
+ assert ddconfig['double_z']
+ self.quant_conv = torch.nn.Conv2d(2 * ddconfig['z_channels'],
+ 2 * embed_dim, 1)
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim,
+ ddconfig['z_channels'], 1)
+ self.embed_dim = embed_dim
+ if colorize_nlabels is not None:
+ assert type(colorize_nlabels) == int
+ self.register_buffer('colorize',
+ torch.randn(3, colorize_nlabels, 1, 1))
+ if monitor is not None:
+ self.monitor = monitor
+ self.use_ema = ema_decay is not None
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
+
+ def init_from_ckpt(self, path, ignore_keys=list()):
+ sd = torch.load(path, map_location='cpu')['state_dict']
+ keys = list(sd.keys())
+ for key in keys:
+ print(key, sd[key].shape)
+ import collections
+ sd_new = collections.OrderedDict()
+ for k in keys:
+ if k.find('first_stage_model') >= 0:
+ k_new = k.split('first_stage_model.')[-1]
+ sd_new[k_new] = sd[k]
+ self.load_state_dict(sd_new, strict=True)
+ print(f'Restored from {path}')
+
+ def init_from_ckpt2(self, path, ignore_keys=list()):
+ sd = torch.load(path, map_location='cpu')['state_dict']
+ keys = list(sd.keys())
+
+ first_stage_model
+ for k in keys:
+ for ik in ignore_keys:
+ if k.startswith(ik):
+ print('Deleting key {} from state_dict.'.format(k))
+ del sd[k]
+ self.load_state_dict(sd, strict=False)
+ print(f'Restored from {path}')
+
+ def on_train_batch_end(self, *args, **kwargs):
+ if self.use_ema:
+ self.model_ema(self)
+
+ def encode(self, x):
+ h = self.encoder(x)
+ moments = self.quant_conv(h)
+ posterior = DiagonalGaussianDistribution(moments)
+ return posterior
+
+ def decode(self, z):
+ z = self.post_quant_conv(z)
+ dec = self.decoder(z)
+ return dec
+
+ def forward(self, input, sample_posterior=True):
+ posterior = self.encode(input)
+ if sample_posterior:
+ z = posterior.sample()
+ else:
+ z = posterior.mode()
+ dec = self.decode(z)
+ return dec, posterior
+
+ def get_input(self, batch, k):
+ x = batch[k]
+ if len(x.shape) == 3:
+ x = x[..., None]
+ x = x.permute(0, 3, 1,
+ 2).to(memory_format=torch.contiguous_format).float()
+ return x
+
+ def get_last_layer(self):
+ return self.decoder.conv_out.weight
+
+ @torch.no_grad()
+ def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs):
+ log = dict()
+ x = self.get_input(batch, self.image_key)
+ x = x.to(self.device)
+ if not only_inputs:
+ xrec, posterior = self(x)
+ if x.shape[1] > 3:
+ # colorize with random projection
+ assert xrec.shape[1] > 3
+ x = self.to_rgb(x)
+ xrec = self.to_rgb(xrec)
+ log['samples'] = self.decode(torch.randn_like(posterior.sample()))
+ log['reconstructions'] = xrec
+ if log_ema or self.use_ema:
+ with self.ema_scope():
+ xrec_ema, posterior_ema = self(x)
+ if x.shape[1] > 3:
+ # colorize with random projection
+ assert xrec_ema.shape[1] > 3
+ xrec_ema = self.to_rgb(xrec_ema)
+ log['samples_ema'] = self.decode(
+ torch.randn_like(posterior_ema.sample()))
+ log['reconstructions_ema'] = xrec_ema
+ log['inputs'] = x
+ return log
+
+ def to_rgb(self, x):
+ assert self.image_key == 'segmentation'
+ if not hasattr(self, 'colorize'):
+ self.register_buffer('colorize',
+ torch.randn(3, x.shape[1], 1, 1).to(x))
+ x = F.conv2d(x, weight=self.colorize)
+ x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
+ return x
+
+
+class IdentityFirstStage(torch.nn.Module):
+
+ def __init__(self, *args, vq_interface=False, **kwargs):
+ self.vq_interface = vq_interface
+ super().__init__()
+
+ def encode(self, x, *args, **kwargs):
+ return x
+
+ def decode(self, x, *args, **kwargs):
+ return x
+
+ def quantize(self, x, *args, **kwargs):
+ if self.vq_interface:
+ return x, None, [None, None, None]
+ return x
+
+ def forward(self, x, *args, **kwargs):
+ return x
diff --git a/modelscope/models/multi_modal/videocomposer/clip.py b/modelscope/models/multi_modal/videocomposer/clip.py
new file mode 100644
index 00000000..f55a5206
--- /dev/null
+++ b/modelscope/models/multi_modal/videocomposer/clip.py
@@ -0,0 +1,143 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+
+import numpy as np
+import open_clip
+import torch
+import torch.nn as nn
+import torchvision.transforms as T
+
+
+class FrozenOpenCLIPEmbedder(nn.Module):
+ """
+ Uses the OpenCLIP transformer encoder for text
+ """
+ LAYERS = ['last', 'penultimate']
+
+ def __init__(self,
+ arch='ViT-H-14',
+ pretrained='laion2b_s32b_b79k',
+ device='cuda',
+ max_length=77,
+ freeze=True,
+ layer='last'):
+ super().__init__()
+ assert layer in self.LAYERS
+ model, _, _ = open_clip.create_model_and_transforms(
+ arch, device=torch.device('cpu'), pretrained=pretrained)
+ del model.visual
+ self.model = model
+
+ self.device = device
+ self.max_length = max_length
+ if freeze:
+ self.freeze()
+ self.layer = layer
+ if self.layer == 'last':
+ self.layer_idx = 0
+ elif self.layer == 'penultimate':
+ self.layer_idx = 1
+ else:
+ raise NotImplementedError()
+
+ def freeze(self):
+ self.model = self.model.eval()
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, text):
+ tokens = open_clip.tokenize(text)
+ z = self.encode_with_transformer(tokens.to(self.device))
+ return z
+
+ def encode_with_transformer(self, text):
+ x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
+ x = x + self.model.positional_embedding
+ x = x.permute(1, 0, 2) # NLD -> LND
+ x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
+ x = x.permute(1, 0, 2) # LND -> NLD
+ x = self.model.ln_final(x)
+ return x
+
+ def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
+ for i, r in enumerate(self.model.transformer.resblocks):
+ if i == len(self.model.transformer.resblocks) - self.layer_idx:
+ break
+ if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting(
+ ):
+ x = checkpoint(r, x, attn_mask)
+ else:
+ x = r(x, attn_mask=attn_mask)
+ return x
+
+ def encode(self, text):
+ return self(text)
+
+
+class FrozenOpenCLIPVisualEmbedder(nn.Module):
+ """
+ Uses the OpenCLIP transformer encoder for text
+ """
+ LAYERS = ['last', 'penultimate']
+
+ def __init__(self,
+ arch='ViT-H-14',
+ pretrained='laion2b_s32b_b79k',
+ device='cuda',
+ max_length=77,
+ freeze=True,
+ layer='last',
+ input_shape=(224, 224, 3)):
+ super().__init__()
+ assert layer in self.LAYERS
+ model, _, preprocess = open_clip.create_model_and_transforms(
+ arch, device=torch.device('cpu'), pretrained=pretrained)
+ del model.transformer
+ self.model = model
+ data_white = np.ones(input_shape, dtype=np.uint8) * 255
+ self.black_image = preprocess(T.ToPILImage()(data_white)).unsqueeze(0)
+ self.preprocess = preprocess
+
+ self.device = device
+ self.max_length = max_length # 77
+ if freeze:
+ self.freeze()
+ self.layer = layer # 'penultimate'
+ if self.layer == 'last':
+ self.layer_idx = 0
+ elif self.layer == 'penultimate':
+ self.layer_idx = 1
+ else:
+ raise NotImplementedError()
+
+ def freeze(self):
+ self.model = self.model.eval()
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, image):
+ # tokens = open_clip.tokenize(text)
+ z = self.model.encode_image(image.to(self.device))
+ return z
+
+ def encode_with_transformer(self, text):
+ x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
+ x = x + self.model.positional_embedding
+ x = x.permute(1, 0, 2) # NLD -> LND
+ x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
+ x = x.permute(1, 0, 2) # LND -> NLD
+ x = self.model.ln_final(x)
+ return x
+
+ def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
+ for i, r in enumerate(self.model.transformer.resblocks):
+ if i == len(self.model.transformer.resblocks) - self.layer_idx:
+ break
+ if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting(
+ ):
+ x = checkpoint(r, x, attn_mask)
+ else:
+ x = r(x, attn_mask=attn_mask)
+ return x
+
+ def encode(self, text):
+ return self(text)
diff --git a/modelscope/models/multi_modal/videocomposer/config.py b/modelscope/models/multi_modal/videocomposer/config.py
new file mode 100644
index 00000000..499d8484
--- /dev/null
+++ b/modelscope/models/multi_modal/videocomposer/config.py
@@ -0,0 +1,156 @@
+import logging
+import os
+import os.path as osp
+from datetime import datetime
+
+import torch
+from easydict import EasyDict
+
+cfg = EasyDict(__name__='Config: VideoComposer')
+
+pmi_world_size = int(os.getenv('WORLD_SIZE', 1))
+gpus_per_machine = torch.cuda.device_count()
+world_size = pmi_world_size * gpus_per_machine
+
+cfg.video_compositions = [
+ 'text', 'mask', 'depthmap', 'sketch', 'motion', 'image', 'local_image',
+ 'single_sketch'
+]
+
+# dataset
+cfg.root_dir = 'webvid10m/'
+
+cfg.alpha = 0.7
+
+cfg.misc_size = 384
+cfg.depth_std = 20.0
+cfg.depth_clamp = 10.0
+cfg.hist_sigma = 10.0
+
+cfg.use_image_dataset = False
+cfg.alpha_img = 0.7
+
+cfg.resolution = 256
+cfg.mean = [0.5, 0.5, 0.5]
+cfg.std = [0.5, 0.5, 0.5]
+
+# sketch
+cfg.sketch_mean = [0.485, 0.456, 0.406]
+cfg.sketch_std = [0.229, 0.224, 0.225]
+
+# dataloader
+cfg.max_words = 1000
+
+cfg.frame_lens = [
+ 16,
+ 16,
+ 16,
+ 16,
+]
+cfg.feature_framerates = [
+ 4,
+]
+cfg.feature_framerate = 4
+cfg.batch_sizes = {
+ str(1): 1,
+ str(4): 1,
+ str(8): 1,
+ str(16): 1,
+}
+
+cfg.chunk_size = 64
+cfg.num_workers = 8
+cfg.prefetch_factor = 2
+cfg.seed = 8888
+
+# diffusion
+cfg.num_timesteps = 1000
+cfg.mean_type = 'eps'
+cfg.var_type = 'fixed_small'
+cfg.loss_type = 'mse'
+cfg.ddim_timesteps = 50
+cfg.ddim_eta = 0.0
+cfg.clamp = 1.0
+cfg.share_noise = False
+cfg.use_div_loss = False
+
+# classifier-free guidance
+cfg.p_zero = 0.9
+cfg.guide_scale = 6.0
+
+# stabel diffusion
+cfg.sd_checkpoint = 'v2-1_512-ema-pruned.ckpt'
+
+# clip vision encoder
+cfg.vit_image_size = 336
+cfg.vit_patch_size = 14
+cfg.vit_dim = 1024
+cfg.vit_out_dim = 768
+cfg.vit_heads = 16
+cfg.vit_layers = 24
+cfg.vit_mean = [0.48145466, 0.4578275, 0.40821073]
+cfg.vit_std = [0.26862954, 0.26130258, 0.27577711]
+cfg.clip_checkpoint = 'open_clip_pytorch_model.bin'
+cfg.mvs_visual = False
+
+# unet
+cfg.unet_in_dim = 4
+cfg.unet_concat_dim = 8
+cfg.unet_y_dim = cfg.vit_out_dim
+cfg.unet_context_dim = 1024
+cfg.unet_out_dim = 8 if cfg.var_type.startswith('learned') else 4
+cfg.unet_dim = 320
+cfg.unet_dim_mult = [1, 2, 4, 4]
+cfg.unet_res_blocks = 2
+cfg.unet_num_heads = 8
+cfg.unet_head_dim = 64
+cfg.unet_attn_scales = [1 / 1, 1 / 2, 1 / 4]
+cfg.unet_dropout = 0.1
+cfg.misc_dropout = 0.5
+cfg.p_all_zero = 0.1
+cfg.p_all_keep = 0.1
+cfg.temporal_conv = False
+cfg.temporal_attn_times = 1
+cfg.temporal_attention = True
+
+cfg.use_fps_condition = False
+cfg.use_sim_mask = False
+
+# Default: load 2d pretrain
+cfg.pretrained = False
+cfg.fix_weight = False
+
+# Default resume
+cfg.resume = True
+cfg.resume_step = 148000
+cfg.resume_check_dir = '.'
+cfg.resume_checkpoint = os.path.join(
+ cfg.resume_check_dir,
+ f'step_{cfg.resume_step}/non_ema_{cfg.resume_step}.pth')
+cfg.resume_optimizer = False
+if cfg.resume_optimizer:
+ cfg.resume_optimizer = os.path.join(
+ cfg.resume_check_dir, f'optimizer_step_{cfg.resume_step}.pt')
+
+# acceleration
+cfg.use_ema = True
+# for debug, no ema
+if world_size < 2:
+ cfg.use_ema = False
+cfg.load_from = None
+
+cfg.use_checkpoint = True
+cfg.use_sharded_ddp = False
+cfg.use_fsdp = False
+cfg.use_fp16 = True
+
+# training
+cfg.ema_decay = 0.9999
+cfg.viz_interval = 1000
+cfg.save_ckp_interval = 1000
+
+# logging
+cfg.log_interval = 100
+composition_strings = '_'.join(cfg.video_compositions)
+# Default log_dir
+cfg.log_dir = 'outputs/'
diff --git a/modelscope/models/multi_modal/videocomposer/configs/base.yaml b/modelscope/models/multi_modal/videocomposer/configs/base.yaml
new file mode 100644
index 00000000..42f756f8
--- /dev/null
+++ b/modelscope/models/multi_modal/videocomposer/configs/base.yaml
@@ -0,0 +1,2 @@
+ENABLE: true
+DATASET: webvid10m
diff --git a/modelscope/models/multi_modal/videocomposer/configs/exp01_vidcomposer_full.yaml b/modelscope/models/multi_modal/videocomposer/configs/exp01_vidcomposer_full.yaml
new file mode 100644
index 00000000..ec312138
--- /dev/null
+++ b/modelscope/models/multi_modal/videocomposer/configs/exp01_vidcomposer_full.yaml
@@ -0,0 +1,20 @@
+TASK_TYPE: MULTI_TASK
+ENABLE: true
+DATASET: webvid10m
+video_compositions: ['text', 'mask', 'depthmap', 'sketch', 'motion', 'image', 'local_image', 'single_sketch']
+batch_sizes: {
+ "1": 1,
+ "4": 1,
+ "8": 1,
+ "16": 1,
+}
+vit_image_size: 224
+network_name: UNetSD_temporal
+resume: true
+resume_step: 228000
+num_workers: 1
+mvs_visual: False
+chunk_size: 1
+resume_checkpoint: "model_weights/non_ema_228000.pth"
+log_dir: 'outputs'
+num_steps: 1
diff --git a/modelscope/models/multi_modal/videocomposer/configs/exp02_motion_transfer.yaml b/modelscope/models/multi_modal/videocomposer/configs/exp02_motion_transfer.yaml
new file mode 100644
index 00000000..4b756d32
--- /dev/null
+++ b/modelscope/models/multi_modal/videocomposer/configs/exp02_motion_transfer.yaml
@@ -0,0 +1,23 @@
+TASK_TYPE: SINGLE_TASK
+read_image: True # You NEED Open It
+ENABLE: true
+DATASET: webvid10m
+video_compositions: ['text', 'mask', 'depthmap', 'sketch', 'motion', 'image', 'local_image', 'single_sketch']
+guidances: ['y', 'local_image', 'motion'] # You NEED Open It
+batch_sizes: {
+ "1": 1,
+ "4": 1,
+ "8": 1,
+ "16": 1,
+}
+vit_image_size: 224
+network_name: UNetSD_temporal
+resume: true
+resume_step: 228000
+seed: 182
+num_workers: 0
+mvs_visual: False
+chunk_size: 1
+resume_checkpoint: "model_weights/non_ema_228000.pth"
+log_dir: 'outputs'
+num_steps: 1
diff --git a/modelscope/models/multi_modal/videocomposer/configs/exp02_motion_transfer_vs_style.yaml b/modelscope/models/multi_modal/videocomposer/configs/exp02_motion_transfer_vs_style.yaml
new file mode 100644
index 00000000..7928e7ba
--- /dev/null
+++ b/modelscope/models/multi_modal/videocomposer/configs/exp02_motion_transfer_vs_style.yaml
@@ -0,0 +1,24 @@
+TASK_TYPE: SINGLE_TASK
+read_image: True # You NEED Open It
+read_style: True
+ENABLE: true
+DATASET: webvid10m
+video_compositions: ['text', 'mask', 'depthmap', 'sketch', 'motion', 'image', 'local_image', 'single_sketch']
+guidances: ['y', 'local_image', 'image', 'motion'] # You NEED Open It
+batch_sizes: {
+ "1": 1,
+ "4": 1,
+ "8": 1,
+ "16": 1,
+}
+vit_image_size: 224
+network_name: UNetSD_temporal
+resume: true
+resume_step: 228000
+seed: 182
+num_workers: 0
+mvs_visual: False
+chunk_size: 1
+resume_checkpoint: "model_weights/non_ema_228000.pth"
+log_dir: 'outputs'
+num_steps: 1
diff --git a/modelscope/models/multi_modal/videocomposer/configs/exp03_sketch2video_style.yaml b/modelscope/models/multi_modal/videocomposer/configs/exp03_sketch2video_style.yaml
new file mode 100644
index 00000000..fd710ee5
--- /dev/null
+++ b/modelscope/models/multi_modal/videocomposer/configs/exp03_sketch2video_style.yaml
@@ -0,0 +1,26 @@
+TASK_TYPE: SINGLE_TASK
+read_image: False # You NEED Open It
+read_style: True
+read_sketch: True
+save_origin_video: False
+ENABLE: true
+DATASET: webvid10m
+video_compositions: ['text', 'mask', 'depthmap', 'sketch', 'motion', 'image', 'local_image', 'single_sketch']
+guidances: ['y', 'image', 'single_sketch'] # You NEED Open It
+batch_sizes: {
+ "1": 1,
+ "4": 1,
+ "8": 1,
+ "16": 1,
+}
+vit_image_size: 224
+network_name: UNetSD_temporal
+resume: true
+resume_step: 228000
+seed: 182
+num_workers: 0
+mvs_visual: False
+chunk_size: 1
+resume_checkpoint: "model_weights/non_ema_228000.pth"
+log_dir: 'outputs'
+num_steps: 1
diff --git a/modelscope/models/multi_modal/videocomposer/configs/exp04_sketch2video_wo_style.yaml b/modelscope/models/multi_modal/videocomposer/configs/exp04_sketch2video_wo_style.yaml
new file mode 100644
index 00000000..a5cc54bf
--- /dev/null
+++ b/modelscope/models/multi_modal/videocomposer/configs/exp04_sketch2video_wo_style.yaml
@@ -0,0 +1,26 @@
+TASK_TYPE: SINGLE_TASK
+read_image: False # You NEED Open It
+read_style: False
+read_sketch: True
+save_origin_video: False
+ENABLE: true
+DATASET: webvid10m
+video_compositions: ['text', 'mask', 'depthmap', 'sketch', 'motion', 'image', 'local_image', 'single_sketch']
+guidances: ['y', 'single_sketch'] # You NEED Open It
+batch_sizes: {
+ "1": 1,
+ "4": 1,
+ "8": 1,
+ "16": 1,
+}
+vit_image_size: 224
+network_name: UNetSD_temporal
+resume: true
+resume_step: 228000
+seed: 182
+num_workers: 0
+mvs_visual: False
+chunk_size: 1
+resume_checkpoint: "model_weights/non_ema_228000.pth"
+log_dir: 'outputs'
+num_steps: 1
diff --git a/modelscope/models/multi_modal/videocomposer/configs/exp05_text_depths_wo_style.yaml b/modelscope/models/multi_modal/videocomposer/configs/exp05_text_depths_wo_style.yaml
new file mode 100644
index 00000000..29c053b1
--- /dev/null
+++ b/modelscope/models/multi_modal/videocomposer/configs/exp05_text_depths_wo_style.yaml
@@ -0,0 +1,26 @@
+TASK_TYPE: SINGLE_TASK
+read_image: False # You NEED Open It
+read_style: False
+read_sketch: False
+save_origin_video: True
+ENABLE: true
+DATASET: webvid10m
+video_compositions: ['text', 'mask', 'depthmap', 'sketch', 'motion', 'image', 'local_image', 'single_sketch']
+guidances: ['y', 'depth'] # You NEED Open It
+batch_sizes: {
+ "1": 1,
+ "4": 1,
+ "8": 1,
+ "16": 1,
+}
+vit_image_size: 224
+network_name: UNetSD_temporal
+resume: true
+resume_step: 228000
+seed: 182
+num_workers: 0
+mvs_visual: False
+chunk_size: 1
+resume_checkpoint: "model_weights/non_ema_228000.pth"
+log_dir: 'outputs'
+num_steps: 1
diff --git a/modelscope/models/multi_modal/videocomposer/configs/exp06_text_depths_vs_style.yaml b/modelscope/models/multi_modal/videocomposer/configs/exp06_text_depths_vs_style.yaml
new file mode 100644
index 00000000..2732dc5c
--- /dev/null
+++ b/modelscope/models/multi_modal/videocomposer/configs/exp06_text_depths_vs_style.yaml
@@ -0,0 +1,26 @@
+TASK_TYPE: SINGLE_TASK
+read_image: False
+read_style: True
+read_sketch: False
+save_origin_video: True
+ENABLE: true
+DATASET: webvid10m
+video_compositions: ['text', 'mask', 'depthmap', 'sketch', 'motion', 'image', 'local_image', 'single_sketch']
+guidances: ['y', 'image', 'depth'] # You NEED Open It
+batch_sizes: {
+ "1": 1,
+ "4": 1,
+ "8": 1,
+ "16": 1,
+}
+vit_image_size: 224
+network_name: UNetSD_temporal
+resume: true
+resume_step: 228000
+seed: 182
+num_workers: 0
+mvs_visual: False
+chunk_size: 1
+resume_checkpoint: "model_weights/non_ema_141000_no_watermark.pth"
+log_dir: 'outputs'
+num_steps: 1
diff --git a/modelscope/models/multi_modal/videocomposer/configs/exp10_vidcomposer_no_watermark_full.yaml b/modelscope/models/multi_modal/videocomposer/configs/exp10_vidcomposer_no_watermark_full.yaml
new file mode 100644
index 00000000..1be311d8
--- /dev/null
+++ b/modelscope/models/multi_modal/videocomposer/configs/exp10_vidcomposer_no_watermark_full.yaml
@@ -0,0 +1,21 @@
+TASK_TYPE: VideoComposer_Inference
+ENABLE: true
+DATASET: webvid10m
+video_compositions: ['text', 'mask', 'depthmap', 'sketch', 'motion', 'image', 'local_image', 'single_sketch']
+batch_sizes: {
+ "1": 1,
+ "4": 1,
+ "8": 1,
+ "16": 1,
+}
+vit_image_size: 224
+network_name: UNetSD_temporal
+resume: true
+resume_step: 141000
+seed: 14
+num_workers: 1
+mvs_visual: True
+chunk_size: 1
+resume_checkpoint: "model_weights/non_ema_141000_no_watermark.pth"
+log_dir: 'outputs'
+num_steps: 1
diff --git a/modelscope/models/multi_modal/videocomposer/data/__init__.py b/modelscope/models/multi_modal/videocomposer/data/__init__.py
new file mode 100644
index 00000000..ba8d1233
--- /dev/null
+++ b/modelscope/models/multi_modal/videocomposer/data/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+
+from .samplers import *
+from .tokenizers import *
+from .transforms import *
diff --git a/modelscope/models/multi_modal/videocomposer/data/samplers.py b/modelscope/models/multi_modal/videocomposer/data/samplers.py
new file mode 100644
index 00000000..4cadef95
--- /dev/null
+++ b/modelscope/models/multi_modal/videocomposer/data/samplers.py
@@ -0,0 +1,158 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+
+import os.path as osp
+
+import json
+import numpy as np
+from torch.utils.data.sampler import Sampler
+
+from modelscope.models.multi_modal.videocomposer.ops.distributed import (
+ get_rank, get_world_size, shared_random_seed)
+from modelscope.models.multi_modal.videocomposer.ops.utils import (ceil_divide,
+ read)
+
+__all__ = ['BatchSampler', 'GroupSampler', 'ImgGroupSampler']
+
+
+class BatchSampler(Sampler):
+ r"""An infinite batch sampler.
+ """
+
+ def __init__(self,
+ dataset_size,
+ batch_size,
+ num_replicas=None,
+ rank=None,
+ shuffle=False,
+ seed=None):
+ self.dataset_size = dataset_size
+ self.batch_size = batch_size
+ self.num_replicas = num_replicas or get_world_size()
+ self.rank = rank or get_rank()
+ self.shuffle = shuffle
+ self.seed = seed or shared_random_seed()
+ self.rng = np.random.default_rng(self.seed + self.rank)
+ self.batches_per_rank = ceil_divide(
+ dataset_size, self.num_replicas * self.batch_size)
+ self.samples_per_rank = self.batches_per_rank * self.batch_size
+
+ # rank indices
+ indices = self.rng.permutation(
+ self.samples_per_rank) if shuffle else np.arange(
+ self.samples_per_rank)
+ indices = indices * self.num_replicas + self.rank
+ indices = indices[indices < dataset_size]
+ self.indices = indices
+
+ def __iter__(self):
+ start = 0
+ while True:
+ batch = [
+ self.indices[i % len(self.indices)]
+ for i in range(start, start + self.batch_size)
+ ]
+ if self.shuffle and (start + self.batch_size) > len(self.indices):
+ self.rng.shuffle(self.indices)
+ start = (start + self.batch_size) % len(self.indices)
+ yield batch
+
+
+class GroupSampler(Sampler):
+
+ def __init__(self,
+ group_file,
+ batch_size,
+ alpha=0.7,
+ update_interval=5000,
+ seed=8888):
+ self.group_file = group_file
+ self.group_folder = osp.join(osp.dirname(group_file), 'groups')
+ self.batch_size = batch_size
+ self.alpha = alpha
+ self.update_interval = update_interval
+ self.seed = seed
+ self.rng = np.random.default_rng(seed)
+
+ def __iter__(self):
+ while True:
+ # keep groups up-to-date
+ self.update_groups()
+
+ # collect items
+ items = self.sample()
+ while len(items) < self.batch_size:
+ items += self.sample()
+
+ # sample a batch
+ batch = self.rng.choice(
+ items,
+ self.batch_size,
+ replace=False if len(items) >= self.batch_size else True)
+ yield [u.strip().split(',') for u in batch]
+
+ def update_groups(self):
+ if not hasattr(self, '_step'):
+ self._step = 0
+ if self._step % self.update_interval == 0:
+ self.groups = json.loads(read(self.group_file))
+ self._step += 1
+
+ def sample(self):
+ scales = np.array(
+ [float(next(iter(u)).split(':')[-1]) for u in self.groups])
+ p = scales**self.alpha / (scales**self.alpha).sum()
+ group = self.rng.choice(self.groups, p=p)
+ list_file = osp.join(self.group_folder,
+ self.rng.choice(next(iter(group.values()))))
+ return read(list_file).strip().split('\n')
+
+
+class ImgGroupSampler(Sampler):
+
+ def __init__(self,
+ group_file,
+ batch_size,
+ alpha=0.7,
+ update_interval=5000,
+ seed=8888):
+ self.group_file = group_file
+ self.group_folder = osp.join(osp.dirname(group_file), 'groups')
+ self.batch_size = batch_size
+ self.alpha = alpha
+ self.update_interval = update_interval
+ self.seed = seed
+ self.rng = np.random.default_rng(seed)
+
+ def __iter__(self):
+ while True:
+ # keep groups up-to-date
+ self.update_groups()
+
+ # collect items
+ items = self.sample()
+ while len(items) < self.batch_size:
+ items += self.sample()
+
+ # sample a batch
+ batch = self.rng.choice(
+ items,
+ self.batch_size,
+ replace=False if len(items) >= self.batch_size else True)
+ yield [u.strip().split(',', 1) for u in batch]
+
+ def update_groups(self):
+ if not hasattr(self, '_step'):
+ self._step = 0
+ if self._step % self.update_interval == 0:
+ self.groups = json.loads(read(self.group_file))
+
+ self._step += 1
+
+ def sample(self):
+ scales = np.array(
+ [float(next(iter(u)).split(':')[-1]) for u in self.groups])
+ p = scales**self.alpha / (scales**self.alpha).sum()
+ group = self.rng.choice(self.groups, p=p)
+ list_file = osp.join(self.group_folder,
+ self.rng.choice(next(iter(group.values()))))
+ return read(list_file).strip().split('\n')
diff --git a/modelscope/models/multi_modal/videocomposer/data/tokenizers.py b/modelscope/models/multi_modal/videocomposer/data/tokenizers.py
new file mode 100644
index 00000000..e3ab506b
--- /dev/null
+++ b/modelscope/models/multi_modal/videocomposer/data/tokenizers.py
@@ -0,0 +1,184 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+
+import gzip
+import html
+import os
+from functools import lru_cache
+
+import ftfy
+import regex as re
+import torch
+from tokenizers import BertWordPieceTokenizer, CharBPETokenizer
+
+__all__ = ['CLIPTokenizer']
+
+
+@lru_cache()
+def default_bpe():
+ root = os.path.realpath(__file__)
+ root = '/'.join(root.split('/')[:-1])
+ return os.path.join(root, 'bpe_simple_vocab_16e6.txt.gz')
+
+
+@lru_cache()
+def bytes_to_unicode():
+ """
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
+ The reversible bpe codes work on unicode strings.
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
+ """
+ bs = list(range(ord('!'),
+ ord('~') + 1)) + list(range(
+ ord('¡'),
+ ord('¬') + 1)) + list(range(ord('®'),
+ ord('ÿ') + 1))
+ cs = bs[:]
+ n = 0
+ for b in range(2**8):
+ if b not in bs:
+ bs.append(b)
+ cs.append(2**8 + n)
+ n += 1
+ cs = [chr(n) for n in cs]
+ return dict(zip(bs, cs))
+
+
+def get_pairs(word):
+ """Return set of symbol pairs in a word.
+ Word is represented as tuple of symbols (symbols being variable-length strings).
+ """
+ pairs = set()
+ prev_char = word[0]
+ for char in word[1:]:
+ pairs.add((prev_char, char))
+ prev_char = char
+ return pairs
+
+
+def basic_clean(text):
+ text = ftfy.fix_text(text)
+ text = html.unescape(html.unescape(text))
+ return text.strip()
+
+
+def whitespace_clean(text):
+ text = re.sub(r'\s+', ' ', text)
+ text = text.strip()
+ return text
+
+
+class SimpleTokenizer(object):
+
+ def __init__(self, bpe_path: str = default_bpe()):
+ self.byte_encoder = bytes_to_unicode()
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
+ merges = gzip.open(bpe_path).read().decode('utf-8').split('\n')
+ merges = merges[1:49152 - 256 - 2 + 1]
+ merges = [tuple(merge.split()) for merge in merges]
+ vocab = list(bytes_to_unicode().values())
+ vocab = vocab + [v + '' for v in vocab]
+ for merge in merges:
+ vocab.append(''.join(merge))
+ vocab.extend(['<|startoftext|>', '<|endoftext|>'])
+ self.encoder = dict(zip(vocab, range(len(vocab))))
+ self.decoder = {v: k for k, v in self.encoder.items()}
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
+ self.cache = {
+ '<|startoftext|>': '<|startoftext|>',
+ '<|endoftext|>': '<|endoftext|>'
+ }
+ self.pat = re.compile(
+ r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
+ re.IGNORECASE)
+
+ def bpe(self, token):
+ if token in self.cache:
+ return self.cache[token]
+ word = tuple(token[:-1]) + (token[-1] + '', )
+ pairs = get_pairs(word)
+
+ if not pairs:
+ return token + ''
+
+ while True:
+ bigram = min(
+ pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
+ if bigram not in self.bpe_ranks:
+ break
+ first, second = bigram
+ new_word = []
+ i = 0
+ while i < len(word):
+ try:
+ j = word.index(first, i)
+ new_word.extend(word[i:j])
+ i = j
+ except Exception as e:
+ new_word.extend(word[i:])
+ print(e)
+ break
+
+ if word[i] == first and i < len(word) - 1 and word[
+ i + 1] == second:
+ new_word.append(first + second)
+ i += 2
+ else:
+ new_word.append(word[i])
+ i += 1
+ new_word = tuple(new_word)
+ word = new_word
+ if len(word) == 1:
+ break
+ else:
+ pairs = get_pairs(word)
+ word = ' '.join(word)
+ self.cache[token] = word
+ return word
+
+ def encode(self, text):
+ bpe_tokens = []
+ text = whitespace_clean(basic_clean(text)).lower()
+ for token in re.findall(self.pat, text):
+ token = ''.join(self.byte_encoder[b]
+ for b in token.encode('utf-8'))
+ bpe_tokens.extend(self.encoder[bpe_token]
+ for bpe_token in self.bpe(token).split(' '))
+ return bpe_tokens
+
+ def decode(self, tokens):
+ text = ''.join([self.decoder[token] for token in tokens])
+ text = bytearray([self.byte_decoder[c] for c in text]).decode(
+ 'utf-8', errors='replace').replace('', ' ')
+ return text
+
+
+class CLIPTokenizer(object):
+
+ def __init__(self, length=77):
+ self.length = length
+
+ # init tokenizer
+ self.tokenizer = SimpleTokenizer(bpe_path=default_bpe())
+ self.sos_token = self.tokenizer.encoder['<|startoftext|>']
+ self.eos_token = self.tokenizer.encoder['<|endoftext|>']
+ self.vocab_size = len(self.tokenizer.encoder)
+
+ def __call__(self, sequence):
+ if isinstance(sequence, str):
+ return torch.LongTensor(self._tokenizer(sequence))
+ elif isinstance(sequence, list):
+ return torch.LongTensor([self._tokenizer(u) for u in sequence])
+ else:
+ raise TypeError(
+ f'Expected the "sequence" to be a string or a list, but got {type(sequence)}'
+ )
+
+ def _tokenizer(self, text):
+ tokens = self.tokenizer.encode(text)[:self.length - 2]
+ tokens = [self.sos_token] + tokens + [self.eos_token]
+ tokens = tokens + [0] * (self.length - len(tokens))
+ return tokens
diff --git a/modelscope/models/multi_modal/videocomposer/data/transforms.py b/modelscope/models/multi_modal/videocomposer/data/transforms.py
new file mode 100644
index 00000000..3573fe86
--- /dev/null
+++ b/modelscope/models/multi_modal/videocomposer/data/transforms.py
@@ -0,0 +1,400 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+
+import math
+import random
+
+import numpy as np
+import torch
+import torchvision.transforms.functional as F
+import torchvision.transforms.functional as TF
+from PIL import Image, ImageFilter
+from torchvision.transforms.functional import InterpolationMode
+
+__all__ = [
+ 'Compose', 'Resize', 'Rescale', 'CenterCrop', 'CenterCropV2', 'RandomCrop',
+ 'RandomCropV2', 'RandomHFlip', 'GaussianBlur', 'ColorJitter', 'RandomGray',
+ 'ToTensor', 'Normalize', 'ResizeRandomCrop', 'ExtractResizeRandomCrop',
+ 'ExtractResizeAssignCrop'
+]
+
+
+def random_resize(img, size):
+ img = [
+ TF.resize(
+ u,
+ size,
+ interpolation=random.choice([
+ InterpolationMode.BILINEAR, InterpolationMode.BICUBIC,
+ InterpolationMode.LANCZOS
+ ])) for u in img
+ ]
+ return img
+
+
+class CenterCropV3(object):
+
+ def __init__(self, size):
+ self.size = size
+
+ def __call__(self, img):
+ # fast resize
+ while min(img.size) >= 2 * self.size:
+ img = img.resize((img.width // 2, img.height // 2),
+ resample=Image.BOX)
+ scale = self.size / min(img.size)
+ img = img.resize((round(scale * img.width), round(scale * img.height)),
+ resample=Image.BICUBIC)
+
+ # center crop
+ x1 = (img.width - self.size) // 2
+ y1 = (img.height - self.size) // 2
+ img = img.crop((x1, y1, x1 + self.size, y1 + self.size))
+ return img
+
+
+class Compose(object):
+
+ def __init__(self, transforms):
+ self.transforms = transforms
+
+ def __getitem__(self, index):
+ if isinstance(index, slice):
+ return Compose(self.transforms[index])
+ else:
+ return self.transforms[index]
+
+ def __len__(self):
+ return len(self.transforms)
+
+ def __call__(self, rgb):
+ for t in self.transforms:
+ rgb = t(rgb)
+ return rgb
+
+
+class Resize(object):
+
+ def __init__(self, size=256):
+ if isinstance(size, int):
+ size = (size, size)
+ self.size = size
+
+ def __call__(self, rgb):
+
+ rgb = [u.resize(self.size, Image.BILINEAR) for u in rgb]
+ return rgb
+
+
+class Rescale(object):
+
+ def __init__(self, size=256, interpolation=Image.BILINEAR):
+ self.size = size
+ self.interpolation = interpolation
+
+ def __call__(self, rgb):
+ w, h = rgb[0].size
+ scale = self.size / min(w, h)
+ out_w, out_h = int(round(w * scale)), int(round(h * scale))
+ rgb = [u.resize((out_w, out_h), self.interpolation) for u in rgb]
+ return rgb
+
+
+class CenterCrop(object):
+
+ def __init__(self, size=224):
+ self.size = size
+
+ def __call__(self, rgb):
+ w, h = rgb[0].size
+ assert min(w, h) >= self.size
+ x1 = (w - self.size) // 2
+ y1 = (h - self.size) // 2
+ rgb = [u.crop((x1, y1, x1 + self.size, y1 + self.size)) for u in rgb]
+ return rgb
+
+
+class ResizeRandomCrop(object):
+
+ def __init__(self, size=256, size_short=292):
+ self.size = size
+ # self.min_area = min_area
+ self.size_short = size_short
+
+ def __call__(self, rgb):
+
+ # consistent crop between rgb and m
+ while min(rgb[0].size) >= 2 * self.size_short:
+ rgb = [
+ u.resize((u.width // 2, u.height // 2), resample=Image.BOX)
+ for u in rgb
+ ]
+ scale = self.size_short / min(rgb[0].size)
+ rgb = [
+ u.resize((round(scale * u.width), round(scale * u.height)),
+ resample=Image.BICUBIC) for u in rgb
+ ]
+ out_w = self.size
+ out_h = self.size
+ w, h = rgb[0].size # (518, 292)
+ x1 = random.randint(0, w - out_w)
+ y1 = random.randint(0, h - out_h)
+
+ rgb = [u.crop((x1, y1, x1 + out_w, y1 + out_h)) for u in rgb]
+
+ return rgb
+
+
+class ExtractResizeRandomCrop(object):
+
+ def __init__(self, size=256, size_short=292):
+ self.size = size
+ self.size_short = size_short
+
+ def __call__(self, rgb):
+
+ # consistent crop between rgb and m
+ while min(rgb[0].size) >= 2 * self.size_short:
+ rgb = [
+ u.resize((u.width // 2, u.height // 2), resample=Image.BOX)
+ for u in rgb
+ ]
+ scale = self.size_short / min(rgb[0].size)
+ rgb = [
+ u.resize((round(scale * u.width), round(scale * u.height)),
+ resample=Image.BICUBIC) for u in rgb
+ ]
+ out_w = self.size
+ out_h = self.size
+ w, h = rgb[0].size # (518, 292)
+ x1 = random.randint(0, w - out_w)
+ y1 = random.randint(0, h - out_h)
+
+ rgb = [u.crop((x1, y1, x1 + out_w, y1 + out_h)) for u in rgb]
+ wh = [x1, y1, x1 + out_w, y1 + out_h]
+
+ return rgb, wh
+
+
+class ExtractResizeAssignCrop(object):
+
+ def __init__(self, size=256, size_short=292):
+ self.size = size
+ self.size_short = size_short
+
+ def __call__(self, rgb, wh):
+
+ # consistent crop between rgb and m
+ while min(rgb[0].size) >= 2 * self.size_short:
+ rgb = [
+ u.resize((u.width // 2, u.height // 2), resample=Image.BOX)
+ for u in rgb
+ ]
+ scale = self.size_short / min(rgb[0].size)
+ rgb = [
+ u.resize((round(scale * u.width), round(scale * u.height)),
+ resample=Image.BICUBIC) for u in rgb
+ ]
+
+ rgb = [u.crop(wh) for u in rgb]
+ rgb = [u.resize((self.size, self.size), Image.BILINEAR) for u in rgb]
+
+ return rgb
+
+
+class CenterCropV2(object):
+
+ def __init__(self, size):
+ self.size = size
+
+ def __call__(self, img):
+ # fast resize
+ while min(img[0].size) >= 2 * self.size:
+ img = [
+ u.resize((u.width // 2, u.height // 2), resample=Image.BOX)
+ for u in img
+ ]
+ scale = self.size / min(img[0].size)
+ img = [
+ u.resize((round(scale * u.width), round(scale * u.height)),
+ resample=Image.BICUBIC) for u in img
+ ]
+
+ # center crop
+ x1 = (img[0].width - self.size) // 2
+ y1 = (img[0].height - self.size) // 2
+ img = [u.crop((x1, y1, x1 + self.size, y1 + self.size)) for u in img]
+ return img
+
+
+class RandomCrop(object):
+
+ def __init__(self, size=224, min_area=0.4):
+ self.size = size
+ self.min_area = min_area
+
+ def __call__(self, rgb):
+
+ # consistent crop between rgb and m
+ w, h = rgb[0].size
+ area = w * h
+ out_w, out_h = float('inf'), float('inf')
+ while out_w > w or out_h > h:
+ target_area = random.uniform(self.min_area, 1.0) * area
+ aspect_ratio = random.uniform(3. / 4., 4. / 3.)
+ out_w = int(round(math.sqrt(target_area * aspect_ratio)))
+ out_h = int(round(math.sqrt(target_area / aspect_ratio)))
+ x1 = random.randint(0, w - out_w)
+ y1 = random.randint(0, h - out_h)
+
+ rgb = [u.crop((x1, y1, x1 + out_w, y1 + out_h)) for u in rgb]
+ rgb = [u.resize((self.size, self.size), Image.BILINEAR) for u in rgb]
+
+ return rgb
+
+
+class RandomCropV2(object):
+
+ def __init__(self, size=224, min_area=0.4, ratio=(3. / 4., 4. / 3.)):
+ if isinstance(size, (tuple, list)):
+ self.size = size
+ else:
+ self.size = (size, size)
+ self.min_area = min_area
+ self.ratio = ratio
+
+ def _get_params(self, img):
+ width, height = img.size
+ area = height * width
+
+ for _ in range(10):
+ target_area = random.uniform(self.min_area, 1.0) * area
+ log_ratio = (math.log(self.ratio[0]), math.log(self.ratio[1]))
+ aspect_ratio = math.exp(random.uniform(*log_ratio))
+
+ w = int(round(math.sqrt(target_area * aspect_ratio)))
+ h = int(round(math.sqrt(target_area / aspect_ratio)))
+
+ if 0 < w <= width and 0 < h <= height:
+ i = random.randint(0, height - h)
+ j = random.randint(0, width - w)
+ return i, j, h, w
+
+ # Fallback to central crop
+ in_ratio = float(width) / float(height)
+ if (in_ratio < min(self.ratio)):
+ w = width
+ h = int(round(w / min(self.ratio)))
+ elif (in_ratio > max(self.ratio)):
+ h = height
+ w = int(round(h * max(self.ratio)))
+ else: # whole image
+ w = width
+ h = height
+ i = (height - h) // 2
+ j = (width - w) // 2
+ return i, j, h, w
+
+ def __call__(self, rgb):
+ i, j, h, w = self._get_params(rgb[0])
+ rgb = [F.resized_crop(u, i, j, h, w, self.size) for u in rgb]
+ return rgb
+
+
+class RandomHFlip(object):
+
+ def __init__(self, p=0.5):
+ self.p = p
+
+ def __call__(self, rgb):
+ if random.random() < self.p:
+ rgb = [u.transpose(Image.FLIP_LEFT_RIGHT) for u in rgb]
+ return rgb
+
+
+class GaussianBlur(object):
+
+ def __init__(self, sigmas=[0.1, 2.0], p=0.5):
+ self.sigmas = sigmas
+ self.p = p
+
+ def __call__(self, rgb):
+ if random.random() < self.p:
+ sigma = random.uniform(*self.sigmas)
+ rgb = [
+ u.filter(ImageFilter.GaussianBlur(radius=sigma)) for u in rgb
+ ]
+ return rgb
+
+
+class ColorJitter(object):
+
+ def __init__(self,
+ brightness=0.4,
+ contrast=0.4,
+ saturation=0.4,
+ hue=0.1,
+ p=0.5):
+ self.brightness = brightness
+ self.contrast = contrast
+ self.saturation = saturation
+ self.hue = hue
+ self.p = p
+
+ def __call__(self, rgb):
+ if random.random() < self.p:
+ brightness, contrast, saturation, hue = self._random_params()
+ transforms = [
+ lambda f: F.adjust_brightness(f, brightness),
+ lambda f: F.adjust_contrast(f, contrast),
+ lambda f: F.adjust_saturation(f, saturation),
+ lambda f: F.adjust_hue(f, hue)
+ ]
+ random.shuffle(transforms)
+ for t in transforms:
+ rgb = [t(u) for u in rgb]
+
+ return rgb
+
+ def _random_params(self):
+ brightness = random.uniform(
+ max(0, 1 - self.brightness), 1 + self.brightness)
+ contrast = random.uniform(max(0, 1 - self.contrast), 1 + self.contrast)
+ saturation = random.uniform(
+ max(0, 1 - self.saturation), 1 + self.saturation)
+ hue = random.uniform(-self.hue, self.hue)
+ return brightness, contrast, saturation, hue
+
+
+class RandomGray(object):
+
+ def __init__(self, p=0.2):
+ self.p = p
+
+ def __call__(self, rgb):
+ if random.random() < self.p:
+ rgb = [u.convert('L').convert('RGB') for u in rgb]
+ return rgb
+
+
+class ToTensor(object):
+
+ def __call__(self, rgb):
+ rgb = torch.stack([F.to_tensor(u) for u in rgb], dim=0)
+ return rgb
+
+
+class Normalize(object):
+
+ def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
+ self.mean = mean
+ self.std = std
+
+ def __call__(self, rgb):
+ rgb = rgb.clone()
+ rgb.clamp_(0, 1)
+ if not isinstance(self.mean, torch.Tensor):
+ self.mean = rgb.new_tensor(self.mean).view(-1)
+ if not isinstance(self.std, torch.Tensor):
+ self.std = rgb.new_tensor(self.std).view(-1)
+ rgb.sub_(self.mean.view(1, -1, 1, 1)).div_(self.std.view(1, -1, 1, 1))
+ return rgb
diff --git a/modelscope/models/multi_modal/videocomposer/diffusion.py b/modelscope/models/multi_modal/videocomposer/diffusion.py
new file mode 100644
index 00000000..8b669ca9
--- /dev/null
+++ b/modelscope/models/multi_modal/videocomposer/diffusion.py
@@ -0,0 +1,1514 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+
+import math
+
+import torch
+
+from .dpm_solver import (DPM_Solver, NoiseScheduleVP, model_wrapper,
+ model_wrapper_guided_diffusion)
+from .ops.losses import discretized_gaussian_log_likelihood, kl_divergence
+
+__all__ = ['GaussianDiffusion', 'beta_schedule', 'GaussianDiffusion_style']
+
+
+def _i(tensor, t, x):
+ r"""Index tensor using t and format the output according to x.
+ """
+ shape = (x.size(0), ) + (1, ) * (x.ndim - 1)
+ if tensor.device != x.device:
+ tensor = tensor.to(x.device)
+ return tensor[t].view(shape).to(x)
+
+
+def beta_schedule(schedule,
+ num_timesteps=1000,
+ init_beta=None,
+ last_beta=None):
+ '''
+ This code defines a function beta_schedule that generates a sequence of beta
+ values based on the given input parameters.
+ These beta values can be used in video diffusion processes. The function has the following parameters:
+ schedule(str): Determines the type of beta schedule to be generated.
+ It can be 'linear', 'linear_sd', 'quadratic', or 'cosine'.
+ num_timesteps(int, optional): The number of timesteps for the generated beta schedule. Default is 1000.
+ init_beta(float, optional): The initial beta value.
+ If not provided, a default value is used based on the chosen schedule.
+ last_beta(float, optional): The final beta value.
+ If not provided, a default value is used based on the chosen schedule.
+ The function returns a PyTorch tensor containing the generated beta values.
+ The beta schedule is determined by the schedule parameter:
+ 1.Linear: Generates a linear sequence of beta values betweeninit_betaandlast_beta.
+ 2.Linear_sd: Generates a linear sequence of beta values between the square root of
+ init_beta and the square root oflast_beta, and then squares the result.
+ 3.Quadratic: Similar to the 'linear_sd' schedule, but with different default values forinit_betaandlast_beta.
+ 4.Cosine: Generates a sequence of beta values based on a cosine function,
+ ensuring the values are between 0 and 0.999.
+ If an unsupported schedule is provided, a ValueError is raised with a message indicating the issue.
+ '''
+ if schedule == 'linear':
+ scale = 1000.0 / num_timesteps
+ init_beta = init_beta or scale * 0.0001
+ last_beta = last_beta or scale * 0.02
+ return torch.linspace(
+ init_beta, last_beta, num_timesteps, dtype=torch.float64)
+ elif schedule == 'linear_sd':
+ return torch.linspace(
+ init_beta**0.5, last_beta**0.5, num_timesteps,
+ dtype=torch.float64)**2
+ elif schedule == 'quadratic':
+ init_beta = init_beta or 0.0015
+ last_beta = last_beta or 0.0195
+ return torch.linspace(
+ init_beta**0.5, last_beta**0.5, num_timesteps,
+ dtype=torch.float64)**2
+ elif schedule == 'cosine':
+ betas = []
+ for step in range(num_timesteps):
+ t1 = step / num_timesteps
+ t2 = (step + 1) / num_timesteps
+ fn = lambda u: math.cos( # noqa
+ (u + 0.008) / 1.008 * math.pi / 2)**2 # noqa
+ betas.append(min(1.0 - fn(t2) / fn(t1), 0.999))
+ return torch.tensor(betas, dtype=torch.float64)
+ else:
+ raise ValueError(f'Unsupported schedule: {schedule}')
+
+
+def load_stable_diffusion_pretrained(state_dict, temporal_attention):
+ import collections
+ sd_new = collections.OrderedDict()
+ keys = list(state_dict.keys())
+
+ for k in keys:
+ if k.find('diffusion_model') >= 0:
+ k_new = k.split('diffusion_model.')[-1]
+ if k_new in [
+ 'input_blocks.3.0.op.weight', 'input_blocks.3.0.op.bias',
+ 'input_blocks.6.0.op.weight', 'input_blocks.6.0.op.bias',
+ 'input_blocks.9.0.op.weight', 'input_blocks.9.0.op.bias'
+ ]:
+ k_new = k_new.replace('0.op', 'op')
+ if temporal_attention:
+ if k_new.find('middle_block.2') >= 0:
+ k_new = k_new.replace('middle_block.2', 'middle_block.3')
+ if k_new.find('output_blocks.5.2') >= 0:
+ k_new = k_new.replace('output_blocks.5.2',
+ 'output_blocks.5.3')
+ if k_new.find('output_blocks.8.2') >= 0:
+ k_new = k_new.replace('output_blocks.8.2',
+ 'output_blocks.8.3')
+ sd_new[k_new] = state_dict[k]
+
+ return sd_new
+
+
+class AddGaussianNoise(object):
+
+ def __init__(self, mean=0., std=0.1):
+ self.std = std
+ self.mean = mean
+
+ def __call__(self, img):
+ assert isinstance(img, torch.Tensor)
+ dtype = img.dtype
+ if not img.is_floating_point():
+ img = img.to(torch.float32)
+ out = img + self.std * torch.randn_like(img) + self.mean
+ if out.dtype != dtype:
+ out = out.to(dtype)
+ return out
+
+ def __repr__(self):
+ return self.__class__.__name__ + '(mean={0}, std={1})'.format(
+ self.mean, self.std)
+
+
+class GaussianDiffusion(object):
+
+ def __init__(self,
+ betas,
+ mean_type='eps',
+ var_type='learned_range',
+ loss_type='mse',
+ epsilon=1e-12,
+ rescale_timesteps=False):
+ # check input
+ if not isinstance(betas, torch.DoubleTensor):
+ betas = torch.tensor(betas, dtype=torch.float64)
+ assert min(betas) > 0 and max(betas) <= 1
+ assert mean_type in ['x0', 'x_{t-1}', 'eps']
+ assert var_type in [
+ 'learned', 'learned_range', 'fixed_large', 'fixed_small'
+ ]
+ assert loss_type in [
+ 'mse', 'rescaled_mse', 'kl', 'rescaled_kl', 'l1', 'rescaled_l1',
+ 'charbonnier'
+ ]
+ self.betas = betas
+ self.num_timesteps = len(betas)
+ self.mean_type = mean_type # eps
+ self.var_type = var_type # 'fixed_small'
+ self.loss_type = loss_type # mse
+ self.epsilon = epsilon # 1e-12
+ self.rescale_timesteps = rescale_timesteps # False
+
+ # alphas
+ alphas = 1 - self.betas
+ self.alphas_cumprod = torch.cumprod(alphas, dim=0)
+ self.alphas_cumprod_prev = torch.cat(
+ [alphas.new_ones([1]), self.alphas_cumprod[:-1]])
+ self.alphas_cumprod_next = torch.cat(
+ [self.alphas_cumprod[1:],
+ alphas.new_zeros([1])])
+
+ # q(x_t | x_{t-1})
+ self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
+ self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0
+ - self.alphas_cumprod)
+ self.log_one_minus_alphas_cumprod = torch.log(1.0
+ - self.alphas_cumprod)
+ self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod)
+ self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod
+ - 1)
+
+ # q(x_{t-1} | x_t, x_0)
+ self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / (
+ 1.0 - self.alphas_cumprod)
+ self.posterior_log_variance_clipped = torch.log(
+ self.posterior_variance.clamp(1e-20))
+ self.posterior_mean_coef1 = betas * torch.sqrt(
+ self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
+ self.posterior_mean_coef2 = (
+ 1.0 - self.alphas_cumprod_prev) * torch.sqrt(alphas) / (
+ 1.0 - self.alphas_cumprod)
+
+ def q_sample(self, x0, t, noise=None):
+ r"""Sample from q(x_t | x_0).
+ """
+ noise = torch.randn_like(x0) if noise is None else noise
+ return _i(self.sqrt_alphas_cumprod, t, x0) * x0 + \
+ _i(self.sqrt_one_minus_alphas_cumprod, t, x0) * noise # noqa
+
+ def q_mean_variance(self, x0, t):
+ r"""Distribution of q(x_t | x_0).
+ """
+ mu = _i(self.sqrt_alphas_cumprod, t, x0) * x0
+ var = _i(1.0 - self.alphas_cumprod, t, x0)
+ log_var = _i(self.log_one_minus_alphas_cumprod, t, x0)
+ return mu, var, log_var
+
+ def q_posterior_mean_variance(self, x0, xt, t):
+ r"""Distribution of q(x_{t-1} | x_t, x_0).
+ """
+ mu = _i(self.posterior_mean_coef1, t, xt) * x0 + _i(
+ self.posterior_mean_coef2, t, xt) * xt
+ var = _i(self.posterior_variance, t, xt)
+ log_var = _i(self.posterior_log_variance_clipped, t, xt)
+ return mu, var, log_var
+
+ @torch.no_grad()
+ def p_sample(self,
+ xt,
+ t,
+ model,
+ model_kwargs={},
+ clamp=None,
+ percentile=None,
+ condition_fn=None,
+ guide_scale=None):
+ r"""Sample from p(x_{t-1} | x_t).
+ - condition_fn: for classifier-based guidance (guided-diffusion).
+ - guide_scale: for classifier-free guidance (glide/dalle-2).
+ """
+ # predict distribution of p(x_{t-1} | x_t)
+ mu, var, log_var, x0 = self.p_mean_variance(xt, t, model, model_kwargs,
+ clamp, percentile,
+ guide_scale)
+
+ # random sample (with optional conditional function)
+ noise = torch.randn_like(xt)
+ mask = t.ne(0).float().view(
+ -1,
+ *((1, ) * # noqa
+ (xt.ndim - 1)))
+ if condition_fn is not None:
+ grad = condition_fn(xt, self._scale_timesteps(t), **model_kwargs)
+ mu = mu.float() + var * grad.float()
+ xt_1 = mu + mask * torch.exp(0.5 * log_var) * noise
+ return xt_1, x0
+
+ @torch.no_grad()
+ def p_sample_loop(self,
+ noise,
+ model,
+ model_kwargs={},
+ clamp=None,
+ percentile=None,
+ condition_fn=None,
+ guide_scale=None):
+ r"""Sample from p(x_{t-1} | x_t) p(x_{t-2} | x_{t-1}) ... p(x_0 | x_1).
+ """
+ # prepare input
+ b = noise.size(0)
+ xt = noise
+
+ # diffusion process
+ for step in torch.arange(self.num_timesteps).flip(0):
+ t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
+ xt, _ = self.p_sample(xt, t, model, model_kwargs, clamp,
+ percentile, condition_fn, guide_scale)
+ return xt
+
+ def p_mean_variance(self,
+ xt,
+ t,
+ model,
+ model_kwargs={},
+ clamp=None,
+ percentile=None,
+ guide_scale=None):
+ r"""Distribution of p(x_{t-1} | x_t).
+ """
+ # predict distribution
+ if guide_scale is None:
+ out = model(xt, self._scale_timesteps(t), **model_kwargs)
+ else:
+ # classifier-free guidance
+ # (model_kwargs[0]: conditional kwargs; model_kwargs[1]: non-conditional kwargs)
+ assert isinstance(model_kwargs, list) and len(model_kwargs) == 2
+ y_out = model(xt, self._scale_timesteps(t), **model_kwargs[0])
+ u_out = model(xt, self._scale_timesteps(t), **model_kwargs[1])
+ dim = y_out.size(1) if self.var_type.startswith(
+ 'fixed') else y_out.size(1) // 2
+ out = torch.cat(
+ [
+ u_out[:, :dim] + guide_scale * # noqa
+ (y_out[:, :dim] - u_out[:, :dim]),
+ y_out[:, dim:]
+ ],
+ dim=1) # noqa
+
+ # compute variance
+ if self.var_type == 'learned':
+ out, log_var = out.chunk(2, dim=1)
+ var = torch.exp(log_var)
+ elif self.var_type == 'learned_range':
+ out, fraction = out.chunk(2, dim=1)
+ min_log_var = _i(self.posterior_log_variance_clipped, t, xt)
+ max_log_var = _i(torch.log(self.betas), t, xt)
+ fraction = (fraction + 1) / 2.0
+ log_var = fraction * max_log_var + (1 - fraction) * min_log_var
+ var = torch.exp(log_var)
+ elif self.var_type == 'fixed_large':
+ var = _i(
+ torch.cat([self.posterior_variance[1:2], self.betas[1:]]), t,
+ xt)
+ log_var = torch.log(var)
+ elif self.var_type == 'fixed_small':
+ var = _i(self.posterior_variance, t, xt)
+ log_var = _i(self.posterior_log_variance_clipped, t, xt)
+
+ # compute mean and x0
+ if self.mean_type == 'x_{t-1}':
+ mu = out # x_{t-1}
+ x0 = _i(1.0 / self.posterior_mean_coef1, t, xt) * mu - \
+ _i(self.posterior_mean_coef2 / self.posterior_mean_coef1, t, xt) * xt # noqa
+ elif self.mean_type == 'x0':
+ x0 = out
+ mu, _, _ = self.q_posterior_mean_variance(x0, xt, t)
+ elif self.mean_type == 'eps':
+ x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \
+ _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * out # noqa
+ mu, _, _ = self.q_posterior_mean_variance(x0, xt, t)
+
+ # restrict the range of x0
+ if percentile is not None:
+ assert percentile > 0 and percentile <= 1 # e.g., 0.995
+ s = torch.quantile(
+ x0.flatten(1).abs(), percentile,
+ dim=1).clamp_(1.0).view(-1, 1, 1, 1)
+ x0 = torch.min(s, torch.max(-s, x0)) / s
+ elif clamp is not None:
+ x0 = x0.clamp(-clamp, clamp)
+ return mu, var, log_var, x0
+
+ @torch.no_grad()
+ def ddim_sample(self,
+ xt,
+ t,
+ model,
+ model_kwargs={},
+ clamp=None,
+ percentile=None,
+ condition_fn=None,
+ guide_scale=None,
+ ddim_timesteps=20,
+ eta=0.0):
+ r"""Sample from p(x_{t-1} | x_t) using DDIM.
+ - condition_fn: for classifier-based guidance (guided-diffusion).
+ - guide_scale: for classifier-free guidance (glide/dalle-2).
+ """
+ stride = self.num_timesteps // ddim_timesteps
+
+ # predict distribution of p(x_{t-1} | x_t)
+ _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp,
+ percentile, guide_scale)
+ if condition_fn is not None:
+ # x0 -> eps
+ alpha = _i(self.alphas_cumprod, t, xt)
+ eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \
+ _i(self.sqrt_recipm1_alphas_cumprod, t, xt) # noqa
+ eps = eps - (1 - alpha).sqrt() * condition_fn(
+ xt, self._scale_timesteps(t), **model_kwargs)
+
+ # eps -> x0
+ x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \
+ _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * eps # noqa
+
+ # derive variables
+ eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \
+ _i(self.sqrt_recipm1_alphas_cumprod, t, xt) # noqa
+ alphas = _i(self.alphas_cumprod, t, xt)
+ alphas_prev = _i(self.alphas_cumprod, (t - stride).clamp(0), xt)
+ sigmas = eta * torch.sqrt((1 - alphas_prev) / (1 - alphas) * # noqa
+ (1 - alphas / alphas_prev))
+
+ # random sample
+ noise = torch.randn_like(xt)
+ direction = torch.sqrt(1 - alphas_prev - sigmas**2) * eps
+ mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1)))
+ xt_1 = torch.sqrt(alphas_prev) * x0 + direction + mask * sigmas * noise
+ return xt_1, x0
+
+ @torch.no_grad()
+ def ddim_sample_loop(self,
+ noise,
+ model,
+ model_kwargs={},
+ clamp=None,
+ percentile=None,
+ condition_fn=None,
+ guide_scale=None,
+ ddim_timesteps=20,
+ eta=0.0):
+ # prepare input
+ b = noise.size(0)
+ xt = noise
+
+ # diffusion process (TODO: clamp is inaccurate! Consider replacing the stride by explicit prev/next steps)
+ steps = (1 + torch.arange(0, self.num_timesteps,
+ self.num_timesteps // ddim_timesteps)).clamp(
+ 0, self.num_timesteps - 1).flip(0)
+ # import ipdb; ipdb.set_trace()
+ for step in steps:
+ t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
+ xt, _ = self.ddim_sample(xt, t, model, model_kwargs, clamp,
+ percentile, condition_fn, guide_scale,
+ ddim_timesteps, eta)
+ return xt
+
+ @torch.no_grad()
+ def ddim_reverse_sample(self,
+ xt,
+ t,
+ model,
+ model_kwargs={},
+ clamp=None,
+ percentile=None,
+ guide_scale=None,
+ ddim_timesteps=20):
+ r"""Sample from p(x_{t+1} | x_t) using DDIM reverse ODE (deterministic).
+ """
+ stride = self.num_timesteps // ddim_timesteps
+
+ # predict distribution of p(x_{t-1} | x_t)
+ _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp,
+ percentile, guide_scale)
+
+ # derive variables
+ eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \
+ _i(self.sqrt_recipm1_alphas_cumprod, t, xt) # noqa
+ alphas_next = _i(
+ torch.cat(
+ [self.alphas_cumprod,
+ self.alphas_cumprod.new_zeros([1])]),
+ (t + stride).clamp(0, self.num_timesteps), xt)
+
+ # reverse sample
+ mu = torch.sqrt(alphas_next) * x0 + torch.sqrt(1 - alphas_next) * eps
+ return mu, x0
+
+ @torch.no_grad()
+ def ddim_reverse_sample_loop(self,
+ x0,
+ model,
+ model_kwargs={},
+ clamp=None,
+ percentile=None,
+ guide_scale=None,
+ ddim_timesteps=20):
+ # prepare input
+ b = x0.size(0)
+ xt = x0
+
+ # reconstruction steps
+ steps = torch.arange(0, self.num_timesteps,
+ self.num_timesteps // ddim_timesteps)
+ for step in steps:
+ t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
+ xt, _ = self.ddim_reverse_sample(xt, t, model, model_kwargs, clamp,
+ percentile, guide_scale,
+ ddim_timesteps)
+ return xt
+
+ @torch.no_grad()
+ def plms_sample(self,
+ xt,
+ t,
+ model,
+ model_kwargs={},
+ clamp=None,
+ percentile=None,
+ condition_fn=None,
+ guide_scale=None,
+ plms_timesteps=20):
+ r"""Sample from p(x_{t-1} | x_t) using PLMS.
+ - condition_fn: for classifier-based guidance (guided-diffusion).
+ - guide_scale: for classifier-free guidance (glide/dalle-2).
+ """
+ stride = self.num_timesteps // plms_timesteps
+
+ # function for compute eps
+ def compute_eps(xt, t):
+ # predict distribution of p(x_{t-1} | x_t)
+ _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs,
+ clamp, percentile, guide_scale)
+
+ # condition
+ if condition_fn is not None:
+ # x0 -> eps
+ alpha = _i(self.alphas_cumprod, t, xt)
+ eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \
+ _i(self.sqrt_recipm1_alphas_cumprod, t, xt) # noqa
+ eps = eps - (1 - alpha).sqrt() * condition_fn(
+ xt, self._scale_timesteps(t), **model_kwargs)
+
+ # eps -> x0
+ x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \
+ _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * eps # noqa
+
+ # derive eps
+ eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \
+ _i(self.sqrt_recipm1_alphas_cumprod, t, xt) # noqa
+ return eps
+
+ # function for compute x_0 and x_{t-1}
+ def compute_x0(eps, t):
+ # eps -> x0
+ x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \
+ _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * eps # noqa
+
+ # deterministic sample
+ alphas_prev = _i(self.alphas_cumprod, (t - stride).clamp(0), xt)
+ direction = torch.sqrt(1 - alphas_prev) * eps
+ xt_1 = torch.sqrt(alphas_prev) * x0 + direction
+ return xt_1, x0
+
+ # PLMS sample
+ eps = compute_eps(xt, t)
+ if len(eps_cache) == 0:
+ # 2nd order pseudo improved Euler
+ xt_1, x0 = compute_x0(eps, t)
+ eps_next = compute_eps(xt_1, (t - stride).clamp(0))
+ eps_prime = (eps + eps_next) / 2.0
+ elif len(eps_cache) == 1:
+ # 2nd order pseudo linear multistep (Adams-Bashforth)
+ eps_prime = (3 * eps - eps_cache[-1]) / 2.0
+ elif len(eps_cache) == 2:
+ # 3nd order pseudo linear multistep (Adams-Bashforth)
+ eps_prime = (23 * eps - 16 * eps_cache[-1]
+ + 5 * eps_cache[-2]) / 12.0
+ elif len(eps_cache) >= 3:
+ # 4nd order pseudo linear multistep (Adams-Bashforth)
+ eps_prime = (55 * eps - 59 * eps_cache[-1] + 37 * eps_cache[-2]
+ - 9 * eps_cache[-3]) / 24.0
+ xt_1, x0 = compute_x0(eps_prime, t)
+ return xt_1, x0, eps
+
+ @torch.no_grad()
+ def plms_sample_loop(self,
+ noise,
+ model,
+ model_kwargs={},
+ clamp=None,
+ percentile=None,
+ condition_fn=None,
+ guide_scale=None,
+ plms_timesteps=20):
+ # prepare input
+ b = noise.size(0)
+ xt = noise
+
+ # diffusion process
+ steps = (1 + torch.arange(0, self.num_timesteps,
+ self.num_timesteps // plms_timesteps)).clamp(
+ 0, self.num_timesteps - 1).flip(0)
+ eps_cache = []
+ for step in steps:
+ # PLMS sampling step
+ t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
+ xt, _, eps = self.plms_sample(xt, t, model, model_kwargs, clamp,
+ percentile, condition_fn,
+ guide_scale, plms_timesteps,
+ eps_cache)
+
+ # update eps cache
+ eps_cache.append(eps)
+ if len(eps_cache) >= 4:
+ eps_cache.pop(0)
+ return xt
+
+ def loss(self,
+ x0,
+ t,
+ model,
+ model_kwargs={},
+ noise=None,
+ weight=None,
+ use_div_loss=False):
+ noise = torch.randn_like(
+ x0) if noise is None else noise # [80, 4, 8, 32, 32]
+ xt = self.q_sample(x0, t, noise=noise)
+
+ # compute loss
+ if self.loss_type in ['kl', 'rescaled_kl']:
+ loss, _ = self.variational_lower_bound(x0, xt, t, model,
+ model_kwargs)
+ if self.loss_type == 'rescaled_kl':
+ loss = loss * self.num_timesteps
+ elif self.loss_type in ['mse', 'rescaled_mse', 'l1',
+ 'rescaled_l1']: # self.loss_type: mse
+ out = model(xt, self._scale_timesteps(t), **model_kwargs)
+
+ # VLB for variation
+ loss_vlb = 0.0
+ if self.var_type in ['learned', 'learned_range'
+ ]: # self.var_type: 'fixed_small'
+ out, var = out.chunk(2, dim=1)
+ frozen = torch.cat([
+ out.detach(), var
+ ], dim=1) # learn var without affecting the prediction of mean
+ loss_vlb, _ = self.variational_lower_bound(
+ x0, xt, t, model=lambda *args, **kwargs: frozen)
+ if self.loss_type.startswith('rescaled_'):
+ loss_vlb = loss_vlb * self.num_timesteps / 1000.0
+
+ # MSE/L1 for x0/eps
+ target = {
+ 'eps': noise,
+ 'x0': x0,
+ 'x_{t-1}': self.q_posterior_mean_variance(x0, xt, t)[0]
+ }[self.mean_type]
+ loss = (out - target).pow(1 if self.loss_type.endswith('l1') else 2
+ ).abs().flatten(1).mean(dim=1)
+ if weight is not None:
+ loss = loss * weight
+
+ # div loss
+ if use_div_loss and self.mean_type == 'eps' and x0.shape[2] > 1:
+
+ # derive x0
+ x0_ = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \
+ _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * out
+
+ # ncfhw, std on f
+ div_loss = 0.001 / (
+ x0_.std(dim=2).flatten(1).mean(dim=1) + 1e-4)
+ loss = loss + div_loss
+
+ # total loss
+ loss = loss + loss_vlb
+ elif self.loss_type in ['charbonnier']:
+ out = model(xt, self._scale_timesteps(t), **model_kwargs)
+
+ # VLB for variation
+ loss_vlb = 0.0
+ if self.var_type in ['learned', 'learned_range']:
+ out, var = out.chunk(2, dim=1)
+ frozen = torch.cat([out.detach(), var], dim=1)
+ loss_vlb, _ = self.variational_lower_bound(
+ x0, xt, t, model=lambda *args, **kwargs: frozen)
+ if self.loss_type.startswith('rescaled_'):
+ loss_vlb = loss_vlb * self.num_timesteps / 1000.0
+
+ # MSE/L1 for x0/eps
+ target = {
+ 'eps': noise,
+ 'x0': x0,
+ 'x_{t-1}': self.q_posterior_mean_variance(x0, xt, t)[0]
+ }[self.mean_type]
+ loss = torch.sqrt((out - target)**2 + self.epsilon)
+ if weight is not None:
+ loss = loss * weight
+ loss = loss.flatten(1).mean(dim=1)
+
+ # total loss
+ loss = loss + loss_vlb
+ return loss
+
+ def variational_lower_bound(self,
+ x0,
+ xt,
+ t,
+ model,
+ model_kwargs={},
+ clamp=None,
+ percentile=None):
+ # compute groundtruth and predicted distributions
+ mu1, _, log_var1 = self.q_posterior_mean_variance(x0, xt, t)
+ mu2, _, log_var2, x0 = self.p_mean_variance(xt, t, model, model_kwargs,
+ clamp, percentile)
+
+ # compute KL loss
+ kl = kl_divergence(mu1, log_var1, mu2, log_var2)
+ kl = kl.flatten(1).mean(dim=1) / math.log(2.0)
+
+ # compute discretized NLL loss (for p(x0 | x1) only)
+ nll = -discretized_gaussian_log_likelihood(
+ x0, mean=mu2, log_scale=0.5 * log_var2)
+ nll = nll.flatten(1).mean(dim=1) / math.log(2.0)
+
+ # NLL for p(x0 | x1) and KL otherwise
+ vlb = torch.where(t == 0, nll, kl)
+ return vlb, x0
+
+ @torch.no_grad()
+ def variational_lower_bound_loop(self,
+ x0,
+ model,
+ model_kwargs={},
+ clamp=None,
+ percentile=None):
+ r"""Compute the entire variational lower bound, measured in bits-per-dim.
+ """
+ # prepare input and output
+ b = x0.size(0)
+ metrics = {'vlb': [], 'mse': [], 'x0_mse': []}
+
+ # loop
+ for step in torch.arange(self.num_timesteps).flip(0):
+ # compute VLB
+ t = torch.full((b, ), step, dtype=torch.long, device=x0.device)
+ noise = torch.randn_like(x0)
+ xt = self.q_sample(x0, t, noise)
+ vlb, pred_x0 = self.variational_lower_bound(
+ x0, xt, t, model, model_kwargs, clamp, percentile)
+
+ # predict eps from x0
+ eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \
+ _i(self.sqrt_recipm1_alphas_cumprod, t, xt) # noqa
+
+ # collect metrics
+ metrics['vlb'].append(vlb)
+ metrics['x0_mse'].append(
+ (pred_x0 - x0).square().flatten(1).mean(dim=1))
+ metrics['mse'].append(
+ (eps - noise).square().flatten(1).mean(dim=1))
+ metrics = {k: torch.stack(v, dim=1) for k, v in metrics.items()}
+
+ # compute the prior KL term for VLB, measured in bits-per-dim
+ mu, _, log_var = self.q_mean_variance(x0, t)
+ kl_prior = kl_divergence(mu, log_var, torch.zeros_like(mu),
+ torch.zeros_like(log_var))
+ kl_prior = kl_prior.flatten(1).mean(dim=1) / math.log(2.0)
+
+ # update metrics
+ metrics['prior_bits_per_dim'] = kl_prior
+ metrics['total_bits_per_dim'] = metrics['vlb'].sum(dim=1) + kl_prior
+ return metrics
+
+ def _scale_timesteps(self, t):
+ if self.rescale_timesteps: # noqa
+ return t.float() * 1000.0 / self.num_timesteps
+ return t
+
+
+class GaussianDiffusion_style(object):
+
+ def __init__(self,
+ betas,
+ mean_type='eps',
+ var_type='fixed_small',
+ loss_type='mse',
+ rescale_timesteps=False):
+ # check input
+ if not isinstance(betas, torch.DoubleTensor):
+ betas = torch.tensor(betas, dtype=torch.float64)
+ assert min(betas) > 0 and max(betas) <= 1
+ assert mean_type in ['x0', 'x_{t-1}', 'eps']
+ assert var_type in [
+ 'learned', 'learned_range', 'fixed_large', 'fixed_small'
+ ]
+ assert loss_type in [
+ 'mse', 'rescaled_mse', 'kl', 'rescaled_kl', 'l1', 'rescaled_l1'
+ ]
+ self.betas = betas
+ self.num_timesteps = len(betas)
+ self.mean_type = mean_type
+ self.var_type = var_type
+ self.loss_type = loss_type
+ self.rescale_timesteps = rescale_timesteps
+
+ # alphas
+ alphas = 1 - self.betas
+ self.alphas_cumprod = torch.cumprod(alphas, dim=0)
+ self.alphas_cumprod_prev = torch.cat(
+ [alphas.new_ones([1]), self.alphas_cumprod[:-1]])
+ self.alphas_cumprod_next = torch.cat(
+ [self.alphas_cumprod[1:],
+ alphas.new_zeros([1])])
+
+ # q(x_t | x_{t-1})
+ self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
+ self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0
+ - self.alphas_cumprod)
+ self.log_one_minus_alphas_cumprod = torch.log(1.0
+ - self.alphas_cumprod)
+ self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod)
+ self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod
+ - 1)
+
+ # q(x_{t-1} | x_t, x_0)
+ self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / (
+ 1.0 - self.alphas_cumprod)
+ self.posterior_log_variance_clipped = torch.log(
+ self.posterior_variance.clamp(1e-20))
+ self.posterior_mean_coef1 = betas * torch.sqrt(
+ self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
+ self.posterior_mean_coef2 = (
+ 1.0 - self.alphas_cumprod_prev) * torch.sqrt(alphas) / (
+ 1.0 - self.alphas_cumprod)
+
+ def q_sample(self, x0, t, noise=None):
+ r"""Sample from q(x_t | x_0).
+ """
+ noise = torch.randn_like(x0) if noise is None else noise
+ xt = _i(self.sqrt_alphas_cumprod, t, x0) * x0 + \
+ _i(self.sqrt_one_minus_alphas_cumprod, t, x0) * noise # noqa
+ return xt.type_as(x0)
+
+ def q_mean_variance(self, x0, t):
+ r"""Distribution of q(x_t | x_0).
+ """
+ mu = _i(self.sqrt_alphas_cumprod, t, x0) * x0
+ var = _i(1.0 - self.alphas_cumprod, t, x0)
+ log_var = _i(self.log_one_minus_alphas_cumprod, t, x0)
+ return mu, var, log_var
+
+ def q_posterior_mean_variance(self, x0, xt, t):
+ r"""Distribution of q(x_{t-1} | x_t, x_0).
+ """
+ mu = _i(self.posterior_mean_coef1, t, xt) * x0 + _i(
+ self.posterior_mean_coef2, t, xt) * xt
+ var = _i(self.posterior_variance, t, xt)
+ log_var = _i(self.posterior_log_variance_clipped, t, xt)
+ return mu, var, log_var
+
+ @torch.no_grad()
+ def p_sample(self,
+ xt,
+ t,
+ model,
+ model_kwargs={},
+ clamp=None,
+ percentile=None,
+ condition_fn=None,
+ guide_scale=None):
+ r"""Sample from p(x_{t-1} | x_t).
+ - condition_fn: for classifier-based guidance (guided-diffusion).
+ - guide_scale: for classifier-free guidance (glide/dalle-2).
+ """
+ dtype = xt.dtype
+
+ # predict distribution of p(x_{t-1} | x_t)
+ mu, var, log_var, x0 = self.p_mean_variance(xt, t, model, model_kwargs,
+ clamp, percentile,
+ guide_scale)
+
+ # random sample (with optional conditional function)
+ noise = torch.randn_like(xt)
+ t_mask = t.ne(0).float().view(
+ -1,
+ *((1, ) * # noqa
+ (xt.ndim - 1)))
+ if condition_fn is not None:
+ grad = condition_fn(xt, self._scale_timesteps(t), **model_kwargs)
+ mu = mu.float() + var * grad.float()
+ xt_1 = mu + t_mask * torch.exp(0.5 * log_var) * noise
+ return xt_1.type(dtype), x0.type(dtype)
+
+ @torch.no_grad()
+ def p_sample_loop(self,
+ noise,
+ model,
+ model_kwargs={},
+ clamp=None,
+ percentile=None,
+ condition_fn=None,
+ guide_scale=None):
+ r"""Sample from p(x_{t-1} | x_t) p(x_{t-2} | x_{t-1}) ... p(x_0 | x_1).
+ """
+ # prepare input
+ b = noise.size(0)
+ xt = noise
+
+ # diffusion process
+ for step in torch.arange(self.num_timesteps).flip(0):
+ t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
+ xt, _ = self.p_sample(xt, t, model, model_kwargs, clamp,
+ percentile, condition_fn, guide_scale)
+ return xt
+
+ def p_mean_variance(self,
+ xt,
+ t,
+ model,
+ model_kwargs={},
+ clamp=None,
+ percentile=None,
+ guide_scale=None):
+ r"""Distribution of p(x_{t-1} | x_t).
+ """
+ # predict distribution
+ if guide_scale is None:
+ out = model(xt, t=self._scale_timesteps(t), **model_kwargs)
+ else:
+ # classifier-free guidance
+ # (model_kwargs[0]: conditional kwargs; model_kwargs[1]: non-conditional kwargs)
+ assert isinstance(model_kwargs, list) and len(model_kwargs) == 2
+ y_out = model(xt, t=self._scale_timesteps(t), **model_kwargs[0])
+ if guide_scale != 1.0:
+ u_out = model(
+ xt, t=self._scale_timesteps(t), **model_kwargs[1])
+ dim = y_out.size(1) if self.var_type.startswith(
+ 'fixed') else y_out.size(1) // 2
+ out = torch.cat(
+ [
+ u_out[:, :dim] + guide_scale * # noqa
+ (y_out[:, :dim] - u_out[:, :dim]),
+ y_out[:, dim:]
+ ],
+ dim=1) # noqa
+ else:
+ out = y_out
+
+ # compute variance
+ if self.var_type == 'learned':
+ out, log_var = out.chunk(2, dim=1)
+ var = torch.exp(log_var)
+ elif self.var_type == 'learned_range':
+ out, fraction = out.chunk(2, dim=1)
+ min_log_var = _i(self.posterior_log_variance_clipped, t, xt)
+ max_log_var = _i(torch.log(self.betas), t, xt)
+ fraction = (fraction + 1) / 2.0
+ log_var = fraction * max_log_var + (1 - fraction) * min_log_var
+ var = torch.exp(log_var)
+ elif self.var_type == 'fixed_large':
+ var = _i(
+ torch.cat([self.posterior_variance[1:2], self.betas[1:]]), t,
+ xt)
+ log_var = torch.log(var)
+ elif self.var_type == 'fixed_small':
+ var = _i(self.posterior_variance, t, xt)
+ log_var = _i(self.posterior_log_variance_clipped, t, xt)
+
+ # compute mean and x0
+ if self.mean_type == 'x_{t-1}':
+ mu = out # x_{t-1}
+ x0 = _i(1.0 / self.posterior_mean_coef1, t, xt) * mu - \
+ _i(self.posterior_mean_coef2 / self.posterior_mean_coef1, t, xt) * xt # noqa
+ elif self.mean_type == 'x0':
+ x0 = out
+ elif self.mean_type == 'eps':
+ x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \
+ _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * out # noqa
+
+ # restrict the range of x0
+ if percentile is not None:
+ assert percentile > 0 and percentile <= 1 # e.g., 0.995
+ s = torch.quantile(
+ x0.flatten(1).abs(), percentile,
+ dim=1).clamp_(1.0).view(-1, 1, 1, 1, 1)
+ # s = torch.quantile(x0.flatten(1).abs(), percentile, dim=1).clamp_(1.0).view(-1, 1, 1, 1) # old
+ x0 = torch.min(s, torch.max(-s, x0)) / s
+ elif clamp is not None:
+ x0 = x0.clamp(-clamp, clamp)
+
+ # recompute mu using the restricted x0
+ mu, _, _ = self.q_posterior_mean_variance(x0, xt, t)
+ return mu, var, log_var, x0
+
+ @torch.no_grad()
+ def ddim_sample(self,
+ xt,
+ t,
+ t_prev,
+ model,
+ model_kwargs={},
+ clamp=None,
+ percentile=None,
+ condition_fn=None,
+ guide_scale=None,
+ ddim_timesteps=20,
+ eta=0.0):
+ r"""Sample from p(x_{t-1} | x_t) using DDIM.
+ - condition_fn: for classifier-based guidance (guided-diffusion).
+ - guide_scale: for classifier-free guidance (glide/dalle-2).
+ """
+ dtype = xt.dtype
+
+ # predict distribution of p(x_{t-1} | x_t)
+ _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp,
+ percentile, guide_scale)
+ if condition_fn is not None:
+ # x0 -> eps
+ alpha = _i(self.alphas_cumprod, t, xt)
+ eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \
+ _i(self.sqrt_recipm1_alphas_cumprod, t, xt) # noqa
+ eps = eps - (1 - alpha).sqrt() * condition_fn(
+ xt, self._scale_timesteps(t), **model_kwargs)
+
+ # eps -> x0
+ x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \
+ _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * eps # noqa
+
+ # derive variables
+ eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \
+ _i(self.sqrt_recipm1_alphas_cumprod, t, xt) # noqa
+ alphas = _i(self.alphas_cumprod, t, xt)
+ alphas_prev = _i(self.alphas_cumprod, t_prev, xt)
+ sigmas = eta * torch.sqrt((1 - alphas_prev) / (1 - alphas) * # noqa
+ (1 - alphas / alphas_prev))
+
+ # random sample
+ noise = torch.randn_like(xt)
+ direction = torch.sqrt(1 - alphas_prev - sigmas**2) * eps
+ t_mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1)))
+ xt_1 = torch.sqrt(
+ alphas_prev) * x0 + direction + t_mask * sigmas * noise
+ return xt_1.type(dtype), x0.type(dtype)
+
+ @torch.no_grad()
+ def ddim_sample_loop(self,
+ noise,
+ model,
+ model_kwargs={},
+ clamp=None,
+ percentile=None,
+ condition_fn=None,
+ guide_scale=None,
+ ddim_timesteps=20,
+ eta=0.0):
+ # prepare input
+ b = noise.size(0)
+ xt = noise
+
+ # diffusion process
+ steps = (1 + torch.arange(0, self.num_timesteps,
+ self.num_timesteps // ddim_timesteps)).clamp(
+ 0, self.num_timesteps - 1).flip(0)
+ for i, step in enumerate(steps):
+ t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
+ t_prev = torch.full((b, ),
+ steps[i + 1] if i < len(steps) - 1 else 0,
+ dtype=torch.long,
+ device=xt.device)
+ xt, _ = self.ddim_sample(xt, t, t_prev, model, model_kwargs, clamp,
+ percentile, condition_fn, guide_scale,
+ ddim_timesteps, eta)
+ return xt
+
+ @torch.no_grad()
+ def ddim_reverse_sample(self,
+ xt,
+ t,
+ t_next,
+ model,
+ model_kwargs={},
+ clamp=None,
+ percentile=None,
+ guide_scale=None,
+ ddim_timesteps=20):
+ r"""Sample from p(x_{t+1} | x_t) using DDIM reverse ODE (deterministic).
+ """
+ dtype = xt.dtype
+
+ # predict distribution of p(x_{t-1} | x_t)
+ _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp,
+ percentile, guide_scale)
+
+ # derive variables
+ eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \
+ _i(self.sqrt_recipm1_alphas_cumprod, t, xt) # noqa
+ alphas_next = _i(
+ torch.cat(
+ [self.alphas_cumprod,
+ self.alphas_cumprod.new_zeros([1])]), t_next, xt)
+
+ # reverse sample
+ mu = torch.sqrt(alphas_next) * x0 + torch.sqrt(1 - alphas_next) * eps
+ return mu.type(dtype), x0.type(dtype)
+
+ @torch.no_grad()
+ def ddim_reverse_sample_loop(self,
+ x0,
+ model,
+ model_kwargs={},
+ clamp=None,
+ percentile=None,
+ guide_scale=None,
+ ddim_timesteps=20):
+ # prepare input
+ b = x0.size(0)
+ xt = x0
+
+ # reconstruction steps
+ steps = (1 + torch.arange(0, self.num_timesteps,
+ self.num_timesteps // ddim_timesteps)).clamp(
+ 0, self.num_timesteps - 1)
+ for i, step in enumerate(steps):
+ t = torch.full((b, ),
+ steps[i - 1] if i > 0 else 0,
+ dtype=torch.long,
+ device=xt.device)
+ t_next = torch.full((b, ),
+ step,
+ dtype=torch.long,
+ device=xt.device)
+ xt, _ = self.ddim_reverse_sample(xt, t, t_next, model,
+ model_kwargs, clamp, percentile,
+ guide_scale, ddim_timesteps)
+ return xt
+
+ @torch.no_grad()
+ def plms_sample(self,
+ xt,
+ t,
+ t_prev,
+ model,
+ model_kwargs={},
+ clamp=None,
+ percentile=None,
+ condition_fn=None,
+ guide_scale=None,
+ plms_timesteps=20):
+ r"""Sample from p(x_{t-1} | x_t) using PLMS.
+ - condition_fn: for classifier-based guidance (guided-diffusion).
+ - guide_scale: for classifier-free guidance (glide/dalle-2).
+ """
+
+ # function for compute eps
+ def compute_eps(xt, t):
+ dtype = xt.dtype
+
+ # predict distribution of p(x_{t-1} | x_t)
+ _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs,
+ clamp, percentile, guide_scale)
+
+ # condition
+ if condition_fn is not None:
+ # x0 -> eps
+ alpha = _i(self.alphas_cumprod, t, xt)
+ eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \
+ _i(self.sqrt_recipm1_alphas_cumprod, t, xt) # noqa
+ eps = eps - (1 - alpha).sqrt() * condition_fn(
+ xt, self._scale_timesteps(t), **model_kwargs)
+
+ # eps -> x0
+ x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \
+ _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * eps # noqa
+
+ # derive eps
+ eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \
+ _i(self.sqrt_recipm1_alphas_cumprod, t, xt) # noqa
+ return eps.type(dtype)
+
+ # function for compute x_0 and x_{t-1}
+ def compute_x0(eps, t):
+ dtype = eps.dtype
+
+ # eps -> x0
+ x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \
+ _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * eps # noqa
+
+ # deterministic sample
+ alphas_prev = _i(self.alphas_cumprod, t_prev, xt)
+ direction = torch.sqrt(1 - alphas_prev) * eps
+ xt_1 = torch.sqrt(alphas_prev) * x0 + direction
+ return xt_1.type(dtype), x0.type(dtype)
+
+ # PLMS sample
+ eps = compute_eps(xt, t)
+ if len(eps_cache) == 0:
+ # 2nd order pseudo improved Euler
+ xt_1, x0 = compute_x0(eps, t)
+ eps_next = compute_eps(xt_1, t_prev)
+ eps_prime = (eps + eps_next) / 2.0
+ elif len(eps_cache) == 1:
+ # 2nd order pseudo linear multistep (Adams-Bashforth)
+ eps_prime = (3 * eps - eps_cache[-1]) / 2.0
+ elif len(eps_cache) == 2:
+ # 3nd order pseudo linear multistep (Adams-Bashforth)
+ eps_prime = (23 * eps - 16 * eps_cache[-1]
+ + 5 * eps_cache[-2]) / 12.0
+ elif len(eps_cache) >= 3:
+ # 4nd order pseudo linear multistep (Adams-Bashforth)
+ eps_prime = (55 * eps - 59 * eps_cache[-1] + 37 * eps_cache[-2]
+ - 9 * eps_cache[-3]) / 24.0
+ xt_1, x0 = compute_x0(eps_prime, t)
+ return xt_1, x0, eps
+
+ @torch.no_grad()
+ def plms_sample_loop(self,
+ noise,
+ model,
+ model_kwargs={},
+ clamp=None,
+ percentile=None,
+ condition_fn=None,
+ guide_scale=None,
+ plms_timesteps=20):
+ # prepare input
+ b = noise.size(0)
+ xt = noise
+
+ # diffusion process
+ steps = (1 + torch.arange(0, self.num_timesteps,
+ self.num_timesteps // plms_timesteps)).clamp(
+ 0, self.num_timesteps - 1).flip(0)
+ eps_cache = []
+ for i, step in enumerate(steps):
+ # PLMS sampling step
+ t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
+ t_prev = torch.full((b, ),
+ steps[i + 1] if i < len(steps) - 1 else 0,
+ dtype=torch.long,
+ device=xt.device)
+ xt, _, eps = self.plms_sample(xt, t, t_prev, model, model_kwargs,
+ clamp, percentile, condition_fn,
+ guide_scale, plms_timesteps,
+ eps_cache)
+
+ # update eps cache
+ eps_cache.append(eps)
+ if len(eps_cache) >= 4:
+ eps_cache.pop(0)
+ return xt
+
+ @torch.no_grad()
+ def dpm_solver_sample_loop(self,
+ noise,
+ model,
+ model_kwargs={},
+ order=2,
+ skip_type='logSNR',
+ method='multistep',
+ clamp=None,
+ percentile=None,
+ condition_fn=None,
+ guide_scale=None,
+ dpm_solver_timesteps=20,
+ algorithm_type='dpmsolver++',
+ t_start=None,
+ t_end=None,
+ lower_order_final=True,
+ denoise_to_zero=False,
+ solver_type='dpmsolver'):
+ r"""Sample using DPM-Solver-based method.
+ - condition_fn: for classifier-based guidance (guided-diffusion).
+ - guide_scale: for classifier-free guidance (glide/dalle-2).
+
+ Please check all the parameters in `dpm_solver.sample` before using.
+ """
+ assert self.mean_type in ('eps', 'x0')
+ assert percentile in (None, 0.995)
+ assert clamp is None or percentile is None
+ noise_schedule = NoiseScheduleVP(
+ schedule='discrete', betas=self.betas.float())
+ model_fn = model_wrapper_guided_diffusion(
+ model=model,
+ noise_schedule=noise_schedule,
+ var_type=self.var_type,
+ mean_type=self.mean_type,
+ model_kwargs=model_kwargs,
+ rescale_timesteps=self.rescale_timesteps,
+ num_timesteps=self.num_timesteps,
+ guide_scale=guide_scale,
+ condition_fn=condition_fn)
+ dpm_solver = DPM_Solver(
+ model_fn=model_fn,
+ noise_schedule=noise_schedule,
+ algorithm_type=algorithm_type,
+ percentile=percentile,
+ clamp=clamp)
+ xt = dpm_solver.sample(
+ noise,
+ steps=dpm_solver_timesteps,
+ order=order,
+ skip_type=skip_type,
+ method=method,
+ solver_type=solver_type,
+ t_start=t_start,
+ t_end=t_end,
+ lower_order_final=lower_order_final,
+ denoise_to_zero=denoise_to_zero)
+ return xt
+
+ @torch.no_grad()
+ def inpaint_p_sample(self,
+ xt,
+ t,
+ y,
+ mask,
+ model,
+ model_kwargs={},
+ clamp=None,
+ percentile=None,
+ guide_scale=None):
+ r"""DDPM sampling step for inpainting.
+ """
+ dtype = xt.dtype
+
+ # predict distribution of p(x_{t-1} | x_t), conditioned on y and mask
+ xt = self.q_sample(y, t) * mask + xt * (1 - mask)
+ mu, var, log_var, x0 = self.p_mean_variance(xt, t, model, model_kwargs,
+ clamp, percentile,
+ guide_scale)
+
+ # random sample
+ t_mask = t.ne(0).float().view(
+ -1,
+ *((1, ) * # noqa
+ (xt.ndim - 1)))
+ xt_1 = mu + t_mask * torch.exp(0.5 * log_var) * torch.randn_like(xt)
+ return xt_1.type(dtype), x0.type(dtype)
+
+ @torch.no_grad()
+ def inpaint_p_sample_loop(self,
+ noise,
+ y,
+ mask,
+ model,
+ model_kwargs={},
+ clamp=None,
+ percentile=None,
+ guide_scale=None):
+ r"""DDPM sampling loop for inpainting.
+ """
+ # prepare input
+ b = noise.size(0)
+ xt = noise
+
+ # diffusion process
+ for step in torch.arange(self.num_timesteps).flip(0):
+ t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
+ xt, _ = self.inpaint_p_sample(xt, t, y, mask, model, model_kwargs,
+ clamp, percentile, guide_scale)
+ return xt
+
+ @torch.no_grad()
+ def inpaint_mcg_p_sample(self,
+ xt,
+ t,
+ y,
+ mask,
+ model,
+ model_kwargs={},
+ clamp=None,
+ percentile=None,
+ guide_scale=None,
+ mcg_scale=1.0):
+ r"""DDPM sampling step for inpainting, with Manifold Constrained Gradient (MCG) correction.
+ """
+ dtype = xt.dtype
+
+ # predict distribution of p(x_{t-1} | x_t), conditioned on y and mask
+ with torch.enable_grad():
+ xt.requires_grad_(True)
+ mu, var, log_var, x0 = self.p_mean_variance(
+ xt, t, model, model_kwargs, clamp, percentile, guide_scale)
+ loss = (y * mask - x0 * mask).square().mean()
+ grad = torch.autograd.grad(loss, xt)[0]
+
+ # random sample
+ t_mask = t.ne(0).float().view(
+ -1,
+ *((1, ) * # noqa
+ (xt.ndim - 1)))
+ xt_1 = mu + t_mask * torch.exp(0.5 * log_var) * torch.randn_like(xt)
+ xt_1 = xt_1 - mcg_scale * grad
+
+ # merge foreground and background
+ xt_1 = self.q_sample(y, t) * mask + xt_1 * (1 - mask)
+ return xt_1.type(dtype), x0.type(dtype)
+
+ @torch.no_grad()
+ def inpaint_mcg_p_sample_loop(self,
+ noise,
+ y,
+ mask,
+ model,
+ model_kwargs={},
+ clamp=None,
+ percentile=None,
+ guide_scale=None,
+ mcg_scale=1.0):
+ r"""DDPM sampling loop for inpainting, with Manifold Constrained Gradient (MCG) correction.
+ """
+ # prepare input
+ b = noise.size(0)
+ xt = noise
+
+ # diffusion process
+ for step in torch.arange(self.num_timesteps).flip(0):
+ t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
+ xt, _ = self.inpaint_mcg_p_sample(xt, t, y, mask, model,
+ model_kwargs, clamp, percentile,
+ guide_scale, mcg_scale)
+ return xt
+
+ def loss(self,
+ x0,
+ t,
+ model,
+ model_kwargs={},
+ noise=None,
+ input_x0=None,
+ reduction='mean'):
+ assert reduction in ['mean', 'none']
+ noise = torch.randn_like(x0) if noise is None else noise
+ input_x0 = x0 if input_x0 is None else input_x0
+ xt = self.q_sample(input_x0, t, noise=noise)
+
+ # compute loss
+ if self.loss_type in ['kl', 'rescaled_kl']:
+ loss, _ = self.variational_lower_bound(x0, xt, t, model,
+ model_kwargs)
+ if self.loss_type == 'rescaled_kl':
+ loss = loss * self.num_timesteps
+ elif self.loss_type in ['mse', 'rescaled_mse', 'l1', 'rescaled_l1']:
+ out = model(xt, t=self._scale_timesteps(t), **model_kwargs)
+
+ # VLB for variation
+ loss_vlb = 0.0
+ if self.var_type in ['learned', 'learned_range']:
+ out, var = out.chunk(2, dim=1)
+ frozen = torch.cat([
+ out.detach(), var
+ ], dim=1) # learn var without affecting the prediction of mean
+ loss_vlb, _ = self.variational_lower_bound(
+ x0,
+ xt,
+ t,
+ model=lambda *args, **kwargs: frozen,
+ reduction=reduction)
+ if self.loss_type.startswith('rescaled_'):
+ loss_vlb = loss_vlb * self.num_timesteps / 1000.0
+
+ # MSE/L1 for x0/eps
+ target = {
+ 'eps': noise,
+ 'x0': x0,
+ 'x_{t-1}': self.q_posterior_mean_variance(x0, xt, t)[0]
+ }[self.mean_type]
+ loss = (
+ out
+ - target).pow(1 if self.loss_type.endswith('l1') else 2).abs()
+ if reduction == 'mean':
+ loss = loss.flatten(1).mean(dim=1)
+
+ # total loss
+ loss = loss + loss_vlb
+ return loss
+
+ def variational_lower_bound(self,
+ x0,
+ xt,
+ t,
+ model,
+ model_kwargs={},
+ clamp=None,
+ percentile=None,
+ reduction='mean'):
+ assert reduction in ['mean', 'none']
+
+ # compute groundtruth and predicted distributions
+ mu1, _, log_var1 = self.q_posterior_mean_variance(x0, xt, t)
+ mu2, _, log_var2, x0 = self.p_mean_variance(xt, t, model, model_kwargs,
+ clamp, percentile)
+
+ # compute KL loss
+ kl = kl_divergence(mu1, log_var1, mu2, log_var2) / math.log(2.0)
+ if reduction == 'mean':
+ kl = kl.flatten(1).mean(dim=1)
+
+ # compute discretized NLL loss (for p(x0 | x1) only)
+ nll = -discretized_gaussian_log_likelihood(
+ x0, mean=mu2, log_scale=0.5 * log_var2) / math.log(2.0)
+ if reduction == 'mean':
+ nll = nll.flatten(1).mean(dim=1)
+
+ # NLL for p(x0 | x1) and KL otherwise
+ t = t.view(-1, *(1, ) * (nll.ndim - 1))
+ vlb = torch.where(t == 0, nll, kl)
+ return vlb, x0
+
+ @torch.no_grad()
+ def variational_lower_bound_loop(self,
+ x0,
+ model,
+ model_kwargs={},
+ clamp=None,
+ percentile=None):
+ r"""Compute the entire variational lower bound, measured in bits-per-dim.
+ """
+ # prepare input and output
+ b = x0.size(0)
+ metrics = {'vlb': [], 'mse': [], 'x0_mse': []}
+
+ # loop
+ for step in torch.arange(self.num_timesteps).flip(0):
+ # compute VLB
+ t = torch.full((b, ), step, dtype=torch.long, device=x0.device)
+ noise = torch.randn_like(x0)
+ xt = self.q_sample(x0, t, noise)
+ vlb, pred_x0 = self.variational_lower_bound(
+ x0, xt, t, model, model_kwargs, clamp, percentile)
+
+ # predict eps from x0
+ eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \
+ _i(self.sqrt_recipm1_alphas_cumprod, t, xt) # noqa
+
+ # collect metrics
+ metrics['vlb'].append(vlb)
+ metrics['x0_mse'].append(
+ (pred_x0 - x0).square().flatten(1).mean(dim=1))
+ metrics['mse'].append(
+ (eps - noise).square().flatten(1).mean(dim=1))
+ metrics = {k: torch.stack(v, dim=1) for k, v in metrics.items()}
+
+ # compute the prior KL term for VLB, measured in bits-per-dim
+ mu, _, log_var = self.q_mean_variance(x0, t)
+ kl_prior = kl_divergence(mu, log_var, torch.zeros_like(mu),
+ torch.zeros_like(log_var))
+ kl_prior = kl_prior.flatten(1).mean(dim=1) / math.log(2.0)
+
+ # update metrics
+ metrics['prior_bits_per_dim'] = kl_prior
+ metrics['total_bits_per_dim'] = metrics['vlb'].sum(dim=1) + kl_prior
+ return metrics
+
+ def _scale_timesteps(self, t):
+ if self.rescale_timesteps:
+ return t.float() * 1000.0 / self.num_timesteps
+ return t
diff --git a/modelscope/models/multi_modal/videocomposer/dpm_solver.py b/modelscope/models/multi_modal/videocomposer/dpm_solver.py
new file mode 100644
index 00000000..a1e15a2f
--- /dev/null
+++ b/modelscope/models/multi_modal/videocomposer/dpm_solver.py
@@ -0,0 +1,1697 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+
+import math
+
+import torch
+import torch.nn.functional as F
+
+__all__ = [
+ 'NoiseScheduleVP', 'model_wrapper', 'model_wrapper_guided_diffusion',
+ 'DPM_Solver', 'interpolate_fn'
+]
+
+
+def _i(tensor, t, x):
+ r"""Index tensor using t and format the output according to x.
+ """
+ shape = (x.size(0), ) + (1, ) * (x.ndim - 1)
+ device = x.device
+ return tensor[t.to(device)].view(shape).to(device)
+
+
+class NoiseScheduleVP:
+
+ def __init__(self,
+ schedule='discrete',
+ betas=None,
+ alphas_cumprod=None,
+ continuous_beta_0=0.1,
+ continuous_beta_1=20.,
+ dtype=torch.float32):
+ r"""Create a wrapper class for the forward SDE (VP type).
+
+ ***
+ Update: We support discrete-time diffusion models by implementing a picewise
+ linear interpolation for log_alpha_t.
+ We recommend to use schedule='discrete' for the discrete-time diffusion models,
+ especially for high-resolution images.
+ ***
+
+ The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ).
+ We further define lambda_t = log(alpha_t) - log(sigma_t),
+ which is the half-logSNR (described in the DPM-Solver paper).
+ Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have:
+
+ log_alpha_t = self.marginal_log_mean_coeff(t)
+ sigma_t = self.marginal_std(t)
+ lambda_t = self.marginal_lambda(t)
+
+ Moreover, as lambda(t) is an invertible function, we also support its inverse function:
+
+ t = self.inverse_lambda(lambda_t)
+
+ ===============================================================
+
+ We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1)
+ and continuous-time DPMs (trained on t in [t_0, T]).
+
+ 1. For discrete-time DPMs:
+
+ For discrete-time DPMs trained on n = 0, 1, ..., N-1,
+ we convert the discrete steps to continuous time steps by:
+ t_i = (i + 1) / N
+ e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1.
+ We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3.
+
+ Args:
+ betas: A `torch.Tensor`. The beta array for the discrete-time DPM.
+ (See the original DDPM paper for details)
+ alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM.
+ (See the original DDPM paper for details)
+
+ Note that we always have alphas_cumprod = cumprod(1 - betas).
+ Therefore, we only need to set one of `betas` and `alphas_cumprod`.
+
+ **Important**: Please pay special attention for the args for `alphas_cumprod`:
+ The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically,
+ DDPMs assume that
+ q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ).
+ Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver.
+ In fact, we have
+ alpha_{t_n} = \sqrt{\hat{alpha_n}},
+ and
+ log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}).
+
+ 2. For continuous-time DPMs:
+
+ We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM).
+ The hyperparameters for the noise
+ schedule are the default settings in DDPM and improved-DDPM:
+
+ Args:
+ beta_min: A `float` number. The smallest beta for the linear schedule.
+ beta_max: A `float` number. The largest beta for the linear schedule.
+ cosine_s: A `float` number. The hyperparameter in the cosine schedule.
+ cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule.
+ T: A `float` number. The ending time of the forward process.
+
+ ===============================================================
+
+ Args:
+ schedule: A `str`. The noise schedule of the forward SDE.
+ 'discrete' for discrete-time DPMs,
+ 'linear' or 'cosine' for continuous-time DPMs.
+ Returns:
+ A wrapper object of the forward SDE (VP type).
+
+ ===============================================================
+
+ Example:
+
+ # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1):
+ >>> ns = NoiseScheduleVP('discrete', betas=betas)
+
+ # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n}
+ # array for n = 0, 1, ..., N - 1):
+ >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
+
+ # For continuous-time DPMs (VPSDE), linear schedule:
+ >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.)
+
+ """
+
+ if schedule not in ['discrete', 'linear', 'cosine']:
+ raise ValueError(
+ "Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'"
+ .format(schedule))
+
+ self.schedule = schedule
+ if schedule == 'discrete':
+ if betas is not None:
+ log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
+ else:
+ assert alphas_cumprod is not None
+ log_alphas = 0.5 * torch.log(alphas_cumprod)
+ self.total_N = len(log_alphas)
+ self.T = 1.
+ self.t_array = torch.linspace(0., 1.,
+ self.total_N + 1)[1:].reshape(
+ (1, -1)).to(dtype=dtype)
+ self.log_alpha_array = log_alphas.reshape((
+ 1,
+ -1,
+ )).to(dtype=dtype)
+ else:
+ self.total_N = 1000
+ self.beta_0 = continuous_beta_0
+ self.beta_1 = continuous_beta_1
+ self.cosine_s = 0.008
+ self.cosine_beta_max = 999.
+ self.cosine_t_max = math.atan(
+ self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * (
+ 1. + self.cosine_s) / math.pi - self.cosine_s
+ self.cosine_log_alpha_0 = math.log(
+ math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.))
+ self.schedule = schedule
+ if schedule == 'cosine':
+ # For the cosine schedule,
+ # T = 1 will have numerical issues. So we manually set the ending time T.
+ # Note that T = 0.9946 may be not the optimal setting. However, we find it works well.
+ self.T = 0.9946
+ else:
+ self.T = 1.
+
+ def marginal_log_mean_coeff(self, t):
+ """
+ Compute log(alpha_t) of a given continuous-time label t in [0, T].
+ """
+ if self.schedule == 'discrete':
+ return interpolate_fn(
+ t.reshape((-1, 1)), self.t_array.to(t.device),
+ self.log_alpha_array.to(t.device)).reshape((-1))
+ elif self.schedule == 'linear':
+ return -0.25 * t**2 * (self.beta_1
+ - self.beta_0) - 0.5 * t * self.beta_0
+ elif self.schedule == 'cosine':
+ log_alpha_fn = lambda s: torch.log( # noqa
+ torch.cos((s + self.cosine_s) / # noqa
+ (1. + self.cosine_s) * math.pi / 2.))
+ log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0
+ return log_alpha_t
+
+ def marginal_alpha(self, t):
+ """
+ Compute alpha_t of a given continuous-time label t in [0, T].
+ """
+ return torch.exp(self.marginal_log_mean_coeff(t))
+
+ def marginal_std(self, t):
+ """
+ Compute sigma_t of a given continuous-time label t in [0, T].
+ """
+ return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t)))
+
+ def marginal_lambda(self, t):
+ """
+ Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
+ """
+ log_mean_coeff = self.marginal_log_mean_coeff(t)
+ log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
+ return log_mean_coeff - log_std
+
+ def inverse_lambda(self, lamb):
+ """
+ Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
+ """
+ if self.schedule == 'linear':
+ tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(
+ -2. * lamb,
+ torch.zeros((1, )).to(lamb))
+ Delta = self.beta_0**2 + tmp
+ return tmp / (torch.sqrt(Delta) + self.beta_0) / (
+ self.beta_1 - self.beta_0)
+ elif self.schedule == 'discrete':
+ log_alpha = -0.5 * torch.logaddexp(
+ torch.zeros((1, )).to(lamb.device), -2. * lamb)
+ t = interpolate_fn(
+ log_alpha.reshape((-1, 1)),
+ torch.flip(self.log_alpha_array.to(lamb.device), [1]),
+ torch.flip(self.t_array.to(lamb.device), [1]))
+ return t.reshape((-1, ))
+ else:
+ log_alpha = -0.5 * torch.logaddexp(-2. * lamb,
+ torch.zeros((1, )).to(lamb))
+ t_fn = lambda log_alpha_t: torch.arccos( # noqa
+ torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * (
+ 1. + self.cosine_s) / math.pi - self.cosine_s
+ t = t_fn(log_alpha)
+ return t
+
+
+def model_wrapper(model,
+ noise_schedule,
+ model_type='noise',
+ model_kwargs={},
+ guidance_type='uncond',
+ condition=None,
+ unconditional_condition=None,
+ guidance_scale=1.,
+ classifier_fn=None,
+ classifier_kwargs={}):
+ r"""Create a wrapper function for the noise prediction model.
+
+ DPM-Solver needs to solve the continuous-time diffusion ODEs.
+ For DPMs trained on discrete-time labels, we need to
+ firstly wrap the model function to a noise prediction model
+ that accepts the continuous time as the input.
+
+ We support four types of the diffusion model by setting `model_type`:
+
+ 1. "noise": noise prediction model. (Trained by predicting noise).
+
+ 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0).
+
+ 3. "v": velocity prediction model. (Trained by predicting the velocity).
+ The "v" prediction is derivation detailed in Appendix D of [1],
+ and is used in Imagen-Video [2].
+
+ [1] Salimans, Tim, and Jonathan Ho.
+ "Progressive distillation for fast sampling of diffusion models."
+ arXiv preprint arXiv:2202.00512 (2022).
+ [2] Ho, Jonathan, et al. "Imagen Video: High Definition
+ Video Generation with Diffusion Models."
+ arXiv preprint arXiv:2210.02303 (2022).
+
+ 4. "score": marginal score function. (Trained by denoising score matching).
+ Note that the score function and the noise prediction model follows a simple relationship:
+ ```
+ noise(x_t, t) = -sigma_t * score(x_t, t)
+ ```
+
+ We support three types of guided sampling by DPMs by setting `guidance_type`:
+ 1. "uncond": unconditional sampling by DPMs.
+ The input `model` has the following format:
+ ``
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
+ ``
+
+ 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier.
+ The input `model` has the following format:
+ ``
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
+ ``
+
+ The input `classifier_fn` has the following format:
+ ``
+ classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond)
+ ``
+
+ [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis,"
+ in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794.
+
+ 3. "classifier-free": classifier-free guidance sampling by conditional DPMs.
+ The input `model` has the following format:
+ ``
+ model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
+ ``
+ And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
+
+ [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
+ arXiv preprint arXiv:2207.12598 (2022).
+
+
+ The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
+ or continuous-time labels (i.e. epsilon to T).
+
+ We wrap the model function to accept only `x` and `t_continuous` as inputs,
+ and outputs the predicted noise:
+ ``
+ def model_fn(x, t_continuous) -> noise:
+ t_input = get_model_input_time(t_continuous)
+ return noise_pred(model, x, t_input, **model_kwargs)
+ ``
+ where `t_continuous` is the continuous time labels (i.e. epsilon to T).
+ And we use `model_fn` for DPM-Solver.
+
+ ===============================================================
+
+ Args:
+ model: A diffusion model with the corresponding format described above.
+ noise_schedule: A noise schedule object, such as NoiseScheduleVP.
+ model_type: A `str`. The parameterization type of the diffusion model.
+ "noise" or "x_start" or "v" or "score".
+ model_kwargs: A `dict`. A dict for the other inputs of the model function.
+ guidance_type: A `str`. The type of the guidance for sampling.
+ "uncond" or "classifier" or "classifier-free".
+ condition: A pytorch tensor. The condition for the guided sampling.
+ Only used for "classifier" or "classifier-free" guidance type.
+ unconditional_condition: A pytorch tensor. The condition for the unconditional sampling.
+ Only used for "classifier-free" guidance type.
+ guidance_scale: A `float`. The scale for the guided sampling.
+ classifier_fn: A classifier function. Only used for the classifier guidance.
+ classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function.
+ Returns:
+ A noise prediction model that accepts the noised data and the continuous time as the inputs.
+ """
+
+ def get_model_input_time(t_continuous):
+ """
+ Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
+ For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
+ For continuous-time DPMs, we just use `t_continuous`.
+ """
+ if noise_schedule.schedule == 'discrete':
+ return (t_continuous - 1. / noise_schedule.total_N) * 1000.
+ else:
+ return t_continuous
+
+ def noise_pred_fn(x, t_continuous, cond=None):
+ t_input = get_model_input_time(t_continuous)
+ if cond is None:
+ output = model(x, t_input, **model_kwargs)
+ else:
+ output = model(x, t_input, cond, **model_kwargs)
+ if model_type == 'noise':
+ return output
+ elif model_type == 'x_start':
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(
+ t_continuous), noise_schedule.marginal_std(t_continuous)
+ return (x - alpha_t * output) / sigma_t
+ elif model_type == 'v':
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(
+ t_continuous), noise_schedule.marginal_std(t_continuous)
+ return alpha_t * output + sigma_t * x
+ elif model_type == 'score':
+ sigma_t = noise_schedule.marginal_std(t_continuous)
+ return -sigma_t * output
+
+ def cond_grad_fn(x, t_input):
+ """
+ Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
+ """
+ with torch.enable_grad():
+ x_in = x.detach().requires_grad_(True)
+ log_prob = classifier_fn(x_in, t_input, condition,
+ **classifier_kwargs)
+ return torch.autograd.grad(log_prob.sum(), x_in)[0]
+
+ def model_fn(x, t_continuous):
+ """
+ The noise predicition model function that is used for DPM-Solver.
+ """
+ if guidance_type == 'uncond':
+ return noise_pred_fn(x, t_continuous)
+ elif guidance_type == 'classifier':
+ assert classifier_fn is not None
+ t_input = get_model_input_time(t_continuous)
+ cond_grad = cond_grad_fn(x, t_input)
+ sigma_t = noise_schedule.marginal_std(t_continuous)
+ noise = noise_pred_fn(x, t_continuous)
+ return noise - guidance_scale * sigma_t * cond_grad
+ elif guidance_type == 'classifier-free':
+ if guidance_scale == 1. or unconditional_condition is None:
+ return noise_pred_fn(x, t_continuous, cond=condition)
+ else:
+ x_in = torch.cat([x] * 2)
+ t_in = torch.cat([t_continuous] * 2)
+ c_in = torch.cat([unconditional_condition, condition])
+ noise_uncond, noise = noise_pred_fn(
+ x_in, t_in, cond=c_in).chunk(2)
+ return noise_uncond + guidance_scale * (noise - noise_uncond)
+
+ assert model_type in ['noise', 'x_start', 'v']
+ assert guidance_type in ['uncond', 'classifier', 'classifier-free']
+ return model_fn
+
+
+def model_wrapper_guided_diffusion(model,
+ noise_schedule,
+ var_type,
+ mean_type,
+ model_kwargs={},
+ rescale_timesteps=False,
+ num_timesteps=1000,
+ guide_scale=None,
+ condition_fn=None):
+ """Create a wrapper function for the noise prediction model guided diffusion.
+
+ DPM-Solver needs to solve the continuous-time diffusion ODEs.
+ For DPMs trained on discrete-time labels, we need to
+ firstly wrap the model function to a noise prediction model that accepts the continuous time as the input.
+
+ We support three types of the diffusion model by setting `mean_type`:
+
+ 1. "x_{t-1}": previous data prediction model.
+ (Trained by predicting the data x_{t-1} at time t-1).
+
+ 2. "x0": data prediction model. (Trained by predicting the data x0 at time 0).
+
+ 3. "eps": noise prediction model. (Trained by predicting the noise).
+
+ We support three types of guided sampling by DPMs:
+ 1. unconditional sampling by DPMs.
+ The input `model` has the following format:
+ ``
+ model(xt, t_input, **model_kwargs) -> x_{t-1} | x0 | eps
+ ``
+
+ 2. classifier guidance sampling [3] by DPMs and another classifier.
+ The input `model` has the following format:
+ ``
+ model(xt, t_input, **model_kwargs) -> x_{t-1} | x0 | eps
+ ``
+
+ The input `condition_fn` has the following format:
+ ``
+ condition_fn(xt, t_input, **model_kwargs) -> logits(xt, t_input)
+ ``
+
+ [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis,"
+ in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794.
+
+ 3. classifier-free guidance sampling by conditional DPMs.
+ The input `model` has the following format:
+ ``
+ model(xt, t_input, **model_kwargs) -> x_{t-1} | x0 | eps
+ ``
+
+ [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
+ arXiv preprint arXiv:2207.12598 (2022).
+
+
+ The `t_input` is the time label of the model, which may be discrete-time labels
+ or continuous-time labels (i.e. epsilon to T).
+
+ We wrap the model function to accept only `x` and `t_continuous` as inputs,
+ and outputs the predicted noise:
+ ``
+ def model_fn(x, t_continuous) -> noise:
+ t_input = get_model_input_time(t_continuous)
+ return noise_pred_fn(model, x, t_input, **model_kwargs)
+ ``
+ where `t_continuous` is the continuous time labels (i.e. epsilon to T).
+ And we use `model_fn` for DPM-Solver.
+
+ ===============================================================
+
+ Args:
+ model: A diffusion model with the corresponding format described above.
+ noise_schedule: A noise schedule object, such as NoiseScheduleVP.
+ var_type: A `str`. The variance type of the diffusion model.
+ "learned" or "learned_range" or "fixed_large" or "fixed_small".
+ mean_type: A `str`. The prediction type of the diffusion model.
+ "x_{t-1}" or "x0" or "eps".
+ model_kwargs: A `dict`. A dict for the other inputs of the model function.
+ clamp: A `float`. The range to clamp the data.
+ rescale_timesteps: A `bool`. Whether to rescale the timesteps.
+ num_timesteps: An `int`. The number of the total diffusion steps.
+ guide_scale: A `float`. The strength of the classifier-free guidance.
+ condition_fn: A function. The function of the classifier guidance.
+ Returns:
+ A noise prediction model that accepts the noised data and the continuous time as the inputs.
+ """
+
+ def _scale_timesteps(t):
+ if rescale_timesteps:
+ return t.float() * 1000.0 / num_timesteps
+ return t
+
+ def get_model_input_time(t_continuous):
+ """
+ Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
+ For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
+ For continuous-time DPMs, we just use `t_continuous`.
+ """
+ if noise_schedule.schedule == 'discrete':
+ return (t_continuous - 1. / noise_schedule.total_N) * 1000.
+ else:
+ return t_continuous
+
+ def noise_pred_fn(xt, t_continuous, **model_kwargs):
+ if t_continuous.reshape((-1, )).shape[0] == 1:
+ t_continuous = t_continuous.expand((xt.shape[0]))
+ t_input = get_model_input_time(_scale_timesteps(t_continuous))
+ # predict distribution
+ out = model(xt, t_input, **model_kwargs)
+
+ if var_type == 'learned':
+ out, _ = out.chunk(2, dim=1)
+ elif var_type == 'learned_range':
+ out, _ = out.chunk(2, dim=1)
+
+ if mean_type == 'eps':
+ eps = out
+ elif mean_type == 'x_{t-1}':
+ raise NotImplementedError
+ assert noise_schedule.schedule == 'discrete'
+ mu = out
+ posterior_mean_coef1 = None
+ posterior_mean_coef2 = None
+ x0 = expand_dims(
+ 1. / posterior_mean_coef1, dims) * mu - expand_dims(
+ posterior_mean_coef2 / posterior_mean_coef1, dims) * xt
+ eps = (xt - expand_dims(alpha_t, dims) * x0) / expand_dims(
+ sigma_t, dims)
+ elif mean_type == 'x0':
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(
+ t_continuous), noise_schedule.marginal_std(t_continuous)
+ dims = out.dim()
+ eps = (xt - expand_dims(alpha_t, dims) * out) / expand_dims(
+ sigma_t, dims)
+
+ if condition_fn is not None:
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(
+ t_continuous), noise_schedule.marginal_std(t_continuous)
+ eps = eps - (1 - alpha_t).sqrt() * condition_fn(
+ xt, t_input, **model_kwargs)
+
+ return eps
+
+ def model_fn(x, t_continuous):
+ """
+ The noise predicition model function that is used for DPM-Solver.
+ """
+ if t_continuous.reshape((-1, )).shape[0] == 1:
+ t_continuous = t_continuous.expand((x.shape[0]))
+
+ if guide_scale is None:
+ eps = noise_pred_fn(x, t_continuous, **model_kwargs)
+ else:
+ # classifier-free guidance
+ # (model_kwargs[0]: conditional kwargs; model_kwargs[1]: non-conditional kwargs)
+ assert isinstance(model_kwargs, list) and len(model_kwargs) == 2
+ y_out = noise_pred_fn(x, t_continuous, **model_kwargs[0])
+ u_out = noise_pred_fn(x, t_continuous, **model_kwargs[1])
+ dim = y_out.size(1) if var_type.startswith(
+ 'fixed') else y_out.size(1) // 2
+ eps = u_out[:, :dim] + guide_scale * (
+ y_out[:, :dim] - u_out[:, :dim])
+ return eps
+
+ return model_fn
+
+
+class DPM_Solver:
+
+ def __init__(self,
+ model_fn,
+ noise_schedule,
+ algorithm_type='dpmsolver++',
+ percentile=None,
+ thresholding_max_val=1.,
+ clamp=None):
+ r"""Construct a DPM-Solver.
+
+ We support both the noise prediction model ("predicting epsilon")
+ and the data prediction model ("predicting x0").
+ If `predict_x0` is False, we use the solver for the noise prediction model (DPM-Solver).
+ If `predict_x0` is True, we use the solver for the data prediction model (DPM-Solver++).
+ In such case, we further support the "dynamic thresholding"
+ in [1] when `thresholding` is True.
+ The "dynamic thresholding" can greatly improve the sample quality for pixel-space DPMs
+ with large guidance scales.
+
+ Args:
+ model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]):
+ ``
+ def model_fn(x, t_continuous):
+ return noise
+ ``
+ noise_schedule: A noise schedule object, such as NoiseScheduleVP.
+ predict_x0: A `bool`. If true, use the data prediction model; else, use the noise prediction model.
+ thresholding: A `bool`. Valid when `predict_x0` is True. Whether to use the "dynamic thresholding" in [1].
+ max_val: A `float`. Valid when both `predict_x0` and
+ `thresholding` are True. The max value for thresholding.
+
+ [1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang,
+ Emily Denton, Seyed Kamyar Seyed Ghasemipour,
+ Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al.
+ Photorealistic text-to-image diffusion models with deep language understanding.
+ arXiv preprint arXiv:2205.11487, 2022b.
+ """
+ self.model = model_fn
+ self.noise_schedule = noise_schedule
+ self.algorithm_type = algorithm_type
+ self.percentile = percentile
+ self.thresholding_max_val = thresholding_max_val
+ self.clamp = clamp
+
+ def noise_prediction_fn(self, x, t):
+ """
+ Return the noise prediction model.
+ """
+ return self.model(x, t)
+
+ def data_prediction_fn(self, x, t):
+ """
+ Return the data prediction model (with thresholding).
+ """
+ noise = self.noise_prediction_fn(x, t)
+ dims = x.dim()
+ alpha_t, sigma_t = self.noise_schedule.marginal_alpha(
+ t), self.noise_schedule.marginal_std(t)
+ x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(
+ alpha_t, dims)
+ if self.percentile is not None:
+ s = torch.quantile(
+ torch.abs(x0).reshape((x0.shape[0], -1)),
+ self.percentile,
+ dim=1)
+ s = expand_dims(
+ torch.maximum(
+ s,
+ self.thresholding_max_val
+ * torch.ones_like(s).to(s.device)), dims)
+ x0 = torch.clamp(x0, -s, s) / s
+ elif self.clamp is not None:
+ x0 = x0.clamp(-clamp, clamp)
+ return x0
+
+ def model_fn(self, x, t):
+ """
+ Convert the model to the noise prediction model or the data prediction model.
+ """
+ if self.algorithm_type == 'dpmsolver++':
+ return self.data_prediction_fn(x, t)
+ else:
+ return self.noise_prediction_fn(x, t)
+
+ def get_time_steps(self, skip_type, t_T, t_0, N, device):
+ """Compute the intermediate time steps for sampling.
+
+ Args:
+ skip_type: A `str`. The type for the spacing of the time steps. We support three types:
+ - 'logSNR': uniform logSNR for the time steps.
+ - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
+ - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
+ t_T: A `float`. The starting time of the sampling (default is T).
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
+ N: A `int`. The total number of the spacing of the time steps.
+ device: A torch device.
+ Returns:
+ A pytorch tensor of the time steps, with the shape (N + 1,).
+ """
+ if skip_type == 'logSNR':
+ lambda_T = self.noise_schedule.marginal_lambda(
+ torch.tensor(t_T).to(device))
+ lambda_0 = self.noise_schedule.marginal_lambda(
+ torch.tensor(t_0).to(device))
+ logSNR_steps = torch.linspace(lambda_T.cpu().item(),
+ lambda_0.cpu().item(),
+ N + 1).to(device)
+ return self.noise_schedule.inverse_lambda(logSNR_steps)
+ elif skip_type == 'time_uniform':
+ return torch.linspace(t_T, t_0, N + 1).to(device)
+ elif skip_type == 'time_quadratic':
+ t_order = 2
+ t = torch.linspace(t_T**(1. / t_order), t_0**(1. / t_order),
+ N + 1).pow(t_order).to(device)
+ return t
+ else:
+ raise ValueError(
+ "Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'"
+ .format(skip_type))
+
+ def get_orders_and_timesteps_for_singlestep_solver(self, steps, order,
+ skip_type, t_T, t_0,
+ device):
+ """
+ Get the order of each step for sampling by the singlestep DPM-Solver.
+
+ We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast".
+ Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is:
+ - If order == 1:
+ We take `steps` of DPM-Solver-1 (i.e. DDIM).
+ - If order == 2:
+ - Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling.
+ - If steps % 2 == 0, we use K steps of DPM-Solver-2.
+ - If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1.
+ - If order == 3:
+ - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
+ - If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3,
+ and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1.
+ - If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1.
+ - If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2.
+
+ ============================================
+ Args:
+ order: A `int`. The max order for the solver (2 or 3).
+ steps: A `int`. The total number of function evaluations (NFE).
+ skip_type: A `str`. The type for the spacing of the time steps. We support three types:
+ - 'logSNR': uniform logSNR for the time steps.
+ - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
+ - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
+ t_T: A `float`. The starting time of the sampling (default is T).
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
+ device: A torch device.
+ Returns:
+ orders: A list of the solver order of each step.
+ """
+ if order == 3:
+ K = steps // 3 + 1
+ if steps % 3 == 0:
+ orders = [
+ 3,
+ ] * (K - 2) + [2, 1]
+ elif steps % 3 == 1:
+ orders = [
+ 3,
+ ] * (K - 1) + [1]
+ else:
+ orders = [
+ 3,
+ ] * (K - 1) + [2]
+ elif order == 2:
+ if steps % 2 == 0:
+ K = steps // 2
+ orders = [
+ 2,
+ ] * K
+ else:
+ K = steps // 2 + 1
+ orders = [
+ 2,
+ ] * (K - 1) + [1]
+ elif order == 1:
+ K = 1
+ orders = [
+ 1,
+ ] * steps
+ else:
+ raise ValueError("'order' must be '1' or '2' or '3'.")
+ if skip_type == 'logSNR':
+ # To reproduce the results in DPM-Solver paper
+ timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K,
+ device)
+ else:
+ timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps,
+ device)[torch.cumsum(
+ torch.tensor([
+ 0,
+ ] + orders),
+ 0).to(device)]
+ return timesteps_outer, orders
+
+ def denoise_to_zero_fn(self, x, s):
+ """
+ Denoise at the final step, which is equivalent to solve
+ the ODE from lambda_s to infty by first-order discretization.
+ """
+ return self.data_prediction_fn(x, s)
+
+ def dpm_solver_first_update(self,
+ x,
+ s,
+ t,
+ model_s=None,
+ return_intermediate=False):
+ """
+ DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`.
+
+ Args:
+ x: A pytorch tensor. The initial value at time `s`.
+ s: A pytorch tensor. The starting time, with the shape (1,).
+ t: A pytorch tensor. The ending time, with the shape (1,).
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
+ return_intermediate: A `bool`. If true, also return the model value at time `s`.
+ Returns:
+ x_t: A pytorch tensor. The approximated solution at time `t`.
+ """
+ ns = self.noise_schedule
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
+ h = lambda_t - lambda_s
+ log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(
+ s), ns.marginal_log_mean_coeff(t)
+ sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t)
+ alpha_t = torch.exp(log_alpha_t)
+
+ if self.algorithm_type == 'dpmsolver++':
+ phi_1 = torch.expm1(-h)
+ if model_s is None:
+ model_s = self.model_fn(x, s)
+ x_t = (sigma_t / sigma_s * x - alpha_t * phi_1 * model_s)
+ if return_intermediate:
+ return x_t, {'model_s': model_s}
+ else:
+ return x_t
+ else:
+ phi_1 = torch.expm1(h)
+ if model_s is None:
+ model_s = self.model_fn(x, s)
+ x_t = (
+ torch.exp(log_alpha_t - log_alpha_s) * x # noqa
+ - (sigma_t * phi_1) * model_s)
+ if return_intermediate:
+ return x_t, {'model_s': model_s}
+ else:
+ return x_t
+
+ def singlestep_dpm_solver_second_update(self,
+ x,
+ s,
+ t,
+ r1=0.5,
+ model_s=None,
+ return_intermediate=False,
+ solver_type='dpmsolver'):
+ """
+ Singlestep solver DPM-Solver-2 from time `s` to time `t`.
+
+ Args:
+ x: A pytorch tensor. The initial value at time `s`.
+ s: A pytorch tensor. The starting time, with the shape (1,).
+ t: A pytorch tensor. The ending time, with the shape (1,).
+ r1: A `float`. The hyperparameter of the second-order solver.
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
+ return_intermediate: A `bool`.
+ If true, also return the model value at time `s` and `s1` (the intermediate time).
+ solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
+ The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
+ Returns:
+ x_t: A pytorch tensor. The approximated solution at time `t`.
+ """
+ if solver_type not in ['dpmsolver', 'taylor']:
+ raise ValueError(
+ "'solver_type' must be either 'dpmsolver' or 'taylor', got {}".
+ format(solver_type))
+ if r1 is None:
+ r1 = 0.5
+ ns = self.noise_schedule
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
+ h = lambda_t - lambda_s
+ lambda_s1 = lambda_s + r1 * h
+ s1 = ns.inverse_lambda(lambda_s1)
+ log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff(
+ s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(t)
+ sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(
+ s1), ns.marginal_std(t)
+ alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t)
+
+ if self.algorithm_type == 'dpmsolver++':
+ phi_11 = torch.expm1(-r1 * h)
+ phi_1 = torch.expm1(-h)
+
+ if model_s is None:
+ model_s = self.model_fn(x, s)
+ x_s1 = ((sigma_s1 / sigma_s) * x - (alpha_s1 * phi_11) * model_s)
+ model_s1 = self.model_fn(x_s1, s1)
+ if solver_type == 'dpmsolver':
+ x_t = ((sigma_t / sigma_s) * x - # noqa
+ (alpha_t * phi_1) * model_s # noqa
+ - (0.5 / r1) * (alpha_t * phi_1) * (model_s1 - model_s))
+ elif solver_type == 'taylor':
+ x_t = ((sigma_t / sigma_s) * x - # noqa
+ (alpha_t * phi_1) * model_s # noqa
+ + (1. / r1) * (
+ alpha_t * # noqa
+ (phi_1 / h + 1.)) * (model_s1 - model_s))
+ else:
+ phi_11 = torch.expm1(r1 * h)
+ phi_1 = torch.expm1(h)
+
+ if model_s is None:
+ model_s = self.model_fn(x, s)
+ x_s1 = (
+ torch.exp(log_alpha_s1 - log_alpha_s) * x # noqa
+ - (sigma_s1 * phi_11) * model_s)
+ model_s1 = self.model_fn(x_s1, s1)
+ if solver_type == 'dpmsolver':
+ x_t = (
+ torch.exp(log_alpha_t - log_alpha_s) * x # noqa
+ - (sigma_t * phi_1) * model_s - (0.5 / r1) # noqa
+ * (sigma_t * phi_1) * (model_s1 - model_s))
+ elif solver_type == 'taylor':
+ x_t = (
+ torch.exp(log_alpha_t - log_alpha_s) * x # noqa
+ - (sigma_t * phi_1) * model_s - (1. / r1) # noqa
+ * (sigma_t * (phi_1 / h - 1.)) * (model_s1 - model_s))
+ if return_intermediate:
+ return x_t, {'model_s': model_s, 'model_s1': model_s1}
+ else:
+ return x_t
+
+ def singlestep_dpm_solver_third_update(self,
+ x,
+ s,
+ t,
+ r1=1. / 3.,
+ r2=2. / 3.,
+ model_s=None,
+ model_s1=None,
+ return_intermediate=False,
+ solver_type='dpmsolver'):
+ """
+ Singlestep solver DPM-Solver-3 from time `s` to time `t`.
+
+ Args:
+ x: A pytorch tensor. The initial value at time `s`.
+ s: A pytorch tensor. The starting time, with the shape (1,).
+ t: A pytorch tensor. The ending time, with the shape (1,).
+ r1: A `float`. The hyperparameter of the third-order solver.
+ r2: A `float`. The hyperparameter of the third-order solver.
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
+ model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`).
+ If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it.
+ return_intermediate: A `bool`. If true,
+ also return the model value at time `s`, `s1` and `s2` (the intermediate times).
+ solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
+ The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
+ Returns:
+ x_t: A pytorch tensor. The approximated solution at time `t`.
+ """
+ if solver_type not in ['dpmsolver', 'taylor']:
+ raise ValueError(
+ "'solver_type' must be either 'dpmsolver' or 'taylor', got {}".
+ format(solver_type))
+ if r1 is None:
+ r1 = 1. / 3.
+ if r2 is None:
+ r2 = 2. / 3.
+ ns = self.noise_schedule
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
+ h = lambda_t - lambda_s
+ lambda_s1 = lambda_s + r1 * h
+ lambda_s2 = lambda_s + r2 * h
+ s1 = ns.inverse_lambda(lambda_s1)
+ s2 = ns.inverse_lambda(lambda_s2)
+ log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ns.marginal_log_mean_coeff(
+ s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(
+ s2), ns.marginal_log_mean_coeff(t)
+ sigma_s, sigma_s1, sigma_s2, sigma_t = ns.marginal_std(
+ s), ns.marginal_std(s1), ns.marginal_std(s2), ns.marginal_std(t)
+ alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(
+ log_alpha_s2), torch.exp(log_alpha_t)
+
+ if self.algorithm_type == 'dpmsolver++':
+ phi_11 = torch.expm1(-r1 * h)
+ phi_12 = torch.expm1(-r2 * h)
+ phi_1 = torch.expm1(-h)
+ phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1.
+ phi_2 = phi_1 / h + 1.
+ phi_3 = phi_2 / h - 0.5
+
+ if model_s is None:
+ model_s = self.model_fn(x, s)
+ if model_s1 is None:
+ x_s1 = ((sigma_s1 / sigma_s) * x # noqa
+ - (alpha_s1 * phi_11) * model_s)
+ model_s1 = self.model_fn(x_s1, s1)
+ x_s2 = ((sigma_s2 / sigma_s) * x - # noqa
+ (alpha_s2 * phi_12) * model_s + r2 / r1 # noqa
+ * (alpha_s2 * phi_22) * (model_s1 - model_s))
+ model_s2 = self.model_fn(x_s2, s2)
+ if solver_type == 'dpmsolver':
+ x_t = ((sigma_t / sigma_s) * x - # noqa
+ (alpha_t * phi_1) * model_s # noqa
+ + (1. / r2) * (alpha_t * phi_2) * (model_s2 - model_s))
+ elif solver_type == 'taylor':
+ D1_0 = (1. / r1) * (model_s1 - model_s)
+ D1_1 = (1. / r2) * (model_s2 - model_s)
+ D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
+ D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
+ x_t = ((sigma_t / sigma_s) * x - # noqa
+ (alpha_t * phi_1) * model_s # noqa
+ + (alpha_t * phi_2) * D1 - (alpha_t * phi_3) * D2)
+ else:
+ phi_11 = torch.expm1(r1 * h)
+ phi_12 = torch.expm1(r2 * h)
+ phi_1 = torch.expm1(h)
+ phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1.
+ phi_2 = phi_1 / h - 1.
+ phi_3 = phi_2 / h - 0.5
+
+ if model_s is None:
+ model_s = self.model_fn(x, s)
+ if model_s1 is None:
+ x_s1 = ((torch.exp(log_alpha_s1 - log_alpha_s)) * x # noqa
+ - (sigma_s1 * phi_11) * model_s)
+ model_s1 = self.model_fn(x_s1, s1)
+ x_s2 = ((torch.exp(log_alpha_s2 - log_alpha_s)) * x # noqa
+ - (sigma_s2 * phi_12) * model_s - r2 / r1 # noqa
+ * (sigma_s2 * phi_22) * (model_s1 - model_s))
+ model_s2 = self.model_fn(x_s2, s2)
+ if solver_type == 'dpmsolver':
+ x_t = ((torch.exp(log_alpha_t - log_alpha_s)) * x # noqa
+ - (sigma_t * phi_1) * model_s - (1. / r2) # noqa
+ * (sigma_t * phi_2) * (model_s2 - model_s))
+ elif solver_type == 'taylor':
+ D1_0 = (1. / r1) * (model_s1 - model_s)
+ D1_1 = (1. / r2) * (model_s2 - model_s)
+ D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
+ D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
+ x_t = ((torch.exp(log_alpha_t - log_alpha_s)) * x # noqa
+ - (sigma_t * phi_1) * model_s - # noqa
+ (sigma_t * phi_2) * D1 # noqa
+ - (sigma_t * phi_3) * D2)
+
+ if return_intermediate:
+ return x_t, {
+ 'model_s': model_s,
+ 'model_s1': model_s1,
+ 'model_s2': model_s2
+ }
+ else:
+ return x_t
+
+ def multistep_dpm_solver_second_update(self,
+ x,
+ model_prev_list,
+ t_prev_list,
+ t,
+ solver_type='dpmsolver'):
+ """
+ Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`.
+
+ Args:
+ x: A pytorch tensor. The initial value at time `s`.
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,)
+ t: A pytorch tensor. The ending time, with the shape (1,).
+ solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
+ The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
+ Returns:
+ x_t: A pytorch tensor. The approximated solution at time `t`.
+ """
+ if solver_type not in ['dpmsolver', 'taylor']:
+ raise ValueError(
+ "'solver_type' must be either 'dpmsolver' or 'taylor', got {}".
+ format(solver_type))
+ ns = self.noise_schedule
+ model_prev_1, model_prev_0 = model_prev_list[-2], model_prev_list[-1]
+ t_prev_1, t_prev_0 = t_prev_list[-2], t_prev_list[-1]
+ lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(
+ t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t)
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(
+ t_prev_0), ns.marginal_log_mean_coeff(t)
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
+ alpha_t = torch.exp(log_alpha_t)
+
+ h_0 = lambda_prev_0 - lambda_prev_1
+ h = lambda_t - lambda_prev_0
+ r0 = h_0 / h
+ D1_0 = (1. / r0) * (model_prev_0 - model_prev_1)
+ if self.algorithm_type == 'dpmsolver++':
+ phi_1 = torch.expm1(-h)
+ if solver_type == 'dpmsolver':
+ x_t = ((sigma_t / sigma_prev_0) * x # noqa
+ - (alpha_t * phi_1) * model_prev_0 - 0.5 # noqa
+ * (alpha_t * phi_1) * D1_0)
+ elif solver_type == 'taylor':
+ x_t = ((sigma_t / sigma_prev_0) * x - # noqa
+ (alpha_t * phi_1) * model_prev_0 + # noqa
+ (alpha_t * (phi_1 / h + 1.)) * D1_0)
+ else:
+ phi_1 = torch.expm1(h)
+ if solver_type == 'dpmsolver':
+ x_t = ((torch.exp(log_alpha_t - log_alpha_prev_0)) * x # noqa
+ - (sigma_t * phi_1) * model_prev_0 - 0.5 * # noqa
+ (sigma_t * phi_1) * D1_0)
+ elif solver_type == 'taylor':
+ x_t = ((torch.exp(log_alpha_t - log_alpha_prev_0)) * x # noqa
+ - (sigma_t * phi_1) * model_prev_0 # noqa
+ - (sigma_t * (phi_1 / h - 1.)) * D1_0)
+ return x_t
+
+ def multistep_dpm_solver_third_update(self,
+ x,
+ model_prev_list,
+ t_prev_list,
+ t,
+ solver_type='dpmsolver'):
+ """
+ Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`.
+
+ Args:
+ x: A pytorch tensor. The initial value at time `s`.
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,)
+ t: A pytorch tensor. The ending time, with the shape (1,).
+ solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
+ The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
+ Returns:
+ x_t: A pytorch tensor. The approximated solution at time `t`.
+ """
+ ns = self.noise_schedule
+ model_prev_2, model_prev_1, model_prev_0 = model_prev_list
+ t_prev_2, t_prev_1, t_prev_0 = t_prev_list
+ lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(
+ t_prev_2), ns.marginal_lambda(t_prev_1), ns.marginal_lambda(
+ t_prev_0), ns.marginal_lambda(t)
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(
+ t_prev_0), ns.marginal_log_mean_coeff(t)
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
+ alpha_t = torch.exp(log_alpha_t)
+
+ h_1 = lambda_prev_1 - lambda_prev_2
+ h_0 = lambda_prev_0 - lambda_prev_1
+ h = lambda_t - lambda_prev_0
+ r0, r1 = h_0 / h, h_1 / h
+ D1_0 = (1. / r0) * (model_prev_0 - model_prev_1)
+ D1_1 = (1. / r1) * (model_prev_1 - model_prev_2)
+ D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1)
+ D2 = (1. / (r0 + r1)) * (D1_0 - D1_1)
+ if self.algorithm_type == 'dpmsolver++':
+ phi_1 = torch.expm1(-h)
+ phi_2 = phi_1 / h + 1.
+ phi_3 = phi_2 / h - 0.5
+ x_t = ((sigma_t / sigma_prev_0) * x - # noqa
+ (alpha_t * phi_1) * model_prev_0 + # noqa
+ (alpha_t * phi_2) * D1 # noqa
+ - (alpha_t * phi_3) * D2)
+ else:
+ phi_1 = torch.expm1(h)
+ phi_2 = phi_1 / h - 1.
+ phi_3 = phi_2 / h - 0.5
+ x_t = ((torch.exp(log_alpha_t - log_alpha_prev_0)) * x # noqa
+ - (sigma_t * phi_1) * model_prev_0 - # noqa
+ (sigma_t * phi_2) * D1 # noqa
+ - (sigma_t * phi_3) * D2)
+ return x_t
+
+ def singlestep_dpm_solver_update(self,
+ x,
+ s,
+ t,
+ order,
+ return_intermediate=False,
+ solver_type='dpmsolver',
+ r1=None,
+ r2=None):
+ """
+ Singlestep DPM-Solver with the order `order` from time `s` to time `t`.
+
+ Args:
+ x: A pytorch tensor. The initial value at time `s`.
+ s: A pytorch tensor. The starting time, with the shape (1,).
+ t: A pytorch tensor. The ending time, with the shape (1,).
+ order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
+ return_intermediate: A `bool`. If true, also return the model value at time `s`,
+ `s1` and `s2` (the intermediate times).
+ solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
+ The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
+ r1: A `float`. The hyperparameter of the second-order or third-order solver.
+ r2: A `float`. The hyperparameter of the third-order solver.
+ Returns:
+ x_t: A pytorch tensor. The approximated solution at time `t`.
+ """
+ if order == 1:
+ return self.dpm_solver_first_update(
+ x, s, t, return_intermediate=return_intermediate)
+ elif order == 2:
+ return self.singlestep_dpm_solver_second_update(
+ x,
+ s,
+ t,
+ return_intermediate=return_intermediate,
+ solver_type=solver_type,
+ r1=r1)
+ elif order == 3:
+ return self.singlestep_dpm_solver_third_update(
+ x,
+ s,
+ t,
+ return_intermediate=return_intermediate,
+ solver_type=solver_type,
+ r1=r1,
+ r2=r2)
+ else:
+ raise ValueError(
+ 'Solver order must be 1 or 2 or 3, got {}'.format(order))
+
+ def multistep_dpm_solver_update(self,
+ x,
+ model_prev_list,
+ t_prev_list,
+ t,
+ order,
+ solver_type='dpmsolver'):
+ """
+ Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`.
+
+ Args:
+ x: A pytorch tensor. The initial value at time `s`.
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,)
+ t: A pytorch tensor. The ending time, with the shape (1,).
+ order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
+ solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
+ The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
+ Returns:
+ x_t: A pytorch tensor. The approximated solution at time `t`.
+ """
+ if order == 1:
+ return self.dpm_solver_first_update(
+ x, t_prev_list[-1], t, model_s=model_prev_list[-1])
+ elif order == 2:
+ return self.multistep_dpm_solver_second_update(
+ x, model_prev_list, t_prev_list, t, solver_type=solver_type)
+ elif order == 3:
+ return self.multistep_dpm_solver_third_update(
+ x, model_prev_list, t_prev_list, t, solver_type=solver_type)
+ else:
+ raise ValueError(
+ 'Solver order must be 1 or 2 or 3, got {}'.format(order))
+
+ def dpm_solver_adaptive(self,
+ x,
+ order,
+ t_T,
+ t_0,
+ h_init=0.05,
+ atol=0.0078,
+ rtol=0.05,
+ theta=0.9,
+ t_err=1e-5,
+ solver_type='dpmsolver'):
+ """
+ The adaptive step size solver based on singlestep DPM-Solver.
+
+ Args:
+ x: A pytorch tensor. The initial value at time `t_T`.
+ order: A `int`. The (higher) order of the solver. We only support order == 2 or 3.
+ t_T: A `float`. The starting time of the sampling (default is T).
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
+ h_init: A `float`. The initial step size (for logSNR).
+ atol: A `float`. The absolute tolerance of the solver.
+ For image data, the default setting is 0.0078, followed [1].
+ rtol: A `float`. The relative tolerance of the solver.
+ The default setting is 0.05.
+ theta: A `float`. The safety hyperparameter for adapting the step size.
+ The default setting is 0.9, followed [1].
+ t_err: A `float`. The tolerance for the time.
+ We solve the diffusion ODE until the absolute error between the
+ current time and `t_0` is less than `t_err`. The default setting is 1e-5.
+ solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
+ The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
+ Returns:
+ x_0: A pytorch tensor. The approximated solution at time `t_0`.
+
+ [1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas,
+ "Gotta go fast when generating data with score-based models,"
+ arXiv preprint arXiv:2105.14080, 2021.
+ """
+ ns = self.noise_schedule
+ s = t_T * torch.ones((1, )).to(x)
+ lambda_s = ns.marginal_lambda(s)
+ lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x))
+ h = h_init * torch.ones_like(s).to(x)
+ x_prev = x
+ nfe = 0
+ if order == 2:
+ r1 = 0.5
+ lower_update = lambda x, s, t: self.dpm_solver_first_update( # noqa
+ x, s, t, return_intermediate=True)
+ higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update( # noqa
+ x, s, t, r1=r1, solver_type=solver_type, **kwargs)
+ elif order == 3:
+ r1, r2 = 1. / 3., 2. / 3.
+ lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update( # noqa
+ x,
+ s,
+ t,
+ r1=r1,
+ return_intermediate=True,
+ solver_type=solver_type)
+ higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update( # noqa
+ x,
+ s,
+ t,
+ r1=r1,
+ r2=r2,
+ solver_type=solver_type,
+ **kwargs)
+ else:
+ raise ValueError(
+ 'For adaptive step size solver, order must be 2 or 3, got {}'.
+ format(order))
+ while torch.abs((s - t_0)).mean() > t_err:
+ t = ns.inverse_lambda(lambda_s + h)
+ x_lower, lower_noise_kwargs = lower_update(x, s, t)
+ x_higher = higher_update(x, s, t, **lower_noise_kwargs)
+ delta = torch.max(
+ torch.ones_like(x).to(x) * atol,
+ rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev)))
+ norm_fn = lambda v: torch.sqrt( # noqa
+ torch.square(v.reshape(
+ (v.shape[0], -1))).mean(dim=-1, keepdim=True))
+ E = norm_fn((x_higher - x_lower) / delta).max()
+ if torch.all(E <= 1.):
+ x = x_higher
+ s = t
+ x_prev = x_lower
+ lambda_s = ns.marginal_lambda(s)
+ h = torch.min(
+ theta * h * torch.float_power(E, -1. / order).float(),
+ lambda_0 - lambda_s)
+ nfe += order
+ print('adaptive solver nfe', nfe)
+ return x
+
+ def add_noise(self, x, t, noise=None):
+ """
+ Compute the noised input xt = alpha_t * x + sigma_t * noise.
+
+ Args:
+ x: A `torch.Tensor` with shape `(batch_size, *shape)`.
+ t: A `torch.Tensor` with shape `(t_size,)`.
+ Returns:
+ xt with shape `(t_size, batch_size, *shape)`.
+ """
+ alpha_t, sigma_t = self.noise_schedule.marginal_alpha(
+ t), self.noise_schedule.marginal_std(t)
+ if noise is None:
+ noise = torch.randn((t.shape[0], *x.shape), device=x.device)
+ x = x.reshape((-1, *x.shape))
+ xt = expand_dims(alpha_t, x.dim()) * x + expand_dims(sigma_t,
+ x.dim()) * noise
+ if t.shape[0] == 1:
+ return xt.squeeze(0)
+ else:
+ return xt
+
+ def inverse(self,
+ x,
+ steps=20,
+ t_start=None,
+ t_end=None,
+ order=2,
+ skip_type='time_uniform',
+ method='multistep',
+ lower_order_final=True,
+ denoise_to_zero=False,
+ solver_type='dpmsolver',
+ atol=0.0078,
+ rtol=0.05,
+ return_intermediate=False):
+ """
+ Inverse the sample `x` from time `t_start` to `t_end` by DPM-Solver.
+ For discrete-time DPMs, we use `t_start=1/N`, where `N` is the total time steps during training.
+ """
+ t_0 = 1. / self.noise_schedule.total_N if t_start is None else t_start
+ t_T = self.noise_schedule.T if t_end is None else t_end
+ assert t_0 > 0 and t_T > 0, 'Time range needs to be greater than 0. For discrete-time DPMs, \
+ it needs to be in [1 / N, 1], where N is the length of betas array'
+
+ return self.sample(
+ x,
+ steps=steps,
+ t_start=t_0,
+ t_end=t_T,
+ order=order,
+ skip_type=skip_type,
+ method=method,
+ lower_order_final=lower_order_final,
+ denoise_to_zero=denoise_to_zero,
+ solver_type=solver_type,
+ atol=atol,
+ rtol=rtol,
+ return_intermediate=return_intermediate)
+
+ def sample(self,
+ x,
+ steps=20,
+ t_start=None,
+ t_end=None,
+ order=2,
+ skip_type='time_uniform',
+ method='multistep',
+ lower_order_final=True,
+ denoise_to_zero=False,
+ solver_type='dpmsolver',
+ atol=0.0078,
+ rtol=0.05):
+ """
+ Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`.
+
+ =====================================================
+
+ We support the following algorithms for both noise prediction model and data prediction model:
+ - 'singlestep':
+ Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper),
+ which combines different orders of singlestep DPM-Solver.
+ We combine all the singlestep solvers with order <=
+ `order` to use up all the function evaluations (steps).
+ The total number of function evaluations (NFE) == `steps`.
+ Given a fixed NFE == `steps`, the sampling procedure is:
+ - If `order` == 1:
+ - Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM).
+ - If `order` == 2:
+ - Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling.
+ - If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2.
+ - If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
+ - If `order` == 3:
+ - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
+ - If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3,
+ and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
+ - If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1.
+ - If steps % 3 == 2, we use (K - 1)
+ steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2.
+ - 'multistep':
+ Multistep DPM-Solver with the order of `order`.
+ The total number of function evaluations (NFE) == `steps`.
+ We initialize the first `order` values by lower order multistep solvers.
+ Given a fixed NFE == `steps`, the sampling procedure is:
+ Denote K = steps.
+ - If `order` == 1:
+ - We use K steps of DPM-Solver-1 (i.e. DDIM).
+ - If `order` == 2:
+ - We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2.
+ - If `order` == 3:
+ - We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2,
+ then (K - 2) step of multistep DPM-Solver-3.
+ - 'singlestep_fixed':
+ Fixed order singlestep DPM-Solver
+ (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3).
+ We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3,
+ with total [`steps` // `order`] * `order` NFE.
+ - 'adaptive':
+ Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper).
+ We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`.
+ You can adjust the absolute tolerance `atol` and the relative tolerance
+ `rtol` to balance the computatation costs
+ (NFE) and the sample quality.
+ - If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1
+ and singlestep DPM-Solver-2.
+ - If `order` == 3, we use DPM-Solver-23 which combines singlestep
+ DPM-Solver-2 and singlestep DPM-Solver-3.
+
+ =====================================================
+
+ Some advices for choosing the algorithm:
+ - For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs:
+ Use singlestep DPM-Solver ("DPM-Solver-fast" in the paper) with `order = 3`.
+ e.g.
+ >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=False)
+ >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3,
+ skip_type='time_uniform', method='singlestep')
+ - For **guided sampling with large guidance scale** by DPMs:
+ Use multistep DPM-Solver with `predict_x0 = True` and `order = 2`.
+ e.g.
+ >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True)
+ >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2,
+ skip_type='time_uniform', method='multistep')
+
+ We support three types of `skip_type`:
+ - 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images**
+ - 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**.
+ - 'time_quadratic': quadratic time for the time steps.
+
+ =====================================================
+ Args:
+ x: A pytorch tensor. The initial value at time `t_start`
+ e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution.
+ steps: A `int`. The total number of function evaluations (NFE).
+ t_start: A `float`. The starting time of the sampling.
+ If `T` is None, we use self.noise_schedule.T (default is 1.0).
+ t_end: A `float`. The ending time of the sampling.
+ If `t_end` is None, we use 1. / self.noise_schedule.total_N.
+ e.g. if total_N == 1000, we have `t_end` == 1e-3.
+ For discrete-time DPMs:
+ - We recommend `t_end` == 1. / self.noise_schedule.total_N.
+ For continuous-time DPMs:
+ - We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15.
+ order: A `int`. The order of DPM-Solver.
+ skip_type: A `str`. The type for the spacing of the time steps.
+ 'time_uniform' or 'logSNR' or 'time_quadratic'.
+ method: A `str`. The method for sampling.
+ 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'.
+ denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step.
+ Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1).
+
+ This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and
+ score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID
+ for diffusion models sampling by diffusion SDEs for low-resolutional images
+ (such as CIFAR-10). However, we observed that such trick does not matter for
+ high-resolutional images. As it needs an additional NFE, we do not recommend
+ it for high-resolutional images.
+ lower_order_final: A `bool`. Whether to use lower order solvers at the final steps.
+ Only valid for `method=multistep` and `steps < 15`. We empirically find that
+ this trick is a key to stabilizing the sampling by DPM-Solver with very few steps
+ (especially for steps <= 10). So we recommend to set it to be `True`.
+ solver_type: A `str`. The taylor expansion type for the solver.
+ `dpmsolver` or `taylor`. We recommend `dpmsolver`.
+ atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
+ rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
+ Returns:
+ x_end: A pytorch tensor. The approximated solution at time `t_end`.
+
+ """
+ t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
+ t_T = self.noise_schedule.T if t_start is None else t_start
+ assert t_0 > 0 and t_T > 0, 'Time range needs to be greater than 0. For discrete-time DPMs, \
+ it needs to be in [1 / N, 1], where N is the length of betas array'
+
+ device = x.device
+ with torch.no_grad():
+ if method == 'adaptive':
+ x = self.dpm_solver_adaptive(
+ x,
+ order=order,
+ t_T=t_T,
+ t_0=t_0,
+ atol=atol,
+ rtol=rtol,
+ solver_type=solver_type)
+ elif method == 'multistep':
+ assert steps >= order
+ timesteps = self.get_time_steps(
+ skip_type=skip_type,
+ t_T=t_T,
+ t_0=t_0,
+ N=steps,
+ device=device)
+ assert timesteps.shape[0] - 1 == steps
+ # Init the initial values.
+ step = 0
+ t = timesteps[step]
+ t_prev_list = [t]
+ model_prev_list = [self.model_fn(x, t)]
+ # Init the first `order` values by lower order multistep DPM-Solver.
+ for step in range(1, order):
+ t = timesteps[step]
+ x = self.multistep_dpm_solver_update(
+ x,
+ model_prev_list,
+ t_prev_list,
+ t,
+ step,
+ solver_type=solver_type)
+ t_prev_list.append(t)
+ model_prev_list.append(self.model_fn(x, t))
+ # Compute the remaining values by `order`-th order multistep DPM-Solver.
+ for step in range(order, steps + 1):
+ t = timesteps[step]
+ # We only use lower order for steps < 10
+ if lower_order_final and steps < 10:
+ step_order = min(order, steps + 1 - step)
+ else:
+ step_order = order
+ x = self.multistep_dpm_solver_update(
+ x,
+ model_prev_list,
+ t_prev_list,
+ t,
+ step_order,
+ solver_type=solver_type)
+ for i in range(order - 1):
+ t_prev_list[i] = t_prev_list[i + 1]
+ model_prev_list[i] = model_prev_list[i + 1]
+ t_prev_list[-1] = t
+ # We do not need to evaluate the final model value.
+ if step < steps:
+ model_prev_list[-1] = self.model_fn(x, t)
+ elif method in ['singlestep', 'singlestep_fixed']:
+ if method == 'singlestep':
+ timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(
+ steps=steps,
+ order=order,
+ skip_type=skip_type,
+ t_T=t_T,
+ t_0=t_0,
+ device=device)
+ elif method == 'singlestep_fixed':
+ K = steps // order
+ orders = [
+ order,
+ ] * K
+ timesteps_outer = self.get_time_steps(
+ skip_type=skip_type,
+ t_T=t_T,
+ t_0=t_0,
+ N=K,
+ device=device)
+ for step, order in enumerate(orders):
+ s, t = timesteps_outer[step], timesteps_outer[step + 1]
+ timesteps_inner = self.get_time_steps(
+ skip_type=skip_type,
+ t_T=s.item(),
+ t_0=t.item(),
+ N=order,
+ device=device)
+ lambda_inner = self.noise_schedule.marginal_lambda(
+ timesteps_inner)
+ h = lambda_inner[-1] - lambda_inner[0]
+ r1 = None if order <= 1 else (lambda_inner[1]
+ - lambda_inner[0]) / h
+ r2 = None if order <= 2 else (lambda_inner[2]
+ - lambda_inner[0]) / h
+ x = self.singlestep_dpm_solver_update(
+ x, s, t, order, solver_type=solver_type, r1=r1, r2=r2)
+
+ else:
+ raise ValueError('Got wrong method {}'.format(method))
+ if denoise_to_zero:
+ t = torch.ones((1, )).to(device) * t_0
+ x = self.denoise_to_zero_fn(x, t)
+ return x
+
+
+def interpolate_fn(x, xp, yp):
+ """
+ A piecewise linear function y = f(x), using xp and yp as keypoints.
+ We implement f(x) in a differentiable way (i.e. applicable for autograd).
+ The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp,
+ we use the outmost points of xp to define the linear function.)
+
+ Args:
+ x: PyTorch tensor with shape [N, C], where N is the batch size,
+ C is the number of channels (we use C = 1 for DPM-Solver).
+ xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
+ yp: PyTorch tensor with shape [C, K].
+ Returns:
+ The function values f(x), with shape [N, C].
+ """
+ N, K = x.shape[0], xp.shape[1]
+ all_x = torch.cat(
+ [x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
+ sorted_all_x, x_indices = torch.sort(all_x, dim=2)
+ x_idx = torch.argmin(x_indices, dim=2)
+ cand_start_idx = x_idx - 1
+ start_idx = torch.where(
+ torch.eq(x_idx, 0),
+ torch.tensor(1, device=x.device),
+ torch.where(
+ torch.eq(x_idx, K),
+ torch.tensor(K - 2, device=x.device),
+ cand_start_idx,
+ ),
+ )
+ end_idx = torch.where(
+ torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
+ start_x = torch.gather(
+ sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
+ end_x = torch.gather(
+ sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
+ start_idx2 = torch.where(
+ torch.eq(x_idx, 0),
+ torch.tensor(0, device=x.device),
+ torch.where(
+ torch.eq(x_idx, K),
+ torch.tensor(K - 2, device=x.device),
+ cand_start_idx,
+ ),
+ )
+ y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
+ start_y = torch.gather(
+ y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
+ end_y = torch.gather(
+ y_positions_expanded, dim=2,
+ index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
+ cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
+ return cand
+
+
+def expand_dims(v, dims):
+ """
+ Expand the tensor `v` to the dim `dims`.
+
+ Args:
+ `v`: a PyTorch tensor with shape [N].
+ `dim`: a `int`.
+ Returns:
+ a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
+ """
+ return v[(..., ) + (None, ) * (dims - 1)]
diff --git a/modelscope/models/multi_modal/videocomposer/mha_flash.py b/modelscope/models/multi_modal/videocomposer/mha_flash.py
new file mode 100644
index 00000000..009efd44
--- /dev/null
+++ b/modelscope/models/multi_modal/videocomposer/mha_flash.py
@@ -0,0 +1,120 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+
+import math
+import os
+import random
+import time
+
+import numpy as np
+import torch
+import torch.cuda.amp as amp
+import torch.nn as nn
+import torch.nn.functional as F
+from flash_attn.flash_attention import FlashAttention
+
+
+class FlashAttentionBlock(nn.Module):
+
+ def __init__(self,
+ dim,
+ context_dim=None,
+ num_heads=None,
+ head_dim=None,
+ batch_size=4):
+ # consider head_dim first, then num_heads
+ num_heads = dim // head_dim if head_dim else num_heads
+ head_dim = dim // num_heads
+ assert num_heads * head_dim == dim
+ super(FlashAttentionBlock, self).__init__()
+ self.dim = dim
+ self.context_dim = context_dim
+ self.num_heads = num_heads
+ self.head_dim = head_dim
+ self.scale = math.pow(head_dim, -0.25)
+
+ # layers
+ self.norm = nn.GroupNorm(32, dim)
+ self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
+ if context_dim is not None:
+ self.context_kv = nn.Linear(context_dim, dim * 2)
+ self.proj = nn.Conv2d(dim, dim, 1)
+
+ if self.head_dim <= 128 and (self.head_dim % 8) == 0:
+ self.flash_attn = FlashAttention(
+ softmax_scale=None, attention_dropout=0.0)
+
+ # zero out the last layer params
+ nn.init.zeros_(self.proj.weight)
+
+ def _init_weight(self, module):
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=0.15)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Conv2d):
+ module.weight.data.normal_(mean=0.0, std=0.15)
+ if module.bias is not None:
+ module.bias.data.zero_()
+
+ def forward(self, x, context=None):
+ r"""x: [B, C, H, W].
+ context: [B, L, C] or None.
+ """
+ identity = x
+ b, c, h, w, n, d = *x.size(), self.num_heads, self.head_dim
+
+ # compute query, key, value
+ x = self.norm(x)
+ q, k, v = self.to_qkv(x).view(b, n * 3, d, h * w).chunk(3, dim=1)
+ if context is not None:
+ ck, cv = self.context_kv(context).reshape(b, -1, n * 2,
+ d).permute(0, 2, 3,
+ 1).chunk(
+ 2, dim=1)
+ k = torch.cat([ck, k], dim=-1)
+ v = torch.cat([cv, v], dim=-1)
+ cq = torch.zeros([b, n, d, 4], dtype=q.dtype, device=q.device)
+ q = torch.cat([q, cq], dim=-1)
+
+ qkv = torch.cat([q, k, v], dim=1)
+ origin_dtype = qkv.dtype
+ qkv = qkv.permute(0, 3, 1, 2).reshape(b, -1, 3, n,
+ d).half().contiguous()
+ out, _ = self.flash_attn(qkv)
+ out.to(origin_dtype)
+
+ if context is not None:
+ out = out[:, :-4, :, :]
+ out = out.permute(0, 2, 3, 1).reshape(b, c, h, w)
+
+ # output
+ x = self.proj(out)
+ return x + identity
+
+
+if __name__ == '__main__':
+ batch_size = 8
+ flash_net = FlashAttentionBlock(
+ dim=1280,
+ context_dim=512,
+ num_heads=None,
+ head_dim=64,
+ batch_size=batch_size).cuda()
+
+ x = torch.randn([batch_size, 1280, 32, 32], dtype=torch.float32).cuda()
+ context = torch.randn([batch_size, 4, 512], dtype=torch.float32).cuda()
+ # context = None
+ flash_net.eval()
+
+ with amp.autocast(enabled=True):
+ # warm up
+ for i in range(5):
+ y = flash_net(x, context)
+ torch.cuda.synchronize()
+ s1 = time.time()
+ for i in range(10):
+ y = flash_net(x, context)
+ torch.cuda.synchronize()
+ s2 = time.time()
+
+ print(f'Average cost time {(s2-s1)*1000/10} ms')
diff --git a/modelscope/models/multi_modal/videocomposer/models/__init__.py b/modelscope/models/multi_modal/videocomposer/models/__init__.py
new file mode 100644
index 00000000..13e46148
--- /dev/null
+++ b/modelscope/models/multi_modal/videocomposer/models/__init__.py
@@ -0,0 +1,2 @@
+from .clip import *
+from .midas import *
diff --git a/modelscope/models/multi_modal/videocomposer/models/clip.py b/modelscope/models/multi_modal/videocomposer/models/clip.py
new file mode 100644
index 00000000..3ce2d818
--- /dev/null
+++ b/modelscope/models/multi_modal/videocomposer/models/clip.py
@@ -0,0 +1,460 @@
+import math
+import os.path as osp
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+import modelscope.models.multi_modal.videocomposer.ops as ops
+
+__all__ = [
+ 'CLIP', 'clip_vit_b_32', 'clip_vit_b_16', 'clip_vit_l_14',
+ 'clip_vit_l_14_336px', 'clip_vit_h_16'
+]
+
+
+def DOWNLOAD_TO_CACHE(oss_key,
+ file_or_dirname=None,
+ cache_dir=osp.join(
+ '/'.join(osp.abspath(__file__).split('/')[:-2]),
+ 'model_weights')):
+ r"""Download OSS [file or folder] to the cache folder.
+ Only the 0th process on each node will run the downloading.
+ Barrier all processes until the downloading is completed.
+ """
+ # source and target paths
+ base_path = osp.join(cache_dir, file_or_dirname or osp.basename(oss_key))
+
+ return base_path
+
+
+def to_fp16(m):
+ if isinstance(m, (nn.Linear, nn.Conv2d)):
+ m.weight.data = m.weight.data.half()
+ if m.bias is not None:
+ m.bias.data = m.bias.data.half()
+ elif hasattr(m, 'head'):
+ p = getattr(m, 'head')
+ p.data = p.data.half()
+
+
+class QuickGELU(nn.Module):
+
+ def forward(self, x):
+ return x * torch.sigmoid(1.702 * x)
+
+
+class LayerNorm(nn.LayerNorm):
+ r"""Subclass of nn.LayerNorm to handle fp16.
+ """
+
+ def forward(self, x):
+ return super(LayerNorm, self).forward(x.float()).type_as(x)
+
+
+class SelfAttention(nn.Module):
+
+ def __init__(self, dim, num_heads, attn_dropout=0.0, proj_dropout=0.0):
+ assert dim % num_heads == 0
+ super(SelfAttention, self).__init__()
+ self.dim = dim
+ self.num_heads = num_heads
+ self.head_dim = dim // num_heads
+ self.scale = 1.0 / math.sqrt(self.head_dim)
+
+ # layers
+ self.to_qkv = nn.Linear(dim, dim * 3)
+ self.attn_dropout = nn.Dropout(attn_dropout)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_dropout = nn.Dropout(proj_dropout)
+
+ def forward(self, x, mask=None):
+ r"""x: [B, L, C].
+ mask: [*, L, L].
+ """
+ b, l, _, n = *x.size(), self.num_heads
+
+ # compute query, key, and value
+ q, k, v = self.to_qkv(x.transpose(0, 1)).chunk(3, dim=-1)
+ q = q.reshape(l, b * n, -1).transpose(0, 1)
+ k = k.reshape(l, b * n, -1).transpose(0, 1)
+ v = v.reshape(l, b * n, -1).transpose(0, 1)
+
+ # compute attention
+ attn = self.scale * torch.bmm(q, k.transpose(1, 2))
+ if mask is not None:
+ attn = attn.masked_fill(mask[:, :l, :l] == 0, float('-inf'))
+ attn = F.softmax(attn.float(), dim=-1).type_as(attn)
+ attn = self.attn_dropout(attn)
+
+ # gather context
+ x = torch.bmm(attn, v)
+ x = x.view(b, n, l, -1).transpose(1, 2).reshape(b, l, -1)
+
+ # output
+ x = self.proj(x)
+ x = self.proj_dropout(x)
+ return x
+
+
+class AttentionBlock(nn.Module):
+
+ def __init__(self, dim, num_heads, attn_dropout=0.0, proj_dropout=0.0):
+ super(AttentionBlock, self).__init__()
+ self.dim = dim
+ self.num_heads = num_heads
+
+ # layers
+ self.norm1 = LayerNorm(dim)
+ self.attn = SelfAttention(dim, num_heads, attn_dropout, proj_dropout)
+ self.norm2 = LayerNorm(dim)
+ self.mlp = nn.Sequential(
+ nn.Linear(dim, dim * 4), QuickGELU(), nn.Linear(dim * 4, dim),
+ nn.Dropout(proj_dropout))
+
+ def forward(self, x, mask=None):
+ x = x + self.attn(self.norm1(x), mask)
+ x = x + self.mlp(self.norm2(x))
+ return x
+
+
+class VisionTransformer(nn.Module):
+
+ def __init__(self,
+ image_size=224,
+ patch_size=16,
+ dim=768,
+ out_dim=512,
+ num_heads=12,
+ num_layers=12,
+ attn_dropout=0.0,
+ proj_dropout=0.0,
+ embedding_dropout=0.0):
+ assert image_size % patch_size == 0
+ super(VisionTransformer, self).__init__()
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.dim = dim
+ self.out_dim = out_dim
+ self.num_heads = num_heads
+ self.num_layers = num_layers
+ self.num_patches = (image_size // patch_size)**2
+
+ # embeddings
+ gain = 1.0 / math.sqrt(dim)
+ self.patch_embedding = nn.Conv2d(
+ 3, dim, kernel_size=patch_size, stride=patch_size, bias=False)
+ self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
+ self.pos_embedding = nn.Parameter(
+ gain * torch.randn(1, self.num_patches + 1, dim))
+ self.dropout = nn.Dropout(embedding_dropout)
+
+ # transformer
+ self.pre_norm = LayerNorm(dim)
+ self.transformer = nn.Sequential(*[
+ AttentionBlock(dim, num_heads, attn_dropout, proj_dropout)
+ for _ in range(num_layers)
+ ])
+ self.post_norm = LayerNorm(dim)
+
+ # head
+ self.head = nn.Parameter(gain * torch.randn(dim, out_dim))
+
+ def forward(self, x):
+ b, dtype = x.size(0), self.head.dtype
+ x = x.type(dtype)
+
+ # patch-embedding
+ x = self.patch_embedding(x).flatten(2).permute(0, 2, 1) # [b, n, c]
+ x = torch.cat([self.cls_embedding.repeat(b, 1, 1).type(dtype), x],
+ dim=1)
+ x = self.dropout(x + self.pos_embedding.type(dtype))
+ x = self.pre_norm(x)
+
+ # transformer
+ x = self.transformer(x)
+
+ # head
+ x = self.post_norm(x)
+ x = torch.mm(x[:, 0, :], self.head)
+ return x
+
+ def fp16(self):
+ return self.apply(to_fp16)
+
+
+class TextTransformer(nn.Module):
+
+ def __init__(self,
+ vocab_size,
+ text_len,
+ dim=512,
+ out_dim=512,
+ num_heads=8,
+ num_layers=12,
+ attn_dropout=0.0,
+ proj_dropout=0.0,
+ embedding_dropout=0.0):
+ super(TextTransformer, self).__init__()
+ self.vocab_size = vocab_size
+ self.text_len = text_len
+ self.dim = dim
+ self.out_dim = out_dim
+ self.num_heads = num_heads
+ self.num_layers = num_layers
+
+ # embeddings
+ self.token_embedding = nn.Embedding(vocab_size, dim)
+ self.pos_embedding = nn.Parameter(0.01 * torch.randn(1, text_len, dim))
+ self.dropout = nn.Dropout(embedding_dropout)
+
+ # transformer
+ self.transformer = nn.ModuleList([
+ AttentionBlock(dim, num_heads, attn_dropout, proj_dropout)
+ for _ in range(num_layers)
+ ])
+ self.norm = LayerNorm(dim)
+
+ # head
+ gain = 1.0 / math.sqrt(dim)
+ self.head = nn.Parameter(gain * torch.randn(dim, out_dim))
+
+ # causal attention mask
+ self.register_buffer('attn_mask',
+ torch.tril(torch.ones(1, text_len, text_len)))
+
+ def forward(self, x):
+ eot, dtype = x.argmax(dim=-1), self.head.dtype
+
+ # embeddings
+ x = self.dropout(
+ self.token_embedding(x).type(dtype)
+ + self.pos_embedding.type(dtype))
+
+ # transformer
+ for block in self.transformer:
+ x = block(x, self.attn_mask)
+
+ # head
+ x = self.norm(x)
+ x = torch.mm(x[torch.arange(x.size(0)), eot], self.head)
+ return x
+
+ def fp16(self):
+ return self.apply(to_fp16)
+
+
+class CLIP(nn.Module):
+
+ def __init__(self,
+ embed_dim=512,
+ image_size=224,
+ patch_size=16,
+ vision_dim=768,
+ vision_heads=12,
+ vision_layers=12,
+ vocab_size=49408,
+ text_len=77,
+ text_dim=512,
+ text_heads=8,
+ text_layers=12,
+ attn_dropout=0.0,
+ proj_dropout=0.0,
+ embedding_dropout=0.0):
+ super(CLIP, self).__init__()
+ self.embed_dim = embed_dim
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.vision_dim = vision_dim
+ self.vision_heads = vision_heads
+ self.vision_layers = vision_layers
+ self.vocab_size = vocab_size
+ self.text_len = text_len
+ self.text_dim = text_dim
+ self.text_heads = text_heads
+ self.text_layers = text_layers
+
+ # models
+ self.visual = VisionTransformer(
+ image_size=image_size,
+ patch_size=patch_size,
+ dim=vision_dim,
+ out_dim=embed_dim,
+ num_heads=vision_heads,
+ num_layers=vision_layers,
+ attn_dropout=attn_dropout,
+ proj_dropout=proj_dropout,
+ embedding_dropout=embedding_dropout)
+ self.textual = TextTransformer(
+ vocab_size=vocab_size,
+ text_len=text_len,
+ dim=text_dim,
+ out_dim=embed_dim,
+ num_heads=text_heads,
+ num_layers=text_layers,
+ attn_dropout=attn_dropout,
+ proj_dropout=proj_dropout,
+ embedding_dropout=embedding_dropout)
+ self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([]))
+
+ def forward(self, imgs, txt_tokens):
+ r"""imgs: [B, C, H, W] of torch.float32.
+ txt_tokens: [B, T] of torch.long.
+ """
+ xi = self.visual(imgs)
+ xt = self.textual(txt_tokens)
+
+ # normalize features
+ xi = F.normalize(xi, p=2, dim=1)
+ xt = F.normalize(xt, p=2, dim=1)
+
+ # gather features from all ranks
+ full_xi = ops.diff_all_gather(xi)
+ full_xt = ops.diff_all_gather(xt)
+
+ # logits
+ scale = self.log_scale.exp()
+ logits_i2t = scale * torch.mm(xi, full_xt.t())
+ logits_t2i = scale * torch.mm(xt, full_xi.t())
+
+ # labels
+ labels = torch.arange(
+ len(xi) * ops.get_rank(),
+ len(xi) * (ops.get_rank() + 1),
+ dtype=torch.long,
+ device=xi.device)
+ return logits_i2t, logits_t2i, labels
+
+ def init_weights(self):
+ # embeddings
+ nn.init.normal_(self.textual.token_embedding.weight, std=0.02)
+ nn.init.normal_(self.visual.patch_embedding.weight, tsd=0.1)
+
+ # attentions
+ for modality in ['visual', 'textual']:
+ dim = self.vision_dim if modality == 'visual' else 'textual'
+ transformer = getattr(self, modality).transformer
+ proj_gain = (1.0 / math.sqrt(dim)) * (
+ 1.0 / math.sqrt(2 * transformer.num_layers))
+ attn_gain = 1.0 / math.sqrt(dim)
+ mlp_gain = 1.0 / math.sqrt(2.0 * dim)
+ for block in transformer.layers:
+ nn.init.normal_(block.attn.to_qkv.weight, std=attn_gain)
+ nn.init.normal_(block.attn.proj.weight, std=proj_gain)
+ nn.init.normal_(block.mlp[0].weight, std=mlp_gain)
+ nn.init.normal_(block.mlp[2].weight, std=proj_gain)
+
+ def param_groups(self):
+ groups = [{
+ 'params': [
+ p for n, p in self.named_parameters()
+ if 'norm' in n or n.endswith('bias')
+ ],
+ 'weight_decay':
+ 0.0
+ }, {
+ 'params': [
+ p for n, p in self.named_parameters()
+ if not ('norm' in n or n.endswith('bias'))
+ ]
+ }]
+ return groups
+
+ def fp16(self):
+ return self.apply(to_fp16)
+
+
+def _clip(name, pretrained=False, **kwargs):
+ model = CLIP(**kwargs)
+ if pretrained:
+ model.load_state_dict(
+ torch.load(
+ DOWNLOAD_TO_CACHE(f'models/clip/{name}.pth'),
+ map_location='cpu'))
+ return model
+
+
+def clip_vit_b_32(pretrained=False, **kwargs):
+ cfg = dict(
+ embed_dim=512,
+ image_size=224,
+ patch_size=32,
+ vision_dim=768,
+ vision_heads=12,
+ vision_layers=12,
+ vocab_size=49408,
+ text_len=77,
+ text_dim=512,
+ text_heads=8,
+ text_layers=12)
+ cfg.update(**kwargs)
+ return _clip('openai-clip-vit-base-32', pretrained, **cfg)
+
+
+def clip_vit_b_16(pretrained=False, **kwargs):
+ cfg = dict(
+ embed_dim=512,
+ image_size=224,
+ patch_size=32,
+ vision_dim=768,
+ vision_heads=12,
+ vision_layers=12,
+ vocab_size=49408,
+ text_len=77,
+ text_dim=512,
+ text_heads=8,
+ text_layers=12)
+ cfg.update(**kwargs)
+ return _clip('openai-clip-vit-base-16', pretrained, **cfg)
+
+
+def clip_vit_l_14(pretrained=False, **kwargs):
+ cfg = dict(
+ embed_dim=768,
+ image_size=224,
+ patch_size=14,
+ vision_dim=1024,
+ vision_heads=16,
+ vision_layers=24,
+ vocab_size=49408,
+ text_len=77,
+ text_dim=768,
+ text_heads=12,
+ text_layers=12)
+ cfg.update(**kwargs)
+ return _clip('openai-clip-vit-large-14', pretrained, **cfg)
+
+
+def clip_vit_l_14_336px(pretrained=False, **kwargs):
+ cfg = dict(
+ embed_dim=768,
+ image_size=336,
+ patch_size=14,
+ vision_dim=1024,
+ vision_heads=16,
+ vision_layers=24,
+ vocab_size=49408,
+ text_len=77,
+ text_dim=768,
+ text_heads=12,
+ text_layers=12)
+ cfg.update(**kwargs)
+ return _clip('openai-clip-vit-large-14-336px', pretrained, **cfg)
+
+
+def clip_vit_h_16(pretrained=False, **kwargs):
+ assert not pretrained, 'pretrained model for openai-clip-vit-huge-16 is not available!'
+ cfg = dict(
+ embed_dim=1024,
+ image_size=256,
+ patch_size=16,
+ vision_dim=1280,
+ vision_heads=16,
+ vision_layers=32,
+ vocab_size=49408,
+ text_len=77,
+ text_dim=1024,
+ text_heads=16,
+ text_layers=24)
+ cfg.update(**kwargs)
+ return _clip('openai-clip-vit-huge-16', pretrained, **cfg)
diff --git a/modelscope/models/multi_modal/videocomposer/models/midas.py b/modelscope/models/multi_modal/videocomposer/models/midas.py
new file mode 100644
index 00000000..4992e721
--- /dev/null
+++ b/modelscope/models/multi_modal/videocomposer/models/midas.py
@@ -0,0 +1,320 @@
+r"""A much cleaner re-implementation of ``https://github.com/isl-org/MiDaS''.
+ Image augmentation: T.Compose([
+ Resize(
+ keep_aspect_ratio=True,
+ ensure_multiple_of=32,
+ interpolation=cv2.INTER_CUBIC),
+ T.ToTensor(),
+ T.Normalize(
+ mean=[0.5, 0.5, 0.5],
+ std=[0.5, 0.5, 0.5])]).
+ Fast inference:
+ model = model.to(memory_format=torch.channels_last).half()
+ input = input.to(memory_format=torch.channels_last).half()
+ output = model(input)
+"""
+import math
+import os
+import os.path as osp
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+__all__ = ['MiDaS', 'midas_v3']
+
+
+class SelfAttention(nn.Module):
+
+ def __init__(self, dim, num_heads):
+ assert dim % num_heads == 0
+ super(SelfAttention, self).__init__()
+ self.dim = dim
+ self.num_heads = num_heads
+ self.head_dim = dim // num_heads
+ self.scale = 1.0 / math.sqrt(self.head_dim)
+
+ # layers
+ self.to_qkv = nn.Linear(dim, dim * 3)
+ self.proj = nn.Linear(dim, dim)
+
+ def forward(self, x):
+ b, l, c, n, d = *x.size(), self.num_heads, self.head_dim
+
+ # compute query, key, value
+ q, k, v = self.to_qkv(x).view(b, l, n * 3, d).chunk(3, dim=2)
+
+ # compute attention
+ attn = self.scale * torch.einsum('binc,bjnc->bnij', q, k)
+ attn = F.softmax(attn.float(), dim=-1).type_as(attn)
+
+ # gather context
+ x = torch.einsum('bnij,bjnc->binc', attn, v)
+ x = x.reshape(b, l, c)
+
+ # output
+ x = self.proj(x)
+ return x
+
+
+class AttentionBlock(nn.Module):
+
+ def __init__(self, dim, num_heads):
+ super(AttentionBlock, self).__init__()
+ self.dim = dim
+ self.num_heads = num_heads
+
+ # layers
+ self.norm1 = nn.LayerNorm(dim)
+ self.attn = SelfAttention(dim, num_heads)
+ self.norm2 = nn.LayerNorm(dim)
+ self.mlp = nn.Sequential(
+ nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim))
+
+ def forward(self, x):
+ x = x + self.attn(self.norm1(x))
+ x = x + self.mlp(self.norm2(x))
+ return x
+
+
+class VisionTransformer(nn.Module):
+
+ def __init__(self,
+ image_size=384,
+ patch_size=16,
+ dim=1024,
+ out_dim=1000,
+ num_heads=16,
+ num_layers=24):
+ assert image_size % patch_size == 0
+ super(VisionTransformer, self).__init__()
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.dim = dim
+ self.out_dim = out_dim
+ self.num_heads = num_heads
+ self.num_layers = num_layers
+ self.num_patches = (image_size // patch_size)**2
+
+ # embeddings
+ self.patch_embedding = nn.Conv2d(
+ 3, dim, kernel_size=patch_size, stride=patch_size)
+ self.cls_embedding = nn.Parameter(torch.zeros(1, 1, dim))
+ self.pos_embedding = nn.Parameter(
+ torch.empty(1, self.num_patches + 1, dim).normal_(std=0.02))
+
+ # blocks
+ self.blocks = nn.Sequential(
+ *[AttentionBlock(dim, num_heads) for _ in range(num_layers)])
+ self.norm = nn.LayerNorm(dim)
+
+ # head
+ self.head = nn.Linear(dim, out_dim)
+
+ def forward(self, x):
+ b = x.size(0)
+
+ # embeddings
+ x = self.patch_embedding(x).flatten(2).permute(0, 2, 1)
+ x = torch.cat([self.cls_embedding.repeat(b, 1, 1), x], dim=1)
+ x = x + self.pos_embedding
+
+ # blocks
+ x = self.blocks(x)
+ x = self.norm(x)
+
+ # head
+ x = self.head(x)
+ return x
+
+
+class ResidualBlock(nn.Module):
+
+ def __init__(self, dim):
+ super(ResidualBlock, self).__init__()
+ self.dim = dim
+
+ # layers
+ self.residual = nn.Sequential(
+ nn.ReLU(inplace=False), # NOTE: avoid modifying the input
+ nn.Conv2d(dim, dim, 3, padding=1),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(dim, dim, 3, padding=1))
+
+ def forward(self, x):
+ return x + self.residual(x)
+
+
+class FusionBlock(nn.Module):
+
+ def __init__(self, dim):
+ super(FusionBlock, self).__init__()
+ self.dim = dim
+
+ # layers
+ self.layer1 = ResidualBlock(dim)
+ self.layer2 = ResidualBlock(dim)
+ self.conv_out = nn.Conv2d(dim, dim, 1)
+
+ def forward(self, *xs):
+ assert len(xs) in (1, 2), 'invalid number of inputs'
+ if len(xs) == 1:
+ x = self.layer2(xs[0])
+ else:
+ x = self.layer2(xs[0] + self.layer1(xs[1]))
+ x = F.interpolate(
+ x, scale_factor=2, mode='bilinear', align_corners=True)
+ x = self.conv_out(x)
+ return x
+
+
+class MiDaS(nn.Module):
+ r"""MiDaS v3.0 DPT-Large from ``https://github.com/isl-org/MiDaS''.
+ Monocular depth estimation using dense prediction transformers.
+ """
+
+ def __init__(self,
+ image_size=384,
+ patch_size=16,
+ dim=1024,
+ neck_dims=[256, 512, 1024, 1024],
+ fusion_dim=256,
+ num_heads=16,
+ num_layers=24):
+ assert image_size % patch_size == 0
+ assert num_layers % 4 == 0
+ super(MiDaS, self).__init__()
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.dim = dim
+ self.neck_dims = neck_dims
+ self.fusion_dim = fusion_dim
+ self.num_heads = num_heads
+ self.num_layers = num_layers
+ self.num_patches = (image_size // patch_size)**2
+
+ # embeddings
+ self.patch_embedding = nn.Conv2d(
+ 3, dim, kernel_size=patch_size, stride=patch_size)
+ self.cls_embedding = nn.Parameter(torch.zeros(1, 1, dim))
+ self.pos_embedding = nn.Parameter(
+ torch.empty(1, self.num_patches + 1, dim).normal_(std=0.02))
+
+ # blocks
+ stride = num_layers // 4
+ self.blocks = nn.Sequential(
+ *[AttentionBlock(dim, num_heads) for _ in range(num_layers)])
+ self.slices = [slice(i * stride, (i + 1) * stride) for i in range(4)]
+
+ # stage1 (4x)
+ self.fc1 = nn.Sequential(nn.Linear(dim * 2, dim), nn.GELU())
+ self.conv1 = nn.Sequential(
+ nn.Conv2d(dim, neck_dims[0], 1),
+ nn.ConvTranspose2d(neck_dims[0], neck_dims[0], 4, stride=4),
+ nn.Conv2d(neck_dims[0], fusion_dim, 3, padding=1, bias=False))
+ self.fusion1 = FusionBlock(fusion_dim)
+
+ # stage2 (8x)
+ self.fc2 = nn.Sequential(nn.Linear(dim * 2, dim), nn.GELU())
+ self.conv2 = nn.Sequential(
+ nn.Conv2d(dim, neck_dims[1], 1),
+ nn.ConvTranspose2d(neck_dims[1], neck_dims[1], 2, stride=2),
+ nn.Conv2d(neck_dims[1], fusion_dim, 3, padding=1, bias=False))
+ self.fusion2 = FusionBlock(fusion_dim)
+
+ # stage3 (16x)
+ self.fc3 = nn.Sequential(nn.Linear(dim * 2, dim), nn.GELU())
+ self.conv3 = nn.Sequential(
+ nn.Conv2d(dim, neck_dims[2], 1),
+ nn.Conv2d(neck_dims[2], fusion_dim, 3, padding=1, bias=False))
+ self.fusion3 = FusionBlock(fusion_dim)
+
+ # stage4 (32x)
+ self.fc4 = nn.Sequential(nn.Linear(dim * 2, dim), nn.GELU())
+ self.conv4 = nn.Sequential(
+ nn.Conv2d(dim, neck_dims[3], 1),
+ nn.Conv2d(neck_dims[3], neck_dims[3], 3, stride=2, padding=1),
+ nn.Conv2d(neck_dims[3], fusion_dim, 3, padding=1, bias=False))
+ self.fusion4 = FusionBlock(fusion_dim)
+
+ # head
+ self.head = nn.Sequential(
+ nn.Conv2d(fusion_dim, fusion_dim // 2, 3, padding=1),
+ nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
+ nn.Conv2d(fusion_dim // 2, 32, 3, padding=1),
+ nn.ReLU(inplace=True), nn.ConvTranspose2d(32, 1, 1),
+ nn.ReLU(inplace=True))
+
+ def forward(self, x):
+ b, _, h, w, p = *x.size(), self.patch_size
+ assert h % p == 0 and w % p == 0, f'Image size ({w}, {h}) is not divisible by patch size ({p}, {p})'
+ hp, wp, grid = h // p, w // p, self.image_size // p
+
+ # embeddings
+ pos_embedding = torch.cat([
+ self.pos_embedding[:, :1],
+ F.interpolate(
+ self.pos_embedding[:, 1:].reshape(1, grid, grid, -1).permute(
+ 0, 3, 1, 2),
+ size=(hp, wp),
+ mode='bilinear',
+ align_corners=False).permute(0, 2, 3, 1).reshape(
+ 1, hp * wp, -1)
+ ],
+ dim=1) # noqa
+ x = self.patch_embedding(x).flatten(2).permute(0, 2, 1)
+ x = torch.cat([self.cls_embedding.repeat(b, 1, 1), x], dim=1)
+ x = x + pos_embedding
+
+ # stage1
+ x = self.blocks[self.slices[0]](x)
+ x1 = torch.cat([x[:, 1:], x[:, :1].expand_as(x[:, 1:])], dim=-1)
+ x1 = self.fc1(x1).permute(0, 2, 1).unflatten(2, (hp, wp))
+ x1 = self.conv1(x1)
+
+ # stage2
+ x = self.blocks[self.slices[1]](x)
+ x2 = torch.cat([x[:, 1:], x[:, :1].expand_as(x[:, 1:])], dim=-1)
+ x2 = self.fc2(x2).permute(0, 2, 1).unflatten(2, (hp, wp))
+ x2 = self.conv2(x2)
+
+ # stage3
+ x = self.blocks[self.slices[2]](x)
+ x3 = torch.cat([x[:, 1:], x[:, :1].expand_as(x[:, 1:])], dim=-1)
+ x3 = self.fc3(x3).permute(0, 2, 1).unflatten(2, (hp, wp))
+ x3 = self.conv3(x3)
+
+ # stage4
+ x = self.blocks[self.slices[3]](x)
+ x4 = torch.cat([x[:, 1:], x[:, :1].expand_as(x[:, 1:])], dim=-1)
+ x4 = self.fc4(x4).permute(0, 2, 1).unflatten(2, (hp, wp))
+ x4 = self.conv4(x4)
+
+ # fusion
+ x4 = self.fusion4(x4)
+ x3 = self.fusion3(x4, x3)
+ x2 = self.fusion2(x3, x2)
+ x1 = self.fusion1(x2, x1)
+
+ # head
+ x = self.head(x1)
+ return x
+
+
+def midas_v3(model_dir, pretrained=False, **kwargs):
+ cfg = dict(
+ image_size=384,
+ patch_size=16,
+ dim=1024,
+ neck_dims=[256, 512, 1024, 1024],
+ fusion_dim=256,
+ num_heads=16,
+ num_layers=24)
+ cfg.update(**kwargs)
+ model = MiDaS(**cfg)
+ if pretrained:
+ model.load_state_dict(
+ torch.load(
+ os.path.join(model_dir, 'midas_v3_dpt_large.pth'),
+ map_location='cpu'))
+ return model
diff --git a/modelscope/models/multi_modal/videocomposer/ops/__init__.py b/modelscope/models/multi_modal/videocomposer/ops/__init__.py
new file mode 100644
index 00000000..48b87bda
--- /dev/null
+++ b/modelscope/models/multi_modal/videocomposer/ops/__init__.py
@@ -0,0 +1,7 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+
+from .degration import *
+from .distributed import *
+from .losses import *
+from .random_mask import *
+from .utils import *
diff --git a/modelscope/models/multi_modal/videocomposer/ops/degration.py b/modelscope/models/multi_modal/videocomposer/ops/degration.py
new file mode 100644
index 00000000..de97be65
--- /dev/null
+++ b/modelscope/models/multi_modal/videocomposer/ops/degration.py
@@ -0,0 +1,998 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+
+import math
+import os
+import random
+from datetime import datetime
+
+import numpy as np
+import scipy
+import scipy.stats as stats
+import torch
+from scipy import ndimage
+from scipy.interpolate import interp2d
+from scipy.linalg import orth
+from torchvision.utils import make_grid
+
+os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
+
+__all__ = ['degradation_bsrgan_light', 'degradation_bsrgan']
+
+
+# get uint8 image of size HxWxn_channles (RGB)
+def imread_uint(path, n_channels=3):
+ # input: path
+ # output: HxWx3(RGB or GGG), or HxWx1 (G)
+ if n_channels == 1:
+ img = cv2.imread(path, 0)
+ img = np.expand_dims(img, axis=2)
+ elif n_channels == 3:
+ img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
+ if img.ndim == 2:
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
+ else:
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ return img
+
+
+def uint2single(img):
+ return np.float32(img / 255.)
+
+
+def single2uint(img):
+ return np.uint8((img.clip(0, 1) * 255.).round())
+
+
+def uint162single(img):
+ return np.float32(img / 65535.)
+
+
+def single2uint16(img):
+ return np.uint16((img.clip(0, 1) * 65535.).round())
+
+
+def rgb2ycbcr(img, only_y=True):
+ '''same as matlab rgb2ycbcr
+ only_y: only return Y channel
+ Input:
+ uint8, [0, 255]
+ float, [0, 1]
+ '''
+ in_img_type = img.dtype
+ img.astype(np.float32)
+ if in_img_type != np.uint8:
+ img *= 255.
+ # convert
+ if only_y:
+ rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0
+ else:
+ rlt = np.matmul(img,
+ [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786],
+ [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128]
+ if in_img_type == np.uint8:
+ rlt = rlt.round()
+ else:
+ rlt /= 255.
+ return rlt.astype(in_img_type)
+
+
+def ycbcr2rgb(img):
+ '''same as matlab ycbcr2rgb
+ Input:
+ uint8, [0, 255]
+ float, [0, 1]
+ '''
+ in_img_type = img.dtype
+ img.astype(np.float32)
+ if in_img_type != np.uint8:
+ img *= 255.
+ # convert
+ rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621],
+ [0, -0.00153632, 0.00791071],
+ [0.00625893, -0.00318811, 0]]) * 255.0 + [
+ -222.921, 135.576, -276.836
+ ] # noqa
+ if in_img_type == np.uint8:
+ rlt = rlt.round()
+ else:
+ rlt /= 255.
+ return rlt.astype(in_img_type)
+
+
+def bgr2ycbcr(img, only_y=True):
+ '''bgr version of rgb2ycbcr
+ only_y: only return Y channel
+ Input:
+ uint8, [0, 255]
+ float, [0, 1]
+ '''
+ in_img_type = img.dtype
+ img.astype(np.float32)
+ if in_img_type != np.uint8:
+ img *= 255.
+ # convert
+ if only_y:
+ rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0
+ else:
+ rlt = np.matmul(img,
+ [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
+ [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128]
+ if in_img_type == np.uint8:
+ rlt = rlt.round()
+ else:
+ rlt /= 255.
+ return rlt.astype(in_img_type)
+
+
+def channel_convert(in_c, tar_type, img_list):
+ # conversion among BGR, gray and y
+ if in_c == 3 and tar_type == 'gray':
+ gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list]
+ return [np.expand_dims(img, axis=2) for img in gray_list]
+ elif in_c == 3 and tar_type == 'y':
+ y_list = [bgr2ycbcr(img, only_y=True) for img in img_list]
+ return [np.expand_dims(img, axis=2) for img in y_list]
+ elif in_c == 1 and tar_type == 'RGB':
+ return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list]
+ else:
+ return img_list
+
+
+# PSNR
+def calculate_psnr(img1, img2, border=0):
+ # img1 and img2 have range [0, 255]
+ if not img1.shape == img2.shape:
+ raise ValueError('Input images must have the same dimensions.')
+ h, w = img1.shape[:2]
+ img1 = img1[border:h - border, border:w - border]
+ img2 = img2[border:h - border, border:w - border]
+
+ img1 = img1.astype(np.float64)
+ img2 = img2.astype(np.float64)
+ mse = np.mean((img1 - img2)**2)
+ if mse == 0:
+ return float('inf')
+ return 20 * math.log10(255.0 / math.sqrt(mse))
+
+
+# SSIM
+def calculate_ssim(img1, img2, border=0):
+ '''calculate SSIM
+ the same outputs as MATLAB's
+ img1, img2: [0, 255]
+ '''
+ if not img1.shape == img2.shape:
+ raise ValueError('Input images must have the same dimensions.')
+ h, w = img1.shape[:2]
+ img1 = img1[border:h - border, border:w - border]
+ img2 = img2[border:h - border, border:w - border]
+
+ if img1.ndim == 2:
+ return ssim(img1, img2)
+ elif img1.ndim == 3:
+ if img1.shape[2] == 3:
+ ssims = []
+ for i in range(3):
+ ssims.append(ssim(img1[:, :, i], img2[:, :, i]))
+ return np.array(ssims).mean()
+ elif img1.shape[2] == 1:
+ return ssim(np.squeeze(img1), np.squeeze(img2))
+ else:
+ raise ValueError('Wrong input image dimensions.')
+
+
+def ssim(img1, img2):
+ C1 = (0.01 * 255)**2
+ C2 = (0.03 * 255)**2
+
+ img1 = img1.astype(np.float64)
+ img2 = img2.astype(np.float64)
+ kernel = cv2.getGaussianKernel(11, 1.5)
+ window = np.outer(kernel, kernel.transpose())
+
+ mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5]
+ mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
+ mu1_sq = mu1**2
+ mu2_sq = mu2**2
+ mu1_mu2 = mu1 * mu2
+ sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
+ sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
+ sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
+
+ ssim_map = ((2 * mu1_mu2 + C1) # noqa
+ * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) # noqa
+ * # noqa
+ (sigma1_sq + sigma2_sq + C2)) # noqa
+ return ssim_map.mean()
+
+
+# matlab 'imresize' function, now only support 'bicubic'
+def cubic(x):
+ absx = torch.abs(x)
+ absx2 = absx**2
+ absx3 = absx**3
+ return (1.5 * absx3 - 2.5 * absx2 + 1) * ((absx <= 1).type_as(absx)) + \
+ (-0.5 * absx3 + 2.5 * absx2 - 4*absx + 2) * (((absx > 1) * (absx <= 2)).type_as(absx)) # noqa
+
+
+def calculate_weights_indices(in_length, out_length, scale, kernel,
+ kernel_width, antialiasing):
+ if (scale < 1) and (antialiasing):
+ # Use a modified kernel to simultaneously interpolate and antialias- larger kernel width
+ kernel_width = kernel_width / scale
+
+ # Output-space coordinates
+ x = torch.linspace(1, out_length, out_length)
+
+ # Input-space coordinates. Calculate the inverse mapping such that 0.5
+ # in output space maps to 0.5 in input space, and 0.5+scale in output
+ # space maps to 1.5 in input space.
+ u = x / scale + 0.5 * (1 - 1 / scale)
+
+ # What is the left-most pixel that can be involved in the computation?
+ left = torch.floor(u - kernel_width / 2)
+
+ # What is the maximum number of pixels that can be involved in the
+ # computation? Note: it's OK to use an extra pixel here; if the
+ # corresponding weights are all zero, it will be eliminated at the end
+ # of this function.
+ P = math.ceil(kernel_width) + 2
+
+ # The indices of the input pixels involved in computing the k-th output
+ # pixel are in row k of the indices matrix.
+ indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(
+ 0, P - 1, P).view(1, P).expand(out_length, P)
+
+ # The weights used to compute the k-th output pixel are in row k of the
+ # weights matrix.
+ distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices
+ # apply cubic kernel
+ if (scale < 1) and (antialiasing):
+ weights = scale * cubic(distance_to_center * scale)
+ else:
+ weights = cubic(distance_to_center)
+ # Normalize the weights matrix so that each row sums to 1.
+ weights_sum = torch.sum(weights, 1).view(out_length, 1)
+ weights = weights / weights_sum.expand(out_length, P)
+
+ # If a column in weights is all zero, get rid of it. only consider the first and last column.
+ weights_zero_tmp = torch.sum((weights == 0), 0)
+ if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
+ indices = indices.narrow(1, 1, P - 2)
+ weights = weights.narrow(1, 1, P - 2)
+ if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
+ indices = indices.narrow(1, 0, P - 2)
+ weights = weights.narrow(1, 0, P - 2)
+ weights = weights.contiguous()
+ indices = indices.contiguous()
+ sym_len_s = -indices.min() + 1
+ sym_len_e = indices.max() - in_length
+ indices = indices + sym_len_s - 1
+ return weights, indices, int(sym_len_s), int(sym_len_e)
+
+
+# imresize for tensor image [0, 1]
+def imresize(img, scale, antialiasing=True):
+ # Now the scale should be the same for H and W
+ # input: img: pytorch tensor, CHW or HW [0,1]
+ # output: CHW or HW [0,1] w/o round
+ need_squeeze = True if img.dim() == 2 else False
+ if need_squeeze:
+ img.unsqueeze_(0)
+ in_C, in_H, in_W = img.size()
+ out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W
+ * scale)
+ kernel_width = 4
+ kernel = 'cubic'
+
+ # Return the desired dimension order for performing the resize. The
+ # strategy is to perform the resize first along the dimension with the
+ # smallest scale factor.
+ # Now we do not support this.
+
+ # get weights and indices
+ weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
+ in_H, out_H, scale, kernel, kernel_width, antialiasing)
+ weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
+ in_W, out_W, scale, kernel, kernel_width, antialiasing)
+ # process H dimension
+ # symmetric copying
+ img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W)
+ img_aug.narrow(1, sym_len_Hs, in_H).copy_(img)
+
+ sym_patch = img[:, :sym_len_Hs, :]
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
+ img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv)
+
+ sym_patch = img[:, -sym_len_He:, :]
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
+ img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
+
+ out_1 = torch.FloatTensor(in_C, out_H, in_W)
+ kernel_width = weights_H.size(1)
+ for i in range(out_H):
+ idx = int(indices_H[i][0])
+ for j in range(out_C):
+ out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(
+ 0, 1).mv(weights_H[i])
+
+ # process W dimension
+ # symmetric copying
+ out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We)
+ out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1)
+
+ sym_patch = out_1[:, :, :sym_len_Ws]
+ inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(2, inv_idx)
+ out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv)
+
+ sym_patch = out_1[:, :, -sym_len_We:]
+ inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(2, inv_idx)
+ out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
+
+ out_2 = torch.FloatTensor(in_C, out_H, out_W)
+ kernel_width = weights_W.size(1)
+ for i in range(out_W):
+ idx = int(indices_W[i][0])
+ for j in range(out_C):
+ out_2[j, :, i] = out_1_aug[j, :,
+ idx:idx + kernel_width].mv(weights_W[i])
+ if need_squeeze:
+ out_2.squeeze_()
+ return out_2
+
+
+# imresize for numpy image [0, 1]
+def imresize_np(img, scale, antialiasing=True):
+ # Now the scale should be the same for H and W
+ # input: img: Numpy, HWC or HW [0,1]
+ # output: HWC or HW [0,1] w/o round
+ img = torch.from_numpy(img)
+ need_squeeze = True if img.dim() == 2 else False
+ if need_squeeze:
+ img.unsqueeze_(2)
+
+ in_H, in_W, in_C = img.size()
+ out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W
+ * scale)
+ kernel_width = 4
+ kernel = 'cubic'
+
+ # Return the desired dimension order for performing the resize. The
+ # strategy is to perform the resize first along the dimension with the
+ # smallest scale factor.
+ # Now we do not support this.
+
+ # get weights and indices
+ weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
+ in_H, out_H, scale, kernel, kernel_width, antialiasing)
+ weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
+ in_W, out_W, scale, kernel, kernel_width, antialiasing)
+ # process H dimension
+ # symmetric copying
+ img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C)
+ img_aug.narrow(0, sym_len_Hs, in_H).copy_(img)
+
+ sym_patch = img[:sym_len_Hs, :, :]
+ inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(0, inv_idx)
+ img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv)
+
+ sym_patch = img[-sym_len_He:, :, :]
+ inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(0, inv_idx)
+ img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
+
+ out_1 = torch.FloatTensor(out_H, in_W, in_C)
+ kernel_width = weights_H.size(1)
+ for i in range(out_H):
+ idx = int(indices_H[i][0])
+ for j in range(out_C):
+ out_1[i, :, j] = img_aug[idx:idx + kernel_width, :,
+ j].transpose(0, 1).mv(weights_H[i])
+
+ # process W dimension
+ # symmetric copying
+ out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C)
+ out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1)
+
+ sym_patch = out_1[:, :sym_len_Ws, :]
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
+ out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv)
+
+ sym_patch = out_1[:, -sym_len_We:, :]
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
+ out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
+
+ out_2 = torch.FloatTensor(out_H, out_W, in_C)
+ kernel_width = weights_W.size(1)
+ for i in range(out_W):
+ idx = int(indices_W[i][0])
+ for j in range(out_C):
+ out_2[:, i, j] = out_1_aug[:, idx:idx + kernel_width,
+ j].mv(weights_W[i])
+ if need_squeeze:
+ out_2.squeeze_()
+
+ return out_2.numpy()
+
+
+def modcrop_np(img, sf):
+ '''
+ Args:
+ img: numpy image, WxH or WxHxC
+ sf: scale factor
+ Return:
+ cropped image
+ '''
+ w, h = img.shape[:2]
+ im = np.copy(img)
+ return im[:w - w % sf, :h - h % sf, ...]
+
+
+def analytic_kernel(k):
+ """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
+ k_size = k.shape[0]
+ # Calculate the big kernels size
+ big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2))
+ # Loop over the small kernel to fill the big one
+ for r in range(k_size):
+ for c in range(k_size):
+ big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k
+ # Crop the edges of the big kernel to ignore very small values and increase run time of SR
+ crop = k_size // 2
+ cropped_big_k = big_k[crop:-crop, crop:-crop]
+ # Normalize to 1
+ return cropped_big_k / cropped_big_k.sum()
+
+
+def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
+ """ generate an anisotropic Gaussian kernel
+ Args:
+ ksize : e.g., 15, kernel size
+ theta : [0, pi], rotation angle range
+ l1 : [0.1,50], scaling of eigenvalues
+ l2 : [0.1,l1], scaling of eigenvalues
+ If l1 = l2, will get an isotropic Gaussian kernel.
+ Returns:
+ k : kernel
+ """
+
+ v = np.dot(
+ np.array([[np.cos(theta), -np.sin(theta)],
+ [np.sin(theta), np.cos(theta)]]), np.array([1., 0.]))
+ V = np.array([[v[0], v[1]], [v[1], -v[0]]])
+ D = np.array([[l1, 0], [0, l2]])
+ Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
+ k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)
+ return k
+
+
+def gm_blur_kernel(mean, cov, size=15):
+ center = size / 2.0 + 0.5
+ k = np.zeros([size, size])
+ for y in range(size):
+ for x in range(size):
+ cy = y - center + 1
+ cx = x - center + 1
+ k[y, x] = stats.multivariate_normal.pdf([cx, cy],
+ mean=mean,
+ cov=cov)
+
+ k = k / np.sum(k)
+ return k
+
+
+def shift_pixel(x, sf, upper_left=True):
+ """shift pixel for super-resolution with different scale factors
+ Args:
+ x: WxHxC or WxH
+ sf: scale factor
+ upper_left: shift direction
+ """
+ h, w = x.shape[:2]
+ shift = (sf - 1) * 0.5
+ xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
+ if upper_left:
+ x1 = xv + shift
+ y1 = yv + shift
+ else:
+ x1 = xv - shift
+ y1 = yv - shift
+
+ x1 = np.clip(x1, 0, w - 1)
+ y1 = np.clip(y1, 0, h - 1)
+
+ if x.ndim == 2:
+ x = interp2d(xv, yv, x)(x1, y1)
+ if x.ndim == 3:
+ for i in range(x.shape[-1]):
+ x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
+
+ return x
+
+
+def blur(x, k):
+ '''
+ x: image, NxcxHxW
+ k: kernel, Nx1xhxw
+ '''
+ n, c = x.shape[:2]
+ p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2
+ x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate')
+ k = k.repeat(1, c, 1, 1)
+ k = k.view(-1, 1, k.shape[2], k.shape[3])
+ x = x.view(1, -1, x.shape[2], x.shape[3])
+ x = torch.nn.functional.conv2d(
+ x, k, bias=None, stride=1, padding=0, groups=n * c)
+ x = x.view(n, c, x.shape[2], x.shape[3])
+
+ return x
+
+
+def gen_kernel(
+ k_size=np.array([15, 15]),
+ scale_factor=np.array([4, 4]),
+ min_var=0.6,
+ max_var=10.,
+ noise_level=0):
+ """"
+ # modified version of https://github.com/assafshocher/BlindSR_dataset_generator
+ # Kai Zhang
+ # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
+ # max_var = 2.5 * sf
+ """
+ # Set random eigen-vals (lambdas) and angle (theta) for COV matrix
+ lambda_1 = min_var + np.random.rand() * (max_var - min_var)
+ lambda_2 = min_var + np.random.rand() * (max_var - min_var)
+ theta = np.random.rand() * np.pi # random theta
+ noise = -noise_level + np.random.rand(*k_size) * noise_level * 2
+
+ # Set COV matrix using Lambdas and Theta
+ LAMBDA = np.diag([lambda_1, lambda_2])
+ Q = np.array([[np.cos(theta), -np.sin(theta)],
+ [np.sin(theta), np.cos(theta)]])
+ SIGMA = Q @ LAMBDA @ Q.T
+ INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
+
+ # Set expectation position (shifting kernel for aligned image)
+ MU = k_size // 2 - 0.5 * (scale_factor - 1)
+ MU = MU[None, None, :, None]
+
+ # Create meshgrid for Gaussian
+ [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
+ Z = np.stack([X, Y], 2)[:, :, :, None]
+
+ # Calcualte Gaussian for every pixel of the kernel
+ ZZ = Z - MU
+ ZZ_t = ZZ.transpose(0, 1, 3, 2)
+ raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
+
+ # shift the kernel so it will be centered
+ # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
+
+ # Normalize the kernel and return
+ # kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
+ kernel = raw_kernel / np.sum(raw_kernel)
+ return kernel
+
+
+def fspecial_gaussian(hsize, sigma):
+ hsize = [hsize, hsize]
+ siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0]
+ std = sigma
+ [x, y] = np.meshgrid(
+ np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1))
+ arg = -(x * x + y * y) / (2 * std * std)
+ h = np.exp(arg)
+ h[h < scipy.finfo(float).eps * h.max()] = 0
+ sumh = h.sum()
+ if sumh != 0:
+ h = h / sumh
+ return h
+
+
+def fspecial_laplacian(alpha):
+ alpha = max([0, min([alpha, 1])])
+ h1 = alpha / (alpha + 1)
+ h2 = (1 - alpha) / (alpha + 1)
+ h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]]
+ h = np.array(h)
+ return h
+
+
+def fspecial(filter_type, *args, **kwargs):
+ if filter_type == 'gaussian':
+ return fspecial_gaussian(*args, **kwargs)
+ if filter_type == 'laplacian':
+ return fspecial_laplacian(*args, **kwargs)
+
+
+def bicubic_degradation(x, sf=3):
+ '''
+ Args:
+ x: HxWxC image, [0, 1]
+ sf: down-scale factor
+ Return:
+ bicubicly downsampled LR image
+ '''
+ x = imresize_np(x, scale=1 / sf)
+ return x
+
+
+def srmd_degradation(x, k, sf=3):
+ ''' blur + bicubic downsampling
+ Args:
+ x: HxWxC image, [0, 1]
+ k: hxw, double
+ sf: down-scale factor
+ Return:
+ downsampled LR image
+ Reference:
+ @inproceedings{zhang2018learning,
+ title={Learning a single convolutional super-resolution network for multiple degradations},
+ author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
+ booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
+ pages={3262--3271},
+ year={2018}
+ }
+ '''
+ x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
+ x = bicubic_degradation(x, sf=sf)
+ return x
+
+
+def dpsr_degradation(x, k, sf=3):
+ ''' bicubic downsampling + blur
+ Args:
+ x: HxWxC image, [0, 1]
+ k: hxw, double
+ sf: down-scale factor
+ Return:
+ downsampled LR image
+ Reference:
+ @inproceedings{zhang2019deep,
+ title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
+ author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
+ booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
+ pages={1671--1681},
+ year={2019}
+ }
+ '''
+ x = bicubic_degradation(x, sf=sf)
+ x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
+ return x
+
+
+def classical_degradation(x, k, sf=3):
+ ''' blur + downsampling
+ Args:
+ x: HxWxC image, [0, 1]/[0, 255]
+ k: hxw, double
+ sf: down-scale factor
+ Return:
+ downsampled LR image
+ '''
+ x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
+ st = 0
+ return x[st::sf, st::sf, ...]
+
+
+def add_sharpening(img, weight=0.5, radius=50, threshold=10):
+ """USM sharpening. borrowed from real-ESRGAN
+ Input image: I; Blurry image: B.
+ 1. K = I + weight * (I - B)
+ 2. Mask = 1 if abs(I - B) > threshold, else: 0
+ 3. Blur mask:
+ 4. Out = Mask * K + (1 - Mask) * I
+ Args:
+ img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
+ weight (float): Sharp weight. Default: 1.
+ radius (float): Kernel size of Gaussian blur. Default: 50.
+ threshold (int):
+ """
+ if radius % 2 == 0:
+ radius += 1
+ blur = cv2.GaussianBlur(img, (radius, radius), 0)
+ residual = img - blur
+ mask = np.abs(residual) * 255 > threshold
+ mask = mask.astype('float32')
+ soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
+
+ K = img + weight * residual
+ K = np.clip(K, 0, 1)
+ return soft_mask * K + (1 - soft_mask) * img
+
+
+def add_blur_1(img, sf=4):
+ wd2 = 4.0 + sf
+ wd = 2.0 + 0.2 * sf
+
+ wd2 = wd2 / 4
+ wd = wd / 4
+
+ if random.random() < 0.5:
+ l1 = wd2 * random.random()
+ l2 = wd2 * random.random()
+ k = anisotropic_Gaussian(
+ ksize=random.randint(2, 11) + 3,
+ theta=random.random() * np.pi,
+ l1=l1,
+ l2=l2)
+ else:
+ k = fspecial('gaussian',
+ random.randint(2, 4) + 3, wd * random.random())
+ img = ndimage.filters.convolve(
+ img, np.expand_dims(k, axis=2), mode='mirror')
+
+ return img
+
+
+def add_resize(img, sf=4):
+ rnum = np.random.rand()
+ if rnum > 0.8:
+ sf1 = random.uniform(1, 2)
+ elif rnum < 0.7:
+ sf1 = random.uniform(0.5 / sf, 1)
+ else:
+ sf1 = 1.0
+ img = cv2.resize(
+ img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ img = np.clip(img, 0.0, 1.0)
+
+ return img
+
+
+def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
+ noise_level = random.randint(noise_level1, noise_level2)
+ rnum = np.random.rand()
+ if rnum > 0.6:
+ img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(
+ np.float32)
+ elif rnum < 0.4:
+ img = img + np.random.normal(0, noise_level / 255.0,
+ (*img.shape[:2], 1)).astype(np.float32)
+ else:
+ L = noise_level2 / 255.
+ D = np.diag(np.random.rand(3))
+ U = orth(np.random.rand(3, 3))
+ conv = np.dot(np.dot(np.transpose(U), D), U)
+ img = img + np.random.multivariate_normal([0, 0, 0], np.abs(
+ L**2 * conv), img.shape[:2]).astype(np.float32)
+ img = np.clip(img, 0.0, 1.0)
+ return img
+
+
+def add_speckle_noise(img, noise_level1=2, noise_level2=25):
+ noise_level = random.randint(noise_level1, noise_level2)
+ img = np.clip(img, 0.0, 1.0)
+ rnum = random.random()
+ if rnum > 0.6:
+ img += img * np.random.normal(0, noise_level / 255.0,
+ img.shape).astype(np.float32)
+ elif rnum < 0.4:
+ img += img * np.random.normal(0, noise_level / 255.0,
+ (*img.shape[:2], 1)).astype(np.float32)
+ else:
+ L = noise_level2 / 255.
+ D = np.diag(np.random.rand(3))
+ U = orth(np.random.rand(3, 3))
+ conv = np.dot(np.dot(np.transpose(U), D), U)
+ img += img * np.random.multivariate_normal(
+ [0, 0, 0], np.abs(L**2 * conv), img.shape[:2]).astype(np.float32)
+ img = np.clip(img, 0.0, 1.0)
+ return img
+
+
+def add_Poisson_noise(img):
+ img = np.clip((img * 255.0).round(), 0, 255) / 255.
+ vals = 10**(2 * random.random() + 2.0)
+ if random.random() < 0.5:
+ img = np.random.poisson(img * vals).astype(np.float32) / vals
+ else:
+ img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])
+ img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.
+ noise_gray = np.random.poisson(img_gray * vals).astype(
+ np.float32) / vals - img_gray
+ img += noise_gray[:, :, np.newaxis]
+ img = np.clip(img, 0.0, 1.0)
+ return img
+
+
+def add_JPEG_noise(img):
+ quality_factor = random.randint(80, 95)
+ img = cv2.cvtColor(single2uint(img), cv2.COLOR_RGB2BGR)
+ result, encimg = cv2.imencode(
+ '.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
+ img = cv2.imdecode(encimg, 1)
+ img = cv2.cvtColor(uint2single(img), cv2.COLOR_BGR2RGB)
+ return img
+
+
+def random_crop(lq, hq, sf=4, lq_patchsize=64):
+ h, w = lq.shape[:2]
+ rnd_h = random.randint(0, h - lq_patchsize)
+ rnd_w = random.randint(0, w - lq_patchsize)
+ lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :]
+
+ rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)
+ hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf,
+ rnd_w_H:rnd_w_H + lq_patchsize * sf, :]
+ return lq, hq
+
+
+def degradation_bsrgan_light(image, sf=4, isp_model=None):
+ """
+ This is the variant of the degradation model of BSRGAN from the paper
+ "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
+ sf: scale factor
+ isp_model: camera ISP model
+ Returns
+ img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
+ hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
+ """
+ image = uint2single(image)
+ _, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
+
+ h1, w1 = image.shape[:2]
+ image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...]
+ h, w = image.shape[:2]
+
+ if sf == 4 and random.random() < scale2_prob:
+ if np.random.rand() < 0.5:
+ image = cv2.resize(
+ image,
+ (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ image = imresize_np(image, 1 / 2, True)
+ image = np.clip(image, 0.0, 1.0)
+ sf = 2
+
+ shuffle_order = random.sample(range(7), 7)
+ idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
+ if idx1 > idx2:
+ shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[
+ idx2], shuffle_order[idx1]
+
+ for i in shuffle_order:
+
+ if i == 0:
+ image = add_blur_1(image, sf=sf)
+ elif i == 2:
+ a, b = image.shape[1], image.shape[0]
+ # downsample2
+ if random.random() < 0.8:
+ sf1 = random.uniform(1, 2 * sf)
+ image = cv2.resize(
+ image, (int(1 / sf1 * image.shape[1]),
+ int(1 / sf1 * image.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
+ k_shifted = shift_pixel(k, sf)
+ k_shifted = k_shifted / k_shifted.sum()
+ image = ndimage.filters.convolve(
+ image, np.expand_dims(k_shifted, axis=2), mode='mirror')
+ image = image[0::sf, 0::sf, ...]
+ image = np.clip(image, 0.0, 1.0)
+ elif i == 3:
+ # downsample3
+ image = cv2.resize(
+ image, (int(1 / sf * a), int(1 / sf * b)),
+ interpolation=random.choice([1, 2, 3]))
+ image = np.clip(image, 0.0, 1.0)
+ elif i == 4:
+ # add Gaussian noise
+ image = add_Gaussian_noise(image, noise_level1=1, noise_level2=2)
+ elif i == 5:
+ # add JPEG noise
+ if random.random() < jpeg_prob:
+ image = add_JPEG_noise(image)
+
+ # add final JPEG compression noise
+ image = add_JPEG_noise(image)
+ image = single2uint(image)
+ return image
+
+
+def add_blur_2(img, sf=4):
+ wd2 = 4.0 + sf
+ wd = 2.0 + 0.2 * sf
+ if random.random() < 0.5:
+ l1 = wd2 * random.random()
+ l2 = wd2 * random.random()
+ k = anisotropic_Gaussian(
+ ksize=2 * random.randint(2, 11) + 3,
+ theta=random.random() * np.pi,
+ l1=l1,
+ l2=l2)
+ else:
+ k = fspecial('gaussian', 2 * random.randint(2, 11) + 3,
+ wd * random.random())
+ img = ndimage.filters.convolve(
+ img, np.expand_dims(k, axis=2), mode='mirror')
+ return img
+
+
+def degradation_bsrgan(image, sf=4, isp_model=None):
+ """
+ This is the variant of the degradation model of BSRGAN from the paper
+ "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
+ sf: scale factor
+ isp_model: camera ISP model
+ Returns
+ img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
+ hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
+ """
+ image = uint2single(image)
+ _, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
+
+ h1, w1 = image.shape[:2]
+ image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...]
+ h, w = image.shape[:2]
+
+ if sf == 4 and random.random() < scale2_prob:
+ if np.random.rand() < 0.5:
+ image = cv2.resize(
+ image,
+ (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ image = imresize_np(image, 1 / 2, True)
+ image = np.clip(image, 0.0, 1.0)
+ sf = 2
+
+ shuffle_order = random.sample(range(7), 7)
+ idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
+ if idx1 > idx2:
+ shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[
+ idx2], shuffle_order[idx1]
+
+ for i in shuffle_order:
+
+ if i == 0:
+ image = add_blur_2(image, sf=sf)
+ elif i == 1:
+ image = add_blur_2(image, sf=sf)
+ elif i == 2:
+ a, b = image.shape[1], image.shape[0]
+ # downsample2
+ if random.random() < 0.75:
+ sf1 = random.uniform(1, 2 * sf)
+ image = cv2.resize(
+ image, (int(1 / sf1 * image.shape[1]),
+ int(1 / sf1 * image.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
+ k_shifted = shift_pixel(k, sf)
+ k_shifted = k_shifted / k_shifted.sum()
+ image = ndimage.filters.convolve(
+ image, np.expand_dims(k_shifted, axis=2), mode='mirror')
+ image = image[0::sf, 0::sf, ...]
+ image = np.clip(image, 0.0, 1.0)
+ elif i == 3:
+ # downsample3
+ image = cv2.resize(
+ image, (int(1 / sf * a), int(1 / sf * b)),
+ interpolation=random.choice([1, 2, 3]))
+ image = np.clip(image, 0.0, 1.0)
+ elif i == 4:
+ # add Gaussian noise
+ image = add_Gaussian_noise(image, noise_level1=2, noise_level2=25)
+ elif i == 5:
+ # add JPEG noise
+ if random.random() < jpeg_prob:
+ image = add_JPEG_noise(image)
+
+ # add final JPEG compression noise
+ image = add_JPEG_noise(image)
+ image = single2uint(image)
+ return image
diff --git a/modelscope/models/multi_modal/videocomposer/ops/distributed.py b/modelscope/models/multi_modal/videocomposer/ops/distributed.py
new file mode 100644
index 00000000..201e156f
--- /dev/null
+++ b/modelscope/models/multi_modal/videocomposer/ops/distributed.py
@@ -0,0 +1,460 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+
+import functools
+import pickle
+from collections import OrderedDict
+
+import numpy as np
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+from torch.autograd import Function
+
+__all__ = [
+ 'is_dist_initialized', 'get_world_size', 'get_rank', 'new_group',
+ 'destroy_process_group', 'barrier', 'broadcast', 'all_reduce', 'reduce',
+ 'gather', 'all_gather', 'reduce_dict', 'get_global_gloo_group',
+ 'generalized_all_gather', 'generalized_gather', 'scatter',
+ 'reduce_scatter', 'send', 'recv', 'isend', 'irecv', 'shared_random_seed',
+ 'diff_all_gather', 'diff_all_reduce', 'diff_scatter', 'diff_copy',
+ 'spherical_kmeans', 'sinkhorn'
+]
+
+
+def is_dist_initialized():
+ return dist.is_available() and dist.is_initialized()
+
+
+def get_world_size(group=None):
+ return dist.get_world_size(group) if is_dist_initialized() else 1
+
+
+def get_rank(group=None):
+ return dist.get_rank(group) if is_dist_initialized() else 0
+
+
+def new_group(ranks=None, **kwargs):
+ if is_dist_initialized():
+ return dist.new_group(ranks, **kwargs)
+ return None
+
+
+def destroy_process_group():
+ if is_dist_initialized():
+ dist.destroy_process_group()
+
+
+def barrier(group=None, **kwargs):
+ if get_world_size(group) > 1:
+ dist.barrier(group, **kwargs)
+
+
+def broadcast(tensor, src, group=None, **kwargs):
+ if get_world_size(group) > 1:
+ return dist.broadcast(tensor, src, group, **kwargs)
+
+
+def all_reduce(tensor, op=dist.ReduceOp.SUM, group=None, **kwargs):
+ if get_world_size(group) > 1:
+ return dist.all_reduce(tensor, op, group, **kwargs)
+
+
+def reduce(tensor, dst, op=dist.ReduceOp.SUM, group=None, **kwargs):
+ if get_world_size(group) > 1:
+ return dist.reduce(tensor, dst, op, group, **kwargs)
+
+
+def gather(tensor, dst=0, group=None, **kwargs):
+ rank = get_rank()
+ world_size = get_world_size(group)
+ if world_size == 1:
+ return [tensor]
+ tensor_list = [torch.empty_like(tensor)
+ for _ in range(world_size)] if rank == dst else None
+ dist.gather(tensor, tensor_list, dst, group, **kwargs)
+ return tensor_list
+
+
+def all_gather(tensor, uniform_size=True, group=None, **kwargs):
+ world_size = get_world_size(group)
+ if world_size == 1:
+ return [tensor]
+ assert tensor.is_contiguous(
+ ), 'ops.all_gather requires the tensor to be contiguous()'
+
+ if uniform_size:
+ tensor_list = [torch.empty_like(tensor) for _ in range(world_size)]
+ dist.all_gather(tensor_list, tensor, group, **kwargs)
+ return tensor_list
+ else:
+ # collect tensor shapes across GPUs
+ shape = tuple(tensor.shape)
+ shape_list = generalized_all_gather(shape, group)
+
+ # flatten the tensor
+ tensor = tensor.reshape(-1)
+ size = int(np.prod(shape))
+ size_list = [int(np.prod(u)) for u in shape_list]
+ max_size = max(size_list)
+
+ # pad to maximum size
+ if size != max_size:
+ padding = tensor.new_zeros(max_size - size)
+ tensor = torch.cat([tensor, padding], dim=0)
+
+ # all_gather
+ tensor_list = [torch.empty_like(tensor) for _ in range(world_size)]
+ dist.all_gather(tensor_list, tensor, group, **kwargs)
+
+ # reshape tensors
+ tensor_list = [
+ t[:n].view(s)
+ for t, n, s in zip(tensor_list, size_list, shape_list)
+ ]
+ return tensor_list
+
+
+@torch.no_grad()
+def reduce_dict(input_dict, group=None, reduction='mean', **kwargs):
+ assert reduction in ['mean', 'sum']
+ world_size = get_world_size(group)
+ if world_size == 1:
+ return input_dict
+
+ # ensure that the orders of keys are consistent across processes
+ if isinstance(input_dict, OrderedDict):
+ keys = list(input_dict.keys)
+ else:
+ keys = sorted(input_dict.keys())
+ vals = [input_dict[key] for key in keys]
+ vals = torch.stack(vals, dim=0)
+ dist.reduce(vals, dst=0, group=group, **kwargs)
+ if dist.get_rank(group) == 0 and reduction == 'mean':
+ vals /= world_size
+ dist.broadcast(vals, src=0, group=group, **kwargs)
+ reduced_dict = type(input_dict)([(key, val)
+ for key, val in zip(keys, vals)])
+ return reduced_dict
+
+
+@functools.lru_cache()
+def get_global_gloo_group():
+ backend = dist.get_backend()
+ assert backend in ['gloo', 'nccl']
+ if backend == 'nccl':
+ return dist.new_group(backend='gloo')
+ else:
+ return dist.group.WORLD
+
+
+def _serialize_to_tensor(data, group):
+ backend = dist.get_backend(group)
+ assert backend in ['gloo', 'nccl']
+ device = torch.device('cpu' if backend == 'gloo' else 'cuda')
+
+ buffer = pickle.dumps(data)
+ if len(buffer) > 1024**3:
+ logger = logging.getLogger(__name__)
+ logger.warning(
+ 'Rank {} trying to all-gather {:.2f} GB of data on device'
+ '{}'.format(get_rank(),
+ len(buffer) / (1024**3), device))
+ storage = torch.ByteStorage.from_buffer(buffer)
+ tensor = torch.ByteTensor(storage).to(device=device)
+ return tensor
+
+
+def _pad_to_largest_tensor(tensor, group):
+ world_size = dist.get_world_size(group=group)
+ assert world_size >= 1, \
+ 'gather/all_gather must be called from ranks within' \
+ 'the give group!'
+ local_size = torch.tensor([tensor.numel()],
+ dtype=torch.int64,
+ device=tensor.device)
+ size_list = [
+ torch.zeros([1], dtype=torch.int64, device=tensor.device)
+ for _ in range(world_size)
+ ]
+
+ # gather tensors and compute the maximum size
+ dist.all_gather(size_list, local_size, group=group)
+ size_list = [int(size.item()) for size in size_list]
+ max_size = max(size_list)
+
+ # pad tensors to the same size
+ if local_size != max_size:
+ padding = torch.zeros((max_size - local_size, ),
+ dtype=torch.uint8,
+ device=tensor.device)
+ tensor = torch.cat((tensor, padding), dim=0)
+ return size_list, tensor
+
+
+def generalized_all_gather(data, group=None):
+ if get_world_size(group) == 1:
+ return [data]
+ if group is None:
+ group = get_global_gloo_group()
+
+ tensor = _serialize_to_tensor(data, group)
+ size_list, tensor = _pad_to_largest_tensor(tensor, group)
+ max_size = max(size_list)
+
+ # receiving tensors from all ranks
+ tensor_list = [
+ torch.empty((max_size, ), dtype=torch.uint8, device=tensor.device)
+ for _ in size_list
+ ]
+ dist.all_gather(tensor_list, tensor, group=group)
+
+ data_list = []
+ for size, tensor in zip(size_list, tensor_list):
+ buffer = tensor.cpu().numpy().tobytes()[:size]
+ data_list.append(pickle.loads(buffer))
+ return data_list
+
+
+def generalized_gather(data, dst=0, group=None):
+ world_size = get_world_size(group)
+ if world_size == 1:
+ return [data]
+ if group is None:
+ group = get_global_gloo_group()
+ rank = dist.get_rank()
+
+ tensor = _serialize_to_tensor(data, group)
+ size_list, tensor = _pad_to_largest_tensor(tensor, group)
+
+ # receiving tensors from all ranks to dst
+ if rank == dst:
+ max_size = max(size_list)
+ tensor_list = [
+ torch.empty((max_size, ), dtype=torch.uint8, device=tensor.device)
+ for _ in size_list
+ ]
+ dist.gather(tensor, tensor_list, dst=dst, group=group)
+
+ data_list = []
+ for size, tensor in zip(size_list, tensor_list):
+ buffer = tensor.cpu().numpy().tobytes()[:size]
+ data_list.append(pickle.loads(buffer))
+ return data_list
+ else:
+ dist.gather(tensor, [], dst=dst, group=group)
+ return []
+
+
+def scatter(data, scatter_list=None, src=0, group=None, **kwargs):
+ r"""NOTE: only supports CPU tensor communication.
+ """
+ if get_world_size(group) > 1:
+ return dist.scatter(data, scatter_list, src, group, **kwargs)
+
+
+def reduce_scatter(output,
+ input_list,
+ op=dist.ReduceOp.SUM,
+ group=None,
+ **kwargs):
+ if get_world_size(group) > 1:
+ return dist.reduce_scatter(output, input_list, op, group, **kwargs)
+
+
+def send(tensor, dst, group=None, **kwargs):
+ if get_world_size(group) > 1:
+ assert tensor.is_contiguous(
+ ), 'ops.send requires the tensor to be contiguous()'
+ return dist.send(tensor, dst, group, **kwargs)
+
+
+def recv(tensor, src=None, group=None, **kwargs):
+ if get_world_size(group) > 1:
+ assert tensor.is_contiguous(
+ ), 'ops.recv requires the tensor to be contiguous()'
+ return dist.recv(tensor, src, group, **kwargs)
+
+
+def isend(tensor, dst, group=None, **kwargs):
+ if get_world_size(group) > 1:
+ assert tensor.is_contiguous(
+ ), 'ops.isend requires the tensor to be contiguous()'
+ return dist.isend(tensor, dst, group, **kwargs)
+
+
+def irecv(tensor, src=None, group=None, **kwargs):
+ if get_world_size(group) > 1:
+ assert tensor.is_contiguous(
+ ), 'ops.irecv requires the tensor to be contiguous()'
+ return dist.irecv(tensor, src, group, **kwargs)
+
+
+def shared_random_seed(group=None):
+ seed = np.random.randint(2**31)
+ all_seeds = generalized_all_gather(seed, group)
+ return all_seeds[0]
+
+
+def _all_gather(x):
+ if not (dist.is_available()
+ and dist.is_initialized()) or dist.get_world_size() == 1:
+ return x
+ rank = dist.get_rank()
+ world_size = dist.get_world_size()
+ tensors = [torch.empty_like(x) for _ in range(world_size)]
+ tensors[rank] = x
+ dist.all_gather(tensors, x)
+ return torch.cat(tensors, dim=0).contiguous()
+
+
+def _all_reduce(x):
+ if not (dist.is_available()
+ and dist.is_initialized()) or dist.get_world_size() == 1:
+ return x
+ dist.all_reduce(x)
+ return x
+
+
+def _split(x):
+ if not (dist.is_available()
+ and dist.is_initialized()) or dist.get_world_size() == 1:
+ return x
+ rank = dist.get_rank()
+ world_size = dist.get_world_size()
+ return x.chunk(world_size, dim=0)[rank].contiguous()
+
+
+class DiffAllGather(Function):
+ r"""Differentiable all-gather.
+ """
+
+ @staticmethod
+ def symbolic(graph, input):
+ return _all_gather(input)
+
+ @staticmethod
+ def forward(ctx, input):
+ return _all_gather(input)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ return _split(grad_output)
+
+
+class DiffAllReduce(Function):
+ r"""Differentiable all-reducd.
+ """
+
+ @staticmethod
+ def symbolic(graph, input):
+ return _all_reduce(input)
+
+ @staticmethod
+ def forward(ctx, input):
+ return _all_reduce(input)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ return grad_output
+
+
+class DiffScatter(Function):
+ r"""Differentiable scatter.
+ """
+
+ @staticmethod
+ def symbolic(ctx, input):
+ return _split(input)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ return _all_gather(grad_output)
+
+
+class DiffCopy(Function):
+ r"""Differentiable copy that reduces all gradients during backward.
+ """
+
+ @staticmethod
+ def symbolic(graph, input):
+ return input
+
+ @staticmethod
+ def forward(ctx, input):
+ return input
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ return _all_reduce(grad_output)
+
+
+diff_all_gather = DiffAllGather.apply
+diff_all_reduce = DiffAllReduce.apply
+diff_scatter = DiffScatter.apply
+diff_copy = DiffCopy.apply
+
+
+@torch.no_grad()
+def spherical_kmeans(feats, num_clusters, num_iters=10):
+ k, n, c = num_clusters, *feats.size()
+ ones = feats.new_ones(n, dtype=torch.long)
+
+ # distributed settings
+ world_size = get_world_size()
+
+ # init clusters
+ rand_inds = torch.randperm(n)[:int(np.ceil(k / world_size))]
+ clusters = torch.cat(all_gather(feats[rand_inds]), dim=0)[:k]
+
+ # variables
+ new_clusters = feats.new_zeros(k, c)
+ counts = feats.new_zeros(k, dtype=torch.long)
+
+ # iterative Expectation-Maximization
+ for step in range(num_iters + 1):
+ # Expectation step
+ simmat = torch.mm(feats, clusters.t())
+ scores, assigns = simmat.max(dim=1)
+ if step == num_iters:
+ break
+
+ # Maximization step
+ new_clusters.zero_().scatter_add_(0,
+ assigns.unsqueeze(1).repeat(1, c),
+ feats)
+ all_reduce(new_clusters)
+
+ counts.zero_()
+ counts.index_add_(0, assigns, ones)
+ all_reduce(counts)
+
+ mask = (counts > 0)
+ clusters[mask] = new_clusters[mask] / counts[mask].view(-1, 1)
+ clusters = F.normalize(clusters, p=2, dim=1)
+ return clusters, assigns, scores
+
+
+@torch.no_grad()
+def sinkhorn(Q, eps=0.5, num_iters=3):
+ # normalize Q
+ Q = torch.exp(Q / eps).t()
+ sum_Q = Q.sum()
+ all_reduce(sum_Q)
+ Q /= sum_Q
+
+ # variables
+ n, m = Q.size()
+ u = Q.new_zeros(n)
+ r = Q.new_ones(n) / n
+ c = Q.new_ones(m) / (m * get_world_size())
+
+ # iterative update
+ cur_sum = Q.sum(dim=1)
+ all_reduce(cur_sum)
+ for i in range(num_iters):
+ u = cur_sum
+ Q *= (r / u).unsqueeze(1)
+ Q *= (c / Q.sum(dim=0)).unsqueeze(0)
+ cur_sum = Q.sum(dim=1)
+ all_reduce(cur_sum)
+ return (Q / Q.sum(dim=0, keepdim=True)).t().float()
diff --git a/modelscope/models/multi_modal/videocomposer/ops/losses.py b/modelscope/models/multi_modal/videocomposer/ops/losses.py
new file mode 100644
index 00000000..fbcb0b60
--- /dev/null
+++ b/modelscope/models/multi_modal/videocomposer/ops/losses.py
@@ -0,0 +1,37 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+
+import math
+
+import torch
+
+__all__ = ['kl_divergence', 'discretized_gaussian_log_likelihood']
+
+
+def kl_divergence(mu1, logvar1, mu2, logvar2):
+ return 0.5 * (
+ -1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + # noqa
+ ((mu1 - mu2)**2) * torch.exp(-logvar2))
+
+
+def standard_normal_cdf(x):
+ r"""A fast approximation of the cumulative distribution function of the standard normal.
+ """
+ return 0.5 * (1.0 + torch.tanh(
+ math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
+
+
+def discretized_gaussian_log_likelihood(x0, mean, log_scale):
+ assert x0.shape == mean.shape == log_scale.shape
+ cx = x0 - mean
+ inv_stdv = torch.exp(-log_scale)
+ cdf_plus = standard_normal_cdf(inv_stdv * (cx + 1.0 / 255.0))
+ cdf_min = standard_normal_cdf(inv_stdv * (cx - 1.0 / 255.0))
+ log_cdf_plus = torch.log(cdf_plus.clamp(min=1e-12))
+ log_one_minus_cdf_min = torch.log((1.0 - cdf_min).clamp(min=1e-12))
+ cdf_delta = cdf_plus - cdf_min
+ log_probs = torch.where(
+ x0 < -0.999, log_cdf_plus,
+ torch.where(x0 > 0.999, log_one_minus_cdf_min,
+ torch.log(cdf_delta.clamp(min=1e-12))))
+ assert log_probs.shape == x0.shape
+ return log_probs
diff --git a/modelscope/models/multi_modal/videocomposer/ops/random_mask.py b/modelscope/models/multi_modal/videocomposer/ops/random_mask.py
new file mode 100644
index 00000000..f23219b5
--- /dev/null
+++ b/modelscope/models/multi_modal/videocomposer/ops/random_mask.py
@@ -0,0 +1,81 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+
+import cv2
+import numpy as np
+
+__all__ = ['make_irregular_mask', 'make_rectangle_mask', 'make_uncrop']
+
+
+def make_irregular_mask(w,
+ h,
+ max_angle=4,
+ max_length=200,
+ max_width=100,
+ min_strokes=1,
+ max_strokes=5,
+ mode='line'):
+ # initialize mask
+ assert mode in ['line', 'circle', 'square']
+ mask = np.zeros((h, w), np.float32)
+
+ # draw strokes
+ num_strokes = np.random.randint(min_strokes, max_strokes + 1)
+ for i in range(num_strokes):
+ x1 = np.random.randint(w)
+ y1 = np.random.randint(h)
+ for j in range(1 + np.random.randint(5)):
+ angle = 0.01 + np.random.randint(max_angle)
+ if i % 2 == 0:
+ angle = 2 * 3.1415926 - angle
+ length = 10 + np.random.randint(max_length)
+ radius = 5 + np.random.randint(max_width)
+ x2 = np.clip((x1 + length * np.sin(angle)).astype(np.int32), 0, w)
+ y2 = np.clip((y1 + length * np.cos(angle)).astype(np.int32), 0, h)
+ if mode == 'line':
+ cv2.line(mask, (x1, y1), (x2, y2), 1.0, radius)
+ elif mode == 'circle':
+ cv2.circle(
+ mask, (x1, y1), radius=radius, color=1.0, thickness=-1)
+ elif mode == 'square':
+ radius = radius // 2
+ mask[y1 - radius:y1 + radius, x1 - radius:x1 + radius] = 1
+ x1, y1 = x2, y2
+ return mask
+
+
+def make_rectangle_mask(w,
+ h,
+ margin=10,
+ min_size=30,
+ max_size=150,
+ min_strokes=1,
+ max_strokes=4):
+ # initialize mask
+ mask = np.zeros((h, w), np.float32)
+
+ # draw rectangles
+ num_strokes = np.random.randint(min_strokes, max_strokes + 1)
+ for i in range(num_strokes):
+ box_w = np.random.randint(min_size, max_size)
+ box_h = np.random.randint(min_size, max_size)
+ x1 = np.random.randint(margin, w - margin - box_w + 1)
+ y1 = np.random.randint(margin, h - margin - box_h + 1)
+ mask[y1:y1 + box_h, x1:x1 + box_w] = 1
+ return mask
+
+
+def make_uncrop(w, h):
+ # initialize mask
+ mask = np.zeros((h, w), np.float32)
+
+ # randomly halve the image
+ side = np.random.choice([0, 1, 2, 3])
+ if side == 0:
+ mask[:h // 2, :] = 1
+ elif side == 1:
+ mask[h // 2:, :] = 1
+ elif side == 2:
+ mask[:, :w // 2] = 1
+ elif side == 3:
+ mask[:, w // 2:] = 1
+ return mask
diff --git a/modelscope/models/multi_modal/videocomposer/ops/utils.py b/modelscope/models/multi_modal/videocomposer/ops/utils.py
new file mode 100644
index 00000000..f9aadc15
--- /dev/null
+++ b/modelscope/models/multi_modal/videocomposer/ops/utils.py
@@ -0,0 +1,1037 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+
+import base64
+import binascii
+import copy
+import glob
+import gzip
+import hashlib
+import logging
+import math
+import os
+import os.path as osp
+import pickle
+import sys
+import time
+import urllib.request
+import zipfile
+from io import BytesIO
+from multiprocessing.pool import ThreadPool as Pool
+
+import imageio
+import json
+import numpy as np
+import oss2 as oss
+import requests
+import skvideo.io
+import torch
+import torch.nn.functional as F
+import torchvision.utils as tvutils
+from einops import rearrange
+from PIL import Image
+
+from modelscope.models.multi_modal.videocomposer.autoencoder import \
+ DiagonalGaussianDistribution
+from modelscope.utils.logger import get_logger
+
+logger = get_logger()
+
+__all__ = [
+ 'parse_oss_url', 'parse_bucket', 'read', 'read_image', 'read_gzip',
+ 'ceil_divide', 'to_device', 'put_object', 'put_torch_object',
+ 'put_object_from_file', 'get_object', 'get_object_to_file', 'rand_name',
+ 'save_image', 'save_video', 'save_video_vs_conditions',
+ 'save_video_multiple_conditions_with_data',
+ 'save_video_multiple_conditions', 'download_video_to_file',
+ 'save_video_grid_mp4', 'save_caps', 'ema', 'parallel', 'exists',
+ 'download', 'unzip', 'load_state_dict', 'inverse_indices',
+ 'detect_duplicates', 'read_tfs', 'md5', 'rope', 'format_state',
+ 'breakup_grid', 'huggingface_tokenizer', 'huggingface_model'
+]
+
+TFS_CLIENT = None
+
+
+def rand_name(length=8, suffix=''):
+ name = binascii.b2a_hex(os.urandom(length)).decode('utf-8')
+ if suffix:
+ if not suffix.startswith('.'):
+ suffix = '.' + suffix
+ name += suffix
+ return name
+
+
+def save_with_model_kwargs(model_kwargs, video_data, autoencoder, ori_video,
+ viz_num, step, caps, palette, cfg):
+ scale_factor = 0.18215
+ video_data = 1. / scale_factor * video_data
+
+ bs_vd = video_data.shape[0]
+ video_data = rearrange(video_data, 'b c f h w -> (b f) c h w')
+ chunk_size = min(16, video_data.shape[0])
+ video_data_list = torch.chunk(
+ video_data, video_data.shape[0] // chunk_size, dim=0)
+ decode_data = []
+ for vd_data in video_data_list:
+ tmp = autoencoder.decode(vd_data)
+ decode_data.append(tmp)
+ video_data = torch.cat(decode_data, dim=0)
+ video_data = rearrange(video_data, '(b f) c h w -> b c f h w', b=bs_vd)
+ ori_video = ori_video[:viz_num]
+
+ oss_key = os.path.join(cfg.log_dir, 'rank.gif')
+ text_key = osp.join(cfg.log_dir, 'text_description.txt')
+
+ if not os.path.exists(cfg.log_dir):
+ os.mkdir(cfg.log_dir)
+
+ # Save videos and text inputs.
+ try:
+ del model_kwargs[0][list(model_kwargs[0].keys())[0]]
+ del model_kwargs[1][list(model_kwargs[1].keys())[0]]
+
+ save_video_multiple_conditions(
+ oss_key,
+ video_data,
+ model_kwargs,
+ ori_video,
+ palette,
+ cfg.mean,
+ cfg.std,
+ nrow=1,
+ save_origin_video=cfg.save_origin_video)
+
+ texts = '\n'.join(caps[:viz_num])
+ open(text_key, 'w').writelines(texts)
+ except Exception as e:
+ logger.error(f'Save text or video error. {e}')
+
+ logger.info(f'Save videos to {oss_key}')
+
+
+def prepare_model_kwargs(partial_keys, full_model_kwargs, use_fps_condition):
+ for partial_key in partial_keys:
+ assert partial_key in [
+ 'y', 'depth', 'canny', 'masked', 'sketch', 'image', 'motion',
+ 'local_image', 'single_sketch'
+ ]
+
+ if use_fps_condition is True:
+ partial_keys.append('fps')
+
+ partial_model_kwargs = [{}, {}]
+ for partial_key in partial_keys:
+ partial_model_kwargs[0][partial_key] = full_model_kwargs[0][
+ partial_key]
+ partial_model_kwargs[1][partial_key] = full_model_kwargs[1][
+ partial_key]
+
+ return partial_model_kwargs
+
+
+@torch.no_grad()
+def get_first_stage_encoding(encoder_posterior):
+ scale_factor = 0.18215
+ if isinstance(encoder_posterior, DiagonalGaussianDistribution):
+ z = encoder_posterior.sample()
+ elif isinstance(encoder_posterior, torch.Tensor):
+ z = encoder_posterior
+ else:
+ raise NotImplementedError(
+ f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented"
+ )
+ return scale_factor * z
+
+
+def make_masked_images(imgs, masks):
+ masked_imgs = []
+ for i, mask in enumerate(masks):
+ # concatenation
+ masked_imgs.append(
+ torch.cat([imgs[i] * (1 - mask), (1 - mask)], dim=1))
+ return torch.stack(masked_imgs, dim=0)
+
+
+def DOWNLOAD_TO_CACHE(oss_key,
+ file_or_dirname=None,
+ cache_dir=osp.join(
+ '/'.join(osp.abspath(__file__).split('/')[:-2]),
+ 'model_weights')):
+ r"""Download OSS [file or folder] to the cache folder.
+ Only the 0th process on each node will run the downloading.
+ Barrier all processes until the downloading is completed.
+ """
+ # source and target paths
+ base_path = osp.join(cache_dir, file_or_dirname or osp.basename(oss_key))
+
+ return base_path
+
+
+def parse_oss_url(path):
+ if path.startswith('oss://'):
+ path = path[len('oss://'):]
+
+ # configs
+ configs = {
+ 'endpoint': os.getenv('OSS_ENDPOINT', None),
+ 'accessKeyID': os.getenv('OSS_ACCESS_KEY_ID', None),
+ 'accessKeySecret': os.getenv('OSS_ACCESS_KEY_SECRET', None),
+ 'securityToken': os.getenv('OSS_SECURITY_TOKEN', None)
+ }
+ bucket, path = path.split('/', maxsplit=1)
+ if '?' in bucket:
+ bucket, config = bucket.split('?', maxsplit=1)
+ for pair in config.split('&'):
+ k, v = pair.split('=', maxsplit=1)
+ configs[k] = v
+
+ # session
+ session = parse_oss_url._sessions.setdefault(f'{bucket}@{os.getpid()}',
+ oss.Session())
+
+ # bucket
+ bucket = oss.Bucket(
+ auth=oss.Auth(configs['accessKeyID'], configs['accessKeySecret']),
+ endpoint=configs['endpoint'],
+ bucket_name=bucket,
+ session=session)
+ return bucket, path
+
+
+parse_oss_url._sessions = {}
+
+
+def parse_bucket(url):
+ return parse_oss_url(osp.join(url, '_placeholder'))[0]
+
+
+def read(filename, mode='r', retry=5):
+ assert mode in ['r', 'rb']
+ exception = None
+ for _ in range(retry):
+ try:
+ if filename.startswith('oss://'):
+ bucket, path = parse_oss_url(filename)
+ content = bucket.get_object(path).read()
+ if mode == 'r':
+ content = content.decode('utf-8')
+ elif filename.startswith('http'):
+ content = requests.get(filename).content
+ if mode == 'r':
+ content = content.decode('utf-8')
+ else:
+ with open(filename, mode=mode) as f:
+ content = f.read()
+ return content
+ except Exception as e:
+ exception = e
+ continue
+ else:
+ raise exception
+
+
+def read_image(filename, retry=5):
+ exception = None
+ for _ in range(retry):
+ try:
+ return Image.open(BytesIO(read(filename, mode='rb', retry=retry)))
+ except Exception as e:
+ exception = e
+ continue
+ else:
+ raise exception
+
+
+def download_video_to_file(filename, local_file, retry=5):
+ exception = None
+ for _ in range(retry):
+ try:
+ bucket, path = parse_oss_url(filename)
+ bucket.get_object_to_file(path, local_file)
+ break
+ except Exception as e:
+ exception = e
+ continue
+ else:
+ raise exception
+
+
+def read_gzip(filename, retry=5):
+ exception = None
+ for _ in range(retry):
+ try:
+ remove = False
+ if filename.startswith('oss://'):
+ bucket, path = parse_oss_url(filename)
+ filename = rand_name(suffix=osp.splitext(filename)[1])
+ bucket.get_object_to_file(path, filename)
+ remove = True
+ with gzip.open(filename) as f:
+ content = f.read()
+ if remove:
+ os.remove(filename)
+ return content
+ except Exception as e:
+ exception = e
+ continue
+ else:
+ raise exception
+
+
+def ceil_divide(a, b):
+ return int(math.ceil(a / b))
+
+
+def to_device(batch, device, non_blocking=False):
+ if isinstance(batch, (list, tuple)):
+ return type(batch)([to_device(u, device, non_blocking) for u in batch])
+ elif isinstance(batch, dict):
+ return type(batch)([(k, to_device(v, device, non_blocking))
+ for k, v in batch.items()])
+ elif isinstance(batch, torch.Tensor) and batch.device != device:
+ batch = batch.to(device, non_blocking=non_blocking)
+ return batch
+
+
+def put_object(bucket, oss_key, data, retry=5):
+ exception = None
+ for _ in range(retry):
+ try:
+ return bucket.put_object(oss_key, data)
+ except Exception as e:
+ exception = e
+ continue
+ else:
+ logger.info(
+ f'put_object to {oss_key} failed with error: {exception}',
+ flush=True)
+
+
+def put_torch_object(bucket, oss_key, data, retry=5):
+ exception = None
+ for _ in range(retry):
+ try:
+ buffer = BytesIO()
+ torch.save(data, buffer)
+ return bucket.put_object(oss_key, buffer.getvalue())
+ except Exception as e:
+ exception = e
+ continue
+ else:
+ logger.info(
+ f'put_torch_object to {oss_key} failed with error: {exception}',
+ flush=True)
+
+
+def put_object_from_file(bucket, oss_key, filename, retry=5):
+ exception = None
+ for _ in range(retry):
+ try:
+ return bucket.put_object_from_file(oss_key, filename)
+ except Exception as e:
+ exception = e
+ continue
+ else:
+ logger.error(
+ f'put_object_from_file to {oss_key} failed with error: {exception}',
+ flush=True)
+
+
+def get_object(bucket, oss_key, retry=5):
+ exception = None
+ for _ in range(retry):
+ try:
+ return bucket.get_object(oss_key).read()
+ except Exception as e:
+ exception = e
+ continue
+ else:
+ logger.error(
+ f'get_object from {oss_key} failed with error: {exception}',
+ flush=True)
+
+
+def get_object_to_file(bucket, oss_key, filename, retry=5):
+ exception = None
+ for _ in range(retry):
+ try:
+ return bucket.get_object_to_file(oss_key, filename)
+ except Exception as e:
+ exception = e
+ continue
+ else:
+ logger.error(
+ f'get_object_to_file from {oss_key} failed with error: {exception}',
+ flush=True)
+
+
+@torch.no_grad()
+def save_image(bucket,
+ oss_key,
+ tensor,
+ nrow=8,
+ normalize=True,
+ range=(-1, 1),
+ retry=5):
+ filename = rand_name(suffix='.jpg')
+ for _ in [None] * retry:
+ try:
+ tvutils.save_image(
+ tensor, filename, nrow=nrow, normalize=normalize, range=range)
+ bucket.put_object_from_file(oss_key, filename)
+ exception = None
+ break
+ except Exception as e:
+ exception = e
+ continue
+
+ # remove temporary file
+ if osp.exists(filename):
+ os.remove(filename)
+ if exception is not None:
+ logger.error(
+ 'save image to {} failed, error: {}'.format(oss_key, exception),
+ flush=True)
+
+
+@torch.no_grad()
+def video_tensor_to_gif(tensor, path, duration=120, loop=0, optimize=True):
+ tensor = tensor.permute(1, 2, 3, 0)
+ images = tensor.unbind(dim=0)
+ images = [(image.numpy() * 255).astype('uint8') for image in images]
+ imageio.mimwrite(path, images, duration=125)
+ return images
+
+
+@torch.no_grad()
+def save_video(bucket,
+ oss_key,
+ tensor,
+ mean=[0.5, 0.5, 0.5],
+ std=[0.5, 0.5, 0.5],
+ nrow=8,
+ retry=5):
+ mean = torch.tensor(mean, device=tensor.device).view(1, -1, 1, 1, 1)
+ std = torch.tensor(std, device=tensor.device).view(1, -1, 1, 1, 1)
+ tensor = tensor.mul_(std).add_(mean)
+ tensor.clamp_(0, 1)
+
+ filename = rand_name(suffix='.gif')
+ for _ in [None] * retry:
+ try:
+ one_gif = rearrange(
+ tensor, '(i j) c f h w -> c f (i h) (j w)', i=nrow)
+ video_tensor_to_gif(one_gif, filename)
+ bucket.put_object_from_file(oss_key, filename)
+ exception = None
+ break
+ except Exception as e:
+ exception = e
+ continue
+
+ # remove temporary file
+ if osp.exists(filename):
+ os.remove(filename)
+ if exception is not None:
+ logger.error(
+ 'save video to {} failed, error: {}'.format(oss_key, exception),
+ flush=True)
+
+
+@torch.no_grad()
+def save_video_multiple_conditions(oss_key,
+ video_tensor,
+ model_kwargs,
+ source_imgs,
+ palette,
+ mean=[0.5, 0.5, 0.5],
+ std=[0.5, 0.5, 0.5],
+ nrow=8,
+ retry=5,
+ save_origin_video=True,
+ bucket=None):
+ mean = torch.tensor(mean, device=video_tensor.device).view(1, -1, 1, 1, 1)
+ std = torch.tensor(std, device=video_tensor.device).view(1, -1, 1, 1, 1)
+ video_tensor = video_tensor.mul_(std).add_(mean)
+ try:
+ video_tensor.clamp_(0, 1)
+ except Exception as e:
+ logger.error(e)
+ video_tensor = video_tensor.float().clamp_(0, 1)
+ video_tensor = video_tensor.cpu()
+
+ b, c, n, h, w = video_tensor.shape
+ source_imgs = F.adaptive_avg_pool3d(source_imgs, (n, h, w))
+ source_imgs = source_imgs.cpu()
+
+ model_kwargs_channel3 = {}
+ for key, conditions in model_kwargs[0].items():
+ if conditions.shape[-1] == 1024:
+ # Skip for style embeding
+ continue
+ if len(conditions.shape) == 3:
+ conditions_np = conditions.cpu().numpy()
+ conditions = []
+ for i in conditions_np:
+ vis_i = []
+ for j in i:
+ vis_i.append(
+ palette.get_palette_image(
+ j, percentile=90, width=256, height=256))
+ conditions.append(np.stack(vis_i))
+ conditions = torch.from_numpy(np.stack(conditions))
+ conditions = rearrange(conditions, 'b n h w c -> b c n h w')
+ else:
+ if conditions.size(1) == 1:
+ conditions = torch.cat([conditions, conditions, conditions],
+ dim=1)
+ conditions = F.adaptive_avg_pool3d(conditions, (n, h, w))
+ if conditions.size(1) == 2:
+ conditions = torch.cat([conditions, conditions[:, :1, ]],
+ dim=1)
+ conditions = F.adaptive_avg_pool3d(conditions, (n, h, w))
+ elif conditions.size(1) == 3:
+ conditions = F.adaptive_avg_pool3d(conditions, (n, h, w))
+ elif conditions.size(1) == 4:
+ color = ((conditions[:, 0:3] + 1.) / 2.)
+ alpha = conditions[:, 3:4]
+ conditions = color * alpha + 1.0 * (1.0 - alpha)
+ conditions = F.adaptive_avg_pool3d(conditions, (n, h, w))
+ model_kwargs_channel3[key] = conditions.cpu(
+ ) if conditions.is_cuda else conditions
+
+ filename = oss_key
+ for _ in [None] * retry:
+ try:
+ vid_gif = rearrange(
+ video_tensor, '(i j) c f h w -> c f (i h) (j w)', i=nrow)
+ cons_list = [
+ rearrange(con, '(i j) c f h w -> c f (i h) (j w)', i=nrow)
+ for _, con in model_kwargs_channel3.items()
+ ]
+ source_imgs = rearrange(
+ source_imgs, '(i j) c f h w -> c f (i h) (j w)', i=nrow)
+
+ if save_origin_video:
+ vid_gif = torch.cat(
+ [
+ source_imgs,
+ ] + cons_list + [
+ vid_gif,
+ ], dim=3)
+ else:
+ vid_gif = torch.cat(
+ cons_list + [
+ vid_gif,
+ ], dim=3)
+
+ video_tensor_to_gif(vid_gif, filename)
+ exception = None
+ break
+ except Exception as e:
+ exception = e
+ continue
+ if exception is not None:
+ logging.info('save video to {} failed, error: {}'.format(
+ oss_key, exception))
+
+
+@torch.no_grad()
+def save_video_multiple_conditions_with_data(bucket,
+ video_save_key,
+ gt_video_save_key,
+ vis_oss_key,
+ video_tensor,
+ model_kwargs,
+ source_imgs,
+ palette,
+ mean=[0.5, 0.5, 0.5],
+ std=[0.5, 0.5, 0.5],
+ nrow=8,
+ retry=5):
+ mean = torch.tensor(mean, device=video_tensor.device).view(1, -1, 1, 1, 1)
+ std = torch.tensor(std, device=video_tensor.device).view(1, -1, 1, 1, 1)
+ video_tensor = video_tensor.mul_(std).add_(mean)
+ video_tensor.clamp_(0, 1)
+
+ b, c, n, h, w = video_tensor.shape
+ source_imgs = F.adaptive_avg_pool3d(source_imgs, (n, h, w))
+ source_imgs = source_imgs.cpu()
+
+ model_kwargs_channel3 = {}
+ for key, conditions in model_kwargs[0].items():
+ if len(conditions.shape) == 3:
+ conditions_np = conditions.cpu().numpy()
+ conditions = []
+ for i in conditions_np:
+ vis_i = []
+ for j in i:
+ vis_i.append(
+ palette.get_palette_image(
+ j, percentile=90, width=256, height=256))
+ conditions.append(np.stack(vis_i))
+ conditions = torch.from_numpy(np.stack(conditions))
+ conditions = rearrange(conditions, 'b n h w c -> b c n h w')
+ else:
+ if conditions.size(1) == 1:
+ conditions = torch.cat([conditions, conditions, conditions],
+ dim=1)
+ conditions = F.adaptive_avg_pool3d(conditions, (n, h, w))
+ if conditions.size(1) == 2:
+ conditions = torch.cat([conditions, conditions[:, :1, ]],
+ dim=1)
+ conditions = F.adaptive_avg_pool3d(conditions, (n, h, w))
+ elif conditions.size(1) == 3:
+ conditions = F.adaptive_avg_pool3d(conditions, (n, h, w))
+ elif conditions.size(1) == 4:
+ color = ((conditions[:, 0:3] + 1.) / 2.)
+ alpha = conditions[:, 3:4]
+ conditions = color * alpha + 1.0 * (1.0 - alpha)
+ conditions = F.adaptive_avg_pool3d(conditions, (n, h, w))
+ model_kwargs_channel3[key] = conditions.cpu(
+ ) if conditions.is_cuda else conditions
+
+ copy_video_tensor = video_tensor.clone()
+ copy_source_imgs = source_imgs.clone()
+
+ filename = rand_name(suffix='.gif')
+ for _ in [None] * retry:
+ try:
+ vid_gif = rearrange(
+ video_tensor, '(i j) c f h w -> c f (i h) (j w)', i=nrow)
+ cons_list = [
+ rearrange(con, '(i j) c f h w -> c f (i h) (j w)', j=nrow)
+ for _, con in model_kwargs_channel3.items()
+ ]
+ source_imgs = rearrange(
+ source_imgs, '(i j) c f h w -> c f (i h) (j w)', i=nrow)
+ vid_gif = torch.cat(
+ [
+ source_imgs,
+ ] + cons_list + [
+ vid_gif,
+ ], dim=3)
+
+ video_tensor_to_gif(vid_gif, filename)
+ bucket.put_object_from_file(vis_oss_key, filename)
+ exception = None
+ break
+ except Exception as e:
+ exception = e
+ continue
+
+ # remove temporary file
+ if osp.exists(filename):
+ os.remove(filename)
+
+ filename_pred = rand_name(suffix='.pkl')
+ for _ in [None] * retry:
+ try:
+ copy_video_np = (copy_video_tensor.numpy() * 255).astype('uint8')
+ pickle.dump(copy_video_np, open(filename_pred, 'wb'))
+ bucket.put_object_from_file(video_save_key, filename_pred)
+ break
+ except Exception as e:
+ logger.error('error! ', video_save_key, e)
+ continue
+
+ # remove temporary file
+ if osp.exists(filename_pred):
+ os.remove(filename_pred)
+
+ filename_gt = rand_name(suffix='.pkl')
+ for _ in [None] * retry:
+ try:
+ copy_source_np = (copy_source_imgs.numpy() * 255).astype('uint8')
+ pickle.dump(copy_source_np, open(filename_gt, 'wb'))
+ bucket.put_object_from_file(gt_video_save_key, filename_gt)
+ break
+ except Exception as e:
+ logger.error('error! ', gt_video_save_key, e)
+ continue
+
+ # remove temporary file
+ if osp.exists(filename_gt):
+ os.remove(filename_gt)
+
+ if exception is not None:
+ logger.error(
+ 'save video to {} failed, error: {}'.format(
+ vis_oss_key, exception),
+ flush=True)
+
+
+@torch.no_grad()
+def save_video_vs_conditions(bucket,
+ oss_key,
+ video_tensor,
+ conditions,
+ source_imgs,
+ mean=[0.5, 0.5, 0.5],
+ std=[0.5, 0.5, 0.5],
+ nrow=8,
+ retry=5):
+ mean = torch.tensor(mean, device=video_tensor.device).view(1, -1, 1, 1, 1)
+ std = torch.tensor(std, device=video_tensor.device).view(1, -1, 1, 1, 1)
+ video_tensor = video_tensor.mul_(std).add_(mean)
+ video_tensor.clamp_(0, 1)
+
+ b, c, n, h, w = video_tensor.shape
+ source_imgs = F.adaptive_avg_pool3d(source_imgs, (n, h, w))
+ source_imgs = source_imgs.cpu()
+
+ if conditions.size(1) == 1:
+ conditions = torch.cat([conditions, conditions, conditions], dim=1)
+ conditions = F.adaptive_avg_pool3d(conditions, (n, h, w))
+
+ filename = rand_name(suffix='.gif')
+ for _ in [None] * retry:
+ try:
+ vid_gif = rearrange(
+ video_tensor, '(i j) c f h w -> c f (i h) (j w)', i=nrow)
+ con_gif = rearrange(
+ conditions, '(i j) c f h w -> c f (i h) (j w)', i=nrow)
+ source_imgs = rearrange(
+ source_imgs, '(i j) c f h w -> c f (i h) (j w)', i=nrow)
+ vid_gif = torch.cat([vid_gif, con_gif, source_imgs], dim=2)
+
+ video_tensor_to_gif(vid_gif, filename)
+ bucket.put_object_from_file(oss_key, filename)
+ exception = None
+ break
+ except Exception as e:
+ exception = e
+ logger.error(exception)
+ continue
+
+ # remove temporary file
+ if osp.exists(filename):
+ os.remove(filename)
+ if exception is not None:
+ logger.error(
+ 'save video to {} failed, error: {}'.format(oss_key, exception),
+ flush=True)
+
+
+@torch.no_grad()
+def save_video_grid_mp4(bucket,
+ oss_key,
+ tensor,
+ mean=[0.5, 0.5, 0.5],
+ std=[0.5, 0.5, 0.5],
+ nrow=None,
+ fps=5,
+ retry=5):
+ mean = torch.tensor(mean, device=tensor.device).view(1, -1, 1, 1, 1)
+ std = torch.tensor(std, device=tensor.device).view(1, -1, 1, 1, 1)
+ tensor = tensor.mul_(std).add_(mean)
+ tensor.clamp_(0, 1)
+ b, c, t, h, w = tensor.shape
+ tensor = tensor.permute(0, 2, 3, 4, 1)
+ tensor = (tensor.cpu().numpy() * 255).astype('uint8')
+
+ filename = rand_name(suffix='.mp4')
+ for _ in [None] * retry:
+ try:
+ if nrow is None:
+ nrow = math.ceil(math.sqrt(b))
+ ncol = math.ceil(b / nrow)
+ padding = 1
+ video_grid = np.zeros((t, (padding + h) * nrow + padding,
+ (padding + w) * ncol + padding, c),
+ dtype='uint8')
+ for i in range(b):
+ r = i // ncol
+ c_ = i % ncol
+
+ start_r = (padding + h) * r
+ start_c = (padding + w) * c_
+ video_grid[:, start_r:start_r + h,
+ start_c:start_c + w] = tensor[i]
+ skvideo.io.vwrite(filename, video_grid, inputdict={'-r': str(fps)})
+
+ bucket.put_object_from_file(oss_key, filename)
+ exception = None
+ break
+ except Exception as e:
+ exception = e
+ logger.error(exception)
+ continue
+
+ # remove temporary file
+ if osp.exists(filename):
+ os.remove(filename)
+ if exception is not None:
+ logger.error(
+ 'save video to {} failed, error: {}'.format(oss_key, exception),
+ flush=True)
+
+
+@torch.no_grad()
+def save_text(bucket, oss_key, tensor, nrow=8, retry=5):
+ len = tensor.shape[0]
+ num_per_row = int(len / nrow)
+ assert (len == nrow * num_per_row)
+ texts = ''
+ for i in range(nrow):
+ for j in range(num_per_row):
+ text = dec_bytes2obj(tensor[i * num_per_row + j])
+ texts += text + '\n'
+ texts += '\n'
+
+ for _ in [None] * retry:
+ try:
+ bucket.put_object(oss_key, texts)
+ exception = None
+ break
+ except Exception as e:
+ exception = e
+ logger.error(exception)
+ continue
+ if exception is not None:
+ logger.error(
+ 'save video to {} failed, error: {}'.format(oss_key, exception),
+ flush=True)
+
+
+@torch.no_grad()
+def save_caps(bucket, oss_key, caps, retry=5):
+ texts = ''
+ for cap in caps:
+ texts += cap
+ texts += '\n'
+
+ for _ in [None] * retry:
+ try:
+ bucket.put_object(oss_key, texts)
+ exception = None
+ break
+ except Exception as e:
+ exception = e
+ logger.error(exception)
+ continue
+ if exception is not None:
+ logger.error(
+ 'save video to {} failed, error: {}'.format(oss_key, exception),
+ flush=True)
+
+
+@torch.no_grad()
+def ema(net_ema, net, beta, copy_buffer=False):
+ assert 0.0 <= beta <= 1.0
+ for p_ema, p in zip(net_ema.parameters(), net.parameters()):
+ p_ema.copy_(p.lerp(p_ema, beta))
+ if copy_buffer:
+ for b_ema, b in zip(net_ema.buffers(), net.buffers()):
+ b_ema.copy_(b)
+
+
+def parallel(func, args_list, num_workers=32, timeout=None):
+ assert isinstance(args_list, list)
+ if not isinstance(args_list[0], tuple):
+ args_list = [(args, ) for args in args_list]
+ if num_workers == 0:
+ return [func(*args) for args in args_list]
+ with Pool(processes=num_workers) as pool:
+ results = [pool.apply_async(func, args) for args in args_list]
+ results = [res.get(timeout=timeout) for res in results]
+ return results
+
+
+def exists(filename):
+ if filename.startswith('oss://'):
+ bucket, path = parse_oss_url(filename)
+ return bucket.object_exists(path)
+ else:
+ return osp.exists(filename)
+
+
+def download(url, filename=None, replace=False, quiet=False):
+ if filename is None:
+ filename = osp.basename(url)
+ if not osp.exists(filename) or replace:
+ try:
+ if url.startswith('oss://'):
+ bucket, oss_key = parse_oss_url(url)
+ bucket.get_object_to_file(oss_key, filename)
+ else:
+ urllib.request.urlretrieve(url, filename)
+ if not quiet:
+ logger.error(f'Downloaded {url} to {filename}', flush=True)
+ except Exception as e:
+ raise ValueError(f'Downloading {filename} failed with error {e}')
+ return osp.abspath(filename)
+
+
+def unzip(filename, dst_dir=None):
+ if dst_dir is None:
+ dst_dir = osp.dirname(filename)
+ with zipfile.ZipFile(filename, 'r') as zip_ref:
+ zip_ref.extractall(dst_dir)
+
+
+def load_state_dict(module, state_dict, drop_prefix=''):
+ # find incompatible key-vals
+ src, dst = state_dict, module.state_dict()
+ if drop_prefix:
+ src = type(src)([
+ (k[len(drop_prefix):] if k.startswith(drop_prefix) else k, v)
+ for k, v in src.items()
+ ])
+ missing = [k for k in dst if k not in src]
+ unexpected = [k for k in src if k not in dst]
+ unmatched = [
+ k for k in src.keys() & dst.keys() if src[k].shape != dst[k].shape
+ ]
+
+ # keep only compatible key-vals
+ incompatible = set(unexpected + unmatched)
+ src = type(src)([(k, v) for k, v in src.items() if k not in incompatible])
+ module.load_state_dict(src, strict=False)
+
+ # report incompatible key-vals
+ if len(missing) != 0:
+ logger.info(' Missing: ' + ', '.join(missing), flush=True)
+ if len(unexpected) != 0:
+ logger.info(' Unexpected: ' + ', '.join(unexpected), flush=True)
+ if len(unmatched) != 0:
+ logger.info(' Shape unmatched: ' + ', '.join(unmatched), flush=True)
+
+
+def inverse_indices(indices):
+ r"""Inverse map of indices.
+ E.g., if A[indices] == B, then B[inv_indices] == A.
+ """
+ inv_indices = torch.empty_like(indices)
+ inv_indices[indices] = torch.arange(len(indices)).to(indices)
+ return inv_indices
+
+
+def detect_duplicates(feats, thr=0.9):
+ assert feats.ndim == 2
+
+ # compute simmat
+ feats = F.normalize(feats, p=2, dim=1)
+ simmat = torch.mm(feats, feats.T)
+ simmat.triu_(1)
+ torch.cuda.synchronize()
+
+ # detect duplicates
+ mask = ~simmat.gt(thr).any(dim=0)
+ return torch.where(mask)[0]
+
+
+class TFSClient(object):
+
+ def __init__(self,
+ host='restful-store.vip.tbsite.net:3800',
+ app_key='5354c9fae75f5'):
+ self.host = host
+ self.app_key = app_key
+
+ # candidate servers
+ self.servers = [
+ u for u in read(f'http://{host}/url.list').strip().split('\n')[1:]
+ if ':' in u
+ ]
+ assert len(self.servers) >= 1
+ self.__server_id = -1
+
+ @property
+ def server(self):
+ self.__server_id = (self.__server_id + 1) % len(self.servers)
+ return self.servers[self.__server_id]
+
+ def read(self, tfs):
+ tfs = osp.basename(tfs)
+ meta = json.loads(
+ read(
+ f'http://{self.server}/v1/{self.app_key}/metadata/{tfs}?force=0'
+ ))
+ img = Image.open(
+ BytesIO(
+ read(
+ f'http://{self.server}/v1/{self.app_key}/{tfs}?offset=0&size={meta["SIZE"]}',
+ 'rb')))
+ return img
+
+
+def read_tfs(tfs, retry=5):
+ exception = None
+ for _ in range(retry):
+ try:
+ global TFS_CLIENT
+ if TFS_CLIENT is None:
+ TFS_CLIENT = TFSClient()
+ return TFS_CLIENT.read(tfs)
+ except Exception as e:
+ exception = e
+ continue
+ else:
+ raise exception
+
+
+def md5(filename):
+ with open(filename, 'rb') as f:
+ return hashlib.md5(f.read()).hexdigest()
+
+
+def rope(x):
+ r"""Apply rotary position embedding on x of shape [B, *(spatial dimensions), C].
+ """
+ # reshape
+ shape = x.shape
+ x = x.view(x.size(0), -1, x.size(-1))
+ l, c = x.shape[-2:]
+ assert c % 2 == 0
+ half = c // 2
+
+ # apply rotary position embedding on x
+ sinusoid = torch.outer(
+ torch.arange(l).to(x),
+ torch.pow(10000, -torch.arange(half).to(x).div(half)))
+ sin, cos = torch.sin(sinusoid), torch.cos(sinusoid)
+ x1, x2 = x.chunk(2, dim=-1)
+ x = torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1)
+
+ # reshape back
+ return x.view(shape)
+
+
+def format_state(state, filename=None):
+ r"""For comparing/aligning state_dict.
+ """
+ content = '\n'.join([f'{k}\t{tuple(v.shape)}' for k, v in state.items()])
+ if filename:
+ with open(filename, 'w') as f:
+ f.write(content)
+
+
+def breakup_grid(img, grid_size):
+ r"""The inverse operator of ``torchvision.utils.make_grid``.
+ """
+ # params
+ nrow = img.height // grid_size
+ ncol = img.width // grid_size
+ wrow = wcol = 2
+
+ # collect grids
+ grids = []
+ for i in range(nrow):
+ for j in range(ncol):
+ x1 = j * grid_size + (j + 1) * wcol
+ y1 = i * grid_size + (i + 1) * wrow
+ grids.append(img.crop((x1, y1, x1 + grid_size, y1 + grid_size)))
+ return grids
+
+
+def huggingface_tokenizer(name='google/mt5-xxl', **kwargs):
+ from transformers import AutoTokenizer
+ return AutoTokenizer.from_pretrained(
+ DOWNLOAD_TO_CACHE(f'huggingface/tokenizers/{name}', name), **kwargs)
+
+
+def huggingface_model(name='google/mt5-xxl', model_type='AutoModel', **kwargs):
+ import transformers
+ return getattr(transformers, model_type).from_pretrained(
+ DOWNLOAD_TO_CACHE(f'huggingface/models/{name}', name), **kwargs)
diff --git a/modelscope/models/multi_modal/videocomposer/unet_sd.py b/modelscope/models/multi_modal/videocomposer/unet_sd.py
new file mode 100644
index 00000000..23ee29f6
--- /dev/null
+++ b/modelscope/models/multi_modal/videocomposer/unet_sd.py
@@ -0,0 +1,2102 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+
+import math
+import os
+from functools import partial
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange
+from fairscale.nn.checkpoint import checkpoint_wrapper
+from rotary_embedding_torch import RotaryEmbedding
+from torch import einsum
+
+__all__ = ['UNetSD_temporal']
+
+USE_TEMPORAL_TRANSFORMER = True
+
+
+# load all keys started with prefix and replace them with new_prefix
+def load_Block(state, prefix, new_prefix=None):
+ if new_prefix is None:
+ new_prefix = prefix
+
+ state_dict = {}
+
+ state = {key: value for key, value in state.items() if prefix in key}
+
+ for key, value in state.items():
+ new_key = key.replace(prefix, new_prefix)
+ state_dict[new_key] = value
+
+ return state_dict
+
+
+def load_2d_pretrained_state_dict(state, cfg):
+
+ new_state_dict = {}
+
+ dim = cfg.unet_dim
+ num_res_blocks = cfg.unet_res_blocks
+ dim_mult = cfg.unet_dim_mult
+ attn_scales = cfg.unet_attn_scales
+
+ # params
+ enc_dims = [dim * u for u in [1] + dim_mult]
+ dec_dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
+ shortcut_dims = []
+ scale = 1.0
+
+ # embeddings
+ state_dict = load_Block(state, prefix='time_embedding')
+ new_state_dict.update(state_dict)
+ state_dict = load_Block(state, prefix='y_embedding')
+ new_state_dict.update(state_dict)
+ state_dict = load_Block(state, prefix='context_embedding')
+ new_state_dict.update(state_dict)
+
+ encoder_idx = 0
+ # init block
+ state_dict = load_Block(
+ state,
+ prefix=f'encoder.{encoder_idx}',
+ new_prefix=f'encoder.{encoder_idx}.0')
+ new_state_dict.update(state_dict)
+ encoder_idx += 1
+
+ shortcut_dims.append(dim)
+ for i, (in_dim, out_dim) in enumerate(zip(enc_dims[:-1], enc_dims[1:])):
+ for j in range(num_res_blocks):
+ # residual (+attention) blocks
+ idx = 0
+ idx_ = 0
+ # residual (+attention) blocks
+ state_dict = load_Block(
+ state,
+ prefix=f'encoder.{encoder_idx}.{idx}',
+ new_prefix=f'encoder.{encoder_idx}.{idx_}')
+ new_state_dict.update(state_dict)
+ idx += 1
+ idx_ = 2
+
+ if scale in attn_scales:
+ # block.append(AttentionBlock(out_dim, context_dim, num_heads, head_dim))
+ state_dict = load_Block(
+ state,
+ prefix=f'encoder.{encoder_idx}.{idx}',
+ new_prefix=f'encoder.{encoder_idx}.{idx_}')
+ new_state_dict.update(state_dict)
+ in_dim = out_dim
+ encoder_idx += 1
+ shortcut_dims.append(out_dim)
+
+ # downsample
+ if i != len(dim_mult) - 1 and j == num_res_blocks - 1:
+ state_dict = load_Block(
+ state,
+ prefix='encoder.{encoder_idx}',
+ new_prefix='encoder.{encoder_idx}.0')
+ new_state_dict.update(state_dict)
+
+ shortcut_dims.append(out_dim)
+ scale /= 2.0
+ encoder_idx += 1
+
+ # middle
+ middle_idx = 0
+
+ state_dict = load_Block(state, prefix=f'middle.{middle_idx}')
+ new_state_dict.update(state_dict)
+ middle_idx += 2
+
+ state_dict = load_Block(
+ state, prefix='middle.1', new_prefix=f'middle.{middle_idx}')
+ new_state_dict.update(state_dict)
+ middle_idx += 1
+
+ for _ in range(cfg.temporal_attn_times):
+ middle_idx += 1
+
+ state_dict = load_Block(
+ state, prefix='middle.2', new_prefix=f'middle.{middle_idx}')
+ new_state_dict.update(state_dict)
+ middle_idx += 2
+
+ decoder_idx = 0
+ for i, (in_dim, out_dim) in enumerate(zip(dec_dims[:-1], dec_dims[1:])):
+ for j in range(num_res_blocks + 1):
+ idx = 0
+ idx_ = 0
+ # residual (+attention) blocks
+ state_dict = load_Block(
+ state,
+ prefix=f'decoder.{decoder_idx}.{idx}',
+ new_prefix=f'decoder.{decoder_idx}.{idx_}')
+ new_state_dict.update(state_dict)
+ idx += 1
+ idx_ += 2
+ if scale in attn_scales:
+ state_dict = load_Block(
+ state,
+ prefix=f'decoder.{decoder_idx}.{idx}',
+ new_prefix=f'decoder.{decoder_idx}.{idx_}')
+ new_state_dict.update(state_dict)
+ idx += 1
+ idx_ += 1
+ for _ in range(cfg.temporal_attn_times):
+ idx_ += 1
+
+ # upsample
+ if i != len(dim_mult) - 1 and j == num_res_blocks:
+ state_dict = load_Block(
+ state,
+ prefix=f'decoder.{decoder_idx}.{idx}',
+ new_prefix=f'decoder.{decoder_idx}.{idx_}')
+ new_state_dict.update(state_dict)
+ idx += 1
+ idx_ += 2
+
+ scale *= 2.0
+ decoder_idx += 1
+
+ state_dict = load_Block(state, prefix='head')
+ new_state_dict.update(state_dict)
+
+ return new_state_dict
+
+
+def sinusoidal_embedding(timesteps, dim):
+ # check input
+ half = dim // 2
+ timesteps = timesteps.float()
+
+ # compute sinusoidal embedding
+ sinusoid = torch.outer(
+ timesteps, torch.pow(10000,
+ -torch.arange(half).to(timesteps).div(half)))
+ x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
+ if dim % 2 != 0:
+ x = torch.cat([x, torch.zeros_like(x[:, :1])], dim=1)
+ return x
+
+
+def exists(x):
+ return x is not None
+
+
+def default(val, d):
+ if exists(val):
+ return val
+ return d() if callable(d) else d
+
+
+def prob_mask_like(shape, prob, device):
+ if prob == 1:
+ return torch.ones(shape, device=device, dtype=torch.bool)
+ elif prob == 0:
+ return torch.zeros(shape, device=device, dtype=torch.bool)
+ else:
+ mask = torch.zeros(shape, device=device).float().uniform_(0, 1) < prob
+ # aviod mask all, which will cause find_unused_parameters error
+ if mask.all():
+ mask[0] = False
+ return mask
+
+
+class RelativePositionBias(nn.Module):
+
+ def __init__(self, heads=8, num_buckets=32, max_distance=128):
+ super().__init__()
+ self.num_buckets = num_buckets
+ self.max_distance = max_distance
+ self.relative_attention_bias = nn.Embedding(num_buckets, heads)
+
+ @staticmethod
+ def _relative_position_bucket(relative_position,
+ num_buckets=32,
+ max_distance=128):
+ ret = 0
+ n = -relative_position
+
+ num_buckets //= 2
+ ret += (n < 0).long() * num_buckets
+ n = torch.abs(n)
+
+ max_exact = num_buckets // 2
+ is_small = n < max_exact
+
+ val_if_large = max_exact + (
+ torch.log(n.float() / max_exact)
+ / math.log(max_distance / max_exact) * # noqa
+ (num_buckets - max_exact)).long()
+ val_if_large = torch.min(
+ val_if_large, torch.full_like(val_if_large, num_buckets - 1))
+
+ ret += torch.where(is_small, n, val_if_large)
+ return ret
+
+ def forward(self, n, device):
+ q_pos = torch.arange(n, dtype=torch.long, device=device)
+ k_pos = torch.arange(n, dtype=torch.long, device=device)
+ rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1')
+ rp_bucket = self._relative_position_bucket(
+ rel_pos,
+ num_buckets=self.num_buckets,
+ max_distance=self.max_distance)
+ values = self.relative_attention_bias(rp_bucket)
+ return rearrange(values, 'i j h -> h i j')
+
+
+class SpatialTransformer(nn.Module):
+ """
+ Transformer block for image-like data.
+ First, project the input (aka embedding)
+ and reshape to b, t, d.
+ Then apply standard transformer action.
+ Finally, reshape to image
+ NEW: use_linear for more efficiency instead of the 1x1 convs
+ """
+
+ def __init__(self,
+ in_channels,
+ n_heads,
+ d_head,
+ depth=1,
+ dropout=0.,
+ context_dim=None,
+ disable_self_attn=False,
+ use_linear=False,
+ use_checkpoint=True):
+ super().__init__()
+ if exists(context_dim) and not isinstance(context_dim, list):
+ context_dim = [context_dim]
+ self.in_channels = in_channels
+ inner_dim = n_heads * d_head
+ self.norm = torch.nn.GroupNorm(
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
+ if not use_linear:
+ self.proj_in = nn.Conv2d(
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
+ else:
+ self.proj_in = nn.Linear(in_channels, inner_dim)
+
+ self.transformer_blocks = nn.ModuleList([
+ BasicTransformerBlock(
+ inner_dim,
+ n_heads,
+ d_head,
+ dropout=dropout,
+ context_dim=context_dim[d],
+ disable_self_attn=disable_self_attn,
+ checkpoint=use_checkpoint) for d in range(depth)
+ ])
+ if not use_linear:
+ self.proj_out = zero_module(
+ nn.Conv2d(
+ inner_dim, in_channels, kernel_size=1, stride=1,
+ padding=0))
+ else:
+ self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
+ self.use_linear = use_linear
+
+ def forward(self, x, context=None):
+ # note: if no context is given, cross-attention defaults to self-attention
+ if not isinstance(context, list):
+ context = [context]
+ b, c, h, w = x.shape
+ x_in = x
+ x = self.norm(x)
+ if not self.use_linear:
+ x = self.proj_in(x)
+ x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
+ if self.use_linear:
+ x = self.proj_in(x)
+ for i, block in enumerate(self.transformer_blocks):
+ x = block(x, context=context[i])
+ if self.use_linear:
+ x = self.proj_out(x)
+ x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
+ if not self.use_linear:
+ x = self.proj_out(x)
+ return x + x_in
+
+
+_ATTN_PRECISION = os.environ.get('ATTN_PRECISION', 'fp32')
+
+
+class CrossAttention(nn.Module):
+
+ def __init__(self,
+ query_dim,
+ context_dim=None,
+ heads=8,
+ dim_head=64,
+ dropout=0.):
+ super().__init__()
+ inner_dim = dim_head * heads
+ context_dim = default(context_dim, query_dim)
+
+ self.scale = dim_head**-0.5
+ self.heads = heads
+
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
+
+ self.to_out = nn.Sequential(
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
+
+ def forward(self, x, context=None, mask=None):
+ h = self.heads
+
+ q = self.to_q(x)
+ context = default(context, x)
+ k = self.to_k(context)
+ v = self.to_v(context)
+
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
+ (q, k, v))
+
+ # force cast to fp32 to avoid overflowing
+ if _ATTN_PRECISION == 'fp32':
+ with torch.autocast(enabled=False, device_type='cuda'):
+ q, k = q.float(), k.float()
+ sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale
+ else:
+ sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale
+
+ del q, k
+
+ if exists(mask):
+ mask = rearrange(mask, 'b ... -> b (...)')
+ max_neg_value = -torch.finfo(sim.dtype).max
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
+ sim.masked_fill_(~mask, max_neg_value)
+
+ # attention, what we cannot get enough of
+ sim = sim.softmax(dim=-1)
+
+ out = torch.einsum('b i j, b j d -> b i d', sim, v)
+ out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
+ return self.to_out(out)
+
+
+class BasicTransformerBlock(nn.Module):
+
+ def __init__(self,
+ dim,
+ n_heads,
+ d_head,
+ dropout=0.,
+ context_dim=None,
+ gated_ff=True,
+ checkpoint=True,
+ disable_self_attn=False):
+ super().__init__()
+ attn_cls = CrossAttention
+ self.disable_self_attn = disable_self_attn
+ self.attn1 = attn_cls(
+ query_dim=dim,
+ heads=n_heads,
+ dim_head=d_head,
+ dropout=dropout,
+ context_dim=context_dim if self.disable_self_attn else None)
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
+ self.attn2 = attn_cls(
+ query_dim=dim,
+ context_dim=context_dim,
+ heads=n_heads,
+ dim_head=d_head,
+ dropout=dropout)
+ self.norm1 = nn.LayerNorm(dim)
+ self.norm2 = nn.LayerNorm(dim)
+ self.norm3 = nn.LayerNorm(dim)
+ self.checkpoint = checkpoint
+
+ def forward_(self, x, context=None):
+ return checkpoint(self._forward, (x, context), self.parameters(),
+ self.checkpoint)
+
+ def forward(self, x, context=None):
+ x = self.attn1(
+ self.norm1(x),
+ context=context if self.disable_self_attn else None) + x
+ x = self.attn2(self.norm2(x), context=context) + x
+ x = self.ff(self.norm3(x)) + x
+ return x
+
+
+# feedforward
+class GEGLU(nn.Module):
+
+ def __init__(self, dim_in, dim_out):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out * 2)
+
+ def forward(self, x):
+ x, gate = self.proj(x).chunk(2, dim=-1)
+ return x * F.gelu(gate)
+
+
+def zero_module(module):
+ """
+ Zero out the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+
+class FeedForward(nn.Module):
+
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
+ super().__init__()
+ inner_dim = int(dim * mult)
+ dim_out = default(dim_out, dim)
+ project_in = nn.Sequential(nn.Linear(
+ dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)
+
+ self.net = nn.Sequential(project_in, nn.Dropout(dropout),
+ nn.Linear(inner_dim, dim_out))
+
+ def forward(self, x):
+ return self.net(x)
+
+
+class Upsample(nn.Module):
+ """
+ An upsampling layer with an optional convolution.
+ :param channels: channels in the inputs and outputs.
+ :param use_conv: a bool determining if a convolution is applied.
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ upsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(self,
+ channels,
+ use_conv,
+ dims=2,
+ out_channels=None,
+ padding=1):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.dims = dims
+ if use_conv:
+ self.conv = nn.Conv2d(
+ self.channels, self.out_channels, 3, padding=padding)
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ if self.dims == 3:
+ x = F.interpolate(
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2),
+ mode='nearest')
+ else:
+ x = F.interpolate(x, scale_factor=2, mode='nearest')
+ if self.use_conv:
+ x = self.conv(x)
+ return x
+
+
+class ResBlock(nn.Module):
+ """
+ A residual block that can optionally change the number of channels.
+ :param channels: the number of input channels.
+ :param emb_channels: the number of timestep embedding channels.
+ :param dropout: the rate of dropout.
+ :param out_channels: if specified, the number of out channels.
+ :param use_conv: if True and out_channels is specified, use a spatial
+ convolution instead of a smaller 1x1 convolution to change the
+ channels in the skip connection.
+ :param dims: determines if the signal is 1D, 2D, or 3D.
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
+ :param up: if True, use this block for upsampling.
+ :param down: if True, use this block for downsampling.
+ """
+
+ def __init__(
+ self,
+ channels,
+ emb_channels,
+ dropout,
+ out_channels=None,
+ use_conv=False,
+ use_scale_shift_norm=False,
+ dims=2,
+ up=False,
+ down=False,
+ use_temporal_conv=True,
+ use_image_dataset=False,
+ ):
+ super().__init__()
+ self.channels = channels
+ self.emb_channels = emb_channels
+ self.dropout = dropout
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.use_scale_shift_norm = use_scale_shift_norm
+ self.use_temporal_conv = use_temporal_conv
+
+ self.in_layers = nn.Sequential(
+ nn.GroupNorm(32, channels),
+ nn.SiLU(),
+ nn.Conv2d(channels, self.out_channels, 3, padding=1),
+ )
+
+ self.updown = up or down
+
+ if up:
+ self.h_upd = Upsample(channels, False, dims)
+ self.x_upd = Upsample(channels, False, dims)
+ elif down:
+ self.h_upd = Downsample(channels, False, dims)
+ self.x_upd = Downsample(channels, False, dims)
+ else:
+ self.h_upd = self.x_upd = nn.Identity()
+
+ self.emb_layers = nn.Sequential(
+ nn.SiLU(),
+ nn.Linear(
+ emb_channels,
+ 2 * self.out_channels
+ if use_scale_shift_norm else self.out_channels,
+ ),
+ )
+ self.out_layers = nn.Sequential(
+ nn.GroupNorm(32, self.out_channels),
+ nn.SiLU(),
+ nn.Dropout(p=dropout),
+ zero_module(
+ nn.Conv2d(self.out_channels, self.out_channels, 3, padding=1)),
+ )
+
+ if self.out_channels == channels:
+ self.skip_connection = nn.Identity()
+ elif use_conv:
+ self.skip_connection = conv_nd(
+ dims, channels, self.out_channels, 3, padding=1)
+ else:
+ self.skip_connection = nn.Conv2d(channels, self.out_channels, 1)
+
+ if self.use_temporal_conv:
+ self.temopral_conv = TemporalConvBlock_v2(
+ self.out_channels,
+ self.out_channels,
+ dropout=0.1,
+ use_image_dataset=use_image_dataset)
+
+ def forward(self, x, emb, batch_size):
+ """
+ Apply the block to a Tensor, conditioned on a timestep embedding.
+ :param x: an [N x C x ...] Tensor of features.
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
+ :return: an [N x C x ...] Tensor of outputs.
+ """
+ return self._forward(x, emb, batch_size)
+
+ def _forward(self, x, emb, batch_size):
+ if self.updown:
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
+ h = in_rest(x)
+ h = self.h_upd(h)
+ x = self.x_upd(x)
+ h = in_conv(h)
+ else:
+ h = self.in_layers(x)
+ emb_out = self.emb_layers(emb).type(h.dtype)
+ while len(emb_out.shape) < len(h.shape):
+ emb_out = emb_out[..., None]
+ if self.use_scale_shift_norm:
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
+ scale, shift = th.chunk(emb_out, 2, dim=1)
+ h = out_norm(h) * (1 + scale) + shift
+ h = out_rest(h)
+ else:
+ h = h + emb_out
+ h = self.out_layers(h)
+ h = self.skip_connection(x) + h
+
+ if self.use_temporal_conv:
+ h = rearrange(h, '(b f) c h w -> b c f h w', b=batch_size)
+ h = self.temopral_conv(h)
+ h = rearrange(h, 'b c f h w -> (b f) c h w')
+ return h
+
+
+class Downsample(nn.Module):
+ """
+ A downsampling layer with an optional convolution.
+ :param channels: channels in the inputs and outputs.
+ :param use_conv: a bool determining if a convolution is applied.
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ downsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(self,
+ channels,
+ use_conv,
+ dims=2,
+ out_channels=None,
+ padding=1):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.dims = dims
+ stride = 2 if dims != 3 else (1, 2, 2)
+ if use_conv:
+ self.op = nn.Conv2d(
+ self.channels,
+ self.out_channels,
+ 3,
+ stride=stride,
+ padding=padding)
+ else:
+ assert self.channels == self.out_channels
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ return self.op(x)
+
+
+class Resample(nn.Module):
+
+ def __init__(self, in_dim, out_dim, mode):
+ assert mode in ['none', 'upsample', 'downsample']
+ super(Resample, self).__init__()
+ self.in_dim = in_dim
+ self.out_dim = out_dim
+ self.mode = mode
+
+ def forward(self, x, reference=None):
+ if self.mode == 'upsample':
+ assert reference is not None
+ x = F.interpolate(x, size=reference.shape[-2:], mode='nearest')
+ elif self.mode == 'downsample':
+ x = F.adaptive_avg_pool2d(
+ x, output_size=tuple(u // 2 for u in x.shape[-2:]))
+ return x
+
+
+class ResidualBlock(nn.Module):
+
+ def __init__(self,
+ in_dim,
+ embed_dim,
+ out_dim,
+ use_scale_shift_norm=True,
+ mode='none',
+ dropout=0.0):
+ super(ResidualBlock, self).__init__()
+ self.in_dim = in_dim
+ self.embed_dim = embed_dim
+ self.out_dim = out_dim
+ self.use_scale_shift_norm = use_scale_shift_norm
+ self.mode = mode
+
+ # layers
+ self.layer1 = nn.Sequential(
+ nn.GroupNorm(32, in_dim), nn.SiLU(),
+ nn.Conv2d(in_dim, out_dim, 3, padding=1))
+ self.resample = Resample(in_dim, in_dim, mode)
+ self.embedding = nn.Sequential(
+ nn.SiLU(),
+ nn.Linear(embed_dim,
+ out_dim * 2 if use_scale_shift_norm else out_dim))
+ self.layer2 = nn.Sequential(
+ nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout),
+ nn.Conv2d(out_dim, out_dim, 3, padding=1))
+ self.shortcut = nn.Identity() if in_dim == out_dim else nn.Conv2d(
+ in_dim, out_dim, 1)
+
+ # zero out the last layer params
+ nn.init.zeros_(self.layer2[-1].weight)
+
+ def forward(self, x, e, reference=None):
+ identity = self.resample(x, reference)
+ x = self.layer1[-1](self.resample(self.layer1[:-1](x), reference))
+ e = self.embedding(e).unsqueeze(-1).unsqueeze(-1).type(x.dtype)
+ if self.use_scale_shift_norm:
+ scale, shift = e.chunk(2, dim=1)
+ x = self.layer2[0](x) * (1 + scale) + shift
+ x = self.layer2[1:](x)
+ else:
+ x = x + e
+ x = self.layer2(x)
+ x = x + self.shortcut(identity)
+ return x
+
+
+class AttentionBlock(nn.Module):
+
+ def __init__(self, dim, context_dim=None, num_heads=None, head_dim=None):
+ # consider head_dim first, then num_heads
+ num_heads = dim // head_dim if head_dim else num_heads
+ head_dim = dim // num_heads
+ assert num_heads * head_dim == dim
+ super(AttentionBlock, self).__init__()
+ self.dim = dim
+ self.context_dim = context_dim
+ self.num_heads = num_heads
+ self.head_dim = head_dim
+ self.scale = math.pow(head_dim, -0.25)
+
+ # layers
+ self.norm = nn.GroupNorm(32, dim)
+ self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
+ if context_dim is not None:
+ self.context_kv = nn.Linear(context_dim, dim * 2)
+ self.proj = nn.Conv2d(dim, dim, 1)
+
+ # zero out the last layer params
+ nn.init.zeros_(self.proj.weight)
+
+ def forward(self, x, context=None):
+ r"""x: [B, C, H, W].
+ context: [B, L, C] or None.
+ """
+ identity = x
+ b, c, h, w, n, d = *x.size(), self.num_heads, self.head_dim
+
+ # compute query, key, value
+ x = self.norm(x)
+ q, k, v = self.to_qkv(x).view(b, n * 3, d, h * w).chunk(3, dim=1)
+ if context is not None:
+ ck, cv = self.context_kv(context).reshape(b, -1, n * 2,
+ d).permute(0, 2, 3,
+ 1).chunk(
+ 2, dim=1)
+ k = torch.cat([ck, k], dim=-1)
+ v = torch.cat([cv, v], dim=-1)
+
+ # compute attention
+ attn = torch.matmul(q.transpose(-1, -2) * self.scale, k * self.scale)
+ attn = F.softmax(attn, dim=-1)
+
+ # gather context
+ x = torch.matmul(v, attn.transpose(-1, -2))
+ x = x.reshape(b, c, h, w)
+
+ # output
+ x = self.proj(x)
+ return x + identity
+
+
+class TemporalAttentionBlock(nn.Module):
+
+ def __init__(self,
+ dim,
+ heads=4,
+ dim_head=32,
+ rotary_emb=None,
+ use_image_dataset=False,
+ use_sim_mask=False):
+ super().__init__()
+ # consider num_heads first, as pos_bias needs fixed num_heads
+ dim_head = dim // heads
+ assert heads * dim_head == dim
+ self.use_image_dataset = use_image_dataset
+ self.use_sim_mask = use_sim_mask
+
+ self.scale = dim_head**-0.5
+ self.heads = heads
+ hidden_dim = dim_head * heads
+
+ self.norm = nn.GroupNorm(32, dim)
+ self.rotary_emb = rotary_emb
+ self.to_qkv = nn.Linear(dim, hidden_dim * 3)
+ self.to_out = nn.Linear(hidden_dim, dim)
+
+ def forward(self,
+ x,
+ pos_bias=None,
+ focus_present_mask=None,
+ video_mask=None):
+
+ identity = x
+ n, height, device = x.shape[2], x.shape[-2], x.device
+
+ x = self.norm(x)
+ x = rearrange(x, 'b c f h w -> b (h w) f c')
+
+ qkv = self.to_qkv(x).chunk(3, dim=-1)
+
+ if exists(focus_present_mask) and focus_present_mask.all():
+ # if all batch samples are focusing on present
+ # it would be equivalent to passing that token's values (v=qkv[-1]) through to the output
+ values = qkv[-1]
+ out = self.to_out(values)
+ out = rearrange(out, 'b (h w) f c -> b c f h w', h=height)
+
+ return out + identity
+
+ # split out heads
+ q = rearrange(qkv[0], '... n (h d) -> ... h n d', h=self.heads)
+ k = rearrange(qkv[1], '... n (h d) -> ... h n d', h=self.heads)
+ v = rearrange(qkv[2], '... n (h d) -> ... h n d', h=self.heads)
+
+ # scale
+ q = q * self.scale
+
+ # rotate positions into queries and keys for time attention
+ if exists(self.rotary_emb):
+ q = self.rotary_emb.rotate_queries_or_keys(q)
+ k = self.rotary_emb.rotate_queries_or_keys(k)
+
+ # similarity
+ sim = torch.einsum('... h i d, ... h j d -> ... h i j', q, k)
+
+ # relative positional bias
+
+ if exists(pos_bias):
+ sim = sim + pos_bias
+
+ if (focus_present_mask is None and video_mask is not None):
+ mask = video_mask[:, None, :] * video_mask[:, :, None]
+ mask = mask.unsqueeze(1).unsqueeze(1)
+ sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
+ elif exists(focus_present_mask) and not (~focus_present_mask).all():
+ attend_all_mask = torch.ones((n, n),
+ device=device,
+ dtype=torch.bool)
+ attend_self_mask = torch.eye(n, device=device, dtype=torch.bool)
+
+ mask = torch.where(
+ rearrange(focus_present_mask, 'b -> b 1 1 1 1'),
+ rearrange(attend_self_mask, 'i j -> 1 1 1 i j'),
+ rearrange(attend_all_mask, 'i j -> 1 1 1 i j'),
+ )
+
+ sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
+
+ if self.use_sim_mask:
+ sim_mask = torch.tril(
+ torch.ones((n, n), device=device, dtype=torch.bool),
+ diagonal=0)
+ sim = sim.masked_fill(~sim_mask, -torch.finfo(sim.dtype).max)
+
+ # numerical stability
+ sim = sim - sim.amax(dim=-1, keepdim=True).detach()
+ attn = sim.softmax(dim=-1)
+
+ # aggregate values
+ out = torch.einsum('... h i j, ... h j d -> ... h i d', attn, v)
+ out = rearrange(out, '... h n d -> ... n (h d)')
+ out = self.to_out(out)
+
+ out = rearrange(out, 'b (h w) f c -> b c f h w', h=height)
+
+ if self.use_image_dataset:
+ out = identity + 0 * out
+ else:
+ out = identity + out
+ return out
+
+
+class TemporalTransformer(nn.Module):
+ """
+ Transformer block for image-like data.
+ First, project the input (aka embedding)
+ and reshape to b, t, d.
+ Then apply standard transformer action.
+ Finally, reshape to image
+ """
+
+ def __init__(self,
+ in_channels,
+ n_heads,
+ d_head,
+ depth=1,
+ dropout=0.,
+ context_dim=None,
+ disable_self_attn=False,
+ use_linear=False,
+ use_checkpoint=True,
+ only_self_att=True,
+ multiply_zero=False):
+ super().__init__()
+ self.multiply_zero = multiply_zero
+ self.only_self_att = only_self_att
+ self.use_adaptor = False
+ if self.only_self_att:
+ context_dim = None
+ if not isinstance(context_dim, list):
+ context_dim = [context_dim]
+ self.in_channels = in_channels
+ inner_dim = n_heads * d_head
+ self.norm = torch.nn.GroupNorm(
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
+ if not use_linear:
+ self.proj_in = nn.Conv1d(
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
+ else:
+ self.proj_in = nn.Linear(in_channels, inner_dim)
+ if self.use_adaptor:
+ self.adaptor_in = nn.Linear(frames, frames)
+
+ self.transformer_blocks = nn.ModuleList([
+ BasicTransformerBlock(
+ inner_dim,
+ n_heads,
+ d_head,
+ dropout=dropout,
+ context_dim=context_dim[d],
+ checkpoint=use_checkpoint) for d in range(depth)
+ ])
+ if not use_linear:
+ self.proj_out = zero_module(
+ nn.Conv1d(
+ inner_dim, in_channels, kernel_size=1, stride=1,
+ padding=0))
+ else:
+ self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
+ if self.use_adaptor:
+ self.adaptor_out = nn.Linear(frames, frames)
+ self.use_linear = use_linear
+
+ def forward(self, x, context=None):
+ # note: if no context is given, cross-attention defaults to self-attention
+ if self.only_self_att:
+ context = None
+ if not isinstance(context, list):
+ context = [context]
+ b, c, f, h, w = x.shape
+ x_in = x
+ x = self.norm(x)
+
+ if not self.use_linear:
+ x = rearrange(x, 'b c f h w -> (b h w) c f').contiguous()
+ x = self.proj_in(x)
+ if self.use_linear:
+ x = rearrange(
+ x, '(b f) c h w -> b (h w) f c', f=self.frames).contiguous()
+ x = self.proj_in(x)
+
+ if self.only_self_att:
+ x = rearrange(x, 'bhw c f -> bhw f c').contiguous()
+ for i, block in enumerate(self.transformer_blocks):
+ x = block(x)
+ x = rearrange(x, '(b hw) f c -> b hw f c', b=b).contiguous()
+ else:
+ x = rearrange(x, '(b hw) c f -> b hw f c', b=b).contiguous()
+ for i, block in enumerate(self.transformer_blocks):
+ context[i] = rearrange(
+ context[i], '(b f) l con -> b f l con',
+ f=self.frames).contiguous()
+ # calculate each batch one by one
+ # (since number in shape could not greater then 65,535 for some package)
+ for j in range(b):
+ context_i_j = repeat(
+ context[i][j],
+ 'f l con -> (f r) l con',
+ r=(h * w) // self.frames,
+ f=self.frames).contiguous()
+ x[j] = block(x[j], context=context_i_j)
+
+ if self.use_linear:
+ x = self.proj_out(x)
+ x = rearrange(x, 'b (h w) f c -> b f c h w', h=h, w=w).contiguous()
+ if not self.use_linear:
+ x = rearrange(x, 'b hw f c -> (b hw) c f').contiguous()
+ x = self.proj_out(x)
+ x = rearrange(
+ x, '(b h w) c f -> b c f h w', b=b, h=h, w=w).contiguous()
+
+ if self.multiply_zero:
+ x = 0.0 * x + x_in
+ else:
+ x = x + x_in
+ return x
+
+
+class TemporalAttentionMultiBlock(nn.Module):
+
+ def __init__(
+ self,
+ dim,
+ heads=4,
+ dim_head=32,
+ rotary_emb=None,
+ use_image_dataset=False,
+ use_sim_mask=False,
+ temporal_attn_times=1,
+ ):
+ super().__init__()
+ self.att_layers = nn.ModuleList([
+ TemporalAttentionBlock(dim, heads, dim_head, rotary_emb,
+ use_image_dataset, use_sim_mask)
+ for _ in range(temporal_attn_times)
+ ])
+
+ def forward(self,
+ x,
+ pos_bias=None,
+ focus_present_mask=None,
+ video_mask=None):
+ for layer in self.att_layers:
+ x = layer(x, pos_bias, focus_present_mask, video_mask)
+ return x
+
+
+class InitTemporalConvBlock(nn.Module):
+
+ def __init__(self,
+ in_dim,
+ out_dim=None,
+ dropout=0.0,
+ use_image_dataset=False):
+ super(InitTemporalConvBlock, self).__init__()
+ if out_dim is None:
+ out_dim = in_dim
+ self.in_dim = in_dim
+ self.out_dim = out_dim
+ self.use_image_dataset = use_image_dataset
+
+ # conv layers
+ self.conv = nn.Sequential(
+ nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout),
+ nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)))
+
+ # zero out the last layer params,so the conv block is identity
+ nn.init.zeros_(self.conv[-1].weight)
+ nn.init.zeros_(self.conv[-1].bias)
+
+ def forward(self, x):
+ identity = x
+ x = self.conv(x)
+ if self.use_image_dataset:
+ x = identity + 0 * x
+ else:
+ x = identity + x
+ return x
+
+
+class TemporalConvBlock(nn.Module):
+
+ def __init__(self,
+ in_dim,
+ out_dim=None,
+ dropout=0.0,
+ use_image_dataset=False):
+ super(TemporalConvBlock, self).__init__()
+ if out_dim is None:
+ out_dim = in_dim
+ self.in_dim = in_dim
+ self.out_dim = out_dim
+ self.use_image_dataset = use_image_dataset
+
+ # conv layers
+ self.conv1 = nn.Sequential(
+ nn.GroupNorm(32, in_dim), nn.SiLU(),
+ nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0)))
+ self.conv2 = nn.Sequential(
+ nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout),
+ nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)))
+
+ # zero out the last layer params,so the conv block is identity
+ nn.init.zeros_(self.conv2[-1].weight)
+ nn.init.zeros_(self.conv2[-1].bias)
+
+ def forward(self, x):
+ identity = x
+ x = self.conv1(x)
+ x = self.conv2(x)
+ if self.use_image_dataset:
+ x = identity + 0 * x
+ else:
+ x = identity + x
+ return x
+
+
+class TemporalConvBlock_v2(nn.Module):
+
+ def __init__(self,
+ in_dim,
+ out_dim=None,
+ dropout=0.0,
+ use_image_dataset=False):
+ super(TemporalConvBlock_v2, self).__init__()
+ if out_dim is None:
+ out_dim = in_dim
+ self.in_dim = in_dim
+ self.out_dim = out_dim
+ self.use_image_dataset = use_image_dataset
+
+ # conv layers
+ self.conv1 = nn.Sequential(
+ nn.GroupNorm(32, in_dim), nn.SiLU(),
+ nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0)))
+ self.conv2 = nn.Sequential(
+ nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout),
+ nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)))
+ self.conv3 = nn.Sequential(
+ nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout),
+ nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)))
+ self.conv4 = nn.Sequential(
+ nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout),
+ nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)))
+
+ # zero out the last layer params,so the conv block is identity
+ nn.init.zeros_(self.conv4[-1].weight)
+ nn.init.zeros_(self.conv4[-1].bias)
+
+ def forward(self, x):
+ identity = x
+ x = self.conv1(x)
+ x = self.conv2(x)
+ x = self.conv3(x)
+ x = self.conv4(x)
+
+ if self.use_image_dataset:
+ x = identity + 0.0 * x
+ else:
+ x = identity + x
+ return x
+
+
+class UNetSD_temporal(nn.Module):
+
+ def __init__(
+ self,
+ cfg,
+ in_dim=7,
+ dim=512,
+ y_dim=512,
+ context_dim=512,
+ hist_dim=156,
+ concat_dim=8,
+ out_dim=6,
+ dim_mult=[1, 2, 3, 4],
+ num_heads=None,
+ head_dim=64,
+ num_res_blocks=3,
+ attn_scales=[1 / 2, 1 / 4, 1 / 8],
+ use_scale_shift_norm=True,
+ dropout=0.1,
+ temporal_attn_times=1,
+ temporal_attention=True,
+ use_checkpoint=False,
+ use_image_dataset=False,
+ use_fps_condition=False,
+ use_sim_mask=False,
+ misc_dropout=0.5,
+ training=True,
+ inpainting=True,
+ video_compositions=['text', 'mask'],
+ p_all_zero=0.1,
+ p_all_keep=0.1,
+ zero_y=None,
+ black_image_feature=None,
+ ):
+ embed_dim = dim * 4
+ num_heads = num_heads if num_heads else dim // 32
+ super(UNetSD_temporal, self).__init__()
+ self.zero_y = zero_y
+ self.black_image_feature = black_image_feature
+ self.cfg = cfg
+ self.in_dim = in_dim
+ self.dim = dim
+ self.y_dim = y_dim
+ self.context_dim = context_dim
+ self.hist_dim = hist_dim
+ self.concat_dim = concat_dim
+ self.embed_dim = embed_dim
+ self.out_dim = out_dim
+ self.dim_mult = dim_mult
+ # for temporal attention
+ self.num_heads = num_heads
+ # for spatial attention
+ self.head_dim = head_dim
+ self.num_res_blocks = num_res_blocks
+ self.attn_scales = attn_scales
+ self.use_scale_shift_norm = use_scale_shift_norm
+ self.temporal_attn_times = temporal_attn_times
+ self.temporal_attention = temporal_attention
+ self.use_checkpoint = use_checkpoint
+ self.use_image_dataset = use_image_dataset
+ self.use_fps_condition = use_fps_condition
+ self.use_sim_mask = use_sim_mask
+ self.training = training
+ self.inpainting = inpainting
+ self.video_compositions = video_compositions
+ self.misc_dropout = misc_dropout
+ self.p_all_zero = p_all_zero
+ self.p_all_keep = p_all_keep
+
+ use_linear_in_temporal = False
+ transformer_depth = 1
+ disabled_sa = False
+ # params
+ enc_dims = [dim * u for u in [1] + dim_mult]
+ dec_dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
+ shortcut_dims = []
+ scale = 1.0
+ if hasattr(cfg, 'adapter_transformer_layers'
+ ) and cfg.adapter_transformer_layers:
+ adapter_transformer_layers = cfg.adapter_transformer_layers
+ else:
+ adapter_transformer_layers = 1
+
+ # embeddings
+ self.time_embed = nn.Sequential(
+ nn.Linear(dim, embed_dim), nn.SiLU(),
+ nn.Linear(embed_dim, embed_dim))
+ self.pre_image_condition = nn.Sequential(
+ nn.Linear(1024, 1024), nn.SiLU(), nn.Linear(1024, 1024))
+
+ # depth embedding
+ if 'depthmap' in self.video_compositions:
+ self.depth_embedding = nn.Sequential(
+ nn.Conv2d(1, concat_dim * 4, 3, padding=1), nn.SiLU(),
+ nn.AdaptiveAvgPool2d((128, 128)),
+ nn.Conv2d(
+ concat_dim * 4, concat_dim * 4, 3, stride=2, padding=1),
+ nn.SiLU(),
+ nn.Conv2d(concat_dim * 4, concat_dim, 3, stride=2, padding=1))
+ self.depth_embedding_after = Transformer_v2(
+ heads=2,
+ dim=concat_dim,
+ dim_head_k=concat_dim,
+ dim_head_v=concat_dim,
+ dropout_atte=0.05,
+ mlp_dim=concat_dim,
+ dropout_ffn=0.05,
+ depth=adapter_transformer_layers)
+
+ if 'motion' in self.video_compositions:
+ self.motion_embedding = nn.Sequential(
+ nn.Conv2d(2, concat_dim * 4, 3, padding=1), nn.SiLU(),
+ nn.AdaptiveAvgPool2d((128, 128)),
+ nn.Conv2d(
+ concat_dim * 4, concat_dim * 4, 3, stride=2, padding=1),
+ nn.SiLU(),
+ nn.Conv2d(concat_dim * 4, concat_dim, 3, stride=2, padding=1))
+ self.motion_embedding_after = Transformer_v2(
+ heads=2,
+ dim=concat_dim,
+ dim_head_k=concat_dim,
+ dim_head_v=concat_dim,
+ dropout_atte=0.05,
+ mlp_dim=concat_dim,
+ dropout_ffn=0.05,
+ depth=adapter_transformer_layers)
+
+ # canny embedding
+ if 'canny' in self.video_compositions:
+ self.canny_embedding = nn.Sequential(
+ nn.Conv2d(1, concat_dim * 4, 3, padding=1), nn.SiLU(),
+ nn.AdaptiveAvgPool2d((128, 128)),
+ nn.Conv2d(
+ concat_dim * 4, concat_dim * 4, 3, stride=2, padding=1),
+ nn.SiLU(),
+ nn.Conv2d(concat_dim * 4, concat_dim, 3, stride=2, padding=1))
+ self.canny_embedding_after = Transformer_v2(
+ heads=2,
+ dim=concat_dim,
+ dim_head_k=concat_dim,
+ dim_head_v=concat_dim,
+ dropout_atte=0.05,
+ mlp_dim=concat_dim,
+ dropout_ffn=0.05,
+ depth=adapter_transformer_layers)
+
+ # masked-image embedding
+ if 'mask' in self.video_compositions:
+ self.masked_embedding = nn.Sequential(
+ nn.Conv2d(4, concat_dim * 4, 3, padding=1), nn.SiLU(),
+ nn.AdaptiveAvgPool2d((128, 128)),
+ nn.Conv2d(
+ concat_dim * 4, concat_dim
+ * 4, 3, stride=2, padding=1), nn.SiLU(),
+ nn.Conv2d(concat_dim * 4, concat_dim, 3, stride=2,
+ padding=1)) if inpainting else None
+ self.mask_embedding_after = Transformer_v2(
+ heads=2,
+ dim=concat_dim,
+ dim_head_k=concat_dim,
+ dim_head_v=concat_dim,
+ dropout_atte=0.05,
+ mlp_dim=concat_dim,
+ dropout_ffn=0.05,
+ depth=adapter_transformer_layers)
+
+ # sketch embedding
+ if 'sketch' in self.video_compositions:
+ self.sketch_embedding = nn.Sequential(
+ nn.Conv2d(1, concat_dim * 4, 3, padding=1), nn.SiLU(),
+ nn.AdaptiveAvgPool2d((128, 128)),
+ nn.Conv2d(
+ concat_dim * 4, concat_dim * 4, 3, stride=2, padding=1),
+ nn.SiLU(),
+ nn.Conv2d(concat_dim * 4, concat_dim, 3, stride=2, padding=1))
+ self.sketch_embedding_after = Transformer_v2(
+ heads=2,
+ dim=concat_dim,
+ dim_head_k=concat_dim,
+ dim_head_v=concat_dim,
+ dropout_atte=0.05,
+ mlp_dim=concat_dim,
+ dropout_ffn=0.05,
+ depth=adapter_transformer_layers)
+
+ if 'single_sketch' in self.video_compositions:
+ self.single_sketch_embedding = nn.Sequential(
+ nn.Conv2d(1, concat_dim * 4, 3, padding=1), nn.SiLU(),
+ nn.AdaptiveAvgPool2d((128, 128)),
+ nn.Conv2d(
+ concat_dim * 4, concat_dim * 4, 3, stride=2, padding=1),
+ nn.SiLU(),
+ nn.Conv2d(concat_dim * 4, concat_dim, 3, stride=2, padding=1))
+ self.single_sketch_embedding_after = Transformer_v2(
+ heads=2,
+ dim=concat_dim,
+ dim_head_k=concat_dim,
+ dim_head_v=concat_dim,
+ dropout_atte=0.05,
+ mlp_dim=concat_dim,
+ dropout_ffn=0.05,
+ depth=adapter_transformer_layers)
+
+ if 'local_image' in self.video_compositions:
+ self.local_image_embedding = nn.Sequential(
+ nn.Conv2d(3, concat_dim * 4, 3, padding=1), nn.SiLU(),
+ nn.AdaptiveAvgPool2d((128, 128)),
+ nn.Conv2d(
+ concat_dim * 4, concat_dim * 4, 3, stride=2, padding=1),
+ nn.SiLU(),
+ nn.Conv2d(concat_dim * 4, concat_dim, 3, stride=2, padding=1))
+ self.local_image_embedding_after = Transformer_v2(
+ heads=2,
+ dim=concat_dim,
+ dim_head_k=concat_dim,
+ dim_head_v=concat_dim,
+ dropout_atte=0.05,
+ mlp_dim=concat_dim,
+ dropout_ffn=0.05,
+ depth=adapter_transformer_layers)
+
+ # Condition Dropout
+ self.misc_dropout = DropPath(misc_dropout)
+
+ if temporal_attention and not USE_TEMPORAL_TRANSFORMER:
+ self.rotary_emb = RotaryEmbedding(min(32, head_dim))
+ self.time_rel_pos_bias = RelativePositionBias(
+ heads=num_heads, max_distance=32)
+
+ if self.use_fps_condition:
+ self.fps_embedding = nn.Sequential(
+ nn.Linear(dim, embed_dim), nn.SiLU(),
+ nn.Linear(embed_dim, embed_dim))
+ nn.init.zeros_(self.fps_embedding[-1].weight)
+ nn.init.zeros_(self.fps_embedding[-1].bias)
+
+ # encoder
+ self.input_blocks = nn.ModuleList()
+ if cfg.resume:
+ self.pre_image = nn.Sequential()
+ init_block = nn.ModuleList(
+ [nn.Conv2d(self.in_dim + concat_dim, dim, 3, padding=1)])
+ else:
+ self.pre_image = nn.Sequential(
+ nn.Conv2d(self.in_dim + concat_dim, self.in_dim, 1, padding=0))
+ init_block = nn.ModuleList(
+ [nn.Conv2d(self.in_dim, dim, 3, padding=1)])
+
+ # need an initial temporal attention?
+ if temporal_attention:
+ if USE_TEMPORAL_TRANSFORMER:
+ init_block.append(
+ TemporalTransformer(
+ dim,
+ num_heads,
+ head_dim,
+ depth=transformer_depth,
+ context_dim=context_dim,
+ disable_self_attn=disabled_sa,
+ use_linear=use_linear_in_temporal,
+ multiply_zero=use_image_dataset))
+ else:
+ init_block.append(
+ TemporalAttentionMultiBlock(
+ dim,
+ num_heads,
+ head_dim,
+ rotary_emb=self.rotary_emb,
+ temporal_attn_times=temporal_attn_times,
+ use_image_dataset=use_image_dataset))
+
+ self.input_blocks.append(init_block)
+ shortcut_dims.append(dim)
+ for i, (in_dim,
+ out_dim) in enumerate(zip(enc_dims[:-1], enc_dims[1:])):
+ for j in range(num_res_blocks):
+ # residual (+attention) blocks
+ block = nn.ModuleList([
+ ResBlock(
+ in_dim,
+ embed_dim,
+ dropout,
+ out_channels=out_dim,
+ use_scale_shift_norm=False,
+ use_image_dataset=use_image_dataset,
+ )
+ ])
+ if scale in attn_scales:
+ #
+ block.append(
+ SpatialTransformer(
+ out_dim,
+ out_dim // head_dim,
+ head_dim,
+ depth=1,
+ context_dim=self.context_dim,
+ disable_self_attn=False,
+ use_linear=True))
+ if self.temporal_attention:
+ if USE_TEMPORAL_TRANSFORMER:
+ block.append(
+ TemporalTransformer(
+ out_dim,
+ out_dim // head_dim,
+ head_dim,
+ depth=transformer_depth,
+ context_dim=context_dim,
+ disable_self_attn=disabled_sa,
+ use_linear=use_linear_in_temporal,
+ multiply_zero=use_image_dataset))
+ else:
+ block.append(
+ TemporalAttentionMultiBlock(
+ out_dim,
+ num_heads,
+ head_dim,
+ rotary_emb=self.rotary_emb,
+ use_image_dataset=use_image_dataset,
+ use_sim_mask=use_sim_mask,
+ temporal_attn_times=temporal_attn_times))
+ in_dim = out_dim
+ self.input_blocks.append(block)
+ shortcut_dims.append(out_dim)
+
+ # downsample
+ if i != len(dim_mult) - 1 and j == num_res_blocks - 1:
+ downsample = Downsample(
+ out_dim, True, dims=2, out_channels=out_dim)
+ shortcut_dims.append(out_dim)
+ scale /= 2.0
+ self.input_blocks.append(downsample)
+
+ # middle
+ self.middle_block = nn.ModuleList([
+ ResBlock(
+ out_dim,
+ embed_dim,
+ dropout,
+ use_scale_shift_norm=False,
+ use_image_dataset=use_image_dataset,
+ ),
+ SpatialTransformer(
+ out_dim,
+ out_dim // head_dim,
+ head_dim,
+ depth=1,
+ context_dim=self.context_dim,
+ disable_self_attn=False,
+ use_linear=True)
+ ])
+
+ if self.temporal_attention:
+ if USE_TEMPORAL_TRANSFORMER:
+ self.middle_block.append(
+ TemporalTransformer(
+ out_dim,
+ out_dim // head_dim,
+ head_dim,
+ depth=transformer_depth,
+ context_dim=context_dim,
+ disable_self_attn=disabled_sa,
+ use_linear=use_linear_in_temporal,
+ multiply_zero=use_image_dataset,
+ ))
+ else:
+ self.middle_block.append(
+ TemporalAttentionMultiBlock(
+ out_dim,
+ num_heads,
+ head_dim,
+ rotary_emb=self.rotary_emb,
+ use_image_dataset=use_image_dataset,
+ use_sim_mask=use_sim_mask,
+ temporal_attn_times=temporal_attn_times))
+
+ self.middle_block.append(
+ ResBlock(out_dim, embed_dim, dropout, use_scale_shift_norm=False))
+
+ # decoder
+ self.output_blocks = nn.ModuleList()
+ for i, (in_dim,
+ out_dim) in enumerate(zip(dec_dims[:-1], dec_dims[1:])):
+ for j in range(num_res_blocks + 1):
+ block = nn.ModuleList([
+ ResBlock(
+ in_dim + shortcut_dims.pop(),
+ embed_dim,
+ dropout,
+ out_dim,
+ use_scale_shift_norm=False,
+ use_image_dataset=use_image_dataset,
+ )
+ ])
+ if scale in attn_scales:
+ block.append(
+ SpatialTransformer(
+ out_dim,
+ out_dim // head_dim,
+ head_dim,
+ depth=1,
+ context_dim=1024,
+ disable_self_attn=False,
+ use_linear=True))
+ if self.temporal_attention:
+ if USE_TEMPORAL_TRANSFORMER:
+ block.append(
+ TemporalTransformer(
+ out_dim,
+ out_dim // head_dim,
+ head_dim,
+ depth=transformer_depth,
+ context_dim=context_dim,
+ disable_self_attn=disabled_sa,
+ use_linear=use_linear_in_temporal,
+ multiply_zero=use_image_dataset))
+ else:
+ block.append(
+ TemporalAttentionMultiBlock(
+ out_dim,
+ num_heads,
+ head_dim,
+ rotary_emb=self.rotary_emb,
+ use_image_dataset=use_image_dataset,
+ use_sim_mask=use_sim_mask,
+ temporal_attn_times=temporal_attn_times))
+ in_dim = out_dim
+
+ # upsample
+ if i != len(dim_mult) - 1 and j == num_res_blocks:
+ upsample = Upsample(
+ out_dim, True, dims=2.0, out_channels=out_dim)
+ scale *= 2.0
+ block.append(upsample)
+ self.output_blocks.append(block)
+
+ # head
+ self.out = nn.Sequential(
+ nn.GroupNorm(32, out_dim), nn.SiLU(),
+ nn.Conv2d(out_dim, self.out_dim, 3, padding=1))
+
+ # zero out the last layer params
+ nn.init.zeros_(self.out[-1].weight)
+
+ def forward(
+ self,
+ x,
+ t,
+ y=None,
+ depth=None,
+ image=None,
+ motion=None,
+ local_image=None,
+ single_sketch=None,
+ masked=None,
+ canny=None,
+ sketch=None,
+ histogram=None,
+ fps=None,
+ video_mask=None,
+ focus_present_mask=None,
+ prob_focus_present=0.,
+ mask_last_frame_num=0 # mask last frame num
+ ):
+
+ assert self.inpainting or masked is None, 'inpainting is not supported'
+
+ batch, c, f, h, w = x.shape
+ device = x.device
+ self.batch = batch
+
+ # image and video joint training, if mask_last_frame_num is set, prob_focus_present will be ignored
+ if mask_last_frame_num > 0:
+ focus_present_mask = None
+ video_mask[-mask_last_frame_num:] = False
+ else:
+ focus_present_mask = default(
+ focus_present_mask, lambda: prob_mask_like(
+ (batch, ), prob_focus_present, device=device))
+
+ if self.temporal_attention and not USE_TEMPORAL_TRANSFORMER:
+ time_rel_pos_bias = self.time_rel_pos_bias(
+ x.shape[2], device=x.device)
+ else:
+ time_rel_pos_bias = None
+
+ # all-zero and all-keep masks
+ zero = torch.zeros(batch, dtype=torch.bool).to(x.device)
+ keep = torch.zeros(batch, dtype=torch.bool).to(x.device)
+ if self.training:
+ nzero = (torch.rand(batch) < self.p_all_zero).sum()
+ nkeep = (torch.rand(batch) < self.p_all_keep).sum()
+ index = torch.randperm(batch)
+ zero[index[0:nzero]] = True
+ keep[index[nzero:nzero + nkeep]] = True
+ assert not (zero & keep).any()
+ misc_dropout = partial(self.misc_dropout, zero=zero, keep=keep)
+
+ concat = x.new_zeros(batch, self.concat_dim, f, h, w)
+ if depth is not None:
+ # DropPath mask
+ depth = rearrange(depth, 'b c f h w -> (b f) c h w')
+ depth = self.depth_embedding(depth)
+ h = depth.shape[2]
+ depth = self.depth_embedding_after(
+ rearrange(depth, '(b f) c h w -> (b h w) f c', b=batch))
+
+ #
+ depth = rearrange(depth, '(b h w) f c -> b c f h w', b=batch, h=h)
+ concat = concat + misc_dropout(depth)
+
+ # local_image_embedding
+ if local_image is not None:
+ local_image = rearrange(local_image, 'b c f h w -> (b f) c h w')
+ local_image = self.local_image_embedding(local_image)
+
+ h = local_image.shape[2]
+ local_image = self.local_image_embedding_after(
+ rearrange(local_image, '(b f) c h w -> (b h w) f c', b=batch))
+ local_image = rearrange(
+ local_image, '(b h w) f c -> b c f h w', b=batch, h=h)
+
+ concat = concat + misc_dropout(local_image)
+
+ if motion is not None:
+ motion = rearrange(motion, 'b c f h w -> (b f) c h w')
+ motion = self.motion_embedding(motion)
+
+ h = motion.shape[2]
+ motion = self.motion_embedding_after(
+ rearrange(motion, '(b f) c h w -> (b h w) f c', b=batch))
+ motion = rearrange(
+ motion, '(b h w) f c -> b c f h w', b=batch, h=h)
+
+ if hasattr(self.cfg, 'p_zero_motion_alone'
+ ) and self.cfg.p_zero_motion_alone and self.training:
+ motion_d = torch.rand(batch) < self.cfg.p_zero_motion
+ motion_d = motion_d[:, None, None, None, None]
+ motion = motion.masked_fill(motion_d.cuda(), 0)
+ concat = concat + motion
+ else:
+ concat = concat + misc_dropout(motion)
+
+ if canny is not None:
+ # DropPath mask
+ canny = rearrange(canny, 'b c f h w -> (b f) c h w')
+ canny = self.canny_embedding(canny)
+
+ h = canny.shape[2]
+ canny = self.canny_embedding_after(
+ rearrange(canny, '(b f) c h w -> (b h w) f c', b=batch))
+ canny = rearrange(canny, '(b h w) f c -> b c f h w', b=batch, h=h)
+
+ concat = concat + misc_dropout(canny)
+
+ if sketch is not None:
+ # DropPath mask
+ sketch = rearrange(sketch, 'b c f h w -> (b f) c h w')
+ sketch = self.sketch_embedding(sketch)
+
+ h = sketch.shape[2]
+ sketch = self.sketch_embedding_after(
+ rearrange(sketch, '(b f) c h w -> (b h w) f c', b=batch))
+ sketch = rearrange(
+ sketch, '(b h w) f c -> b c f h w', b=batch, h=h)
+
+ concat = concat + misc_dropout(sketch)
+
+ if single_sketch is not None:
+ # DropPath mask
+ single_sketch = rearrange(single_sketch,
+ 'b c f h w -> (b f) c h w')
+ single_sketch = self.single_sketch_embedding(single_sketch)
+
+ h = single_sketch.shape[2]
+ single_sketch = self.single_sketch_embedding_after(
+ rearrange(
+ single_sketch, '(b f) c h w -> (b h w) f c', b=batch))
+ single_sketch = rearrange(
+ single_sketch, '(b h w) f c -> b c f h w', b=batch, h=h)
+
+ concat = concat + misc_dropout(single_sketch)
+
+ if masked is not None:
+ # DropPath mask
+ masked = rearrange(masked, 'b c f h w -> (b f) c h w')
+ masked = self.masked_embedding(masked)
+
+ h = masked.shape[2]
+ masked = self.mask_embedding_after(
+ rearrange(masked, '(b f) c h w -> (b h w) f c', b=batch))
+ masked = rearrange(
+ masked, '(b h w) f c -> b c f h w', b=batch, h=h)
+
+ concat = concat + misc_dropout(masked)
+
+ x = torch.cat([x, concat], dim=1)
+ x = rearrange(x, 'b c f h w -> (b f) c h w')
+ x = self.pre_image(x)
+ x = rearrange(x, '(b f) c h w -> b c f h w', b=batch)
+
+ # embeddings
+ if self.use_fps_condition and fps is not None:
+ e = self.time_embed(sinusoidal_embedding(
+ t, self.dim)) + self.fps_embedding(
+ sinusoidal_embedding(fps, self.dim))
+ else:
+ e = self.time_embed(sinusoidal_embedding(t, self.dim))
+
+ context = x.new_zeros(batch, 0, self.context_dim)
+ if y is not None:
+ y_context = misc_dropout(y)
+ context = torch.cat([context, y_context], dim=1)
+ else:
+ y_context = self.zero_y.repeat(batch, 1, 1)
+ context = torch.cat([context, y_context], dim=1)
+
+ if image is not None:
+ image_context = misc_dropout(self.pre_image_condition(image))
+ context = torch.cat([context, image_context], dim=1)
+
+ # repeat f times for spatial e and context
+ e = e.repeat_interleave(repeats=f, dim=0)
+ context = context.repeat_interleave(repeats=f, dim=0)
+
+ # always in shape (b f) c h w, except for temporal layer
+ x = rearrange(x, 'b c f h w -> (b f) c h w')
+ # encoder
+ xs = []
+ for block in self.input_blocks:
+ x = self._forward_single(block, x, e, context, time_rel_pos_bias,
+ focus_present_mask, video_mask)
+ xs.append(x)
+
+ # middle
+ for block in self.middle_block:
+ x = self._forward_single(block, x, e, context, time_rel_pos_bias,
+ focus_present_mask, video_mask)
+
+ # decoder
+ for block in self.output_blocks:
+ x = torch.cat([x, xs.pop()], dim=1)
+ x = self._forward_single(
+ block,
+ x,
+ e,
+ context,
+ time_rel_pos_bias,
+ focus_present_mask,
+ video_mask,
+ reference=xs[-1] if len(xs) > 0 else None)
+
+ # head
+ x = self.out(x)
+
+ # reshape back to (b c f h w)
+ x = rearrange(x, '(b f) c h w -> b c f h w', b=batch)
+ return x
+
+ def _forward_single(self,
+ module,
+ x,
+ e,
+ context,
+ time_rel_pos_bias,
+ focus_present_mask,
+ video_mask,
+ reference=None):
+ if isinstance(module, ResidualBlock):
+ module = checkpoint_wrapper(
+ module) if self.use_checkpoint else module
+ x = x.contiguous()
+ x = module(x, e, reference)
+ elif isinstance(module, ResBlock):
+ module = checkpoint_wrapper(
+ module) if self.use_checkpoint else module
+ x = x.contiguous()
+ x = module(x, e, self.batch)
+ elif isinstance(module, SpatialTransformer):
+ module = checkpoint_wrapper(
+ module) if self.use_checkpoint else module
+ x = module(x, context)
+ elif isinstance(module, TemporalTransformer):
+ module = checkpoint_wrapper(
+ module) if self.use_checkpoint else module
+ x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch)
+ x = module(x, context)
+ x = rearrange(x, 'b c f h w -> (b f) c h w')
+ elif isinstance(module, CrossAttention):
+ module = checkpoint_wrapper(
+ module) if self.use_checkpoint else module
+ x = module(x, context)
+ elif isinstance(module, BasicTransformerBlock):
+ module = checkpoint_wrapper(
+ module) if self.use_checkpoint else module
+ x = module(x, context)
+ elif isinstance(module, FeedForward):
+ x = module(x, context)
+ elif isinstance(module, Upsample):
+ x = module(x)
+ elif isinstance(module, Downsample):
+ x = module(x)
+ elif isinstance(module, Resample):
+ x = module(x, reference)
+ elif isinstance(module, TemporalAttentionBlock):
+ module = checkpoint_wrapper(
+ module) if self.use_checkpoint else module
+ x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch)
+ x = module(x, time_rel_pos_bias, focus_present_mask, video_mask)
+ x = rearrange(x, 'b c f h w -> (b f) c h w')
+ elif isinstance(module, TemporalAttentionMultiBlock):
+ module = checkpoint_wrapper(
+ module) if self.use_checkpoint else module
+ x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch)
+ x = module(x, time_rel_pos_bias, focus_present_mask, video_mask)
+ x = rearrange(x, 'b c f h w -> (b f) c h w')
+ elif isinstance(module, InitTemporalConvBlock):
+ module = checkpoint_wrapper(
+ module) if self.use_checkpoint else module
+ x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch)
+ x = module(x)
+ x = rearrange(x, 'b c f h w -> (b f) c h w')
+ elif isinstance(module, TemporalConvBlock):
+ module = checkpoint_wrapper(
+ module) if self.use_checkpoint else module
+ x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch)
+ x = module(x)
+ x = rearrange(x, 'b c f h w -> (b f) c h w')
+ elif isinstance(module, nn.ModuleList):
+ for block in module:
+ x = self._forward_single(block, x, e, context,
+ time_rel_pos_bias, focus_present_mask,
+ video_mask, reference)
+ else:
+ x = module(x)
+ return x
+
+
+class PreNormattention(nn.Module):
+
+ def __init__(self, dim, fn):
+ super().__init__()
+ self.norm = nn.LayerNorm(dim)
+ self.fn = fn
+
+ def forward(self, x, **kwargs):
+ return self.fn(self.norm(x), **kwargs) + x
+
+
+class PreNormattention_qkv(nn.Module):
+
+ def __init__(self, dim, fn):
+ super().__init__()
+ self.norm = nn.LayerNorm(dim)
+ self.fn = fn
+
+ def forward(self, q, k, v, **kwargs):
+ return self.fn(self.norm(q), self.norm(k), self.norm(v), **kwargs) + q
+
+
+class Attention(nn.Module):
+
+ def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
+ super().__init__()
+ inner_dim = dim_head * heads
+ project_out = not (heads == 1 and dim_head == dim)
+
+ self.heads = heads
+ self.scale = dim_head**-0.5
+
+ self.attend = nn.Softmax(dim=-1)
+ self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
+
+ self.to_out = nn.Sequential(
+ nn.Linear(inner_dim, dim),
+ nn.Dropout(dropout)) if project_out else nn.Identity()
+
+ def forward(self, x):
+ _, _, _, h = *x.shape, self.heads
+ qkv = self.to_qkv(x).chunk(3, dim=-1)
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), qkv)
+
+ dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
+
+ attn = self.attend(dots)
+
+ out = einsum('b h i j, b h j d -> b h i d', attn, v)
+ out = rearrange(out, 'b h n d -> b n (h d)')
+ return self.to_out(out)
+
+
+class Attention_qkv(nn.Module):
+
+ def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
+ super().__init__()
+ inner_dim = dim_head * heads
+ project_out = not (heads == 1 and dim_head == dim)
+
+ self.heads = heads
+ self.scale = dim_head**-0.5
+
+ self.attend = nn.Softmax(dim=-1)
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
+ self.to_k = nn.Linear(dim, inner_dim, bias=False)
+ self.to_v = nn.Linear(dim, inner_dim, bias=False)
+
+ self.to_out = nn.Sequential(
+ nn.Linear(inner_dim, dim),
+ nn.Dropout(dropout)) if project_out else nn.Identity()
+
+ def forward(self, q, k, v):
+ _, _, _, h = *q.shape, self.heads
+ bk = k.shape[0]
+ q = self.to_q(q)
+ k = self.to_k(k)
+ v = self.to_v(v)
+ q = rearrange(q, 'b n (h d) -> b h n d', h=h)
+ k = rearrange(k, 'b n (h d) -> b h n d', b=bk, h=h)
+ v = rearrange(v, 'b n (h d) -> b h n d', b=bk, h=h)
+
+ dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
+
+ attn = self.attend(dots)
+
+ out = einsum('b h i j, b h j d -> b h i d', attn, v)
+ out = rearrange(out, 'b h n d -> b n (h d)')
+ return self.to_out(out)
+
+
+class PostNormattention(nn.Module):
+
+ def __init__(self, dim, fn):
+ super().__init__()
+ self.norm = nn.LayerNorm(dim)
+ self.fn = fn
+
+ def forward(self, x, **kwargs):
+ return self.norm(self.fn(x, **kwargs) + x)
+
+
+class Transformer_v2(nn.Module):
+
+ def __init__(self,
+ heads=8,
+ dim=2048,
+ dim_head_k=256,
+ dim_head_v=256,
+ dropout_atte=0.05,
+ mlp_dim=2048,
+ dropout_ffn=0.05,
+ depth=1):
+ super().__init__()
+ self.layers = nn.ModuleList([])
+ self.depth = depth
+ for _ in range(depth):
+ self.layers.append(
+ nn.ModuleList([
+ PreNormattention(
+ dim,
+ Attention(
+ dim,
+ heads=heads,
+ dim_head=dim_head_k,
+ dropout=dropout_atte)),
+ FeedForward(dim, mlp_dim, dropout=dropout_ffn),
+ ]))
+
+ def forward(self, x):
+ for attn, ff in self.layers[:1]:
+ x = attn(x)
+ x = ff(x) + x
+ if self.depth > 1:
+ for attn, ff in self.layers[1:]:
+ x = attn(x)
+ x = ff(x) + x
+ return x
+
+
+class DropPath(nn.Module):
+ r"""DropPath but without rescaling and supports optional all-zero and/or all-keep.
+ """
+
+ def __init__(self, p):
+ super(DropPath, self).__init__()
+ self.p = p
+
+ def forward(self, *args, zero=None, keep=None):
+ if not self.training:
+ return args[0] if len(args) == 1 else args
+
+ # params
+ x = args[0]
+ b = x.size(0)
+ n = (torch.rand(b) < self.p).sum()
+
+ # non-zero and non-keep mask
+ mask = x.new_ones(b, dtype=torch.bool)
+ if keep is not None:
+ mask[keep] = False
+ if zero is not None:
+ mask[zero] = False
+
+ # drop-path index
+ index = torch.where(mask)[0]
+ index = index[torch.randperm(len(index))[:n]]
+ if zero is not None:
+ index = torch.cat([index, torch.where(zero)[0]], dim=0)
+
+ # drop-path multiplier
+ multiplier = x.new_ones(b)
+ multiplier[index] = 0.0
+ output = tuple(u * self.broadcast(multiplier, u) for u in args)
+ return output[0] if len(args) == 1 else output
+
+ def broadcast(self, src, dst):
+ assert src.size(0) == dst.size(0)
+ shape = (dst.size(0), ) + (1, ) * (dst.ndim - 1)
+ return src.view(shape)
+
+
+if __name__ == '__main__':
+ from config import cfg
+ # [model] unet
+ model = UNetSD_temporal(
+ in_dim=cfg.unet_in_dim,
+ dim=cfg.unet_dim,
+ y_dim=cfg.unet_y_dim,
+ context_dim=cfg.unet_context_dim,
+ out_dim=cfg.unet_out_dim,
+ dim_mult=cfg.unet_dim_mult,
+ num_heads=cfg.unet_num_heads,
+ head_dim=cfg.unet_head_dim,
+ num_res_blocks=cfg.unet_res_blocks,
+ attn_scales=cfg.unet_attn_scales,
+ dropout=cfg.unet_dropout,
+ temporal_attn_times=0,
+ use_checkpoint=cfg.use_checkpoint,
+ use_image_dataset=True,
+ use_fps_condition=cfg.use_fps_condition)
+
+ print(
+ int(sum(p.numel() for k, p in model.named_parameters()) / (1024**2)),
+ 'M parameters')
diff --git a/modelscope/models/multi_modal/videocomposer/utils/__init__.py b/modelscope/models/multi_modal/videocomposer/utils/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/modelscope/models/multi_modal/videocomposer/utils/config.py b/modelscope/models/multi_modal/videocomposer/utils/config.py
new file mode 100644
index 00000000..18424257
--- /dev/null
+++ b/modelscope/models/multi_modal/videocomposer/utils/config.py
@@ -0,0 +1,273 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+
+import argparse
+import copy
+import os
+
+import json
+import yaml
+
+from modelscope.utils.logger import get_logger
+
+logger = get_logger()
+
+
+def setup_seed(seed):
+ print('Seed: ', seed)
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+
+
+class Config(object):
+
+ def __init__(self, load=True, cfg_dict=None, cfg_level=None):
+ self._level = 'cfg' + ('.'
+ + cfg_level if cfg_level is not None else '')
+ if load:
+ self.args = self._parse_args()
+ logger.info('Loading config from {}.'.format(self.args.cfg_file))
+ self.need_initialization = True
+ cfg_base = self._initialize_cfg()
+ cfg_dict = self._load_yaml(self.args)
+ cfg_dict = self._merge_cfg_from_base(cfg_base, cfg_dict)
+ cfg_dict = self._update_from_args(cfg_dict)
+ self.cfg_dict = cfg_dict
+ self._update_dict(cfg_dict)
+
+ def _parse_args(self):
+ parser = argparse.ArgumentParser(
+ description=
+ 'Argparser for configuring [code base name to think of] codebase')
+ parser.add_argument(
+ '--cfg',
+ dest='cfg_file',
+ help='Path to the configuration file',
+ default=
+ './modelscope/models/multi_modal/videocomposer/configs/exp06_text_depths_vs_style.yaml'
+ )
+ parser.add_argument(
+ '--init_method',
+ help='Initialization method, includes TCP or shared file-system',
+ default='tcp://localhost:9999',
+ type=str,
+ )
+ parser.add_argument(
+ '--seed',
+ type=int,
+ default=8888,
+ help='Need to explore for different videos')
+ parser.add_argument(
+ '--debug',
+ action='store_true',
+ default=False,
+ help='Into debug information')
+ parser.add_argument(
+ '--input_video',
+ default='demo_video/video_8800.mp4',
+ help='input video for full task, or motion vector of input videos',
+ type=str,
+ ),
+ parser.add_argument(
+ '--image_path', default='', help='Single Image Input', type=str)
+ parser.add_argument(
+ '--sketch_path', default='', help='Single Sketch Input', type=str)
+ parser.add_argument(
+ '--style_image', help='Single Sketch Input', type=str)
+ parser.add_argument(
+ '--input_text_desc',
+ default=
+ 'A colorful and beautiful fish swimming in a small glass bowl with \
+ multicolored piece of stone, Macro Video',
+ type=str,
+ ),
+ parser.add_argument(
+ 'opts',
+ help='other configurations',
+ default=None,
+ nargs=argparse.REMAINDER)
+ return parser.parse_args()
+
+ def _path_join(self, path_list):
+ path = ''
+ for p in path_list:
+ path += p + '/'
+ return path[:-1]
+
+ def _update_from_args(self, cfg_dict):
+ args = self.args
+ for var in vars(args):
+ cfg_dict[var] = getattr(args, var)
+ return cfg_dict
+
+ def _initialize_cfg(self):
+ if self.need_initialization:
+ self.need_initialization = False
+ if os.path.exists(
+ './modelscope/models/multi_modal/videocomposer/configs/base.yaml'
+ ):
+ with open(
+ './modelscope/models/multi_modal/videocomposer/configs/base.yaml',
+ 'r') as f:
+ cfg = yaml.load(f.read(), Loader=yaml.SafeLoader)
+ else:
+ with open(
+ './modelscope/models/multi_modal/videocomposer/configs/base.yaml',
+ 'r') as f:
+ cfg = yaml.load(f.read(), Loader=yaml.SafeLoader)
+ return cfg
+
+ def _load_yaml(self, args, file_name=''):
+ assert args.cfg_file is not None
+ if not file_name == '': # reading from base file
+ with open(file_name, 'r') as f:
+ cfg = yaml.load(f.read(), Loader=yaml.SafeLoader)
+ else:
+ if os.getcwd().split('/')[-1] == args.cfg_file.split('/')[0]:
+ args.cfg_file = args.cfg_file.replace(
+ os.getcwd().split('/')[-1], './')
+ try:
+ with open(args.cfg_file, 'r') as f:
+ cfg = yaml.load(f.read(), Loader=yaml.SafeLoader)
+ file_name = args.cfg_file
+ except Exception as e:
+ args.cfg_file = os.path.realpath(__file__).split(
+ '/')[-3] + '/' + args.cfg_file
+ with open(args.cfg_file, 'r') as f:
+ cfg = yaml.load(f.read(), Loader=yaml.SafeLoader)
+ file_name = args.cfg_file
+ print(e)
+
+ if '_BASE_RUN' not in cfg.keys() and '_BASE_MODEL' not in cfg.keys(
+ ) and '_BASE' not in cfg.keys():
+ return cfg
+
+ if '_BASE' in cfg.keys():
+ if cfg['_BASE'][1] == '.':
+ prev_count = cfg['_BASE'].count('..')
+ cfg_base_file = self._path_join(
+ file_name.split('/')[:(-1 - cfg['_BASE'].count('..'))]
+ + cfg['_BASE'].split('/')[prev_count:])
+ else:
+ cfg_base_file = cfg['_BASE'].replace(
+ './',
+ args.cfg_file.replace(args.cfg_file.split('/')[-1], ''))
+ cfg_base = self._load_yaml(args, cfg_base_file)
+ cfg = self._merge_cfg_from_base(cfg_base, cfg)
+ else:
+ if '_BASE_RUN' in cfg.keys():
+ if cfg['_BASE_RUN'][1] == '.':
+ prev_count = cfg['_BASE_RUN'].count('..')
+ cfg_base_file = self._path_join(
+ file_name.split('/')[:(-1 - prev_count)]
+ + cfg['_BASE_RUN'].split('/')[prev_count:])
+ else:
+ cfg_base_file = cfg['_BASE_RUN'].replace(
+ './',
+ args.cfg_file.replace(
+ args.cfg_file.split('/')[-1], ''))
+ cfg_base = self._load_yaml(args, cfg_base_file)
+ cfg = self._merge_cfg_from_base(
+ cfg_base, cfg, preserve_base=True)
+ if '_BASE_MODEL' in cfg.keys():
+ if cfg['_BASE_MODEL'][1] == '.':
+ prev_count = cfg['_BASE_MODEL'].count('..')
+ cfg_base_file = self._path_join(
+ file_name.split('/')[:(
+ -1 - cfg['_BASE_MODEL'].count('..'))]
+ + cfg['_BASE_MODEL'].split('/')[prev_count:])
+ else:
+ cfg_base_file = cfg['_BASE_MODEL'].replace(
+ './',
+ args.cfg_file.replace(
+ args.cfg_file.split('/')[-1], ''))
+ cfg_base = self._load_yaml(args, cfg_base_file)
+ cfg = self._merge_cfg_from_base(cfg_base, cfg)
+ cfg = self._merge_cfg_from_command(args, cfg)
+ return cfg
+
+ def _merge_cfg_from_base(self, cfg_base, cfg_new, preserve_base=False):
+ for k, v in cfg_new.items():
+ if k in cfg_base.keys():
+ if isinstance(v, dict):
+ self._merge_cfg_from_base(cfg_base[k], v)
+ else:
+ cfg_base[k] = v
+ else:
+ if 'BASE' not in k or preserve_base:
+ cfg_base[k] = v
+ return cfg_base
+
+ def _merge_cfg_from_command(self, args, cfg):
+ assert len(
+ args.opts) % 2 == 0, 'Override list {} has odd length: {}.'.format(
+ args.opts, len(args.opts))
+ keys = args.opts[0::2]
+ vals = args.opts[1::2]
+
+ # maximum supported depth 3
+ for idx, key in enumerate(keys):
+ key_split = key.split('.')
+ assert len(
+ key_split
+ ) <= 4, 'Key depth error. \nMaximum depth: 3\n Get depth: {}'.format(
+ len(key_split))
+ assert key_split[0] in cfg.keys(), 'Non-existant key: {}.'.format(
+ key_split[0])
+ if len(key_split) == 2:
+ assert key_split[1] in cfg[
+ key_split[0]].keys(), 'Non-existant key: {}.'.format(key)
+ elif len(key_split) == 3:
+ assert key_split[1] in cfg[
+ key_split[0]].keys(), 'Non-existant key: {}.'.format(key)
+ assert key_split[2] in cfg[key_split[0]][
+ key_split[1]].keys(), 'Non-existant key: {}.'.format(key)
+ elif len(key_split) == 4:
+ assert key_split[1] in cfg[
+ key_split[0]].keys(), 'Non-existant key: {}.'.format(key)
+ assert key_split[2] in cfg[key_split[0]][
+ key_split[1]].keys(), 'Non-existant key: {}.'.format(key)
+ assert key_split[3] in cfg[key_split[0]][key_split[1]][
+ key_split[2]].keys(), 'Non-existant key: {}.'.format(key)
+ if len(key_split) == 1:
+ cfg[key_split[0]] = vals[idx]
+ elif len(key_split) == 2:
+ cfg[key_split[0]][key_split[1]] = vals[idx]
+ elif len(key_split) == 3:
+ cfg[key_split[0]][key_split[1]][key_split[2]] = vals[idx]
+ elif len(key_split) == 4:
+ cfg[key_split[0]][key_split[1]][key_split[2]][
+ key_split[3]] = vals[idx]
+ return cfg
+
+ def _update_dict(self, cfg_dict):
+
+ def recur(key, elem):
+ if type(elem) is dict:
+ return key, Config(load=False, cfg_dict=elem, cfg_level=key)
+ else:
+ if type(elem) is str and elem[1:3] == 'e-':
+ elem = float(elem)
+ return key, elem
+
+ dic = dict(recur(k, v) for k, v in cfg_dict.items())
+ self.__dict__.update(dic)
+
+ def get_args(self):
+ return self.args
+
+ def __repr__(self):
+ return '{}\n'.format(self.dump())
+
+ def dump(self):
+ return json.dumps(self.cfg_dict, indent=2)
+
+ def deep_copy(self):
+ return copy.deepcopy(self)
+
+
+if __name__ == '__main__':
+ # debug
+ cfg = Config(load=True)
+ print(cfg.DATA)
diff --git a/modelscope/models/multi_modal/videocomposer/utils/distributed.py b/modelscope/models/multi_modal/videocomposer/utils/distributed.py
new file mode 100644
index 00000000..b622e700
--- /dev/null
+++ b/modelscope/models/multi_modal/videocomposer/utils/distributed.py
@@ -0,0 +1,297 @@
+#!/usr/bin/env python3
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+"""Distributed helpers."""
+
+import functools
+import logging
+import pickle
+
+import torch
+import torch.distributed as dist
+
+_LOCAL_PROCESS_GROUP = None
+
+
+def all_gather(tensors):
+ """
+ All gathers the provided tensors from all processes across machines.
+ Args:
+ tensors (list): tensors to perform all gather across all processes in
+ all machines.
+ """
+
+ gather_list = []
+ output_tensor = []
+ world_size = dist.get_world_size()
+ for tensor in tensors:
+ tensor_placeholder = [
+ torch.ones_like(tensor) for _ in range(world_size)
+ ]
+ dist.all_gather(tensor_placeholder, tensor, async_op=False)
+ gather_list.append(tensor_placeholder)
+ for gathered_tensor in gather_list:
+ output_tensor.append(torch.cat(gathered_tensor, dim=0))
+ return output_tensor
+
+
+def all_reduce(tensors, average=True):
+ """
+ All reduce the provided tensors from all processes across machines.
+ Args:
+ tensors (list): tensors to perform all reduce across all processes in
+ all machines.
+ average (bool): scales the reduced tensor by the number of overall
+ processes across all machines.
+ """
+
+ for tensor in tensors:
+ dist.all_reduce(tensor, async_op=False)
+ if average:
+ world_size = dist.get_world_size()
+ for tensor in tensors:
+ tensor.mul_(1.0 / world_size)
+ return tensors
+
+
+def init_process_group(
+ local_rank,
+ local_world_size,
+ shard_id,
+ num_shards,
+ init_method,
+ dist_backend='nccl',
+):
+ """
+ Initializes the default process group.
+ Args:
+ local_rank (int): the rank on the current local machine.
+ local_world_size (int): the world size (number of processes running) on
+ the current local machine.
+ shard_id (int): the shard index (machine rank) of the current machine.
+ num_shards (int): number of shards for distributed training.
+ init_method (string): supporting three different methods for
+ initializing process groups:
+ "file": use shared file system to initialize the groups across
+ different processes.
+ "tcp": use tcp address to initialize the groups across different
+ dist_backend (string): backend to use for distributed training. Options
+ includes gloo, mpi and nccl, the details can be found here:
+ https://pytorch.org/docs/stable/distributed.html
+ """
+ # Sets the GPU to use.
+ torch.cuda.set_device(local_rank)
+ # Initialize the process group.
+ proc_rank = local_rank + shard_id * local_world_size
+ world_size = local_world_size * num_shards
+ dist.init_process_group(
+ backend=dist_backend,
+ init_method=init_method,
+ world_size=world_size,
+ rank=proc_rank,
+ )
+
+
+def is_master_proc(num_gpus=8):
+ """
+ Determines if the current process is the master process.
+ """
+ if torch.distributed.is_initialized():
+ return dist.get_rank() % num_gpus == 0
+ else:
+ return True
+
+
+def get_world_size():
+ """
+ Get the size of the world.
+ """
+ if not dist.is_available():
+ return 1
+ if not dist.is_initialized():
+ return 1
+ return dist.get_world_size()
+
+
+def get_rank():
+ """
+ Get the rank of the current process.
+ """
+ if not dist.is_available():
+ return 0
+ if not dist.is_initialized():
+ return 0
+ return dist.get_rank()
+
+
+def synchronize():
+ """
+ Helper function to synchronize (barrier) among all processes when
+ using distributed training
+ """
+ if not dist.is_available():
+ return
+ if not dist.is_initialized():
+ return
+ world_size = dist.get_world_size()
+ if world_size == 1:
+ return
+ dist.barrier()
+
+
+@functools.lru_cache()
+def _get_global_gloo_group():
+ """
+ Return a process group based on gloo backend, containing all the ranks
+ The result is cached.
+ Returns:
+ (group): pytorch dist group.
+ """
+ if dist.get_backend() == 'nccl':
+ return dist.new_group(backend='gloo')
+ else:
+ return dist.group.WORLD
+
+
+def _serialize_to_tensor(data, group):
+ """
+ Seriialize the tensor to ByteTensor. Note that only `gloo` and `nccl`
+ backend is supported.
+ Args:
+ data (data): data to be serialized.
+ group (group): pytorch dist group.
+ Returns:
+ tensor (ByteTensor): tensor that serialized.
+ """
+
+ backend = dist.get_backend(group)
+ assert backend in ['gloo', 'nccl']
+ device = torch.device('cpu' if backend == 'gloo' else 'cuda')
+
+ buffer = pickle.dumps(data)
+ if len(buffer) > 1024**3:
+ logger = logging.getLogger(__name__)
+ logger.warning(
+ 'Rank {} trying to all-gather {:.2f} GB of data on device {}'.
+ format(get_rank(),
+ len(buffer) / (1024**3), device))
+ storage = torch.ByteStorage.from_buffer(buffer)
+ tensor = torch.ByteTensor(storage).to(device=device)
+ return tensor
+
+
+def _pad_to_largest_tensor(tensor, group):
+ """
+ Padding all the tensors from different GPUs to the largest ones.
+ Args:
+ tensor (tensor): tensor to pad.
+ group (group): pytorch dist group.
+ Returns:
+ list[int]: size of the tensor, on each rank
+ Tensor: padded tensor that has the max size
+ """
+ world_size = dist.get_world_size(group=group)
+ assert (
+ world_size >= 1
+ ), 'comm.gather/all_gather must be called from ranks within the given group!'
+ local_size = torch.tensor([tensor.numel()],
+ dtype=torch.int64,
+ device=tensor.device)
+ size_list = [
+ torch.zeros([1], dtype=torch.int64, device=tensor.device)
+ for _ in range(world_size)
+ ]
+ dist.all_gather(size_list, local_size, group=group)
+ size_list = [int(size.item()) for size in size_list]
+
+ max_size = max(size_list)
+
+ # we pad the tensor because torch all_gather does not support
+ # gathering tensors of different shapes
+ if local_size != max_size:
+ padding = torch.zeros((max_size - local_size, ),
+ dtype=torch.uint8,
+ device=tensor.device)
+ tensor = torch.cat((tensor, padding), dim=0)
+ return size_list, tensor
+
+
+def all_gather_unaligned(data, group=None):
+ """
+ Run all_gather on arbitrary picklable data (not necessarily tensors).
+
+ Args:
+ data: any picklable object
+ group: a torch process group. By default, will use a group which
+ contains all ranks on gloo backend.
+
+ Returns:
+ list[data]: list of data gathered from each rank
+ """
+ if get_world_size() == 1:
+ return [data]
+ if group is None:
+ group = _get_global_gloo_group()
+ if dist.get_world_size(group) == 1:
+ return [data]
+
+ tensor = _serialize_to_tensor(data, group)
+
+ size_list, tensor = _pad_to_largest_tensor(tensor, group)
+ max_size = max(size_list)
+
+ # receiving Tensor from all ranks
+ tensor_list = [
+ torch.empty((max_size, ), dtype=torch.uint8, device=tensor.device)
+ for _ in size_list
+ ]
+ dist.all_gather(tensor_list, tensor, group=group)
+
+ data_list = []
+ for size, tensor in zip(size_list, tensor_list):
+ buffer = tensor.cpu().numpy().tobytes()[:size]
+ data_list.append(pickle.loads(buffer))
+
+ return data_list
+
+
+def init_distributed_training(cfg):
+ """
+ Initialize variables needed for distributed training.
+ """
+ if cfg.NUM_GPUS <= 1:
+ return
+ num_gpus_per_machine = cfg.NUM_GPUS
+ num_machines = dist.get_world_size() // num_gpus_per_machine
+ for i in range(num_machines):
+ ranks_on_i = list(
+ range(i * num_gpus_per_machine, (i + 1) * num_gpus_per_machine))
+ pg = dist.new_group(ranks_on_i)
+ if i == cfg.SHARD_ID:
+ global _LOCAL_PROCESS_GROUP
+ _LOCAL_PROCESS_GROUP = pg
+
+
+def get_local_size() -> int:
+ """
+ Returns:
+ The size of the per-machine process group,
+ i.e. the number of processes per machine.
+ """
+ if not dist.is_available():
+ return 1
+ if not dist.is_initialized():
+ return 1
+ return dist.get_world_size(group=_LOCAL_PROCESS_GROUP)
+
+
+def get_local_rank() -> int:
+ """
+ Returns:
+ The rank of the current process within the local (per-machine) process group.
+ """
+ if not dist.is_available():
+ return 0
+ if not dist.is_initialized():
+ return 0
+ assert _LOCAL_PROCESS_GROUP is not None
+ return dist.get_rank(group=_LOCAL_PROCESS_GROUP)
diff --git a/modelscope/models/multi_modal/videocomposer/utils/utils.py b/modelscope/models/multi_modal/videocomposer/utils/utils.py
new file mode 100644
index 00000000..d01b3dd2
--- /dev/null
+++ b/modelscope/models/multi_modal/videocomposer/utils/utils.py
@@ -0,0 +1,955 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+
+import base64
+import binascii
+import copy
+import glob
+import gzip
+import hashlib
+import logging
+import math
+import os
+import os.path as osp
+import pickle
+import random
+import sys
+import time
+import urllib.request
+import zipfile
+from io import BytesIO
+from multiprocessing.pool import ThreadPool as Pool
+
+import imageio
+import json
+import numpy as np
+import oss2 as oss
+import requests
+import skvideo.io
+import torch
+import torch.nn.functional as F
+import torchvision.utils as tvutils
+from einops import rearrange
+from PIL import Image
+
+__all__ = [
+ 'parse_oss_url', 'parse_bucket', 'read', 'read_image', 'read_gzip',
+ 'ceil_divide', 'to_device', 'put_object', 'put_torch_object',
+ 'put_object_from_file', 'get_object', 'get_object_to_file', 'rand_name',
+ 'save_image', 'save_video', 'save_video_vs_conditions',
+ 'save_video_multiple_conditions_with_data',
+ 'save_video_multiple_conditions', 'download_video_to_file',
+ 'save_video_grid_mp4', 'save_caps', 'ema', 'parallel', 'exists',
+ 'download', 'unzip', 'load_state_dict', 'inverse_indices',
+ 'detect_duplicates', 'read_tfs', 'md5', 'rope', 'format_state',
+ 'breakup_grid', 'huggingface_tokenizer', 'huggingface_model'
+]
+
+TFS_CLIENT = None
+
+
+def DOWNLOAD_TO_CACHE(oss_key,
+ file_or_dirname=None,
+ cache_dir=osp.join(
+ '/'.join(osp.abspath(__file__).split('/')[:-2]),
+ 'model_weights')):
+ r"""Download OSS [file or folder] to the cache folder.
+ Only the 0th process on each node will run the downloading.
+ Barrier all processes until the downloading is completed.
+ """
+ # source and target paths
+ base_path = osp.join(cache_dir, file_or_dirname or osp.basename(oss_key))
+
+ return base_path
+
+
+def find_free_port():
+ """https://stackoverflow.com/questions/1365265/on-localhost-how-do-i-pick-a-free-port-number"""
+ import socket
+ from contextlib import closing
+ with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
+ s.bind(('', 0))
+ s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ return str(s.getsockname()[1])
+
+
+def setup_seed(seed):
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ np.random.seed(seed)
+ random.seed(seed)
+ torch.backends.cudnn.deterministic = True
+
+
+def parse_oss_url(path):
+ if path.startswith('oss://'):
+ path = path[len('oss://'):]
+
+ # configs
+ configs = {
+ 'endpoint': os.getenv('OSS_ENDPOINT', None),
+ 'accessKeyID': os.getenv('OSS_ACCESS_KEY_ID', None),
+ 'accessKeySecret': os.getenv('OSS_ACCESS_KEY_SECRET', None),
+ 'securityToken': os.getenv('OSS_SECURITY_TOKEN', None)
+ }
+ bucket, path = path.split('/', maxsplit=1)
+ if '?' in bucket:
+ bucket, config = bucket.split('?', maxsplit=1)
+ for pair in config.split('&'):
+ k, v = pair.split('=', maxsplit=1)
+ configs[k] = v
+
+ # session
+ session = parse_oss_url._sessions.setdefault(f'{bucket}@{os.getpid()}',
+ oss.Session())
+
+ # bucket
+ bucket = oss.Bucket(
+ auth=oss.Auth(configs['accessKeyID'], configs['accessKeySecret']),
+ endpoint=configs['endpoint'],
+ bucket_name=bucket,
+ session=session)
+ return bucket, path
+
+
+parse_oss_url._sessions = {}
+
+
+def parse_bucket(url):
+ return parse_oss_url(osp.join(url, '_placeholder'))[0]
+
+
+def read(filename, mode='r', retry=5):
+ assert mode in ['r', 'rb']
+ exception = None
+ for _ in range(retry):
+ try:
+ if filename.startswith('oss://'):
+ bucket, path = parse_oss_url(filename)
+ content = bucket.get_object(path).read()
+ if mode == 'r':
+ content = content.decode('utf-8')
+ elif filename.startswith('http'):
+ content = requests.get(filename).content
+ if mode == 'r':
+ content = content.decode('utf-8')
+ else:
+ with open(filename, mode=mode) as f:
+ content = f.read()
+ return content
+ except Exception as e:
+ exception = e
+ continue
+ else:
+ raise exception
+
+
+def read_image(filename, retry=5):
+ exception = None
+ for _ in range(retry):
+ try:
+ return Image.open(BytesIO(read(filename, mode='rb', retry=retry)))
+ except Exception as e:
+ exception = e
+ continue
+ else:
+ raise exception
+
+
+def download_video_to_file(filename, local_file, retry=5):
+ exception = None
+ for _ in range(retry):
+ try:
+ bucket, path = parse_oss_url(filename)
+ bucket.get_object_to_file(path, local_file)
+ break
+ except Exception as e:
+ exception = e
+ continue
+ else:
+ raise exception
+
+
+def read_gzip(filename, retry=5):
+ exception = None
+ for _ in range(retry):
+ try:
+ remove = False
+ if filename.startswith('oss://'):
+ bucket, path = parse_oss_url(filename)
+ filename = rand_name(suffix=osp.splitext(filename)[1])
+ bucket.get_object_to_file(path, filename)
+ remove = True
+ with gzip.open(filename) as f:
+ content = f.read()
+ if remove:
+ os.remove(filename)
+ return content
+ except Exception as e:
+ exception = e
+ continue
+ else:
+ raise exception
+
+
+def ceil_divide(a, b):
+ return int(math.ceil(a / b))
+
+
+def to_device(batch, device, non_blocking=False):
+ if isinstance(batch, (list, tuple)):
+ return type(batch)([to_device(u, device, non_blocking) for u in batch])
+ elif isinstance(batch, dict):
+ return type(batch)([(k, to_device(v, device, non_blocking))
+ for k, v in batch.items()])
+ elif isinstance(batch, torch.Tensor) and batch.device != device:
+ batch = batch.to(device, non_blocking=non_blocking)
+ return batch
+
+
+def put_object(bucket, oss_key, data, retry=5):
+ exception = None
+ for _ in range(retry):
+ try:
+ return bucket.put_object(oss_key, data)
+ except Exception as e:
+ exception = e
+ continue
+ else:
+ print(
+ f'put_object to {oss_key} failed with error: {exception}',
+ flush=True)
+
+
+def put_torch_object(bucket, oss_key, data, retry=5):
+ exception = None
+ for _ in range(retry):
+ try:
+ buffer = BytesIO()
+ torch.save(data, buffer)
+ return bucket.put_object(oss_key, buffer.getvalue())
+ except Exception as e:
+ exception = e
+ continue
+ else:
+ print(
+ f'put_torch_object to {oss_key} failed with error: {exception}',
+ flush=True)
+
+
+def put_object_from_file(bucket, oss_key, filename, retry=5):
+ exception = None
+ for _ in range(retry):
+ try:
+ return bucket.put_object_from_file(oss_key, filename)
+ except Exception as e:
+ exception = e
+ continue
+ else:
+ print(
+ f'put_object_from_file to {oss_key} failed with error: {exception}',
+ flush=True)
+
+
+def get_object(bucket, oss_key, retry=5):
+ exception = None
+ for _ in range(retry):
+ try:
+ return bucket.get_object(oss_key).read()
+ except Exception as e:
+ exception = e
+ continue
+ else:
+ print(
+ f'get_object from {oss_key} failed with error: {exception}',
+ flush=True)
+
+
+def get_object_to_file(bucket, oss_key, filename, retry=5):
+ exception = None
+ for _ in range(retry):
+ try:
+ return bucket.get_object_to_file(oss_key, filename)
+ except Exception as e:
+ exception = e
+ continue
+ else:
+ print(
+ f'get_object_to_file from {oss_key} failed with error: {exception}',
+ flush=True)
+
+
+def rand_name(length=8, suffix=''):
+ name = binascii.b2a_hex(os.urandom(length)).decode('utf-8')
+ if suffix:
+ if not suffix.startswith('.'):
+ suffix = '.' + suffix
+ name += suffix
+ return name
+
+
+@torch.no_grad()
+def save_image(bucket,
+ oss_key,
+ tensor,
+ nrow=8,
+ normalize=True,
+ range=(-1, 1),
+ retry=5):
+ filename = rand_name(suffix='.jpg')
+ for _ in [None] * retry:
+ try:
+ tvutils.save_image(
+ tensor, filename, nrow=nrow, normalize=normalize, range=range)
+ bucket.put_object_from_file(oss_key, filename)
+ exception = None
+ break
+ except Exception as e:
+ exception = e
+ continue
+
+ # remove temporary file
+ if osp.exists(filename):
+ os.remove(filename)
+ if exception is not None:
+ print(
+ 'save image to {} failed, error: {}'.format(oss_key, exception),
+ flush=True)
+
+
+@torch.no_grad()
+def video_tensor_to_gif(tensor, path, duration=120, loop=0, optimize=True):
+ tensor = tensor.permute(1, 2, 3, 0)
+ images = tensor.unbind(dim=0)
+ images = [(image.numpy() * 255).astype('uint8') for image in images]
+ imageio.mimwrite(path, images, fps=8)
+ return images
+
+
+@torch.no_grad()
+def save_video(bucket,
+ oss_key,
+ tensor,
+ mean=[0.5, 0.5, 0.5],
+ std=[0.5, 0.5, 0.5],
+ nrow=8,
+ retry=5):
+ mean = torch.tensor(mean, device=tensor.device).view(1, -1, 1, 1, 1)
+ std = torch.tensor(std, device=tensor.device).view(1, -1, 1, 1, 1)
+ tensor = tensor.mul_(std).add_(mean)
+ tensor.clamp_(0, 1)
+
+ filename = rand_name(suffix='.gif')
+ for _ in [None] * retry:
+ try:
+ one_gif = rearrange(
+ tensor, '(i j) c f h w -> c f (i h) (j w)', i=nrow)
+ video_tensor_to_gif(one_gif, filename)
+ bucket.put_object_from_file(oss_key, filename)
+ exception = None
+ break
+ except Exception as e:
+ exception = e
+ continue
+
+ # remove temporary file
+ if osp.exists(filename):
+ os.remove(filename)
+ if exception is not None:
+ print(
+ 'save video to {} failed, error: {}'.format(oss_key, exception),
+ flush=True)
+
+
+@torch.no_grad()
+def save_video_multiple_conditions(oss_key,
+ video_tensor,
+ model_kwargs,
+ source_imgs,
+ palette,
+ mean=[0.5, 0.5, 0.5],
+ std=[0.5, 0.5, 0.5],
+ nrow=8,
+ retry=5,
+ save_origin_video=True,
+ bucket=None):
+ mean = torch.tensor(mean, device=video_tensor.device).view(1, -1, 1, 1, 1)
+ std = torch.tensor(std, device=video_tensor.device).view(1, -1, 1, 1, 1)
+ video_tensor = video_tensor.mul_(std).add_(mean)
+ try:
+ video_tensor.clamp_(0, 1)
+ except Exception as e:
+ video_tensor = video_tensor.float().clamp_(0, 1)
+ print(e)
+ video_tensor = video_tensor.cpu()
+
+ b, c, n, h, w = video_tensor.shape
+ source_imgs = F.adaptive_avg_pool3d(source_imgs, (n, h, w))
+ source_imgs = source_imgs.cpu()
+
+ model_kwargs_channel3 = {}
+ for key, conditions in model_kwargs[0].items():
+ if conditions.shape[-1] == 1024:
+ # Skip for style embeding
+ continue
+ if len(conditions.shape) == 3:
+ conditions_np = conditions.cpu().numpy()
+ conditions = []
+ for i in conditions_np:
+ vis_i = []
+ for j in i:
+ vis_i.append(
+ palette.get_palette_image(
+ j, percentile=90, width=256, height=256))
+ conditions.append(np.stack(vis_i))
+ conditions = torch.from_numpy(np.stack(conditions))
+ conditions = rearrange(conditions, 'b n h w c -> b c n h w')
+ else:
+ if conditions.size(1) == 1:
+ conditions = torch.cat([conditions, conditions, conditions],
+ dim=1)
+ conditions = F.adaptive_avg_pool3d(conditions, (n, h, w))
+ if conditions.size(1) == 2:
+ conditions = torch.cat([conditions, conditions[:, :1, ]],
+ dim=1)
+ conditions = F.adaptive_avg_pool3d(conditions, (n, h, w))
+ elif conditions.size(1) == 3:
+ conditions = F.adaptive_avg_pool3d(conditions, (n, h, w))
+ elif conditions.size(1) == 4:
+ color = ((conditions[:, 0:3] + 1.) / 2.)
+ alpha = conditions[:, 3:4]
+ conditions = color * alpha + 1.0 * (1.0 - alpha)
+ conditions = F.adaptive_avg_pool3d(conditions, (n, h, w))
+ model_kwargs_channel3[key] = conditions.cpu(
+ ) if conditions.is_cuda else conditions
+
+ filename = oss_key
+ for _ in [None] * retry:
+ try:
+ vid_gif = rearrange(
+ video_tensor, '(i j) c f h w -> c f (i h) (j w)', i=nrow)
+ cons_list = [
+ rearrange(con, '(i j) c f h w -> c f (i h) (j w)', i=nrow)
+ for _, con in model_kwargs_channel3.items()
+ ]
+ source_imgs = rearrange(
+ source_imgs, '(i j) c f h w -> c f (i h) (j w)', i=nrow)
+
+ if save_origin_video:
+ vid_gif = torch.cat(
+ [
+ source_imgs,
+ ] + cons_list + [
+ vid_gif,
+ ], dim=3)
+ else:
+ vid_gif = torch.cat(
+ cons_list + [
+ vid_gif,
+ ], dim=3)
+
+ video_tensor_to_gif(vid_gif, filename)
+ exception = None
+ break
+ except Exception as e:
+ exception = e
+ continue
+ if exception is not None:
+ logging.info('save video to {} failed, error: {}'.format(
+ oss_key, exception))
+
+
+@torch.no_grad()
+def save_video_multiple_conditions_with_data(bucket,
+ video_save_key,
+ gt_video_save_key,
+ vis_oss_key,
+ video_tensor,
+ model_kwargs,
+ source_imgs,
+ palette,
+ mean=[0.5, 0.5, 0.5],
+ std=[0.5, 0.5, 0.5],
+ nrow=8,
+ retry=5):
+ mean = torch.tensor(mean, device=video_tensor.device).view(1, -1, 1, 1, 1)
+ std = torch.tensor(std, device=video_tensor.device).view(1, -1, 1, 1, 1)
+ video_tensor = video_tensor.mul_(std).add_(mean)
+ video_tensor.clamp_(0, 1)
+
+ b, c, n, h, w = video_tensor.shape
+ source_imgs = F.adaptive_avg_pool3d(source_imgs, (n, h, w))
+ source_imgs = source_imgs.cpu()
+
+ model_kwargs_channel3 = {}
+ for key, conditions in model_kwargs[0].items():
+ if len(conditions.shape) == 3:
+ conditions_np = conditions.cpu().numpy()
+ conditions = []
+ for i in conditions_np:
+ vis_i = []
+ for j in i:
+ vis_i.append(
+ palette.get_palette_image(
+ j, percentile=90, width=256, height=256))
+ conditions.append(np.stack(vis_i))
+ conditions = torch.from_numpy(np.stack(conditions))
+ conditions = rearrange(conditions, 'b n h w c -> b c n h w')
+ else:
+ if conditions.size(1) == 1:
+ conditions = torch.cat([conditions, conditions, conditions],
+ dim=1)
+ conditions = F.adaptive_avg_pool3d(conditions, (n, h, w))
+ if conditions.size(1) == 2:
+ conditions = torch.cat([conditions, conditions[:, :1, ]],
+ dim=1)
+ conditions = F.adaptive_avg_pool3d(conditions, (n, h, w))
+ elif conditions.size(1) == 3:
+ conditions = F.adaptive_avg_pool3d(conditions, (n, h, w))
+ elif conditions.size(1) == 4:
+ color = ((conditions[:, 0:3] + 1.) / 2.)
+ alpha = conditions[:, 3:4]
+ conditions = color * alpha + 1.0 * (1.0 - alpha)
+ conditions = F.adaptive_avg_pool3d(conditions, (n, h, w))
+ model_kwargs_channel3[key] = conditions.cpu(
+ ) if conditions.is_cuda else conditions
+
+ copy_video_tensor = video_tensor.clone()
+ copy_source_imgs = source_imgs.clone()
+
+ filename = rand_name(suffix='.gif')
+ for _ in [None] * retry:
+ try:
+ vid_gif = rearrange(
+ video_tensor, '(i j) c f h w -> c f (i h) (j w)', i=nrow)
+ cons_list = [
+ rearrange(con, '(i j) c f h w -> c f (i h) (j w)', j=nrow)
+ for _, con in model_kwargs_channel3.items()
+ ]
+ source_imgs = rearrange(
+ source_imgs, '(i j) c f h w -> c f (i h) (j w)', i=nrow)
+ vid_gif = torch.cat(
+ [
+ source_imgs,
+ ] + cons_list + [
+ vid_gif,
+ ], dim=3)
+
+ video_tensor_to_gif(vid_gif, filename)
+ bucket.put_object_from_file(vis_oss_key, filename)
+ exception = None
+ break
+ except Exception as e:
+ exception = e
+ continue
+
+ # remove temporary file
+ if osp.exists(filename):
+ os.remove(filename)
+
+ filename_pred = rand_name(suffix='.pkl')
+ for _ in [None] * retry:
+ try:
+ copy_video_np = (copy_video_tensor.numpy() * 255).astype('uint8')
+ pickle.dump(copy_video_np, open(filename_pred, 'wb'))
+ bucket.put_object_from_file(video_save_key, filename_pred)
+ break
+ except Exception as e:
+ print('error! ', video_save_key, e)
+ continue
+
+ # remove temporary file
+ if osp.exists(filename_pred):
+ os.remove(filename_pred)
+
+ filename_gt = rand_name(suffix='.pkl')
+ for _ in [None] * retry:
+ try:
+ copy_source_np = (copy_source_imgs.numpy() * 255).astype('uint8')
+ pickle.dump(copy_source_np, open(filename_gt, 'wb'))
+ bucket.put_object_from_file(gt_video_save_key, filename_gt)
+ break
+ except Exception as e:
+ print('error! ', gt_video_save_key, e)
+ continue
+
+ # remove temporary file
+ if osp.exists(filename_gt):
+ os.remove(filename_gt)
+
+ if exception is not None:
+ print(
+ 'save video to {} failed, error: {}'.format(
+ vis_oss_key, exception),
+ flush=True)
+
+
+@torch.no_grad()
+def save_video_vs_conditions(bucket,
+ oss_key,
+ video_tensor,
+ conditions,
+ source_imgs,
+ mean=[0.5, 0.5, 0.5],
+ std=[0.5, 0.5, 0.5],
+ nrow=8,
+ retry=5):
+ mean = torch.tensor(mean, device=video_tensor.device).view(1, -1, 1, 1, 1)
+ std = torch.tensor(std, device=video_tensor.device).view(1, -1, 1, 1, 1)
+ video_tensor = video_tensor.mul_(std).add_(mean)
+ video_tensor.clamp_(0, 1)
+
+ b, c, n, h, w = video_tensor.shape
+ source_imgs = F.adaptive_avg_pool3d(source_imgs, (n, h, w))
+ source_imgs = source_imgs.cpu()
+
+ if conditions.size(1) == 1:
+ conditions = torch.cat([conditions, conditions, conditions], dim=1)
+ conditions = F.adaptive_avg_pool3d(conditions, (n, h, w))
+
+ filename = rand_name(suffix='.gif')
+ for _ in [None] * retry:
+ try:
+ vid_gif = rearrange(
+ video_tensor, '(i j) c f h w -> c f (i h) (j w)', i=nrow)
+ con_gif = rearrange(
+ conditions, '(i j) c f h w -> c f (i h) (j w)', i=nrow)
+ source_imgs = rearrange(
+ source_imgs, '(i j) c f h w -> c f (i h) (j w)', i=nrow)
+ vid_gif = torch.cat([vid_gif, con_gif, source_imgs], dim=2)
+
+ video_tensor_to_gif(vid_gif, filename)
+ bucket.put_object_from_file(oss_key, filename)
+ exception = None
+ break
+ except Exception as e:
+ exception = e
+ continue
+
+ # remove temporary file
+ if osp.exists(filename):
+ os.remove(filename)
+ if exception is not None:
+ print(
+ 'save video to {} failed, error: {}'.format(oss_key, exception),
+ flush=True)
+
+
+@torch.no_grad()
+def save_video_grid_mp4(bucket,
+ oss_key,
+ tensor,
+ mean=[0.5, 0.5, 0.5],
+ std=[0.5, 0.5, 0.5],
+ nrow=None,
+ fps=5,
+ retry=5):
+ mean = torch.tensor(mean, device=tensor.device).view(1, -1, 1, 1, 1)
+ std = torch.tensor(std, device=tensor.device).view(1, -1, 1, 1, 1)
+ tensor = tensor.mul_(std).add_(mean)
+ tensor.clamp_(0, 1)
+ b, c, t, h, w = tensor.shape
+ tensor = tensor.permute(0, 2, 3, 4, 1)
+ tensor = (tensor.cpu().numpy() * 255).astype('uint8')
+
+ filename = rand_name(suffix='.mp4')
+ for _ in [None] * retry:
+ try:
+ if nrow is None:
+ nrow = math.ceil(math.sqrt(b))
+ ncol = math.ceil(b / nrow)
+ padding = 1
+ video_grid = np.zeros((t, (padding + h) * nrow + padding,
+ (padding + w) * ncol + padding, c),
+ dtype='uint8')
+ for i in range(b):
+ r = i // ncol
+ c_ = i % ncol
+
+ start_r = (padding + h) * r
+ start_c = (padding + w) * c_
+ video_grid[:, start_r:start_r + h,
+ start_c:start_c + w] = tensor[i]
+ skvideo.io.vwrite(filename, video_grid, inputdict={'-r': str(fps)})
+
+ bucket.put_object_from_file(oss_key, filename)
+ exception = None
+ break
+ except Exception as e:
+ exception = e
+ continue
+
+ # remove temporary file
+ if osp.exists(filename):
+ os.remove(filename)
+ if exception is not None:
+ print(
+ 'save video to {} failed, error: {}'.format(oss_key, exception),
+ flush=True)
+
+
+@torch.no_grad()
+def save_text(bucket, oss_key, tensor, nrow=8, retry=5):
+ len = tensor.shape[0]
+ num_per_row = int(len / nrow)
+ assert (len == nrow * num_per_row)
+ texts = ''
+ for i in range(nrow):
+ for j in range(num_per_row):
+ text = dec_bytes2obj(tensor[i * num_per_row + j])
+ texts += text + '\n'
+ texts += '\n'
+
+ for _ in [None] * retry:
+ try:
+ bucket.put_object(oss_key, texts)
+ exception = None
+ break
+ except Exception as e:
+ exception = e
+ continue
+ if exception is not None:
+ print(
+ 'save video to {} failed, error: {}'.format(oss_key, exception),
+ flush=True)
+
+
+@torch.no_grad()
+def save_caps(bucket, oss_key, caps, retry=5):
+ texts = ''
+ for cap in caps:
+ texts += cap
+ texts += '\n'
+
+ for _ in [None] * retry:
+ try:
+ bucket.put_object(oss_key, texts)
+ exception = None
+ break
+ except Exception as e:
+ exception = e
+ continue
+ if exception is not None:
+ print(
+ 'save video to {} failed, error: {}'.format(oss_key, exception),
+ flush=True)
+
+
+@torch.no_grad()
+def ema(net_ema, net, beta, copy_buffer=False):
+ assert 0.0 <= beta <= 1.0
+ for p_ema, p in zip(net_ema.parameters(), net.parameters()):
+ p_ema.copy_(p.lerp(p_ema, beta))
+ if copy_buffer:
+ for b_ema, b in zip(net_ema.buffers(), net.buffers()):
+ b_ema.copy_(b)
+
+
+def parallel(func, args_list, num_workers=32, timeout=None):
+ assert isinstance(args_list, list)
+ if not isinstance(args_list[0], tuple):
+ args_list = [(args, ) for args in args_list]
+ if num_workers == 0:
+ return [func(*args) for args in args_list]
+ with Pool(processes=num_workers) as pool:
+ results = [pool.apply_async(func, args) for args in args_list]
+ results = [res.get(timeout=timeout) for res in results]
+ return results
+
+
+def exists(filename):
+ if filename.startswith('oss://'):
+ bucket, path = parse_oss_url(filename)
+ return bucket.object_exists(path)
+ else:
+ return osp.exists(filename)
+
+
+def download(url, filename=None, replace=False, quiet=False):
+ if filename is None:
+ filename = osp.basename(url)
+ if not osp.exists(filename) or replace:
+ try:
+ if url.startswith('oss://'):
+ bucket, oss_key = parse_oss_url(url)
+ bucket.get_object_to_file(oss_key, filename)
+ else:
+ urllib.request.urlretrieve(url, filename)
+ if not quiet:
+ print(f'Downloaded {url} to {filename}', flush=True)
+ except Exception as e:
+ raise ValueError(f'Downloading {filename} failed with error {e}')
+ return osp.abspath(filename)
+
+
+def unzip(filename, dst_dir=None):
+ if dst_dir is None:
+ dst_dir = osp.dirname(filename)
+ with zipfile.ZipFile(filename, 'r') as zip_ref:
+ zip_ref.extractall(dst_dir)
+
+
+def load_state_dict(module, state_dict, drop_prefix=''):
+ # find incompatible key-vals
+ src, dst = state_dict, module.state_dict()
+ if drop_prefix:
+ src = type(src)([
+ (k[len(drop_prefix):] if k.startswith(drop_prefix) else k, v)
+ for k, v in src.items()
+ ])
+ missing = [k for k in dst if k not in src]
+ unexpected = [k for k in src if k not in dst]
+ unmatched = [
+ k for k in src.keys() & dst.keys() if src[k].shape != dst[k].shape
+ ]
+
+ # keep only compatible key-vals
+ incompatible = set(unexpected + unmatched)
+ src = type(src)([(k, v) for k, v in src.items() if k not in incompatible])
+ module.load_state_dict(src, strict=False)
+
+ # report incompatible key-vals
+ if len(missing) != 0:
+ print(' Missing: ' + ', '.join(missing), flush=True)
+ if len(unexpected) != 0:
+ print(' Unexpected: ' + ', '.join(unexpected), flush=True)
+ if len(unmatched) != 0:
+ print(' Shape unmatched: ' + ', '.join(unmatched), flush=True)
+
+
+def inverse_indices(indices):
+ r"""Inverse map of indices.
+ E.g., if A[indices] == B, then B[inv_indices] == A.
+ """
+ inv_indices = torch.empty_like(indices)
+ inv_indices[indices] = torch.arange(len(indices)).to(indices)
+ return inv_indices
+
+
+def detect_duplicates(feats, thr=0.9):
+ assert feats.ndim == 2
+
+ # compute simmat
+ feats = F.normalize(feats, p=2, dim=1)
+ simmat = torch.mm(feats, feats.T)
+ simmat.triu_(1)
+ torch.cuda.synchronize()
+
+ # detect duplicates
+ mask = ~simmat.gt(thr).any(dim=0)
+ return torch.where(mask)[0]
+
+
+class TFSClient(object):
+
+ def __init__(self,
+ host='restful-store.vip.tbsite.net:3800',
+ app_key='5354c9fae75f5'):
+ self.host = host
+ self.app_key = app_key
+
+ # candidate servers
+ self.servers = [
+ u for u in read(f'http://{host}/url.list').strip().split('\n')[1:]
+ if ':' in u
+ ]
+ assert len(self.servers) >= 1
+ self.__server_id = -1
+
+ @property
+ def server(self):
+ self.__server_id = (self.__server_id + 1) % len(self.servers)
+ return self.servers[self.__server_id]
+
+ def read(self, tfs):
+ tfs = osp.basename(tfs)
+ meta = json.loads(
+ read(
+ f'http://{self.server}/v1/{self.app_key}/metadata/{tfs}?force=0'
+ ))
+ img = Image.open(
+ BytesIO(
+ read(
+ f'http://{self.server}/v1/{self.app_key}/{tfs}?offset=0&size={meta["SIZE"]}',
+ 'rb')))
+ return img
+
+
+def read_tfs(tfs, retry=5):
+ exception = None
+ for _ in range(retry):
+ try:
+ global TFS_CLIENT
+ if TFS_CLIENT is None:
+ TFS_CLIENT = TFSClient()
+ return TFS_CLIENT.read(tfs)
+ except Exception as e:
+ exception = e
+ continue
+ else:
+ raise exception
+
+
+def md5(filename):
+ with open(filename, 'rb') as f:
+ return hashlib.md5(f.read()).hexdigest()
+
+
+def rope(x):
+ r"""Apply rotary position embedding on x of shape [B, *(spatial dimensions), C].
+ """
+ # reshape
+ shape = x.shape
+ x = x.view(x.size(0), -1, x.size(-1))
+ l, c = x.shape[-2:]
+ assert c % 2 == 0
+ half = c // 2
+
+ # apply rotary position embedding on x
+ sinusoid = torch.outer(
+ torch.arange(l).to(x),
+ torch.pow(10000, -torch.arange(half).to(x).div(half)))
+ sin, cos = torch.sin(sinusoid), torch.cos(sinusoid)
+ x1, x2 = x.chunk(2, dim=-1)
+ x = torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1)
+
+ # reshape back
+ return x.view(shape)
+
+
+def format_state(state, filename=None):
+ r"""For comparing/aligning state_dict.
+ """
+ content = '\n'.join([f'{k}\t{tuple(v.shape)}' for k, v in state.items()])
+ if filename:
+ with open(filename, 'w') as f:
+ f.write(content)
+
+
+def breakup_grid(img, grid_size):
+ r"""The inverse operator of ``torchvision.utils.make_grid``.
+ """
+ # params
+ nrow = img.height // grid_size
+ ncol = img.width // grid_size
+ wrow = wcol = 2
+
+ # collect grids
+ grids = []
+ for i in range(nrow):
+ for j in range(ncol):
+ x1 = j * grid_size + (j + 1) * wcol
+ y1 = i * grid_size + (i + 1) * wrow
+ grids.append(img.crop((x1, y1, x1 + grid_size, y1 + grid_size)))
+ return grids
+
+
+def huggingface_tokenizer(name='google/mt5-xxl', **kwargs):
+ from transformers import AutoTokenizer
+ return AutoTokenizer.from_pretrained(
+ DOWNLOAD_TO_CACHE(f'huggingface/tokenizers/{name}', name), **kwargs)
+
+
+def huggingface_model(name='google/mt5-xxl', model_type='AutoModel', **kwargs):
+ import transformers
+ return getattr(transformers, model_type).from_pretrained(
+ DOWNLOAD_TO_CACHE(f'huggingface/models/{name}', name), **kwargs)
diff --git a/modelscope/models/multi_modal/videocomposer/videocomposer_model.py b/modelscope/models/multi_modal/videocomposer/videocomposer_model.py
new file mode 100644
index 00000000..3e2a910c
--- /dev/null
+++ b/modelscope/models/multi_modal/videocomposer/videocomposer_model.py
@@ -0,0 +1,480 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+
+import os
+from copy import copy, deepcopy
+from os import path as osp
+from typing import Any, Dict
+
+import open_clip
+import pynvml
+import torch
+import torch.cuda.amp as amp
+import torch.nn as nn
+from einops import rearrange
+
+import modelscope.models.multi_modal.videocomposer.models as models
+from modelscope.metainfo import Models
+from modelscope.models import TorchModel
+from modelscope.models.builder import MODELS
+from modelscope.models.multi_modal.videocomposer.annotator.sketch import (
+ pidinet_bsd, sketch_simplification_gan)
+from modelscope.models.multi_modal.videocomposer.autoencoder import \
+ AutoencoderKL
+from modelscope.models.multi_modal.videocomposer.clip import (
+ FrozenOpenCLIPEmbedder, FrozenOpenCLIPVisualEmbedder)
+from modelscope.models.multi_modal.videocomposer.diffusion import (
+ GaussianDiffusion, beta_schedule)
+from modelscope.models.multi_modal.videocomposer.ops.utils import (
+ get_first_stage_encoding, make_masked_images, prepare_model_kwargs,
+ save_with_model_kwargs)
+from modelscope.models.multi_modal.videocomposer.unet_sd import UNetSD_temporal
+from modelscope.models.multi_modal.videocomposer.utils.config import Config
+from modelscope.models.multi_modal.videocomposer.utils.utils import (
+ find_free_port, setup_seed, to_device)
+from modelscope.outputs import OutputKeys
+from modelscope.preprocessors.image import load_image
+from modelscope.utils.constant import ModelFile, Tasks
+from .config import cfg
+
+__all__ = ['VideoComposer']
+
+
+@MODELS.register_module(
+ Tasks.text_to_video_synthesis, module_name=Models.videocomposer)
+class VideoComposer(TorchModel):
+ r"""
+ task for video composer.
+
+ Attributes:
+ sd_model: denosing model using in this task.
+ diffusion: diffusion model for DDIM.
+ autoencoder: decode the latent representation into visual space with VQGAN.
+ clip_encoder: encode the text into text embedding.
+ """
+
+ def __init__(self, model_dir, *args, **kwargs):
+ r"""
+ Args:
+ model_dir (`str` or `os.PathLike`)
+ Can be either:
+ - A string, the *model id* of a pretrained model hosted inside a model repo on modelscope
+ or modelscope.cn. Valid model ids can be located at the root-level, like `bert-base-uncased`,
+ or namespaced under a user or organization name, like `dbmdz/bert-base-german-cased`.
+ - A path to a *directory* containing model weights saved using
+ [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
+ - A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In
+ this case, `from_tf` should be set to `True` and a configuration object should be provided as
+ `config` argument. This loading path is slower than converting the TensorFlow checkpoint in a
+ PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
+ - A path or url to a model folder containing a *flax checkpoint file* in *.msgpack* format (e.g,
+ `./flax_model/` containing `flax_model.msgpack`). In this case, `from_flax` should be set to
+ `True`.
+ """
+ super().__init__(model_dir=model_dir, *args, **kwargs)
+ self.device = torch.device('cuda') if torch.cuda.is_available() \
+ else torch.device('cpu')
+ clip_checkpoint = kwargs.pop('clip_checkpoint',
+ 'open_clip_pytorch_model.bin')
+ sd_checkpoint = kwargs.pop('sd_checkpoint', 'v2-1_512-ema-pruned.ckpt')
+ _cfg = Config(load=True)
+ cfg.update(_cfg.cfg_dict)
+ # rank-wise params
+ l1 = len(cfg.frame_lens)
+ l2 = len(cfg.feature_framerates)
+ cfg.max_frames = cfg.frame_lens[0 % (l1 * l2) // l2]
+ cfg.batch_size = cfg.batch_sizes[str(cfg.max_frames)]
+ # Copy update input parameter to current task
+ self.cfg = cfg
+ if 'MASTER_ADDR' not in os.environ:
+ os.environ['MASTER_ADDR'] = 'localhost'
+ os.environ['MASTER_PORT'] = find_free_port()
+ self.cfg.pmi_rank = int(os.getenv('RANK', 0))
+ self.cfg.pmi_world_size = int(os.getenv('WORLD_SIZE', 1))
+ setup_seed(self.cfg.seed)
+ self.read_image = kwargs.pop('read_image', False)
+ self.read_style = kwargs.pop('read_style', True)
+ self.read_sketch = kwargs.pop('read_sketch', False)
+ self.save_origin_video = kwargs.pop('save_origin_video', True)
+ self.video_compositions = kwargs.pop('video_compositions', [
+ 'text', 'mask', 'depthmap', 'sketch', 'motion', 'image',
+ 'local_image', 'single_sketch'
+ ])
+ self.viz_num = self.cfg.batch_size
+ self.clip_encoder = FrozenOpenCLIPEmbedder(
+ layer='penultimate',
+ pretrained=os.path.join(model_dir, clip_checkpoint))
+ self.clip_encoder = self.clip_encoder.to(self.device)
+ self.clip_encoder_visual = FrozenOpenCLIPVisualEmbedder(
+ layer='penultimate',
+ pretrained=os.path.join(model_dir, clip_checkpoint))
+ self.clip_encoder_visual.model.to(self.device)
+ ddconfig = {
+ 'double_z': True,
+ 'z_channels': 4,
+ 'resolution': 256,
+ 'in_channels': 3,
+ 'out_ch': 3,
+ 'ch': 128,
+ 'ch_mult': [1, 2, 4, 4],
+ 'num_res_blocks': 2,
+ 'attn_resolutions': [],
+ 'dropout': 0.0
+ }
+ self.autoencoder = AutoencoderKL(
+ ddconfig, 4, ckpt_path=os.path.join(model_dir, sd_checkpoint))
+ self.zero_y = self.clip_encoder('').detach()
+ black_image_feature = self.clip_encoder_visual(
+ self.clip_encoder_visual.black_image).unsqueeze(1)
+ black_image_feature = torch.zeros_like(black_image_feature)
+ self.autoencoder.eval()
+ for param in self.autoencoder.parameters():
+ param.requires_grad = False
+ self.autoencoder.cuda()
+ self.model = UNetSD_temporal(
+ cfg=self.cfg,
+ in_dim=self.cfg.unet_in_dim,
+ concat_dim=self.cfg.unet_concat_dim,
+ dim=self.cfg.unet_dim,
+ y_dim=self.cfg.unet_y_dim,
+ context_dim=self.cfg.unet_context_dim,
+ out_dim=self.cfg.unet_out_dim,
+ dim_mult=self.cfg.unet_dim_mult,
+ num_heads=self.cfg.unet_num_heads,
+ head_dim=self.cfg.unet_head_dim,
+ num_res_blocks=self.cfg.unet_res_blocks,
+ attn_scales=self.cfg.unet_attn_scales,
+ dropout=self.cfg.unet_dropout,
+ temporal_attention=self.cfg.temporal_attention,
+ temporal_attn_times=self.cfg.temporal_attn_times,
+ use_checkpoint=self.cfg.use_checkpoint,
+ use_fps_condition=self.cfg.use_fps_condition,
+ use_sim_mask=self.cfg.use_sim_mask,
+ video_compositions=self.cfg.video_compositions,
+ misc_dropout=self.cfg.misc_dropout,
+ p_all_zero=self.cfg.p_all_zero,
+ p_all_keep=self.cfg.p_all_zero,
+ zero_y=self.zero_y,
+ black_image_feature=black_image_feature,
+ ).to(self.device)
+
+ # Load checkpoint
+ if self.cfg.resume and self.cfg.resume_checkpoint:
+ if hasattr(self.cfg, 'text_to_video_pretrain'
+ ) and self.cfg.text_to_video_pretrain:
+ checkpoint_name = cfg.resume_checkpoint.split('/')[-1]
+ ss = torch.load(
+ os.path.join(self.model_dir, cfg.resume_checkpoint))
+ ss = {
+ key: p
+ for key, p in ss.items() if 'input_blocks.0.0' not in key
+ }
+ self.model.load_state_dict(ss, strict=False)
+ else:
+ checkpoint_name = cfg.resume_checkpoint.split('/')[-1]
+ self.model.load_state_dict(
+ torch.load(
+ os.path.join(self.model_dir, checkpoint_name),
+ map_location='cpu'),
+ strict=False)
+
+ torch.cuda.empty_cache()
+ else:
+ raise ValueError(
+ f'The checkpoint file {self.cfg.resume_checkpoint} is wrong ')
+
+ # diffusion
+ betas = beta_schedule(
+ 'linear_sd',
+ self.cfg.num_timesteps,
+ init_beta=0.00085,
+ last_beta=0.0120)
+ self.diffusion = GaussianDiffusion(
+ betas=betas,
+ mean_type=self.cfg.mean_type,
+ var_type=self.cfg.var_type,
+ loss_type=self.cfg.loss_type,
+ rescale_timesteps=False)
+
+ def forward(self, input: Dict[str, Any]):
+ frame_in = None
+ if self.read_image:
+ image_key = input['style_image']
+ frame = load_image(image_key)
+ frame_in = misc_transforms([frame])
+
+ frame_sketch = None
+ if self.read_sketch:
+ sketch_key = self.cfg.sketch_path
+ frame_sketch = load_image(sketch_key)
+ frame_sketch = misc_transforms([frame_sketch])
+
+ frame_style = None
+ if self.read_style:
+ frame_style = load_image(input['style_image'])
+
+ # Generators for various conditions
+ if 'depthmap' in self.video_compositions:
+ midas = models.midas_v3(
+ pretrained=True,
+ model_dir=self.model_dir).eval().requires_grad_(False).to(
+ memory_format=torch.channels_last).half().to(self.device)
+ if 'canny' in self.video_compositions:
+ canny_detector = CannyDetector()
+ if 'sketch' in self.video_compositions:
+ pidinet = pidinet_bsd(
+ self.model_dir, pretrained=True,
+ vanilla_cnn=True).eval().requires_grad_(False).to(self.device)
+ cleaner = sketch_simplification_gan(
+ self.model_dir,
+ pretrained=True).eval().requires_grad_(False).to(self.device)
+ pidi_mean = torch.tensor(self.cfg.sketch_mean).view(
+ 1, -1, 1, 1).to(self.device)
+ pidi_std = torch.tensor(self.cfg.sketch_std).view(1, -1, 1, 1).to(
+ self.device)
+ # Placeholder for color inference
+ palette = None
+
+ self.model.eval()
+ caps = input['cap_txt']
+ if self.cfg.max_frames == 1 and self.cfg.use_image_dataset:
+ ref_imgs = input['ref_frame']
+ video_data = input['video_data']
+ misc_data = input['misc_data']
+ mask = input['mask']
+ mv_data = input['mv_data']
+ fps = torch.tensor(
+ [self.cfg.feature_framerate] * self.cfg.batch_size,
+ dtype=torch.long,
+ device=self.device)
+ else:
+ ref_imgs = input['ref_frame']
+ video_data = input['video_data']
+ misc_data = input['misc_data']
+ mask = input['mask']
+ mv_data = input['mv_data']
+ # add fps test
+ fps = torch.tensor(
+ [self.cfg.feature_framerate] * self.cfg.batch_size,
+ dtype=torch.long,
+ device=self.device)
+
+ # save for visualization
+ misc_backups = copy(misc_data)
+ misc_backups = rearrange(misc_backups, 'b f c h w -> b c f h w')
+ mv_data_video = []
+ if 'motion' in self.cfg.video_compositions:
+ mv_data_video = rearrange(mv_data, 'b f c h w -> b c f h w')
+
+ # mask images
+ masked_video = []
+ if 'mask' in self.cfg.video_compositions:
+ masked_video = make_masked_images(
+ misc_data.sub(0.5).div_(0.5), mask)
+ masked_video = rearrange(masked_video, 'b f c h w -> b c f h w')
+
+ # Single Image
+ image_local = []
+ if 'local_image' in self.cfg.video_compositions:
+ frames_num = misc_data.shape[1]
+ bs_vd_local = misc_data.shape[0]
+ if self.cfg.read_image:
+ image_local = frame_in.unsqueeze(0).repeat(
+ bs_vd_local, frames_num, 1, 1, 1).cuda()
+ else:
+ image_local = misc_data[:, :1].clone().repeat(
+ 1, frames_num, 1, 1, 1)
+ image_local = rearrange(
+ image_local, 'b f c h w -> b c f h w', b=bs_vd_local)
+
+ # encode the video_data
+ bs_vd = video_data.shape[0]
+ video_data = rearrange(video_data, 'b f c h w -> (b f) c h w')
+ misc_data = rearrange(misc_data, 'b f c h w -> (b f) c h w')
+
+ video_data_list = torch.chunk(
+ video_data, video_data.shape[0] // self.cfg.chunk_size, dim=0)
+ misc_data_list = torch.chunk(
+ misc_data, misc_data.shape[0] // self.cfg.chunk_size, dim=0)
+
+ with torch.no_grad():
+ decode_data = []
+ for vd_data in video_data_list:
+ encoder_posterior = self.autoencoder.encode(vd_data)
+ tmp = get_first_stage_encoding(encoder_posterior).detach()
+ decode_data.append(tmp)
+ video_data = torch.cat(decode_data, dim=0)
+ video_data = rearrange(
+ video_data, '(b f) c h w -> b c f h w', b=bs_vd)
+
+ depth_data = []
+ if 'depthmap' in self.cfg.video_compositions:
+ for misc_imgs in misc_data_list:
+ depth = midas(
+ misc_imgs.sub(0.5).div_(0.5).to(
+ memory_format=torch.channels_last).half())
+ depth = (depth / self.cfg.depth_std).clamp_(
+ 0, self.cfg.depth_clamp)
+ depth_data.append(depth)
+ depth_data = torch.cat(depth_data, dim=0)
+ depth_data = rearrange(
+ depth_data, '(b f) c h w -> b c f h w', b=bs_vd)
+
+ canny_data = []
+ if 'canny' in self.cfg.video_compositions:
+ for misc_imgs in misc_data_list:
+ misc_imgs = rearrange(misc_imgs.clone(),
+ 'k c h w -> k h w c')
+ canny_condition = torch.stack(
+ [canny_detector(misc_img) for misc_img in misc_imgs])
+ canny_condition = rearrange(canny_condition,
+ 'k h w c-> k c h w')
+ canny_data.append(canny_condition)
+ canny_data = torch.cat(canny_data, dim=0)
+ canny_data = rearrange(
+ canny_data, '(b f) c h w -> b c f h w', b=bs_vd)
+
+ sketch_data = []
+ if 'sketch' in self.cfg.video_compositions:
+ sketch_list = misc_data_list
+ if self.cfg.read_sketch:
+ sketch_repeat = frame_sketch.repeat(frames_num, 1, 1,
+ 1).cuda()
+ sketch_list = [sketch_repeat]
+
+ for misc_imgs in sketch_list:
+ sketch = pidinet(misc_imgs.sub(pidi_mean).div_(pidi_std))
+ sketch = 1.0 - cleaner(1.0 - sketch)
+ sketch_data.append(sketch)
+ sketch_data = torch.cat(sketch_data, dim=0)
+ sketch_data = rearrange(
+ sketch_data, '(b f) c h w -> b c f h w', b=bs_vd)
+
+ single_sketch_data = []
+ if 'single_sketch' in self.cfg.video_compositions:
+ single_sketch_data = sketch_data.clone()[:, :, :1].repeat(
+ 1, 1, frames_num, 1, 1)
+
+ # preprocess for input text descripts
+ y = self.clip_encoder(caps).detach()
+ y0 = y.clone()
+
+ y_visual = []
+ if 'image' in self.cfg.video_compositions:
+ with torch.no_grad():
+ if self.cfg.read_style:
+ y_visual = self.clip_encoder_visual(
+ self.clip_encoder_visual.preprocess(
+ frame_style).unsqueeze(0).cuda()).unsqueeze(0)
+ y_visual0 = y_visual.clone()
+ else:
+ ref_imgs = ref_imgs.squeeze(1)
+ y_visual = self.clip_encoder_visual(ref_imgs).unsqueeze(1)
+ y_visual0 = y_visual.clone()
+
+ with torch.no_grad():
+ # Log memory
+ pynvml.nvmlInit()
+ # Sample images (DDIM)
+ with amp.autocast(enabled=self.cfg.use_fp16):
+ if self.cfg.share_noise:
+ b, c, f, h, w = video_data.shape
+ noise = torch.randn((self.viz_num, c, h, w),
+ device=self.device)
+ noise = noise.repeat_interleave(repeats=f, dim=0)
+ noise = rearrange(
+ noise, '(b f) c h w->b c f h w', b=self.viz_num)
+ noise = noise.contiguous()
+ else:
+ noise = torch.randn_like(video_data[:self.viz_num])
+
+ full_model_kwargs = [{
+ 'y':
+ y0[:self.viz_num],
+ 'local_image':
+ None
+ if len(image_local) == 0 else image_local[:self.viz_num],
+ 'image':
+ None if len(y_visual) == 0 else y_visual0[:self.viz_num],
+ 'depth':
+ None
+ if len(depth_data) == 0 else depth_data[:self.viz_num],
+ 'canny':
+ None
+ if len(canny_data) == 0 else canny_data[:self.viz_num],
+ 'sketch':
+ None
+ if len(sketch_data) == 0 else sketch_data[:self.viz_num],
+ 'masked':
+ None
+ if len(masked_video) == 0 else masked_video[:self.viz_num],
+ 'motion':
+ None if len(mv_data_video) == 0 else
+ mv_data_video[:self.viz_num],
+ 'single_sketch':
+ None if len(single_sketch_data) == 0 else
+ single_sketch_data[:self.viz_num],
+ 'fps':
+ fps[:self.viz_num]
+ }, {
+ 'y':
+ self.zero_y.repeat(self.viz_num, 1, 1)
+ if not self.cfg.use_fps_condition else
+ torch.zeros_like(y0)[:self.viz_num],
+ 'local_image':
+ None
+ if len(image_local) == 0 else image_local[:self.viz_num],
+ 'image':
+ None if len(y_visual) == 0 else torch.zeros_like(
+ y_visual0[:self.viz_num]),
+ 'depth':
+ None
+ if len(depth_data) == 0 else depth_data[:self.viz_num],
+ 'canny':
+ None
+ if len(canny_data) == 0 else canny_data[:self.viz_num],
+ 'sketch':
+ None
+ if len(sketch_data) == 0 else sketch_data[:self.viz_num],
+ 'masked':
+ None
+ if len(masked_video) == 0 else masked_video[:self.viz_num],
+ 'motion':
+ None if len(mv_data_video) == 0 else
+ mv_data_video[:self.viz_num],
+ 'single_sketch':
+ None if len(single_sketch_data) == 0 else
+ single_sketch_data[:self.viz_num],
+ 'fps':
+ fps[:self.viz_num]
+ }]
+
+ # Save generated videos
+ partial_keys = self.cfg.guidances
+ noise_motion = noise.clone()
+ model_kwargs = prepare_model_kwargs(
+ partial_keys=partial_keys,
+ full_model_kwargs=full_model_kwargs,
+ use_fps_condition=self.cfg.use_fps_condition)
+ video_output = self.diffusion.ddim_sample_loop(
+ noise=noise_motion,
+ model=self.model.eval(),
+ model_kwargs=model_kwargs,
+ guide_scale=9.0,
+ ddim_timesteps=self.cfg.ddim_timesteps,
+ eta=0.0)
+
+ save_with_model_kwargs(
+ model_kwargs=model_kwargs,
+ video_data=video_output,
+ autoencoder=self.autoencoder,
+ ori_video=misc_backups,
+ viz_num=self.viz_num,
+ step=0,
+ caps=caps,
+ palette=palette,
+ cfg=self.cfg)
+
+ return {
+ 'video': video_output.type(torch.float32).cpu(),
+ 'video_path': self.cfg
+ }
diff --git a/modelscope/pipelines/multi_modal/__init__.py b/modelscope/pipelines/multi_modal/__init__.py
index 11c28fdd..fd147513 100644
--- a/modelscope/pipelines/multi_modal/__init__.py
+++ b/modelscope/pipelines/multi_modal/__init__.py
@@ -22,6 +22,7 @@ if TYPE_CHECKING:
from .soonet_video_temporal_grounding_pipeline import SOONetVideoTemporalGroundingPipeline
from .text_to_video_synthesis_pipeline import TextToVideoSynthesisPipeline
from .multimodal_dialogue_pipeline import MultimodalDialoguePipeline
+ from .videocomposer_pipeline import VideoComposerPipeline
else:
_import_structure = {
'image_captioning_pipeline': ['ImageCaptioningPipeline'],
@@ -46,7 +47,8 @@ else:
'soonet_video_temporal_grounding_pipeline':
['SOONetVideoTemporalGroundingPipeline'],
'text_to_video_synthesis_pipeline': ['TextToVideoSynthesisPipeline'],
- 'multimodal_dialogue_pipeline': ['MultimodalDialoguePipeline']
+ 'multimodal_dialogue_pipeline': ['MultimodalDialoguePipeline'],
+ 'videocomposer_pipeline': ['VideoComposerPipeline']
}
import sys
diff --git a/modelscope/pipelines/multi_modal/videocomposer_pipeline.py b/modelscope/pipelines/multi_modal/videocomposer_pipeline.py
new file mode 100644
index 00000000..539884fb
--- /dev/null
+++ b/modelscope/pipelines/multi_modal/videocomposer_pipeline.py
@@ -0,0 +1,382 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+
+import os
+import random
+import subprocess
+import tempfile
+import time
+from functools import partial
+from typing import Any, Dict
+
+import cv2
+import imageio
+import numpy as np
+import torch
+import torchvision.transforms as T
+from mvextractor.videocap import VideoCap
+from PIL import Image
+
+import modelscope.models.multi_modal.videocomposer.data as data
+from modelscope.metainfo import Pipelines
+from modelscope.models.multi_modal.videocomposer.data.transforms import (
+ CenterCropV3, random_resize)
+from modelscope.models.multi_modal.videocomposer.ops.random_mask import (
+ make_irregular_mask, make_rectangle_mask, make_uncrop)
+from modelscope.models.multi_modal.videocomposer.utils.utils import rand_name
+from modelscope.outputs import OutputKeys
+from modelscope.pipelines.base import Input, Pipeline
+from modelscope.pipelines.builder import PIPELINES
+from modelscope.utils.constant import Tasks
+from modelscope.utils.device import device_placement
+from modelscope.utils.logger import get_logger
+
+logger = get_logger()
+
+
+@PIPELINES.register_module(
+ Tasks.text_to_video_synthesis, module_name=Pipelines.videocomposer)
+class VideoComposerPipeline(Pipeline):
+ r""" Video Composer Pipeline.
+
+ Examples:
+
+ >>> from modelscope.pipelines import pipeline
+ >>> from modelscope.utils.constant import Tasks
+ >>> pipe = pipeline(
+ task=Tasks.text_to_video_synthesis,
+ model='buptwq/videocomposer',
+ model_revision='v1.0.1')
+ >>> inputs = {'Video:FILE': 'path/input_video.mp4',
+ 'Image:FILE': 'path/input_image.png',
+ 'text': 'the text description'}
+ >>> output = pipe(inputs)
+ """
+
+ def __init__(self, model: str, **kwargs):
+ """
+ use `model` to create a videocomposer pipeline for prediction
+ Args:
+ model: model id on modelscope hub.
+ """
+ super().__init__(model=model)
+ self.log_dir = kwargs.pop('log_dir', './video_outputs')
+ if not os.path.exists(self.log_dir):
+ os.makedirs(self.log_dir)
+ self.feature_framerate = kwargs.pop('feature_framerate', 4)
+ self.frame_lens = kwargs.pop('frame_lens', [
+ 16,
+ 16,
+ 16,
+ 16,
+ ])
+ self.feature_framerates = kwargs.pop('feature_framerates', [
+ 4,
+ ])
+ self.batch_sizes = kwargs.pop('batch_sizes', {
+ '1': 1,
+ '4': 1,
+ '8': 1,
+ '16': 1,
+ })
+ l1 = len(self.frame_lens)
+ l2 = len(self.feature_framerates)
+ self.max_frames = self.frame_lens[0 % (l1 * l2) // l2]
+ self.batch_size = self.batch_sizes[str(self.max_frames)]
+ self.resolution = kwargs.pop('resolution', 256)
+ self.image_resolution = kwargs.pop('image_resolution', 256)
+ self.mean = kwargs.pop('mean', [0.5, 0.5, 0.5])
+ self.std = kwargs.pop('std', [0.5, 0.5, 0.5])
+ self.vit_image_size = kwargs.pop('vit_image_size', 224)
+ self.vit_mean = kwargs.pop('vit_mean',
+ [0.48145466, 0.4578275, 0.40821073])
+ self.vit_std = kwargs.pop('vit_std',
+ [0.26862954, 0.26130258, 0.27577711])
+ self.misc_size = kwargs.pop('kwargs.pop', 384)
+ self.visual_mv = kwargs.pop('visual_mv', False)
+ self.max_words = kwargs.pop('max_words', 1000)
+ self.mvs_visual = kwargs.pop('mvs_visual', False)
+
+ self.infer_trans = data.Compose([
+ data.CenterCropV2(size=self.resolution),
+ data.ToTensor(),
+ data.Normalize(mean=self.mean, std=self.std)
+ ])
+
+ self.misc_transforms = data.Compose([
+ T.Lambda(partial(random_resize, size=self.misc_size)),
+ data.CenterCropV2(self.misc_size),
+ data.ToTensor()
+ ])
+
+ self.mv_transforms = data.Compose(
+ [T.Resize(size=self.resolution),
+ T.CenterCrop(self.resolution)])
+
+ self.vit_transforms = T.Compose([
+ CenterCropV3(self.vit_image_size),
+ T.ToTensor(),
+ T.Normalize(mean=self.vit_mean, std=self.vit_std)
+ ])
+
+ def preprocess(self, input: Input) -> Dict[str, Any]:
+ video_key = input['Video:FILE']
+ cap_txt = input['text']
+ style_image = input['Image:FILE']
+
+ total_frames = None
+
+ feature_framerate = self.feature_framerate
+ if os.path.exists(video_key):
+ try:
+ ref_frame, vit_image, video_data, misc_data, mv_data = self.video_data_preprocess(
+ video_key, self.feature_framerate, total_frames,
+ self.mvs_visual)
+ except Exception as e:
+ logger.info(
+ '{} get frames failed... with error: {}'.format(
+ video_key, e),
+ flush=True)
+
+ ref_frame = torch.zeros(3, self.vit_image_size,
+ self.vit_image_size)
+ video_data = torch.zeros(self.max_frames, 3,
+ self.image_resolution,
+ self.image_resolution)
+ misc_data = torch.zeros(self.max_frames, 3, self.misc_size,
+ self.misc_size)
+
+ mv_data = torch.zeros(self.max_frames, 2,
+ self.image_resolution,
+ self.image_resolution)
+ else:
+ logger.info(
+ 'The video path does not exist or no video dir provided!')
+ ref_frame = torch.zeros(3, self.vit_image_size,
+ self.vit_image_size)
+ _ = torch.zeros(3, self.vit_image_size, self.vit_image_size)
+ video_data = torch.zeros(self.max_frames, 3, self.image_resolution,
+ self.image_resolution)
+ misc_data = torch.zeros(self.max_frames, 3, self.misc_size,
+ self.misc_size)
+ mv_data = torch.zeros(self.max_frames, 2, self.image_resolution,
+ self.image_resolution)
+
+ # inpainting mask
+ p = random.random()
+ if p < 0.7:
+ mask = make_irregular_mask(512, 512)
+ elif p < 0.9:
+ mask = make_rectangle_mask(512, 512)
+ else:
+ mask = make_uncrop(512, 512)
+ mask = torch.from_numpy(
+ cv2.resize(
+ mask, (self.misc_size, self.misc_size),
+ interpolation=cv2.INTER_NEAREST)).unsqueeze(0).float()
+
+ mask = mask.unsqueeze(0).repeat_interleave(
+ repeats=self.max_frames, dim=0)
+ video_input = {
+ 'ref_frame': ref_frame.unsqueeze(0),
+ 'cap_txt': cap_txt,
+ 'video_data': video_data.unsqueeze(0),
+ 'misc_data': misc_data.unsqueeze(0),
+ 'feature_framerate': feature_framerate,
+ 'mask': mask.unsqueeze(0),
+ 'mv_data': mv_data.unsqueeze(0),
+ 'style_image': style_image
+ }
+ return video_input
+
+ def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
+ return self.model(input)
+
+ def postprocess(self, inputs: Dict[str, Any],
+ **post_params) -> Dict[str, Any]:
+ output_video_path = post_params.get('output_video', None)
+ temp_video_file = False
+ if output_video_path is not None:
+ output_video_path = tempfile.NamedTemporaryFile(suffix='.gif').name
+ temp_video_file = True
+
+ if temp_video_file:
+ return {OutputKeys.OUTPUT_VIDEO: inputs['video_path']}
+ else:
+ return {OutputKeys.OUTPUT_VIDEO: inputs['video']}
+
+ def video_data_preprocess(self, video_key, feature_framerate, total_frames,
+ visual_mv):
+
+ filename = video_key
+ for _ in range(5):
+ try:
+ frame_types, frames, mvs, mvs_visual = self.extract_motion_vectors(
+ input_video=filename,
+ fps=feature_framerate,
+ visual_mv=visual_mv)
+ break
+ except Exception as e:
+ logger.error(
+ '{} read video frames and motion vectors failed with error: {}'
+ .format(video_key, e),
+ flush=True)
+
+ total_frames = len(frame_types)
+ start_indexs = np.where((np.array(frame_types) == 'I') & (
+ total_frames - np.arange(total_frames) >= self.max_frames))[0]
+ start_index = np.random.choice(start_indexs)
+ indices = np.arange(start_index, start_index + self.max_frames)
+
+ # note frames are in BGR mode, need to trans to RGB mode
+ frames = [Image.fromarray(frames[i][:, :, ::-1]) for i in indices]
+ mvs = [torch.from_numpy(mvs[i].transpose((2, 0, 1))) for i in indices]
+ mvs = torch.stack(mvs)
+
+ if visual_mv:
+ images = [(mvs_visual[i][:, :, ::-1]).astype('uint8')
+ for i in indices]
+ path = self.log_dir + '/visual_mv/' + video_key.split(
+ '/')[-1] + '.gif'
+ if not os.path.exists(self.log_dir + '/visual_mv/'):
+ os.makedirs(self.log_dir + '/visual_mv/', exist_ok=True)
+ logger.info('save motion vectors visualization to :', path)
+ imageio.mimwrite(path, images, fps=8)
+
+ have_frames = len(frames) > 0
+ middle_indix = int(len(frames) / 2)
+ if have_frames:
+ ref_frame = frames[middle_indix]
+ vit_image = self.vit_transforms(ref_frame)
+ misc_imgs_np = self.misc_transforms[:2](frames)
+ misc_imgs = self.misc_transforms[2:](misc_imgs_np)
+ frames = self.infer_trans(frames)
+ mvs = self.mv_transforms(mvs)
+ else:
+ vit_image = torch.zeros(3, self.vit_image_size,
+ self.vit_image_size)
+
+ video_data = torch.zeros(self.max_frames, 3, self.image_resolution,
+ self.image_resolution)
+ mv_data = torch.zeros(self.max_frames, 2, self.image_resolution,
+ self.image_resolution)
+ misc_data = torch.zeros(self.max_frames, 3, self.misc_size,
+ self.misc_size)
+ if have_frames:
+ video_data[:len(frames), ...] = frames
+ misc_data[:len(frames), ...] = misc_imgs
+ mv_data[:len(frames), ...] = mvs
+
+ ref_frame = vit_image
+
+ del frames
+ del misc_imgs
+ del mvs
+
+ return ref_frame, vit_image, video_data, misc_data, mv_data
+
+ def extract_motion_vectors(self,
+ input_video,
+ fps=4,
+ dump=False,
+ verbose=False,
+ visual_mv=False):
+
+ if dump:
+ now = datetime.now().strftime('%Y-%m-%dT%H:%M:%S')
+ for child in ['frames', 'motion_vectors']:
+ os.makedirs(os.path.join(f'out-{now}', child), exist_ok=True)
+ temp = rand_name()
+ tmp_video = os.path.join(
+ input_video.split('/')[0], f'{temp}' + input_video.split('/')[-1])
+ videocapture = cv2.VideoCapture(input_video)
+ frames_num = videocapture.get(cv2.CAP_PROP_FRAME_COUNT)
+ fps_video = videocapture.get(cv2.CAP_PROP_FPS)
+ # check if enough frames
+ if frames_num / fps_video * fps > 16:
+ fps = max(fps, 1)
+ else:
+ fps = int(16 / (frames_num / fps_video)) + 1
+ ffmpeg_cmd = f'ffmpeg -threads 8 -loglevel error -i {input_video} -filter:v \
+ fps={fps} -c:v mpeg4 -f rawvideo {tmp_video}'
+
+ if os.path.exists(tmp_video):
+ os.remove(tmp_video)
+
+ subprocess.run(args=ffmpeg_cmd, shell=True, timeout=120)
+
+ cap = VideoCap()
+ # open the video file
+ ret = cap.open(tmp_video)
+ if not ret:
+ raise RuntimeError(f'Could not open {tmp_video}')
+
+ step = 0
+ times = []
+
+ frame_types = []
+ frames = []
+ mvs = []
+ mvs_visual = []
+ # continuously read and display video frames and motion vectors
+ while True:
+ if verbose:
+ logger.info('Frame: ', step, end=' ')
+
+ tstart = time.perf_counter()
+
+ # read next video frame and corresponding motion vectors
+ ret, frame, motion_vectors, frame_type, timestamp = cap.read()
+
+ tend = time.perf_counter()
+ telapsed = tend - tstart
+ times.append(telapsed)
+
+ # if there is an error reading the frame
+ if not ret:
+ if verbose:
+ logger.warning('No frame read. Stopping.')
+ break
+
+ frame_save = np.zeros(frame.copy().shape, dtype=np.uint8)
+ if visual_mv:
+ frame_save = draw_motion_vectors(frame_save, motion_vectors)
+
+ # store motion vectors, frames, etc. in output directory
+ dump = False
+ if frame.shape[1] >= frame.shape[0]:
+ w_half = (frame.shape[1] - frame.shape[0]) // 2
+ if dump:
+ cv2.imwrite(
+ os.path.join('./mv_visual/', f'frame-{step}.jpg'),
+ frame_save[:, w_half:-w_half])
+ mvs_visual.append(frame_save[:, w_half:-w_half])
+ else:
+ h_half = (frame.shape[0] - frame.shape[1]) // 2
+ if dump:
+ cv2.imwrite(
+ os.path.join('./mv_visual/', f'frame-{step}.jpg'),
+ frame_save[h_half:-h_half, :])
+ mvs_visual.append(frame_save[h_half:-h_half, :])
+
+ h, w = frame.shape[:2]
+ mv = np.zeros((h, w, 2))
+ position = motion_vectors[:, 5:7].clip((0, 0), (w - 1, h - 1))
+ mv[position[:, 1],
+ position[:,
+ 0]] = motion_vectors[:, 0:
+ 1] * motion_vectors[:, 7:
+ 9] / motion_vectors[:,
+ 9:]
+
+ step += 1
+ frame_types.append(frame_type)
+ frames.append(frame)
+ mvs.append(mv)
+ if verbose:
+ logger.info('average dt: ', np.mean(times))
+ cap.release()
+
+ if os.path.exists(tmp_video):
+ os.remove(tmp_video)
+
+ return frame_types, frames, mvs, mvs_visual
diff --git a/modelscope/preprocessors/multi_modal.py b/modelscope/preprocessors/multi_modal.py
index 5edc8c48..d180289b 100644
--- a/modelscope/preprocessors/multi_modal.py
+++ b/modelscope/preprocessors/multi_modal.py
@@ -1,4 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
+
import os.path as osp
import re
from io import BytesIO
diff --git a/tests/pipelines/test_videocomposer.py b/tests/pipelines/test_videocomposer.py
new file mode 100644
index 00000000..4cbca237
--- /dev/null
+++ b/tests/pipelines/test_videocomposer.py
@@ -0,0 +1,38 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import sys
+import unittest
+
+from modelscope.models import Model
+from modelscope.outputs import OutputKeys
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import DownloadMode, Tasks
+from modelscope.utils.test_utils import test_level
+
+
+class VideoDeinterlaceTest(unittest.TestCase):
+
+ def setUp(self) -> None:
+ self.task = Tasks.text_to_video_synthesis
+ self.model_id = 'buptwq/videocomposer'
+ self.model_revision = 'v1.0.1'
+ self.dataset_id = 'buptwq/videocomposer-depths-style'
+ self.text = 'A glittering and translucent fish swimming in a \
+ small glass bowl with multicolored piece of stone, like a glass fish'
+
+ @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
+ def test_run_pipeline(self):
+ pipe = pipeline(
+ task=Tasks.text_to_video_synthesis,
+ model=self.model_id,
+ model_revision=self.model_revision)
+ ds = MsDataset.load(
+ self.dataset_id,
+ split='train',
+ download_mode=DownloadMode.FORCE_REDOWNLOAD)
+ inputs = next(iter(ds))
+ inputs.update({'text': self.text})
+ _ = pipe(inputs)
+
+
+if __name__ == '__main__':
+ unittest.main()