From ee8afd2d62d276b306fe9ba8b5c672bbc175da06 Mon Sep 17 00:00:00 2001 From: Wang Qiang <37444407+XDUWQ@users.noreply.github.com> Date: Tue, 15 Aug 2023 12:01:03 +0800 Subject: [PATCH] VideoComposer: Compositional Video Synthesis with Motion Controllability (#431) * VideoComposer: Compositional Video Synthesis with Motion Controllability * videocomposer pipeline * pre commit * delete xformers --- modelscope/metainfo.py | 2 + modelscope/models/multi_modal/__init__.py | 2 + .../text_to_video_synthesis_model.py | 3 + .../multi_modal/videocomposer/__init__.py | 23 + .../videocomposer/annotator/__init__.py | 1 + .../videocomposer/annotator/canny/__init__.py | 44 + .../annotator/histogram/__init__.py | 3 + .../annotator/histogram/palette.py | 155 ++ .../annotator/sketch/__init__.py | 4 + .../videocomposer/annotator/sketch/pidinet.py | 940 ++++++++ .../annotator/sketch/sketch_simplification.py | 110 + .../videocomposer/annotator/util.py | 42 + .../multi_modal/videocomposer/autoencoder.py | 650 +++++ .../models/multi_modal/videocomposer/clip.py | 143 ++ .../multi_modal/videocomposer/config.py | 156 ++ .../videocomposer/configs/base.yaml | 2 + .../configs/exp01_vidcomposer_full.yaml | 20 + .../configs/exp02_motion_transfer.yaml | 23 + .../exp02_motion_transfer_vs_style.yaml | 24 + .../configs/exp03_sketch2video_style.yaml | 26 + .../configs/exp04_sketch2video_wo_style.yaml | 26 + .../configs/exp05_text_depths_wo_style.yaml | 26 + .../configs/exp06_text_depths_vs_style.yaml | 26 + .../exp10_vidcomposer_no_watermark_full.yaml | 21 + .../videocomposer/data/__init__.py | 5 + .../videocomposer/data/samplers.py | 158 ++ .../videocomposer/data/tokenizers.py | 184 ++ .../videocomposer/data/transforms.py | 400 ++++ .../multi_modal/videocomposer/diffusion.py | 1514 ++++++++++++ .../multi_modal/videocomposer/dpm_solver.py | 1697 +++++++++++++ .../multi_modal/videocomposer/mha_flash.py | 120 + .../videocomposer/models/__init__.py | 2 + .../multi_modal/videocomposer/models/clip.py | 460 ++++ .../multi_modal/videocomposer/models/midas.py | 320 +++ .../multi_modal/videocomposer/ops/__init__.py | 7 + .../videocomposer/ops/degration.py | 998 ++++++++ .../videocomposer/ops/distributed.py | 460 ++++ .../multi_modal/videocomposer/ops/losses.py | 37 + .../videocomposer/ops/random_mask.py | 81 + .../multi_modal/videocomposer/ops/utils.py | 1037 ++++++++ .../multi_modal/videocomposer/unet_sd.py | 2102 +++++++++++++++++ .../videocomposer/utils/__init__.py | 0 .../multi_modal/videocomposer/utils/config.py | 273 +++ .../videocomposer/utils/distributed.py | 297 +++ .../multi_modal/videocomposer/utils/utils.py | 955 ++++++++ .../videocomposer/videocomposer_model.py | 480 ++++ modelscope/pipelines/multi_modal/__init__.py | 4 +- .../multi_modal/videocomposer_pipeline.py | 382 +++ modelscope/preprocessors/multi_modal.py | 1 + tests/pipelines/test_videocomposer.py | 38 + 50 files changed, 14483 insertions(+), 1 deletion(-) create mode 100644 modelscope/models/multi_modal/videocomposer/__init__.py create mode 100644 modelscope/models/multi_modal/videocomposer/annotator/__init__.py create mode 100644 modelscope/models/multi_modal/videocomposer/annotator/canny/__init__.py create mode 100644 modelscope/models/multi_modal/videocomposer/annotator/histogram/__init__.py create mode 100644 modelscope/models/multi_modal/videocomposer/annotator/histogram/palette.py create mode 100644 modelscope/models/multi_modal/videocomposer/annotator/sketch/__init__.py create mode 100644 modelscope/models/multi_modal/videocomposer/annotator/sketch/pidinet.py create mode 100644 modelscope/models/multi_modal/videocomposer/annotator/sketch/sketch_simplification.py create mode 100644 modelscope/models/multi_modal/videocomposer/annotator/util.py create mode 100644 modelscope/models/multi_modal/videocomposer/autoencoder.py create mode 100644 modelscope/models/multi_modal/videocomposer/clip.py create mode 100644 modelscope/models/multi_modal/videocomposer/config.py create mode 100644 modelscope/models/multi_modal/videocomposer/configs/base.yaml create mode 100644 modelscope/models/multi_modal/videocomposer/configs/exp01_vidcomposer_full.yaml create mode 100644 modelscope/models/multi_modal/videocomposer/configs/exp02_motion_transfer.yaml create mode 100644 modelscope/models/multi_modal/videocomposer/configs/exp02_motion_transfer_vs_style.yaml create mode 100644 modelscope/models/multi_modal/videocomposer/configs/exp03_sketch2video_style.yaml create mode 100644 modelscope/models/multi_modal/videocomposer/configs/exp04_sketch2video_wo_style.yaml create mode 100644 modelscope/models/multi_modal/videocomposer/configs/exp05_text_depths_wo_style.yaml create mode 100644 modelscope/models/multi_modal/videocomposer/configs/exp06_text_depths_vs_style.yaml create mode 100644 modelscope/models/multi_modal/videocomposer/configs/exp10_vidcomposer_no_watermark_full.yaml create mode 100644 modelscope/models/multi_modal/videocomposer/data/__init__.py create mode 100644 modelscope/models/multi_modal/videocomposer/data/samplers.py create mode 100644 modelscope/models/multi_modal/videocomposer/data/tokenizers.py create mode 100644 modelscope/models/multi_modal/videocomposer/data/transforms.py create mode 100644 modelscope/models/multi_modal/videocomposer/diffusion.py create mode 100644 modelscope/models/multi_modal/videocomposer/dpm_solver.py create mode 100644 modelscope/models/multi_modal/videocomposer/mha_flash.py create mode 100644 modelscope/models/multi_modal/videocomposer/models/__init__.py create mode 100644 modelscope/models/multi_modal/videocomposer/models/clip.py create mode 100644 modelscope/models/multi_modal/videocomposer/models/midas.py create mode 100644 modelscope/models/multi_modal/videocomposer/ops/__init__.py create mode 100644 modelscope/models/multi_modal/videocomposer/ops/degration.py create mode 100644 modelscope/models/multi_modal/videocomposer/ops/distributed.py create mode 100644 modelscope/models/multi_modal/videocomposer/ops/losses.py create mode 100644 modelscope/models/multi_modal/videocomposer/ops/random_mask.py create mode 100644 modelscope/models/multi_modal/videocomposer/ops/utils.py create mode 100644 modelscope/models/multi_modal/videocomposer/unet_sd.py create mode 100644 modelscope/models/multi_modal/videocomposer/utils/__init__.py create mode 100644 modelscope/models/multi_modal/videocomposer/utils/config.py create mode 100644 modelscope/models/multi_modal/videocomposer/utils/distributed.py create mode 100644 modelscope/models/multi_modal/videocomposer/utils/utils.py create mode 100644 modelscope/models/multi_modal/videocomposer/videocomposer_model.py create mode 100644 modelscope/pipelines/multi_modal/videocomposer_pipeline.py create mode 100644 tests/pipelines/test_videocomposer.py diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index df5e8a75..70fd9c86 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -218,6 +218,7 @@ class Models(object): mplug_owl = 'mplug-owl' clip_interrogator = 'clip-interrogator' stable_diffusion = 'stable-diffusion' + videocomposer = 'videocomposer' text_to_360panorama_image = 'text-to-360panorama-image' # science models @@ -525,6 +526,7 @@ class Pipelines(object): multi_modal_similarity = 'multi-modal-similarity' text_to_image_synthesis = 'text-to-image-synthesis' video_multi_modal_embedding = 'video-multi-modal-embedding' + videocomposer = 'videocomposer' image_text_retrieval = 'image-text-retrieval' ofa_ocr_recognition = 'ofa-ocr-recognition' ofa_asr = 'ofa-asr' diff --git a/modelscope/models/multi_modal/__init__.py b/modelscope/models/multi_modal/__init__.py index 9fa34baf..31af94b2 100644 --- a/modelscope/models/multi_modal/__init__.py +++ b/modelscope/models/multi_modal/__init__.py @@ -22,6 +22,7 @@ if TYPE_CHECKING: from .efficient_diffusion_tuning import EfficientStableDiffusion from .mplug_owl import MplugOwlForConditionalGeneration from .clip_interrogator import CLIP_Interrogator + from .videocomposer import VideoComposer else: _import_structure = { @@ -42,6 +43,7 @@ else: 'efficient_diffusion_tuning': ['EfficientStableDiffusion'], 'mplug_owl': ['MplugOwlForConditionalGeneration'], 'clip_interrogator': ['CLIP_Interrogator'], + 'videocomposer': ['VideoComposer'], } import sys diff --git a/modelscope/models/multi_modal/video_synthesis/text_to_video_synthesis_model.py b/modelscope/models/multi_modal/video_synthesis/text_to_video_synthesis_model.py index 93a7e3ba..0ec66069 100644 --- a/modelscope/models/multi_modal/video_synthesis/text_to_video_synthesis_model.py +++ b/modelscope/models/multi_modal/video_synthesis/text_to_video_synthesis_model.py @@ -19,6 +19,9 @@ from modelscope.models.multi_modal.video_synthesis.diffusion import ( from modelscope.models.multi_modal.video_synthesis.unet_sd import UNetSD from modelscope.utils.config import Config from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() __all__ = ['TextToVideoSynthesis'] diff --git a/modelscope/models/multi_modal/videocomposer/__init__.py b/modelscope/models/multi_modal/videocomposer/__init__.py new file mode 100644 index 00000000..ae9102c0 --- /dev/null +++ b/modelscope/models/multi_modal/videocomposer/__init__.py @@ -0,0 +1,23 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + + from .videocomposer_model import VideoComposer + +else: + _import_structure = { + 'videocomposer_model': ['VideoComposer'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/multi_modal/videocomposer/annotator/__init__.py b/modelscope/models/multi_modal/videocomposer/annotator/__init__.py new file mode 100644 index 00000000..b937315b --- /dev/null +++ b/modelscope/models/multi_modal/videocomposer/annotator/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. diff --git a/modelscope/models/multi_modal/videocomposer/annotator/canny/__init__.py b/modelscope/models/multi_modal/videocomposer/annotator/canny/__init__.py new file mode 100644 index 00000000..bf653d47 --- /dev/null +++ b/modelscope/models/multi_modal/videocomposer/annotator/canny/__init__.py @@ -0,0 +1,44 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import cv2 +import numpy as np +import torch +from tools.annotator.util import HWC3 + + +class CannyDetector: + + def __call__(self, + img, + low_threshold=None, + high_threshold=None, + random_threshold=True): + + # Convert to numpy + if isinstance(img, torch.Tensor): # (h, w, c) + img = img.cpu().numpy() + img_np = cv2.convertScaleAbs((img * 255.)) + elif isinstance(img, np.ndarray): # (h, w, c) + img_np = img # we assume values are in the range from 0 to 255. + else: + assert False + + # Select the threshold + if (low_threshold is None) and (high_threshold is None): + median_intensity = np.median(img_np) + if random_threshold is False: + low_threshold = int(max(0, (1 - 0.33) * median_intensity)) + high_threshold = int(min(255, (1 + 0.33) * median_intensity)) + else: + random_canny = np.random.uniform(0.1, 0.4) + # Might try other values + low_threshold = int( + max(0, (1 - random_canny) * median_intensity)) + high_threshold = 2 * low_threshold + + # Detect canny edge + canny_edge = cv2.Canny(img_np, low_threshold, high_threshold) + + canny_condition = torch.from_numpy( + canny_edge.copy()).unsqueeze(dim=-1).float().cuda() / 255.0 + return canny_condition diff --git a/modelscope/models/multi_modal/videocomposer/annotator/histogram/__init__.py b/modelscope/models/multi_modal/videocomposer/annotator/histogram/__init__.py new file mode 100644 index 00000000..62f092a1 --- /dev/null +++ b/modelscope/models/multi_modal/videocomposer/annotator/histogram/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from .palette import * diff --git a/modelscope/models/multi_modal/videocomposer/annotator/histogram/palette.py b/modelscope/models/multi_modal/videocomposer/annotator/histogram/palette.py new file mode 100644 index 00000000..d28ba85b --- /dev/null +++ b/modelscope/models/multi_modal/videocomposer/annotator/histogram/palette.py @@ -0,0 +1,155 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +r"""Modified from ``https://github.com/sergeyk/rayleigh''. +""" +import os +import os.path as osp + +import numpy as np +from skimage.color import hsv2rgb, lab2rgb, rgb2lab +from skimage.io import imsave +from sklearn.metrics import euclidean_distances + +__all__ = ['Palette'] + + +def rgb2hex(rgb): + return '#%02x%02x%02x' % tuple([int(round(255.0 * u)) for u in rgb]) + + +def hex2rgb(hex): + rgb = hex.strip('#') + fn = lambda u: round(int(u, 16) / 255.0, 5) # noqa + return fn(rgb[:2]), fn(rgb[2:4]), fn(rgb[4:6]) + + +class Palette(object): + r"""Create a color palette (codebook) in the form of a 2D grid of colors. + Further, the rightmost column has num_hues gradations from black to white. + + Parameters: + num_hues: number of colors with full lightness and saturation, in the middle. + num_sat: number of rows above middle row that show the same hues with decreasing saturation. + """ + + def __init__(self, num_hues=11, num_sat=5, num_light=4): + n = num_sat + 2 * num_light + + # hues + if num_hues == 8: + hues = np.tile( + np.array([0., 0.10, 0.15, 0.28, 0.51, 0.58, 0.77, 0.85]), + (n, 1)) + elif num_hues == 9: + hues = np.tile( + np.array([0., 0.10, 0.15, 0.28, 0.49, 0.54, 0.60, 0.7, 0.87]), + (n, 1)) + elif num_hues == 10: + hues = np.tile( + np.array( + [0., 0.10, 0.15, 0.28, 0.49, 0.54, 0.60, 0.66, 0.76, + 0.87]), (n, 1)) + elif num_hues == 11: + hues = np.tile( + np.array([ + 0.0, 0.0833, 0.166, 0.25, 0.333, 0.5, 0.56333, 0.666, 0.73, + 0.803, 0.916 + ]), (n, 1)) + else: + hues = np.tile(np.linspace(0, 1, num_hues + 1)[:-1], (n, 1)) + + # saturations + sats = np.hstack(( + np.linspace(0, 1, num_sat + 2)[1:-1], + 1, + [1] * num_light, + [0.4] * # noqa + (num_light - 1))) + sats = np.tile(np.atleast_2d(sats).T, (1, num_hues)) + + # lights + lights = np.hstack( + ([1] * num_sat, 1, np.linspace(1, 0.2, num_light + 2)[1:-1], + np.linspace(1, 0.2, num_light + 2)[1:-2])) + lights = np.tile(np.atleast_2d(lights).T, (1, num_hues)) + + # colors + rgb = hsv2rgb(np.dstack([hues, sats, lights])) + gray = np.tile( + np.linspace(1, 0, n)[:, np.newaxis, np.newaxis], (1, 1, 3)) + self.thumbnail = np.hstack([rgb, gray]) + + # flatten + rgb = rgb.T.reshape(3, -1).T + gray = gray.T.reshape(3, -1).T + self.rgb = np.vstack((rgb, gray)) + self.lab = rgb2lab(self.rgb[np.newaxis, :, :]).squeeze() + self.hex = [rgb2hex(u) for u in self.rgb] + self.lab_dists = euclidean_distances(self.lab, squared=True) + + def histogram(self, rgb_img, sigma=20): + # compute histogram + lab = rgb2lab(rgb_img).reshape((-1, 3)) + min_ind = np.argmin( + euclidean_distances(lab, self.lab, squared=True), axis=1) + hist = 1.0 * np.bincount( + min_ind, minlength=self.lab.shape[0]) / lab.shape[0] + + # smooth histogram + if sigma > 0: + weight = np.exp(-self.lab_dists / (2.0 * sigma**2)) + weight = weight / weight.sum(1)[:, np.newaxis] + hist = (weight * hist).sum(1) + hist[hist < 1e-5] = 0 + return hist + + def get_palette_image(self, hist, percentile=90, width=200, height=50): + # curate histogram + ind = np.argsort(-hist) + ind = ind[hist[ind] > np.percentile(hist, percentile)] + hist = hist[ind] / hist[ind].sum() + + # draw palette + nums = np.array(hist * width, dtype=int) + array = np.vstack([ + np.tile(np.array(u), (v, 1)) for u, v in zip(self.rgb[ind], nums) + ]) + array = np.tile(array[np.newaxis, :, :], (height, 1, 1)) + if array.shape[1] < width: + array = np.concatenate( + [array, np.zeros((height, width - array.shape[1], 3))], axis=1) + return array + + def quantize_image(self, rgb_img): + lab = rgb2lab(rgb_img).reshape((-1, 3)) + min_ind = np.argmin( + euclidean_distances(lab, self.lab, squared=True), axis=1) + quantized_lab = self.lab[min_ind] + img = lab2rgb(quantized_lab.reshape(rgb_img.shape)) + return img + + def export(self, dirname): + if not osp.exists(dirname): + os.makedirs(dirname) + + # save thumbnail + imsave(osp.join(dirname, 'palette.png'), self.thumbnail) + + # save html + with open(osp.join(dirname, 'palette.html'), 'w') as f: + html = ''' + + ''' + for row in self.thumbnail: + for col in row: + html += '\n'.format( + rgb2hex(col)) + html += '
\n' + f.write(html) diff --git a/modelscope/models/multi_modal/videocomposer/annotator/sketch/__init__.py b/modelscope/models/multi_modal/videocomposer/annotator/sketch/__init__.py new file mode 100644 index 00000000..51e28cc7 --- /dev/null +++ b/modelscope/models/multi_modal/videocomposer/annotator/sketch/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from .pidinet import * +from .sketch_simplification import * diff --git a/modelscope/models/multi_modal/videocomposer/annotator/sketch/pidinet.py b/modelscope/models/multi_modal/videocomposer/annotator/sketch/pidinet.py new file mode 100644 index 00000000..86c17d5d --- /dev/null +++ b/modelscope/models/multi_modal/videocomposer/annotator/sketch/pidinet.py @@ -0,0 +1,940 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +r"""Modified from ``https://github.com/zhuoinoulu/pidinet''. + Image augmentation: T.Compose([ + T.ToTensor(), + T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]). +""" +import math +import os + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from modelscope.models.multi_modal.videocomposer.utils.utils import \ + DOWNLOAD_TO_CACHE + +__all__ = [ + 'PiDiNet', 'pidinet_bsd_tiny', 'pidinet_bsd_small', 'pidinet_bsd', + 'pidinet_nyud', 'pidinet_multicue' +] + +CONFIGS = { + 'baseline': { + 'layer0': 'cv', + 'layer1': 'cv', + 'layer2': 'cv', + 'layer3': 'cv', + 'layer4': 'cv', + 'layer5': 'cv', + 'layer6': 'cv', + 'layer7': 'cv', + 'layer8': 'cv', + 'layer9': 'cv', + 'layer10': 'cv', + 'layer11': 'cv', + 'layer12': 'cv', + 'layer13': 'cv', + 'layer14': 'cv', + 'layer15': 'cv', + }, + 'c-v15': { + 'layer0': 'cd', + 'layer1': 'cv', + 'layer2': 'cv', + 'layer3': 'cv', + 'layer4': 'cv', + 'layer5': 'cv', + 'layer6': 'cv', + 'layer7': 'cv', + 'layer8': 'cv', + 'layer9': 'cv', + 'layer10': 'cv', + 'layer11': 'cv', + 'layer12': 'cv', + 'layer13': 'cv', + 'layer14': 'cv', + 'layer15': 'cv', + }, + 'a-v15': { + 'layer0': 'ad', + 'layer1': 'cv', + 'layer2': 'cv', + 'layer3': 'cv', + 'layer4': 'cv', + 'layer5': 'cv', + 'layer6': 'cv', + 'layer7': 'cv', + 'layer8': 'cv', + 'layer9': 'cv', + 'layer10': 'cv', + 'layer11': 'cv', + 'layer12': 'cv', + 'layer13': 'cv', + 'layer14': 'cv', + 'layer15': 'cv', + }, + 'r-v15': { + 'layer0': 'rd', + 'layer1': 'cv', + 'layer2': 'cv', + 'layer3': 'cv', + 'layer4': 'cv', + 'layer5': 'cv', + 'layer6': 'cv', + 'layer7': 'cv', + 'layer8': 'cv', + 'layer9': 'cv', + 'layer10': 'cv', + 'layer11': 'cv', + 'layer12': 'cv', + 'layer13': 'cv', + 'layer14': 'cv', + 'layer15': 'cv', + }, + 'cvvv4': { + 'layer0': 'cd', + 'layer1': 'cv', + 'layer2': 'cv', + 'layer3': 'cv', + 'layer4': 'cd', + 'layer5': 'cv', + 'layer6': 'cv', + 'layer7': 'cv', + 'layer8': 'cd', + 'layer9': 'cv', + 'layer10': 'cv', + 'layer11': 'cv', + 'layer12': 'cd', + 'layer13': 'cv', + 'layer14': 'cv', + 'layer15': 'cv', + }, + 'avvv4': { + 'layer0': 'ad', + 'layer1': 'cv', + 'layer2': 'cv', + 'layer3': 'cv', + 'layer4': 'ad', + 'layer5': 'cv', + 'layer6': 'cv', + 'layer7': 'cv', + 'layer8': 'ad', + 'layer9': 'cv', + 'layer10': 'cv', + 'layer11': 'cv', + 'layer12': 'ad', + 'layer13': 'cv', + 'layer14': 'cv', + 'layer15': 'cv', + }, + 'rvvv4': { + 'layer0': 'rd', + 'layer1': 'cv', + 'layer2': 'cv', + 'layer3': 'cv', + 'layer4': 'rd', + 'layer5': 'cv', + 'layer6': 'cv', + 'layer7': 'cv', + 'layer8': 'rd', + 'layer9': 'cv', + 'layer10': 'cv', + 'layer11': 'cv', + 'layer12': 'rd', + 'layer13': 'cv', + 'layer14': 'cv', + 'layer15': 'cv', + }, + 'cccv4': { + 'layer0': 'cd', + 'layer1': 'cd', + 'layer2': 'cd', + 'layer3': 'cv', + 'layer4': 'cd', + 'layer5': 'cd', + 'layer6': 'cd', + 'layer7': 'cv', + 'layer8': 'cd', + 'layer9': 'cd', + 'layer10': 'cd', + 'layer11': 'cv', + 'layer12': 'cd', + 'layer13': 'cd', + 'layer14': 'cd', + 'layer15': 'cv', + }, + 'aaav4': { + 'layer0': 'ad', + 'layer1': 'ad', + 'layer2': 'ad', + 'layer3': 'cv', + 'layer4': 'ad', + 'layer5': 'ad', + 'layer6': 'ad', + 'layer7': 'cv', + 'layer8': 'ad', + 'layer9': 'ad', + 'layer10': 'ad', + 'layer11': 'cv', + 'layer12': 'ad', + 'layer13': 'ad', + 'layer14': 'ad', + 'layer15': 'cv', + }, + 'rrrv4': { + 'layer0': 'rd', + 'layer1': 'rd', + 'layer2': 'rd', + 'layer3': 'cv', + 'layer4': 'rd', + 'layer5': 'rd', + 'layer6': 'rd', + 'layer7': 'cv', + 'layer8': 'rd', + 'layer9': 'rd', + 'layer10': 'rd', + 'layer11': 'cv', + 'layer12': 'rd', + 'layer13': 'rd', + 'layer14': 'rd', + 'layer15': 'cv', + }, + 'c16': { + 'layer0': 'cd', + 'layer1': 'cd', + 'layer2': 'cd', + 'layer3': 'cd', + 'layer4': 'cd', + 'layer5': 'cd', + 'layer6': 'cd', + 'layer7': 'cd', + 'layer8': 'cd', + 'layer9': 'cd', + 'layer10': 'cd', + 'layer11': 'cd', + 'layer12': 'cd', + 'layer13': 'cd', + 'layer14': 'cd', + 'layer15': 'cd', + }, + 'a16': { + 'layer0': 'ad', + 'layer1': 'ad', + 'layer2': 'ad', + 'layer3': 'ad', + 'layer4': 'ad', + 'layer5': 'ad', + 'layer6': 'ad', + 'layer7': 'ad', + 'layer8': 'ad', + 'layer9': 'ad', + 'layer10': 'ad', + 'layer11': 'ad', + 'layer12': 'ad', + 'layer13': 'ad', + 'layer14': 'ad', + 'layer15': 'ad', + }, + 'r16': { + 'layer0': 'rd', + 'layer1': 'rd', + 'layer2': 'rd', + 'layer3': 'rd', + 'layer4': 'rd', + 'layer5': 'rd', + 'layer6': 'rd', + 'layer7': 'rd', + 'layer8': 'rd', + 'layer9': 'rd', + 'layer10': 'rd', + 'layer11': 'rd', + 'layer12': 'rd', + 'layer13': 'rd', + 'layer14': 'rd', + 'layer15': 'rd', + }, + 'carv4': { + 'layer0': 'cd', + 'layer1': 'ad', + 'layer2': 'rd', + 'layer3': 'cv', + 'layer4': 'cd', + 'layer5': 'ad', + 'layer6': 'rd', + 'layer7': 'cv', + 'layer8': 'cd', + 'layer9': 'ad', + 'layer10': 'rd', + 'layer11': 'cv', + 'layer12': 'cd', + 'layer13': 'ad', + 'layer14': 'rd', + 'layer15': 'cv' + } +} + + +def create_conv_func(op_type): + assert op_type in ['cv', 'cd', 'ad', + 'rd'], 'unknown op type: %s' % str(op_type) + if op_type == 'cv': + return F.conv2d + if op_type == 'cd': + + def func(x, + weights, + bias=None, + stride=1, + padding=0, + dilation=1, + groups=1): + assert dilation in [1, + 2], 'dilation for cd_conv should be in 1 or 2' + assert weights.size(2) == 3 and weights.size( + 3) == 3, 'kernel size for cd_conv should be 3x3' + assert padding == dilation, 'padding for cd_conv set wrong' + + weights_c = weights.sum(dim=[2, 3], keepdim=True) + yc = F.conv2d( + x, weights_c, stride=stride, padding=0, groups=groups) + y = F.conv2d( + x, + weights, + bias, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups) + return y - yc + + return func + elif op_type == 'ad': + + def func(x, + weights, + bias=None, + stride=1, + padding=0, + dilation=1, + groups=1): + assert dilation in [1, + 2], 'dilation for ad_conv should be in 1 or 2' + assert weights.size(2) == 3 and weights.size( + 3) == 3, 'kernel size for ad_conv should be 3x3' + assert padding == dilation, 'padding for ad_conv set wrong' + + shape = weights.shape + weights = weights.view(shape[0], shape[1], -1) + weights_conv = (weights + - weights[:, :, [3, 0, 1, 6, 4, 2, 7, 8, 5]]).view( + shape) # clock-wise + y = F.conv2d( + x, + weights_conv, + bias, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups) + return y + + return func + elif op_type == 'rd': + + def func(x, + weights, + bias=None, + stride=1, + padding=0, + dilation=1, + groups=1): + assert dilation in [1, + 2], 'dilation for rd_conv should be in 1 or 2' + assert weights.size(2) == 3 and weights.size( + 3) == 3, 'kernel size for rd_conv should be 3x3' + padding = 2 * dilation + + shape = weights.shape + if weights.is_cuda: + buffer = torch.cuda.FloatTensor(shape[0], shape[1], + 5 * 5).fill_(0) + else: + buffer = torch.zeros(shape[0], shape[1], 5 * 5) + weights = weights.view(shape[0], shape[1], -1) + buffer[:, :, [0, 2, 4, 10, 14, 20, 22, 24]] = weights[:, :, 1:] + buffer[:, :, [6, 7, 8, 11, 13, 16, 17, 18]] = -weights[:, :, 1:] + buffer[:, :, 12] = 0 + buffer = buffer.view(shape[0], shape[1], 5, 5) + y = F.conv2d( + x, + buffer, + bias, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups) + return y + + return func + else: + print('impossible to be here unless you force that', flush=True) + return None + + +def config_model(model): + model_options = list(CONFIGS.keys()) + assert model in model_options, \ + 'unrecognized model, please choose from %s' % str(model_options) + + pdcs = [] + for i in range(16): + layer_name = 'layer%d' % i + op = CONFIGS[model][layer_name] + pdcs.append(create_conv_func(op)) + return pdcs + + +def config_model_converted(model): + model_options = list(CONFIGS.keys()) + assert model in model_options, \ + 'unrecognized model, please choose from %s' % str(model_options) + + pdcs = [] + for i in range(16): + layer_name = 'layer%d' % i + op = CONFIGS[model][layer_name] + pdcs.append(op) + return pdcs + + +def convert_pdc(op, weight): + if op == 'cv': + return weight + elif op == 'cd': + shape = weight.shape + weight_c = weight.sum(dim=[2, 3]) + weight = weight.view(shape[0], shape[1], -1) + weight[:, :, 4] = weight[:, :, 4] - weight_c + weight = weight.view(shape) + return weight + elif op == 'ad': + shape = weight.shape + weight = weight.view(shape[0], shape[1], -1) + weight_conv = (weight + - weight[:, :, [3, 0, 1, 6, 4, 2, 7, 8, 5]]).view(shape) + return weight_conv + elif op == 'rd': + shape = weight.shape + buffer = torch.zeros(shape[0], shape[1], 5 * 5, device=weight.device) + weight = weight.view(shape[0], shape[1], -1) + buffer[:, :, [0, 2, 4, 10, 14, 20, 22, 24]] = weight[:, :, 1:] + buffer[:, :, [6, 7, 8, 11, 13, 16, 17, 18]] = -weight[:, :, 1:] + buffer = buffer.view(shape[0], shape[1], 5, 5) + return buffer + raise ValueError('wrong op {}'.format(str(op))) + + +def convert_pidinet(state_dict, config): + pdcs = config_model_converted(config) + new_dict = {} + for pname, p in state_dict.items(): + if 'init_block.weight' in pname: + new_dict[pname] = convert_pdc(pdcs[0], p) + elif 'block1_1.conv1.weight' in pname: + new_dict[pname] = convert_pdc(pdcs[1], p) + elif 'block1_2.conv1.weight' in pname: + new_dict[pname] = convert_pdc(pdcs[2], p) + elif 'block1_3.conv1.weight' in pname: + new_dict[pname] = convert_pdc(pdcs[3], p) + elif 'block2_1.conv1.weight' in pname: + new_dict[pname] = convert_pdc(pdcs[4], p) + elif 'block2_2.conv1.weight' in pname: + new_dict[pname] = convert_pdc(pdcs[5], p) + elif 'block2_3.conv1.weight' in pname: + new_dict[pname] = convert_pdc(pdcs[6], p) + elif 'block2_4.conv1.weight' in pname: + new_dict[pname] = convert_pdc(pdcs[7], p) + elif 'block3_1.conv1.weight' in pname: + new_dict[pname] = convert_pdc(pdcs[8], p) + elif 'block3_2.conv1.weight' in pname: + new_dict[pname] = convert_pdc(pdcs[9], p) + elif 'block3_3.conv1.weight' in pname: + new_dict[pname] = convert_pdc(pdcs[10], p) + elif 'block3_4.conv1.weight' in pname: + new_dict[pname] = convert_pdc(pdcs[11], p) + elif 'block4_1.conv1.weight' in pname: + new_dict[pname] = convert_pdc(pdcs[12], p) + elif 'block4_2.conv1.weight' in pname: + new_dict[pname] = convert_pdc(pdcs[13], p) + elif 'block4_3.conv1.weight' in pname: + new_dict[pname] = convert_pdc(pdcs[14], p) + elif 'block4_4.conv1.weight' in pname: + new_dict[pname] = convert_pdc(pdcs[15], p) + else: + new_dict[pname] = p + return new_dict + + +class Conv2d(nn.Module): + + def __init__(self, + pdc, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=False): + super(Conv2d, self).__init__() + if in_channels % groups != 0: + raise ValueError('in_channels must be divisible by groups') + if out_channels % groups != 0: + raise ValueError('out_channels must be divisible by groups') + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.dilation = dilation + self.groups = groups + self.weight = nn.Parameter( + torch.Tensor(out_channels, in_channels // groups, kernel_size, + kernel_size)) + if bias: + self.bias = nn.Parameter(torch.Tensor(out_channels)) + else: + self.register_parameter('bias', None) + self.reset_parameters() + self.pdc = pdc + + def reset_parameters(self): + nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + if self.bias is not None: + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) + bound = 1 / math.sqrt(fan_in) + nn.init.uniform_(self.bias, -bound, bound) + + def forward(self, input): + return self.pdc(input, self.weight, self.bias, self.stride, + self.padding, self.dilation, self.groups) + + +class CSAM(nn.Module): + r""" + Compact Spatial Attention Module + """ + + def __init__(self, channels): + super(CSAM, self).__init__() + + mid_channels = 4 + self.relu1 = nn.ReLU() + self.conv1 = nn.Conv2d( + channels, mid_channels, kernel_size=1, padding=0) + self.conv2 = nn.Conv2d( + mid_channels, 1, kernel_size=3, padding=1, bias=False) + self.sigmoid = nn.Sigmoid() + nn.init.constant_(self.conv1.bias, 0) + + def forward(self, x): + y = self.relu1(x) + y = self.conv1(y) + y = self.conv2(y) + y = self.sigmoid(y) + + return x * y + + +class CDCM(nn.Module): + r""" + Compact Dilation Convolution based Module + """ + + def __init__(self, in_channels, out_channels): + super(CDCM, self).__init__() + + self.relu1 = nn.ReLU() + self.conv1 = nn.Conv2d( + in_channels, out_channels, kernel_size=1, padding=0) + self.conv2_1 = nn.Conv2d( + out_channels, + out_channels, + kernel_size=3, + dilation=5, + padding=5, + bias=False) + self.conv2_2 = nn.Conv2d( + out_channels, + out_channels, + kernel_size=3, + dilation=7, + padding=7, + bias=False) + self.conv2_3 = nn.Conv2d( + out_channels, + out_channels, + kernel_size=3, + dilation=9, + padding=9, + bias=False) + self.conv2_4 = nn.Conv2d( + out_channels, + out_channels, + kernel_size=3, + dilation=11, + padding=11, + bias=False) + nn.init.constant_(self.conv1.bias, 0) + + def forward(self, x): + x = self.relu1(x) + x = self.conv1(x) + x1 = self.conv2_1(x) + x2 = self.conv2_2(x) + x3 = self.conv2_3(x) + x4 = self.conv2_4(x) + return x1 + x2 + x3 + x4 + + +class MapReduce(nn.Module): + r""" + Reduce feature maps into a single edge map + """ + + def __init__(self, channels): + super(MapReduce, self).__init__() + self.conv = nn.Conv2d(channels, 1, kernel_size=1, padding=0) + nn.init.constant_(self.conv.bias, 0) + + def forward(self, x): + return self.conv(x) + + +class PDCBlock(nn.Module): + + def __init__(self, pdc, inplane, ouplane, stride=1): + super(PDCBlock, self).__init__() + self.stride = stride + + self.stride = stride + if self.stride > 1: + self.pool = nn.MaxPool2d(kernel_size=2, stride=2) + self.shortcut = nn.Conv2d( + inplane, ouplane, kernel_size=1, padding=0) + self.conv1 = Conv2d( + pdc, + inplane, + inplane, + kernel_size=3, + padding=1, + groups=inplane, + bias=False) + self.relu2 = nn.ReLU() + self.conv2 = nn.Conv2d( + inplane, ouplane, kernel_size=1, padding=0, bias=False) + + def forward(self, x): + if self.stride > 1: + x = self.pool(x) + y = self.conv1(x) + y = self.relu2(y) + y = self.conv2(y) + if self.stride > 1: + x = self.shortcut(x) + y = y + x + return y + + +class PDCBlock_converted(nn.Module): + r""" + CPDC, APDC can be converted to vanilla 3x3 convolution + RPDC can be converted to vanilla 5x5 convolution + """ + + def __init__(self, pdc, inplane, ouplane, stride=1): + super(PDCBlock_converted, self).__init__() + self.stride = stride + + if self.stride > 1: + self.pool = nn.MaxPool2d(kernel_size=2, stride=2) + self.shortcut = nn.Conv2d( + inplane, ouplane, kernel_size=1, padding=0) + if pdc == 'rd': + self.conv1 = nn.Conv2d( + inplane, + inplane, + kernel_size=5, + padding=2, + groups=inplane, + bias=False) + else: + self.conv1 = nn.Conv2d( + inplane, + inplane, + kernel_size=3, + padding=1, + groups=inplane, + bias=False) + self.relu2 = nn.ReLU() + self.conv2 = nn.Conv2d( + inplane, ouplane, kernel_size=1, padding=0, bias=False) + + def forward(self, x): + if self.stride > 1: + x = self.pool(x) + y = self.conv1(x) + y = self.relu2(y) + y = self.conv2(y) + if self.stride > 1: + x = self.shortcut(x) + y = y + x + return y + + +class PiDiNet(nn.Module): + + def __init__(self, inplane, pdcs, dil=None, sa=False, convert=False): + super(PiDiNet, self).__init__() + self.sa = sa + if dil is not None: + assert isinstance(dil, int), 'dil should be an int' + self.dil = dil + + self.fuseplanes = [] + + self.inplane = inplane + if convert: + if pdcs[0] == 'rd': + init_kernel_size = 5 + init_padding = 2 + else: + init_kernel_size = 3 + init_padding = 1 + self.init_block = nn.Conv2d( + 3, + self.inplane, + kernel_size=init_kernel_size, + padding=init_padding, + bias=False) + block_class = PDCBlock_converted + else: + self.init_block = Conv2d( + pdcs[0], 3, self.inplane, kernel_size=3, padding=1) + block_class = PDCBlock + + self.block1_1 = block_class(pdcs[1], self.inplane, self.inplane) + self.block1_2 = block_class(pdcs[2], self.inplane, self.inplane) + self.block1_3 = block_class(pdcs[3], self.inplane, self.inplane) + self.fuseplanes.append(self.inplane) # C + + inplane = self.inplane + self.inplane = self.inplane * 2 + self.block2_1 = block_class(pdcs[4], inplane, self.inplane, stride=2) + self.block2_2 = block_class(pdcs[5], self.inplane, self.inplane) + self.block2_3 = block_class(pdcs[6], self.inplane, self.inplane) + self.block2_4 = block_class(pdcs[7], self.inplane, self.inplane) + self.fuseplanes.append(self.inplane) # 2C + + inplane = self.inplane + self.inplane = self.inplane * 2 + self.block3_1 = block_class(pdcs[8], inplane, self.inplane, stride=2) + self.block3_2 = block_class(pdcs[9], self.inplane, self.inplane) + self.block3_3 = block_class(pdcs[10], self.inplane, self.inplane) + self.block3_4 = block_class(pdcs[11], self.inplane, self.inplane) + self.fuseplanes.append(self.inplane) # 4C + + self.block4_1 = block_class( + pdcs[12], self.inplane, self.inplane, stride=2) + self.block4_2 = block_class(pdcs[13], self.inplane, self.inplane) + self.block4_3 = block_class(pdcs[14], self.inplane, self.inplane) + self.block4_4 = block_class(pdcs[15], self.inplane, self.inplane) + self.fuseplanes.append(self.inplane) # 4C + + self.conv_reduces = nn.ModuleList() + if self.sa and self.dil is not None: + self.attentions = nn.ModuleList() + self.dilations = nn.ModuleList() + for i in range(4): + self.dilations.append(CDCM(self.fuseplanes[i], self.dil)) + self.attentions.append(CSAM(self.dil)) + self.conv_reduces.append(MapReduce(self.dil)) + elif self.sa: + self.attentions = nn.ModuleList() + for i in range(4): + self.attentions.append(CSAM(self.fuseplanes[i])) + self.conv_reduces.append(MapReduce(self.fuseplanes[i])) + elif self.dil is not None: + self.dilations = nn.ModuleList() + for i in range(4): + self.dilations.append(CDCM(self.fuseplanes[i], self.dil)) + self.conv_reduces.append(MapReduce(self.dil)) + else: + for i in range(4): + self.conv_reduces.append(MapReduce(self.fuseplanes[i])) + + self.classifier = nn.Conv2d(4, 1, kernel_size=1) + nn.init.constant_(self.classifier.weight, 0.25) + nn.init.constant_(self.classifier.bias, 0) + + def get_weights(self): + conv_weights = [] + bn_weights = [] + relu_weights = [] + for pname, p in self.named_parameters(): + if 'bn' in pname: + bn_weights.append(p) + elif 'relu' in pname: + relu_weights.append(p) + else: + conv_weights.append(p) + + return conv_weights, bn_weights, relu_weights + + def forward(self, x): + H, W = x.size()[2:] + + x = self.init_block(x) + + x1 = self.block1_1(x) + x1 = self.block1_2(x1) + x1 = self.block1_3(x1) + + x2 = self.block2_1(x1) + x2 = self.block2_2(x2) + x2 = self.block2_3(x2) + x2 = self.block2_4(x2) + + x3 = self.block3_1(x2) + x3 = self.block3_2(x3) + x3 = self.block3_3(x3) + x3 = self.block3_4(x3) + + x4 = self.block4_1(x3) + x4 = self.block4_2(x4) + x4 = self.block4_3(x4) + x4 = self.block4_4(x4) + + x_fuses = [] + if self.sa and self.dil is not None: + for i, xi in enumerate([x1, x2, x3, x4]): + x_fuses.append(self.attentions[i](self.dilations[i](xi))) + elif self.sa: + for i, xi in enumerate([x1, x2, x3, x4]): + x_fuses.append(self.attentions[i](xi)) + elif self.dil is not None: + for i, xi in enumerate([x1, x2, x3, x4]): + x_fuses.append(self.dilations[i](xi)) + else: + x_fuses = [x1, x2, x3, x4] + + e1 = self.conv_reduces[0](x_fuses[0]) + e1 = F.interpolate(e1, (H, W), mode='bilinear', align_corners=False) + + e2 = self.conv_reduces[1](x_fuses[1]) + e2 = F.interpolate(e2, (H, W), mode='bilinear', align_corners=False) + + e3 = self.conv_reduces[2](x_fuses[2]) + e3 = F.interpolate(e3, (H, W), mode='bilinear', align_corners=False) + + e4 = self.conv_reduces[3](x_fuses[3]) + e4 = F.interpolate(e4, (H, W), mode='bilinear', align_corners=False) + + outputs = [e1, e2, e3, e4] + output = self.classifier(torch.cat(outputs, dim=1)) + + outputs.append(output) + outputs = [torch.sigmoid(r) for r in outputs] + return outputs[-1] + + +def pidinet_bsd_tiny(pretrained=False, vanilla_cnn=True): + pdcs = config_model_converted('carv4') if vanilla_cnn else config_model( + 'carv4') + model = PiDiNet(20, pdcs, dil=8, sa=True, convert=vanilla_cnn) + if pretrained: + state = torch.load( + DOWNLOAD_TO_CACHE('models/pidinet/table5_pidinet-tiny.pth'), + map_location='cpu')['state_dict'] + if vanilla_cnn: + state = convert_pidinet(state, 'carv4') + state = { + k[len('module.'):] if k.startswith('module.') else k: v + for k, v in state.items() + } + model.load_state_dict(state) + return model + + +def pidinet_bsd_small(pretrained=False, vanilla_cnn=True): + pdcs = config_model_converted('carv4') if vanilla_cnn else config_model( + 'carv4') + model = PiDiNet(30, pdcs, dil=12, sa=True, convert=vanilla_cnn) + if pretrained: + state = torch.load( + DOWNLOAD_TO_CACHE('models/pidinet/table5_pidinet-small.pth'), + map_location='cpu')['state_dict'] + if vanilla_cnn: + state = convert_pidinet(state, 'carv4') + state = { + k[len('module.'):] if k.startswith('module.') else k: v + for k, v in state.items() + } + model.load_state_dict(state) + return model + + +def pidinet_bsd(model_dir, pretrained=False, vanilla_cnn=True): + pdcs = config_model_converted('carv4') if vanilla_cnn else config_model( + 'carv4') + model = PiDiNet(60, pdcs, dil=24, sa=True, convert=vanilla_cnn) + if pretrained: + state = torch.load( + os.path.join(model_dir, 'table5_pidinet.pth'), + map_location='cpu')['state_dict'] + if vanilla_cnn: + state = convert_pidinet(state, 'carv4') + state = { + k[len('module.'):] if k.startswith('module.') else k: v + for k, v in state.items() + } + model.load_state_dict(state) + return model + + +def pidinet_nyud(pretrained=False, vanilla_cnn=True): + pdcs = config_model_converted('carv4') if vanilla_cnn else config_model( + 'carv4') + model = PiDiNet(60, pdcs, dil=24, sa=True, convert=vanilla_cnn) + if pretrained: + state = torch.load( + DOWNLOAD_TO_CACHE('models/pidinet/table6_pidinet.pth'), + map_location='cpu')['state_dict'] + if vanilla_cnn: + state = convert_pidinet(state, 'carv4') + state = { + k[len('module.'):] if k.startswith('module.') else k: v + for k, v in state.items() + } + model.load_state_dict(state) + return model + + +def pidinet_multicue(pretrained=False, vanilla_cnn=True): + pdcs = config_model_converted('carv4') if vanilla_cnn else config_model( + 'carv4') + model = PiDiNet(60, pdcs, dil=24, sa=True, convert=vanilla_cnn) + if pretrained: + state = torch.load( + DOWNLOAD_TO_CACHE('models/pidinet/table7_pidinet.pth'), + map_location='cpu')['state_dict'] + if vanilla_cnn: + state = convert_pidinet(state, 'carv4') + state = { + k[len('module.'):] if k.startswith('module.') else k: v + for k, v in state.items() + } + model.load_state_dict(state) + return model diff --git a/modelscope/models/multi_modal/videocomposer/annotator/sketch/sketch_simplification.py b/modelscope/models/multi_modal/videocomposer/annotator/sketch/sketch_simplification.py new file mode 100644 index 00000000..2555d593 --- /dev/null +++ b/modelscope/models/multi_modal/videocomposer/annotator/sketch/sketch_simplification.py @@ -0,0 +1,110 @@ +r"""PyTorch re-implementation adapted from the Lua code in ``https://github.com/bobbens/sketch_simplification''. +""" +import math +import os + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from modelscope.models.multi_modal.videocomposer.utils.utils import \ + DOWNLOAD_TO_CACHE + +__all__ = [ + 'SketchSimplification', 'sketch_simplification_gan', + 'sketch_simplification_mse', 'sketch_to_pencil_v1', 'sketch_to_pencil_v2' +] + + +class SketchSimplification(nn.Module): + r"""NOTE: + 1. Input image should has only one gray channel. + 2. Input image size should be divisible by 8. + 3. Sketch in the input/output image is in dark color while background in light color. + """ + + def __init__(self, mean, std): + assert isinstance(mean, float) and isinstance(std, float) + super(SketchSimplification, self).__init__() + self.mean = mean + self.std = std + + # layers + self.layers = nn.Sequential( + nn.Conv2d(1, 48, 5, 2, 2), nn.ReLU(inplace=True), + nn.Conv2d(48, 128, 3, 1, 1), nn.ReLU(inplace=True), + nn.Conv2d(128, 128, 3, 1, 1), nn.ReLU(inplace=True), + nn.Conv2d(128, 128, 3, 2, 1), nn.ReLU(inplace=True), + nn.Conv2d(128, 256, 3, 1, 1), nn.ReLU(inplace=True), + nn.Conv2d(256, 256, 3, 1, 1), nn.ReLU(inplace=True), + nn.Conv2d(256, 256, 3, 2, 1), nn.ReLU(inplace=True), + nn.Conv2d(256, 512, 3, 1, 1), nn.ReLU(inplace=True), + nn.Conv2d(512, 1024, 3, 1, 1), nn.ReLU(inplace=True), + nn.Conv2d(1024, 1024, 3, 1, 1), nn.ReLU(inplace=True), + nn.Conv2d(1024, 1024, 3, 1, 1), nn.ReLU(inplace=True), + nn.Conv2d(1024, 1024, 3, 1, 1), nn.ReLU(inplace=True), + nn.Conv2d(1024, 512, 3, 1, 1), nn.ReLU(inplace=True), + nn.Conv2d(512, 256, 3, 1, 1), nn.ReLU(inplace=True), + nn.ConvTranspose2d(256, 256, 4, 2, 1), nn.ReLU(inplace=True), + nn.Conv2d(256, 256, 3, 1, 1), nn.ReLU(inplace=True), + nn.Conv2d(256, 128, 3, 1, 1), nn.ReLU(inplace=True), + nn.ConvTranspose2d(128, 128, 4, 2, 1), nn.ReLU(inplace=True), + nn.Conv2d(128, 128, 3, 1, 1), nn.ReLU(inplace=True), + nn.Conv2d(128, 48, 3, 1, 1), nn.ReLU(inplace=True), + nn.ConvTranspose2d(48, 48, 4, 2, 1), nn.ReLU(inplace=True), + nn.Conv2d(48, 24, 3, 1, 1), nn.ReLU(inplace=True), + nn.Conv2d(24, 1, 3, 1, 1), nn.Sigmoid()) + + def forward(self, x): + r"""x: [B, 1, H, W] within range [0, 1]. Sketch pixels in dark color. + """ + x = (x - self.mean) / self.std + return self.layers(x) + + +def sketch_simplification_gan(model_dir, pretrained=False): + model = SketchSimplification( + mean=0.9664114577640158, std=0.0858381272736797) + if pretrained: + model.load_state_dict( + torch.load( + os.path.join(model_dir, 'sketch_simplification_gan.pth'), + map_location='cpu')) + return model + + +def sketch_simplification_mse(pretrained=False): + model = SketchSimplification( + mean=0.9664423107454593, std=0.08583666033640507) + if pretrained: + model.load_state_dict( + torch.load( + DOWNLOAD_TO_CACHE( + 'models/sketch_simplification/sketch_simplification_mse.pth' + ), + map_location='cpu')) + return model + + +def sketch_to_pencil_v1(pretrained=False): + model = SketchSimplification( + mean=0.9817833515894078, std=0.0925009022585048) + if pretrained: + model.load_state_dict( + torch.load( + DOWNLOAD_TO_CACHE( + 'models/sketch_simplification/sketch_to_pencil_v1.pth'), + map_location='cpu')) + return model + + +def sketch_to_pencil_v2(pretrained=False): + model = SketchSimplification( + mean=0.9851298627337799, std=0.07418377454883571) + if pretrained: + model.load_state_dict( + torch.load( + DOWNLOAD_TO_CACHE( + 'models/sketch_simplification/sketch_to_pencil_v2.pth'), + map_location='cpu')) + return model diff --git a/modelscope/models/multi_modal/videocomposer/annotator/util.py b/modelscope/models/multi_modal/videocomposer/annotator/util.py new file mode 100644 index 00000000..b02755c9 --- /dev/null +++ b/modelscope/models/multi_modal/videocomposer/annotator/util.py @@ -0,0 +1,42 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os + +import cv2 +import numpy as np + +annotator_ckpts_path = os.path.join(os.path.dirname(__file__), 'ckpts') + + +def HWC3(x): + assert x.dtype == np.uint8 + if x.ndim == 2: + x = x[:, :, None] + assert x.ndim == 3 + H, W, C = x.shape + assert C == 1 or C == 3 or C == 4 + if C == 3: + return x + if C == 1: + return np.concatenate([x, x, x], axis=2) + if C == 4: + color = x[:, :, 0:3].astype(np.float32) + alpha = x[:, :, 3:4].astype(np.float32) / 255.0 + y = color * alpha + 255.0 * (1.0 - alpha) + y = y.clip(0, 255).astype(np.uint8) + return y + + +def resize_image(input_image, resolution): + H, W, C = input_image.shape + H = float(H) + W = float(W) + k = float(resolution) / min(H, W) + H *= k + W *= k + H = int(np.round(H / 64.0)) * 64 + W = int(np.round(W / 64.0)) * 64 + img = cv2.resize( + input_image, (W, H), + interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA) + return img diff --git a/modelscope/models/multi_modal/videocomposer/autoencoder.py b/modelscope/models/multi_modal/videocomposer/autoencoder.py new file mode 100644 index 00000000..99b65f87 --- /dev/null +++ b/modelscope/models/multi_modal/videocomposer/autoencoder.py @@ -0,0 +1,650 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +__all__ = ['AutoencoderKL'] + + +def nonlinearity(x): + # swish + return x * torch.sigmoid(x) + + +def Normalize(in_channels, num_groups=32): + return torch.nn.GroupNorm( + num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + + +class DiagonalGaussianDistribution(object): + + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like( + self.mean).to(device=self.parameters.device) + + def sample(self): + x = self.mean + self.std * torch.randn( + self.mean.shape).to(device=self.parameters.device) + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.]) + else: + if other is None: + return 0.5 * torch.sum( + torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, + dim=[1, 2, 3]) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var - 1.0 - self.logvar + other.logvar, + dim=[1, 2, 3]) + + def nll(self, sample, dims=[1, 2, 3]): + if self.deterministic: + return torch.Tensor([0.]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + + torch.pow(sample - self.mean, 2) / self.var, + dim=dims) + + def mode(self): + return self.mean + + +class Downsample(nn.Module): + + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=2, padding=0) + + def forward(self, x): + if self.with_conv: + pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, pad, mode='constant', value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class ResnetBlock(nn.Module): + + def __init__(self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout, + temb_channels=512): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d( + out_channels, out_channels, kernel_size=3, stride=1, padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + else: + self.nin_shortcut = torch.nn.Conv2d( + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x + h + + +class AttnBlock(nn.Module): + + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = q.reshape(b, c, h * w) + q = q.permute(0, 2, 1) # b,hw,c + k = k.reshape(b, c, h * w) # b,c,hw + w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h * w) + w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm( + v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b, c, h, w) + + h_ = self.proj_out(h_) + + return x + h_ + + +class AttnBlock(nn.Module): # noqa + + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = q.reshape(b, c, h * w) + q = q.permute(0, 2, 1) # b,hw,c + k = k.reshape(b, c, h * w) # b,c,hw + w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h * w) + w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm( + v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b, c, h, w) + + h_ = self.proj_out(h_) + + return x + h_ + + +class Upsample(nn.Module): + + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x): + x = torch.nn.functional.interpolate( + x, scale_factor=2.0, mode='nearest') + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): # noqa + + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=2, padding=0) + + def forward(self, x): + if self.with_conv: + pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, pad, mode='constant', value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class Encoder(nn.Module): + + def __init__(self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + double_z=True, + use_linear_attn=False, + attn_type='vanilla', + **ignore_kwargs): + super().__init__() + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # downsampling + self.conv_in = torch.nn.Conv2d( + in_channels, self.ch, kernel_size=3, stride=1, padding=1) + + curr_res = resolution + in_ch_mult = (1, ) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, + 2 * z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + # timestep embedding + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + + def __init__(self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + give_pre_end=False, + tanh_out=False, + use_linear_attn=False, + attn_type='vanilla', + **ignorekwargs): + super().__init__() + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + self.tanh_out = tanh_out + + # compute in_ch_mult, block_in and curr_res at lowest res + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2**(self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + print('Working with z of shape {} = {} dimensions.'.format( + self.z_shape, np.prod(self.z_shape))) + + # z to block_in + self.conv_in = torch.nn.Conv2d( + z_channels, block_in, kernel_size=3, stride=1, padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, out_ch, kernel_size=3, stride=1, padding=1) + + def forward(self, z): + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h, temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + if self.tanh_out: + h = torch.tanh(h) + return h + + +class AutoencoderKL(nn.Module): + + def __init__(self, + ddconfig, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key='image', + colorize_nlabels=None, + monitor=None, + ema_decay=None, + learn_logvar=False): + super().__init__() + self.learn_logvar = learn_logvar + self.image_key = image_key + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + assert ddconfig['double_z'] + self.quant_conv = torch.nn.Conv2d(2 * ddconfig['z_channels'], + 2 * embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, + ddconfig['z_channels'], 1) + self.embed_dim = embed_dim + if colorize_nlabels is not None: + assert type(colorize_nlabels) == int + self.register_buffer('colorize', + torch.randn(3, colorize_nlabels, 1, 1)) + if monitor is not None: + self.monitor = monitor + self.use_ema = ema_decay is not None + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location='cpu')['state_dict'] + keys = list(sd.keys()) + for key in keys: + print(key, sd[key].shape) + import collections + sd_new = collections.OrderedDict() + for k in keys: + if k.find('first_stage_model') >= 0: + k_new = k.split('first_stage_model.')[-1] + sd_new[k_new] = sd[k] + self.load_state_dict(sd_new, strict=True) + print(f'Restored from {path}') + + def init_from_ckpt2(self, path, ignore_keys=list()): + sd = torch.load(path, map_location='cpu')['state_dict'] + keys = list(sd.keys()) + + first_stage_model + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print('Deleting key {} from state_dict.'.format(k)) + del sd[k] + self.load_state_dict(sd, strict=False) + print(f'Restored from {path}') + + def on_train_batch_end(self, *args, **kwargs): + if self.use_ema: + self.model_ema(self) + + def encode(self, x): + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def decode(self, z): + z = self.post_quant_conv(z) + dec = self.decoder(z) + return dec + + def forward(self, input, sample_posterior=True): + posterior = self.encode(input) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + dec = self.decode(z) + return dec, posterior + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = x.permute(0, 3, 1, + 2).to(memory_format=torch.contiguous_format).float() + return x + + def get_last_layer(self): + return self.decoder.conv_out.weight + + @torch.no_grad() + def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + if not only_inputs: + xrec, posterior = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec.shape[1] > 3 + x = self.to_rgb(x) + xrec = self.to_rgb(xrec) + log['samples'] = self.decode(torch.randn_like(posterior.sample())) + log['reconstructions'] = xrec + if log_ema or self.use_ema: + with self.ema_scope(): + xrec_ema, posterior_ema = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec_ema.shape[1] > 3 + xrec_ema = self.to_rgb(xrec_ema) + log['samples_ema'] = self.decode( + torch.randn_like(posterior_ema.sample())) + log['reconstructions_ema'] = xrec_ema + log['inputs'] = x + return log + + def to_rgb(self, x): + assert self.image_key == 'segmentation' + if not hasattr(self, 'colorize'): + self.register_buffer('colorize', + torch.randn(3, x.shape[1], 1, 1).to(x)) + x = F.conv2d(x, weight=self.colorize) + x = 2. * (x - x.min()) / (x.max() - x.min()) - 1. + return x + + +class IdentityFirstStage(torch.nn.Module): + + def __init__(self, *args, vq_interface=False, **kwargs): + self.vq_interface = vq_interface + super().__init__() + + def encode(self, x, *args, **kwargs): + return x + + def decode(self, x, *args, **kwargs): + return x + + def quantize(self, x, *args, **kwargs): + if self.vq_interface: + return x, None, [None, None, None] + return x + + def forward(self, x, *args, **kwargs): + return x diff --git a/modelscope/models/multi_modal/videocomposer/clip.py b/modelscope/models/multi_modal/videocomposer/clip.py new file mode 100644 index 00000000..f55a5206 --- /dev/null +++ b/modelscope/models/multi_modal/videocomposer/clip.py @@ -0,0 +1,143 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import numpy as np +import open_clip +import torch +import torch.nn as nn +import torchvision.transforms as T + + +class FrozenOpenCLIPEmbedder(nn.Module): + """ + Uses the OpenCLIP transformer encoder for text + """ + LAYERS = ['last', 'penultimate'] + + def __init__(self, + arch='ViT-H-14', + pretrained='laion2b_s32b_b79k', + device='cuda', + max_length=77, + freeze=True, + layer='last'): + super().__init__() + assert layer in self.LAYERS + model, _, _ = open_clip.create_model_and_transforms( + arch, device=torch.device('cpu'), pretrained=pretrained) + del model.visual + self.model = model + + self.device = device + self.max_length = max_length + if freeze: + self.freeze() + self.layer = layer + if self.layer == 'last': + self.layer_idx = 0 + elif self.layer == 'penultimate': + self.layer_idx = 1 + else: + raise NotImplementedError() + + def freeze(self): + self.model = self.model.eval() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + tokens = open_clip.tokenize(text) + z = self.encode_with_transformer(tokens.to(self.device)) + return z + + def encode_with_transformer(self, text): + x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] + x = x + self.model.positional_embedding + x = x.permute(1, 0, 2) # NLD -> LND + x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.model.ln_final(x) + return x + + def text_transformer_forward(self, x: torch.Tensor, attn_mask=None): + for i, r in enumerate(self.model.transformer.resblocks): + if i == len(self.model.transformer.resblocks) - self.layer_idx: + break + if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting( + ): + x = checkpoint(r, x, attn_mask) + else: + x = r(x, attn_mask=attn_mask) + return x + + def encode(self, text): + return self(text) + + +class FrozenOpenCLIPVisualEmbedder(nn.Module): + """ + Uses the OpenCLIP transformer encoder for text + """ + LAYERS = ['last', 'penultimate'] + + def __init__(self, + arch='ViT-H-14', + pretrained='laion2b_s32b_b79k', + device='cuda', + max_length=77, + freeze=True, + layer='last', + input_shape=(224, 224, 3)): + super().__init__() + assert layer in self.LAYERS + model, _, preprocess = open_clip.create_model_and_transforms( + arch, device=torch.device('cpu'), pretrained=pretrained) + del model.transformer + self.model = model + data_white = np.ones(input_shape, dtype=np.uint8) * 255 + self.black_image = preprocess(T.ToPILImage()(data_white)).unsqueeze(0) + self.preprocess = preprocess + + self.device = device + self.max_length = max_length # 77 + if freeze: + self.freeze() + self.layer = layer # 'penultimate' + if self.layer == 'last': + self.layer_idx = 0 + elif self.layer == 'penultimate': + self.layer_idx = 1 + else: + raise NotImplementedError() + + def freeze(self): + self.model = self.model.eval() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, image): + # tokens = open_clip.tokenize(text) + z = self.model.encode_image(image.to(self.device)) + return z + + def encode_with_transformer(self, text): + x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] + x = x + self.model.positional_embedding + x = x.permute(1, 0, 2) # NLD -> LND + x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.model.ln_final(x) + return x + + def text_transformer_forward(self, x: torch.Tensor, attn_mask=None): + for i, r in enumerate(self.model.transformer.resblocks): + if i == len(self.model.transformer.resblocks) - self.layer_idx: + break + if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting( + ): + x = checkpoint(r, x, attn_mask) + else: + x = r(x, attn_mask=attn_mask) + return x + + def encode(self, text): + return self(text) diff --git a/modelscope/models/multi_modal/videocomposer/config.py b/modelscope/models/multi_modal/videocomposer/config.py new file mode 100644 index 00000000..499d8484 --- /dev/null +++ b/modelscope/models/multi_modal/videocomposer/config.py @@ -0,0 +1,156 @@ +import logging +import os +import os.path as osp +from datetime import datetime + +import torch +from easydict import EasyDict + +cfg = EasyDict(__name__='Config: VideoComposer') + +pmi_world_size = int(os.getenv('WORLD_SIZE', 1)) +gpus_per_machine = torch.cuda.device_count() +world_size = pmi_world_size * gpus_per_machine + +cfg.video_compositions = [ + 'text', 'mask', 'depthmap', 'sketch', 'motion', 'image', 'local_image', + 'single_sketch' +] + +# dataset +cfg.root_dir = 'webvid10m/' + +cfg.alpha = 0.7 + +cfg.misc_size = 384 +cfg.depth_std = 20.0 +cfg.depth_clamp = 10.0 +cfg.hist_sigma = 10.0 + +cfg.use_image_dataset = False +cfg.alpha_img = 0.7 + +cfg.resolution = 256 +cfg.mean = [0.5, 0.5, 0.5] +cfg.std = [0.5, 0.5, 0.5] + +# sketch +cfg.sketch_mean = [0.485, 0.456, 0.406] +cfg.sketch_std = [0.229, 0.224, 0.225] + +# dataloader +cfg.max_words = 1000 + +cfg.frame_lens = [ + 16, + 16, + 16, + 16, +] +cfg.feature_framerates = [ + 4, +] +cfg.feature_framerate = 4 +cfg.batch_sizes = { + str(1): 1, + str(4): 1, + str(8): 1, + str(16): 1, +} + +cfg.chunk_size = 64 +cfg.num_workers = 8 +cfg.prefetch_factor = 2 +cfg.seed = 8888 + +# diffusion +cfg.num_timesteps = 1000 +cfg.mean_type = 'eps' +cfg.var_type = 'fixed_small' +cfg.loss_type = 'mse' +cfg.ddim_timesteps = 50 +cfg.ddim_eta = 0.0 +cfg.clamp = 1.0 +cfg.share_noise = False +cfg.use_div_loss = False + +# classifier-free guidance +cfg.p_zero = 0.9 +cfg.guide_scale = 6.0 + +# stabel diffusion +cfg.sd_checkpoint = 'v2-1_512-ema-pruned.ckpt' + +# clip vision encoder +cfg.vit_image_size = 336 +cfg.vit_patch_size = 14 +cfg.vit_dim = 1024 +cfg.vit_out_dim = 768 +cfg.vit_heads = 16 +cfg.vit_layers = 24 +cfg.vit_mean = [0.48145466, 0.4578275, 0.40821073] +cfg.vit_std = [0.26862954, 0.26130258, 0.27577711] +cfg.clip_checkpoint = 'open_clip_pytorch_model.bin' +cfg.mvs_visual = False + +# unet +cfg.unet_in_dim = 4 +cfg.unet_concat_dim = 8 +cfg.unet_y_dim = cfg.vit_out_dim +cfg.unet_context_dim = 1024 +cfg.unet_out_dim = 8 if cfg.var_type.startswith('learned') else 4 +cfg.unet_dim = 320 +cfg.unet_dim_mult = [1, 2, 4, 4] +cfg.unet_res_blocks = 2 +cfg.unet_num_heads = 8 +cfg.unet_head_dim = 64 +cfg.unet_attn_scales = [1 / 1, 1 / 2, 1 / 4] +cfg.unet_dropout = 0.1 +cfg.misc_dropout = 0.5 +cfg.p_all_zero = 0.1 +cfg.p_all_keep = 0.1 +cfg.temporal_conv = False +cfg.temporal_attn_times = 1 +cfg.temporal_attention = True + +cfg.use_fps_condition = False +cfg.use_sim_mask = False + +# Default: load 2d pretrain +cfg.pretrained = False +cfg.fix_weight = False + +# Default resume +cfg.resume = True +cfg.resume_step = 148000 +cfg.resume_check_dir = '.' +cfg.resume_checkpoint = os.path.join( + cfg.resume_check_dir, + f'step_{cfg.resume_step}/non_ema_{cfg.resume_step}.pth') +cfg.resume_optimizer = False +if cfg.resume_optimizer: + cfg.resume_optimizer = os.path.join( + cfg.resume_check_dir, f'optimizer_step_{cfg.resume_step}.pt') + +# acceleration +cfg.use_ema = True +# for debug, no ema +if world_size < 2: + cfg.use_ema = False +cfg.load_from = None + +cfg.use_checkpoint = True +cfg.use_sharded_ddp = False +cfg.use_fsdp = False +cfg.use_fp16 = True + +# training +cfg.ema_decay = 0.9999 +cfg.viz_interval = 1000 +cfg.save_ckp_interval = 1000 + +# logging +cfg.log_interval = 100 +composition_strings = '_'.join(cfg.video_compositions) +# Default log_dir +cfg.log_dir = 'outputs/' diff --git a/modelscope/models/multi_modal/videocomposer/configs/base.yaml b/modelscope/models/multi_modal/videocomposer/configs/base.yaml new file mode 100644 index 00000000..42f756f8 --- /dev/null +++ b/modelscope/models/multi_modal/videocomposer/configs/base.yaml @@ -0,0 +1,2 @@ +ENABLE: true +DATASET: webvid10m diff --git a/modelscope/models/multi_modal/videocomposer/configs/exp01_vidcomposer_full.yaml b/modelscope/models/multi_modal/videocomposer/configs/exp01_vidcomposer_full.yaml new file mode 100644 index 00000000..ec312138 --- /dev/null +++ b/modelscope/models/multi_modal/videocomposer/configs/exp01_vidcomposer_full.yaml @@ -0,0 +1,20 @@ +TASK_TYPE: MULTI_TASK +ENABLE: true +DATASET: webvid10m +video_compositions: ['text', 'mask', 'depthmap', 'sketch', 'motion', 'image', 'local_image', 'single_sketch'] +batch_sizes: { + "1": 1, + "4": 1, + "8": 1, + "16": 1, +} +vit_image_size: 224 +network_name: UNetSD_temporal +resume: true +resume_step: 228000 +num_workers: 1 +mvs_visual: False +chunk_size: 1 +resume_checkpoint: "model_weights/non_ema_228000.pth" +log_dir: 'outputs' +num_steps: 1 diff --git a/modelscope/models/multi_modal/videocomposer/configs/exp02_motion_transfer.yaml b/modelscope/models/multi_modal/videocomposer/configs/exp02_motion_transfer.yaml new file mode 100644 index 00000000..4b756d32 --- /dev/null +++ b/modelscope/models/multi_modal/videocomposer/configs/exp02_motion_transfer.yaml @@ -0,0 +1,23 @@ +TASK_TYPE: SINGLE_TASK +read_image: True # You NEED Open It +ENABLE: true +DATASET: webvid10m +video_compositions: ['text', 'mask', 'depthmap', 'sketch', 'motion', 'image', 'local_image', 'single_sketch'] +guidances: ['y', 'local_image', 'motion'] # You NEED Open It +batch_sizes: { + "1": 1, + "4": 1, + "8": 1, + "16": 1, +} +vit_image_size: 224 +network_name: UNetSD_temporal +resume: true +resume_step: 228000 +seed: 182 +num_workers: 0 +mvs_visual: False +chunk_size: 1 +resume_checkpoint: "model_weights/non_ema_228000.pth" +log_dir: 'outputs' +num_steps: 1 diff --git a/modelscope/models/multi_modal/videocomposer/configs/exp02_motion_transfer_vs_style.yaml b/modelscope/models/multi_modal/videocomposer/configs/exp02_motion_transfer_vs_style.yaml new file mode 100644 index 00000000..7928e7ba --- /dev/null +++ b/modelscope/models/multi_modal/videocomposer/configs/exp02_motion_transfer_vs_style.yaml @@ -0,0 +1,24 @@ +TASK_TYPE: SINGLE_TASK +read_image: True # You NEED Open It +read_style: True +ENABLE: true +DATASET: webvid10m +video_compositions: ['text', 'mask', 'depthmap', 'sketch', 'motion', 'image', 'local_image', 'single_sketch'] +guidances: ['y', 'local_image', 'image', 'motion'] # You NEED Open It +batch_sizes: { + "1": 1, + "4": 1, + "8": 1, + "16": 1, +} +vit_image_size: 224 +network_name: UNetSD_temporal +resume: true +resume_step: 228000 +seed: 182 +num_workers: 0 +mvs_visual: False +chunk_size: 1 +resume_checkpoint: "model_weights/non_ema_228000.pth" +log_dir: 'outputs' +num_steps: 1 diff --git a/modelscope/models/multi_modal/videocomposer/configs/exp03_sketch2video_style.yaml b/modelscope/models/multi_modal/videocomposer/configs/exp03_sketch2video_style.yaml new file mode 100644 index 00000000..fd710ee5 --- /dev/null +++ b/modelscope/models/multi_modal/videocomposer/configs/exp03_sketch2video_style.yaml @@ -0,0 +1,26 @@ +TASK_TYPE: SINGLE_TASK +read_image: False # You NEED Open It +read_style: True +read_sketch: True +save_origin_video: False +ENABLE: true +DATASET: webvid10m +video_compositions: ['text', 'mask', 'depthmap', 'sketch', 'motion', 'image', 'local_image', 'single_sketch'] +guidances: ['y', 'image', 'single_sketch'] # You NEED Open It +batch_sizes: { + "1": 1, + "4": 1, + "8": 1, + "16": 1, +} +vit_image_size: 224 +network_name: UNetSD_temporal +resume: true +resume_step: 228000 +seed: 182 +num_workers: 0 +mvs_visual: False +chunk_size: 1 +resume_checkpoint: "model_weights/non_ema_228000.pth" +log_dir: 'outputs' +num_steps: 1 diff --git a/modelscope/models/multi_modal/videocomposer/configs/exp04_sketch2video_wo_style.yaml b/modelscope/models/multi_modal/videocomposer/configs/exp04_sketch2video_wo_style.yaml new file mode 100644 index 00000000..a5cc54bf --- /dev/null +++ b/modelscope/models/multi_modal/videocomposer/configs/exp04_sketch2video_wo_style.yaml @@ -0,0 +1,26 @@ +TASK_TYPE: SINGLE_TASK +read_image: False # You NEED Open It +read_style: False +read_sketch: True +save_origin_video: False +ENABLE: true +DATASET: webvid10m +video_compositions: ['text', 'mask', 'depthmap', 'sketch', 'motion', 'image', 'local_image', 'single_sketch'] +guidances: ['y', 'single_sketch'] # You NEED Open It +batch_sizes: { + "1": 1, + "4": 1, + "8": 1, + "16": 1, +} +vit_image_size: 224 +network_name: UNetSD_temporal +resume: true +resume_step: 228000 +seed: 182 +num_workers: 0 +mvs_visual: False +chunk_size: 1 +resume_checkpoint: "model_weights/non_ema_228000.pth" +log_dir: 'outputs' +num_steps: 1 diff --git a/modelscope/models/multi_modal/videocomposer/configs/exp05_text_depths_wo_style.yaml b/modelscope/models/multi_modal/videocomposer/configs/exp05_text_depths_wo_style.yaml new file mode 100644 index 00000000..29c053b1 --- /dev/null +++ b/modelscope/models/multi_modal/videocomposer/configs/exp05_text_depths_wo_style.yaml @@ -0,0 +1,26 @@ +TASK_TYPE: SINGLE_TASK +read_image: False # You NEED Open It +read_style: False +read_sketch: False +save_origin_video: True +ENABLE: true +DATASET: webvid10m +video_compositions: ['text', 'mask', 'depthmap', 'sketch', 'motion', 'image', 'local_image', 'single_sketch'] +guidances: ['y', 'depth'] # You NEED Open It +batch_sizes: { + "1": 1, + "4": 1, + "8": 1, + "16": 1, +} +vit_image_size: 224 +network_name: UNetSD_temporal +resume: true +resume_step: 228000 +seed: 182 +num_workers: 0 +mvs_visual: False +chunk_size: 1 +resume_checkpoint: "model_weights/non_ema_228000.pth" +log_dir: 'outputs' +num_steps: 1 diff --git a/modelscope/models/multi_modal/videocomposer/configs/exp06_text_depths_vs_style.yaml b/modelscope/models/multi_modal/videocomposer/configs/exp06_text_depths_vs_style.yaml new file mode 100644 index 00000000..2732dc5c --- /dev/null +++ b/modelscope/models/multi_modal/videocomposer/configs/exp06_text_depths_vs_style.yaml @@ -0,0 +1,26 @@ +TASK_TYPE: SINGLE_TASK +read_image: False +read_style: True +read_sketch: False +save_origin_video: True +ENABLE: true +DATASET: webvid10m +video_compositions: ['text', 'mask', 'depthmap', 'sketch', 'motion', 'image', 'local_image', 'single_sketch'] +guidances: ['y', 'image', 'depth'] # You NEED Open It +batch_sizes: { + "1": 1, + "4": 1, + "8": 1, + "16": 1, +} +vit_image_size: 224 +network_name: UNetSD_temporal +resume: true +resume_step: 228000 +seed: 182 +num_workers: 0 +mvs_visual: False +chunk_size: 1 +resume_checkpoint: "model_weights/non_ema_141000_no_watermark.pth" +log_dir: 'outputs' +num_steps: 1 diff --git a/modelscope/models/multi_modal/videocomposer/configs/exp10_vidcomposer_no_watermark_full.yaml b/modelscope/models/multi_modal/videocomposer/configs/exp10_vidcomposer_no_watermark_full.yaml new file mode 100644 index 00000000..1be311d8 --- /dev/null +++ b/modelscope/models/multi_modal/videocomposer/configs/exp10_vidcomposer_no_watermark_full.yaml @@ -0,0 +1,21 @@ +TASK_TYPE: VideoComposer_Inference +ENABLE: true +DATASET: webvid10m +video_compositions: ['text', 'mask', 'depthmap', 'sketch', 'motion', 'image', 'local_image', 'single_sketch'] +batch_sizes: { + "1": 1, + "4": 1, + "8": 1, + "16": 1, +} +vit_image_size: 224 +network_name: UNetSD_temporal +resume: true +resume_step: 141000 +seed: 14 +num_workers: 1 +mvs_visual: True +chunk_size: 1 +resume_checkpoint: "model_weights/non_ema_141000_no_watermark.pth" +log_dir: 'outputs' +num_steps: 1 diff --git a/modelscope/models/multi_modal/videocomposer/data/__init__.py b/modelscope/models/multi_modal/videocomposer/data/__init__.py new file mode 100644 index 00000000..ba8d1233 --- /dev/null +++ b/modelscope/models/multi_modal/videocomposer/data/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from .samplers import * +from .tokenizers import * +from .transforms import * diff --git a/modelscope/models/multi_modal/videocomposer/data/samplers.py b/modelscope/models/multi_modal/videocomposer/data/samplers.py new file mode 100644 index 00000000..4cadef95 --- /dev/null +++ b/modelscope/models/multi_modal/videocomposer/data/samplers.py @@ -0,0 +1,158 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os.path as osp + +import json +import numpy as np +from torch.utils.data.sampler import Sampler + +from modelscope.models.multi_modal.videocomposer.ops.distributed import ( + get_rank, get_world_size, shared_random_seed) +from modelscope.models.multi_modal.videocomposer.ops.utils import (ceil_divide, + read) + +__all__ = ['BatchSampler', 'GroupSampler', 'ImgGroupSampler'] + + +class BatchSampler(Sampler): + r"""An infinite batch sampler. + """ + + def __init__(self, + dataset_size, + batch_size, + num_replicas=None, + rank=None, + shuffle=False, + seed=None): + self.dataset_size = dataset_size + self.batch_size = batch_size + self.num_replicas = num_replicas or get_world_size() + self.rank = rank or get_rank() + self.shuffle = shuffle + self.seed = seed or shared_random_seed() + self.rng = np.random.default_rng(self.seed + self.rank) + self.batches_per_rank = ceil_divide( + dataset_size, self.num_replicas * self.batch_size) + self.samples_per_rank = self.batches_per_rank * self.batch_size + + # rank indices + indices = self.rng.permutation( + self.samples_per_rank) if shuffle else np.arange( + self.samples_per_rank) + indices = indices * self.num_replicas + self.rank + indices = indices[indices < dataset_size] + self.indices = indices + + def __iter__(self): + start = 0 + while True: + batch = [ + self.indices[i % len(self.indices)] + for i in range(start, start + self.batch_size) + ] + if self.shuffle and (start + self.batch_size) > len(self.indices): + self.rng.shuffle(self.indices) + start = (start + self.batch_size) % len(self.indices) + yield batch + + +class GroupSampler(Sampler): + + def __init__(self, + group_file, + batch_size, + alpha=0.7, + update_interval=5000, + seed=8888): + self.group_file = group_file + self.group_folder = osp.join(osp.dirname(group_file), 'groups') + self.batch_size = batch_size + self.alpha = alpha + self.update_interval = update_interval + self.seed = seed + self.rng = np.random.default_rng(seed) + + def __iter__(self): + while True: + # keep groups up-to-date + self.update_groups() + + # collect items + items = self.sample() + while len(items) < self.batch_size: + items += self.sample() + + # sample a batch + batch = self.rng.choice( + items, + self.batch_size, + replace=False if len(items) >= self.batch_size else True) + yield [u.strip().split(',') for u in batch] + + def update_groups(self): + if not hasattr(self, '_step'): + self._step = 0 + if self._step % self.update_interval == 0: + self.groups = json.loads(read(self.group_file)) + self._step += 1 + + def sample(self): + scales = np.array( + [float(next(iter(u)).split(':')[-1]) for u in self.groups]) + p = scales**self.alpha / (scales**self.alpha).sum() + group = self.rng.choice(self.groups, p=p) + list_file = osp.join(self.group_folder, + self.rng.choice(next(iter(group.values())))) + return read(list_file).strip().split('\n') + + +class ImgGroupSampler(Sampler): + + def __init__(self, + group_file, + batch_size, + alpha=0.7, + update_interval=5000, + seed=8888): + self.group_file = group_file + self.group_folder = osp.join(osp.dirname(group_file), 'groups') + self.batch_size = batch_size + self.alpha = alpha + self.update_interval = update_interval + self.seed = seed + self.rng = np.random.default_rng(seed) + + def __iter__(self): + while True: + # keep groups up-to-date + self.update_groups() + + # collect items + items = self.sample() + while len(items) < self.batch_size: + items += self.sample() + + # sample a batch + batch = self.rng.choice( + items, + self.batch_size, + replace=False if len(items) >= self.batch_size else True) + yield [u.strip().split(',', 1) for u in batch] + + def update_groups(self): + if not hasattr(self, '_step'): + self._step = 0 + if self._step % self.update_interval == 0: + self.groups = json.loads(read(self.group_file)) + + self._step += 1 + + def sample(self): + scales = np.array( + [float(next(iter(u)).split(':')[-1]) for u in self.groups]) + p = scales**self.alpha / (scales**self.alpha).sum() + group = self.rng.choice(self.groups, p=p) + list_file = osp.join(self.group_folder, + self.rng.choice(next(iter(group.values())))) + return read(list_file).strip().split('\n') diff --git a/modelscope/models/multi_modal/videocomposer/data/tokenizers.py b/modelscope/models/multi_modal/videocomposer/data/tokenizers.py new file mode 100644 index 00000000..e3ab506b --- /dev/null +++ b/modelscope/models/multi_modal/videocomposer/data/tokenizers.py @@ -0,0 +1,184 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import gzip +import html +import os +from functools import lru_cache + +import ftfy +import regex as re +import torch +from tokenizers import BertWordPieceTokenizer, CharBPETokenizer + +__all__ = ['CLIPTokenizer'] + + +@lru_cache() +def default_bpe(): + root = os.path.realpath(__file__) + root = '/'.join(root.split('/')[:-1]) + return os.path.join(root, 'bpe_simple_vocab_16e6.txt.gz') + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a signficant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = list(range(ord('!'), + ord('~') + 1)) + list(range( + ord('¡'), + ord('¬') + 1)) + list(range(ord('®'), + ord('ÿ') + 1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r'\s+', ' ', text) + text = text.strip() + return text + + +class SimpleTokenizer(object): + + def __init__(self, bpe_path: str = default_bpe()): + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + merges = gzip.open(bpe_path).read().decode('utf-8').split('\n') + merges = merges[1:49152 - 256 - 2 + 1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v + '' for v in vocab] + for merge in merges: + vocab.append(''.join(merge)) + vocab.extend(['<|startoftext|>', '<|endoftext|>']) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = { + '<|startoftext|>': '<|startoftext|>', + '<|endoftext|>': '<|endoftext|>' + } + self.pat = re.compile( + r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", + re.IGNORECASE) + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + (token[-1] + '', ) + pairs = get_pairs(word) + + if not pairs: + return token + '' + + while True: + bigram = min( + pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except Exception as e: + new_word.extend(word[i:]) + print(e) + break + + if word[i] == first and i < len(word) - 1 and word[ + i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = ' '.join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = ''.join(self.byte_encoder[b] + for b in token.encode('utf-8')) + bpe_tokens.extend(self.encoder[bpe_token] + for bpe_token in self.bpe(token).split(' ')) + return bpe_tokens + + def decode(self, tokens): + text = ''.join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode( + 'utf-8', errors='replace').replace('', ' ') + return text + + +class CLIPTokenizer(object): + + def __init__(self, length=77): + self.length = length + + # init tokenizer + self.tokenizer = SimpleTokenizer(bpe_path=default_bpe()) + self.sos_token = self.tokenizer.encoder['<|startoftext|>'] + self.eos_token = self.tokenizer.encoder['<|endoftext|>'] + self.vocab_size = len(self.tokenizer.encoder) + + def __call__(self, sequence): + if isinstance(sequence, str): + return torch.LongTensor(self._tokenizer(sequence)) + elif isinstance(sequence, list): + return torch.LongTensor([self._tokenizer(u) for u in sequence]) + else: + raise TypeError( + f'Expected the "sequence" to be a string or a list, but got {type(sequence)}' + ) + + def _tokenizer(self, text): + tokens = self.tokenizer.encode(text)[:self.length - 2] + tokens = [self.sos_token] + tokens + [self.eos_token] + tokens = tokens + [0] * (self.length - len(tokens)) + return tokens diff --git a/modelscope/models/multi_modal/videocomposer/data/transforms.py b/modelscope/models/multi_modal/videocomposer/data/transforms.py new file mode 100644 index 00000000..3573fe86 --- /dev/null +++ b/modelscope/models/multi_modal/videocomposer/data/transforms.py @@ -0,0 +1,400 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import math +import random + +import numpy as np +import torch +import torchvision.transforms.functional as F +import torchvision.transforms.functional as TF +from PIL import Image, ImageFilter +from torchvision.transforms.functional import InterpolationMode + +__all__ = [ + 'Compose', 'Resize', 'Rescale', 'CenterCrop', 'CenterCropV2', 'RandomCrop', + 'RandomCropV2', 'RandomHFlip', 'GaussianBlur', 'ColorJitter', 'RandomGray', + 'ToTensor', 'Normalize', 'ResizeRandomCrop', 'ExtractResizeRandomCrop', + 'ExtractResizeAssignCrop' +] + + +def random_resize(img, size): + img = [ + TF.resize( + u, + size, + interpolation=random.choice([ + InterpolationMode.BILINEAR, InterpolationMode.BICUBIC, + InterpolationMode.LANCZOS + ])) for u in img + ] + return img + + +class CenterCropV3(object): + + def __init__(self, size): + self.size = size + + def __call__(self, img): + # fast resize + while min(img.size) >= 2 * self.size: + img = img.resize((img.width // 2, img.height // 2), + resample=Image.BOX) + scale = self.size / min(img.size) + img = img.resize((round(scale * img.width), round(scale * img.height)), + resample=Image.BICUBIC) + + # center crop + x1 = (img.width - self.size) // 2 + y1 = (img.height - self.size) // 2 + img = img.crop((x1, y1, x1 + self.size, y1 + self.size)) + return img + + +class Compose(object): + + def __init__(self, transforms): + self.transforms = transforms + + def __getitem__(self, index): + if isinstance(index, slice): + return Compose(self.transforms[index]) + else: + return self.transforms[index] + + def __len__(self): + return len(self.transforms) + + def __call__(self, rgb): + for t in self.transforms: + rgb = t(rgb) + return rgb + + +class Resize(object): + + def __init__(self, size=256): + if isinstance(size, int): + size = (size, size) + self.size = size + + def __call__(self, rgb): + + rgb = [u.resize(self.size, Image.BILINEAR) for u in rgb] + return rgb + + +class Rescale(object): + + def __init__(self, size=256, interpolation=Image.BILINEAR): + self.size = size + self.interpolation = interpolation + + def __call__(self, rgb): + w, h = rgb[0].size + scale = self.size / min(w, h) + out_w, out_h = int(round(w * scale)), int(round(h * scale)) + rgb = [u.resize((out_w, out_h), self.interpolation) for u in rgb] + return rgb + + +class CenterCrop(object): + + def __init__(self, size=224): + self.size = size + + def __call__(self, rgb): + w, h = rgb[0].size + assert min(w, h) >= self.size + x1 = (w - self.size) // 2 + y1 = (h - self.size) // 2 + rgb = [u.crop((x1, y1, x1 + self.size, y1 + self.size)) for u in rgb] + return rgb + + +class ResizeRandomCrop(object): + + def __init__(self, size=256, size_short=292): + self.size = size + # self.min_area = min_area + self.size_short = size_short + + def __call__(self, rgb): + + # consistent crop between rgb and m + while min(rgb[0].size) >= 2 * self.size_short: + rgb = [ + u.resize((u.width // 2, u.height // 2), resample=Image.BOX) + for u in rgb + ] + scale = self.size_short / min(rgb[0].size) + rgb = [ + u.resize((round(scale * u.width), round(scale * u.height)), + resample=Image.BICUBIC) for u in rgb + ] + out_w = self.size + out_h = self.size + w, h = rgb[0].size # (518, 292) + x1 = random.randint(0, w - out_w) + y1 = random.randint(0, h - out_h) + + rgb = [u.crop((x1, y1, x1 + out_w, y1 + out_h)) for u in rgb] + + return rgb + + +class ExtractResizeRandomCrop(object): + + def __init__(self, size=256, size_short=292): + self.size = size + self.size_short = size_short + + def __call__(self, rgb): + + # consistent crop between rgb and m + while min(rgb[0].size) >= 2 * self.size_short: + rgb = [ + u.resize((u.width // 2, u.height // 2), resample=Image.BOX) + for u in rgb + ] + scale = self.size_short / min(rgb[0].size) + rgb = [ + u.resize((round(scale * u.width), round(scale * u.height)), + resample=Image.BICUBIC) for u in rgb + ] + out_w = self.size + out_h = self.size + w, h = rgb[0].size # (518, 292) + x1 = random.randint(0, w - out_w) + y1 = random.randint(0, h - out_h) + + rgb = [u.crop((x1, y1, x1 + out_w, y1 + out_h)) for u in rgb] + wh = [x1, y1, x1 + out_w, y1 + out_h] + + return rgb, wh + + +class ExtractResizeAssignCrop(object): + + def __init__(self, size=256, size_short=292): + self.size = size + self.size_short = size_short + + def __call__(self, rgb, wh): + + # consistent crop between rgb and m + while min(rgb[0].size) >= 2 * self.size_short: + rgb = [ + u.resize((u.width // 2, u.height // 2), resample=Image.BOX) + for u in rgb + ] + scale = self.size_short / min(rgb[0].size) + rgb = [ + u.resize((round(scale * u.width), round(scale * u.height)), + resample=Image.BICUBIC) for u in rgb + ] + + rgb = [u.crop(wh) for u in rgb] + rgb = [u.resize((self.size, self.size), Image.BILINEAR) for u in rgb] + + return rgb + + +class CenterCropV2(object): + + def __init__(self, size): + self.size = size + + def __call__(self, img): + # fast resize + while min(img[0].size) >= 2 * self.size: + img = [ + u.resize((u.width // 2, u.height // 2), resample=Image.BOX) + for u in img + ] + scale = self.size / min(img[0].size) + img = [ + u.resize((round(scale * u.width), round(scale * u.height)), + resample=Image.BICUBIC) for u in img + ] + + # center crop + x1 = (img[0].width - self.size) // 2 + y1 = (img[0].height - self.size) // 2 + img = [u.crop((x1, y1, x1 + self.size, y1 + self.size)) for u in img] + return img + + +class RandomCrop(object): + + def __init__(self, size=224, min_area=0.4): + self.size = size + self.min_area = min_area + + def __call__(self, rgb): + + # consistent crop between rgb and m + w, h = rgb[0].size + area = w * h + out_w, out_h = float('inf'), float('inf') + while out_w > w or out_h > h: + target_area = random.uniform(self.min_area, 1.0) * area + aspect_ratio = random.uniform(3. / 4., 4. / 3.) + out_w = int(round(math.sqrt(target_area * aspect_ratio))) + out_h = int(round(math.sqrt(target_area / aspect_ratio))) + x1 = random.randint(0, w - out_w) + y1 = random.randint(0, h - out_h) + + rgb = [u.crop((x1, y1, x1 + out_w, y1 + out_h)) for u in rgb] + rgb = [u.resize((self.size, self.size), Image.BILINEAR) for u in rgb] + + return rgb + + +class RandomCropV2(object): + + def __init__(self, size=224, min_area=0.4, ratio=(3. / 4., 4. / 3.)): + if isinstance(size, (tuple, list)): + self.size = size + else: + self.size = (size, size) + self.min_area = min_area + self.ratio = ratio + + def _get_params(self, img): + width, height = img.size + area = height * width + + for _ in range(10): + target_area = random.uniform(self.min_area, 1.0) * area + log_ratio = (math.log(self.ratio[0]), math.log(self.ratio[1])) + aspect_ratio = math.exp(random.uniform(*log_ratio)) + + w = int(round(math.sqrt(target_area * aspect_ratio))) + h = int(round(math.sqrt(target_area / aspect_ratio))) + + if 0 < w <= width and 0 < h <= height: + i = random.randint(0, height - h) + j = random.randint(0, width - w) + return i, j, h, w + + # Fallback to central crop + in_ratio = float(width) / float(height) + if (in_ratio < min(self.ratio)): + w = width + h = int(round(w / min(self.ratio))) + elif (in_ratio > max(self.ratio)): + h = height + w = int(round(h * max(self.ratio))) + else: # whole image + w = width + h = height + i = (height - h) // 2 + j = (width - w) // 2 + return i, j, h, w + + def __call__(self, rgb): + i, j, h, w = self._get_params(rgb[0]) + rgb = [F.resized_crop(u, i, j, h, w, self.size) for u in rgb] + return rgb + + +class RandomHFlip(object): + + def __init__(self, p=0.5): + self.p = p + + def __call__(self, rgb): + if random.random() < self.p: + rgb = [u.transpose(Image.FLIP_LEFT_RIGHT) for u in rgb] + return rgb + + +class GaussianBlur(object): + + def __init__(self, sigmas=[0.1, 2.0], p=0.5): + self.sigmas = sigmas + self.p = p + + def __call__(self, rgb): + if random.random() < self.p: + sigma = random.uniform(*self.sigmas) + rgb = [ + u.filter(ImageFilter.GaussianBlur(radius=sigma)) for u in rgb + ] + return rgb + + +class ColorJitter(object): + + def __init__(self, + brightness=0.4, + contrast=0.4, + saturation=0.4, + hue=0.1, + p=0.5): + self.brightness = brightness + self.contrast = contrast + self.saturation = saturation + self.hue = hue + self.p = p + + def __call__(self, rgb): + if random.random() < self.p: + brightness, contrast, saturation, hue = self._random_params() + transforms = [ + lambda f: F.adjust_brightness(f, brightness), + lambda f: F.adjust_contrast(f, contrast), + lambda f: F.adjust_saturation(f, saturation), + lambda f: F.adjust_hue(f, hue) + ] + random.shuffle(transforms) + for t in transforms: + rgb = [t(u) for u in rgb] + + return rgb + + def _random_params(self): + brightness = random.uniform( + max(0, 1 - self.brightness), 1 + self.brightness) + contrast = random.uniform(max(0, 1 - self.contrast), 1 + self.contrast) + saturation = random.uniform( + max(0, 1 - self.saturation), 1 + self.saturation) + hue = random.uniform(-self.hue, self.hue) + return brightness, contrast, saturation, hue + + +class RandomGray(object): + + def __init__(self, p=0.2): + self.p = p + + def __call__(self, rgb): + if random.random() < self.p: + rgb = [u.convert('L').convert('RGB') for u in rgb] + return rgb + + +class ToTensor(object): + + def __call__(self, rgb): + rgb = torch.stack([F.to_tensor(u) for u in rgb], dim=0) + return rgb + + +class Normalize(object): + + def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]): + self.mean = mean + self.std = std + + def __call__(self, rgb): + rgb = rgb.clone() + rgb.clamp_(0, 1) + if not isinstance(self.mean, torch.Tensor): + self.mean = rgb.new_tensor(self.mean).view(-1) + if not isinstance(self.std, torch.Tensor): + self.std = rgb.new_tensor(self.std).view(-1) + rgb.sub_(self.mean.view(1, -1, 1, 1)).div_(self.std.view(1, -1, 1, 1)) + return rgb diff --git a/modelscope/models/multi_modal/videocomposer/diffusion.py b/modelscope/models/multi_modal/videocomposer/diffusion.py new file mode 100644 index 00000000..8b669ca9 --- /dev/null +++ b/modelscope/models/multi_modal/videocomposer/diffusion.py @@ -0,0 +1,1514 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import math + +import torch + +from .dpm_solver import (DPM_Solver, NoiseScheduleVP, model_wrapper, + model_wrapper_guided_diffusion) +from .ops.losses import discretized_gaussian_log_likelihood, kl_divergence + +__all__ = ['GaussianDiffusion', 'beta_schedule', 'GaussianDiffusion_style'] + + +def _i(tensor, t, x): + r"""Index tensor using t and format the output according to x. + """ + shape = (x.size(0), ) + (1, ) * (x.ndim - 1) + if tensor.device != x.device: + tensor = tensor.to(x.device) + return tensor[t].view(shape).to(x) + + +def beta_schedule(schedule, + num_timesteps=1000, + init_beta=None, + last_beta=None): + ''' + This code defines a function beta_schedule that generates a sequence of beta + values based on the given input parameters. + These beta values can be used in video diffusion processes. The function has the following parameters: + schedule(str): Determines the type of beta schedule to be generated. + It can be 'linear', 'linear_sd', 'quadratic', or 'cosine'. + num_timesteps(int, optional): The number of timesteps for the generated beta schedule. Default is 1000. + init_beta(float, optional): The initial beta value. + If not provided, a default value is used based on the chosen schedule. + last_beta(float, optional): The final beta value. + If not provided, a default value is used based on the chosen schedule. + The function returns a PyTorch tensor containing the generated beta values. + The beta schedule is determined by the schedule parameter: + 1.Linear: Generates a linear sequence of beta values betweeninit_betaandlast_beta. + 2.Linear_sd: Generates a linear sequence of beta values between the square root of + init_beta and the square root oflast_beta, and then squares the result. + 3.Quadratic: Similar to the 'linear_sd' schedule, but with different default values forinit_betaandlast_beta. + 4.Cosine: Generates a sequence of beta values based on a cosine function, + ensuring the values are between 0 and 0.999. + If an unsupported schedule is provided, a ValueError is raised with a message indicating the issue. + ''' + if schedule == 'linear': + scale = 1000.0 / num_timesteps + init_beta = init_beta or scale * 0.0001 + last_beta = last_beta or scale * 0.02 + return torch.linspace( + init_beta, last_beta, num_timesteps, dtype=torch.float64) + elif schedule == 'linear_sd': + return torch.linspace( + init_beta**0.5, last_beta**0.5, num_timesteps, + dtype=torch.float64)**2 + elif schedule == 'quadratic': + init_beta = init_beta or 0.0015 + last_beta = last_beta or 0.0195 + return torch.linspace( + init_beta**0.5, last_beta**0.5, num_timesteps, + dtype=torch.float64)**2 + elif schedule == 'cosine': + betas = [] + for step in range(num_timesteps): + t1 = step / num_timesteps + t2 = (step + 1) / num_timesteps + fn = lambda u: math.cos( # noqa + (u + 0.008) / 1.008 * math.pi / 2)**2 # noqa + betas.append(min(1.0 - fn(t2) / fn(t1), 0.999)) + return torch.tensor(betas, dtype=torch.float64) + else: + raise ValueError(f'Unsupported schedule: {schedule}') + + +def load_stable_diffusion_pretrained(state_dict, temporal_attention): + import collections + sd_new = collections.OrderedDict() + keys = list(state_dict.keys()) + + for k in keys: + if k.find('diffusion_model') >= 0: + k_new = k.split('diffusion_model.')[-1] + if k_new in [ + 'input_blocks.3.0.op.weight', 'input_blocks.3.0.op.bias', + 'input_blocks.6.0.op.weight', 'input_blocks.6.0.op.bias', + 'input_blocks.9.0.op.weight', 'input_blocks.9.0.op.bias' + ]: + k_new = k_new.replace('0.op', 'op') + if temporal_attention: + if k_new.find('middle_block.2') >= 0: + k_new = k_new.replace('middle_block.2', 'middle_block.3') + if k_new.find('output_blocks.5.2') >= 0: + k_new = k_new.replace('output_blocks.5.2', + 'output_blocks.5.3') + if k_new.find('output_blocks.8.2') >= 0: + k_new = k_new.replace('output_blocks.8.2', + 'output_blocks.8.3') + sd_new[k_new] = state_dict[k] + + return sd_new + + +class AddGaussianNoise(object): + + def __init__(self, mean=0., std=0.1): + self.std = std + self.mean = mean + + def __call__(self, img): + assert isinstance(img, torch.Tensor) + dtype = img.dtype + if not img.is_floating_point(): + img = img.to(torch.float32) + out = img + self.std * torch.randn_like(img) + self.mean + if out.dtype != dtype: + out = out.to(dtype) + return out + + def __repr__(self): + return self.__class__.__name__ + '(mean={0}, std={1})'.format( + self.mean, self.std) + + +class GaussianDiffusion(object): + + def __init__(self, + betas, + mean_type='eps', + var_type='learned_range', + loss_type='mse', + epsilon=1e-12, + rescale_timesteps=False): + # check input + if not isinstance(betas, torch.DoubleTensor): + betas = torch.tensor(betas, dtype=torch.float64) + assert min(betas) > 0 and max(betas) <= 1 + assert mean_type in ['x0', 'x_{t-1}', 'eps'] + assert var_type in [ + 'learned', 'learned_range', 'fixed_large', 'fixed_small' + ] + assert loss_type in [ + 'mse', 'rescaled_mse', 'kl', 'rescaled_kl', 'l1', 'rescaled_l1', + 'charbonnier' + ] + self.betas = betas + self.num_timesteps = len(betas) + self.mean_type = mean_type # eps + self.var_type = var_type # 'fixed_small' + self.loss_type = loss_type # mse + self.epsilon = epsilon # 1e-12 + self.rescale_timesteps = rescale_timesteps # False + + # alphas + alphas = 1 - self.betas + self.alphas_cumprod = torch.cumprod(alphas, dim=0) + self.alphas_cumprod_prev = torch.cat( + [alphas.new_ones([1]), self.alphas_cumprod[:-1]]) + self.alphas_cumprod_next = torch.cat( + [self.alphas_cumprod[1:], + alphas.new_zeros([1])]) + + # q(x_t | x_{t-1}) + self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod) + self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 + - self.alphas_cumprod) + self.log_one_minus_alphas_cumprod = torch.log(1.0 + - self.alphas_cumprod) + self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod) + self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod + - 1) + + # q(x_{t-1} | x_t, x_0) + self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / ( + 1.0 - self.alphas_cumprod) + self.posterior_log_variance_clipped = torch.log( + self.posterior_variance.clamp(1e-20)) + self.posterior_mean_coef1 = betas * torch.sqrt( + self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + self.posterior_mean_coef2 = ( + 1.0 - self.alphas_cumprod_prev) * torch.sqrt(alphas) / ( + 1.0 - self.alphas_cumprod) + + def q_sample(self, x0, t, noise=None): + r"""Sample from q(x_t | x_0). + """ + noise = torch.randn_like(x0) if noise is None else noise + return _i(self.sqrt_alphas_cumprod, t, x0) * x0 + \ + _i(self.sqrt_one_minus_alphas_cumprod, t, x0) * noise # noqa + + def q_mean_variance(self, x0, t): + r"""Distribution of q(x_t | x_0). + """ + mu = _i(self.sqrt_alphas_cumprod, t, x0) * x0 + var = _i(1.0 - self.alphas_cumprod, t, x0) + log_var = _i(self.log_one_minus_alphas_cumprod, t, x0) + return mu, var, log_var + + def q_posterior_mean_variance(self, x0, xt, t): + r"""Distribution of q(x_{t-1} | x_t, x_0). + """ + mu = _i(self.posterior_mean_coef1, t, xt) * x0 + _i( + self.posterior_mean_coef2, t, xt) * xt + var = _i(self.posterior_variance, t, xt) + log_var = _i(self.posterior_log_variance_clipped, t, xt) + return mu, var, log_var + + @torch.no_grad() + def p_sample(self, + xt, + t, + model, + model_kwargs={}, + clamp=None, + percentile=None, + condition_fn=None, + guide_scale=None): + r"""Sample from p(x_{t-1} | x_t). + - condition_fn: for classifier-based guidance (guided-diffusion). + - guide_scale: for classifier-free guidance (glide/dalle-2). + """ + # predict distribution of p(x_{t-1} | x_t) + mu, var, log_var, x0 = self.p_mean_variance(xt, t, model, model_kwargs, + clamp, percentile, + guide_scale) + + # random sample (with optional conditional function) + noise = torch.randn_like(xt) + mask = t.ne(0).float().view( + -1, + *((1, ) * # noqa + (xt.ndim - 1))) + if condition_fn is not None: + grad = condition_fn(xt, self._scale_timesteps(t), **model_kwargs) + mu = mu.float() + var * grad.float() + xt_1 = mu + mask * torch.exp(0.5 * log_var) * noise + return xt_1, x0 + + @torch.no_grad() + def p_sample_loop(self, + noise, + model, + model_kwargs={}, + clamp=None, + percentile=None, + condition_fn=None, + guide_scale=None): + r"""Sample from p(x_{t-1} | x_t) p(x_{t-2} | x_{t-1}) ... p(x_0 | x_1). + """ + # prepare input + b = noise.size(0) + xt = noise + + # diffusion process + for step in torch.arange(self.num_timesteps).flip(0): + t = torch.full((b, ), step, dtype=torch.long, device=xt.device) + xt, _ = self.p_sample(xt, t, model, model_kwargs, clamp, + percentile, condition_fn, guide_scale) + return xt + + def p_mean_variance(self, + xt, + t, + model, + model_kwargs={}, + clamp=None, + percentile=None, + guide_scale=None): + r"""Distribution of p(x_{t-1} | x_t). + """ + # predict distribution + if guide_scale is None: + out = model(xt, self._scale_timesteps(t), **model_kwargs) + else: + # classifier-free guidance + # (model_kwargs[0]: conditional kwargs; model_kwargs[1]: non-conditional kwargs) + assert isinstance(model_kwargs, list) and len(model_kwargs) == 2 + y_out = model(xt, self._scale_timesteps(t), **model_kwargs[0]) + u_out = model(xt, self._scale_timesteps(t), **model_kwargs[1]) + dim = y_out.size(1) if self.var_type.startswith( + 'fixed') else y_out.size(1) // 2 + out = torch.cat( + [ + u_out[:, :dim] + guide_scale * # noqa + (y_out[:, :dim] - u_out[:, :dim]), + y_out[:, dim:] + ], + dim=1) # noqa + + # compute variance + if self.var_type == 'learned': + out, log_var = out.chunk(2, dim=1) + var = torch.exp(log_var) + elif self.var_type == 'learned_range': + out, fraction = out.chunk(2, dim=1) + min_log_var = _i(self.posterior_log_variance_clipped, t, xt) + max_log_var = _i(torch.log(self.betas), t, xt) + fraction = (fraction + 1) / 2.0 + log_var = fraction * max_log_var + (1 - fraction) * min_log_var + var = torch.exp(log_var) + elif self.var_type == 'fixed_large': + var = _i( + torch.cat([self.posterior_variance[1:2], self.betas[1:]]), t, + xt) + log_var = torch.log(var) + elif self.var_type == 'fixed_small': + var = _i(self.posterior_variance, t, xt) + log_var = _i(self.posterior_log_variance_clipped, t, xt) + + # compute mean and x0 + if self.mean_type == 'x_{t-1}': + mu = out # x_{t-1} + x0 = _i(1.0 / self.posterior_mean_coef1, t, xt) * mu - \ + _i(self.posterior_mean_coef2 / self.posterior_mean_coef1, t, xt) * xt # noqa + elif self.mean_type == 'x0': + x0 = out + mu, _, _ = self.q_posterior_mean_variance(x0, xt, t) + elif self.mean_type == 'eps': + x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \ + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * out # noqa + mu, _, _ = self.q_posterior_mean_variance(x0, xt, t) + + # restrict the range of x0 + if percentile is not None: + assert percentile > 0 and percentile <= 1 # e.g., 0.995 + s = torch.quantile( + x0.flatten(1).abs(), percentile, + dim=1).clamp_(1.0).view(-1, 1, 1, 1) + x0 = torch.min(s, torch.max(-s, x0)) / s + elif clamp is not None: + x0 = x0.clamp(-clamp, clamp) + return mu, var, log_var, x0 + + @torch.no_grad() + def ddim_sample(self, + xt, + t, + model, + model_kwargs={}, + clamp=None, + percentile=None, + condition_fn=None, + guide_scale=None, + ddim_timesteps=20, + eta=0.0): + r"""Sample from p(x_{t-1} | x_t) using DDIM. + - condition_fn: for classifier-based guidance (guided-diffusion). + - guide_scale: for classifier-free guidance (glide/dalle-2). + """ + stride = self.num_timesteps // ddim_timesteps + + # predict distribution of p(x_{t-1} | x_t) + _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp, + percentile, guide_scale) + if condition_fn is not None: + # x0 -> eps + alpha = _i(self.alphas_cumprod, t, xt) + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \ + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) # noqa + eps = eps - (1 - alpha).sqrt() * condition_fn( + xt, self._scale_timesteps(t), **model_kwargs) + + # eps -> x0 + x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \ + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * eps # noqa + + # derive variables + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \ + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) # noqa + alphas = _i(self.alphas_cumprod, t, xt) + alphas_prev = _i(self.alphas_cumprod, (t - stride).clamp(0), xt) + sigmas = eta * torch.sqrt((1 - alphas_prev) / (1 - alphas) * # noqa + (1 - alphas / alphas_prev)) + + # random sample + noise = torch.randn_like(xt) + direction = torch.sqrt(1 - alphas_prev - sigmas**2) * eps + mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1))) + xt_1 = torch.sqrt(alphas_prev) * x0 + direction + mask * sigmas * noise + return xt_1, x0 + + @torch.no_grad() + def ddim_sample_loop(self, + noise, + model, + model_kwargs={}, + clamp=None, + percentile=None, + condition_fn=None, + guide_scale=None, + ddim_timesteps=20, + eta=0.0): + # prepare input + b = noise.size(0) + xt = noise + + # diffusion process (TODO: clamp is inaccurate! Consider replacing the stride by explicit prev/next steps) + steps = (1 + torch.arange(0, self.num_timesteps, + self.num_timesteps // ddim_timesteps)).clamp( + 0, self.num_timesteps - 1).flip(0) + # import ipdb; ipdb.set_trace() + for step in steps: + t = torch.full((b, ), step, dtype=torch.long, device=xt.device) + xt, _ = self.ddim_sample(xt, t, model, model_kwargs, clamp, + percentile, condition_fn, guide_scale, + ddim_timesteps, eta) + return xt + + @torch.no_grad() + def ddim_reverse_sample(self, + xt, + t, + model, + model_kwargs={}, + clamp=None, + percentile=None, + guide_scale=None, + ddim_timesteps=20): + r"""Sample from p(x_{t+1} | x_t) using DDIM reverse ODE (deterministic). + """ + stride = self.num_timesteps // ddim_timesteps + + # predict distribution of p(x_{t-1} | x_t) + _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp, + percentile, guide_scale) + + # derive variables + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \ + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) # noqa + alphas_next = _i( + torch.cat( + [self.alphas_cumprod, + self.alphas_cumprod.new_zeros([1])]), + (t + stride).clamp(0, self.num_timesteps), xt) + + # reverse sample + mu = torch.sqrt(alphas_next) * x0 + torch.sqrt(1 - alphas_next) * eps + return mu, x0 + + @torch.no_grad() + def ddim_reverse_sample_loop(self, + x0, + model, + model_kwargs={}, + clamp=None, + percentile=None, + guide_scale=None, + ddim_timesteps=20): + # prepare input + b = x0.size(0) + xt = x0 + + # reconstruction steps + steps = torch.arange(0, self.num_timesteps, + self.num_timesteps // ddim_timesteps) + for step in steps: + t = torch.full((b, ), step, dtype=torch.long, device=xt.device) + xt, _ = self.ddim_reverse_sample(xt, t, model, model_kwargs, clamp, + percentile, guide_scale, + ddim_timesteps) + return xt + + @torch.no_grad() + def plms_sample(self, + xt, + t, + model, + model_kwargs={}, + clamp=None, + percentile=None, + condition_fn=None, + guide_scale=None, + plms_timesteps=20): + r"""Sample from p(x_{t-1} | x_t) using PLMS. + - condition_fn: for classifier-based guidance (guided-diffusion). + - guide_scale: for classifier-free guidance (glide/dalle-2). + """ + stride = self.num_timesteps // plms_timesteps + + # function for compute eps + def compute_eps(xt, t): + # predict distribution of p(x_{t-1} | x_t) + _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, + clamp, percentile, guide_scale) + + # condition + if condition_fn is not None: + # x0 -> eps + alpha = _i(self.alphas_cumprod, t, xt) + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \ + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) # noqa + eps = eps - (1 - alpha).sqrt() * condition_fn( + xt, self._scale_timesteps(t), **model_kwargs) + + # eps -> x0 + x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \ + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * eps # noqa + + # derive eps + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \ + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) # noqa + return eps + + # function for compute x_0 and x_{t-1} + def compute_x0(eps, t): + # eps -> x0 + x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \ + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * eps # noqa + + # deterministic sample + alphas_prev = _i(self.alphas_cumprod, (t - stride).clamp(0), xt) + direction = torch.sqrt(1 - alphas_prev) * eps + xt_1 = torch.sqrt(alphas_prev) * x0 + direction + return xt_1, x0 + + # PLMS sample + eps = compute_eps(xt, t) + if len(eps_cache) == 0: + # 2nd order pseudo improved Euler + xt_1, x0 = compute_x0(eps, t) + eps_next = compute_eps(xt_1, (t - stride).clamp(0)) + eps_prime = (eps + eps_next) / 2.0 + elif len(eps_cache) == 1: + # 2nd order pseudo linear multistep (Adams-Bashforth) + eps_prime = (3 * eps - eps_cache[-1]) / 2.0 + elif len(eps_cache) == 2: + # 3nd order pseudo linear multistep (Adams-Bashforth) + eps_prime = (23 * eps - 16 * eps_cache[-1] + + 5 * eps_cache[-2]) / 12.0 + elif len(eps_cache) >= 3: + # 4nd order pseudo linear multistep (Adams-Bashforth) + eps_prime = (55 * eps - 59 * eps_cache[-1] + 37 * eps_cache[-2] + - 9 * eps_cache[-3]) / 24.0 + xt_1, x0 = compute_x0(eps_prime, t) + return xt_1, x0, eps + + @torch.no_grad() + def plms_sample_loop(self, + noise, + model, + model_kwargs={}, + clamp=None, + percentile=None, + condition_fn=None, + guide_scale=None, + plms_timesteps=20): + # prepare input + b = noise.size(0) + xt = noise + + # diffusion process + steps = (1 + torch.arange(0, self.num_timesteps, + self.num_timesteps // plms_timesteps)).clamp( + 0, self.num_timesteps - 1).flip(0) + eps_cache = [] + for step in steps: + # PLMS sampling step + t = torch.full((b, ), step, dtype=torch.long, device=xt.device) + xt, _, eps = self.plms_sample(xt, t, model, model_kwargs, clamp, + percentile, condition_fn, + guide_scale, plms_timesteps, + eps_cache) + + # update eps cache + eps_cache.append(eps) + if len(eps_cache) >= 4: + eps_cache.pop(0) + return xt + + def loss(self, + x0, + t, + model, + model_kwargs={}, + noise=None, + weight=None, + use_div_loss=False): + noise = torch.randn_like( + x0) if noise is None else noise # [80, 4, 8, 32, 32] + xt = self.q_sample(x0, t, noise=noise) + + # compute loss + if self.loss_type in ['kl', 'rescaled_kl']: + loss, _ = self.variational_lower_bound(x0, xt, t, model, + model_kwargs) + if self.loss_type == 'rescaled_kl': + loss = loss * self.num_timesteps + elif self.loss_type in ['mse', 'rescaled_mse', 'l1', + 'rescaled_l1']: # self.loss_type: mse + out = model(xt, self._scale_timesteps(t), **model_kwargs) + + # VLB for variation + loss_vlb = 0.0 + if self.var_type in ['learned', 'learned_range' + ]: # self.var_type: 'fixed_small' + out, var = out.chunk(2, dim=1) + frozen = torch.cat([ + out.detach(), var + ], dim=1) # learn var without affecting the prediction of mean + loss_vlb, _ = self.variational_lower_bound( + x0, xt, t, model=lambda *args, **kwargs: frozen) + if self.loss_type.startswith('rescaled_'): + loss_vlb = loss_vlb * self.num_timesteps / 1000.0 + + # MSE/L1 for x0/eps + target = { + 'eps': noise, + 'x0': x0, + 'x_{t-1}': self.q_posterior_mean_variance(x0, xt, t)[0] + }[self.mean_type] + loss = (out - target).pow(1 if self.loss_type.endswith('l1') else 2 + ).abs().flatten(1).mean(dim=1) + if weight is not None: + loss = loss * weight + + # div loss + if use_div_loss and self.mean_type == 'eps' and x0.shape[2] > 1: + + # derive x0 + x0_ = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \ + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * out + + # ncfhw, std on f + div_loss = 0.001 / ( + x0_.std(dim=2).flatten(1).mean(dim=1) + 1e-4) + loss = loss + div_loss + + # total loss + loss = loss + loss_vlb + elif self.loss_type in ['charbonnier']: + out = model(xt, self._scale_timesteps(t), **model_kwargs) + + # VLB for variation + loss_vlb = 0.0 + if self.var_type in ['learned', 'learned_range']: + out, var = out.chunk(2, dim=1) + frozen = torch.cat([out.detach(), var], dim=1) + loss_vlb, _ = self.variational_lower_bound( + x0, xt, t, model=lambda *args, **kwargs: frozen) + if self.loss_type.startswith('rescaled_'): + loss_vlb = loss_vlb * self.num_timesteps / 1000.0 + + # MSE/L1 for x0/eps + target = { + 'eps': noise, + 'x0': x0, + 'x_{t-1}': self.q_posterior_mean_variance(x0, xt, t)[0] + }[self.mean_type] + loss = torch.sqrt((out - target)**2 + self.epsilon) + if weight is not None: + loss = loss * weight + loss = loss.flatten(1).mean(dim=1) + + # total loss + loss = loss + loss_vlb + return loss + + def variational_lower_bound(self, + x0, + xt, + t, + model, + model_kwargs={}, + clamp=None, + percentile=None): + # compute groundtruth and predicted distributions + mu1, _, log_var1 = self.q_posterior_mean_variance(x0, xt, t) + mu2, _, log_var2, x0 = self.p_mean_variance(xt, t, model, model_kwargs, + clamp, percentile) + + # compute KL loss + kl = kl_divergence(mu1, log_var1, mu2, log_var2) + kl = kl.flatten(1).mean(dim=1) / math.log(2.0) + + # compute discretized NLL loss (for p(x0 | x1) only) + nll = -discretized_gaussian_log_likelihood( + x0, mean=mu2, log_scale=0.5 * log_var2) + nll = nll.flatten(1).mean(dim=1) / math.log(2.0) + + # NLL for p(x0 | x1) and KL otherwise + vlb = torch.where(t == 0, nll, kl) + return vlb, x0 + + @torch.no_grad() + def variational_lower_bound_loop(self, + x0, + model, + model_kwargs={}, + clamp=None, + percentile=None): + r"""Compute the entire variational lower bound, measured in bits-per-dim. + """ + # prepare input and output + b = x0.size(0) + metrics = {'vlb': [], 'mse': [], 'x0_mse': []} + + # loop + for step in torch.arange(self.num_timesteps).flip(0): + # compute VLB + t = torch.full((b, ), step, dtype=torch.long, device=x0.device) + noise = torch.randn_like(x0) + xt = self.q_sample(x0, t, noise) + vlb, pred_x0 = self.variational_lower_bound( + x0, xt, t, model, model_kwargs, clamp, percentile) + + # predict eps from x0 + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \ + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) # noqa + + # collect metrics + metrics['vlb'].append(vlb) + metrics['x0_mse'].append( + (pred_x0 - x0).square().flatten(1).mean(dim=1)) + metrics['mse'].append( + (eps - noise).square().flatten(1).mean(dim=1)) + metrics = {k: torch.stack(v, dim=1) for k, v in metrics.items()} + + # compute the prior KL term for VLB, measured in bits-per-dim + mu, _, log_var = self.q_mean_variance(x0, t) + kl_prior = kl_divergence(mu, log_var, torch.zeros_like(mu), + torch.zeros_like(log_var)) + kl_prior = kl_prior.flatten(1).mean(dim=1) / math.log(2.0) + + # update metrics + metrics['prior_bits_per_dim'] = kl_prior + metrics['total_bits_per_dim'] = metrics['vlb'].sum(dim=1) + kl_prior + return metrics + + def _scale_timesteps(self, t): + if self.rescale_timesteps: # noqa + return t.float() * 1000.0 / self.num_timesteps + return t + + +class GaussianDiffusion_style(object): + + def __init__(self, + betas, + mean_type='eps', + var_type='fixed_small', + loss_type='mse', + rescale_timesteps=False): + # check input + if not isinstance(betas, torch.DoubleTensor): + betas = torch.tensor(betas, dtype=torch.float64) + assert min(betas) > 0 and max(betas) <= 1 + assert mean_type in ['x0', 'x_{t-1}', 'eps'] + assert var_type in [ + 'learned', 'learned_range', 'fixed_large', 'fixed_small' + ] + assert loss_type in [ + 'mse', 'rescaled_mse', 'kl', 'rescaled_kl', 'l1', 'rescaled_l1' + ] + self.betas = betas + self.num_timesteps = len(betas) + self.mean_type = mean_type + self.var_type = var_type + self.loss_type = loss_type + self.rescale_timesteps = rescale_timesteps + + # alphas + alphas = 1 - self.betas + self.alphas_cumprod = torch.cumprod(alphas, dim=0) + self.alphas_cumprod_prev = torch.cat( + [alphas.new_ones([1]), self.alphas_cumprod[:-1]]) + self.alphas_cumprod_next = torch.cat( + [self.alphas_cumprod[1:], + alphas.new_zeros([1])]) + + # q(x_t | x_{t-1}) + self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod) + self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 + - self.alphas_cumprod) + self.log_one_minus_alphas_cumprod = torch.log(1.0 + - self.alphas_cumprod) + self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod) + self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod + - 1) + + # q(x_{t-1} | x_t, x_0) + self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / ( + 1.0 - self.alphas_cumprod) + self.posterior_log_variance_clipped = torch.log( + self.posterior_variance.clamp(1e-20)) + self.posterior_mean_coef1 = betas * torch.sqrt( + self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + self.posterior_mean_coef2 = ( + 1.0 - self.alphas_cumprod_prev) * torch.sqrt(alphas) / ( + 1.0 - self.alphas_cumprod) + + def q_sample(self, x0, t, noise=None): + r"""Sample from q(x_t | x_0). + """ + noise = torch.randn_like(x0) if noise is None else noise + xt = _i(self.sqrt_alphas_cumprod, t, x0) * x0 + \ + _i(self.sqrt_one_minus_alphas_cumprod, t, x0) * noise # noqa + return xt.type_as(x0) + + def q_mean_variance(self, x0, t): + r"""Distribution of q(x_t | x_0). + """ + mu = _i(self.sqrt_alphas_cumprod, t, x0) * x0 + var = _i(1.0 - self.alphas_cumprod, t, x0) + log_var = _i(self.log_one_minus_alphas_cumprod, t, x0) + return mu, var, log_var + + def q_posterior_mean_variance(self, x0, xt, t): + r"""Distribution of q(x_{t-1} | x_t, x_0). + """ + mu = _i(self.posterior_mean_coef1, t, xt) * x0 + _i( + self.posterior_mean_coef2, t, xt) * xt + var = _i(self.posterior_variance, t, xt) + log_var = _i(self.posterior_log_variance_clipped, t, xt) + return mu, var, log_var + + @torch.no_grad() + def p_sample(self, + xt, + t, + model, + model_kwargs={}, + clamp=None, + percentile=None, + condition_fn=None, + guide_scale=None): + r"""Sample from p(x_{t-1} | x_t). + - condition_fn: for classifier-based guidance (guided-diffusion). + - guide_scale: for classifier-free guidance (glide/dalle-2). + """ + dtype = xt.dtype + + # predict distribution of p(x_{t-1} | x_t) + mu, var, log_var, x0 = self.p_mean_variance(xt, t, model, model_kwargs, + clamp, percentile, + guide_scale) + + # random sample (with optional conditional function) + noise = torch.randn_like(xt) + t_mask = t.ne(0).float().view( + -1, + *((1, ) * # noqa + (xt.ndim - 1))) + if condition_fn is not None: + grad = condition_fn(xt, self._scale_timesteps(t), **model_kwargs) + mu = mu.float() + var * grad.float() + xt_1 = mu + t_mask * torch.exp(0.5 * log_var) * noise + return xt_1.type(dtype), x0.type(dtype) + + @torch.no_grad() + def p_sample_loop(self, + noise, + model, + model_kwargs={}, + clamp=None, + percentile=None, + condition_fn=None, + guide_scale=None): + r"""Sample from p(x_{t-1} | x_t) p(x_{t-2} | x_{t-1}) ... p(x_0 | x_1). + """ + # prepare input + b = noise.size(0) + xt = noise + + # diffusion process + for step in torch.arange(self.num_timesteps).flip(0): + t = torch.full((b, ), step, dtype=torch.long, device=xt.device) + xt, _ = self.p_sample(xt, t, model, model_kwargs, clamp, + percentile, condition_fn, guide_scale) + return xt + + def p_mean_variance(self, + xt, + t, + model, + model_kwargs={}, + clamp=None, + percentile=None, + guide_scale=None): + r"""Distribution of p(x_{t-1} | x_t). + """ + # predict distribution + if guide_scale is None: + out = model(xt, t=self._scale_timesteps(t), **model_kwargs) + else: + # classifier-free guidance + # (model_kwargs[0]: conditional kwargs; model_kwargs[1]: non-conditional kwargs) + assert isinstance(model_kwargs, list) and len(model_kwargs) == 2 + y_out = model(xt, t=self._scale_timesteps(t), **model_kwargs[0]) + if guide_scale != 1.0: + u_out = model( + xt, t=self._scale_timesteps(t), **model_kwargs[1]) + dim = y_out.size(1) if self.var_type.startswith( + 'fixed') else y_out.size(1) // 2 + out = torch.cat( + [ + u_out[:, :dim] + guide_scale * # noqa + (y_out[:, :dim] - u_out[:, :dim]), + y_out[:, dim:] + ], + dim=1) # noqa + else: + out = y_out + + # compute variance + if self.var_type == 'learned': + out, log_var = out.chunk(2, dim=1) + var = torch.exp(log_var) + elif self.var_type == 'learned_range': + out, fraction = out.chunk(2, dim=1) + min_log_var = _i(self.posterior_log_variance_clipped, t, xt) + max_log_var = _i(torch.log(self.betas), t, xt) + fraction = (fraction + 1) / 2.0 + log_var = fraction * max_log_var + (1 - fraction) * min_log_var + var = torch.exp(log_var) + elif self.var_type == 'fixed_large': + var = _i( + torch.cat([self.posterior_variance[1:2], self.betas[1:]]), t, + xt) + log_var = torch.log(var) + elif self.var_type == 'fixed_small': + var = _i(self.posterior_variance, t, xt) + log_var = _i(self.posterior_log_variance_clipped, t, xt) + + # compute mean and x0 + if self.mean_type == 'x_{t-1}': + mu = out # x_{t-1} + x0 = _i(1.0 / self.posterior_mean_coef1, t, xt) * mu - \ + _i(self.posterior_mean_coef2 / self.posterior_mean_coef1, t, xt) * xt # noqa + elif self.mean_type == 'x0': + x0 = out + elif self.mean_type == 'eps': + x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \ + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * out # noqa + + # restrict the range of x0 + if percentile is not None: + assert percentile > 0 and percentile <= 1 # e.g., 0.995 + s = torch.quantile( + x0.flatten(1).abs(), percentile, + dim=1).clamp_(1.0).view(-1, 1, 1, 1, 1) + # s = torch.quantile(x0.flatten(1).abs(), percentile, dim=1).clamp_(1.0).view(-1, 1, 1, 1) # old + x0 = torch.min(s, torch.max(-s, x0)) / s + elif clamp is not None: + x0 = x0.clamp(-clamp, clamp) + + # recompute mu using the restricted x0 + mu, _, _ = self.q_posterior_mean_variance(x0, xt, t) + return mu, var, log_var, x0 + + @torch.no_grad() + def ddim_sample(self, + xt, + t, + t_prev, + model, + model_kwargs={}, + clamp=None, + percentile=None, + condition_fn=None, + guide_scale=None, + ddim_timesteps=20, + eta=0.0): + r"""Sample from p(x_{t-1} | x_t) using DDIM. + - condition_fn: for classifier-based guidance (guided-diffusion). + - guide_scale: for classifier-free guidance (glide/dalle-2). + """ + dtype = xt.dtype + + # predict distribution of p(x_{t-1} | x_t) + _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp, + percentile, guide_scale) + if condition_fn is not None: + # x0 -> eps + alpha = _i(self.alphas_cumprod, t, xt) + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \ + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) # noqa + eps = eps - (1 - alpha).sqrt() * condition_fn( + xt, self._scale_timesteps(t), **model_kwargs) + + # eps -> x0 + x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \ + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * eps # noqa + + # derive variables + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \ + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) # noqa + alphas = _i(self.alphas_cumprod, t, xt) + alphas_prev = _i(self.alphas_cumprod, t_prev, xt) + sigmas = eta * torch.sqrt((1 - alphas_prev) / (1 - alphas) * # noqa + (1 - alphas / alphas_prev)) + + # random sample + noise = torch.randn_like(xt) + direction = torch.sqrt(1 - alphas_prev - sigmas**2) * eps + t_mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1))) + xt_1 = torch.sqrt( + alphas_prev) * x0 + direction + t_mask * sigmas * noise + return xt_1.type(dtype), x0.type(dtype) + + @torch.no_grad() + def ddim_sample_loop(self, + noise, + model, + model_kwargs={}, + clamp=None, + percentile=None, + condition_fn=None, + guide_scale=None, + ddim_timesteps=20, + eta=0.0): + # prepare input + b = noise.size(0) + xt = noise + + # diffusion process + steps = (1 + torch.arange(0, self.num_timesteps, + self.num_timesteps // ddim_timesteps)).clamp( + 0, self.num_timesteps - 1).flip(0) + for i, step in enumerate(steps): + t = torch.full((b, ), step, dtype=torch.long, device=xt.device) + t_prev = torch.full((b, ), + steps[i + 1] if i < len(steps) - 1 else 0, + dtype=torch.long, + device=xt.device) + xt, _ = self.ddim_sample(xt, t, t_prev, model, model_kwargs, clamp, + percentile, condition_fn, guide_scale, + ddim_timesteps, eta) + return xt + + @torch.no_grad() + def ddim_reverse_sample(self, + xt, + t, + t_next, + model, + model_kwargs={}, + clamp=None, + percentile=None, + guide_scale=None, + ddim_timesteps=20): + r"""Sample from p(x_{t+1} | x_t) using DDIM reverse ODE (deterministic). + """ + dtype = xt.dtype + + # predict distribution of p(x_{t-1} | x_t) + _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp, + percentile, guide_scale) + + # derive variables + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \ + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) # noqa + alphas_next = _i( + torch.cat( + [self.alphas_cumprod, + self.alphas_cumprod.new_zeros([1])]), t_next, xt) + + # reverse sample + mu = torch.sqrt(alphas_next) * x0 + torch.sqrt(1 - alphas_next) * eps + return mu.type(dtype), x0.type(dtype) + + @torch.no_grad() + def ddim_reverse_sample_loop(self, + x0, + model, + model_kwargs={}, + clamp=None, + percentile=None, + guide_scale=None, + ddim_timesteps=20): + # prepare input + b = x0.size(0) + xt = x0 + + # reconstruction steps + steps = (1 + torch.arange(0, self.num_timesteps, + self.num_timesteps // ddim_timesteps)).clamp( + 0, self.num_timesteps - 1) + for i, step in enumerate(steps): + t = torch.full((b, ), + steps[i - 1] if i > 0 else 0, + dtype=torch.long, + device=xt.device) + t_next = torch.full((b, ), + step, + dtype=torch.long, + device=xt.device) + xt, _ = self.ddim_reverse_sample(xt, t, t_next, model, + model_kwargs, clamp, percentile, + guide_scale, ddim_timesteps) + return xt + + @torch.no_grad() + def plms_sample(self, + xt, + t, + t_prev, + model, + model_kwargs={}, + clamp=None, + percentile=None, + condition_fn=None, + guide_scale=None, + plms_timesteps=20): + r"""Sample from p(x_{t-1} | x_t) using PLMS. + - condition_fn: for classifier-based guidance (guided-diffusion). + - guide_scale: for classifier-free guidance (glide/dalle-2). + """ + + # function for compute eps + def compute_eps(xt, t): + dtype = xt.dtype + + # predict distribution of p(x_{t-1} | x_t) + _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, + clamp, percentile, guide_scale) + + # condition + if condition_fn is not None: + # x0 -> eps + alpha = _i(self.alphas_cumprod, t, xt) + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \ + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) # noqa + eps = eps - (1 - alpha).sqrt() * condition_fn( + xt, self._scale_timesteps(t), **model_kwargs) + + # eps -> x0 + x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \ + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * eps # noqa + + # derive eps + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \ + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) # noqa + return eps.type(dtype) + + # function for compute x_0 and x_{t-1} + def compute_x0(eps, t): + dtype = eps.dtype + + # eps -> x0 + x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \ + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * eps # noqa + + # deterministic sample + alphas_prev = _i(self.alphas_cumprod, t_prev, xt) + direction = torch.sqrt(1 - alphas_prev) * eps + xt_1 = torch.sqrt(alphas_prev) * x0 + direction + return xt_1.type(dtype), x0.type(dtype) + + # PLMS sample + eps = compute_eps(xt, t) + if len(eps_cache) == 0: + # 2nd order pseudo improved Euler + xt_1, x0 = compute_x0(eps, t) + eps_next = compute_eps(xt_1, t_prev) + eps_prime = (eps + eps_next) / 2.0 + elif len(eps_cache) == 1: + # 2nd order pseudo linear multistep (Adams-Bashforth) + eps_prime = (3 * eps - eps_cache[-1]) / 2.0 + elif len(eps_cache) == 2: + # 3nd order pseudo linear multistep (Adams-Bashforth) + eps_prime = (23 * eps - 16 * eps_cache[-1] + + 5 * eps_cache[-2]) / 12.0 + elif len(eps_cache) >= 3: + # 4nd order pseudo linear multistep (Adams-Bashforth) + eps_prime = (55 * eps - 59 * eps_cache[-1] + 37 * eps_cache[-2] + - 9 * eps_cache[-3]) / 24.0 + xt_1, x0 = compute_x0(eps_prime, t) + return xt_1, x0, eps + + @torch.no_grad() + def plms_sample_loop(self, + noise, + model, + model_kwargs={}, + clamp=None, + percentile=None, + condition_fn=None, + guide_scale=None, + plms_timesteps=20): + # prepare input + b = noise.size(0) + xt = noise + + # diffusion process + steps = (1 + torch.arange(0, self.num_timesteps, + self.num_timesteps // plms_timesteps)).clamp( + 0, self.num_timesteps - 1).flip(0) + eps_cache = [] + for i, step in enumerate(steps): + # PLMS sampling step + t = torch.full((b, ), step, dtype=torch.long, device=xt.device) + t_prev = torch.full((b, ), + steps[i + 1] if i < len(steps) - 1 else 0, + dtype=torch.long, + device=xt.device) + xt, _, eps = self.plms_sample(xt, t, t_prev, model, model_kwargs, + clamp, percentile, condition_fn, + guide_scale, plms_timesteps, + eps_cache) + + # update eps cache + eps_cache.append(eps) + if len(eps_cache) >= 4: + eps_cache.pop(0) + return xt + + @torch.no_grad() + def dpm_solver_sample_loop(self, + noise, + model, + model_kwargs={}, + order=2, + skip_type='logSNR', + method='multistep', + clamp=None, + percentile=None, + condition_fn=None, + guide_scale=None, + dpm_solver_timesteps=20, + algorithm_type='dpmsolver++', + t_start=None, + t_end=None, + lower_order_final=True, + denoise_to_zero=False, + solver_type='dpmsolver'): + r"""Sample using DPM-Solver-based method. + - condition_fn: for classifier-based guidance (guided-diffusion). + - guide_scale: for classifier-free guidance (glide/dalle-2). + + Please check all the parameters in `dpm_solver.sample` before using. + """ + assert self.mean_type in ('eps', 'x0') + assert percentile in (None, 0.995) + assert clamp is None or percentile is None + noise_schedule = NoiseScheduleVP( + schedule='discrete', betas=self.betas.float()) + model_fn = model_wrapper_guided_diffusion( + model=model, + noise_schedule=noise_schedule, + var_type=self.var_type, + mean_type=self.mean_type, + model_kwargs=model_kwargs, + rescale_timesteps=self.rescale_timesteps, + num_timesteps=self.num_timesteps, + guide_scale=guide_scale, + condition_fn=condition_fn) + dpm_solver = DPM_Solver( + model_fn=model_fn, + noise_schedule=noise_schedule, + algorithm_type=algorithm_type, + percentile=percentile, + clamp=clamp) + xt = dpm_solver.sample( + noise, + steps=dpm_solver_timesteps, + order=order, + skip_type=skip_type, + method=method, + solver_type=solver_type, + t_start=t_start, + t_end=t_end, + lower_order_final=lower_order_final, + denoise_to_zero=denoise_to_zero) + return xt + + @torch.no_grad() + def inpaint_p_sample(self, + xt, + t, + y, + mask, + model, + model_kwargs={}, + clamp=None, + percentile=None, + guide_scale=None): + r"""DDPM sampling step for inpainting. + """ + dtype = xt.dtype + + # predict distribution of p(x_{t-1} | x_t), conditioned on y and mask + xt = self.q_sample(y, t) * mask + xt * (1 - mask) + mu, var, log_var, x0 = self.p_mean_variance(xt, t, model, model_kwargs, + clamp, percentile, + guide_scale) + + # random sample + t_mask = t.ne(0).float().view( + -1, + *((1, ) * # noqa + (xt.ndim - 1))) + xt_1 = mu + t_mask * torch.exp(0.5 * log_var) * torch.randn_like(xt) + return xt_1.type(dtype), x0.type(dtype) + + @torch.no_grad() + def inpaint_p_sample_loop(self, + noise, + y, + mask, + model, + model_kwargs={}, + clamp=None, + percentile=None, + guide_scale=None): + r"""DDPM sampling loop for inpainting. + """ + # prepare input + b = noise.size(0) + xt = noise + + # diffusion process + for step in torch.arange(self.num_timesteps).flip(0): + t = torch.full((b, ), step, dtype=torch.long, device=xt.device) + xt, _ = self.inpaint_p_sample(xt, t, y, mask, model, model_kwargs, + clamp, percentile, guide_scale) + return xt + + @torch.no_grad() + def inpaint_mcg_p_sample(self, + xt, + t, + y, + mask, + model, + model_kwargs={}, + clamp=None, + percentile=None, + guide_scale=None, + mcg_scale=1.0): + r"""DDPM sampling step for inpainting, with Manifold Constrained Gradient (MCG) correction. + """ + dtype = xt.dtype + + # predict distribution of p(x_{t-1} | x_t), conditioned on y and mask + with torch.enable_grad(): + xt.requires_grad_(True) + mu, var, log_var, x0 = self.p_mean_variance( + xt, t, model, model_kwargs, clamp, percentile, guide_scale) + loss = (y * mask - x0 * mask).square().mean() + grad = torch.autograd.grad(loss, xt)[0] + + # random sample + t_mask = t.ne(0).float().view( + -1, + *((1, ) * # noqa + (xt.ndim - 1))) + xt_1 = mu + t_mask * torch.exp(0.5 * log_var) * torch.randn_like(xt) + xt_1 = xt_1 - mcg_scale * grad + + # merge foreground and background + xt_1 = self.q_sample(y, t) * mask + xt_1 * (1 - mask) + return xt_1.type(dtype), x0.type(dtype) + + @torch.no_grad() + def inpaint_mcg_p_sample_loop(self, + noise, + y, + mask, + model, + model_kwargs={}, + clamp=None, + percentile=None, + guide_scale=None, + mcg_scale=1.0): + r"""DDPM sampling loop for inpainting, with Manifold Constrained Gradient (MCG) correction. + """ + # prepare input + b = noise.size(0) + xt = noise + + # diffusion process + for step in torch.arange(self.num_timesteps).flip(0): + t = torch.full((b, ), step, dtype=torch.long, device=xt.device) + xt, _ = self.inpaint_mcg_p_sample(xt, t, y, mask, model, + model_kwargs, clamp, percentile, + guide_scale, mcg_scale) + return xt + + def loss(self, + x0, + t, + model, + model_kwargs={}, + noise=None, + input_x0=None, + reduction='mean'): + assert reduction in ['mean', 'none'] + noise = torch.randn_like(x0) if noise is None else noise + input_x0 = x0 if input_x0 is None else input_x0 + xt = self.q_sample(input_x0, t, noise=noise) + + # compute loss + if self.loss_type in ['kl', 'rescaled_kl']: + loss, _ = self.variational_lower_bound(x0, xt, t, model, + model_kwargs) + if self.loss_type == 'rescaled_kl': + loss = loss * self.num_timesteps + elif self.loss_type in ['mse', 'rescaled_mse', 'l1', 'rescaled_l1']: + out = model(xt, t=self._scale_timesteps(t), **model_kwargs) + + # VLB for variation + loss_vlb = 0.0 + if self.var_type in ['learned', 'learned_range']: + out, var = out.chunk(2, dim=1) + frozen = torch.cat([ + out.detach(), var + ], dim=1) # learn var without affecting the prediction of mean + loss_vlb, _ = self.variational_lower_bound( + x0, + xt, + t, + model=lambda *args, **kwargs: frozen, + reduction=reduction) + if self.loss_type.startswith('rescaled_'): + loss_vlb = loss_vlb * self.num_timesteps / 1000.0 + + # MSE/L1 for x0/eps + target = { + 'eps': noise, + 'x0': x0, + 'x_{t-1}': self.q_posterior_mean_variance(x0, xt, t)[0] + }[self.mean_type] + loss = ( + out + - target).pow(1 if self.loss_type.endswith('l1') else 2).abs() + if reduction == 'mean': + loss = loss.flatten(1).mean(dim=1) + + # total loss + loss = loss + loss_vlb + return loss + + def variational_lower_bound(self, + x0, + xt, + t, + model, + model_kwargs={}, + clamp=None, + percentile=None, + reduction='mean'): + assert reduction in ['mean', 'none'] + + # compute groundtruth and predicted distributions + mu1, _, log_var1 = self.q_posterior_mean_variance(x0, xt, t) + mu2, _, log_var2, x0 = self.p_mean_variance(xt, t, model, model_kwargs, + clamp, percentile) + + # compute KL loss + kl = kl_divergence(mu1, log_var1, mu2, log_var2) / math.log(2.0) + if reduction == 'mean': + kl = kl.flatten(1).mean(dim=1) + + # compute discretized NLL loss (for p(x0 | x1) only) + nll = -discretized_gaussian_log_likelihood( + x0, mean=mu2, log_scale=0.5 * log_var2) / math.log(2.0) + if reduction == 'mean': + nll = nll.flatten(1).mean(dim=1) + + # NLL for p(x0 | x1) and KL otherwise + t = t.view(-1, *(1, ) * (nll.ndim - 1)) + vlb = torch.where(t == 0, nll, kl) + return vlb, x0 + + @torch.no_grad() + def variational_lower_bound_loop(self, + x0, + model, + model_kwargs={}, + clamp=None, + percentile=None): + r"""Compute the entire variational lower bound, measured in bits-per-dim. + """ + # prepare input and output + b = x0.size(0) + metrics = {'vlb': [], 'mse': [], 'x0_mse': []} + + # loop + for step in torch.arange(self.num_timesteps).flip(0): + # compute VLB + t = torch.full((b, ), step, dtype=torch.long, device=x0.device) + noise = torch.randn_like(x0) + xt = self.q_sample(x0, t, noise) + vlb, pred_x0 = self.variational_lower_bound( + x0, xt, t, model, model_kwargs, clamp, percentile) + + # predict eps from x0 + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \ + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) # noqa + + # collect metrics + metrics['vlb'].append(vlb) + metrics['x0_mse'].append( + (pred_x0 - x0).square().flatten(1).mean(dim=1)) + metrics['mse'].append( + (eps - noise).square().flatten(1).mean(dim=1)) + metrics = {k: torch.stack(v, dim=1) for k, v in metrics.items()} + + # compute the prior KL term for VLB, measured in bits-per-dim + mu, _, log_var = self.q_mean_variance(x0, t) + kl_prior = kl_divergence(mu, log_var, torch.zeros_like(mu), + torch.zeros_like(log_var)) + kl_prior = kl_prior.flatten(1).mean(dim=1) / math.log(2.0) + + # update metrics + metrics['prior_bits_per_dim'] = kl_prior + metrics['total_bits_per_dim'] = metrics['vlb'].sum(dim=1) + kl_prior + return metrics + + def _scale_timesteps(self, t): + if self.rescale_timesteps: + return t.float() * 1000.0 / self.num_timesteps + return t diff --git a/modelscope/models/multi_modal/videocomposer/dpm_solver.py b/modelscope/models/multi_modal/videocomposer/dpm_solver.py new file mode 100644 index 00000000..a1e15a2f --- /dev/null +++ b/modelscope/models/multi_modal/videocomposer/dpm_solver.py @@ -0,0 +1,1697 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import math + +import torch +import torch.nn.functional as F + +__all__ = [ + 'NoiseScheduleVP', 'model_wrapper', 'model_wrapper_guided_diffusion', + 'DPM_Solver', 'interpolate_fn' +] + + +def _i(tensor, t, x): + r"""Index tensor using t and format the output according to x. + """ + shape = (x.size(0), ) + (1, ) * (x.ndim - 1) + device = x.device + return tensor[t.to(device)].view(shape).to(device) + + +class NoiseScheduleVP: + + def __init__(self, + schedule='discrete', + betas=None, + alphas_cumprod=None, + continuous_beta_0=0.1, + continuous_beta_1=20., + dtype=torch.float32): + r"""Create a wrapper class for the forward SDE (VP type). + + *** + Update: We support discrete-time diffusion models by implementing a picewise + linear interpolation for log_alpha_t. + We recommend to use schedule='discrete' for the discrete-time diffusion models, + especially for high-resolution images. + *** + + The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ). + We further define lambda_t = log(alpha_t) - log(sigma_t), + which is the half-logSNR (described in the DPM-Solver paper). + Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have: + + log_alpha_t = self.marginal_log_mean_coeff(t) + sigma_t = self.marginal_std(t) + lambda_t = self.marginal_lambda(t) + + Moreover, as lambda(t) is an invertible function, we also support its inverse function: + + t = self.inverse_lambda(lambda_t) + + =============================================================== + + We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) + and continuous-time DPMs (trained on t in [t_0, T]). + + 1. For discrete-time DPMs: + + For discrete-time DPMs trained on n = 0, 1, ..., N-1, + we convert the discrete steps to continuous time steps by: + t_i = (i + 1) / N + e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1. + We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3. + + Args: + betas: A `torch.Tensor`. The beta array for the discrete-time DPM. + (See the original DDPM paper for details) + alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. + (See the original DDPM paper for details) + + Note that we always have alphas_cumprod = cumprod(1 - betas). + Therefore, we only need to set one of `betas` and `alphas_cumprod`. + + **Important**: Please pay special attention for the args for `alphas_cumprod`: + The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, + DDPMs assume that + q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ). + Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. + In fact, we have + alpha_{t_n} = \sqrt{\hat{alpha_n}}, + and + log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}). + + 2. For continuous-time DPMs: + + We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). + The hyperparameters for the noise + schedule are the default settings in DDPM and improved-DDPM: + + Args: + beta_min: A `float` number. The smallest beta for the linear schedule. + beta_max: A `float` number. The largest beta for the linear schedule. + cosine_s: A `float` number. The hyperparameter in the cosine schedule. + cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule. + T: A `float` number. The ending time of the forward process. + + =============================================================== + + Args: + schedule: A `str`. The noise schedule of the forward SDE. + 'discrete' for discrete-time DPMs, + 'linear' or 'cosine' for continuous-time DPMs. + Returns: + A wrapper object of the forward SDE (VP type). + + =============================================================== + + Example: + + # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1): + >>> ns = NoiseScheduleVP('discrete', betas=betas) + + # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} + # array for n = 0, 1, ..., N - 1): + >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod) + + # For continuous-time DPMs (VPSDE), linear schedule: + >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.) + + """ + + if schedule not in ['discrete', 'linear', 'cosine']: + raise ValueError( + "Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'" + .format(schedule)) + + self.schedule = schedule + if schedule == 'discrete': + if betas is not None: + log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0) + else: + assert alphas_cumprod is not None + log_alphas = 0.5 * torch.log(alphas_cumprod) + self.total_N = len(log_alphas) + self.T = 1. + self.t_array = torch.linspace(0., 1., + self.total_N + 1)[1:].reshape( + (1, -1)).to(dtype=dtype) + self.log_alpha_array = log_alphas.reshape(( + 1, + -1, + )).to(dtype=dtype) + else: + self.total_N = 1000 + self.beta_0 = continuous_beta_0 + self.beta_1 = continuous_beta_1 + self.cosine_s = 0.008 + self.cosine_beta_max = 999. + self.cosine_t_max = math.atan( + self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * ( + 1. + self.cosine_s) / math.pi - self.cosine_s + self.cosine_log_alpha_0 = math.log( + math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.)) + self.schedule = schedule + if schedule == 'cosine': + # For the cosine schedule, + # T = 1 will have numerical issues. So we manually set the ending time T. + # Note that T = 0.9946 may be not the optimal setting. However, we find it works well. + self.T = 0.9946 + else: + self.T = 1. + + def marginal_log_mean_coeff(self, t): + """ + Compute log(alpha_t) of a given continuous-time label t in [0, T]. + """ + if self.schedule == 'discrete': + return interpolate_fn( + t.reshape((-1, 1)), self.t_array.to(t.device), + self.log_alpha_array.to(t.device)).reshape((-1)) + elif self.schedule == 'linear': + return -0.25 * t**2 * (self.beta_1 + - self.beta_0) - 0.5 * t * self.beta_0 + elif self.schedule == 'cosine': + log_alpha_fn = lambda s: torch.log( # noqa + torch.cos((s + self.cosine_s) / # noqa + (1. + self.cosine_s) * math.pi / 2.)) + log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0 + return log_alpha_t + + def marginal_alpha(self, t): + """ + Compute alpha_t of a given continuous-time label t in [0, T]. + """ + return torch.exp(self.marginal_log_mean_coeff(t)) + + def marginal_std(self, t): + """ + Compute sigma_t of a given continuous-time label t in [0, T]. + """ + return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t))) + + def marginal_lambda(self, t): + """ + Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T]. + """ + log_mean_coeff = self.marginal_log_mean_coeff(t) + log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff)) + return log_mean_coeff - log_std + + def inverse_lambda(self, lamb): + """ + Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t. + """ + if self.schedule == 'linear': + tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp( + -2. * lamb, + torch.zeros((1, )).to(lamb)) + Delta = self.beta_0**2 + tmp + return tmp / (torch.sqrt(Delta) + self.beta_0) / ( + self.beta_1 - self.beta_0) + elif self.schedule == 'discrete': + log_alpha = -0.5 * torch.logaddexp( + torch.zeros((1, )).to(lamb.device), -2. * lamb) + t = interpolate_fn( + log_alpha.reshape((-1, 1)), + torch.flip(self.log_alpha_array.to(lamb.device), [1]), + torch.flip(self.t_array.to(lamb.device), [1])) + return t.reshape((-1, )) + else: + log_alpha = -0.5 * torch.logaddexp(-2. * lamb, + torch.zeros((1, )).to(lamb)) + t_fn = lambda log_alpha_t: torch.arccos( # noqa + torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * ( + 1. + self.cosine_s) / math.pi - self.cosine_s + t = t_fn(log_alpha) + return t + + +def model_wrapper(model, + noise_schedule, + model_type='noise', + model_kwargs={}, + guidance_type='uncond', + condition=None, + unconditional_condition=None, + guidance_scale=1., + classifier_fn=None, + classifier_kwargs={}): + r"""Create a wrapper function for the noise prediction model. + + DPM-Solver needs to solve the continuous-time diffusion ODEs. + For DPMs trained on discrete-time labels, we need to + firstly wrap the model function to a noise prediction model + that accepts the continuous time as the input. + + We support four types of the diffusion model by setting `model_type`: + + 1. "noise": noise prediction model. (Trained by predicting noise). + + 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0). + + 3. "v": velocity prediction model. (Trained by predicting the velocity). + The "v" prediction is derivation detailed in Appendix D of [1], + and is used in Imagen-Video [2]. + + [1] Salimans, Tim, and Jonathan Ho. + "Progressive distillation for fast sampling of diffusion models." + arXiv preprint arXiv:2202.00512 (2022). + [2] Ho, Jonathan, et al. "Imagen Video: High Definition + Video Generation with Diffusion Models." + arXiv preprint arXiv:2210.02303 (2022). + + 4. "score": marginal score function. (Trained by denoising score matching). + Note that the score function and the noise prediction model follows a simple relationship: + ``` + noise(x_t, t) = -sigma_t * score(x_t, t) + ``` + + We support three types of guided sampling by DPMs by setting `guidance_type`: + 1. "uncond": unconditional sampling by DPMs. + The input `model` has the following format: + `` + model(x, t_input, **model_kwargs) -> noise | x_start | v | score + `` + + 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier. + The input `model` has the following format: + `` + model(x, t_input, **model_kwargs) -> noise | x_start | v | score + `` + + The input `classifier_fn` has the following format: + `` + classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond) + `` + + [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis," + in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794. + + 3. "classifier-free": classifier-free guidance sampling by conditional DPMs. + The input `model` has the following format: + `` + model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score + `` + And if cond == `unconditional_condition`, the model output is the unconditional DPM output. + + [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance." + arXiv preprint arXiv:2207.12598 (2022). + + + The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999) + or continuous-time labels (i.e. epsilon to T). + + We wrap the model function to accept only `x` and `t_continuous` as inputs, + and outputs the predicted noise: + `` + def model_fn(x, t_continuous) -> noise: + t_input = get_model_input_time(t_continuous) + return noise_pred(model, x, t_input, **model_kwargs) + `` + where `t_continuous` is the continuous time labels (i.e. epsilon to T). + And we use `model_fn` for DPM-Solver. + + =============================================================== + + Args: + model: A diffusion model with the corresponding format described above. + noise_schedule: A noise schedule object, such as NoiseScheduleVP. + model_type: A `str`. The parameterization type of the diffusion model. + "noise" or "x_start" or "v" or "score". + model_kwargs: A `dict`. A dict for the other inputs of the model function. + guidance_type: A `str`. The type of the guidance for sampling. + "uncond" or "classifier" or "classifier-free". + condition: A pytorch tensor. The condition for the guided sampling. + Only used for "classifier" or "classifier-free" guidance type. + unconditional_condition: A pytorch tensor. The condition for the unconditional sampling. + Only used for "classifier-free" guidance type. + guidance_scale: A `float`. The scale for the guided sampling. + classifier_fn: A classifier function. Only used for the classifier guidance. + classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function. + Returns: + A noise prediction model that accepts the noised data and the continuous time as the inputs. + """ + + def get_model_input_time(t_continuous): + """ + Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time. + For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N]. + For continuous-time DPMs, we just use `t_continuous`. + """ + if noise_schedule.schedule == 'discrete': + return (t_continuous - 1. / noise_schedule.total_N) * 1000. + else: + return t_continuous + + def noise_pred_fn(x, t_continuous, cond=None): + t_input = get_model_input_time(t_continuous) + if cond is None: + output = model(x, t_input, **model_kwargs) + else: + output = model(x, t_input, cond, **model_kwargs) + if model_type == 'noise': + return output + elif model_type == 'x_start': + alpha_t, sigma_t = noise_schedule.marginal_alpha( + t_continuous), noise_schedule.marginal_std(t_continuous) + return (x - alpha_t * output) / sigma_t + elif model_type == 'v': + alpha_t, sigma_t = noise_schedule.marginal_alpha( + t_continuous), noise_schedule.marginal_std(t_continuous) + return alpha_t * output + sigma_t * x + elif model_type == 'score': + sigma_t = noise_schedule.marginal_std(t_continuous) + return -sigma_t * output + + def cond_grad_fn(x, t_input): + """ + Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t). + """ + with torch.enable_grad(): + x_in = x.detach().requires_grad_(True) + log_prob = classifier_fn(x_in, t_input, condition, + **classifier_kwargs) + return torch.autograd.grad(log_prob.sum(), x_in)[0] + + def model_fn(x, t_continuous): + """ + The noise predicition model function that is used for DPM-Solver. + """ + if guidance_type == 'uncond': + return noise_pred_fn(x, t_continuous) + elif guidance_type == 'classifier': + assert classifier_fn is not None + t_input = get_model_input_time(t_continuous) + cond_grad = cond_grad_fn(x, t_input) + sigma_t = noise_schedule.marginal_std(t_continuous) + noise = noise_pred_fn(x, t_continuous) + return noise - guidance_scale * sigma_t * cond_grad + elif guidance_type == 'classifier-free': + if guidance_scale == 1. or unconditional_condition is None: + return noise_pred_fn(x, t_continuous, cond=condition) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t_continuous] * 2) + c_in = torch.cat([unconditional_condition, condition]) + noise_uncond, noise = noise_pred_fn( + x_in, t_in, cond=c_in).chunk(2) + return noise_uncond + guidance_scale * (noise - noise_uncond) + + assert model_type in ['noise', 'x_start', 'v'] + assert guidance_type in ['uncond', 'classifier', 'classifier-free'] + return model_fn + + +def model_wrapper_guided_diffusion(model, + noise_schedule, + var_type, + mean_type, + model_kwargs={}, + rescale_timesteps=False, + num_timesteps=1000, + guide_scale=None, + condition_fn=None): + """Create a wrapper function for the noise prediction model guided diffusion. + + DPM-Solver needs to solve the continuous-time diffusion ODEs. + For DPMs trained on discrete-time labels, we need to + firstly wrap the model function to a noise prediction model that accepts the continuous time as the input. + + We support three types of the diffusion model by setting `mean_type`: + + 1. "x_{t-1}": previous data prediction model. + (Trained by predicting the data x_{t-1} at time t-1). + + 2. "x0": data prediction model. (Trained by predicting the data x0 at time 0). + + 3. "eps": noise prediction model. (Trained by predicting the noise). + + We support three types of guided sampling by DPMs: + 1. unconditional sampling by DPMs. + The input `model` has the following format: + `` + model(xt, t_input, **model_kwargs) -> x_{t-1} | x0 | eps + `` + + 2. classifier guidance sampling [3] by DPMs and another classifier. + The input `model` has the following format: + `` + model(xt, t_input, **model_kwargs) -> x_{t-1} | x0 | eps + `` + + The input `condition_fn` has the following format: + `` + condition_fn(xt, t_input, **model_kwargs) -> logits(xt, t_input) + `` + + [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis," + in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794. + + 3. classifier-free guidance sampling by conditional DPMs. + The input `model` has the following format: + `` + model(xt, t_input, **model_kwargs) -> x_{t-1} | x0 | eps + `` + + [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance." + arXiv preprint arXiv:2207.12598 (2022). + + + The `t_input` is the time label of the model, which may be discrete-time labels + or continuous-time labels (i.e. epsilon to T). + + We wrap the model function to accept only `x` and `t_continuous` as inputs, + and outputs the predicted noise: + `` + def model_fn(x, t_continuous) -> noise: + t_input = get_model_input_time(t_continuous) + return noise_pred_fn(model, x, t_input, **model_kwargs) + `` + where `t_continuous` is the continuous time labels (i.e. epsilon to T). + And we use `model_fn` for DPM-Solver. + + =============================================================== + + Args: + model: A diffusion model with the corresponding format described above. + noise_schedule: A noise schedule object, such as NoiseScheduleVP. + var_type: A `str`. The variance type of the diffusion model. + "learned" or "learned_range" or "fixed_large" or "fixed_small". + mean_type: A `str`. The prediction type of the diffusion model. + "x_{t-1}" or "x0" or "eps". + model_kwargs: A `dict`. A dict for the other inputs of the model function. + clamp: A `float`. The range to clamp the data. + rescale_timesteps: A `bool`. Whether to rescale the timesteps. + num_timesteps: An `int`. The number of the total diffusion steps. + guide_scale: A `float`. The strength of the classifier-free guidance. + condition_fn: A function. The function of the classifier guidance. + Returns: + A noise prediction model that accepts the noised data and the continuous time as the inputs. + """ + + def _scale_timesteps(t): + if rescale_timesteps: + return t.float() * 1000.0 / num_timesteps + return t + + def get_model_input_time(t_continuous): + """ + Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time. + For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N]. + For continuous-time DPMs, we just use `t_continuous`. + """ + if noise_schedule.schedule == 'discrete': + return (t_continuous - 1. / noise_schedule.total_N) * 1000. + else: + return t_continuous + + def noise_pred_fn(xt, t_continuous, **model_kwargs): + if t_continuous.reshape((-1, )).shape[0] == 1: + t_continuous = t_continuous.expand((xt.shape[0])) + t_input = get_model_input_time(_scale_timesteps(t_continuous)) + # predict distribution + out = model(xt, t_input, **model_kwargs) + + if var_type == 'learned': + out, _ = out.chunk(2, dim=1) + elif var_type == 'learned_range': + out, _ = out.chunk(2, dim=1) + + if mean_type == 'eps': + eps = out + elif mean_type == 'x_{t-1}': + raise NotImplementedError + assert noise_schedule.schedule == 'discrete' + mu = out + posterior_mean_coef1 = None + posterior_mean_coef2 = None + x0 = expand_dims( + 1. / posterior_mean_coef1, dims) * mu - expand_dims( + posterior_mean_coef2 / posterior_mean_coef1, dims) * xt + eps = (xt - expand_dims(alpha_t, dims) * x0) / expand_dims( + sigma_t, dims) + elif mean_type == 'x0': + alpha_t, sigma_t = noise_schedule.marginal_alpha( + t_continuous), noise_schedule.marginal_std(t_continuous) + dims = out.dim() + eps = (xt - expand_dims(alpha_t, dims) * out) / expand_dims( + sigma_t, dims) + + if condition_fn is not None: + alpha_t, sigma_t = noise_schedule.marginal_alpha( + t_continuous), noise_schedule.marginal_std(t_continuous) + eps = eps - (1 - alpha_t).sqrt() * condition_fn( + xt, t_input, **model_kwargs) + + return eps + + def model_fn(x, t_continuous): + """ + The noise predicition model function that is used for DPM-Solver. + """ + if t_continuous.reshape((-1, )).shape[0] == 1: + t_continuous = t_continuous.expand((x.shape[0])) + + if guide_scale is None: + eps = noise_pred_fn(x, t_continuous, **model_kwargs) + else: + # classifier-free guidance + # (model_kwargs[0]: conditional kwargs; model_kwargs[1]: non-conditional kwargs) + assert isinstance(model_kwargs, list) and len(model_kwargs) == 2 + y_out = noise_pred_fn(x, t_continuous, **model_kwargs[0]) + u_out = noise_pred_fn(x, t_continuous, **model_kwargs[1]) + dim = y_out.size(1) if var_type.startswith( + 'fixed') else y_out.size(1) // 2 + eps = u_out[:, :dim] + guide_scale * ( + y_out[:, :dim] - u_out[:, :dim]) + return eps + + return model_fn + + +class DPM_Solver: + + def __init__(self, + model_fn, + noise_schedule, + algorithm_type='dpmsolver++', + percentile=None, + thresholding_max_val=1., + clamp=None): + r"""Construct a DPM-Solver. + + We support both the noise prediction model ("predicting epsilon") + and the data prediction model ("predicting x0"). + If `predict_x0` is False, we use the solver for the noise prediction model (DPM-Solver). + If `predict_x0` is True, we use the solver for the data prediction model (DPM-Solver++). + In such case, we further support the "dynamic thresholding" + in [1] when `thresholding` is True. + The "dynamic thresholding" can greatly improve the sample quality for pixel-space DPMs + with large guidance scales. + + Args: + model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]): + `` + def model_fn(x, t_continuous): + return noise + `` + noise_schedule: A noise schedule object, such as NoiseScheduleVP. + predict_x0: A `bool`. If true, use the data prediction model; else, use the noise prediction model. + thresholding: A `bool`. Valid when `predict_x0` is True. Whether to use the "dynamic thresholding" in [1]. + max_val: A `float`. Valid when both `predict_x0` and + `thresholding` are True. The max value for thresholding. + + [1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, + Emily Denton, Seyed Kamyar Seyed Ghasemipour, + Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. + Photorealistic text-to-image diffusion models with deep language understanding. + arXiv preprint arXiv:2205.11487, 2022b. + """ + self.model = model_fn + self.noise_schedule = noise_schedule + self.algorithm_type = algorithm_type + self.percentile = percentile + self.thresholding_max_val = thresholding_max_val + self.clamp = clamp + + def noise_prediction_fn(self, x, t): + """ + Return the noise prediction model. + """ + return self.model(x, t) + + def data_prediction_fn(self, x, t): + """ + Return the data prediction model (with thresholding). + """ + noise = self.noise_prediction_fn(x, t) + dims = x.dim() + alpha_t, sigma_t = self.noise_schedule.marginal_alpha( + t), self.noise_schedule.marginal_std(t) + x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims( + alpha_t, dims) + if self.percentile is not None: + s = torch.quantile( + torch.abs(x0).reshape((x0.shape[0], -1)), + self.percentile, + dim=1) + s = expand_dims( + torch.maximum( + s, + self.thresholding_max_val + * torch.ones_like(s).to(s.device)), dims) + x0 = torch.clamp(x0, -s, s) / s + elif self.clamp is not None: + x0 = x0.clamp(-clamp, clamp) + return x0 + + def model_fn(self, x, t): + """ + Convert the model to the noise prediction model or the data prediction model. + """ + if self.algorithm_type == 'dpmsolver++': + return self.data_prediction_fn(x, t) + else: + return self.noise_prediction_fn(x, t) + + def get_time_steps(self, skip_type, t_T, t_0, N, device): + """Compute the intermediate time steps for sampling. + + Args: + skip_type: A `str`. The type for the spacing of the time steps. We support three types: + - 'logSNR': uniform logSNR for the time steps. + - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.) + - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.) + t_T: A `float`. The starting time of the sampling (default is T). + t_0: A `float`. The ending time of the sampling (default is epsilon). + N: A `int`. The total number of the spacing of the time steps. + device: A torch device. + Returns: + A pytorch tensor of the time steps, with the shape (N + 1,). + """ + if skip_type == 'logSNR': + lambda_T = self.noise_schedule.marginal_lambda( + torch.tensor(t_T).to(device)) + lambda_0 = self.noise_schedule.marginal_lambda( + torch.tensor(t_0).to(device)) + logSNR_steps = torch.linspace(lambda_T.cpu().item(), + lambda_0.cpu().item(), + N + 1).to(device) + return self.noise_schedule.inverse_lambda(logSNR_steps) + elif skip_type == 'time_uniform': + return torch.linspace(t_T, t_0, N + 1).to(device) + elif skip_type == 'time_quadratic': + t_order = 2 + t = torch.linspace(t_T**(1. / t_order), t_0**(1. / t_order), + N + 1).pow(t_order).to(device) + return t + else: + raise ValueError( + "Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'" + .format(skip_type)) + + def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, + skip_type, t_T, t_0, + device): + """ + Get the order of each step for sampling by the singlestep DPM-Solver. + + We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast". + Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is: + - If order == 1: + We take `steps` of DPM-Solver-1 (i.e. DDIM). + - If order == 2: + - Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling. + - If steps % 2 == 0, we use K steps of DPM-Solver-2. + - If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1. + - If order == 3: + - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling. + - If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, + and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1. + - If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1. + - If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2. + + ============================================ + Args: + order: A `int`. The max order for the solver (2 or 3). + steps: A `int`. The total number of function evaluations (NFE). + skip_type: A `str`. The type for the spacing of the time steps. We support three types: + - 'logSNR': uniform logSNR for the time steps. + - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.) + - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.) + t_T: A `float`. The starting time of the sampling (default is T). + t_0: A `float`. The ending time of the sampling (default is epsilon). + device: A torch device. + Returns: + orders: A list of the solver order of each step. + """ + if order == 3: + K = steps // 3 + 1 + if steps % 3 == 0: + orders = [ + 3, + ] * (K - 2) + [2, 1] + elif steps % 3 == 1: + orders = [ + 3, + ] * (K - 1) + [1] + else: + orders = [ + 3, + ] * (K - 1) + [2] + elif order == 2: + if steps % 2 == 0: + K = steps // 2 + orders = [ + 2, + ] * K + else: + K = steps // 2 + 1 + orders = [ + 2, + ] * (K - 1) + [1] + elif order == 1: + K = 1 + orders = [ + 1, + ] * steps + else: + raise ValueError("'order' must be '1' or '2' or '3'.") + if skip_type == 'logSNR': + # To reproduce the results in DPM-Solver paper + timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, + device) + else: + timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, + device)[torch.cumsum( + torch.tensor([ + 0, + ] + orders), + 0).to(device)] + return timesteps_outer, orders + + def denoise_to_zero_fn(self, x, s): + """ + Denoise at the final step, which is equivalent to solve + the ODE from lambda_s to infty by first-order discretization. + """ + return self.data_prediction_fn(x, s) + + def dpm_solver_first_update(self, + x, + s, + t, + model_s=None, + return_intermediate=False): + """ + DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (1,). + t: A pytorch tensor. The ending time, with the shape (1,). + model_s: A pytorch tensor. The model function evaluated at time `s`. + If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. + return_intermediate: A `bool`. If true, also return the model value at time `s`. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + ns = self.noise_schedule + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff( + s), ns.marginal_log_mean_coeff(t) + sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + if self.algorithm_type == 'dpmsolver++': + phi_1 = torch.expm1(-h) + if model_s is None: + model_s = self.model_fn(x, s) + x_t = (sigma_t / sigma_s * x - alpha_t * phi_1 * model_s) + if return_intermediate: + return x_t, {'model_s': model_s} + else: + return x_t + else: + phi_1 = torch.expm1(h) + if model_s is None: + model_s = self.model_fn(x, s) + x_t = ( + torch.exp(log_alpha_t - log_alpha_s) * x # noqa + - (sigma_t * phi_1) * model_s) + if return_intermediate: + return x_t, {'model_s': model_s} + else: + return x_t + + def singlestep_dpm_solver_second_update(self, + x, + s, + t, + r1=0.5, + model_s=None, + return_intermediate=False, + solver_type='dpmsolver'): + """ + Singlestep solver DPM-Solver-2 from time `s` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (1,). + t: A pytorch tensor. The ending time, with the shape (1,). + r1: A `float`. The hyperparameter of the second-order solver. + model_s: A pytorch tensor. The model function evaluated at time `s`. + If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. + return_intermediate: A `bool`. + If true, also return the model value at time `s` and `s1` (the intermediate time). + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if solver_type not in ['dpmsolver', 'taylor']: + raise ValueError( + "'solver_type' must be either 'dpmsolver' or 'taylor', got {}". + format(solver_type)) + if r1 is None: + r1 = 0.5 + ns = self.noise_schedule + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + lambda_s1 = lambda_s + r1 * h + s1 = ns.inverse_lambda(lambda_s1) + log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff( + s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(t) + sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std( + s1), ns.marginal_std(t) + alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t) + + if self.algorithm_type == 'dpmsolver++': + phi_11 = torch.expm1(-r1 * h) + phi_1 = torch.expm1(-h) + + if model_s is None: + model_s = self.model_fn(x, s) + x_s1 = ((sigma_s1 / sigma_s) * x - (alpha_s1 * phi_11) * model_s) + model_s1 = self.model_fn(x_s1, s1) + if solver_type == 'dpmsolver': + x_t = ((sigma_t / sigma_s) * x - # noqa + (alpha_t * phi_1) * model_s # noqa + - (0.5 / r1) * (alpha_t * phi_1) * (model_s1 - model_s)) + elif solver_type == 'taylor': + x_t = ((sigma_t / sigma_s) * x - # noqa + (alpha_t * phi_1) * model_s # noqa + + (1. / r1) * ( + alpha_t * # noqa + (phi_1 / h + 1.)) * (model_s1 - model_s)) + else: + phi_11 = torch.expm1(r1 * h) + phi_1 = torch.expm1(h) + + if model_s is None: + model_s = self.model_fn(x, s) + x_s1 = ( + torch.exp(log_alpha_s1 - log_alpha_s) * x # noqa + - (sigma_s1 * phi_11) * model_s) + model_s1 = self.model_fn(x_s1, s1) + if solver_type == 'dpmsolver': + x_t = ( + torch.exp(log_alpha_t - log_alpha_s) * x # noqa + - (sigma_t * phi_1) * model_s - (0.5 / r1) # noqa + * (sigma_t * phi_1) * (model_s1 - model_s)) + elif solver_type == 'taylor': + x_t = ( + torch.exp(log_alpha_t - log_alpha_s) * x # noqa + - (sigma_t * phi_1) * model_s - (1. / r1) # noqa + * (sigma_t * (phi_1 / h - 1.)) * (model_s1 - model_s)) + if return_intermediate: + return x_t, {'model_s': model_s, 'model_s1': model_s1} + else: + return x_t + + def singlestep_dpm_solver_third_update(self, + x, + s, + t, + r1=1. / 3., + r2=2. / 3., + model_s=None, + model_s1=None, + return_intermediate=False, + solver_type='dpmsolver'): + """ + Singlestep solver DPM-Solver-3 from time `s` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (1,). + t: A pytorch tensor. The ending time, with the shape (1,). + r1: A `float`. The hyperparameter of the third-order solver. + r2: A `float`. The hyperparameter of the third-order solver. + model_s: A pytorch tensor. The model function evaluated at time `s`. + If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. + model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`). + If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it. + return_intermediate: A `bool`. If true, + also return the model value at time `s`, `s1` and `s2` (the intermediate times). + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if solver_type not in ['dpmsolver', 'taylor']: + raise ValueError( + "'solver_type' must be either 'dpmsolver' or 'taylor', got {}". + format(solver_type)) + if r1 is None: + r1 = 1. / 3. + if r2 is None: + r2 = 2. / 3. + ns = self.noise_schedule + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + lambda_s1 = lambda_s + r1 * h + lambda_s2 = lambda_s + r2 * h + s1 = ns.inverse_lambda(lambda_s1) + s2 = ns.inverse_lambda(lambda_s2) + log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ns.marginal_log_mean_coeff( + s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff( + s2), ns.marginal_log_mean_coeff(t) + sigma_s, sigma_s1, sigma_s2, sigma_t = ns.marginal_std( + s), ns.marginal_std(s1), ns.marginal_std(s2), ns.marginal_std(t) + alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp( + log_alpha_s2), torch.exp(log_alpha_t) + + if self.algorithm_type == 'dpmsolver++': + phi_11 = torch.expm1(-r1 * h) + phi_12 = torch.expm1(-r2 * h) + phi_1 = torch.expm1(-h) + phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1. + phi_2 = phi_1 / h + 1. + phi_3 = phi_2 / h - 0.5 + + if model_s is None: + model_s = self.model_fn(x, s) + if model_s1 is None: + x_s1 = ((sigma_s1 / sigma_s) * x # noqa + - (alpha_s1 * phi_11) * model_s) + model_s1 = self.model_fn(x_s1, s1) + x_s2 = ((sigma_s2 / sigma_s) * x - # noqa + (alpha_s2 * phi_12) * model_s + r2 / r1 # noqa + * (alpha_s2 * phi_22) * (model_s1 - model_s)) + model_s2 = self.model_fn(x_s2, s2) + if solver_type == 'dpmsolver': + x_t = ((sigma_t / sigma_s) * x - # noqa + (alpha_t * phi_1) * model_s # noqa + + (1. / r2) * (alpha_t * phi_2) * (model_s2 - model_s)) + elif solver_type == 'taylor': + D1_0 = (1. / r1) * (model_s1 - model_s) + D1_1 = (1. / r2) * (model_s2 - model_s) + D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1) + D2 = 2. * (D1_1 - D1_0) / (r2 - r1) + x_t = ((sigma_t / sigma_s) * x - # noqa + (alpha_t * phi_1) * model_s # noqa + + (alpha_t * phi_2) * D1 - (alpha_t * phi_3) * D2) + else: + phi_11 = torch.expm1(r1 * h) + phi_12 = torch.expm1(r2 * h) + phi_1 = torch.expm1(h) + phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1. + phi_2 = phi_1 / h - 1. + phi_3 = phi_2 / h - 0.5 + + if model_s is None: + model_s = self.model_fn(x, s) + if model_s1 is None: + x_s1 = ((torch.exp(log_alpha_s1 - log_alpha_s)) * x # noqa + - (sigma_s1 * phi_11) * model_s) + model_s1 = self.model_fn(x_s1, s1) + x_s2 = ((torch.exp(log_alpha_s2 - log_alpha_s)) * x # noqa + - (sigma_s2 * phi_12) * model_s - r2 / r1 # noqa + * (sigma_s2 * phi_22) * (model_s1 - model_s)) + model_s2 = self.model_fn(x_s2, s2) + if solver_type == 'dpmsolver': + x_t = ((torch.exp(log_alpha_t - log_alpha_s)) * x # noqa + - (sigma_t * phi_1) * model_s - (1. / r2) # noqa + * (sigma_t * phi_2) * (model_s2 - model_s)) + elif solver_type == 'taylor': + D1_0 = (1. / r1) * (model_s1 - model_s) + D1_1 = (1. / r2) * (model_s2 - model_s) + D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1) + D2 = 2. * (D1_1 - D1_0) / (r2 - r1) + x_t = ((torch.exp(log_alpha_t - log_alpha_s)) * x # noqa + - (sigma_t * phi_1) * model_s - # noqa + (sigma_t * phi_2) * D1 # noqa + - (sigma_t * phi_3) * D2) + + if return_intermediate: + return x_t, { + 'model_s': model_s, + 'model_s1': model_s1, + 'model_s2': model_s2 + } + else: + return x_t + + def multistep_dpm_solver_second_update(self, + x, + model_prev_list, + t_prev_list, + t, + solver_type='dpmsolver'): + """ + Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + model_prev_list: A list of pytorch tensor. The previous computed model values. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,) + t: A pytorch tensor. The ending time, with the shape (1,). + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if solver_type not in ['dpmsolver', 'taylor']: + raise ValueError( + "'solver_type' must be either 'dpmsolver' or 'taylor', got {}". + format(solver_type)) + ns = self.noise_schedule + model_prev_1, model_prev_0 = model_prev_list[-2], model_prev_list[-1] + t_prev_1, t_prev_0 = t_prev_list[-2], t_prev_list[-1] + lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda( + t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t) + log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff( + t_prev_0), ns.marginal_log_mean_coeff(t) + sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + h_0 = lambda_prev_0 - lambda_prev_1 + h = lambda_t - lambda_prev_0 + r0 = h_0 / h + D1_0 = (1. / r0) * (model_prev_0 - model_prev_1) + if self.algorithm_type == 'dpmsolver++': + phi_1 = torch.expm1(-h) + if solver_type == 'dpmsolver': + x_t = ((sigma_t / sigma_prev_0) * x # noqa + - (alpha_t * phi_1) * model_prev_0 - 0.5 # noqa + * (alpha_t * phi_1) * D1_0) + elif solver_type == 'taylor': + x_t = ((sigma_t / sigma_prev_0) * x - # noqa + (alpha_t * phi_1) * model_prev_0 + # noqa + (alpha_t * (phi_1 / h + 1.)) * D1_0) + else: + phi_1 = torch.expm1(h) + if solver_type == 'dpmsolver': + x_t = ((torch.exp(log_alpha_t - log_alpha_prev_0)) * x # noqa + - (sigma_t * phi_1) * model_prev_0 - 0.5 * # noqa + (sigma_t * phi_1) * D1_0) + elif solver_type == 'taylor': + x_t = ((torch.exp(log_alpha_t - log_alpha_prev_0)) * x # noqa + - (sigma_t * phi_1) * model_prev_0 # noqa + - (sigma_t * (phi_1 / h - 1.)) * D1_0) + return x_t + + def multistep_dpm_solver_third_update(self, + x, + model_prev_list, + t_prev_list, + t, + solver_type='dpmsolver'): + """ + Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + model_prev_list: A list of pytorch tensor. The previous computed model values. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,) + t: A pytorch tensor. The ending time, with the shape (1,). + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + ns = self.noise_schedule + model_prev_2, model_prev_1, model_prev_0 = model_prev_list + t_prev_2, t_prev_1, t_prev_0 = t_prev_list + lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda( + t_prev_2), ns.marginal_lambda(t_prev_1), ns.marginal_lambda( + t_prev_0), ns.marginal_lambda(t) + log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff( + t_prev_0), ns.marginal_log_mean_coeff(t) + sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + h_1 = lambda_prev_1 - lambda_prev_2 + h_0 = lambda_prev_0 - lambda_prev_1 + h = lambda_t - lambda_prev_0 + r0, r1 = h_0 / h, h_1 / h + D1_0 = (1. / r0) * (model_prev_0 - model_prev_1) + D1_1 = (1. / r1) * (model_prev_1 - model_prev_2) + D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1) + D2 = (1. / (r0 + r1)) * (D1_0 - D1_1) + if self.algorithm_type == 'dpmsolver++': + phi_1 = torch.expm1(-h) + phi_2 = phi_1 / h + 1. + phi_3 = phi_2 / h - 0.5 + x_t = ((sigma_t / sigma_prev_0) * x - # noqa + (alpha_t * phi_1) * model_prev_0 + # noqa + (alpha_t * phi_2) * D1 # noqa + - (alpha_t * phi_3) * D2) + else: + phi_1 = torch.expm1(h) + phi_2 = phi_1 / h - 1. + phi_3 = phi_2 / h - 0.5 + x_t = ((torch.exp(log_alpha_t - log_alpha_prev_0)) * x # noqa + - (sigma_t * phi_1) * model_prev_0 - # noqa + (sigma_t * phi_2) * D1 # noqa + - (sigma_t * phi_3) * D2) + return x_t + + def singlestep_dpm_solver_update(self, + x, + s, + t, + order, + return_intermediate=False, + solver_type='dpmsolver', + r1=None, + r2=None): + """ + Singlestep DPM-Solver with the order `order` from time `s` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (1,). + t: A pytorch tensor. The ending time, with the shape (1,). + order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3. + return_intermediate: A `bool`. If true, also return the model value at time `s`, + `s1` and `s2` (the intermediate times). + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + r1: A `float`. The hyperparameter of the second-order or third-order solver. + r2: A `float`. The hyperparameter of the third-order solver. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if order == 1: + return self.dpm_solver_first_update( + x, s, t, return_intermediate=return_intermediate) + elif order == 2: + return self.singlestep_dpm_solver_second_update( + x, + s, + t, + return_intermediate=return_intermediate, + solver_type=solver_type, + r1=r1) + elif order == 3: + return self.singlestep_dpm_solver_third_update( + x, + s, + t, + return_intermediate=return_intermediate, + solver_type=solver_type, + r1=r1, + r2=r2) + else: + raise ValueError( + 'Solver order must be 1 or 2 or 3, got {}'.format(order)) + + def multistep_dpm_solver_update(self, + x, + model_prev_list, + t_prev_list, + t, + order, + solver_type='dpmsolver'): + """ + Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + model_prev_list: A list of pytorch tensor. The previous computed model values. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,) + t: A pytorch tensor. The ending time, with the shape (1,). + order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3. + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if order == 1: + return self.dpm_solver_first_update( + x, t_prev_list[-1], t, model_s=model_prev_list[-1]) + elif order == 2: + return self.multistep_dpm_solver_second_update( + x, model_prev_list, t_prev_list, t, solver_type=solver_type) + elif order == 3: + return self.multistep_dpm_solver_third_update( + x, model_prev_list, t_prev_list, t, solver_type=solver_type) + else: + raise ValueError( + 'Solver order must be 1 or 2 or 3, got {}'.format(order)) + + def dpm_solver_adaptive(self, + x, + order, + t_T, + t_0, + h_init=0.05, + atol=0.0078, + rtol=0.05, + theta=0.9, + t_err=1e-5, + solver_type='dpmsolver'): + """ + The adaptive step size solver based on singlestep DPM-Solver. + + Args: + x: A pytorch tensor. The initial value at time `t_T`. + order: A `int`. The (higher) order of the solver. We only support order == 2 or 3. + t_T: A `float`. The starting time of the sampling (default is T). + t_0: A `float`. The ending time of the sampling (default is epsilon). + h_init: A `float`. The initial step size (for logSNR). + atol: A `float`. The absolute tolerance of the solver. + For image data, the default setting is 0.0078, followed [1]. + rtol: A `float`. The relative tolerance of the solver. + The default setting is 0.05. + theta: A `float`. The safety hyperparameter for adapting the step size. + The default setting is 0.9, followed [1]. + t_err: A `float`. The tolerance for the time. + We solve the diffusion ODE until the absolute error between the + current time and `t_0` is less than `t_err`. The default setting is 1e-5. + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_0: A pytorch tensor. The approximated solution at time `t_0`. + + [1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, + "Gotta go fast when generating data with score-based models," + arXiv preprint arXiv:2105.14080, 2021. + """ + ns = self.noise_schedule + s = t_T * torch.ones((1, )).to(x) + lambda_s = ns.marginal_lambda(s) + lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x)) + h = h_init * torch.ones_like(s).to(x) + x_prev = x + nfe = 0 + if order == 2: + r1 = 0.5 + lower_update = lambda x, s, t: self.dpm_solver_first_update( # noqa + x, s, t, return_intermediate=True) + higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update( # noqa + x, s, t, r1=r1, solver_type=solver_type, **kwargs) + elif order == 3: + r1, r2 = 1. / 3., 2. / 3. + lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update( # noqa + x, + s, + t, + r1=r1, + return_intermediate=True, + solver_type=solver_type) + higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update( # noqa + x, + s, + t, + r1=r1, + r2=r2, + solver_type=solver_type, + **kwargs) + else: + raise ValueError( + 'For adaptive step size solver, order must be 2 or 3, got {}'. + format(order)) + while torch.abs((s - t_0)).mean() > t_err: + t = ns.inverse_lambda(lambda_s + h) + x_lower, lower_noise_kwargs = lower_update(x, s, t) + x_higher = higher_update(x, s, t, **lower_noise_kwargs) + delta = torch.max( + torch.ones_like(x).to(x) * atol, + rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev))) + norm_fn = lambda v: torch.sqrt( # noqa + torch.square(v.reshape( + (v.shape[0], -1))).mean(dim=-1, keepdim=True)) + E = norm_fn((x_higher - x_lower) / delta).max() + if torch.all(E <= 1.): + x = x_higher + s = t + x_prev = x_lower + lambda_s = ns.marginal_lambda(s) + h = torch.min( + theta * h * torch.float_power(E, -1. / order).float(), + lambda_0 - lambda_s) + nfe += order + print('adaptive solver nfe', nfe) + return x + + def add_noise(self, x, t, noise=None): + """ + Compute the noised input xt = alpha_t * x + sigma_t * noise. + + Args: + x: A `torch.Tensor` with shape `(batch_size, *shape)`. + t: A `torch.Tensor` with shape `(t_size,)`. + Returns: + xt with shape `(t_size, batch_size, *shape)`. + """ + alpha_t, sigma_t = self.noise_schedule.marginal_alpha( + t), self.noise_schedule.marginal_std(t) + if noise is None: + noise = torch.randn((t.shape[0], *x.shape), device=x.device) + x = x.reshape((-1, *x.shape)) + xt = expand_dims(alpha_t, x.dim()) * x + expand_dims(sigma_t, + x.dim()) * noise + if t.shape[0] == 1: + return xt.squeeze(0) + else: + return xt + + def inverse(self, + x, + steps=20, + t_start=None, + t_end=None, + order=2, + skip_type='time_uniform', + method='multistep', + lower_order_final=True, + denoise_to_zero=False, + solver_type='dpmsolver', + atol=0.0078, + rtol=0.05, + return_intermediate=False): + """ + Inverse the sample `x` from time `t_start` to `t_end` by DPM-Solver. + For discrete-time DPMs, we use `t_start=1/N`, where `N` is the total time steps during training. + """ + t_0 = 1. / self.noise_schedule.total_N if t_start is None else t_start + t_T = self.noise_schedule.T if t_end is None else t_end + assert t_0 > 0 and t_T > 0, 'Time range needs to be greater than 0. For discrete-time DPMs, \ + it needs to be in [1 / N, 1], where N is the length of betas array' + + return self.sample( + x, + steps=steps, + t_start=t_0, + t_end=t_T, + order=order, + skip_type=skip_type, + method=method, + lower_order_final=lower_order_final, + denoise_to_zero=denoise_to_zero, + solver_type=solver_type, + atol=atol, + rtol=rtol, + return_intermediate=return_intermediate) + + def sample(self, + x, + steps=20, + t_start=None, + t_end=None, + order=2, + skip_type='time_uniform', + method='multistep', + lower_order_final=True, + denoise_to_zero=False, + solver_type='dpmsolver', + atol=0.0078, + rtol=0.05): + """ + Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`. + + ===================================================== + + We support the following algorithms for both noise prediction model and data prediction model: + - 'singlestep': + Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper), + which combines different orders of singlestep DPM-Solver. + We combine all the singlestep solvers with order <= + `order` to use up all the function evaluations (steps). + The total number of function evaluations (NFE) == `steps`. + Given a fixed NFE == `steps`, the sampling procedure is: + - If `order` == 1: + - Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM). + - If `order` == 2: + - Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling. + - If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2. + - If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1. + - If `order` == 3: + - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling. + - If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, + and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1. + - If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1. + - If steps % 3 == 2, we use (K - 1) + steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2. + - 'multistep': + Multistep DPM-Solver with the order of `order`. + The total number of function evaluations (NFE) == `steps`. + We initialize the first `order` values by lower order multistep solvers. + Given a fixed NFE == `steps`, the sampling procedure is: + Denote K = steps. + - If `order` == 1: + - We use K steps of DPM-Solver-1 (i.e. DDIM). + - If `order` == 2: + - We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2. + - If `order` == 3: + - We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, + then (K - 2) step of multistep DPM-Solver-3. + - 'singlestep_fixed': + Fixed order singlestep DPM-Solver + (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3). + We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, + with total [`steps` // `order`] * `order` NFE. + - 'adaptive': + Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper). + We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`. + You can adjust the absolute tolerance `atol` and the relative tolerance + `rtol` to balance the computatation costs + (NFE) and the sample quality. + - If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 + and singlestep DPM-Solver-2. + - If `order` == 3, we use DPM-Solver-23 which combines singlestep + DPM-Solver-2 and singlestep DPM-Solver-3. + + ===================================================== + + Some advices for choosing the algorithm: + - For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs: + Use singlestep DPM-Solver ("DPM-Solver-fast" in the paper) with `order = 3`. + e.g. + >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=False) + >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3, + skip_type='time_uniform', method='singlestep') + - For **guided sampling with large guidance scale** by DPMs: + Use multistep DPM-Solver with `predict_x0 = True` and `order = 2`. + e.g. + >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True) + >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2, + skip_type='time_uniform', method='multistep') + + We support three types of `skip_type`: + - 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images** + - 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**. + - 'time_quadratic': quadratic time for the time steps. + + ===================================================== + Args: + x: A pytorch tensor. The initial value at time `t_start` + e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution. + steps: A `int`. The total number of function evaluations (NFE). + t_start: A `float`. The starting time of the sampling. + If `T` is None, we use self.noise_schedule.T (default is 1.0). + t_end: A `float`. The ending time of the sampling. + If `t_end` is None, we use 1. / self.noise_schedule.total_N. + e.g. if total_N == 1000, we have `t_end` == 1e-3. + For discrete-time DPMs: + - We recommend `t_end` == 1. / self.noise_schedule.total_N. + For continuous-time DPMs: + - We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15. + order: A `int`. The order of DPM-Solver. + skip_type: A `str`. The type for the spacing of the time steps. + 'time_uniform' or 'logSNR' or 'time_quadratic'. + method: A `str`. The method for sampling. + 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'. + denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step. + Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1). + + This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and + score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID + for diffusion models sampling by diffusion SDEs for low-resolutional images + (such as CIFAR-10). However, we observed that such trick does not matter for + high-resolutional images. As it needs an additional NFE, we do not recommend + it for high-resolutional images. + lower_order_final: A `bool`. Whether to use lower order solvers at the final steps. + Only valid for `method=multistep` and `steps < 15`. We empirically find that + this trick is a key to stabilizing the sampling by DPM-Solver with very few steps + (especially for steps <= 10). So we recommend to set it to be `True`. + solver_type: A `str`. The taylor expansion type for the solver. + `dpmsolver` or `taylor`. We recommend `dpmsolver`. + atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'. + rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'. + Returns: + x_end: A pytorch tensor. The approximated solution at time `t_end`. + + """ + t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end + t_T = self.noise_schedule.T if t_start is None else t_start + assert t_0 > 0 and t_T > 0, 'Time range needs to be greater than 0. For discrete-time DPMs, \ + it needs to be in [1 / N, 1], where N is the length of betas array' + + device = x.device + with torch.no_grad(): + if method == 'adaptive': + x = self.dpm_solver_adaptive( + x, + order=order, + t_T=t_T, + t_0=t_0, + atol=atol, + rtol=rtol, + solver_type=solver_type) + elif method == 'multistep': + assert steps >= order + timesteps = self.get_time_steps( + skip_type=skip_type, + t_T=t_T, + t_0=t_0, + N=steps, + device=device) + assert timesteps.shape[0] - 1 == steps + # Init the initial values. + step = 0 + t = timesteps[step] + t_prev_list = [t] + model_prev_list = [self.model_fn(x, t)] + # Init the first `order` values by lower order multistep DPM-Solver. + for step in range(1, order): + t = timesteps[step] + x = self.multistep_dpm_solver_update( + x, + model_prev_list, + t_prev_list, + t, + step, + solver_type=solver_type) + t_prev_list.append(t) + model_prev_list.append(self.model_fn(x, t)) + # Compute the remaining values by `order`-th order multistep DPM-Solver. + for step in range(order, steps + 1): + t = timesteps[step] + # We only use lower order for steps < 10 + if lower_order_final and steps < 10: + step_order = min(order, steps + 1 - step) + else: + step_order = order + x = self.multistep_dpm_solver_update( + x, + model_prev_list, + t_prev_list, + t, + step_order, + solver_type=solver_type) + for i in range(order - 1): + t_prev_list[i] = t_prev_list[i + 1] + model_prev_list[i] = model_prev_list[i + 1] + t_prev_list[-1] = t + # We do not need to evaluate the final model value. + if step < steps: + model_prev_list[-1] = self.model_fn(x, t) + elif method in ['singlestep', 'singlestep_fixed']: + if method == 'singlestep': + timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver( + steps=steps, + order=order, + skip_type=skip_type, + t_T=t_T, + t_0=t_0, + device=device) + elif method == 'singlestep_fixed': + K = steps // order + orders = [ + order, + ] * K + timesteps_outer = self.get_time_steps( + skip_type=skip_type, + t_T=t_T, + t_0=t_0, + N=K, + device=device) + for step, order in enumerate(orders): + s, t = timesteps_outer[step], timesteps_outer[step + 1] + timesteps_inner = self.get_time_steps( + skip_type=skip_type, + t_T=s.item(), + t_0=t.item(), + N=order, + device=device) + lambda_inner = self.noise_schedule.marginal_lambda( + timesteps_inner) + h = lambda_inner[-1] - lambda_inner[0] + r1 = None if order <= 1 else (lambda_inner[1] + - lambda_inner[0]) / h + r2 = None if order <= 2 else (lambda_inner[2] + - lambda_inner[0]) / h + x = self.singlestep_dpm_solver_update( + x, s, t, order, solver_type=solver_type, r1=r1, r2=r2) + + else: + raise ValueError('Got wrong method {}'.format(method)) + if denoise_to_zero: + t = torch.ones((1, )).to(device) * t_0 + x = self.denoise_to_zero_fn(x, t) + return x + + +def interpolate_fn(x, xp, yp): + """ + A piecewise linear function y = f(x), using xp and yp as keypoints. + We implement f(x) in a differentiable way (i.e. applicable for autograd). + The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, + we use the outmost points of xp to define the linear function.) + + Args: + x: PyTorch tensor with shape [N, C], where N is the batch size, + C is the number of channels (we use C = 1 for DPM-Solver). + xp: PyTorch tensor with shape [C, K], where K is the number of keypoints. + yp: PyTorch tensor with shape [C, K]. + Returns: + The function values f(x), with shape [N, C]. + """ + N, K = x.shape[0], xp.shape[1] + all_x = torch.cat( + [x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2) + sorted_all_x, x_indices = torch.sort(all_x, dim=2) + x_idx = torch.argmin(x_indices, dim=2) + cand_start_idx = x_idx - 1 + start_idx = torch.where( + torch.eq(x_idx, 0), + torch.tensor(1, device=x.device), + torch.where( + torch.eq(x_idx, K), + torch.tensor(K - 2, device=x.device), + cand_start_idx, + ), + ) + end_idx = torch.where( + torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1) + start_x = torch.gather( + sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2) + end_x = torch.gather( + sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2) + start_idx2 = torch.where( + torch.eq(x_idx, 0), + torch.tensor(0, device=x.device), + torch.where( + torch.eq(x_idx, K), + torch.tensor(K - 2, device=x.device), + cand_start_idx, + ), + ) + y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1) + start_y = torch.gather( + y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2) + end_y = torch.gather( + y_positions_expanded, dim=2, + index=(start_idx2 + 1).unsqueeze(2)).squeeze(2) + cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x) + return cand + + +def expand_dims(v, dims): + """ + Expand the tensor `v` to the dim `dims`. + + Args: + `v`: a PyTorch tensor with shape [N]. + `dim`: a `int`. + Returns: + a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`. + """ + return v[(..., ) + (None, ) * (dims - 1)] diff --git a/modelscope/models/multi_modal/videocomposer/mha_flash.py b/modelscope/models/multi_modal/videocomposer/mha_flash.py new file mode 100644 index 00000000..009efd44 --- /dev/null +++ b/modelscope/models/multi_modal/videocomposer/mha_flash.py @@ -0,0 +1,120 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import math +import os +import random +import time + +import numpy as np +import torch +import torch.cuda.amp as amp +import torch.nn as nn +import torch.nn.functional as F +from flash_attn.flash_attention import FlashAttention + + +class FlashAttentionBlock(nn.Module): + + def __init__(self, + dim, + context_dim=None, + num_heads=None, + head_dim=None, + batch_size=4): + # consider head_dim first, then num_heads + num_heads = dim // head_dim if head_dim else num_heads + head_dim = dim // num_heads + assert num_heads * head_dim == dim + super(FlashAttentionBlock, self).__init__() + self.dim = dim + self.context_dim = context_dim + self.num_heads = num_heads + self.head_dim = head_dim + self.scale = math.pow(head_dim, -0.25) + + # layers + self.norm = nn.GroupNorm(32, dim) + self.to_qkv = nn.Conv2d(dim, dim * 3, 1) + if context_dim is not None: + self.context_kv = nn.Linear(context_dim, dim * 2) + self.proj = nn.Conv2d(dim, dim, 1) + + if self.head_dim <= 128 and (self.head_dim % 8) == 0: + self.flash_attn = FlashAttention( + softmax_scale=None, attention_dropout=0.0) + + # zero out the last layer params + nn.init.zeros_(self.proj.weight) + + def _init_weight(self, module): + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=0.15) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Conv2d): + module.weight.data.normal_(mean=0.0, std=0.15) + if module.bias is not None: + module.bias.data.zero_() + + def forward(self, x, context=None): + r"""x: [B, C, H, W]. + context: [B, L, C] or None. + """ + identity = x + b, c, h, w, n, d = *x.size(), self.num_heads, self.head_dim + + # compute query, key, value + x = self.norm(x) + q, k, v = self.to_qkv(x).view(b, n * 3, d, h * w).chunk(3, dim=1) + if context is not None: + ck, cv = self.context_kv(context).reshape(b, -1, n * 2, + d).permute(0, 2, 3, + 1).chunk( + 2, dim=1) + k = torch.cat([ck, k], dim=-1) + v = torch.cat([cv, v], dim=-1) + cq = torch.zeros([b, n, d, 4], dtype=q.dtype, device=q.device) + q = torch.cat([q, cq], dim=-1) + + qkv = torch.cat([q, k, v], dim=1) + origin_dtype = qkv.dtype + qkv = qkv.permute(0, 3, 1, 2).reshape(b, -1, 3, n, + d).half().contiguous() + out, _ = self.flash_attn(qkv) + out.to(origin_dtype) + + if context is not None: + out = out[:, :-4, :, :] + out = out.permute(0, 2, 3, 1).reshape(b, c, h, w) + + # output + x = self.proj(out) + return x + identity + + +if __name__ == '__main__': + batch_size = 8 + flash_net = FlashAttentionBlock( + dim=1280, + context_dim=512, + num_heads=None, + head_dim=64, + batch_size=batch_size).cuda() + + x = torch.randn([batch_size, 1280, 32, 32], dtype=torch.float32).cuda() + context = torch.randn([batch_size, 4, 512], dtype=torch.float32).cuda() + # context = None + flash_net.eval() + + with amp.autocast(enabled=True): + # warm up + for i in range(5): + y = flash_net(x, context) + torch.cuda.synchronize() + s1 = time.time() + for i in range(10): + y = flash_net(x, context) + torch.cuda.synchronize() + s2 = time.time() + + print(f'Average cost time {(s2-s1)*1000/10} ms') diff --git a/modelscope/models/multi_modal/videocomposer/models/__init__.py b/modelscope/models/multi_modal/videocomposer/models/__init__.py new file mode 100644 index 00000000..13e46148 --- /dev/null +++ b/modelscope/models/multi_modal/videocomposer/models/__init__.py @@ -0,0 +1,2 @@ +from .clip import * +from .midas import * diff --git a/modelscope/models/multi_modal/videocomposer/models/clip.py b/modelscope/models/multi_modal/videocomposer/models/clip.py new file mode 100644 index 00000000..3ce2d818 --- /dev/null +++ b/modelscope/models/multi_modal/videocomposer/models/clip.py @@ -0,0 +1,460 @@ +import math +import os.path as osp + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import modelscope.models.multi_modal.videocomposer.ops as ops + +__all__ = [ + 'CLIP', 'clip_vit_b_32', 'clip_vit_b_16', 'clip_vit_l_14', + 'clip_vit_l_14_336px', 'clip_vit_h_16' +] + + +def DOWNLOAD_TO_CACHE(oss_key, + file_or_dirname=None, + cache_dir=osp.join( + '/'.join(osp.abspath(__file__).split('/')[:-2]), + 'model_weights')): + r"""Download OSS [file or folder] to the cache folder. + Only the 0th process on each node will run the downloading. + Barrier all processes until the downloading is completed. + """ + # source and target paths + base_path = osp.join(cache_dir, file_or_dirname or osp.basename(oss_key)) + + return base_path + + +def to_fp16(m): + if isinstance(m, (nn.Linear, nn.Conv2d)): + m.weight.data = m.weight.data.half() + if m.bias is not None: + m.bias.data = m.bias.data.half() + elif hasattr(m, 'head'): + p = getattr(m, 'head') + p.data = p.data.half() + + +class QuickGELU(nn.Module): + + def forward(self, x): + return x * torch.sigmoid(1.702 * x) + + +class LayerNorm(nn.LayerNorm): + r"""Subclass of nn.LayerNorm to handle fp16. + """ + + def forward(self, x): + return super(LayerNorm, self).forward(x.float()).type_as(x) + + +class SelfAttention(nn.Module): + + def __init__(self, dim, num_heads, attn_dropout=0.0, proj_dropout=0.0): + assert dim % num_heads == 0 + super(SelfAttention, self).__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = 1.0 / math.sqrt(self.head_dim) + + # layers + self.to_qkv = nn.Linear(dim, dim * 3) + self.attn_dropout = nn.Dropout(attn_dropout) + self.proj = nn.Linear(dim, dim) + self.proj_dropout = nn.Dropout(proj_dropout) + + def forward(self, x, mask=None): + r"""x: [B, L, C]. + mask: [*, L, L]. + """ + b, l, _, n = *x.size(), self.num_heads + + # compute query, key, and value + q, k, v = self.to_qkv(x.transpose(0, 1)).chunk(3, dim=-1) + q = q.reshape(l, b * n, -1).transpose(0, 1) + k = k.reshape(l, b * n, -1).transpose(0, 1) + v = v.reshape(l, b * n, -1).transpose(0, 1) + + # compute attention + attn = self.scale * torch.bmm(q, k.transpose(1, 2)) + if mask is not None: + attn = attn.masked_fill(mask[:, :l, :l] == 0, float('-inf')) + attn = F.softmax(attn.float(), dim=-1).type_as(attn) + attn = self.attn_dropout(attn) + + # gather context + x = torch.bmm(attn, v) + x = x.view(b, n, l, -1).transpose(1, 2).reshape(b, l, -1) + + # output + x = self.proj(x) + x = self.proj_dropout(x) + return x + + +class AttentionBlock(nn.Module): + + def __init__(self, dim, num_heads, attn_dropout=0.0, proj_dropout=0.0): + super(AttentionBlock, self).__init__() + self.dim = dim + self.num_heads = num_heads + + # layers + self.norm1 = LayerNorm(dim) + self.attn = SelfAttention(dim, num_heads, attn_dropout, proj_dropout) + self.norm2 = LayerNorm(dim) + self.mlp = nn.Sequential( + nn.Linear(dim, dim * 4), QuickGELU(), nn.Linear(dim * 4, dim), + nn.Dropout(proj_dropout)) + + def forward(self, x, mask=None): + x = x + self.attn(self.norm1(x), mask) + x = x + self.mlp(self.norm2(x)) + return x + + +class VisionTransformer(nn.Module): + + def __init__(self, + image_size=224, + patch_size=16, + dim=768, + out_dim=512, + num_heads=12, + num_layers=12, + attn_dropout=0.0, + proj_dropout=0.0, + embedding_dropout=0.0): + assert image_size % patch_size == 0 + super(VisionTransformer, self).__init__() + self.image_size = image_size + self.patch_size = patch_size + self.dim = dim + self.out_dim = out_dim + self.num_heads = num_heads + self.num_layers = num_layers + self.num_patches = (image_size // patch_size)**2 + + # embeddings + gain = 1.0 / math.sqrt(dim) + self.patch_embedding = nn.Conv2d( + 3, dim, kernel_size=patch_size, stride=patch_size, bias=False) + self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim)) + self.pos_embedding = nn.Parameter( + gain * torch.randn(1, self.num_patches + 1, dim)) + self.dropout = nn.Dropout(embedding_dropout) + + # transformer + self.pre_norm = LayerNorm(dim) + self.transformer = nn.Sequential(*[ + AttentionBlock(dim, num_heads, attn_dropout, proj_dropout) + for _ in range(num_layers) + ]) + self.post_norm = LayerNorm(dim) + + # head + self.head = nn.Parameter(gain * torch.randn(dim, out_dim)) + + def forward(self, x): + b, dtype = x.size(0), self.head.dtype + x = x.type(dtype) + + # patch-embedding + x = self.patch_embedding(x).flatten(2).permute(0, 2, 1) # [b, n, c] + x = torch.cat([self.cls_embedding.repeat(b, 1, 1).type(dtype), x], + dim=1) + x = self.dropout(x + self.pos_embedding.type(dtype)) + x = self.pre_norm(x) + + # transformer + x = self.transformer(x) + + # head + x = self.post_norm(x) + x = torch.mm(x[:, 0, :], self.head) + return x + + def fp16(self): + return self.apply(to_fp16) + + +class TextTransformer(nn.Module): + + def __init__(self, + vocab_size, + text_len, + dim=512, + out_dim=512, + num_heads=8, + num_layers=12, + attn_dropout=0.0, + proj_dropout=0.0, + embedding_dropout=0.0): + super(TextTransformer, self).__init__() + self.vocab_size = vocab_size + self.text_len = text_len + self.dim = dim + self.out_dim = out_dim + self.num_heads = num_heads + self.num_layers = num_layers + + # embeddings + self.token_embedding = nn.Embedding(vocab_size, dim) + self.pos_embedding = nn.Parameter(0.01 * torch.randn(1, text_len, dim)) + self.dropout = nn.Dropout(embedding_dropout) + + # transformer + self.transformer = nn.ModuleList([ + AttentionBlock(dim, num_heads, attn_dropout, proj_dropout) + for _ in range(num_layers) + ]) + self.norm = LayerNorm(dim) + + # head + gain = 1.0 / math.sqrt(dim) + self.head = nn.Parameter(gain * torch.randn(dim, out_dim)) + + # causal attention mask + self.register_buffer('attn_mask', + torch.tril(torch.ones(1, text_len, text_len))) + + def forward(self, x): + eot, dtype = x.argmax(dim=-1), self.head.dtype + + # embeddings + x = self.dropout( + self.token_embedding(x).type(dtype) + + self.pos_embedding.type(dtype)) + + # transformer + for block in self.transformer: + x = block(x, self.attn_mask) + + # head + x = self.norm(x) + x = torch.mm(x[torch.arange(x.size(0)), eot], self.head) + return x + + def fp16(self): + return self.apply(to_fp16) + + +class CLIP(nn.Module): + + def __init__(self, + embed_dim=512, + image_size=224, + patch_size=16, + vision_dim=768, + vision_heads=12, + vision_layers=12, + vocab_size=49408, + text_len=77, + text_dim=512, + text_heads=8, + text_layers=12, + attn_dropout=0.0, + proj_dropout=0.0, + embedding_dropout=0.0): + super(CLIP, self).__init__() + self.embed_dim = embed_dim + self.image_size = image_size + self.patch_size = patch_size + self.vision_dim = vision_dim + self.vision_heads = vision_heads + self.vision_layers = vision_layers + self.vocab_size = vocab_size + self.text_len = text_len + self.text_dim = text_dim + self.text_heads = text_heads + self.text_layers = text_layers + + # models + self.visual = VisionTransformer( + image_size=image_size, + patch_size=patch_size, + dim=vision_dim, + out_dim=embed_dim, + num_heads=vision_heads, + num_layers=vision_layers, + attn_dropout=attn_dropout, + proj_dropout=proj_dropout, + embedding_dropout=embedding_dropout) + self.textual = TextTransformer( + vocab_size=vocab_size, + text_len=text_len, + dim=text_dim, + out_dim=embed_dim, + num_heads=text_heads, + num_layers=text_layers, + attn_dropout=attn_dropout, + proj_dropout=proj_dropout, + embedding_dropout=embedding_dropout) + self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([])) + + def forward(self, imgs, txt_tokens): + r"""imgs: [B, C, H, W] of torch.float32. + txt_tokens: [B, T] of torch.long. + """ + xi = self.visual(imgs) + xt = self.textual(txt_tokens) + + # normalize features + xi = F.normalize(xi, p=2, dim=1) + xt = F.normalize(xt, p=2, dim=1) + + # gather features from all ranks + full_xi = ops.diff_all_gather(xi) + full_xt = ops.diff_all_gather(xt) + + # logits + scale = self.log_scale.exp() + logits_i2t = scale * torch.mm(xi, full_xt.t()) + logits_t2i = scale * torch.mm(xt, full_xi.t()) + + # labels + labels = torch.arange( + len(xi) * ops.get_rank(), + len(xi) * (ops.get_rank() + 1), + dtype=torch.long, + device=xi.device) + return logits_i2t, logits_t2i, labels + + def init_weights(self): + # embeddings + nn.init.normal_(self.textual.token_embedding.weight, std=0.02) + nn.init.normal_(self.visual.patch_embedding.weight, tsd=0.1) + + # attentions + for modality in ['visual', 'textual']: + dim = self.vision_dim if modality == 'visual' else 'textual' + transformer = getattr(self, modality).transformer + proj_gain = (1.0 / math.sqrt(dim)) * ( + 1.0 / math.sqrt(2 * transformer.num_layers)) + attn_gain = 1.0 / math.sqrt(dim) + mlp_gain = 1.0 / math.sqrt(2.0 * dim) + for block in transformer.layers: + nn.init.normal_(block.attn.to_qkv.weight, std=attn_gain) + nn.init.normal_(block.attn.proj.weight, std=proj_gain) + nn.init.normal_(block.mlp[0].weight, std=mlp_gain) + nn.init.normal_(block.mlp[2].weight, std=proj_gain) + + def param_groups(self): + groups = [{ + 'params': [ + p for n, p in self.named_parameters() + if 'norm' in n or n.endswith('bias') + ], + 'weight_decay': + 0.0 + }, { + 'params': [ + p for n, p in self.named_parameters() + if not ('norm' in n or n.endswith('bias')) + ] + }] + return groups + + def fp16(self): + return self.apply(to_fp16) + + +def _clip(name, pretrained=False, **kwargs): + model = CLIP(**kwargs) + if pretrained: + model.load_state_dict( + torch.load( + DOWNLOAD_TO_CACHE(f'models/clip/{name}.pth'), + map_location='cpu')) + return model + + +def clip_vit_b_32(pretrained=False, **kwargs): + cfg = dict( + embed_dim=512, + image_size=224, + patch_size=32, + vision_dim=768, + vision_heads=12, + vision_layers=12, + vocab_size=49408, + text_len=77, + text_dim=512, + text_heads=8, + text_layers=12) + cfg.update(**kwargs) + return _clip('openai-clip-vit-base-32', pretrained, **cfg) + + +def clip_vit_b_16(pretrained=False, **kwargs): + cfg = dict( + embed_dim=512, + image_size=224, + patch_size=32, + vision_dim=768, + vision_heads=12, + vision_layers=12, + vocab_size=49408, + text_len=77, + text_dim=512, + text_heads=8, + text_layers=12) + cfg.update(**kwargs) + return _clip('openai-clip-vit-base-16', pretrained, **cfg) + + +def clip_vit_l_14(pretrained=False, **kwargs): + cfg = dict( + embed_dim=768, + image_size=224, + patch_size=14, + vision_dim=1024, + vision_heads=16, + vision_layers=24, + vocab_size=49408, + text_len=77, + text_dim=768, + text_heads=12, + text_layers=12) + cfg.update(**kwargs) + return _clip('openai-clip-vit-large-14', pretrained, **cfg) + + +def clip_vit_l_14_336px(pretrained=False, **kwargs): + cfg = dict( + embed_dim=768, + image_size=336, + patch_size=14, + vision_dim=1024, + vision_heads=16, + vision_layers=24, + vocab_size=49408, + text_len=77, + text_dim=768, + text_heads=12, + text_layers=12) + cfg.update(**kwargs) + return _clip('openai-clip-vit-large-14-336px', pretrained, **cfg) + + +def clip_vit_h_16(pretrained=False, **kwargs): + assert not pretrained, 'pretrained model for openai-clip-vit-huge-16 is not available!' + cfg = dict( + embed_dim=1024, + image_size=256, + patch_size=16, + vision_dim=1280, + vision_heads=16, + vision_layers=32, + vocab_size=49408, + text_len=77, + text_dim=1024, + text_heads=16, + text_layers=24) + cfg.update(**kwargs) + return _clip('openai-clip-vit-huge-16', pretrained, **cfg) diff --git a/modelscope/models/multi_modal/videocomposer/models/midas.py b/modelscope/models/multi_modal/videocomposer/models/midas.py new file mode 100644 index 00000000..4992e721 --- /dev/null +++ b/modelscope/models/multi_modal/videocomposer/models/midas.py @@ -0,0 +1,320 @@ +r"""A much cleaner re-implementation of ``https://github.com/isl-org/MiDaS''. + Image augmentation: T.Compose([ + Resize( + keep_aspect_ratio=True, + ensure_multiple_of=32, + interpolation=cv2.INTER_CUBIC), + T.ToTensor(), + T.Normalize( + mean=[0.5, 0.5, 0.5], + std=[0.5, 0.5, 0.5])]). + Fast inference: + model = model.to(memory_format=torch.channels_last).half() + input = input.to(memory_format=torch.channels_last).half() + output = model(input) +""" +import math +import os +import os.path as osp + +import torch +import torch.nn as nn +import torch.nn.functional as F + +__all__ = ['MiDaS', 'midas_v3'] + + +class SelfAttention(nn.Module): + + def __init__(self, dim, num_heads): + assert dim % num_heads == 0 + super(SelfAttention, self).__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = 1.0 / math.sqrt(self.head_dim) + + # layers + self.to_qkv = nn.Linear(dim, dim * 3) + self.proj = nn.Linear(dim, dim) + + def forward(self, x): + b, l, c, n, d = *x.size(), self.num_heads, self.head_dim + + # compute query, key, value + q, k, v = self.to_qkv(x).view(b, l, n * 3, d).chunk(3, dim=2) + + # compute attention + attn = self.scale * torch.einsum('binc,bjnc->bnij', q, k) + attn = F.softmax(attn.float(), dim=-1).type_as(attn) + + # gather context + x = torch.einsum('bnij,bjnc->binc', attn, v) + x = x.reshape(b, l, c) + + # output + x = self.proj(x) + return x + + +class AttentionBlock(nn.Module): + + def __init__(self, dim, num_heads): + super(AttentionBlock, self).__init__() + self.dim = dim + self.num_heads = num_heads + + # layers + self.norm1 = nn.LayerNorm(dim) + self.attn = SelfAttention(dim, num_heads) + self.norm2 = nn.LayerNorm(dim) + self.mlp = nn.Sequential( + nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim)) + + def forward(self, x): + x = x + self.attn(self.norm1(x)) + x = x + self.mlp(self.norm2(x)) + return x + + +class VisionTransformer(nn.Module): + + def __init__(self, + image_size=384, + patch_size=16, + dim=1024, + out_dim=1000, + num_heads=16, + num_layers=24): + assert image_size % patch_size == 0 + super(VisionTransformer, self).__init__() + self.image_size = image_size + self.patch_size = patch_size + self.dim = dim + self.out_dim = out_dim + self.num_heads = num_heads + self.num_layers = num_layers + self.num_patches = (image_size // patch_size)**2 + + # embeddings + self.patch_embedding = nn.Conv2d( + 3, dim, kernel_size=patch_size, stride=patch_size) + self.cls_embedding = nn.Parameter(torch.zeros(1, 1, dim)) + self.pos_embedding = nn.Parameter( + torch.empty(1, self.num_patches + 1, dim).normal_(std=0.02)) + + # blocks + self.blocks = nn.Sequential( + *[AttentionBlock(dim, num_heads) for _ in range(num_layers)]) + self.norm = nn.LayerNorm(dim) + + # head + self.head = nn.Linear(dim, out_dim) + + def forward(self, x): + b = x.size(0) + + # embeddings + x = self.patch_embedding(x).flatten(2).permute(0, 2, 1) + x = torch.cat([self.cls_embedding.repeat(b, 1, 1), x], dim=1) + x = x + self.pos_embedding + + # blocks + x = self.blocks(x) + x = self.norm(x) + + # head + x = self.head(x) + return x + + +class ResidualBlock(nn.Module): + + def __init__(self, dim): + super(ResidualBlock, self).__init__() + self.dim = dim + + # layers + self.residual = nn.Sequential( + nn.ReLU(inplace=False), # NOTE: avoid modifying the input + nn.Conv2d(dim, dim, 3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(dim, dim, 3, padding=1)) + + def forward(self, x): + return x + self.residual(x) + + +class FusionBlock(nn.Module): + + def __init__(self, dim): + super(FusionBlock, self).__init__() + self.dim = dim + + # layers + self.layer1 = ResidualBlock(dim) + self.layer2 = ResidualBlock(dim) + self.conv_out = nn.Conv2d(dim, dim, 1) + + def forward(self, *xs): + assert len(xs) in (1, 2), 'invalid number of inputs' + if len(xs) == 1: + x = self.layer2(xs[0]) + else: + x = self.layer2(xs[0] + self.layer1(xs[1])) + x = F.interpolate( + x, scale_factor=2, mode='bilinear', align_corners=True) + x = self.conv_out(x) + return x + + +class MiDaS(nn.Module): + r"""MiDaS v3.0 DPT-Large from ``https://github.com/isl-org/MiDaS''. + Monocular depth estimation using dense prediction transformers. + """ + + def __init__(self, + image_size=384, + patch_size=16, + dim=1024, + neck_dims=[256, 512, 1024, 1024], + fusion_dim=256, + num_heads=16, + num_layers=24): + assert image_size % patch_size == 0 + assert num_layers % 4 == 0 + super(MiDaS, self).__init__() + self.image_size = image_size + self.patch_size = patch_size + self.dim = dim + self.neck_dims = neck_dims + self.fusion_dim = fusion_dim + self.num_heads = num_heads + self.num_layers = num_layers + self.num_patches = (image_size // patch_size)**2 + + # embeddings + self.patch_embedding = nn.Conv2d( + 3, dim, kernel_size=patch_size, stride=patch_size) + self.cls_embedding = nn.Parameter(torch.zeros(1, 1, dim)) + self.pos_embedding = nn.Parameter( + torch.empty(1, self.num_patches + 1, dim).normal_(std=0.02)) + + # blocks + stride = num_layers // 4 + self.blocks = nn.Sequential( + *[AttentionBlock(dim, num_heads) for _ in range(num_layers)]) + self.slices = [slice(i * stride, (i + 1) * stride) for i in range(4)] + + # stage1 (4x) + self.fc1 = nn.Sequential(nn.Linear(dim * 2, dim), nn.GELU()) + self.conv1 = nn.Sequential( + nn.Conv2d(dim, neck_dims[0], 1), + nn.ConvTranspose2d(neck_dims[0], neck_dims[0], 4, stride=4), + nn.Conv2d(neck_dims[0], fusion_dim, 3, padding=1, bias=False)) + self.fusion1 = FusionBlock(fusion_dim) + + # stage2 (8x) + self.fc2 = nn.Sequential(nn.Linear(dim * 2, dim), nn.GELU()) + self.conv2 = nn.Sequential( + nn.Conv2d(dim, neck_dims[1], 1), + nn.ConvTranspose2d(neck_dims[1], neck_dims[1], 2, stride=2), + nn.Conv2d(neck_dims[1], fusion_dim, 3, padding=1, bias=False)) + self.fusion2 = FusionBlock(fusion_dim) + + # stage3 (16x) + self.fc3 = nn.Sequential(nn.Linear(dim * 2, dim), nn.GELU()) + self.conv3 = nn.Sequential( + nn.Conv2d(dim, neck_dims[2], 1), + nn.Conv2d(neck_dims[2], fusion_dim, 3, padding=1, bias=False)) + self.fusion3 = FusionBlock(fusion_dim) + + # stage4 (32x) + self.fc4 = nn.Sequential(nn.Linear(dim * 2, dim), nn.GELU()) + self.conv4 = nn.Sequential( + nn.Conv2d(dim, neck_dims[3], 1), + nn.Conv2d(neck_dims[3], neck_dims[3], 3, stride=2, padding=1), + nn.Conv2d(neck_dims[3], fusion_dim, 3, padding=1, bias=False)) + self.fusion4 = FusionBlock(fusion_dim) + + # head + self.head = nn.Sequential( + nn.Conv2d(fusion_dim, fusion_dim // 2, 3, padding=1), + nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), + nn.Conv2d(fusion_dim // 2, 32, 3, padding=1), + nn.ReLU(inplace=True), nn.ConvTranspose2d(32, 1, 1), + nn.ReLU(inplace=True)) + + def forward(self, x): + b, _, h, w, p = *x.size(), self.patch_size + assert h % p == 0 and w % p == 0, f'Image size ({w}, {h}) is not divisible by patch size ({p}, {p})' + hp, wp, grid = h // p, w // p, self.image_size // p + + # embeddings + pos_embedding = torch.cat([ + self.pos_embedding[:, :1], + F.interpolate( + self.pos_embedding[:, 1:].reshape(1, grid, grid, -1).permute( + 0, 3, 1, 2), + size=(hp, wp), + mode='bilinear', + align_corners=False).permute(0, 2, 3, 1).reshape( + 1, hp * wp, -1) + ], + dim=1) # noqa + x = self.patch_embedding(x).flatten(2).permute(0, 2, 1) + x = torch.cat([self.cls_embedding.repeat(b, 1, 1), x], dim=1) + x = x + pos_embedding + + # stage1 + x = self.blocks[self.slices[0]](x) + x1 = torch.cat([x[:, 1:], x[:, :1].expand_as(x[:, 1:])], dim=-1) + x1 = self.fc1(x1).permute(0, 2, 1).unflatten(2, (hp, wp)) + x1 = self.conv1(x1) + + # stage2 + x = self.blocks[self.slices[1]](x) + x2 = torch.cat([x[:, 1:], x[:, :1].expand_as(x[:, 1:])], dim=-1) + x2 = self.fc2(x2).permute(0, 2, 1).unflatten(2, (hp, wp)) + x2 = self.conv2(x2) + + # stage3 + x = self.blocks[self.slices[2]](x) + x3 = torch.cat([x[:, 1:], x[:, :1].expand_as(x[:, 1:])], dim=-1) + x3 = self.fc3(x3).permute(0, 2, 1).unflatten(2, (hp, wp)) + x3 = self.conv3(x3) + + # stage4 + x = self.blocks[self.slices[3]](x) + x4 = torch.cat([x[:, 1:], x[:, :1].expand_as(x[:, 1:])], dim=-1) + x4 = self.fc4(x4).permute(0, 2, 1).unflatten(2, (hp, wp)) + x4 = self.conv4(x4) + + # fusion + x4 = self.fusion4(x4) + x3 = self.fusion3(x4, x3) + x2 = self.fusion2(x3, x2) + x1 = self.fusion1(x2, x1) + + # head + x = self.head(x1) + return x + + +def midas_v3(model_dir, pretrained=False, **kwargs): + cfg = dict( + image_size=384, + patch_size=16, + dim=1024, + neck_dims=[256, 512, 1024, 1024], + fusion_dim=256, + num_heads=16, + num_layers=24) + cfg.update(**kwargs) + model = MiDaS(**cfg) + if pretrained: + model.load_state_dict( + torch.load( + os.path.join(model_dir, 'midas_v3_dpt_large.pth'), + map_location='cpu')) + return model diff --git a/modelscope/models/multi_modal/videocomposer/ops/__init__.py b/modelscope/models/multi_modal/videocomposer/ops/__init__.py new file mode 100644 index 00000000..48b87bda --- /dev/null +++ b/modelscope/models/multi_modal/videocomposer/ops/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from .degration import * +from .distributed import * +from .losses import * +from .random_mask import * +from .utils import * diff --git a/modelscope/models/multi_modal/videocomposer/ops/degration.py b/modelscope/models/multi_modal/videocomposer/ops/degration.py new file mode 100644 index 00000000..de97be65 --- /dev/null +++ b/modelscope/models/multi_modal/videocomposer/ops/degration.py @@ -0,0 +1,998 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import math +import os +import random +from datetime import datetime + +import numpy as np +import scipy +import scipy.stats as stats +import torch +from scipy import ndimage +from scipy.interpolate import interp2d +from scipy.linalg import orth +from torchvision.utils import make_grid + +os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' + +__all__ = ['degradation_bsrgan_light', 'degradation_bsrgan'] + + +# get uint8 image of size HxWxn_channles (RGB) +def imread_uint(path, n_channels=3): + # input: path + # output: HxWx3(RGB or GGG), or HxWx1 (G) + if n_channels == 1: + img = cv2.imread(path, 0) + img = np.expand_dims(img, axis=2) + elif n_channels == 3: + img = cv2.imread(path, cv2.IMREAD_UNCHANGED) + if img.ndim == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) + else: + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + return img + + +def uint2single(img): + return np.float32(img / 255.) + + +def single2uint(img): + return np.uint8((img.clip(0, 1) * 255.).round()) + + +def uint162single(img): + return np.float32(img / 65535.) + + +def single2uint16(img): + return np.uint16((img.clip(0, 1) * 65535.).round()) + + +def rgb2ycbcr(img, only_y=True): + '''same as matlab rgb2ycbcr + only_y: only return Y channel + Input: + uint8, [0, 255] + float, [0, 1] + ''' + in_img_type = img.dtype + img.astype(np.float32) + if in_img_type != np.uint8: + img *= 255. + # convert + if only_y: + rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0 + else: + rlt = np.matmul(img, + [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], + [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128] + if in_img_type == np.uint8: + rlt = rlt.round() + else: + rlt /= 255. + return rlt.astype(in_img_type) + + +def ycbcr2rgb(img): + '''same as matlab ycbcr2rgb + Input: + uint8, [0, 255] + float, [0, 1] + ''' + in_img_type = img.dtype + img.astype(np.float32) + if in_img_type != np.uint8: + img *= 255. + # convert + rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], + [0, -0.00153632, 0.00791071], + [0.00625893, -0.00318811, 0]]) * 255.0 + [ + -222.921, 135.576, -276.836 + ] # noqa + if in_img_type == np.uint8: + rlt = rlt.round() + else: + rlt /= 255. + return rlt.astype(in_img_type) + + +def bgr2ycbcr(img, only_y=True): + '''bgr version of rgb2ycbcr + only_y: only return Y channel + Input: + uint8, [0, 255] + float, [0, 1] + ''' + in_img_type = img.dtype + img.astype(np.float32) + if in_img_type != np.uint8: + img *= 255. + # convert + if only_y: + rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0 + else: + rlt = np.matmul(img, + [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], + [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128] + if in_img_type == np.uint8: + rlt = rlt.round() + else: + rlt /= 255. + return rlt.astype(in_img_type) + + +def channel_convert(in_c, tar_type, img_list): + # conversion among BGR, gray and y + if in_c == 3 and tar_type == 'gray': + gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list] + return [np.expand_dims(img, axis=2) for img in gray_list] + elif in_c == 3 and tar_type == 'y': + y_list = [bgr2ycbcr(img, only_y=True) for img in img_list] + return [np.expand_dims(img, axis=2) for img in y_list] + elif in_c == 1 and tar_type == 'RGB': + return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list] + else: + return img_list + + +# PSNR +def calculate_psnr(img1, img2, border=0): + # img1 and img2 have range [0, 255] + if not img1.shape == img2.shape: + raise ValueError('Input images must have the same dimensions.') + h, w = img1.shape[:2] + img1 = img1[border:h - border, border:w - border] + img2 = img2[border:h - border, border:w - border] + + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + mse = np.mean((img1 - img2)**2) + if mse == 0: + return float('inf') + return 20 * math.log10(255.0 / math.sqrt(mse)) + + +# SSIM +def calculate_ssim(img1, img2, border=0): + '''calculate SSIM + the same outputs as MATLAB's + img1, img2: [0, 255] + ''' + if not img1.shape == img2.shape: + raise ValueError('Input images must have the same dimensions.') + h, w = img1.shape[:2] + img1 = img1[border:h - border, border:w - border] + img2 = img2[border:h - border, border:w - border] + + if img1.ndim == 2: + return ssim(img1, img2) + elif img1.ndim == 3: + if img1.shape[2] == 3: + ssims = [] + for i in range(3): + ssims.append(ssim(img1[:, :, i], img2[:, :, i])) + return np.array(ssims).mean() + elif img1.shape[2] == 1: + return ssim(np.squeeze(img1), np.squeeze(img2)) + else: + raise ValueError('Wrong input image dimensions.') + + +def ssim(img1, img2): + C1 = (0.01 * 255)**2 + C2 = (0.03 * 255)**2 + + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + kernel = cv2.getGaussianKernel(11, 1.5) + window = np.outer(kernel, kernel.transpose()) + + mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] + mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] + mu1_sq = mu1**2 + mu2_sq = mu2**2 + mu1_mu2 = mu1 * mu2 + sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq + sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq + sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 + + ssim_map = ((2 * mu1_mu2 + C1) # noqa + * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) # noqa + * # noqa + (sigma1_sq + sigma2_sq + C2)) # noqa + return ssim_map.mean() + + +# matlab 'imresize' function, now only support 'bicubic' +def cubic(x): + absx = torch.abs(x) + absx2 = absx**2 + absx3 = absx**3 + return (1.5 * absx3 - 2.5 * absx2 + 1) * ((absx <= 1).type_as(absx)) + \ + (-0.5 * absx3 + 2.5 * absx2 - 4*absx + 2) * (((absx > 1) * (absx <= 2)).type_as(absx)) # noqa + + +def calculate_weights_indices(in_length, out_length, scale, kernel, + kernel_width, antialiasing): + if (scale < 1) and (antialiasing): + # Use a modified kernel to simultaneously interpolate and antialias- larger kernel width + kernel_width = kernel_width / scale + + # Output-space coordinates + x = torch.linspace(1, out_length, out_length) + + # Input-space coordinates. Calculate the inverse mapping such that 0.5 + # in output space maps to 0.5 in input space, and 0.5+scale in output + # space maps to 1.5 in input space. + u = x / scale + 0.5 * (1 - 1 / scale) + + # What is the left-most pixel that can be involved in the computation? + left = torch.floor(u - kernel_width / 2) + + # What is the maximum number of pixels that can be involved in the + # computation? Note: it's OK to use an extra pixel here; if the + # corresponding weights are all zero, it will be eliminated at the end + # of this function. + P = math.ceil(kernel_width) + 2 + + # The indices of the input pixels involved in computing the k-th output + # pixel are in row k of the indices matrix. + indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace( + 0, P - 1, P).view(1, P).expand(out_length, P) + + # The weights used to compute the k-th output pixel are in row k of the + # weights matrix. + distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices + # apply cubic kernel + if (scale < 1) and (antialiasing): + weights = scale * cubic(distance_to_center * scale) + else: + weights = cubic(distance_to_center) + # Normalize the weights matrix so that each row sums to 1. + weights_sum = torch.sum(weights, 1).view(out_length, 1) + weights = weights / weights_sum.expand(out_length, P) + + # If a column in weights is all zero, get rid of it. only consider the first and last column. + weights_zero_tmp = torch.sum((weights == 0), 0) + if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6): + indices = indices.narrow(1, 1, P - 2) + weights = weights.narrow(1, 1, P - 2) + if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6): + indices = indices.narrow(1, 0, P - 2) + weights = weights.narrow(1, 0, P - 2) + weights = weights.contiguous() + indices = indices.contiguous() + sym_len_s = -indices.min() + 1 + sym_len_e = indices.max() - in_length + indices = indices + sym_len_s - 1 + return weights, indices, int(sym_len_s), int(sym_len_e) + + +# imresize for tensor image [0, 1] +def imresize(img, scale, antialiasing=True): + # Now the scale should be the same for H and W + # input: img: pytorch tensor, CHW or HW [0,1] + # output: CHW or HW [0,1] w/o round + need_squeeze = True if img.dim() == 2 else False + if need_squeeze: + img.unsqueeze_(0) + in_C, in_H, in_W = img.size() + out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W + * scale) + kernel_width = 4 + kernel = 'cubic' + + # Return the desired dimension order for performing the resize. The + # strategy is to perform the resize first along the dimension with the + # smallest scale factor. + # Now we do not support this. + + # get weights and indices + weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( + in_H, out_H, scale, kernel, kernel_width, antialiasing) + weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( + in_W, out_W, scale, kernel, kernel_width, antialiasing) + # process H dimension + # symmetric copying + img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W) + img_aug.narrow(1, sym_len_Hs, in_H).copy_(img) + + sym_patch = img[:, :sym_len_Hs, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv) + + sym_patch = img[:, -sym_len_He:, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv) + + out_1 = torch.FloatTensor(in_C, out_H, in_W) + kernel_width = weights_H.size(1) + for i in range(out_H): + idx = int(indices_H[i][0]) + for j in range(out_C): + out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose( + 0, 1).mv(weights_H[i]) + + # process W dimension + # symmetric copying + out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We) + out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1) + + sym_patch = out_1[:, :, :sym_len_Ws] + inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(2, inv_idx) + out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv) + + sym_patch = out_1[:, :, -sym_len_We:] + inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(2, inv_idx) + out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv) + + out_2 = torch.FloatTensor(in_C, out_H, out_W) + kernel_width = weights_W.size(1) + for i in range(out_W): + idx = int(indices_W[i][0]) + for j in range(out_C): + out_2[j, :, i] = out_1_aug[j, :, + idx:idx + kernel_width].mv(weights_W[i]) + if need_squeeze: + out_2.squeeze_() + return out_2 + + +# imresize for numpy image [0, 1] +def imresize_np(img, scale, antialiasing=True): + # Now the scale should be the same for H and W + # input: img: Numpy, HWC or HW [0,1] + # output: HWC or HW [0,1] w/o round + img = torch.from_numpy(img) + need_squeeze = True if img.dim() == 2 else False + if need_squeeze: + img.unsqueeze_(2) + + in_H, in_W, in_C = img.size() + out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W + * scale) + kernel_width = 4 + kernel = 'cubic' + + # Return the desired dimension order for performing the resize. The + # strategy is to perform the resize first along the dimension with the + # smallest scale factor. + # Now we do not support this. + + # get weights and indices + weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( + in_H, out_H, scale, kernel, kernel_width, antialiasing) + weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( + in_W, out_W, scale, kernel, kernel_width, antialiasing) + # process H dimension + # symmetric copying + img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C) + img_aug.narrow(0, sym_len_Hs, in_H).copy_(img) + + sym_patch = img[:sym_len_Hs, :, :] + inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(0, inv_idx) + img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv) + + sym_patch = img[-sym_len_He:, :, :] + inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(0, inv_idx) + img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv) + + out_1 = torch.FloatTensor(out_H, in_W, in_C) + kernel_width = weights_H.size(1) + for i in range(out_H): + idx = int(indices_H[i][0]) + for j in range(out_C): + out_1[i, :, j] = img_aug[idx:idx + kernel_width, :, + j].transpose(0, 1).mv(weights_H[i]) + + # process W dimension + # symmetric copying + out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C) + out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1) + + sym_patch = out_1[:, :sym_len_Ws, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv) + + sym_patch = out_1[:, -sym_len_We:, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv) + + out_2 = torch.FloatTensor(out_H, out_W, in_C) + kernel_width = weights_W.size(1) + for i in range(out_W): + idx = int(indices_W[i][0]) + for j in range(out_C): + out_2[:, i, j] = out_1_aug[:, idx:idx + kernel_width, + j].mv(weights_W[i]) + if need_squeeze: + out_2.squeeze_() + + return out_2.numpy() + + +def modcrop_np(img, sf): + ''' + Args: + img: numpy image, WxH or WxHxC + sf: scale factor + Return: + cropped image + ''' + w, h = img.shape[:2] + im = np.copy(img) + return im[:w - w % sf, :h - h % sf, ...] + + +def analytic_kernel(k): + """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)""" + k_size = k.shape[0] + # Calculate the big kernels size + big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2)) + # Loop over the small kernel to fill the big one + for r in range(k_size): + for c in range(k_size): + big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k + # Crop the edges of the big kernel to ignore very small values and increase run time of SR + crop = k_size // 2 + cropped_big_k = big_k[crop:-crop, crop:-crop] + # Normalize to 1 + return cropped_big_k / cropped_big_k.sum() + + +def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6): + """ generate an anisotropic Gaussian kernel + Args: + ksize : e.g., 15, kernel size + theta : [0, pi], rotation angle range + l1 : [0.1,50], scaling of eigenvalues + l2 : [0.1,l1], scaling of eigenvalues + If l1 = l2, will get an isotropic Gaussian kernel. + Returns: + k : kernel + """ + + v = np.dot( + np.array([[np.cos(theta), -np.sin(theta)], + [np.sin(theta), np.cos(theta)]]), np.array([1., 0.])) + V = np.array([[v[0], v[1]], [v[1], -v[0]]]) + D = np.array([[l1, 0], [0, l2]]) + Sigma = np.dot(np.dot(V, D), np.linalg.inv(V)) + k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize) + return k + + +def gm_blur_kernel(mean, cov, size=15): + center = size / 2.0 + 0.5 + k = np.zeros([size, size]) + for y in range(size): + for x in range(size): + cy = y - center + 1 + cx = x - center + 1 + k[y, x] = stats.multivariate_normal.pdf([cx, cy], + mean=mean, + cov=cov) + + k = k / np.sum(k) + return k + + +def shift_pixel(x, sf, upper_left=True): + """shift pixel for super-resolution with different scale factors + Args: + x: WxHxC or WxH + sf: scale factor + upper_left: shift direction + """ + h, w = x.shape[:2] + shift = (sf - 1) * 0.5 + xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0) + if upper_left: + x1 = xv + shift + y1 = yv + shift + else: + x1 = xv - shift + y1 = yv - shift + + x1 = np.clip(x1, 0, w - 1) + y1 = np.clip(y1, 0, h - 1) + + if x.ndim == 2: + x = interp2d(xv, yv, x)(x1, y1) + if x.ndim == 3: + for i in range(x.shape[-1]): + x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1) + + return x + + +def blur(x, k): + ''' + x: image, NxcxHxW + k: kernel, Nx1xhxw + ''' + n, c = x.shape[:2] + p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2 + x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate') + k = k.repeat(1, c, 1, 1) + k = k.view(-1, 1, k.shape[2], k.shape[3]) + x = x.view(1, -1, x.shape[2], x.shape[3]) + x = torch.nn.functional.conv2d( + x, k, bias=None, stride=1, padding=0, groups=n * c) + x = x.view(n, c, x.shape[2], x.shape[3]) + + return x + + +def gen_kernel( + k_size=np.array([15, 15]), + scale_factor=np.array([4, 4]), + min_var=0.6, + max_var=10., + noise_level=0): + """" + # modified version of https://github.com/assafshocher/BlindSR_dataset_generator + # Kai Zhang + # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var + # max_var = 2.5 * sf + """ + # Set random eigen-vals (lambdas) and angle (theta) for COV matrix + lambda_1 = min_var + np.random.rand() * (max_var - min_var) + lambda_2 = min_var + np.random.rand() * (max_var - min_var) + theta = np.random.rand() * np.pi # random theta + noise = -noise_level + np.random.rand(*k_size) * noise_level * 2 + + # Set COV matrix using Lambdas and Theta + LAMBDA = np.diag([lambda_1, lambda_2]) + Q = np.array([[np.cos(theta), -np.sin(theta)], + [np.sin(theta), np.cos(theta)]]) + SIGMA = Q @ LAMBDA @ Q.T + INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :] + + # Set expectation position (shifting kernel for aligned image) + MU = k_size // 2 - 0.5 * (scale_factor - 1) + MU = MU[None, None, :, None] + + # Create meshgrid for Gaussian + [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1])) + Z = np.stack([X, Y], 2)[:, :, :, None] + + # Calcualte Gaussian for every pixel of the kernel + ZZ = Z - MU + ZZ_t = ZZ.transpose(0, 1, 3, 2) + raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise) + + # shift the kernel so it will be centered + # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor) + + # Normalize the kernel and return + # kernel = raw_kernel_centered / np.sum(raw_kernel_centered) + kernel = raw_kernel / np.sum(raw_kernel) + return kernel + + +def fspecial_gaussian(hsize, sigma): + hsize = [hsize, hsize] + siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0] + std = sigma + [x, y] = np.meshgrid( + np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1)) + arg = -(x * x + y * y) / (2 * std * std) + h = np.exp(arg) + h[h < scipy.finfo(float).eps * h.max()] = 0 + sumh = h.sum() + if sumh != 0: + h = h / sumh + return h + + +def fspecial_laplacian(alpha): + alpha = max([0, min([alpha, 1])]) + h1 = alpha / (alpha + 1) + h2 = (1 - alpha) / (alpha + 1) + h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]] + h = np.array(h) + return h + + +def fspecial(filter_type, *args, **kwargs): + if filter_type == 'gaussian': + return fspecial_gaussian(*args, **kwargs) + if filter_type == 'laplacian': + return fspecial_laplacian(*args, **kwargs) + + +def bicubic_degradation(x, sf=3): + ''' + Args: + x: HxWxC image, [0, 1] + sf: down-scale factor + Return: + bicubicly downsampled LR image + ''' + x = imresize_np(x, scale=1 / sf) + return x + + +def srmd_degradation(x, k, sf=3): + ''' blur + bicubic downsampling + Args: + x: HxWxC image, [0, 1] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + Reference: + @inproceedings{zhang2018learning, + title={Learning a single convolutional super-resolution network for multiple degradations}, + author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, + pages={3262--3271}, + year={2018} + } + ''' + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + x = bicubic_degradation(x, sf=sf) + return x + + +def dpsr_degradation(x, k, sf=3): + ''' bicubic downsampling + blur + Args: + x: HxWxC image, [0, 1] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + Reference: + @inproceedings{zhang2019deep, + title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels}, + author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, + pages={1671--1681}, + year={2019} + } + ''' + x = bicubic_degradation(x, sf=sf) + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + return x + + +def classical_degradation(x, k, sf=3): + ''' blur + downsampling + Args: + x: HxWxC image, [0, 1]/[0, 255] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + ''' + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + st = 0 + return x[st::sf, st::sf, ...] + + +def add_sharpening(img, weight=0.5, radius=50, threshold=10): + """USM sharpening. borrowed from real-ESRGAN + Input image: I; Blurry image: B. + 1. K = I + weight * (I - B) + 2. Mask = 1 if abs(I - B) > threshold, else: 0 + 3. Blur mask: + 4. Out = Mask * K + (1 - Mask) * I + Args: + img (Numpy array): Input image, HWC, BGR; float32, [0, 1]. + weight (float): Sharp weight. Default: 1. + radius (float): Kernel size of Gaussian blur. Default: 50. + threshold (int): + """ + if radius % 2 == 0: + radius += 1 + blur = cv2.GaussianBlur(img, (radius, radius), 0) + residual = img - blur + mask = np.abs(residual) * 255 > threshold + mask = mask.astype('float32') + soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0) + + K = img + weight * residual + K = np.clip(K, 0, 1) + return soft_mask * K + (1 - soft_mask) * img + + +def add_blur_1(img, sf=4): + wd2 = 4.0 + sf + wd = 2.0 + 0.2 * sf + + wd2 = wd2 / 4 + wd = wd / 4 + + if random.random() < 0.5: + l1 = wd2 * random.random() + l2 = wd2 * random.random() + k = anisotropic_Gaussian( + ksize=random.randint(2, 11) + 3, + theta=random.random() * np.pi, + l1=l1, + l2=l2) + else: + k = fspecial('gaussian', + random.randint(2, 4) + 3, wd * random.random()) + img = ndimage.filters.convolve( + img, np.expand_dims(k, axis=2), mode='mirror') + + return img + + +def add_resize(img, sf=4): + rnum = np.random.rand() + if rnum > 0.8: + sf1 = random.uniform(1, 2) + elif rnum < 0.7: + sf1 = random.uniform(0.5 / sf, 1) + else: + sf1 = 1.0 + img = cv2.resize( + img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), + interpolation=random.choice([1, 2, 3])) + img = np.clip(img, 0.0, 1.0) + + return img + + +def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): + noise_level = random.randint(noise_level1, noise_level2) + rnum = np.random.rand() + if rnum > 0.6: + img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype( + np.float32) + elif rnum < 0.4: + img = img + np.random.normal(0, noise_level / 255.0, + (*img.shape[:2], 1)).astype(np.float32) + else: + L = noise_level2 / 255. + D = np.diag(np.random.rand(3)) + U = orth(np.random.rand(3, 3)) + conv = np.dot(np.dot(np.transpose(U), D), U) + img = img + np.random.multivariate_normal([0, 0, 0], np.abs( + L**2 * conv), img.shape[:2]).astype(np.float32) + img = np.clip(img, 0.0, 1.0) + return img + + +def add_speckle_noise(img, noise_level1=2, noise_level2=25): + noise_level = random.randint(noise_level1, noise_level2) + img = np.clip(img, 0.0, 1.0) + rnum = random.random() + if rnum > 0.6: + img += img * np.random.normal(0, noise_level / 255.0, + img.shape).astype(np.float32) + elif rnum < 0.4: + img += img * np.random.normal(0, noise_level / 255.0, + (*img.shape[:2], 1)).astype(np.float32) + else: + L = noise_level2 / 255. + D = np.diag(np.random.rand(3)) + U = orth(np.random.rand(3, 3)) + conv = np.dot(np.dot(np.transpose(U), D), U) + img += img * np.random.multivariate_normal( + [0, 0, 0], np.abs(L**2 * conv), img.shape[:2]).astype(np.float32) + img = np.clip(img, 0.0, 1.0) + return img + + +def add_Poisson_noise(img): + img = np.clip((img * 255.0).round(), 0, 255) / 255. + vals = 10**(2 * random.random() + 2.0) + if random.random() < 0.5: + img = np.random.poisson(img * vals).astype(np.float32) / vals + else: + img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114]) + img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255. + noise_gray = np.random.poisson(img_gray * vals).astype( + np.float32) / vals - img_gray + img += noise_gray[:, :, np.newaxis] + img = np.clip(img, 0.0, 1.0) + return img + + +def add_JPEG_noise(img): + quality_factor = random.randint(80, 95) + img = cv2.cvtColor(single2uint(img), cv2.COLOR_RGB2BGR) + result, encimg = cv2.imencode( + '.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]) + img = cv2.imdecode(encimg, 1) + img = cv2.cvtColor(uint2single(img), cv2.COLOR_BGR2RGB) + return img + + +def random_crop(lq, hq, sf=4, lq_patchsize=64): + h, w = lq.shape[:2] + rnd_h = random.randint(0, h - lq_patchsize) + rnd_w = random.randint(0, w - lq_patchsize) + lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :] + + rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf) + hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, + rnd_w_H:rnd_w_H + lq_patchsize * sf, :] + return lq, hq + + +def degradation_bsrgan_light(image, sf=4, isp_model=None): + """ + This is the variant of the degradation model of BSRGAN from the paper + "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution" + sf: scale factor + isp_model: camera ISP model + Returns + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + image = uint2single(image) + _, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 + + h1, w1 = image.shape[:2] + image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] + h, w = image.shape[:2] + + if sf == 4 and random.random() < scale2_prob: + if np.random.rand() < 0.5: + image = cv2.resize( + image, + (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + image = imresize_np(image, 1 / 2, True) + image = np.clip(image, 0.0, 1.0) + sf = 2 + + shuffle_order = random.sample(range(7), 7) + idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) + if idx1 > idx2: + shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[ + idx2], shuffle_order[idx1] + + for i in shuffle_order: + + if i == 0: + image = add_blur_1(image, sf=sf) + elif i == 2: + a, b = image.shape[1], image.shape[0] + # downsample2 + if random.random() < 0.8: + sf1 = random.uniform(1, 2 * sf) + image = cv2.resize( + image, (int(1 / sf1 * image.shape[1]), + int(1 / sf1 * image.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) + k_shifted = shift_pixel(k, sf) + k_shifted = k_shifted / k_shifted.sum() + image = ndimage.filters.convolve( + image, np.expand_dims(k_shifted, axis=2), mode='mirror') + image = image[0::sf, 0::sf, ...] + image = np.clip(image, 0.0, 1.0) + elif i == 3: + # downsample3 + image = cv2.resize( + image, (int(1 / sf * a), int(1 / sf * b)), + interpolation=random.choice([1, 2, 3])) + image = np.clip(image, 0.0, 1.0) + elif i == 4: + # add Gaussian noise + image = add_Gaussian_noise(image, noise_level1=1, noise_level2=2) + elif i == 5: + # add JPEG noise + if random.random() < jpeg_prob: + image = add_JPEG_noise(image) + + # add final JPEG compression noise + image = add_JPEG_noise(image) + image = single2uint(image) + return image + + +def add_blur_2(img, sf=4): + wd2 = 4.0 + sf + wd = 2.0 + 0.2 * sf + if random.random() < 0.5: + l1 = wd2 * random.random() + l2 = wd2 * random.random() + k = anisotropic_Gaussian( + ksize=2 * random.randint(2, 11) + 3, + theta=random.random() * np.pi, + l1=l1, + l2=l2) + else: + k = fspecial('gaussian', 2 * random.randint(2, 11) + 3, + wd * random.random()) + img = ndimage.filters.convolve( + img, np.expand_dims(k, axis=2), mode='mirror') + return img + + +def degradation_bsrgan(image, sf=4, isp_model=None): + """ + This is the variant of the degradation model of BSRGAN from the paper + "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution" + sf: scale factor + isp_model: camera ISP model + Returns + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + image = uint2single(image) + _, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 + + h1, w1 = image.shape[:2] + image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] + h, w = image.shape[:2] + + if sf == 4 and random.random() < scale2_prob: + if np.random.rand() < 0.5: + image = cv2.resize( + image, + (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + image = imresize_np(image, 1 / 2, True) + image = np.clip(image, 0.0, 1.0) + sf = 2 + + shuffle_order = random.sample(range(7), 7) + idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) + if idx1 > idx2: + shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[ + idx2], shuffle_order[idx1] + + for i in shuffle_order: + + if i == 0: + image = add_blur_2(image, sf=sf) + elif i == 1: + image = add_blur_2(image, sf=sf) + elif i == 2: + a, b = image.shape[1], image.shape[0] + # downsample2 + if random.random() < 0.75: + sf1 = random.uniform(1, 2 * sf) + image = cv2.resize( + image, (int(1 / sf1 * image.shape[1]), + int(1 / sf1 * image.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) + k_shifted = shift_pixel(k, sf) + k_shifted = k_shifted / k_shifted.sum() + image = ndimage.filters.convolve( + image, np.expand_dims(k_shifted, axis=2), mode='mirror') + image = image[0::sf, 0::sf, ...] + image = np.clip(image, 0.0, 1.0) + elif i == 3: + # downsample3 + image = cv2.resize( + image, (int(1 / sf * a), int(1 / sf * b)), + interpolation=random.choice([1, 2, 3])) + image = np.clip(image, 0.0, 1.0) + elif i == 4: + # add Gaussian noise + image = add_Gaussian_noise(image, noise_level1=2, noise_level2=25) + elif i == 5: + # add JPEG noise + if random.random() < jpeg_prob: + image = add_JPEG_noise(image) + + # add final JPEG compression noise + image = add_JPEG_noise(image) + image = single2uint(image) + return image diff --git a/modelscope/models/multi_modal/videocomposer/ops/distributed.py b/modelscope/models/multi_modal/videocomposer/ops/distributed.py new file mode 100644 index 00000000..201e156f --- /dev/null +++ b/modelscope/models/multi_modal/videocomposer/ops/distributed.py @@ -0,0 +1,460 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import functools +import pickle +from collections import OrderedDict + +import numpy as np +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch.autograd import Function + +__all__ = [ + 'is_dist_initialized', 'get_world_size', 'get_rank', 'new_group', + 'destroy_process_group', 'barrier', 'broadcast', 'all_reduce', 'reduce', + 'gather', 'all_gather', 'reduce_dict', 'get_global_gloo_group', + 'generalized_all_gather', 'generalized_gather', 'scatter', + 'reduce_scatter', 'send', 'recv', 'isend', 'irecv', 'shared_random_seed', + 'diff_all_gather', 'diff_all_reduce', 'diff_scatter', 'diff_copy', + 'spherical_kmeans', 'sinkhorn' +] + + +def is_dist_initialized(): + return dist.is_available() and dist.is_initialized() + + +def get_world_size(group=None): + return dist.get_world_size(group) if is_dist_initialized() else 1 + + +def get_rank(group=None): + return dist.get_rank(group) if is_dist_initialized() else 0 + + +def new_group(ranks=None, **kwargs): + if is_dist_initialized(): + return dist.new_group(ranks, **kwargs) + return None + + +def destroy_process_group(): + if is_dist_initialized(): + dist.destroy_process_group() + + +def barrier(group=None, **kwargs): + if get_world_size(group) > 1: + dist.barrier(group, **kwargs) + + +def broadcast(tensor, src, group=None, **kwargs): + if get_world_size(group) > 1: + return dist.broadcast(tensor, src, group, **kwargs) + + +def all_reduce(tensor, op=dist.ReduceOp.SUM, group=None, **kwargs): + if get_world_size(group) > 1: + return dist.all_reduce(tensor, op, group, **kwargs) + + +def reduce(tensor, dst, op=dist.ReduceOp.SUM, group=None, **kwargs): + if get_world_size(group) > 1: + return dist.reduce(tensor, dst, op, group, **kwargs) + + +def gather(tensor, dst=0, group=None, **kwargs): + rank = get_rank() + world_size = get_world_size(group) + if world_size == 1: + return [tensor] + tensor_list = [torch.empty_like(tensor) + for _ in range(world_size)] if rank == dst else None + dist.gather(tensor, tensor_list, dst, group, **kwargs) + return tensor_list + + +def all_gather(tensor, uniform_size=True, group=None, **kwargs): + world_size = get_world_size(group) + if world_size == 1: + return [tensor] + assert tensor.is_contiguous( + ), 'ops.all_gather requires the tensor to be contiguous()' + + if uniform_size: + tensor_list = [torch.empty_like(tensor) for _ in range(world_size)] + dist.all_gather(tensor_list, tensor, group, **kwargs) + return tensor_list + else: + # collect tensor shapes across GPUs + shape = tuple(tensor.shape) + shape_list = generalized_all_gather(shape, group) + + # flatten the tensor + tensor = tensor.reshape(-1) + size = int(np.prod(shape)) + size_list = [int(np.prod(u)) for u in shape_list] + max_size = max(size_list) + + # pad to maximum size + if size != max_size: + padding = tensor.new_zeros(max_size - size) + tensor = torch.cat([tensor, padding], dim=0) + + # all_gather + tensor_list = [torch.empty_like(tensor) for _ in range(world_size)] + dist.all_gather(tensor_list, tensor, group, **kwargs) + + # reshape tensors + tensor_list = [ + t[:n].view(s) + for t, n, s in zip(tensor_list, size_list, shape_list) + ] + return tensor_list + + +@torch.no_grad() +def reduce_dict(input_dict, group=None, reduction='mean', **kwargs): + assert reduction in ['mean', 'sum'] + world_size = get_world_size(group) + if world_size == 1: + return input_dict + + # ensure that the orders of keys are consistent across processes + if isinstance(input_dict, OrderedDict): + keys = list(input_dict.keys) + else: + keys = sorted(input_dict.keys()) + vals = [input_dict[key] for key in keys] + vals = torch.stack(vals, dim=0) + dist.reduce(vals, dst=0, group=group, **kwargs) + if dist.get_rank(group) == 0 and reduction == 'mean': + vals /= world_size + dist.broadcast(vals, src=0, group=group, **kwargs) + reduced_dict = type(input_dict)([(key, val) + for key, val in zip(keys, vals)]) + return reduced_dict + + +@functools.lru_cache() +def get_global_gloo_group(): + backend = dist.get_backend() + assert backend in ['gloo', 'nccl'] + if backend == 'nccl': + return dist.new_group(backend='gloo') + else: + return dist.group.WORLD + + +def _serialize_to_tensor(data, group): + backend = dist.get_backend(group) + assert backend in ['gloo', 'nccl'] + device = torch.device('cpu' if backend == 'gloo' else 'cuda') + + buffer = pickle.dumps(data) + if len(buffer) > 1024**3: + logger = logging.getLogger(__name__) + logger.warning( + 'Rank {} trying to all-gather {:.2f} GB of data on device' + '{}'.format(get_rank(), + len(buffer) / (1024**3), device)) + storage = torch.ByteStorage.from_buffer(buffer) + tensor = torch.ByteTensor(storage).to(device=device) + return tensor + + +def _pad_to_largest_tensor(tensor, group): + world_size = dist.get_world_size(group=group) + assert world_size >= 1, \ + 'gather/all_gather must be called from ranks within' \ + 'the give group!' + local_size = torch.tensor([tensor.numel()], + dtype=torch.int64, + device=tensor.device) + size_list = [ + torch.zeros([1], dtype=torch.int64, device=tensor.device) + for _ in range(world_size) + ] + + # gather tensors and compute the maximum size + dist.all_gather(size_list, local_size, group=group) + size_list = [int(size.item()) for size in size_list] + max_size = max(size_list) + + # pad tensors to the same size + if local_size != max_size: + padding = torch.zeros((max_size - local_size, ), + dtype=torch.uint8, + device=tensor.device) + tensor = torch.cat((tensor, padding), dim=0) + return size_list, tensor + + +def generalized_all_gather(data, group=None): + if get_world_size(group) == 1: + return [data] + if group is None: + group = get_global_gloo_group() + + tensor = _serialize_to_tensor(data, group) + size_list, tensor = _pad_to_largest_tensor(tensor, group) + max_size = max(size_list) + + # receiving tensors from all ranks + tensor_list = [ + torch.empty((max_size, ), dtype=torch.uint8, device=tensor.device) + for _ in size_list + ] + dist.all_gather(tensor_list, tensor, group=group) + + data_list = [] + for size, tensor in zip(size_list, tensor_list): + buffer = tensor.cpu().numpy().tobytes()[:size] + data_list.append(pickle.loads(buffer)) + return data_list + + +def generalized_gather(data, dst=0, group=None): + world_size = get_world_size(group) + if world_size == 1: + return [data] + if group is None: + group = get_global_gloo_group() + rank = dist.get_rank() + + tensor = _serialize_to_tensor(data, group) + size_list, tensor = _pad_to_largest_tensor(tensor, group) + + # receiving tensors from all ranks to dst + if rank == dst: + max_size = max(size_list) + tensor_list = [ + torch.empty((max_size, ), dtype=torch.uint8, device=tensor.device) + for _ in size_list + ] + dist.gather(tensor, tensor_list, dst=dst, group=group) + + data_list = [] + for size, tensor in zip(size_list, tensor_list): + buffer = tensor.cpu().numpy().tobytes()[:size] + data_list.append(pickle.loads(buffer)) + return data_list + else: + dist.gather(tensor, [], dst=dst, group=group) + return [] + + +def scatter(data, scatter_list=None, src=0, group=None, **kwargs): + r"""NOTE: only supports CPU tensor communication. + """ + if get_world_size(group) > 1: + return dist.scatter(data, scatter_list, src, group, **kwargs) + + +def reduce_scatter(output, + input_list, + op=dist.ReduceOp.SUM, + group=None, + **kwargs): + if get_world_size(group) > 1: + return dist.reduce_scatter(output, input_list, op, group, **kwargs) + + +def send(tensor, dst, group=None, **kwargs): + if get_world_size(group) > 1: + assert tensor.is_contiguous( + ), 'ops.send requires the tensor to be contiguous()' + return dist.send(tensor, dst, group, **kwargs) + + +def recv(tensor, src=None, group=None, **kwargs): + if get_world_size(group) > 1: + assert tensor.is_contiguous( + ), 'ops.recv requires the tensor to be contiguous()' + return dist.recv(tensor, src, group, **kwargs) + + +def isend(tensor, dst, group=None, **kwargs): + if get_world_size(group) > 1: + assert tensor.is_contiguous( + ), 'ops.isend requires the tensor to be contiguous()' + return dist.isend(tensor, dst, group, **kwargs) + + +def irecv(tensor, src=None, group=None, **kwargs): + if get_world_size(group) > 1: + assert tensor.is_contiguous( + ), 'ops.irecv requires the tensor to be contiguous()' + return dist.irecv(tensor, src, group, **kwargs) + + +def shared_random_seed(group=None): + seed = np.random.randint(2**31) + all_seeds = generalized_all_gather(seed, group) + return all_seeds[0] + + +def _all_gather(x): + if not (dist.is_available() + and dist.is_initialized()) or dist.get_world_size() == 1: + return x + rank = dist.get_rank() + world_size = dist.get_world_size() + tensors = [torch.empty_like(x) for _ in range(world_size)] + tensors[rank] = x + dist.all_gather(tensors, x) + return torch.cat(tensors, dim=0).contiguous() + + +def _all_reduce(x): + if not (dist.is_available() + and dist.is_initialized()) or dist.get_world_size() == 1: + return x + dist.all_reduce(x) + return x + + +def _split(x): + if not (dist.is_available() + and dist.is_initialized()) or dist.get_world_size() == 1: + return x + rank = dist.get_rank() + world_size = dist.get_world_size() + return x.chunk(world_size, dim=0)[rank].contiguous() + + +class DiffAllGather(Function): + r"""Differentiable all-gather. + """ + + @staticmethod + def symbolic(graph, input): + return _all_gather(input) + + @staticmethod + def forward(ctx, input): + return _all_gather(input) + + @staticmethod + def backward(ctx, grad_output): + return _split(grad_output) + + +class DiffAllReduce(Function): + r"""Differentiable all-reducd. + """ + + @staticmethod + def symbolic(graph, input): + return _all_reduce(input) + + @staticmethod + def forward(ctx, input): + return _all_reduce(input) + + @staticmethod + def backward(ctx, grad_output): + return grad_output + + +class DiffScatter(Function): + r"""Differentiable scatter. + """ + + @staticmethod + def symbolic(ctx, input): + return _split(input) + + @staticmethod + def backward(ctx, grad_output): + return _all_gather(grad_output) + + +class DiffCopy(Function): + r"""Differentiable copy that reduces all gradients during backward. + """ + + @staticmethod + def symbolic(graph, input): + return input + + @staticmethod + def forward(ctx, input): + return input + + @staticmethod + def backward(ctx, grad_output): + return _all_reduce(grad_output) + + +diff_all_gather = DiffAllGather.apply +diff_all_reduce = DiffAllReduce.apply +diff_scatter = DiffScatter.apply +diff_copy = DiffCopy.apply + + +@torch.no_grad() +def spherical_kmeans(feats, num_clusters, num_iters=10): + k, n, c = num_clusters, *feats.size() + ones = feats.new_ones(n, dtype=torch.long) + + # distributed settings + world_size = get_world_size() + + # init clusters + rand_inds = torch.randperm(n)[:int(np.ceil(k / world_size))] + clusters = torch.cat(all_gather(feats[rand_inds]), dim=0)[:k] + + # variables + new_clusters = feats.new_zeros(k, c) + counts = feats.new_zeros(k, dtype=torch.long) + + # iterative Expectation-Maximization + for step in range(num_iters + 1): + # Expectation step + simmat = torch.mm(feats, clusters.t()) + scores, assigns = simmat.max(dim=1) + if step == num_iters: + break + + # Maximization step + new_clusters.zero_().scatter_add_(0, + assigns.unsqueeze(1).repeat(1, c), + feats) + all_reduce(new_clusters) + + counts.zero_() + counts.index_add_(0, assigns, ones) + all_reduce(counts) + + mask = (counts > 0) + clusters[mask] = new_clusters[mask] / counts[mask].view(-1, 1) + clusters = F.normalize(clusters, p=2, dim=1) + return clusters, assigns, scores + + +@torch.no_grad() +def sinkhorn(Q, eps=0.5, num_iters=3): + # normalize Q + Q = torch.exp(Q / eps).t() + sum_Q = Q.sum() + all_reduce(sum_Q) + Q /= sum_Q + + # variables + n, m = Q.size() + u = Q.new_zeros(n) + r = Q.new_ones(n) / n + c = Q.new_ones(m) / (m * get_world_size()) + + # iterative update + cur_sum = Q.sum(dim=1) + all_reduce(cur_sum) + for i in range(num_iters): + u = cur_sum + Q *= (r / u).unsqueeze(1) + Q *= (c / Q.sum(dim=0)).unsqueeze(0) + cur_sum = Q.sum(dim=1) + all_reduce(cur_sum) + return (Q / Q.sum(dim=0, keepdim=True)).t().float() diff --git a/modelscope/models/multi_modal/videocomposer/ops/losses.py b/modelscope/models/multi_modal/videocomposer/ops/losses.py new file mode 100644 index 00000000..fbcb0b60 --- /dev/null +++ b/modelscope/models/multi_modal/videocomposer/ops/losses.py @@ -0,0 +1,37 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import math + +import torch + +__all__ = ['kl_divergence', 'discretized_gaussian_log_likelihood'] + + +def kl_divergence(mu1, logvar1, mu2, logvar2): + return 0.5 * ( + -1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + # noqa + ((mu1 - mu2)**2) * torch.exp(-logvar2)) + + +def standard_normal_cdf(x): + r"""A fast approximation of the cumulative distribution function of the standard normal. + """ + return 0.5 * (1.0 + torch.tanh( + math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) + + +def discretized_gaussian_log_likelihood(x0, mean, log_scale): + assert x0.shape == mean.shape == log_scale.shape + cx = x0 - mean + inv_stdv = torch.exp(-log_scale) + cdf_plus = standard_normal_cdf(inv_stdv * (cx + 1.0 / 255.0)) + cdf_min = standard_normal_cdf(inv_stdv * (cx - 1.0 / 255.0)) + log_cdf_plus = torch.log(cdf_plus.clamp(min=1e-12)) + log_one_minus_cdf_min = torch.log((1.0 - cdf_min).clamp(min=1e-12)) + cdf_delta = cdf_plus - cdf_min + log_probs = torch.where( + x0 < -0.999, log_cdf_plus, + torch.where(x0 > 0.999, log_one_minus_cdf_min, + torch.log(cdf_delta.clamp(min=1e-12)))) + assert log_probs.shape == x0.shape + return log_probs diff --git a/modelscope/models/multi_modal/videocomposer/ops/random_mask.py b/modelscope/models/multi_modal/videocomposer/ops/random_mask.py new file mode 100644 index 00000000..f23219b5 --- /dev/null +++ b/modelscope/models/multi_modal/videocomposer/ops/random_mask.py @@ -0,0 +1,81 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import cv2 +import numpy as np + +__all__ = ['make_irregular_mask', 'make_rectangle_mask', 'make_uncrop'] + + +def make_irregular_mask(w, + h, + max_angle=4, + max_length=200, + max_width=100, + min_strokes=1, + max_strokes=5, + mode='line'): + # initialize mask + assert mode in ['line', 'circle', 'square'] + mask = np.zeros((h, w), np.float32) + + # draw strokes + num_strokes = np.random.randint(min_strokes, max_strokes + 1) + for i in range(num_strokes): + x1 = np.random.randint(w) + y1 = np.random.randint(h) + for j in range(1 + np.random.randint(5)): + angle = 0.01 + np.random.randint(max_angle) + if i % 2 == 0: + angle = 2 * 3.1415926 - angle + length = 10 + np.random.randint(max_length) + radius = 5 + np.random.randint(max_width) + x2 = np.clip((x1 + length * np.sin(angle)).astype(np.int32), 0, w) + y2 = np.clip((y1 + length * np.cos(angle)).astype(np.int32), 0, h) + if mode == 'line': + cv2.line(mask, (x1, y1), (x2, y2), 1.0, radius) + elif mode == 'circle': + cv2.circle( + mask, (x1, y1), radius=radius, color=1.0, thickness=-1) + elif mode == 'square': + radius = radius // 2 + mask[y1 - radius:y1 + radius, x1 - radius:x1 + radius] = 1 + x1, y1 = x2, y2 + return mask + + +def make_rectangle_mask(w, + h, + margin=10, + min_size=30, + max_size=150, + min_strokes=1, + max_strokes=4): + # initialize mask + mask = np.zeros((h, w), np.float32) + + # draw rectangles + num_strokes = np.random.randint(min_strokes, max_strokes + 1) + for i in range(num_strokes): + box_w = np.random.randint(min_size, max_size) + box_h = np.random.randint(min_size, max_size) + x1 = np.random.randint(margin, w - margin - box_w + 1) + y1 = np.random.randint(margin, h - margin - box_h + 1) + mask[y1:y1 + box_h, x1:x1 + box_w] = 1 + return mask + + +def make_uncrop(w, h): + # initialize mask + mask = np.zeros((h, w), np.float32) + + # randomly halve the image + side = np.random.choice([0, 1, 2, 3]) + if side == 0: + mask[:h // 2, :] = 1 + elif side == 1: + mask[h // 2:, :] = 1 + elif side == 2: + mask[:, :w // 2] = 1 + elif side == 3: + mask[:, w // 2:] = 1 + return mask diff --git a/modelscope/models/multi_modal/videocomposer/ops/utils.py b/modelscope/models/multi_modal/videocomposer/ops/utils.py new file mode 100644 index 00000000..f9aadc15 --- /dev/null +++ b/modelscope/models/multi_modal/videocomposer/ops/utils.py @@ -0,0 +1,1037 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import base64 +import binascii +import copy +import glob +import gzip +import hashlib +import logging +import math +import os +import os.path as osp +import pickle +import sys +import time +import urllib.request +import zipfile +from io import BytesIO +from multiprocessing.pool import ThreadPool as Pool + +import imageio +import json +import numpy as np +import oss2 as oss +import requests +import skvideo.io +import torch +import torch.nn.functional as F +import torchvision.utils as tvutils +from einops import rearrange +from PIL import Image + +from modelscope.models.multi_modal.videocomposer.autoencoder import \ + DiagonalGaussianDistribution +from modelscope.utils.logger import get_logger + +logger = get_logger() + +__all__ = [ + 'parse_oss_url', 'parse_bucket', 'read', 'read_image', 'read_gzip', + 'ceil_divide', 'to_device', 'put_object', 'put_torch_object', + 'put_object_from_file', 'get_object', 'get_object_to_file', 'rand_name', + 'save_image', 'save_video', 'save_video_vs_conditions', + 'save_video_multiple_conditions_with_data', + 'save_video_multiple_conditions', 'download_video_to_file', + 'save_video_grid_mp4', 'save_caps', 'ema', 'parallel', 'exists', + 'download', 'unzip', 'load_state_dict', 'inverse_indices', + 'detect_duplicates', 'read_tfs', 'md5', 'rope', 'format_state', + 'breakup_grid', 'huggingface_tokenizer', 'huggingface_model' +] + +TFS_CLIENT = None + + +def rand_name(length=8, suffix=''): + name = binascii.b2a_hex(os.urandom(length)).decode('utf-8') + if suffix: + if not suffix.startswith('.'): + suffix = '.' + suffix + name += suffix + return name + + +def save_with_model_kwargs(model_kwargs, video_data, autoencoder, ori_video, + viz_num, step, caps, palette, cfg): + scale_factor = 0.18215 + video_data = 1. / scale_factor * video_data + + bs_vd = video_data.shape[0] + video_data = rearrange(video_data, 'b c f h w -> (b f) c h w') + chunk_size = min(16, video_data.shape[0]) + video_data_list = torch.chunk( + video_data, video_data.shape[0] // chunk_size, dim=0) + decode_data = [] + for vd_data in video_data_list: + tmp = autoencoder.decode(vd_data) + decode_data.append(tmp) + video_data = torch.cat(decode_data, dim=0) + video_data = rearrange(video_data, '(b f) c h w -> b c f h w', b=bs_vd) + ori_video = ori_video[:viz_num] + + oss_key = os.path.join(cfg.log_dir, 'rank.gif') + text_key = osp.join(cfg.log_dir, 'text_description.txt') + + if not os.path.exists(cfg.log_dir): + os.mkdir(cfg.log_dir) + + # Save videos and text inputs. + try: + del model_kwargs[0][list(model_kwargs[0].keys())[0]] + del model_kwargs[1][list(model_kwargs[1].keys())[0]] + + save_video_multiple_conditions( + oss_key, + video_data, + model_kwargs, + ori_video, + palette, + cfg.mean, + cfg.std, + nrow=1, + save_origin_video=cfg.save_origin_video) + + texts = '\n'.join(caps[:viz_num]) + open(text_key, 'w').writelines(texts) + except Exception as e: + logger.error(f'Save text or video error. {e}') + + logger.info(f'Save videos to {oss_key}') + + +def prepare_model_kwargs(partial_keys, full_model_kwargs, use_fps_condition): + for partial_key in partial_keys: + assert partial_key in [ + 'y', 'depth', 'canny', 'masked', 'sketch', 'image', 'motion', + 'local_image', 'single_sketch' + ] + + if use_fps_condition is True: + partial_keys.append('fps') + + partial_model_kwargs = [{}, {}] + for partial_key in partial_keys: + partial_model_kwargs[0][partial_key] = full_model_kwargs[0][ + partial_key] + partial_model_kwargs[1][partial_key] = full_model_kwargs[1][ + partial_key] + + return partial_model_kwargs + + +@torch.no_grad() +def get_first_stage_encoding(encoder_posterior): + scale_factor = 0.18215 + if isinstance(encoder_posterior, DiagonalGaussianDistribution): + z = encoder_posterior.sample() + elif isinstance(encoder_posterior, torch.Tensor): + z = encoder_posterior + else: + raise NotImplementedError( + f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented" + ) + return scale_factor * z + + +def make_masked_images(imgs, masks): + masked_imgs = [] + for i, mask in enumerate(masks): + # concatenation + masked_imgs.append( + torch.cat([imgs[i] * (1 - mask), (1 - mask)], dim=1)) + return torch.stack(masked_imgs, dim=0) + + +def DOWNLOAD_TO_CACHE(oss_key, + file_or_dirname=None, + cache_dir=osp.join( + '/'.join(osp.abspath(__file__).split('/')[:-2]), + 'model_weights')): + r"""Download OSS [file or folder] to the cache folder. + Only the 0th process on each node will run the downloading. + Barrier all processes until the downloading is completed. + """ + # source and target paths + base_path = osp.join(cache_dir, file_or_dirname or osp.basename(oss_key)) + + return base_path + + +def parse_oss_url(path): + if path.startswith('oss://'): + path = path[len('oss://'):] + + # configs + configs = { + 'endpoint': os.getenv('OSS_ENDPOINT', None), + 'accessKeyID': os.getenv('OSS_ACCESS_KEY_ID', None), + 'accessKeySecret': os.getenv('OSS_ACCESS_KEY_SECRET', None), + 'securityToken': os.getenv('OSS_SECURITY_TOKEN', None) + } + bucket, path = path.split('/', maxsplit=1) + if '?' in bucket: + bucket, config = bucket.split('?', maxsplit=1) + for pair in config.split('&'): + k, v = pair.split('=', maxsplit=1) + configs[k] = v + + # session + session = parse_oss_url._sessions.setdefault(f'{bucket}@{os.getpid()}', + oss.Session()) + + # bucket + bucket = oss.Bucket( + auth=oss.Auth(configs['accessKeyID'], configs['accessKeySecret']), + endpoint=configs['endpoint'], + bucket_name=bucket, + session=session) + return bucket, path + + +parse_oss_url._sessions = {} + + +def parse_bucket(url): + return parse_oss_url(osp.join(url, '_placeholder'))[0] + + +def read(filename, mode='r', retry=5): + assert mode in ['r', 'rb'] + exception = None + for _ in range(retry): + try: + if filename.startswith('oss://'): + bucket, path = parse_oss_url(filename) + content = bucket.get_object(path).read() + if mode == 'r': + content = content.decode('utf-8') + elif filename.startswith('http'): + content = requests.get(filename).content + if mode == 'r': + content = content.decode('utf-8') + else: + with open(filename, mode=mode) as f: + content = f.read() + return content + except Exception as e: + exception = e + continue + else: + raise exception + + +def read_image(filename, retry=5): + exception = None + for _ in range(retry): + try: + return Image.open(BytesIO(read(filename, mode='rb', retry=retry))) + except Exception as e: + exception = e + continue + else: + raise exception + + +def download_video_to_file(filename, local_file, retry=5): + exception = None + for _ in range(retry): + try: + bucket, path = parse_oss_url(filename) + bucket.get_object_to_file(path, local_file) + break + except Exception as e: + exception = e + continue + else: + raise exception + + +def read_gzip(filename, retry=5): + exception = None + for _ in range(retry): + try: + remove = False + if filename.startswith('oss://'): + bucket, path = parse_oss_url(filename) + filename = rand_name(suffix=osp.splitext(filename)[1]) + bucket.get_object_to_file(path, filename) + remove = True + with gzip.open(filename) as f: + content = f.read() + if remove: + os.remove(filename) + return content + except Exception as e: + exception = e + continue + else: + raise exception + + +def ceil_divide(a, b): + return int(math.ceil(a / b)) + + +def to_device(batch, device, non_blocking=False): + if isinstance(batch, (list, tuple)): + return type(batch)([to_device(u, device, non_blocking) for u in batch]) + elif isinstance(batch, dict): + return type(batch)([(k, to_device(v, device, non_blocking)) + for k, v in batch.items()]) + elif isinstance(batch, torch.Tensor) and batch.device != device: + batch = batch.to(device, non_blocking=non_blocking) + return batch + + +def put_object(bucket, oss_key, data, retry=5): + exception = None + for _ in range(retry): + try: + return bucket.put_object(oss_key, data) + except Exception as e: + exception = e + continue + else: + logger.info( + f'put_object to {oss_key} failed with error: {exception}', + flush=True) + + +def put_torch_object(bucket, oss_key, data, retry=5): + exception = None + for _ in range(retry): + try: + buffer = BytesIO() + torch.save(data, buffer) + return bucket.put_object(oss_key, buffer.getvalue()) + except Exception as e: + exception = e + continue + else: + logger.info( + f'put_torch_object to {oss_key} failed with error: {exception}', + flush=True) + + +def put_object_from_file(bucket, oss_key, filename, retry=5): + exception = None + for _ in range(retry): + try: + return bucket.put_object_from_file(oss_key, filename) + except Exception as e: + exception = e + continue + else: + logger.error( + f'put_object_from_file to {oss_key} failed with error: {exception}', + flush=True) + + +def get_object(bucket, oss_key, retry=5): + exception = None + for _ in range(retry): + try: + return bucket.get_object(oss_key).read() + except Exception as e: + exception = e + continue + else: + logger.error( + f'get_object from {oss_key} failed with error: {exception}', + flush=True) + + +def get_object_to_file(bucket, oss_key, filename, retry=5): + exception = None + for _ in range(retry): + try: + return bucket.get_object_to_file(oss_key, filename) + except Exception as e: + exception = e + continue + else: + logger.error( + f'get_object_to_file from {oss_key} failed with error: {exception}', + flush=True) + + +@torch.no_grad() +def save_image(bucket, + oss_key, + tensor, + nrow=8, + normalize=True, + range=(-1, 1), + retry=5): + filename = rand_name(suffix='.jpg') + for _ in [None] * retry: + try: + tvutils.save_image( + tensor, filename, nrow=nrow, normalize=normalize, range=range) + bucket.put_object_from_file(oss_key, filename) + exception = None + break + except Exception as e: + exception = e + continue + + # remove temporary file + if osp.exists(filename): + os.remove(filename) + if exception is not None: + logger.error( + 'save image to {} failed, error: {}'.format(oss_key, exception), + flush=True) + + +@torch.no_grad() +def video_tensor_to_gif(tensor, path, duration=120, loop=0, optimize=True): + tensor = tensor.permute(1, 2, 3, 0) + images = tensor.unbind(dim=0) + images = [(image.numpy() * 255).astype('uint8') for image in images] + imageio.mimwrite(path, images, duration=125) + return images + + +@torch.no_grad() +def save_video(bucket, + oss_key, + tensor, + mean=[0.5, 0.5, 0.5], + std=[0.5, 0.5, 0.5], + nrow=8, + retry=5): + mean = torch.tensor(mean, device=tensor.device).view(1, -1, 1, 1, 1) + std = torch.tensor(std, device=tensor.device).view(1, -1, 1, 1, 1) + tensor = tensor.mul_(std).add_(mean) + tensor.clamp_(0, 1) + + filename = rand_name(suffix='.gif') + for _ in [None] * retry: + try: + one_gif = rearrange( + tensor, '(i j) c f h w -> c f (i h) (j w)', i=nrow) + video_tensor_to_gif(one_gif, filename) + bucket.put_object_from_file(oss_key, filename) + exception = None + break + except Exception as e: + exception = e + continue + + # remove temporary file + if osp.exists(filename): + os.remove(filename) + if exception is not None: + logger.error( + 'save video to {} failed, error: {}'.format(oss_key, exception), + flush=True) + + +@torch.no_grad() +def save_video_multiple_conditions(oss_key, + video_tensor, + model_kwargs, + source_imgs, + palette, + mean=[0.5, 0.5, 0.5], + std=[0.5, 0.5, 0.5], + nrow=8, + retry=5, + save_origin_video=True, + bucket=None): + mean = torch.tensor(mean, device=video_tensor.device).view(1, -1, 1, 1, 1) + std = torch.tensor(std, device=video_tensor.device).view(1, -1, 1, 1, 1) + video_tensor = video_tensor.mul_(std).add_(mean) + try: + video_tensor.clamp_(0, 1) + except Exception as e: + logger.error(e) + video_tensor = video_tensor.float().clamp_(0, 1) + video_tensor = video_tensor.cpu() + + b, c, n, h, w = video_tensor.shape + source_imgs = F.adaptive_avg_pool3d(source_imgs, (n, h, w)) + source_imgs = source_imgs.cpu() + + model_kwargs_channel3 = {} + for key, conditions in model_kwargs[0].items(): + if conditions.shape[-1] == 1024: + # Skip for style embeding + continue + if len(conditions.shape) == 3: + conditions_np = conditions.cpu().numpy() + conditions = [] + for i in conditions_np: + vis_i = [] + for j in i: + vis_i.append( + palette.get_palette_image( + j, percentile=90, width=256, height=256)) + conditions.append(np.stack(vis_i)) + conditions = torch.from_numpy(np.stack(conditions)) + conditions = rearrange(conditions, 'b n h w c -> b c n h w') + else: + if conditions.size(1) == 1: + conditions = torch.cat([conditions, conditions, conditions], + dim=1) + conditions = F.adaptive_avg_pool3d(conditions, (n, h, w)) + if conditions.size(1) == 2: + conditions = torch.cat([conditions, conditions[:, :1, ]], + dim=1) + conditions = F.adaptive_avg_pool3d(conditions, (n, h, w)) + elif conditions.size(1) == 3: + conditions = F.adaptive_avg_pool3d(conditions, (n, h, w)) + elif conditions.size(1) == 4: + color = ((conditions[:, 0:3] + 1.) / 2.) + alpha = conditions[:, 3:4] + conditions = color * alpha + 1.0 * (1.0 - alpha) + conditions = F.adaptive_avg_pool3d(conditions, (n, h, w)) + model_kwargs_channel3[key] = conditions.cpu( + ) if conditions.is_cuda else conditions + + filename = oss_key + for _ in [None] * retry: + try: + vid_gif = rearrange( + video_tensor, '(i j) c f h w -> c f (i h) (j w)', i=nrow) + cons_list = [ + rearrange(con, '(i j) c f h w -> c f (i h) (j w)', i=nrow) + for _, con in model_kwargs_channel3.items() + ] + source_imgs = rearrange( + source_imgs, '(i j) c f h w -> c f (i h) (j w)', i=nrow) + + if save_origin_video: + vid_gif = torch.cat( + [ + source_imgs, + ] + cons_list + [ + vid_gif, + ], dim=3) + else: + vid_gif = torch.cat( + cons_list + [ + vid_gif, + ], dim=3) + + video_tensor_to_gif(vid_gif, filename) + exception = None + break + except Exception as e: + exception = e + continue + if exception is not None: + logging.info('save video to {} failed, error: {}'.format( + oss_key, exception)) + + +@torch.no_grad() +def save_video_multiple_conditions_with_data(bucket, + video_save_key, + gt_video_save_key, + vis_oss_key, + video_tensor, + model_kwargs, + source_imgs, + palette, + mean=[0.5, 0.5, 0.5], + std=[0.5, 0.5, 0.5], + nrow=8, + retry=5): + mean = torch.tensor(mean, device=video_tensor.device).view(1, -1, 1, 1, 1) + std = torch.tensor(std, device=video_tensor.device).view(1, -1, 1, 1, 1) + video_tensor = video_tensor.mul_(std).add_(mean) + video_tensor.clamp_(0, 1) + + b, c, n, h, w = video_tensor.shape + source_imgs = F.adaptive_avg_pool3d(source_imgs, (n, h, w)) + source_imgs = source_imgs.cpu() + + model_kwargs_channel3 = {} + for key, conditions in model_kwargs[0].items(): + if len(conditions.shape) == 3: + conditions_np = conditions.cpu().numpy() + conditions = [] + for i in conditions_np: + vis_i = [] + for j in i: + vis_i.append( + palette.get_palette_image( + j, percentile=90, width=256, height=256)) + conditions.append(np.stack(vis_i)) + conditions = torch.from_numpy(np.stack(conditions)) + conditions = rearrange(conditions, 'b n h w c -> b c n h w') + else: + if conditions.size(1) == 1: + conditions = torch.cat([conditions, conditions, conditions], + dim=1) + conditions = F.adaptive_avg_pool3d(conditions, (n, h, w)) + if conditions.size(1) == 2: + conditions = torch.cat([conditions, conditions[:, :1, ]], + dim=1) + conditions = F.adaptive_avg_pool3d(conditions, (n, h, w)) + elif conditions.size(1) == 3: + conditions = F.adaptive_avg_pool3d(conditions, (n, h, w)) + elif conditions.size(1) == 4: + color = ((conditions[:, 0:3] + 1.) / 2.) + alpha = conditions[:, 3:4] + conditions = color * alpha + 1.0 * (1.0 - alpha) + conditions = F.adaptive_avg_pool3d(conditions, (n, h, w)) + model_kwargs_channel3[key] = conditions.cpu( + ) if conditions.is_cuda else conditions + + copy_video_tensor = video_tensor.clone() + copy_source_imgs = source_imgs.clone() + + filename = rand_name(suffix='.gif') + for _ in [None] * retry: + try: + vid_gif = rearrange( + video_tensor, '(i j) c f h w -> c f (i h) (j w)', i=nrow) + cons_list = [ + rearrange(con, '(i j) c f h w -> c f (i h) (j w)', j=nrow) + for _, con in model_kwargs_channel3.items() + ] + source_imgs = rearrange( + source_imgs, '(i j) c f h w -> c f (i h) (j w)', i=nrow) + vid_gif = torch.cat( + [ + source_imgs, + ] + cons_list + [ + vid_gif, + ], dim=3) + + video_tensor_to_gif(vid_gif, filename) + bucket.put_object_from_file(vis_oss_key, filename) + exception = None + break + except Exception as e: + exception = e + continue + + # remove temporary file + if osp.exists(filename): + os.remove(filename) + + filename_pred = rand_name(suffix='.pkl') + for _ in [None] * retry: + try: + copy_video_np = (copy_video_tensor.numpy() * 255).astype('uint8') + pickle.dump(copy_video_np, open(filename_pred, 'wb')) + bucket.put_object_from_file(video_save_key, filename_pred) + break + except Exception as e: + logger.error('error! ', video_save_key, e) + continue + + # remove temporary file + if osp.exists(filename_pred): + os.remove(filename_pred) + + filename_gt = rand_name(suffix='.pkl') + for _ in [None] * retry: + try: + copy_source_np = (copy_source_imgs.numpy() * 255).astype('uint8') + pickle.dump(copy_source_np, open(filename_gt, 'wb')) + bucket.put_object_from_file(gt_video_save_key, filename_gt) + break + except Exception as e: + logger.error('error! ', gt_video_save_key, e) + continue + + # remove temporary file + if osp.exists(filename_gt): + os.remove(filename_gt) + + if exception is not None: + logger.error( + 'save video to {} failed, error: {}'.format( + vis_oss_key, exception), + flush=True) + + +@torch.no_grad() +def save_video_vs_conditions(bucket, + oss_key, + video_tensor, + conditions, + source_imgs, + mean=[0.5, 0.5, 0.5], + std=[0.5, 0.5, 0.5], + nrow=8, + retry=5): + mean = torch.tensor(mean, device=video_tensor.device).view(1, -1, 1, 1, 1) + std = torch.tensor(std, device=video_tensor.device).view(1, -1, 1, 1, 1) + video_tensor = video_tensor.mul_(std).add_(mean) + video_tensor.clamp_(0, 1) + + b, c, n, h, w = video_tensor.shape + source_imgs = F.adaptive_avg_pool3d(source_imgs, (n, h, w)) + source_imgs = source_imgs.cpu() + + if conditions.size(1) == 1: + conditions = torch.cat([conditions, conditions, conditions], dim=1) + conditions = F.adaptive_avg_pool3d(conditions, (n, h, w)) + + filename = rand_name(suffix='.gif') + for _ in [None] * retry: + try: + vid_gif = rearrange( + video_tensor, '(i j) c f h w -> c f (i h) (j w)', i=nrow) + con_gif = rearrange( + conditions, '(i j) c f h w -> c f (i h) (j w)', i=nrow) + source_imgs = rearrange( + source_imgs, '(i j) c f h w -> c f (i h) (j w)', i=nrow) + vid_gif = torch.cat([vid_gif, con_gif, source_imgs], dim=2) + + video_tensor_to_gif(vid_gif, filename) + bucket.put_object_from_file(oss_key, filename) + exception = None + break + except Exception as e: + exception = e + logger.error(exception) + continue + + # remove temporary file + if osp.exists(filename): + os.remove(filename) + if exception is not None: + logger.error( + 'save video to {} failed, error: {}'.format(oss_key, exception), + flush=True) + + +@torch.no_grad() +def save_video_grid_mp4(bucket, + oss_key, + tensor, + mean=[0.5, 0.5, 0.5], + std=[0.5, 0.5, 0.5], + nrow=None, + fps=5, + retry=5): + mean = torch.tensor(mean, device=tensor.device).view(1, -1, 1, 1, 1) + std = torch.tensor(std, device=tensor.device).view(1, -1, 1, 1, 1) + tensor = tensor.mul_(std).add_(mean) + tensor.clamp_(0, 1) + b, c, t, h, w = tensor.shape + tensor = tensor.permute(0, 2, 3, 4, 1) + tensor = (tensor.cpu().numpy() * 255).astype('uint8') + + filename = rand_name(suffix='.mp4') + for _ in [None] * retry: + try: + if nrow is None: + nrow = math.ceil(math.sqrt(b)) + ncol = math.ceil(b / nrow) + padding = 1 + video_grid = np.zeros((t, (padding + h) * nrow + padding, + (padding + w) * ncol + padding, c), + dtype='uint8') + for i in range(b): + r = i // ncol + c_ = i % ncol + + start_r = (padding + h) * r + start_c = (padding + w) * c_ + video_grid[:, start_r:start_r + h, + start_c:start_c + w] = tensor[i] + skvideo.io.vwrite(filename, video_grid, inputdict={'-r': str(fps)}) + + bucket.put_object_from_file(oss_key, filename) + exception = None + break + except Exception as e: + exception = e + logger.error(exception) + continue + + # remove temporary file + if osp.exists(filename): + os.remove(filename) + if exception is not None: + logger.error( + 'save video to {} failed, error: {}'.format(oss_key, exception), + flush=True) + + +@torch.no_grad() +def save_text(bucket, oss_key, tensor, nrow=8, retry=5): + len = tensor.shape[0] + num_per_row = int(len / nrow) + assert (len == nrow * num_per_row) + texts = '' + for i in range(nrow): + for j in range(num_per_row): + text = dec_bytes2obj(tensor[i * num_per_row + j]) + texts += text + '\n' + texts += '\n' + + for _ in [None] * retry: + try: + bucket.put_object(oss_key, texts) + exception = None + break + except Exception as e: + exception = e + logger.error(exception) + continue + if exception is not None: + logger.error( + 'save video to {} failed, error: {}'.format(oss_key, exception), + flush=True) + + +@torch.no_grad() +def save_caps(bucket, oss_key, caps, retry=5): + texts = '' + for cap in caps: + texts += cap + texts += '\n' + + for _ in [None] * retry: + try: + bucket.put_object(oss_key, texts) + exception = None + break + except Exception as e: + exception = e + logger.error(exception) + continue + if exception is not None: + logger.error( + 'save video to {} failed, error: {}'.format(oss_key, exception), + flush=True) + + +@torch.no_grad() +def ema(net_ema, net, beta, copy_buffer=False): + assert 0.0 <= beta <= 1.0 + for p_ema, p in zip(net_ema.parameters(), net.parameters()): + p_ema.copy_(p.lerp(p_ema, beta)) + if copy_buffer: + for b_ema, b in zip(net_ema.buffers(), net.buffers()): + b_ema.copy_(b) + + +def parallel(func, args_list, num_workers=32, timeout=None): + assert isinstance(args_list, list) + if not isinstance(args_list[0], tuple): + args_list = [(args, ) for args in args_list] + if num_workers == 0: + return [func(*args) for args in args_list] + with Pool(processes=num_workers) as pool: + results = [pool.apply_async(func, args) for args in args_list] + results = [res.get(timeout=timeout) for res in results] + return results + + +def exists(filename): + if filename.startswith('oss://'): + bucket, path = parse_oss_url(filename) + return bucket.object_exists(path) + else: + return osp.exists(filename) + + +def download(url, filename=None, replace=False, quiet=False): + if filename is None: + filename = osp.basename(url) + if not osp.exists(filename) or replace: + try: + if url.startswith('oss://'): + bucket, oss_key = parse_oss_url(url) + bucket.get_object_to_file(oss_key, filename) + else: + urllib.request.urlretrieve(url, filename) + if not quiet: + logger.error(f'Downloaded {url} to {filename}', flush=True) + except Exception as e: + raise ValueError(f'Downloading {filename} failed with error {e}') + return osp.abspath(filename) + + +def unzip(filename, dst_dir=None): + if dst_dir is None: + dst_dir = osp.dirname(filename) + with zipfile.ZipFile(filename, 'r') as zip_ref: + zip_ref.extractall(dst_dir) + + +def load_state_dict(module, state_dict, drop_prefix=''): + # find incompatible key-vals + src, dst = state_dict, module.state_dict() + if drop_prefix: + src = type(src)([ + (k[len(drop_prefix):] if k.startswith(drop_prefix) else k, v) + for k, v in src.items() + ]) + missing = [k for k in dst if k not in src] + unexpected = [k for k in src if k not in dst] + unmatched = [ + k for k in src.keys() & dst.keys() if src[k].shape != dst[k].shape + ] + + # keep only compatible key-vals + incompatible = set(unexpected + unmatched) + src = type(src)([(k, v) for k, v in src.items() if k not in incompatible]) + module.load_state_dict(src, strict=False) + + # report incompatible key-vals + if len(missing) != 0: + logger.info(' Missing: ' + ', '.join(missing), flush=True) + if len(unexpected) != 0: + logger.info(' Unexpected: ' + ', '.join(unexpected), flush=True) + if len(unmatched) != 0: + logger.info(' Shape unmatched: ' + ', '.join(unmatched), flush=True) + + +def inverse_indices(indices): + r"""Inverse map of indices. + E.g., if A[indices] == B, then B[inv_indices] == A. + """ + inv_indices = torch.empty_like(indices) + inv_indices[indices] = torch.arange(len(indices)).to(indices) + return inv_indices + + +def detect_duplicates(feats, thr=0.9): + assert feats.ndim == 2 + + # compute simmat + feats = F.normalize(feats, p=2, dim=1) + simmat = torch.mm(feats, feats.T) + simmat.triu_(1) + torch.cuda.synchronize() + + # detect duplicates + mask = ~simmat.gt(thr).any(dim=0) + return torch.where(mask)[0] + + +class TFSClient(object): + + def __init__(self, + host='restful-store.vip.tbsite.net:3800', + app_key='5354c9fae75f5'): + self.host = host + self.app_key = app_key + + # candidate servers + self.servers = [ + u for u in read(f'http://{host}/url.list').strip().split('\n')[1:] + if ':' in u + ] + assert len(self.servers) >= 1 + self.__server_id = -1 + + @property + def server(self): + self.__server_id = (self.__server_id + 1) % len(self.servers) + return self.servers[self.__server_id] + + def read(self, tfs): + tfs = osp.basename(tfs) + meta = json.loads( + read( + f'http://{self.server}/v1/{self.app_key}/metadata/{tfs}?force=0' + )) + img = Image.open( + BytesIO( + read( + f'http://{self.server}/v1/{self.app_key}/{tfs}?offset=0&size={meta["SIZE"]}', + 'rb'))) + return img + + +def read_tfs(tfs, retry=5): + exception = None + for _ in range(retry): + try: + global TFS_CLIENT + if TFS_CLIENT is None: + TFS_CLIENT = TFSClient() + return TFS_CLIENT.read(tfs) + except Exception as e: + exception = e + continue + else: + raise exception + + +def md5(filename): + with open(filename, 'rb') as f: + return hashlib.md5(f.read()).hexdigest() + + +def rope(x): + r"""Apply rotary position embedding on x of shape [B, *(spatial dimensions), C]. + """ + # reshape + shape = x.shape + x = x.view(x.size(0), -1, x.size(-1)) + l, c = x.shape[-2:] + assert c % 2 == 0 + half = c // 2 + + # apply rotary position embedding on x + sinusoid = torch.outer( + torch.arange(l).to(x), + torch.pow(10000, -torch.arange(half).to(x).div(half))) + sin, cos = torch.sin(sinusoid), torch.cos(sinusoid) + x1, x2 = x.chunk(2, dim=-1) + x = torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1) + + # reshape back + return x.view(shape) + + +def format_state(state, filename=None): + r"""For comparing/aligning state_dict. + """ + content = '\n'.join([f'{k}\t{tuple(v.shape)}' for k, v in state.items()]) + if filename: + with open(filename, 'w') as f: + f.write(content) + + +def breakup_grid(img, grid_size): + r"""The inverse operator of ``torchvision.utils.make_grid``. + """ + # params + nrow = img.height // grid_size + ncol = img.width // grid_size + wrow = wcol = 2 + + # collect grids + grids = [] + for i in range(nrow): + for j in range(ncol): + x1 = j * grid_size + (j + 1) * wcol + y1 = i * grid_size + (i + 1) * wrow + grids.append(img.crop((x1, y1, x1 + grid_size, y1 + grid_size))) + return grids + + +def huggingface_tokenizer(name='google/mt5-xxl', **kwargs): + from transformers import AutoTokenizer + return AutoTokenizer.from_pretrained( + DOWNLOAD_TO_CACHE(f'huggingface/tokenizers/{name}', name), **kwargs) + + +def huggingface_model(name='google/mt5-xxl', model_type='AutoModel', **kwargs): + import transformers + return getattr(transformers, model_type).from_pretrained( + DOWNLOAD_TO_CACHE(f'huggingface/models/{name}', name), **kwargs) diff --git a/modelscope/models/multi_modal/videocomposer/unet_sd.py b/modelscope/models/multi_modal/videocomposer/unet_sd.py new file mode 100644 index 00000000..23ee29f6 --- /dev/null +++ b/modelscope/models/multi_modal/videocomposer/unet_sd.py @@ -0,0 +1,2102 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import math +import os +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from fairscale.nn.checkpoint import checkpoint_wrapper +from rotary_embedding_torch import RotaryEmbedding +from torch import einsum + +__all__ = ['UNetSD_temporal'] + +USE_TEMPORAL_TRANSFORMER = True + + +# load all keys started with prefix and replace them with new_prefix +def load_Block(state, prefix, new_prefix=None): + if new_prefix is None: + new_prefix = prefix + + state_dict = {} + + state = {key: value for key, value in state.items() if prefix in key} + + for key, value in state.items(): + new_key = key.replace(prefix, new_prefix) + state_dict[new_key] = value + + return state_dict + + +def load_2d_pretrained_state_dict(state, cfg): + + new_state_dict = {} + + dim = cfg.unet_dim + num_res_blocks = cfg.unet_res_blocks + dim_mult = cfg.unet_dim_mult + attn_scales = cfg.unet_attn_scales + + # params + enc_dims = [dim * u for u in [1] + dim_mult] + dec_dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + shortcut_dims = [] + scale = 1.0 + + # embeddings + state_dict = load_Block(state, prefix='time_embedding') + new_state_dict.update(state_dict) + state_dict = load_Block(state, prefix='y_embedding') + new_state_dict.update(state_dict) + state_dict = load_Block(state, prefix='context_embedding') + new_state_dict.update(state_dict) + + encoder_idx = 0 + # init block + state_dict = load_Block( + state, + prefix=f'encoder.{encoder_idx}', + new_prefix=f'encoder.{encoder_idx}.0') + new_state_dict.update(state_dict) + encoder_idx += 1 + + shortcut_dims.append(dim) + for i, (in_dim, out_dim) in enumerate(zip(enc_dims[:-1], enc_dims[1:])): + for j in range(num_res_blocks): + # residual (+attention) blocks + idx = 0 + idx_ = 0 + # residual (+attention) blocks + state_dict = load_Block( + state, + prefix=f'encoder.{encoder_idx}.{idx}', + new_prefix=f'encoder.{encoder_idx}.{idx_}') + new_state_dict.update(state_dict) + idx += 1 + idx_ = 2 + + if scale in attn_scales: + # block.append(AttentionBlock(out_dim, context_dim, num_heads, head_dim)) + state_dict = load_Block( + state, + prefix=f'encoder.{encoder_idx}.{idx}', + new_prefix=f'encoder.{encoder_idx}.{idx_}') + new_state_dict.update(state_dict) + in_dim = out_dim + encoder_idx += 1 + shortcut_dims.append(out_dim) + + # downsample + if i != len(dim_mult) - 1 and j == num_res_blocks - 1: + state_dict = load_Block( + state, + prefix='encoder.{encoder_idx}', + new_prefix='encoder.{encoder_idx}.0') + new_state_dict.update(state_dict) + + shortcut_dims.append(out_dim) + scale /= 2.0 + encoder_idx += 1 + + # middle + middle_idx = 0 + + state_dict = load_Block(state, prefix=f'middle.{middle_idx}') + new_state_dict.update(state_dict) + middle_idx += 2 + + state_dict = load_Block( + state, prefix='middle.1', new_prefix=f'middle.{middle_idx}') + new_state_dict.update(state_dict) + middle_idx += 1 + + for _ in range(cfg.temporal_attn_times): + middle_idx += 1 + + state_dict = load_Block( + state, prefix='middle.2', new_prefix=f'middle.{middle_idx}') + new_state_dict.update(state_dict) + middle_idx += 2 + + decoder_idx = 0 + for i, (in_dim, out_dim) in enumerate(zip(dec_dims[:-1], dec_dims[1:])): + for j in range(num_res_blocks + 1): + idx = 0 + idx_ = 0 + # residual (+attention) blocks + state_dict = load_Block( + state, + prefix=f'decoder.{decoder_idx}.{idx}', + new_prefix=f'decoder.{decoder_idx}.{idx_}') + new_state_dict.update(state_dict) + idx += 1 + idx_ += 2 + if scale in attn_scales: + state_dict = load_Block( + state, + prefix=f'decoder.{decoder_idx}.{idx}', + new_prefix=f'decoder.{decoder_idx}.{idx_}') + new_state_dict.update(state_dict) + idx += 1 + idx_ += 1 + for _ in range(cfg.temporal_attn_times): + idx_ += 1 + + # upsample + if i != len(dim_mult) - 1 and j == num_res_blocks: + state_dict = load_Block( + state, + prefix=f'decoder.{decoder_idx}.{idx}', + new_prefix=f'decoder.{decoder_idx}.{idx_}') + new_state_dict.update(state_dict) + idx += 1 + idx_ += 2 + + scale *= 2.0 + decoder_idx += 1 + + state_dict = load_Block(state, prefix='head') + new_state_dict.update(state_dict) + + return new_state_dict + + +def sinusoidal_embedding(timesteps, dim): + # check input + half = dim // 2 + timesteps = timesteps.float() + + # compute sinusoidal embedding + sinusoid = torch.outer( + timesteps, torch.pow(10000, + -torch.arange(half).to(timesteps).div(half))) + x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) + if dim % 2 != 0: + x = torch.cat([x, torch.zeros_like(x[:, :1])], dim=1) + return x + + +def exists(x): + return x is not None + + +def default(val, d): + if exists(val): + return val + return d() if callable(d) else d + + +def prob_mask_like(shape, prob, device): + if prob == 1: + return torch.ones(shape, device=device, dtype=torch.bool) + elif prob == 0: + return torch.zeros(shape, device=device, dtype=torch.bool) + else: + mask = torch.zeros(shape, device=device).float().uniform_(0, 1) < prob + # aviod mask all, which will cause find_unused_parameters error + if mask.all(): + mask[0] = False + return mask + + +class RelativePositionBias(nn.Module): + + def __init__(self, heads=8, num_buckets=32, max_distance=128): + super().__init__() + self.num_buckets = num_buckets + self.max_distance = max_distance + self.relative_attention_bias = nn.Embedding(num_buckets, heads) + + @staticmethod + def _relative_position_bucket(relative_position, + num_buckets=32, + max_distance=128): + ret = 0 + n = -relative_position + + num_buckets //= 2 + ret += (n < 0).long() * num_buckets + n = torch.abs(n) + + max_exact = num_buckets // 2 + is_small = n < max_exact + + val_if_large = max_exact + ( + torch.log(n.float() / max_exact) + / math.log(max_distance / max_exact) * # noqa + (num_buckets - max_exact)).long() + val_if_large = torch.min( + val_if_large, torch.full_like(val_if_large, num_buckets - 1)) + + ret += torch.where(is_small, n, val_if_large) + return ret + + def forward(self, n, device): + q_pos = torch.arange(n, dtype=torch.long, device=device) + k_pos = torch.arange(n, dtype=torch.long, device=device) + rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1') + rp_bucket = self._relative_position_bucket( + rel_pos, + num_buckets=self.num_buckets, + max_distance=self.max_distance) + values = self.relative_attention_bias(rp_bucket) + return rearrange(values, 'i j h -> h i j') + + +class SpatialTransformer(nn.Module): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + NEW: use_linear for more efficiency instead of the 1x1 convs + """ + + def __init__(self, + in_channels, + n_heads, + d_head, + depth=1, + dropout=0., + context_dim=None, + disable_self_attn=False, + use_linear=False, + use_checkpoint=True): + super().__init__() + if exists(context_dim) and not isinstance(context_dim, list): + context_dim = [context_dim] + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = torch.nn.GroupNorm( + num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + if not use_linear: + self.proj_in = nn.Conv2d( + in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + else: + self.proj_in = nn.Linear(in_channels, inner_dim) + + self.transformer_blocks = nn.ModuleList([ + BasicTransformerBlock( + inner_dim, + n_heads, + d_head, + dropout=dropout, + context_dim=context_dim[d], + disable_self_attn=disable_self_attn, + checkpoint=use_checkpoint) for d in range(depth) + ]) + if not use_linear: + self.proj_out = zero_module( + nn.Conv2d( + inner_dim, in_channels, kernel_size=1, stride=1, + padding=0)) + else: + self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) + self.use_linear = use_linear + + def forward(self, x, context=None): + # note: if no context is given, cross-attention defaults to self-attention + if not isinstance(context, list): + context = [context] + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + if not self.use_linear: + x = self.proj_in(x) + x = rearrange(x, 'b c h w -> b (h w) c').contiguous() + if self.use_linear: + x = self.proj_in(x) + for i, block in enumerate(self.transformer_blocks): + x = block(x, context=context[i]) + if self.use_linear: + x = self.proj_out(x) + x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() + if not self.use_linear: + x = self.proj_out(x) + return x + x_in + + +_ATTN_PRECISION = os.environ.get('ATTN_PRECISION', 'fp32') + + +class CrossAttention(nn.Module): + + def __init__(self, + query_dim, + context_dim=None, + heads=8, + dim_head=64, + dropout=0.): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.scale = dim_head**-0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) + + def forward(self, x, context=None, mask=None): + h = self.heads + + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), + (q, k, v)) + + # force cast to fp32 to avoid overflowing + if _ATTN_PRECISION == 'fp32': + with torch.autocast(enabled=False, device_type='cuda'): + q, k = q.float(), k.float() + sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale + else: + sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale + + del q, k + + if exists(mask): + mask = rearrange(mask, 'b ... -> b (...)') + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h=h) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + sim = sim.softmax(dim=-1) + + out = torch.einsum('b i j, b j d -> b i d', sim, v) + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + return self.to_out(out) + + +class BasicTransformerBlock(nn.Module): + + def __init__(self, + dim, + n_heads, + d_head, + dropout=0., + context_dim=None, + gated_ff=True, + checkpoint=True, + disable_self_attn=False): + super().__init__() + attn_cls = CrossAttention + self.disable_self_attn = disable_self_attn + self.attn1 = attn_cls( + query_dim=dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout, + context_dim=context_dim if self.disable_self_attn else None) + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = attn_cls( + query_dim=dim, + context_dim=context_dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout) + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + + def forward_(self, x, context=None): + return checkpoint(self._forward, (x, context), self.parameters(), + self.checkpoint) + + def forward(self, x, context=None): + x = self.attn1( + self.norm1(x), + context=context if self.disable_self_attn else None) + x + x = self.attn2(self.norm2(x), context=context) + x + x = self.ff(self.norm3(x)) + x + return x + + +# feedforward +class GEGLU(nn.Module): + + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +class FeedForward(nn.Module): + + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = nn.Sequential(nn.Linear( + dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential(project_in, nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out)) + + def forward(self, x): + return self.net(x) + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, + channels, + use_conv, + dims=2, + out_channels=None, + padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = nn.Conv2d( + self.channels, self.out_channels, 3, padding=padding) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate( + x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), + mode='nearest') + else: + x = F.interpolate(x, scale_factor=2, mode='nearest') + if self.use_conv: + x = self.conv(x) + return x + + +class ResBlock(nn.Module): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param use_checkpoint: if True, use gradient checkpointing on this module. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + up=False, + down=False, + use_temporal_conv=True, + use_image_dataset=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_scale_shift_norm = use_scale_shift_norm + self.use_temporal_conv = use_temporal_conv + + self.in_layers = nn.Sequential( + nn.GroupNorm(32, channels), + nn.SiLU(), + nn.Conv2d(channels, self.out_channels, 3, padding=1), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + nn.SiLU(), + nn.Linear( + emb_channels, + 2 * self.out_channels + if use_scale_shift_norm else self.out_channels, + ), + ) + self.out_layers = nn.Sequential( + nn.GroupNorm(32, self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module( + nn.Conv2d(self.out_channels, self.out_channels, 3, padding=1)), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, 3, padding=1) + else: + self.skip_connection = nn.Conv2d(channels, self.out_channels, 1) + + if self.use_temporal_conv: + self.temopral_conv = TemporalConvBlock_v2( + self.out_channels, + self.out_channels, + dropout=0.1, + use_image_dataset=use_image_dataset) + + def forward(self, x, emb, batch_size): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + return self._forward(x, emb, batch_size) + + def _forward(self, x, emb, batch_size): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = th.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + h = self.skip_connection(x) + h + + if self.use_temporal_conv: + h = rearrange(h, '(b f) c h w -> b c f h w', b=batch_size) + h = self.temopral_conv(h) + h = rearrange(h, 'b c f h w -> (b f) c h w') + return h + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, + channels, + use_conv, + dims=2, + out_channels=None, + padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = nn.Conv2d( + self.channels, + self.out_channels, + 3, + stride=stride, + padding=padding) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class Resample(nn.Module): + + def __init__(self, in_dim, out_dim, mode): + assert mode in ['none', 'upsample', 'downsample'] + super(Resample, self).__init__() + self.in_dim = in_dim + self.out_dim = out_dim + self.mode = mode + + def forward(self, x, reference=None): + if self.mode == 'upsample': + assert reference is not None + x = F.interpolate(x, size=reference.shape[-2:], mode='nearest') + elif self.mode == 'downsample': + x = F.adaptive_avg_pool2d( + x, output_size=tuple(u // 2 for u in x.shape[-2:])) + return x + + +class ResidualBlock(nn.Module): + + def __init__(self, + in_dim, + embed_dim, + out_dim, + use_scale_shift_norm=True, + mode='none', + dropout=0.0): + super(ResidualBlock, self).__init__() + self.in_dim = in_dim + self.embed_dim = embed_dim + self.out_dim = out_dim + self.use_scale_shift_norm = use_scale_shift_norm + self.mode = mode + + # layers + self.layer1 = nn.Sequential( + nn.GroupNorm(32, in_dim), nn.SiLU(), + nn.Conv2d(in_dim, out_dim, 3, padding=1)) + self.resample = Resample(in_dim, in_dim, mode) + self.embedding = nn.Sequential( + nn.SiLU(), + nn.Linear(embed_dim, + out_dim * 2 if use_scale_shift_norm else out_dim)) + self.layer2 = nn.Sequential( + nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout), + nn.Conv2d(out_dim, out_dim, 3, padding=1)) + self.shortcut = nn.Identity() if in_dim == out_dim else nn.Conv2d( + in_dim, out_dim, 1) + + # zero out the last layer params + nn.init.zeros_(self.layer2[-1].weight) + + def forward(self, x, e, reference=None): + identity = self.resample(x, reference) + x = self.layer1[-1](self.resample(self.layer1[:-1](x), reference)) + e = self.embedding(e).unsqueeze(-1).unsqueeze(-1).type(x.dtype) + if self.use_scale_shift_norm: + scale, shift = e.chunk(2, dim=1) + x = self.layer2[0](x) * (1 + scale) + shift + x = self.layer2[1:](x) + else: + x = x + e + x = self.layer2(x) + x = x + self.shortcut(identity) + return x + + +class AttentionBlock(nn.Module): + + def __init__(self, dim, context_dim=None, num_heads=None, head_dim=None): + # consider head_dim first, then num_heads + num_heads = dim // head_dim if head_dim else num_heads + head_dim = dim // num_heads + assert num_heads * head_dim == dim + super(AttentionBlock, self).__init__() + self.dim = dim + self.context_dim = context_dim + self.num_heads = num_heads + self.head_dim = head_dim + self.scale = math.pow(head_dim, -0.25) + + # layers + self.norm = nn.GroupNorm(32, dim) + self.to_qkv = nn.Conv2d(dim, dim * 3, 1) + if context_dim is not None: + self.context_kv = nn.Linear(context_dim, dim * 2) + self.proj = nn.Conv2d(dim, dim, 1) + + # zero out the last layer params + nn.init.zeros_(self.proj.weight) + + def forward(self, x, context=None): + r"""x: [B, C, H, W]. + context: [B, L, C] or None. + """ + identity = x + b, c, h, w, n, d = *x.size(), self.num_heads, self.head_dim + + # compute query, key, value + x = self.norm(x) + q, k, v = self.to_qkv(x).view(b, n * 3, d, h * w).chunk(3, dim=1) + if context is not None: + ck, cv = self.context_kv(context).reshape(b, -1, n * 2, + d).permute(0, 2, 3, + 1).chunk( + 2, dim=1) + k = torch.cat([ck, k], dim=-1) + v = torch.cat([cv, v], dim=-1) + + # compute attention + attn = torch.matmul(q.transpose(-1, -2) * self.scale, k * self.scale) + attn = F.softmax(attn, dim=-1) + + # gather context + x = torch.matmul(v, attn.transpose(-1, -2)) + x = x.reshape(b, c, h, w) + + # output + x = self.proj(x) + return x + identity + + +class TemporalAttentionBlock(nn.Module): + + def __init__(self, + dim, + heads=4, + dim_head=32, + rotary_emb=None, + use_image_dataset=False, + use_sim_mask=False): + super().__init__() + # consider num_heads first, as pos_bias needs fixed num_heads + dim_head = dim // heads + assert heads * dim_head == dim + self.use_image_dataset = use_image_dataset + self.use_sim_mask = use_sim_mask + + self.scale = dim_head**-0.5 + self.heads = heads + hidden_dim = dim_head * heads + + self.norm = nn.GroupNorm(32, dim) + self.rotary_emb = rotary_emb + self.to_qkv = nn.Linear(dim, hidden_dim * 3) + self.to_out = nn.Linear(hidden_dim, dim) + + def forward(self, + x, + pos_bias=None, + focus_present_mask=None, + video_mask=None): + + identity = x + n, height, device = x.shape[2], x.shape[-2], x.device + + x = self.norm(x) + x = rearrange(x, 'b c f h w -> b (h w) f c') + + qkv = self.to_qkv(x).chunk(3, dim=-1) + + if exists(focus_present_mask) and focus_present_mask.all(): + # if all batch samples are focusing on present + # it would be equivalent to passing that token's values (v=qkv[-1]) through to the output + values = qkv[-1] + out = self.to_out(values) + out = rearrange(out, 'b (h w) f c -> b c f h w', h=height) + + return out + identity + + # split out heads + q = rearrange(qkv[0], '... n (h d) -> ... h n d', h=self.heads) + k = rearrange(qkv[1], '... n (h d) -> ... h n d', h=self.heads) + v = rearrange(qkv[2], '... n (h d) -> ... h n d', h=self.heads) + + # scale + q = q * self.scale + + # rotate positions into queries and keys for time attention + if exists(self.rotary_emb): + q = self.rotary_emb.rotate_queries_or_keys(q) + k = self.rotary_emb.rotate_queries_or_keys(k) + + # similarity + sim = torch.einsum('... h i d, ... h j d -> ... h i j', q, k) + + # relative positional bias + + if exists(pos_bias): + sim = sim + pos_bias + + if (focus_present_mask is None and video_mask is not None): + mask = video_mask[:, None, :] * video_mask[:, :, None] + mask = mask.unsqueeze(1).unsqueeze(1) + sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max) + elif exists(focus_present_mask) and not (~focus_present_mask).all(): + attend_all_mask = torch.ones((n, n), + device=device, + dtype=torch.bool) + attend_self_mask = torch.eye(n, device=device, dtype=torch.bool) + + mask = torch.where( + rearrange(focus_present_mask, 'b -> b 1 1 1 1'), + rearrange(attend_self_mask, 'i j -> 1 1 1 i j'), + rearrange(attend_all_mask, 'i j -> 1 1 1 i j'), + ) + + sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max) + + if self.use_sim_mask: + sim_mask = torch.tril( + torch.ones((n, n), device=device, dtype=torch.bool), + diagonal=0) + sim = sim.masked_fill(~sim_mask, -torch.finfo(sim.dtype).max) + + # numerical stability + sim = sim - sim.amax(dim=-1, keepdim=True).detach() + attn = sim.softmax(dim=-1) + + # aggregate values + out = torch.einsum('... h i j, ... h j d -> ... h i d', attn, v) + out = rearrange(out, '... h n d -> ... n (h d)') + out = self.to_out(out) + + out = rearrange(out, 'b (h w) f c -> b c f h w', h=height) + + if self.use_image_dataset: + out = identity + 0 * out + else: + out = identity + out + return out + + +class TemporalTransformer(nn.Module): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + """ + + def __init__(self, + in_channels, + n_heads, + d_head, + depth=1, + dropout=0., + context_dim=None, + disable_self_attn=False, + use_linear=False, + use_checkpoint=True, + only_self_att=True, + multiply_zero=False): + super().__init__() + self.multiply_zero = multiply_zero + self.only_self_att = only_self_att + self.use_adaptor = False + if self.only_self_att: + context_dim = None + if not isinstance(context_dim, list): + context_dim = [context_dim] + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = torch.nn.GroupNorm( + num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + if not use_linear: + self.proj_in = nn.Conv1d( + in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + else: + self.proj_in = nn.Linear(in_channels, inner_dim) + if self.use_adaptor: + self.adaptor_in = nn.Linear(frames, frames) + + self.transformer_blocks = nn.ModuleList([ + BasicTransformerBlock( + inner_dim, + n_heads, + d_head, + dropout=dropout, + context_dim=context_dim[d], + checkpoint=use_checkpoint) for d in range(depth) + ]) + if not use_linear: + self.proj_out = zero_module( + nn.Conv1d( + inner_dim, in_channels, kernel_size=1, stride=1, + padding=0)) + else: + self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) + if self.use_adaptor: + self.adaptor_out = nn.Linear(frames, frames) + self.use_linear = use_linear + + def forward(self, x, context=None): + # note: if no context is given, cross-attention defaults to self-attention + if self.only_self_att: + context = None + if not isinstance(context, list): + context = [context] + b, c, f, h, w = x.shape + x_in = x + x = self.norm(x) + + if not self.use_linear: + x = rearrange(x, 'b c f h w -> (b h w) c f').contiguous() + x = self.proj_in(x) + if self.use_linear: + x = rearrange( + x, '(b f) c h w -> b (h w) f c', f=self.frames).contiguous() + x = self.proj_in(x) + + if self.only_self_att: + x = rearrange(x, 'bhw c f -> bhw f c').contiguous() + for i, block in enumerate(self.transformer_blocks): + x = block(x) + x = rearrange(x, '(b hw) f c -> b hw f c', b=b).contiguous() + else: + x = rearrange(x, '(b hw) c f -> b hw f c', b=b).contiguous() + for i, block in enumerate(self.transformer_blocks): + context[i] = rearrange( + context[i], '(b f) l con -> b f l con', + f=self.frames).contiguous() + # calculate each batch one by one + # (since number in shape could not greater then 65,535 for some package) + for j in range(b): + context_i_j = repeat( + context[i][j], + 'f l con -> (f r) l con', + r=(h * w) // self.frames, + f=self.frames).contiguous() + x[j] = block(x[j], context=context_i_j) + + if self.use_linear: + x = self.proj_out(x) + x = rearrange(x, 'b (h w) f c -> b f c h w', h=h, w=w).contiguous() + if not self.use_linear: + x = rearrange(x, 'b hw f c -> (b hw) c f').contiguous() + x = self.proj_out(x) + x = rearrange( + x, '(b h w) c f -> b c f h w', b=b, h=h, w=w).contiguous() + + if self.multiply_zero: + x = 0.0 * x + x_in + else: + x = x + x_in + return x + + +class TemporalAttentionMultiBlock(nn.Module): + + def __init__( + self, + dim, + heads=4, + dim_head=32, + rotary_emb=None, + use_image_dataset=False, + use_sim_mask=False, + temporal_attn_times=1, + ): + super().__init__() + self.att_layers = nn.ModuleList([ + TemporalAttentionBlock(dim, heads, dim_head, rotary_emb, + use_image_dataset, use_sim_mask) + for _ in range(temporal_attn_times) + ]) + + def forward(self, + x, + pos_bias=None, + focus_present_mask=None, + video_mask=None): + for layer in self.att_layers: + x = layer(x, pos_bias, focus_present_mask, video_mask) + return x + + +class InitTemporalConvBlock(nn.Module): + + def __init__(self, + in_dim, + out_dim=None, + dropout=0.0, + use_image_dataset=False): + super(InitTemporalConvBlock, self).__init__() + if out_dim is None: + out_dim = in_dim + self.in_dim = in_dim + self.out_dim = out_dim + self.use_image_dataset = use_image_dataset + + # conv layers + self.conv = nn.Sequential( + nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout), + nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0))) + + # zero out the last layer params,so the conv block is identity + nn.init.zeros_(self.conv[-1].weight) + nn.init.zeros_(self.conv[-1].bias) + + def forward(self, x): + identity = x + x = self.conv(x) + if self.use_image_dataset: + x = identity + 0 * x + else: + x = identity + x + return x + + +class TemporalConvBlock(nn.Module): + + def __init__(self, + in_dim, + out_dim=None, + dropout=0.0, + use_image_dataset=False): + super(TemporalConvBlock, self).__init__() + if out_dim is None: + out_dim = in_dim + self.in_dim = in_dim + self.out_dim = out_dim + self.use_image_dataset = use_image_dataset + + # conv layers + self.conv1 = nn.Sequential( + nn.GroupNorm(32, in_dim), nn.SiLU(), + nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0))) + self.conv2 = nn.Sequential( + nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout), + nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0))) + + # zero out the last layer params,so the conv block is identity + nn.init.zeros_(self.conv2[-1].weight) + nn.init.zeros_(self.conv2[-1].bias) + + def forward(self, x): + identity = x + x = self.conv1(x) + x = self.conv2(x) + if self.use_image_dataset: + x = identity + 0 * x + else: + x = identity + x + return x + + +class TemporalConvBlock_v2(nn.Module): + + def __init__(self, + in_dim, + out_dim=None, + dropout=0.0, + use_image_dataset=False): + super(TemporalConvBlock_v2, self).__init__() + if out_dim is None: + out_dim = in_dim + self.in_dim = in_dim + self.out_dim = out_dim + self.use_image_dataset = use_image_dataset + + # conv layers + self.conv1 = nn.Sequential( + nn.GroupNorm(32, in_dim), nn.SiLU(), + nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0))) + self.conv2 = nn.Sequential( + nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout), + nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0))) + self.conv3 = nn.Sequential( + nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout), + nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0))) + self.conv4 = nn.Sequential( + nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout), + nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0))) + + # zero out the last layer params,so the conv block is identity + nn.init.zeros_(self.conv4[-1].weight) + nn.init.zeros_(self.conv4[-1].bias) + + def forward(self, x): + identity = x + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + x = self.conv4(x) + + if self.use_image_dataset: + x = identity + 0.0 * x + else: + x = identity + x + return x + + +class UNetSD_temporal(nn.Module): + + def __init__( + self, + cfg, + in_dim=7, + dim=512, + y_dim=512, + context_dim=512, + hist_dim=156, + concat_dim=8, + out_dim=6, + dim_mult=[1, 2, 3, 4], + num_heads=None, + head_dim=64, + num_res_blocks=3, + attn_scales=[1 / 2, 1 / 4, 1 / 8], + use_scale_shift_norm=True, + dropout=0.1, + temporal_attn_times=1, + temporal_attention=True, + use_checkpoint=False, + use_image_dataset=False, + use_fps_condition=False, + use_sim_mask=False, + misc_dropout=0.5, + training=True, + inpainting=True, + video_compositions=['text', 'mask'], + p_all_zero=0.1, + p_all_keep=0.1, + zero_y=None, + black_image_feature=None, + ): + embed_dim = dim * 4 + num_heads = num_heads if num_heads else dim // 32 + super(UNetSD_temporal, self).__init__() + self.zero_y = zero_y + self.black_image_feature = black_image_feature + self.cfg = cfg + self.in_dim = in_dim + self.dim = dim + self.y_dim = y_dim + self.context_dim = context_dim + self.hist_dim = hist_dim + self.concat_dim = concat_dim + self.embed_dim = embed_dim + self.out_dim = out_dim + self.dim_mult = dim_mult + # for temporal attention + self.num_heads = num_heads + # for spatial attention + self.head_dim = head_dim + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.use_scale_shift_norm = use_scale_shift_norm + self.temporal_attn_times = temporal_attn_times + self.temporal_attention = temporal_attention + self.use_checkpoint = use_checkpoint + self.use_image_dataset = use_image_dataset + self.use_fps_condition = use_fps_condition + self.use_sim_mask = use_sim_mask + self.training = training + self.inpainting = inpainting + self.video_compositions = video_compositions + self.misc_dropout = misc_dropout + self.p_all_zero = p_all_zero + self.p_all_keep = p_all_keep + + use_linear_in_temporal = False + transformer_depth = 1 + disabled_sa = False + # params + enc_dims = [dim * u for u in [1] + dim_mult] + dec_dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + shortcut_dims = [] + scale = 1.0 + if hasattr(cfg, 'adapter_transformer_layers' + ) and cfg.adapter_transformer_layers: + adapter_transformer_layers = cfg.adapter_transformer_layers + else: + adapter_transformer_layers = 1 + + # embeddings + self.time_embed = nn.Sequential( + nn.Linear(dim, embed_dim), nn.SiLU(), + nn.Linear(embed_dim, embed_dim)) + self.pre_image_condition = nn.Sequential( + nn.Linear(1024, 1024), nn.SiLU(), nn.Linear(1024, 1024)) + + # depth embedding + if 'depthmap' in self.video_compositions: + self.depth_embedding = nn.Sequential( + nn.Conv2d(1, concat_dim * 4, 3, padding=1), nn.SiLU(), + nn.AdaptiveAvgPool2d((128, 128)), + nn.Conv2d( + concat_dim * 4, concat_dim * 4, 3, stride=2, padding=1), + nn.SiLU(), + nn.Conv2d(concat_dim * 4, concat_dim, 3, stride=2, padding=1)) + self.depth_embedding_after = Transformer_v2( + heads=2, + dim=concat_dim, + dim_head_k=concat_dim, + dim_head_v=concat_dim, + dropout_atte=0.05, + mlp_dim=concat_dim, + dropout_ffn=0.05, + depth=adapter_transformer_layers) + + if 'motion' in self.video_compositions: + self.motion_embedding = nn.Sequential( + nn.Conv2d(2, concat_dim * 4, 3, padding=1), nn.SiLU(), + nn.AdaptiveAvgPool2d((128, 128)), + nn.Conv2d( + concat_dim * 4, concat_dim * 4, 3, stride=2, padding=1), + nn.SiLU(), + nn.Conv2d(concat_dim * 4, concat_dim, 3, stride=2, padding=1)) + self.motion_embedding_after = Transformer_v2( + heads=2, + dim=concat_dim, + dim_head_k=concat_dim, + dim_head_v=concat_dim, + dropout_atte=0.05, + mlp_dim=concat_dim, + dropout_ffn=0.05, + depth=adapter_transformer_layers) + + # canny embedding + if 'canny' in self.video_compositions: + self.canny_embedding = nn.Sequential( + nn.Conv2d(1, concat_dim * 4, 3, padding=1), nn.SiLU(), + nn.AdaptiveAvgPool2d((128, 128)), + nn.Conv2d( + concat_dim * 4, concat_dim * 4, 3, stride=2, padding=1), + nn.SiLU(), + nn.Conv2d(concat_dim * 4, concat_dim, 3, stride=2, padding=1)) + self.canny_embedding_after = Transformer_v2( + heads=2, + dim=concat_dim, + dim_head_k=concat_dim, + dim_head_v=concat_dim, + dropout_atte=0.05, + mlp_dim=concat_dim, + dropout_ffn=0.05, + depth=adapter_transformer_layers) + + # masked-image embedding + if 'mask' in self.video_compositions: + self.masked_embedding = nn.Sequential( + nn.Conv2d(4, concat_dim * 4, 3, padding=1), nn.SiLU(), + nn.AdaptiveAvgPool2d((128, 128)), + nn.Conv2d( + concat_dim * 4, concat_dim + * 4, 3, stride=2, padding=1), nn.SiLU(), + nn.Conv2d(concat_dim * 4, concat_dim, 3, stride=2, + padding=1)) if inpainting else None + self.mask_embedding_after = Transformer_v2( + heads=2, + dim=concat_dim, + dim_head_k=concat_dim, + dim_head_v=concat_dim, + dropout_atte=0.05, + mlp_dim=concat_dim, + dropout_ffn=0.05, + depth=adapter_transformer_layers) + + # sketch embedding + if 'sketch' in self.video_compositions: + self.sketch_embedding = nn.Sequential( + nn.Conv2d(1, concat_dim * 4, 3, padding=1), nn.SiLU(), + nn.AdaptiveAvgPool2d((128, 128)), + nn.Conv2d( + concat_dim * 4, concat_dim * 4, 3, stride=2, padding=1), + nn.SiLU(), + nn.Conv2d(concat_dim * 4, concat_dim, 3, stride=2, padding=1)) + self.sketch_embedding_after = Transformer_v2( + heads=2, + dim=concat_dim, + dim_head_k=concat_dim, + dim_head_v=concat_dim, + dropout_atte=0.05, + mlp_dim=concat_dim, + dropout_ffn=0.05, + depth=adapter_transformer_layers) + + if 'single_sketch' in self.video_compositions: + self.single_sketch_embedding = nn.Sequential( + nn.Conv2d(1, concat_dim * 4, 3, padding=1), nn.SiLU(), + nn.AdaptiveAvgPool2d((128, 128)), + nn.Conv2d( + concat_dim * 4, concat_dim * 4, 3, stride=2, padding=1), + nn.SiLU(), + nn.Conv2d(concat_dim * 4, concat_dim, 3, stride=2, padding=1)) + self.single_sketch_embedding_after = Transformer_v2( + heads=2, + dim=concat_dim, + dim_head_k=concat_dim, + dim_head_v=concat_dim, + dropout_atte=0.05, + mlp_dim=concat_dim, + dropout_ffn=0.05, + depth=adapter_transformer_layers) + + if 'local_image' in self.video_compositions: + self.local_image_embedding = nn.Sequential( + nn.Conv2d(3, concat_dim * 4, 3, padding=1), nn.SiLU(), + nn.AdaptiveAvgPool2d((128, 128)), + nn.Conv2d( + concat_dim * 4, concat_dim * 4, 3, stride=2, padding=1), + nn.SiLU(), + nn.Conv2d(concat_dim * 4, concat_dim, 3, stride=2, padding=1)) + self.local_image_embedding_after = Transformer_v2( + heads=2, + dim=concat_dim, + dim_head_k=concat_dim, + dim_head_v=concat_dim, + dropout_atte=0.05, + mlp_dim=concat_dim, + dropout_ffn=0.05, + depth=adapter_transformer_layers) + + # Condition Dropout + self.misc_dropout = DropPath(misc_dropout) + + if temporal_attention and not USE_TEMPORAL_TRANSFORMER: + self.rotary_emb = RotaryEmbedding(min(32, head_dim)) + self.time_rel_pos_bias = RelativePositionBias( + heads=num_heads, max_distance=32) + + if self.use_fps_condition: + self.fps_embedding = nn.Sequential( + nn.Linear(dim, embed_dim), nn.SiLU(), + nn.Linear(embed_dim, embed_dim)) + nn.init.zeros_(self.fps_embedding[-1].weight) + nn.init.zeros_(self.fps_embedding[-1].bias) + + # encoder + self.input_blocks = nn.ModuleList() + if cfg.resume: + self.pre_image = nn.Sequential() + init_block = nn.ModuleList( + [nn.Conv2d(self.in_dim + concat_dim, dim, 3, padding=1)]) + else: + self.pre_image = nn.Sequential( + nn.Conv2d(self.in_dim + concat_dim, self.in_dim, 1, padding=0)) + init_block = nn.ModuleList( + [nn.Conv2d(self.in_dim, dim, 3, padding=1)]) + + # need an initial temporal attention? + if temporal_attention: + if USE_TEMPORAL_TRANSFORMER: + init_block.append( + TemporalTransformer( + dim, + num_heads, + head_dim, + depth=transformer_depth, + context_dim=context_dim, + disable_self_attn=disabled_sa, + use_linear=use_linear_in_temporal, + multiply_zero=use_image_dataset)) + else: + init_block.append( + TemporalAttentionMultiBlock( + dim, + num_heads, + head_dim, + rotary_emb=self.rotary_emb, + temporal_attn_times=temporal_attn_times, + use_image_dataset=use_image_dataset)) + + self.input_blocks.append(init_block) + shortcut_dims.append(dim) + for i, (in_dim, + out_dim) in enumerate(zip(enc_dims[:-1], enc_dims[1:])): + for j in range(num_res_blocks): + # residual (+attention) blocks + block = nn.ModuleList([ + ResBlock( + in_dim, + embed_dim, + dropout, + out_channels=out_dim, + use_scale_shift_norm=False, + use_image_dataset=use_image_dataset, + ) + ]) + if scale in attn_scales: + # + block.append( + SpatialTransformer( + out_dim, + out_dim // head_dim, + head_dim, + depth=1, + context_dim=self.context_dim, + disable_self_attn=False, + use_linear=True)) + if self.temporal_attention: + if USE_TEMPORAL_TRANSFORMER: + block.append( + TemporalTransformer( + out_dim, + out_dim // head_dim, + head_dim, + depth=transformer_depth, + context_dim=context_dim, + disable_self_attn=disabled_sa, + use_linear=use_linear_in_temporal, + multiply_zero=use_image_dataset)) + else: + block.append( + TemporalAttentionMultiBlock( + out_dim, + num_heads, + head_dim, + rotary_emb=self.rotary_emb, + use_image_dataset=use_image_dataset, + use_sim_mask=use_sim_mask, + temporal_attn_times=temporal_attn_times)) + in_dim = out_dim + self.input_blocks.append(block) + shortcut_dims.append(out_dim) + + # downsample + if i != len(dim_mult) - 1 and j == num_res_blocks - 1: + downsample = Downsample( + out_dim, True, dims=2, out_channels=out_dim) + shortcut_dims.append(out_dim) + scale /= 2.0 + self.input_blocks.append(downsample) + + # middle + self.middle_block = nn.ModuleList([ + ResBlock( + out_dim, + embed_dim, + dropout, + use_scale_shift_norm=False, + use_image_dataset=use_image_dataset, + ), + SpatialTransformer( + out_dim, + out_dim // head_dim, + head_dim, + depth=1, + context_dim=self.context_dim, + disable_self_attn=False, + use_linear=True) + ]) + + if self.temporal_attention: + if USE_TEMPORAL_TRANSFORMER: + self.middle_block.append( + TemporalTransformer( + out_dim, + out_dim // head_dim, + head_dim, + depth=transformer_depth, + context_dim=context_dim, + disable_self_attn=disabled_sa, + use_linear=use_linear_in_temporal, + multiply_zero=use_image_dataset, + )) + else: + self.middle_block.append( + TemporalAttentionMultiBlock( + out_dim, + num_heads, + head_dim, + rotary_emb=self.rotary_emb, + use_image_dataset=use_image_dataset, + use_sim_mask=use_sim_mask, + temporal_attn_times=temporal_attn_times)) + + self.middle_block.append( + ResBlock(out_dim, embed_dim, dropout, use_scale_shift_norm=False)) + + # decoder + self.output_blocks = nn.ModuleList() + for i, (in_dim, + out_dim) in enumerate(zip(dec_dims[:-1], dec_dims[1:])): + for j in range(num_res_blocks + 1): + block = nn.ModuleList([ + ResBlock( + in_dim + shortcut_dims.pop(), + embed_dim, + dropout, + out_dim, + use_scale_shift_norm=False, + use_image_dataset=use_image_dataset, + ) + ]) + if scale in attn_scales: + block.append( + SpatialTransformer( + out_dim, + out_dim // head_dim, + head_dim, + depth=1, + context_dim=1024, + disable_self_attn=False, + use_linear=True)) + if self.temporal_attention: + if USE_TEMPORAL_TRANSFORMER: + block.append( + TemporalTransformer( + out_dim, + out_dim // head_dim, + head_dim, + depth=transformer_depth, + context_dim=context_dim, + disable_self_attn=disabled_sa, + use_linear=use_linear_in_temporal, + multiply_zero=use_image_dataset)) + else: + block.append( + TemporalAttentionMultiBlock( + out_dim, + num_heads, + head_dim, + rotary_emb=self.rotary_emb, + use_image_dataset=use_image_dataset, + use_sim_mask=use_sim_mask, + temporal_attn_times=temporal_attn_times)) + in_dim = out_dim + + # upsample + if i != len(dim_mult) - 1 and j == num_res_blocks: + upsample = Upsample( + out_dim, True, dims=2.0, out_channels=out_dim) + scale *= 2.0 + block.append(upsample) + self.output_blocks.append(block) + + # head + self.out = nn.Sequential( + nn.GroupNorm(32, out_dim), nn.SiLU(), + nn.Conv2d(out_dim, self.out_dim, 3, padding=1)) + + # zero out the last layer params + nn.init.zeros_(self.out[-1].weight) + + def forward( + self, + x, + t, + y=None, + depth=None, + image=None, + motion=None, + local_image=None, + single_sketch=None, + masked=None, + canny=None, + sketch=None, + histogram=None, + fps=None, + video_mask=None, + focus_present_mask=None, + prob_focus_present=0., + mask_last_frame_num=0 # mask last frame num + ): + + assert self.inpainting or masked is None, 'inpainting is not supported' + + batch, c, f, h, w = x.shape + device = x.device + self.batch = batch + + # image and video joint training, if mask_last_frame_num is set, prob_focus_present will be ignored + if mask_last_frame_num > 0: + focus_present_mask = None + video_mask[-mask_last_frame_num:] = False + else: + focus_present_mask = default( + focus_present_mask, lambda: prob_mask_like( + (batch, ), prob_focus_present, device=device)) + + if self.temporal_attention and not USE_TEMPORAL_TRANSFORMER: + time_rel_pos_bias = self.time_rel_pos_bias( + x.shape[2], device=x.device) + else: + time_rel_pos_bias = None + + # all-zero and all-keep masks + zero = torch.zeros(batch, dtype=torch.bool).to(x.device) + keep = torch.zeros(batch, dtype=torch.bool).to(x.device) + if self.training: + nzero = (torch.rand(batch) < self.p_all_zero).sum() + nkeep = (torch.rand(batch) < self.p_all_keep).sum() + index = torch.randperm(batch) + zero[index[0:nzero]] = True + keep[index[nzero:nzero + nkeep]] = True + assert not (zero & keep).any() + misc_dropout = partial(self.misc_dropout, zero=zero, keep=keep) + + concat = x.new_zeros(batch, self.concat_dim, f, h, w) + if depth is not None: + # DropPath mask + depth = rearrange(depth, 'b c f h w -> (b f) c h w') + depth = self.depth_embedding(depth) + h = depth.shape[2] + depth = self.depth_embedding_after( + rearrange(depth, '(b f) c h w -> (b h w) f c', b=batch)) + + # + depth = rearrange(depth, '(b h w) f c -> b c f h w', b=batch, h=h) + concat = concat + misc_dropout(depth) + + # local_image_embedding + if local_image is not None: + local_image = rearrange(local_image, 'b c f h w -> (b f) c h w') + local_image = self.local_image_embedding(local_image) + + h = local_image.shape[2] + local_image = self.local_image_embedding_after( + rearrange(local_image, '(b f) c h w -> (b h w) f c', b=batch)) + local_image = rearrange( + local_image, '(b h w) f c -> b c f h w', b=batch, h=h) + + concat = concat + misc_dropout(local_image) + + if motion is not None: + motion = rearrange(motion, 'b c f h w -> (b f) c h w') + motion = self.motion_embedding(motion) + + h = motion.shape[2] + motion = self.motion_embedding_after( + rearrange(motion, '(b f) c h w -> (b h w) f c', b=batch)) + motion = rearrange( + motion, '(b h w) f c -> b c f h w', b=batch, h=h) + + if hasattr(self.cfg, 'p_zero_motion_alone' + ) and self.cfg.p_zero_motion_alone and self.training: + motion_d = torch.rand(batch) < self.cfg.p_zero_motion + motion_d = motion_d[:, None, None, None, None] + motion = motion.masked_fill(motion_d.cuda(), 0) + concat = concat + motion + else: + concat = concat + misc_dropout(motion) + + if canny is not None: + # DropPath mask + canny = rearrange(canny, 'b c f h w -> (b f) c h w') + canny = self.canny_embedding(canny) + + h = canny.shape[2] + canny = self.canny_embedding_after( + rearrange(canny, '(b f) c h w -> (b h w) f c', b=batch)) + canny = rearrange(canny, '(b h w) f c -> b c f h w', b=batch, h=h) + + concat = concat + misc_dropout(canny) + + if sketch is not None: + # DropPath mask + sketch = rearrange(sketch, 'b c f h w -> (b f) c h w') + sketch = self.sketch_embedding(sketch) + + h = sketch.shape[2] + sketch = self.sketch_embedding_after( + rearrange(sketch, '(b f) c h w -> (b h w) f c', b=batch)) + sketch = rearrange( + sketch, '(b h w) f c -> b c f h w', b=batch, h=h) + + concat = concat + misc_dropout(sketch) + + if single_sketch is not None: + # DropPath mask + single_sketch = rearrange(single_sketch, + 'b c f h w -> (b f) c h w') + single_sketch = self.single_sketch_embedding(single_sketch) + + h = single_sketch.shape[2] + single_sketch = self.single_sketch_embedding_after( + rearrange( + single_sketch, '(b f) c h w -> (b h w) f c', b=batch)) + single_sketch = rearrange( + single_sketch, '(b h w) f c -> b c f h w', b=batch, h=h) + + concat = concat + misc_dropout(single_sketch) + + if masked is not None: + # DropPath mask + masked = rearrange(masked, 'b c f h w -> (b f) c h w') + masked = self.masked_embedding(masked) + + h = masked.shape[2] + masked = self.mask_embedding_after( + rearrange(masked, '(b f) c h w -> (b h w) f c', b=batch)) + masked = rearrange( + masked, '(b h w) f c -> b c f h w', b=batch, h=h) + + concat = concat + misc_dropout(masked) + + x = torch.cat([x, concat], dim=1) + x = rearrange(x, 'b c f h w -> (b f) c h w') + x = self.pre_image(x) + x = rearrange(x, '(b f) c h w -> b c f h w', b=batch) + + # embeddings + if self.use_fps_condition and fps is not None: + e = self.time_embed(sinusoidal_embedding( + t, self.dim)) + self.fps_embedding( + sinusoidal_embedding(fps, self.dim)) + else: + e = self.time_embed(sinusoidal_embedding(t, self.dim)) + + context = x.new_zeros(batch, 0, self.context_dim) + if y is not None: + y_context = misc_dropout(y) + context = torch.cat([context, y_context], dim=1) + else: + y_context = self.zero_y.repeat(batch, 1, 1) + context = torch.cat([context, y_context], dim=1) + + if image is not None: + image_context = misc_dropout(self.pre_image_condition(image)) + context = torch.cat([context, image_context], dim=1) + + # repeat f times for spatial e and context + e = e.repeat_interleave(repeats=f, dim=0) + context = context.repeat_interleave(repeats=f, dim=0) + + # always in shape (b f) c h w, except for temporal layer + x = rearrange(x, 'b c f h w -> (b f) c h w') + # encoder + xs = [] + for block in self.input_blocks: + x = self._forward_single(block, x, e, context, time_rel_pos_bias, + focus_present_mask, video_mask) + xs.append(x) + + # middle + for block in self.middle_block: + x = self._forward_single(block, x, e, context, time_rel_pos_bias, + focus_present_mask, video_mask) + + # decoder + for block in self.output_blocks: + x = torch.cat([x, xs.pop()], dim=1) + x = self._forward_single( + block, + x, + e, + context, + time_rel_pos_bias, + focus_present_mask, + video_mask, + reference=xs[-1] if len(xs) > 0 else None) + + # head + x = self.out(x) + + # reshape back to (b c f h w) + x = rearrange(x, '(b f) c h w -> b c f h w', b=batch) + return x + + def _forward_single(self, + module, + x, + e, + context, + time_rel_pos_bias, + focus_present_mask, + video_mask, + reference=None): + if isinstance(module, ResidualBlock): + module = checkpoint_wrapper( + module) if self.use_checkpoint else module + x = x.contiguous() + x = module(x, e, reference) + elif isinstance(module, ResBlock): + module = checkpoint_wrapper( + module) if self.use_checkpoint else module + x = x.contiguous() + x = module(x, e, self.batch) + elif isinstance(module, SpatialTransformer): + module = checkpoint_wrapper( + module) if self.use_checkpoint else module + x = module(x, context) + elif isinstance(module, TemporalTransformer): + module = checkpoint_wrapper( + module) if self.use_checkpoint else module + x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch) + x = module(x, context) + x = rearrange(x, 'b c f h w -> (b f) c h w') + elif isinstance(module, CrossAttention): + module = checkpoint_wrapper( + module) if self.use_checkpoint else module + x = module(x, context) + elif isinstance(module, BasicTransformerBlock): + module = checkpoint_wrapper( + module) if self.use_checkpoint else module + x = module(x, context) + elif isinstance(module, FeedForward): + x = module(x, context) + elif isinstance(module, Upsample): + x = module(x) + elif isinstance(module, Downsample): + x = module(x) + elif isinstance(module, Resample): + x = module(x, reference) + elif isinstance(module, TemporalAttentionBlock): + module = checkpoint_wrapper( + module) if self.use_checkpoint else module + x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch) + x = module(x, time_rel_pos_bias, focus_present_mask, video_mask) + x = rearrange(x, 'b c f h w -> (b f) c h w') + elif isinstance(module, TemporalAttentionMultiBlock): + module = checkpoint_wrapper( + module) if self.use_checkpoint else module + x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch) + x = module(x, time_rel_pos_bias, focus_present_mask, video_mask) + x = rearrange(x, 'b c f h w -> (b f) c h w') + elif isinstance(module, InitTemporalConvBlock): + module = checkpoint_wrapper( + module) if self.use_checkpoint else module + x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch) + x = module(x) + x = rearrange(x, 'b c f h w -> (b f) c h w') + elif isinstance(module, TemporalConvBlock): + module = checkpoint_wrapper( + module) if self.use_checkpoint else module + x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch) + x = module(x) + x = rearrange(x, 'b c f h w -> (b f) c h w') + elif isinstance(module, nn.ModuleList): + for block in module: + x = self._forward_single(block, x, e, context, + time_rel_pos_bias, focus_present_mask, + video_mask, reference) + else: + x = module(x) + return x + + +class PreNormattention(nn.Module): + + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + + def forward(self, x, **kwargs): + return self.fn(self.norm(x), **kwargs) + x + + +class PreNormattention_qkv(nn.Module): + + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + + def forward(self, q, k, v, **kwargs): + return self.fn(self.norm(q), self.norm(k), self.norm(v), **kwargs) + q + + +class Attention(nn.Module): + + def __init__(self, dim, heads=8, dim_head=64, dropout=0.): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head**-0.5 + + self.attend = nn.Softmax(dim=-1) + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout)) if project_out else nn.Identity() + + def forward(self, x): + _, _, _, h = *x.shape, self.heads + qkv = self.to_qkv(x).chunk(3, dim=-1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), qkv) + + dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale + + attn = self.attend(dots) + + out = einsum('b h i j, b h j d -> b h i d', attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + + +class Attention_qkv(nn.Module): + + def __init__(self, dim, heads=8, dim_head=64, dropout=0.): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head**-0.5 + + self.attend = nn.Softmax(dim=-1) + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_k = nn.Linear(dim, inner_dim, bias=False) + self.to_v = nn.Linear(dim, inner_dim, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout)) if project_out else nn.Identity() + + def forward(self, q, k, v): + _, _, _, h = *q.shape, self.heads + bk = k.shape[0] + q = self.to_q(q) + k = self.to_k(k) + v = self.to_v(v) + q = rearrange(q, 'b n (h d) -> b h n d', h=h) + k = rearrange(k, 'b n (h d) -> b h n d', b=bk, h=h) + v = rearrange(v, 'b n (h d) -> b h n d', b=bk, h=h) + + dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale + + attn = self.attend(dots) + + out = einsum('b h i j, b h j d -> b h i d', attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + + +class PostNormattention(nn.Module): + + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + + def forward(self, x, **kwargs): + return self.norm(self.fn(x, **kwargs) + x) + + +class Transformer_v2(nn.Module): + + def __init__(self, + heads=8, + dim=2048, + dim_head_k=256, + dim_head_v=256, + dropout_atte=0.05, + mlp_dim=2048, + dropout_ffn=0.05, + depth=1): + super().__init__() + self.layers = nn.ModuleList([]) + self.depth = depth + for _ in range(depth): + self.layers.append( + nn.ModuleList([ + PreNormattention( + dim, + Attention( + dim, + heads=heads, + dim_head=dim_head_k, + dropout=dropout_atte)), + FeedForward(dim, mlp_dim, dropout=dropout_ffn), + ])) + + def forward(self, x): + for attn, ff in self.layers[:1]: + x = attn(x) + x = ff(x) + x + if self.depth > 1: + for attn, ff in self.layers[1:]: + x = attn(x) + x = ff(x) + x + return x + + +class DropPath(nn.Module): + r"""DropPath but without rescaling and supports optional all-zero and/or all-keep. + """ + + def __init__(self, p): + super(DropPath, self).__init__() + self.p = p + + def forward(self, *args, zero=None, keep=None): + if not self.training: + return args[0] if len(args) == 1 else args + + # params + x = args[0] + b = x.size(0) + n = (torch.rand(b) < self.p).sum() + + # non-zero and non-keep mask + mask = x.new_ones(b, dtype=torch.bool) + if keep is not None: + mask[keep] = False + if zero is not None: + mask[zero] = False + + # drop-path index + index = torch.where(mask)[0] + index = index[torch.randperm(len(index))[:n]] + if zero is not None: + index = torch.cat([index, torch.where(zero)[0]], dim=0) + + # drop-path multiplier + multiplier = x.new_ones(b) + multiplier[index] = 0.0 + output = tuple(u * self.broadcast(multiplier, u) for u in args) + return output[0] if len(args) == 1 else output + + def broadcast(self, src, dst): + assert src.size(0) == dst.size(0) + shape = (dst.size(0), ) + (1, ) * (dst.ndim - 1) + return src.view(shape) + + +if __name__ == '__main__': + from config import cfg + # [model] unet + model = UNetSD_temporal( + in_dim=cfg.unet_in_dim, + dim=cfg.unet_dim, + y_dim=cfg.unet_y_dim, + context_dim=cfg.unet_context_dim, + out_dim=cfg.unet_out_dim, + dim_mult=cfg.unet_dim_mult, + num_heads=cfg.unet_num_heads, + head_dim=cfg.unet_head_dim, + num_res_blocks=cfg.unet_res_blocks, + attn_scales=cfg.unet_attn_scales, + dropout=cfg.unet_dropout, + temporal_attn_times=0, + use_checkpoint=cfg.use_checkpoint, + use_image_dataset=True, + use_fps_condition=cfg.use_fps_condition) + + print( + int(sum(p.numel() for k, p in model.named_parameters()) / (1024**2)), + 'M parameters') diff --git a/modelscope/models/multi_modal/videocomposer/utils/__init__.py b/modelscope/models/multi_modal/videocomposer/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/multi_modal/videocomposer/utils/config.py b/modelscope/models/multi_modal/videocomposer/utils/config.py new file mode 100644 index 00000000..18424257 --- /dev/null +++ b/modelscope/models/multi_modal/videocomposer/utils/config.py @@ -0,0 +1,273 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import argparse +import copy +import os + +import json +import yaml + +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +def setup_seed(seed): + print('Seed: ', seed) + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +class Config(object): + + def __init__(self, load=True, cfg_dict=None, cfg_level=None): + self._level = 'cfg' + ('.' + + cfg_level if cfg_level is not None else '') + if load: + self.args = self._parse_args() + logger.info('Loading config from {}.'.format(self.args.cfg_file)) + self.need_initialization = True + cfg_base = self._initialize_cfg() + cfg_dict = self._load_yaml(self.args) + cfg_dict = self._merge_cfg_from_base(cfg_base, cfg_dict) + cfg_dict = self._update_from_args(cfg_dict) + self.cfg_dict = cfg_dict + self._update_dict(cfg_dict) + + def _parse_args(self): + parser = argparse.ArgumentParser( + description= + 'Argparser for configuring [code base name to think of] codebase') + parser.add_argument( + '--cfg', + dest='cfg_file', + help='Path to the configuration file', + default= + './modelscope/models/multi_modal/videocomposer/configs/exp06_text_depths_vs_style.yaml' + ) + parser.add_argument( + '--init_method', + help='Initialization method, includes TCP or shared file-system', + default='tcp://localhost:9999', + type=str, + ) + parser.add_argument( + '--seed', + type=int, + default=8888, + help='Need to explore for different videos') + parser.add_argument( + '--debug', + action='store_true', + default=False, + help='Into debug information') + parser.add_argument( + '--input_video', + default='demo_video/video_8800.mp4', + help='input video for full task, or motion vector of input videos', + type=str, + ), + parser.add_argument( + '--image_path', default='', help='Single Image Input', type=str) + parser.add_argument( + '--sketch_path', default='', help='Single Sketch Input', type=str) + parser.add_argument( + '--style_image', help='Single Sketch Input', type=str) + parser.add_argument( + '--input_text_desc', + default= + 'A colorful and beautiful fish swimming in a small glass bowl with \ + multicolored piece of stone, Macro Video', + type=str, + ), + parser.add_argument( + 'opts', + help='other configurations', + default=None, + nargs=argparse.REMAINDER) + return parser.parse_args() + + def _path_join(self, path_list): + path = '' + for p in path_list: + path += p + '/' + return path[:-1] + + def _update_from_args(self, cfg_dict): + args = self.args + for var in vars(args): + cfg_dict[var] = getattr(args, var) + return cfg_dict + + def _initialize_cfg(self): + if self.need_initialization: + self.need_initialization = False + if os.path.exists( + './modelscope/models/multi_modal/videocomposer/configs/base.yaml' + ): + with open( + './modelscope/models/multi_modal/videocomposer/configs/base.yaml', + 'r') as f: + cfg = yaml.load(f.read(), Loader=yaml.SafeLoader) + else: + with open( + './modelscope/models/multi_modal/videocomposer/configs/base.yaml', + 'r') as f: + cfg = yaml.load(f.read(), Loader=yaml.SafeLoader) + return cfg + + def _load_yaml(self, args, file_name=''): + assert args.cfg_file is not None + if not file_name == '': # reading from base file + with open(file_name, 'r') as f: + cfg = yaml.load(f.read(), Loader=yaml.SafeLoader) + else: + if os.getcwd().split('/')[-1] == args.cfg_file.split('/')[0]: + args.cfg_file = args.cfg_file.replace( + os.getcwd().split('/')[-1], './') + try: + with open(args.cfg_file, 'r') as f: + cfg = yaml.load(f.read(), Loader=yaml.SafeLoader) + file_name = args.cfg_file + except Exception as e: + args.cfg_file = os.path.realpath(__file__).split( + '/')[-3] + '/' + args.cfg_file + with open(args.cfg_file, 'r') as f: + cfg = yaml.load(f.read(), Loader=yaml.SafeLoader) + file_name = args.cfg_file + print(e) + + if '_BASE_RUN' not in cfg.keys() and '_BASE_MODEL' not in cfg.keys( + ) and '_BASE' not in cfg.keys(): + return cfg + + if '_BASE' in cfg.keys(): + if cfg['_BASE'][1] == '.': + prev_count = cfg['_BASE'].count('..') + cfg_base_file = self._path_join( + file_name.split('/')[:(-1 - cfg['_BASE'].count('..'))] + + cfg['_BASE'].split('/')[prev_count:]) + else: + cfg_base_file = cfg['_BASE'].replace( + './', + args.cfg_file.replace(args.cfg_file.split('/')[-1], '')) + cfg_base = self._load_yaml(args, cfg_base_file) + cfg = self._merge_cfg_from_base(cfg_base, cfg) + else: + if '_BASE_RUN' in cfg.keys(): + if cfg['_BASE_RUN'][1] == '.': + prev_count = cfg['_BASE_RUN'].count('..') + cfg_base_file = self._path_join( + file_name.split('/')[:(-1 - prev_count)] + + cfg['_BASE_RUN'].split('/')[prev_count:]) + else: + cfg_base_file = cfg['_BASE_RUN'].replace( + './', + args.cfg_file.replace( + args.cfg_file.split('/')[-1], '')) + cfg_base = self._load_yaml(args, cfg_base_file) + cfg = self._merge_cfg_from_base( + cfg_base, cfg, preserve_base=True) + if '_BASE_MODEL' in cfg.keys(): + if cfg['_BASE_MODEL'][1] == '.': + prev_count = cfg['_BASE_MODEL'].count('..') + cfg_base_file = self._path_join( + file_name.split('/')[:( + -1 - cfg['_BASE_MODEL'].count('..'))] + + cfg['_BASE_MODEL'].split('/')[prev_count:]) + else: + cfg_base_file = cfg['_BASE_MODEL'].replace( + './', + args.cfg_file.replace( + args.cfg_file.split('/')[-1], '')) + cfg_base = self._load_yaml(args, cfg_base_file) + cfg = self._merge_cfg_from_base(cfg_base, cfg) + cfg = self._merge_cfg_from_command(args, cfg) + return cfg + + def _merge_cfg_from_base(self, cfg_base, cfg_new, preserve_base=False): + for k, v in cfg_new.items(): + if k in cfg_base.keys(): + if isinstance(v, dict): + self._merge_cfg_from_base(cfg_base[k], v) + else: + cfg_base[k] = v + else: + if 'BASE' not in k or preserve_base: + cfg_base[k] = v + return cfg_base + + def _merge_cfg_from_command(self, args, cfg): + assert len( + args.opts) % 2 == 0, 'Override list {} has odd length: {}.'.format( + args.opts, len(args.opts)) + keys = args.opts[0::2] + vals = args.opts[1::2] + + # maximum supported depth 3 + for idx, key in enumerate(keys): + key_split = key.split('.') + assert len( + key_split + ) <= 4, 'Key depth error. \nMaximum depth: 3\n Get depth: {}'.format( + len(key_split)) + assert key_split[0] in cfg.keys(), 'Non-existant key: {}.'.format( + key_split[0]) + if len(key_split) == 2: + assert key_split[1] in cfg[ + key_split[0]].keys(), 'Non-existant key: {}.'.format(key) + elif len(key_split) == 3: + assert key_split[1] in cfg[ + key_split[0]].keys(), 'Non-existant key: {}.'.format(key) + assert key_split[2] in cfg[key_split[0]][ + key_split[1]].keys(), 'Non-existant key: {}.'.format(key) + elif len(key_split) == 4: + assert key_split[1] in cfg[ + key_split[0]].keys(), 'Non-existant key: {}.'.format(key) + assert key_split[2] in cfg[key_split[0]][ + key_split[1]].keys(), 'Non-existant key: {}.'.format(key) + assert key_split[3] in cfg[key_split[0]][key_split[1]][ + key_split[2]].keys(), 'Non-existant key: {}.'.format(key) + if len(key_split) == 1: + cfg[key_split[0]] = vals[idx] + elif len(key_split) == 2: + cfg[key_split[0]][key_split[1]] = vals[idx] + elif len(key_split) == 3: + cfg[key_split[0]][key_split[1]][key_split[2]] = vals[idx] + elif len(key_split) == 4: + cfg[key_split[0]][key_split[1]][key_split[2]][ + key_split[3]] = vals[idx] + return cfg + + def _update_dict(self, cfg_dict): + + def recur(key, elem): + if type(elem) is dict: + return key, Config(load=False, cfg_dict=elem, cfg_level=key) + else: + if type(elem) is str and elem[1:3] == 'e-': + elem = float(elem) + return key, elem + + dic = dict(recur(k, v) for k, v in cfg_dict.items()) + self.__dict__.update(dic) + + def get_args(self): + return self.args + + def __repr__(self): + return '{}\n'.format(self.dump()) + + def dump(self): + return json.dumps(self.cfg_dict, indent=2) + + def deep_copy(self): + return copy.deepcopy(self) + + +if __name__ == '__main__': + # debug + cfg = Config(load=True) + print(cfg.DATA) diff --git a/modelscope/models/multi_modal/videocomposer/utils/distributed.py b/modelscope/models/multi_modal/videocomposer/utils/distributed.py new file mode 100644 index 00000000..b622e700 --- /dev/null +++ b/modelscope/models/multi_modal/videocomposer/utils/distributed.py @@ -0,0 +1,297 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +"""Distributed helpers.""" + +import functools +import logging +import pickle + +import torch +import torch.distributed as dist + +_LOCAL_PROCESS_GROUP = None + + +def all_gather(tensors): + """ + All gathers the provided tensors from all processes across machines. + Args: + tensors (list): tensors to perform all gather across all processes in + all machines. + """ + + gather_list = [] + output_tensor = [] + world_size = dist.get_world_size() + for tensor in tensors: + tensor_placeholder = [ + torch.ones_like(tensor) for _ in range(world_size) + ] + dist.all_gather(tensor_placeholder, tensor, async_op=False) + gather_list.append(tensor_placeholder) + for gathered_tensor in gather_list: + output_tensor.append(torch.cat(gathered_tensor, dim=0)) + return output_tensor + + +def all_reduce(tensors, average=True): + """ + All reduce the provided tensors from all processes across machines. + Args: + tensors (list): tensors to perform all reduce across all processes in + all machines. + average (bool): scales the reduced tensor by the number of overall + processes across all machines. + """ + + for tensor in tensors: + dist.all_reduce(tensor, async_op=False) + if average: + world_size = dist.get_world_size() + for tensor in tensors: + tensor.mul_(1.0 / world_size) + return tensors + + +def init_process_group( + local_rank, + local_world_size, + shard_id, + num_shards, + init_method, + dist_backend='nccl', +): + """ + Initializes the default process group. + Args: + local_rank (int): the rank on the current local machine. + local_world_size (int): the world size (number of processes running) on + the current local machine. + shard_id (int): the shard index (machine rank) of the current machine. + num_shards (int): number of shards for distributed training. + init_method (string): supporting three different methods for + initializing process groups: + "file": use shared file system to initialize the groups across + different processes. + "tcp": use tcp address to initialize the groups across different + dist_backend (string): backend to use for distributed training. Options + includes gloo, mpi and nccl, the details can be found here: + https://pytorch.org/docs/stable/distributed.html + """ + # Sets the GPU to use. + torch.cuda.set_device(local_rank) + # Initialize the process group. + proc_rank = local_rank + shard_id * local_world_size + world_size = local_world_size * num_shards + dist.init_process_group( + backend=dist_backend, + init_method=init_method, + world_size=world_size, + rank=proc_rank, + ) + + +def is_master_proc(num_gpus=8): + """ + Determines if the current process is the master process. + """ + if torch.distributed.is_initialized(): + return dist.get_rank() % num_gpus == 0 + else: + return True + + +def get_world_size(): + """ + Get the size of the world. + """ + if not dist.is_available(): + return 1 + if not dist.is_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + """ + Get the rank of the current process. + """ + if not dist.is_available(): + return 0 + if not dist.is_initialized(): + return 0 + return dist.get_rank() + + +def synchronize(): + """ + Helper function to synchronize (barrier) among all processes when + using distributed training + """ + if not dist.is_available(): + return + if not dist.is_initialized(): + return + world_size = dist.get_world_size() + if world_size == 1: + return + dist.barrier() + + +@functools.lru_cache() +def _get_global_gloo_group(): + """ + Return a process group based on gloo backend, containing all the ranks + The result is cached. + Returns: + (group): pytorch dist group. + """ + if dist.get_backend() == 'nccl': + return dist.new_group(backend='gloo') + else: + return dist.group.WORLD + + +def _serialize_to_tensor(data, group): + """ + Seriialize the tensor to ByteTensor. Note that only `gloo` and `nccl` + backend is supported. + Args: + data (data): data to be serialized. + group (group): pytorch dist group. + Returns: + tensor (ByteTensor): tensor that serialized. + """ + + backend = dist.get_backend(group) + assert backend in ['gloo', 'nccl'] + device = torch.device('cpu' if backend == 'gloo' else 'cuda') + + buffer = pickle.dumps(data) + if len(buffer) > 1024**3: + logger = logging.getLogger(__name__) + logger.warning( + 'Rank {} trying to all-gather {:.2f} GB of data on device {}'. + format(get_rank(), + len(buffer) / (1024**3), device)) + storage = torch.ByteStorage.from_buffer(buffer) + tensor = torch.ByteTensor(storage).to(device=device) + return tensor + + +def _pad_to_largest_tensor(tensor, group): + """ + Padding all the tensors from different GPUs to the largest ones. + Args: + tensor (tensor): tensor to pad. + group (group): pytorch dist group. + Returns: + list[int]: size of the tensor, on each rank + Tensor: padded tensor that has the max size + """ + world_size = dist.get_world_size(group=group) + assert ( + world_size >= 1 + ), 'comm.gather/all_gather must be called from ranks within the given group!' + local_size = torch.tensor([tensor.numel()], + dtype=torch.int64, + device=tensor.device) + size_list = [ + torch.zeros([1], dtype=torch.int64, device=tensor.device) + for _ in range(world_size) + ] + dist.all_gather(size_list, local_size, group=group) + size_list = [int(size.item()) for size in size_list] + + max_size = max(size_list) + + # we pad the tensor because torch all_gather does not support + # gathering tensors of different shapes + if local_size != max_size: + padding = torch.zeros((max_size - local_size, ), + dtype=torch.uint8, + device=tensor.device) + tensor = torch.cat((tensor, padding), dim=0) + return size_list, tensor + + +def all_gather_unaligned(data, group=None): + """ + Run all_gather on arbitrary picklable data (not necessarily tensors). + + Args: + data: any picklable object + group: a torch process group. By default, will use a group which + contains all ranks on gloo backend. + + Returns: + list[data]: list of data gathered from each rank + """ + if get_world_size() == 1: + return [data] + if group is None: + group = _get_global_gloo_group() + if dist.get_world_size(group) == 1: + return [data] + + tensor = _serialize_to_tensor(data, group) + + size_list, tensor = _pad_to_largest_tensor(tensor, group) + max_size = max(size_list) + + # receiving Tensor from all ranks + tensor_list = [ + torch.empty((max_size, ), dtype=torch.uint8, device=tensor.device) + for _ in size_list + ] + dist.all_gather(tensor_list, tensor, group=group) + + data_list = [] + for size, tensor in zip(size_list, tensor_list): + buffer = tensor.cpu().numpy().tobytes()[:size] + data_list.append(pickle.loads(buffer)) + + return data_list + + +def init_distributed_training(cfg): + """ + Initialize variables needed for distributed training. + """ + if cfg.NUM_GPUS <= 1: + return + num_gpus_per_machine = cfg.NUM_GPUS + num_machines = dist.get_world_size() // num_gpus_per_machine + for i in range(num_machines): + ranks_on_i = list( + range(i * num_gpus_per_machine, (i + 1) * num_gpus_per_machine)) + pg = dist.new_group(ranks_on_i) + if i == cfg.SHARD_ID: + global _LOCAL_PROCESS_GROUP + _LOCAL_PROCESS_GROUP = pg + + +def get_local_size() -> int: + """ + Returns: + The size of the per-machine process group, + i.e. the number of processes per machine. + """ + if not dist.is_available(): + return 1 + if not dist.is_initialized(): + return 1 + return dist.get_world_size(group=_LOCAL_PROCESS_GROUP) + + +def get_local_rank() -> int: + """ + Returns: + The rank of the current process within the local (per-machine) process group. + """ + if not dist.is_available(): + return 0 + if not dist.is_initialized(): + return 0 + assert _LOCAL_PROCESS_GROUP is not None + return dist.get_rank(group=_LOCAL_PROCESS_GROUP) diff --git a/modelscope/models/multi_modal/videocomposer/utils/utils.py b/modelscope/models/multi_modal/videocomposer/utils/utils.py new file mode 100644 index 00000000..d01b3dd2 --- /dev/null +++ b/modelscope/models/multi_modal/videocomposer/utils/utils.py @@ -0,0 +1,955 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import base64 +import binascii +import copy +import glob +import gzip +import hashlib +import logging +import math +import os +import os.path as osp +import pickle +import random +import sys +import time +import urllib.request +import zipfile +from io import BytesIO +from multiprocessing.pool import ThreadPool as Pool + +import imageio +import json +import numpy as np +import oss2 as oss +import requests +import skvideo.io +import torch +import torch.nn.functional as F +import torchvision.utils as tvutils +from einops import rearrange +from PIL import Image + +__all__ = [ + 'parse_oss_url', 'parse_bucket', 'read', 'read_image', 'read_gzip', + 'ceil_divide', 'to_device', 'put_object', 'put_torch_object', + 'put_object_from_file', 'get_object', 'get_object_to_file', 'rand_name', + 'save_image', 'save_video', 'save_video_vs_conditions', + 'save_video_multiple_conditions_with_data', + 'save_video_multiple_conditions', 'download_video_to_file', + 'save_video_grid_mp4', 'save_caps', 'ema', 'parallel', 'exists', + 'download', 'unzip', 'load_state_dict', 'inverse_indices', + 'detect_duplicates', 'read_tfs', 'md5', 'rope', 'format_state', + 'breakup_grid', 'huggingface_tokenizer', 'huggingface_model' +] + +TFS_CLIENT = None + + +def DOWNLOAD_TO_CACHE(oss_key, + file_or_dirname=None, + cache_dir=osp.join( + '/'.join(osp.abspath(__file__).split('/')[:-2]), + 'model_weights')): + r"""Download OSS [file or folder] to the cache folder. + Only the 0th process on each node will run the downloading. + Barrier all processes until the downloading is completed. + """ + # source and target paths + base_path = osp.join(cache_dir, file_or_dirname or osp.basename(oss_key)) + + return base_path + + +def find_free_port(): + """https://stackoverflow.com/questions/1365265/on-localhost-how-do-i-pick-a-free-port-number""" + import socket + from contextlib import closing + with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: + s.bind(('', 0)) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + return str(s.getsockname()[1]) + + +def setup_seed(seed): + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + torch.backends.cudnn.deterministic = True + + +def parse_oss_url(path): + if path.startswith('oss://'): + path = path[len('oss://'):] + + # configs + configs = { + 'endpoint': os.getenv('OSS_ENDPOINT', None), + 'accessKeyID': os.getenv('OSS_ACCESS_KEY_ID', None), + 'accessKeySecret': os.getenv('OSS_ACCESS_KEY_SECRET', None), + 'securityToken': os.getenv('OSS_SECURITY_TOKEN', None) + } + bucket, path = path.split('/', maxsplit=1) + if '?' in bucket: + bucket, config = bucket.split('?', maxsplit=1) + for pair in config.split('&'): + k, v = pair.split('=', maxsplit=1) + configs[k] = v + + # session + session = parse_oss_url._sessions.setdefault(f'{bucket}@{os.getpid()}', + oss.Session()) + + # bucket + bucket = oss.Bucket( + auth=oss.Auth(configs['accessKeyID'], configs['accessKeySecret']), + endpoint=configs['endpoint'], + bucket_name=bucket, + session=session) + return bucket, path + + +parse_oss_url._sessions = {} + + +def parse_bucket(url): + return parse_oss_url(osp.join(url, '_placeholder'))[0] + + +def read(filename, mode='r', retry=5): + assert mode in ['r', 'rb'] + exception = None + for _ in range(retry): + try: + if filename.startswith('oss://'): + bucket, path = parse_oss_url(filename) + content = bucket.get_object(path).read() + if mode == 'r': + content = content.decode('utf-8') + elif filename.startswith('http'): + content = requests.get(filename).content + if mode == 'r': + content = content.decode('utf-8') + else: + with open(filename, mode=mode) as f: + content = f.read() + return content + except Exception as e: + exception = e + continue + else: + raise exception + + +def read_image(filename, retry=5): + exception = None + for _ in range(retry): + try: + return Image.open(BytesIO(read(filename, mode='rb', retry=retry))) + except Exception as e: + exception = e + continue + else: + raise exception + + +def download_video_to_file(filename, local_file, retry=5): + exception = None + for _ in range(retry): + try: + bucket, path = parse_oss_url(filename) + bucket.get_object_to_file(path, local_file) + break + except Exception as e: + exception = e + continue + else: + raise exception + + +def read_gzip(filename, retry=5): + exception = None + for _ in range(retry): + try: + remove = False + if filename.startswith('oss://'): + bucket, path = parse_oss_url(filename) + filename = rand_name(suffix=osp.splitext(filename)[1]) + bucket.get_object_to_file(path, filename) + remove = True + with gzip.open(filename) as f: + content = f.read() + if remove: + os.remove(filename) + return content + except Exception as e: + exception = e + continue + else: + raise exception + + +def ceil_divide(a, b): + return int(math.ceil(a / b)) + + +def to_device(batch, device, non_blocking=False): + if isinstance(batch, (list, tuple)): + return type(batch)([to_device(u, device, non_blocking) for u in batch]) + elif isinstance(batch, dict): + return type(batch)([(k, to_device(v, device, non_blocking)) + for k, v in batch.items()]) + elif isinstance(batch, torch.Tensor) and batch.device != device: + batch = batch.to(device, non_blocking=non_blocking) + return batch + + +def put_object(bucket, oss_key, data, retry=5): + exception = None + for _ in range(retry): + try: + return bucket.put_object(oss_key, data) + except Exception as e: + exception = e + continue + else: + print( + f'put_object to {oss_key} failed with error: {exception}', + flush=True) + + +def put_torch_object(bucket, oss_key, data, retry=5): + exception = None + for _ in range(retry): + try: + buffer = BytesIO() + torch.save(data, buffer) + return bucket.put_object(oss_key, buffer.getvalue()) + except Exception as e: + exception = e + continue + else: + print( + f'put_torch_object to {oss_key} failed with error: {exception}', + flush=True) + + +def put_object_from_file(bucket, oss_key, filename, retry=5): + exception = None + for _ in range(retry): + try: + return bucket.put_object_from_file(oss_key, filename) + except Exception as e: + exception = e + continue + else: + print( + f'put_object_from_file to {oss_key} failed with error: {exception}', + flush=True) + + +def get_object(bucket, oss_key, retry=5): + exception = None + for _ in range(retry): + try: + return bucket.get_object(oss_key).read() + except Exception as e: + exception = e + continue + else: + print( + f'get_object from {oss_key} failed with error: {exception}', + flush=True) + + +def get_object_to_file(bucket, oss_key, filename, retry=5): + exception = None + for _ in range(retry): + try: + return bucket.get_object_to_file(oss_key, filename) + except Exception as e: + exception = e + continue + else: + print( + f'get_object_to_file from {oss_key} failed with error: {exception}', + flush=True) + + +def rand_name(length=8, suffix=''): + name = binascii.b2a_hex(os.urandom(length)).decode('utf-8') + if suffix: + if not suffix.startswith('.'): + suffix = '.' + suffix + name += suffix + return name + + +@torch.no_grad() +def save_image(bucket, + oss_key, + tensor, + nrow=8, + normalize=True, + range=(-1, 1), + retry=5): + filename = rand_name(suffix='.jpg') + for _ in [None] * retry: + try: + tvutils.save_image( + tensor, filename, nrow=nrow, normalize=normalize, range=range) + bucket.put_object_from_file(oss_key, filename) + exception = None + break + except Exception as e: + exception = e + continue + + # remove temporary file + if osp.exists(filename): + os.remove(filename) + if exception is not None: + print( + 'save image to {} failed, error: {}'.format(oss_key, exception), + flush=True) + + +@torch.no_grad() +def video_tensor_to_gif(tensor, path, duration=120, loop=0, optimize=True): + tensor = tensor.permute(1, 2, 3, 0) + images = tensor.unbind(dim=0) + images = [(image.numpy() * 255).astype('uint8') for image in images] + imageio.mimwrite(path, images, fps=8) + return images + + +@torch.no_grad() +def save_video(bucket, + oss_key, + tensor, + mean=[0.5, 0.5, 0.5], + std=[0.5, 0.5, 0.5], + nrow=8, + retry=5): + mean = torch.tensor(mean, device=tensor.device).view(1, -1, 1, 1, 1) + std = torch.tensor(std, device=tensor.device).view(1, -1, 1, 1, 1) + tensor = tensor.mul_(std).add_(mean) + tensor.clamp_(0, 1) + + filename = rand_name(suffix='.gif') + for _ in [None] * retry: + try: + one_gif = rearrange( + tensor, '(i j) c f h w -> c f (i h) (j w)', i=nrow) + video_tensor_to_gif(one_gif, filename) + bucket.put_object_from_file(oss_key, filename) + exception = None + break + except Exception as e: + exception = e + continue + + # remove temporary file + if osp.exists(filename): + os.remove(filename) + if exception is not None: + print( + 'save video to {} failed, error: {}'.format(oss_key, exception), + flush=True) + + +@torch.no_grad() +def save_video_multiple_conditions(oss_key, + video_tensor, + model_kwargs, + source_imgs, + palette, + mean=[0.5, 0.5, 0.5], + std=[0.5, 0.5, 0.5], + nrow=8, + retry=5, + save_origin_video=True, + bucket=None): + mean = torch.tensor(mean, device=video_tensor.device).view(1, -1, 1, 1, 1) + std = torch.tensor(std, device=video_tensor.device).view(1, -1, 1, 1, 1) + video_tensor = video_tensor.mul_(std).add_(mean) + try: + video_tensor.clamp_(0, 1) + except Exception as e: + video_tensor = video_tensor.float().clamp_(0, 1) + print(e) + video_tensor = video_tensor.cpu() + + b, c, n, h, w = video_tensor.shape + source_imgs = F.adaptive_avg_pool3d(source_imgs, (n, h, w)) + source_imgs = source_imgs.cpu() + + model_kwargs_channel3 = {} + for key, conditions in model_kwargs[0].items(): + if conditions.shape[-1] == 1024: + # Skip for style embeding + continue + if len(conditions.shape) == 3: + conditions_np = conditions.cpu().numpy() + conditions = [] + for i in conditions_np: + vis_i = [] + for j in i: + vis_i.append( + palette.get_palette_image( + j, percentile=90, width=256, height=256)) + conditions.append(np.stack(vis_i)) + conditions = torch.from_numpy(np.stack(conditions)) + conditions = rearrange(conditions, 'b n h w c -> b c n h w') + else: + if conditions.size(1) == 1: + conditions = torch.cat([conditions, conditions, conditions], + dim=1) + conditions = F.adaptive_avg_pool3d(conditions, (n, h, w)) + if conditions.size(1) == 2: + conditions = torch.cat([conditions, conditions[:, :1, ]], + dim=1) + conditions = F.adaptive_avg_pool3d(conditions, (n, h, w)) + elif conditions.size(1) == 3: + conditions = F.adaptive_avg_pool3d(conditions, (n, h, w)) + elif conditions.size(1) == 4: + color = ((conditions[:, 0:3] + 1.) / 2.) + alpha = conditions[:, 3:4] + conditions = color * alpha + 1.0 * (1.0 - alpha) + conditions = F.adaptive_avg_pool3d(conditions, (n, h, w)) + model_kwargs_channel3[key] = conditions.cpu( + ) if conditions.is_cuda else conditions + + filename = oss_key + for _ in [None] * retry: + try: + vid_gif = rearrange( + video_tensor, '(i j) c f h w -> c f (i h) (j w)', i=nrow) + cons_list = [ + rearrange(con, '(i j) c f h w -> c f (i h) (j w)', i=nrow) + for _, con in model_kwargs_channel3.items() + ] + source_imgs = rearrange( + source_imgs, '(i j) c f h w -> c f (i h) (j w)', i=nrow) + + if save_origin_video: + vid_gif = torch.cat( + [ + source_imgs, + ] + cons_list + [ + vid_gif, + ], dim=3) + else: + vid_gif = torch.cat( + cons_list + [ + vid_gif, + ], dim=3) + + video_tensor_to_gif(vid_gif, filename) + exception = None + break + except Exception as e: + exception = e + continue + if exception is not None: + logging.info('save video to {} failed, error: {}'.format( + oss_key, exception)) + + +@torch.no_grad() +def save_video_multiple_conditions_with_data(bucket, + video_save_key, + gt_video_save_key, + vis_oss_key, + video_tensor, + model_kwargs, + source_imgs, + palette, + mean=[0.5, 0.5, 0.5], + std=[0.5, 0.5, 0.5], + nrow=8, + retry=5): + mean = torch.tensor(mean, device=video_tensor.device).view(1, -1, 1, 1, 1) + std = torch.tensor(std, device=video_tensor.device).view(1, -1, 1, 1, 1) + video_tensor = video_tensor.mul_(std).add_(mean) + video_tensor.clamp_(0, 1) + + b, c, n, h, w = video_tensor.shape + source_imgs = F.adaptive_avg_pool3d(source_imgs, (n, h, w)) + source_imgs = source_imgs.cpu() + + model_kwargs_channel3 = {} + for key, conditions in model_kwargs[0].items(): + if len(conditions.shape) == 3: + conditions_np = conditions.cpu().numpy() + conditions = [] + for i in conditions_np: + vis_i = [] + for j in i: + vis_i.append( + palette.get_palette_image( + j, percentile=90, width=256, height=256)) + conditions.append(np.stack(vis_i)) + conditions = torch.from_numpy(np.stack(conditions)) + conditions = rearrange(conditions, 'b n h w c -> b c n h w') + else: + if conditions.size(1) == 1: + conditions = torch.cat([conditions, conditions, conditions], + dim=1) + conditions = F.adaptive_avg_pool3d(conditions, (n, h, w)) + if conditions.size(1) == 2: + conditions = torch.cat([conditions, conditions[:, :1, ]], + dim=1) + conditions = F.adaptive_avg_pool3d(conditions, (n, h, w)) + elif conditions.size(1) == 3: + conditions = F.adaptive_avg_pool3d(conditions, (n, h, w)) + elif conditions.size(1) == 4: + color = ((conditions[:, 0:3] + 1.) / 2.) + alpha = conditions[:, 3:4] + conditions = color * alpha + 1.0 * (1.0 - alpha) + conditions = F.adaptive_avg_pool3d(conditions, (n, h, w)) + model_kwargs_channel3[key] = conditions.cpu( + ) if conditions.is_cuda else conditions + + copy_video_tensor = video_tensor.clone() + copy_source_imgs = source_imgs.clone() + + filename = rand_name(suffix='.gif') + for _ in [None] * retry: + try: + vid_gif = rearrange( + video_tensor, '(i j) c f h w -> c f (i h) (j w)', i=nrow) + cons_list = [ + rearrange(con, '(i j) c f h w -> c f (i h) (j w)', j=nrow) + for _, con in model_kwargs_channel3.items() + ] + source_imgs = rearrange( + source_imgs, '(i j) c f h w -> c f (i h) (j w)', i=nrow) + vid_gif = torch.cat( + [ + source_imgs, + ] + cons_list + [ + vid_gif, + ], dim=3) + + video_tensor_to_gif(vid_gif, filename) + bucket.put_object_from_file(vis_oss_key, filename) + exception = None + break + except Exception as e: + exception = e + continue + + # remove temporary file + if osp.exists(filename): + os.remove(filename) + + filename_pred = rand_name(suffix='.pkl') + for _ in [None] * retry: + try: + copy_video_np = (copy_video_tensor.numpy() * 255).astype('uint8') + pickle.dump(copy_video_np, open(filename_pred, 'wb')) + bucket.put_object_from_file(video_save_key, filename_pred) + break + except Exception as e: + print('error! ', video_save_key, e) + continue + + # remove temporary file + if osp.exists(filename_pred): + os.remove(filename_pred) + + filename_gt = rand_name(suffix='.pkl') + for _ in [None] * retry: + try: + copy_source_np = (copy_source_imgs.numpy() * 255).astype('uint8') + pickle.dump(copy_source_np, open(filename_gt, 'wb')) + bucket.put_object_from_file(gt_video_save_key, filename_gt) + break + except Exception as e: + print('error! ', gt_video_save_key, e) + continue + + # remove temporary file + if osp.exists(filename_gt): + os.remove(filename_gt) + + if exception is not None: + print( + 'save video to {} failed, error: {}'.format( + vis_oss_key, exception), + flush=True) + + +@torch.no_grad() +def save_video_vs_conditions(bucket, + oss_key, + video_tensor, + conditions, + source_imgs, + mean=[0.5, 0.5, 0.5], + std=[0.5, 0.5, 0.5], + nrow=8, + retry=5): + mean = torch.tensor(mean, device=video_tensor.device).view(1, -1, 1, 1, 1) + std = torch.tensor(std, device=video_tensor.device).view(1, -1, 1, 1, 1) + video_tensor = video_tensor.mul_(std).add_(mean) + video_tensor.clamp_(0, 1) + + b, c, n, h, w = video_tensor.shape + source_imgs = F.adaptive_avg_pool3d(source_imgs, (n, h, w)) + source_imgs = source_imgs.cpu() + + if conditions.size(1) == 1: + conditions = torch.cat([conditions, conditions, conditions], dim=1) + conditions = F.adaptive_avg_pool3d(conditions, (n, h, w)) + + filename = rand_name(suffix='.gif') + for _ in [None] * retry: + try: + vid_gif = rearrange( + video_tensor, '(i j) c f h w -> c f (i h) (j w)', i=nrow) + con_gif = rearrange( + conditions, '(i j) c f h w -> c f (i h) (j w)', i=nrow) + source_imgs = rearrange( + source_imgs, '(i j) c f h w -> c f (i h) (j w)', i=nrow) + vid_gif = torch.cat([vid_gif, con_gif, source_imgs], dim=2) + + video_tensor_to_gif(vid_gif, filename) + bucket.put_object_from_file(oss_key, filename) + exception = None + break + except Exception as e: + exception = e + continue + + # remove temporary file + if osp.exists(filename): + os.remove(filename) + if exception is not None: + print( + 'save video to {} failed, error: {}'.format(oss_key, exception), + flush=True) + + +@torch.no_grad() +def save_video_grid_mp4(bucket, + oss_key, + tensor, + mean=[0.5, 0.5, 0.5], + std=[0.5, 0.5, 0.5], + nrow=None, + fps=5, + retry=5): + mean = torch.tensor(mean, device=tensor.device).view(1, -1, 1, 1, 1) + std = torch.tensor(std, device=tensor.device).view(1, -1, 1, 1, 1) + tensor = tensor.mul_(std).add_(mean) + tensor.clamp_(0, 1) + b, c, t, h, w = tensor.shape + tensor = tensor.permute(0, 2, 3, 4, 1) + tensor = (tensor.cpu().numpy() * 255).astype('uint8') + + filename = rand_name(suffix='.mp4') + for _ in [None] * retry: + try: + if nrow is None: + nrow = math.ceil(math.sqrt(b)) + ncol = math.ceil(b / nrow) + padding = 1 + video_grid = np.zeros((t, (padding + h) * nrow + padding, + (padding + w) * ncol + padding, c), + dtype='uint8') + for i in range(b): + r = i // ncol + c_ = i % ncol + + start_r = (padding + h) * r + start_c = (padding + w) * c_ + video_grid[:, start_r:start_r + h, + start_c:start_c + w] = tensor[i] + skvideo.io.vwrite(filename, video_grid, inputdict={'-r': str(fps)}) + + bucket.put_object_from_file(oss_key, filename) + exception = None + break + except Exception as e: + exception = e + continue + + # remove temporary file + if osp.exists(filename): + os.remove(filename) + if exception is not None: + print( + 'save video to {} failed, error: {}'.format(oss_key, exception), + flush=True) + + +@torch.no_grad() +def save_text(bucket, oss_key, tensor, nrow=8, retry=5): + len = tensor.shape[0] + num_per_row = int(len / nrow) + assert (len == nrow * num_per_row) + texts = '' + for i in range(nrow): + for j in range(num_per_row): + text = dec_bytes2obj(tensor[i * num_per_row + j]) + texts += text + '\n' + texts += '\n' + + for _ in [None] * retry: + try: + bucket.put_object(oss_key, texts) + exception = None + break + except Exception as e: + exception = e + continue + if exception is not None: + print( + 'save video to {} failed, error: {}'.format(oss_key, exception), + flush=True) + + +@torch.no_grad() +def save_caps(bucket, oss_key, caps, retry=5): + texts = '' + for cap in caps: + texts += cap + texts += '\n' + + for _ in [None] * retry: + try: + bucket.put_object(oss_key, texts) + exception = None + break + except Exception as e: + exception = e + continue + if exception is not None: + print( + 'save video to {} failed, error: {}'.format(oss_key, exception), + flush=True) + + +@torch.no_grad() +def ema(net_ema, net, beta, copy_buffer=False): + assert 0.0 <= beta <= 1.0 + for p_ema, p in zip(net_ema.parameters(), net.parameters()): + p_ema.copy_(p.lerp(p_ema, beta)) + if copy_buffer: + for b_ema, b in zip(net_ema.buffers(), net.buffers()): + b_ema.copy_(b) + + +def parallel(func, args_list, num_workers=32, timeout=None): + assert isinstance(args_list, list) + if not isinstance(args_list[0], tuple): + args_list = [(args, ) for args in args_list] + if num_workers == 0: + return [func(*args) for args in args_list] + with Pool(processes=num_workers) as pool: + results = [pool.apply_async(func, args) for args in args_list] + results = [res.get(timeout=timeout) for res in results] + return results + + +def exists(filename): + if filename.startswith('oss://'): + bucket, path = parse_oss_url(filename) + return bucket.object_exists(path) + else: + return osp.exists(filename) + + +def download(url, filename=None, replace=False, quiet=False): + if filename is None: + filename = osp.basename(url) + if not osp.exists(filename) or replace: + try: + if url.startswith('oss://'): + bucket, oss_key = parse_oss_url(url) + bucket.get_object_to_file(oss_key, filename) + else: + urllib.request.urlretrieve(url, filename) + if not quiet: + print(f'Downloaded {url} to {filename}', flush=True) + except Exception as e: + raise ValueError(f'Downloading {filename} failed with error {e}') + return osp.abspath(filename) + + +def unzip(filename, dst_dir=None): + if dst_dir is None: + dst_dir = osp.dirname(filename) + with zipfile.ZipFile(filename, 'r') as zip_ref: + zip_ref.extractall(dst_dir) + + +def load_state_dict(module, state_dict, drop_prefix=''): + # find incompatible key-vals + src, dst = state_dict, module.state_dict() + if drop_prefix: + src = type(src)([ + (k[len(drop_prefix):] if k.startswith(drop_prefix) else k, v) + for k, v in src.items() + ]) + missing = [k for k in dst if k not in src] + unexpected = [k for k in src if k not in dst] + unmatched = [ + k for k in src.keys() & dst.keys() if src[k].shape != dst[k].shape + ] + + # keep only compatible key-vals + incompatible = set(unexpected + unmatched) + src = type(src)([(k, v) for k, v in src.items() if k not in incompatible]) + module.load_state_dict(src, strict=False) + + # report incompatible key-vals + if len(missing) != 0: + print(' Missing: ' + ', '.join(missing), flush=True) + if len(unexpected) != 0: + print(' Unexpected: ' + ', '.join(unexpected), flush=True) + if len(unmatched) != 0: + print(' Shape unmatched: ' + ', '.join(unmatched), flush=True) + + +def inverse_indices(indices): + r"""Inverse map of indices. + E.g., if A[indices] == B, then B[inv_indices] == A. + """ + inv_indices = torch.empty_like(indices) + inv_indices[indices] = torch.arange(len(indices)).to(indices) + return inv_indices + + +def detect_duplicates(feats, thr=0.9): + assert feats.ndim == 2 + + # compute simmat + feats = F.normalize(feats, p=2, dim=1) + simmat = torch.mm(feats, feats.T) + simmat.triu_(1) + torch.cuda.synchronize() + + # detect duplicates + mask = ~simmat.gt(thr).any(dim=0) + return torch.where(mask)[0] + + +class TFSClient(object): + + def __init__(self, + host='restful-store.vip.tbsite.net:3800', + app_key='5354c9fae75f5'): + self.host = host + self.app_key = app_key + + # candidate servers + self.servers = [ + u for u in read(f'http://{host}/url.list').strip().split('\n')[1:] + if ':' in u + ] + assert len(self.servers) >= 1 + self.__server_id = -1 + + @property + def server(self): + self.__server_id = (self.__server_id + 1) % len(self.servers) + return self.servers[self.__server_id] + + def read(self, tfs): + tfs = osp.basename(tfs) + meta = json.loads( + read( + f'http://{self.server}/v1/{self.app_key}/metadata/{tfs}?force=0' + )) + img = Image.open( + BytesIO( + read( + f'http://{self.server}/v1/{self.app_key}/{tfs}?offset=0&size={meta["SIZE"]}', + 'rb'))) + return img + + +def read_tfs(tfs, retry=5): + exception = None + for _ in range(retry): + try: + global TFS_CLIENT + if TFS_CLIENT is None: + TFS_CLIENT = TFSClient() + return TFS_CLIENT.read(tfs) + except Exception as e: + exception = e + continue + else: + raise exception + + +def md5(filename): + with open(filename, 'rb') as f: + return hashlib.md5(f.read()).hexdigest() + + +def rope(x): + r"""Apply rotary position embedding on x of shape [B, *(spatial dimensions), C]. + """ + # reshape + shape = x.shape + x = x.view(x.size(0), -1, x.size(-1)) + l, c = x.shape[-2:] + assert c % 2 == 0 + half = c // 2 + + # apply rotary position embedding on x + sinusoid = torch.outer( + torch.arange(l).to(x), + torch.pow(10000, -torch.arange(half).to(x).div(half))) + sin, cos = torch.sin(sinusoid), torch.cos(sinusoid) + x1, x2 = x.chunk(2, dim=-1) + x = torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1) + + # reshape back + return x.view(shape) + + +def format_state(state, filename=None): + r"""For comparing/aligning state_dict. + """ + content = '\n'.join([f'{k}\t{tuple(v.shape)}' for k, v in state.items()]) + if filename: + with open(filename, 'w') as f: + f.write(content) + + +def breakup_grid(img, grid_size): + r"""The inverse operator of ``torchvision.utils.make_grid``. + """ + # params + nrow = img.height // grid_size + ncol = img.width // grid_size + wrow = wcol = 2 + + # collect grids + grids = [] + for i in range(nrow): + for j in range(ncol): + x1 = j * grid_size + (j + 1) * wcol + y1 = i * grid_size + (i + 1) * wrow + grids.append(img.crop((x1, y1, x1 + grid_size, y1 + grid_size))) + return grids + + +def huggingface_tokenizer(name='google/mt5-xxl', **kwargs): + from transformers import AutoTokenizer + return AutoTokenizer.from_pretrained( + DOWNLOAD_TO_CACHE(f'huggingface/tokenizers/{name}', name), **kwargs) + + +def huggingface_model(name='google/mt5-xxl', model_type='AutoModel', **kwargs): + import transformers + return getattr(transformers, model_type).from_pretrained( + DOWNLOAD_TO_CACHE(f'huggingface/models/{name}', name), **kwargs) diff --git a/modelscope/models/multi_modal/videocomposer/videocomposer_model.py b/modelscope/models/multi_modal/videocomposer/videocomposer_model.py new file mode 100644 index 00000000..3e2a910c --- /dev/null +++ b/modelscope/models/multi_modal/videocomposer/videocomposer_model.py @@ -0,0 +1,480 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +from copy import copy, deepcopy +from os import path as osp +from typing import Any, Dict + +import open_clip +import pynvml +import torch +import torch.cuda.amp as amp +import torch.nn as nn +from einops import rearrange + +import modelscope.models.multi_modal.videocomposer.models as models +from modelscope.metainfo import Models +from modelscope.models import TorchModel +from modelscope.models.builder import MODELS +from modelscope.models.multi_modal.videocomposer.annotator.sketch import ( + pidinet_bsd, sketch_simplification_gan) +from modelscope.models.multi_modal.videocomposer.autoencoder import \ + AutoencoderKL +from modelscope.models.multi_modal.videocomposer.clip import ( + FrozenOpenCLIPEmbedder, FrozenOpenCLIPVisualEmbedder) +from modelscope.models.multi_modal.videocomposer.diffusion import ( + GaussianDiffusion, beta_schedule) +from modelscope.models.multi_modal.videocomposer.ops.utils import ( + get_first_stage_encoding, make_masked_images, prepare_model_kwargs, + save_with_model_kwargs) +from modelscope.models.multi_modal.videocomposer.unet_sd import UNetSD_temporal +from modelscope.models.multi_modal.videocomposer.utils.config import Config +from modelscope.models.multi_modal.videocomposer.utils.utils import ( + find_free_port, setup_seed, to_device) +from modelscope.outputs import OutputKeys +from modelscope.preprocessors.image import load_image +from modelscope.utils.constant import ModelFile, Tasks +from .config import cfg + +__all__ = ['VideoComposer'] + + +@MODELS.register_module( + Tasks.text_to_video_synthesis, module_name=Models.videocomposer) +class VideoComposer(TorchModel): + r""" + task for video composer. + + Attributes: + sd_model: denosing model using in this task. + diffusion: diffusion model for DDIM. + autoencoder: decode the latent representation into visual space with VQGAN. + clip_encoder: encode the text into text embedding. + """ + + def __init__(self, model_dir, *args, **kwargs): + r""" + Args: + model_dir (`str` or `os.PathLike`) + Can be either: + - A string, the *model id* of a pretrained model hosted inside a model repo on modelscope + or modelscope.cn. Valid model ids can be located at the root-level, like `bert-base-uncased`, + or namespaced under a user or organization name, like `dbmdz/bert-base-german-cased`. + - A path to a *directory* containing model weights saved using + [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. + - A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In + this case, `from_tf` should be set to `True` and a configuration object should be provided as + `config` argument. This loading path is slower than converting the TensorFlow checkpoint in a + PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards. + - A path or url to a model folder containing a *flax checkpoint file* in *.msgpack* format (e.g, + `./flax_model/` containing `flax_model.msgpack`). In this case, `from_flax` should be set to + `True`. + """ + super().__init__(model_dir=model_dir, *args, **kwargs) + self.device = torch.device('cuda') if torch.cuda.is_available() \ + else torch.device('cpu') + clip_checkpoint = kwargs.pop('clip_checkpoint', + 'open_clip_pytorch_model.bin') + sd_checkpoint = kwargs.pop('sd_checkpoint', 'v2-1_512-ema-pruned.ckpt') + _cfg = Config(load=True) + cfg.update(_cfg.cfg_dict) + # rank-wise params + l1 = len(cfg.frame_lens) + l2 = len(cfg.feature_framerates) + cfg.max_frames = cfg.frame_lens[0 % (l1 * l2) // l2] + cfg.batch_size = cfg.batch_sizes[str(cfg.max_frames)] + # Copy update input parameter to current task + self.cfg = cfg + if 'MASTER_ADDR' not in os.environ: + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = find_free_port() + self.cfg.pmi_rank = int(os.getenv('RANK', 0)) + self.cfg.pmi_world_size = int(os.getenv('WORLD_SIZE', 1)) + setup_seed(self.cfg.seed) + self.read_image = kwargs.pop('read_image', False) + self.read_style = kwargs.pop('read_style', True) + self.read_sketch = kwargs.pop('read_sketch', False) + self.save_origin_video = kwargs.pop('save_origin_video', True) + self.video_compositions = kwargs.pop('video_compositions', [ + 'text', 'mask', 'depthmap', 'sketch', 'motion', 'image', + 'local_image', 'single_sketch' + ]) + self.viz_num = self.cfg.batch_size + self.clip_encoder = FrozenOpenCLIPEmbedder( + layer='penultimate', + pretrained=os.path.join(model_dir, clip_checkpoint)) + self.clip_encoder = self.clip_encoder.to(self.device) + self.clip_encoder_visual = FrozenOpenCLIPVisualEmbedder( + layer='penultimate', + pretrained=os.path.join(model_dir, clip_checkpoint)) + self.clip_encoder_visual.model.to(self.device) + ddconfig = { + 'double_z': True, + 'z_channels': 4, + 'resolution': 256, + 'in_channels': 3, + 'out_ch': 3, + 'ch': 128, + 'ch_mult': [1, 2, 4, 4], + 'num_res_blocks': 2, + 'attn_resolutions': [], + 'dropout': 0.0 + } + self.autoencoder = AutoencoderKL( + ddconfig, 4, ckpt_path=os.path.join(model_dir, sd_checkpoint)) + self.zero_y = self.clip_encoder('').detach() + black_image_feature = self.clip_encoder_visual( + self.clip_encoder_visual.black_image).unsqueeze(1) + black_image_feature = torch.zeros_like(black_image_feature) + self.autoencoder.eval() + for param in self.autoencoder.parameters(): + param.requires_grad = False + self.autoencoder.cuda() + self.model = UNetSD_temporal( + cfg=self.cfg, + in_dim=self.cfg.unet_in_dim, + concat_dim=self.cfg.unet_concat_dim, + dim=self.cfg.unet_dim, + y_dim=self.cfg.unet_y_dim, + context_dim=self.cfg.unet_context_dim, + out_dim=self.cfg.unet_out_dim, + dim_mult=self.cfg.unet_dim_mult, + num_heads=self.cfg.unet_num_heads, + head_dim=self.cfg.unet_head_dim, + num_res_blocks=self.cfg.unet_res_blocks, + attn_scales=self.cfg.unet_attn_scales, + dropout=self.cfg.unet_dropout, + temporal_attention=self.cfg.temporal_attention, + temporal_attn_times=self.cfg.temporal_attn_times, + use_checkpoint=self.cfg.use_checkpoint, + use_fps_condition=self.cfg.use_fps_condition, + use_sim_mask=self.cfg.use_sim_mask, + video_compositions=self.cfg.video_compositions, + misc_dropout=self.cfg.misc_dropout, + p_all_zero=self.cfg.p_all_zero, + p_all_keep=self.cfg.p_all_zero, + zero_y=self.zero_y, + black_image_feature=black_image_feature, + ).to(self.device) + + # Load checkpoint + if self.cfg.resume and self.cfg.resume_checkpoint: + if hasattr(self.cfg, 'text_to_video_pretrain' + ) and self.cfg.text_to_video_pretrain: + checkpoint_name = cfg.resume_checkpoint.split('/')[-1] + ss = torch.load( + os.path.join(self.model_dir, cfg.resume_checkpoint)) + ss = { + key: p + for key, p in ss.items() if 'input_blocks.0.0' not in key + } + self.model.load_state_dict(ss, strict=False) + else: + checkpoint_name = cfg.resume_checkpoint.split('/')[-1] + self.model.load_state_dict( + torch.load( + os.path.join(self.model_dir, checkpoint_name), + map_location='cpu'), + strict=False) + + torch.cuda.empty_cache() + else: + raise ValueError( + f'The checkpoint file {self.cfg.resume_checkpoint} is wrong ') + + # diffusion + betas = beta_schedule( + 'linear_sd', + self.cfg.num_timesteps, + init_beta=0.00085, + last_beta=0.0120) + self.diffusion = GaussianDiffusion( + betas=betas, + mean_type=self.cfg.mean_type, + var_type=self.cfg.var_type, + loss_type=self.cfg.loss_type, + rescale_timesteps=False) + + def forward(self, input: Dict[str, Any]): + frame_in = None + if self.read_image: + image_key = input['style_image'] + frame = load_image(image_key) + frame_in = misc_transforms([frame]) + + frame_sketch = None + if self.read_sketch: + sketch_key = self.cfg.sketch_path + frame_sketch = load_image(sketch_key) + frame_sketch = misc_transforms([frame_sketch]) + + frame_style = None + if self.read_style: + frame_style = load_image(input['style_image']) + + # Generators for various conditions + if 'depthmap' in self.video_compositions: + midas = models.midas_v3( + pretrained=True, + model_dir=self.model_dir).eval().requires_grad_(False).to( + memory_format=torch.channels_last).half().to(self.device) + if 'canny' in self.video_compositions: + canny_detector = CannyDetector() + if 'sketch' in self.video_compositions: + pidinet = pidinet_bsd( + self.model_dir, pretrained=True, + vanilla_cnn=True).eval().requires_grad_(False).to(self.device) + cleaner = sketch_simplification_gan( + self.model_dir, + pretrained=True).eval().requires_grad_(False).to(self.device) + pidi_mean = torch.tensor(self.cfg.sketch_mean).view( + 1, -1, 1, 1).to(self.device) + pidi_std = torch.tensor(self.cfg.sketch_std).view(1, -1, 1, 1).to( + self.device) + # Placeholder for color inference + palette = None + + self.model.eval() + caps = input['cap_txt'] + if self.cfg.max_frames == 1 and self.cfg.use_image_dataset: + ref_imgs = input['ref_frame'] + video_data = input['video_data'] + misc_data = input['misc_data'] + mask = input['mask'] + mv_data = input['mv_data'] + fps = torch.tensor( + [self.cfg.feature_framerate] * self.cfg.batch_size, + dtype=torch.long, + device=self.device) + else: + ref_imgs = input['ref_frame'] + video_data = input['video_data'] + misc_data = input['misc_data'] + mask = input['mask'] + mv_data = input['mv_data'] + # add fps test + fps = torch.tensor( + [self.cfg.feature_framerate] * self.cfg.batch_size, + dtype=torch.long, + device=self.device) + + # save for visualization + misc_backups = copy(misc_data) + misc_backups = rearrange(misc_backups, 'b f c h w -> b c f h w') + mv_data_video = [] + if 'motion' in self.cfg.video_compositions: + mv_data_video = rearrange(mv_data, 'b f c h w -> b c f h w') + + # mask images + masked_video = [] + if 'mask' in self.cfg.video_compositions: + masked_video = make_masked_images( + misc_data.sub(0.5).div_(0.5), mask) + masked_video = rearrange(masked_video, 'b f c h w -> b c f h w') + + # Single Image + image_local = [] + if 'local_image' in self.cfg.video_compositions: + frames_num = misc_data.shape[1] + bs_vd_local = misc_data.shape[0] + if self.cfg.read_image: + image_local = frame_in.unsqueeze(0).repeat( + bs_vd_local, frames_num, 1, 1, 1).cuda() + else: + image_local = misc_data[:, :1].clone().repeat( + 1, frames_num, 1, 1, 1) + image_local = rearrange( + image_local, 'b f c h w -> b c f h w', b=bs_vd_local) + + # encode the video_data + bs_vd = video_data.shape[0] + video_data = rearrange(video_data, 'b f c h w -> (b f) c h w') + misc_data = rearrange(misc_data, 'b f c h w -> (b f) c h w') + + video_data_list = torch.chunk( + video_data, video_data.shape[0] // self.cfg.chunk_size, dim=0) + misc_data_list = torch.chunk( + misc_data, misc_data.shape[0] // self.cfg.chunk_size, dim=0) + + with torch.no_grad(): + decode_data = [] + for vd_data in video_data_list: + encoder_posterior = self.autoencoder.encode(vd_data) + tmp = get_first_stage_encoding(encoder_posterior).detach() + decode_data.append(tmp) + video_data = torch.cat(decode_data, dim=0) + video_data = rearrange( + video_data, '(b f) c h w -> b c f h w', b=bs_vd) + + depth_data = [] + if 'depthmap' in self.cfg.video_compositions: + for misc_imgs in misc_data_list: + depth = midas( + misc_imgs.sub(0.5).div_(0.5).to( + memory_format=torch.channels_last).half()) + depth = (depth / self.cfg.depth_std).clamp_( + 0, self.cfg.depth_clamp) + depth_data.append(depth) + depth_data = torch.cat(depth_data, dim=0) + depth_data = rearrange( + depth_data, '(b f) c h w -> b c f h w', b=bs_vd) + + canny_data = [] + if 'canny' in self.cfg.video_compositions: + for misc_imgs in misc_data_list: + misc_imgs = rearrange(misc_imgs.clone(), + 'k c h w -> k h w c') + canny_condition = torch.stack( + [canny_detector(misc_img) for misc_img in misc_imgs]) + canny_condition = rearrange(canny_condition, + 'k h w c-> k c h w') + canny_data.append(canny_condition) + canny_data = torch.cat(canny_data, dim=0) + canny_data = rearrange( + canny_data, '(b f) c h w -> b c f h w', b=bs_vd) + + sketch_data = [] + if 'sketch' in self.cfg.video_compositions: + sketch_list = misc_data_list + if self.cfg.read_sketch: + sketch_repeat = frame_sketch.repeat(frames_num, 1, 1, + 1).cuda() + sketch_list = [sketch_repeat] + + for misc_imgs in sketch_list: + sketch = pidinet(misc_imgs.sub(pidi_mean).div_(pidi_std)) + sketch = 1.0 - cleaner(1.0 - sketch) + sketch_data.append(sketch) + sketch_data = torch.cat(sketch_data, dim=0) + sketch_data = rearrange( + sketch_data, '(b f) c h w -> b c f h w', b=bs_vd) + + single_sketch_data = [] + if 'single_sketch' in self.cfg.video_compositions: + single_sketch_data = sketch_data.clone()[:, :, :1].repeat( + 1, 1, frames_num, 1, 1) + + # preprocess for input text descripts + y = self.clip_encoder(caps).detach() + y0 = y.clone() + + y_visual = [] + if 'image' in self.cfg.video_compositions: + with torch.no_grad(): + if self.cfg.read_style: + y_visual = self.clip_encoder_visual( + self.clip_encoder_visual.preprocess( + frame_style).unsqueeze(0).cuda()).unsqueeze(0) + y_visual0 = y_visual.clone() + else: + ref_imgs = ref_imgs.squeeze(1) + y_visual = self.clip_encoder_visual(ref_imgs).unsqueeze(1) + y_visual0 = y_visual.clone() + + with torch.no_grad(): + # Log memory + pynvml.nvmlInit() + # Sample images (DDIM) + with amp.autocast(enabled=self.cfg.use_fp16): + if self.cfg.share_noise: + b, c, f, h, w = video_data.shape + noise = torch.randn((self.viz_num, c, h, w), + device=self.device) + noise = noise.repeat_interleave(repeats=f, dim=0) + noise = rearrange( + noise, '(b f) c h w->b c f h w', b=self.viz_num) + noise = noise.contiguous() + else: + noise = torch.randn_like(video_data[:self.viz_num]) + + full_model_kwargs = [{ + 'y': + y0[:self.viz_num], + 'local_image': + None + if len(image_local) == 0 else image_local[:self.viz_num], + 'image': + None if len(y_visual) == 0 else y_visual0[:self.viz_num], + 'depth': + None + if len(depth_data) == 0 else depth_data[:self.viz_num], + 'canny': + None + if len(canny_data) == 0 else canny_data[:self.viz_num], + 'sketch': + None + if len(sketch_data) == 0 else sketch_data[:self.viz_num], + 'masked': + None + if len(masked_video) == 0 else masked_video[:self.viz_num], + 'motion': + None if len(mv_data_video) == 0 else + mv_data_video[:self.viz_num], + 'single_sketch': + None if len(single_sketch_data) == 0 else + single_sketch_data[:self.viz_num], + 'fps': + fps[:self.viz_num] + }, { + 'y': + self.zero_y.repeat(self.viz_num, 1, 1) + if not self.cfg.use_fps_condition else + torch.zeros_like(y0)[:self.viz_num], + 'local_image': + None + if len(image_local) == 0 else image_local[:self.viz_num], + 'image': + None if len(y_visual) == 0 else torch.zeros_like( + y_visual0[:self.viz_num]), + 'depth': + None + if len(depth_data) == 0 else depth_data[:self.viz_num], + 'canny': + None + if len(canny_data) == 0 else canny_data[:self.viz_num], + 'sketch': + None + if len(sketch_data) == 0 else sketch_data[:self.viz_num], + 'masked': + None + if len(masked_video) == 0 else masked_video[:self.viz_num], + 'motion': + None if len(mv_data_video) == 0 else + mv_data_video[:self.viz_num], + 'single_sketch': + None if len(single_sketch_data) == 0 else + single_sketch_data[:self.viz_num], + 'fps': + fps[:self.viz_num] + }] + + # Save generated videos + partial_keys = self.cfg.guidances + noise_motion = noise.clone() + model_kwargs = prepare_model_kwargs( + partial_keys=partial_keys, + full_model_kwargs=full_model_kwargs, + use_fps_condition=self.cfg.use_fps_condition) + video_output = self.diffusion.ddim_sample_loop( + noise=noise_motion, + model=self.model.eval(), + model_kwargs=model_kwargs, + guide_scale=9.0, + ddim_timesteps=self.cfg.ddim_timesteps, + eta=0.0) + + save_with_model_kwargs( + model_kwargs=model_kwargs, + video_data=video_output, + autoencoder=self.autoencoder, + ori_video=misc_backups, + viz_num=self.viz_num, + step=0, + caps=caps, + palette=palette, + cfg=self.cfg) + + return { + 'video': video_output.type(torch.float32).cpu(), + 'video_path': self.cfg + } diff --git a/modelscope/pipelines/multi_modal/__init__.py b/modelscope/pipelines/multi_modal/__init__.py index 11c28fdd..fd147513 100644 --- a/modelscope/pipelines/multi_modal/__init__.py +++ b/modelscope/pipelines/multi_modal/__init__.py @@ -22,6 +22,7 @@ if TYPE_CHECKING: from .soonet_video_temporal_grounding_pipeline import SOONetVideoTemporalGroundingPipeline from .text_to_video_synthesis_pipeline import TextToVideoSynthesisPipeline from .multimodal_dialogue_pipeline import MultimodalDialoguePipeline + from .videocomposer_pipeline import VideoComposerPipeline else: _import_structure = { 'image_captioning_pipeline': ['ImageCaptioningPipeline'], @@ -46,7 +47,8 @@ else: 'soonet_video_temporal_grounding_pipeline': ['SOONetVideoTemporalGroundingPipeline'], 'text_to_video_synthesis_pipeline': ['TextToVideoSynthesisPipeline'], - 'multimodal_dialogue_pipeline': ['MultimodalDialoguePipeline'] + 'multimodal_dialogue_pipeline': ['MultimodalDialoguePipeline'], + 'videocomposer_pipeline': ['VideoComposerPipeline'] } import sys diff --git a/modelscope/pipelines/multi_modal/videocomposer_pipeline.py b/modelscope/pipelines/multi_modal/videocomposer_pipeline.py new file mode 100644 index 00000000..539884fb --- /dev/null +++ b/modelscope/pipelines/multi_modal/videocomposer_pipeline.py @@ -0,0 +1,382 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +import random +import subprocess +import tempfile +import time +from functools import partial +from typing import Any, Dict + +import cv2 +import imageio +import numpy as np +import torch +import torchvision.transforms as T +from mvextractor.videocap import VideoCap +from PIL import Image + +import modelscope.models.multi_modal.videocomposer.data as data +from modelscope.metainfo import Pipelines +from modelscope.models.multi_modal.videocomposer.data.transforms import ( + CenterCropV3, random_resize) +from modelscope.models.multi_modal.videocomposer.ops.random_mask import ( + make_irregular_mask, make_rectangle_mask, make_uncrop) +from modelscope.models.multi_modal.videocomposer.utils.utils import rand_name +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.utils.constant import Tasks +from modelscope.utils.device import device_placement +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.text_to_video_synthesis, module_name=Pipelines.videocomposer) +class VideoComposerPipeline(Pipeline): + r""" Video Composer Pipeline. + + Examples: + + >>> from modelscope.pipelines import pipeline + >>> from modelscope.utils.constant import Tasks + >>> pipe = pipeline( + task=Tasks.text_to_video_synthesis, + model='buptwq/videocomposer', + model_revision='v1.0.1') + >>> inputs = {'Video:FILE': 'path/input_video.mp4', + 'Image:FILE': 'path/input_image.png', + 'text': 'the text description'} + >>> output = pipe(inputs) + """ + + def __init__(self, model: str, **kwargs): + """ + use `model` to create a videocomposer pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model) + self.log_dir = kwargs.pop('log_dir', './video_outputs') + if not os.path.exists(self.log_dir): + os.makedirs(self.log_dir) + self.feature_framerate = kwargs.pop('feature_framerate', 4) + self.frame_lens = kwargs.pop('frame_lens', [ + 16, + 16, + 16, + 16, + ]) + self.feature_framerates = kwargs.pop('feature_framerates', [ + 4, + ]) + self.batch_sizes = kwargs.pop('batch_sizes', { + '1': 1, + '4': 1, + '8': 1, + '16': 1, + }) + l1 = len(self.frame_lens) + l2 = len(self.feature_framerates) + self.max_frames = self.frame_lens[0 % (l1 * l2) // l2] + self.batch_size = self.batch_sizes[str(self.max_frames)] + self.resolution = kwargs.pop('resolution', 256) + self.image_resolution = kwargs.pop('image_resolution', 256) + self.mean = kwargs.pop('mean', [0.5, 0.5, 0.5]) + self.std = kwargs.pop('std', [0.5, 0.5, 0.5]) + self.vit_image_size = kwargs.pop('vit_image_size', 224) + self.vit_mean = kwargs.pop('vit_mean', + [0.48145466, 0.4578275, 0.40821073]) + self.vit_std = kwargs.pop('vit_std', + [0.26862954, 0.26130258, 0.27577711]) + self.misc_size = kwargs.pop('kwargs.pop', 384) + self.visual_mv = kwargs.pop('visual_mv', False) + self.max_words = kwargs.pop('max_words', 1000) + self.mvs_visual = kwargs.pop('mvs_visual', False) + + self.infer_trans = data.Compose([ + data.CenterCropV2(size=self.resolution), + data.ToTensor(), + data.Normalize(mean=self.mean, std=self.std) + ]) + + self.misc_transforms = data.Compose([ + T.Lambda(partial(random_resize, size=self.misc_size)), + data.CenterCropV2(self.misc_size), + data.ToTensor() + ]) + + self.mv_transforms = data.Compose( + [T.Resize(size=self.resolution), + T.CenterCrop(self.resolution)]) + + self.vit_transforms = T.Compose([ + CenterCropV3(self.vit_image_size), + T.ToTensor(), + T.Normalize(mean=self.vit_mean, std=self.vit_std) + ]) + + def preprocess(self, input: Input) -> Dict[str, Any]: + video_key = input['Video:FILE'] + cap_txt = input['text'] + style_image = input['Image:FILE'] + + total_frames = None + + feature_framerate = self.feature_framerate + if os.path.exists(video_key): + try: + ref_frame, vit_image, video_data, misc_data, mv_data = self.video_data_preprocess( + video_key, self.feature_framerate, total_frames, + self.mvs_visual) + except Exception as e: + logger.info( + '{} get frames failed... with error: {}'.format( + video_key, e), + flush=True) + + ref_frame = torch.zeros(3, self.vit_image_size, + self.vit_image_size) + video_data = torch.zeros(self.max_frames, 3, + self.image_resolution, + self.image_resolution) + misc_data = torch.zeros(self.max_frames, 3, self.misc_size, + self.misc_size) + + mv_data = torch.zeros(self.max_frames, 2, + self.image_resolution, + self.image_resolution) + else: + logger.info( + 'The video path does not exist or no video dir provided!') + ref_frame = torch.zeros(3, self.vit_image_size, + self.vit_image_size) + _ = torch.zeros(3, self.vit_image_size, self.vit_image_size) + video_data = torch.zeros(self.max_frames, 3, self.image_resolution, + self.image_resolution) + misc_data = torch.zeros(self.max_frames, 3, self.misc_size, + self.misc_size) + mv_data = torch.zeros(self.max_frames, 2, self.image_resolution, + self.image_resolution) + + # inpainting mask + p = random.random() + if p < 0.7: + mask = make_irregular_mask(512, 512) + elif p < 0.9: + mask = make_rectangle_mask(512, 512) + else: + mask = make_uncrop(512, 512) + mask = torch.from_numpy( + cv2.resize( + mask, (self.misc_size, self.misc_size), + interpolation=cv2.INTER_NEAREST)).unsqueeze(0).float() + + mask = mask.unsqueeze(0).repeat_interleave( + repeats=self.max_frames, dim=0) + video_input = { + 'ref_frame': ref_frame.unsqueeze(0), + 'cap_txt': cap_txt, + 'video_data': video_data.unsqueeze(0), + 'misc_data': misc_data.unsqueeze(0), + 'feature_framerate': feature_framerate, + 'mask': mask.unsqueeze(0), + 'mv_data': mv_data.unsqueeze(0), + 'style_image': style_image + } + return video_input + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + return self.model(input) + + def postprocess(self, inputs: Dict[str, Any], + **post_params) -> Dict[str, Any]: + output_video_path = post_params.get('output_video', None) + temp_video_file = False + if output_video_path is not None: + output_video_path = tempfile.NamedTemporaryFile(suffix='.gif').name + temp_video_file = True + + if temp_video_file: + return {OutputKeys.OUTPUT_VIDEO: inputs['video_path']} + else: + return {OutputKeys.OUTPUT_VIDEO: inputs['video']} + + def video_data_preprocess(self, video_key, feature_framerate, total_frames, + visual_mv): + + filename = video_key + for _ in range(5): + try: + frame_types, frames, mvs, mvs_visual = self.extract_motion_vectors( + input_video=filename, + fps=feature_framerate, + visual_mv=visual_mv) + break + except Exception as e: + logger.error( + '{} read video frames and motion vectors failed with error: {}' + .format(video_key, e), + flush=True) + + total_frames = len(frame_types) + start_indexs = np.where((np.array(frame_types) == 'I') & ( + total_frames - np.arange(total_frames) >= self.max_frames))[0] + start_index = np.random.choice(start_indexs) + indices = np.arange(start_index, start_index + self.max_frames) + + # note frames are in BGR mode, need to trans to RGB mode + frames = [Image.fromarray(frames[i][:, :, ::-1]) for i in indices] + mvs = [torch.from_numpy(mvs[i].transpose((2, 0, 1))) for i in indices] + mvs = torch.stack(mvs) + + if visual_mv: + images = [(mvs_visual[i][:, :, ::-1]).astype('uint8') + for i in indices] + path = self.log_dir + '/visual_mv/' + video_key.split( + '/')[-1] + '.gif' + if not os.path.exists(self.log_dir + '/visual_mv/'): + os.makedirs(self.log_dir + '/visual_mv/', exist_ok=True) + logger.info('save motion vectors visualization to :', path) + imageio.mimwrite(path, images, fps=8) + + have_frames = len(frames) > 0 + middle_indix = int(len(frames) / 2) + if have_frames: + ref_frame = frames[middle_indix] + vit_image = self.vit_transforms(ref_frame) + misc_imgs_np = self.misc_transforms[:2](frames) + misc_imgs = self.misc_transforms[2:](misc_imgs_np) + frames = self.infer_trans(frames) + mvs = self.mv_transforms(mvs) + else: + vit_image = torch.zeros(3, self.vit_image_size, + self.vit_image_size) + + video_data = torch.zeros(self.max_frames, 3, self.image_resolution, + self.image_resolution) + mv_data = torch.zeros(self.max_frames, 2, self.image_resolution, + self.image_resolution) + misc_data = torch.zeros(self.max_frames, 3, self.misc_size, + self.misc_size) + if have_frames: + video_data[:len(frames), ...] = frames + misc_data[:len(frames), ...] = misc_imgs + mv_data[:len(frames), ...] = mvs + + ref_frame = vit_image + + del frames + del misc_imgs + del mvs + + return ref_frame, vit_image, video_data, misc_data, mv_data + + def extract_motion_vectors(self, + input_video, + fps=4, + dump=False, + verbose=False, + visual_mv=False): + + if dump: + now = datetime.now().strftime('%Y-%m-%dT%H:%M:%S') + for child in ['frames', 'motion_vectors']: + os.makedirs(os.path.join(f'out-{now}', child), exist_ok=True) + temp = rand_name() + tmp_video = os.path.join( + input_video.split('/')[0], f'{temp}' + input_video.split('/')[-1]) + videocapture = cv2.VideoCapture(input_video) + frames_num = videocapture.get(cv2.CAP_PROP_FRAME_COUNT) + fps_video = videocapture.get(cv2.CAP_PROP_FPS) + # check if enough frames + if frames_num / fps_video * fps > 16: + fps = max(fps, 1) + else: + fps = int(16 / (frames_num / fps_video)) + 1 + ffmpeg_cmd = f'ffmpeg -threads 8 -loglevel error -i {input_video} -filter:v \ + fps={fps} -c:v mpeg4 -f rawvideo {tmp_video}' + + if os.path.exists(tmp_video): + os.remove(tmp_video) + + subprocess.run(args=ffmpeg_cmd, shell=True, timeout=120) + + cap = VideoCap() + # open the video file + ret = cap.open(tmp_video) + if not ret: + raise RuntimeError(f'Could not open {tmp_video}') + + step = 0 + times = [] + + frame_types = [] + frames = [] + mvs = [] + mvs_visual = [] + # continuously read and display video frames and motion vectors + while True: + if verbose: + logger.info('Frame: ', step, end=' ') + + tstart = time.perf_counter() + + # read next video frame and corresponding motion vectors + ret, frame, motion_vectors, frame_type, timestamp = cap.read() + + tend = time.perf_counter() + telapsed = tend - tstart + times.append(telapsed) + + # if there is an error reading the frame + if not ret: + if verbose: + logger.warning('No frame read. Stopping.') + break + + frame_save = np.zeros(frame.copy().shape, dtype=np.uint8) + if visual_mv: + frame_save = draw_motion_vectors(frame_save, motion_vectors) + + # store motion vectors, frames, etc. in output directory + dump = False + if frame.shape[1] >= frame.shape[0]: + w_half = (frame.shape[1] - frame.shape[0]) // 2 + if dump: + cv2.imwrite( + os.path.join('./mv_visual/', f'frame-{step}.jpg'), + frame_save[:, w_half:-w_half]) + mvs_visual.append(frame_save[:, w_half:-w_half]) + else: + h_half = (frame.shape[0] - frame.shape[1]) // 2 + if dump: + cv2.imwrite( + os.path.join('./mv_visual/', f'frame-{step}.jpg'), + frame_save[h_half:-h_half, :]) + mvs_visual.append(frame_save[h_half:-h_half, :]) + + h, w = frame.shape[:2] + mv = np.zeros((h, w, 2)) + position = motion_vectors[:, 5:7].clip((0, 0), (w - 1, h - 1)) + mv[position[:, 1], + position[:, + 0]] = motion_vectors[:, 0: + 1] * motion_vectors[:, 7: + 9] / motion_vectors[:, + 9:] + + step += 1 + frame_types.append(frame_type) + frames.append(frame) + mvs.append(mv) + if verbose: + logger.info('average dt: ', np.mean(times)) + cap.release() + + if os.path.exists(tmp_video): + os.remove(tmp_video) + + return frame_types, frames, mvs, mvs_visual diff --git a/modelscope/preprocessors/multi_modal.py b/modelscope/preprocessors/multi_modal.py index 5edc8c48..d180289b 100644 --- a/modelscope/preprocessors/multi_modal.py +++ b/modelscope/preprocessors/multi_modal.py @@ -1,4 +1,5 @@ # Copyright (c) Alibaba, Inc. and its affiliates. + import os.path as osp import re from io import BytesIO diff --git a/tests/pipelines/test_videocomposer.py b/tests/pipelines/test_videocomposer.py new file mode 100644 index 00000000..4cbca237 --- /dev/null +++ b/tests/pipelines/test_videocomposer.py @@ -0,0 +1,38 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import sys +import unittest + +from modelscope.models import Model +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.utils.constant import DownloadMode, Tasks +from modelscope.utils.test_utils import test_level + + +class VideoDeinterlaceTest(unittest.TestCase): + + def setUp(self) -> None: + self.task = Tasks.text_to_video_synthesis + self.model_id = 'buptwq/videocomposer' + self.model_revision = 'v1.0.1' + self.dataset_id = 'buptwq/videocomposer-depths-style' + self.text = 'A glittering and translucent fish swimming in a \ + small glass bowl with multicolored piece of stone, like a glass fish' + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_pipeline(self): + pipe = pipeline( + task=Tasks.text_to_video_synthesis, + model=self.model_id, + model_revision=self.model_revision) + ds = MsDataset.load( + self.dataset_id, + split='train', + download_mode=DownloadMode.FORCE_REDOWNLOAD) + inputs = next(iter(ds)) + inputs.update({'text': self.text}) + _ = pipe(inputs) + + +if __name__ == '__main__': + unittest.main()