diff --git a/.dev_scripts/dockerci.sh b/.dev_scripts/dockerci.sh index e06fb101..b4332f39 100644 --- a/.dev_scripts/dockerci.sh +++ b/.dev_scripts/dockerci.sh @@ -4,11 +4,15 @@ CODE_DIR=$PWD CODE_DIR_IN_CONTAINER=/Maas-lib echo "$USER" gpus='0,1 2,3 4,5 6,7' -cpu_sets='45-58 31-44 16-30 0-15' +cpu_sets='0-15 16-31 32-47 48-63' cpu_sets_arr=($cpu_sets) is_get_file_lock=false CI_COMMAND=${CI_COMMAND:-bash .dev_scripts/ci_container_test.sh python tests/run.py --parallel 2 --run_config tests/run_config.yaml} echo "ci command: $CI_COMMAND" +PR_CHANGED_FILES="${PR_CHANGED_FILES:-''}" +echo "PR modified files: $PR_CHANGED_FILES" +PR_CHANGED_FILES=${PR_CHANGED_FILES//[ ]/#} +echo "PR_CHANGED_FILES: $PR_CHANGED_FILES" idx=0 for gpu in $gpus do @@ -42,6 +46,7 @@ do -e MODELSCOPE_ENVIRONMENT='ci' \ -e TEST_UPLOAD_MS_TOKEN=$TEST_UPLOAD_MS_TOKEN \ -e MODEL_TAG_URL=$MODEL_TAG_URL \ + -e PR_CHANGED_FILES=$PR_CHANGED_FILES \ --workdir=$CODE_DIR_IN_CONTAINER \ ${IMAGE_NAME}:${IMAGE_VERSION} \ $CI_COMMAND @@ -64,6 +69,7 @@ do -e MODELSCOPE_ENVIRONMENT='ci' \ -e TEST_UPLOAD_MS_TOKEN=$TEST_UPLOAD_MS_TOKEN \ -e MODEL_TAG_URL=$MODEL_TAG_URL \ + -e PR_CHANGED_FILES=$PR_CHANGED_FILES \ --workdir=$CODE_DIR_IN_CONTAINER \ ${IMAGE_NAME}:${IMAGE_VERSION} \ $CI_COMMAND diff --git a/.github/workflows/citest.yaml b/.github/workflows/citest.yaml index 1ff78a65..8060f0bb 100644 --- a/.github/workflows/citest.yaml +++ b/.github/workflows/citest.yaml @@ -39,7 +39,7 @@ concurrency: jobs: unittest: # The type of runner that the job will run on - runs-on: [modelscope-self-hosted] + runs-on: [modelscope-self-hosted-us] timeout-minutes: 240 steps: - name: ResetFileMode @@ -52,10 +52,19 @@ jobs: sudo chown -R $USER:$USER $ACTION_RUNNER_DIR - name: Checkout - uses: actions/checkout@v2 + uses: actions/checkout@v3 with: lfs: 'true' submodules: 'true' + fetch-depth: ${{ github.event_name == 'pull_request' && 2 || 0 }} + - name: Get changed files + id: changed-files + run: | + if ${{ github.event_name == 'pull_request' }}; then + echo "PR_CHANGED_FILES=$(git diff --name-only -r HEAD^1 HEAD | xargs)" >> $GITHUB_ENV + else + echo "PR_CHANGED_FILES=$(git diff --name-only ${{ github.event.before }} ${{ github.event.after }} | xargs)" >> $GITHUB_ENV + fi - name: Checkout LFS objects run: git lfs checkout - name: Run unittest diff --git a/.github/workflows/daily_regression.yaml b/.github/workflows/daily_regression.yaml index 0500b61c..85ca5e0b 100644 --- a/.github/workflows/daily_regression.yaml +++ b/.github/workflows/daily_regression.yaml @@ -12,7 +12,7 @@ concurrency: jobs: unittest: # The type of runner that the job will run on - runs-on: [modelscope-self-hosted] + runs-on: [modelscope-self-hosted-us] steps: - name: ResetFileMode shell: bash diff --git a/examples/pytorch/stable_diffusion/custom/finetune_stable_diffusion_custom.py b/examples/pytorch/stable_diffusion/custom/finetune_stable_diffusion_custom.py index 007ea82b..76a050c4 100644 --- a/examples/pytorch/stable_diffusion/custom/finetune_stable_diffusion_custom.py +++ b/examples/pytorch/stable_diffusion/custom/finetune_stable_diffusion_custom.py @@ -2,6 +2,7 @@ import os from dataclasses import dataclass, field import cv2 +import torch from modelscope.metainfo import Trainers from modelscope.msdatasets import MsDataset @@ -95,6 +96,12 @@ class StableDiffusionCustomArguments(TrainingArgs): 'help': 'Path to json containing multiple concepts.', }) + torch_type: str = field( + default='float32', + metadata={ + 'help': ' The torch type, default is float32.', + }) + training_args = StableDiffusionCustomArguments( task='text-to-image-synthesis').parse_cli() @@ -148,6 +155,8 @@ kwargs = dict( work_dir=training_args.work_dir, train_dataset=train_dataset, eval_dataset=validation_dataset, + torch_type=torch.float16 + if args.torch_type == 'float16' else torch.float32, cfg_modify_fn=cfg_modify_fn) # build trainer and training @@ -159,7 +168,7 @@ pipe = pipeline( task=Tasks.text_to_image_synthesis, model=training_args.model, custom_dir=training_args.work_dir + '/output', - modifier_token='+', + modifier_token=args.modifier_token, model_revision=args.model_revision) output = pipe({'text': args.instance_prompt}) diff --git a/examples/pytorch/stable_diffusion/custom/run_train_custom.sh b/examples/pytorch/stable_diffusion/custom/run_train_custom.sh index fab8e059..7f9cb500 100644 --- a/examples/pytorch/stable_diffusion/custom/run_train_custom.sh +++ b/examples/pytorch/stable_diffusion/custom/run_train_custom.sh @@ -7,11 +7,12 @@ PYTHONPATH=. torchrun examples/pytorch/stable_diffusion/custom/finetune_stable_d --class_data_dir './tmp/class_data' \ --train_dataset_name 'buptwq/lora-stable-diffusion-finetune-dog' \ --max_epochs 250 \ - --modifier_token "+" \ + --modifier_token "" \ --num_class_images=200 \ --save_ckpt_strategy 'by_epoch' \ --logging_interval 1 \ --train.dataloader.workers_per_gpu 0 \ --evaluation.dataloader.workers_per_gpu 0 \ --train.optimizer.lr 1e-5 \ + --torch_type 'float32' \ --use_model_config true diff --git a/examples/pytorch/stable_diffusion/dreambooth/finetune_stable_diffusion_dreambooth.py b/examples/pytorch/stable_diffusion/dreambooth/finetune_stable_diffusion_dreambooth.py index 760396d0..5659a105 100644 --- a/examples/pytorch/stable_diffusion/dreambooth/finetune_stable_diffusion_dreambooth.py +++ b/examples/pytorch/stable_diffusion/dreambooth/finetune_stable_diffusion_dreambooth.py @@ -2,6 +2,7 @@ import os from dataclasses import dataclass, field import cv2 +import torch from modelscope.metainfo import Trainers from modelscope.msdatasets import MsDataset @@ -59,6 +60,12 @@ class StableDiffusionDreamboothArguments(TrainingArgs): 'help': 'The pipeline prompt.', }) + torch_type: str = field( + default='float32', + metadata={ + 'help': ' The torch type, default is float32.', + }) + training_args = StableDiffusionDreamboothArguments( task='text-to-image-synthesis').parse_cli() @@ -106,6 +113,8 @@ kwargs = dict( resolution=args.resolution, prior_loss_weight=args.prior_loss_weight, prompt=args.prompt, + torch_type=torch.float16 + if args.torch_type == 'float16' else torch.float32, cfg_modify_fn=cfg_modify_fn) # build trainer and training diff --git a/examples/pytorch/stable_diffusion/dreambooth/run_train_dreambooth.sh b/examples/pytorch/stable_diffusion/dreambooth/run_train_dreambooth.sh index 03c51ffb..461434ee 100644 --- a/examples/pytorch/stable_diffusion/dreambooth/run_train_dreambooth.sh +++ b/examples/pytorch/stable_diffusion/dreambooth/run_train_dreambooth.sh @@ -17,4 +17,5 @@ PYTHONPATH=. torchrun examples/pytorch/stable_diffusion/dreambooth/finetune_stab --train.dataloader.workers_per_gpu 0 \ --evaluation.dataloader.workers_per_gpu 0 \ --train.optimizer.lr 5e-6 \ + --torch_type 'float32' \ --use_model_config true diff --git a/examples/pytorch/stable_diffusion/lora/finetune_stable_diffusion_lora.py b/examples/pytorch/stable_diffusion/lora/finetune_stable_diffusion_lora.py index 6001af48..f0c40be7 100644 --- a/examples/pytorch/stable_diffusion/lora/finetune_stable_diffusion_lora.py +++ b/examples/pytorch/stable_diffusion/lora/finetune_stable_diffusion_lora.py @@ -2,6 +2,7 @@ import os from dataclasses import dataclass, field import cv2 +import torch from modelscope.metainfo import Trainers from modelscope.msdatasets import MsDataset @@ -25,6 +26,12 @@ class StableDiffusionLoraArguments(TrainingArgs): 'help': 'The rank size of lora intermediate linear.', }) + torch_type: str = field( + default='float32', + metadata={ + 'help': ' The torch type, default is float32.', + }) + training_args = StableDiffusionLoraArguments( task='text-to-image-synthesis').parse_cli() @@ -66,6 +73,8 @@ kwargs = dict( train_dataset=train_dataset, eval_dataset=validation_dataset, lora_rank=args.lora_rank, + torch_type=torch.float16 + if args.torch_type == 'float16' else torch.float32, cfg_modify_fn=cfg_modify_fn) # build trainer and training diff --git a/examples/pytorch/stable_diffusion/lora/run_train_lora.sh b/examples/pytorch/stable_diffusion/lora/run_train_lora.sh index d4f1b07d..82e31aad 100644 --- a/examples/pytorch/stable_diffusion/lora/run_train_lora.sh +++ b/examples/pytorch/stable_diffusion/lora/run_train_lora.sh @@ -5,10 +5,11 @@ PYTHONPATH=. torchrun examples/pytorch/stable_diffusion/lora/finetune_stable_dif --work_dir './tmp/lora_diffusion' \ --train_dataset_name 'buptwq/lora-stable-diffusion-finetune' \ --max_epochs 100 \ - --lora_rank 4 \ + --lora_rank 16 \ --save_ckpt_strategy 'by_epoch' \ --logging_interval 1 \ --train.dataloader.workers_per_gpu 0 \ --evaluation.dataloader.workers_per_gpu 0 \ --train.optimizer.lr 1e-4 \ + --torch_type 'float16' \ --use_model_config true diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index a5856e85..630d4aa5 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -220,6 +220,8 @@ class Models(object): stable_diffusion = 'stable-diffusion' videocomposer = 'videocomposer' text_to_360panorama_image = 'text-to-360panorama-image' + image_to_video_model = 'image-to-video-model' + video_to_video_model = 'video-to-video-model' # science models unifold = 'unifold' @@ -235,6 +237,7 @@ class TaskModels(object): feature_extraction = 'feature-extraction' text_generation = 'text-generation' text_ranking = 'text-ranking' + machine_reading_comprehension = 'machine-reading-comprehension' class Heads(object): @@ -487,6 +490,7 @@ class Pipelines(object): document_grounded_dialog_rerank = 'document-grounded-dialog-rerank' document_grounded_dialog_generate = 'document-grounded-dialog-generate' language_identification = 'language_identification' + machine_reading_comprehension_for_ner = 'machine-reading-comprehension-for-ner' # audio tasks sambert_hifigan_tts = 'sambert-hifigan-tts' @@ -543,6 +547,8 @@ class Pipelines(object): efficient_diffusion_tuning = 'efficient-diffusion-tuning' multimodal_dialogue = 'multimodal-dialogue' llama2_text_generation_pipeline = 'llama2-text-generation-pipeline' + image_to_video_task_pipeline = 'image-to-video-task-pipeline' + video_to_video_pipeline = 'video-to-video-pipeline' # science tasks protein_structure = 'unifold-protein-structure' @@ -1062,6 +1068,7 @@ class Preprocessors(object): document_grounded_dialog_retrieval = 'document-grounded-dialog-retrieval' document_grounded_dialog_rerank = 'document-grounded-dialog-rerank' document_grounded_dialog_generate = 'document-grounded-dialog-generate' + machine_reading_comprehension_for_ner = 'machine-reading-comprehension-for-ner' # audio preprocessor linear_aec_fbank = 'linear-aec-fbank' diff --git a/modelscope/models/cv/image_face_fusion/facegan/face_gan.py b/modelscope/models/cv/image_face_fusion/facegan/face_gan.py new file mode 100644 index 00000000..80d62997 --- /dev/null +++ b/modelscope/models/cv/image_face_fusion/facegan/face_gan.py @@ -0,0 +1,55 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os + +import numpy as np +import torch +import torch.nn.functional as F + +from .gpen_model import FullGenerator + + +class GPEN(object): + + def __init__(self, + model_path, + size=512, + channel_multiplier=2, + device=torch.device('cpu')): + self.mfile = model_path + self.n_mlp = 8 + self.resolution = size + self.device = device + self.load_model(channel_multiplier) + + def load_model(self, channel_multiplier=2): + self.model = FullGenerator(self.resolution, 512, self.n_mlp, + channel_multiplier).to(self.device) + pretrained_dict = torch.load(self.mfile) + self.model.load_state_dict(pretrained_dict) + self.model.eval() + + def process(self, im): + preds = [] + imt = self.img2tensor(im) + imt = F.interpolate(imt, (self.resolution, self.resolution)) + + with torch.no_grad(): + img_out, __ = self.model(imt) + + face = self.tensor2img(img_out) + + return face, preds + + def img2tensor(self, img): + img_t = torch.from_numpy(img).to(self.device) + img_t = (img_t / 255. - 0.5) / 0.5 + img_t = img_t.permute(2, 0, 1).unsqueeze(0).flip(1) # BGR->RGB + return img_t + + def tensor2img(self, image_tensor, pmax=255.0, imtype=np.uint8): + image_tensor = image_tensor * 0.5 + 0.5 + image_tensor = image_tensor.squeeze(0).permute(1, 2, + 0).flip(2) # RGB->BGR + image_numpy = np.clip(image_tensor.float().cpu().numpy(), 0, 1) * pmax + + return image_numpy.astype(imtype) diff --git a/modelscope/models/cv/image_face_fusion/facegan/gan_wrap.py b/modelscope/models/cv/image_face_fusion/facegan/gan_wrap.py deleted file mode 100644 index c46b17eb..00000000 --- a/modelscope/models/cv/image_face_fusion/facegan/gan_wrap.py +++ /dev/null @@ -1,93 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. -import os - -import cv2 -import numpy as np -import torch -import torch.nn.functional as F -from PIL import Image -from torchvision import transforms - -from .model import FullGenerator - - -class GANWrap(object): - - def __init__(self, - model_path, - size=256, - channel_multiplier=1, - device='cpu'): - self.device = device - self.mfile = model_path - self.transform = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), - inplace=True), - ]) - self.batchSize = 2 - self.n_mlp = 8 - self.resolution = size - self.load_model(channel_multiplier) - - def load_model(self, channel_multiplier=2): - self.model = FullGenerator(self.resolution, 512, self.n_mlp, - channel_multiplier).to(self.device) - pretrained_dict = torch.load( - self.mfile, map_location=torch.device('cpu')) - self.model.load_state_dict(pretrained_dict) - self.model.eval() - - def process_tensor(self, img_t, return_face=True): - b, c, h, w = img_t.shape - img_t = F.interpolate(img_t, (self.resolution, self.resolution)) - - with torch.no_grad(): - out, __ = self.model(img_t) - - out = F.interpolate(out, (w, h)) - return out - - def process(self, ims, return_face=True): - res = [] - faces = [] - for i in range(0, len(ims), self.batchSize): - sizes = [] - imt = None - for im in ims[i:i + self.batchSize]: - sizes.append(im.shape[0]) - im = cv2.resize(im, (self.resolution, self.resolution)) - im_pil = Image.fromarray(im) - imt = self.img2tensor(im_pil) if imt is None else torch.cat( - (imt, self.img2tensor(im_pil)), dim=0) - - imt = torch.flip(imt, [1]) - with torch.no_grad(): - img_outs, __ = self.model(imt) - - for sz, img_out in zip(sizes, img_outs): - img = self.tensor2img(img_out) - if return_face: - faces.append(img) - img = cv2.resize(img, (sz, sz), interpolation=cv2.INTER_AREA) - res.append(img) - - return res, faces - - def img2tensor(self, img): - img_t = self.transform(img).to(self.device) - img_t = torch.unsqueeze(img_t, 0) - return img_t - - def tensor2img(self, image_tensor, bytes=255.0, imtype=np.uint8): - if image_tensor.dim() == 3: - image_numpy = image_tensor.cpu().float().numpy() - else: - image_numpy = image_tensor[0].cpu().float().numpy() - image_numpy = np.transpose(image_numpy, (1, 2, 0)) - image_numpy = image_numpy[:, :, ::-1] - image_numpy = np.clip( - image_numpy * np.asarray([0.5, 0.5, 0.5]) - + np.asarray([0.5, 0.5, 0.5]), 0, 1) - image_numpy = image_numpy * bytes - return image_numpy.astype(imtype) diff --git a/modelscope/models/cv/image_face_fusion/facegan/model.py b/modelscope/models/cv/image_face_fusion/facegan/gpen_model.py similarity index 76% rename from modelscope/models/cv/image_face_fusion/facegan/model.py rename to modelscope/models/cv/image_face_fusion/facegan/gpen_model.py index eb142779..22dfb1e1 100644 --- a/modelscope/models/cv/image_face_fusion/facegan/model.py +++ b/modelscope/models/cv/image_face_fusion/facegan/gpen_model.py @@ -1,17 +1,19 @@ -# The implementation is adopted from stylegan2-pytorch, -# made public available under the MIT License at https://github.com/rosinality/stylegan2-pytorch/blob/master/model.py +# The implementation is adopted from InsightFace_Pytorch, made publicly available under the MIT License +# at https://github.com/yangxy/GPEN +import functools +import itertools import math +import operator import random import torch from torch import nn +from torch.autograd import Function from torch.nn import functional as F from .op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d isconcat = True -sss = 2 if isconcat else 1 -ratio = 2 class PixelNorm(nn.Module): @@ -306,6 +308,7 @@ class NoiseInjection(nn.Module): def forward(self, image, noise=None): if noise is not None: + # print(image.shape, noise.shape) if isconcat: return torch.cat((image, self.weight * noise), dim=1) # concat return image + self.weight * noise @@ -356,7 +359,8 @@ class StyledConv(nn.Module): ) self.noise = NoiseInjection() - self.activate = FusedLeakyReLU(out_channel * sss) + feat_multiplier = 2 + self.activate = FusedLeakyReLU(out_channel * feat_multiplier) def forward(self, input, style, noise=None): out = self.conv(input, style) @@ -405,12 +409,14 @@ class Generator(nn.Module): channel_multiplier=2, blur_kernel=[1, 3, 3, 1], lr_mlp=0.01, + narrow=1, ): super().__init__() self.size = size self.n_mlp = n_mlp self.style_dim = style_dim + self.feat_multiplier = 2 layers = [PixelNorm()] @@ -425,15 +431,16 @@ class Generator(nn.Module): self.style = nn.Sequential(*layers) self.channels = { - 4: 512 // ratio, - 8: 512 // ratio, - 16: 512 // ratio, - 32: 512 // ratio, - 64: 256 // ratio * channel_multiplier, - 128: 128 // ratio * channel_multiplier, - 256: 64 // ratio * channel_multiplier, - 512: 32 // ratio * channel_multiplier, - 1024: 16 // ratio * channel_multiplier, + 4: int(512 * narrow), + 8: int(512 * narrow), + 16: int(512 * narrow), + 32: int(512 * narrow), + 64: int(256 * channel_multiplier * narrow), + 128: int(128 * channel_multiplier * narrow), + 256: int(64 * channel_multiplier * narrow), + 512: int(32 * channel_multiplier * narrow), + 1024: int(16 * channel_multiplier * narrow), + 2048: int(8 * channel_multiplier * narrow) } self.input = ConstantInput(self.channels[4]) @@ -443,7 +450,8 @@ class Generator(nn.Module): 3, style_dim, blur_kernel=blur_kernel) - self.to_rgb1 = ToRGB(self.channels[4] * sss, style_dim, upsample=False) + self.to_rgb1 = ToRGB( + self.channels[4] * self.feat_multiplier, style_dim, upsample=False) self.log_size = int(math.log(size, 2)) @@ -458,23 +466,23 @@ class Generator(nn.Module): self.convs.append( StyledConv( - in_channel * sss, + in_channel * self.feat_multiplier, out_channel, 3, style_dim, upsample=True, - blur_kernel=blur_kernel, - )) + blur_kernel=blur_kernel)) self.convs.append( StyledConv( - out_channel * sss, + out_channel * self.feat_multiplier, out_channel, 3, style_dim, blur_kernel=blur_kernel)) - self.to_rgbs.append(ToRGB(out_channel * sss, style_dim)) + self.to_rgbs.append( + ToRGB(out_channel * self.feat_multiplier, style_dim)) in_channel = out_channel @@ -515,6 +523,9 @@ class Generator(nn.Module): styles = [self.style(s) for s in styles] if noise is None: + ''' + noise = [None] * (2 * (self.log_size - 2) + 1) + ''' noise = [] batch = styles[0].shape[0] for i in range(self.n_mlp + 1): @@ -557,16 +568,14 @@ class Generator(nn.Module): skip = self.to_rgb1(out, latent[:, 1]) i = 1 - noise_i = 1 - - for conv1, conv2, to_rgb in zip(self.convs[::2], self.convs[1::2], - self.to_rgbs): - out = conv1(out, latent[:, i], noise=noise[(noise_i + 1) // 2]) - out = conv2(out, latent[:, i + 1], noise=noise[(noise_i + 2) // 2]) + for conv1, conv2, noise1, noise2, to_rgb in zip( + self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], + self.to_rgbs): + out = conv1(out, latent[:, i], noise=noise1) + out = conv2(out, latent[:, i + 1], noise=noise2) skip = to_rgb(out, latent[:, i + 2], skip) i += 2 - noise_i += 2 image = skip @@ -652,21 +661,106 @@ class ResBlock(nn.Module): return out +class FullGenerator(nn.Module): + + def __init__( + self, + size, + style_dim, + n_mlp, + channel_multiplier=2, + blur_kernel=[1, 3, 3, 1], + lr_mlp=0.01, + narrow=1, + ): + super().__init__() + channels = { + 4: int(512 * narrow), + 8: int(512 * narrow), + 16: int(512 * narrow), + 32: int(512 * narrow), + 64: int(256 * channel_multiplier * narrow), + 128: int(128 * channel_multiplier * narrow), + 256: int(64 * channel_multiplier * narrow), + 512: int(32 * channel_multiplier * narrow), + 1024: int(16 * channel_multiplier * narrow), + 2048: int(8 * channel_multiplier * narrow) + } + + self.log_size = int(math.log(size, 2)) + self.generator = Generator( + size, + style_dim, + n_mlp, + channel_multiplier=channel_multiplier, + blur_kernel=blur_kernel, + lr_mlp=lr_mlp, + narrow=narrow) + + conv = [ConvLayer(3, channels[size], 1)] + self.ecd0 = nn.Sequential(*conv) + in_channel = channels[size] + + self.names = ['ecd%d' % i for i in range(self.log_size - 1)] + for i in range(self.log_size, 2, -1): + out_channel = channels[2**(i - 1)] + conv = [ConvLayer(in_channel, out_channel, 3, downsample=True)] + setattr(self, self.names[self.log_size - i + 1], + nn.Sequential(*conv)) + in_channel = out_channel + self.final_linear = nn.Sequential( + EqualLinear( + channels[4] * 4 * 4, style_dim, activation='fused_lrelu')) + + def forward( + self, + inputs, + return_latents=False, + inject_index=None, + truncation=1, + truncation_latent=None, + input_is_latent=False, + ): + noise = [] + for i in range(self.log_size - 1): + ecd = getattr(self, self.names[i]) + inputs = ecd(inputs) + noise.append(inputs) + inputs = inputs.view(inputs.shape[0], -1) + outs = self.final_linear(inputs) + noise = list( + itertools.chain.from_iterable( + itertools.repeat(x, 2) for x in noise))[::-1] + outs = self.generator([outs], + return_latents, + inject_index, + truncation, + truncation_latent, + input_is_latent, + noise=noise[1:]) + return outs + + class Discriminator(nn.Module): - def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]): + def __init__(self, + size, + channel_multiplier=2, + blur_kernel=[1, 3, 3, 1], + narrow=1): super().__init__() channels = { - 4: 512, - 8: 512, - 16: 512, - 32: 512, - 64: 256 * channel_multiplier, - 128: 128 * channel_multiplier, - 256: 64 * channel_multiplier, - 512: 32 * channel_multiplier, - 1024: 16 * channel_multiplier, + 4: int(512 * narrow), + 8: int(512 * narrow), + 16: int(512 * narrow), + 32: int(512 * narrow), + 64: int(256 * channel_multiplier * narrow), + 128: int(128 * channel_multiplier * narrow), + 256: int(64 * channel_multiplier * narrow), + 512: int(32 * channel_multiplier * narrow), + 1024: int(16 * channel_multiplier * narrow), + 2048: int(8 * channel_multiplier * narrow) } convs = [ConvLayer(3, channels[size], 1)] @@ -713,48 +807,53 @@ class Discriminator(nn.Module): return out -class FullGenerator(nn.Module): +class FullGenerator_SR(nn.Module): def __init__( self, - size, + in_size, + out_size, style_dim, n_mlp, channel_multiplier=2, blur_kernel=[1, 3, 3, 1], lr_mlp=0.01, + narrow=1, ): super().__init__() channels = { - 4: 512 // ratio, - 8: 512 // ratio, - 16: 512 // ratio, - 32: 512 // ratio, - 64: 256 // ratio * channel_multiplier, - 128: 128 // ratio * channel_multiplier, - 256: 64 // ratio * channel_multiplier, - 512: 32 // ratio * channel_multiplier, - 1024: 16 // ratio * channel_multiplier, + 4: int(512 * narrow), + 8: int(512 * narrow), + 16: int(512 * narrow), + 32: int(512 * narrow), + 64: int(256 * channel_multiplier * narrow), + 128: int(128 * channel_multiplier * narrow), + 256: int(64 * channel_multiplier * narrow), + 512: int(32 * channel_multiplier * narrow), + 1024: int(16 * channel_multiplier * narrow), + 2048: int(8 * channel_multiplier * narrow), } - self.log_size = int(math.log(size, 2)) + self.log_insize = int(math.log(in_size, 2)) + self.log_outsize = int(math.log(out_size, 2)) self.generator = Generator( - size, + out_size, style_dim, n_mlp, channel_multiplier=channel_multiplier, blur_kernel=blur_kernel, - lr_mlp=lr_mlp) + lr_mlp=lr_mlp, + narrow=narrow) - conv = [ConvLayer(3, channels[size], 1)] + conv = [ConvLayer(3, channels[in_size], 1)] self.ecd0 = nn.Sequential(*conv) - in_channel = channels[size] + in_channel = channels[in_size] - self.names = ['ecd%d' % i for i in range(self.log_size - 1)] - for i in range(self.log_size, 2, -1): + self.names = ['ecd%d' % i for i in range(self.log_insize - 1)] + for i in range(self.log_insize, 2, -1): out_channel = channels[2**(i - 1)] conv = [ConvLayer(in_channel, out_channel, 3, downsample=True)] - setattr(self, self.names[self.log_size - i + 1], + setattr(self, self.names[self.log_insize - i + 1], nn.Sequential(*conv)) in_channel = out_channel self.final_linear = nn.Sequential( @@ -771,18 +870,22 @@ class FullGenerator(nn.Module): input_is_latent=False, ): noise = [] - for i in range(self.log_size - 1): + for i in range(self.log_outsize - self.log_insize): + noise.append(None) + for i in range(self.log_insize - 1): ecd = getattr(self, self.names[i]) inputs = ecd(inputs) noise.append(inputs) inputs = inputs.view(inputs.shape[0], -1) outs = self.final_linear(inputs) - outs = self.generator([outs], - return_latents, - inject_index, - truncation, - truncation_latent, - input_is_latent, - noise=noise[::-1]) - - return outs + noise = list( + itertools.chain.from_iterable( + itertools.repeat(x, 2) for x in noise))[::-1] + image, latent = self.generator([outs], + return_latents, + inject_index, + truncation, + truncation_latent, + input_is_latent, + noise=noise[1:]) + return image, latent diff --git a/modelscope/models/cv/image_face_fusion/facegan/op/__init__.py b/modelscope/models/cv/image_face_fusion/facegan/op/__init__.py index 74477cfb..d0918d92 100644 --- a/modelscope/models/cv/image_face_fusion/facegan/op/__init__.py +++ b/modelscope/models/cv/image_face_fusion/facegan/op/__init__.py @@ -1,4 +1,2 @@ -# The implementation is adopted from stylegan2-pytorch, made public available under the MIT License -# at https://github.com/rosinality/stylegan2-pytorch from .fused_act import FusedLeakyReLU, fused_leaky_relu from .upfirdn2d import upfirdn2d diff --git a/modelscope/models/cv/image_face_fusion/facelib/align_trans.py b/modelscope/models/cv/image_face_fusion/facelib/align_trans.py index 554b0e7c..0d7ebbb6 100644 --- a/modelscope/models/cv/image_face_fusion/facelib/align_trans.py +++ b/modelscope/models/cv/image_face_fusion/facelib/align_trans.py @@ -15,6 +15,77 @@ REFERENCE_FACIAL_POINTS = [[30.29459953, 51.69630051], DEFAULT_CROP_SIZE = (96, 112) +def _umeyama(src, dst, estimate_scale=True, scale=1.0): + """Estimate N-D similarity transformation with or without scaling. + Parameters + ---------- + src : (M, N) array + Source coordinates. + dst : (M, N) array + Destination coordinates. + estimate_scale : bool + Whether to estimate scaling factor. + Returns + ------- + T : (N + 1, N + 1) + The homogeneous similarity transformation matrix. The matrix contains + NaN values only if the problem is not well-conditioned. + References + ---------- + .. [1] "Least-squares estimation of transformation parameters between two + point patterns", Shinji Umeyama, PAMI 1991, :DOI:`10.1109/34.88573` + """ + + num = src.shape[0] + dim = src.shape[1] + + # Compute mean of src and dst. + src_mean = src.mean(axis=0) + dst_mean = dst.mean(axis=0) + + # Subtract mean from src and dst. + src_demean = src - src_mean + dst_demean = dst - dst_mean + + # Eq. (38). + A = dst_demean.T @ src_demean / num + + # Eq. (39). + d = np.ones((dim, ), dtype=np.double) + if np.linalg.det(A) < 0: + d[dim - 1] = -1 + + T = np.eye(dim + 1, dtype=np.double) + + U, S, V = np.linalg.svd(A) + + # Eq. (40) and (43). + rank = np.linalg.matrix_rank(A) + if rank == 0: + return np.nan * T + elif rank == dim - 1: + if np.linalg.det(U) * np.linalg.det(V) > 0: + T[:dim, :dim] = U @ V + else: + s = d[dim - 1] + d[dim - 1] = -1 + T[:dim, :dim] = U @ np.diag(d) @ V + d[dim - 1] = s + else: + T[:dim, :dim] = U @ np.diag(d) @ V + + if estimate_scale: + # Eq. (41) and (42). + scale = 1.0 / src_demean.var(axis=0).sum() * (S @ d) + else: + scale = scale + + T[:dim, dim] = dst_mean - scale * (T[:dim, :dim] @ src_mean.T) + T[:dim, :dim] *= scale + + return T, scale + + class FaceWarpException(Exception): def __str__(self): @@ -246,6 +317,65 @@ def warp_and_crop_face(src_img, return face_img +def warp_and_crop_face_enhance(src_img, + facial_pts, + reference_pts=None, + crop_size=(96, 112), + align_type='smilarity'): + if reference_pts is None: + if crop_size[0] == 96 and crop_size[1] == 112: + reference_pts = REFERENCE_FACIAL_POINTS + else: + default_square = False + inner_padding_factor = 0 + outer_padding = (0, 0) + output_size = crop_size + + reference_pts = get_reference_facial_points( + output_size, inner_padding_factor, outer_padding, + default_square) + + ref_pts = np.float32(reference_pts) + ref_pts_shp = ref_pts.shape + if max(ref_pts_shp) < 3 or min(ref_pts_shp) != 2: + raise FaceWarpException( + 'reference_pts.shape must be (K,2) or (2,K) and K>2') + + if ref_pts_shp[0] == 2: + ref_pts = ref_pts.T + + src_pts = np.float32(facial_pts) + src_pts_shp = src_pts.shape + if max(src_pts_shp) < 3 or min(src_pts_shp) != 2: + raise FaceWarpException( + 'facial_pts.shape must be (K,2) or (2,K) and K>2') + + if src_pts_shp[0] == 2: + src_pts = src_pts.T + + if src_pts.shape != ref_pts.shape: + raise FaceWarpException( + 'facial_pts and reference_pts must have the same shape') + + if align_type == 'cv2_affine': + tfm = cv2.getAffineTransform(src_pts[0:3], ref_pts[0:3]) + tfm_inv = cv2.getAffineTransform(ref_pts[0:3], src_pts[0:3]) + elif align_type == 'affine': + tfm = get_affine_transform_matrix(src_pts, ref_pts) + tfm_inv = get_affine_transform_matrix(ref_pts, src_pts) + else: + params, scale = _umeyama(src_pts, ref_pts) + tfm = params[:2, :] + + params, _ = _umeyama(ref_pts, src_pts, False, scale=1.0 / scale) + tfm_inv = params[:2, :] + + face_img = cv2.warpAffine( + src_img, tfm, (crop_size[0], crop_size[1]), flags=3) + + return face_img, tfm_inv + + def get_f5p(landmarks, np_img): eye_left = find_pupil(landmarks[36:41], np_img) eye_right = find_pupil(landmarks[42:47], np_img) diff --git a/modelscope/models/cv/image_face_fusion/image_face_fusion.py b/modelscope/models/cv/image_face_fusion/image_face_fusion.py index 24907ceb..b7a876b0 100644 --- a/modelscope/models/cv/image_face_fusion/image_face_fusion.py +++ b/modelscope/models/cv/image_face_fusion/image_face_fusion.py @@ -16,9 +16,10 @@ from modelscope.models.builder import MODELS from modelscope.models.cv.face_detection.peppa_pig_face.facer import FaceAna from modelscope.utils.constant import ModelFile, Tasks from modelscope.utils.logger import get_logger -from .facegan.gan_wrap import GANWrap +from .facegan.face_gan import GPEN from .facelib.align_trans import (get_f5p, get_reference_facial_points, - warp_and_crop_face) + warp_and_crop_face, + warp_and_crop_face_enhance) from .network.aei_flow_net import AEI_Net from .network.bfm import ParametricFaceModel from .network.facerecon_model import ReconNetWrapper @@ -78,14 +79,6 @@ class ImageFaceFusion(TorchModel): self.face_model = ParametricFaceModel(bfm_folder=bfm_dir) self.face_model.to(self.device) - face_enhance_path = os.path.join(model_dir, 'faceEnhance', - '350000-Ns256.pt') - self.ganwrap = GANWrap( - model_path=face_enhance_path, - size=256, - channel_multiplier=1, - device=self.device) - self.facer = FaceAna(model_dir) logger.info('load facefusion models done') @@ -94,6 +87,27 @@ class ImageFaceFusion(TorchModel): self.mask_init = cv2.resize(self.mask_init, (256, 256)) self.mask = self.image_transform(self.mask_init, is_norm=False) + face_enhance_path = os.path.join(model_dir, 'faceEnhance', + 'GPEN-BFR-1024.pth') + + if not os.path.exists(face_enhance_path): + logger.warning( + 'model path not found, please update the latest model!') + + self.ganwrap_1024 = GPEN(face_enhance_path, 1024, 2, self.device) + + self.mask_enhance = np.zeros((512, 512), np.float32) + cv2.rectangle(self.mask_enhance, (26, 26), (486, 486), (1, 1, 1), -1, + cv2.LINE_AA) + self.mask_enhance = cv2.GaussianBlur(self.mask_enhance, (101, 101), 11) + self.mask_enhance = cv2.GaussianBlur(self.mask_enhance, (101, 101), 11) + + default_square = True + inner_padding_factor = 0.25 + outer_padding = (0, 0) + self.reference_5pts_1024 = get_reference_facial_points( + (1024, 1024), inner_padding_factor, outer_padding, default_square) + self.test_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) @@ -157,7 +171,7 @@ class ImageFaceFusion(TorchModel): src_h, src_w, _ = img.shape boxes, landmarks, _ = self.facer.run(img) if boxes.shape[0] == 0: - return None + return None, None, None elif boxes.shape[0] > 1: max_area = 0 max_index = 0 @@ -168,9 +182,14 @@ class ImageFaceFusion(TorchModel): if area > max_area: max_index = i max_area = area - return landmarks[max_index] + + fw = boxes[max_index][2] - boxes[max_index][0] + fh = boxes[max_index][3] - boxes[max_index][1] + return landmarks[max_index], fw, fh else: - return landmarks[0] + fw = boxes[0][2] - boxes[0][0] + fh = boxes[0][3] - boxes[0][1] + return landmarks[0], fw, fh def compute_3d_params(self, Xs, Xt): kp_fuse = {} @@ -198,6 +217,51 @@ class ImageFaceFusion(TorchModel): return kp_fuse, kp_t + def process_enhance(self, im, f5p, fh, fw): + height, width, _ = im.shape + + of, tfm_inv = warp_and_crop_face_enhance( + im, + f5p, + reference_pts=self.reference_5pts_1024, + crop_size=(1024, 1024)) + ef, pred = self.ganwrap_1024.process(of) + + tmp_mask = self.mask_enhance + tmp_mask = cv2.resize(tmp_mask, ef.shape[:2]) + tmp_mask = cv2.warpAffine(tmp_mask, tfm_inv, (width, height), flags=3) + + full_mask = np.zeros((height, width), dtype=np.float32) + full_img = np.zeros(im.shape, dtype=np.uint8) + + if min(fh, fw) < 40: + ef = cv2.pyrDown(ef) + ef = cv2.pyrDown(ef) + ef = cv2.pyrUp(ef) + ef = cv2.pyrUp(ef) + elif min(fh, fw) < 60: + ef = cv2.pyrDown(ef) + ef = cv2.resize(ef, (0, 0), fx=2, fy=2) + ef = cv2.resize(ef, (0, 0), fx=0.5, fy=0.5) + ef = cv2.pyrUp(ef) + elif min(fh, fw) < 80: + ef = cv2.pyrDown(ef) + ef = cv2.pyrUp(ef) + elif min(fh, fw) < 100: + ef = cv2.pyrDown(ef) + ef = cv2.resize(ef, (0, 0), fx=2, fy=2) + + tmp_img = cv2.warpAffine(ef, tfm_inv, (width, height), flags=3) + + mask = tmp_mask - full_mask + full_mask[np.where(mask > 0)] = tmp_mask[np.where(mask > 0)] + full_img[np.where(mask > 0)] = tmp_img[np.where(mask > 0)] + + full_mask = full_mask[:, :, np.newaxis] + im = cv2.convertScaleAbs(im * (1 - full_mask) + full_img * full_mask) + im = cv2.resize(im, (width, height)) + return im + def inference(self, template_img, user_img): ori_h, ori_w, _ = template_img.shape @@ -205,14 +269,14 @@ class ImageFaceFusion(TorchModel): user_img = user_img.cpu().numpy() user_img_bgr = user_img[:, :, ::-1] - landmark_source = self.detect_face(user_img) + landmark_source, _, _ = self.detect_face(user_img) if landmark_source is None: logger.warning('No face detected in user image!') return template_img f5p_user = get_f5p(landmark_source, user_img_bgr) template_img_bgr = template_img[:, :, ::-1] - landmark_template = self.detect_face(template_img) + landmark_template, fw, fh = self.detect_face(template_img) if landmark_template is None: logger.warning('No face detected in template image!') return template_img @@ -235,7 +299,6 @@ class ImageFaceFusion(TorchModel): with torch.no_grad(): kp_fuse, kp_t = self.compute_3d_params(Xs, Xt) Yt, _, _ = self.netG(Xt, Xs_embeds, kp_fuse, kp_t) - Yt = self.ganwrap.process_tensor(Yt) Yt = Yt * 0.5 + 0.5 Yt = torch.clamp(Yt, 0, 1) @@ -247,6 +310,7 @@ class ImageFaceFusion(TorchModel): 0).cpu().numpy() Yt_trans_inv = Yt_trans_inv.astype(np.float32) out_img = Yt_trans_inv[:, :, ::-1] * 255. + out_img = self.process_enhance(out_img, f5p_template, fh, fw) logger.info('model inference done') diff --git a/modelscope/models/multi_modal/efficient_diffusion_tuning/efficient_stable_diffusion.py b/modelscope/models/multi_modal/efficient_diffusion_tuning/efficient_stable_diffusion.py index 3260b61a..f253ebbe 100644 --- a/modelscope/models/multi_modal/efficient_diffusion_tuning/efficient_stable_diffusion.py +++ b/modelscope/models/multi_modal/efficient_diffusion_tuning/efficient_stable_diffusion.py @@ -15,6 +15,7 @@ from diffusers.models import cross_attention from diffusers.utils import deprecation_utils from transformers import CLIPTextModel, CLIPTokenizer +from modelscope import snapshot_download from modelscope.metainfo import Models from modelscope.models import TorchModel from modelscope.models.builder import MODELS @@ -56,7 +57,10 @@ class EfficientStableDiffusion(TorchModel): super().__init__(model_dir, *args, **kwargs) tuner_name = kwargs.pop('tuner_name', 'lora') pretrained_model_name_or_path = kwargs.pop( - 'pretrained_model_name_or_path', 'runwayml/stable-diffusion-v1-5') + 'pretrained_model_name_or_path', + 'AI-ModelScope/stable-diffusion-v1-5') + pretrained_model_name_or_path = snapshot_download( + pretrained_model_name_or_path) tuner_config = kwargs.pop('tuner_config', None) pretrained_tuner = kwargs.pop('pretrained_tuner', None) revision = kwargs.pop('revision', None) diff --git a/modelscope/models/multi_modal/image_to_video/__init__.py b/modelscope/models/multi_modal/image_to_video/__init__.py new file mode 100644 index 00000000..e78958d1 --- /dev/null +++ b/modelscope/models/multi_modal/image_to_video/__init__.py @@ -0,0 +1,24 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + + from .image_to_video_model import ImageToVideo + +else: + _import_structure = { + 'image_to_video_model': ['ImageToVideo'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/multi_modal/image_to_video/image_to_video_model.py b/modelscope/models/multi_modal/image_to_video/image_to_video_model.py new file mode 100755 index 00000000..5c79053c --- /dev/null +++ b/modelscope/models/multi_modal/image_to_video/image_to_video_model.py @@ -0,0 +1,215 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +import os.path as osp +import random +from copy import copy +from typing import Any, Dict + +import torch +import torch.cuda.amp as amp + +import modelscope.models.multi_modal.image_to_video.utils.transforms as data +from modelscope.metainfo import Models +from modelscope.models.base import TorchModel +from modelscope.models.builder import MODELS +from modelscope.models.multi_modal.image_to_video.modules import * +from modelscope.models.multi_modal.image_to_video.modules import ( + AutoencoderKL, FrozenOpenCLIPVisualEmbedder, Img2VidSDUNet) +from modelscope.models.multi_modal.image_to_video.utils.config import cfg +from modelscope.models.multi_modal.image_to_video.utils.diffusion import \ + GaussianDiffusion +from modelscope.models.multi_modal.image_to_video.utils.seed import setup_seed +from modelscope.models.multi_modal.image_to_video.utils.shedule import \ + beta_schedule +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +__all__ = ['ImageToVideo'] + +logger = get_logger() + + +@MODELS.register_module( + Tasks.image_to_video, module_name=Models.image_to_video_model) +class ImageToVideo(TorchModel): + r""" + Image2Video aims to solve the task of generating high-definition videos based on input images. + Image2Video is a video generation basic model developed by Alibaba Cloud, with a parameter size + of approximately 2 billion. It has been pre trained on large-scale video and image data and + fine-tuned on a small amount of high-quality data. The data is widely distributed and diverse + in categories, and the model has good generalization ability for different types of data + + Paper link: https://arxiv.org/abs/2306.02018 + + Attributes: + diffusion: diffusion model for DDIM. + autoencoder: decode the latent representation into visual space. + clip_encoder: encode the image into image embedding. + """ + + def __init__(self, model_dir, *args, **kwargs): + r""" + Args: + model_dir (`str` or `os.PathLike`) + Can be either: + - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co + or modelscope.cn. Valid model ids can be located at the root-level, like `bert-base-uncased`, + or namespaced under a user or organization name, like `dbmdz/bert-base-german-cased`. + - A path to a *directory* containing model weights saved using + [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. + - A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In + this case, `from_tf` should be set to `True` and a configuration object should be provided as + `config` argument. This loading path is slower than converting the TensorFlow checkpoint in a + PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards. + - A path or url to a model folder containing a *flax checkpoint file* in *.msgpack* format (e.g, + `./flax_model/` containing `flax_model.msgpack`). In this case, `from_flax` should be set to + `True`. + """ + super().__init__(model_dir=model_dir, *args, **kwargs) + + self.config = Config.from_file( + osp.join(model_dir, ModelFile.CONFIGURATION)) + + # assign default value + cfg.batch_size = self.config.model.model_cfg.batch_size + cfg.target_fps = self.config.model.model_cfg.target_fps + cfg.max_frames = self.config.model.model_cfg.max_frames + cfg.latent_hei = self.config.model.model_cfg.latent_hei + cfg.latent_wid = self.config.model.model_cfg.latent_wid + cfg.model_path = osp.join(model_dir, + self.config.model.model_args.ckpt_unet) + + self.device = torch.device( + 'cuda') if torch.cuda.is_available() else torch.device('cpu') + + if 'seed' in self.config.model.model_args.keys(): + cfg.seed = self.config.model.model_args.seed + else: + cfg.seed = random.randint(0, 99999) + setup_seed(cfg.seed) + + # transform + vid_trans = data.Compose([ + data.CenterCropWide(size=(cfg.resolution[0], cfg.resolution[0])), + data.Resize(cfg.vit_resolution), + data.ToTensor(), + data.Normalize(mean=cfg.vit_mean, std=cfg.vit_std) + ]) + self.vid_trans = vid_trans + + cfg.embedder.pretrained = osp.join( + model_dir, self.config.model.model_args.ckpt_clip) + clip_encoder = FrozenOpenCLIPVisualEmbedder(**cfg.embedder) + clip_encoder.model.to(self.device) + self.clip_encoder = clip_encoder + logger.info(f'Build encoder with {cfg.embedder.type}') + + # [unet] + generator = Img2VidSDUNet(**cfg.UNet) + generator = generator.to(self.device) + generator.eval() + load_dict = torch.load(cfg.model_path, map_location='cpu') + ret = generator.load_state_dict(load_dict['state_dict'], strict=True) + self.generator = generator + logger.info('Load model {} path {}, with local status {}'.format( + cfg.UNet.type, cfg.model_path, ret)) + + # [diffusion] + betas = beta_schedule( + 'linear_sd', + cfg.num_timesteps, + init_beta=0.00085, + last_beta=0.0120) + diffusion = GaussianDiffusion( + betas=betas, + mean_type=cfg.mean_type, + var_type=cfg.var_type, + loss_type=cfg.loss_type, + rescale_timesteps=False, + noise_strength=getattr(cfg, 'noise_strength', 0)) + self.diffusion = diffusion + logger.info('Build diffusion with type of GaussianDiffusion') + + # [auotoencoder] + cfg.auto_encoder.pretrained = osp.join( + model_dir, self.config.model.model_args.ckpt_autoencoder) + autoencoder = AutoencoderKL(**cfg.auto_encoder) + autoencoder.eval() + for param in autoencoder.parameters(): + param.requires_grad = False + autoencoder.to(self.device) + self.autoencoder = autoencoder + torch.cuda.empty_cache() + + zero_feature = torch.zeros(1, 1, cfg.UNet.input_dim).to(self.device) + self.zero_feature = zero_feature + self.fps_tensor = torch.tensor([cfg.target_fps], + dtype=torch.long, + device=self.device) + self.cfg = cfg + + def forward(self, input: Dict[str, Any]): + r""" + The entry function of image to video task. + 1. Using diffusion model to generate the video's latent representation. + 2. Using autoencoder to decode the video's latent representation to visual space. + + Args: + input (`Dict[Str, Any]`): + The input of the task + Returns: + A generated video (as pytorch tensor). + """ + + vit_frame = input['vit_frame'] + cfg = self.cfg + + img_embedding = self.clip_encoder(vit_frame).unsqueeze(1) + + noise = self.build_noise() + zero_feature = copy(self.zero_feature) + with torch.no_grad(): + with amp.autocast(enabled=cfg.use_fp16): + model_kwargs = [{ + 'y': img_embedding, + 'fps': self.fps_tensor + }, { + 'y': zero_feature.repeat(cfg.batch_size, 1, 1), + 'fps': self.fps_tensor + }] + gen_video = self.diffusion.ddim_sample_loop( + noise=noise, + model=self.generator, + model_kwargs=model_kwargs, + guide_scale=cfg.guide_scale, + ddim_timesteps=cfg.ddim_timesteps, + eta=0.0) + + gen_video = 1. / cfg.scale_factor * gen_video + gen_video = rearrange(gen_video, 'b c f h w -> (b f) c h w') + chunk_size = min(cfg.decoder_bs, gen_video.shape[0]) + gen_video_list = torch.chunk( + gen_video, gen_video.shape[0] // chunk_size, dim=0) + decode_generator = [] + for vd_data in gen_video_list: + gen_frames = self.autoencoder.decode(vd_data) + decode_generator.append(gen_frames) + + gen_video = torch.cat(decode_generator, dim=0) + gen_video = rearrange( + gen_video, '(b f) c h w -> b c f h w', b=cfg.batch_size) + + return gen_video.type(torch.float32).cpu() + + def build_noise(self): + cfg = self.cfg + noise = torch.randn( + [1, 4, cfg.max_frames, cfg.latent_hei, + cfg.latent_wid]).to(self.device) + if cfg.noise_strength > 0: + b, c, f, *_ = noise.shape + offset_noise = torch.randn(b, c, f, 1, 1, device=noise.device) + noise = noise + cfg.noise_strength * offset_noise + return noise.contiguous() diff --git a/modelscope/models/multi_modal/image_to_video/modules/__init__.py b/modelscope/models/multi_modal/image_to_video/modules/__init__.py new file mode 100755 index 00000000..bc69102b --- /dev/null +++ b/modelscope/models/multi_modal/image_to_video/modules/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from .autoencoder import * +from .embedder import * +from .unet_i2v import * diff --git a/modelscope/models/multi_modal/image_to_video/modules/autoencoder.py b/modelscope/models/multi_modal/image_to_video/modules/autoencoder.py new file mode 100755 index 00000000..935134bc --- /dev/null +++ b/modelscope/models/multi_modal/image_to_video/modules/autoencoder.py @@ -0,0 +1,573 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import collections +import logging + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def nonlinearity(x): + # swish + return x * torch.sigmoid(x) + + +def Normalize(in_channels, num_groups=32): + return torch.nn.GroupNorm( + num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + + +class DiagonalGaussianDistribution(object): + + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like( + self.mean).to(device=self.parameters.device) + + def sample(self): + x = self.mean + self.std * torch.randn( + self.mean.shape).to(device=self.parameters.device) + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.]) + else: + if other is None: + return 0.5 * torch.sum( + torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, + dim=[1, 2, 3]) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var - 1.0 - self.logvar + other.logvar, + dim=[1, 2, 3]) + + def nll(self, sample, dims=[1, 2, 3]): + if self.deterministic: + return torch.Tensor([0.]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + + torch.pow(sample - self.mean, 2) / self.var, + dim=dims) + + def mode(self): + return self.mean + + +class ResnetBlock(nn.Module): + + def __init__(self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout, + temb_channels=512): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d( + out_channels, out_channels, kernel_size=3, stride=1, padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + else: + self.nin_shortcut = torch.nn.Conv2d( + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x + h + + +class AttnBlock(nn.Module): + + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = q.reshape(b, c, h * w) + q = q.permute(0, 2, 1) + k = k.reshape(b, c, h * w) + w_ = torch.bmm(q, k) + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h * w) + w_ = w_.permute(0, 2, 1) + h_ = torch.bmm(v, w_) + h_ = h_.reshape(b, c, h, w) + + h_ = self.proj_out(h_) + + return x + h_ + + +class Upsample(nn.Module): + + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x): + x = torch.nn.functional.interpolate( + x, scale_factor=2.0, mode='nearest') + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=2, padding=0) + + def forward(self, x): + if self.with_conv: + pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, pad, mode='constant', value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class Encoder(nn.Module): + + def __init__(self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + double_z=True, + use_linear_attn=False, + attn_type='vanilla', + **ignore_kwargs): + super().__init__() + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # downsampling + self.conv_in = torch.nn.Conv2d( + in_channels, self.ch, kernel_size=3, stride=1, padding=1) + + curr_res = resolution + in_ch_mult = (1, ) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, + 2 * z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + # timestep embedding + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + + def __init__(self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + give_pre_end=False, + tanh_out=False, + use_linear_attn=False, + attn_type='vanilla', + **ignorekwargs): + super().__init__() + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + self.tanh_out = tanh_out + + # compute block_in and curr_res at lowest res + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2**(self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + logging.info('Working with z of shape {} = {} dimensions.'.format( + self.z_shape, np.prod(self.z_shape))) + + # z to block_in + self.conv_in = torch.nn.Conv2d( + z_channels, block_in, kernel_size=3, stride=1, padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, out_ch, kernel_size=3, stride=1, padding=1) + + def forward(self, z): + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h, temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + if self.tanh_out: + h = torch.tanh(h) + return h + + +class AutoencoderKL(nn.Module): + + def __init__(self, + ddconfig, + embed_dim, + pretrained=None, + ignore_keys=[], + image_key='image', + colorize_nlabels=None, + monitor=None, + ema_decay=None, + learn_logvar=False, + **kwargs): + super().__init__() + self.learn_logvar = learn_logvar + self.image_key = image_key + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + assert ddconfig['double_z'] + self.quant_conv = torch.nn.Conv2d(2 * ddconfig['z_channels'], + 2 * embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, + ddconfig['z_channels'], 1) + self.embed_dim = embed_dim + if colorize_nlabels is not None: + assert type(colorize_nlabels) == int + self.register_buffer('colorize', + torch.randn(3, colorize_nlabels, 1, 1)) + if monitor is not None: + self.monitor = monitor + + self.use_ema = ema_decay is not None + + if pretrained is not None: + self.init_from_ckpt(pretrained, ignore_keys=ignore_keys) + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location='cpu')['state_dict'] + keys = list(sd.keys()) + sd_new = collections.OrderedDict() + for k in keys: + if k.find('first_stage_model') >= 0: + k_new = k.split('first_stage_model.')[-1] + sd_new[k_new] = sd[k] + self.load_state_dict(sd_new, strict=True) + logging.info(f'Restored from {path}') + + def on_train_batch_end(self, *args, **kwargs): + if self.use_ema: + self.model_ema(self) + + def encode(self, x): + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def decode(self, z): + z = self.post_quant_conv(z) + dec = self.decoder(z) + return dec + + def forward(self, input, sample_posterior=True): + posterior = self.encode(input) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + dec = self.decode(z) + return dec, posterior + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = x.permute(0, 3, 1, + 2).to(memory_format=torch.contiguous_format).float() + return x + + def get_last_layer(self): + return self.decoder.conv_out.weight + + @torch.no_grad() + def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + if not only_inputs: + xrec, posterior = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec.shape[1] > 3 + x = self.to_rgb(x) + xrec = self.to_rgb(xrec) + log['samples'] = self.decode(torch.randn_like(posterior.sample())) + log['reconstructions'] = xrec + if log_ema or self.use_ema: + with self.ema_scope(): + xrec_ema, posterior_ema = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec_ema.shape[1] > 3 + xrec_ema = self.to_rgb(xrec_ema) + log['samples_ema'] = self.decode( + torch.randn_like(posterior_ema.sample())) + log['reconstructions_ema'] = xrec_ema + log['inputs'] = x + return log + + def to_rgb(self, x): + assert self.image_key == 'segmentation' + if not hasattr(self, 'colorize'): + self.register_buffer('colorize', + torch.randn(3, x.shape[1], 1, 1).to(x)) + x = F.conv2d(x, weight=self.colorize) + x = 2. * (x - x.min()) / (x.max() - x.min()) - 1. + return x + + +class IdentityFirstStage(torch.nn.Module): + + def __init__(self, *args, vq_interface=False, **kwargs): + self.vq_interface = vq_interface + super().__init__() + + def encode(self, x, *args, **kwargs): + return x + + def decode(self, x, *args, **kwargs): + return x + + def quantize(self, x, *args, **kwargs): + if self.vq_interface: + return x, None, [None, None, None] + return x + + def forward(self, x, *args, **kwargs): + return x diff --git a/modelscope/models/multi_modal/image_to_video/modules/embedder.py b/modelscope/models/multi_modal/image_to_video/modules/embedder.py new file mode 100755 index 00000000..39063a57 --- /dev/null +++ b/modelscope/models/multi_modal/image_to_video/modules/embedder.py @@ -0,0 +1,82 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import logging +import os + +import numpy as np +import open_clip +import torch +import torch.nn as nn +import torchvision.transforms as T + + +class FrozenOpenCLIPVisualEmbedder(nn.Module): + """ + Uses the OpenCLIP transformer encoder for text + """ + LAYERS = ['last', 'penultimate'] + + def __init__(self, + pretrained, + vit_resolution=(224, 224), + arch='ViT-H-14', + device='cuda', + max_length=77, + freeze=True, + layer='last', + **kwargs): + super().__init__() + assert layer in self.LAYERS + model, _, preprocess = open_clip.create_model_and_transforms( + arch, device=torch.device('cpu'), pretrained=pretrained) + + del model.transformer + self.model = model + data_white = np.ones( + (vit_resolution[0], vit_resolution[1], 3), dtype=np.uint8) * 255 + self.white_image = preprocess(T.ToPILImage()(data_white)).unsqueeze(0) + + self.device = device + self.max_length = max_length + if freeze: + self.freeze() + self.layer = layer + if self.layer == 'last': + self.layer_idx = 0 + elif self.layer == 'penultimate': + self.layer_idx = 1 + else: + raise NotImplementedError() + + def freeze(self): + self.model = self.model.eval() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, image): + z = self.model.encode_image(image.to(self.device)) + return z + + def encode_with_transformer(self, text): + x = self.model.token_embedding(text) + x = x + self.model.positional_embedding + x = x.permute(1, 0, 2) + x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) + x = x.permute(1, 0, 2) + x = self.model.ln_final(x) + + return x + + def text_transformer_forward(self, x: torch.Tensor, attn_mask=None): + for i, r in enumerate(self.model.transformer.resblocks): + if i == len(self.model.transformer.resblocks) - self.layer_idx: + break + if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting( + ): + x = checkpoint(r, x, attn_mask) + else: + x = r(x, attn_mask=attn_mask) + return x + + def encode(self, text): + return self(text) diff --git a/modelscope/models/multi_modal/image_to_video/modules/unet_i2v.py b/modelscope/models/multi_modal/image_to_video/modules/unet_i2v.py new file mode 100644 index 00000000..dae226a5 --- /dev/null +++ b/modelscope/models/multi_modal/image_to_video/modules/unet_i2v.py @@ -0,0 +1,1504 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import math +import os + +import torch +import torch.nn as nn +import torch.nn.functional as F +import xformers +import xformers.ops +from einops import rearrange +from fairscale.nn.checkpoint import checkpoint_wrapper +from rotary_embedding_torch import RotaryEmbedding + +USE_TEMPORAL_TRANSFORMER = True + + +def sinusoidal_embedding(timesteps, dim): + # check input + half = dim // 2 + timesteps = timesteps.float() + + # compute sinusoidal embedding + sinusoid = torch.outer( + timesteps, torch.pow(10000, + -torch.arange(half).to(timesteps).div(half))) + x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) + if dim % 2 != 0: + x = torch.cat([x, torch.zeros_like(x[:, :1])], dim=1) + return x + + +def exists(x): + return x is not None + + +def default(val, d): + if exists(val): + return val + return d() if callable(d) else d + + +def prob_mask_like(shape, prob, device): + if prob == 1: + return torch.ones(shape, device=device, dtype=torch.bool) + elif prob == 0: + return torch.zeros(shape, device=device, dtype=torch.bool) + else: + mask = torch.zeros(shape, device=device).float().uniform_(0, 1) < prob + # aviod mask all, which will cause find_unused_parameters error + if mask.all(): + mask[0] = False + return mask + + +class MemoryEfficientCrossAttention(nn.Module): + + def __init__(self, + query_dim, + context_dim=None, + heads=8, + dim_head=64, + dropout=0.0): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.heads = heads + self.dim_head = dim_head + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) + self.attention_op: Optional[Any] = None + + def forward(self, x, context=None, mask=None): + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + b, _, _ = q.shape + q, k, v = map( + lambda t: t.unsqueeze(3).reshape(b, t.shape[ + 1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape( + b * self.heads, t.shape[1], self.dim_head).contiguous(), + (q, k, v), + ) + + out = xformers.ops.memory_efficient_attention( + q, k, v, attn_bias=None, op=self.attention_op) + + if exists(mask): + raise NotImplementedError + out = ( + out.unsqueeze(0).reshape( + b, self.heads, out.shape[1], + self.dim_head).permute(0, 2, 1, + 3).reshape(b, out.shape[1], + self.heads * self.dim_head)) + return self.to_out(out) + + +class RelativePositionBias(nn.Module): + + def __init__(self, heads=8, num_buckets=32, max_distance=128): + super().__init__() + self.num_buckets = num_buckets + self.max_distance = max_distance + self.relative_attention_bias = nn.Embedding(num_buckets, heads) + + @staticmethod + def _relative_position_bucket(relative_position, + num_buckets=32, + max_distance=128): + ret = 0 + n = -relative_position + + num_buckets //= 2 + ret += (n < 0).long() * num_buckets + n = torch.abs(n) + + max_exact = num_buckets // 2 + is_small = n < max_exact + + val_if_large = max_exact + ( + torch.log(n.float() / max_exact) + / math.log(max_distance / max_exact) * # noqa + (num_buckets - max_exact)).long() + val_if_large = torch.min( + val_if_large, torch.full_like(val_if_large, num_buckets - 1)) + + ret += torch.where(is_small, n, val_if_large) + return ret + + def forward(self, n, device): + q_pos = torch.arange(n, dtype=torch.long, device=device) + k_pos = torch.arange(n, dtype=torch.long, device=device) + rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1') + rp_bucket = self._relative_position_bucket( + rel_pos, + num_buckets=self.num_buckets, + max_distance=self.max_distance) + values = self.relative_attention_bias(rp_bucket) + return rearrange(values, 'i j h -> h i j') + + +class SpatialTransformer(nn.Module): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + NEW: use_linear for more efficiency instead of the 1x1 convs + """ + + def __init__(self, + in_channels, + n_heads, + d_head, + depth=1, + dropout=0., + context_dim=None, + disable_self_attn=False, + use_linear=False, + use_checkpoint=True): + super().__init__() + if exists(context_dim) and not isinstance(context_dim, list): + context_dim = [context_dim] + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = torch.nn.GroupNorm( + num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + if not use_linear: + self.proj_in = nn.Conv2d( + in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + else: + self.proj_in = nn.Linear(in_channels, inner_dim) + + self.transformer_blocks = nn.ModuleList([ + BasicTransformerBlock( + inner_dim, + n_heads, + d_head, + dropout=dropout, + context_dim=context_dim[d], + disable_self_attn=disable_self_attn, + checkpoint=use_checkpoint) for d in range(depth) + ]) + if not use_linear: + self.proj_out = zero_module( + nn.Conv2d( + inner_dim, in_channels, kernel_size=1, stride=1, + padding=0)) + else: + self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) + self.use_linear = use_linear + + def forward(self, x, context=None): + # note: if no context is given, cross-attention defaults to self-attention + if not isinstance(context, list): + context = [context] + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + if not self.use_linear: + x = self.proj_in(x) + x = rearrange(x, 'b c h w -> b (h w) c').contiguous() + if self.use_linear: + x = self.proj_in(x) + for i, block in enumerate(self.transformer_blocks): + x = block(x, context=context[i]) + if self.use_linear: + x = self.proj_out(x) + x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() + if not self.use_linear: + x = self.proj_out(x) + return x + x_in + + +_ATTN_PRECISION = os.environ.get('ATTN_PRECISION', 'fp32') + + +class CrossAttention(nn.Module): + + def __init__(self, + query_dim, + context_dim=None, + heads=8, + dim_head=64, + dropout=0.): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.scale = dim_head**-0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) + + def forward(self, x, context=None, mask=None): + h = self.heads + + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), + (q, k, v)) + + # force cast to fp32 to avoid overflowing + if _ATTN_PRECISION == 'fp32': + with torch.autocast(enabled=False, device_type='cuda'): + q, k = q.float(), k.float() + sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale + else: + sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale + + del q, k + + if exists(mask): + mask = rearrange(mask, 'b ... -> b (...)') + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h=h) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + sim = sim.softmax(dim=-1) + + out = torch.einsum('b i j, b j d -> b i d', sim, v) + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + return self.to_out(out) + + +class BasicTransformerBlock(nn.Module): + + def __init__(self, + dim, + n_heads, + d_head, + dropout=0., + context_dim=None, + gated_ff=True, + checkpoint=True, + disable_self_attn=False): + super().__init__() + + attn_cls = MemoryEfficientCrossAttention + self.disable_self_attn = disable_self_attn + self.attn1 = attn_cls( + query_dim=dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout, + context_dim=context_dim if self.disable_self_attn else None) + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = attn_cls( + query_dim=dim, + context_dim=context_dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout) + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + + def forward_(self, x, context=None): + return checkpoint(self._forward, (x, context), self.parameters(), + self.checkpoint) + + def forward(self, x, context=None): + x = self.attn1( + self.norm1(x), + context=context if self.disable_self_attn else None) + x + x = self.attn2(self.norm2(x), context=context) + x + x = self.ff(self.norm3(x)) + x + return x + + +# feedforward +class GEGLU(nn.Module): + + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +class FeedForward(nn.Module): + + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = nn.Sequential(nn.Linear( + dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential(project_in, nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out)) + + def forward(self, x): + return self.net(x) + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, + channels, + use_conv, + dims=2, + out_channels=None, + padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = nn.Conv2d( + self.channels, self.out_channels, 3, padding=padding) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate( + x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), + mode='nearest') + else: + x = F.interpolate(x, scale_factor=2, mode='nearest') + if self.use_conv: + x = self.conv(x) + return x + + +class ResBlock(nn.Module): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param use_checkpoint: if True, use gradient checkpointing on this module. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + up=False, + down=False, + use_temporal_conv=True, + use_image_dataset=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_scale_shift_norm = use_scale_shift_norm + self.use_temporal_conv = use_temporal_conv + + self.in_layers = nn.Sequential( + nn.GroupNorm(32, channels), + nn.SiLU(), + nn.Conv2d(channels, self.out_channels, 3, padding=1), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + nn.SiLU(), + nn.Linear( + emb_channels, + 2 * self.out_channels + if use_scale_shift_norm else self.out_channels, + ), + ) + self.out_layers = nn.Sequential( + nn.GroupNorm(32, self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module( + nn.Conv2d(self.out_channels, self.out_channels, 3, padding=1)), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, 3, padding=1) + else: + self.skip_connection = nn.Conv2d(channels, self.out_channels, 1) + + if self.use_temporal_conv: + self.temopral_conv = TemporalConvBlock_v2( + self.out_channels, + self.out_channels, + dropout=0.1, + use_image_dataset=use_image_dataset) + + def forward(self, x, emb, batch_size): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + return self._forward(x, emb, batch_size) + + def _forward(self, x, emb, batch_size): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = th.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + h = self.skip_connection(x) + h + + if self.use_temporal_conv: + h = rearrange(h, '(b f) c h w -> b c f h w', b=batch_size) + h = self.temopral_conv(h) + h = rearrange(h, 'b c f h w -> (b f) c h w') + return h + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, + channels, + use_conv, + dims=2, + out_channels=None, + padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = nn.Conv2d( + self.channels, + self.out_channels, + 3, + stride=stride, + padding=padding) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class Resample(nn.Module): + + def __init__(self, in_dim, out_dim, mode): + assert mode in ['none', 'upsample', 'downsample'] + super(Resample, self).__init__() + self.in_dim = in_dim + self.out_dim = out_dim + self.mode = mode + + def forward(self, x, reference=None): + if self.mode == 'upsample': + assert reference is not None + x = F.interpolate(x, size=reference.shape[-2:], mode='nearest') + elif self.mode == 'downsample': + x = F.adaptive_avg_pool2d( + x, output_size=tuple(u // 2 for u in x.shape[-2:])) + return x + + +class ResidualBlock(nn.Module): + + def __init__(self, + in_dim, + embed_dim, + out_dim, + use_scale_shift_norm=True, + mode='none', + dropout=0.0): + super(ResidualBlock, self).__init__() + self.in_dim = in_dim + self.embed_dim = embed_dim + self.out_dim = out_dim + self.use_scale_shift_norm = use_scale_shift_norm + self.mode = mode + + # layers + self.layer1 = nn.Sequential( + nn.GroupNorm(32, in_dim), nn.SiLU(), + nn.Conv2d(in_dim, out_dim, 3, padding=1)) + self.resample = Resample(in_dim, in_dim, mode) + self.embedding = nn.Sequential( + nn.SiLU(), + nn.Linear(embed_dim, + out_dim * 2 if use_scale_shift_norm else out_dim)) + self.layer2 = nn.Sequential( + nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout), + nn.Conv2d(out_dim, out_dim, 3, padding=1)) + self.shortcut = nn.Identity() if in_dim == out_dim else nn.Conv2d( + in_dim, out_dim, 1) + + # zero out the last layer params + nn.init.zeros_(self.layer2[-1].weight) + + def forward(self, x, e, reference=None): + identity = self.resample(x, reference) + x = self.layer1[-1](self.resample(self.layer1[:-1](x), reference)) + e = self.embedding(e).unsqueeze(-1).unsqueeze(-1).type(x.dtype) + if self.use_scale_shift_norm: + scale, shift = e.chunk(2, dim=1) + x = self.layer2[0](x) * (1 + scale) + shift + x = self.layer2[1:](x) + else: + x = x + e + x = self.layer2(x) + x = x + self.shortcut(identity) + return x + + +class AttentionBlock(nn.Module): + + def __init__(self, dim, context_dim=None, num_heads=None, head_dim=None): + # consider head_dim first, then num_heads + num_heads = dim // head_dim if head_dim else num_heads + head_dim = dim // num_heads + assert num_heads * head_dim == dim + super(AttentionBlock, self).__init__() + self.dim = dim + self.context_dim = context_dim + self.num_heads = num_heads + self.head_dim = head_dim + self.scale = math.pow(head_dim, -0.25) + + # layers + self.norm = nn.GroupNorm(32, dim) + self.to_qkv = nn.Conv2d(dim, dim * 3, 1) + if context_dim is not None: + self.context_kv = nn.Linear(context_dim, dim * 2) + self.proj = nn.Conv2d(dim, dim, 1) + + # zero out the last layer params + nn.init.zeros_(self.proj.weight) + + def forward(self, x, context=None): + r"""x: [B, C, H, W]. + context: [B, L, C] or None. + """ + identity = x + b, c, h, w, n, d = *x.size(), self.num_heads, self.head_dim + + # compute query, key, value + x = self.norm(x) + q, k, v = self.to_qkv(x).view(b, n * 3, d, h * w).chunk(3, dim=1) + if context is not None: + ck, cv = self.context_kv(context).reshape(b, -1, n * 2, + d).permute(0, 2, 3, + 1).chunk( + 2, dim=1) + k = torch.cat([ck, k], dim=-1) + v = torch.cat([cv, v], dim=-1) + + # compute attention + attn = torch.matmul(q.transpose(-1, -2) * self.scale, k * self.scale) + attn = F.softmax(attn, dim=-1) + + # gather context + x = torch.matmul(v, attn.transpose(-1, -2)) + x = x.reshape(b, c, h, w) + + # output + x = self.proj(x) + return x + identity + + +class TemporalAttentionBlock(nn.Module): + + def __init__(self, + dim, + heads=4, + dim_head=32, + rotary_emb=None, + use_image_dataset=False, + use_sim_mask=False): + super().__init__() + # consider num_heads first, as pos_bias needs fixed num_heads + dim_head = dim // heads + assert heads * dim_head == dim + self.use_image_dataset = use_image_dataset + self.use_sim_mask = use_sim_mask + + self.scale = dim_head**-0.5 + self.heads = heads + hidden_dim = dim_head * heads + + self.norm = nn.GroupNorm(32, dim) + self.rotary_emb = rotary_emb + self.to_qkv = nn.Linear(dim, hidden_dim * 3) + self.to_out = nn.Linear(hidden_dim, dim) + + def forward(self, + x, + pos_bias=None, + focus_present_mask=None, + video_mask=None): + + identity = x + n, height, device = x.shape[2], x.shape[-2], x.device + + x = self.norm(x) + x = rearrange(x, 'b c f h w -> b (h w) f c') + + qkv = self.to_qkv(x).chunk(3, dim=-1) + + if exists(focus_present_mask) and focus_present_mask.all(): + # if all batch samples are focusing on present + # it would be equivalent to passing that token's values (v=qkv[-1]) through to the output + values = qkv[-1] + out = self.to_out(values) + out = rearrange(out, 'b (h w) f c -> b c f h w', h=height) + + return out + identity + + # split out heads + # shape [b (hw) h n c/h], n=f + q = rearrange(qkv[0], '... n (h d) -> ... h n d', h=self.heads) + k = rearrange(qkv[1], '... n (h d) -> ... h n d', h=self.heads) + v = rearrange(qkv[2], '... n (h d) -> ... h n d', h=self.heads) + + # scale + + q = q * self.scale + + # rotate positions into queries and keys for time attention + if exists(self.rotary_emb): + q = self.rotary_emb.rotate_queries_or_keys(q) + k = self.rotary_emb.rotate_queries_or_keys(k) + + # similarity + # shape [b (hw) h n n], n=f + sim = torch.einsum('... h i d, ... h j d -> ... h i j', q, k) + + # relative positional bias + + if exists(pos_bias): + sim = sim + pos_bias + + if (focus_present_mask is None and video_mask is not None): + # video_mask: [B, n] + mask = video_mask[:, None, :] * video_mask[:, :, None] + mask = mask.unsqueeze(1).unsqueeze(1) + sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max) + elif exists(focus_present_mask) and not (~focus_present_mask).all(): + attend_all_mask = torch.ones((n, n), + device=device, + dtype=torch.bool) + attend_self_mask = torch.eye(n, device=device, dtype=torch.bool) + + mask = torch.where( + rearrange(focus_present_mask, 'b -> b 1 1 1 1'), + rearrange(attend_self_mask, 'i j -> 1 1 1 i j'), + rearrange(attend_all_mask, 'i j -> 1 1 1 i j'), + ) + + sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max) + + if self.use_sim_mask: + sim_mask = torch.tril( + torch.ones((n, n), device=device, dtype=torch.bool), + diagonal=0) + sim = sim.masked_fill(~sim_mask, -torch.finfo(sim.dtype).max) + + # numerical stability + sim = sim - sim.amax(dim=-1, keepdim=True).detach() + attn = sim.softmax(dim=-1) + + # aggregate values + + out = torch.einsum('... h i j, ... h j d -> ... h i d', attn, v) + out = rearrange(out, '... h n d -> ... n (h d)') + out = self.to_out(out) + + out = rearrange(out, 'b (h w) f c -> b c f h w', h=height) + + if self.use_image_dataset: + out = identity + 0 * out + else: + out = identity + out + return out + + +class TemporalTransformer(nn.Module): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + """ + + def __init__(self, + in_channels, + n_heads, + d_head, + depth=1, + dropout=0., + context_dim=None, + disable_self_attn=False, + use_linear=False, + use_checkpoint=True, + only_self_att=True, + multiply_zero=False): + super().__init__() + self.multiply_zero = multiply_zero + self.only_self_att = only_self_att + self.use_adaptor = False + if self.only_self_att: + context_dim = None + if not isinstance(context_dim, list): + context_dim = [context_dim] + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = torch.nn.GroupNorm( + num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + if not use_linear: + self.proj_in = nn.Conv1d( + in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + else: + self.proj_in = nn.Linear(in_channels, inner_dim) + if self.use_adaptor: + self.adaptor_in = nn.Linear(frames, frames) + + self.transformer_blocks = nn.ModuleList([ + BasicTransformerBlock( + inner_dim, + n_heads, + d_head, + dropout=dropout, + context_dim=context_dim[d], + checkpoint=use_checkpoint) for d in range(depth) + ]) + if not use_linear: + self.proj_out = zero_module( + nn.Conv1d( + inner_dim, in_channels, kernel_size=1, stride=1, + padding=0)) + else: + self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) + if self.use_adaptor: + self.adaptor_out = nn.Linear(frames, frames) + self.use_linear = use_linear + + def forward(self, x, context=None): + # note: if no context is given, cross-attention defaults to self-attention + if self.only_self_att: + context = None + if not isinstance(context, list): + context = [context] + b, c, f, h, w = x.shape + x_in = x + x = self.norm(x) + + if not self.use_linear: + x = rearrange(x, 'b c f h w -> (b h w) c f').contiguous() + x = self.proj_in(x) + # [16384, 16, 320] + if self.use_linear: + x = rearrange( + x, '(b f) c h w -> b (h w) f c', f=self.frames).contiguous() + x = self.proj_in(x) + + if self.only_self_att: + x = rearrange(x, 'bhw c f -> bhw f c').contiguous() + for i, block in enumerate(self.transformer_blocks): + x = block(x) + x = rearrange(x, '(b hw) f c -> b hw f c', b=b).contiguous() + else: + x = rearrange(x, '(b hw) c f -> b hw f c', b=b).contiguous() + for i, block in enumerate(self.transformer_blocks): + context[i] = rearrange( + context[i], '(b f) l con -> b f l con', + f=self.frames).contiguous() + # calculate each batch one by one (since number in shape could not greater then 65,535 for some package) + for j in range(b): + context_i_j = repeat( + context[i][j], + 'f l con -> (f r) l con', + r=(h * w) // self.frames, + f=self.frames).contiguous() + x[j] = block(x[j], context=context_i_j) + + if self.use_linear: + x = self.proj_out(x) + x = rearrange(x, 'b (h w) f c -> b f c h w', h=h, w=w).contiguous() + if not self.use_linear: + x = rearrange(x, 'b hw f c -> (b hw) c f').contiguous() + x = self.proj_out(x) + x = rearrange( + x, '(b h w) c f -> b c f h w', b=b, h=h, w=w).contiguous() + + if self.multiply_zero: + x = 0.0 * x + x_in + else: + x = x + x_in + return x + + +class TemporalAttentionMultiBlock(nn.Module): + + def __init__( + self, + dim, + heads=4, + dim_head=32, + rotary_emb=None, + use_image_dataset=False, + use_sim_mask=False, + temporal_attn_times=1, + ): + super().__init__() + self.att_layers = nn.ModuleList([ + TemporalAttentionBlock(dim, heads, dim_head, rotary_emb, + use_image_dataset, use_sim_mask) + for _ in range(temporal_attn_times) + ]) + + def forward(self, + x, + pos_bias=None, + focus_present_mask=None, + video_mask=None): + for layer in self.att_layers: + x = layer(x, pos_bias, focus_present_mask, video_mask) + return x + + +class InitTemporalConvBlock(nn.Module): + + def __init__(self, + in_dim, + out_dim=None, + dropout=0.0, + use_image_dataset=False): + super(InitTemporalConvBlock, self).__init__() + if out_dim is None: + out_dim = in_dim + self.in_dim = in_dim + self.out_dim = out_dim + self.use_image_dataset = use_image_dataset + + # conv layers + self.conv = nn.Sequential( + nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout), + nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0))) + + # zero out the last layer params,so the conv block is identity + nn.init.zeros_(self.conv[-1].weight) + nn.init.zeros_(self.conv[-1].bias) + + def forward(self, x): + identity = x + x = self.conv(x) + if self.use_image_dataset: + x = identity + 0 * x + else: + x = identity + x + return x + + +class TemporalConvBlock(nn.Module): + + def __init__(self, + in_dim, + out_dim=None, + dropout=0.0, + use_image_dataset=False): + super(TemporalConvBlock, self).__init__() + if out_dim is None: + out_dim = in_dim + self.in_dim = in_dim + self.out_dim = out_dim + self.use_image_dataset = use_image_dataset + + # conv layers + self.conv1 = nn.Sequential( + nn.GroupNorm(32, in_dim), nn.SiLU(), + nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0))) + self.conv2 = nn.Sequential( + nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout), + nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0))) + + # zero out the last layer params,so the conv block is identity + nn.init.zeros_(self.conv2[-1].weight) + nn.init.zeros_(self.conv2[-1].bias) + + def forward(self, x): + identity = x + x = self.conv1(x) + x = self.conv2(x) + if self.use_image_dataset: + x = identity + 0 * x + else: + x = identity + x + return x + + +class TemporalConvBlock_v2(nn.Module): + + def __init__(self, + in_dim, + out_dim=None, + dropout=0.0, + use_image_dataset=False): + super(TemporalConvBlock_v2, self).__init__() + if out_dim is None: + out_dim = in_dim + self.in_dim = in_dim + self.out_dim = out_dim + self.use_image_dataset = use_image_dataset + + # conv layers + self.conv1 = nn.Sequential( + nn.GroupNorm(32, in_dim), nn.SiLU(), + nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0))) + self.conv2 = nn.Sequential( + nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout), + nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0))) + self.conv3 = nn.Sequential( + nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout), + nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0))) + self.conv4 = nn.Sequential( + nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout), + nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0))) + + # zero out the last layer params,so the conv block is identity + nn.init.zeros_(self.conv4[-1].weight) + nn.init.zeros_(self.conv4[-1].bias) + + def forward(self, x): + identity = x + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + x = self.conv4(x) + + if self.use_image_dataset: + x = identity + 0.0 * x + else: + x = identity + x + return x + + +class Img2VidSDUNet(nn.Module): + + def __init__(self, + in_dim=7, + dim=512, + y_dim=512, + num_tokens=4, + context_dim=512, + out_dim=6, + dim_mult=[1, 2, 3, 4], + num_heads=None, + head_dim=64, + num_res_blocks=3, + attn_scales=[1 / 2, 1 / 4, 1 / 8], + use_scale_shift_norm=True, + dropout=0.1, + default_fps=8, + temporal_attn_times=1, + temporal_attention=True, + use_checkpoint=False, + use_image_dataset=False, + use_sim_mask=False, + training=True, + inpainting=True, + **kwargs): + embed_dim = dim * 4 + num_heads = num_heads if num_heads else dim // 32 + super(Img2VidSDUNet, self).__init__() + self.in_dim = in_dim + self.num_tokens = num_tokens + self.dim = dim + self.y_dim = y_dim + self.context_dim = context_dim + self.embed_dim = embed_dim + self.out_dim = out_dim + self.dim_mult = dim_mult + # for temporal attention + self.num_heads = num_heads + # for spatial attention + self.default_fps = default_fps + self.head_dim = head_dim + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.use_scale_shift_norm = use_scale_shift_norm + self.temporal_attn_times = temporal_attn_times + self.temporal_attention = temporal_attention + self.use_checkpoint = use_checkpoint + self.use_image_dataset = use_image_dataset + self.use_sim_mask = use_sim_mask + self.training = training + self.inpainting = inpainting + + use_linear_in_temporal = False + transformer_depth = 1 + disabled_sa = False + # params + enc_dims = [dim * u for u in [1] + dim_mult] + dec_dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + shortcut_dims = [] + scale = 1.0 + + # embeddings + self.time_embed = nn.Sequential( + nn.Linear(dim, embed_dim), nn.SiLU(), + nn.Linear(embed_dim, embed_dim)) + + self.context_embedding = nn.Sequential( + nn.Linear(y_dim, embed_dim), nn.SiLU(), + nn.Linear(embed_dim, context_dim * self.num_tokens)) + + self.fps_embedding = nn.Sequential( + nn.Linear(dim, embed_dim), nn.SiLU(), + nn.Linear(embed_dim, embed_dim)) + nn.init.zeros_(self.fps_embedding[-1].weight) + nn.init.zeros_(self.fps_embedding[-1].bias) + + if temporal_attention and not USE_TEMPORAL_TRANSFORMER: + self.rotary_emb = RotaryEmbedding(min(32, head_dim)) + self.time_rel_pos_bias = RelativePositionBias( + heads=num_heads, max_distance=32) + + # encoder + self.input_blocks = nn.ModuleList() + init_block = nn.ModuleList([nn.Conv2d(self.in_dim, dim, 3, padding=1)]) + # need an initial temporal attention? + if temporal_attention: + if USE_TEMPORAL_TRANSFORMER: + init_block.append( + TemporalTransformer( + dim, + num_heads, + head_dim, + depth=transformer_depth, + context_dim=context_dim, + disable_self_attn=disabled_sa, + use_linear=use_linear_in_temporal, + multiply_zero=use_image_dataset)) + else: + init_block.append( + TemporalAttentionMultiBlock( + dim, + num_heads, + head_dim, + rotary_emb=self.rotary_emb, + temporal_attn_times=temporal_attn_times, + use_image_dataset=use_image_dataset)) + + self.input_blocks.append(init_block) + shortcut_dims.append(dim) + for i, (in_dim, + out_dim) in enumerate(zip(enc_dims[:-1], enc_dims[1:])): + for j in range(num_res_blocks): + block = nn.ModuleList([ + ResBlock( + in_dim, + embed_dim, + dropout, + out_channels=out_dim, + use_scale_shift_norm=False, + use_image_dataset=use_image_dataset, + ) + ]) + if scale in attn_scales: + block.append( + SpatialTransformer( + out_dim, + out_dim // head_dim, + head_dim, + depth=1, + context_dim=self.context_dim, + disable_self_attn=False, + use_linear=True)) + if self.temporal_attention: + if USE_TEMPORAL_TRANSFORMER: + block.append( + TemporalTransformer( + out_dim, + out_dim // head_dim, + head_dim, + depth=transformer_depth, + context_dim=context_dim, + disable_self_attn=disabled_sa, + use_linear=use_linear_in_temporal, + multiply_zero=use_image_dataset)) + else: + block.append( + TemporalAttentionMultiBlock( + out_dim, + num_heads, + head_dim, + rotary_emb=self.rotary_emb, + use_image_dataset=use_image_dataset, + use_sim_mask=use_sim_mask, + temporal_attn_times=temporal_attn_times)) + in_dim = out_dim + self.input_blocks.append(block) + shortcut_dims.append(out_dim) + + # downsample + if i != len(dim_mult) - 1 and j == num_res_blocks - 1: + downsample = Downsample( + out_dim, True, dims=2, out_channels=out_dim) + shortcut_dims.append(out_dim) + scale /= 2.0 + self.input_blocks.append(downsample) + + self.middle_block = nn.ModuleList([ + ResBlock( + out_dim, + embed_dim, + dropout, + use_scale_shift_norm=False, + use_image_dataset=use_image_dataset, + ), + SpatialTransformer( + out_dim, + out_dim // head_dim, + head_dim, + depth=1, + context_dim=self.context_dim, + disable_self_attn=False, + use_linear=True) + ]) + + if self.temporal_attention: + if USE_TEMPORAL_TRANSFORMER: + self.middle_block.append( + TemporalTransformer( + out_dim, + out_dim // head_dim, + head_dim, + depth=transformer_depth, + context_dim=context_dim, + disable_self_attn=disabled_sa, + use_linear=use_linear_in_temporal, + multiply_zero=use_image_dataset, + )) + else: + self.middle_block.append( + TemporalAttentionMultiBlock( + out_dim, + num_heads, + head_dim, + rotary_emb=self.rotary_emb, + use_image_dataset=use_image_dataset, + use_sim_mask=use_sim_mask, + temporal_attn_times=temporal_attn_times)) + + self.middle_block.append( + ResBlock(out_dim, embed_dim, dropout, use_scale_shift_norm=False)) + + # decoder + self.output_blocks = nn.ModuleList() + for i, (in_dim, + out_dim) in enumerate(zip(dec_dims[:-1], dec_dims[1:])): + for j in range(num_res_blocks + 1): + block = nn.ModuleList([ + ResBlock( + in_dim + shortcut_dims.pop(), + embed_dim, + dropout, + out_dim, + use_scale_shift_norm=False, + use_image_dataset=use_image_dataset, + ) + ]) + if scale in attn_scales: + block.append( + SpatialTransformer( + out_dim, + out_dim // head_dim, + head_dim, + depth=1, + context_dim=1024, + disable_self_attn=False, + use_linear=True)) + if self.temporal_attention: + if USE_TEMPORAL_TRANSFORMER: + block.append( + TemporalTransformer( + out_dim, + out_dim // head_dim, + head_dim, + depth=transformer_depth, + context_dim=context_dim, + disable_self_attn=disabled_sa, + use_linear=use_linear_in_temporal, + multiply_zero=use_image_dataset)) + else: + block.append( + TemporalAttentionMultiBlock( + out_dim, + num_heads, + head_dim, + rotary_emb=self.rotary_emb, + use_image_dataset=use_image_dataset, + use_sim_mask=use_sim_mask, + temporal_attn_times=temporal_attn_times)) + in_dim = out_dim + + # upsample + if i != len(dim_mult) - 1 and j == num_res_blocks: + upsample = Upsample( + out_dim, True, dims=2.0, out_channels=out_dim) + scale *= 2.0 + block.append(upsample) + self.output_blocks.append(block) + + # head + self.out = nn.Sequential( + nn.GroupNorm(32, out_dim), nn.SiLU(), + nn.Conv2d(out_dim, self.out_dim, 3, padding=1)) + + # zero out the last layer params + nn.init.zeros_(self.out[-1].weight) + + def forward(self, + x, + t, + y, + fps=None, + video_mask=None, + focus_present_mask=None, + prob_focus_present=0., + mask_last_frame_num=0, + **kwargs): + + batch, c, f, h, w = x.shape + device = x.device + self.batch = batch + if fps is None: + fps = torch.tensor( + [cfg.default_fps] * batch, dtype=torch.long, device=device) + + # image and video joint training, if mask_last_frame_num is set, prob_focus_present will be ignored + if mask_last_frame_num > 0: + focus_present_mask = None + video_mask[-mask_last_frame_num:] = False + else: + focus_present_mask = default( + focus_present_mask, lambda: prob_mask_like( + (batch, ), prob_focus_present, device=device)) + + if self.temporal_attention and not USE_TEMPORAL_TRANSFORMER: + time_rel_pos_bias = self.time_rel_pos_bias( + x.shape[2], device=x.device) + else: + time_rel_pos_bias = None + + # embeddings + embeddings = self.time_embed(sinusoidal_embedding( + t, self.dim)) + self.fps_embedding( + sinusoidal_embedding(fps, self.dim)) + + context = self.context_embedding(y) + context = context.view(-1, self.num_tokens, self.context_dim) + + # repeat f times for spatial e and context + embeddings = embeddings.repeat_interleave(repeats=f, dim=0) + context = context.repeat_interleave(repeats=f, dim=0) + + # always in shape (b f) c h w, except for temporal layer + x = rearrange(x, 'b c f h w -> (b f) c h w') + # encoder + xs = [] + for block in self.input_blocks: + x = self._forward_single(block, x, embeddings, context, + time_rel_pos_bias, focus_present_mask, + video_mask) + xs.append(x) + + # middle + for block in self.middle_block: + x = self._forward_single(block, x, embeddings, context, + time_rel_pos_bias, focus_present_mask, + video_mask) + + # decoder + for block in self.output_blocks: + x = torch.cat([x, xs.pop()], dim=1) + x = self._forward_single( + block, + x, + embeddings, + context, + time_rel_pos_bias, + focus_present_mask, + video_mask, + reference=xs[-1] if len(xs) > 0 else None) + + # head + x = self.out(x) + + # reshape back to (b c f h w) + x = rearrange(x, '(b f) c h w -> b c f h w', b=batch) + return x + + def _forward_single(self, + module, + x, + e, + context, + time_rel_pos_bias, + focus_present_mask, + video_mask, + reference=None): + if isinstance(module, ResidualBlock): + module = checkpoint_wrapper( + module) if self.use_checkpoint else module + x = x.contiguous() + x = module(x, e, reference) + elif isinstance(module, ResBlock): + module = checkpoint_wrapper( + module) if self.use_checkpoint else module + x = x.contiguous() + x = module(x, e, self.batch) + elif isinstance(module, SpatialTransformer): + module = checkpoint_wrapper( + module) if self.use_checkpoint else module + x = module(x, context) + elif isinstance(module, TemporalTransformer): + module = checkpoint_wrapper( + module) if self.use_checkpoint else module + x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch) + x = module(x, context) + x = rearrange(x, 'b c f h w -> (b f) c h w') + elif isinstance(module, CrossAttention): + module = checkpoint_wrapper( + module) if self.use_checkpoint else module + x = module(x, context) + elif isinstance(module, MemoryEfficientCrossAttention): + module = checkpoint_wrapper( + module) if self.use_checkpoint else module + x = module(x, context) + elif isinstance(module, BasicTransformerBlock): + module = checkpoint_wrapper( + module) if self.use_checkpoint else module + x = module(x, context) + elif isinstance(module, FeedForward): + x = module(x, context) + elif isinstance(module, Upsample): + x = module(x) + elif isinstance(module, Downsample): + x = module(x) + elif isinstance(module, Resample): + x = module(x, reference) + elif isinstance(module, TemporalAttentionBlock): + module = checkpoint_wrapper( + module) if self.use_checkpoint else module + x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch) + x = module(x, time_rel_pos_bias, focus_present_mask, video_mask) + x = rearrange(x, 'b c f h w -> (b f) c h w') + elif isinstance(module, TemporalAttentionMultiBlock): + module = checkpoint_wrapper( + module) if self.use_checkpoint else module + x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch) + x = module(x, time_rel_pos_bias, focus_present_mask, video_mask) + x = rearrange(x, 'b c f h w -> (b f) c h w') + elif isinstance(module, InitTemporalConvBlock): + module = checkpoint_wrapper( + module) if self.use_checkpoint else module + x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch) + x = module(x) + x = rearrange(x, 'b c f h w -> (b f) c h w') + elif isinstance(module, TemporalConvBlock): + module = checkpoint_wrapper( + module) if self.use_checkpoint else module + x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch) + x = module(x) + x = rearrange(x, 'b c f h w -> (b f) c h w') + elif isinstance(module, nn.ModuleList): + for block in module: + x = self._forward_single(block, x, e, context, + time_rel_pos_bias, focus_present_mask, + video_mask, reference) + else: + x = module(x) + return x diff --git a/modelscope/models/multi_modal/image_to_video/utils/__init__.py b/modelscope/models/multi_modal/image_to_video/utils/__init__.py new file mode 100755 index 00000000..92654e51 --- /dev/null +++ b/modelscope/models/multi_modal/image_to_video/utils/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os diff --git a/modelscope/models/multi_modal/image_to_video/utils/config.py b/modelscope/models/multi_modal/image_to_video/utils/config.py new file mode 100755 index 00000000..f5e71cde --- /dev/null +++ b/modelscope/models/multi_modal/image_to_video/utils/config.py @@ -0,0 +1,161 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import logging +import os +import os.path as osp +from datetime import datetime + +import torch +from easydict import EasyDict + +cfg = EasyDict(__name__='Config: VideoLDM Decoder') + +# ---------------------------work dir-------------------------- +cfg.work_dir = 'workspace/' + +# ---------------------------Global Variable----------------------------------- +cfg.resolution = [448, 256] +# ----------------------------------------------------------------------------- + +# ---------------------------Dataset Parameter--------------------------------- +cfg.mean = [0.5, 0.5, 0.5] +cfg.std = [0.5, 0.5, 0.5] +cfg.max_words = 1000 + +# PlaceHolder +cfg.vit_out_dim = 1024 +cfg.vit_resolution = [224, 224] +cfg.depth_clamp = 10.0 +cfg.misc_size = 384 +cfg.depth_std = 20.0 + +cfg.frame_lens = 32 +cfg.sample_fps = 8 + +cfg.batch_sizes = 1 +# ----------------------------------------------------------------------------- + +# ---------------------------Mode Parameters----------------------------------- +# Diffusion +cfg.schedule = 'cosine' +cfg.num_timesteps = 1000 +cfg.mean_type = 'v' +cfg.var_type = 'fixed_small' +cfg.loss_type = 'mse' +cfg.ddim_timesteps = 50 +cfg.ddim_eta = 0.0 +cfg.clamp = 1.0 +cfg.share_noise = False +cfg.use_div_loss = False +cfg.noise_strength = 0.1 + +# classifier-free guidance +cfg.p_zero = 0.1 +cfg.guide_scale = 3.0 + +# clip vision encoder +cfg.vit_mean = [0.48145466, 0.4578275, 0.40821073] +cfg.vit_std = [0.26862954, 0.26130258, 0.27577711] + +# Model +cfg.scale_factor = 0.18215 +cfg.use_fp16 = True +cfg.temporal_attention = True +cfg.decoder_bs = 8 + +cfg.UNet = { + 'type': 'Img2VidSDUNet', + 'in_dim': 4, + 'dim': 320, + 'y_dim': cfg.vit_out_dim, + 'context_dim': 1024, + 'out_dim': 8 if cfg.var_type.startswith('learned') else 4, + 'dim_mult': [1, 2, 4, 4], + 'num_heads': 8, + 'head_dim': 64, + 'num_res_blocks': 2, + 'attn_scales': [1 / 1, 1 / 2, 1 / 4], + 'dropout': 0.1, + 'temporal_attention': cfg.temporal_attention, + 'temporal_attn_times': 1, + 'use_checkpoint': False, + 'use_fps_condition': False, + 'use_sim_mask': False, + 'num_tokens': 4, + 'default_fps': 8, + 'input_dim': 1024 +} + +cfg.guidances = [] + +# auotoencoder from stabel diffusion +cfg.auto_encoder = { + 'type': 'AutoencoderKL', + 'ddconfig': { + 'double_z': True, + 'z_channels': 4, + 'resolution': 256, + 'in_channels': 3, + 'out_ch': 3, + 'ch': 128, + 'ch_mult': [1, 2, 4, 4], + 'num_res_blocks': 2, + 'attn_resolutions': [], + 'dropout': 0.0 + }, + 'embed_dim': 4, + 'pretrained': 'v2-1_512-ema-pruned.ckpt' +} +# clip embedder +cfg.embedder = { + 'type': 'FrozenOpenCLIPVisualEmbedder', + 'layer': 'penultimate', + 'vit_resolution': [224, 224], + 'pretrained': 'open_clip_pytorch_model.bin' +} +# ----------------------------------------------------------------------------- + +# ---------------------------Training Settings--------------------------------- +# training and optimizer +cfg.ema_decay = 0.9999 +cfg.num_steps = 600000 +cfg.lr = 5e-5 +cfg.weight_decay = 0.0 +cfg.betas = (0.9, 0.999) +cfg.eps = 1.0e-8 +cfg.chunk_size = 16 +cfg.alpha = 0.7 +cfg.save_ckp_interval = 1000 +# ----------------------------------------------------------------------------- + +# ----------------------------Pretrain Settings--------------------------------- +# Default: load 2d pretrain +cfg.fix_weight = False +cfg.load_match = False +cfg.pretrained_checkpoint = 'v2-1_512-ema-pruned.ckpt' +cfg.pretrained_image_keys = 'stable_diffusion_image_key_temporal_attention_x1.json' +cfg.resume_checkpoint = 'img2video_ldm_0779000.pth' +# ----------------------------------------------------------------------------- + +# -----------------------------Visual------------------------------------------- +# Visual videos +cfg.viz_interval = 1000 +cfg.visual_train = { + 'type': 'VisualVideoTextDuringTrain', +} +cfg.visual_inference = { + 'type': 'VisualGeneratedVideos', +} +cfg.inference_list_path = '' + +# logging +cfg.log_interval = 100 + +# Default log_dir +cfg.log_dir = 'workspace/output_data' +# ----------------------------------------------------------------------------- + +# ---------------------------Others-------------------------------------------- +# seed +cfg.seed = 8888 +# ----------------------------------------------------------------------------- diff --git a/modelscope/models/multi_modal/image_to_video/utils/diffusion.py b/modelscope/models/multi_modal/image_to_video/utils/diffusion.py new file mode 100755 index 00000000..fc4b96cf --- /dev/null +++ b/modelscope/models/multi_modal/image_to_video/utils/diffusion.py @@ -0,0 +1,511 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import math + +import torch + +__all__ = ['GaussianDiffusion', 'beta_schedule'] + + +def _i(tensor, t, x): + r"""Index tensor using t and format the output according to x. + """ + shape = (x.size(0), ) + (1, ) * (x.ndim - 1) + if tensor.device != x.device: + tensor = tensor.to(x.device) + return tensor[t].view(shape).to(x) + + +def fn(u): + return math.cos((u + 0.008) / 1.008 * math.pi / 2)**2 + + +def beta_schedule(schedule, + num_timesteps=1000, + init_beta=None, + last_beta=None): + if schedule == 'linear': + scale = 1000.0 / num_timesteps + init_beta = init_beta or scale * 0.0001 + last_beta = last_beta or scale * 0.02 + return torch.linspace( + init_beta, last_beta, num_timesteps, dtype=torch.float64) + elif schedule == 'quadratic': + init_beta = init_beta or 0.0015 + last_beta = last_beta or 0.0195 + return torch.linspace( + init_beta**0.5, last_beta**0.5, num_timesteps, + dtype=torch.float64)**2 + elif schedule == 'cosine': + betas = [] + for step in range(num_timesteps): + t1 = step / num_timesteps + t2 = (step + 1) / num_timesteps + betas.append(min(1.0 - fn(t2) / fn(t1), 0.999)) + return torch.tensor(betas, dtype=torch.float64) + else: + raise ValueError(f'Unsupported schedule: {schedule}') + + +class GaussianDiffusion(object): + + def __init__(self, + betas, + mean_type='eps', + var_type='learned_range', + loss_type='mse', + epsilon=1e-12, + rescale_timesteps=False, + noise_strength=0.0): + # check input + if not isinstance(betas, torch.DoubleTensor): + betas = torch.tensor(betas, dtype=torch.float64) + assert min(betas) > 0 and max(betas) <= 1 + assert mean_type in ['x0', 'x_{t-1}', 'eps', 'v'] + assert var_type in [ + 'learned', 'learned_range', 'fixed_large', 'fixed_small' + ] + assert loss_type in [ + 'mse', 'rescaled_mse', 'kl', 'rescaled_kl', 'l1', 'rescaled_l1', + 'charbonnier' + ] + self.betas = betas + self.num_timesteps = len(betas) + self.mean_type = mean_type + self.var_type = var_type + self.loss_type = loss_type + self.epsilon = epsilon + self.rescale_timesteps = rescale_timesteps + self.noise_strength = noise_strength + + # alphas + alphas = 1 - self.betas + self.alphas_cumprod = torch.cumprod(alphas, dim=0) + self.alphas_cumprod_prev = torch.cat( + [alphas.new_ones([1]), self.alphas_cumprod[:-1]]) + self.alphas_cumprod_next = torch.cat( + [self.alphas_cumprod[1:], + alphas.new_zeros([1])]) + + # q(x_t | x_{t-1}) + self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod) + self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 + - self.alphas_cumprod) + self.log_one_minus_alphas_cumprod = torch.log(1.0 + - self.alphas_cumprod) + self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod) + self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod + - 1) + + # q(x_{t-1} | x_t, x_0) + self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / ( + 1.0 - self.alphas_cumprod) + self.posterior_log_variance_clipped = torch.log( + self.posterior_variance.clamp(1e-20)) + self.posterior_mean_coef1 = betas * torch.sqrt( + self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + self.posterior_mean_coef2 = ( + 1.0 - self.alphas_cumprod_prev) * torch.sqrt(alphas) / ( + 1.0 - self.alphas_cumprod) + + def sample_loss(self, x0, noise=None): + if noise is None: + noise = torch.randn_like(x0) + if self.noise_strength > 0: + b, c, f, _, _ = x0.shape + offset_noise = torch.randn(b, c, f, 1, 1, device=x0.device) + noise = noise + self.noise_strength * offset_noise + return noise + + def q_sample(self, x0, t, noise=None): + r"""Sample from q(x_t | x_0). + """ + # noise = torch.randn_like(x0) if noise is None else noise + noise = self.sample_loss(x0, noise) + return _i(self.sqrt_alphas_cumprod, t, x0) * x0 + ( + _i(self.sqrt_one_minus_alphas_cumprod, t, x0) * noise) + + def q_mean_variance(self, x0, t): + r"""Distribution of q(x_t | x_0). + """ + mu = _i(self.sqrt_alphas_cumprod, t, x0) * x0 + var = _i(1.0 - self.alphas_cumprod, t, x0) + log_var = _i(self.log_one_minus_alphas_cumprod, t, x0) + return mu, var, log_var + + def q_posterior_mean_variance(self, x0, xt, t): + r"""Distribution of q(x_{t-1} | x_t, x_0). + """ + mu = _i(self.posterior_mean_coef1, t, xt) * x0 + _i( + self.posterior_mean_coef2, t, xt) * xt + var = _i(self.posterior_variance, t, xt) + log_var = _i(self.posterior_log_variance_clipped, t, xt) + return mu, var, log_var + + @torch.no_grad() + def p_sample(self, + xt, + t, + model, + model_kwargs={}, + clamp=None, + percentile=None, + condition_fn=None, + guide_scale=None): + r"""Sample from p(x_{t-1} | x_t). + - condition_fn: for classifier-based guidance (guided-diffusion). + - guide_scale: for classifier-free guidance (glide/dalle-2). + """ + # predict distribution of p(x_{t-1} | x_t) + mu, var, log_var, x0 = self.p_mean_variance(xt, t, model, model_kwargs, + clamp, percentile, + guide_scale) + + # random sample (with optional conditional function) + noise = torch.randn_like(xt) + mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1))) + if condition_fn is not None: + grad = condition_fn(xt, self._scale_timesteps(t), **model_kwargs) + mu = mu.float() + var * grad.float() + xt_1 = mu + mask * torch.exp(0.5 * log_var) * noise + return xt_1, x0 + + @torch.no_grad() + def p_sample_loop(self, + noise, + model, + model_kwargs={}, + clamp=None, + percentile=None, + condition_fn=None, + guide_scale=None): + r"""Sample from p(x_{t-1} | x_t) p(x_{t-2} | x_{t-1}) ... p(x_0 | x_1). + """ + # prepare input + b = noise.size(0) + xt = noise + + # diffusion process + for step in torch.arange(self.num_timesteps).flip(0): + t = torch.full((b, ), step, dtype=torch.long, device=xt.device) + xt, _ = self.p_sample(xt, t, model, model_kwargs, clamp, + percentile, condition_fn, guide_scale) + return xt + + def p_mean_variance(self, + xt, + t, + model, + model_kwargs={}, + clamp=None, + percentile=None, + guide_scale=None): + r"""Distribution of p(x_{t-1} | x_t). + """ + # predict distribution + if guide_scale is None: + out = model(xt, self._scale_timesteps(t), **model_kwargs) + else: + # classifier-free guidance + # (model_kwargs[0]: conditional kwargs; model_kwargs[1]: non-conditional kwargs) + assert isinstance(model_kwargs, list) and len(model_kwargs) == 2 + y_out = model(xt, self._scale_timesteps(t), **model_kwargs[0]) + u_out = model(xt, self._scale_timesteps(t), **model_kwargs[1]) + dim = y_out.size(1) if self.var_type.startswith( + 'fixed') else y_out.size(1) // 2 + out = torch.cat( + [ + u_out[:, :dim] + guide_scale * # noqa + (y_out[:, :dim] - u_out[:, :dim]), + y_out[:, dim:] + ], + dim=1) + + # compute variance + if self.var_type == 'learned': + out, log_var = out.chunk(2, dim=1) + var = torch.exp(log_var) + elif self.var_type == 'learned_range': + out, fraction = out.chunk(2, dim=1) + min_log_var = _i(self.posterior_log_variance_clipped, t, xt) + max_log_var = _i(torch.log(self.betas), t, xt) + fraction = (fraction + 1) / 2.0 + log_var = fraction * max_log_var + (1 - fraction) * min_log_var + var = torch.exp(log_var) + elif self.var_type == 'fixed_large': + var = _i( + torch.cat([self.posterior_variance[1:2], self.betas[1:]]), t, + xt) + log_var = torch.log(var) + elif self.var_type == 'fixed_small': + var = _i(self.posterior_variance, t, xt) + log_var = _i(self.posterior_log_variance_clipped, t, xt) + + # compute mean and x0 + if self.mean_type == 'x_{t-1}': + mu = out + x0 = _i(1.0 / self.posterior_mean_coef1, t, xt) * mu - ( + _i(self.posterior_mean_coef2 / self.posterior_mean_coef1, t, + xt) * xt) + elif self.mean_type == 'x0': + x0 = out + mu, _, _ = self.q_posterior_mean_variance(x0, xt, t) + elif self.mean_type == 'eps': + x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - ( + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * out) + mu, _, _ = self.q_posterior_mean_variance(x0, xt, t) + elif self.mean_type == 'v': + x0 = _i(self.sqrt_alphas_cumprod, t, xt) * xt - ( + _i(self.sqrt_one_minus_alphas_cumprod, t, xt) * out) + mu, _, _ = self.q_posterior_mean_variance(x0, xt, t) + + # restrict the range of x0 + if percentile is not None: + assert percentile > 0 and percentile <= 1 + s = torch.quantile( + x0.flatten(1).abs(), percentile, + dim=1).clamp_(1.0).view(-1, 1, 1, 1) + x0 = torch.min(s, torch.max(-s, x0)) / s + elif clamp is not None: + x0 = x0.clamp(-clamp, clamp) + return mu, var, log_var, x0 + + @torch.no_grad() + def ddim_sample(self, + xt, + t, + model, + model_kwargs={}, + clamp=None, + percentile=None, + condition_fn=None, + guide_scale=None, + ddim_timesteps=20, + eta=0.0): + r"""Sample from p(x_{t-1} | x_t) using DDIM. + - condition_fn: for classifier-based guidance (guided-diffusion). + - guide_scale: for classifier-free guidance (glide/dalle-2). + """ + stride = self.num_timesteps // ddim_timesteps + + # predict distribution of p(x_{t-1} | x_t) + _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp, + percentile, guide_scale) + if condition_fn is not None: + # x0 -> eps + alpha = _i(self.alphas_cumprod, t, xt) + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / ( + _i(self.sqrt_recipm1_alphas_cumprod, t, xt)) + eps = eps - (1 - alpha).sqrt() * condition_fn( + xt, self._scale_timesteps(t), **model_kwargs) + + # eps -> x0 + x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - ( + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * eps) + + # derive variables + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / ( + _i(self.sqrt_recipm1_alphas_cumprod, t, xt)) + alphas = _i(self.alphas_cumprod, t, xt) + alphas_prev = _i(self.alphas_cumprod, (t - stride).clamp(0), xt) + sigmas = eta * torch.sqrt((1 - alphas_prev) / (1 - alphas) * # noqa + (1 - alphas / alphas_prev)) + + # random sample + noise = torch.randn_like(xt) + direction = torch.sqrt(1 - alphas_prev - sigmas**2) * eps + mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1))) + xt_1 = torch.sqrt(alphas_prev) * x0 + direction + mask * sigmas * noise + return xt_1, x0 + + @torch.no_grad() + def ddim_sample_loop(self, + noise, + model, + model_kwargs={}, + clamp=None, + percentile=None, + condition_fn=None, + guide_scale=None, + ddim_timesteps=20, + eta=0.0): + # prepare input + b = noise.size(0) + xt = noise + + # diffusion process (TODO: clamp is inaccurate! Consider replacing the stride by explicit prev/next steps) + steps = (1 + torch.arange(0, self.num_timesteps, + self.num_timesteps // ddim_timesteps)).clamp( + 0, self.num_timesteps - 1).flip(0) + for step in steps: + t = torch.full((b, ), step, dtype=torch.long, device=xt.device) + xt, _ = self.ddim_sample(xt, t, model, model_kwargs, clamp, + percentile, condition_fn, guide_scale, + ddim_timesteps, eta) + return xt + + @torch.no_grad() + def ddim_reverse_sample(self, + xt, + t, + model, + model_kwargs={}, + clamp=None, + percentile=None, + guide_scale=None, + ddim_timesteps=20): + r"""Sample from p(x_{t+1} | x_t) using DDIM reverse ODE (deterministic). + """ + stride = self.num_timesteps // ddim_timesteps + + # predict distribution of p(x_{t-1} | x_t) + _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp, + percentile, guide_scale) + + # derive variables + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / ( + _i(self.sqrt_recipm1_alphas_cumprod, t, xt)) + alphas_next = _i( + torch.cat( + [self.alphas_cumprod, + self.alphas_cumprod.new_zeros([1])]), + (t + stride).clamp(0, self.num_timesteps), xt) + + # reverse sample + mu = torch.sqrt(alphas_next) * x0 + torch.sqrt(1 - alphas_next) * eps + return mu, x0 + + @torch.no_grad() + def ddim_reverse_sample_loop(self, + x0, + model, + model_kwargs={}, + clamp=None, + percentile=None, + guide_scale=None, + ddim_timesteps=20): + # prepare input + b = x0.size(0) + xt = x0 + + # reconstruction steps + steps = torch.arange(0, self.num_timesteps, + self.num_timesteps // ddim_timesteps) + for step in steps: + t = torch.full((b, ), step, dtype=torch.long, device=xt.device) + xt, _ = self.ddim_reverse_sample(xt, t, model, model_kwargs, clamp, + percentile, guide_scale, + ddim_timesteps) + return xt + + @torch.no_grad() + def plms_sample(self, + xt, + t, + model, + model_kwargs={}, + clamp=None, + percentile=None, + condition_fn=None, + guide_scale=None, + plms_timesteps=20): + r"""Sample from p(x_{t-1} | x_t) using PLMS. + - condition_fn: for classifier-based guidance (guided-diffusion). + - guide_scale: for classifier-free guidance (glide/dalle-2). + """ + stride = self.num_timesteps // plms_timesteps + + # function for compute eps + def compute_eps(xt, t): + # predict distribution of p(x_{t-1} | x_t) + _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, + clamp, percentile, guide_scale) + + # condition + if condition_fn is not None: + # x0 -> eps + alpha = _i(self.alphas_cumprod, t, xt) + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / ( + _i(self.sqrt_recipm1_alphas_cumprod, t, xt)) + eps = eps - (1 - alpha).sqrt() * condition_fn( + xt, self._scale_timesteps(t), **model_kwargs) + + # eps -> x0 + x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - ( + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * eps) + + # derive eps + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / ( + _i(self.sqrt_recipm1_alphas_cumprod, t, xt)) + return eps + + # function for compute x_0 and x_{t-1} + def compute_x0(eps, t): + # eps -> x0 + x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - ( + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * eps) + + # deterministic sample + alphas_prev = _i(self.alphas_cumprod, (t - stride).clamp(0), xt) + direction = torch.sqrt(1 - alphas_prev) * eps + xt_1 = torch.sqrt(alphas_prev) * x0 + direction + return xt_1, x0 + + # PLMS sample + eps = compute_eps(xt, t) + if len(eps_cache) == 0: + # 2nd order pseudo improved Euler + xt_1, x0 = compute_x0(eps, t) + eps_next = compute_eps(xt_1, (t - stride).clamp(0)) + eps_prime = (eps + eps_next) / 2.0 + elif len(eps_cache) == 1: + # 2nd order pseudo linear multistep (Adams-Bashforth) + eps_prime = (3 * eps - eps_cache[-1]) / 2.0 + elif len(eps_cache) == 2: + # 3nd order pseudo linear multistep (Adams-Bashforth) + eps_prime = (23 * eps - 16 * eps_cache[-1] + + 5 * eps_cache[-2]) / 12.0 + elif len(eps_cache) >= 3: + # 4nd order pseudo linear multistep (Adams-Bashforth) + eps_prime = (55 * eps - 59 * eps_cache[-1] + 37 * eps_cache[-2] + - 9 * eps_cache[-3]) / 24.0 + xt_1, x0 = compute_x0(eps_prime, t) + return xt_1, x0, eps + + @torch.no_grad() + def plms_sample_loop(self, + noise, + model, + model_kwargs={}, + clamp=None, + percentile=None, + condition_fn=None, + guide_scale=None, + plms_timesteps=20): + # prepare input + b = noise.size(0) + xt = noise + + # diffusion process + steps = (1 + torch.arange(0, self.num_timesteps, + self.num_timesteps // plms_timesteps)).clamp( + 0, self.num_timesteps - 1).flip(0) + eps_cache = [] + for step in steps: + # PLMS sampling step + t = torch.full((b, ), step, dtype=torch.long, device=xt.device) + xt, _, eps = self.plms_sample(xt, t, model, model_kwargs, clamp, + percentile, condition_fn, + guide_scale, plms_timesteps, + eps_cache) + + # update eps cache + eps_cache.append(eps) + if len(eps_cache) >= 4: + eps_cache.pop(0) + return xt + + def _scale_timesteps(self, t): + if self.rescale_timesteps: + return t.float() * 1000.0 / self.num_timesteps + return t diff --git a/modelscope/models/multi_modal/image_to_video/utils/seed.py b/modelscope/models/multi_modal/image_to_video/utils/seed.py new file mode 100755 index 00000000..df3c9c50 --- /dev/null +++ b/modelscope/models/multi_modal/image_to_video/utils/seed.py @@ -0,0 +1,14 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import random + +import numpy as np +import torch + + +def setup_seed(seed): + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + torch.backends.cudnn.deterministic = True diff --git a/modelscope/models/multi_modal/image_to_video/utils/shedule.py b/modelscope/models/multi_modal/image_to_video/utils/shedule.py new file mode 100644 index 00000000..a255be15 --- /dev/null +++ b/modelscope/models/multi_modal/image_to_video/utils/shedule.py @@ -0,0 +1,60 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import math + +import torch + + +def fn(u): + return math.cos((u + 0.008) / 1.008 * math.pi / 2)**2 + + +def beta_schedule(schedule, + num_timesteps=1000, + init_beta=None, + last_beta=None): + ''' + This code defines a function beta_schedule that generates a sequence of beta values based on the given input + parameters. These beta values can be used in video diffusion processes. The function has the following parameters: + schedule(str): Determines the type of beta schedule to be generated. It can be 'linear', 'linear_sd', + 'quadratic', or 'cosine'. + num_timesteps(int, optional): The number of timesteps for the generated beta schedule. Default is 1000. + init_beta(float, optional): The initial beta value. If not provided, a default value is used based on the + chosen schedule. + last_beta(float, optional): The final beta value. If not provided, a default value is used based on the + chosen schedule. + The function returns a PyTorch tensor containing the generated beta values. + The beta schedule is determined by the schedule parameter: + 1.Linear: Generates a linear sequence of beta values betweeninit_betaandlast_beta. + 2.Linear_sd: Generates a linear sequence of beta values between the square root of init_beta and the square root + oflast_beta, and then squares the result. + 3.Quadratic: Similar to the 'linear_sd' schedule, but with different default values forinit_betaandlast_beta. + 4.Cosine: Generates a sequence of beta values based on a cosine function, ensuring the values are between 0 + and 0.999. + If an unsupported schedule is provided, a ValueError is raised with a message indicating the issue. + ''' + if schedule == 'linear': + scale = 1000.0 / num_timesteps + init_beta = init_beta or scale * 0.0001 + last_beta = last_beta or scale * 0.02 + return torch.linspace( + init_beta, last_beta, num_timesteps, dtype=torch.float64) + elif schedule == 'linear_sd': + return torch.linspace( + init_beta**0.5, last_beta**0.5, num_timesteps, + dtype=torch.float64)**2 + elif schedule == 'quadratic': + init_beta = init_beta or 0.0015 + last_beta = last_beta or 0.0195 + return torch.linspace( + init_beta**0.5, last_beta**0.5, num_timesteps, + dtype=torch.float64)**2 + elif schedule == 'cosine': + betas = [] + for step in range(num_timesteps): + t1 = step / num_timesteps + t2 = (step + 1) / num_timesteps + betas.append(min(1.0 - fn(t2) / fn(t1), 0.999)) + return torch.tensor(betas, dtype=torch.float64) + else: + raise ValueError(f'Unsupported schedule: {schedule}') diff --git a/modelscope/models/multi_modal/image_to_video/utils/transforms.py b/modelscope/models/multi_modal/image_to_video/utils/transforms.py new file mode 100755 index 00000000..3663620f --- /dev/null +++ b/modelscope/models/multi_modal/image_to_video/utils/transforms.py @@ -0,0 +1,404 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import math +import random + +import numpy as np +import torch +import torchvision.transforms.functional as F +from PIL import Image, ImageFilter + +__all__ = [ + 'Compose', 'Resize', 'Rescale', 'CenterCrop', 'CenterCropV2', + 'CenterCropWide', 'RandomCrop', 'RandomCropV2', 'RandomHFlip', + 'GaussianBlur', 'ColorJitter', 'RandomGray', 'ToTensor', 'Normalize', + 'ResizeRandomCrop', 'ExtractResizeRandomCrop', 'ExtractResizeAssignCrop' +] + + +class Compose(object): + + def __init__(self, transforms): + self.transforms = transforms + + def __getitem__(self, index): + if isinstance(index, slice): + return Compose(self.transforms[index]) + else: + return self.transforms[index] + + def __len__(self): + return len(self.transforms) + + def __call__(self, rgb): + for t in self.transforms: + rgb = t(rgb) + return rgb + + +class Resize(object): + + def __init__(self, size=256): + if isinstance(size, int): + size = (size, size) + self.size = size + + def __call__(self, rgb): + if isinstance(rgb, list): + rgb = [u.resize(self.size, Image.BILINEAR) for u in rgb] + else: + rgb = rgb.resize(self.size, Image.BILINEAR) + return rgb + + +class Rescale(object): + + def __init__(self, size=256, interpolation=Image.BILINEAR): + self.size = size + self.interpolation = interpolation + + def __call__(self, rgb): + w, h = rgb[0].size + scale = self.size / min(w, h) + out_w, out_h = int(round(w * scale)), int(round(h * scale)) + rgb = [u.resize((out_w, out_h), self.interpolation) for u in rgb] + return rgb + + +class CenterCrop(object): + + def __init__(self, size=224): + self.size = size + + def __call__(self, rgb): + w, h = rgb[0].size + assert min(w, h) >= self.size + x1 = (w - self.size) // 2 + y1 = (h - self.size) // 2 + rgb = [u.crop((x1, y1, x1 + self.size, y1 + self.size)) for u in rgb] + return rgb + + +class ResizeRandomCrop(object): + + def __init__(self, size=256, size_short=292): + self.size = size + self.size_short = size_short + + def __call__(self, rgb): + + # consistent crop between rgb and m + while min(rgb[0].size) >= 2 * self.size_short: + rgb = [ + u.resize((u.width // 2, u.height // 2), resample=Image.BOX) + for u in rgb + ] + scale = self.size_short / min(rgb[0].size) + rgb = [ + u.resize((round(scale * u.width), round(scale * u.height)), + resample=Image.BICUBIC) for u in rgb + ] + out_w = self.size + out_h = self.size + w, h = rgb[0].size + x1 = random.randint(0, w - out_w) + y1 = random.randint(0, h - out_h) + + rgb = [u.crop((x1, y1, x1 + out_w, y1 + out_h)) for u in rgb] + return rgb + + +class ExtractResizeRandomCrop(object): + + def __init__(self, size=256, size_short=292): + self.size = size + self.size_short = size_short + + def __call__(self, rgb): + + # consistent crop between rgb and m + while min(rgb[0].size) >= 2 * self.size_short: + rgb = [ + u.resize((u.width // 2, u.height // 2), resample=Image.BOX) + for u in rgb + ] + scale = self.size_short / min(rgb[0].size) + rgb = [ + u.resize((round(scale * u.width), round(scale * u.height)), + resample=Image.BICUBIC) for u in rgb + ] + out_w = self.size + out_h = self.size + w, h = rgb[0].size + x1 = random.randint(0, w - out_w) + y1 = random.randint(0, h - out_h) + + rgb = [u.crop((x1, y1, x1 + out_w, y1 + out_h)) for u in rgb] + wh = [x1, y1, x1 + out_w, y1 + out_h] + return rgb, wh + + +class ExtractResizeAssignCrop(object): + + def __init__(self, size=256, size_short=292): + self.size = size + self.size_short = size_short + + def __call__(self, rgb, wh): + + # consistent crop between rgb and m + while min(rgb[0].size) >= 2 * self.size_short: + rgb = [ + u.resize((u.width // 2, u.height // 2), resample=Image.BOX) + for u in rgb + ] + scale = self.size_short / min(rgb[0].size) + rgb = [ + u.resize((round(scale * u.width), round(scale * u.height)), + resample=Image.BICUBIC) for u in rgb + ] + + rgb = [u.crop(wh) for u in rgb] + rgb = [u.resize((self.size, self.size), Image.BILINEAR) for u in rgb] + + return rgb + + +class CenterCropV2(object): + + def __init__(self, size): + self.size = size + + def __call__(self, img): + # fast resize + while min(img[0].size) >= 2 * self.size: + img = [ + u.resize((u.width // 2, u.height // 2), resample=Image.BOX) + for u in img + ] + scale = self.size / min(img[0].size) + img = [ + u.resize((round(scale * u.width), round(scale * u.height)), + resample=Image.BICUBIC) for u in img + ] + + # center crop + x1 = (img[0].width - self.size) // 2 + y1 = (img[0].height - self.size) // 2 + img = [u.crop((x1, y1, x1 + self.size, y1 + self.size)) for u in img] + return img + + +class CenterCropWide(object): + + def __init__(self, size): + self.size = size + + def __call__(self, img): + if isinstance(img, list): + scale = min(img[0].size[0] / self.size[0], + img[0].size[1] / self.size[1]) + img = [ + u.resize((round(u.width // scale), round(u.height // scale)), + resample=Image.BOX) for u in img + ] + + # center crop + x1 = (img[0].width - self.size[0]) // 2 + y1 = (img[0].height - self.size[1]) // 2 + img = [ + u.crop((x1, y1, x1 + self.size[0], y1 + self.size[1])) + for u in img + ] + return img + else: + scale = min(img.size[0] / self.size[0], img.size[1] / self.size[1]) + img = img.resize( + (round(img.width // scale), round(img.height // scale)), + resample=Image.BOX) + x1 = (img.width - self.size[0]) // 2 + y1 = (img.height - self.size[1]) // 2 + img = img.crop((x1, y1, x1 + self.size[0], y1 + self.size[1])) + return img + + +class RandomCrop(object): + + def __init__(self, size=224, min_area=0.4): + self.size = size + self.min_area = min_area + + def __call__(self, rgb): + + # consistent crop between rgb and m + w, h = rgb[0].size + area = w * h + out_w, out_h = float('inf'), float('inf') + while out_w > w or out_h > h: + target_area = random.uniform(self.min_area, 1.0) * area + aspect_ratio = random.uniform(3. / 4., 4. / 3.) + out_w = int(round(math.sqrt(target_area * aspect_ratio))) + out_h = int(round(math.sqrt(target_area / aspect_ratio))) + x1 = random.randint(0, w - out_w) + y1 = random.randint(0, h - out_h) + + rgb = [u.crop((x1, y1, x1 + out_w, y1 + out_h)) for u in rgb] + rgb = [u.resize((self.size, self.size), Image.BILINEAR) for u in rgb] + + return rgb + + +class RandomCropV2(object): + + def __init__(self, size=224, min_area=0.4, ratio=(3. / 4., 4. / 3.)): + if isinstance(size, (tuple, list)): + self.size = size + else: + self.size = (size, size) + self.min_area = min_area + self.ratio = ratio + + def _get_params(self, img): + width, height = img.size + area = height * width + + for _ in range(10): + target_area = random.uniform(self.min_area, 1.0) * area + log_ratio = (math.log(self.ratio[0]), math.log(self.ratio[1])) + aspect_ratio = math.exp(random.uniform(*log_ratio)) + + w = int(round(math.sqrt(target_area * aspect_ratio))) + h = int(round(math.sqrt(target_area / aspect_ratio))) + + if 0 < w <= width and 0 < h <= height: + i = random.randint(0, height - h) + j = random.randint(0, width - w) + return i, j, h, w + + # Fallback to central crop + in_ratio = float(width) / float(height) + if (in_ratio < min(self.ratio)): + w = width + h = int(round(w / min(self.ratio))) + elif (in_ratio > max(self.ratio)): + h = height + w = int(round(h * max(self.ratio))) + else: + w = width + h = height + i = (height - h) // 2 + j = (width - w) // 2 + return i, j, h, w + + def __call__(self, rgb): + i, j, h, w = self._get_params(rgb[0]) + rgb = [F.resized_crop(u, i, j, h, w, self.size) for u in rgb] + return rgb + + +class RandomHFlip(object): + + def __init__(self, p=0.5): + self.p = p + + def __call__(self, rgb): + if random.random() < self.p: + rgb = [u.transpose(Image.FLIP_LEFT_RIGHT) for u in rgb] + return rgb + + +class GaussianBlur(object): + + def __init__(self, sigmas=[0.1, 2.0], p=0.5): + self.sigmas = sigmas + self.p = p + + def __call__(self, rgb): + if random.random() < self.p: + sigma = random.uniform(*self.sigmas) + rgb = [ + u.filter(ImageFilter.GaussianBlur(radius=sigma)) for u in rgb + ] + return rgb + + +class ColorJitter(object): + + def __init__(self, + brightness=0.4, + contrast=0.4, + saturation=0.4, + hue=0.1, + p=0.5): + self.brightness = brightness + self.contrast = contrast + self.saturation = saturation + self.hue = hue + self.p = p + + def __call__(self, rgb): + if random.random() < self.p: + brightness, contrast, saturation, hue = self._random_params() + transforms = [ + lambda f: F.adjust_brightness(f, brightness), + lambda f: F.adjust_contrast(f, contrast), + lambda f: F.adjust_saturation(f, saturation), + lambda f: F.adjust_hue(f, hue) + ] + random.shuffle(transforms) + for t in transforms: + rgb = [t(u) for u in rgb] + + return rgb + + def _random_params(self): + brightness = random.uniform( + max(0, 1 - self.brightness), 1 + self.brightness) + contrast = random.uniform(max(0, 1 - self.contrast), 1 + self.contrast) + saturation = random.uniform( + max(0, 1 - self.saturation), 1 + self.saturation) + hue = random.uniform(-self.hue, self.hue) + return brightness, contrast, saturation, hue + + +class RandomGray(object): + + def __init__(self, p=0.2): + self.p = p + + def __call__(self, rgb): + if random.random() < self.p: + rgb = [u.convert('L').convert('RGB') for u in rgb] + return rgb + + +class ToTensor(object): + + def __call__(self, rgb): + if isinstance(rgb, list): + rgb = torch.stack([F.to_tensor(u) for u in rgb], dim=0) + else: + rgb = F.to_tensor(rgb) + + return rgb + + +class Normalize(object): + + def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]): + self.mean = mean + self.std = std + + def __call__(self, rgb): + rgb = rgb.clone() + rgb.clamp_(0, 1) + if not isinstance(self.mean, torch.Tensor): + self.mean = rgb.new_tensor(self.mean).view(-1) + if not isinstance(self.std, torch.Tensor): + self.std = rgb.new_tensor(self.std).view(-1) + if rgb.dim() == 4: + rgb.sub_(self.mean.view(1, -1, 1, + 1)).div_(self.std.view(1, -1, 1, 1)) + elif rgb.dim() == 3: + rgb.sub_(self.mean.view(-1, 1, 1)).div_(self.std.view(-1, 1, 1)) + return rgb diff --git a/modelscope/models/multi_modal/stable_diffusion/stable_diffusion.py b/modelscope/models/multi_modal/stable_diffusion/stable_diffusion.py index 6b829485..80d8ab28 100644 --- a/modelscope/models/multi_modal/stable_diffusion/stable_diffusion.py +++ b/modelscope/models/multi_modal/stable_diffusion/stable_diffusion.py @@ -39,7 +39,7 @@ class StableDiffusion(TorchModel): self.lora_tune = kwargs.pop('lora_tune', False) self.dreambooth_tune = kwargs.pop('dreambooth_tune', False) - self.weight_dtype = torch.float32 + self.weight_dtype = kwargs.pop('torch_type', torch.float32) self.device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu') @@ -59,14 +59,15 @@ class StableDiffusion(TorchModel): # Freeze gradient calculation and move to device if self.vae is not None: self.vae.requires_grad_(False) - self.vae = self.vae.to(self.device) + self.vae = self.vae.to(self.device, dtype=self.weight_dtype) if self.text_encoder is not None: self.text_encoder.requires_grad_(False) - self.text_encoder = self.text_encoder.to(self.device) + self.text_encoder = self.text_encoder.to( + self.device, dtype=self.weight_dtype) if self.unet is not None: if self.lora_tune: self.unet.requires_grad_(False) - self.unet = self.unet.to(self.device) + self.unet = self.unet.to(self.device, dtype=self.weight_dtype) # xformers accelerate memory efficient attention if xformers_enable: diff --git a/modelscope/models/multi_modal/video_to_video/__init__.py b/modelscope/models/multi_modal/video_to_video/__init__.py new file mode 100644 index 00000000..0e10a8eb --- /dev/null +++ b/modelscope/models/multi_modal/video_to_video/__init__.py @@ -0,0 +1,24 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + + from .video_to_video_model import VideoToVideo + +else: + _import_structure = { + 'video_to_video_model': ['VideoToVideo'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/multi_modal/video_to_video/modules/__init__.py b/modelscope/models/multi_modal/video_to_video/modules/__init__.py new file mode 100644 index 00000000..6c882318 --- /dev/null +++ b/modelscope/models/multi_modal/video_to_video/modules/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from .autoencoder import * +from .embedder import * +from .unet_v2v import * diff --git a/modelscope/models/multi_modal/video_to_video/modules/autoencoder.py b/modelscope/models/multi_modal/video_to_video/modules/autoencoder.py new file mode 100644 index 00000000..714a8953 --- /dev/null +++ b/modelscope/models/multi_modal/video_to_video/modules/autoencoder.py @@ -0,0 +1,590 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import collections + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +def nonlinearity(x): + # swish + return x * torch.sigmoid(x) + + +def Normalize(in_channels, num_groups=32): + return torch.nn.GroupNorm( + num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + + +@torch.no_grad() +def get_first_stage_encoding(encoder_posterior): + scale_factor = 0.18215 + if isinstance(encoder_posterior, DiagonalGaussianDistribution): + z = encoder_posterior.sample() + elif isinstance(encoder_posterior, torch.Tensor): + z = encoder_posterior + else: + raise NotImplementedError( + f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented" + ) + return scale_factor * z + + +class DiagonalGaussianDistribution(object): + + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like( + self.mean).to(device=self.parameters.device) + + def sample(self): + x = self.mean + self.std * torch.randn( + self.mean.shape).to(device=self.parameters.device) + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.]) + else: + if other is None: + return 0.5 * torch.sum( + torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, + dim=[1, 2, 3]) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var - 1.0 - self.logvar + other.logvar, + dim=[1, 2, 3]) + + def nll(self, sample, dims=[1, 2, 3]): + if self.deterministic: + return torch.Tensor([0.]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + + torch.pow(sample - self.mean, 2) / self.var, + dim=dims) + + def mode(self): + return self.mean + + +class ResnetBlock(nn.Module): + + def __init__(self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout, + temb_channels=512): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d( + out_channels, out_channels, kernel_size=3, stride=1, padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + else: + self.nin_shortcut = torch.nn.Conv2d( + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x + h + + +class AttnBlock(nn.Module): + + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = q.reshape(b, c, h * w) + q = q.permute(0, 2, 1) + k = k.reshape(b, c, h * w) + w_ = torch.bmm(q, k) + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h * w) + w_ = w_.permute(0, 2, 1) + h_ = torch.bmm(v, w_) + h_ = h_.reshape(b, c, h, w) + + h_ = self.proj_out(h_) + + return x + h_ + + +class Upsample(nn.Module): + + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x): + x = torch.nn.functional.interpolate( + x, scale_factor=2.0, mode='nearest') + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=2, padding=0) + + def forward(self, x): + if self.with_conv: + pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, pad, mode='constant', value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class Encoder(nn.Module): + + def __init__(self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + double_z=True, + use_linear_attn=False, + attn_type='vanilla', + **ignore_kwargs): + super().__init__() + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # downsampling + self.conv_in = torch.nn.Conv2d( + in_channels, self.ch, kernel_size=3, stride=1, padding=1) + + curr_res = resolution + in_ch_mult = (1, ) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, + 2 * z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + # timestep embedding + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + + def __init__(self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + give_pre_end=False, + tanh_out=False, + use_linear_attn=False, + attn_type='vanilla', + **ignorekwargs): + super().__init__() + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + self.tanh_out = tanh_out + + # compute block_in and curr_res at lowest res + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2**(self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + logger.info('Working with z of shape {} = {} dimensions.'.format( + self.z_shape, np.prod(self.z_shape))) + + # z to block_in + self.conv_in = torch.nn.Conv2d( + z_channels, block_in, kernel_size=3, stride=1, padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, out_ch, kernel_size=3, stride=1, padding=1) + + def forward(self, z): + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h, temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + if self.tanh_out: + h = torch.tanh(h) + return h + + +class AutoencoderKL(nn.Module): + + def __init__(self, + ddconfig, + embed_dim, + pretrained=None, + ignore_keys=[], + image_key='image', + colorize_nlabels=None, + monitor=None, + ema_decay=None, + learn_logvar=False, + **kwargs): + super().__init__() + self.learn_logvar = learn_logvar + self.image_key = image_key + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + assert ddconfig['double_z'] + self.quant_conv = torch.nn.Conv2d(2 * ddconfig['z_channels'], + 2 * embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, + ddconfig['z_channels'], 1) + self.embed_dim = embed_dim + if colorize_nlabels is not None: + assert type(colorize_nlabels) == int + self.register_buffer('colorize', + torch.randn(3, colorize_nlabels, 1, 1)) + if monitor is not None: + self.monitor = monitor + + self.use_ema = ema_decay is not None + + if pretrained is not None: + self.init_from_ckpt(pretrained, ignore_keys=ignore_keys) + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location='cpu')['state_dict'] + keys = list(sd.keys()) + sd_new = collections.OrderedDict() + for k in keys: + if k.find('first_stage_model') >= 0: + k_new = k.split('first_stage_model.')[-1] + sd_new[k_new] = sd[k] + self.load_state_dict(sd_new, strict=True) + logger.info(f'Restored from {path}') + + def on_train_batch_end(self, *args, **kwargs): + if self.use_ema: + self.model_ema(self) + + def encode(self, x): + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def decode(self, z): + z = self.post_quant_conv(z) + dec = self.decoder(z) + return dec + + def forward(self, input, sample_posterior=True): + posterior = self.encode(input) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + dec = self.decode(z) + return dec, posterior + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = x.permute(0, 3, 1, + 2).to(memory_format=torch.contiguous_format).float() + return x + + def get_last_layer(self): + return self.decoder.conv_out.weight + + @torch.no_grad() + def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + if not only_inputs: + xrec, posterior = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec.shape[1] > 3 + x = self.to_rgb(x) + xrec = self.to_rgb(xrec) + log['samples'] = self.decode(torch.randn_like(posterior.sample())) + log['reconstructions'] = xrec + if log_ema or self.use_ema: + with self.ema_scope(): + xrec_ema, posterior_ema = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec_ema.shape[1] > 3 + xrec_ema = self.to_rgb(xrec_ema) + log['samples_ema'] = self.decode( + torch.randn_like(posterior_ema.sample())) + log['reconstructions_ema'] = xrec_ema + log['inputs'] = x + return log + + def to_rgb(self, x): + assert self.image_key == 'segmentation' + if not hasattr(self, 'colorize'): + self.register_buffer('colorize', + torch.randn(3, x.shape[1], 1, 1).to(x)) + x = F.conv2d(x, weight=self.colorize) + x = 2. * (x - x.min()) / (x.max() - x.min()) - 1. + return x + + +class IdentityFirstStage(torch.nn.Module): + + def __init__(self, *args, vq_interface=False, **kwargs): + self.vq_interface = vq_interface + super().__init__() + + def encode(self, x, *args, **kwargs): + return x + + def decode(self, x, *args, **kwargs): + return x + + def quantize(self, x, *args, **kwargs): + if self.vq_interface: + return x, None, [None, None, None] + return x + + def forward(self, x, *args, **kwargs): + return x diff --git a/modelscope/models/multi_modal/video_to_video/modules/embedder.py b/modelscope/models/multi_modal/video_to_video/modules/embedder.py new file mode 100644 index 00000000..ae8889a6 --- /dev/null +++ b/modelscope/models/multi_modal/video_to_video/modules/embedder.py @@ -0,0 +1,76 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os + +import numpy as np +import open_clip +import torch +import torch.nn as nn +import torchvision.transforms as T + + +class FrozenOpenCLIPEmbedder(nn.Module): + """ + Uses the OpenCLIP transformer encoder for text + """ + LAYERS = ['last', 'penultimate'] + + def __init__(self, + pretrained, + arch='ViT-H-14', + device='cuda', + max_length=77, + freeze=True, + layer='penultimate'): + super().__init__() + assert layer in self.LAYERS + model, _, preprocess = open_clip.create_model_and_transforms( + arch, device=torch.device('cpu'), pretrained=pretrained) + + del model.visual + self.model = model + self.device = device + self.max_length = max_length + + if freeze: + self.freeze() + self.layer = layer + if self.layer == 'last': + self.layer_idx = 0 + elif self.layer == 'penultimate': + self.layer_idx = 1 + else: + raise NotImplementedError() + + def freeze(self): + self.model = self.model.eval() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + tokens = open_clip.tokenize(text) + z = self.encode_with_transformer(tokens.to(self.device)) + return z + + def encode_with_transformer(self, text): + x = self.model.token_embedding(text) + x = x + self.model.positional_embedding + x = x.permute(1, 0, 2) + x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) + x = x.permute(1, 0, 2) + x = self.model.ln_final(x) + return x + + def text_transformer_forward(self, x: torch.Tensor, attn_mask=None): + for i, r in enumerate(self.model.transformer.resblocks): + if i == len(self.model.transformer.resblocks) - self.layer_idx: + break + if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting( + ): + x = checkpoint(r, x, attn_mask) + else: + x = r(x, attn_mask=attn_mask) + return x + + def encode(self, text): + return self(text) diff --git a/modelscope/models/multi_modal/video_to_video/modules/unet_v2v.py b/modelscope/models/multi_modal/video_to_video/modules/unet_v2v.py new file mode 100644 index 00000000..219ddb43 --- /dev/null +++ b/modelscope/models/multi_modal/video_to_video/modules/unet_v2v.py @@ -0,0 +1,1530 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import math +import os + +import torch +import torch.nn as nn +import torch.nn.functional as F +import xformers +import xformers.ops +from einops import rearrange +from fairscale.nn.checkpoint import checkpoint_wrapper +from rotary_embedding_torch import RotaryEmbedding + +USE_TEMPORAL_TRANSFORMER = True + + +class DropPath(nn.Module): + r"""DropPath but without rescaling and supports optional all-zero and/or all-keep. + """ + + def __init__(self, p): + super(DropPath, self).__init__() + self.p = p + + def forward(self, *args, zero=None, keep=None): + if not self.training: + return args[0] if len(args) == 1 else args + + # params + x = args[0] + b = x.size(0) + n = (torch.rand(b) < self.p).sum() + + # non-zero and non-keep mask + mask = x.new_ones(b, dtype=torch.bool) + if keep is not None: + mask[keep] = False + if zero is not None: + mask[zero] = False + + # drop-path index + index = torch.where(mask)[0] + index = index[torch.randperm(len(index))[:n]] + if zero is not None: + index = torch.cat([index, torch.where(zero)[0]], dim=0) + + # drop-path multiplier + multiplier = x.new_ones(b) + multiplier[index] = 0.0 + output = tuple(u * self.broadcast(multiplier, u) for u in args) + return output[0] if len(args) == 1 else output + + def broadcast(self, src, dst): + assert src.size(0) == dst.size(0) + shape = (dst.size(0), ) + (1, ) * (dst.ndim - 1) + return src.view(shape) + + +def sinusoidal_embedding(timesteps, dim): + # check input + half = dim // 2 + timesteps = timesteps.float() + + # compute sinusoidal embedding + sinusoid = torch.outer( + timesteps, torch.pow(10000, + -torch.arange(half).to(timesteps).div(half))) + x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) + if dim % 2 != 0: + x = torch.cat([x, torch.zeros_like(x[:, :1])], dim=1) + return x + + +def exists(x): + return x is not None + + +def default(val, d): + if exists(val): + return val + return d() if callable(d) else d + + +def prob_mask_like(shape, prob, device): + if prob == 1: + return torch.ones(shape, device=device, dtype=torch.bool) + elif prob == 0: + return torch.zeros(shape, device=device, dtype=torch.bool) + else: + mask = torch.zeros(shape, device=device).float().uniform_(0, 1) < prob + # aviod mask all, which will cause find_unused_parameters error + if mask.all(): + mask[0] = False + return mask + + +class MemoryEfficientCrossAttention(nn.Module): + + def __init__(self, + query_dim, + context_dim=None, + heads=8, + dim_head=64, + dropout=0.0): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.heads = heads + self.dim_head = dim_head + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) + self.attention_op: Optional[Any] = None + + def forward(self, x, context=None, mask=None): + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + b, _, _ = q.shape + q, k, v = map( + lambda t: t.unsqueeze(3).reshape(b, t.shape[ + 1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape( + b * self.heads, t.shape[1], self.dim_head).contiguous(), + (q, k, v), + ) + + # actually compute the attention, what we cannot get enough of + out = xformers.ops.memory_efficient_attention( + q, k, v, attn_bias=None, op=self.attention_op) + + if exists(mask): + raise NotImplementedError + out = ( + out.unsqueeze(0).reshape( + b, self.heads, out.shape[1], + self.dim_head).permute(0, 2, 1, + 3).reshape(b, out.shape[1], + self.heads * self.dim_head)) + return self.to_out(out) + + +class RelativePositionBias(nn.Module): + + def __init__(self, heads=8, num_buckets=32, max_distance=128): + super().__init__() + self.num_buckets = num_buckets + self.max_distance = max_distance + self.relative_attention_bias = nn.Embedding(num_buckets, heads) + + @staticmethod + def _relative_position_bucket(relative_position, + num_buckets=32, + max_distance=128): + ret = 0 + n = -relative_position + + num_buckets //= 2 + ret += (n < 0).long() * num_buckets + n = torch.abs(n) + + max_exact = num_buckets // 2 + is_small = n < max_exact + + val_if_large = max_exact + ( + torch.log(n.float() / max_exact) + / math.log(max_distance / max_exact) * # noqa + (num_buckets - max_exact)).long() + val_if_large = torch.min( + val_if_large, torch.full_like(val_if_large, num_buckets - 1)) + + ret += torch.where(is_small, n, val_if_large) + return ret + + def forward(self, n, device): + q_pos = torch.arange(n, dtype=torch.long, device=device) + k_pos = torch.arange(n, dtype=torch.long, device=device) + rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1') + rp_bucket = self._relative_position_bucket( + rel_pos, + num_buckets=self.num_buckets, + max_distance=self.max_distance) + values = self.relative_attention_bias(rp_bucket) + return rearrange(values, 'i j h -> h i j') + + +class SpatialTransformer(nn.Module): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + NEW: use_linear for more efficiency instead of the 1x1 convs + """ + + def __init__(self, + in_channels, + n_heads, + d_head, + depth=1, + dropout=0., + context_dim=None, + disable_self_attn=False, + use_linear=False, + use_checkpoint=True): + super().__init__() + if exists(context_dim) and not isinstance(context_dim, list): + context_dim = [context_dim] + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = torch.nn.GroupNorm( + num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + if not use_linear: + self.proj_in = nn.Conv2d( + in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + else: + self.proj_in = nn.Linear(in_channels, inner_dim) + + self.transformer_blocks = nn.ModuleList([ + BasicTransformerBlock( + inner_dim, + n_heads, + d_head, + dropout=dropout, + context_dim=context_dim[d], + disable_self_attn=disable_self_attn, + checkpoint=use_checkpoint) for d in range(depth) + ]) + if not use_linear: + self.proj_out = zero_module( + nn.Conv2d( + inner_dim, in_channels, kernel_size=1, stride=1, + padding=0)) + else: + self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) + self.use_linear = use_linear + + def forward(self, x, context=None): + # note: if no context is given, cross-attention defaults to self-attention + if not isinstance(context, list): + context = [context] + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + if not self.use_linear: + x = self.proj_in(x) + x = rearrange(x, 'b c h w -> b (h w) c').contiguous() + if self.use_linear: + x = self.proj_in(x) + for i, block in enumerate(self.transformer_blocks): + x = block(x, context=context[i]) + if self.use_linear: + x = self.proj_out(x) + x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() + if not self.use_linear: + x = self.proj_out(x) + return x + x_in + + +_ATTN_PRECISION = os.environ.get('ATTN_PRECISION', 'fp32') + + +class CrossAttention(nn.Module): + + def __init__(self, + query_dim, + context_dim=None, + heads=8, + dim_head=64, + dropout=0.): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.scale = dim_head**-0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) + + def forward(self, x, context=None, mask=None): + h = self.heads + + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), + (q, k, v)) + + # force cast to fp32 to avoid overflowing + if _ATTN_PRECISION == 'fp32': + with torch.autocast(enabled=False, device_type='cuda'): + q, k = q.float(), k.float() + sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale + else: + sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale + + del q, k + + if exists(mask): + mask = rearrange(mask, 'b ... -> b (...)') + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h=h) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + sim = sim.softmax(dim=-1) + + out = torch.einsum('b i j, b j d -> b i d', sim, v) + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + return self.to_out(out) + + +class BasicTransformerBlock(nn.Module): + + def __init__(self, + dim, + n_heads, + d_head, + dropout=0., + context_dim=None, + gated_ff=True, + checkpoint=True, + disable_self_attn=False): + super().__init__() + attn_cls = MemoryEfficientCrossAttention + self.disable_self_attn = disable_self_attn + self.attn1 = attn_cls( + query_dim=dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout, + context_dim=context_dim if self.disable_self_attn else None) + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = attn_cls( + query_dim=dim, + context_dim=context_dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout) + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + + def forward_(self, x, context=None): + return checkpoint(self._forward, (x, context), self.parameters(), + self.checkpoint) + + def forward(self, x, context=None): + x = self.attn1( + self.norm1(x), + context=context if self.disable_self_attn else None) + x + x = self.attn2(self.norm2(x), context=context) + x + x = self.ff(self.norm3(x)) + x + return x + + +# feedforward +class GEGLU(nn.Module): + + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +class FeedForward(nn.Module): + + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = nn.Sequential(nn.Linear( + dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential(project_in, nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out)) + + def forward(self, x): + return self.net(x) + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, + channels, + use_conv, + dims=2, + out_channels=None, + padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = nn.Conv2d( + self.channels, self.out_channels, 3, padding=padding) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate( + x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), + mode='nearest') + else: + x = F.interpolate(x, scale_factor=2, mode='nearest') + x = x[..., 1:-1, :] + if self.use_conv: + x = self.conv(x) + return x + + +class ResBlock(nn.Module): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param use_checkpoint: if True, use gradient checkpointing on this module. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + up=False, + down=False, + use_temporal_conv=True, + use_image_dataset=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_scale_shift_norm = use_scale_shift_norm + self.use_temporal_conv = use_temporal_conv + + self.in_layers = nn.Sequential( + nn.GroupNorm(32, channels), + nn.SiLU(), + nn.Conv2d(channels, self.out_channels, 3, padding=1), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + nn.SiLU(), + nn.Linear( + emb_channels, + 2 * self.out_channels + if use_scale_shift_norm else self.out_channels, + ), + ) + self.out_layers = nn.Sequential( + nn.GroupNorm(32, self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module( + nn.Conv2d(self.out_channels, self.out_channels, 3, padding=1)), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, 3, padding=1) + else: + self.skip_connection = nn.Conv2d(channels, self.out_channels, 1) + + if self.use_temporal_conv: + self.temopral_conv = TemporalConvBlock_v2( + self.out_channels, + self.out_channels, + dropout=0.1, + use_image_dataset=use_image_dataset) + + def forward(self, x, emb, batch_size): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + return self._forward(x, emb, batch_size) + + def _forward(self, x, emb, batch_size): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = th.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + h = self.skip_connection(x) + h + + if self.use_temporal_conv: + h = rearrange(h, '(b f) c h w -> b c f h w', b=batch_size) + h = self.temopral_conv(h) + h = rearrange(h, 'b c f h w -> (b f) c h w') + return h + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, + channels, + use_conv, + dims=2, + out_channels=None, + padding=(2, 1)): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = nn.Conv2d( + self.channels, + self.out_channels, + 3, + stride=stride, + padding=padding) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class Resample(nn.Module): + + def __init__(self, in_dim, out_dim, mode): + assert mode in ['none', 'upsample', 'downsample'] + super(Resample, self).__init__() + self.in_dim = in_dim + self.out_dim = out_dim + self.mode = mode + + def forward(self, x, reference=None): + if self.mode == 'upsample': + assert reference is not None + x = F.interpolate(x, size=reference.shape[-2:], mode='nearest') + elif self.mode == 'downsample': + x = F.adaptive_avg_pool2d( + x, output_size=tuple(u // 2 for u in x.shape[-2:])) + return x + + +class ResidualBlock(nn.Module): + + def __init__(self, + in_dim, + embed_dim, + out_dim, + use_scale_shift_norm=True, + mode='none', + dropout=0.0): + super(ResidualBlock, self).__init__() + self.in_dim = in_dim + self.embed_dim = embed_dim + self.out_dim = out_dim + self.use_scale_shift_norm = use_scale_shift_norm + self.mode = mode + + # layers + self.layer1 = nn.Sequential( + nn.GroupNorm(32, in_dim), nn.SiLU(), + nn.Conv2d(in_dim, out_dim, 3, padding=1)) + self.resample = Resample(in_dim, in_dim, mode) + self.embedding = nn.Sequential( + nn.SiLU(), + nn.Linear(embed_dim, + out_dim * 2 if use_scale_shift_norm else out_dim)) + self.layer2 = nn.Sequential( + nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout), + nn.Conv2d(out_dim, out_dim, 3, padding=1)) + self.shortcut = nn.Identity() if in_dim == out_dim else nn.Conv2d( + in_dim, out_dim, 1) + + # zero out the last layer params + nn.init.zeros_(self.layer2[-1].weight) + + def forward(self, x, e, reference=None): + identity = self.resample(x, reference) + x = self.layer1[-1](self.resample(self.layer1[:-1](x), reference)) + e = self.embedding(e).unsqueeze(-1).unsqueeze(-1).type(x.dtype) + if self.use_scale_shift_norm: + scale, shift = e.chunk(2, dim=1) + x = self.layer2[0](x) * (1 + scale) + shift + x = self.layer2[1:](x) + else: + x = x + e + x = self.layer2(x) + x = x + self.shortcut(identity) + return x + + +class AttentionBlock(nn.Module): + + def __init__(self, dim, context_dim=None, num_heads=None, head_dim=None): + # consider head_dim first, then num_heads + num_heads = dim // head_dim if head_dim else num_heads + head_dim = dim // num_heads + assert num_heads * head_dim == dim + super(AttentionBlock, self).__init__() + self.dim = dim + self.context_dim = context_dim + self.num_heads = num_heads + self.head_dim = head_dim + self.scale = math.pow(head_dim, -0.25) + + # layers + self.norm = nn.GroupNorm(32, dim) + self.to_qkv = nn.Conv2d(dim, dim * 3, 1) + if context_dim is not None: + self.context_kv = nn.Linear(context_dim, dim * 2) + self.proj = nn.Conv2d(dim, dim, 1) + + # zero out the last layer params + nn.init.zeros_(self.proj.weight) + + def forward(self, x, context=None): + r"""x: [B, C, H, W]. + context: [B, L, C] or None. + """ + identity = x + b, c, h, w, n, d = *x.size(), self.num_heads, self.head_dim + + # compute query, key, value + x = self.norm(x) + q, k, v = self.to_qkv(x).view(b, n * 3, d, h * w).chunk(3, dim=1) + if context is not None: + ck, cv = self.context_kv(context).reshape(b, -1, n * 2, + d).permute(0, 2, 3, + 1).chunk( + 2, dim=1) + k = torch.cat([ck, k], dim=-1) + v = torch.cat([cv, v], dim=-1) + + # compute attention + attn = torch.matmul(q.transpose(-1, -2) * self.scale, k * self.scale) + attn = F.softmax(attn, dim=-1) + + # gather context + x = torch.matmul(v, attn.transpose(-1, -2)) + x = x.reshape(b, c, h, w) + + # output + x = self.proj(x) + return x + identity + + +class TemporalAttentionBlock(nn.Module): + + def __init__(self, + dim, + heads=4, + dim_head=32, + rotary_emb=None, + use_image_dataset=False, + use_sim_mask=False): + super().__init__() + # consider num_heads first, as pos_bias needs fixed num_heads + dim_head = dim // heads + assert heads * dim_head == dim + self.use_image_dataset = use_image_dataset + self.use_sim_mask = use_sim_mask + + self.scale = dim_head**-0.5 + self.heads = heads + hidden_dim = dim_head * heads + + self.norm = nn.GroupNorm(32, dim) + self.rotary_emb = rotary_emb + self.to_qkv = nn.Linear(dim, hidden_dim * 3) + self.to_out = nn.Linear(hidden_dim, dim) + + def forward(self, + x, + pos_bias=None, + focus_present_mask=None, + video_mask=None): + + identity = x + n, height, device = x.shape[2], x.shape[-2], x.device + + x = self.norm(x) + x = rearrange(x, 'b c f h w -> b (h w) f c') + + qkv = self.to_qkv(x).chunk(3, dim=-1) + + if exists(focus_present_mask) and focus_present_mask.all(): + # if all batch samples are focusing on present + # it would be equivalent to passing that token's values (v=qkv[-1]) through to the output + values = qkv[-1] + out = self.to_out(values) + out = rearrange(out, 'b (h w) f c -> b c f h w', h=height) + + return out + identity + + # split out heads + q = rearrange(qkv[0], '... n (h d) -> ... h n d', h=self.heads) + k = rearrange(qkv[1], '... n (h d) -> ... h n d', h=self.heads) + v = rearrange(qkv[2], '... n (h d) -> ... h n d', h=self.heads) + + # scale + + q = q * self.scale + + # rotate positions into queries and keys for time attention + if exists(self.rotary_emb): + q = self.rotary_emb.rotate_queries_or_keys(q) + k = self.rotary_emb.rotate_queries_or_keys(k) + + # similarity + # shape [b (hw) h n n], n=f + sim = torch.einsum('... h i d, ... h j d -> ... h i j', q, k) + + # relative positional bias + + if exists(pos_bias): + sim = sim + pos_bias + + if (focus_present_mask is None and video_mask is not None): + # video_mask: [B, n] + mask = video_mask[:, None, :] * video_mask[:, :, None] + mask = mask.unsqueeze(1).unsqueeze(1) + sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max) + elif exists(focus_present_mask) and not (~focus_present_mask).all(): + attend_all_mask = torch.ones((n, n), + device=device, + dtype=torch.bool) + attend_self_mask = torch.eye(n, device=device, dtype=torch.bool) + + mask = torch.where( + rearrange(focus_present_mask, 'b -> b 1 1 1 1'), + rearrange(attend_self_mask, 'i j -> 1 1 1 i j'), + rearrange(attend_all_mask, 'i j -> 1 1 1 i j'), + ) + + sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max) + + if self.use_sim_mask: + sim_mask = torch.tril( + torch.ones((n, n), device=device, dtype=torch.bool), + diagonal=0) + sim = sim.masked_fill(~sim_mask, -torch.finfo(sim.dtype).max) + + # numerical stability + sim = sim - sim.amax(dim=-1, keepdim=True).detach() + attn = sim.softmax(dim=-1) + + # aggregate values + + out = torch.einsum('... h i j, ... h j d -> ... h i d', attn, v) + out = rearrange(out, '... h n d -> ... n (h d)') + out = self.to_out(out) + + out = rearrange(out, 'b (h w) f c -> b c f h w', h=height) + + if self.use_image_dataset: + out = identity + 0 * out + else: + out = identity + out + return out + + +class TemporalTransformer(nn.Module): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + """ + + def __init__(self, + in_channels, + n_heads, + d_head, + depth=1, + dropout=0., + context_dim=None, + disable_self_attn=False, + use_linear=False, + use_checkpoint=True, + only_self_att=True, + multiply_zero=False): + super().__init__() + self.multiply_zero = multiply_zero + self.only_self_att = only_self_att + self.use_adaptor = False + if self.only_self_att: + context_dim = None + if not isinstance(context_dim, list): + context_dim = [context_dim] + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = torch.nn.GroupNorm( + num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + if not use_linear: + self.proj_in = nn.Conv1d( + in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + else: + self.proj_in = nn.Linear(in_channels, inner_dim) + if self.use_adaptor: + self.adaptor_in = nn.Linear(frames, frames) + + self.transformer_blocks = nn.ModuleList([ + BasicTransformerBlock( + inner_dim, + n_heads, + d_head, + dropout=dropout, + context_dim=context_dim[d], + checkpoint=use_checkpoint) for d in range(depth) + ]) + if not use_linear: + self.proj_out = zero_module( + nn.Conv1d( + inner_dim, in_channels, kernel_size=1, stride=1, + padding=0)) + else: + self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) + if self.use_adaptor: + self.adaptor_out = nn.Linear(frames, frames) + self.use_linear = use_linear + + def forward(self, x, context=None): + # note: if no context is given, cross-attention defaults to self-attention + if self.only_self_att: + context = None + if not isinstance(context, list): + context = [context] + b, c, f, h, w = x.shape + x_in = x + x = self.norm(x) + + if not self.use_linear: + x = rearrange(x, 'b c f h w -> (b h w) c f').contiguous() + x = self.proj_in(x) + # [16384, 16, 320] + if self.use_linear: + x = rearrange( + x, '(b f) c h w -> b (h w) f c', f=self.frames).contiguous() + x = self.proj_in(x) + + if self.only_self_att: + x = rearrange(x, 'bhw c f -> bhw f c').contiguous() + for i, block in enumerate(self.transformer_blocks): + x = block(x) + x = rearrange(x, '(b hw) f c -> b hw f c', b=b).contiguous() + else: + x = rearrange(x, '(b hw) c f -> b hw f c', b=b).contiguous() + for i, block in enumerate(self.transformer_blocks): + context[i] = rearrange( + context[i], '(b f) l con -> b f l con', + f=self.frames).contiguous() + # calculate each batch one by one (since number in shape could not greater then 65,535 for some package) + for j in range(b): + context_i_j = repeat( + context[i][j], + 'f l con -> (f r) l con', + r=(h * w) // self.frames, + f=self.frames).contiguous() + x[j] = block(x[j], context=context_i_j) + + if self.use_linear: + x = self.proj_out(x) + x = rearrange(x, 'b (h w) f c -> b f c h w', h=h, w=w).contiguous() + if not self.use_linear: + x = rearrange(x, 'b hw f c -> (b hw) c f').contiguous() + x = self.proj_out(x) + x = rearrange( + x, '(b h w) c f -> b c f h w', b=b, h=h, w=w).contiguous() + + if self.multiply_zero: + x = 0.0 * x + x_in + else: + x = x + x_in + return x + + +class TemporalAttentionMultiBlock(nn.Module): + + def __init__( + self, + dim, + heads=4, + dim_head=32, + rotary_emb=None, + use_image_dataset=False, + use_sim_mask=False, + temporal_attn_times=1, + ): + super().__init__() + self.att_layers = nn.ModuleList([ + TemporalAttentionBlock(dim, heads, dim_head, rotary_emb, + use_image_dataset, use_sim_mask) + for _ in range(temporal_attn_times) + ]) + + def forward(self, + x, + pos_bias=None, + focus_present_mask=None, + video_mask=None): + for layer in self.att_layers: + x = layer(x, pos_bias, focus_present_mask, video_mask) + return x + + +class InitTemporalConvBlock(nn.Module): + + def __init__(self, + in_dim, + out_dim=None, + dropout=0.0, + use_image_dataset=False): + super(InitTemporalConvBlock, self).__init__() + if out_dim is None: + out_dim = in_dim + self.in_dim = in_dim + self.out_dim = out_dim + self.use_image_dataset = use_image_dataset + + # conv layers + self.conv = nn.Sequential( + nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout), + nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0))) + + # zero out the last layer params,so the conv block is identity + nn.init.zeros_(self.conv[-1].weight) + nn.init.zeros_(self.conv[-1].bias) + + def forward(self, x): + identity = x + x = self.conv(x) + if self.use_image_dataset: + x = identity + 0 * x + else: + x = identity + x + return x + + +class TemporalConvBlock(nn.Module): + + def __init__(self, + in_dim, + out_dim=None, + dropout=0.0, + use_image_dataset=False): + super(TemporalConvBlock, self).__init__() + if out_dim is None: + out_dim = in_dim + self.in_dim = in_dim + self.out_dim = out_dim + self.use_image_dataset = use_image_dataset + + # conv layers + self.conv1 = nn.Sequential( + nn.GroupNorm(32, in_dim), nn.SiLU(), + nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0))) + self.conv2 = nn.Sequential( + nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout), + nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0))) + + # zero out the last layer params,so the conv block is identity + nn.init.zeros_(self.conv2[-1].weight) + nn.init.zeros_(self.conv2[-1].bias) + + def forward(self, x): + identity = x + x = self.conv1(x) + x = self.conv2(x) + if self.use_image_dataset: + x = identity + 0 * x + else: + x = identity + x + return x + + +class TemporalConvBlock_v2(nn.Module): + + def __init__(self, + in_dim, + out_dim=None, + dropout=0.0, + use_image_dataset=False): + super(TemporalConvBlock_v2, self).__init__() + if out_dim is None: + out_dim = in_dim + self.in_dim = in_dim + self.out_dim = out_dim + self.use_image_dataset = use_image_dataset + + # conv layers + self.conv1 = nn.Sequential( + nn.GroupNorm(32, in_dim), nn.SiLU(), + nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0))) + self.conv2 = nn.Sequential( + nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout), + nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0))) + self.conv3 = nn.Sequential( + nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout), + nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0))) + self.conv4 = nn.Sequential( + nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout), + nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0))) + + # zero out the last layer params,so the conv block is identity + nn.init.zeros_(self.conv4[-1].weight) + nn.init.zeros_(self.conv4[-1].bias) + + def forward(self, x): + identity = x + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + x = self.conv4(x) + + if self.use_image_dataset: + x = identity + 0.0 * x + else: + x = identity + x + return x + + +class Vid2VidSDUNet(nn.Module): + + def __init__(self, + in_dim=4, + dim=320, + y_dim=1024, + context_dim=1024, + out_dim=4, + dim_mult=[1, 2, 4, 4], + num_heads=8, + head_dim=64, + num_res_blocks=2, + attn_scales=[1 / 1, 1 / 2, 1 / 4], + use_scale_shift_norm=True, + dropout=0.1, + temporal_attn_times=1, + temporal_attention=True, + use_checkpoint=True, + use_image_dataset=False, + use_fps_condition=False, + use_sim_mask=False, + training=False, + inpainting=True): + embed_dim = dim * 4 + num_heads = num_heads if num_heads else dim // 32 + super(Vid2VidSDUNet, self).__init__() + self.in_dim = in_dim + self.dim = dim + self.y_dim = y_dim + self.context_dim = context_dim + self.embed_dim = embed_dim + self.out_dim = out_dim + self.dim_mult = dim_mult + # for temporal attention + self.num_heads = num_heads + # for spatial attention + self.head_dim = head_dim + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.use_scale_shift_norm = use_scale_shift_norm + self.temporal_attn_times = temporal_attn_times + self.temporal_attention = temporal_attention + self.use_checkpoint = use_checkpoint + self.use_image_dataset = use_image_dataset + self.use_fps_condition = use_fps_condition + self.use_sim_mask = use_sim_mask + self.training = training + self.inpainting = inpainting + + use_linear_in_temporal = False + transformer_depth = 1 + disabled_sa = False + # params + enc_dims = [dim * u for u in [1] + dim_mult] + dec_dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + shortcut_dims = [] + scale = 1.0 + + # embeddings + self.time_embed = nn.Sequential( + nn.Linear(dim, embed_dim), nn.SiLU(), + nn.Linear(embed_dim, embed_dim)) + + if temporal_attention and not USE_TEMPORAL_TRANSFORMER: + self.rotary_emb = RotaryEmbedding(min(32, head_dim)) + self.time_rel_pos_bias = RelativePositionBias( + heads=num_heads, max_distance=32) + + if self.use_fps_condition: + self.fps_embedding = nn.Sequential( + nn.Linear(dim, embed_dim), nn.SiLU(), + nn.Linear(embed_dim, embed_dim)) + nn.init.zeros_(self.fps_embedding[-1].weight) + nn.init.zeros_(self.fps_embedding[-1].bias) + + # encoder + self.input_blocks = nn.ModuleList() + init_block = nn.ModuleList([nn.Conv2d(self.in_dim, dim, 3, padding=1)]) + # need an initial temporal attention? + if temporal_attention: + if USE_TEMPORAL_TRANSFORMER: + init_block.append( + TemporalTransformer( + dim, + num_heads, + head_dim, + depth=transformer_depth, + context_dim=context_dim, + disable_self_attn=disabled_sa, + use_linear=use_linear_in_temporal, + multiply_zero=use_image_dataset)) + else: + init_block.append( + TemporalAttentionMultiBlock( + dim, + num_heads, + head_dim, + rotary_emb=self.rotary_emb, + temporal_attn_times=temporal_attn_times, + use_image_dataset=use_image_dataset)) + self.input_blocks.append(init_block) + shortcut_dims.append(dim) + for i, (in_dim, + out_dim) in enumerate(zip(enc_dims[:-1], enc_dims[1:])): + for j in range(num_res_blocks): + block = nn.ModuleList([ + ResBlock( + in_dim, + embed_dim, + dropout, + out_channels=out_dim, + use_scale_shift_norm=False, + use_image_dataset=use_image_dataset, + ) + ]) + if scale in attn_scales: + block.append( + SpatialTransformer( + out_dim, + out_dim // head_dim, + head_dim, + depth=1, + context_dim=self.context_dim, + disable_self_attn=False, + use_linear=True)) + if self.temporal_attention: + if USE_TEMPORAL_TRANSFORMER: + block.append( + TemporalTransformer( + out_dim, + out_dim // head_dim, + head_dim, + depth=transformer_depth, + context_dim=context_dim, + disable_self_attn=disabled_sa, + use_linear=use_linear_in_temporal, + multiply_zero=use_image_dataset)) + else: + block.append( + TemporalAttentionMultiBlock( + out_dim, + num_heads, + head_dim, + rotary_emb=self.rotary_emb, + use_image_dataset=use_image_dataset, + use_sim_mask=use_sim_mask, + temporal_attn_times=temporal_attn_times)) + in_dim = out_dim + self.input_blocks.append(block) + shortcut_dims.append(out_dim) + + # downsample + if i != len(dim_mult) - 1 and j == num_res_blocks - 1: + downsample = Downsample( + out_dim, True, dims=2, out_channels=out_dim) + shortcut_dims.append(out_dim) + scale /= 2.0 + self.input_blocks.append(downsample) + + self.middle_block = nn.ModuleList([ + ResBlock( + out_dim, + embed_dim, + dropout, + use_scale_shift_norm=False, + use_image_dataset=use_image_dataset, + ), + SpatialTransformer( + out_dim, + out_dim // head_dim, + head_dim, + depth=1, + context_dim=self.context_dim, + disable_self_attn=False, + use_linear=True) + ]) + + if self.temporal_attention: + if USE_TEMPORAL_TRANSFORMER: + self.middle_block.append( + TemporalTransformer( + out_dim, + out_dim // head_dim, + head_dim, + depth=transformer_depth, + context_dim=context_dim, + disable_self_attn=disabled_sa, + use_linear=use_linear_in_temporal, + multiply_zero=use_image_dataset, + )) + else: + self.middle_block.append( + TemporalAttentionMultiBlock( + out_dim, + num_heads, + head_dim, + rotary_emb=self.rotary_emb, + use_image_dataset=use_image_dataset, + use_sim_mask=use_sim_mask, + temporal_attn_times=temporal_attn_times)) + + self.middle_block.append( + ResBlock(out_dim, embed_dim, dropout, use_scale_shift_norm=False)) + + # decoder + self.output_blocks = nn.ModuleList() + for i, (in_dim, + out_dim) in enumerate(zip(dec_dims[:-1], dec_dims[1:])): + for j in range(num_res_blocks + 1): + block = nn.ModuleList([ + ResBlock( + in_dim + shortcut_dims.pop(), + embed_dim, + dropout, + out_dim, + use_scale_shift_norm=False, + use_image_dataset=use_image_dataset, + ) + ]) + if scale in attn_scales: + block.append( + SpatialTransformer( + out_dim, + out_dim // head_dim, + head_dim, + depth=1, + context_dim=1024, + disable_self_attn=False, + use_linear=True)) + if self.temporal_attention: + if USE_TEMPORAL_TRANSFORMER: + block.append( + TemporalTransformer( + out_dim, + out_dim // head_dim, + head_dim, + depth=transformer_depth, + context_dim=context_dim, + disable_self_attn=disabled_sa, + use_linear=use_linear_in_temporal, + multiply_zero=use_image_dataset)) + else: + block.append( + TemporalAttentionMultiBlock( + out_dim, + num_heads, + head_dim, + rotary_emb=self.rotary_emb, + use_image_dataset=use_image_dataset, + use_sim_mask=use_sim_mask, + temporal_attn_times=temporal_attn_times)) + in_dim = out_dim + + # upsample + if i != len(dim_mult) - 1 and j == num_res_blocks: + upsample = Upsample( + out_dim, True, dims=2.0, out_channels=out_dim) + scale *= 2.0 + block.append(upsample) + self.output_blocks.append(block) + + # head + self.out = nn.Sequential( + nn.GroupNorm(32, out_dim), nn.SiLU(), + nn.Conv2d(out_dim, self.out_dim, 3, padding=1)) + + # zero out the last layer params + nn.init.zeros_(self.out[-1].weight) + + def forward(self, + x, + t, + y, + x_lr=None, + fps=None, + video_mask=None, + focus_present_mask=None, + prob_focus_present=0., + mask_last_frame_num=0): + + batch, x_c, x_f, x_h, x_w = x.shape + device = x.device + self.batch = batch + + # image and video joint training, if mask_last_frame_num is set, prob_focus_present will be ignored + if mask_last_frame_num > 0: + focus_present_mask = None + video_mask[-mask_last_frame_num:] = False + else: + focus_present_mask = default( + focus_present_mask, lambda: prob_mask_like( + (batch, ), prob_focus_present, device=device)) + + if self.temporal_attention and not USE_TEMPORAL_TRANSFORMER: + time_rel_pos_bias = self.time_rel_pos_bias( + x.shape[2], device=x.device) + else: + time_rel_pos_bias = None + + # embeddings + e = self.time_embed(sinusoidal_embedding(t, self.dim)) + context = y + + # repeat f times for spatial e and context + e = e.repeat_interleave(repeats=x_f, dim=0) + context = context.repeat_interleave(repeats=x_f, dim=0) + + # always in shape (b f) c h w, except for temporal layer + x = rearrange(x, 'b c f h w -> (b f) c h w') + # encoder + xs = [] + for block in self.input_blocks: + x = self._forward_single(block, x, e, context, time_rel_pos_bias, + focus_present_mask, video_mask) + xs.append(x) + + # middle + for block in self.middle_block: + x = self._forward_single(block, x, e, context, time_rel_pos_bias, + focus_present_mask, video_mask) + + # decoder + for block in self.output_blocks: + x = torch.cat([x, xs.pop()], dim=1) + x = self._forward_single( + block, + x, + e, + context, + time_rel_pos_bias, + focus_present_mask, + video_mask, + reference=xs[-1] if len(xs) > 0 else None) + + # head + x = self.out(x) + + # reshape back to (b c f h w) + x = rearrange(x, '(b f) c h w -> b c f h w', b=batch) + return x + + def _forward_single(self, + module, + x, + e, + context, + time_rel_pos_bias, + focus_present_mask, + video_mask, + reference=None): + if isinstance(module, ResidualBlock): + module = checkpoint_wrapper( + module) if self.use_checkpoint else module + x = x.contiguous() + x = module(x, e, reference) + elif isinstance(module, ResBlock): + module = checkpoint_wrapper( + module) if self.use_checkpoint else module + x = x.contiguous() + x = module(x, e, self.batch) + elif isinstance(module, SpatialTransformer): + module = checkpoint_wrapper( + module) if self.use_checkpoint else module + x = module(x, context) + elif isinstance(module, TemporalTransformer): + module = checkpoint_wrapper( + module) if self.use_checkpoint else module + x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch) + x = module(x, context) + x = rearrange(x, 'b c f h w -> (b f) c h w') + elif isinstance(module, CrossAttention): + module = checkpoint_wrapper( + module) if self.use_checkpoint else module + x = module(x, context) + elif isinstance(module, MemoryEfficientCrossAttention): + module = checkpoint_wrapper( + module) if self.use_checkpoint else module + x = module(x, context) + elif isinstance(module, BasicTransformerBlock): + module = checkpoint_wrapper( + module) if self.use_checkpoint else module + x = module(x, context) + elif isinstance(module, FeedForward): + x = module(x, context) + elif isinstance(module, Upsample): + x = module(x) + elif isinstance(module, Downsample): + x = module(x) + elif isinstance(module, Resample): + x = module(x, reference) + elif isinstance(module, TemporalAttentionBlock): + module = checkpoint_wrapper( + module) if self.use_checkpoint else module + x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch) + x = module(x, time_rel_pos_bias, focus_present_mask, video_mask) + x = rearrange(x, 'b c f h w -> (b f) c h w') + elif isinstance(module, TemporalAttentionMultiBlock): + module = checkpoint_wrapper( + module) if self.use_checkpoint else module + x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch) + x = module(x, time_rel_pos_bias, focus_present_mask, video_mask) + x = rearrange(x, 'b c f h w -> (b f) c h w') + elif isinstance(module, InitTemporalConvBlock): + module = checkpoint_wrapper( + module) if self.use_checkpoint else module + x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch) + x = module(x) + x = rearrange(x, 'b c f h w -> (b f) c h w') + elif isinstance(module, TemporalConvBlock): + module = checkpoint_wrapper( + module) if self.use_checkpoint else module + x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch) + x = module(x) + x = rearrange(x, 'b c f h w -> (b f) c h w') + elif isinstance(module, nn.ModuleList): + for block in module: + x = self._forward_single(block, x, e, context, + time_rel_pos_bias, focus_present_mask, + video_mask, reference) + else: + x = module(x) + return x diff --git a/modelscope/models/multi_modal/video_to_video/utils/__init__.py b/modelscope/models/multi_modal/video_to_video/utils/__init__.py new file mode 100644 index 00000000..92654e51 --- /dev/null +++ b/modelscope/models/multi_modal/video_to_video/utils/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os diff --git a/modelscope/models/multi_modal/video_to_video/utils/config.py b/modelscope/models/multi_modal/video_to_video/utils/config.py new file mode 100644 index 00000000..1c9586eb --- /dev/null +++ b/modelscope/models/multi_modal/video_to_video/utils/config.py @@ -0,0 +1,171 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import logging +import os +import os.path as osp +from datetime import datetime + +import torch +from easydict import EasyDict + +cfg = EasyDict(__name__='Config: VideoLDM Decoder') + +# ---------------------------work dir-------------------------- +cfg.work_dir = 'workspace/' + +# ---------------------------Global Variable----------------------------------- +cfg.resolution = [448, 256] +cfg.max_frames = 32 +# ----------------------------------------------------------------------------- + +# ---------------------------Dataset Parameter--------------------------------- +cfg.mean = [0.5, 0.5, 0.5] +cfg.std = [0.5, 0.5, 0.5] +cfg.max_words = 1000 + +# PlaceHolder +cfg.vit_out_dim = 1024 +cfg.vit_resolution = [224, 224] +cfg.depth_clamp = 10.0 +cfg.misc_size = 384 +cfg.depth_std = 20.0 + +cfg.frame_lens = 32 +cfg.sample_fps = 8 + +cfg.batch_sizes = 1 +# ----------------------------------------------------------------------------- + +# ---------------------------Mode Parameters----------------------------------- +# Diffusion +cfg.schedule = 'cosine' +cfg.num_timesteps = 1000 +cfg.mean_type = 'v' +cfg.var_type = 'fixed_small' +cfg.loss_type = 'mse' +cfg.ddim_timesteps = 50 +cfg.ddim_eta = 0.0 +cfg.clamp = 1.0 +cfg.share_noise = False +cfg.use_div_loss = False +cfg.noise_strength = 0.1 + +# classifier-free guidance +cfg.p_zero = 0.1 +cfg.guide_scale = 3.0 + +# clip vision encoder +cfg.vit_mean = [0.48145466, 0.4578275, 0.40821073] +cfg.vit_std = [0.26862954, 0.26130258, 0.27577711] + +# Model +cfg.scale_factor = 0.18215 +cfg.use_fp16 = True +cfg.temporal_attention = True +cfg.decoder_bs = 8 + +cfg.UNet = { + 'type': 'Vid2VidSDUNet', + 'in_dim': 4, + 'dim': 320, + 'y_dim': cfg.vit_out_dim, + 'context_dim': 1024, + 'out_dim': 8 if cfg.var_type.startswith('learned') else 4, + 'dim_mult': [1, 2, 4, 4], + 'num_heads': 8, + 'head_dim': 64, + 'num_res_blocks': 2, + 'attn_scales': [1 / 1, 1 / 2, 1 / 4], + 'dropout': 0.1, + 'temporal_attention': cfg.temporal_attention, + 'temporal_attn_times': 1, + 'use_checkpoint': False, + 'use_fps_condition': False, + 'use_sim_mask': False, + 'num_tokens': 4, + 'default_fps': 8, + 'input_dim': 1024 +} + +cfg.guidances = [] + +# auotoencoder from stabel diffusion +cfg.auto_encoder = { + 'type': 'AutoencoderKL', + 'ddconfig': { + 'double_z': True, + 'z_channels': 4, + 'resolution': 256, + 'in_channels': 3, + 'out_ch': 3, + 'ch': 128, + 'ch_mult': [1, 2, 4, 4], + 'num_res_blocks': 2, + 'attn_resolutions': [], + 'dropout': 0.0 + }, + 'embed_dim': 4, + 'pretrained': 'models/v2-1_512-ema-pruned.ckpt' +} +# clip embedder +cfg.embedder = { + 'type': 'FrozenOpenCLIPEmbedder', + 'layer': 'penultimate', + 'vit_resolution': [224, 224], + 'pretrained': 'open_clip_pytorch_model.bin' +} +# ----------------------------------------------------------------------------- + +# ---------------------------Training Settings--------------------------------- +# training and optimizer +cfg.ema_decay = 0.9999 +cfg.num_steps = 600000 +cfg.lr = 5e-5 +cfg.weight_decay = 0.0 +cfg.betas = (0.9, 0.999) +cfg.eps = 1.0e-8 +cfg.chunk_size = 16 +cfg.alpha = 0.7 +cfg.save_ckp_interval = 1000 +# ----------------------------------------------------------------------------- + +# ----------------------------Pretrain Settings--------------------------------- +# Default: load 2d pretrain +cfg.fix_weight = False +cfg.load_match = False +cfg.pretrained_checkpoint = 'v2-1_512-ema-pruned.ckpt' +cfg.pretrained_image_keys = 'stable_diffusion_image_key_temporal_attention_x1.json' +cfg.resume_checkpoint = 'img2video_ldm_0779000.pth' +# ----------------------------------------------------------------------------- + +# -----------------------------Visual------------------------------------------- +# Visual videos +cfg.viz_interval = 1000 +cfg.visual_train = { + 'type': 'VisualVideoTextDuringTrain', +} +cfg.visual_inference = { + 'type': 'VisualGeneratedVideos', +} +cfg.inference_list_path = '' + +# logging +cfg.log_interval = 100 + +# Default log_dir +cfg.log_dir = 'workspace/output_data' +# ----------------------------------------------------------------------------- + +# ---------------------------Others-------------------------------------------- +# seed +cfg.seed = 8888 +cfg.negative_prompt = 'worst quality, normal quality, low quality, low res, blurry, text, \ +watermark, logo, banner, extra digits, cropped, jpeg artifacts, signature, username, error, \ +sketch ,duplicate, ugly, monochrome, horror, geometry, mutation, disgusting' + +cfg.positive_prompt = ', cinematic, High Contrast, highly detailed, unreal engine, \ +taken using a Canon EOS R camera, hyper detailed photo - realistic maximum detail, \ +32k, Color Grading, ultra HD, extreme meticulous detailing, skin pore detailing, \ +hyper sharpness, perfect without deformations, Unreal Engine 5, 4k render' + +# ----------------------------------------------------------------------------- diff --git a/modelscope/models/multi_modal/video_to_video/utils/diffusion_sdedit.py b/modelscope/models/multi_modal/video_to_video/utils/diffusion_sdedit.py new file mode 100644 index 00000000..be5b5f57 --- /dev/null +++ b/modelscope/models/multi_modal/video_to_video/utils/diffusion_sdedit.py @@ -0,0 +1,247 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import random + +import torch + +from .schedules_sdedit import karras_schedule +from .solvers_sdedit import sample_dpmpp_2m_sde, sample_heun + +__all__ = ['GaussianDiffusion_SDEdit'] + + +def _i(tensor, t, x): + shape = (x.size(0), ) + (1, ) * (x.ndim - 1) + return tensor[t.to(tensor.device)].view(shape).to(x.device) + + +class GaussianDiffusion_SDEdit(object): + + def __init__(self, sigmas, prediction_type='eps'): + assert prediction_type in {'x0', 'eps', 'v'} + self.sigmas = sigmas + self.alphas = torch.sqrt(1 - sigmas**2) + self.num_timesteps = len(sigmas) + self.prediction_type = prediction_type + + def diffuse(self, x0, t, noise=None): + noise = torch.randn_like(x0) if noise is None else noise + xt = _i(self.alphas, t, x0) * x0 + _i(self.sigmas, t, x0) * noise + return xt + + def denoise(self, + xt, + t, + s, + model, + model_kwargs={}, + guide_scale=None, + guide_rescale=None, + clamp=None, + percentile=None): + s = t - 1 if s is None else s + + # hyperparams + sigmas = _i(self.sigmas, t, xt) + alphas = _i(self.alphas, t, xt) + alphas_s = _i(self.alphas, s.clamp(0), xt) + alphas_s[s < 0] = 1. + sigmas_s = torch.sqrt(1 - alphas_s**2) + + # precompute variables + betas = 1 - (alphas / alphas_s)**2 + coef1 = betas * alphas_s / sigmas**2 + coef2 = (alphas * sigmas_s**2) / (alphas_s * sigmas**2) + var = betas * (sigmas_s / sigmas)**2 + log_var = torch.log(var).clamp_(-20, 20) + + # prediction + if guide_scale is None: + assert isinstance(model_kwargs, dict) + out = model(xt, t=t, **model_kwargs) + else: + # classifier-free guidance + assert isinstance(model_kwargs, list) and len(model_kwargs) == 2 + y_out = model(xt, t=t, **model_kwargs[0]) + if guide_scale == 1.: + out = y_out + else: + u_out = model(xt, t=t, **model_kwargs[1]) + out = u_out + guide_scale * (y_out - u_out) + + if guide_rescale is not None: + assert guide_rescale >= 0 and guide_rescale <= 1 + ratio = ( + y_out.flatten(1).std(dim=1) / # noqa + (out.flatten(1).std(dim=1) + 1e-12) + ).view((-1, ) + (1, ) * (y_out.ndim - 1)) + out *= guide_rescale * ratio + (1 - guide_rescale) * 1.0 + + # compute x0 + if self.prediction_type == 'x0': + x0 = out + elif self.prediction_type == 'eps': + x0 = (xt - sigmas * out) / alphas + elif self.prediction_type == 'v': + x0 = alphas * xt - sigmas * out + else: + raise NotImplementedError( + f'prediction_type {self.prediction_type} not implemented') + + # restrict the range of x0 + if percentile is not None: + assert percentile > 0 and percentile <= 1 + s = torch.quantile(x0.flatten(1).abs(), percentile, dim=1) + s = s.clamp_(1.0).view((-1, ) + (1, ) * (xt.ndim - 1)) + x0 = torch.min(s, torch.max(-s, x0)) / s + elif clamp is not None: + x0 = x0.clamp(-clamp, clamp) + + # recompute eps using the restricted x0 + eps = (xt - alphas * x0) / sigmas + + # compute mu (mean of posterior distribution) using the restricted x0 + mu = coef1 * x0 + coef2 * xt + return mu, var, log_var, x0, eps + + @torch.no_grad() + def sample(self, + noise, + model, + model_kwargs={}, + condition_fn=None, + guide_scale=None, + guide_rescale=None, + clamp=None, + percentile=None, + solver='euler_a', + steps=20, + t_max=None, + t_min=None, + discretization=None, + discard_penultimate_step=None, + return_intermediate=None, + show_progress=False, + seed=-1, + **kwargs): + # sanity check + assert isinstance(steps, (int, torch.LongTensor)) + assert t_max is None or (t_max > 0 and t_max <= self.num_timesteps - 1) + assert t_min is None or (t_min >= 0 and t_min < self.num_timesteps - 1) + assert discretization in (None, 'leading', 'linspace', 'trailing') + assert discard_penultimate_step in (None, True, False) + assert return_intermediate in (None, 'x0', 'xt') + + # function of diffusion solver + solver_fn = { + 'heun': sample_heun, + 'dpmpp_2m_sde': sample_dpmpp_2m_sde + }[solver] + + # options + schedule = 'karras' if 'karras' in solver else None + discretization = discretization or 'linspace' + seed = seed if seed >= 0 else random.randint(0, 2**31) + if isinstance(steps, torch.LongTensor): + discard_penultimate_step = False + if discard_penultimate_step is None: + discard_penultimate_step = True if solver in ( + 'dpm2', 'dpm2_ancestral', 'dpmpp_2m_sde', 'dpm2_karras', + 'dpm2_ancestral_karras', 'dpmpp_2m_sde_karras') else False + + # function for denoising xt to get x0 + intermediates = [] + + def model_fn(xt, sigma): + # denoising + t = self._sigma_to_t(sigma).repeat(len(xt)).round().long() + x0 = self.denoise(xt, t, None, model, model_kwargs, guide_scale, + guide_rescale, clamp, percentile)[-2] + + # collect intermediate outputs + if return_intermediate == 'xt': + intermediates.append(xt) + elif return_intermediate == 'x0': + intermediates.append(x0) + return x0 + + # get timesteps + if isinstance(steps, int): + steps += 1 if discard_penultimate_step else 0 + t_max = self.num_timesteps - 1 if t_max is None else t_max + t_min = 0 if t_min is None else t_min + + # discretize timesteps + if discretization == 'leading': + steps = torch.arange(t_min, t_max + 1, + (t_max - t_min + 1) / steps).flip(0) + elif discretization == 'linspace': + steps = torch.linspace(t_max, t_min, steps) + elif discretization == 'trailing': + steps = torch.arange(t_max, t_min - 1, + -((t_max - t_min + 1) / steps)) + else: + raise NotImplementedError( + f'{discretization} discretization not implemented') + steps = steps.clamp_(t_min, t_max) + steps = torch.as_tensor( + steps, dtype=torch.float32, device=noise.device) + + # get sigmas + sigmas = self._t_to_sigma(steps) + sigmas = torch.cat([sigmas, sigmas.new_zeros([1])]) + if schedule == 'karras': + if sigmas[0] == float('inf'): + sigmas = karras_schedule( + n=len(steps) - 1, + sigma_min=sigmas[sigmas > 0].min().item(), + sigma_max=sigmas[sigmas < float('inf')].max().item(), + rho=7.).to(sigmas) + sigmas = torch.cat([ + sigmas.new_tensor([float('inf')]), sigmas, + sigmas.new_zeros([1]) + ]) + else: + sigmas = karras_schedule( + n=len(steps), + sigma_min=sigmas[sigmas > 0].min().item(), + sigma_max=sigmas.max().item(), + rho=7.).to(sigmas) + sigmas = torch.cat([sigmas, sigmas.new_zeros([1])]) + if discard_penultimate_step: + sigmas = torch.cat([sigmas[:-2], sigmas[-1:]]) + + # sampling + x0 = solver_fn( + noise, model_fn, sigmas, show_progress=show_progress, **kwargs) + return (x0, intermediates) if return_intermediate is not None else x0 + + def _sigma_to_t(self, sigma): + if sigma == float('inf'): + t = torch.full_like(sigma, len(self.sigmas) - 1) + else: + log_sigmas = torch.sqrt(self.sigmas**2 / # noqa + (1 - self.sigmas**2)).log().to(sigma) + log_sigma = sigma.log() + dists = log_sigma - log_sigmas[:, None] + low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp( + max=log_sigmas.shape[0] - 2) + high_idx = low_idx + 1 + low, high = log_sigmas[low_idx], log_sigmas[high_idx] + w = (low - log_sigma) / (low - high) + w = w.clamp(0, 1) + t = (1 - w) * low_idx + w * high_idx + t = t.view(sigma.shape) + if t.ndim == 0: + t = t.unsqueeze(0) + return t + + def _t_to_sigma(self, t): + t = t.float() + low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac() + log_sigmas = torch.sqrt(self.sigmas**2 / # noqa + (1 - self.sigmas**2)).log().to(t) + log_sigma = (1 - w) * log_sigmas[low_idx] + w * log_sigmas[high_idx] + log_sigma[torch.isnan(log_sigma) + | torch.isinf(log_sigma)] = float('inf') + return log_sigma.exp() diff --git a/modelscope/models/multi_modal/video_to_video/utils/schedules_sdedit.py b/modelscope/models/multi_modal/video_to_video/utils/schedules_sdedit.py new file mode 100644 index 00000000..06fd4e8a --- /dev/null +++ b/modelscope/models/multi_modal/video_to_video/utils/schedules_sdedit.py @@ -0,0 +1,85 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import math + +import torch + + +def betas_to_sigmas(betas): + return torch.sqrt(1 - torch.cumprod(1 - betas, dim=0)) + + +def sigmas_to_betas(sigmas): + square_alphas = 1 - sigmas**2 + betas = 1 - torch.cat( + [square_alphas[:1], square_alphas[1:] / square_alphas[:-1]]) + return betas + + +def logsnrs_to_sigmas(logsnrs): + return torch.sqrt(torch.sigmoid(-logsnrs)) + + +def sigmas_to_logsnrs(sigmas): + square_sigmas = sigmas**2 + return torch.log(square_sigmas / (1 - square_sigmas)) + + +def _logsnr_cosine(n, logsnr_min=-15, logsnr_max=15): + t_min = math.atan(math.exp(-0.5 * logsnr_min)) + t_max = math.atan(math.exp(-0.5 * logsnr_max)) + t = torch.linspace(1, 0, n) + logsnrs = -2 * torch.log(torch.tan(t_min + t * (t_max - t_min))) + return logsnrs + + +def _logsnr_cosine_shifted(n, logsnr_min=-15, logsnr_max=15, scale=2): + logsnrs = _logsnr_cosine(n, logsnr_min, logsnr_max) + logsnrs += 2 * math.log(1 / scale) + return logsnrs + + +def _logsnr_cosine_interp(n, + logsnr_min=-15, + logsnr_max=15, + scale_min=2, + scale_max=4): + t = torch.linspace(1, 0, n) + logsnrs_min = _logsnr_cosine_shifted(n, logsnr_min, logsnr_max, scale_min) + logsnrs_max = _logsnr_cosine_shifted(n, logsnr_min, logsnr_max, scale_max) + logsnrs = t * logsnrs_min + (1 - t) * logsnrs_max + return logsnrs + + +def karras_schedule(n, sigma_min=0.002, sigma_max=80.0, rho=7.0): + ramp = torch.linspace(1, 0, n) + min_inv_rho = sigma_min**(1 / rho) + max_inv_rho = sigma_max**(1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho))**rho + sigmas = torch.sqrt(sigmas**2 / (1 + sigmas**2)) + return sigmas + + +def logsnr_cosine_interp_schedule(n, + logsnr_min=-15, + logsnr_max=15, + scale_min=2, + scale_max=4): + return logsnrs_to_sigmas( + _logsnr_cosine_interp(n, logsnr_min, logsnr_max, scale_min, scale_max)) + + +def noise_schedule(schedule='logsnr_cosine_interp', + n=1000, + zero_terminal_snr=False, + **kwargs): + # compute sigmas + sigmas = { + 'logsnr_cosine_interp': logsnr_cosine_interp_schedule + }[schedule](n, **kwargs) + + # post-processing + if zero_terminal_snr and sigmas.max() != 1.0: + scale = (1.0 - sigmas.min()) / (sigmas.max() - sigmas.min()) + sigmas = sigmas.min() + scale * (sigmas - sigmas.min()) + return sigmas diff --git a/modelscope/models/multi_modal/video_to_video/utils/seed.py b/modelscope/models/multi_modal/video_to_video/utils/seed.py new file mode 100644 index 00000000..df3c9c50 --- /dev/null +++ b/modelscope/models/multi_modal/video_to_video/utils/seed.py @@ -0,0 +1,14 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import random + +import numpy as np +import torch + + +def setup_seed(seed): + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + torch.backends.cudnn.deterministic = True diff --git a/modelscope/models/multi_modal/video_to_video/utils/solvers_sdedit.py b/modelscope/models/multi_modal/video_to_video/utils/solvers_sdedit.py new file mode 100644 index 00000000..8d00a39f --- /dev/null +++ b/modelscope/models/multi_modal/video_to_video/utils/solvers_sdedit.py @@ -0,0 +1,194 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import torch +import torchsde +from tqdm.auto import trange + + +def get_ancestral_step(sigma_from, sigma_to, eta=1.): + """ + Calculates the noise level (sigma_down) to step down to and the amount + of noise to add (sigma_up) when doing an ancestral sampling step. + """ + if not eta: + return sigma_to, 0. + sigma_up = min( + sigma_to, + eta * ( + sigma_to**2 * # noqa + (sigma_from**2 - sigma_to**2) / sigma_from**2)**0.5) + sigma_down = (sigma_to**2 - sigma_up**2)**0.5 + return sigma_down, sigma_up + + +def get_scalings(sigma): + c_out = -sigma + c_in = 1 / (sigma**2 + 1.**2)**0.5 + return c_out, c_in + + +@torch.no_grad() +def sample_heun(noise, + model, + sigmas, + s_churn=0., + s_tmin=0., + s_tmax=float('inf'), + s_noise=1., + show_progress=True): + """ + Implements Algorithm 2 (Heun steps) from Karras et al. (2022). + """ + x = noise * sigmas[0] + for i in trange(len(sigmas) - 1, disable=not show_progress): + gamma = 0. + if s_tmin <= sigmas[i] <= s_tmax and sigmas[i] < float('inf'): + gamma = min(s_churn / (len(sigmas) - 1), 2**0.5 - 1) + eps = torch.randn_like(x) * s_noise + sigma_hat = sigmas[i] * (gamma + 1) + if gamma > 0: + x = x + eps * (sigma_hat**2 - sigmas[i]**2)**0.5 + if sigmas[i] == float('inf'): + # Euler method + denoised = model(noise, sigma_hat) + x = denoised + sigmas[i + 1] * (gamma + 1) * noise + else: + _, c_in = get_scalings(sigma_hat) + denoised = model(x * c_in, sigma_hat) + d = (x - denoised) / sigma_hat + dt = sigmas[i + 1] - sigma_hat + if sigmas[i + 1] == 0: + # Euler method + x = x + d * dt + else: + # Heun's method + x_2 = x + d * dt + _, c_in = get_scalings(sigmas[i + 1]) + denoised_2 = model(x_2 * c_in, sigmas[i + 1]) + d_2 = (x_2 - denoised_2) / sigmas[i + 1] + d_prime = (d + d_2) / 2 + x = x + d_prime * dt + return x + + +class BatchedBrownianTree: + """ + A wrapper around torchsde.BrownianTree that enables batches of entropy. + """ + + def __init__(self, x, t0, t1, seed=None, **kwargs): + t0, t1, self.sign = self.sort(t0, t1) + w0 = kwargs.get('w0', torch.zeros_like(x)) + if seed is None: + seed = torch.randint(0, 2**63 - 1, []).item() + self.batched = True + try: + assert len(seed) == x.shape[0] + w0 = w0[0] + except TypeError: + seed = [seed] + self.batched = False + self.trees = [ + torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) + for s in seed + ] + + @staticmethod + def sort(a, b): + return (a, b, 1) if a < b else (b, a, -1) + + def __call__(self, t0, t1): + t0, t1, sign = self.sort(t0, t1) + w = torch.stack([tree(t0, t1) for tree in self.trees]) * ( + self.sign * sign) + return w if self.batched else w[0] + + +class BrownianTreeNoiseSampler: + """ + A noise sampler backed by a torchsde.BrownianTree. + + Args: + x (Tensor): The tensor whose shape, device and dtype to use to generate + random samples. + sigma_min (float): The low end of the valid interval. + sigma_max (float): The high end of the valid interval. + seed (int or List[int]): The random seed. If a list of seeds is + supplied instead of a single integer, then the noise sampler will + use one BrownianTree per batch item, each with its own seed. + transform (callable): A function that maps sigma to the sampler's + internal timestep. + """ + + def __init__(self, + x, + sigma_min, + sigma_max, + seed=None, + transform=lambda x: x): + self.transform = transform + t0 = self.transform(torch.as_tensor(sigma_min)) + t1 = self.transform(torch.as_tensor(sigma_max)) + self.tree = BatchedBrownianTree(x, t0, t1, seed) + + def __call__(self, sigma, sigma_next): + t0 = self.transform(torch.as_tensor(sigma)) + t1 = self.transform(torch.as_tensor(sigma_next)) + return self.tree(t0, t1) / (t1 - t0).abs().sqrt() + + +@torch.no_grad() +def sample_dpmpp_2m_sde(noise, + model, + sigmas, + eta=1., + s_noise=1., + solver_type='midpoint', + show_progress=True): + """ + DPM-Solver++ (2M) SDE. + """ + assert solver_type in {'heun', 'midpoint'} + + x = noise * sigmas[0] + sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas[ + sigmas < float('inf')].max() + noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) + old_denoised = None + h_last = None + + for i in trange(len(sigmas) - 1, disable=not show_progress): + if sigmas[i] == float('inf'): + # Euler method + denoised = model(noise, sigmas[i]) + x = denoised + sigmas[i + 1] * noise + else: + _, c_in = get_scalings(sigmas[i]) + denoised = model(x * c_in, sigmas[i]) + if sigmas[i + 1] == 0: + # Denoising step + x = denoised + else: + # DPM-Solver++(2M) SDE + t, s = -sigmas[i].log(), -sigmas[i + 1].log() + h = s - t + eta_h = eta * h + + x = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * x + \ + (-h - eta_h).expm1().neg() * denoised + + if old_denoised is not None: + r = h_last / h + if solver_type == 'heun': + x = x + ((-h - eta_h).expm1().neg() / (-h - eta_h) + 1) * \ + (1 / r) * (denoised - old_denoised) + elif solver_type == 'midpoint': + x = x + 0.5 * (-h - eta_h).expm1().neg() * \ + (1 / r) * (denoised - old_denoised) + + x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[ + i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise + + old_denoised = denoised + h_last = h + return x diff --git a/modelscope/models/multi_modal/video_to_video/utils/transforms.py b/modelscope/models/multi_modal/video_to_video/utils/transforms.py new file mode 100644 index 00000000..3663620f --- /dev/null +++ b/modelscope/models/multi_modal/video_to_video/utils/transforms.py @@ -0,0 +1,404 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import math +import random + +import numpy as np +import torch +import torchvision.transforms.functional as F +from PIL import Image, ImageFilter + +__all__ = [ + 'Compose', 'Resize', 'Rescale', 'CenterCrop', 'CenterCropV2', + 'CenterCropWide', 'RandomCrop', 'RandomCropV2', 'RandomHFlip', + 'GaussianBlur', 'ColorJitter', 'RandomGray', 'ToTensor', 'Normalize', + 'ResizeRandomCrop', 'ExtractResizeRandomCrop', 'ExtractResizeAssignCrop' +] + + +class Compose(object): + + def __init__(self, transforms): + self.transforms = transforms + + def __getitem__(self, index): + if isinstance(index, slice): + return Compose(self.transforms[index]) + else: + return self.transforms[index] + + def __len__(self): + return len(self.transforms) + + def __call__(self, rgb): + for t in self.transforms: + rgb = t(rgb) + return rgb + + +class Resize(object): + + def __init__(self, size=256): + if isinstance(size, int): + size = (size, size) + self.size = size + + def __call__(self, rgb): + if isinstance(rgb, list): + rgb = [u.resize(self.size, Image.BILINEAR) for u in rgb] + else: + rgb = rgb.resize(self.size, Image.BILINEAR) + return rgb + + +class Rescale(object): + + def __init__(self, size=256, interpolation=Image.BILINEAR): + self.size = size + self.interpolation = interpolation + + def __call__(self, rgb): + w, h = rgb[0].size + scale = self.size / min(w, h) + out_w, out_h = int(round(w * scale)), int(round(h * scale)) + rgb = [u.resize((out_w, out_h), self.interpolation) for u in rgb] + return rgb + + +class CenterCrop(object): + + def __init__(self, size=224): + self.size = size + + def __call__(self, rgb): + w, h = rgb[0].size + assert min(w, h) >= self.size + x1 = (w - self.size) // 2 + y1 = (h - self.size) // 2 + rgb = [u.crop((x1, y1, x1 + self.size, y1 + self.size)) for u in rgb] + return rgb + + +class ResizeRandomCrop(object): + + def __init__(self, size=256, size_short=292): + self.size = size + self.size_short = size_short + + def __call__(self, rgb): + + # consistent crop between rgb and m + while min(rgb[0].size) >= 2 * self.size_short: + rgb = [ + u.resize((u.width // 2, u.height // 2), resample=Image.BOX) + for u in rgb + ] + scale = self.size_short / min(rgb[0].size) + rgb = [ + u.resize((round(scale * u.width), round(scale * u.height)), + resample=Image.BICUBIC) for u in rgb + ] + out_w = self.size + out_h = self.size + w, h = rgb[0].size + x1 = random.randint(0, w - out_w) + y1 = random.randint(0, h - out_h) + + rgb = [u.crop((x1, y1, x1 + out_w, y1 + out_h)) for u in rgb] + return rgb + + +class ExtractResizeRandomCrop(object): + + def __init__(self, size=256, size_short=292): + self.size = size + self.size_short = size_short + + def __call__(self, rgb): + + # consistent crop between rgb and m + while min(rgb[0].size) >= 2 * self.size_short: + rgb = [ + u.resize((u.width // 2, u.height // 2), resample=Image.BOX) + for u in rgb + ] + scale = self.size_short / min(rgb[0].size) + rgb = [ + u.resize((round(scale * u.width), round(scale * u.height)), + resample=Image.BICUBIC) for u in rgb + ] + out_w = self.size + out_h = self.size + w, h = rgb[0].size + x1 = random.randint(0, w - out_w) + y1 = random.randint(0, h - out_h) + + rgb = [u.crop((x1, y1, x1 + out_w, y1 + out_h)) for u in rgb] + wh = [x1, y1, x1 + out_w, y1 + out_h] + return rgb, wh + + +class ExtractResizeAssignCrop(object): + + def __init__(self, size=256, size_short=292): + self.size = size + self.size_short = size_short + + def __call__(self, rgb, wh): + + # consistent crop between rgb and m + while min(rgb[0].size) >= 2 * self.size_short: + rgb = [ + u.resize((u.width // 2, u.height // 2), resample=Image.BOX) + for u in rgb + ] + scale = self.size_short / min(rgb[0].size) + rgb = [ + u.resize((round(scale * u.width), round(scale * u.height)), + resample=Image.BICUBIC) for u in rgb + ] + + rgb = [u.crop(wh) for u in rgb] + rgb = [u.resize((self.size, self.size), Image.BILINEAR) for u in rgb] + + return rgb + + +class CenterCropV2(object): + + def __init__(self, size): + self.size = size + + def __call__(self, img): + # fast resize + while min(img[0].size) >= 2 * self.size: + img = [ + u.resize((u.width // 2, u.height // 2), resample=Image.BOX) + for u in img + ] + scale = self.size / min(img[0].size) + img = [ + u.resize((round(scale * u.width), round(scale * u.height)), + resample=Image.BICUBIC) for u in img + ] + + # center crop + x1 = (img[0].width - self.size) // 2 + y1 = (img[0].height - self.size) // 2 + img = [u.crop((x1, y1, x1 + self.size, y1 + self.size)) for u in img] + return img + + +class CenterCropWide(object): + + def __init__(self, size): + self.size = size + + def __call__(self, img): + if isinstance(img, list): + scale = min(img[0].size[0] / self.size[0], + img[0].size[1] / self.size[1]) + img = [ + u.resize((round(u.width // scale), round(u.height // scale)), + resample=Image.BOX) for u in img + ] + + # center crop + x1 = (img[0].width - self.size[0]) // 2 + y1 = (img[0].height - self.size[1]) // 2 + img = [ + u.crop((x1, y1, x1 + self.size[0], y1 + self.size[1])) + for u in img + ] + return img + else: + scale = min(img.size[0] / self.size[0], img.size[1] / self.size[1]) + img = img.resize( + (round(img.width // scale), round(img.height // scale)), + resample=Image.BOX) + x1 = (img.width - self.size[0]) // 2 + y1 = (img.height - self.size[1]) // 2 + img = img.crop((x1, y1, x1 + self.size[0], y1 + self.size[1])) + return img + + +class RandomCrop(object): + + def __init__(self, size=224, min_area=0.4): + self.size = size + self.min_area = min_area + + def __call__(self, rgb): + + # consistent crop between rgb and m + w, h = rgb[0].size + area = w * h + out_w, out_h = float('inf'), float('inf') + while out_w > w or out_h > h: + target_area = random.uniform(self.min_area, 1.0) * area + aspect_ratio = random.uniform(3. / 4., 4. / 3.) + out_w = int(round(math.sqrt(target_area * aspect_ratio))) + out_h = int(round(math.sqrt(target_area / aspect_ratio))) + x1 = random.randint(0, w - out_w) + y1 = random.randint(0, h - out_h) + + rgb = [u.crop((x1, y1, x1 + out_w, y1 + out_h)) for u in rgb] + rgb = [u.resize((self.size, self.size), Image.BILINEAR) for u in rgb] + + return rgb + + +class RandomCropV2(object): + + def __init__(self, size=224, min_area=0.4, ratio=(3. / 4., 4. / 3.)): + if isinstance(size, (tuple, list)): + self.size = size + else: + self.size = (size, size) + self.min_area = min_area + self.ratio = ratio + + def _get_params(self, img): + width, height = img.size + area = height * width + + for _ in range(10): + target_area = random.uniform(self.min_area, 1.0) * area + log_ratio = (math.log(self.ratio[0]), math.log(self.ratio[1])) + aspect_ratio = math.exp(random.uniform(*log_ratio)) + + w = int(round(math.sqrt(target_area * aspect_ratio))) + h = int(round(math.sqrt(target_area / aspect_ratio))) + + if 0 < w <= width and 0 < h <= height: + i = random.randint(0, height - h) + j = random.randint(0, width - w) + return i, j, h, w + + # Fallback to central crop + in_ratio = float(width) / float(height) + if (in_ratio < min(self.ratio)): + w = width + h = int(round(w / min(self.ratio))) + elif (in_ratio > max(self.ratio)): + h = height + w = int(round(h * max(self.ratio))) + else: + w = width + h = height + i = (height - h) // 2 + j = (width - w) // 2 + return i, j, h, w + + def __call__(self, rgb): + i, j, h, w = self._get_params(rgb[0]) + rgb = [F.resized_crop(u, i, j, h, w, self.size) for u in rgb] + return rgb + + +class RandomHFlip(object): + + def __init__(self, p=0.5): + self.p = p + + def __call__(self, rgb): + if random.random() < self.p: + rgb = [u.transpose(Image.FLIP_LEFT_RIGHT) for u in rgb] + return rgb + + +class GaussianBlur(object): + + def __init__(self, sigmas=[0.1, 2.0], p=0.5): + self.sigmas = sigmas + self.p = p + + def __call__(self, rgb): + if random.random() < self.p: + sigma = random.uniform(*self.sigmas) + rgb = [ + u.filter(ImageFilter.GaussianBlur(radius=sigma)) for u in rgb + ] + return rgb + + +class ColorJitter(object): + + def __init__(self, + brightness=0.4, + contrast=0.4, + saturation=0.4, + hue=0.1, + p=0.5): + self.brightness = brightness + self.contrast = contrast + self.saturation = saturation + self.hue = hue + self.p = p + + def __call__(self, rgb): + if random.random() < self.p: + brightness, contrast, saturation, hue = self._random_params() + transforms = [ + lambda f: F.adjust_brightness(f, brightness), + lambda f: F.adjust_contrast(f, contrast), + lambda f: F.adjust_saturation(f, saturation), + lambda f: F.adjust_hue(f, hue) + ] + random.shuffle(transforms) + for t in transforms: + rgb = [t(u) for u in rgb] + + return rgb + + def _random_params(self): + brightness = random.uniform( + max(0, 1 - self.brightness), 1 + self.brightness) + contrast = random.uniform(max(0, 1 - self.contrast), 1 + self.contrast) + saturation = random.uniform( + max(0, 1 - self.saturation), 1 + self.saturation) + hue = random.uniform(-self.hue, self.hue) + return brightness, contrast, saturation, hue + + +class RandomGray(object): + + def __init__(self, p=0.2): + self.p = p + + def __call__(self, rgb): + if random.random() < self.p: + rgb = [u.convert('L').convert('RGB') for u in rgb] + return rgb + + +class ToTensor(object): + + def __call__(self, rgb): + if isinstance(rgb, list): + rgb = torch.stack([F.to_tensor(u) for u in rgb], dim=0) + else: + rgb = F.to_tensor(rgb) + + return rgb + + +class Normalize(object): + + def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]): + self.mean = mean + self.std = std + + def __call__(self, rgb): + rgb = rgb.clone() + rgb.clamp_(0, 1) + if not isinstance(self.mean, torch.Tensor): + self.mean = rgb.new_tensor(self.mean).view(-1) + if not isinstance(self.std, torch.Tensor): + self.std = rgb.new_tensor(self.std).view(-1) + if rgb.dim() == 4: + rgb.sub_(self.mean.view(1, -1, 1, + 1)).div_(self.std.view(1, -1, 1, 1)) + elif rgb.dim() == 3: + rgb.sub_(self.mean.view(-1, 1, 1)).div_(self.std.view(-1, 1, 1)) + return rgb diff --git a/modelscope/models/multi_modal/video_to_video/video_to_video_model.py b/modelscope/models/multi_modal/video_to_video/video_to_video_model.py new file mode 100755 index 00000000..283de03f --- /dev/null +++ b/modelscope/models/multi_modal/video_to_video/video_to_video_model.py @@ -0,0 +1,227 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +import os.path as osp +import random +from copy import copy +from typing import Any, Dict + +import torch +import torch.cuda.amp as amp +import torch.nn.functional as F + +import modelscope.models.multi_modal.video_to_video.utils.transforms as data +from modelscope.metainfo import Models +from modelscope.models.base import TorchModel +from modelscope.models.builder import MODELS +from modelscope.models.multi_modal.video_to_video.modules import * +from modelscope.models.multi_modal.video_to_video.modules import ( + AutoencoderKL, FrozenOpenCLIPEmbedder, Vid2VidSDUNet, + get_first_stage_encoding) +from modelscope.models.multi_modal.video_to_video.utils.config import cfg +from modelscope.models.multi_modal.video_to_video.utils.diffusion_sdedit import \ + GaussianDiffusion_SDEdit +from modelscope.models.multi_modal.video_to_video.utils.schedules_sdedit import \ + noise_schedule +from modelscope.models.multi_modal.video_to_video.utils.seed import setup_seed +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +__all__ = ['VideoToVideo'] + +logger = get_logger() + + +@MODELS.register_module( + Tasks.video_to_video, module_name=Models.video_to_video_model) +class VideoToVideo(TorchModel): + r""" + Video2Video aims to solve the task of generating super-resolution videos based on input + video and text, which is a video generation basic model developed by Alibaba Cloud. + + Paper link: https://arxiv.org/abs/2306.02018 + + Attributes: + diffusion: diffusion model for DDIM. + autoencoder: decode the latent representation of input video into visual space. + clip_encoder: encode the text into text embedding. + """ + + def __init__(self, model_dir, *args, **kwargs): + r""" + Args: + model_dir (`str` or `os.PathLike`) + Can be either: + - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co + or modelscope.cn. Valid model ids can be located at the root-level, like `bert-base-uncased`, + or namespaced under a user or organization name, like `dbmdz/bert-base-german-cased`. + - A path to a *directory* containing model weights saved using + [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. + - A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In + this case, `from_tf` should be set to `True` and a configuration object should be provided as + `config` argument. This loading path is slower than converting the TensorFlow checkpoint in a + PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards. + - A path or url to a model folder containing a *flax checkpoint file* in *.msgpack* format (e.g, + `./flax_model/` containing `flax_model.msgpack`). In this case, `from_flax` should be set to + `True`. + """ + super().__init__(model_dir=model_dir, *args, **kwargs) + + self.config = Config.from_file( + osp.join(model_dir, ModelFile.CONFIGURATION)) + + cfg.solver_mode = self.config.model.model_args.solver_mode + + # assign default value + cfg.batch_size = self.config.model.model_cfg.batch_size + cfg.target_fps = self.config.model.model_cfg.target_fps + cfg.max_frames = self.config.model.model_cfg.max_frames + cfg.latent_hei = self.config.model.model_cfg.latent_hei + cfg.latent_wid = self.config.model.model_cfg.latent_wid + cfg.model_path = osp.join(model_dir, + self.config.model.model_args.ckpt_unet) + + self.device = torch.device( + 'cuda') if torch.cuda.is_available() else torch.device('cpu') + + if 'seed' in self.config.model.model_args.keys(): + cfg.seed = self.config.model.model_args.seed + else: + cfg.seed = random.randint(0, 99999) + setup_seed(cfg.seed) + + # transform + vid_trans = data.Compose( + [data.ToTensor(), + data.Normalize(mean=cfg.mean, std=cfg.std)]) + self.vid_trans = vid_trans + + cfg.embedder.pretrained = osp.join( + model_dir, self.config.model.model_args.ckpt_clip) + clip_encoder = FrozenOpenCLIPEmbedder( + pretrained=cfg.embedder.pretrained) + clip_encoder.model.to(self.device) + self.clip_encoder = clip_encoder + logger.info(f'Build encoder with {cfg.embedder.type}') + + # [unet] + generator = Vid2VidSDUNet() + generator = generator.to(self.device) + generator.eval() + load_dict = torch.load(cfg.model_path, map_location='cpu') + ret = generator.load_state_dict(load_dict['state_dict'], strict=True) + self.generator = generator + logger.info('Load model {} path {}, with local status {}'.format( + cfg.UNet.type, cfg.model_path, ret)) + + # [diffusion] + sigmas = noise_schedule( + schedule='logsnr_cosine_interp', + n=1000, + zero_terminal_snr=True, + scale_min=2.0, + scale_max=4.0) + diffusion = GaussianDiffusion_SDEdit( + sigmas=sigmas, prediction_type='v') + self.diffusion = diffusion + logger.info('Build diffusion with type of GaussianDiffusion_SDEdit') + + # [auotoencoder] + cfg.auto_encoder.pretrained = osp.join( + model_dir, self.config.model.model_args.ckpt_autoencoder) + autoencoder = AutoencoderKL(**cfg.auto_encoder) + autoencoder.eval() + for param in autoencoder.parameters(): + param.requires_grad = False + autoencoder.to(self.device) + self.autoencoder = autoencoder + torch.cuda.empty_cache() + + negative_prompt = cfg.negative_prompt + negative_y = clip_encoder(negative_prompt).detach() + self.negative_y = negative_y + + positive_prompt = cfg.positive_prompt + self.positive_prompt = positive_prompt + + self.cfg = cfg + + def forward(self, input: Dict[str, Any]): + r""" + The entry function of video to video task. + 1. Using CLIP to encode text into embeddings. + 2. Using diffusion model to generate the video's latent representation. + 3. Using autoencoder to decode the video's latent representation to visual space. + + Args: + input (`Dict[Str, Any]`): + The input of the task + Returns: + A generated video (as pytorch tensor). + """ + + video_data = input['video_data'] + y = input['y'] + cfg = self.cfg + + video_data = F.interpolate( + video_data, size=(720, 1280), mode='bilinear') + video_data = video_data.unsqueeze(0) + video_data = video_data.to(self.device) + + batch_size, frames_num, _, _, _ = video_data.shape + video_data = rearrange(video_data, 'b f c h w -> (b f) c h w') + + video_data_list = torch.chunk( + video_data, video_data.shape[0] // 2, dim=0) + with torch.no_grad(): + decode_data = [] + for vd_data in video_data_list: + encoder_posterior = self.autoencoder.encode(vd_data) + tmp = get_first_stage_encoding(encoder_posterior).detach() + decode_data.append(tmp) + video_data_feature = torch.cat(decode_data, dim=0) + video_data_feature = rearrange( + video_data_feature, '(b f) c h w -> b c f h w', b=batch_size) + + with amp.autocast(enabled=True): + total_noise_levels = 600 + t = torch.randint( + total_noise_levels - 1, + total_noise_levels, (1, ), + dtype=torch.long).to(self.device) + + noise = torch.randn_like(video_data_feature) + noised_lr = self.diffusion.diffuse(video_data_feature, t, noise) + model_kwargs = [{'y': y}, {'y': self.negative_y}] + + gen_vid = self.diffusion.sample( + noise=noised_lr, + model=self.generator, + model_kwargs=model_kwargs, + guide_scale=7.5, + guide_rescale=0.2, + solver='dpmpp_2m_sde' if cfg.solver_mode == 'fast' else 'heun', + steps=30 if cfg.solver_mode == 'fast' else 50, + t_max=total_noise_levels - 1, + t_min=0, + discretization='trailing') + + scale_factor = 0.18215 + vid_tensor_feature = 1. / scale_factor * gen_vid + + vid_tensor_feature = rearrange(vid_tensor_feature, + 'b c f h w -> (b f) c h w') + vid_tensor_feature_list = torch.chunk( + vid_tensor_feature, vid_tensor_feature.shape[0] // 2, dim=0) + decode_data = [] + for vd_data in vid_tensor_feature_list: + tmp = self.autoencoder.decode(vd_data) + decode_data.append(tmp) + vid_tensor_gen = torch.cat(decode_data, dim=0) + + gen_video = rearrange( + vid_tensor_gen, '(b f) c h w -> b c f h w', b=cfg.batch_size) + + return gen_video.type(torch.float32).cpu() diff --git a/modelscope/models/nlp/__init__.py b/modelscope/models/nlp/__init__.py index 1090588b..0f4a0fe6 100644 --- a/modelscope/models/nlp/__init__.py +++ b/modelscope/models/nlp/__init__.py @@ -65,6 +65,7 @@ if TYPE_CHECKING: ModelForTextRanking, ModelForTokenClassification, ModelForTokenClassificationWithCRF, + ModelForMachineReadingComprehension, ) from .unite import UniTEForTranslationEvaluation from .use import UserSatisfactionEstimation @@ -159,6 +160,7 @@ else: 'ModelForTextRanking', 'ModelForTokenClassification', 'ModelForTokenClassificationWithCRF', + 'ModelForMachineReadingComprehension', ], 'sentence_embedding': ['SentenceEmbedding'], 'T5': ['T5ForConditionalGeneration'], diff --git a/modelscope/models/nlp/task_models/__init__.py b/modelscope/models/nlp/task_models/__init__.py index cd8ca926..6900d5ca 100644 --- a/modelscope/models/nlp/task_models/__init__.py +++ b/modelscope/models/nlp/task_models/__init__.py @@ -13,6 +13,7 @@ if TYPE_CHECKING: ModelForTokenClassificationWithCRF) from .text_generation import ModelForTextGeneration from .text_ranking import ModelForTextRanking + from .machine_reading_comprehension import ModelForMachineReadingComprehension else: _import_structure = { @@ -25,6 +26,8 @@ else: ['ModelForTokenClassification', 'ModelForTokenClassificationWithCRF'], 'text_generation': ['ModelForTextGeneration'], 'text_ranking': ['ModelForTextRanking'], + 'machine_reading_comprehension': + ['ModelForMachineReadingComprehension'], } import sys diff --git a/modelscope/models/nlp/task_models/machine_reading_comprehension.py b/modelscope/models/nlp/task_models/machine_reading_comprehension.py new file mode 100644 index 00000000..034e53ce --- /dev/null +++ b/modelscope/models/nlp/task_models/machine_reading_comprehension.py @@ -0,0 +1,139 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +from dataclasses import dataclass +from typing import Any, Dict, Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import nn +from transformers import AutoConfig +from transformers.modeling_outputs import ModelOutput +from transformers.models.roberta.modeling_roberta import ( + RobertaModel, RobertaPreTrainedModel) + +from modelscope.metainfo import Heads, Models, TaskModels +from modelscope.models.base import TorchModel +from modelscope.models.builder import MODELS +from modelscope.models.nlp.task_models.task_model import EncoderModel +from modelscope.outputs import MachineReadingComprehensionOutput, OutputKeys +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.hub import parse_label_mapping + +__all__ = ['ModelForMachineReadingComprehension'] + + +@MODELS.register_module( + Tasks.machine_reading_comprehension, + module_name=TaskModels.machine_reading_comprehension) +class ModelForMachineReadingComprehension(TorchModel): + ''' + Pretrained Machine Reader (PMR) model (https://arxiv.org/pdf/2212.04755.pdf) + + ''' + + _keys_to_ignore_on_load_unexpected = [r'pooler'] + _keys_to_ignore_on_load_missing = [r'position_ids'] + + def __init__(self, model_dir: str, *args, **kwargs): + super().__init__(model_dir, *args, **kwargs) + self.config = AutoConfig.from_pretrained(model_dir) + self.num_labels = self.config.num_labels + self.roberta = RobertaModel(self.config, add_pooling_layer=False) + self.span_transfer = MultiNonLinearProjection( + self.config.hidden_size, + self.config.hidden_size, + self.config.hidden_dropout_prob, + intermediate_hidden_size=self.config. + projection_intermediate_hidden_size) + self.load_state_dict( + torch.load( + os.path.join(model_dir, ModelFile.TORCH_MODEL_BIN_FILE))) + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + label_mask=None, + match_labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.roberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + # adapted from https://github.com/ShannonAI/mrc-for-flat-nested-ner + # for every position $i$ in sequence, should concate $j$ to + # predict if $i$ and $j$ are start_pos and end_pos for an entity. + # [batch, seq_len, hidden] + span_intermediate = self.span_transfer(sequence_output) + # [batch, seq_len, seq_len] + span_logits = torch.matmul(span_intermediate, + sequence_output.transpose(-1, -2)) + + total_loss = None + if match_labels is not None: + match_loss = self.compute_loss(span_logits, match_labels, + label_mask) + total_loss = match_loss + if not return_dict: + output = (span_logits) + outputs[2:] + return ((total_loss, ) + + output) if total_loss is not None else output + + return MachineReadingComprehensionOutput( + loss=total_loss, + span_logits=span_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class MultiNonLinearProjection(nn.Module): + + def __init__(self, + hidden_size, + num_label, + dropout_rate, + act_func='gelu', + intermediate_hidden_size=None): + super(MultiNonLinearProjection, self).__init__() + self.num_label = num_label + self.intermediate_hidden_size = hidden_size if intermediate_hidden_size is None else intermediate_hidden_size + self.classifier1 = nn.Linear(hidden_size, + self.intermediate_hidden_size) + self.classifier2 = nn.Linear(self.intermediate_hidden_size, + self.num_label) + self.dropout = nn.Dropout(dropout_rate) + self.act_func = act_func + + def forward(self, input_features): + features_output1 = self.classifier1(input_features) + if self.act_func == 'gelu': + features_output1 = F.gelu(features_output1) + elif self.act_func == 'relu': + features_output1 = F.relu(features_output1) + elif self.act_func == 'tanh': + features_output1 = F.tanh(features_output1) + else: + raise ValueError + features_output1 = self.dropout(features_output1) + features_output2 = self.classifier2(features_output1) + return features_output2 diff --git a/modelscope/outputs/nlp_outputs.py b/modelscope/outputs/nlp_outputs.py index d6b934c2..ed42cb5a 100644 --- a/modelscope/outputs/nlp_outputs.py +++ b/modelscope/outputs/nlp_outputs.py @@ -464,3 +464,25 @@ class TranslationEvaluationOutput(ModelOutputBase): score: Tensor = None loss: Tensor = None input_format: List[str] = None + + +@dataclass +class MachineReadingComprehensionOutput(ModelOutputBase): + """The output class for machine reading comprehension models. + + Args: + loss (`Tensor`, *optional*): The training loss of the current batch + match_loss (`Tensor`, *optinal*): The match loss of the current batch + span_logits (`Tensor`): The logits of the span matrix output by the model + hidden_states (`Tuple[Tensor]`, *optinal*): The hidden states output by the model + attentions (`Tuple[Tensor]`, *optinal*): The attention scores output by the model + input_ids (`Tensor`): The token ids of the input sentence + + """ + + loss: Optional[Tensor] = None + match_loss: Optional[Tensor] = None + span_logits: Tensor = None + hidden_states: Optional[Tuple[Tensor]] = None + attentions: Optional[Tuple[Tensor]] = None + input_ids: Tensor = None diff --git a/modelscope/pipeline_inputs.py b/modelscope/pipeline_inputs.py index 4e044054..8c7d3780 100644 --- a/modelscope/pipeline_inputs.py +++ b/modelscope/pipeline_inputs.py @@ -323,6 +323,8 @@ TASK_INPUTS = { 'positive': InputType.LIST, 'negative': InputType.LIST }, + Tasks.machine_reading_comprehension: + InputType.TEXT, # ============ audio tasks =================== Tasks.auto_speech_recognition: # input can be audio, or audio and text. diff --git a/modelscope/pipelines/audio/speaker_verification_light_pipeline.py b/modelscope/pipelines/audio/speaker_verification_light_pipeline.py index 41282067..e3d1968a 100644 --- a/modelscope/pipelines/audio/speaker_verification_light_pipeline.py +++ b/modelscope/pipelines/audio/speaker_verification_light_pipeline.py @@ -81,28 +81,24 @@ class SpeakerVerificationPipeline(Pipeline): inputs: torch.Tensor, in_audios: Union[np.ndarray, list], save_dir=None): - if isinstance(in_audios[0], str): - if save_dir is not None: - # save the embeddings - os.makedirs(save_dir, exist_ok=True) - for i, p in enumerate(in_audios): - save_path = os.path.join( - save_dir, '%s.npy' % - (os.path.basename(p).rsplit('.', 1)[0])) - np.save(save_path, inputs[i].numpy()) + if isinstance(in_audios[0], str) and save_dir is not None: + # save the embeddings + os.makedirs(save_dir, exist_ok=True) + for i, p in enumerate(in_audios): + save_path = os.path.join( + save_dir, '%s.npy' % + (os.path.basename(p).rsplit('.', 1)[0])) + np.save(save_path, inputs[i].numpy()) - if len(in_audios) == 2: - # compute the score - score = self.compute_cos_similarity(inputs[0], inputs[1]) - score = round(score, 5) - if score >= self.thr: - ans = 'yes' - else: - ans = 'no' - output = {OutputKeys.SCORE: score, OutputKeys.TEXT: ans} + if len(inputs) == 2: + # compute the score + score = self.compute_cos_similarity(inputs[0], inputs[1]) + score = round(score, 5) + if score >= self.thr: + ans = 'yes' else: - output = {OutputKeys.TEXT: 'No similarity score output'} - + ans = 'no' + output = {OutputKeys.SCORE: score, OutputKeys.TEXT: ans} else: output = {OutputKeys.TEXT: 'No similarity score output'} diff --git a/modelscope/pipelines/multi_modal/diffusers_wrapped/stable_diffusion/stable_diffusion_pipeline.py b/modelscope/pipelines/multi_modal/diffusers_wrapped/stable_diffusion/stable_diffusion_pipeline.py index 7e56f24c..e5345543 100644 --- a/modelscope/pipelines/multi_modal/diffusers_wrapped/stable_diffusion/stable_diffusion_pipeline.py +++ b/modelscope/pipelines/multi_modal/diffusers_wrapped/stable_diffusion/stable_diffusion_pipeline.py @@ -40,6 +40,7 @@ class StableDiffusionPipeline(DiffusersPipeline): use_safetensors: load safetensors weights. """ use_safetensors = kwargs.pop('use_safetensors', False) + torch_type = kwargs.pop('torch_type', torch.float32) # check custom diffusion input value if custom_dir is None and modifier_token is not None: raise ValueError( @@ -50,7 +51,6 @@ class StableDiffusionPipeline(DiffusersPipeline): self.device = 'cuda' if torch.cuda.is_available() else 'cpu' # load pipeline - torch_type = torch.float16 if self.device == 'cuda' else torch.float32 self.pipeline = DiffusionPipeline.from_pretrained( model, use_safetensors=use_safetensors, torch_dtype=torch_type) self.pipeline = self.pipeline.to(self.device) diff --git a/modelscope/pipelines/multi_modal/image_to_video_pipeline.py b/modelscope/pipelines/multi_modal/image_to_video_pipeline.py new file mode 100644 index 00000000..7ac71d16 --- /dev/null +++ b/modelscope/pipelines/multi_modal/image_to_video_pipeline.py @@ -0,0 +1,104 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +import tempfile +from typing import Any, Dict, Optional + +import cv2 +import torch +from einops import rearrange + +from modelscope.metainfo import Pipelines +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Model, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors.image import LoadImage +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.image_to_video, module_name=Pipelines.image_to_video_task_pipeline) +class ImageToVideoPipeline(Pipeline): + r""" Image To Video Pipeline. + + Examples: + >>> from modelscope.pipelines import pipeline + >>> from modelscope.outputs import OutputKeys + + >>> p = pipeline('image-to-video', 'damo/Image-to-Video') + >>> input = 'path_to_image' + >>> p(input,) + + >>> {OutputKeys.OUTPUT_VIDEO: path-to-the-generated-video} + >>> + """ + + def __init__(self, model: str, **kwargs): + """ + use `model` to create a kws pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model, **kwargs) + + def preprocess(self, input: Input, **preprocess_params) -> Dict[str, Any]: + img_path = input + + image = LoadImage.convert_to_img(img_path) + if image.mode != 'RGB': + image = image.convert('RGB') + + vit_frame = self.model.vid_trans(image) + vit_frame = vit_frame.unsqueeze(0) + vit_frame = vit_frame.to(self.model.device) + + return {'vit_frame': vit_frame} + + def forward(self, input: Dict[str, Any], + **forward_params) -> Dict[str, Any]: + video = self.model(input) + return {'video': video} + + def postprocess(self, inputs: Dict[str, Any], + **post_params) -> Dict[str, Any]: + video = tensor2vid(inputs['video'], self.model.cfg.mean, + self.model.cfg.std) + output_video_path = post_params.get('output_video', None) + temp_video_file = False + if output_video_path is None: + output_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name + temp_video_file = True + + fourcc = cv2.VideoWriter_fourcc(*'mp4v') + h, w, c = video[0].shape + video_writer = cv2.VideoWriter( + output_video_path, fourcc, fps=8, frameSize=(w, h)) + for i in range(len(video)): + img = cv2.cvtColor(video[i], cv2.COLOR_RGB2BGR) + video_writer.write(img) + video_writer.release() + if temp_video_file: + video_file_content = b'' + with open(output_video_path, 'rb') as f: + video_file_content = f.read() + os.remove(output_video_path) + return {OutputKeys.OUTPUT_VIDEO: video_file_content} + else: + return {OutputKeys.OUTPUT_VIDEO: output_video_path} + + +def tensor2vid(video, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]): + mean = torch.tensor(mean, device=video.device).reshape(1, -1, 1, 1, 1) + std = torch.tensor(std, device=video.device).reshape(1, -1, 1, 1, 1) + + video = video.mul_(std).add_(mean) + video.clamp_(0, 1) + video = video * 255.0 + + images = rearrange(video, 'b c f h w -> b f h w c')[0] + images = [(img.numpy()).astype('uint8') for img in images] + + return images diff --git a/modelscope/pipelines/multi_modal/video_to_video_pipeline.py b/modelscope/pipelines/multi_modal/video_to_video_pipeline.py new file mode 100644 index 00000000..36e6544d --- /dev/null +++ b/modelscope/pipelines/multi_modal/video_to_video_pipeline.py @@ -0,0 +1,140 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +import tempfile +from typing import Any, Dict, Optional + +import cv2 +import torch +from einops import rearrange + +from modelscope.metainfo import Pipelines +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Model, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors.image import LoadImage +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.video_to_video, module_name=Pipelines.video_to_video_pipeline) +class VideoToVideoPipeline(Pipeline): + r""" Video To Video Pipeline, generating super-resolution videos based on input + video and text + + Examples: + >>> from modelscope.pipelines import pipeline + >>> from modelscope.outputs import OutputKeys + + >>> # YOUR_VIDEO_PATH: your video url or local position in low resolution + >>> # INPUT_TEXT: when we do video super-resolution, we will add the text content + >>> # into results + >>> # output_video_path: path-to-the-generated-video + + >>> p = pipeline('video-to-video', 'damo/Video-to-Video') + >>> input = {"video_path":YOUR_VIDEO_PATH, "text": INPUT_TEXT} + >>> output_video_path = p(input,output_video='./output.mp4')[OutputKeys.OUTPUT_VIDEO] + + """ + + def __init__(self, model: str, **kwargs): + """ + use `model` to create a kws pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model, **kwargs) + + def preprocess(self, input: Input, **preprocess_params) -> Dict[str, Any]: + vid_path = input['video_path'] + if 'text' in input.keys(): + text = input['text'] + else: + text = '' + + caption = text + self.model.positive_prompt + y = self.model.clip_encoder(caption).detach() + + max_frames = self.model.cfg.max_frames + + capture = cv2.VideoCapture(vid_path) + _fps = capture.get(cv2.CAP_PROP_FPS) + sample_fps = _fps + _total_frame_num = capture.get(cv2.CAP_PROP_FRAME_COUNT) + stride = round(_fps / sample_fps) + start_frame = 0 + + pointer = 0 + frame_list = [] + while len(frame_list) < max_frames: + ret, frame = capture.read() + pointer += 1 + if (not ret) or (frame is None): + break + if pointer < start_frame: + continue + if pointer >= _total_frame_num + 1: + break + if (pointer - start_frame) % stride == 0: + frame = LoadImage.convert_to_img(frame) + frame_list.append(frame) + capture.release() + + video_data = self.model.vid_trans(frame_list) + + return {'video_data': video_data, 'y': y} + + def forward(self, input: Dict[str, Any], + **forward_params) -> Dict[str, Any]: + video = self.model(input) + return {'video': video} + + def postprocess(self, inputs: Dict[str, Any], + **post_params) -> Dict[str, Any]: + video = tensor2vid(inputs['video'], self.model.cfg.mean, + self.model.cfg.std) + output_video_path = post_params.get('output_video', None) + temp_video_file = False + if output_video_path is None: + output_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name + temp_video_file = True + + temp_dir = tempfile.mkdtemp() + for fid, frame in enumerate(video): + tpth = os.path.join(temp_dir, '%06d.png' % (fid + 1)) + cv2.imwrite(tpth, frame[:, :, ::-1], + [int(cv2.IMWRITE_JPEG_QUALITY), 100]) + + cmd = f'ffmpeg -y -f image2 -loglevel quiet -framerate 8.0 -i {temp_dir}/%06d.png \ + -vcodec libx264 -crf 17 -pix_fmt yuv420p {output_video_path}' + + status = os.system(cmd) + if status != 0: + logger.info('Save Video Error with {}'.format(status)) + os.system(f'rm -rf {temp_dir}') + + if temp_video_file: + video_file_content = b'' + with open(output_video_path, 'rb') as f: + video_file_content = f.read() + os.remove(output_video_path) + return {OutputKeys.OUTPUT_VIDEO: video_file_content} + else: + return {OutputKeys.OUTPUT_VIDEO: output_video_path} + + +def tensor2vid(video, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]): + mean = torch.tensor(mean, device=video.device).reshape(1, -1, 1, 1, 1) + std = torch.tensor(std, device=video.device).reshape(1, -1, 1, 1, 1) + + video = video.mul_(std).add_(mean) + video.clamp_(0, 1) + video = video * 255.0 + + images = rearrange(video, 'b c f h w -> b f h w c')[0] + images = [(img.numpy()).astype('uint8') for img in images] + + return images diff --git a/modelscope/pipelines/nlp/__init__.py b/modelscope/pipelines/nlp/__init__.py index 3f958826..1216464e 100644 --- a/modelscope/pipelines/nlp/__init__.py +++ b/modelscope/pipelines/nlp/__init__.py @@ -45,6 +45,7 @@ if TYPE_CHECKING: from .document_grounded_dialog_retrieval_pipeline import DocumentGroundedDialogRetrievalPipeline from .document_grounded_dialog_rerank_pipeline import DocumentGroundedDialogRerankPipeline from .language_identification_pipline import LanguageIdentificationPipeline + from .machine_reading_comprehension_pipeline import MachineReadingComprehensionForNERPipeline else: _import_structure = { @@ -108,7 +109,10 @@ else: 'document_grounded_dialog_retrieval_pipeline': [ 'DocumentGroundedDialogRetrievalPipeline' ], - 'language_identification_pipline': ['LanguageIdentificationPipeline'] + 'language_identification_pipline': ['LanguageIdentificationPipeline'], + 'machine_reading_comprehension_pipeline': [ + 'MachineReadingComprehensionForNERPipeline' + ], } import sys diff --git a/modelscope/pipelines/nlp/machine_reading_comprehension_pipeline.py b/modelscope/pipelines/nlp/machine_reading_comprehension_pipeline.py new file mode 100644 index 00000000..08234eef --- /dev/null +++ b/modelscope/pipelines/nlp/machine_reading_comprehension_pipeline.py @@ -0,0 +1,83 @@ +from typing import Any, Dict, Union + +import torch + +from modelscope.metainfo import Pipelines, Preprocessors +from modelscope.models.base import Model +from modelscope.outputs import MachineReadingComprehensionOutput, OutputKeys +from modelscope.pipelines.base import Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import Preprocessor +from modelscope.utils.constant import Fields, Tasks + + +@PIPELINES.register_module( + Tasks.machine_reading_comprehension, + module_name=Pipelines.machine_reading_comprehension_for_ner) +class MachineReadingComprehensionForNERPipeline(Pipeline): + ''' + Pipeline for Pretrained Machine Reader (PMR) finetuned on Named Entity Recognition (NER) + + Examples: + + >>> from modelscope.pipelines import pipeline + >>> pipeline_ins = pipeline( + >>> task=Tasks.machine_reading_comprehension, + >>> model='damo/nlp_roberta_machine-reading-comprehension_for-ner') + >>> pipeline_ins('Soccer - Japan get lucky win , China in surprise defeat .') + >>> {'ORG': [], 'PER': [], 'LOC': [' Japan', ' China'], 'MISC': []} + ''' + + def __init__(self, + model: Union[Model, str], + preprocessor: Preprocessor = None, + config_file: str = None, + device: str = 'gpu', + auto_collate=True, + **kwargs): + super().__init__( + model=model, + preprocessor=preprocessor, + config_file=config_file, + device=device, + auto_collate=auto_collate) + + assert isinstance(self.model, Model), \ + f'please check whether model config exists in {ModelFile.CONFIGURATION}' + + if preprocessor is None: + self.preprocessor = Preprocessor.from_pretrained( + self.model.model_dir, **kwargs) + self.labels = [label for label in self.preprocessor.label2query] + + self.model.eval() + + def forward( + self, inputs, **forward_params + ) -> Union[Dict[str, Any], MachineReadingComprehensionOutput]: + with torch.no_grad(): + outputs = self.model(**inputs) + span_logits = outputs['span_logits'] + + return MachineReadingComprehensionOutput( + span_logits=span_logits, + input_ids=inputs['input_ids'], + ) + + def postprocess( + self, inputs: Union[Dict[str, Any], MachineReadingComprehensionOutput] + ) -> Dict[str, Any]: + + span_preds = inputs['span_logits'] > 0 + extracted_indices = torch.nonzero(span_preds.long()) + + result = {label: [] for label in self.labels} + for index in extracted_indices: + label = self.labels[index[0]] + start = index[1] + end = index[2] + 1 + ids = inputs['input_ids'][index[0], start:end] + entity = self.preprocessor.tokenizer.decode(ids) + result[label].append(entity) + + return result diff --git a/modelscope/preprocessors/__init__.py b/modelscope/preprocessors/__init__.py index dbcb0813..6bfa4330 100644 --- a/modelscope/preprocessors/__init__.py +++ b/modelscope/preprocessors/__init__.py @@ -46,7 +46,8 @@ if TYPE_CHECKING: CanmtTranslationPreprocessor, DialogueClassificationUsePreprocessor, SiameseUiePreprocessor, DocumentGroundedDialogGeneratePreprocessor, DocumentGroundedDialogRetrievalPreprocessor, - DocumentGroundedDialogRerankPreprocessor) + DocumentGroundedDialogRerankPreprocessor, + MachineReadingComprehensionForNERPreprocessor) from .video import ReadVideoData, MovieSceneSegmentationPreprocessor else: @@ -77,7 +78,8 @@ else: 'nlp': [ 'DocumentSegmentationTransformersPreprocessor', 'FaqQuestionAnsweringTransformersPreprocessor', - 'FillMaskPoNetPreprocessor', 'FillMaskTransformersPreprocessor', + 'FillMaskPoNetPreprocessor', + 'FillMaskTransformersPreprocessor', 'NLPTokenizerPreprocessorBase', 'TextRankingTransformersPreprocessor', 'RelationExtractionTransformersPreprocessor', @@ -85,26 +87,34 @@ else: 'TextGenerationSentencePiecePreprocessor', 'TextClassificationTransformersPreprocessor', 'TokenClassificationTransformersPreprocessor', - 'TextErrorCorrectionPreprocessor', 'WordAlignmentPreprocessor', - 'TextGenerationTransformersPreprocessor', 'Tokenize', + 'TextErrorCorrectionPreprocessor', + 'WordAlignmentPreprocessor', + 'TextGenerationTransformersPreprocessor', + 'Tokenize', 'TextGenerationT5Preprocessor', 'WordSegmentationBlankSetToLabelPreprocessor', - 'MGLMSummarizationPreprocessor', 'CodeGeeXPreprocessor', + 'MGLMSummarizationPreprocessor', + 'CodeGeeXPreprocessor', 'ZeroShotClassificationTransformersPreprocessor', - 'TextGenerationJiebaPreprocessor', 'SentencePiecePreprocessor', - 'NERPreprocessorViet', 'NERPreprocessorThai', + 'TextGenerationJiebaPreprocessor', + 'SentencePiecePreprocessor', + 'NERPreprocessorViet', + 'NERPreprocessorThai', 'WordSegmentationPreprocessorThai', - 'DialogIntentPredictionPreprocessor', 'DialogModelingPreprocessor', + 'DialogIntentPredictionPreprocessor', + 'DialogModelingPreprocessor', 'DialogStateTrackingPreprocessor', 'ConversationalTextToSqlPreprocessor', 'TableQuestionAnsweringPreprocessor', 'TranslationEvaluationTransformersPreprocessor', 'CanmtTranslationPreprocessor', - 'DialogueClassificationUsePreprocessor', 'SiameseUiePreprocessor', + 'DialogueClassificationUsePreprocessor', + 'SiameseUiePreprocessor', 'DialogueClassificationUsePreprocessor', 'DocumentGroundedDialogGeneratePreprocessor', 'DocumentGroundedDialogRetrievalPreprocessor', - 'DocumentGroundedDialogRerankPreprocessor' + 'DocumentGroundedDialogRerankPreprocessor', + 'MachineReadingComprehensionForNERPreprocessor', ], } diff --git a/modelscope/preprocessors/nlp/__init__.py b/modelscope/preprocessors/nlp/__init__.py index 19421fa0..ad15b965 100644 --- a/modelscope/preprocessors/nlp/__init__.py +++ b/modelscope/preprocessors/nlp/__init__.py @@ -36,6 +36,7 @@ if TYPE_CHECKING: from .document_grounded_dialog_generate_preprocessor import DocumentGroundedDialogGeneratePreprocessor from .document_grounded_dialog_retrieval_preprocessor import DocumentGroundedDialogRetrievalPreprocessor from .document_grounded_dialog_rerank_preprocessor import DocumentGroundedDialogRerankPreprocessor + from .machine_reading_comprehension_preprocessor import MachineReadingComprehensionForNERPreprocessor else: _import_structure = { 'bert_seq_cls_tokenizer': ['Tokenize'], @@ -102,7 +103,10 @@ else: 'document_grounded_dialog_retrieval_preprocessor': ['DocumentGroundedDialogRetrievalPreprocessor'], 'document_grounded_dialog_rerank_preprocessor': - ['DocumentGroundedDialogRerankPreprocessor'] + ['DocumentGroundedDialogRerankPreprocessor'], + 'machine_reading_comprehension_preprocessor': [ + 'MachineReadingComprehensionForNERPreprocessor' + ], } import sys diff --git a/modelscope/preprocessors/nlp/machine_reading_comprehension_preprocessor.py b/modelscope/preprocessors/nlp/machine_reading_comprehension_preprocessor.py new file mode 100644 index 00000000..04a29aaf --- /dev/null +++ b/modelscope/preprocessors/nlp/machine_reading_comprehension_preprocessor.py @@ -0,0 +1,252 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os + +import torch +from torch.utils.data import DataLoader, Dataset, SequentialSampler +from transformers import AutoTokenizer +from transformers.tokenization_utils_base import TruncationStrategy +from transformers.utils import logging + +from modelscope.metainfo import Preprocessors +from modelscope.outputs import OutputKeys +from modelscope.preprocessors import Preprocessor +from modelscope.preprocessors.builder import PREPROCESSORS +from modelscope.utils.config import Config +from modelscope.utils.constant import ConfigFields, Fields, ModeKeys, ModelFile + +logger = logging.get_logger(__name__) +MULTI_SEP_TOKENS_TOKENIZERS_SET = {'roberta', 'camembert', 'bart', 'mpnet'} + + +@PREPROCESSORS.register_module( + Fields.nlp, + module_name=Preprocessors.machine_reading_comprehension_for_ner) +class MachineReadingComprehensionForNERPreprocessor(Preprocessor): + ''' + Preprocessor for Pretrained Machiner Reader (PMR) finetuned on Named Entity Recognition (NER) + + ''' + + def __init__(self, model_dir, label2query=None): + super().__init__() + self.tokenizer = AutoTokenizer.from_pretrained( + model_dir, use_fast=False) + + if label2query is None: + config_path = os.path.join(model_dir, ModelFile.CONFIGURATION) + config = Config.from_file(config_path) + self.label2query = config[ConfigFields.preprocessor].label2query + else: + self.label2query = label2query + + def __call__(self, data: str): + all_data = [] + for label in self.label2query: + all_data.append({ + 'context': data, + 'end_position': [], + 'entity_label': label, + 'impossible': False, + 'qas_id': '', + 'query': self.label2query[label], + 'span_position': [], + 'start_position': [] + }) + + all_data = self.prompt(all_data) + output = [] + for data in all_data: + output.append(self.encode(data)) + output = collate_to_max_length_roberta(output) + + output = { + 'input_ids': output[0], + 'attention_mask': output[1], + 'token_type_ids': output[2], + } + + return output + + def prompt(self, all_data, var=0): + new_datas = [] + for data in all_data: + label = data['entity_label'] + details = data['query'] + context = data['context'] + start_positions = data['start_position'] + end_positions = data['end_position'] + words = context.split() + assert len(words) == len(context.split(' ')) + if var == 0: + query = '"{}". {}'.format(label, details) # ori + elif var == 1: + query = 'What are the "{}" entity, where {}'.format( + label, details) # variant 1 + elif var == 2: + query = 'Identify the spans (if any) related to "{}" entity. Details: {}'.format( + label, details) # variant 2 + span_positions = { + '{};{}'.format(start_positions[i], end_positions[i]): + ' '.join(words[start_positions[i]:end_positions[i] + 1]) + for i in range(len(start_positions)) + } + new_data = { + 'context': words, + 'end_position': end_positions, + 'entity_label': label, + 'impossible': data['impossible'], + 'qas_id': data['qas_id'], + 'query': query, + 'span_position': span_positions, + 'start_position': start_positions, + } + new_datas.append(new_data) + return new_datas + + def encode(self, data, max_length=512, max_query_length=64): + + tokenizer = self.tokenizer + + query = data['query'] + context = data['context'] + start_positions = data['start_position'] + end_positions = data['end_position'] + + tokenizer_type = type(tokenizer).__name__.replace('Tokenizer', + '').lower() + sequence_added_tokens = ( + tokenizer.model_max_length - tokenizer.max_len_single_sentence + + 1 if tokenizer_type in MULTI_SEP_TOKENS_TOKENIZERS_SET else + tokenizer.model_max_length - tokenizer.max_len_single_sentence) + + tok_to_orig_index = [] + orig_to_tok_index = [] + all_doc_tokens = [] + for (i, token) in enumerate(context): + orig_to_tok_index.append(len(all_doc_tokens)) + if tokenizer.__class__.__name__ in [ + 'RobertaTokenizer', + 'LongformerTokenizer', + 'BartTokenizer', + 'RobertaTokenizerFast', + 'LongformerTokenizerFast', + 'BartTokenizerFast', + ]: + sub_tokens = tokenizer.tokenize(token, add_prefix_space=True) + elif tokenizer.__class__.__name__ in ['BertTokenizer']: + sub_tokens = tokenizer.tokenize(token) + elif tokenizer.__class__.__name__ in ['BertWordPieceTokenizer']: + sub_tokens = tokenizer.encode( + token, add_special_tokens=False).tokens + else: + sub_tokens = tokenizer.tokenize(token) + for sub_token in sub_tokens: + tok_to_orig_index.append(i) + all_doc_tokens.append(sub_token) + + tok_start_positions = [orig_to_tok_index[x] for x in start_positions] + tok_end_positions = [] + for x in end_positions: + if x < len(context) - 1: + tok_end_positions.append(orig_to_tok_index[x + 1] - 1) + else: + tok_end_positions.append(len(all_doc_tokens) - 1) + + truncation = TruncationStrategy.ONLY_SECOND.value + padding_strategy = 'do_not_pad' + + truncated_query = tokenizer.encode( + query, + add_special_tokens=False, + truncation=True, + max_length=max_query_length) + encoded_dict = tokenizer.encode_plus( # TODO(thom) update this logic + truncated_query, + all_doc_tokens, + truncation=truncation, + padding=padding_strategy, + max_length=max_length, + return_overflowing_tokens=True, + return_token_type_ids=True, + ) + tokens = encoded_dict['input_ids'] + type_ids = encoded_dict['token_type_ids'] + attn_mask = encoded_dict['attention_mask'] + + # find new start_positions/end_positions, considering + # 1. we add query tokens at the beginning + # 2. special tokens + doc_offset = len(truncated_query) + sequence_added_tokens + new_start_positions = [ + x + doc_offset for x in tok_start_positions + if (x + doc_offset) < max_length - 1 + ] + new_end_positions = [ + x + doc_offset if + (x + doc_offset) < max_length - 1 else max_length - 2 + for x in tok_end_positions + ] + new_end_positions = new_end_positions[:len(new_start_positions)] + + label_mask = [0] * doc_offset + [1] * (len(tokens) - doc_offset + - 1) + [0] + + assert all(label_mask[p] != 0 for p in new_start_positions) + assert all(label_mask[p] != 0 for p in new_end_positions) + + assert len(label_mask) == len(tokens) + + seq_len = len(tokens) + match_labels = torch.zeros([seq_len, seq_len], dtype=torch.long) + for start, end in zip(new_start_positions, new_end_positions): + if start >= seq_len or end >= seq_len: + continue + match_labels[start, end] = 1 + + return [ + torch.LongTensor(tokens), + torch.LongTensor(attn_mask), + torch.LongTensor(type_ids), + torch.LongTensor(label_mask), + match_labels, + ] + + +def collate_to_max_length_roberta(batch): + """ + adapted form https://github.com/ShannonAI/mrc-for-flat-nested-ner + pad to maximum length of this batch + Args: + batch: a batch of samples, each contains a list of field data(Tensor): + tokens, token_type_ids, start_labels, end_labels, start_label_mask, + end_label_mask, match_labels, sample_idx, label_idx + Returns: + output: list of field batched data, which shape is [batch, max_length] + """ + batch_size = len(batch) + max_length = max(x[0].shape[0] for x in batch) + output = [] + + for field_idx in range(4): + if field_idx == 0: + pad_output = torch.full([batch_size, max_length], + 1, + dtype=batch[0][field_idx].dtype) + else: + pad_output = torch.full([batch_size, max_length], + 0, + dtype=batch[0][field_idx].dtype) + for sample_idx in range(batch_size): + data = batch[sample_idx][field_idx] + pad_output[sample_idx][:data.shape[0]] = data + output.append(pad_output) + + pad_match_labels = torch.zeros([batch_size, max_length, max_length], + dtype=torch.long) + for sample_idx in range(batch_size): + data = batch[sample_idx][4] + pad_match_labels[sample_idx, :data.shape[1], :data.shape[1]] = data + output.append(pad_match_labels) + + return output diff --git a/modelscope/trainers/multi_modal/custom_diffusion/custom_diffusion_trainer.py b/modelscope/trainers/multi_modal/custom_diffusion/custom_diffusion_trainer.py index 28140fb2..435fd2e3 100644 --- a/modelscope/trainers/multi_modal/custom_diffusion/custom_diffusion_trainer.py +++ b/modelscope/trainers/multi_modal/custom_diffusion/custom_diffusion_trainer.py @@ -37,24 +37,31 @@ from modelscope.utils.torch_utils import is_dist class CustomCheckpointProcessor(CheckpointProcessor): - def __init__(self, modifier_token, modifier_token_id): + def __init__(self, + modifier_token, + modifier_token_id, + torch_type=torch.float32): """Checkpoint processor for custom diffusion. Args: modifier_token: The token to use as a modifier for the concept. modifier_token_id: The modifier token id for the concept. + torch_type: The torch type, default is float32. + """ self.modifier_token = modifier_token self.modifier_token_id = modifier_token_id + self.torch_type = torch_type def save_checkpoints(self, trainer, checkpoint_path_prefix, output_dir, - meta=None): + meta=None, + save_optimizers=True): """Save the state dict for custom diffusion model. """ - trainer.model.unet = trainer.model.unet.to(torch.float32) + trainer.model.unet = trainer.model.unet.to(self.torch_type) trainer.model.unet.save_attn_procs(output_dir) learned_embeds = trainer.model.text_encoder.get_input_embeddings( @@ -281,6 +288,7 @@ class CustomDiffusionTrainer(EpochBasedTrainer): instance_prompt = kwargs.pop('instance_prompt', 'a photo of sks dog') class_prompt = kwargs.pop('class_prompt', 'dog') class_data_dir = kwargs.pop('class_data_dir', '/tmp/class_data') + self.torch_type = kwargs.pop('torch_type', torch.float32) self.real_prior = kwargs.pop('real_prior', False) self.num_class_images = kwargs.pop('num_class_images', 200) self.resolution = kwargs.pop('resolution', 512) @@ -387,7 +395,7 @@ class CustomDiffusionTrainer(EpochBasedTrainer): self.hooks))[0] ckpt_hook.set_processor( CustomCheckpointProcessor(self.modifier_token, - self.modifier_token_id)) + self.modifier_token_id, self.torch_type)) # Add new Custom Diffusion weights to the attention layers attention_class = CustomDiffusionAttnProcessor @@ -477,7 +485,7 @@ class CustomDiffusionTrainer(EpochBasedTrainer): size=self.resolution, mask_size=self.model.vae.encode( torch.randn(1, 3, self.resolution, - self.resolution).to(dtype=torch.float32).to( + self.resolution).to(dtype=self.torch_type).to( self.device)).latent_dist.sample().size()[-1], center_crop=self.center_crop, num_class_images=self.num_class_images, @@ -534,8 +542,8 @@ class CustomDiffusionTrainer(EpochBasedTrainer): if cur_class_images < self.num_class_images: pipeline = DiffusionPipeline.from_pretrained( self.model_dir, - torch_dtype=torch.float32, safety_checker=None, + torch_dtype=self.torch_type, revision=None, ) pipeline.set_progress_bar_config(disable=True) @@ -656,7 +664,7 @@ class CustomDiffusionTrainer(EpochBasedTrainer): batch = next(self.iter_train_dataloader) # Convert images to latent space latents = self.model.vae.encode(batch['pixel_values'].to( - dtype=torch.float32).to(self.device)).latent_dist.sample() + dtype=self.torch_type).to(self.device)).latent_dist.sample() latents = latents * self.model.vae.config.scaling_factor # Sample noise that we'll add to the latents diff --git a/modelscope/trainers/multi_modal/dreambooth_diffusion/dreambooth_diffusion_trainer.py b/modelscope/trainers/multi_modal/dreambooth_diffusion/dreambooth_diffusion_trainer.py index 3b300ea4..f391f87a 100644 --- a/modelscope/trainers/multi_modal/dreambooth_diffusion/dreambooth_diffusion_trainer.py +++ b/modelscope/trainers/multi_modal/dreambooth_diffusion/dreambooth_diffusion_trainer.py @@ -32,8 +32,16 @@ from modelscope.utils.torch_utils import is_dist class DreamboothCheckpointProcessor(CheckpointProcessor): - def __init__(self, model_dir): + def __init__(self, model_dir, torch_type=torch.float32): + """Checkpoint processor for dreambooth diffusion. + + Args: + model_dir: The model id or local model dir. + torch_type: The torch type, default is float32. + + """ self.model_dir = model_dir + self.torch_type = torch_type def save_checkpoints(self, trainer, @@ -49,6 +57,7 @@ class DreamboothCheckpointProcessor(CheckpointProcessor): pipeline = DiffusionPipeline.from_pretrained( self.model_dir, unet=trainer.model.unet, + torch_type=self.torch_type, **pipeline_args, ) scheduler_args = {} @@ -174,6 +183,7 @@ class DreamboothDiffusionTrainer(EpochBasedTrainer): prior_loss_weight: the weight of the prior loss. """ + self.torch_type = kwargs.pop('torch_type', torch.float32) self.with_prior_preservation = kwargs.pop('with_prior_preservation', False) self.instance_prompt = kwargs.pop('instance_prompt', @@ -219,7 +229,7 @@ class DreamboothDiffusionTrainer(EpochBasedTrainer): warnings.warn('Multiple GPU inference not yet supported.') pipeline = DiffusionPipeline.from_pretrained( self.model_dir, - torch_dtype=torch.float32, + torch_dtype=self.torch_type, safety_checker=None, revision=None, ) @@ -309,7 +319,8 @@ class DreamboothDiffusionTrainer(EpochBasedTrainer): input_ids = batch['input_ids'].to(self.device) with torch.no_grad(): latents = self.model.vae.encode( - target_prior.to(dtype=torch.float32)).latent_dist.sample() + target_prior.to( + dtype=self.torch_type)).latent_dist.sample() latents = latents * self.model.vae.config.scaling_factor # Sample noise that we'll add to the latents diff --git a/modelscope/trainers/multi_modal/lora_diffusion/lora_diffusion_trainer.py b/modelscope/trainers/multi_modal/lora_diffusion/lora_diffusion_trainer.py index 7c6644bd..2e4d4090 100644 --- a/modelscope/trainers/multi_modal/lora_diffusion/lora_diffusion_trainer.py +++ b/modelscope/trainers/multi_modal/lora_diffusion/lora_diffusion_trainer.py @@ -17,6 +17,15 @@ from modelscope.utils.config import ConfigDict class LoraDiffusionCheckpointProcessor(CheckpointProcessor): + def __init__(self, torch_type=torch.float32): + """Checkpoint processor for lora diffusion. + + Args: + torch_type: The torch type, default is float32. + + """ + self.torch_type = torch_type + def save_checkpoints(self, trainer, checkpoint_path_prefix, @@ -25,7 +34,7 @@ class LoraDiffusionCheckpointProcessor(CheckpointProcessor): save_optimizers=True): """Save the state dict for lora tune model. """ - trainer.model.unet = trainer.model.unet.to(torch.float32) + trainer.model.unet = trainer.model.unet.to(self.torch_type) trainer.model.unet.save_attn_procs(output_dir) @@ -38,15 +47,18 @@ class LoraDiffusionTrainer(EpochBasedTrainer): Args: lora_rank: The rank size of lora intermediate linear. + torch_type: The torch type, default is float32. """ lora_rank = kwargs.pop('lora_rank', 4) + torch_type = kwargs.pop('torch_type', torch.float32) # set lora save checkpoint processor ckpt_hook = list( filter(lambda hook: isinstance(hook, CheckpointHook), self.hooks))[0] - ckpt_hook.set_processor(LoraDiffusionCheckpointProcessor()) + ckpt_hook.set_processor( + LoraDiffusionCheckpointProcessor(torch_type=torch_type)) # Set correct lora layers lora_attn_procs = {} for name in self.model.unet.attn_processors.keys(): diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index 51924829..3bcad94c 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -213,6 +213,7 @@ class NLPTasks(object): document_grounded_dialog_retrieval = 'document-grounded-dialog-retrieval' document_grounded_dialog_rerank = 'document-grounded-dialog-rerank' document_grounded_dialog_generate = 'document-grounded-dialog-generate' + machine_reading_comprehension = 'machine-reading-comprehension' class AudioTasks(object): @@ -255,6 +256,8 @@ class MultiModalTasks(object): text_to_video_synthesis = 'text-to-video-synthesis' efficient_diffusion_tuning = 'efficient-diffusion-tuning' multimodal_dialogue = 'multimodal-dialogue' + image_to_video = 'image-to-video' + video_to_video = 'video-to-video' class ScienceTasks(object): diff --git a/tests/pipelines/test_efficient_diffusion_tuning.py b/tests/pipelines/test_efficient_diffusion_tuning.py index f1aa52de..330aee57 100644 --- a/tests/pipelines/test_efficient_diffusion_tuning.py +++ b/tests/pipelines/test_efficient_diffusion_tuning.py @@ -13,39 +13,45 @@ class EfficientDiffusionTuningTest(unittest.TestCase): def setUp(self) -> None: self.task = Tasks.efficient_diffusion_tuning - @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_efficient_diffusion_tuning_lora_run_pipeline(self): model_id = 'damo/multi-modal_efficient-diffusion-tuning-lora' + model_revision = 'v1.0.2' inputs = {'prompt': 'pale golden rod circle with old lace background'} - edt_pipeline = pipeline(self.task, model_id) + edt_pipeline = pipeline( + self.task, model_id, model_revision=model_revision) result = edt_pipeline(inputs) print(f'Efficient-diffusion-tuning-lora output: {result}.') - @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_efficient_diffusion_tuning_lora_load_model_from_pretrained(self): model_id = 'damo/multi-modal_efficient-diffusion-tuning-lora' - model = Model.from_pretrained(model_id) + model_revision = 'v1.0.2' + model = Model.from_pretrained(model_id, model_revision=model_revision) self.assertTrue(model.__class__ == EfficientStableDiffusion) - @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_efficient_diffusion_tuning_control_lora_run_pipeline(self): # TODO: to be fixed in the future model_id = 'damo/multi-modal_efficient-diffusion-tuning-control-lora' + model_revision = 'v1.0.2' inputs = { 'prompt': 'pale golden rod circle with old lace background', 'cond': 'data/test/images/efficient_diffusion_tuning_sd_control_lora_source.png' } - edt_pipeline = pipeline(self.task, model_id) + edt_pipeline = pipeline( + self.task, model_id, model_revision=model_revision) result = edt_pipeline(inputs) print(f'Efficient-diffusion-tuning-control-lora output: {result}.') - @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_efficient_diffusion_tuning_control_lora_load_model_from_pretrained( self): model_id = 'damo/multi-modal_efficient-diffusion-tuning-control-lora' - model = Model.from_pretrained(model_id) + model_revision = 'v1.0.2' + model = Model.from_pretrained(model_id, model_revision=model_revision) self.assertTrue(model.__class__ == EfficientStableDiffusion) diff --git a/tests/pipelines/test_efficient_diffusion_tuning_swift.py b/tests/pipelines/test_efficient_diffusion_tuning_swift.py index a63a6e26..09b739a0 100644 --- a/tests/pipelines/test_efficient_diffusion_tuning_swift.py +++ b/tests/pipelines/test_efficient_diffusion_tuning_swift.py @@ -16,14 +16,16 @@ class EfficientDiffusionTuningTestSwift(unittest.TestCase): def setUp(self) -> None: self.task = Tasks.efficient_diffusion_tuning - @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_efficient_diffusion_tuning_swift_lora_run_pipeline(self): model_id = 'damo/multi-modal_efficient-diffusion-tuning-swift-lora' + model_revision = 'v1.0.2' inputs = { 'prompt': 'a street scene with a cafe and a restaurant sign in anime style' } - sd_tuner_pipeline = pipeline(self.task, model_id) + sd_tuner_pipeline = pipeline( + self.task, model_id, model_revision=model_revision) result = sd_tuner_pipeline(inputs, generator_seed=0) output_image_path = tempfile.NamedTemporaryFile(suffix='.png').name cv2.imwrite(output_image_path, result['output_imgs'][0]) @@ -31,21 +33,24 @@ class EfficientDiffusionTuningTestSwift(unittest.TestCase): f'Efficient-diffusion-tuning-swift-lora output: {output_image_path}' ) - @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_efficient_diffusion_tuning_swift_lora_load_model_from_pretrained( self): model_id = 'damo/multi-modal_efficient-diffusion-tuning-swift-lora' - model = Model.from_pretrained(model_id) + model_revision = 'v1.0.2' + model = Model.from_pretrained(model_id, model_revision=model_revision) self.assertTrue(model.__class__ == EfficientStableDiffusion) - @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_efficient_diffusion_tuning_swift_adapter_run_pipeline(self): model_id = 'damo/multi-modal_efficient-diffusion-tuning-swift-adapter' + model_revision = 'v1.0.2' inputs = { 'prompt': 'a street scene with a cafe and a restaurant sign in anime style' } - sd_tuner_pipeline = pipeline(self.task, model_id) + sd_tuner_pipeline = pipeline( + self.task, model_id, model_revision=model_revision) result = sd_tuner_pipeline(inputs, generator_seed=0) output_image_path = tempfile.NamedTemporaryFile(suffix='.png').name cv2.imwrite(output_image_path, result['output_imgs'][0]) @@ -53,21 +58,24 @@ class EfficientDiffusionTuningTestSwift(unittest.TestCase): f'Efficient-diffusion-tuning-swift-adapter output: {output_image_path}' ) - @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_efficient_diffusion_tuning_swift_adapter_load_model_from_pretrained( self): model_id = 'damo/multi-modal_efficient-diffusion-tuning-swift-adapter' - model = Model.from_pretrained(model_id) + model_revision = 'v1.0.2' + model = Model.from_pretrained(model_id, model_revision=model_revision) self.assertTrue(model.__class__ == EfficientStableDiffusion) - @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_efficient_diffusion_tuning_swift_prompt_run_pipeline(self): model_id = 'damo/multi-modal_efficient-diffusion-tuning-swift-prompt' + model_revision = 'v1.0.2' inputs = { 'prompt': 'a street scene with a cafe and a restaurant sign in anime style' } - sd_tuner_pipeline = pipeline(self.task, model_id) + sd_tuner_pipeline = pipeline( + self.task, model_id, model_revision=model_revision) result = sd_tuner_pipeline(inputs, generator_seed=0) output_image_path = tempfile.NamedTemporaryFile(suffix='.png').name cv2.imwrite(output_image_path, result['output_imgs'][0]) @@ -75,11 +83,12 @@ class EfficientDiffusionTuningTestSwift(unittest.TestCase): f'Efficient-diffusion-tuning-swift-prompt output: {output_image_path}' ) - @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_efficient_diffusion_tuning_swift_prompt_load_model_from_pretrained( self): model_id = 'damo/multi-modal_efficient-diffusion-tuning-swift-prompt' - model = Model.from_pretrained(model_id) + model_revision = 'v1.0.2' + model = Model.from_pretrained(model_id, model_revision=model_revision) self.assertTrue(model.__class__ == EfficientStableDiffusion) diff --git a/tests/pipelines/test_image2video.py b/tests/pipelines/test_image2video.py new file mode 100644 index 00000000..b6daf73c --- /dev/null +++ b/tests/pipelines/test_image2video.py @@ -0,0 +1,28 @@ +import sys +import unittest + +from modelscope.models import Model +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.utils.constant import DownloadMode, Tasks +from modelscope.utils.test_utils import test_level + + +class Image2VideoTest(unittest.TestCase): + + def setUp(self) -> None: + self.task = Tasks.image_to_video + self.model_id = 'damo/Image-to-Video' + self.path = 'https://video-generation-wulanchabu.oss-cn-wulanchabu.aliyuncs.com/baishao/test.jpeg' + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_model_from_modelhub(self): + pipe = pipeline(task=self.task, model=self.model_id) + + output_video_path = pipe( + self.path, output_video='./output.mp4')[OutputKeys.OUTPUT_VIDEO] + print(output_video_path) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_machine_reading_comprehension.py b/tests/pipelines/test_machine_reading_comprehension.py new file mode 100644 index 00000000..156c19b8 --- /dev/null +++ b/tests/pipelines/test_machine_reading_comprehension.py @@ -0,0 +1,59 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.models import Model +from modelscope.models.nlp import ModelForMachineReadingComprehension +from modelscope.pipelines import pipeline +from modelscope.pipelines.nlp import MachineReadingComprehensionForNERPipeline +from modelscope.preprocessors import \ + MachineReadingComprehensionForNERPreprocessor +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + + +class MachineReadingComprehensionTest(unittest.TestCase): + sentence = 'Soccer - Japan get lucky win , China in surprise defeat .' + model_id = 'damo/nlp_roberta_machine-reading-comprehension_for-ner' + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_mrc_for_ner_by_direct_model_download(self): + cache_path = snapshot_download(self.model_id) + tokenizer = MachineReadingComprehensionForNERPreprocessor(cache_path) + model = ModelForMachineReadingComprehension.from_pretrained(cache_path) + pipeline1 = MachineReadingComprehensionForNERPipeline( + model, preprocessor=tokenizer) + + pipeline2 = pipeline( + Tasks.machine_reading_comprehension, + model=model, + preprocessor=tokenizer) + print(f'sentence: {self.sentence}\n' + f'pipeline1:{pipeline1(input=self.sentence)}') + print() + print(f'pipeline2: {pipeline2(input=self.sentence)}') + # {'ORG': [], 'PER': [], 'LOC': [' Japan', ' China'], 'MISC': []} + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_mrc_for_ner_with_model_from_modelhub(self): + model = Model.from_pretrained(self.model_id) + tokenizer = MachineReadingComprehensionForNERPreprocessor( + model.model_dir) + pipeline_ins = pipeline( + task=Tasks.machine_reading_comprehension, + model=model, + preprocessor=tokenizer) + print(f'sentence: {self.sentence}\n' + f'pipeline:{pipeline_ins(input=self.sentence)}') + # {'ORG': [], 'PER': [], 'LOC': [' Japan', ' China'], 'MISC': []} + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_mrc_for_ner_with_model_name(self): + pipeline_ins = pipeline( + task=Tasks.machine_reading_comprehension, model=self.model_id) + print(pipeline_ins(input=self.sentence)) + # {'ORG': [], 'PER': [], 'LOC': [' Japan', ' China'], 'MISC': []} + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_text_to_360panorama_image.py b/tests/pipelines/test_text_to_360panorama_image.py index b2780597..fcf1ec44 100644 --- a/tests/pipelines/test_text_to_360panorama_image.py +++ b/tests/pipelines/test_text_to_360panorama_image.py @@ -52,7 +52,7 @@ class Text2360PanoramaImageTest(unittest.TestCase): print( 'pipeline: the output image path is {}'.format(output_image_path)) - @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_run_with_model_from_modelhub(self): output_image_path = tempfile.NamedTemporaryFile(suffix='.png').name pipeline_ins = pipeline( diff --git a/tests/pipelines/test_video2video.py b/tests/pipelines/test_video2video.py new file mode 100644 index 00000000..fcd9a7e5 --- /dev/null +++ b/tests/pipelines/test_video2video.py @@ -0,0 +1,32 @@ +import sys +import unittest + +from modelscope.models import Model +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + + +class Video2VideoTest(unittest.TestCase): + + def setUp(self) -> None: + self.task = Tasks.video_to_video + self.model_id = 'damo/Video-to-Video' + self.path = 'https://video-generation-wulanchabu.oss-cn-wulanchabu.aliyuncs.com/baishao/test.mp4' + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_model_from_modelhub(self): + pipe = pipeline(task=self.task, model=self.model_id) + p_input = { + 'video_path': self.path, + 'text': 'A panda is surfing on the sea' + } + + output_video_path = pipe( + p_input, output_video='./output.mp4')[OutputKeys.OUTPUT_VIDEO] + print(output_video_path) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/run_analysis.py b/tests/run_analysis.py index ca0a0018..95c24698 100644 --- a/tests/run_analysis.py +++ b/tests/run_analysis.py @@ -125,16 +125,21 @@ def get_current_branch(): def get_modified_files(): - cmd = ['git', 'diff', '--name-only', 'origin/master...'] - cmd_output = run_command_get_output(cmd) - logger.info('Modified files: ') - logger.info(cmd_output) + if 'PR_CHANGED_FILES' in os.environ and os.environ[ + 'PR_CHANGED_FILES'] != '': + logger.info('Getting PR modified files.') + # get modify file from environment + diff_files = os.environ['PR_CHANGED_FILES'].replace('#', '\n') + else: + cmd = ['git', 'diff', '--name-only', 'origin/master...'] + diff_files = run_command_get_output(cmd) + logger.info('Diff files: ') + logger.info(diff_files) modified_files = [] # remove the deleted file. - for diff_file in cmd_output.splitlines(): - if os.path.exists(diff_file): - modified_files.append(diff_file) - + for diff_file in diff_files.splitlines(): + if os.path.exists(diff_file.strip()): + modified_files.append(diff_file.strip()) return modified_files diff --git a/tests/trainers/test_efficient_diffusion_tuning_trainer.py b/tests/trainers/test_efficient_diffusion_tuning_trainer.py index a19bf21d..de23782c 100644 --- a/tests/trainers/test_efficient_diffusion_tuning_trainer.py +++ b/tests/trainers/test_efficient_diffusion_tuning_trainer.py @@ -42,6 +42,7 @@ class TestEfficientDiffusionTuningTrainer(unittest.TestCase): @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_efficient_diffusion_tuning_lora_train(self): model_id = 'damo/multi-modal_efficient-diffusion-tuning-lora' + model_revision = 'v1.0.2' def cfg_modify_fn(cfg): cfg.train.max_epochs = self.max_epochs @@ -51,6 +52,7 @@ class TestEfficientDiffusionTuningTrainer(unittest.TestCase): kwargs = dict( model=model_id, + model_revision=model_revision, work_dir=self.tmp_dir, train_dataset=self.train_dataset, eval_dataset=self.eval_dataset, @@ -70,6 +72,7 @@ class TestEfficientDiffusionTuningTrainer(unittest.TestCase): @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_efficient_diffusion_tuning_lora_eval(self): model_id = 'damo/multi-modal_efficient-diffusion-tuning-lora' + model_revision = 'v1.0.2' def cfg_modify_fn(cfg): cfg.model.inference = False @@ -77,6 +80,7 @@ class TestEfficientDiffusionTuningTrainer(unittest.TestCase): kwargs = dict( model=model_id, + model_revision=model_revision, work_dir=self.tmp_dir, train_dataset=None, eval_dataset=self.eval_dataset, @@ -90,6 +94,7 @@ class TestEfficientDiffusionTuningTrainer(unittest.TestCase): @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_efficient_diffusion_tuning_control_lora_train(self): model_id = 'damo/multi-modal_efficient-diffusion-tuning-control-lora' + model_revision = 'v1.0.2' def cfg_modify_fn(cfg): cfg.train.max_epochs = self.max_epochs @@ -99,6 +104,7 @@ class TestEfficientDiffusionTuningTrainer(unittest.TestCase): kwargs = dict( model=model_id, + model_revision=model_revision, work_dir=self.tmp_dir, train_dataset=self.train_dataset, eval_dataset=self.eval_dataset, @@ -119,6 +125,7 @@ class TestEfficientDiffusionTuningTrainer(unittest.TestCase): @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_efficient_diffusion_tuning_control_lora_eval(self): model_id = 'damo/multi-modal_efficient-diffusion-tuning-control-lora' + model_revision = 'v1.0.2' def cfg_modify_fn(cfg): cfg.model.inference = False @@ -126,6 +133,7 @@ class TestEfficientDiffusionTuningTrainer(unittest.TestCase): kwargs = dict( model=model_id, + model_revision=model_revision, work_dir=self.tmp_dir, train_dataset=None, eval_dataset=self.eval_dataset, diff --git a/tests/trainers/test_efficient_diffusion_tuning_trainer_swift.py b/tests/trainers/test_efficient_diffusion_tuning_trainer_swift.py index 7c3ecd99..9e12335e 100644 --- a/tests/trainers/test_efficient_diffusion_tuning_trainer_swift.py +++ b/tests/trainers/test_efficient_diffusion_tuning_trainer_swift.py @@ -33,9 +33,10 @@ class TestEfficientDiffusionTuningTrainerSwift(unittest.TestCase): shutil.rmtree(self.tmp_dir) super().tearDown() - @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_efficient_diffusion_tuning_swift_lora_train(self): model_id = 'damo/multi-modal_efficient-diffusion-tuning-swift-lora' + model_revision = 'v1.0.2' def cfg_modify_fn(cfg): cfg.train.max_epochs = self.max_epochs @@ -47,6 +48,7 @@ class TestEfficientDiffusionTuningTrainerSwift(unittest.TestCase): kwargs = dict( model=model_id, + model_revision=model_revision, work_dir=self.tmp_dir, train_dataset=self.train_dataset, cfg_modify_fn=cfg_modify_fn) @@ -60,9 +62,10 @@ class TestEfficientDiffusionTuningTrainerSwift(unittest.TestCase): self.assertIn(f'{trainer.timestamp}.log.json', results_files) self.assertIn(f'epoch_{self.max_epochs}.pth', results_files) - @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_efficient_diffusion_tuning_swift_adapter_train(self): model_id = 'damo/multi-modal_efficient-diffusion-tuning-swift-adapter' + model_revision = 'v1.0.2' def cfg_modify_fn(cfg): cfg.train.max_epochs = self.max_epochs @@ -74,6 +77,7 @@ class TestEfficientDiffusionTuningTrainerSwift(unittest.TestCase): kwargs = dict( model=model_id, + model_revision=model_revision, work_dir=self.tmp_dir, train_dataset=self.train_dataset, cfg_modify_fn=cfg_modify_fn) @@ -87,9 +91,10 @@ class TestEfficientDiffusionTuningTrainerSwift(unittest.TestCase): self.assertIn(f'{trainer.timestamp}.log.json', results_files) self.assertIn(f'epoch_{self.max_epochs}.pth', results_files) - @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_efficient_diffusion_tuning_swift_prompt_train(self): model_id = 'damo/multi-modal_efficient-diffusion-tuning-swift-prompt' + model_revision = 'v1.0.2' def cfg_modify_fn(cfg): cfg.train.max_epochs = self.max_epochs @@ -101,6 +106,7 @@ class TestEfficientDiffusionTuningTrainerSwift(unittest.TestCase): kwargs = dict( model=model_id, + model_revision=model_revision, work_dir=self.tmp_dir, train_dataset=self.train_dataset, cfg_modify_fn=cfg_modify_fn)