merge master-github

This commit is contained in:
wenmeng.zwm
2023-08-21 18:57:09 +08:00
70 changed files with 8652 additions and 266 deletions

View File

@@ -4,11 +4,15 @@ CODE_DIR=$PWD
CODE_DIR_IN_CONTAINER=/Maas-lib
echo "$USER"
gpus='0,1 2,3 4,5 6,7'
cpu_sets='45-58 31-44 16-30 0-15'
cpu_sets='0-15 16-31 32-47 48-63'
cpu_sets_arr=($cpu_sets)
is_get_file_lock=false
CI_COMMAND=${CI_COMMAND:-bash .dev_scripts/ci_container_test.sh python tests/run.py --parallel 2 --run_config tests/run_config.yaml}
echo "ci command: $CI_COMMAND"
PR_CHANGED_FILES="${PR_CHANGED_FILES:-''}"
echo "PR modified files: $PR_CHANGED_FILES"
PR_CHANGED_FILES=${PR_CHANGED_FILES//[ ]/#}
echo "PR_CHANGED_FILES: $PR_CHANGED_FILES"
idx=0
for gpu in $gpus
do
@@ -42,6 +46,7 @@ do
-e MODELSCOPE_ENVIRONMENT='ci' \
-e TEST_UPLOAD_MS_TOKEN=$TEST_UPLOAD_MS_TOKEN \
-e MODEL_TAG_URL=$MODEL_TAG_URL \
-e PR_CHANGED_FILES=$PR_CHANGED_FILES \
--workdir=$CODE_DIR_IN_CONTAINER \
${IMAGE_NAME}:${IMAGE_VERSION} \
$CI_COMMAND
@@ -64,6 +69,7 @@ do
-e MODELSCOPE_ENVIRONMENT='ci' \
-e TEST_UPLOAD_MS_TOKEN=$TEST_UPLOAD_MS_TOKEN \
-e MODEL_TAG_URL=$MODEL_TAG_URL \
-e PR_CHANGED_FILES=$PR_CHANGED_FILES \
--workdir=$CODE_DIR_IN_CONTAINER \
${IMAGE_NAME}:${IMAGE_VERSION} \
$CI_COMMAND

View File

@@ -39,7 +39,7 @@ concurrency:
jobs:
unittest:
# The type of runner that the job will run on
runs-on: [modelscope-self-hosted]
runs-on: [modelscope-self-hosted-us]
timeout-minutes: 240
steps:
- name: ResetFileMode
@@ -52,10 +52,19 @@ jobs:
sudo chown -R $USER:$USER $ACTION_RUNNER_DIR
- name: Checkout
uses: actions/checkout@v2
uses: actions/checkout@v3
with:
lfs: 'true'
submodules: 'true'
fetch-depth: ${{ github.event_name == 'pull_request' && 2 || 0 }}
- name: Get changed files
id: changed-files
run: |
if ${{ github.event_name == 'pull_request' }}; then
echo "PR_CHANGED_FILES=$(git diff --name-only -r HEAD^1 HEAD | xargs)" >> $GITHUB_ENV
else
echo "PR_CHANGED_FILES=$(git diff --name-only ${{ github.event.before }} ${{ github.event.after }} | xargs)" >> $GITHUB_ENV
fi
- name: Checkout LFS objects
run: git lfs checkout
- name: Run unittest

View File

@@ -12,7 +12,7 @@ concurrency:
jobs:
unittest:
# The type of runner that the job will run on
runs-on: [modelscope-self-hosted]
runs-on: [modelscope-self-hosted-us]
steps:
- name: ResetFileMode
shell: bash

View File

@@ -2,6 +2,7 @@ import os
from dataclasses import dataclass, field
import cv2
import torch
from modelscope.metainfo import Trainers
from modelscope.msdatasets import MsDataset
@@ -95,6 +96,12 @@ class StableDiffusionCustomArguments(TrainingArgs):
'help': 'Path to json containing multiple concepts.',
})
torch_type: str = field(
default='float32',
metadata={
'help': ' The torch type, default is float32.',
})
training_args = StableDiffusionCustomArguments(
task='text-to-image-synthesis').parse_cli()
@@ -148,6 +155,8 @@ kwargs = dict(
work_dir=training_args.work_dir,
train_dataset=train_dataset,
eval_dataset=validation_dataset,
torch_type=torch.float16
if args.torch_type == 'float16' else torch.float32,
cfg_modify_fn=cfg_modify_fn)
# build trainer and training
@@ -159,7 +168,7 @@ pipe = pipeline(
task=Tasks.text_to_image_synthesis,
model=training_args.model,
custom_dir=training_args.work_dir + '/output',
modifier_token='<new1>+<new2>',
modifier_token=args.modifier_token,
model_revision=args.model_revision)
output = pipe({'text': args.instance_prompt})

View File

@@ -7,11 +7,12 @@ PYTHONPATH=. torchrun examples/pytorch/stable_diffusion/custom/finetune_stable_d
--class_data_dir './tmp/class_data' \
--train_dataset_name 'buptwq/lora-stable-diffusion-finetune-dog' \
--max_epochs 250 \
--modifier_token "<new1>+<new2>" \
--modifier_token "<new1>" \
--num_class_images=200 \
--save_ckpt_strategy 'by_epoch' \
--logging_interval 1 \
--train.dataloader.workers_per_gpu 0 \
--evaluation.dataloader.workers_per_gpu 0 \
--train.optimizer.lr 1e-5 \
--torch_type 'float32' \
--use_model_config true

View File

@@ -2,6 +2,7 @@ import os
from dataclasses import dataclass, field
import cv2
import torch
from modelscope.metainfo import Trainers
from modelscope.msdatasets import MsDataset
@@ -59,6 +60,12 @@ class StableDiffusionDreamboothArguments(TrainingArgs):
'help': 'The pipeline prompt.',
})
torch_type: str = field(
default='float32',
metadata={
'help': ' The torch type, default is float32.',
})
training_args = StableDiffusionDreamboothArguments(
task='text-to-image-synthesis').parse_cli()
@@ -106,6 +113,8 @@ kwargs = dict(
resolution=args.resolution,
prior_loss_weight=args.prior_loss_weight,
prompt=args.prompt,
torch_type=torch.float16
if args.torch_type == 'float16' else torch.float32,
cfg_modify_fn=cfg_modify_fn)
# build trainer and training

View File

@@ -17,4 +17,5 @@ PYTHONPATH=. torchrun examples/pytorch/stable_diffusion/dreambooth/finetune_stab
--train.dataloader.workers_per_gpu 0 \
--evaluation.dataloader.workers_per_gpu 0 \
--train.optimizer.lr 5e-6 \
--torch_type 'float32' \
--use_model_config true

View File

@@ -2,6 +2,7 @@ import os
from dataclasses import dataclass, field
import cv2
import torch
from modelscope.metainfo import Trainers
from modelscope.msdatasets import MsDataset
@@ -25,6 +26,12 @@ class StableDiffusionLoraArguments(TrainingArgs):
'help': 'The rank size of lora intermediate linear.',
})
torch_type: str = field(
default='float32',
metadata={
'help': ' The torch type, default is float32.',
})
training_args = StableDiffusionLoraArguments(
task='text-to-image-synthesis').parse_cli()
@@ -66,6 +73,8 @@ kwargs = dict(
train_dataset=train_dataset,
eval_dataset=validation_dataset,
lora_rank=args.lora_rank,
torch_type=torch.float16
if args.torch_type == 'float16' else torch.float32,
cfg_modify_fn=cfg_modify_fn)
# build trainer and training

View File

@@ -5,10 +5,11 @@ PYTHONPATH=. torchrun examples/pytorch/stable_diffusion/lora/finetune_stable_dif
--work_dir './tmp/lora_diffusion' \
--train_dataset_name 'buptwq/lora-stable-diffusion-finetune' \
--max_epochs 100 \
--lora_rank 4 \
--lora_rank 16 \
--save_ckpt_strategy 'by_epoch' \
--logging_interval 1 \
--train.dataloader.workers_per_gpu 0 \
--evaluation.dataloader.workers_per_gpu 0 \
--train.optimizer.lr 1e-4 \
--torch_type 'float16' \
--use_model_config true

View File

@@ -220,6 +220,8 @@ class Models(object):
stable_diffusion = 'stable-diffusion'
videocomposer = 'videocomposer'
text_to_360panorama_image = 'text-to-360panorama-image'
image_to_video_model = 'image-to-video-model'
video_to_video_model = 'video-to-video-model'
# science models
unifold = 'unifold'
@@ -235,6 +237,7 @@ class TaskModels(object):
feature_extraction = 'feature-extraction'
text_generation = 'text-generation'
text_ranking = 'text-ranking'
machine_reading_comprehension = 'machine-reading-comprehension'
class Heads(object):
@@ -487,6 +490,7 @@ class Pipelines(object):
document_grounded_dialog_rerank = 'document-grounded-dialog-rerank'
document_grounded_dialog_generate = 'document-grounded-dialog-generate'
language_identification = 'language_identification'
machine_reading_comprehension_for_ner = 'machine-reading-comprehension-for-ner'
# audio tasks
sambert_hifigan_tts = 'sambert-hifigan-tts'
@@ -543,6 +547,8 @@ class Pipelines(object):
efficient_diffusion_tuning = 'efficient-diffusion-tuning'
multimodal_dialogue = 'multimodal-dialogue'
llama2_text_generation_pipeline = 'llama2-text-generation-pipeline'
image_to_video_task_pipeline = 'image-to-video-task-pipeline'
video_to_video_pipeline = 'video-to-video-pipeline'
# science tasks
protein_structure = 'unifold-protein-structure'
@@ -1062,6 +1068,7 @@ class Preprocessors(object):
document_grounded_dialog_retrieval = 'document-grounded-dialog-retrieval'
document_grounded_dialog_rerank = 'document-grounded-dialog-rerank'
document_grounded_dialog_generate = 'document-grounded-dialog-generate'
machine_reading_comprehension_for_ner = 'machine-reading-comprehension-for-ner'
# audio preprocessor
linear_aec_fbank = 'linear-aec-fbank'

View File

@@ -0,0 +1,55 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import numpy as np
import torch
import torch.nn.functional as F
from .gpen_model import FullGenerator
class GPEN(object):
def __init__(self,
model_path,
size=512,
channel_multiplier=2,
device=torch.device('cpu')):
self.mfile = model_path
self.n_mlp = 8
self.resolution = size
self.device = device
self.load_model(channel_multiplier)
def load_model(self, channel_multiplier=2):
self.model = FullGenerator(self.resolution, 512, self.n_mlp,
channel_multiplier).to(self.device)
pretrained_dict = torch.load(self.mfile)
self.model.load_state_dict(pretrained_dict)
self.model.eval()
def process(self, im):
preds = []
imt = self.img2tensor(im)
imt = F.interpolate(imt, (self.resolution, self.resolution))
with torch.no_grad():
img_out, __ = self.model(imt)
face = self.tensor2img(img_out)
return face, preds
def img2tensor(self, img):
img_t = torch.from_numpy(img).to(self.device)
img_t = (img_t / 255. - 0.5) / 0.5
img_t = img_t.permute(2, 0, 1).unsqueeze(0).flip(1) # BGR->RGB
return img_t
def tensor2img(self, image_tensor, pmax=255.0, imtype=np.uint8):
image_tensor = image_tensor * 0.5 + 0.5
image_tensor = image_tensor.squeeze(0).permute(1, 2,
0).flip(2) # RGB->BGR
image_numpy = np.clip(image_tensor.float().cpu().numpy(), 0, 1) * pmax
return image_numpy.astype(imtype)

View File

@@ -1,93 +0,0 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import cv2
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms
from .model import FullGenerator
class GANWrap(object):
def __init__(self,
model_path,
size=256,
channel_multiplier=1,
device='cpu'):
self.device = device
self.mfile = model_path
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5),
inplace=True),
])
self.batchSize = 2
self.n_mlp = 8
self.resolution = size
self.load_model(channel_multiplier)
def load_model(self, channel_multiplier=2):
self.model = FullGenerator(self.resolution, 512, self.n_mlp,
channel_multiplier).to(self.device)
pretrained_dict = torch.load(
self.mfile, map_location=torch.device('cpu'))
self.model.load_state_dict(pretrained_dict)
self.model.eval()
def process_tensor(self, img_t, return_face=True):
b, c, h, w = img_t.shape
img_t = F.interpolate(img_t, (self.resolution, self.resolution))
with torch.no_grad():
out, __ = self.model(img_t)
out = F.interpolate(out, (w, h))
return out
def process(self, ims, return_face=True):
res = []
faces = []
for i in range(0, len(ims), self.batchSize):
sizes = []
imt = None
for im in ims[i:i + self.batchSize]:
sizes.append(im.shape[0])
im = cv2.resize(im, (self.resolution, self.resolution))
im_pil = Image.fromarray(im)
imt = self.img2tensor(im_pil) if imt is None else torch.cat(
(imt, self.img2tensor(im_pil)), dim=0)
imt = torch.flip(imt, [1])
with torch.no_grad():
img_outs, __ = self.model(imt)
for sz, img_out in zip(sizes, img_outs):
img = self.tensor2img(img_out)
if return_face:
faces.append(img)
img = cv2.resize(img, (sz, sz), interpolation=cv2.INTER_AREA)
res.append(img)
return res, faces
def img2tensor(self, img):
img_t = self.transform(img).to(self.device)
img_t = torch.unsqueeze(img_t, 0)
return img_t
def tensor2img(self, image_tensor, bytes=255.0, imtype=np.uint8):
if image_tensor.dim() == 3:
image_numpy = image_tensor.cpu().float().numpy()
else:
image_numpy = image_tensor[0].cpu().float().numpy()
image_numpy = np.transpose(image_numpy, (1, 2, 0))
image_numpy = image_numpy[:, :, ::-1]
image_numpy = np.clip(
image_numpy * np.asarray([0.5, 0.5, 0.5])
+ np.asarray([0.5, 0.5, 0.5]), 0, 1)
image_numpy = image_numpy * bytes
return image_numpy.astype(imtype)

View File

@@ -1,17 +1,19 @@
# The implementation is adopted from stylegan2-pytorch,
# made public available under the MIT License at https://github.com/rosinality/stylegan2-pytorch/blob/master/model.py
# The implementation is adopted from InsightFace_Pytorch, made publicly available under the MIT License
# at https://github.com/yangxy/GPEN
import functools
import itertools
import math
import operator
import random
import torch
from torch import nn
from torch.autograd import Function
from torch.nn import functional as F
from .op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d
isconcat = True
sss = 2 if isconcat else 1
ratio = 2
class PixelNorm(nn.Module):
@@ -306,6 +308,7 @@ class NoiseInjection(nn.Module):
def forward(self, image, noise=None):
if noise is not None:
# print(image.shape, noise.shape)
if isconcat:
return torch.cat((image, self.weight * noise), dim=1) # concat
return image + self.weight * noise
@@ -356,7 +359,8 @@ class StyledConv(nn.Module):
)
self.noise = NoiseInjection()
self.activate = FusedLeakyReLU(out_channel * sss)
feat_multiplier = 2
self.activate = FusedLeakyReLU(out_channel * feat_multiplier)
def forward(self, input, style, noise=None):
out = self.conv(input, style)
@@ -405,12 +409,14 @@ class Generator(nn.Module):
channel_multiplier=2,
blur_kernel=[1, 3, 3, 1],
lr_mlp=0.01,
narrow=1,
):
super().__init__()
self.size = size
self.n_mlp = n_mlp
self.style_dim = style_dim
self.feat_multiplier = 2
layers = [PixelNorm()]
@@ -425,15 +431,16 @@ class Generator(nn.Module):
self.style = nn.Sequential(*layers)
self.channels = {
4: 512 // ratio,
8: 512 // ratio,
16: 512 // ratio,
32: 512 // ratio,
64: 256 // ratio * channel_multiplier,
128: 128 // ratio * channel_multiplier,
256: 64 // ratio * channel_multiplier,
512: 32 // ratio * channel_multiplier,
1024: 16 // ratio * channel_multiplier,
4: int(512 * narrow),
8: int(512 * narrow),
16: int(512 * narrow),
32: int(512 * narrow),
64: int(256 * channel_multiplier * narrow),
128: int(128 * channel_multiplier * narrow),
256: int(64 * channel_multiplier * narrow),
512: int(32 * channel_multiplier * narrow),
1024: int(16 * channel_multiplier * narrow),
2048: int(8 * channel_multiplier * narrow)
}
self.input = ConstantInput(self.channels[4])
@@ -443,7 +450,8 @@ class Generator(nn.Module):
3,
style_dim,
blur_kernel=blur_kernel)
self.to_rgb1 = ToRGB(self.channels[4] * sss, style_dim, upsample=False)
self.to_rgb1 = ToRGB(
self.channels[4] * self.feat_multiplier, style_dim, upsample=False)
self.log_size = int(math.log(size, 2))
@@ -458,23 +466,23 @@ class Generator(nn.Module):
self.convs.append(
StyledConv(
in_channel * sss,
in_channel * self.feat_multiplier,
out_channel,
3,
style_dim,
upsample=True,
blur_kernel=blur_kernel,
))
blur_kernel=blur_kernel))
self.convs.append(
StyledConv(
out_channel * sss,
out_channel * self.feat_multiplier,
out_channel,
3,
style_dim,
blur_kernel=blur_kernel))
self.to_rgbs.append(ToRGB(out_channel * sss, style_dim))
self.to_rgbs.append(
ToRGB(out_channel * self.feat_multiplier, style_dim))
in_channel = out_channel
@@ -515,6 +523,9 @@ class Generator(nn.Module):
styles = [self.style(s) for s in styles]
if noise is None:
'''
noise = [None] * (2 * (self.log_size - 2) + 1)
'''
noise = []
batch = styles[0].shape[0]
for i in range(self.n_mlp + 1):
@@ -557,16 +568,14 @@ class Generator(nn.Module):
skip = self.to_rgb1(out, latent[:, 1])
i = 1
noise_i = 1
for conv1, conv2, to_rgb in zip(self.convs[::2], self.convs[1::2],
self.to_rgbs):
out = conv1(out, latent[:, i], noise=noise[(noise_i + 1) // 2])
out = conv2(out, latent[:, i + 1], noise=noise[(noise_i + 2) // 2])
for conv1, conv2, noise1, noise2, to_rgb in zip(
self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2],
self.to_rgbs):
out = conv1(out, latent[:, i], noise=noise1)
out = conv2(out, latent[:, i + 1], noise=noise2)
skip = to_rgb(out, latent[:, i + 2], skip)
i += 2
noise_i += 2
image = skip
@@ -652,21 +661,106 @@ class ResBlock(nn.Module):
return out
class FullGenerator(nn.Module):
def __init__(
self,
size,
style_dim,
n_mlp,
channel_multiplier=2,
blur_kernel=[1, 3, 3, 1],
lr_mlp=0.01,
narrow=1,
):
super().__init__()
channels = {
4: int(512 * narrow),
8: int(512 * narrow),
16: int(512 * narrow),
32: int(512 * narrow),
64: int(256 * channel_multiplier * narrow),
128: int(128 * channel_multiplier * narrow),
256: int(64 * channel_multiplier * narrow),
512: int(32 * channel_multiplier * narrow),
1024: int(16 * channel_multiplier * narrow),
2048: int(8 * channel_multiplier * narrow)
}
self.log_size = int(math.log(size, 2))
self.generator = Generator(
size,
style_dim,
n_mlp,
channel_multiplier=channel_multiplier,
blur_kernel=blur_kernel,
lr_mlp=lr_mlp,
narrow=narrow)
conv = [ConvLayer(3, channels[size], 1)]
self.ecd0 = nn.Sequential(*conv)
in_channel = channels[size]
self.names = ['ecd%d' % i for i in range(self.log_size - 1)]
for i in range(self.log_size, 2, -1):
out_channel = channels[2**(i - 1)]
conv = [ConvLayer(in_channel, out_channel, 3, downsample=True)]
setattr(self, self.names[self.log_size - i + 1],
nn.Sequential(*conv))
in_channel = out_channel
self.final_linear = nn.Sequential(
EqualLinear(
channels[4] * 4 * 4, style_dim, activation='fused_lrelu'))
def forward(
self,
inputs,
return_latents=False,
inject_index=None,
truncation=1,
truncation_latent=None,
input_is_latent=False,
):
noise = []
for i in range(self.log_size - 1):
ecd = getattr(self, self.names[i])
inputs = ecd(inputs)
noise.append(inputs)
inputs = inputs.view(inputs.shape[0], -1)
outs = self.final_linear(inputs)
noise = list(
itertools.chain.from_iterable(
itertools.repeat(x, 2) for x in noise))[::-1]
outs = self.generator([outs],
return_latents,
inject_index,
truncation,
truncation_latent,
input_is_latent,
noise=noise[1:])
return outs
class Discriminator(nn.Module):
def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):
def __init__(self,
size,
channel_multiplier=2,
blur_kernel=[1, 3, 3, 1],
narrow=1):
super().__init__()
channels = {
4: 512,
8: 512,
16: 512,
32: 512,
64: 256 * channel_multiplier,
128: 128 * channel_multiplier,
256: 64 * channel_multiplier,
512: 32 * channel_multiplier,
1024: 16 * channel_multiplier,
4: int(512 * narrow),
8: int(512 * narrow),
16: int(512 * narrow),
32: int(512 * narrow),
64: int(256 * channel_multiplier * narrow),
128: int(128 * channel_multiplier * narrow),
256: int(64 * channel_multiplier * narrow),
512: int(32 * channel_multiplier * narrow),
1024: int(16 * channel_multiplier * narrow),
2048: int(8 * channel_multiplier * narrow)
}
convs = [ConvLayer(3, channels[size], 1)]
@@ -713,48 +807,53 @@ class Discriminator(nn.Module):
return out
class FullGenerator(nn.Module):
class FullGenerator_SR(nn.Module):
def __init__(
self,
size,
in_size,
out_size,
style_dim,
n_mlp,
channel_multiplier=2,
blur_kernel=[1, 3, 3, 1],
lr_mlp=0.01,
narrow=1,
):
super().__init__()
channels = {
4: 512 // ratio,
8: 512 // ratio,
16: 512 // ratio,
32: 512 // ratio,
64: 256 // ratio * channel_multiplier,
128: 128 // ratio * channel_multiplier,
256: 64 // ratio * channel_multiplier,
512: 32 // ratio * channel_multiplier,
1024: 16 // ratio * channel_multiplier,
4: int(512 * narrow),
8: int(512 * narrow),
16: int(512 * narrow),
32: int(512 * narrow),
64: int(256 * channel_multiplier * narrow),
128: int(128 * channel_multiplier * narrow),
256: int(64 * channel_multiplier * narrow),
512: int(32 * channel_multiplier * narrow),
1024: int(16 * channel_multiplier * narrow),
2048: int(8 * channel_multiplier * narrow),
}
self.log_size = int(math.log(size, 2))
self.log_insize = int(math.log(in_size, 2))
self.log_outsize = int(math.log(out_size, 2))
self.generator = Generator(
size,
out_size,
style_dim,
n_mlp,
channel_multiplier=channel_multiplier,
blur_kernel=blur_kernel,
lr_mlp=lr_mlp)
lr_mlp=lr_mlp,
narrow=narrow)
conv = [ConvLayer(3, channels[size], 1)]
conv = [ConvLayer(3, channels[in_size], 1)]
self.ecd0 = nn.Sequential(*conv)
in_channel = channels[size]
in_channel = channels[in_size]
self.names = ['ecd%d' % i for i in range(self.log_size - 1)]
for i in range(self.log_size, 2, -1):
self.names = ['ecd%d' % i for i in range(self.log_insize - 1)]
for i in range(self.log_insize, 2, -1):
out_channel = channels[2**(i - 1)]
conv = [ConvLayer(in_channel, out_channel, 3, downsample=True)]
setattr(self, self.names[self.log_size - i + 1],
setattr(self, self.names[self.log_insize - i + 1],
nn.Sequential(*conv))
in_channel = out_channel
self.final_linear = nn.Sequential(
@@ -771,18 +870,22 @@ class FullGenerator(nn.Module):
input_is_latent=False,
):
noise = []
for i in range(self.log_size - 1):
for i in range(self.log_outsize - self.log_insize):
noise.append(None)
for i in range(self.log_insize - 1):
ecd = getattr(self, self.names[i])
inputs = ecd(inputs)
noise.append(inputs)
inputs = inputs.view(inputs.shape[0], -1)
outs = self.final_linear(inputs)
outs = self.generator([outs],
return_latents,
inject_index,
truncation,
truncation_latent,
input_is_latent,
noise=noise[::-1])
return outs
noise = list(
itertools.chain.from_iterable(
itertools.repeat(x, 2) for x in noise))[::-1]
image, latent = self.generator([outs],
return_latents,
inject_index,
truncation,
truncation_latent,
input_is_latent,
noise=noise[1:])
return image, latent

View File

@@ -1,4 +1,2 @@
# The implementation is adopted from stylegan2-pytorch, made public available under the MIT License
# at https://github.com/rosinality/stylegan2-pytorch
from .fused_act import FusedLeakyReLU, fused_leaky_relu
from .upfirdn2d import upfirdn2d

View File

@@ -15,6 +15,77 @@ REFERENCE_FACIAL_POINTS = [[30.29459953, 51.69630051],
DEFAULT_CROP_SIZE = (96, 112)
def _umeyama(src, dst, estimate_scale=True, scale=1.0):
"""Estimate N-D similarity transformation with or without scaling.
Parameters
----------
src : (M, N) array
Source coordinates.
dst : (M, N) array
Destination coordinates.
estimate_scale : bool
Whether to estimate scaling factor.
Returns
-------
T : (N + 1, N + 1)
The homogeneous similarity transformation matrix. The matrix contains
NaN values only if the problem is not well-conditioned.
References
----------
.. [1] "Least-squares estimation of transformation parameters between two
point patterns", Shinji Umeyama, PAMI 1991, :DOI:`10.1109/34.88573`
"""
num = src.shape[0]
dim = src.shape[1]
# Compute mean of src and dst.
src_mean = src.mean(axis=0)
dst_mean = dst.mean(axis=0)
# Subtract mean from src and dst.
src_demean = src - src_mean
dst_demean = dst - dst_mean
# Eq. (38).
A = dst_demean.T @ src_demean / num
# Eq. (39).
d = np.ones((dim, ), dtype=np.double)
if np.linalg.det(A) < 0:
d[dim - 1] = -1
T = np.eye(dim + 1, dtype=np.double)
U, S, V = np.linalg.svd(A)
# Eq. (40) and (43).
rank = np.linalg.matrix_rank(A)
if rank == 0:
return np.nan * T
elif rank == dim - 1:
if np.linalg.det(U) * np.linalg.det(V) > 0:
T[:dim, :dim] = U @ V
else:
s = d[dim - 1]
d[dim - 1] = -1
T[:dim, :dim] = U @ np.diag(d) @ V
d[dim - 1] = s
else:
T[:dim, :dim] = U @ np.diag(d) @ V
if estimate_scale:
# Eq. (41) and (42).
scale = 1.0 / src_demean.var(axis=0).sum() * (S @ d)
else:
scale = scale
T[:dim, dim] = dst_mean - scale * (T[:dim, :dim] @ src_mean.T)
T[:dim, :dim] *= scale
return T, scale
class FaceWarpException(Exception):
def __str__(self):
@@ -246,6 +317,65 @@ def warp_and_crop_face(src_img,
return face_img
def warp_and_crop_face_enhance(src_img,
facial_pts,
reference_pts=None,
crop_size=(96, 112),
align_type='smilarity'):
if reference_pts is None:
if crop_size[0] == 96 and crop_size[1] == 112:
reference_pts = REFERENCE_FACIAL_POINTS
else:
default_square = False
inner_padding_factor = 0
outer_padding = (0, 0)
output_size = crop_size
reference_pts = get_reference_facial_points(
output_size, inner_padding_factor, outer_padding,
default_square)
ref_pts = np.float32(reference_pts)
ref_pts_shp = ref_pts.shape
if max(ref_pts_shp) < 3 or min(ref_pts_shp) != 2:
raise FaceWarpException(
'reference_pts.shape must be (K,2) or (2,K) and K>2')
if ref_pts_shp[0] == 2:
ref_pts = ref_pts.T
src_pts = np.float32(facial_pts)
src_pts_shp = src_pts.shape
if max(src_pts_shp) < 3 or min(src_pts_shp) != 2:
raise FaceWarpException(
'facial_pts.shape must be (K,2) or (2,K) and K>2')
if src_pts_shp[0] == 2:
src_pts = src_pts.T
if src_pts.shape != ref_pts.shape:
raise FaceWarpException(
'facial_pts and reference_pts must have the same shape')
if align_type == 'cv2_affine':
tfm = cv2.getAffineTransform(src_pts[0:3], ref_pts[0:3])
tfm_inv = cv2.getAffineTransform(ref_pts[0:3], src_pts[0:3])
elif align_type == 'affine':
tfm = get_affine_transform_matrix(src_pts, ref_pts)
tfm_inv = get_affine_transform_matrix(ref_pts, src_pts)
else:
params, scale = _umeyama(src_pts, ref_pts)
tfm = params[:2, :]
params, _ = _umeyama(ref_pts, src_pts, False, scale=1.0 / scale)
tfm_inv = params[:2, :]
face_img = cv2.warpAffine(
src_img, tfm, (crop_size[0], crop_size[1]), flags=3)
return face_img, tfm_inv
def get_f5p(landmarks, np_img):
eye_left = find_pupil(landmarks[36:41], np_img)
eye_right = find_pupil(landmarks[42:47], np_img)

View File

@@ -16,9 +16,10 @@ from modelscope.models.builder import MODELS
from modelscope.models.cv.face_detection.peppa_pig_face.facer import FaceAna
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger
from .facegan.gan_wrap import GANWrap
from .facegan.face_gan import GPEN
from .facelib.align_trans import (get_f5p, get_reference_facial_points,
warp_and_crop_face)
warp_and_crop_face,
warp_and_crop_face_enhance)
from .network.aei_flow_net import AEI_Net
from .network.bfm import ParametricFaceModel
from .network.facerecon_model import ReconNetWrapper
@@ -78,14 +79,6 @@ class ImageFaceFusion(TorchModel):
self.face_model = ParametricFaceModel(bfm_folder=bfm_dir)
self.face_model.to(self.device)
face_enhance_path = os.path.join(model_dir, 'faceEnhance',
'350000-Ns256.pt')
self.ganwrap = GANWrap(
model_path=face_enhance_path,
size=256,
channel_multiplier=1,
device=self.device)
self.facer = FaceAna(model_dir)
logger.info('load facefusion models done')
@@ -94,6 +87,27 @@ class ImageFaceFusion(TorchModel):
self.mask_init = cv2.resize(self.mask_init, (256, 256))
self.mask = self.image_transform(self.mask_init, is_norm=False)
face_enhance_path = os.path.join(model_dir, 'faceEnhance',
'GPEN-BFR-1024.pth')
if not os.path.exists(face_enhance_path):
logger.warning(
'model path not found, please update the latest model!')
self.ganwrap_1024 = GPEN(face_enhance_path, 1024, 2, self.device)
self.mask_enhance = np.zeros((512, 512), np.float32)
cv2.rectangle(self.mask_enhance, (26, 26), (486, 486), (1, 1, 1), -1,
cv2.LINE_AA)
self.mask_enhance = cv2.GaussianBlur(self.mask_enhance, (101, 101), 11)
self.mask_enhance = cv2.GaussianBlur(self.mask_enhance, (101, 101), 11)
default_square = True
inner_padding_factor = 0.25
outer_padding = (0, 0)
self.reference_5pts_1024 = get_reference_facial_points(
(1024, 1024), inner_padding_factor, outer_padding, default_square)
self.test_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
@@ -157,7 +171,7 @@ class ImageFaceFusion(TorchModel):
src_h, src_w, _ = img.shape
boxes, landmarks, _ = self.facer.run(img)
if boxes.shape[0] == 0:
return None
return None, None, None
elif boxes.shape[0] > 1:
max_area = 0
max_index = 0
@@ -168,9 +182,14 @@ class ImageFaceFusion(TorchModel):
if area > max_area:
max_index = i
max_area = area
return landmarks[max_index]
fw = boxes[max_index][2] - boxes[max_index][0]
fh = boxes[max_index][3] - boxes[max_index][1]
return landmarks[max_index], fw, fh
else:
return landmarks[0]
fw = boxes[0][2] - boxes[0][0]
fh = boxes[0][3] - boxes[0][1]
return landmarks[0], fw, fh
def compute_3d_params(self, Xs, Xt):
kp_fuse = {}
@@ -198,6 +217,51 @@ class ImageFaceFusion(TorchModel):
return kp_fuse, kp_t
def process_enhance(self, im, f5p, fh, fw):
height, width, _ = im.shape
of, tfm_inv = warp_and_crop_face_enhance(
im,
f5p,
reference_pts=self.reference_5pts_1024,
crop_size=(1024, 1024))
ef, pred = self.ganwrap_1024.process(of)
tmp_mask = self.mask_enhance
tmp_mask = cv2.resize(tmp_mask, ef.shape[:2])
tmp_mask = cv2.warpAffine(tmp_mask, tfm_inv, (width, height), flags=3)
full_mask = np.zeros((height, width), dtype=np.float32)
full_img = np.zeros(im.shape, dtype=np.uint8)
if min(fh, fw) < 40:
ef = cv2.pyrDown(ef)
ef = cv2.pyrDown(ef)
ef = cv2.pyrUp(ef)
ef = cv2.pyrUp(ef)
elif min(fh, fw) < 60:
ef = cv2.pyrDown(ef)
ef = cv2.resize(ef, (0, 0), fx=2, fy=2)
ef = cv2.resize(ef, (0, 0), fx=0.5, fy=0.5)
ef = cv2.pyrUp(ef)
elif min(fh, fw) < 80:
ef = cv2.pyrDown(ef)
ef = cv2.pyrUp(ef)
elif min(fh, fw) < 100:
ef = cv2.pyrDown(ef)
ef = cv2.resize(ef, (0, 0), fx=2, fy=2)
tmp_img = cv2.warpAffine(ef, tfm_inv, (width, height), flags=3)
mask = tmp_mask - full_mask
full_mask[np.where(mask > 0)] = tmp_mask[np.where(mask > 0)]
full_img[np.where(mask > 0)] = tmp_img[np.where(mask > 0)]
full_mask = full_mask[:, :, np.newaxis]
im = cv2.convertScaleAbs(im * (1 - full_mask) + full_img * full_mask)
im = cv2.resize(im, (width, height))
return im
def inference(self, template_img, user_img):
ori_h, ori_w, _ = template_img.shape
@@ -205,14 +269,14 @@ class ImageFaceFusion(TorchModel):
user_img = user_img.cpu().numpy()
user_img_bgr = user_img[:, :, ::-1]
landmark_source = self.detect_face(user_img)
landmark_source, _, _ = self.detect_face(user_img)
if landmark_source is None:
logger.warning('No face detected in user image!')
return template_img
f5p_user = get_f5p(landmark_source, user_img_bgr)
template_img_bgr = template_img[:, :, ::-1]
landmark_template = self.detect_face(template_img)
landmark_template, fw, fh = self.detect_face(template_img)
if landmark_template is None:
logger.warning('No face detected in template image!')
return template_img
@@ -235,7 +299,6 @@ class ImageFaceFusion(TorchModel):
with torch.no_grad():
kp_fuse, kp_t = self.compute_3d_params(Xs, Xt)
Yt, _, _ = self.netG(Xt, Xs_embeds, kp_fuse, kp_t)
Yt = self.ganwrap.process_tensor(Yt)
Yt = Yt * 0.5 + 0.5
Yt = torch.clamp(Yt, 0, 1)
@@ -247,6 +310,7 @@ class ImageFaceFusion(TorchModel):
0).cpu().numpy()
Yt_trans_inv = Yt_trans_inv.astype(np.float32)
out_img = Yt_trans_inv[:, :, ::-1] * 255.
out_img = self.process_enhance(out_img, f5p_template, fh, fw)
logger.info('model inference done')

View File

@@ -15,6 +15,7 @@ from diffusers.models import cross_attention
from diffusers.utils import deprecation_utils
from transformers import CLIPTextModel, CLIPTokenizer
from modelscope import snapshot_download
from modelscope.metainfo import Models
from modelscope.models import TorchModel
from modelscope.models.builder import MODELS
@@ -56,7 +57,10 @@ class EfficientStableDiffusion(TorchModel):
super().__init__(model_dir, *args, **kwargs)
tuner_name = kwargs.pop('tuner_name', 'lora')
pretrained_model_name_or_path = kwargs.pop(
'pretrained_model_name_or_path', 'runwayml/stable-diffusion-v1-5')
'pretrained_model_name_or_path',
'AI-ModelScope/stable-diffusion-v1-5')
pretrained_model_name_or_path = snapshot_download(
pretrained_model_name_or_path)
tuner_config = kwargs.pop('tuner_config', None)
pretrained_tuner = kwargs.pop('pretrained_tuner', None)
revision = kwargs.pop('revision', None)

View File

@@ -0,0 +1,24 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import TYPE_CHECKING
from modelscope.utils.import_utils import LazyImportModule
if TYPE_CHECKING:
from .image_to_video_model import ImageToVideo
else:
_import_structure = {
'image_to_video_model': ['ImageToVideo'],
}
import sys
sys.modules[__name__] = LazyImportModule(
__name__,
globals()['__file__'],
_import_structure,
module_spec=__spec__,
extra_objects={},
)

View File

@@ -0,0 +1,215 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import os.path as osp
import random
from copy import copy
from typing import Any, Dict
import torch
import torch.cuda.amp as amp
import modelscope.models.multi_modal.image_to_video.utils.transforms as data
from modelscope.metainfo import Models
from modelscope.models.base import TorchModel
from modelscope.models.builder import MODELS
from modelscope.models.multi_modal.image_to_video.modules import *
from modelscope.models.multi_modal.image_to_video.modules import (
AutoencoderKL, FrozenOpenCLIPVisualEmbedder, Img2VidSDUNet)
from modelscope.models.multi_modal.image_to_video.utils.config import cfg
from modelscope.models.multi_modal.image_to_video.utils.diffusion import \
GaussianDiffusion
from modelscope.models.multi_modal.image_to_video.utils.seed import setup_seed
from modelscope.models.multi_modal.image_to_video.utils.shedule import \
beta_schedule
from modelscope.utils.config import Config
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger
__all__ = ['ImageToVideo']
logger = get_logger()
@MODELS.register_module(
Tasks.image_to_video, module_name=Models.image_to_video_model)
class ImageToVideo(TorchModel):
r"""
Image2Video aims to solve the task of generating high-definition videos based on input images.
Image2Video is a video generation basic model developed by Alibaba Cloud, with a parameter size
of approximately 2 billion. It has been pre trained on large-scale video and image data and
fine-tuned on a small amount of high-quality data. The data is widely distributed and diverse
in categories, and the model has good generalization ability for different types of data
Paper link: https://arxiv.org/abs/2306.02018
Attributes:
diffusion: diffusion model for DDIM.
autoencoder: decode the latent representation into visual space.
clip_encoder: encode the image into image embedding.
"""
def __init__(self, model_dir, *args, **kwargs):
r"""
Args:
model_dir (`str` or `os.PathLike`)
Can be either:
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co
or modelscope.cn. Valid model ids can be located at the root-level, like `bert-base-uncased`,
or namespaced under a user or organization name, like `dbmdz/bert-base-german-cased`.
- A path to a *directory* containing model weights saved using
[`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
- A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In
this case, `from_tf` should be set to `True` and a configuration object should be provided as
`config` argument. This loading path is slower than converting the TensorFlow checkpoint in a
PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
- A path or url to a model folder containing a *flax checkpoint file* in *.msgpack* format (e.g,
`./flax_model/` containing `flax_model.msgpack`). In this case, `from_flax` should be set to
`True`.
"""
super().__init__(model_dir=model_dir, *args, **kwargs)
self.config = Config.from_file(
osp.join(model_dir, ModelFile.CONFIGURATION))
# assign default value
cfg.batch_size = self.config.model.model_cfg.batch_size
cfg.target_fps = self.config.model.model_cfg.target_fps
cfg.max_frames = self.config.model.model_cfg.max_frames
cfg.latent_hei = self.config.model.model_cfg.latent_hei
cfg.latent_wid = self.config.model.model_cfg.latent_wid
cfg.model_path = osp.join(model_dir,
self.config.model.model_args.ckpt_unet)
self.device = torch.device(
'cuda') if torch.cuda.is_available() else torch.device('cpu')
if 'seed' in self.config.model.model_args.keys():
cfg.seed = self.config.model.model_args.seed
else:
cfg.seed = random.randint(0, 99999)
setup_seed(cfg.seed)
# transform
vid_trans = data.Compose([
data.CenterCropWide(size=(cfg.resolution[0], cfg.resolution[0])),
data.Resize(cfg.vit_resolution),
data.ToTensor(),
data.Normalize(mean=cfg.vit_mean, std=cfg.vit_std)
])
self.vid_trans = vid_trans
cfg.embedder.pretrained = osp.join(
model_dir, self.config.model.model_args.ckpt_clip)
clip_encoder = FrozenOpenCLIPVisualEmbedder(**cfg.embedder)
clip_encoder.model.to(self.device)
self.clip_encoder = clip_encoder
logger.info(f'Build encoder with {cfg.embedder.type}')
# [unet]
generator = Img2VidSDUNet(**cfg.UNet)
generator = generator.to(self.device)
generator.eval()
load_dict = torch.load(cfg.model_path, map_location='cpu')
ret = generator.load_state_dict(load_dict['state_dict'], strict=True)
self.generator = generator
logger.info('Load model {} path {}, with local status {}'.format(
cfg.UNet.type, cfg.model_path, ret))
# [diffusion]
betas = beta_schedule(
'linear_sd',
cfg.num_timesteps,
init_beta=0.00085,
last_beta=0.0120)
diffusion = GaussianDiffusion(
betas=betas,
mean_type=cfg.mean_type,
var_type=cfg.var_type,
loss_type=cfg.loss_type,
rescale_timesteps=False,
noise_strength=getattr(cfg, 'noise_strength', 0))
self.diffusion = diffusion
logger.info('Build diffusion with type of GaussianDiffusion')
# [auotoencoder]
cfg.auto_encoder.pretrained = osp.join(
model_dir, self.config.model.model_args.ckpt_autoencoder)
autoencoder = AutoencoderKL(**cfg.auto_encoder)
autoencoder.eval()
for param in autoencoder.parameters():
param.requires_grad = False
autoencoder.to(self.device)
self.autoencoder = autoencoder
torch.cuda.empty_cache()
zero_feature = torch.zeros(1, 1, cfg.UNet.input_dim).to(self.device)
self.zero_feature = zero_feature
self.fps_tensor = torch.tensor([cfg.target_fps],
dtype=torch.long,
device=self.device)
self.cfg = cfg
def forward(self, input: Dict[str, Any]):
r"""
The entry function of image to video task.
1. Using diffusion model to generate the video's latent representation.
2. Using autoencoder to decode the video's latent representation to visual space.
Args:
input (`Dict[Str, Any]`):
The input of the task
Returns:
A generated video (as pytorch tensor).
"""
vit_frame = input['vit_frame']
cfg = self.cfg
img_embedding = self.clip_encoder(vit_frame).unsqueeze(1)
noise = self.build_noise()
zero_feature = copy(self.zero_feature)
with torch.no_grad():
with amp.autocast(enabled=cfg.use_fp16):
model_kwargs = [{
'y': img_embedding,
'fps': self.fps_tensor
}, {
'y': zero_feature.repeat(cfg.batch_size, 1, 1),
'fps': self.fps_tensor
}]
gen_video = self.diffusion.ddim_sample_loop(
noise=noise,
model=self.generator,
model_kwargs=model_kwargs,
guide_scale=cfg.guide_scale,
ddim_timesteps=cfg.ddim_timesteps,
eta=0.0)
gen_video = 1. / cfg.scale_factor * gen_video
gen_video = rearrange(gen_video, 'b c f h w -> (b f) c h w')
chunk_size = min(cfg.decoder_bs, gen_video.shape[0])
gen_video_list = torch.chunk(
gen_video, gen_video.shape[0] // chunk_size, dim=0)
decode_generator = []
for vd_data in gen_video_list:
gen_frames = self.autoencoder.decode(vd_data)
decode_generator.append(gen_frames)
gen_video = torch.cat(decode_generator, dim=0)
gen_video = rearrange(
gen_video, '(b f) c h w -> b c f h w', b=cfg.batch_size)
return gen_video.type(torch.float32).cpu()
def build_noise(self):
cfg = self.cfg
noise = torch.randn(
[1, 4, cfg.max_frames, cfg.latent_hei,
cfg.latent_wid]).to(self.device)
if cfg.noise_strength > 0:
b, c, f, *_ = noise.shape
offset_noise = torch.randn(b, c, f, 1, 1, device=noise.device)
noise = noise + cfg.noise_strength * offset_noise
return noise.contiguous()

View File

@@ -0,0 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from .autoencoder import *
from .embedder import *
from .unet_i2v import *

View File

@@ -0,0 +1,573 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import collections
import logging
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
def nonlinearity(x):
# swish
return x * torch.sigmoid(x)
def Normalize(in_channels, num_groups=32):
return torch.nn.GroupNorm(
num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
class DiagonalGaussianDistribution(object):
def __init__(self, parameters, deterministic=False):
self.parameters = parameters
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.deterministic = deterministic
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
if self.deterministic:
self.var = self.std = torch.zeros_like(
self.mean).to(device=self.parameters.device)
def sample(self):
x = self.mean + self.std * torch.randn(
self.mean.shape).to(device=self.parameters.device)
return x
def kl(self, other=None):
if self.deterministic:
return torch.Tensor([0.])
else:
if other is None:
return 0.5 * torch.sum(
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
dim=[1, 2, 3])
else:
return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
dim=[1, 2, 3])
def nll(self, sample, dims=[1, 2, 3]):
if self.deterministic:
return torch.Tensor([0.])
logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum(
logtwopi + self.logvar
+ torch.pow(sample - self.mean, 2) / self.var,
dim=dims)
def mode(self):
return self.mean
class ResnetBlock(nn.Module):
def __init__(self,
*,
in_channels,
out_channels=None,
conv_shortcut=False,
dropout,
temb_channels=512):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.norm1 = Normalize(in_channels)
self.conv1 = torch.nn.Conv2d(
in_channels, out_channels, kernel_size=3, stride=1, padding=1)
if temb_channels > 0:
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
self.norm2 = Normalize(out_channels)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = torch.nn.Conv2d(
out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = torch.nn.Conv2d(
in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
else:
self.nin_shortcut = torch.nn.Conv2d(
in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0)
def forward(self, x, temb):
h = x
h = self.norm1(h)
h = nonlinearity(h)
h = self.conv1(h)
if temb is not None:
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
h = self.norm2(h)
h = nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)
return x + h
class AttnBlock(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.k = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
b, c, h, w = q.shape
q = q.reshape(b, c, h * w)
q = q.permute(0, 2, 1)
k = k.reshape(b, c, h * w)
w_ = torch.bmm(q, k)
w_ = w_ * (int(c)**(-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)
# attend to values
v = v.reshape(b, c, h * w)
w_ = w_.permute(0, 2, 1)
h_ = torch.bmm(v, w_)
h_ = h_.reshape(b, c, h, w)
h_ = self.proj_out(h_)
return x + h_
class Upsample(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
self.conv = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=1, padding=1)
def forward(self, x):
x = torch.nn.functional.interpolate(
x, scale_factor=2.0, mode='nearest')
if self.with_conv:
x = self.conv(x)
return x
class Downsample(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
# no asymmetric padding in torch conv, must do it ourselves
self.conv = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=2, padding=0)
def forward(self, x):
if self.with_conv:
pad = (0, 1, 0, 1)
x = torch.nn.functional.pad(x, pad, mode='constant', value=0)
x = self.conv(x)
else:
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
return x
class Encoder(nn.Module):
def __init__(self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
z_channels,
double_z=True,
use_linear_attn=False,
attn_type='vanilla',
**ignore_kwargs):
super().__init__()
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
# downsampling
self.conv_in = torch.nn.Conv2d(
in_channels, self.ch, kernel_size=3, stride=1, padding=1)
curr_res = resolution
in_ch_mult = (1, ) + tuple(ch_mult)
self.in_ch_mult = in_ch_mult
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks):
block.append(
ResnetBlock(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(AttnBlock(block_in))
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions - 1:
down.downsample = Downsample(block_in, resamp_with_conv)
curr_res = curr_res // 2
self.down.append(down)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
self.mid.attn_1 = AttnBlock(block_in)
self.mid.block_2 = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(
block_in,
2 * z_channels if double_z else z_channels,
kernel_size=3,
stride=1,
padding=1)
def forward(self, x):
# timestep embedding
temb = None
# downsampling
hs = [self.conv_in(x)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1], temb)
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
hs.append(h)
if i_level != self.num_resolutions - 1:
hs.append(self.down[i_level].downsample(hs[-1]))
# middle
h = hs[-1]
h = self.mid.block_1(h, temb)
h = self.mid.attn_1(h)
h = self.mid.block_2(h, temb)
# end
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
return h
class Decoder(nn.Module):
def __init__(self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
z_channels,
give_pre_end=False,
tanh_out=False,
use_linear_attn=False,
attn_type='vanilla',
**ignorekwargs):
super().__init__()
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.give_pre_end = give_pre_end
self.tanh_out = tanh_out
# compute block_in and curr_res at lowest res
block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2**(self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res)
logging.info('Working with z of shape {} = {} dimensions.'.format(
self.z_shape, np.prod(self.z_shape)))
# z to block_in
self.conv_in = torch.nn.Conv2d(
z_channels, block_in, kernel_size=3, stride=1, padding=1)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
self.mid.attn_1 = AttnBlock(block_in)
self.mid.block_2 = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
block.append(
ResnetBlock(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(AttnBlock(block_in))
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = Upsample(block_in, resamp_with_conv)
curr_res = curr_res * 2
self.up.insert(0, up)
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(
block_in, out_ch, kernel_size=3, stride=1, padding=1)
def forward(self, z):
self.last_z_shape = z.shape
# timestep embedding
temb = None
# z to block_in
h = self.conv_in(z)
# middle
h = self.mid.block_1(h, temb)
h = self.mid.attn_1(h)
h = self.mid.block_2(h, temb)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](h, temb)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
if i_level != 0:
h = self.up[i_level].upsample(h)
# end
if self.give_pre_end:
return h
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
if self.tanh_out:
h = torch.tanh(h)
return h
class AutoencoderKL(nn.Module):
def __init__(self,
ddconfig,
embed_dim,
pretrained=None,
ignore_keys=[],
image_key='image',
colorize_nlabels=None,
monitor=None,
ema_decay=None,
learn_logvar=False,
**kwargs):
super().__init__()
self.learn_logvar = learn_logvar
self.image_key = image_key
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
assert ddconfig['double_z']
self.quant_conv = torch.nn.Conv2d(2 * ddconfig['z_channels'],
2 * embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim,
ddconfig['z_channels'], 1)
self.embed_dim = embed_dim
if colorize_nlabels is not None:
assert type(colorize_nlabels) == int
self.register_buffer('colorize',
torch.randn(3, colorize_nlabels, 1, 1))
if monitor is not None:
self.monitor = monitor
self.use_ema = ema_decay is not None
if pretrained is not None:
self.init_from_ckpt(pretrained, ignore_keys=ignore_keys)
def init_from_ckpt(self, path, ignore_keys=list()):
sd = torch.load(path, map_location='cpu')['state_dict']
keys = list(sd.keys())
sd_new = collections.OrderedDict()
for k in keys:
if k.find('first_stage_model') >= 0:
k_new = k.split('first_stage_model.')[-1]
sd_new[k_new] = sd[k]
self.load_state_dict(sd_new, strict=True)
logging.info(f'Restored from {path}')
def on_train_batch_end(self, *args, **kwargs):
if self.use_ema:
self.model_ema(self)
def encode(self, x):
h = self.encoder(x)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
return posterior
def decode(self, z):
z = self.post_quant_conv(z)
dec = self.decoder(z)
return dec
def forward(self, input, sample_posterior=True):
posterior = self.encode(input)
if sample_posterior:
z = posterior.sample()
else:
z = posterior.mode()
dec = self.decode(z)
return dec, posterior
def get_input(self, batch, k):
x = batch[k]
if len(x.shape) == 3:
x = x[..., None]
x = x.permute(0, 3, 1,
2).to(memory_format=torch.contiguous_format).float()
return x
def get_last_layer(self):
return self.decoder.conv_out.weight
@torch.no_grad()
def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs):
log = dict()
x = self.get_input(batch, self.image_key)
x = x.to(self.device)
if not only_inputs:
xrec, posterior = self(x)
if x.shape[1] > 3:
# colorize with random projection
assert xrec.shape[1] > 3
x = self.to_rgb(x)
xrec = self.to_rgb(xrec)
log['samples'] = self.decode(torch.randn_like(posterior.sample()))
log['reconstructions'] = xrec
if log_ema or self.use_ema:
with self.ema_scope():
xrec_ema, posterior_ema = self(x)
if x.shape[1] > 3:
# colorize with random projection
assert xrec_ema.shape[1] > 3
xrec_ema = self.to_rgb(xrec_ema)
log['samples_ema'] = self.decode(
torch.randn_like(posterior_ema.sample()))
log['reconstructions_ema'] = xrec_ema
log['inputs'] = x
return log
def to_rgb(self, x):
assert self.image_key == 'segmentation'
if not hasattr(self, 'colorize'):
self.register_buffer('colorize',
torch.randn(3, x.shape[1], 1, 1).to(x))
x = F.conv2d(x, weight=self.colorize)
x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
return x
class IdentityFirstStage(torch.nn.Module):
def __init__(self, *args, vq_interface=False, **kwargs):
self.vq_interface = vq_interface
super().__init__()
def encode(self, x, *args, **kwargs):
return x
def decode(self, x, *args, **kwargs):
return x
def quantize(self, x, *args, **kwargs):
if self.vq_interface:
return x, None, [None, None, None]
return x
def forward(self, x, *args, **kwargs):
return x

View File

@@ -0,0 +1,82 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import logging
import os
import numpy as np
import open_clip
import torch
import torch.nn as nn
import torchvision.transforms as T
class FrozenOpenCLIPVisualEmbedder(nn.Module):
"""
Uses the OpenCLIP transformer encoder for text
"""
LAYERS = ['last', 'penultimate']
def __init__(self,
pretrained,
vit_resolution=(224, 224),
arch='ViT-H-14',
device='cuda',
max_length=77,
freeze=True,
layer='last',
**kwargs):
super().__init__()
assert layer in self.LAYERS
model, _, preprocess = open_clip.create_model_and_transforms(
arch, device=torch.device('cpu'), pretrained=pretrained)
del model.transformer
self.model = model
data_white = np.ones(
(vit_resolution[0], vit_resolution[1], 3), dtype=np.uint8) * 255
self.white_image = preprocess(T.ToPILImage()(data_white)).unsqueeze(0)
self.device = device
self.max_length = max_length
if freeze:
self.freeze()
self.layer = layer
if self.layer == 'last':
self.layer_idx = 0
elif self.layer == 'penultimate':
self.layer_idx = 1
else:
raise NotImplementedError()
def freeze(self):
self.model = self.model.eval()
for param in self.parameters():
param.requires_grad = False
def forward(self, image):
z = self.model.encode_image(image.to(self.device))
return z
def encode_with_transformer(self, text):
x = self.model.token_embedding(text)
x = x + self.model.positional_embedding
x = x.permute(1, 0, 2)
x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
x = x.permute(1, 0, 2)
x = self.model.ln_final(x)
return x
def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
for i, r in enumerate(self.model.transformer.resblocks):
if i == len(self.model.transformer.resblocks) - self.layer_idx:
break
if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting(
):
x = checkpoint(r, x, attn_mask)
else:
x = r(x, attn_mask=attn_mask)
return x
def encode(self, text):
return self(text)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,2 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os

View File

@@ -0,0 +1,161 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import logging
import os
import os.path as osp
from datetime import datetime
import torch
from easydict import EasyDict
cfg = EasyDict(__name__='Config: VideoLDM Decoder')
# ---------------------------work dir--------------------------
cfg.work_dir = 'workspace/'
# ---------------------------Global Variable-----------------------------------
cfg.resolution = [448, 256]
# -----------------------------------------------------------------------------
# ---------------------------Dataset Parameter---------------------------------
cfg.mean = [0.5, 0.5, 0.5]
cfg.std = [0.5, 0.5, 0.5]
cfg.max_words = 1000
# PlaceHolder
cfg.vit_out_dim = 1024
cfg.vit_resolution = [224, 224]
cfg.depth_clamp = 10.0
cfg.misc_size = 384
cfg.depth_std = 20.0
cfg.frame_lens = 32
cfg.sample_fps = 8
cfg.batch_sizes = 1
# -----------------------------------------------------------------------------
# ---------------------------Mode Parameters-----------------------------------
# Diffusion
cfg.schedule = 'cosine'
cfg.num_timesteps = 1000
cfg.mean_type = 'v'
cfg.var_type = 'fixed_small'
cfg.loss_type = 'mse'
cfg.ddim_timesteps = 50
cfg.ddim_eta = 0.0
cfg.clamp = 1.0
cfg.share_noise = False
cfg.use_div_loss = False
cfg.noise_strength = 0.1
# classifier-free guidance
cfg.p_zero = 0.1
cfg.guide_scale = 3.0
# clip vision encoder
cfg.vit_mean = [0.48145466, 0.4578275, 0.40821073]
cfg.vit_std = [0.26862954, 0.26130258, 0.27577711]
# Model
cfg.scale_factor = 0.18215
cfg.use_fp16 = True
cfg.temporal_attention = True
cfg.decoder_bs = 8
cfg.UNet = {
'type': 'Img2VidSDUNet',
'in_dim': 4,
'dim': 320,
'y_dim': cfg.vit_out_dim,
'context_dim': 1024,
'out_dim': 8 if cfg.var_type.startswith('learned') else 4,
'dim_mult': [1, 2, 4, 4],
'num_heads': 8,
'head_dim': 64,
'num_res_blocks': 2,
'attn_scales': [1 / 1, 1 / 2, 1 / 4],
'dropout': 0.1,
'temporal_attention': cfg.temporal_attention,
'temporal_attn_times': 1,
'use_checkpoint': False,
'use_fps_condition': False,
'use_sim_mask': False,
'num_tokens': 4,
'default_fps': 8,
'input_dim': 1024
}
cfg.guidances = []
# auotoencoder from stabel diffusion
cfg.auto_encoder = {
'type': 'AutoencoderKL',
'ddconfig': {
'double_z': True,
'z_channels': 4,
'resolution': 256,
'in_channels': 3,
'out_ch': 3,
'ch': 128,
'ch_mult': [1, 2, 4, 4],
'num_res_blocks': 2,
'attn_resolutions': [],
'dropout': 0.0
},
'embed_dim': 4,
'pretrained': 'v2-1_512-ema-pruned.ckpt'
}
# clip embedder
cfg.embedder = {
'type': 'FrozenOpenCLIPVisualEmbedder',
'layer': 'penultimate',
'vit_resolution': [224, 224],
'pretrained': 'open_clip_pytorch_model.bin'
}
# -----------------------------------------------------------------------------
# ---------------------------Training Settings---------------------------------
# training and optimizer
cfg.ema_decay = 0.9999
cfg.num_steps = 600000
cfg.lr = 5e-5
cfg.weight_decay = 0.0
cfg.betas = (0.9, 0.999)
cfg.eps = 1.0e-8
cfg.chunk_size = 16
cfg.alpha = 0.7
cfg.save_ckp_interval = 1000
# -----------------------------------------------------------------------------
# ----------------------------Pretrain Settings---------------------------------
# Default: load 2d pretrain
cfg.fix_weight = False
cfg.load_match = False
cfg.pretrained_checkpoint = 'v2-1_512-ema-pruned.ckpt'
cfg.pretrained_image_keys = 'stable_diffusion_image_key_temporal_attention_x1.json'
cfg.resume_checkpoint = 'img2video_ldm_0779000.pth'
# -----------------------------------------------------------------------------
# -----------------------------Visual-------------------------------------------
# Visual videos
cfg.viz_interval = 1000
cfg.visual_train = {
'type': 'VisualVideoTextDuringTrain',
}
cfg.visual_inference = {
'type': 'VisualGeneratedVideos',
}
cfg.inference_list_path = ''
# logging
cfg.log_interval = 100
# Default log_dir
cfg.log_dir = 'workspace/output_data'
# -----------------------------------------------------------------------------
# ---------------------------Others--------------------------------------------
# seed
cfg.seed = 8888
# -----------------------------------------------------------------------------

View File

@@ -0,0 +1,511 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import math
import torch
__all__ = ['GaussianDiffusion', 'beta_schedule']
def _i(tensor, t, x):
r"""Index tensor using t and format the output according to x.
"""
shape = (x.size(0), ) + (1, ) * (x.ndim - 1)
if tensor.device != x.device:
tensor = tensor.to(x.device)
return tensor[t].view(shape).to(x)
def fn(u):
return math.cos((u + 0.008) / 1.008 * math.pi / 2)**2
def beta_schedule(schedule,
num_timesteps=1000,
init_beta=None,
last_beta=None):
if schedule == 'linear':
scale = 1000.0 / num_timesteps
init_beta = init_beta or scale * 0.0001
last_beta = last_beta or scale * 0.02
return torch.linspace(
init_beta, last_beta, num_timesteps, dtype=torch.float64)
elif schedule == 'quadratic':
init_beta = init_beta or 0.0015
last_beta = last_beta or 0.0195
return torch.linspace(
init_beta**0.5, last_beta**0.5, num_timesteps,
dtype=torch.float64)**2
elif schedule == 'cosine':
betas = []
for step in range(num_timesteps):
t1 = step / num_timesteps
t2 = (step + 1) / num_timesteps
betas.append(min(1.0 - fn(t2) / fn(t1), 0.999))
return torch.tensor(betas, dtype=torch.float64)
else:
raise ValueError(f'Unsupported schedule: {schedule}')
class GaussianDiffusion(object):
def __init__(self,
betas,
mean_type='eps',
var_type='learned_range',
loss_type='mse',
epsilon=1e-12,
rescale_timesteps=False,
noise_strength=0.0):
# check input
if not isinstance(betas, torch.DoubleTensor):
betas = torch.tensor(betas, dtype=torch.float64)
assert min(betas) > 0 and max(betas) <= 1
assert mean_type in ['x0', 'x_{t-1}', 'eps', 'v']
assert var_type in [
'learned', 'learned_range', 'fixed_large', 'fixed_small'
]
assert loss_type in [
'mse', 'rescaled_mse', 'kl', 'rescaled_kl', 'l1', 'rescaled_l1',
'charbonnier'
]
self.betas = betas
self.num_timesteps = len(betas)
self.mean_type = mean_type
self.var_type = var_type
self.loss_type = loss_type
self.epsilon = epsilon
self.rescale_timesteps = rescale_timesteps
self.noise_strength = noise_strength
# alphas
alphas = 1 - self.betas
self.alphas_cumprod = torch.cumprod(alphas, dim=0)
self.alphas_cumprod_prev = torch.cat(
[alphas.new_ones([1]), self.alphas_cumprod[:-1]])
self.alphas_cumprod_next = torch.cat(
[self.alphas_cumprod[1:],
alphas.new_zeros([1])])
# q(x_t | x_{t-1})
self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0
- self.alphas_cumprod)
self.log_one_minus_alphas_cumprod = torch.log(1.0
- self.alphas_cumprod)
self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod)
self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod
- 1)
# q(x_{t-1} | x_t, x_0)
self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / (
1.0 - self.alphas_cumprod)
self.posterior_log_variance_clipped = torch.log(
self.posterior_variance.clamp(1e-20))
self.posterior_mean_coef1 = betas * torch.sqrt(
self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
self.posterior_mean_coef2 = (
1.0 - self.alphas_cumprod_prev) * torch.sqrt(alphas) / (
1.0 - self.alphas_cumprod)
def sample_loss(self, x0, noise=None):
if noise is None:
noise = torch.randn_like(x0)
if self.noise_strength > 0:
b, c, f, _, _ = x0.shape
offset_noise = torch.randn(b, c, f, 1, 1, device=x0.device)
noise = noise + self.noise_strength * offset_noise
return noise
def q_sample(self, x0, t, noise=None):
r"""Sample from q(x_t | x_0).
"""
# noise = torch.randn_like(x0) if noise is None else noise
noise = self.sample_loss(x0, noise)
return _i(self.sqrt_alphas_cumprod, t, x0) * x0 + (
_i(self.sqrt_one_minus_alphas_cumprod, t, x0) * noise)
def q_mean_variance(self, x0, t):
r"""Distribution of q(x_t | x_0).
"""
mu = _i(self.sqrt_alphas_cumprod, t, x0) * x0
var = _i(1.0 - self.alphas_cumprod, t, x0)
log_var = _i(self.log_one_minus_alphas_cumprod, t, x0)
return mu, var, log_var
def q_posterior_mean_variance(self, x0, xt, t):
r"""Distribution of q(x_{t-1} | x_t, x_0).
"""
mu = _i(self.posterior_mean_coef1, t, xt) * x0 + _i(
self.posterior_mean_coef2, t, xt) * xt
var = _i(self.posterior_variance, t, xt)
log_var = _i(self.posterior_log_variance_clipped, t, xt)
return mu, var, log_var
@torch.no_grad()
def p_sample(self,
xt,
t,
model,
model_kwargs={},
clamp=None,
percentile=None,
condition_fn=None,
guide_scale=None):
r"""Sample from p(x_{t-1} | x_t).
- condition_fn: for classifier-based guidance (guided-diffusion).
- guide_scale: for classifier-free guidance (glide/dalle-2).
"""
# predict distribution of p(x_{t-1} | x_t)
mu, var, log_var, x0 = self.p_mean_variance(xt, t, model, model_kwargs,
clamp, percentile,
guide_scale)
# random sample (with optional conditional function)
noise = torch.randn_like(xt)
mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1)))
if condition_fn is not None:
grad = condition_fn(xt, self._scale_timesteps(t), **model_kwargs)
mu = mu.float() + var * grad.float()
xt_1 = mu + mask * torch.exp(0.5 * log_var) * noise
return xt_1, x0
@torch.no_grad()
def p_sample_loop(self,
noise,
model,
model_kwargs={},
clamp=None,
percentile=None,
condition_fn=None,
guide_scale=None):
r"""Sample from p(x_{t-1} | x_t) p(x_{t-2} | x_{t-1}) ... p(x_0 | x_1).
"""
# prepare input
b = noise.size(0)
xt = noise
# diffusion process
for step in torch.arange(self.num_timesteps).flip(0):
t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
xt, _ = self.p_sample(xt, t, model, model_kwargs, clamp,
percentile, condition_fn, guide_scale)
return xt
def p_mean_variance(self,
xt,
t,
model,
model_kwargs={},
clamp=None,
percentile=None,
guide_scale=None):
r"""Distribution of p(x_{t-1} | x_t).
"""
# predict distribution
if guide_scale is None:
out = model(xt, self._scale_timesteps(t), **model_kwargs)
else:
# classifier-free guidance
# (model_kwargs[0]: conditional kwargs; model_kwargs[1]: non-conditional kwargs)
assert isinstance(model_kwargs, list) and len(model_kwargs) == 2
y_out = model(xt, self._scale_timesteps(t), **model_kwargs[0])
u_out = model(xt, self._scale_timesteps(t), **model_kwargs[1])
dim = y_out.size(1) if self.var_type.startswith(
'fixed') else y_out.size(1) // 2
out = torch.cat(
[
u_out[:, :dim] + guide_scale * # noqa
(y_out[:, :dim] - u_out[:, :dim]),
y_out[:, dim:]
],
dim=1)
# compute variance
if self.var_type == 'learned':
out, log_var = out.chunk(2, dim=1)
var = torch.exp(log_var)
elif self.var_type == 'learned_range':
out, fraction = out.chunk(2, dim=1)
min_log_var = _i(self.posterior_log_variance_clipped, t, xt)
max_log_var = _i(torch.log(self.betas), t, xt)
fraction = (fraction + 1) / 2.0
log_var = fraction * max_log_var + (1 - fraction) * min_log_var
var = torch.exp(log_var)
elif self.var_type == 'fixed_large':
var = _i(
torch.cat([self.posterior_variance[1:2], self.betas[1:]]), t,
xt)
log_var = torch.log(var)
elif self.var_type == 'fixed_small':
var = _i(self.posterior_variance, t, xt)
log_var = _i(self.posterior_log_variance_clipped, t, xt)
# compute mean and x0
if self.mean_type == 'x_{t-1}':
mu = out
x0 = _i(1.0 / self.posterior_mean_coef1, t, xt) * mu - (
_i(self.posterior_mean_coef2 / self.posterior_mean_coef1, t,
xt) * xt)
elif self.mean_type == 'x0':
x0 = out
mu, _, _ = self.q_posterior_mean_variance(x0, xt, t)
elif self.mean_type == 'eps':
x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - (
_i(self.sqrt_recipm1_alphas_cumprod, t, xt) * out)
mu, _, _ = self.q_posterior_mean_variance(x0, xt, t)
elif self.mean_type == 'v':
x0 = _i(self.sqrt_alphas_cumprod, t, xt) * xt - (
_i(self.sqrt_one_minus_alphas_cumprod, t, xt) * out)
mu, _, _ = self.q_posterior_mean_variance(x0, xt, t)
# restrict the range of x0
if percentile is not None:
assert percentile > 0 and percentile <= 1
s = torch.quantile(
x0.flatten(1).abs(), percentile,
dim=1).clamp_(1.0).view(-1, 1, 1, 1)
x0 = torch.min(s, torch.max(-s, x0)) / s
elif clamp is not None:
x0 = x0.clamp(-clamp, clamp)
return mu, var, log_var, x0
@torch.no_grad()
def ddim_sample(self,
xt,
t,
model,
model_kwargs={},
clamp=None,
percentile=None,
condition_fn=None,
guide_scale=None,
ddim_timesteps=20,
eta=0.0):
r"""Sample from p(x_{t-1} | x_t) using DDIM.
- condition_fn: for classifier-based guidance (guided-diffusion).
- guide_scale: for classifier-free guidance (glide/dalle-2).
"""
stride = self.num_timesteps // ddim_timesteps
# predict distribution of p(x_{t-1} | x_t)
_, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp,
percentile, guide_scale)
if condition_fn is not None:
# x0 -> eps
alpha = _i(self.alphas_cumprod, t, xt)
eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / (
_i(self.sqrt_recipm1_alphas_cumprod, t, xt))
eps = eps - (1 - alpha).sqrt() * condition_fn(
xt, self._scale_timesteps(t), **model_kwargs)
# eps -> x0
x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - (
_i(self.sqrt_recipm1_alphas_cumprod, t, xt) * eps)
# derive variables
eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / (
_i(self.sqrt_recipm1_alphas_cumprod, t, xt))
alphas = _i(self.alphas_cumprod, t, xt)
alphas_prev = _i(self.alphas_cumprod, (t - stride).clamp(0), xt)
sigmas = eta * torch.sqrt((1 - alphas_prev) / (1 - alphas) * # noqa
(1 - alphas / alphas_prev))
# random sample
noise = torch.randn_like(xt)
direction = torch.sqrt(1 - alphas_prev - sigmas**2) * eps
mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1)))
xt_1 = torch.sqrt(alphas_prev) * x0 + direction + mask * sigmas * noise
return xt_1, x0
@torch.no_grad()
def ddim_sample_loop(self,
noise,
model,
model_kwargs={},
clamp=None,
percentile=None,
condition_fn=None,
guide_scale=None,
ddim_timesteps=20,
eta=0.0):
# prepare input
b = noise.size(0)
xt = noise
# diffusion process (TODO: clamp is inaccurate! Consider replacing the stride by explicit prev/next steps)
steps = (1 + torch.arange(0, self.num_timesteps,
self.num_timesteps // ddim_timesteps)).clamp(
0, self.num_timesteps - 1).flip(0)
for step in steps:
t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
xt, _ = self.ddim_sample(xt, t, model, model_kwargs, clamp,
percentile, condition_fn, guide_scale,
ddim_timesteps, eta)
return xt
@torch.no_grad()
def ddim_reverse_sample(self,
xt,
t,
model,
model_kwargs={},
clamp=None,
percentile=None,
guide_scale=None,
ddim_timesteps=20):
r"""Sample from p(x_{t+1} | x_t) using DDIM reverse ODE (deterministic).
"""
stride = self.num_timesteps // ddim_timesteps
# predict distribution of p(x_{t-1} | x_t)
_, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp,
percentile, guide_scale)
# derive variables
eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / (
_i(self.sqrt_recipm1_alphas_cumprod, t, xt))
alphas_next = _i(
torch.cat(
[self.alphas_cumprod,
self.alphas_cumprod.new_zeros([1])]),
(t + stride).clamp(0, self.num_timesteps), xt)
# reverse sample
mu = torch.sqrt(alphas_next) * x0 + torch.sqrt(1 - alphas_next) * eps
return mu, x0
@torch.no_grad()
def ddim_reverse_sample_loop(self,
x0,
model,
model_kwargs={},
clamp=None,
percentile=None,
guide_scale=None,
ddim_timesteps=20):
# prepare input
b = x0.size(0)
xt = x0
# reconstruction steps
steps = torch.arange(0, self.num_timesteps,
self.num_timesteps // ddim_timesteps)
for step in steps:
t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
xt, _ = self.ddim_reverse_sample(xt, t, model, model_kwargs, clamp,
percentile, guide_scale,
ddim_timesteps)
return xt
@torch.no_grad()
def plms_sample(self,
xt,
t,
model,
model_kwargs={},
clamp=None,
percentile=None,
condition_fn=None,
guide_scale=None,
plms_timesteps=20):
r"""Sample from p(x_{t-1} | x_t) using PLMS.
- condition_fn: for classifier-based guidance (guided-diffusion).
- guide_scale: for classifier-free guidance (glide/dalle-2).
"""
stride = self.num_timesteps // plms_timesteps
# function for compute eps
def compute_eps(xt, t):
# predict distribution of p(x_{t-1} | x_t)
_, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs,
clamp, percentile, guide_scale)
# condition
if condition_fn is not None:
# x0 -> eps
alpha = _i(self.alphas_cumprod, t, xt)
eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / (
_i(self.sqrt_recipm1_alphas_cumprod, t, xt))
eps = eps - (1 - alpha).sqrt() * condition_fn(
xt, self._scale_timesteps(t), **model_kwargs)
# eps -> x0
x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - (
_i(self.sqrt_recipm1_alphas_cumprod, t, xt) * eps)
# derive eps
eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / (
_i(self.sqrt_recipm1_alphas_cumprod, t, xt))
return eps
# function for compute x_0 and x_{t-1}
def compute_x0(eps, t):
# eps -> x0
x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - (
_i(self.sqrt_recipm1_alphas_cumprod, t, xt) * eps)
# deterministic sample
alphas_prev = _i(self.alphas_cumprod, (t - stride).clamp(0), xt)
direction = torch.sqrt(1 - alphas_prev) * eps
xt_1 = torch.sqrt(alphas_prev) * x0 + direction
return xt_1, x0
# PLMS sample
eps = compute_eps(xt, t)
if len(eps_cache) == 0:
# 2nd order pseudo improved Euler
xt_1, x0 = compute_x0(eps, t)
eps_next = compute_eps(xt_1, (t - stride).clamp(0))
eps_prime = (eps + eps_next) / 2.0
elif len(eps_cache) == 1:
# 2nd order pseudo linear multistep (Adams-Bashforth)
eps_prime = (3 * eps - eps_cache[-1]) / 2.0
elif len(eps_cache) == 2:
# 3nd order pseudo linear multistep (Adams-Bashforth)
eps_prime = (23 * eps - 16 * eps_cache[-1]
+ 5 * eps_cache[-2]) / 12.0
elif len(eps_cache) >= 3:
# 4nd order pseudo linear multistep (Adams-Bashforth)
eps_prime = (55 * eps - 59 * eps_cache[-1] + 37 * eps_cache[-2]
- 9 * eps_cache[-3]) / 24.0
xt_1, x0 = compute_x0(eps_prime, t)
return xt_1, x0, eps
@torch.no_grad()
def plms_sample_loop(self,
noise,
model,
model_kwargs={},
clamp=None,
percentile=None,
condition_fn=None,
guide_scale=None,
plms_timesteps=20):
# prepare input
b = noise.size(0)
xt = noise
# diffusion process
steps = (1 + torch.arange(0, self.num_timesteps,
self.num_timesteps // plms_timesteps)).clamp(
0, self.num_timesteps - 1).flip(0)
eps_cache = []
for step in steps:
# PLMS sampling step
t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
xt, _, eps = self.plms_sample(xt, t, model, model_kwargs, clamp,
percentile, condition_fn,
guide_scale, plms_timesteps,
eps_cache)
# update eps cache
eps_cache.append(eps)
if len(eps_cache) >= 4:
eps_cache.pop(0)
return xt
def _scale_timesteps(self, t):
if self.rescale_timesteps:
return t.float() * 1000.0 / self.num_timesteps
return t

View File

@@ -0,0 +1,14 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import random
import numpy as np
import torch
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True

View File

@@ -0,0 +1,60 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import math
import torch
def fn(u):
return math.cos((u + 0.008) / 1.008 * math.pi / 2)**2
def beta_schedule(schedule,
num_timesteps=1000,
init_beta=None,
last_beta=None):
'''
This code defines a function beta_schedule that generates a sequence of beta values based on the given input
parameters. These beta values can be used in video diffusion processes. The function has the following parameters:
schedule(str): Determines the type of beta schedule to be generated. It can be 'linear', 'linear_sd',
'quadratic', or 'cosine'.
num_timesteps(int, optional): The number of timesteps for the generated beta schedule. Default is 1000.
init_beta(float, optional): The initial beta value. If not provided, a default value is used based on the
chosen schedule.
last_beta(float, optional): The final beta value. If not provided, a default value is used based on the
chosen schedule.
The function returns a PyTorch tensor containing the generated beta values.
The beta schedule is determined by the schedule parameter:
1.Linear: Generates a linear sequence of beta values betweeninit_betaandlast_beta.
2.Linear_sd: Generates a linear sequence of beta values between the square root of init_beta and the square root
oflast_beta, and then squares the result.
3.Quadratic: Similar to the 'linear_sd' schedule, but with different default values forinit_betaandlast_beta.
4.Cosine: Generates a sequence of beta values based on a cosine function, ensuring the values are between 0
and 0.999.
If an unsupported schedule is provided, a ValueError is raised with a message indicating the issue.
'''
if schedule == 'linear':
scale = 1000.0 / num_timesteps
init_beta = init_beta or scale * 0.0001
last_beta = last_beta or scale * 0.02
return torch.linspace(
init_beta, last_beta, num_timesteps, dtype=torch.float64)
elif schedule == 'linear_sd':
return torch.linspace(
init_beta**0.5, last_beta**0.5, num_timesteps,
dtype=torch.float64)**2
elif schedule == 'quadratic':
init_beta = init_beta or 0.0015
last_beta = last_beta or 0.0195
return torch.linspace(
init_beta**0.5, last_beta**0.5, num_timesteps,
dtype=torch.float64)**2
elif schedule == 'cosine':
betas = []
for step in range(num_timesteps):
t1 = step / num_timesteps
t2 = (step + 1) / num_timesteps
betas.append(min(1.0 - fn(t2) / fn(t1), 0.999))
return torch.tensor(betas, dtype=torch.float64)
else:
raise ValueError(f'Unsupported schedule: {schedule}')

View File

@@ -0,0 +1,404 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import math
import random
import numpy as np
import torch
import torchvision.transforms.functional as F
from PIL import Image, ImageFilter
__all__ = [
'Compose', 'Resize', 'Rescale', 'CenterCrop', 'CenterCropV2',
'CenterCropWide', 'RandomCrop', 'RandomCropV2', 'RandomHFlip',
'GaussianBlur', 'ColorJitter', 'RandomGray', 'ToTensor', 'Normalize',
'ResizeRandomCrop', 'ExtractResizeRandomCrop', 'ExtractResizeAssignCrop'
]
class Compose(object):
def __init__(self, transforms):
self.transforms = transforms
def __getitem__(self, index):
if isinstance(index, slice):
return Compose(self.transforms[index])
else:
return self.transforms[index]
def __len__(self):
return len(self.transforms)
def __call__(self, rgb):
for t in self.transforms:
rgb = t(rgb)
return rgb
class Resize(object):
def __init__(self, size=256):
if isinstance(size, int):
size = (size, size)
self.size = size
def __call__(self, rgb):
if isinstance(rgb, list):
rgb = [u.resize(self.size, Image.BILINEAR) for u in rgb]
else:
rgb = rgb.resize(self.size, Image.BILINEAR)
return rgb
class Rescale(object):
def __init__(self, size=256, interpolation=Image.BILINEAR):
self.size = size
self.interpolation = interpolation
def __call__(self, rgb):
w, h = rgb[0].size
scale = self.size / min(w, h)
out_w, out_h = int(round(w * scale)), int(round(h * scale))
rgb = [u.resize((out_w, out_h), self.interpolation) for u in rgb]
return rgb
class CenterCrop(object):
def __init__(self, size=224):
self.size = size
def __call__(self, rgb):
w, h = rgb[0].size
assert min(w, h) >= self.size
x1 = (w - self.size) // 2
y1 = (h - self.size) // 2
rgb = [u.crop((x1, y1, x1 + self.size, y1 + self.size)) for u in rgb]
return rgb
class ResizeRandomCrop(object):
def __init__(self, size=256, size_short=292):
self.size = size
self.size_short = size_short
def __call__(self, rgb):
# consistent crop between rgb and m
while min(rgb[0].size) >= 2 * self.size_short:
rgb = [
u.resize((u.width // 2, u.height // 2), resample=Image.BOX)
for u in rgb
]
scale = self.size_short / min(rgb[0].size)
rgb = [
u.resize((round(scale * u.width), round(scale * u.height)),
resample=Image.BICUBIC) for u in rgb
]
out_w = self.size
out_h = self.size
w, h = rgb[0].size
x1 = random.randint(0, w - out_w)
y1 = random.randint(0, h - out_h)
rgb = [u.crop((x1, y1, x1 + out_w, y1 + out_h)) for u in rgb]
return rgb
class ExtractResizeRandomCrop(object):
def __init__(self, size=256, size_short=292):
self.size = size
self.size_short = size_short
def __call__(self, rgb):
# consistent crop between rgb and m
while min(rgb[0].size) >= 2 * self.size_short:
rgb = [
u.resize((u.width // 2, u.height // 2), resample=Image.BOX)
for u in rgb
]
scale = self.size_short / min(rgb[0].size)
rgb = [
u.resize((round(scale * u.width), round(scale * u.height)),
resample=Image.BICUBIC) for u in rgb
]
out_w = self.size
out_h = self.size
w, h = rgb[0].size
x1 = random.randint(0, w - out_w)
y1 = random.randint(0, h - out_h)
rgb = [u.crop((x1, y1, x1 + out_w, y1 + out_h)) for u in rgb]
wh = [x1, y1, x1 + out_w, y1 + out_h]
return rgb, wh
class ExtractResizeAssignCrop(object):
def __init__(self, size=256, size_short=292):
self.size = size
self.size_short = size_short
def __call__(self, rgb, wh):
# consistent crop between rgb and m
while min(rgb[0].size) >= 2 * self.size_short:
rgb = [
u.resize((u.width // 2, u.height // 2), resample=Image.BOX)
for u in rgb
]
scale = self.size_short / min(rgb[0].size)
rgb = [
u.resize((round(scale * u.width), round(scale * u.height)),
resample=Image.BICUBIC) for u in rgb
]
rgb = [u.crop(wh) for u in rgb]
rgb = [u.resize((self.size, self.size), Image.BILINEAR) for u in rgb]
return rgb
class CenterCropV2(object):
def __init__(self, size):
self.size = size
def __call__(self, img):
# fast resize
while min(img[0].size) >= 2 * self.size:
img = [
u.resize((u.width // 2, u.height // 2), resample=Image.BOX)
for u in img
]
scale = self.size / min(img[0].size)
img = [
u.resize((round(scale * u.width), round(scale * u.height)),
resample=Image.BICUBIC) for u in img
]
# center crop
x1 = (img[0].width - self.size) // 2
y1 = (img[0].height - self.size) // 2
img = [u.crop((x1, y1, x1 + self.size, y1 + self.size)) for u in img]
return img
class CenterCropWide(object):
def __init__(self, size):
self.size = size
def __call__(self, img):
if isinstance(img, list):
scale = min(img[0].size[0] / self.size[0],
img[0].size[1] / self.size[1])
img = [
u.resize((round(u.width // scale), round(u.height // scale)),
resample=Image.BOX) for u in img
]
# center crop
x1 = (img[0].width - self.size[0]) // 2
y1 = (img[0].height - self.size[1]) // 2
img = [
u.crop((x1, y1, x1 + self.size[0], y1 + self.size[1]))
for u in img
]
return img
else:
scale = min(img.size[0] / self.size[0], img.size[1] / self.size[1])
img = img.resize(
(round(img.width // scale), round(img.height // scale)),
resample=Image.BOX)
x1 = (img.width - self.size[0]) // 2
y1 = (img.height - self.size[1]) // 2
img = img.crop((x1, y1, x1 + self.size[0], y1 + self.size[1]))
return img
class RandomCrop(object):
def __init__(self, size=224, min_area=0.4):
self.size = size
self.min_area = min_area
def __call__(self, rgb):
# consistent crop between rgb and m
w, h = rgb[0].size
area = w * h
out_w, out_h = float('inf'), float('inf')
while out_w > w or out_h > h:
target_area = random.uniform(self.min_area, 1.0) * area
aspect_ratio = random.uniform(3. / 4., 4. / 3.)
out_w = int(round(math.sqrt(target_area * aspect_ratio)))
out_h = int(round(math.sqrt(target_area / aspect_ratio)))
x1 = random.randint(0, w - out_w)
y1 = random.randint(0, h - out_h)
rgb = [u.crop((x1, y1, x1 + out_w, y1 + out_h)) for u in rgb]
rgb = [u.resize((self.size, self.size), Image.BILINEAR) for u in rgb]
return rgb
class RandomCropV2(object):
def __init__(self, size=224, min_area=0.4, ratio=(3. / 4., 4. / 3.)):
if isinstance(size, (tuple, list)):
self.size = size
else:
self.size = (size, size)
self.min_area = min_area
self.ratio = ratio
def _get_params(self, img):
width, height = img.size
area = height * width
for _ in range(10):
target_area = random.uniform(self.min_area, 1.0) * area
log_ratio = (math.log(self.ratio[0]), math.log(self.ratio[1]))
aspect_ratio = math.exp(random.uniform(*log_ratio))
w = int(round(math.sqrt(target_area * aspect_ratio)))
h = int(round(math.sqrt(target_area / aspect_ratio)))
if 0 < w <= width and 0 < h <= height:
i = random.randint(0, height - h)
j = random.randint(0, width - w)
return i, j, h, w
# Fallback to central crop
in_ratio = float(width) / float(height)
if (in_ratio < min(self.ratio)):
w = width
h = int(round(w / min(self.ratio)))
elif (in_ratio > max(self.ratio)):
h = height
w = int(round(h * max(self.ratio)))
else:
w = width
h = height
i = (height - h) // 2
j = (width - w) // 2
return i, j, h, w
def __call__(self, rgb):
i, j, h, w = self._get_params(rgb[0])
rgb = [F.resized_crop(u, i, j, h, w, self.size) for u in rgb]
return rgb
class RandomHFlip(object):
def __init__(self, p=0.5):
self.p = p
def __call__(self, rgb):
if random.random() < self.p:
rgb = [u.transpose(Image.FLIP_LEFT_RIGHT) for u in rgb]
return rgb
class GaussianBlur(object):
def __init__(self, sigmas=[0.1, 2.0], p=0.5):
self.sigmas = sigmas
self.p = p
def __call__(self, rgb):
if random.random() < self.p:
sigma = random.uniform(*self.sigmas)
rgb = [
u.filter(ImageFilter.GaussianBlur(radius=sigma)) for u in rgb
]
return rgb
class ColorJitter(object):
def __init__(self,
brightness=0.4,
contrast=0.4,
saturation=0.4,
hue=0.1,
p=0.5):
self.brightness = brightness
self.contrast = contrast
self.saturation = saturation
self.hue = hue
self.p = p
def __call__(self, rgb):
if random.random() < self.p:
brightness, contrast, saturation, hue = self._random_params()
transforms = [
lambda f: F.adjust_brightness(f, brightness),
lambda f: F.adjust_contrast(f, contrast),
lambda f: F.adjust_saturation(f, saturation),
lambda f: F.adjust_hue(f, hue)
]
random.shuffle(transforms)
for t in transforms:
rgb = [t(u) for u in rgb]
return rgb
def _random_params(self):
brightness = random.uniform(
max(0, 1 - self.brightness), 1 + self.brightness)
contrast = random.uniform(max(0, 1 - self.contrast), 1 + self.contrast)
saturation = random.uniform(
max(0, 1 - self.saturation), 1 + self.saturation)
hue = random.uniform(-self.hue, self.hue)
return brightness, contrast, saturation, hue
class RandomGray(object):
def __init__(self, p=0.2):
self.p = p
def __call__(self, rgb):
if random.random() < self.p:
rgb = [u.convert('L').convert('RGB') for u in rgb]
return rgb
class ToTensor(object):
def __call__(self, rgb):
if isinstance(rgb, list):
rgb = torch.stack([F.to_tensor(u) for u in rgb], dim=0)
else:
rgb = F.to_tensor(rgb)
return rgb
class Normalize(object):
def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
self.mean = mean
self.std = std
def __call__(self, rgb):
rgb = rgb.clone()
rgb.clamp_(0, 1)
if not isinstance(self.mean, torch.Tensor):
self.mean = rgb.new_tensor(self.mean).view(-1)
if not isinstance(self.std, torch.Tensor):
self.std = rgb.new_tensor(self.std).view(-1)
if rgb.dim() == 4:
rgb.sub_(self.mean.view(1, -1, 1,
1)).div_(self.std.view(1, -1, 1, 1))
elif rgb.dim() == 3:
rgb.sub_(self.mean.view(-1, 1, 1)).div_(self.std.view(-1, 1, 1))
return rgb

View File

@@ -39,7 +39,7 @@ class StableDiffusion(TorchModel):
self.lora_tune = kwargs.pop('lora_tune', False)
self.dreambooth_tune = kwargs.pop('dreambooth_tune', False)
self.weight_dtype = torch.float32
self.weight_dtype = kwargs.pop('torch_type', torch.float32)
self.device = torch.device(
'cuda' if torch.cuda.is_available() else 'cpu')
@@ -59,14 +59,15 @@ class StableDiffusion(TorchModel):
# Freeze gradient calculation and move to device
if self.vae is not None:
self.vae.requires_grad_(False)
self.vae = self.vae.to(self.device)
self.vae = self.vae.to(self.device, dtype=self.weight_dtype)
if self.text_encoder is not None:
self.text_encoder.requires_grad_(False)
self.text_encoder = self.text_encoder.to(self.device)
self.text_encoder = self.text_encoder.to(
self.device, dtype=self.weight_dtype)
if self.unet is not None:
if self.lora_tune:
self.unet.requires_grad_(False)
self.unet = self.unet.to(self.device)
self.unet = self.unet.to(self.device, dtype=self.weight_dtype)
# xformers accelerate memory efficient attention
if xformers_enable:

View File

@@ -0,0 +1,24 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import TYPE_CHECKING
from modelscope.utils.import_utils import LazyImportModule
if TYPE_CHECKING:
from .video_to_video_model import VideoToVideo
else:
_import_structure = {
'video_to_video_model': ['VideoToVideo'],
}
import sys
sys.modules[__name__] = LazyImportModule(
__name__,
globals()['__file__'],
_import_structure,
module_spec=__spec__,
extra_objects={},
)

View File

@@ -0,0 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from .autoencoder import *
from .embedder import *
from .unet_v2v import *

View File

@@ -0,0 +1,590 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import collections
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from modelscope.utils.logger import get_logger
logger = get_logger()
def nonlinearity(x):
# swish
return x * torch.sigmoid(x)
def Normalize(in_channels, num_groups=32):
return torch.nn.GroupNorm(
num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
@torch.no_grad()
def get_first_stage_encoding(encoder_posterior):
scale_factor = 0.18215
if isinstance(encoder_posterior, DiagonalGaussianDistribution):
z = encoder_posterior.sample()
elif isinstance(encoder_posterior, torch.Tensor):
z = encoder_posterior
else:
raise NotImplementedError(
f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented"
)
return scale_factor * z
class DiagonalGaussianDistribution(object):
def __init__(self, parameters, deterministic=False):
self.parameters = parameters
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.deterministic = deterministic
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
if self.deterministic:
self.var = self.std = torch.zeros_like(
self.mean).to(device=self.parameters.device)
def sample(self):
x = self.mean + self.std * torch.randn(
self.mean.shape).to(device=self.parameters.device)
return x
def kl(self, other=None):
if self.deterministic:
return torch.Tensor([0.])
else:
if other is None:
return 0.5 * torch.sum(
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
dim=[1, 2, 3])
else:
return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
dim=[1, 2, 3])
def nll(self, sample, dims=[1, 2, 3]):
if self.deterministic:
return torch.Tensor([0.])
logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum(
logtwopi + self.logvar
+ torch.pow(sample - self.mean, 2) / self.var,
dim=dims)
def mode(self):
return self.mean
class ResnetBlock(nn.Module):
def __init__(self,
*,
in_channels,
out_channels=None,
conv_shortcut=False,
dropout,
temb_channels=512):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.norm1 = Normalize(in_channels)
self.conv1 = torch.nn.Conv2d(
in_channels, out_channels, kernel_size=3, stride=1, padding=1)
if temb_channels > 0:
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
self.norm2 = Normalize(out_channels)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = torch.nn.Conv2d(
out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = torch.nn.Conv2d(
in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
else:
self.nin_shortcut = torch.nn.Conv2d(
in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0)
def forward(self, x, temb):
h = x
h = self.norm1(h)
h = nonlinearity(h)
h = self.conv1(h)
if temb is not None:
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
h = self.norm2(h)
h = nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)
return x + h
class AttnBlock(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.k = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
b, c, h, w = q.shape
q = q.reshape(b, c, h * w)
q = q.permute(0, 2, 1)
k = k.reshape(b, c, h * w)
w_ = torch.bmm(q, k)
w_ = w_ * (int(c)**(-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)
# attend to values
v = v.reshape(b, c, h * w)
w_ = w_.permute(0, 2, 1)
h_ = torch.bmm(v, w_)
h_ = h_.reshape(b, c, h, w)
h_ = self.proj_out(h_)
return x + h_
class Upsample(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
self.conv = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=1, padding=1)
def forward(self, x):
x = torch.nn.functional.interpolate(
x, scale_factor=2.0, mode='nearest')
if self.with_conv:
x = self.conv(x)
return x
class Downsample(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
# no asymmetric padding in torch conv, must do it ourselves
self.conv = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=2, padding=0)
def forward(self, x):
if self.with_conv:
pad = (0, 1, 0, 1)
x = torch.nn.functional.pad(x, pad, mode='constant', value=0)
x = self.conv(x)
else:
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
return x
class Encoder(nn.Module):
def __init__(self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
z_channels,
double_z=True,
use_linear_attn=False,
attn_type='vanilla',
**ignore_kwargs):
super().__init__()
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
# downsampling
self.conv_in = torch.nn.Conv2d(
in_channels, self.ch, kernel_size=3, stride=1, padding=1)
curr_res = resolution
in_ch_mult = (1, ) + tuple(ch_mult)
self.in_ch_mult = in_ch_mult
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks):
block.append(
ResnetBlock(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(AttnBlock(block_in))
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions - 1:
down.downsample = Downsample(block_in, resamp_with_conv)
curr_res = curr_res // 2
self.down.append(down)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
self.mid.attn_1 = AttnBlock(block_in)
self.mid.block_2 = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(
block_in,
2 * z_channels if double_z else z_channels,
kernel_size=3,
stride=1,
padding=1)
def forward(self, x):
# timestep embedding
temb = None
# downsampling
hs = [self.conv_in(x)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1], temb)
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
hs.append(h)
if i_level != self.num_resolutions - 1:
hs.append(self.down[i_level].downsample(hs[-1]))
# middle
h = hs[-1]
h = self.mid.block_1(h, temb)
h = self.mid.attn_1(h)
h = self.mid.block_2(h, temb)
# end
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
return h
class Decoder(nn.Module):
def __init__(self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
z_channels,
give_pre_end=False,
tanh_out=False,
use_linear_attn=False,
attn_type='vanilla',
**ignorekwargs):
super().__init__()
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.give_pre_end = give_pre_end
self.tanh_out = tanh_out
# compute block_in and curr_res at lowest res
block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2**(self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res)
logger.info('Working with z of shape {} = {} dimensions.'.format(
self.z_shape, np.prod(self.z_shape)))
# z to block_in
self.conv_in = torch.nn.Conv2d(
z_channels, block_in, kernel_size=3, stride=1, padding=1)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
self.mid.attn_1 = AttnBlock(block_in)
self.mid.block_2 = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
block.append(
ResnetBlock(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(AttnBlock(block_in))
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = Upsample(block_in, resamp_with_conv)
curr_res = curr_res * 2
self.up.insert(0, up)
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(
block_in, out_ch, kernel_size=3, stride=1, padding=1)
def forward(self, z):
self.last_z_shape = z.shape
# timestep embedding
temb = None
# z to block_in
h = self.conv_in(z)
# middle
h = self.mid.block_1(h, temb)
h = self.mid.attn_1(h)
h = self.mid.block_2(h, temb)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](h, temb)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
if i_level != 0:
h = self.up[i_level].upsample(h)
# end
if self.give_pre_end:
return h
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
if self.tanh_out:
h = torch.tanh(h)
return h
class AutoencoderKL(nn.Module):
def __init__(self,
ddconfig,
embed_dim,
pretrained=None,
ignore_keys=[],
image_key='image',
colorize_nlabels=None,
monitor=None,
ema_decay=None,
learn_logvar=False,
**kwargs):
super().__init__()
self.learn_logvar = learn_logvar
self.image_key = image_key
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
assert ddconfig['double_z']
self.quant_conv = torch.nn.Conv2d(2 * ddconfig['z_channels'],
2 * embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim,
ddconfig['z_channels'], 1)
self.embed_dim = embed_dim
if colorize_nlabels is not None:
assert type(colorize_nlabels) == int
self.register_buffer('colorize',
torch.randn(3, colorize_nlabels, 1, 1))
if monitor is not None:
self.monitor = monitor
self.use_ema = ema_decay is not None
if pretrained is not None:
self.init_from_ckpt(pretrained, ignore_keys=ignore_keys)
def init_from_ckpt(self, path, ignore_keys=list()):
sd = torch.load(path, map_location='cpu')['state_dict']
keys = list(sd.keys())
sd_new = collections.OrderedDict()
for k in keys:
if k.find('first_stage_model') >= 0:
k_new = k.split('first_stage_model.')[-1]
sd_new[k_new] = sd[k]
self.load_state_dict(sd_new, strict=True)
logger.info(f'Restored from {path}')
def on_train_batch_end(self, *args, **kwargs):
if self.use_ema:
self.model_ema(self)
def encode(self, x):
h = self.encoder(x)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
return posterior
def decode(self, z):
z = self.post_quant_conv(z)
dec = self.decoder(z)
return dec
def forward(self, input, sample_posterior=True):
posterior = self.encode(input)
if sample_posterior:
z = posterior.sample()
else:
z = posterior.mode()
dec = self.decode(z)
return dec, posterior
def get_input(self, batch, k):
x = batch[k]
if len(x.shape) == 3:
x = x[..., None]
x = x.permute(0, 3, 1,
2).to(memory_format=torch.contiguous_format).float()
return x
def get_last_layer(self):
return self.decoder.conv_out.weight
@torch.no_grad()
def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs):
log = dict()
x = self.get_input(batch, self.image_key)
x = x.to(self.device)
if not only_inputs:
xrec, posterior = self(x)
if x.shape[1] > 3:
# colorize with random projection
assert xrec.shape[1] > 3
x = self.to_rgb(x)
xrec = self.to_rgb(xrec)
log['samples'] = self.decode(torch.randn_like(posterior.sample()))
log['reconstructions'] = xrec
if log_ema or self.use_ema:
with self.ema_scope():
xrec_ema, posterior_ema = self(x)
if x.shape[1] > 3:
# colorize with random projection
assert xrec_ema.shape[1] > 3
xrec_ema = self.to_rgb(xrec_ema)
log['samples_ema'] = self.decode(
torch.randn_like(posterior_ema.sample()))
log['reconstructions_ema'] = xrec_ema
log['inputs'] = x
return log
def to_rgb(self, x):
assert self.image_key == 'segmentation'
if not hasattr(self, 'colorize'):
self.register_buffer('colorize',
torch.randn(3, x.shape[1], 1, 1).to(x))
x = F.conv2d(x, weight=self.colorize)
x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
return x
class IdentityFirstStage(torch.nn.Module):
def __init__(self, *args, vq_interface=False, **kwargs):
self.vq_interface = vq_interface
super().__init__()
def encode(self, x, *args, **kwargs):
return x
def decode(self, x, *args, **kwargs):
return x
def quantize(self, x, *args, **kwargs):
if self.vq_interface:
return x, None, [None, None, None]
return x
def forward(self, x, *args, **kwargs):
return x

View File

@@ -0,0 +1,76 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import numpy as np
import open_clip
import torch
import torch.nn as nn
import torchvision.transforms as T
class FrozenOpenCLIPEmbedder(nn.Module):
"""
Uses the OpenCLIP transformer encoder for text
"""
LAYERS = ['last', 'penultimate']
def __init__(self,
pretrained,
arch='ViT-H-14',
device='cuda',
max_length=77,
freeze=True,
layer='penultimate'):
super().__init__()
assert layer in self.LAYERS
model, _, preprocess = open_clip.create_model_and_transforms(
arch, device=torch.device('cpu'), pretrained=pretrained)
del model.visual
self.model = model
self.device = device
self.max_length = max_length
if freeze:
self.freeze()
self.layer = layer
if self.layer == 'last':
self.layer_idx = 0
elif self.layer == 'penultimate':
self.layer_idx = 1
else:
raise NotImplementedError()
def freeze(self):
self.model = self.model.eval()
for param in self.parameters():
param.requires_grad = False
def forward(self, text):
tokens = open_clip.tokenize(text)
z = self.encode_with_transformer(tokens.to(self.device))
return z
def encode_with_transformer(self, text):
x = self.model.token_embedding(text)
x = x + self.model.positional_embedding
x = x.permute(1, 0, 2)
x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
x = x.permute(1, 0, 2)
x = self.model.ln_final(x)
return x
def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
for i, r in enumerate(self.model.transformer.resblocks):
if i == len(self.model.transformer.resblocks) - self.layer_idx:
break
if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting(
):
x = checkpoint(r, x, attn_mask)
else:
x = r(x, attn_mask=attn_mask)
return x
def encode(self, text):
return self(text)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,2 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os

View File

@@ -0,0 +1,171 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import logging
import os
import os.path as osp
from datetime import datetime
import torch
from easydict import EasyDict
cfg = EasyDict(__name__='Config: VideoLDM Decoder')
# ---------------------------work dir--------------------------
cfg.work_dir = 'workspace/'
# ---------------------------Global Variable-----------------------------------
cfg.resolution = [448, 256]
cfg.max_frames = 32
# -----------------------------------------------------------------------------
# ---------------------------Dataset Parameter---------------------------------
cfg.mean = [0.5, 0.5, 0.5]
cfg.std = [0.5, 0.5, 0.5]
cfg.max_words = 1000
# PlaceHolder
cfg.vit_out_dim = 1024
cfg.vit_resolution = [224, 224]
cfg.depth_clamp = 10.0
cfg.misc_size = 384
cfg.depth_std = 20.0
cfg.frame_lens = 32
cfg.sample_fps = 8
cfg.batch_sizes = 1
# -----------------------------------------------------------------------------
# ---------------------------Mode Parameters-----------------------------------
# Diffusion
cfg.schedule = 'cosine'
cfg.num_timesteps = 1000
cfg.mean_type = 'v'
cfg.var_type = 'fixed_small'
cfg.loss_type = 'mse'
cfg.ddim_timesteps = 50
cfg.ddim_eta = 0.0
cfg.clamp = 1.0
cfg.share_noise = False
cfg.use_div_loss = False
cfg.noise_strength = 0.1
# classifier-free guidance
cfg.p_zero = 0.1
cfg.guide_scale = 3.0
# clip vision encoder
cfg.vit_mean = [0.48145466, 0.4578275, 0.40821073]
cfg.vit_std = [0.26862954, 0.26130258, 0.27577711]
# Model
cfg.scale_factor = 0.18215
cfg.use_fp16 = True
cfg.temporal_attention = True
cfg.decoder_bs = 8
cfg.UNet = {
'type': 'Vid2VidSDUNet',
'in_dim': 4,
'dim': 320,
'y_dim': cfg.vit_out_dim,
'context_dim': 1024,
'out_dim': 8 if cfg.var_type.startswith('learned') else 4,
'dim_mult': [1, 2, 4, 4],
'num_heads': 8,
'head_dim': 64,
'num_res_blocks': 2,
'attn_scales': [1 / 1, 1 / 2, 1 / 4],
'dropout': 0.1,
'temporal_attention': cfg.temporal_attention,
'temporal_attn_times': 1,
'use_checkpoint': False,
'use_fps_condition': False,
'use_sim_mask': False,
'num_tokens': 4,
'default_fps': 8,
'input_dim': 1024
}
cfg.guidances = []
# auotoencoder from stabel diffusion
cfg.auto_encoder = {
'type': 'AutoencoderKL',
'ddconfig': {
'double_z': True,
'z_channels': 4,
'resolution': 256,
'in_channels': 3,
'out_ch': 3,
'ch': 128,
'ch_mult': [1, 2, 4, 4],
'num_res_blocks': 2,
'attn_resolutions': [],
'dropout': 0.0
},
'embed_dim': 4,
'pretrained': 'models/v2-1_512-ema-pruned.ckpt'
}
# clip embedder
cfg.embedder = {
'type': 'FrozenOpenCLIPEmbedder',
'layer': 'penultimate',
'vit_resolution': [224, 224],
'pretrained': 'open_clip_pytorch_model.bin'
}
# -----------------------------------------------------------------------------
# ---------------------------Training Settings---------------------------------
# training and optimizer
cfg.ema_decay = 0.9999
cfg.num_steps = 600000
cfg.lr = 5e-5
cfg.weight_decay = 0.0
cfg.betas = (0.9, 0.999)
cfg.eps = 1.0e-8
cfg.chunk_size = 16
cfg.alpha = 0.7
cfg.save_ckp_interval = 1000
# -----------------------------------------------------------------------------
# ----------------------------Pretrain Settings---------------------------------
# Default: load 2d pretrain
cfg.fix_weight = False
cfg.load_match = False
cfg.pretrained_checkpoint = 'v2-1_512-ema-pruned.ckpt'
cfg.pretrained_image_keys = 'stable_diffusion_image_key_temporal_attention_x1.json'
cfg.resume_checkpoint = 'img2video_ldm_0779000.pth'
# -----------------------------------------------------------------------------
# -----------------------------Visual-------------------------------------------
# Visual videos
cfg.viz_interval = 1000
cfg.visual_train = {
'type': 'VisualVideoTextDuringTrain',
}
cfg.visual_inference = {
'type': 'VisualGeneratedVideos',
}
cfg.inference_list_path = ''
# logging
cfg.log_interval = 100
# Default log_dir
cfg.log_dir = 'workspace/output_data'
# -----------------------------------------------------------------------------
# ---------------------------Others--------------------------------------------
# seed
cfg.seed = 8888
cfg.negative_prompt = 'worst quality, normal quality, low quality, low res, blurry, text, \
watermark, logo, banner, extra digits, cropped, jpeg artifacts, signature, username, error, \
sketch ,duplicate, ugly, monochrome, horror, geometry, mutation, disgusting'
cfg.positive_prompt = ', cinematic, High Contrast, highly detailed, unreal engine, \
taken using a Canon EOS R camera, hyper detailed photo - realistic maximum detail, \
32k, Color Grading, ultra HD, extreme meticulous detailing, skin pore detailing, \
hyper sharpness, perfect without deformations, Unreal Engine 5, 4k render'
# -----------------------------------------------------------------------------

View File

@@ -0,0 +1,247 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import random
import torch
from .schedules_sdedit import karras_schedule
from .solvers_sdedit import sample_dpmpp_2m_sde, sample_heun
__all__ = ['GaussianDiffusion_SDEdit']
def _i(tensor, t, x):
shape = (x.size(0), ) + (1, ) * (x.ndim - 1)
return tensor[t.to(tensor.device)].view(shape).to(x.device)
class GaussianDiffusion_SDEdit(object):
def __init__(self, sigmas, prediction_type='eps'):
assert prediction_type in {'x0', 'eps', 'v'}
self.sigmas = sigmas
self.alphas = torch.sqrt(1 - sigmas**2)
self.num_timesteps = len(sigmas)
self.prediction_type = prediction_type
def diffuse(self, x0, t, noise=None):
noise = torch.randn_like(x0) if noise is None else noise
xt = _i(self.alphas, t, x0) * x0 + _i(self.sigmas, t, x0) * noise
return xt
def denoise(self,
xt,
t,
s,
model,
model_kwargs={},
guide_scale=None,
guide_rescale=None,
clamp=None,
percentile=None):
s = t - 1 if s is None else s
# hyperparams
sigmas = _i(self.sigmas, t, xt)
alphas = _i(self.alphas, t, xt)
alphas_s = _i(self.alphas, s.clamp(0), xt)
alphas_s[s < 0] = 1.
sigmas_s = torch.sqrt(1 - alphas_s**2)
# precompute variables
betas = 1 - (alphas / alphas_s)**2
coef1 = betas * alphas_s / sigmas**2
coef2 = (alphas * sigmas_s**2) / (alphas_s * sigmas**2)
var = betas * (sigmas_s / sigmas)**2
log_var = torch.log(var).clamp_(-20, 20)
# prediction
if guide_scale is None:
assert isinstance(model_kwargs, dict)
out = model(xt, t=t, **model_kwargs)
else:
# classifier-free guidance
assert isinstance(model_kwargs, list) and len(model_kwargs) == 2
y_out = model(xt, t=t, **model_kwargs[0])
if guide_scale == 1.:
out = y_out
else:
u_out = model(xt, t=t, **model_kwargs[1])
out = u_out + guide_scale * (y_out - u_out)
if guide_rescale is not None:
assert guide_rescale >= 0 and guide_rescale <= 1
ratio = (
y_out.flatten(1).std(dim=1) / # noqa
(out.flatten(1).std(dim=1) + 1e-12)
).view((-1, ) + (1, ) * (y_out.ndim - 1))
out *= guide_rescale * ratio + (1 - guide_rescale) * 1.0
# compute x0
if self.prediction_type == 'x0':
x0 = out
elif self.prediction_type == 'eps':
x0 = (xt - sigmas * out) / alphas
elif self.prediction_type == 'v':
x0 = alphas * xt - sigmas * out
else:
raise NotImplementedError(
f'prediction_type {self.prediction_type} not implemented')
# restrict the range of x0
if percentile is not None:
assert percentile > 0 and percentile <= 1
s = torch.quantile(x0.flatten(1).abs(), percentile, dim=1)
s = s.clamp_(1.0).view((-1, ) + (1, ) * (xt.ndim - 1))
x0 = torch.min(s, torch.max(-s, x0)) / s
elif clamp is not None:
x0 = x0.clamp(-clamp, clamp)
# recompute eps using the restricted x0
eps = (xt - alphas * x0) / sigmas
# compute mu (mean of posterior distribution) using the restricted x0
mu = coef1 * x0 + coef2 * xt
return mu, var, log_var, x0, eps
@torch.no_grad()
def sample(self,
noise,
model,
model_kwargs={},
condition_fn=None,
guide_scale=None,
guide_rescale=None,
clamp=None,
percentile=None,
solver='euler_a',
steps=20,
t_max=None,
t_min=None,
discretization=None,
discard_penultimate_step=None,
return_intermediate=None,
show_progress=False,
seed=-1,
**kwargs):
# sanity check
assert isinstance(steps, (int, torch.LongTensor))
assert t_max is None or (t_max > 0 and t_max <= self.num_timesteps - 1)
assert t_min is None or (t_min >= 0 and t_min < self.num_timesteps - 1)
assert discretization in (None, 'leading', 'linspace', 'trailing')
assert discard_penultimate_step in (None, True, False)
assert return_intermediate in (None, 'x0', 'xt')
# function of diffusion solver
solver_fn = {
'heun': sample_heun,
'dpmpp_2m_sde': sample_dpmpp_2m_sde
}[solver]
# options
schedule = 'karras' if 'karras' in solver else None
discretization = discretization or 'linspace'
seed = seed if seed >= 0 else random.randint(0, 2**31)
if isinstance(steps, torch.LongTensor):
discard_penultimate_step = False
if discard_penultimate_step is None:
discard_penultimate_step = True if solver in (
'dpm2', 'dpm2_ancestral', 'dpmpp_2m_sde', 'dpm2_karras',
'dpm2_ancestral_karras', 'dpmpp_2m_sde_karras') else False
# function for denoising xt to get x0
intermediates = []
def model_fn(xt, sigma):
# denoising
t = self._sigma_to_t(sigma).repeat(len(xt)).round().long()
x0 = self.denoise(xt, t, None, model, model_kwargs, guide_scale,
guide_rescale, clamp, percentile)[-2]
# collect intermediate outputs
if return_intermediate == 'xt':
intermediates.append(xt)
elif return_intermediate == 'x0':
intermediates.append(x0)
return x0
# get timesteps
if isinstance(steps, int):
steps += 1 if discard_penultimate_step else 0
t_max = self.num_timesteps - 1 if t_max is None else t_max
t_min = 0 if t_min is None else t_min
# discretize timesteps
if discretization == 'leading':
steps = torch.arange(t_min, t_max + 1,
(t_max - t_min + 1) / steps).flip(0)
elif discretization == 'linspace':
steps = torch.linspace(t_max, t_min, steps)
elif discretization == 'trailing':
steps = torch.arange(t_max, t_min - 1,
-((t_max - t_min + 1) / steps))
else:
raise NotImplementedError(
f'{discretization} discretization not implemented')
steps = steps.clamp_(t_min, t_max)
steps = torch.as_tensor(
steps, dtype=torch.float32, device=noise.device)
# get sigmas
sigmas = self._t_to_sigma(steps)
sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
if schedule == 'karras':
if sigmas[0] == float('inf'):
sigmas = karras_schedule(
n=len(steps) - 1,
sigma_min=sigmas[sigmas > 0].min().item(),
sigma_max=sigmas[sigmas < float('inf')].max().item(),
rho=7.).to(sigmas)
sigmas = torch.cat([
sigmas.new_tensor([float('inf')]), sigmas,
sigmas.new_zeros([1])
])
else:
sigmas = karras_schedule(
n=len(steps),
sigma_min=sigmas[sigmas > 0].min().item(),
sigma_max=sigmas.max().item(),
rho=7.).to(sigmas)
sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
if discard_penultimate_step:
sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
# sampling
x0 = solver_fn(
noise, model_fn, sigmas, show_progress=show_progress, **kwargs)
return (x0, intermediates) if return_intermediate is not None else x0
def _sigma_to_t(self, sigma):
if sigma == float('inf'):
t = torch.full_like(sigma, len(self.sigmas) - 1)
else:
log_sigmas = torch.sqrt(self.sigmas**2 / # noqa
(1 - self.sigmas**2)).log().to(sigma)
log_sigma = sigma.log()
dists = log_sigma - log_sigmas[:, None]
low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(
max=log_sigmas.shape[0] - 2)
high_idx = low_idx + 1
low, high = log_sigmas[low_idx], log_sigmas[high_idx]
w = (low - log_sigma) / (low - high)
w = w.clamp(0, 1)
t = (1 - w) * low_idx + w * high_idx
t = t.view(sigma.shape)
if t.ndim == 0:
t = t.unsqueeze(0)
return t
def _t_to_sigma(self, t):
t = t.float()
low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac()
log_sigmas = torch.sqrt(self.sigmas**2 / # noqa
(1 - self.sigmas**2)).log().to(t)
log_sigma = (1 - w) * log_sigmas[low_idx] + w * log_sigmas[high_idx]
log_sigma[torch.isnan(log_sigma)
| torch.isinf(log_sigma)] = float('inf')
return log_sigma.exp()

View File

@@ -0,0 +1,85 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import math
import torch
def betas_to_sigmas(betas):
return torch.sqrt(1 - torch.cumprod(1 - betas, dim=0))
def sigmas_to_betas(sigmas):
square_alphas = 1 - sigmas**2
betas = 1 - torch.cat(
[square_alphas[:1], square_alphas[1:] / square_alphas[:-1]])
return betas
def logsnrs_to_sigmas(logsnrs):
return torch.sqrt(torch.sigmoid(-logsnrs))
def sigmas_to_logsnrs(sigmas):
square_sigmas = sigmas**2
return torch.log(square_sigmas / (1 - square_sigmas))
def _logsnr_cosine(n, logsnr_min=-15, logsnr_max=15):
t_min = math.atan(math.exp(-0.5 * logsnr_min))
t_max = math.atan(math.exp(-0.5 * logsnr_max))
t = torch.linspace(1, 0, n)
logsnrs = -2 * torch.log(torch.tan(t_min + t * (t_max - t_min)))
return logsnrs
def _logsnr_cosine_shifted(n, logsnr_min=-15, logsnr_max=15, scale=2):
logsnrs = _logsnr_cosine(n, logsnr_min, logsnr_max)
logsnrs += 2 * math.log(1 / scale)
return logsnrs
def _logsnr_cosine_interp(n,
logsnr_min=-15,
logsnr_max=15,
scale_min=2,
scale_max=4):
t = torch.linspace(1, 0, n)
logsnrs_min = _logsnr_cosine_shifted(n, logsnr_min, logsnr_max, scale_min)
logsnrs_max = _logsnr_cosine_shifted(n, logsnr_min, logsnr_max, scale_max)
logsnrs = t * logsnrs_min + (1 - t) * logsnrs_max
return logsnrs
def karras_schedule(n, sigma_min=0.002, sigma_max=80.0, rho=7.0):
ramp = torch.linspace(1, 0, n)
min_inv_rho = sigma_min**(1 / rho)
max_inv_rho = sigma_max**(1 / rho)
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho))**rho
sigmas = torch.sqrt(sigmas**2 / (1 + sigmas**2))
return sigmas
def logsnr_cosine_interp_schedule(n,
logsnr_min=-15,
logsnr_max=15,
scale_min=2,
scale_max=4):
return logsnrs_to_sigmas(
_logsnr_cosine_interp(n, logsnr_min, logsnr_max, scale_min, scale_max))
def noise_schedule(schedule='logsnr_cosine_interp',
n=1000,
zero_terminal_snr=False,
**kwargs):
# compute sigmas
sigmas = {
'logsnr_cosine_interp': logsnr_cosine_interp_schedule
}[schedule](n, **kwargs)
# post-processing
if zero_terminal_snr and sigmas.max() != 1.0:
scale = (1.0 - sigmas.min()) / (sigmas.max() - sigmas.min())
sigmas = sigmas.min() + scale * (sigmas - sigmas.min())
return sigmas

View File

@@ -0,0 +1,14 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import random
import numpy as np
import torch
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True

View File

@@ -0,0 +1,194 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import torch
import torchsde
from tqdm.auto import trange
def get_ancestral_step(sigma_from, sigma_to, eta=1.):
"""
Calculates the noise level (sigma_down) to step down to and the amount
of noise to add (sigma_up) when doing an ancestral sampling step.
"""
if not eta:
return sigma_to, 0.
sigma_up = min(
sigma_to,
eta * (
sigma_to**2 * # noqa
(sigma_from**2 - sigma_to**2) / sigma_from**2)**0.5)
sigma_down = (sigma_to**2 - sigma_up**2)**0.5
return sigma_down, sigma_up
def get_scalings(sigma):
c_out = -sigma
c_in = 1 / (sigma**2 + 1.**2)**0.5
return c_out, c_in
@torch.no_grad()
def sample_heun(noise,
model,
sigmas,
s_churn=0.,
s_tmin=0.,
s_tmax=float('inf'),
s_noise=1.,
show_progress=True):
"""
Implements Algorithm 2 (Heun steps) from Karras et al. (2022).
"""
x = noise * sigmas[0]
for i in trange(len(sigmas) - 1, disable=not show_progress):
gamma = 0.
if s_tmin <= sigmas[i] <= s_tmax and sigmas[i] < float('inf'):
gamma = min(s_churn / (len(sigmas) - 1), 2**0.5 - 1)
eps = torch.randn_like(x) * s_noise
sigma_hat = sigmas[i] * (gamma + 1)
if gamma > 0:
x = x + eps * (sigma_hat**2 - sigmas[i]**2)**0.5
if sigmas[i] == float('inf'):
# Euler method
denoised = model(noise, sigma_hat)
x = denoised + sigmas[i + 1] * (gamma + 1) * noise
else:
_, c_in = get_scalings(sigma_hat)
denoised = model(x * c_in, sigma_hat)
d = (x - denoised) / sigma_hat
dt = sigmas[i + 1] - sigma_hat
if sigmas[i + 1] == 0:
# Euler method
x = x + d * dt
else:
# Heun's method
x_2 = x + d * dt
_, c_in = get_scalings(sigmas[i + 1])
denoised_2 = model(x_2 * c_in, sigmas[i + 1])
d_2 = (x_2 - denoised_2) / sigmas[i + 1]
d_prime = (d + d_2) / 2
x = x + d_prime * dt
return x
class BatchedBrownianTree:
"""
A wrapper around torchsde.BrownianTree that enables batches of entropy.
"""
def __init__(self, x, t0, t1, seed=None, **kwargs):
t0, t1, self.sign = self.sort(t0, t1)
w0 = kwargs.get('w0', torch.zeros_like(x))
if seed is None:
seed = torch.randint(0, 2**63 - 1, []).item()
self.batched = True
try:
assert len(seed) == x.shape[0]
w0 = w0[0]
except TypeError:
seed = [seed]
self.batched = False
self.trees = [
torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs)
for s in seed
]
@staticmethod
def sort(a, b):
return (a, b, 1) if a < b else (b, a, -1)
def __call__(self, t0, t1):
t0, t1, sign = self.sort(t0, t1)
w = torch.stack([tree(t0, t1) for tree in self.trees]) * (
self.sign * sign)
return w if self.batched else w[0]
class BrownianTreeNoiseSampler:
"""
A noise sampler backed by a torchsde.BrownianTree.
Args:
x (Tensor): The tensor whose shape, device and dtype to use to generate
random samples.
sigma_min (float): The low end of the valid interval.
sigma_max (float): The high end of the valid interval.
seed (int or List[int]): The random seed. If a list of seeds is
supplied instead of a single integer, then the noise sampler will
use one BrownianTree per batch item, each with its own seed.
transform (callable): A function that maps sigma to the sampler's
internal timestep.
"""
def __init__(self,
x,
sigma_min,
sigma_max,
seed=None,
transform=lambda x: x):
self.transform = transform
t0 = self.transform(torch.as_tensor(sigma_min))
t1 = self.transform(torch.as_tensor(sigma_max))
self.tree = BatchedBrownianTree(x, t0, t1, seed)
def __call__(self, sigma, sigma_next):
t0 = self.transform(torch.as_tensor(sigma))
t1 = self.transform(torch.as_tensor(sigma_next))
return self.tree(t0, t1) / (t1 - t0).abs().sqrt()
@torch.no_grad()
def sample_dpmpp_2m_sde(noise,
model,
sigmas,
eta=1.,
s_noise=1.,
solver_type='midpoint',
show_progress=True):
"""
DPM-Solver++ (2M) SDE.
"""
assert solver_type in {'heun', 'midpoint'}
x = noise * sigmas[0]
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas[
sigmas < float('inf')].max()
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max)
old_denoised = None
h_last = None
for i in trange(len(sigmas) - 1, disable=not show_progress):
if sigmas[i] == float('inf'):
# Euler method
denoised = model(noise, sigmas[i])
x = denoised + sigmas[i + 1] * noise
else:
_, c_in = get_scalings(sigmas[i])
denoised = model(x * c_in, sigmas[i])
if sigmas[i + 1] == 0:
# Denoising step
x = denoised
else:
# DPM-Solver++(2M) SDE
t, s = -sigmas[i].log(), -sigmas[i + 1].log()
h = s - t
eta_h = eta * h
x = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * x + \
(-h - eta_h).expm1().neg() * denoised
if old_denoised is not None:
r = h_last / h
if solver_type == 'heun':
x = x + ((-h - eta_h).expm1().neg() / (-h - eta_h) + 1) * \
(1 / r) * (denoised - old_denoised)
elif solver_type == 'midpoint':
x = x + 0.5 * (-h - eta_h).expm1().neg() * \
(1 / r) * (denoised - old_denoised)
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[
i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise
old_denoised = denoised
h_last = h
return x

View File

@@ -0,0 +1,404 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import math
import random
import numpy as np
import torch
import torchvision.transforms.functional as F
from PIL import Image, ImageFilter
__all__ = [
'Compose', 'Resize', 'Rescale', 'CenterCrop', 'CenterCropV2',
'CenterCropWide', 'RandomCrop', 'RandomCropV2', 'RandomHFlip',
'GaussianBlur', 'ColorJitter', 'RandomGray', 'ToTensor', 'Normalize',
'ResizeRandomCrop', 'ExtractResizeRandomCrop', 'ExtractResizeAssignCrop'
]
class Compose(object):
def __init__(self, transforms):
self.transforms = transforms
def __getitem__(self, index):
if isinstance(index, slice):
return Compose(self.transforms[index])
else:
return self.transforms[index]
def __len__(self):
return len(self.transforms)
def __call__(self, rgb):
for t in self.transforms:
rgb = t(rgb)
return rgb
class Resize(object):
def __init__(self, size=256):
if isinstance(size, int):
size = (size, size)
self.size = size
def __call__(self, rgb):
if isinstance(rgb, list):
rgb = [u.resize(self.size, Image.BILINEAR) for u in rgb]
else:
rgb = rgb.resize(self.size, Image.BILINEAR)
return rgb
class Rescale(object):
def __init__(self, size=256, interpolation=Image.BILINEAR):
self.size = size
self.interpolation = interpolation
def __call__(self, rgb):
w, h = rgb[0].size
scale = self.size / min(w, h)
out_w, out_h = int(round(w * scale)), int(round(h * scale))
rgb = [u.resize((out_w, out_h), self.interpolation) for u in rgb]
return rgb
class CenterCrop(object):
def __init__(self, size=224):
self.size = size
def __call__(self, rgb):
w, h = rgb[0].size
assert min(w, h) >= self.size
x1 = (w - self.size) // 2
y1 = (h - self.size) // 2
rgb = [u.crop((x1, y1, x1 + self.size, y1 + self.size)) for u in rgb]
return rgb
class ResizeRandomCrop(object):
def __init__(self, size=256, size_short=292):
self.size = size
self.size_short = size_short
def __call__(self, rgb):
# consistent crop between rgb and m
while min(rgb[0].size) >= 2 * self.size_short:
rgb = [
u.resize((u.width // 2, u.height // 2), resample=Image.BOX)
for u in rgb
]
scale = self.size_short / min(rgb[0].size)
rgb = [
u.resize((round(scale * u.width), round(scale * u.height)),
resample=Image.BICUBIC) for u in rgb
]
out_w = self.size
out_h = self.size
w, h = rgb[0].size
x1 = random.randint(0, w - out_w)
y1 = random.randint(0, h - out_h)
rgb = [u.crop((x1, y1, x1 + out_w, y1 + out_h)) for u in rgb]
return rgb
class ExtractResizeRandomCrop(object):
def __init__(self, size=256, size_short=292):
self.size = size
self.size_short = size_short
def __call__(self, rgb):
# consistent crop between rgb and m
while min(rgb[0].size) >= 2 * self.size_short:
rgb = [
u.resize((u.width // 2, u.height // 2), resample=Image.BOX)
for u in rgb
]
scale = self.size_short / min(rgb[0].size)
rgb = [
u.resize((round(scale * u.width), round(scale * u.height)),
resample=Image.BICUBIC) for u in rgb
]
out_w = self.size
out_h = self.size
w, h = rgb[0].size
x1 = random.randint(0, w - out_w)
y1 = random.randint(0, h - out_h)
rgb = [u.crop((x1, y1, x1 + out_w, y1 + out_h)) for u in rgb]
wh = [x1, y1, x1 + out_w, y1 + out_h]
return rgb, wh
class ExtractResizeAssignCrop(object):
def __init__(self, size=256, size_short=292):
self.size = size
self.size_short = size_short
def __call__(self, rgb, wh):
# consistent crop between rgb and m
while min(rgb[0].size) >= 2 * self.size_short:
rgb = [
u.resize((u.width // 2, u.height // 2), resample=Image.BOX)
for u in rgb
]
scale = self.size_short / min(rgb[0].size)
rgb = [
u.resize((round(scale * u.width), round(scale * u.height)),
resample=Image.BICUBIC) for u in rgb
]
rgb = [u.crop(wh) for u in rgb]
rgb = [u.resize((self.size, self.size), Image.BILINEAR) for u in rgb]
return rgb
class CenterCropV2(object):
def __init__(self, size):
self.size = size
def __call__(self, img):
# fast resize
while min(img[0].size) >= 2 * self.size:
img = [
u.resize((u.width // 2, u.height // 2), resample=Image.BOX)
for u in img
]
scale = self.size / min(img[0].size)
img = [
u.resize((round(scale * u.width), round(scale * u.height)),
resample=Image.BICUBIC) for u in img
]
# center crop
x1 = (img[0].width - self.size) // 2
y1 = (img[0].height - self.size) // 2
img = [u.crop((x1, y1, x1 + self.size, y1 + self.size)) for u in img]
return img
class CenterCropWide(object):
def __init__(self, size):
self.size = size
def __call__(self, img):
if isinstance(img, list):
scale = min(img[0].size[0] / self.size[0],
img[0].size[1] / self.size[1])
img = [
u.resize((round(u.width // scale), round(u.height // scale)),
resample=Image.BOX) for u in img
]
# center crop
x1 = (img[0].width - self.size[0]) // 2
y1 = (img[0].height - self.size[1]) // 2
img = [
u.crop((x1, y1, x1 + self.size[0], y1 + self.size[1]))
for u in img
]
return img
else:
scale = min(img.size[0] / self.size[0], img.size[1] / self.size[1])
img = img.resize(
(round(img.width // scale), round(img.height // scale)),
resample=Image.BOX)
x1 = (img.width - self.size[0]) // 2
y1 = (img.height - self.size[1]) // 2
img = img.crop((x1, y1, x1 + self.size[0], y1 + self.size[1]))
return img
class RandomCrop(object):
def __init__(self, size=224, min_area=0.4):
self.size = size
self.min_area = min_area
def __call__(self, rgb):
# consistent crop between rgb and m
w, h = rgb[0].size
area = w * h
out_w, out_h = float('inf'), float('inf')
while out_w > w or out_h > h:
target_area = random.uniform(self.min_area, 1.0) * area
aspect_ratio = random.uniform(3. / 4., 4. / 3.)
out_w = int(round(math.sqrt(target_area * aspect_ratio)))
out_h = int(round(math.sqrt(target_area / aspect_ratio)))
x1 = random.randint(0, w - out_w)
y1 = random.randint(0, h - out_h)
rgb = [u.crop((x1, y1, x1 + out_w, y1 + out_h)) for u in rgb]
rgb = [u.resize((self.size, self.size), Image.BILINEAR) for u in rgb]
return rgb
class RandomCropV2(object):
def __init__(self, size=224, min_area=0.4, ratio=(3. / 4., 4. / 3.)):
if isinstance(size, (tuple, list)):
self.size = size
else:
self.size = (size, size)
self.min_area = min_area
self.ratio = ratio
def _get_params(self, img):
width, height = img.size
area = height * width
for _ in range(10):
target_area = random.uniform(self.min_area, 1.0) * area
log_ratio = (math.log(self.ratio[0]), math.log(self.ratio[1]))
aspect_ratio = math.exp(random.uniform(*log_ratio))
w = int(round(math.sqrt(target_area * aspect_ratio)))
h = int(round(math.sqrt(target_area / aspect_ratio)))
if 0 < w <= width and 0 < h <= height:
i = random.randint(0, height - h)
j = random.randint(0, width - w)
return i, j, h, w
# Fallback to central crop
in_ratio = float(width) / float(height)
if (in_ratio < min(self.ratio)):
w = width
h = int(round(w / min(self.ratio)))
elif (in_ratio > max(self.ratio)):
h = height
w = int(round(h * max(self.ratio)))
else:
w = width
h = height
i = (height - h) // 2
j = (width - w) // 2
return i, j, h, w
def __call__(self, rgb):
i, j, h, w = self._get_params(rgb[0])
rgb = [F.resized_crop(u, i, j, h, w, self.size) for u in rgb]
return rgb
class RandomHFlip(object):
def __init__(self, p=0.5):
self.p = p
def __call__(self, rgb):
if random.random() < self.p:
rgb = [u.transpose(Image.FLIP_LEFT_RIGHT) for u in rgb]
return rgb
class GaussianBlur(object):
def __init__(self, sigmas=[0.1, 2.0], p=0.5):
self.sigmas = sigmas
self.p = p
def __call__(self, rgb):
if random.random() < self.p:
sigma = random.uniform(*self.sigmas)
rgb = [
u.filter(ImageFilter.GaussianBlur(radius=sigma)) for u in rgb
]
return rgb
class ColorJitter(object):
def __init__(self,
brightness=0.4,
contrast=0.4,
saturation=0.4,
hue=0.1,
p=0.5):
self.brightness = brightness
self.contrast = contrast
self.saturation = saturation
self.hue = hue
self.p = p
def __call__(self, rgb):
if random.random() < self.p:
brightness, contrast, saturation, hue = self._random_params()
transforms = [
lambda f: F.adjust_brightness(f, brightness),
lambda f: F.adjust_contrast(f, contrast),
lambda f: F.adjust_saturation(f, saturation),
lambda f: F.adjust_hue(f, hue)
]
random.shuffle(transforms)
for t in transforms:
rgb = [t(u) for u in rgb]
return rgb
def _random_params(self):
brightness = random.uniform(
max(0, 1 - self.brightness), 1 + self.brightness)
contrast = random.uniform(max(0, 1 - self.contrast), 1 + self.contrast)
saturation = random.uniform(
max(0, 1 - self.saturation), 1 + self.saturation)
hue = random.uniform(-self.hue, self.hue)
return brightness, contrast, saturation, hue
class RandomGray(object):
def __init__(self, p=0.2):
self.p = p
def __call__(self, rgb):
if random.random() < self.p:
rgb = [u.convert('L').convert('RGB') for u in rgb]
return rgb
class ToTensor(object):
def __call__(self, rgb):
if isinstance(rgb, list):
rgb = torch.stack([F.to_tensor(u) for u in rgb], dim=0)
else:
rgb = F.to_tensor(rgb)
return rgb
class Normalize(object):
def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
self.mean = mean
self.std = std
def __call__(self, rgb):
rgb = rgb.clone()
rgb.clamp_(0, 1)
if not isinstance(self.mean, torch.Tensor):
self.mean = rgb.new_tensor(self.mean).view(-1)
if not isinstance(self.std, torch.Tensor):
self.std = rgb.new_tensor(self.std).view(-1)
if rgb.dim() == 4:
rgb.sub_(self.mean.view(1, -1, 1,
1)).div_(self.std.view(1, -1, 1, 1))
elif rgb.dim() == 3:
rgb.sub_(self.mean.view(-1, 1, 1)).div_(self.std.view(-1, 1, 1))
return rgb

View File

@@ -0,0 +1,227 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import os.path as osp
import random
from copy import copy
from typing import Any, Dict
import torch
import torch.cuda.amp as amp
import torch.nn.functional as F
import modelscope.models.multi_modal.video_to_video.utils.transforms as data
from modelscope.metainfo import Models
from modelscope.models.base import TorchModel
from modelscope.models.builder import MODELS
from modelscope.models.multi_modal.video_to_video.modules import *
from modelscope.models.multi_modal.video_to_video.modules import (
AutoencoderKL, FrozenOpenCLIPEmbedder, Vid2VidSDUNet,
get_first_stage_encoding)
from modelscope.models.multi_modal.video_to_video.utils.config import cfg
from modelscope.models.multi_modal.video_to_video.utils.diffusion_sdedit import \
GaussianDiffusion_SDEdit
from modelscope.models.multi_modal.video_to_video.utils.schedules_sdedit import \
noise_schedule
from modelscope.models.multi_modal.video_to_video.utils.seed import setup_seed
from modelscope.utils.config import Config
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger
__all__ = ['VideoToVideo']
logger = get_logger()
@MODELS.register_module(
Tasks.video_to_video, module_name=Models.video_to_video_model)
class VideoToVideo(TorchModel):
r"""
Video2Video aims to solve the task of generating super-resolution videos based on input
video and text, which is a video generation basic model developed by Alibaba Cloud.
Paper link: https://arxiv.org/abs/2306.02018
Attributes:
diffusion: diffusion model for DDIM.
autoencoder: decode the latent representation of input video into visual space.
clip_encoder: encode the text into text embedding.
"""
def __init__(self, model_dir, *args, **kwargs):
r"""
Args:
model_dir (`str` or `os.PathLike`)
Can be either:
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co
or modelscope.cn. Valid model ids can be located at the root-level, like `bert-base-uncased`,
or namespaced under a user or organization name, like `dbmdz/bert-base-german-cased`.
- A path to a *directory* containing model weights saved using
[`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
- A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In
this case, `from_tf` should be set to `True` and a configuration object should be provided as
`config` argument. This loading path is slower than converting the TensorFlow checkpoint in a
PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
- A path or url to a model folder containing a *flax checkpoint file* in *.msgpack* format (e.g,
`./flax_model/` containing `flax_model.msgpack`). In this case, `from_flax` should be set to
`True`.
"""
super().__init__(model_dir=model_dir, *args, **kwargs)
self.config = Config.from_file(
osp.join(model_dir, ModelFile.CONFIGURATION))
cfg.solver_mode = self.config.model.model_args.solver_mode
# assign default value
cfg.batch_size = self.config.model.model_cfg.batch_size
cfg.target_fps = self.config.model.model_cfg.target_fps
cfg.max_frames = self.config.model.model_cfg.max_frames
cfg.latent_hei = self.config.model.model_cfg.latent_hei
cfg.latent_wid = self.config.model.model_cfg.latent_wid
cfg.model_path = osp.join(model_dir,
self.config.model.model_args.ckpt_unet)
self.device = torch.device(
'cuda') if torch.cuda.is_available() else torch.device('cpu')
if 'seed' in self.config.model.model_args.keys():
cfg.seed = self.config.model.model_args.seed
else:
cfg.seed = random.randint(0, 99999)
setup_seed(cfg.seed)
# transform
vid_trans = data.Compose(
[data.ToTensor(),
data.Normalize(mean=cfg.mean, std=cfg.std)])
self.vid_trans = vid_trans
cfg.embedder.pretrained = osp.join(
model_dir, self.config.model.model_args.ckpt_clip)
clip_encoder = FrozenOpenCLIPEmbedder(
pretrained=cfg.embedder.pretrained)
clip_encoder.model.to(self.device)
self.clip_encoder = clip_encoder
logger.info(f'Build encoder with {cfg.embedder.type}')
# [unet]
generator = Vid2VidSDUNet()
generator = generator.to(self.device)
generator.eval()
load_dict = torch.load(cfg.model_path, map_location='cpu')
ret = generator.load_state_dict(load_dict['state_dict'], strict=True)
self.generator = generator
logger.info('Load model {} path {}, with local status {}'.format(
cfg.UNet.type, cfg.model_path, ret))
# [diffusion]
sigmas = noise_schedule(
schedule='logsnr_cosine_interp',
n=1000,
zero_terminal_snr=True,
scale_min=2.0,
scale_max=4.0)
diffusion = GaussianDiffusion_SDEdit(
sigmas=sigmas, prediction_type='v')
self.diffusion = diffusion
logger.info('Build diffusion with type of GaussianDiffusion_SDEdit')
# [auotoencoder]
cfg.auto_encoder.pretrained = osp.join(
model_dir, self.config.model.model_args.ckpt_autoencoder)
autoencoder = AutoencoderKL(**cfg.auto_encoder)
autoencoder.eval()
for param in autoencoder.parameters():
param.requires_grad = False
autoencoder.to(self.device)
self.autoencoder = autoencoder
torch.cuda.empty_cache()
negative_prompt = cfg.negative_prompt
negative_y = clip_encoder(negative_prompt).detach()
self.negative_y = negative_y
positive_prompt = cfg.positive_prompt
self.positive_prompt = positive_prompt
self.cfg = cfg
def forward(self, input: Dict[str, Any]):
r"""
The entry function of video to video task.
1. Using CLIP to encode text into embeddings.
2. Using diffusion model to generate the video's latent representation.
3. Using autoencoder to decode the video's latent representation to visual space.
Args:
input (`Dict[Str, Any]`):
The input of the task
Returns:
A generated video (as pytorch tensor).
"""
video_data = input['video_data']
y = input['y']
cfg = self.cfg
video_data = F.interpolate(
video_data, size=(720, 1280), mode='bilinear')
video_data = video_data.unsqueeze(0)
video_data = video_data.to(self.device)
batch_size, frames_num, _, _, _ = video_data.shape
video_data = rearrange(video_data, 'b f c h w -> (b f) c h w')
video_data_list = torch.chunk(
video_data, video_data.shape[0] // 2, dim=0)
with torch.no_grad():
decode_data = []
for vd_data in video_data_list:
encoder_posterior = self.autoencoder.encode(vd_data)
tmp = get_first_stage_encoding(encoder_posterior).detach()
decode_data.append(tmp)
video_data_feature = torch.cat(decode_data, dim=0)
video_data_feature = rearrange(
video_data_feature, '(b f) c h w -> b c f h w', b=batch_size)
with amp.autocast(enabled=True):
total_noise_levels = 600
t = torch.randint(
total_noise_levels - 1,
total_noise_levels, (1, ),
dtype=torch.long).to(self.device)
noise = torch.randn_like(video_data_feature)
noised_lr = self.diffusion.diffuse(video_data_feature, t, noise)
model_kwargs = [{'y': y}, {'y': self.negative_y}]
gen_vid = self.diffusion.sample(
noise=noised_lr,
model=self.generator,
model_kwargs=model_kwargs,
guide_scale=7.5,
guide_rescale=0.2,
solver='dpmpp_2m_sde' if cfg.solver_mode == 'fast' else 'heun',
steps=30 if cfg.solver_mode == 'fast' else 50,
t_max=total_noise_levels - 1,
t_min=0,
discretization='trailing')
scale_factor = 0.18215
vid_tensor_feature = 1. / scale_factor * gen_vid
vid_tensor_feature = rearrange(vid_tensor_feature,
'b c f h w -> (b f) c h w')
vid_tensor_feature_list = torch.chunk(
vid_tensor_feature, vid_tensor_feature.shape[0] // 2, dim=0)
decode_data = []
for vd_data in vid_tensor_feature_list:
tmp = self.autoencoder.decode(vd_data)
decode_data.append(tmp)
vid_tensor_gen = torch.cat(decode_data, dim=0)
gen_video = rearrange(
vid_tensor_gen, '(b f) c h w -> b c f h w', b=cfg.batch_size)
return gen_video.type(torch.float32).cpu()

View File

@@ -65,6 +65,7 @@ if TYPE_CHECKING:
ModelForTextRanking,
ModelForTokenClassification,
ModelForTokenClassificationWithCRF,
ModelForMachineReadingComprehension,
)
from .unite import UniTEForTranslationEvaluation
from .use import UserSatisfactionEstimation
@@ -159,6 +160,7 @@ else:
'ModelForTextRanking',
'ModelForTokenClassification',
'ModelForTokenClassificationWithCRF',
'ModelForMachineReadingComprehension',
],
'sentence_embedding': ['SentenceEmbedding'],
'T5': ['T5ForConditionalGeneration'],

View File

@@ -13,6 +13,7 @@ if TYPE_CHECKING:
ModelForTokenClassificationWithCRF)
from .text_generation import ModelForTextGeneration
from .text_ranking import ModelForTextRanking
from .machine_reading_comprehension import ModelForMachineReadingComprehension
else:
_import_structure = {
@@ -25,6 +26,8 @@ else:
['ModelForTokenClassification', 'ModelForTokenClassificationWithCRF'],
'text_generation': ['ModelForTextGeneration'],
'text_ranking': ['ModelForTextRanking'],
'machine_reading_comprehension':
['ModelForMachineReadingComprehension'],
}
import sys

View File

@@ -0,0 +1,139 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple
import torch
import torch.nn.functional as F
from torch import nn
from transformers import AutoConfig
from transformers.modeling_outputs import ModelOutput
from transformers.models.roberta.modeling_roberta import (
RobertaModel, RobertaPreTrainedModel)
from modelscope.metainfo import Heads, Models, TaskModels
from modelscope.models.base import TorchModel
from modelscope.models.builder import MODELS
from modelscope.models.nlp.task_models.task_model import EncoderModel
from modelscope.outputs import MachineReadingComprehensionOutput, OutputKeys
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.hub import parse_label_mapping
__all__ = ['ModelForMachineReadingComprehension']
@MODELS.register_module(
Tasks.machine_reading_comprehension,
module_name=TaskModels.machine_reading_comprehension)
class ModelForMachineReadingComprehension(TorchModel):
'''
Pretrained Machine Reader (PMR) model (https://arxiv.org/pdf/2212.04755.pdf)
'''
_keys_to_ignore_on_load_unexpected = [r'pooler']
_keys_to_ignore_on_load_missing = [r'position_ids']
def __init__(self, model_dir: str, *args, **kwargs):
super().__init__(model_dir, *args, **kwargs)
self.config = AutoConfig.from_pretrained(model_dir)
self.num_labels = self.config.num_labels
self.roberta = RobertaModel(self.config, add_pooling_layer=False)
self.span_transfer = MultiNonLinearProjection(
self.config.hidden_size,
self.config.hidden_size,
self.config.hidden_dropout_prob,
intermediate_hidden_size=self.config.
projection_intermediate_hidden_size)
self.load_state_dict(
torch.load(
os.path.join(model_dir, ModelFile.TORCH_MODEL_BIN_FILE)))
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
label_mask=None,
match_labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.roberta(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
# adapted from https://github.com/ShannonAI/mrc-for-flat-nested-ner
# for every position $i$ in sequence, should concate $j$ to
# predict if $i$ and $j$ are start_pos and end_pos for an entity.
# [batch, seq_len, hidden]
span_intermediate = self.span_transfer(sequence_output)
# [batch, seq_len, seq_len]
span_logits = torch.matmul(span_intermediate,
sequence_output.transpose(-1, -2))
total_loss = None
if match_labels is not None:
match_loss = self.compute_loss(span_logits, match_labels,
label_mask)
total_loss = match_loss
if not return_dict:
output = (span_logits) + outputs[2:]
return ((total_loss, )
+ output) if total_loss is not None else output
return MachineReadingComprehensionOutput(
loss=total_loss,
span_logits=span_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
class MultiNonLinearProjection(nn.Module):
def __init__(self,
hidden_size,
num_label,
dropout_rate,
act_func='gelu',
intermediate_hidden_size=None):
super(MultiNonLinearProjection, self).__init__()
self.num_label = num_label
self.intermediate_hidden_size = hidden_size if intermediate_hidden_size is None else intermediate_hidden_size
self.classifier1 = nn.Linear(hidden_size,
self.intermediate_hidden_size)
self.classifier2 = nn.Linear(self.intermediate_hidden_size,
self.num_label)
self.dropout = nn.Dropout(dropout_rate)
self.act_func = act_func
def forward(self, input_features):
features_output1 = self.classifier1(input_features)
if self.act_func == 'gelu':
features_output1 = F.gelu(features_output1)
elif self.act_func == 'relu':
features_output1 = F.relu(features_output1)
elif self.act_func == 'tanh':
features_output1 = F.tanh(features_output1)
else:
raise ValueError
features_output1 = self.dropout(features_output1)
features_output2 = self.classifier2(features_output1)
return features_output2

View File

@@ -464,3 +464,25 @@ class TranslationEvaluationOutput(ModelOutputBase):
score: Tensor = None
loss: Tensor = None
input_format: List[str] = None
@dataclass
class MachineReadingComprehensionOutput(ModelOutputBase):
"""The output class for machine reading comprehension models.
Args:
loss (`Tensor`, *optional*): The training loss of the current batch
match_loss (`Tensor`, *optinal*): The match loss of the current batch
span_logits (`Tensor`): The logits of the span matrix output by the model
hidden_states (`Tuple[Tensor]`, *optinal*): The hidden states output by the model
attentions (`Tuple[Tensor]`, *optinal*): The attention scores output by the model
input_ids (`Tensor`): The token ids of the input sentence
"""
loss: Optional[Tensor] = None
match_loss: Optional[Tensor] = None
span_logits: Tensor = None
hidden_states: Optional[Tuple[Tensor]] = None
attentions: Optional[Tuple[Tensor]] = None
input_ids: Tensor = None

View File

@@ -323,6 +323,8 @@ TASK_INPUTS = {
'positive': InputType.LIST,
'negative': InputType.LIST
},
Tasks.machine_reading_comprehension:
InputType.TEXT,
# ============ audio tasks ===================
Tasks.auto_speech_recognition: # input can be audio, or audio and text.

View File

@@ -81,28 +81,24 @@ class SpeakerVerificationPipeline(Pipeline):
inputs: torch.Tensor,
in_audios: Union[np.ndarray, list],
save_dir=None):
if isinstance(in_audios[0], str):
if save_dir is not None:
# save the embeddings
os.makedirs(save_dir, exist_ok=True)
for i, p in enumerate(in_audios):
save_path = os.path.join(
save_dir, '%s.npy' %
(os.path.basename(p).rsplit('.', 1)[0]))
np.save(save_path, inputs[i].numpy())
if isinstance(in_audios[0], str) and save_dir is not None:
# save the embeddings
os.makedirs(save_dir, exist_ok=True)
for i, p in enumerate(in_audios):
save_path = os.path.join(
save_dir, '%s.npy' %
(os.path.basename(p).rsplit('.', 1)[0]))
np.save(save_path, inputs[i].numpy())
if len(in_audios) == 2:
# compute the score
score = self.compute_cos_similarity(inputs[0], inputs[1])
score = round(score, 5)
if score >= self.thr:
ans = 'yes'
else:
ans = 'no'
output = {OutputKeys.SCORE: score, OutputKeys.TEXT: ans}
if len(inputs) == 2:
# compute the score
score = self.compute_cos_similarity(inputs[0], inputs[1])
score = round(score, 5)
if score >= self.thr:
ans = 'yes'
else:
output = {OutputKeys.TEXT: 'No similarity score output'}
ans = 'no'
output = {OutputKeys.SCORE: score, OutputKeys.TEXT: ans}
else:
output = {OutputKeys.TEXT: 'No similarity score output'}

View File

@@ -40,6 +40,7 @@ class StableDiffusionPipeline(DiffusersPipeline):
use_safetensors: load safetensors weights.
"""
use_safetensors = kwargs.pop('use_safetensors', False)
torch_type = kwargs.pop('torch_type', torch.float32)
# check custom diffusion input value
if custom_dir is None and modifier_token is not None:
raise ValueError(
@@ -50,7 +51,6 @@ class StableDiffusionPipeline(DiffusersPipeline):
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
# load pipeline
torch_type = torch.float16 if self.device == 'cuda' else torch.float32
self.pipeline = DiffusionPipeline.from_pretrained(
model, use_safetensors=use_safetensors, torch_dtype=torch_type)
self.pipeline = self.pipeline.to(self.device)

View File

@@ -0,0 +1,104 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import tempfile
from typing import Any, Dict, Optional
import cv2
import torch
from einops import rearrange
from modelscope.metainfo import Pipelines
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Input, Model, Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors.image import LoadImage
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger
logger = get_logger()
@PIPELINES.register_module(
Tasks.image_to_video, module_name=Pipelines.image_to_video_task_pipeline)
class ImageToVideoPipeline(Pipeline):
r""" Image To Video Pipeline.
Examples:
>>> from modelscope.pipelines import pipeline
>>> from modelscope.outputs import OutputKeys
>>> p = pipeline('image-to-video', 'damo/Image-to-Video')
>>> input = 'path_to_image'
>>> p(input,)
>>> {OutputKeys.OUTPUT_VIDEO: path-to-the-generated-video}
>>>
"""
def __init__(self, model: str, **kwargs):
"""
use `model` to create a kws pipeline for prediction
Args:
model: model id on modelscope hub.
"""
super().__init__(model=model, **kwargs)
def preprocess(self, input: Input, **preprocess_params) -> Dict[str, Any]:
img_path = input
image = LoadImage.convert_to_img(img_path)
if image.mode != 'RGB':
image = image.convert('RGB')
vit_frame = self.model.vid_trans(image)
vit_frame = vit_frame.unsqueeze(0)
vit_frame = vit_frame.to(self.model.device)
return {'vit_frame': vit_frame}
def forward(self, input: Dict[str, Any],
**forward_params) -> Dict[str, Any]:
video = self.model(input)
return {'video': video}
def postprocess(self, inputs: Dict[str, Any],
**post_params) -> Dict[str, Any]:
video = tensor2vid(inputs['video'], self.model.cfg.mean,
self.model.cfg.std)
output_video_path = post_params.get('output_video', None)
temp_video_file = False
if output_video_path is None:
output_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name
temp_video_file = True
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
h, w, c = video[0].shape
video_writer = cv2.VideoWriter(
output_video_path, fourcc, fps=8, frameSize=(w, h))
for i in range(len(video)):
img = cv2.cvtColor(video[i], cv2.COLOR_RGB2BGR)
video_writer.write(img)
video_writer.release()
if temp_video_file:
video_file_content = b''
with open(output_video_path, 'rb') as f:
video_file_content = f.read()
os.remove(output_video_path)
return {OutputKeys.OUTPUT_VIDEO: video_file_content}
else:
return {OutputKeys.OUTPUT_VIDEO: output_video_path}
def tensor2vid(video, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]):
mean = torch.tensor(mean, device=video.device).reshape(1, -1, 1, 1, 1)
std = torch.tensor(std, device=video.device).reshape(1, -1, 1, 1, 1)
video = video.mul_(std).add_(mean)
video.clamp_(0, 1)
video = video * 255.0
images = rearrange(video, 'b c f h w -> b f h w c')[0]
images = [(img.numpy()).astype('uint8') for img in images]
return images

View File

@@ -0,0 +1,140 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import tempfile
from typing import Any, Dict, Optional
import cv2
import torch
from einops import rearrange
from modelscope.metainfo import Pipelines
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Input, Model, Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors.image import LoadImage
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger
logger = get_logger()
@PIPELINES.register_module(
Tasks.video_to_video, module_name=Pipelines.video_to_video_pipeline)
class VideoToVideoPipeline(Pipeline):
r""" Video To Video Pipeline, generating super-resolution videos based on input
video and text
Examples:
>>> from modelscope.pipelines import pipeline
>>> from modelscope.outputs import OutputKeys
>>> # YOUR_VIDEO_PATH: your video url or local position in low resolution
>>> # INPUT_TEXT: when we do video super-resolution, we will add the text content
>>> # into results
>>> # output_video_path: path-to-the-generated-video
>>> p = pipeline('video-to-video', 'damo/Video-to-Video')
>>> input = {"video_path":YOUR_VIDEO_PATH, "text": INPUT_TEXT}
>>> output_video_path = p(input,output_video='./output.mp4')[OutputKeys.OUTPUT_VIDEO]
"""
def __init__(self, model: str, **kwargs):
"""
use `model` to create a kws pipeline for prediction
Args:
model: model id on modelscope hub.
"""
super().__init__(model=model, **kwargs)
def preprocess(self, input: Input, **preprocess_params) -> Dict[str, Any]:
vid_path = input['video_path']
if 'text' in input.keys():
text = input['text']
else:
text = ''
caption = text + self.model.positive_prompt
y = self.model.clip_encoder(caption).detach()
max_frames = self.model.cfg.max_frames
capture = cv2.VideoCapture(vid_path)
_fps = capture.get(cv2.CAP_PROP_FPS)
sample_fps = _fps
_total_frame_num = capture.get(cv2.CAP_PROP_FRAME_COUNT)
stride = round(_fps / sample_fps)
start_frame = 0
pointer = 0
frame_list = []
while len(frame_list) < max_frames:
ret, frame = capture.read()
pointer += 1
if (not ret) or (frame is None):
break
if pointer < start_frame:
continue
if pointer >= _total_frame_num + 1:
break
if (pointer - start_frame) % stride == 0:
frame = LoadImage.convert_to_img(frame)
frame_list.append(frame)
capture.release()
video_data = self.model.vid_trans(frame_list)
return {'video_data': video_data, 'y': y}
def forward(self, input: Dict[str, Any],
**forward_params) -> Dict[str, Any]:
video = self.model(input)
return {'video': video}
def postprocess(self, inputs: Dict[str, Any],
**post_params) -> Dict[str, Any]:
video = tensor2vid(inputs['video'], self.model.cfg.mean,
self.model.cfg.std)
output_video_path = post_params.get('output_video', None)
temp_video_file = False
if output_video_path is None:
output_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name
temp_video_file = True
temp_dir = tempfile.mkdtemp()
for fid, frame in enumerate(video):
tpth = os.path.join(temp_dir, '%06d.png' % (fid + 1))
cv2.imwrite(tpth, frame[:, :, ::-1],
[int(cv2.IMWRITE_JPEG_QUALITY), 100])
cmd = f'ffmpeg -y -f image2 -loglevel quiet -framerate 8.0 -i {temp_dir}/%06d.png \
-vcodec libx264 -crf 17 -pix_fmt yuv420p {output_video_path}'
status = os.system(cmd)
if status != 0:
logger.info('Save Video Error with {}'.format(status))
os.system(f'rm -rf {temp_dir}')
if temp_video_file:
video_file_content = b''
with open(output_video_path, 'rb') as f:
video_file_content = f.read()
os.remove(output_video_path)
return {OutputKeys.OUTPUT_VIDEO: video_file_content}
else:
return {OutputKeys.OUTPUT_VIDEO: output_video_path}
def tensor2vid(video, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]):
mean = torch.tensor(mean, device=video.device).reshape(1, -1, 1, 1, 1)
std = torch.tensor(std, device=video.device).reshape(1, -1, 1, 1, 1)
video = video.mul_(std).add_(mean)
video.clamp_(0, 1)
video = video * 255.0
images = rearrange(video, 'b c f h w -> b f h w c')[0]
images = [(img.numpy()).astype('uint8') for img in images]
return images

View File

@@ -45,6 +45,7 @@ if TYPE_CHECKING:
from .document_grounded_dialog_retrieval_pipeline import DocumentGroundedDialogRetrievalPipeline
from .document_grounded_dialog_rerank_pipeline import DocumentGroundedDialogRerankPipeline
from .language_identification_pipline import LanguageIdentificationPipeline
from .machine_reading_comprehension_pipeline import MachineReadingComprehensionForNERPipeline
else:
_import_structure = {
@@ -108,7 +109,10 @@ else:
'document_grounded_dialog_retrieval_pipeline': [
'DocumentGroundedDialogRetrievalPipeline'
],
'language_identification_pipline': ['LanguageIdentificationPipeline']
'language_identification_pipline': ['LanguageIdentificationPipeline'],
'machine_reading_comprehension_pipeline': [
'MachineReadingComprehensionForNERPipeline'
],
}
import sys

View File

@@ -0,0 +1,83 @@
from typing import Any, Dict, Union
import torch
from modelscope.metainfo import Pipelines, Preprocessors
from modelscope.models.base import Model
from modelscope.outputs import MachineReadingComprehensionOutput, OutputKeys
from modelscope.pipelines.base import Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import Preprocessor
from modelscope.utils.constant import Fields, Tasks
@PIPELINES.register_module(
Tasks.machine_reading_comprehension,
module_name=Pipelines.machine_reading_comprehension_for_ner)
class MachineReadingComprehensionForNERPipeline(Pipeline):
'''
Pipeline for Pretrained Machine Reader (PMR) finetuned on Named Entity Recognition (NER)
Examples:
>>> from modelscope.pipelines import pipeline
>>> pipeline_ins = pipeline(
>>> task=Tasks.machine_reading_comprehension,
>>> model='damo/nlp_roberta_machine-reading-comprehension_for-ner')
>>> pipeline_ins('Soccer - Japan get lucky win , China in surprise defeat .')
>>> {'ORG': [], 'PER': [], 'LOC': [' Japan', ' China'], 'MISC': []}
'''
def __init__(self,
model: Union[Model, str],
preprocessor: Preprocessor = None,
config_file: str = None,
device: str = 'gpu',
auto_collate=True,
**kwargs):
super().__init__(
model=model,
preprocessor=preprocessor,
config_file=config_file,
device=device,
auto_collate=auto_collate)
assert isinstance(self.model, Model), \
f'please check whether model config exists in {ModelFile.CONFIGURATION}'
if preprocessor is None:
self.preprocessor = Preprocessor.from_pretrained(
self.model.model_dir, **kwargs)
self.labels = [label for label in self.preprocessor.label2query]
self.model.eval()
def forward(
self, inputs, **forward_params
) -> Union[Dict[str, Any], MachineReadingComprehensionOutput]:
with torch.no_grad():
outputs = self.model(**inputs)
span_logits = outputs['span_logits']
return MachineReadingComprehensionOutput(
span_logits=span_logits,
input_ids=inputs['input_ids'],
)
def postprocess(
self, inputs: Union[Dict[str, Any], MachineReadingComprehensionOutput]
) -> Dict[str, Any]:
span_preds = inputs['span_logits'] > 0
extracted_indices = torch.nonzero(span_preds.long())
result = {label: [] for label in self.labels}
for index in extracted_indices:
label = self.labels[index[0]]
start = index[1]
end = index[2] + 1
ids = inputs['input_ids'][index[0], start:end]
entity = self.preprocessor.tokenizer.decode(ids)
result[label].append(entity)
return result

View File

@@ -46,7 +46,8 @@ if TYPE_CHECKING:
CanmtTranslationPreprocessor, DialogueClassificationUsePreprocessor,
SiameseUiePreprocessor, DocumentGroundedDialogGeneratePreprocessor,
DocumentGroundedDialogRetrievalPreprocessor,
DocumentGroundedDialogRerankPreprocessor)
DocumentGroundedDialogRerankPreprocessor,
MachineReadingComprehensionForNERPreprocessor)
from .video import ReadVideoData, MovieSceneSegmentationPreprocessor
else:
@@ -77,7 +78,8 @@ else:
'nlp': [
'DocumentSegmentationTransformersPreprocessor',
'FaqQuestionAnsweringTransformersPreprocessor',
'FillMaskPoNetPreprocessor', 'FillMaskTransformersPreprocessor',
'FillMaskPoNetPreprocessor',
'FillMaskTransformersPreprocessor',
'NLPTokenizerPreprocessorBase',
'TextRankingTransformersPreprocessor',
'RelationExtractionTransformersPreprocessor',
@@ -85,26 +87,34 @@ else:
'TextGenerationSentencePiecePreprocessor',
'TextClassificationTransformersPreprocessor',
'TokenClassificationTransformersPreprocessor',
'TextErrorCorrectionPreprocessor', 'WordAlignmentPreprocessor',
'TextGenerationTransformersPreprocessor', 'Tokenize',
'TextErrorCorrectionPreprocessor',
'WordAlignmentPreprocessor',
'TextGenerationTransformersPreprocessor',
'Tokenize',
'TextGenerationT5Preprocessor',
'WordSegmentationBlankSetToLabelPreprocessor',
'MGLMSummarizationPreprocessor', 'CodeGeeXPreprocessor',
'MGLMSummarizationPreprocessor',
'CodeGeeXPreprocessor',
'ZeroShotClassificationTransformersPreprocessor',
'TextGenerationJiebaPreprocessor', 'SentencePiecePreprocessor',
'NERPreprocessorViet', 'NERPreprocessorThai',
'TextGenerationJiebaPreprocessor',
'SentencePiecePreprocessor',
'NERPreprocessorViet',
'NERPreprocessorThai',
'WordSegmentationPreprocessorThai',
'DialogIntentPredictionPreprocessor', 'DialogModelingPreprocessor',
'DialogIntentPredictionPreprocessor',
'DialogModelingPreprocessor',
'DialogStateTrackingPreprocessor',
'ConversationalTextToSqlPreprocessor',
'TableQuestionAnsweringPreprocessor',
'TranslationEvaluationTransformersPreprocessor',
'CanmtTranslationPreprocessor',
'DialogueClassificationUsePreprocessor', 'SiameseUiePreprocessor',
'DialogueClassificationUsePreprocessor',
'SiameseUiePreprocessor',
'DialogueClassificationUsePreprocessor',
'DocumentGroundedDialogGeneratePreprocessor',
'DocumentGroundedDialogRetrievalPreprocessor',
'DocumentGroundedDialogRerankPreprocessor'
'DocumentGroundedDialogRerankPreprocessor',
'MachineReadingComprehensionForNERPreprocessor',
],
}

View File

@@ -36,6 +36,7 @@ if TYPE_CHECKING:
from .document_grounded_dialog_generate_preprocessor import DocumentGroundedDialogGeneratePreprocessor
from .document_grounded_dialog_retrieval_preprocessor import DocumentGroundedDialogRetrievalPreprocessor
from .document_grounded_dialog_rerank_preprocessor import DocumentGroundedDialogRerankPreprocessor
from .machine_reading_comprehension_preprocessor import MachineReadingComprehensionForNERPreprocessor
else:
_import_structure = {
'bert_seq_cls_tokenizer': ['Tokenize'],
@@ -102,7 +103,10 @@ else:
'document_grounded_dialog_retrieval_preprocessor':
['DocumentGroundedDialogRetrievalPreprocessor'],
'document_grounded_dialog_rerank_preprocessor':
['DocumentGroundedDialogRerankPreprocessor']
['DocumentGroundedDialogRerankPreprocessor'],
'machine_reading_comprehension_preprocessor': [
'MachineReadingComprehensionForNERPreprocessor'
],
}
import sys

View File

@@ -0,0 +1,252 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import torch
from torch.utils.data import DataLoader, Dataset, SequentialSampler
from transformers import AutoTokenizer
from transformers.tokenization_utils_base import TruncationStrategy
from transformers.utils import logging
from modelscope.metainfo import Preprocessors
from modelscope.outputs import OutputKeys
from modelscope.preprocessors import Preprocessor
from modelscope.preprocessors.builder import PREPROCESSORS
from modelscope.utils.config import Config
from modelscope.utils.constant import ConfigFields, Fields, ModeKeys, ModelFile
logger = logging.get_logger(__name__)
MULTI_SEP_TOKENS_TOKENIZERS_SET = {'roberta', 'camembert', 'bart', 'mpnet'}
@PREPROCESSORS.register_module(
Fields.nlp,
module_name=Preprocessors.machine_reading_comprehension_for_ner)
class MachineReadingComprehensionForNERPreprocessor(Preprocessor):
'''
Preprocessor for Pretrained Machiner Reader (PMR) finetuned on Named Entity Recognition (NER)
'''
def __init__(self, model_dir, label2query=None):
super().__init__()
self.tokenizer = AutoTokenizer.from_pretrained(
model_dir, use_fast=False)
if label2query is None:
config_path = os.path.join(model_dir, ModelFile.CONFIGURATION)
config = Config.from_file(config_path)
self.label2query = config[ConfigFields.preprocessor].label2query
else:
self.label2query = label2query
def __call__(self, data: str):
all_data = []
for label in self.label2query:
all_data.append({
'context': data,
'end_position': [],
'entity_label': label,
'impossible': False,
'qas_id': '',
'query': self.label2query[label],
'span_position': [],
'start_position': []
})
all_data = self.prompt(all_data)
output = []
for data in all_data:
output.append(self.encode(data))
output = collate_to_max_length_roberta(output)
output = {
'input_ids': output[0],
'attention_mask': output[1],
'token_type_ids': output[2],
}
return output
def prompt(self, all_data, var=0):
new_datas = []
for data in all_data:
label = data['entity_label']
details = data['query']
context = data['context']
start_positions = data['start_position']
end_positions = data['end_position']
words = context.split()
assert len(words) == len(context.split(' '))
if var == 0:
query = '"{}". {}'.format(label, details) # ori
elif var == 1:
query = 'What are the "{}" entity, where {}'.format(
label, details) # variant 1
elif var == 2:
query = 'Identify the spans (if any) related to "{}" entity. Details: {}'.format(
label, details) # variant 2
span_positions = {
'{};{}'.format(start_positions[i], end_positions[i]):
' '.join(words[start_positions[i]:end_positions[i] + 1])
for i in range(len(start_positions))
}
new_data = {
'context': words,
'end_position': end_positions,
'entity_label': label,
'impossible': data['impossible'],
'qas_id': data['qas_id'],
'query': query,
'span_position': span_positions,
'start_position': start_positions,
}
new_datas.append(new_data)
return new_datas
def encode(self, data, max_length=512, max_query_length=64):
tokenizer = self.tokenizer
query = data['query']
context = data['context']
start_positions = data['start_position']
end_positions = data['end_position']
tokenizer_type = type(tokenizer).__name__.replace('Tokenizer',
'').lower()
sequence_added_tokens = (
tokenizer.model_max_length - tokenizer.max_len_single_sentence
+ 1 if tokenizer_type in MULTI_SEP_TOKENS_TOKENIZERS_SET else
tokenizer.model_max_length - tokenizer.max_len_single_sentence)
tok_to_orig_index = []
orig_to_tok_index = []
all_doc_tokens = []
for (i, token) in enumerate(context):
orig_to_tok_index.append(len(all_doc_tokens))
if tokenizer.__class__.__name__ in [
'RobertaTokenizer',
'LongformerTokenizer',
'BartTokenizer',
'RobertaTokenizerFast',
'LongformerTokenizerFast',
'BartTokenizerFast',
]:
sub_tokens = tokenizer.tokenize(token, add_prefix_space=True)
elif tokenizer.__class__.__name__ in ['BertTokenizer']:
sub_tokens = tokenizer.tokenize(token)
elif tokenizer.__class__.__name__ in ['BertWordPieceTokenizer']:
sub_tokens = tokenizer.encode(
token, add_special_tokens=False).tokens
else:
sub_tokens = tokenizer.tokenize(token)
for sub_token in sub_tokens:
tok_to_orig_index.append(i)
all_doc_tokens.append(sub_token)
tok_start_positions = [orig_to_tok_index[x] for x in start_positions]
tok_end_positions = []
for x in end_positions:
if x < len(context) - 1:
tok_end_positions.append(orig_to_tok_index[x + 1] - 1)
else:
tok_end_positions.append(len(all_doc_tokens) - 1)
truncation = TruncationStrategy.ONLY_SECOND.value
padding_strategy = 'do_not_pad'
truncated_query = tokenizer.encode(
query,
add_special_tokens=False,
truncation=True,
max_length=max_query_length)
encoded_dict = tokenizer.encode_plus( # TODO(thom) update this logic
truncated_query,
all_doc_tokens,
truncation=truncation,
padding=padding_strategy,
max_length=max_length,
return_overflowing_tokens=True,
return_token_type_ids=True,
)
tokens = encoded_dict['input_ids']
type_ids = encoded_dict['token_type_ids']
attn_mask = encoded_dict['attention_mask']
# find new start_positions/end_positions, considering
# 1. we add query tokens at the beginning
# 2. special tokens
doc_offset = len(truncated_query) + sequence_added_tokens
new_start_positions = [
x + doc_offset for x in tok_start_positions
if (x + doc_offset) < max_length - 1
]
new_end_positions = [
x + doc_offset if
(x + doc_offset) < max_length - 1 else max_length - 2
for x in tok_end_positions
]
new_end_positions = new_end_positions[:len(new_start_positions)]
label_mask = [0] * doc_offset + [1] * (len(tokens) - doc_offset
- 1) + [0]
assert all(label_mask[p] != 0 for p in new_start_positions)
assert all(label_mask[p] != 0 for p in new_end_positions)
assert len(label_mask) == len(tokens)
seq_len = len(tokens)
match_labels = torch.zeros([seq_len, seq_len], dtype=torch.long)
for start, end in zip(new_start_positions, new_end_positions):
if start >= seq_len or end >= seq_len:
continue
match_labels[start, end] = 1
return [
torch.LongTensor(tokens),
torch.LongTensor(attn_mask),
torch.LongTensor(type_ids),
torch.LongTensor(label_mask),
match_labels,
]
def collate_to_max_length_roberta(batch):
"""
adapted form https://github.com/ShannonAI/mrc-for-flat-nested-ner
pad to maximum length of this batch
Args:
batch: a batch of samples, each contains a list of field data(Tensor):
tokens, token_type_ids, start_labels, end_labels, start_label_mask,
end_label_mask, match_labels, sample_idx, label_idx
Returns:
output: list of field batched data, which shape is [batch, max_length]
"""
batch_size = len(batch)
max_length = max(x[0].shape[0] for x in batch)
output = []
for field_idx in range(4):
if field_idx == 0:
pad_output = torch.full([batch_size, max_length],
1,
dtype=batch[0][field_idx].dtype)
else:
pad_output = torch.full([batch_size, max_length],
0,
dtype=batch[0][field_idx].dtype)
for sample_idx in range(batch_size):
data = batch[sample_idx][field_idx]
pad_output[sample_idx][:data.shape[0]] = data
output.append(pad_output)
pad_match_labels = torch.zeros([batch_size, max_length, max_length],
dtype=torch.long)
for sample_idx in range(batch_size):
data = batch[sample_idx][4]
pad_match_labels[sample_idx, :data.shape[1], :data.shape[1]] = data
output.append(pad_match_labels)
return output

View File

@@ -37,24 +37,31 @@ from modelscope.utils.torch_utils import is_dist
class CustomCheckpointProcessor(CheckpointProcessor):
def __init__(self, modifier_token, modifier_token_id):
def __init__(self,
modifier_token,
modifier_token_id,
torch_type=torch.float32):
"""Checkpoint processor for custom diffusion.
Args:
modifier_token: The token to use as a modifier for the concept.
modifier_token_id: The modifier token id for the concept.
torch_type: The torch type, default is float32.
"""
self.modifier_token = modifier_token
self.modifier_token_id = modifier_token_id
self.torch_type = torch_type
def save_checkpoints(self,
trainer,
checkpoint_path_prefix,
output_dir,
meta=None):
meta=None,
save_optimizers=True):
"""Save the state dict for custom diffusion model.
"""
trainer.model.unet = trainer.model.unet.to(torch.float32)
trainer.model.unet = trainer.model.unet.to(self.torch_type)
trainer.model.unet.save_attn_procs(output_dir)
learned_embeds = trainer.model.text_encoder.get_input_embeddings(
@@ -281,6 +288,7 @@ class CustomDiffusionTrainer(EpochBasedTrainer):
instance_prompt = kwargs.pop('instance_prompt', 'a photo of sks dog')
class_prompt = kwargs.pop('class_prompt', 'dog')
class_data_dir = kwargs.pop('class_data_dir', '/tmp/class_data')
self.torch_type = kwargs.pop('torch_type', torch.float32)
self.real_prior = kwargs.pop('real_prior', False)
self.num_class_images = kwargs.pop('num_class_images', 200)
self.resolution = kwargs.pop('resolution', 512)
@@ -387,7 +395,7 @@ class CustomDiffusionTrainer(EpochBasedTrainer):
self.hooks))[0]
ckpt_hook.set_processor(
CustomCheckpointProcessor(self.modifier_token,
self.modifier_token_id))
self.modifier_token_id, self.torch_type))
# Add new Custom Diffusion weights to the attention layers
attention_class = CustomDiffusionAttnProcessor
@@ -477,7 +485,7 @@ class CustomDiffusionTrainer(EpochBasedTrainer):
size=self.resolution,
mask_size=self.model.vae.encode(
torch.randn(1, 3, self.resolution,
self.resolution).to(dtype=torch.float32).to(
self.resolution).to(dtype=self.torch_type).to(
self.device)).latent_dist.sample().size()[-1],
center_crop=self.center_crop,
num_class_images=self.num_class_images,
@@ -534,8 +542,8 @@ class CustomDiffusionTrainer(EpochBasedTrainer):
if cur_class_images < self.num_class_images:
pipeline = DiffusionPipeline.from_pretrained(
self.model_dir,
torch_dtype=torch.float32,
safety_checker=None,
torch_dtype=self.torch_type,
revision=None,
)
pipeline.set_progress_bar_config(disable=True)
@@ -656,7 +664,7 @@ class CustomDiffusionTrainer(EpochBasedTrainer):
batch = next(self.iter_train_dataloader)
# Convert images to latent space
latents = self.model.vae.encode(batch['pixel_values'].to(
dtype=torch.float32).to(self.device)).latent_dist.sample()
dtype=self.torch_type).to(self.device)).latent_dist.sample()
latents = latents * self.model.vae.config.scaling_factor
# Sample noise that we'll add to the latents

View File

@@ -32,8 +32,16 @@ from modelscope.utils.torch_utils import is_dist
class DreamboothCheckpointProcessor(CheckpointProcessor):
def __init__(self, model_dir):
def __init__(self, model_dir, torch_type=torch.float32):
"""Checkpoint processor for dreambooth diffusion.
Args:
model_dir: The model id or local model dir.
torch_type: The torch type, default is float32.
"""
self.model_dir = model_dir
self.torch_type = torch_type
def save_checkpoints(self,
trainer,
@@ -49,6 +57,7 @@ class DreamboothCheckpointProcessor(CheckpointProcessor):
pipeline = DiffusionPipeline.from_pretrained(
self.model_dir,
unet=trainer.model.unet,
torch_type=self.torch_type,
**pipeline_args,
)
scheduler_args = {}
@@ -174,6 +183,7 @@ class DreamboothDiffusionTrainer(EpochBasedTrainer):
prior_loss_weight: the weight of the prior loss.
"""
self.torch_type = kwargs.pop('torch_type', torch.float32)
self.with_prior_preservation = kwargs.pop('with_prior_preservation',
False)
self.instance_prompt = kwargs.pop('instance_prompt',
@@ -219,7 +229,7 @@ class DreamboothDiffusionTrainer(EpochBasedTrainer):
warnings.warn('Multiple GPU inference not yet supported.')
pipeline = DiffusionPipeline.from_pretrained(
self.model_dir,
torch_dtype=torch.float32,
torch_dtype=self.torch_type,
safety_checker=None,
revision=None,
)
@@ -309,7 +319,8 @@ class DreamboothDiffusionTrainer(EpochBasedTrainer):
input_ids = batch['input_ids'].to(self.device)
with torch.no_grad():
latents = self.model.vae.encode(
target_prior.to(dtype=torch.float32)).latent_dist.sample()
target_prior.to(
dtype=self.torch_type)).latent_dist.sample()
latents = latents * self.model.vae.config.scaling_factor
# Sample noise that we'll add to the latents

View File

@@ -17,6 +17,15 @@ from modelscope.utils.config import ConfigDict
class LoraDiffusionCheckpointProcessor(CheckpointProcessor):
def __init__(self, torch_type=torch.float32):
"""Checkpoint processor for lora diffusion.
Args:
torch_type: The torch type, default is float32.
"""
self.torch_type = torch_type
def save_checkpoints(self,
trainer,
checkpoint_path_prefix,
@@ -25,7 +34,7 @@ class LoraDiffusionCheckpointProcessor(CheckpointProcessor):
save_optimizers=True):
"""Save the state dict for lora tune model.
"""
trainer.model.unet = trainer.model.unet.to(torch.float32)
trainer.model.unet = trainer.model.unet.to(self.torch_type)
trainer.model.unet.save_attn_procs(output_dir)
@@ -38,15 +47,18 @@ class LoraDiffusionTrainer(EpochBasedTrainer):
Args:
lora_rank: The rank size of lora intermediate linear.
torch_type: The torch type, default is float32.
"""
lora_rank = kwargs.pop('lora_rank', 4)
torch_type = kwargs.pop('torch_type', torch.float32)
# set lora save checkpoint processor
ckpt_hook = list(
filter(lambda hook: isinstance(hook, CheckpointHook),
self.hooks))[0]
ckpt_hook.set_processor(LoraDiffusionCheckpointProcessor())
ckpt_hook.set_processor(
LoraDiffusionCheckpointProcessor(torch_type=torch_type))
# Set correct lora layers
lora_attn_procs = {}
for name in self.model.unet.attn_processors.keys():

View File

@@ -213,6 +213,7 @@ class NLPTasks(object):
document_grounded_dialog_retrieval = 'document-grounded-dialog-retrieval'
document_grounded_dialog_rerank = 'document-grounded-dialog-rerank'
document_grounded_dialog_generate = 'document-grounded-dialog-generate'
machine_reading_comprehension = 'machine-reading-comprehension'
class AudioTasks(object):
@@ -255,6 +256,8 @@ class MultiModalTasks(object):
text_to_video_synthesis = 'text-to-video-synthesis'
efficient_diffusion_tuning = 'efficient-diffusion-tuning'
multimodal_dialogue = 'multimodal-dialogue'
image_to_video = 'image-to-video'
video_to_video = 'video-to-video'
class ScienceTasks(object):

View File

@@ -13,39 +13,45 @@ class EfficientDiffusionTuningTest(unittest.TestCase):
def setUp(self) -> None:
self.task = Tasks.efficient_diffusion_tuning
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_efficient_diffusion_tuning_lora_run_pipeline(self):
model_id = 'damo/multi-modal_efficient-diffusion-tuning-lora'
model_revision = 'v1.0.2'
inputs = {'prompt': 'pale golden rod circle with old lace background'}
edt_pipeline = pipeline(self.task, model_id)
edt_pipeline = pipeline(
self.task, model_id, model_revision=model_revision)
result = edt_pipeline(inputs)
print(f'Efficient-diffusion-tuning-lora output: {result}.')
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_efficient_diffusion_tuning_lora_load_model_from_pretrained(self):
model_id = 'damo/multi-modal_efficient-diffusion-tuning-lora'
model = Model.from_pretrained(model_id)
model_revision = 'v1.0.2'
model = Model.from_pretrained(model_id, model_revision=model_revision)
self.assertTrue(model.__class__ == EfficientStableDiffusion)
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_efficient_diffusion_tuning_control_lora_run_pipeline(self):
# TODO: to be fixed in the future
model_id = 'damo/multi-modal_efficient-diffusion-tuning-control-lora'
model_revision = 'v1.0.2'
inputs = {
'prompt':
'pale golden rod circle with old lace background',
'cond':
'data/test/images/efficient_diffusion_tuning_sd_control_lora_source.png'
}
edt_pipeline = pipeline(self.task, model_id)
edt_pipeline = pipeline(
self.task, model_id, model_revision=model_revision)
result = edt_pipeline(inputs)
print(f'Efficient-diffusion-tuning-control-lora output: {result}.')
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_efficient_diffusion_tuning_control_lora_load_model_from_pretrained(
self):
model_id = 'damo/multi-modal_efficient-diffusion-tuning-control-lora'
model = Model.from_pretrained(model_id)
model_revision = 'v1.0.2'
model = Model.from_pretrained(model_id, model_revision=model_revision)
self.assertTrue(model.__class__ == EfficientStableDiffusion)

View File

@@ -16,14 +16,16 @@ class EfficientDiffusionTuningTestSwift(unittest.TestCase):
def setUp(self) -> None:
self.task = Tasks.efficient_diffusion_tuning
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_efficient_diffusion_tuning_swift_lora_run_pipeline(self):
model_id = 'damo/multi-modal_efficient-diffusion-tuning-swift-lora'
model_revision = 'v1.0.2'
inputs = {
'prompt':
'a street scene with a cafe and a restaurant sign in anime style'
}
sd_tuner_pipeline = pipeline(self.task, model_id)
sd_tuner_pipeline = pipeline(
self.task, model_id, model_revision=model_revision)
result = sd_tuner_pipeline(inputs, generator_seed=0)
output_image_path = tempfile.NamedTemporaryFile(suffix='.png').name
cv2.imwrite(output_image_path, result['output_imgs'][0])
@@ -31,21 +33,24 @@ class EfficientDiffusionTuningTestSwift(unittest.TestCase):
f'Efficient-diffusion-tuning-swift-lora output: {output_image_path}'
)
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_efficient_diffusion_tuning_swift_lora_load_model_from_pretrained(
self):
model_id = 'damo/multi-modal_efficient-diffusion-tuning-swift-lora'
model = Model.from_pretrained(model_id)
model_revision = 'v1.0.2'
model = Model.from_pretrained(model_id, model_revision=model_revision)
self.assertTrue(model.__class__ == EfficientStableDiffusion)
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_efficient_diffusion_tuning_swift_adapter_run_pipeline(self):
model_id = 'damo/multi-modal_efficient-diffusion-tuning-swift-adapter'
model_revision = 'v1.0.2'
inputs = {
'prompt':
'a street scene with a cafe and a restaurant sign in anime style'
}
sd_tuner_pipeline = pipeline(self.task, model_id)
sd_tuner_pipeline = pipeline(
self.task, model_id, model_revision=model_revision)
result = sd_tuner_pipeline(inputs, generator_seed=0)
output_image_path = tempfile.NamedTemporaryFile(suffix='.png').name
cv2.imwrite(output_image_path, result['output_imgs'][0])
@@ -53,21 +58,24 @@ class EfficientDiffusionTuningTestSwift(unittest.TestCase):
f'Efficient-diffusion-tuning-swift-adapter output: {output_image_path}'
)
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_efficient_diffusion_tuning_swift_adapter_load_model_from_pretrained(
self):
model_id = 'damo/multi-modal_efficient-diffusion-tuning-swift-adapter'
model = Model.from_pretrained(model_id)
model_revision = 'v1.0.2'
model = Model.from_pretrained(model_id, model_revision=model_revision)
self.assertTrue(model.__class__ == EfficientStableDiffusion)
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_efficient_diffusion_tuning_swift_prompt_run_pipeline(self):
model_id = 'damo/multi-modal_efficient-diffusion-tuning-swift-prompt'
model_revision = 'v1.0.2'
inputs = {
'prompt':
'a street scene with a cafe and a restaurant sign in anime style'
}
sd_tuner_pipeline = pipeline(self.task, model_id)
sd_tuner_pipeline = pipeline(
self.task, model_id, model_revision=model_revision)
result = sd_tuner_pipeline(inputs, generator_seed=0)
output_image_path = tempfile.NamedTemporaryFile(suffix='.png').name
cv2.imwrite(output_image_path, result['output_imgs'][0])
@@ -75,11 +83,12 @@ class EfficientDiffusionTuningTestSwift(unittest.TestCase):
f'Efficient-diffusion-tuning-swift-prompt output: {output_image_path}'
)
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_efficient_diffusion_tuning_swift_prompt_load_model_from_pretrained(
self):
model_id = 'damo/multi-modal_efficient-diffusion-tuning-swift-prompt'
model = Model.from_pretrained(model_id)
model_revision = 'v1.0.2'
model = Model.from_pretrained(model_id, model_revision=model_revision)
self.assertTrue(model.__class__ == EfficientStableDiffusion)

View File

@@ -0,0 +1,28 @@
import sys
import unittest
from modelscope.models import Model
from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline
from modelscope.utils.constant import DownloadMode, Tasks
from modelscope.utils.test_utils import test_level
class Image2VideoTest(unittest.TestCase):
def setUp(self) -> None:
self.task = Tasks.image_to_video
self.model_id = 'damo/Image-to-Video'
self.path = 'https://video-generation-wulanchabu.oss-cn-wulanchabu.aliyuncs.com/baishao/test.jpeg'
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_model_from_modelhub(self):
pipe = pipeline(task=self.task, model=self.model_id)
output_video_path = pipe(
self.path, output_video='./output.mp4')[OutputKeys.OUTPUT_VIDEO]
print(output_video_path)
if __name__ == '__main__':
unittest.main()

View File

@@ -0,0 +1,59 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import unittest
from modelscope.hub.snapshot_download import snapshot_download
from modelscope.models import Model
from modelscope.models.nlp import ModelForMachineReadingComprehension
from modelscope.pipelines import pipeline
from modelscope.pipelines.nlp import MachineReadingComprehensionForNERPipeline
from modelscope.preprocessors import \
MachineReadingComprehensionForNERPreprocessor
from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level
class MachineReadingComprehensionTest(unittest.TestCase):
sentence = 'Soccer - Japan get lucky win , China in surprise defeat .'
model_id = 'damo/nlp_roberta_machine-reading-comprehension_for-ner'
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_mrc_for_ner_by_direct_model_download(self):
cache_path = snapshot_download(self.model_id)
tokenizer = MachineReadingComprehensionForNERPreprocessor(cache_path)
model = ModelForMachineReadingComprehension.from_pretrained(cache_path)
pipeline1 = MachineReadingComprehensionForNERPipeline(
model, preprocessor=tokenizer)
pipeline2 = pipeline(
Tasks.machine_reading_comprehension,
model=model,
preprocessor=tokenizer)
print(f'sentence: {self.sentence}\n'
f'pipeline1:{pipeline1(input=self.sentence)}')
print()
print(f'pipeline2: {pipeline2(input=self.sentence)}')
# {'ORG': [], 'PER': [], 'LOC': [' Japan', ' China'], 'MISC': []}
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_mrc_for_ner_with_model_from_modelhub(self):
model = Model.from_pretrained(self.model_id)
tokenizer = MachineReadingComprehensionForNERPreprocessor(
model.model_dir)
pipeline_ins = pipeline(
task=Tasks.machine_reading_comprehension,
model=model,
preprocessor=tokenizer)
print(f'sentence: {self.sentence}\n'
f'pipeline:{pipeline_ins(input=self.sentence)}')
# {'ORG': [], 'PER': [], 'LOC': [' Japan', ' China'], 'MISC': []}
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_mrc_for_ner_with_model_name(self):
pipeline_ins = pipeline(
task=Tasks.machine_reading_comprehension, model=self.model_id)
print(pipeline_ins(input=self.sentence))
# {'ORG': [], 'PER': [], 'LOC': [' Japan', ' China'], 'MISC': []}
if __name__ == '__main__':
unittest.main()

View File

@@ -52,7 +52,7 @@ class Text2360PanoramaImageTest(unittest.TestCase):
print(
'pipeline: the output image path is {}'.format(output_image_path))
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_model_from_modelhub(self):
output_image_path = tempfile.NamedTemporaryFile(suffix='.png').name
pipeline_ins = pipeline(

View File

@@ -0,0 +1,32 @@
import sys
import unittest
from modelscope.models import Model
from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level
class Video2VideoTest(unittest.TestCase):
def setUp(self) -> None:
self.task = Tasks.video_to_video
self.model_id = 'damo/Video-to-Video'
self.path = 'https://video-generation-wulanchabu.oss-cn-wulanchabu.aliyuncs.com/baishao/test.mp4'
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_model_from_modelhub(self):
pipe = pipeline(task=self.task, model=self.model_id)
p_input = {
'video_path': self.path,
'text': 'A panda is surfing on the sea'
}
output_video_path = pipe(
p_input, output_video='./output.mp4')[OutputKeys.OUTPUT_VIDEO]
print(output_video_path)
if __name__ == '__main__':
unittest.main()

View File

@@ -125,16 +125,21 @@ def get_current_branch():
def get_modified_files():
cmd = ['git', 'diff', '--name-only', 'origin/master...']
cmd_output = run_command_get_output(cmd)
logger.info('Modified files: ')
logger.info(cmd_output)
if 'PR_CHANGED_FILES' in os.environ and os.environ[
'PR_CHANGED_FILES'] != '':
logger.info('Getting PR modified files.')
# get modify file from environment
diff_files = os.environ['PR_CHANGED_FILES'].replace('#', '\n')
else:
cmd = ['git', 'diff', '--name-only', 'origin/master...']
diff_files = run_command_get_output(cmd)
logger.info('Diff files: ')
logger.info(diff_files)
modified_files = []
# remove the deleted file.
for diff_file in cmd_output.splitlines():
if os.path.exists(diff_file):
modified_files.append(diff_file)
for diff_file in diff_files.splitlines():
if os.path.exists(diff_file.strip()):
modified_files.append(diff_file.strip())
return modified_files

View File

@@ -42,6 +42,7 @@ class TestEfficientDiffusionTuningTrainer(unittest.TestCase):
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_efficient_diffusion_tuning_lora_train(self):
model_id = 'damo/multi-modal_efficient-diffusion-tuning-lora'
model_revision = 'v1.0.2'
def cfg_modify_fn(cfg):
cfg.train.max_epochs = self.max_epochs
@@ -51,6 +52,7 @@ class TestEfficientDiffusionTuningTrainer(unittest.TestCase):
kwargs = dict(
model=model_id,
model_revision=model_revision,
work_dir=self.tmp_dir,
train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset,
@@ -70,6 +72,7 @@ class TestEfficientDiffusionTuningTrainer(unittest.TestCase):
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_efficient_diffusion_tuning_lora_eval(self):
model_id = 'damo/multi-modal_efficient-diffusion-tuning-lora'
model_revision = 'v1.0.2'
def cfg_modify_fn(cfg):
cfg.model.inference = False
@@ -77,6 +80,7 @@ class TestEfficientDiffusionTuningTrainer(unittest.TestCase):
kwargs = dict(
model=model_id,
model_revision=model_revision,
work_dir=self.tmp_dir,
train_dataset=None,
eval_dataset=self.eval_dataset,
@@ -90,6 +94,7 @@ class TestEfficientDiffusionTuningTrainer(unittest.TestCase):
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_efficient_diffusion_tuning_control_lora_train(self):
model_id = 'damo/multi-modal_efficient-diffusion-tuning-control-lora'
model_revision = 'v1.0.2'
def cfg_modify_fn(cfg):
cfg.train.max_epochs = self.max_epochs
@@ -99,6 +104,7 @@ class TestEfficientDiffusionTuningTrainer(unittest.TestCase):
kwargs = dict(
model=model_id,
model_revision=model_revision,
work_dir=self.tmp_dir,
train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset,
@@ -119,6 +125,7 @@ class TestEfficientDiffusionTuningTrainer(unittest.TestCase):
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_efficient_diffusion_tuning_control_lora_eval(self):
model_id = 'damo/multi-modal_efficient-diffusion-tuning-control-lora'
model_revision = 'v1.0.2'
def cfg_modify_fn(cfg):
cfg.model.inference = False
@@ -126,6 +133,7 @@ class TestEfficientDiffusionTuningTrainer(unittest.TestCase):
kwargs = dict(
model=model_id,
model_revision=model_revision,
work_dir=self.tmp_dir,
train_dataset=None,
eval_dataset=self.eval_dataset,

View File

@@ -33,9 +33,10 @@ class TestEfficientDiffusionTuningTrainerSwift(unittest.TestCase):
shutil.rmtree(self.tmp_dir)
super().tearDown()
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_efficient_diffusion_tuning_swift_lora_train(self):
model_id = 'damo/multi-modal_efficient-diffusion-tuning-swift-lora'
model_revision = 'v1.0.2'
def cfg_modify_fn(cfg):
cfg.train.max_epochs = self.max_epochs
@@ -47,6 +48,7 @@ class TestEfficientDiffusionTuningTrainerSwift(unittest.TestCase):
kwargs = dict(
model=model_id,
model_revision=model_revision,
work_dir=self.tmp_dir,
train_dataset=self.train_dataset,
cfg_modify_fn=cfg_modify_fn)
@@ -60,9 +62,10 @@ class TestEfficientDiffusionTuningTrainerSwift(unittest.TestCase):
self.assertIn(f'{trainer.timestamp}.log.json', results_files)
self.assertIn(f'epoch_{self.max_epochs}.pth', results_files)
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_efficient_diffusion_tuning_swift_adapter_train(self):
model_id = 'damo/multi-modal_efficient-diffusion-tuning-swift-adapter'
model_revision = 'v1.0.2'
def cfg_modify_fn(cfg):
cfg.train.max_epochs = self.max_epochs
@@ -74,6 +77,7 @@ class TestEfficientDiffusionTuningTrainerSwift(unittest.TestCase):
kwargs = dict(
model=model_id,
model_revision=model_revision,
work_dir=self.tmp_dir,
train_dataset=self.train_dataset,
cfg_modify_fn=cfg_modify_fn)
@@ -87,9 +91,10 @@ class TestEfficientDiffusionTuningTrainerSwift(unittest.TestCase):
self.assertIn(f'{trainer.timestamp}.log.json', results_files)
self.assertIn(f'epoch_{self.max_epochs}.pth', results_files)
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_efficient_diffusion_tuning_swift_prompt_train(self):
model_id = 'damo/multi-modal_efficient-diffusion-tuning-swift-prompt'
model_revision = 'v1.0.2'
def cfg_modify_fn(cfg):
cfg.train.max_epochs = self.max_epochs
@@ -101,6 +106,7 @@ class TestEfficientDiffusionTuningTrainerSwift(unittest.TestCase):
kwargs = dict(
model=model_id,
model_revision=model_revision,
work_dir=self.tmp_dir,
train_dataset=self.train_dataset,
cfg_modify_fn=cfg_modify_fn)