mirror of
https://github.com/modelscope/modelscope.git
synced 2026-02-24 20:19:51 +01:00
merge master-github
This commit is contained in:
@@ -4,11 +4,15 @@ CODE_DIR=$PWD
|
||||
CODE_DIR_IN_CONTAINER=/Maas-lib
|
||||
echo "$USER"
|
||||
gpus='0,1 2,3 4,5 6,7'
|
||||
cpu_sets='45-58 31-44 16-30 0-15'
|
||||
cpu_sets='0-15 16-31 32-47 48-63'
|
||||
cpu_sets_arr=($cpu_sets)
|
||||
is_get_file_lock=false
|
||||
CI_COMMAND=${CI_COMMAND:-bash .dev_scripts/ci_container_test.sh python tests/run.py --parallel 2 --run_config tests/run_config.yaml}
|
||||
echo "ci command: $CI_COMMAND"
|
||||
PR_CHANGED_FILES="${PR_CHANGED_FILES:-''}"
|
||||
echo "PR modified files: $PR_CHANGED_FILES"
|
||||
PR_CHANGED_FILES=${PR_CHANGED_FILES//[ ]/#}
|
||||
echo "PR_CHANGED_FILES: $PR_CHANGED_FILES"
|
||||
idx=0
|
||||
for gpu in $gpus
|
||||
do
|
||||
@@ -42,6 +46,7 @@ do
|
||||
-e MODELSCOPE_ENVIRONMENT='ci' \
|
||||
-e TEST_UPLOAD_MS_TOKEN=$TEST_UPLOAD_MS_TOKEN \
|
||||
-e MODEL_TAG_URL=$MODEL_TAG_URL \
|
||||
-e PR_CHANGED_FILES=$PR_CHANGED_FILES \
|
||||
--workdir=$CODE_DIR_IN_CONTAINER \
|
||||
${IMAGE_NAME}:${IMAGE_VERSION} \
|
||||
$CI_COMMAND
|
||||
@@ -64,6 +69,7 @@ do
|
||||
-e MODELSCOPE_ENVIRONMENT='ci' \
|
||||
-e TEST_UPLOAD_MS_TOKEN=$TEST_UPLOAD_MS_TOKEN \
|
||||
-e MODEL_TAG_URL=$MODEL_TAG_URL \
|
||||
-e PR_CHANGED_FILES=$PR_CHANGED_FILES \
|
||||
--workdir=$CODE_DIR_IN_CONTAINER \
|
||||
${IMAGE_NAME}:${IMAGE_VERSION} \
|
||||
$CI_COMMAND
|
||||
|
||||
13
.github/workflows/citest.yaml
vendored
13
.github/workflows/citest.yaml
vendored
@@ -39,7 +39,7 @@ concurrency:
|
||||
jobs:
|
||||
unittest:
|
||||
# The type of runner that the job will run on
|
||||
runs-on: [modelscope-self-hosted]
|
||||
runs-on: [modelscope-self-hosted-us]
|
||||
timeout-minutes: 240
|
||||
steps:
|
||||
- name: ResetFileMode
|
||||
@@ -52,10 +52,19 @@ jobs:
|
||||
sudo chown -R $USER:$USER $ACTION_RUNNER_DIR
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v2
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
lfs: 'true'
|
||||
submodules: 'true'
|
||||
fetch-depth: ${{ github.event_name == 'pull_request' && 2 || 0 }}
|
||||
- name: Get changed files
|
||||
id: changed-files
|
||||
run: |
|
||||
if ${{ github.event_name == 'pull_request' }}; then
|
||||
echo "PR_CHANGED_FILES=$(git diff --name-only -r HEAD^1 HEAD | xargs)" >> $GITHUB_ENV
|
||||
else
|
||||
echo "PR_CHANGED_FILES=$(git diff --name-only ${{ github.event.before }} ${{ github.event.after }} | xargs)" >> $GITHUB_ENV
|
||||
fi
|
||||
- name: Checkout LFS objects
|
||||
run: git lfs checkout
|
||||
- name: Run unittest
|
||||
|
||||
2
.github/workflows/daily_regression.yaml
vendored
2
.github/workflows/daily_regression.yaml
vendored
@@ -12,7 +12,7 @@ concurrency:
|
||||
jobs:
|
||||
unittest:
|
||||
# The type of runner that the job will run on
|
||||
runs-on: [modelscope-self-hosted]
|
||||
runs-on: [modelscope-self-hosted-us]
|
||||
steps:
|
||||
- name: ResetFileMode
|
||||
shell: bash
|
||||
|
||||
@@ -2,6 +2,7 @@ import os
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
|
||||
from modelscope.metainfo import Trainers
|
||||
from modelscope.msdatasets import MsDataset
|
||||
@@ -95,6 +96,12 @@ class StableDiffusionCustomArguments(TrainingArgs):
|
||||
'help': 'Path to json containing multiple concepts.',
|
||||
})
|
||||
|
||||
torch_type: str = field(
|
||||
default='float32',
|
||||
metadata={
|
||||
'help': ' The torch type, default is float32.',
|
||||
})
|
||||
|
||||
|
||||
training_args = StableDiffusionCustomArguments(
|
||||
task='text-to-image-synthesis').parse_cli()
|
||||
@@ -148,6 +155,8 @@ kwargs = dict(
|
||||
work_dir=training_args.work_dir,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=validation_dataset,
|
||||
torch_type=torch.float16
|
||||
if args.torch_type == 'float16' else torch.float32,
|
||||
cfg_modify_fn=cfg_modify_fn)
|
||||
|
||||
# build trainer and training
|
||||
@@ -159,7 +168,7 @@ pipe = pipeline(
|
||||
task=Tasks.text_to_image_synthesis,
|
||||
model=training_args.model,
|
||||
custom_dir=training_args.work_dir + '/output',
|
||||
modifier_token='<new1>+<new2>',
|
||||
modifier_token=args.modifier_token,
|
||||
model_revision=args.model_revision)
|
||||
|
||||
output = pipe({'text': args.instance_prompt})
|
||||
|
||||
@@ -7,11 +7,12 @@ PYTHONPATH=. torchrun examples/pytorch/stable_diffusion/custom/finetune_stable_d
|
||||
--class_data_dir './tmp/class_data' \
|
||||
--train_dataset_name 'buptwq/lora-stable-diffusion-finetune-dog' \
|
||||
--max_epochs 250 \
|
||||
--modifier_token "<new1>+<new2>" \
|
||||
--modifier_token "<new1>" \
|
||||
--num_class_images=200 \
|
||||
--save_ckpt_strategy 'by_epoch' \
|
||||
--logging_interval 1 \
|
||||
--train.dataloader.workers_per_gpu 0 \
|
||||
--evaluation.dataloader.workers_per_gpu 0 \
|
||||
--train.optimizer.lr 1e-5 \
|
||||
--torch_type 'float32' \
|
||||
--use_model_config true
|
||||
|
||||
@@ -2,6 +2,7 @@ import os
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
|
||||
from modelscope.metainfo import Trainers
|
||||
from modelscope.msdatasets import MsDataset
|
||||
@@ -59,6 +60,12 @@ class StableDiffusionDreamboothArguments(TrainingArgs):
|
||||
'help': 'The pipeline prompt.',
|
||||
})
|
||||
|
||||
torch_type: str = field(
|
||||
default='float32',
|
||||
metadata={
|
||||
'help': ' The torch type, default is float32.',
|
||||
})
|
||||
|
||||
|
||||
training_args = StableDiffusionDreamboothArguments(
|
||||
task='text-to-image-synthesis').parse_cli()
|
||||
@@ -106,6 +113,8 @@ kwargs = dict(
|
||||
resolution=args.resolution,
|
||||
prior_loss_weight=args.prior_loss_weight,
|
||||
prompt=args.prompt,
|
||||
torch_type=torch.float16
|
||||
if args.torch_type == 'float16' else torch.float32,
|
||||
cfg_modify_fn=cfg_modify_fn)
|
||||
|
||||
# build trainer and training
|
||||
|
||||
@@ -17,4 +17,5 @@ PYTHONPATH=. torchrun examples/pytorch/stable_diffusion/dreambooth/finetune_stab
|
||||
--train.dataloader.workers_per_gpu 0 \
|
||||
--evaluation.dataloader.workers_per_gpu 0 \
|
||||
--train.optimizer.lr 5e-6 \
|
||||
--torch_type 'float32' \
|
||||
--use_model_config true
|
||||
|
||||
@@ -2,6 +2,7 @@ import os
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
|
||||
from modelscope.metainfo import Trainers
|
||||
from modelscope.msdatasets import MsDataset
|
||||
@@ -25,6 +26,12 @@ class StableDiffusionLoraArguments(TrainingArgs):
|
||||
'help': 'The rank size of lora intermediate linear.',
|
||||
})
|
||||
|
||||
torch_type: str = field(
|
||||
default='float32',
|
||||
metadata={
|
||||
'help': ' The torch type, default is float32.',
|
||||
})
|
||||
|
||||
|
||||
training_args = StableDiffusionLoraArguments(
|
||||
task='text-to-image-synthesis').parse_cli()
|
||||
@@ -66,6 +73,8 @@ kwargs = dict(
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=validation_dataset,
|
||||
lora_rank=args.lora_rank,
|
||||
torch_type=torch.float16
|
||||
if args.torch_type == 'float16' else torch.float32,
|
||||
cfg_modify_fn=cfg_modify_fn)
|
||||
|
||||
# build trainer and training
|
||||
|
||||
@@ -5,10 +5,11 @@ PYTHONPATH=. torchrun examples/pytorch/stable_diffusion/lora/finetune_stable_dif
|
||||
--work_dir './tmp/lora_diffusion' \
|
||||
--train_dataset_name 'buptwq/lora-stable-diffusion-finetune' \
|
||||
--max_epochs 100 \
|
||||
--lora_rank 4 \
|
||||
--lora_rank 16 \
|
||||
--save_ckpt_strategy 'by_epoch' \
|
||||
--logging_interval 1 \
|
||||
--train.dataloader.workers_per_gpu 0 \
|
||||
--evaluation.dataloader.workers_per_gpu 0 \
|
||||
--train.optimizer.lr 1e-4 \
|
||||
--torch_type 'float16' \
|
||||
--use_model_config true
|
||||
|
||||
@@ -220,6 +220,8 @@ class Models(object):
|
||||
stable_diffusion = 'stable-diffusion'
|
||||
videocomposer = 'videocomposer'
|
||||
text_to_360panorama_image = 'text-to-360panorama-image'
|
||||
image_to_video_model = 'image-to-video-model'
|
||||
video_to_video_model = 'video-to-video-model'
|
||||
|
||||
# science models
|
||||
unifold = 'unifold'
|
||||
@@ -235,6 +237,7 @@ class TaskModels(object):
|
||||
feature_extraction = 'feature-extraction'
|
||||
text_generation = 'text-generation'
|
||||
text_ranking = 'text-ranking'
|
||||
machine_reading_comprehension = 'machine-reading-comprehension'
|
||||
|
||||
|
||||
class Heads(object):
|
||||
@@ -487,6 +490,7 @@ class Pipelines(object):
|
||||
document_grounded_dialog_rerank = 'document-grounded-dialog-rerank'
|
||||
document_grounded_dialog_generate = 'document-grounded-dialog-generate'
|
||||
language_identification = 'language_identification'
|
||||
machine_reading_comprehension_for_ner = 'machine-reading-comprehension-for-ner'
|
||||
|
||||
# audio tasks
|
||||
sambert_hifigan_tts = 'sambert-hifigan-tts'
|
||||
@@ -543,6 +547,8 @@ class Pipelines(object):
|
||||
efficient_diffusion_tuning = 'efficient-diffusion-tuning'
|
||||
multimodal_dialogue = 'multimodal-dialogue'
|
||||
llama2_text_generation_pipeline = 'llama2-text-generation-pipeline'
|
||||
image_to_video_task_pipeline = 'image-to-video-task-pipeline'
|
||||
video_to_video_pipeline = 'video-to-video-pipeline'
|
||||
|
||||
# science tasks
|
||||
protein_structure = 'unifold-protein-structure'
|
||||
@@ -1062,6 +1068,7 @@ class Preprocessors(object):
|
||||
document_grounded_dialog_retrieval = 'document-grounded-dialog-retrieval'
|
||||
document_grounded_dialog_rerank = 'document-grounded-dialog-rerank'
|
||||
document_grounded_dialog_generate = 'document-grounded-dialog-generate'
|
||||
machine_reading_comprehension_for_ner = 'machine-reading-comprehension-for-ner'
|
||||
|
||||
# audio preprocessor
|
||||
linear_aec_fbank = 'linear-aec-fbank'
|
||||
|
||||
55
modelscope/models/cv/image_face_fusion/facegan/face_gan.py
Normal file
55
modelscope/models/cv/image_face_fusion/facegan/face_gan.py
Normal file
@@ -0,0 +1,55 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .gpen_model import FullGenerator
|
||||
|
||||
|
||||
class GPEN(object):
|
||||
|
||||
def __init__(self,
|
||||
model_path,
|
||||
size=512,
|
||||
channel_multiplier=2,
|
||||
device=torch.device('cpu')):
|
||||
self.mfile = model_path
|
||||
self.n_mlp = 8
|
||||
self.resolution = size
|
||||
self.device = device
|
||||
self.load_model(channel_multiplier)
|
||||
|
||||
def load_model(self, channel_multiplier=2):
|
||||
self.model = FullGenerator(self.resolution, 512, self.n_mlp,
|
||||
channel_multiplier).to(self.device)
|
||||
pretrained_dict = torch.load(self.mfile)
|
||||
self.model.load_state_dict(pretrained_dict)
|
||||
self.model.eval()
|
||||
|
||||
def process(self, im):
|
||||
preds = []
|
||||
imt = self.img2tensor(im)
|
||||
imt = F.interpolate(imt, (self.resolution, self.resolution))
|
||||
|
||||
with torch.no_grad():
|
||||
img_out, __ = self.model(imt)
|
||||
|
||||
face = self.tensor2img(img_out)
|
||||
|
||||
return face, preds
|
||||
|
||||
def img2tensor(self, img):
|
||||
img_t = torch.from_numpy(img).to(self.device)
|
||||
img_t = (img_t / 255. - 0.5) / 0.5
|
||||
img_t = img_t.permute(2, 0, 1).unsqueeze(0).flip(1) # BGR->RGB
|
||||
return img_t
|
||||
|
||||
def tensor2img(self, image_tensor, pmax=255.0, imtype=np.uint8):
|
||||
image_tensor = image_tensor * 0.5 + 0.5
|
||||
image_tensor = image_tensor.squeeze(0).permute(1, 2,
|
||||
0).flip(2) # RGB->BGR
|
||||
image_numpy = np.clip(image_tensor.float().cpu().numpy(), 0, 1) * pmax
|
||||
|
||||
return image_numpy.astype(imtype)
|
||||
@@ -1,93 +0,0 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
|
||||
from .model import FullGenerator
|
||||
|
||||
|
||||
class GANWrap(object):
|
||||
|
||||
def __init__(self,
|
||||
model_path,
|
||||
size=256,
|
||||
channel_multiplier=1,
|
||||
device='cpu'):
|
||||
self.device = device
|
||||
self.mfile = model_path
|
||||
self.transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5),
|
||||
inplace=True),
|
||||
])
|
||||
self.batchSize = 2
|
||||
self.n_mlp = 8
|
||||
self.resolution = size
|
||||
self.load_model(channel_multiplier)
|
||||
|
||||
def load_model(self, channel_multiplier=2):
|
||||
self.model = FullGenerator(self.resolution, 512, self.n_mlp,
|
||||
channel_multiplier).to(self.device)
|
||||
pretrained_dict = torch.load(
|
||||
self.mfile, map_location=torch.device('cpu'))
|
||||
self.model.load_state_dict(pretrained_dict)
|
||||
self.model.eval()
|
||||
|
||||
def process_tensor(self, img_t, return_face=True):
|
||||
b, c, h, w = img_t.shape
|
||||
img_t = F.interpolate(img_t, (self.resolution, self.resolution))
|
||||
|
||||
with torch.no_grad():
|
||||
out, __ = self.model(img_t)
|
||||
|
||||
out = F.interpolate(out, (w, h))
|
||||
return out
|
||||
|
||||
def process(self, ims, return_face=True):
|
||||
res = []
|
||||
faces = []
|
||||
for i in range(0, len(ims), self.batchSize):
|
||||
sizes = []
|
||||
imt = None
|
||||
for im in ims[i:i + self.batchSize]:
|
||||
sizes.append(im.shape[0])
|
||||
im = cv2.resize(im, (self.resolution, self.resolution))
|
||||
im_pil = Image.fromarray(im)
|
||||
imt = self.img2tensor(im_pil) if imt is None else torch.cat(
|
||||
(imt, self.img2tensor(im_pil)), dim=0)
|
||||
|
||||
imt = torch.flip(imt, [1])
|
||||
with torch.no_grad():
|
||||
img_outs, __ = self.model(imt)
|
||||
|
||||
for sz, img_out in zip(sizes, img_outs):
|
||||
img = self.tensor2img(img_out)
|
||||
if return_face:
|
||||
faces.append(img)
|
||||
img = cv2.resize(img, (sz, sz), interpolation=cv2.INTER_AREA)
|
||||
res.append(img)
|
||||
|
||||
return res, faces
|
||||
|
||||
def img2tensor(self, img):
|
||||
img_t = self.transform(img).to(self.device)
|
||||
img_t = torch.unsqueeze(img_t, 0)
|
||||
return img_t
|
||||
|
||||
def tensor2img(self, image_tensor, bytes=255.0, imtype=np.uint8):
|
||||
if image_tensor.dim() == 3:
|
||||
image_numpy = image_tensor.cpu().float().numpy()
|
||||
else:
|
||||
image_numpy = image_tensor[0].cpu().float().numpy()
|
||||
image_numpy = np.transpose(image_numpy, (1, 2, 0))
|
||||
image_numpy = image_numpy[:, :, ::-1]
|
||||
image_numpy = np.clip(
|
||||
image_numpy * np.asarray([0.5, 0.5, 0.5])
|
||||
+ np.asarray([0.5, 0.5, 0.5]), 0, 1)
|
||||
image_numpy = image_numpy * bytes
|
||||
return image_numpy.astype(imtype)
|
||||
@@ -1,17 +1,19 @@
|
||||
# The implementation is adopted from stylegan2-pytorch,
|
||||
# made public available under the MIT License at https://github.com/rosinality/stylegan2-pytorch/blob/master/model.py
|
||||
# The implementation is adopted from InsightFace_Pytorch, made publicly available under the MIT License
|
||||
# at https://github.com/yangxy/GPEN
|
||||
import functools
|
||||
import itertools
|
||||
import math
|
||||
import operator
|
||||
import random
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.autograd import Function
|
||||
from torch.nn import functional as F
|
||||
|
||||
from .op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d
|
||||
|
||||
isconcat = True
|
||||
sss = 2 if isconcat else 1
|
||||
ratio = 2
|
||||
|
||||
|
||||
class PixelNorm(nn.Module):
|
||||
@@ -306,6 +308,7 @@ class NoiseInjection(nn.Module):
|
||||
def forward(self, image, noise=None):
|
||||
|
||||
if noise is not None:
|
||||
# print(image.shape, noise.shape)
|
||||
if isconcat:
|
||||
return torch.cat((image, self.weight * noise), dim=1) # concat
|
||||
return image + self.weight * noise
|
||||
@@ -356,7 +359,8 @@ class StyledConv(nn.Module):
|
||||
)
|
||||
|
||||
self.noise = NoiseInjection()
|
||||
self.activate = FusedLeakyReLU(out_channel * sss)
|
||||
feat_multiplier = 2
|
||||
self.activate = FusedLeakyReLU(out_channel * feat_multiplier)
|
||||
|
||||
def forward(self, input, style, noise=None):
|
||||
out = self.conv(input, style)
|
||||
@@ -405,12 +409,14 @@ class Generator(nn.Module):
|
||||
channel_multiplier=2,
|
||||
blur_kernel=[1, 3, 3, 1],
|
||||
lr_mlp=0.01,
|
||||
narrow=1,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.size = size
|
||||
self.n_mlp = n_mlp
|
||||
self.style_dim = style_dim
|
||||
self.feat_multiplier = 2
|
||||
|
||||
layers = [PixelNorm()]
|
||||
|
||||
@@ -425,15 +431,16 @@ class Generator(nn.Module):
|
||||
self.style = nn.Sequential(*layers)
|
||||
|
||||
self.channels = {
|
||||
4: 512 // ratio,
|
||||
8: 512 // ratio,
|
||||
16: 512 // ratio,
|
||||
32: 512 // ratio,
|
||||
64: 256 // ratio * channel_multiplier,
|
||||
128: 128 // ratio * channel_multiplier,
|
||||
256: 64 // ratio * channel_multiplier,
|
||||
512: 32 // ratio * channel_multiplier,
|
||||
1024: 16 // ratio * channel_multiplier,
|
||||
4: int(512 * narrow),
|
||||
8: int(512 * narrow),
|
||||
16: int(512 * narrow),
|
||||
32: int(512 * narrow),
|
||||
64: int(256 * channel_multiplier * narrow),
|
||||
128: int(128 * channel_multiplier * narrow),
|
||||
256: int(64 * channel_multiplier * narrow),
|
||||
512: int(32 * channel_multiplier * narrow),
|
||||
1024: int(16 * channel_multiplier * narrow),
|
||||
2048: int(8 * channel_multiplier * narrow)
|
||||
}
|
||||
|
||||
self.input = ConstantInput(self.channels[4])
|
||||
@@ -443,7 +450,8 @@ class Generator(nn.Module):
|
||||
3,
|
||||
style_dim,
|
||||
blur_kernel=blur_kernel)
|
||||
self.to_rgb1 = ToRGB(self.channels[4] * sss, style_dim, upsample=False)
|
||||
self.to_rgb1 = ToRGB(
|
||||
self.channels[4] * self.feat_multiplier, style_dim, upsample=False)
|
||||
|
||||
self.log_size = int(math.log(size, 2))
|
||||
|
||||
@@ -458,23 +466,23 @@ class Generator(nn.Module):
|
||||
|
||||
self.convs.append(
|
||||
StyledConv(
|
||||
in_channel * sss,
|
||||
in_channel * self.feat_multiplier,
|
||||
out_channel,
|
||||
3,
|
||||
style_dim,
|
||||
upsample=True,
|
||||
blur_kernel=blur_kernel,
|
||||
))
|
||||
blur_kernel=blur_kernel))
|
||||
|
||||
self.convs.append(
|
||||
StyledConv(
|
||||
out_channel * sss,
|
||||
out_channel * self.feat_multiplier,
|
||||
out_channel,
|
||||
3,
|
||||
style_dim,
|
||||
blur_kernel=blur_kernel))
|
||||
|
||||
self.to_rgbs.append(ToRGB(out_channel * sss, style_dim))
|
||||
self.to_rgbs.append(
|
||||
ToRGB(out_channel * self.feat_multiplier, style_dim))
|
||||
|
||||
in_channel = out_channel
|
||||
|
||||
@@ -515,6 +523,9 @@ class Generator(nn.Module):
|
||||
styles = [self.style(s) for s in styles]
|
||||
|
||||
if noise is None:
|
||||
'''
|
||||
noise = [None] * (2 * (self.log_size - 2) + 1)
|
||||
'''
|
||||
noise = []
|
||||
batch = styles[0].shape[0]
|
||||
for i in range(self.n_mlp + 1):
|
||||
@@ -557,16 +568,14 @@ class Generator(nn.Module):
|
||||
skip = self.to_rgb1(out, latent[:, 1])
|
||||
|
||||
i = 1
|
||||
noise_i = 1
|
||||
|
||||
for conv1, conv2, to_rgb in zip(self.convs[::2], self.convs[1::2],
|
||||
self.to_rgbs):
|
||||
out = conv1(out, latent[:, i], noise=noise[(noise_i + 1) // 2])
|
||||
out = conv2(out, latent[:, i + 1], noise=noise[(noise_i + 2) // 2])
|
||||
for conv1, conv2, noise1, noise2, to_rgb in zip(
|
||||
self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2],
|
||||
self.to_rgbs):
|
||||
out = conv1(out, latent[:, i], noise=noise1)
|
||||
out = conv2(out, latent[:, i + 1], noise=noise2)
|
||||
skip = to_rgb(out, latent[:, i + 2], skip)
|
||||
|
||||
i += 2
|
||||
noise_i += 2
|
||||
|
||||
image = skip
|
||||
|
||||
@@ -652,21 +661,106 @@ class ResBlock(nn.Module):
|
||||
return out
|
||||
|
||||
|
||||
class FullGenerator(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size,
|
||||
style_dim,
|
||||
n_mlp,
|
||||
channel_multiplier=2,
|
||||
blur_kernel=[1, 3, 3, 1],
|
||||
lr_mlp=0.01,
|
||||
narrow=1,
|
||||
):
|
||||
super().__init__()
|
||||
channels = {
|
||||
4: int(512 * narrow),
|
||||
8: int(512 * narrow),
|
||||
16: int(512 * narrow),
|
||||
32: int(512 * narrow),
|
||||
64: int(256 * channel_multiplier * narrow),
|
||||
128: int(128 * channel_multiplier * narrow),
|
||||
256: int(64 * channel_multiplier * narrow),
|
||||
512: int(32 * channel_multiplier * narrow),
|
||||
1024: int(16 * channel_multiplier * narrow),
|
||||
2048: int(8 * channel_multiplier * narrow)
|
||||
}
|
||||
|
||||
self.log_size = int(math.log(size, 2))
|
||||
self.generator = Generator(
|
||||
size,
|
||||
style_dim,
|
||||
n_mlp,
|
||||
channel_multiplier=channel_multiplier,
|
||||
blur_kernel=blur_kernel,
|
||||
lr_mlp=lr_mlp,
|
||||
narrow=narrow)
|
||||
|
||||
conv = [ConvLayer(3, channels[size], 1)]
|
||||
self.ecd0 = nn.Sequential(*conv)
|
||||
in_channel = channels[size]
|
||||
|
||||
self.names = ['ecd%d' % i for i in range(self.log_size - 1)]
|
||||
for i in range(self.log_size, 2, -1):
|
||||
out_channel = channels[2**(i - 1)]
|
||||
conv = [ConvLayer(in_channel, out_channel, 3, downsample=True)]
|
||||
setattr(self, self.names[self.log_size - i + 1],
|
||||
nn.Sequential(*conv))
|
||||
in_channel = out_channel
|
||||
self.final_linear = nn.Sequential(
|
||||
EqualLinear(
|
||||
channels[4] * 4 * 4, style_dim, activation='fused_lrelu'))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs,
|
||||
return_latents=False,
|
||||
inject_index=None,
|
||||
truncation=1,
|
||||
truncation_latent=None,
|
||||
input_is_latent=False,
|
||||
):
|
||||
noise = []
|
||||
for i in range(self.log_size - 1):
|
||||
ecd = getattr(self, self.names[i])
|
||||
inputs = ecd(inputs)
|
||||
noise.append(inputs)
|
||||
inputs = inputs.view(inputs.shape[0], -1)
|
||||
outs = self.final_linear(inputs)
|
||||
noise = list(
|
||||
itertools.chain.from_iterable(
|
||||
itertools.repeat(x, 2) for x in noise))[::-1]
|
||||
outs = self.generator([outs],
|
||||
return_latents,
|
||||
inject_index,
|
||||
truncation,
|
||||
truncation_latent,
|
||||
input_is_latent,
|
||||
noise=noise[1:])
|
||||
return outs
|
||||
|
||||
|
||||
class Discriminator(nn.Module):
|
||||
|
||||
def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):
|
||||
def __init__(self,
|
||||
size,
|
||||
channel_multiplier=2,
|
||||
blur_kernel=[1, 3, 3, 1],
|
||||
narrow=1):
|
||||
super().__init__()
|
||||
|
||||
channels = {
|
||||
4: 512,
|
||||
8: 512,
|
||||
16: 512,
|
||||
32: 512,
|
||||
64: 256 * channel_multiplier,
|
||||
128: 128 * channel_multiplier,
|
||||
256: 64 * channel_multiplier,
|
||||
512: 32 * channel_multiplier,
|
||||
1024: 16 * channel_multiplier,
|
||||
4: int(512 * narrow),
|
||||
8: int(512 * narrow),
|
||||
16: int(512 * narrow),
|
||||
32: int(512 * narrow),
|
||||
64: int(256 * channel_multiplier * narrow),
|
||||
128: int(128 * channel_multiplier * narrow),
|
||||
256: int(64 * channel_multiplier * narrow),
|
||||
512: int(32 * channel_multiplier * narrow),
|
||||
1024: int(16 * channel_multiplier * narrow),
|
||||
2048: int(8 * channel_multiplier * narrow)
|
||||
}
|
||||
|
||||
convs = [ConvLayer(3, channels[size], 1)]
|
||||
@@ -713,48 +807,53 @@ class Discriminator(nn.Module):
|
||||
return out
|
||||
|
||||
|
||||
class FullGenerator(nn.Module):
|
||||
class FullGenerator_SR(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size,
|
||||
in_size,
|
||||
out_size,
|
||||
style_dim,
|
||||
n_mlp,
|
||||
channel_multiplier=2,
|
||||
blur_kernel=[1, 3, 3, 1],
|
||||
lr_mlp=0.01,
|
||||
narrow=1,
|
||||
):
|
||||
super().__init__()
|
||||
channels = {
|
||||
4: 512 // ratio,
|
||||
8: 512 // ratio,
|
||||
16: 512 // ratio,
|
||||
32: 512 // ratio,
|
||||
64: 256 // ratio * channel_multiplier,
|
||||
128: 128 // ratio * channel_multiplier,
|
||||
256: 64 // ratio * channel_multiplier,
|
||||
512: 32 // ratio * channel_multiplier,
|
||||
1024: 16 // ratio * channel_multiplier,
|
||||
4: int(512 * narrow),
|
||||
8: int(512 * narrow),
|
||||
16: int(512 * narrow),
|
||||
32: int(512 * narrow),
|
||||
64: int(256 * channel_multiplier * narrow),
|
||||
128: int(128 * channel_multiplier * narrow),
|
||||
256: int(64 * channel_multiplier * narrow),
|
||||
512: int(32 * channel_multiplier * narrow),
|
||||
1024: int(16 * channel_multiplier * narrow),
|
||||
2048: int(8 * channel_multiplier * narrow),
|
||||
}
|
||||
|
||||
self.log_size = int(math.log(size, 2))
|
||||
self.log_insize = int(math.log(in_size, 2))
|
||||
self.log_outsize = int(math.log(out_size, 2))
|
||||
self.generator = Generator(
|
||||
size,
|
||||
out_size,
|
||||
style_dim,
|
||||
n_mlp,
|
||||
channel_multiplier=channel_multiplier,
|
||||
blur_kernel=blur_kernel,
|
||||
lr_mlp=lr_mlp)
|
||||
lr_mlp=lr_mlp,
|
||||
narrow=narrow)
|
||||
|
||||
conv = [ConvLayer(3, channels[size], 1)]
|
||||
conv = [ConvLayer(3, channels[in_size], 1)]
|
||||
self.ecd0 = nn.Sequential(*conv)
|
||||
in_channel = channels[size]
|
||||
in_channel = channels[in_size]
|
||||
|
||||
self.names = ['ecd%d' % i for i in range(self.log_size - 1)]
|
||||
for i in range(self.log_size, 2, -1):
|
||||
self.names = ['ecd%d' % i for i in range(self.log_insize - 1)]
|
||||
for i in range(self.log_insize, 2, -1):
|
||||
out_channel = channels[2**(i - 1)]
|
||||
conv = [ConvLayer(in_channel, out_channel, 3, downsample=True)]
|
||||
setattr(self, self.names[self.log_size - i + 1],
|
||||
setattr(self, self.names[self.log_insize - i + 1],
|
||||
nn.Sequential(*conv))
|
||||
in_channel = out_channel
|
||||
self.final_linear = nn.Sequential(
|
||||
@@ -771,18 +870,22 @@ class FullGenerator(nn.Module):
|
||||
input_is_latent=False,
|
||||
):
|
||||
noise = []
|
||||
for i in range(self.log_size - 1):
|
||||
for i in range(self.log_outsize - self.log_insize):
|
||||
noise.append(None)
|
||||
for i in range(self.log_insize - 1):
|
||||
ecd = getattr(self, self.names[i])
|
||||
inputs = ecd(inputs)
|
||||
noise.append(inputs)
|
||||
inputs = inputs.view(inputs.shape[0], -1)
|
||||
outs = self.final_linear(inputs)
|
||||
outs = self.generator([outs],
|
||||
return_latents,
|
||||
inject_index,
|
||||
truncation,
|
||||
truncation_latent,
|
||||
input_is_latent,
|
||||
noise=noise[::-1])
|
||||
|
||||
return outs
|
||||
noise = list(
|
||||
itertools.chain.from_iterable(
|
||||
itertools.repeat(x, 2) for x in noise))[::-1]
|
||||
image, latent = self.generator([outs],
|
||||
return_latents,
|
||||
inject_index,
|
||||
truncation,
|
||||
truncation_latent,
|
||||
input_is_latent,
|
||||
noise=noise[1:])
|
||||
return image, latent
|
||||
@@ -1,4 +1,2 @@
|
||||
# The implementation is adopted from stylegan2-pytorch, made public available under the MIT License
|
||||
# at https://github.com/rosinality/stylegan2-pytorch
|
||||
from .fused_act import FusedLeakyReLU, fused_leaky_relu
|
||||
from .upfirdn2d import upfirdn2d
|
||||
|
||||
@@ -15,6 +15,77 @@ REFERENCE_FACIAL_POINTS = [[30.29459953, 51.69630051],
|
||||
DEFAULT_CROP_SIZE = (96, 112)
|
||||
|
||||
|
||||
def _umeyama(src, dst, estimate_scale=True, scale=1.0):
|
||||
"""Estimate N-D similarity transformation with or without scaling.
|
||||
Parameters
|
||||
----------
|
||||
src : (M, N) array
|
||||
Source coordinates.
|
||||
dst : (M, N) array
|
||||
Destination coordinates.
|
||||
estimate_scale : bool
|
||||
Whether to estimate scaling factor.
|
||||
Returns
|
||||
-------
|
||||
T : (N + 1, N + 1)
|
||||
The homogeneous similarity transformation matrix. The matrix contains
|
||||
NaN values only if the problem is not well-conditioned.
|
||||
References
|
||||
----------
|
||||
.. [1] "Least-squares estimation of transformation parameters between two
|
||||
point patterns", Shinji Umeyama, PAMI 1991, :DOI:`10.1109/34.88573`
|
||||
"""
|
||||
|
||||
num = src.shape[0]
|
||||
dim = src.shape[1]
|
||||
|
||||
# Compute mean of src and dst.
|
||||
src_mean = src.mean(axis=0)
|
||||
dst_mean = dst.mean(axis=0)
|
||||
|
||||
# Subtract mean from src and dst.
|
||||
src_demean = src - src_mean
|
||||
dst_demean = dst - dst_mean
|
||||
|
||||
# Eq. (38).
|
||||
A = dst_demean.T @ src_demean / num
|
||||
|
||||
# Eq. (39).
|
||||
d = np.ones((dim, ), dtype=np.double)
|
||||
if np.linalg.det(A) < 0:
|
||||
d[dim - 1] = -1
|
||||
|
||||
T = np.eye(dim + 1, dtype=np.double)
|
||||
|
||||
U, S, V = np.linalg.svd(A)
|
||||
|
||||
# Eq. (40) and (43).
|
||||
rank = np.linalg.matrix_rank(A)
|
||||
if rank == 0:
|
||||
return np.nan * T
|
||||
elif rank == dim - 1:
|
||||
if np.linalg.det(U) * np.linalg.det(V) > 0:
|
||||
T[:dim, :dim] = U @ V
|
||||
else:
|
||||
s = d[dim - 1]
|
||||
d[dim - 1] = -1
|
||||
T[:dim, :dim] = U @ np.diag(d) @ V
|
||||
d[dim - 1] = s
|
||||
else:
|
||||
T[:dim, :dim] = U @ np.diag(d) @ V
|
||||
|
||||
if estimate_scale:
|
||||
# Eq. (41) and (42).
|
||||
scale = 1.0 / src_demean.var(axis=0).sum() * (S @ d)
|
||||
else:
|
||||
scale = scale
|
||||
|
||||
T[:dim, dim] = dst_mean - scale * (T[:dim, :dim] @ src_mean.T)
|
||||
T[:dim, :dim] *= scale
|
||||
|
||||
return T, scale
|
||||
|
||||
|
||||
class FaceWarpException(Exception):
|
||||
|
||||
def __str__(self):
|
||||
@@ -246,6 +317,65 @@ def warp_and_crop_face(src_img,
|
||||
return face_img
|
||||
|
||||
|
||||
def warp_and_crop_face_enhance(src_img,
|
||||
facial_pts,
|
||||
reference_pts=None,
|
||||
crop_size=(96, 112),
|
||||
align_type='smilarity'):
|
||||
if reference_pts is None:
|
||||
if crop_size[0] == 96 and crop_size[1] == 112:
|
||||
reference_pts = REFERENCE_FACIAL_POINTS
|
||||
else:
|
||||
default_square = False
|
||||
inner_padding_factor = 0
|
||||
outer_padding = (0, 0)
|
||||
output_size = crop_size
|
||||
|
||||
reference_pts = get_reference_facial_points(
|
||||
output_size, inner_padding_factor, outer_padding,
|
||||
default_square)
|
||||
|
||||
ref_pts = np.float32(reference_pts)
|
||||
ref_pts_shp = ref_pts.shape
|
||||
if max(ref_pts_shp) < 3 or min(ref_pts_shp) != 2:
|
||||
raise FaceWarpException(
|
||||
'reference_pts.shape must be (K,2) or (2,K) and K>2')
|
||||
|
||||
if ref_pts_shp[0] == 2:
|
||||
ref_pts = ref_pts.T
|
||||
|
||||
src_pts = np.float32(facial_pts)
|
||||
src_pts_shp = src_pts.shape
|
||||
if max(src_pts_shp) < 3 or min(src_pts_shp) != 2:
|
||||
raise FaceWarpException(
|
||||
'facial_pts.shape must be (K,2) or (2,K) and K>2')
|
||||
|
||||
if src_pts_shp[0] == 2:
|
||||
src_pts = src_pts.T
|
||||
|
||||
if src_pts.shape != ref_pts.shape:
|
||||
raise FaceWarpException(
|
||||
'facial_pts and reference_pts must have the same shape')
|
||||
|
||||
if align_type == 'cv2_affine':
|
||||
tfm = cv2.getAffineTransform(src_pts[0:3], ref_pts[0:3])
|
||||
tfm_inv = cv2.getAffineTransform(ref_pts[0:3], src_pts[0:3])
|
||||
elif align_type == 'affine':
|
||||
tfm = get_affine_transform_matrix(src_pts, ref_pts)
|
||||
tfm_inv = get_affine_transform_matrix(ref_pts, src_pts)
|
||||
else:
|
||||
params, scale = _umeyama(src_pts, ref_pts)
|
||||
tfm = params[:2, :]
|
||||
|
||||
params, _ = _umeyama(ref_pts, src_pts, False, scale=1.0 / scale)
|
||||
tfm_inv = params[:2, :]
|
||||
|
||||
face_img = cv2.warpAffine(
|
||||
src_img, tfm, (crop_size[0], crop_size[1]), flags=3)
|
||||
|
||||
return face_img, tfm_inv
|
||||
|
||||
|
||||
def get_f5p(landmarks, np_img):
|
||||
eye_left = find_pupil(landmarks[36:41], np_img)
|
||||
eye_right = find_pupil(landmarks[42:47], np_img)
|
||||
|
||||
@@ -16,9 +16,10 @@ from modelscope.models.builder import MODELS
|
||||
from modelscope.models.cv.face_detection.peppa_pig_face.facer import FaceAna
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
from .facegan.gan_wrap import GANWrap
|
||||
from .facegan.face_gan import GPEN
|
||||
from .facelib.align_trans import (get_f5p, get_reference_facial_points,
|
||||
warp_and_crop_face)
|
||||
warp_and_crop_face,
|
||||
warp_and_crop_face_enhance)
|
||||
from .network.aei_flow_net import AEI_Net
|
||||
from .network.bfm import ParametricFaceModel
|
||||
from .network.facerecon_model import ReconNetWrapper
|
||||
@@ -78,14 +79,6 @@ class ImageFaceFusion(TorchModel):
|
||||
self.face_model = ParametricFaceModel(bfm_folder=bfm_dir)
|
||||
self.face_model.to(self.device)
|
||||
|
||||
face_enhance_path = os.path.join(model_dir, 'faceEnhance',
|
||||
'350000-Ns256.pt')
|
||||
self.ganwrap = GANWrap(
|
||||
model_path=face_enhance_path,
|
||||
size=256,
|
||||
channel_multiplier=1,
|
||||
device=self.device)
|
||||
|
||||
self.facer = FaceAna(model_dir)
|
||||
|
||||
logger.info('load facefusion models done')
|
||||
@@ -94,6 +87,27 @@ class ImageFaceFusion(TorchModel):
|
||||
self.mask_init = cv2.resize(self.mask_init, (256, 256))
|
||||
self.mask = self.image_transform(self.mask_init, is_norm=False)
|
||||
|
||||
face_enhance_path = os.path.join(model_dir, 'faceEnhance',
|
||||
'GPEN-BFR-1024.pth')
|
||||
|
||||
if not os.path.exists(face_enhance_path):
|
||||
logger.warning(
|
||||
'model path not found, please update the latest model!')
|
||||
|
||||
self.ganwrap_1024 = GPEN(face_enhance_path, 1024, 2, self.device)
|
||||
|
||||
self.mask_enhance = np.zeros((512, 512), np.float32)
|
||||
cv2.rectangle(self.mask_enhance, (26, 26), (486, 486), (1, 1, 1), -1,
|
||||
cv2.LINE_AA)
|
||||
self.mask_enhance = cv2.GaussianBlur(self.mask_enhance, (101, 101), 11)
|
||||
self.mask_enhance = cv2.GaussianBlur(self.mask_enhance, (101, 101), 11)
|
||||
|
||||
default_square = True
|
||||
inner_padding_factor = 0.25
|
||||
outer_padding = (0, 0)
|
||||
self.reference_5pts_1024 = get_reference_facial_points(
|
||||
(1024, 1024), inner_padding_factor, outer_padding, default_square)
|
||||
|
||||
self.test_transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
||||
@@ -157,7 +171,7 @@ class ImageFaceFusion(TorchModel):
|
||||
src_h, src_w, _ = img.shape
|
||||
boxes, landmarks, _ = self.facer.run(img)
|
||||
if boxes.shape[0] == 0:
|
||||
return None
|
||||
return None, None, None
|
||||
elif boxes.shape[0] > 1:
|
||||
max_area = 0
|
||||
max_index = 0
|
||||
@@ -168,9 +182,14 @@ class ImageFaceFusion(TorchModel):
|
||||
if area > max_area:
|
||||
max_index = i
|
||||
max_area = area
|
||||
return landmarks[max_index]
|
||||
|
||||
fw = boxes[max_index][2] - boxes[max_index][0]
|
||||
fh = boxes[max_index][3] - boxes[max_index][1]
|
||||
return landmarks[max_index], fw, fh
|
||||
else:
|
||||
return landmarks[0]
|
||||
fw = boxes[0][2] - boxes[0][0]
|
||||
fh = boxes[0][3] - boxes[0][1]
|
||||
return landmarks[0], fw, fh
|
||||
|
||||
def compute_3d_params(self, Xs, Xt):
|
||||
kp_fuse = {}
|
||||
@@ -198,6 +217,51 @@ class ImageFaceFusion(TorchModel):
|
||||
|
||||
return kp_fuse, kp_t
|
||||
|
||||
def process_enhance(self, im, f5p, fh, fw):
|
||||
height, width, _ = im.shape
|
||||
|
||||
of, tfm_inv = warp_and_crop_face_enhance(
|
||||
im,
|
||||
f5p,
|
||||
reference_pts=self.reference_5pts_1024,
|
||||
crop_size=(1024, 1024))
|
||||
ef, pred = self.ganwrap_1024.process(of)
|
||||
|
||||
tmp_mask = self.mask_enhance
|
||||
tmp_mask = cv2.resize(tmp_mask, ef.shape[:2])
|
||||
tmp_mask = cv2.warpAffine(tmp_mask, tfm_inv, (width, height), flags=3)
|
||||
|
||||
full_mask = np.zeros((height, width), dtype=np.float32)
|
||||
full_img = np.zeros(im.shape, dtype=np.uint8)
|
||||
|
||||
if min(fh, fw) < 40:
|
||||
ef = cv2.pyrDown(ef)
|
||||
ef = cv2.pyrDown(ef)
|
||||
ef = cv2.pyrUp(ef)
|
||||
ef = cv2.pyrUp(ef)
|
||||
elif min(fh, fw) < 60:
|
||||
ef = cv2.pyrDown(ef)
|
||||
ef = cv2.resize(ef, (0, 0), fx=2, fy=2)
|
||||
ef = cv2.resize(ef, (0, 0), fx=0.5, fy=0.5)
|
||||
ef = cv2.pyrUp(ef)
|
||||
elif min(fh, fw) < 80:
|
||||
ef = cv2.pyrDown(ef)
|
||||
ef = cv2.pyrUp(ef)
|
||||
elif min(fh, fw) < 100:
|
||||
ef = cv2.pyrDown(ef)
|
||||
ef = cv2.resize(ef, (0, 0), fx=2, fy=2)
|
||||
|
||||
tmp_img = cv2.warpAffine(ef, tfm_inv, (width, height), flags=3)
|
||||
|
||||
mask = tmp_mask - full_mask
|
||||
full_mask[np.where(mask > 0)] = tmp_mask[np.where(mask > 0)]
|
||||
full_img[np.where(mask > 0)] = tmp_img[np.where(mask > 0)]
|
||||
|
||||
full_mask = full_mask[:, :, np.newaxis]
|
||||
im = cv2.convertScaleAbs(im * (1 - full_mask) + full_img * full_mask)
|
||||
im = cv2.resize(im, (width, height))
|
||||
return im
|
||||
|
||||
def inference(self, template_img, user_img):
|
||||
ori_h, ori_w, _ = template_img.shape
|
||||
|
||||
@@ -205,14 +269,14 @@ class ImageFaceFusion(TorchModel):
|
||||
user_img = user_img.cpu().numpy()
|
||||
|
||||
user_img_bgr = user_img[:, :, ::-1]
|
||||
landmark_source = self.detect_face(user_img)
|
||||
landmark_source, _, _ = self.detect_face(user_img)
|
||||
if landmark_source is None:
|
||||
logger.warning('No face detected in user image!')
|
||||
return template_img
|
||||
f5p_user = get_f5p(landmark_source, user_img_bgr)
|
||||
|
||||
template_img_bgr = template_img[:, :, ::-1]
|
||||
landmark_template = self.detect_face(template_img)
|
||||
landmark_template, fw, fh = self.detect_face(template_img)
|
||||
if landmark_template is None:
|
||||
logger.warning('No face detected in template image!')
|
||||
return template_img
|
||||
@@ -235,7 +299,6 @@ class ImageFaceFusion(TorchModel):
|
||||
with torch.no_grad():
|
||||
kp_fuse, kp_t = self.compute_3d_params(Xs, Xt)
|
||||
Yt, _, _ = self.netG(Xt, Xs_embeds, kp_fuse, kp_t)
|
||||
Yt = self.ganwrap.process_tensor(Yt)
|
||||
Yt = Yt * 0.5 + 0.5
|
||||
Yt = torch.clamp(Yt, 0, 1)
|
||||
|
||||
@@ -247,6 +310,7 @@ class ImageFaceFusion(TorchModel):
|
||||
0).cpu().numpy()
|
||||
Yt_trans_inv = Yt_trans_inv.astype(np.float32)
|
||||
out_img = Yt_trans_inv[:, :, ::-1] * 255.
|
||||
out_img = self.process_enhance(out_img, f5p_template, fh, fw)
|
||||
|
||||
logger.info('model inference done')
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ from diffusers.models import cross_attention
|
||||
from diffusers.utils import deprecation_utils
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from modelscope import snapshot_download
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models import TorchModel
|
||||
from modelscope.models.builder import MODELS
|
||||
@@ -56,7 +57,10 @@ class EfficientStableDiffusion(TorchModel):
|
||||
super().__init__(model_dir, *args, **kwargs)
|
||||
tuner_name = kwargs.pop('tuner_name', 'lora')
|
||||
pretrained_model_name_or_path = kwargs.pop(
|
||||
'pretrained_model_name_or_path', 'runwayml/stable-diffusion-v1-5')
|
||||
'pretrained_model_name_or_path',
|
||||
'AI-ModelScope/stable-diffusion-v1-5')
|
||||
pretrained_model_name_or_path = snapshot_download(
|
||||
pretrained_model_name_or_path)
|
||||
tuner_config = kwargs.pop('tuner_config', None)
|
||||
pretrained_tuner = kwargs.pop('pretrained_tuner', None)
|
||||
revision = kwargs.pop('revision', None)
|
||||
|
||||
24
modelscope/models/multi_modal/image_to_video/__init__.py
Normal file
24
modelscope/models/multi_modal/image_to_video/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
from .image_to_video_model import ImageToVideo
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
'image_to_video_model': ['ImageToVideo'],
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
215
modelscope/models/multi_modal/image_to_video/image_to_video_model.py
Executable file
215
modelscope/models/multi_modal/image_to_video/image_to_video_model.py
Executable file
@@ -0,0 +1,215 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
import os.path as osp
|
||||
import random
|
||||
from copy import copy
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
import torch.cuda.amp as amp
|
||||
|
||||
import modelscope.models.multi_modal.image_to_video.utils.transforms as data
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models.base import TorchModel
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.models.multi_modal.image_to_video.modules import *
|
||||
from modelscope.models.multi_modal.image_to_video.modules import (
|
||||
AutoencoderKL, FrozenOpenCLIPVisualEmbedder, Img2VidSDUNet)
|
||||
from modelscope.models.multi_modal.image_to_video.utils.config import cfg
|
||||
from modelscope.models.multi_modal.image_to_video.utils.diffusion import \
|
||||
GaussianDiffusion
|
||||
from modelscope.models.multi_modal.image_to_video.utils.seed import setup_seed
|
||||
from modelscope.models.multi_modal.image_to_video.utils.shedule import \
|
||||
beta_schedule
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
__all__ = ['ImageToVideo']
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
@MODELS.register_module(
|
||||
Tasks.image_to_video, module_name=Models.image_to_video_model)
|
||||
class ImageToVideo(TorchModel):
|
||||
r"""
|
||||
Image2Video aims to solve the task of generating high-definition videos based on input images.
|
||||
Image2Video is a video generation basic model developed by Alibaba Cloud, with a parameter size
|
||||
of approximately 2 billion. It has been pre trained on large-scale video and image data and
|
||||
fine-tuned on a small amount of high-quality data. The data is widely distributed and diverse
|
||||
in categories, and the model has good generalization ability for different types of data
|
||||
|
||||
Paper link: https://arxiv.org/abs/2306.02018
|
||||
|
||||
Attributes:
|
||||
diffusion: diffusion model for DDIM.
|
||||
autoencoder: decode the latent representation into visual space.
|
||||
clip_encoder: encode the image into image embedding.
|
||||
"""
|
||||
|
||||
def __init__(self, model_dir, *args, **kwargs):
|
||||
r"""
|
||||
Args:
|
||||
model_dir (`str` or `os.PathLike`)
|
||||
Can be either:
|
||||
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co
|
||||
or modelscope.cn. Valid model ids can be located at the root-level, like `bert-base-uncased`,
|
||||
or namespaced under a user or organization name, like `dbmdz/bert-base-german-cased`.
|
||||
- A path to a *directory* containing model weights saved using
|
||||
[`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
|
||||
- A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In
|
||||
this case, `from_tf` should be set to `True` and a configuration object should be provided as
|
||||
`config` argument. This loading path is slower than converting the TensorFlow checkpoint in a
|
||||
PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
|
||||
- A path or url to a model folder containing a *flax checkpoint file* in *.msgpack* format (e.g,
|
||||
`./flax_model/` containing `flax_model.msgpack`). In this case, `from_flax` should be set to
|
||||
`True`.
|
||||
"""
|
||||
super().__init__(model_dir=model_dir, *args, **kwargs)
|
||||
|
||||
self.config = Config.from_file(
|
||||
osp.join(model_dir, ModelFile.CONFIGURATION))
|
||||
|
||||
# assign default value
|
||||
cfg.batch_size = self.config.model.model_cfg.batch_size
|
||||
cfg.target_fps = self.config.model.model_cfg.target_fps
|
||||
cfg.max_frames = self.config.model.model_cfg.max_frames
|
||||
cfg.latent_hei = self.config.model.model_cfg.latent_hei
|
||||
cfg.latent_wid = self.config.model.model_cfg.latent_wid
|
||||
cfg.model_path = osp.join(model_dir,
|
||||
self.config.model.model_args.ckpt_unet)
|
||||
|
||||
self.device = torch.device(
|
||||
'cuda') if torch.cuda.is_available() else torch.device('cpu')
|
||||
|
||||
if 'seed' in self.config.model.model_args.keys():
|
||||
cfg.seed = self.config.model.model_args.seed
|
||||
else:
|
||||
cfg.seed = random.randint(0, 99999)
|
||||
setup_seed(cfg.seed)
|
||||
|
||||
# transform
|
||||
vid_trans = data.Compose([
|
||||
data.CenterCropWide(size=(cfg.resolution[0], cfg.resolution[0])),
|
||||
data.Resize(cfg.vit_resolution),
|
||||
data.ToTensor(),
|
||||
data.Normalize(mean=cfg.vit_mean, std=cfg.vit_std)
|
||||
])
|
||||
self.vid_trans = vid_trans
|
||||
|
||||
cfg.embedder.pretrained = osp.join(
|
||||
model_dir, self.config.model.model_args.ckpt_clip)
|
||||
clip_encoder = FrozenOpenCLIPVisualEmbedder(**cfg.embedder)
|
||||
clip_encoder.model.to(self.device)
|
||||
self.clip_encoder = clip_encoder
|
||||
logger.info(f'Build encoder with {cfg.embedder.type}')
|
||||
|
||||
# [unet]
|
||||
generator = Img2VidSDUNet(**cfg.UNet)
|
||||
generator = generator.to(self.device)
|
||||
generator.eval()
|
||||
load_dict = torch.load(cfg.model_path, map_location='cpu')
|
||||
ret = generator.load_state_dict(load_dict['state_dict'], strict=True)
|
||||
self.generator = generator
|
||||
logger.info('Load model {} path {}, with local status {}'.format(
|
||||
cfg.UNet.type, cfg.model_path, ret))
|
||||
|
||||
# [diffusion]
|
||||
betas = beta_schedule(
|
||||
'linear_sd',
|
||||
cfg.num_timesteps,
|
||||
init_beta=0.00085,
|
||||
last_beta=0.0120)
|
||||
diffusion = GaussianDiffusion(
|
||||
betas=betas,
|
||||
mean_type=cfg.mean_type,
|
||||
var_type=cfg.var_type,
|
||||
loss_type=cfg.loss_type,
|
||||
rescale_timesteps=False,
|
||||
noise_strength=getattr(cfg, 'noise_strength', 0))
|
||||
self.diffusion = diffusion
|
||||
logger.info('Build diffusion with type of GaussianDiffusion')
|
||||
|
||||
# [auotoencoder]
|
||||
cfg.auto_encoder.pretrained = osp.join(
|
||||
model_dir, self.config.model.model_args.ckpt_autoencoder)
|
||||
autoencoder = AutoencoderKL(**cfg.auto_encoder)
|
||||
autoencoder.eval()
|
||||
for param in autoencoder.parameters():
|
||||
param.requires_grad = False
|
||||
autoencoder.to(self.device)
|
||||
self.autoencoder = autoencoder
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
zero_feature = torch.zeros(1, 1, cfg.UNet.input_dim).to(self.device)
|
||||
self.zero_feature = zero_feature
|
||||
self.fps_tensor = torch.tensor([cfg.target_fps],
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
self.cfg = cfg
|
||||
|
||||
def forward(self, input: Dict[str, Any]):
|
||||
r"""
|
||||
The entry function of image to video task.
|
||||
1. Using diffusion model to generate the video's latent representation.
|
||||
2. Using autoencoder to decode the video's latent representation to visual space.
|
||||
|
||||
Args:
|
||||
input (`Dict[Str, Any]`):
|
||||
The input of the task
|
||||
Returns:
|
||||
A generated video (as pytorch tensor).
|
||||
"""
|
||||
|
||||
vit_frame = input['vit_frame']
|
||||
cfg = self.cfg
|
||||
|
||||
img_embedding = self.clip_encoder(vit_frame).unsqueeze(1)
|
||||
|
||||
noise = self.build_noise()
|
||||
zero_feature = copy(self.zero_feature)
|
||||
with torch.no_grad():
|
||||
with amp.autocast(enabled=cfg.use_fp16):
|
||||
model_kwargs = [{
|
||||
'y': img_embedding,
|
||||
'fps': self.fps_tensor
|
||||
}, {
|
||||
'y': zero_feature.repeat(cfg.batch_size, 1, 1),
|
||||
'fps': self.fps_tensor
|
||||
}]
|
||||
gen_video = self.diffusion.ddim_sample_loop(
|
||||
noise=noise,
|
||||
model=self.generator,
|
||||
model_kwargs=model_kwargs,
|
||||
guide_scale=cfg.guide_scale,
|
||||
ddim_timesteps=cfg.ddim_timesteps,
|
||||
eta=0.0)
|
||||
|
||||
gen_video = 1. / cfg.scale_factor * gen_video
|
||||
gen_video = rearrange(gen_video, 'b c f h w -> (b f) c h w')
|
||||
chunk_size = min(cfg.decoder_bs, gen_video.shape[0])
|
||||
gen_video_list = torch.chunk(
|
||||
gen_video, gen_video.shape[0] // chunk_size, dim=0)
|
||||
decode_generator = []
|
||||
for vd_data in gen_video_list:
|
||||
gen_frames = self.autoencoder.decode(vd_data)
|
||||
decode_generator.append(gen_frames)
|
||||
|
||||
gen_video = torch.cat(decode_generator, dim=0)
|
||||
gen_video = rearrange(
|
||||
gen_video, '(b f) c h w -> b c f h w', b=cfg.batch_size)
|
||||
|
||||
return gen_video.type(torch.float32).cpu()
|
||||
|
||||
def build_noise(self):
|
||||
cfg = self.cfg
|
||||
noise = torch.randn(
|
||||
[1, 4, cfg.max_frames, cfg.latent_hei,
|
||||
cfg.latent_wid]).to(self.device)
|
||||
if cfg.noise_strength > 0:
|
||||
b, c, f, *_ = noise.shape
|
||||
offset_noise = torch.randn(b, c, f, 1, 1, device=noise.device)
|
||||
noise = noise + cfg.noise_strength * offset_noise
|
||||
return noise.contiguous()
|
||||
5
modelscope/models/multi_modal/image_to_video/modules/__init__.py
Executable file
5
modelscope/models/multi_modal/image_to_video/modules/__init__.py
Executable file
@@ -0,0 +1,5 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from .autoencoder import *
|
||||
from .embedder import *
|
||||
from .unet_i2v import *
|
||||
573
modelscope/models/multi_modal/image_to_video/modules/autoencoder.py
Executable file
573
modelscope/models/multi_modal/image_to_video/modules/autoencoder.py
Executable file
@@ -0,0 +1,573 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import collections
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def nonlinearity(x):
|
||||
# swish
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
def Normalize(in_channels, num_groups=32):
|
||||
return torch.nn.GroupNorm(
|
||||
num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
|
||||
|
||||
class DiagonalGaussianDistribution(object):
|
||||
|
||||
def __init__(self, parameters, deterministic=False):
|
||||
self.parameters = parameters
|
||||
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
|
||||
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
||||
self.deterministic = deterministic
|
||||
self.std = torch.exp(0.5 * self.logvar)
|
||||
self.var = torch.exp(self.logvar)
|
||||
if self.deterministic:
|
||||
self.var = self.std = torch.zeros_like(
|
||||
self.mean).to(device=self.parameters.device)
|
||||
|
||||
def sample(self):
|
||||
x = self.mean + self.std * torch.randn(
|
||||
self.mean.shape).to(device=self.parameters.device)
|
||||
return x
|
||||
|
||||
def kl(self, other=None):
|
||||
if self.deterministic:
|
||||
return torch.Tensor([0.])
|
||||
else:
|
||||
if other is None:
|
||||
return 0.5 * torch.sum(
|
||||
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
|
||||
dim=[1, 2, 3])
|
||||
else:
|
||||
return 0.5 * torch.sum(
|
||||
torch.pow(self.mean - other.mean, 2) / other.var
|
||||
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
|
||||
dim=[1, 2, 3])
|
||||
|
||||
def nll(self, sample, dims=[1, 2, 3]):
|
||||
if self.deterministic:
|
||||
return torch.Tensor([0.])
|
||||
logtwopi = np.log(2.0 * np.pi)
|
||||
return 0.5 * torch.sum(
|
||||
logtwopi + self.logvar
|
||||
+ torch.pow(sample - self.mean, 2) / self.var,
|
||||
dim=dims)
|
||||
|
||||
def mode(self):
|
||||
return self.mean
|
||||
|
||||
|
||||
class ResnetBlock(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
in_channels,
|
||||
out_channels=None,
|
||||
conv_shortcut=False,
|
||||
dropout,
|
||||
temb_channels=512):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
self.use_conv_shortcut = conv_shortcut
|
||||
|
||||
self.norm1 = Normalize(in_channels)
|
||||
self.conv1 = torch.nn.Conv2d(
|
||||
in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
if temb_channels > 0:
|
||||
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
|
||||
self.norm2 = Normalize(out_channels)
|
||||
self.dropout = torch.nn.Dropout(dropout)
|
||||
self.conv2 = torch.nn.Conv2d(
|
||||
out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
self.conv_shortcut = torch.nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
else:
|
||||
self.nin_shortcut = torch.nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
|
||||
def forward(self, x, temb):
|
||||
h = x
|
||||
h = self.norm1(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv1(h)
|
||||
|
||||
if temb is not None:
|
||||
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
|
||||
|
||||
h = self.norm2(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.dropout(h)
|
||||
h = self.conv2(h)
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
x = self.conv_shortcut(x)
|
||||
else:
|
||||
x = self.nin_shortcut(x)
|
||||
|
||||
return x + h
|
||||
|
||||
|
||||
class AttnBlock(nn.Module):
|
||||
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = Normalize(in_channels)
|
||||
self.q = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.k = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.v = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.proj_out = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
# compute attention
|
||||
b, c, h, w = q.shape
|
||||
q = q.reshape(b, c, h * w)
|
||||
q = q.permute(0, 2, 1)
|
||||
k = k.reshape(b, c, h * w)
|
||||
w_ = torch.bmm(q, k)
|
||||
w_ = w_ * (int(c)**(-0.5))
|
||||
w_ = torch.nn.functional.softmax(w_, dim=2)
|
||||
|
||||
# attend to values
|
||||
v = v.reshape(b, c, h * w)
|
||||
w_ = w_.permute(0, 2, 1)
|
||||
h_ = torch.bmm(v, w_)
|
||||
h_ = h_.reshape(b, c, h, w)
|
||||
|
||||
h_ = self.proj_out(h_)
|
||||
|
||||
return x + h_
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, with_conv):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
if self.with_conv:
|
||||
self.conv = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
x = torch.nn.functional.interpolate(
|
||||
x, scale_factor=2.0, mode='nearest')
|
||||
if self.with_conv:
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, with_conv):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
if self.with_conv:
|
||||
# no asymmetric padding in torch conv, must do it ourselves
|
||||
self.conv = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
if self.with_conv:
|
||||
pad = (0, 1, 0, 1)
|
||||
x = torch.nn.functional.pad(x, pad, mode='constant', value=0)
|
||||
x = self.conv(x)
|
||||
else:
|
||||
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
||||
return x
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
ch,
|
||||
out_ch,
|
||||
ch_mult=(1, 2, 4, 8),
|
||||
num_res_blocks,
|
||||
attn_resolutions,
|
||||
dropout=0.0,
|
||||
resamp_with_conv=True,
|
||||
in_channels,
|
||||
resolution,
|
||||
z_channels,
|
||||
double_z=True,
|
||||
use_linear_attn=False,
|
||||
attn_type='vanilla',
|
||||
**ignore_kwargs):
|
||||
super().__init__()
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
|
||||
# downsampling
|
||||
self.conv_in = torch.nn.Conv2d(
|
||||
in_channels, self.ch, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
curr_res = resolution
|
||||
in_ch_mult = (1, ) + tuple(ch_mult)
|
||||
self.in_ch_mult = in_ch_mult
|
||||
self.down = nn.ModuleList()
|
||||
for i_level in range(self.num_resolutions):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_in = ch * in_ch_mult[i_level]
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks):
|
||||
block.append(
|
||||
ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout))
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(AttnBlock(block_in))
|
||||
down = nn.Module()
|
||||
down.block = block
|
||||
down.attn = attn
|
||||
if i_level != self.num_resolutions - 1:
|
||||
down.downsample = Downsample(block_in, resamp_with_conv)
|
||||
curr_res = curr_res // 2
|
||||
self.down.append(down)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
self.mid.attn_1 = AttnBlock(block_in)
|
||||
self.mid.block_2 = ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = torch.nn.Conv2d(
|
||||
block_in,
|
||||
2 * z_channels if double_z else z_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
# timestep embedding
|
||||
temb = None
|
||||
|
||||
# downsampling
|
||||
hs = [self.conv_in(x)]
|
||||
for i_level in range(self.num_resolutions):
|
||||
for i_block in range(self.num_res_blocks):
|
||||
h = self.down[i_level].block[i_block](hs[-1], temb)
|
||||
if len(self.down[i_level].attn) > 0:
|
||||
h = self.down[i_level].attn[i_block](h)
|
||||
hs.append(h)
|
||||
if i_level != self.num_resolutions - 1:
|
||||
hs.append(self.down[i_level].downsample(hs[-1]))
|
||||
|
||||
# middle
|
||||
h = hs[-1]
|
||||
h = self.mid.block_1(h, temb)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h, temb)
|
||||
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
ch,
|
||||
out_ch,
|
||||
ch_mult=(1, 2, 4, 8),
|
||||
num_res_blocks,
|
||||
attn_resolutions,
|
||||
dropout=0.0,
|
||||
resamp_with_conv=True,
|
||||
in_channels,
|
||||
resolution,
|
||||
z_channels,
|
||||
give_pre_end=False,
|
||||
tanh_out=False,
|
||||
use_linear_attn=False,
|
||||
attn_type='vanilla',
|
||||
**ignorekwargs):
|
||||
super().__init__()
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
self.give_pre_end = give_pre_end
|
||||
self.tanh_out = tanh_out
|
||||
|
||||
# compute block_in and curr_res at lowest res
|
||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||
curr_res = resolution // 2**(self.num_resolutions - 1)
|
||||
self.z_shape = (1, z_channels, curr_res, curr_res)
|
||||
logging.info('Working with z of shape {} = {} dimensions.'.format(
|
||||
self.z_shape, np.prod(self.z_shape)))
|
||||
|
||||
# z to block_in
|
||||
self.conv_in = torch.nn.Conv2d(
|
||||
z_channels, block_in, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
self.mid.attn_1 = AttnBlock(block_in)
|
||||
self.mid.block_2 = ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
|
||||
# upsampling
|
||||
self.up = nn.ModuleList()
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
block.append(
|
||||
ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout))
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(AttnBlock(block_in))
|
||||
up = nn.Module()
|
||||
up.block = block
|
||||
up.attn = attn
|
||||
if i_level != 0:
|
||||
up.upsample = Upsample(block_in, resamp_with_conv)
|
||||
curr_res = curr_res * 2
|
||||
self.up.insert(0, up)
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = torch.nn.Conv2d(
|
||||
block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, z):
|
||||
self.last_z_shape = z.shape
|
||||
|
||||
# timestep embedding
|
||||
temb = None
|
||||
|
||||
# z to block_in
|
||||
h = self.conv_in(z)
|
||||
|
||||
# middle
|
||||
h = self.mid.block_1(h, temb)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h, temb)
|
||||
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
h = self.up[i_level].block[i_block](h, temb)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
h = self.up[i_level].attn[i_block](h)
|
||||
if i_level != 0:
|
||||
h = self.up[i_level].upsample(h)
|
||||
|
||||
# end
|
||||
if self.give_pre_end:
|
||||
return h
|
||||
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
if self.tanh_out:
|
||||
h = torch.tanh(h)
|
||||
return h
|
||||
|
||||
|
||||
class AutoencoderKL(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
ddconfig,
|
||||
embed_dim,
|
||||
pretrained=None,
|
||||
ignore_keys=[],
|
||||
image_key='image',
|
||||
colorize_nlabels=None,
|
||||
monitor=None,
|
||||
ema_decay=None,
|
||||
learn_logvar=False,
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
self.learn_logvar = learn_logvar
|
||||
self.image_key = image_key
|
||||
self.encoder = Encoder(**ddconfig)
|
||||
self.decoder = Decoder(**ddconfig)
|
||||
assert ddconfig['double_z']
|
||||
self.quant_conv = torch.nn.Conv2d(2 * ddconfig['z_channels'],
|
||||
2 * embed_dim, 1)
|
||||
self.post_quant_conv = torch.nn.Conv2d(embed_dim,
|
||||
ddconfig['z_channels'], 1)
|
||||
self.embed_dim = embed_dim
|
||||
if colorize_nlabels is not None:
|
||||
assert type(colorize_nlabels) == int
|
||||
self.register_buffer('colorize',
|
||||
torch.randn(3, colorize_nlabels, 1, 1))
|
||||
if monitor is not None:
|
||||
self.monitor = monitor
|
||||
|
||||
self.use_ema = ema_decay is not None
|
||||
|
||||
if pretrained is not None:
|
||||
self.init_from_ckpt(pretrained, ignore_keys=ignore_keys)
|
||||
|
||||
def init_from_ckpt(self, path, ignore_keys=list()):
|
||||
sd = torch.load(path, map_location='cpu')['state_dict']
|
||||
keys = list(sd.keys())
|
||||
sd_new = collections.OrderedDict()
|
||||
for k in keys:
|
||||
if k.find('first_stage_model') >= 0:
|
||||
k_new = k.split('first_stage_model.')[-1]
|
||||
sd_new[k_new] = sd[k]
|
||||
self.load_state_dict(sd_new, strict=True)
|
||||
logging.info(f'Restored from {path}')
|
||||
|
||||
def on_train_batch_end(self, *args, **kwargs):
|
||||
if self.use_ema:
|
||||
self.model_ema(self)
|
||||
|
||||
def encode(self, x):
|
||||
h = self.encoder(x)
|
||||
moments = self.quant_conv(h)
|
||||
posterior = DiagonalGaussianDistribution(moments)
|
||||
return posterior
|
||||
|
||||
def decode(self, z):
|
||||
z = self.post_quant_conv(z)
|
||||
dec = self.decoder(z)
|
||||
return dec
|
||||
|
||||
def forward(self, input, sample_posterior=True):
|
||||
posterior = self.encode(input)
|
||||
if sample_posterior:
|
||||
z = posterior.sample()
|
||||
else:
|
||||
z = posterior.mode()
|
||||
dec = self.decode(z)
|
||||
return dec, posterior
|
||||
|
||||
def get_input(self, batch, k):
|
||||
x = batch[k]
|
||||
if len(x.shape) == 3:
|
||||
x = x[..., None]
|
||||
x = x.permute(0, 3, 1,
|
||||
2).to(memory_format=torch.contiguous_format).float()
|
||||
return x
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.decoder.conv_out.weight
|
||||
|
||||
@torch.no_grad()
|
||||
def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs):
|
||||
log = dict()
|
||||
x = self.get_input(batch, self.image_key)
|
||||
x = x.to(self.device)
|
||||
if not only_inputs:
|
||||
xrec, posterior = self(x)
|
||||
if x.shape[1] > 3:
|
||||
# colorize with random projection
|
||||
assert xrec.shape[1] > 3
|
||||
x = self.to_rgb(x)
|
||||
xrec = self.to_rgb(xrec)
|
||||
log['samples'] = self.decode(torch.randn_like(posterior.sample()))
|
||||
log['reconstructions'] = xrec
|
||||
if log_ema or self.use_ema:
|
||||
with self.ema_scope():
|
||||
xrec_ema, posterior_ema = self(x)
|
||||
if x.shape[1] > 3:
|
||||
# colorize with random projection
|
||||
assert xrec_ema.shape[1] > 3
|
||||
xrec_ema = self.to_rgb(xrec_ema)
|
||||
log['samples_ema'] = self.decode(
|
||||
torch.randn_like(posterior_ema.sample()))
|
||||
log['reconstructions_ema'] = xrec_ema
|
||||
log['inputs'] = x
|
||||
return log
|
||||
|
||||
def to_rgb(self, x):
|
||||
assert self.image_key == 'segmentation'
|
||||
if not hasattr(self, 'colorize'):
|
||||
self.register_buffer('colorize',
|
||||
torch.randn(3, x.shape[1], 1, 1).to(x))
|
||||
x = F.conv2d(x, weight=self.colorize)
|
||||
x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
|
||||
return x
|
||||
|
||||
|
||||
class IdentityFirstStage(torch.nn.Module):
|
||||
|
||||
def __init__(self, *args, vq_interface=False, **kwargs):
|
||||
self.vq_interface = vq_interface
|
||||
super().__init__()
|
||||
|
||||
def encode(self, x, *args, **kwargs):
|
||||
return x
|
||||
|
||||
def decode(self, x, *args, **kwargs):
|
||||
return x
|
||||
|
||||
def quantize(self, x, *args, **kwargs):
|
||||
if self.vq_interface:
|
||||
return x, None, [None, None, None]
|
||||
return x
|
||||
|
||||
def forward(self, x, *args, **kwargs):
|
||||
return x
|
||||
82
modelscope/models/multi_modal/image_to_video/modules/embedder.py
Executable file
82
modelscope/models/multi_modal/image_to_video/modules/embedder.py
Executable file
@@ -0,0 +1,82 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import open_clip
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.transforms as T
|
||||
|
||||
|
||||
class FrozenOpenCLIPVisualEmbedder(nn.Module):
|
||||
"""
|
||||
Uses the OpenCLIP transformer encoder for text
|
||||
"""
|
||||
LAYERS = ['last', 'penultimate']
|
||||
|
||||
def __init__(self,
|
||||
pretrained,
|
||||
vit_resolution=(224, 224),
|
||||
arch='ViT-H-14',
|
||||
device='cuda',
|
||||
max_length=77,
|
||||
freeze=True,
|
||||
layer='last',
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
assert layer in self.LAYERS
|
||||
model, _, preprocess = open_clip.create_model_and_transforms(
|
||||
arch, device=torch.device('cpu'), pretrained=pretrained)
|
||||
|
||||
del model.transformer
|
||||
self.model = model
|
||||
data_white = np.ones(
|
||||
(vit_resolution[0], vit_resolution[1], 3), dtype=np.uint8) * 255
|
||||
self.white_image = preprocess(T.ToPILImage()(data_white)).unsqueeze(0)
|
||||
|
||||
self.device = device
|
||||
self.max_length = max_length
|
||||
if freeze:
|
||||
self.freeze()
|
||||
self.layer = layer
|
||||
if self.layer == 'last':
|
||||
self.layer_idx = 0
|
||||
elif self.layer == 'penultimate':
|
||||
self.layer_idx = 1
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
def freeze(self):
|
||||
self.model = self.model.eval()
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, image):
|
||||
z = self.model.encode_image(image.to(self.device))
|
||||
return z
|
||||
|
||||
def encode_with_transformer(self, text):
|
||||
x = self.model.token_embedding(text)
|
||||
x = x + self.model.positional_embedding
|
||||
x = x.permute(1, 0, 2)
|
||||
x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
|
||||
x = x.permute(1, 0, 2)
|
||||
x = self.model.ln_final(x)
|
||||
|
||||
return x
|
||||
|
||||
def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
|
||||
for i, r in enumerate(self.model.transformer.resblocks):
|
||||
if i == len(self.model.transformer.resblocks) - self.layer_idx:
|
||||
break
|
||||
if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting(
|
||||
):
|
||||
x = checkpoint(r, x, attn_mask)
|
||||
else:
|
||||
x = r(x, attn_mask=attn_mask)
|
||||
return x
|
||||
|
||||
def encode(self, text):
|
||||
return self(text)
|
||||
1504
modelscope/models/multi_modal/image_to_video/modules/unet_i2v.py
Normal file
1504
modelscope/models/multi_modal/image_to_video/modules/unet_i2v.py
Normal file
File diff suppressed because it is too large
Load Diff
2
modelscope/models/multi_modal/image_to_video/utils/__init__.py
Executable file
2
modelscope/models/multi_modal/image_to_video/utils/__init__.py
Executable file
@@ -0,0 +1,2 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
161
modelscope/models/multi_modal/image_to_video/utils/config.py
Executable file
161
modelscope/models/multi_modal/image_to_video/utils/config.py
Executable file
@@ -0,0 +1,161 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import os.path as osp
|
||||
from datetime import datetime
|
||||
|
||||
import torch
|
||||
from easydict import EasyDict
|
||||
|
||||
cfg = EasyDict(__name__='Config: VideoLDM Decoder')
|
||||
|
||||
# ---------------------------work dir--------------------------
|
||||
cfg.work_dir = 'workspace/'
|
||||
|
||||
# ---------------------------Global Variable-----------------------------------
|
||||
cfg.resolution = [448, 256]
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
# ---------------------------Dataset Parameter---------------------------------
|
||||
cfg.mean = [0.5, 0.5, 0.5]
|
||||
cfg.std = [0.5, 0.5, 0.5]
|
||||
cfg.max_words = 1000
|
||||
|
||||
# PlaceHolder
|
||||
cfg.vit_out_dim = 1024
|
||||
cfg.vit_resolution = [224, 224]
|
||||
cfg.depth_clamp = 10.0
|
||||
cfg.misc_size = 384
|
||||
cfg.depth_std = 20.0
|
||||
|
||||
cfg.frame_lens = 32
|
||||
cfg.sample_fps = 8
|
||||
|
||||
cfg.batch_sizes = 1
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
# ---------------------------Mode Parameters-----------------------------------
|
||||
# Diffusion
|
||||
cfg.schedule = 'cosine'
|
||||
cfg.num_timesteps = 1000
|
||||
cfg.mean_type = 'v'
|
||||
cfg.var_type = 'fixed_small'
|
||||
cfg.loss_type = 'mse'
|
||||
cfg.ddim_timesteps = 50
|
||||
cfg.ddim_eta = 0.0
|
||||
cfg.clamp = 1.0
|
||||
cfg.share_noise = False
|
||||
cfg.use_div_loss = False
|
||||
cfg.noise_strength = 0.1
|
||||
|
||||
# classifier-free guidance
|
||||
cfg.p_zero = 0.1
|
||||
cfg.guide_scale = 3.0
|
||||
|
||||
# clip vision encoder
|
||||
cfg.vit_mean = [0.48145466, 0.4578275, 0.40821073]
|
||||
cfg.vit_std = [0.26862954, 0.26130258, 0.27577711]
|
||||
|
||||
# Model
|
||||
cfg.scale_factor = 0.18215
|
||||
cfg.use_fp16 = True
|
||||
cfg.temporal_attention = True
|
||||
cfg.decoder_bs = 8
|
||||
|
||||
cfg.UNet = {
|
||||
'type': 'Img2VidSDUNet',
|
||||
'in_dim': 4,
|
||||
'dim': 320,
|
||||
'y_dim': cfg.vit_out_dim,
|
||||
'context_dim': 1024,
|
||||
'out_dim': 8 if cfg.var_type.startswith('learned') else 4,
|
||||
'dim_mult': [1, 2, 4, 4],
|
||||
'num_heads': 8,
|
||||
'head_dim': 64,
|
||||
'num_res_blocks': 2,
|
||||
'attn_scales': [1 / 1, 1 / 2, 1 / 4],
|
||||
'dropout': 0.1,
|
||||
'temporal_attention': cfg.temporal_attention,
|
||||
'temporal_attn_times': 1,
|
||||
'use_checkpoint': False,
|
||||
'use_fps_condition': False,
|
||||
'use_sim_mask': False,
|
||||
'num_tokens': 4,
|
||||
'default_fps': 8,
|
||||
'input_dim': 1024
|
||||
}
|
||||
|
||||
cfg.guidances = []
|
||||
|
||||
# auotoencoder from stabel diffusion
|
||||
cfg.auto_encoder = {
|
||||
'type': 'AutoencoderKL',
|
||||
'ddconfig': {
|
||||
'double_z': True,
|
||||
'z_channels': 4,
|
||||
'resolution': 256,
|
||||
'in_channels': 3,
|
||||
'out_ch': 3,
|
||||
'ch': 128,
|
||||
'ch_mult': [1, 2, 4, 4],
|
||||
'num_res_blocks': 2,
|
||||
'attn_resolutions': [],
|
||||
'dropout': 0.0
|
||||
},
|
||||
'embed_dim': 4,
|
||||
'pretrained': 'v2-1_512-ema-pruned.ckpt'
|
||||
}
|
||||
# clip embedder
|
||||
cfg.embedder = {
|
||||
'type': 'FrozenOpenCLIPVisualEmbedder',
|
||||
'layer': 'penultimate',
|
||||
'vit_resolution': [224, 224],
|
||||
'pretrained': 'open_clip_pytorch_model.bin'
|
||||
}
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
# ---------------------------Training Settings---------------------------------
|
||||
# training and optimizer
|
||||
cfg.ema_decay = 0.9999
|
||||
cfg.num_steps = 600000
|
||||
cfg.lr = 5e-5
|
||||
cfg.weight_decay = 0.0
|
||||
cfg.betas = (0.9, 0.999)
|
||||
cfg.eps = 1.0e-8
|
||||
cfg.chunk_size = 16
|
||||
cfg.alpha = 0.7
|
||||
cfg.save_ckp_interval = 1000
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
# ----------------------------Pretrain Settings---------------------------------
|
||||
# Default: load 2d pretrain
|
||||
cfg.fix_weight = False
|
||||
cfg.load_match = False
|
||||
cfg.pretrained_checkpoint = 'v2-1_512-ema-pruned.ckpt'
|
||||
cfg.pretrained_image_keys = 'stable_diffusion_image_key_temporal_attention_x1.json'
|
||||
cfg.resume_checkpoint = 'img2video_ldm_0779000.pth'
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
# -----------------------------Visual-------------------------------------------
|
||||
# Visual videos
|
||||
cfg.viz_interval = 1000
|
||||
cfg.visual_train = {
|
||||
'type': 'VisualVideoTextDuringTrain',
|
||||
}
|
||||
cfg.visual_inference = {
|
||||
'type': 'VisualGeneratedVideos',
|
||||
}
|
||||
cfg.inference_list_path = ''
|
||||
|
||||
# logging
|
||||
cfg.log_interval = 100
|
||||
|
||||
# Default log_dir
|
||||
cfg.log_dir = 'workspace/output_data'
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
# ---------------------------Others--------------------------------------------
|
||||
# seed
|
||||
cfg.seed = 8888
|
||||
# -----------------------------------------------------------------------------
|
||||
511
modelscope/models/multi_modal/image_to_video/utils/diffusion.py
Executable file
511
modelscope/models/multi_modal/image_to_video/utils/diffusion.py
Executable file
@@ -0,0 +1,511 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
|
||||
__all__ = ['GaussianDiffusion', 'beta_schedule']
|
||||
|
||||
|
||||
def _i(tensor, t, x):
|
||||
r"""Index tensor using t and format the output according to x.
|
||||
"""
|
||||
shape = (x.size(0), ) + (1, ) * (x.ndim - 1)
|
||||
if tensor.device != x.device:
|
||||
tensor = tensor.to(x.device)
|
||||
return tensor[t].view(shape).to(x)
|
||||
|
||||
|
||||
def fn(u):
|
||||
return math.cos((u + 0.008) / 1.008 * math.pi / 2)**2
|
||||
|
||||
|
||||
def beta_schedule(schedule,
|
||||
num_timesteps=1000,
|
||||
init_beta=None,
|
||||
last_beta=None):
|
||||
if schedule == 'linear':
|
||||
scale = 1000.0 / num_timesteps
|
||||
init_beta = init_beta or scale * 0.0001
|
||||
last_beta = last_beta or scale * 0.02
|
||||
return torch.linspace(
|
||||
init_beta, last_beta, num_timesteps, dtype=torch.float64)
|
||||
elif schedule == 'quadratic':
|
||||
init_beta = init_beta or 0.0015
|
||||
last_beta = last_beta or 0.0195
|
||||
return torch.linspace(
|
||||
init_beta**0.5, last_beta**0.5, num_timesteps,
|
||||
dtype=torch.float64)**2
|
||||
elif schedule == 'cosine':
|
||||
betas = []
|
||||
for step in range(num_timesteps):
|
||||
t1 = step / num_timesteps
|
||||
t2 = (step + 1) / num_timesteps
|
||||
betas.append(min(1.0 - fn(t2) / fn(t1), 0.999))
|
||||
return torch.tensor(betas, dtype=torch.float64)
|
||||
else:
|
||||
raise ValueError(f'Unsupported schedule: {schedule}')
|
||||
|
||||
|
||||
class GaussianDiffusion(object):
|
||||
|
||||
def __init__(self,
|
||||
betas,
|
||||
mean_type='eps',
|
||||
var_type='learned_range',
|
||||
loss_type='mse',
|
||||
epsilon=1e-12,
|
||||
rescale_timesteps=False,
|
||||
noise_strength=0.0):
|
||||
# check input
|
||||
if not isinstance(betas, torch.DoubleTensor):
|
||||
betas = torch.tensor(betas, dtype=torch.float64)
|
||||
assert min(betas) > 0 and max(betas) <= 1
|
||||
assert mean_type in ['x0', 'x_{t-1}', 'eps', 'v']
|
||||
assert var_type in [
|
||||
'learned', 'learned_range', 'fixed_large', 'fixed_small'
|
||||
]
|
||||
assert loss_type in [
|
||||
'mse', 'rescaled_mse', 'kl', 'rescaled_kl', 'l1', 'rescaled_l1',
|
||||
'charbonnier'
|
||||
]
|
||||
self.betas = betas
|
||||
self.num_timesteps = len(betas)
|
||||
self.mean_type = mean_type
|
||||
self.var_type = var_type
|
||||
self.loss_type = loss_type
|
||||
self.epsilon = epsilon
|
||||
self.rescale_timesteps = rescale_timesteps
|
||||
self.noise_strength = noise_strength
|
||||
|
||||
# alphas
|
||||
alphas = 1 - self.betas
|
||||
self.alphas_cumprod = torch.cumprod(alphas, dim=0)
|
||||
self.alphas_cumprod_prev = torch.cat(
|
||||
[alphas.new_ones([1]), self.alphas_cumprod[:-1]])
|
||||
self.alphas_cumprod_next = torch.cat(
|
||||
[self.alphas_cumprod[1:],
|
||||
alphas.new_zeros([1])])
|
||||
|
||||
# q(x_t | x_{t-1})
|
||||
self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
|
||||
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0
|
||||
- self.alphas_cumprod)
|
||||
self.log_one_minus_alphas_cumprod = torch.log(1.0
|
||||
- self.alphas_cumprod)
|
||||
self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod)
|
||||
self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod
|
||||
- 1)
|
||||
|
||||
# q(x_{t-1} | x_t, x_0)
|
||||
self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / (
|
||||
1.0 - self.alphas_cumprod)
|
||||
self.posterior_log_variance_clipped = torch.log(
|
||||
self.posterior_variance.clamp(1e-20))
|
||||
self.posterior_mean_coef1 = betas * torch.sqrt(
|
||||
self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
|
||||
self.posterior_mean_coef2 = (
|
||||
1.0 - self.alphas_cumprod_prev) * torch.sqrt(alphas) / (
|
||||
1.0 - self.alphas_cumprod)
|
||||
|
||||
def sample_loss(self, x0, noise=None):
|
||||
if noise is None:
|
||||
noise = torch.randn_like(x0)
|
||||
if self.noise_strength > 0:
|
||||
b, c, f, _, _ = x0.shape
|
||||
offset_noise = torch.randn(b, c, f, 1, 1, device=x0.device)
|
||||
noise = noise + self.noise_strength * offset_noise
|
||||
return noise
|
||||
|
||||
def q_sample(self, x0, t, noise=None):
|
||||
r"""Sample from q(x_t | x_0).
|
||||
"""
|
||||
# noise = torch.randn_like(x0) if noise is None else noise
|
||||
noise = self.sample_loss(x0, noise)
|
||||
return _i(self.sqrt_alphas_cumprod, t, x0) * x0 + (
|
||||
_i(self.sqrt_one_minus_alphas_cumprod, t, x0) * noise)
|
||||
|
||||
def q_mean_variance(self, x0, t):
|
||||
r"""Distribution of q(x_t | x_0).
|
||||
"""
|
||||
mu = _i(self.sqrt_alphas_cumprod, t, x0) * x0
|
||||
var = _i(1.0 - self.alphas_cumprod, t, x0)
|
||||
log_var = _i(self.log_one_minus_alphas_cumprod, t, x0)
|
||||
return mu, var, log_var
|
||||
|
||||
def q_posterior_mean_variance(self, x0, xt, t):
|
||||
r"""Distribution of q(x_{t-1} | x_t, x_0).
|
||||
"""
|
||||
mu = _i(self.posterior_mean_coef1, t, xt) * x0 + _i(
|
||||
self.posterior_mean_coef2, t, xt) * xt
|
||||
var = _i(self.posterior_variance, t, xt)
|
||||
log_var = _i(self.posterior_log_variance_clipped, t, xt)
|
||||
return mu, var, log_var
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample(self,
|
||||
xt,
|
||||
t,
|
||||
model,
|
||||
model_kwargs={},
|
||||
clamp=None,
|
||||
percentile=None,
|
||||
condition_fn=None,
|
||||
guide_scale=None):
|
||||
r"""Sample from p(x_{t-1} | x_t).
|
||||
- condition_fn: for classifier-based guidance (guided-diffusion).
|
||||
- guide_scale: for classifier-free guidance (glide/dalle-2).
|
||||
"""
|
||||
# predict distribution of p(x_{t-1} | x_t)
|
||||
mu, var, log_var, x0 = self.p_mean_variance(xt, t, model, model_kwargs,
|
||||
clamp, percentile,
|
||||
guide_scale)
|
||||
|
||||
# random sample (with optional conditional function)
|
||||
noise = torch.randn_like(xt)
|
||||
mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1)))
|
||||
if condition_fn is not None:
|
||||
grad = condition_fn(xt, self._scale_timesteps(t), **model_kwargs)
|
||||
mu = mu.float() + var * grad.float()
|
||||
xt_1 = mu + mask * torch.exp(0.5 * log_var) * noise
|
||||
return xt_1, x0
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_loop(self,
|
||||
noise,
|
||||
model,
|
||||
model_kwargs={},
|
||||
clamp=None,
|
||||
percentile=None,
|
||||
condition_fn=None,
|
||||
guide_scale=None):
|
||||
r"""Sample from p(x_{t-1} | x_t) p(x_{t-2} | x_{t-1}) ... p(x_0 | x_1).
|
||||
"""
|
||||
# prepare input
|
||||
b = noise.size(0)
|
||||
xt = noise
|
||||
|
||||
# diffusion process
|
||||
for step in torch.arange(self.num_timesteps).flip(0):
|
||||
t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
|
||||
xt, _ = self.p_sample(xt, t, model, model_kwargs, clamp,
|
||||
percentile, condition_fn, guide_scale)
|
||||
return xt
|
||||
|
||||
def p_mean_variance(self,
|
||||
xt,
|
||||
t,
|
||||
model,
|
||||
model_kwargs={},
|
||||
clamp=None,
|
||||
percentile=None,
|
||||
guide_scale=None):
|
||||
r"""Distribution of p(x_{t-1} | x_t).
|
||||
"""
|
||||
# predict distribution
|
||||
if guide_scale is None:
|
||||
out = model(xt, self._scale_timesteps(t), **model_kwargs)
|
||||
else:
|
||||
# classifier-free guidance
|
||||
# (model_kwargs[0]: conditional kwargs; model_kwargs[1]: non-conditional kwargs)
|
||||
assert isinstance(model_kwargs, list) and len(model_kwargs) == 2
|
||||
y_out = model(xt, self._scale_timesteps(t), **model_kwargs[0])
|
||||
u_out = model(xt, self._scale_timesteps(t), **model_kwargs[1])
|
||||
dim = y_out.size(1) if self.var_type.startswith(
|
||||
'fixed') else y_out.size(1) // 2
|
||||
out = torch.cat(
|
||||
[
|
||||
u_out[:, :dim] + guide_scale * # noqa
|
||||
(y_out[:, :dim] - u_out[:, :dim]),
|
||||
y_out[:, dim:]
|
||||
],
|
||||
dim=1)
|
||||
|
||||
# compute variance
|
||||
if self.var_type == 'learned':
|
||||
out, log_var = out.chunk(2, dim=1)
|
||||
var = torch.exp(log_var)
|
||||
elif self.var_type == 'learned_range':
|
||||
out, fraction = out.chunk(2, dim=1)
|
||||
min_log_var = _i(self.posterior_log_variance_clipped, t, xt)
|
||||
max_log_var = _i(torch.log(self.betas), t, xt)
|
||||
fraction = (fraction + 1) / 2.0
|
||||
log_var = fraction * max_log_var + (1 - fraction) * min_log_var
|
||||
var = torch.exp(log_var)
|
||||
elif self.var_type == 'fixed_large':
|
||||
var = _i(
|
||||
torch.cat([self.posterior_variance[1:2], self.betas[1:]]), t,
|
||||
xt)
|
||||
log_var = torch.log(var)
|
||||
elif self.var_type == 'fixed_small':
|
||||
var = _i(self.posterior_variance, t, xt)
|
||||
log_var = _i(self.posterior_log_variance_clipped, t, xt)
|
||||
|
||||
# compute mean and x0
|
||||
if self.mean_type == 'x_{t-1}':
|
||||
mu = out
|
||||
x0 = _i(1.0 / self.posterior_mean_coef1, t, xt) * mu - (
|
||||
_i(self.posterior_mean_coef2 / self.posterior_mean_coef1, t,
|
||||
xt) * xt)
|
||||
elif self.mean_type == 'x0':
|
||||
x0 = out
|
||||
mu, _, _ = self.q_posterior_mean_variance(x0, xt, t)
|
||||
elif self.mean_type == 'eps':
|
||||
x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - (
|
||||
_i(self.sqrt_recipm1_alphas_cumprod, t, xt) * out)
|
||||
mu, _, _ = self.q_posterior_mean_variance(x0, xt, t)
|
||||
elif self.mean_type == 'v':
|
||||
x0 = _i(self.sqrt_alphas_cumprod, t, xt) * xt - (
|
||||
_i(self.sqrt_one_minus_alphas_cumprod, t, xt) * out)
|
||||
mu, _, _ = self.q_posterior_mean_variance(x0, xt, t)
|
||||
|
||||
# restrict the range of x0
|
||||
if percentile is not None:
|
||||
assert percentile > 0 and percentile <= 1
|
||||
s = torch.quantile(
|
||||
x0.flatten(1).abs(), percentile,
|
||||
dim=1).clamp_(1.0).view(-1, 1, 1, 1)
|
||||
x0 = torch.min(s, torch.max(-s, x0)) / s
|
||||
elif clamp is not None:
|
||||
x0 = x0.clamp(-clamp, clamp)
|
||||
return mu, var, log_var, x0
|
||||
|
||||
@torch.no_grad()
|
||||
def ddim_sample(self,
|
||||
xt,
|
||||
t,
|
||||
model,
|
||||
model_kwargs={},
|
||||
clamp=None,
|
||||
percentile=None,
|
||||
condition_fn=None,
|
||||
guide_scale=None,
|
||||
ddim_timesteps=20,
|
||||
eta=0.0):
|
||||
r"""Sample from p(x_{t-1} | x_t) using DDIM.
|
||||
- condition_fn: for classifier-based guidance (guided-diffusion).
|
||||
- guide_scale: for classifier-free guidance (glide/dalle-2).
|
||||
"""
|
||||
stride = self.num_timesteps // ddim_timesteps
|
||||
|
||||
# predict distribution of p(x_{t-1} | x_t)
|
||||
_, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp,
|
||||
percentile, guide_scale)
|
||||
if condition_fn is not None:
|
||||
# x0 -> eps
|
||||
alpha = _i(self.alphas_cumprod, t, xt)
|
||||
eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / (
|
||||
_i(self.sqrt_recipm1_alphas_cumprod, t, xt))
|
||||
eps = eps - (1 - alpha).sqrt() * condition_fn(
|
||||
xt, self._scale_timesteps(t), **model_kwargs)
|
||||
|
||||
# eps -> x0
|
||||
x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - (
|
||||
_i(self.sqrt_recipm1_alphas_cumprod, t, xt) * eps)
|
||||
|
||||
# derive variables
|
||||
eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / (
|
||||
_i(self.sqrt_recipm1_alphas_cumprod, t, xt))
|
||||
alphas = _i(self.alphas_cumprod, t, xt)
|
||||
alphas_prev = _i(self.alphas_cumprod, (t - stride).clamp(0), xt)
|
||||
sigmas = eta * torch.sqrt((1 - alphas_prev) / (1 - alphas) * # noqa
|
||||
(1 - alphas / alphas_prev))
|
||||
|
||||
# random sample
|
||||
noise = torch.randn_like(xt)
|
||||
direction = torch.sqrt(1 - alphas_prev - sigmas**2) * eps
|
||||
mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1)))
|
||||
xt_1 = torch.sqrt(alphas_prev) * x0 + direction + mask * sigmas * noise
|
||||
return xt_1, x0
|
||||
|
||||
@torch.no_grad()
|
||||
def ddim_sample_loop(self,
|
||||
noise,
|
||||
model,
|
||||
model_kwargs={},
|
||||
clamp=None,
|
||||
percentile=None,
|
||||
condition_fn=None,
|
||||
guide_scale=None,
|
||||
ddim_timesteps=20,
|
||||
eta=0.0):
|
||||
# prepare input
|
||||
b = noise.size(0)
|
||||
xt = noise
|
||||
|
||||
# diffusion process (TODO: clamp is inaccurate! Consider replacing the stride by explicit prev/next steps)
|
||||
steps = (1 + torch.arange(0, self.num_timesteps,
|
||||
self.num_timesteps // ddim_timesteps)).clamp(
|
||||
0, self.num_timesteps - 1).flip(0)
|
||||
for step in steps:
|
||||
t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
|
||||
xt, _ = self.ddim_sample(xt, t, model, model_kwargs, clamp,
|
||||
percentile, condition_fn, guide_scale,
|
||||
ddim_timesteps, eta)
|
||||
return xt
|
||||
|
||||
@torch.no_grad()
|
||||
def ddim_reverse_sample(self,
|
||||
xt,
|
||||
t,
|
||||
model,
|
||||
model_kwargs={},
|
||||
clamp=None,
|
||||
percentile=None,
|
||||
guide_scale=None,
|
||||
ddim_timesteps=20):
|
||||
r"""Sample from p(x_{t+1} | x_t) using DDIM reverse ODE (deterministic).
|
||||
"""
|
||||
stride = self.num_timesteps // ddim_timesteps
|
||||
|
||||
# predict distribution of p(x_{t-1} | x_t)
|
||||
_, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp,
|
||||
percentile, guide_scale)
|
||||
|
||||
# derive variables
|
||||
eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / (
|
||||
_i(self.sqrt_recipm1_alphas_cumprod, t, xt))
|
||||
alphas_next = _i(
|
||||
torch.cat(
|
||||
[self.alphas_cumprod,
|
||||
self.alphas_cumprod.new_zeros([1])]),
|
||||
(t + stride).clamp(0, self.num_timesteps), xt)
|
||||
|
||||
# reverse sample
|
||||
mu = torch.sqrt(alphas_next) * x0 + torch.sqrt(1 - alphas_next) * eps
|
||||
return mu, x0
|
||||
|
||||
@torch.no_grad()
|
||||
def ddim_reverse_sample_loop(self,
|
||||
x0,
|
||||
model,
|
||||
model_kwargs={},
|
||||
clamp=None,
|
||||
percentile=None,
|
||||
guide_scale=None,
|
||||
ddim_timesteps=20):
|
||||
# prepare input
|
||||
b = x0.size(0)
|
||||
xt = x0
|
||||
|
||||
# reconstruction steps
|
||||
steps = torch.arange(0, self.num_timesteps,
|
||||
self.num_timesteps // ddim_timesteps)
|
||||
for step in steps:
|
||||
t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
|
||||
xt, _ = self.ddim_reverse_sample(xt, t, model, model_kwargs, clamp,
|
||||
percentile, guide_scale,
|
||||
ddim_timesteps)
|
||||
return xt
|
||||
|
||||
@torch.no_grad()
|
||||
def plms_sample(self,
|
||||
xt,
|
||||
t,
|
||||
model,
|
||||
model_kwargs={},
|
||||
clamp=None,
|
||||
percentile=None,
|
||||
condition_fn=None,
|
||||
guide_scale=None,
|
||||
plms_timesteps=20):
|
||||
r"""Sample from p(x_{t-1} | x_t) using PLMS.
|
||||
- condition_fn: for classifier-based guidance (guided-diffusion).
|
||||
- guide_scale: for classifier-free guidance (glide/dalle-2).
|
||||
"""
|
||||
stride = self.num_timesteps // plms_timesteps
|
||||
|
||||
# function for compute eps
|
||||
def compute_eps(xt, t):
|
||||
# predict distribution of p(x_{t-1} | x_t)
|
||||
_, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs,
|
||||
clamp, percentile, guide_scale)
|
||||
|
||||
# condition
|
||||
if condition_fn is not None:
|
||||
# x0 -> eps
|
||||
alpha = _i(self.alphas_cumprod, t, xt)
|
||||
eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / (
|
||||
_i(self.sqrt_recipm1_alphas_cumprod, t, xt))
|
||||
eps = eps - (1 - alpha).sqrt() * condition_fn(
|
||||
xt, self._scale_timesteps(t), **model_kwargs)
|
||||
|
||||
# eps -> x0
|
||||
x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - (
|
||||
_i(self.sqrt_recipm1_alphas_cumprod, t, xt) * eps)
|
||||
|
||||
# derive eps
|
||||
eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / (
|
||||
_i(self.sqrt_recipm1_alphas_cumprod, t, xt))
|
||||
return eps
|
||||
|
||||
# function for compute x_0 and x_{t-1}
|
||||
def compute_x0(eps, t):
|
||||
# eps -> x0
|
||||
x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - (
|
||||
_i(self.sqrt_recipm1_alphas_cumprod, t, xt) * eps)
|
||||
|
||||
# deterministic sample
|
||||
alphas_prev = _i(self.alphas_cumprod, (t - stride).clamp(0), xt)
|
||||
direction = torch.sqrt(1 - alphas_prev) * eps
|
||||
xt_1 = torch.sqrt(alphas_prev) * x0 + direction
|
||||
return xt_1, x0
|
||||
|
||||
# PLMS sample
|
||||
eps = compute_eps(xt, t)
|
||||
if len(eps_cache) == 0:
|
||||
# 2nd order pseudo improved Euler
|
||||
xt_1, x0 = compute_x0(eps, t)
|
||||
eps_next = compute_eps(xt_1, (t - stride).clamp(0))
|
||||
eps_prime = (eps + eps_next) / 2.0
|
||||
elif len(eps_cache) == 1:
|
||||
# 2nd order pseudo linear multistep (Adams-Bashforth)
|
||||
eps_prime = (3 * eps - eps_cache[-1]) / 2.0
|
||||
elif len(eps_cache) == 2:
|
||||
# 3nd order pseudo linear multistep (Adams-Bashforth)
|
||||
eps_prime = (23 * eps - 16 * eps_cache[-1]
|
||||
+ 5 * eps_cache[-2]) / 12.0
|
||||
elif len(eps_cache) >= 3:
|
||||
# 4nd order pseudo linear multistep (Adams-Bashforth)
|
||||
eps_prime = (55 * eps - 59 * eps_cache[-1] + 37 * eps_cache[-2]
|
||||
- 9 * eps_cache[-3]) / 24.0
|
||||
xt_1, x0 = compute_x0(eps_prime, t)
|
||||
return xt_1, x0, eps
|
||||
|
||||
@torch.no_grad()
|
||||
def plms_sample_loop(self,
|
||||
noise,
|
||||
model,
|
||||
model_kwargs={},
|
||||
clamp=None,
|
||||
percentile=None,
|
||||
condition_fn=None,
|
||||
guide_scale=None,
|
||||
plms_timesteps=20):
|
||||
# prepare input
|
||||
b = noise.size(0)
|
||||
xt = noise
|
||||
|
||||
# diffusion process
|
||||
steps = (1 + torch.arange(0, self.num_timesteps,
|
||||
self.num_timesteps // plms_timesteps)).clamp(
|
||||
0, self.num_timesteps - 1).flip(0)
|
||||
eps_cache = []
|
||||
for step in steps:
|
||||
# PLMS sampling step
|
||||
t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
|
||||
xt, _, eps = self.plms_sample(xt, t, model, model_kwargs, clamp,
|
||||
percentile, condition_fn,
|
||||
guide_scale, plms_timesteps,
|
||||
eps_cache)
|
||||
|
||||
# update eps cache
|
||||
eps_cache.append(eps)
|
||||
if len(eps_cache) >= 4:
|
||||
eps_cache.pop(0)
|
||||
return xt
|
||||
|
||||
def _scale_timesteps(self, t):
|
||||
if self.rescale_timesteps:
|
||||
return t.float() * 1000.0 / self.num_timesteps
|
||||
return t
|
||||
14
modelscope/models/multi_modal/image_to_video/utils/seed.py
Executable file
14
modelscope/models/multi_modal/image_to_video/utils/seed.py
Executable file
@@ -0,0 +1,14 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def setup_seed(seed):
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
@@ -0,0 +1,60 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def fn(u):
|
||||
return math.cos((u + 0.008) / 1.008 * math.pi / 2)**2
|
||||
|
||||
|
||||
def beta_schedule(schedule,
|
||||
num_timesteps=1000,
|
||||
init_beta=None,
|
||||
last_beta=None):
|
||||
'''
|
||||
This code defines a function beta_schedule that generates a sequence of beta values based on the given input
|
||||
parameters. These beta values can be used in video diffusion processes. The function has the following parameters:
|
||||
schedule(str): Determines the type of beta schedule to be generated. It can be 'linear', 'linear_sd',
|
||||
'quadratic', or 'cosine'.
|
||||
num_timesteps(int, optional): The number of timesteps for the generated beta schedule. Default is 1000.
|
||||
init_beta(float, optional): The initial beta value. If not provided, a default value is used based on the
|
||||
chosen schedule.
|
||||
last_beta(float, optional): The final beta value. If not provided, a default value is used based on the
|
||||
chosen schedule.
|
||||
The function returns a PyTorch tensor containing the generated beta values.
|
||||
The beta schedule is determined by the schedule parameter:
|
||||
1.Linear: Generates a linear sequence of beta values betweeninit_betaandlast_beta.
|
||||
2.Linear_sd: Generates a linear sequence of beta values between the square root of init_beta and the square root
|
||||
oflast_beta, and then squares the result.
|
||||
3.Quadratic: Similar to the 'linear_sd' schedule, but with different default values forinit_betaandlast_beta.
|
||||
4.Cosine: Generates a sequence of beta values based on a cosine function, ensuring the values are between 0
|
||||
and 0.999.
|
||||
If an unsupported schedule is provided, a ValueError is raised with a message indicating the issue.
|
||||
'''
|
||||
if schedule == 'linear':
|
||||
scale = 1000.0 / num_timesteps
|
||||
init_beta = init_beta or scale * 0.0001
|
||||
last_beta = last_beta or scale * 0.02
|
||||
return torch.linspace(
|
||||
init_beta, last_beta, num_timesteps, dtype=torch.float64)
|
||||
elif schedule == 'linear_sd':
|
||||
return torch.linspace(
|
||||
init_beta**0.5, last_beta**0.5, num_timesteps,
|
||||
dtype=torch.float64)**2
|
||||
elif schedule == 'quadratic':
|
||||
init_beta = init_beta or 0.0015
|
||||
last_beta = last_beta or 0.0195
|
||||
return torch.linspace(
|
||||
init_beta**0.5, last_beta**0.5, num_timesteps,
|
||||
dtype=torch.float64)**2
|
||||
elif schedule == 'cosine':
|
||||
betas = []
|
||||
for step in range(num_timesteps):
|
||||
t1 = step / num_timesteps
|
||||
t2 = (step + 1) / num_timesteps
|
||||
betas.append(min(1.0 - fn(t2) / fn(t1), 0.999))
|
||||
return torch.tensor(betas, dtype=torch.float64)
|
||||
else:
|
||||
raise ValueError(f'Unsupported schedule: {schedule}')
|
||||
404
modelscope/models/multi_modal/image_to_video/utils/transforms.py
Executable file
404
modelscope/models/multi_modal/image_to_video/utils/transforms.py
Executable file
@@ -0,0 +1,404 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import math
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms.functional as F
|
||||
from PIL import Image, ImageFilter
|
||||
|
||||
__all__ = [
|
||||
'Compose', 'Resize', 'Rescale', 'CenterCrop', 'CenterCropV2',
|
||||
'CenterCropWide', 'RandomCrop', 'RandomCropV2', 'RandomHFlip',
|
||||
'GaussianBlur', 'ColorJitter', 'RandomGray', 'ToTensor', 'Normalize',
|
||||
'ResizeRandomCrop', 'ExtractResizeRandomCrop', 'ExtractResizeAssignCrop'
|
||||
]
|
||||
|
||||
|
||||
class Compose(object):
|
||||
|
||||
def __init__(self, transforms):
|
||||
self.transforms = transforms
|
||||
|
||||
def __getitem__(self, index):
|
||||
if isinstance(index, slice):
|
||||
return Compose(self.transforms[index])
|
||||
else:
|
||||
return self.transforms[index]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.transforms)
|
||||
|
||||
def __call__(self, rgb):
|
||||
for t in self.transforms:
|
||||
rgb = t(rgb)
|
||||
return rgb
|
||||
|
||||
|
||||
class Resize(object):
|
||||
|
||||
def __init__(self, size=256):
|
||||
if isinstance(size, int):
|
||||
size = (size, size)
|
||||
self.size = size
|
||||
|
||||
def __call__(self, rgb):
|
||||
if isinstance(rgb, list):
|
||||
rgb = [u.resize(self.size, Image.BILINEAR) for u in rgb]
|
||||
else:
|
||||
rgb = rgb.resize(self.size, Image.BILINEAR)
|
||||
return rgb
|
||||
|
||||
|
||||
class Rescale(object):
|
||||
|
||||
def __init__(self, size=256, interpolation=Image.BILINEAR):
|
||||
self.size = size
|
||||
self.interpolation = interpolation
|
||||
|
||||
def __call__(self, rgb):
|
||||
w, h = rgb[0].size
|
||||
scale = self.size / min(w, h)
|
||||
out_w, out_h = int(round(w * scale)), int(round(h * scale))
|
||||
rgb = [u.resize((out_w, out_h), self.interpolation) for u in rgb]
|
||||
return rgb
|
||||
|
||||
|
||||
class CenterCrop(object):
|
||||
|
||||
def __init__(self, size=224):
|
||||
self.size = size
|
||||
|
||||
def __call__(self, rgb):
|
||||
w, h = rgb[0].size
|
||||
assert min(w, h) >= self.size
|
||||
x1 = (w - self.size) // 2
|
||||
y1 = (h - self.size) // 2
|
||||
rgb = [u.crop((x1, y1, x1 + self.size, y1 + self.size)) for u in rgb]
|
||||
return rgb
|
||||
|
||||
|
||||
class ResizeRandomCrop(object):
|
||||
|
||||
def __init__(self, size=256, size_short=292):
|
||||
self.size = size
|
||||
self.size_short = size_short
|
||||
|
||||
def __call__(self, rgb):
|
||||
|
||||
# consistent crop between rgb and m
|
||||
while min(rgb[0].size) >= 2 * self.size_short:
|
||||
rgb = [
|
||||
u.resize((u.width // 2, u.height // 2), resample=Image.BOX)
|
||||
for u in rgb
|
||||
]
|
||||
scale = self.size_short / min(rgb[0].size)
|
||||
rgb = [
|
||||
u.resize((round(scale * u.width), round(scale * u.height)),
|
||||
resample=Image.BICUBIC) for u in rgb
|
||||
]
|
||||
out_w = self.size
|
||||
out_h = self.size
|
||||
w, h = rgb[0].size
|
||||
x1 = random.randint(0, w - out_w)
|
||||
y1 = random.randint(0, h - out_h)
|
||||
|
||||
rgb = [u.crop((x1, y1, x1 + out_w, y1 + out_h)) for u in rgb]
|
||||
return rgb
|
||||
|
||||
|
||||
class ExtractResizeRandomCrop(object):
|
||||
|
||||
def __init__(self, size=256, size_short=292):
|
||||
self.size = size
|
||||
self.size_short = size_short
|
||||
|
||||
def __call__(self, rgb):
|
||||
|
||||
# consistent crop between rgb and m
|
||||
while min(rgb[0].size) >= 2 * self.size_short:
|
||||
rgb = [
|
||||
u.resize((u.width // 2, u.height // 2), resample=Image.BOX)
|
||||
for u in rgb
|
||||
]
|
||||
scale = self.size_short / min(rgb[0].size)
|
||||
rgb = [
|
||||
u.resize((round(scale * u.width), round(scale * u.height)),
|
||||
resample=Image.BICUBIC) for u in rgb
|
||||
]
|
||||
out_w = self.size
|
||||
out_h = self.size
|
||||
w, h = rgb[0].size
|
||||
x1 = random.randint(0, w - out_w)
|
||||
y1 = random.randint(0, h - out_h)
|
||||
|
||||
rgb = [u.crop((x1, y1, x1 + out_w, y1 + out_h)) for u in rgb]
|
||||
wh = [x1, y1, x1 + out_w, y1 + out_h]
|
||||
return rgb, wh
|
||||
|
||||
|
||||
class ExtractResizeAssignCrop(object):
|
||||
|
||||
def __init__(self, size=256, size_short=292):
|
||||
self.size = size
|
||||
self.size_short = size_short
|
||||
|
||||
def __call__(self, rgb, wh):
|
||||
|
||||
# consistent crop between rgb and m
|
||||
while min(rgb[0].size) >= 2 * self.size_short:
|
||||
rgb = [
|
||||
u.resize((u.width // 2, u.height // 2), resample=Image.BOX)
|
||||
for u in rgb
|
||||
]
|
||||
scale = self.size_short / min(rgb[0].size)
|
||||
rgb = [
|
||||
u.resize((round(scale * u.width), round(scale * u.height)),
|
||||
resample=Image.BICUBIC) for u in rgb
|
||||
]
|
||||
|
||||
rgb = [u.crop(wh) for u in rgb]
|
||||
rgb = [u.resize((self.size, self.size), Image.BILINEAR) for u in rgb]
|
||||
|
||||
return rgb
|
||||
|
||||
|
||||
class CenterCropV2(object):
|
||||
|
||||
def __init__(self, size):
|
||||
self.size = size
|
||||
|
||||
def __call__(self, img):
|
||||
# fast resize
|
||||
while min(img[0].size) >= 2 * self.size:
|
||||
img = [
|
||||
u.resize((u.width // 2, u.height // 2), resample=Image.BOX)
|
||||
for u in img
|
||||
]
|
||||
scale = self.size / min(img[0].size)
|
||||
img = [
|
||||
u.resize((round(scale * u.width), round(scale * u.height)),
|
||||
resample=Image.BICUBIC) for u in img
|
||||
]
|
||||
|
||||
# center crop
|
||||
x1 = (img[0].width - self.size) // 2
|
||||
y1 = (img[0].height - self.size) // 2
|
||||
img = [u.crop((x1, y1, x1 + self.size, y1 + self.size)) for u in img]
|
||||
return img
|
||||
|
||||
|
||||
class CenterCropWide(object):
|
||||
|
||||
def __init__(self, size):
|
||||
self.size = size
|
||||
|
||||
def __call__(self, img):
|
||||
if isinstance(img, list):
|
||||
scale = min(img[0].size[0] / self.size[0],
|
||||
img[0].size[1] / self.size[1])
|
||||
img = [
|
||||
u.resize((round(u.width // scale), round(u.height // scale)),
|
||||
resample=Image.BOX) for u in img
|
||||
]
|
||||
|
||||
# center crop
|
||||
x1 = (img[0].width - self.size[0]) // 2
|
||||
y1 = (img[0].height - self.size[1]) // 2
|
||||
img = [
|
||||
u.crop((x1, y1, x1 + self.size[0], y1 + self.size[1]))
|
||||
for u in img
|
||||
]
|
||||
return img
|
||||
else:
|
||||
scale = min(img.size[0] / self.size[0], img.size[1] / self.size[1])
|
||||
img = img.resize(
|
||||
(round(img.width // scale), round(img.height // scale)),
|
||||
resample=Image.BOX)
|
||||
x1 = (img.width - self.size[0]) // 2
|
||||
y1 = (img.height - self.size[1]) // 2
|
||||
img = img.crop((x1, y1, x1 + self.size[0], y1 + self.size[1]))
|
||||
return img
|
||||
|
||||
|
||||
class RandomCrop(object):
|
||||
|
||||
def __init__(self, size=224, min_area=0.4):
|
||||
self.size = size
|
||||
self.min_area = min_area
|
||||
|
||||
def __call__(self, rgb):
|
||||
|
||||
# consistent crop between rgb and m
|
||||
w, h = rgb[0].size
|
||||
area = w * h
|
||||
out_w, out_h = float('inf'), float('inf')
|
||||
while out_w > w or out_h > h:
|
||||
target_area = random.uniform(self.min_area, 1.0) * area
|
||||
aspect_ratio = random.uniform(3. / 4., 4. / 3.)
|
||||
out_w = int(round(math.sqrt(target_area * aspect_ratio)))
|
||||
out_h = int(round(math.sqrt(target_area / aspect_ratio)))
|
||||
x1 = random.randint(0, w - out_w)
|
||||
y1 = random.randint(0, h - out_h)
|
||||
|
||||
rgb = [u.crop((x1, y1, x1 + out_w, y1 + out_h)) for u in rgb]
|
||||
rgb = [u.resize((self.size, self.size), Image.BILINEAR) for u in rgb]
|
||||
|
||||
return rgb
|
||||
|
||||
|
||||
class RandomCropV2(object):
|
||||
|
||||
def __init__(self, size=224, min_area=0.4, ratio=(3. / 4., 4. / 3.)):
|
||||
if isinstance(size, (tuple, list)):
|
||||
self.size = size
|
||||
else:
|
||||
self.size = (size, size)
|
||||
self.min_area = min_area
|
||||
self.ratio = ratio
|
||||
|
||||
def _get_params(self, img):
|
||||
width, height = img.size
|
||||
area = height * width
|
||||
|
||||
for _ in range(10):
|
||||
target_area = random.uniform(self.min_area, 1.0) * area
|
||||
log_ratio = (math.log(self.ratio[0]), math.log(self.ratio[1]))
|
||||
aspect_ratio = math.exp(random.uniform(*log_ratio))
|
||||
|
||||
w = int(round(math.sqrt(target_area * aspect_ratio)))
|
||||
h = int(round(math.sqrt(target_area / aspect_ratio)))
|
||||
|
||||
if 0 < w <= width and 0 < h <= height:
|
||||
i = random.randint(0, height - h)
|
||||
j = random.randint(0, width - w)
|
||||
return i, j, h, w
|
||||
|
||||
# Fallback to central crop
|
||||
in_ratio = float(width) / float(height)
|
||||
if (in_ratio < min(self.ratio)):
|
||||
w = width
|
||||
h = int(round(w / min(self.ratio)))
|
||||
elif (in_ratio > max(self.ratio)):
|
||||
h = height
|
||||
w = int(round(h * max(self.ratio)))
|
||||
else:
|
||||
w = width
|
||||
h = height
|
||||
i = (height - h) // 2
|
||||
j = (width - w) // 2
|
||||
return i, j, h, w
|
||||
|
||||
def __call__(self, rgb):
|
||||
i, j, h, w = self._get_params(rgb[0])
|
||||
rgb = [F.resized_crop(u, i, j, h, w, self.size) for u in rgb]
|
||||
return rgb
|
||||
|
||||
|
||||
class RandomHFlip(object):
|
||||
|
||||
def __init__(self, p=0.5):
|
||||
self.p = p
|
||||
|
||||
def __call__(self, rgb):
|
||||
if random.random() < self.p:
|
||||
rgb = [u.transpose(Image.FLIP_LEFT_RIGHT) for u in rgb]
|
||||
return rgb
|
||||
|
||||
|
||||
class GaussianBlur(object):
|
||||
|
||||
def __init__(self, sigmas=[0.1, 2.0], p=0.5):
|
||||
self.sigmas = sigmas
|
||||
self.p = p
|
||||
|
||||
def __call__(self, rgb):
|
||||
if random.random() < self.p:
|
||||
sigma = random.uniform(*self.sigmas)
|
||||
rgb = [
|
||||
u.filter(ImageFilter.GaussianBlur(radius=sigma)) for u in rgb
|
||||
]
|
||||
return rgb
|
||||
|
||||
|
||||
class ColorJitter(object):
|
||||
|
||||
def __init__(self,
|
||||
brightness=0.4,
|
||||
contrast=0.4,
|
||||
saturation=0.4,
|
||||
hue=0.1,
|
||||
p=0.5):
|
||||
self.brightness = brightness
|
||||
self.contrast = contrast
|
||||
self.saturation = saturation
|
||||
self.hue = hue
|
||||
self.p = p
|
||||
|
||||
def __call__(self, rgb):
|
||||
if random.random() < self.p:
|
||||
brightness, contrast, saturation, hue = self._random_params()
|
||||
transforms = [
|
||||
lambda f: F.adjust_brightness(f, brightness),
|
||||
lambda f: F.adjust_contrast(f, contrast),
|
||||
lambda f: F.adjust_saturation(f, saturation),
|
||||
lambda f: F.adjust_hue(f, hue)
|
||||
]
|
||||
random.shuffle(transforms)
|
||||
for t in transforms:
|
||||
rgb = [t(u) for u in rgb]
|
||||
|
||||
return rgb
|
||||
|
||||
def _random_params(self):
|
||||
brightness = random.uniform(
|
||||
max(0, 1 - self.brightness), 1 + self.brightness)
|
||||
contrast = random.uniform(max(0, 1 - self.contrast), 1 + self.contrast)
|
||||
saturation = random.uniform(
|
||||
max(0, 1 - self.saturation), 1 + self.saturation)
|
||||
hue = random.uniform(-self.hue, self.hue)
|
||||
return brightness, contrast, saturation, hue
|
||||
|
||||
|
||||
class RandomGray(object):
|
||||
|
||||
def __init__(self, p=0.2):
|
||||
self.p = p
|
||||
|
||||
def __call__(self, rgb):
|
||||
if random.random() < self.p:
|
||||
rgb = [u.convert('L').convert('RGB') for u in rgb]
|
||||
return rgb
|
||||
|
||||
|
||||
class ToTensor(object):
|
||||
|
||||
def __call__(self, rgb):
|
||||
if isinstance(rgb, list):
|
||||
rgb = torch.stack([F.to_tensor(u) for u in rgb], dim=0)
|
||||
else:
|
||||
rgb = F.to_tensor(rgb)
|
||||
|
||||
return rgb
|
||||
|
||||
|
||||
class Normalize(object):
|
||||
|
||||
def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
|
||||
self.mean = mean
|
||||
self.std = std
|
||||
|
||||
def __call__(self, rgb):
|
||||
rgb = rgb.clone()
|
||||
rgb.clamp_(0, 1)
|
||||
if not isinstance(self.mean, torch.Tensor):
|
||||
self.mean = rgb.new_tensor(self.mean).view(-1)
|
||||
if not isinstance(self.std, torch.Tensor):
|
||||
self.std = rgb.new_tensor(self.std).view(-1)
|
||||
if rgb.dim() == 4:
|
||||
rgb.sub_(self.mean.view(1, -1, 1,
|
||||
1)).div_(self.std.view(1, -1, 1, 1))
|
||||
elif rgb.dim() == 3:
|
||||
rgb.sub_(self.mean.view(-1, 1, 1)).div_(self.std.view(-1, 1, 1))
|
||||
return rgb
|
||||
@@ -39,7 +39,7 @@ class StableDiffusion(TorchModel):
|
||||
self.lora_tune = kwargs.pop('lora_tune', False)
|
||||
self.dreambooth_tune = kwargs.pop('dreambooth_tune', False)
|
||||
|
||||
self.weight_dtype = torch.float32
|
||||
self.weight_dtype = kwargs.pop('torch_type', torch.float32)
|
||||
self.device = torch.device(
|
||||
'cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
@@ -59,14 +59,15 @@ class StableDiffusion(TorchModel):
|
||||
# Freeze gradient calculation and move to device
|
||||
if self.vae is not None:
|
||||
self.vae.requires_grad_(False)
|
||||
self.vae = self.vae.to(self.device)
|
||||
self.vae = self.vae.to(self.device, dtype=self.weight_dtype)
|
||||
if self.text_encoder is not None:
|
||||
self.text_encoder.requires_grad_(False)
|
||||
self.text_encoder = self.text_encoder.to(self.device)
|
||||
self.text_encoder = self.text_encoder.to(
|
||||
self.device, dtype=self.weight_dtype)
|
||||
if self.unet is not None:
|
||||
if self.lora_tune:
|
||||
self.unet.requires_grad_(False)
|
||||
self.unet = self.unet.to(self.device)
|
||||
self.unet = self.unet.to(self.device, dtype=self.weight_dtype)
|
||||
|
||||
# xformers accelerate memory efficient attention
|
||||
if xformers_enable:
|
||||
|
||||
24
modelscope/models/multi_modal/video_to_video/__init__.py
Normal file
24
modelscope/models/multi_modal/video_to_video/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
from .video_to_video_model import VideoToVideo
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
'video_to_video_model': ['VideoToVideo'],
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
@@ -0,0 +1,5 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from .autoencoder import *
|
||||
from .embedder import *
|
||||
from .unet_v2v import *
|
||||
@@ -0,0 +1,590 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import collections
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def nonlinearity(x):
|
||||
# swish
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
def Normalize(in_channels, num_groups=32):
|
||||
return torch.nn.GroupNorm(
|
||||
num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def get_first_stage_encoding(encoder_posterior):
|
||||
scale_factor = 0.18215
|
||||
if isinstance(encoder_posterior, DiagonalGaussianDistribution):
|
||||
z = encoder_posterior.sample()
|
||||
elif isinstance(encoder_posterior, torch.Tensor):
|
||||
z = encoder_posterior
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented"
|
||||
)
|
||||
return scale_factor * z
|
||||
|
||||
|
||||
class DiagonalGaussianDistribution(object):
|
||||
|
||||
def __init__(self, parameters, deterministic=False):
|
||||
self.parameters = parameters
|
||||
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
|
||||
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
||||
self.deterministic = deterministic
|
||||
self.std = torch.exp(0.5 * self.logvar)
|
||||
self.var = torch.exp(self.logvar)
|
||||
if self.deterministic:
|
||||
self.var = self.std = torch.zeros_like(
|
||||
self.mean).to(device=self.parameters.device)
|
||||
|
||||
def sample(self):
|
||||
x = self.mean + self.std * torch.randn(
|
||||
self.mean.shape).to(device=self.parameters.device)
|
||||
return x
|
||||
|
||||
def kl(self, other=None):
|
||||
if self.deterministic:
|
||||
return torch.Tensor([0.])
|
||||
else:
|
||||
if other is None:
|
||||
return 0.5 * torch.sum(
|
||||
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
|
||||
dim=[1, 2, 3])
|
||||
else:
|
||||
return 0.5 * torch.sum(
|
||||
torch.pow(self.mean - other.mean, 2) / other.var
|
||||
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
|
||||
dim=[1, 2, 3])
|
||||
|
||||
def nll(self, sample, dims=[1, 2, 3]):
|
||||
if self.deterministic:
|
||||
return torch.Tensor([0.])
|
||||
logtwopi = np.log(2.0 * np.pi)
|
||||
return 0.5 * torch.sum(
|
||||
logtwopi + self.logvar
|
||||
+ torch.pow(sample - self.mean, 2) / self.var,
|
||||
dim=dims)
|
||||
|
||||
def mode(self):
|
||||
return self.mean
|
||||
|
||||
|
||||
class ResnetBlock(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
in_channels,
|
||||
out_channels=None,
|
||||
conv_shortcut=False,
|
||||
dropout,
|
||||
temb_channels=512):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
self.use_conv_shortcut = conv_shortcut
|
||||
|
||||
self.norm1 = Normalize(in_channels)
|
||||
self.conv1 = torch.nn.Conv2d(
|
||||
in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
if temb_channels > 0:
|
||||
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
|
||||
self.norm2 = Normalize(out_channels)
|
||||
self.dropout = torch.nn.Dropout(dropout)
|
||||
self.conv2 = torch.nn.Conv2d(
|
||||
out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
self.conv_shortcut = torch.nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
else:
|
||||
self.nin_shortcut = torch.nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
|
||||
def forward(self, x, temb):
|
||||
h = x
|
||||
h = self.norm1(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv1(h)
|
||||
|
||||
if temb is not None:
|
||||
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
|
||||
|
||||
h = self.norm2(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.dropout(h)
|
||||
h = self.conv2(h)
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
x = self.conv_shortcut(x)
|
||||
else:
|
||||
x = self.nin_shortcut(x)
|
||||
|
||||
return x + h
|
||||
|
||||
|
||||
class AttnBlock(nn.Module):
|
||||
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = Normalize(in_channels)
|
||||
self.q = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.k = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.v = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.proj_out = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
# compute attention
|
||||
b, c, h, w = q.shape
|
||||
q = q.reshape(b, c, h * w)
|
||||
q = q.permute(0, 2, 1)
|
||||
k = k.reshape(b, c, h * w)
|
||||
w_ = torch.bmm(q, k)
|
||||
w_ = w_ * (int(c)**(-0.5))
|
||||
w_ = torch.nn.functional.softmax(w_, dim=2)
|
||||
|
||||
# attend to values
|
||||
v = v.reshape(b, c, h * w)
|
||||
w_ = w_.permute(0, 2, 1)
|
||||
h_ = torch.bmm(v, w_)
|
||||
h_ = h_.reshape(b, c, h, w)
|
||||
|
||||
h_ = self.proj_out(h_)
|
||||
|
||||
return x + h_
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, with_conv):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
if self.with_conv:
|
||||
self.conv = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
x = torch.nn.functional.interpolate(
|
||||
x, scale_factor=2.0, mode='nearest')
|
||||
if self.with_conv:
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, with_conv):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
if self.with_conv:
|
||||
# no asymmetric padding in torch conv, must do it ourselves
|
||||
self.conv = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
if self.with_conv:
|
||||
pad = (0, 1, 0, 1)
|
||||
x = torch.nn.functional.pad(x, pad, mode='constant', value=0)
|
||||
x = self.conv(x)
|
||||
else:
|
||||
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
||||
return x
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
ch,
|
||||
out_ch,
|
||||
ch_mult=(1, 2, 4, 8),
|
||||
num_res_blocks,
|
||||
attn_resolutions,
|
||||
dropout=0.0,
|
||||
resamp_with_conv=True,
|
||||
in_channels,
|
||||
resolution,
|
||||
z_channels,
|
||||
double_z=True,
|
||||
use_linear_attn=False,
|
||||
attn_type='vanilla',
|
||||
**ignore_kwargs):
|
||||
super().__init__()
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
|
||||
# downsampling
|
||||
self.conv_in = torch.nn.Conv2d(
|
||||
in_channels, self.ch, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
curr_res = resolution
|
||||
in_ch_mult = (1, ) + tuple(ch_mult)
|
||||
self.in_ch_mult = in_ch_mult
|
||||
self.down = nn.ModuleList()
|
||||
for i_level in range(self.num_resolutions):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_in = ch * in_ch_mult[i_level]
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks):
|
||||
block.append(
|
||||
ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout))
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(AttnBlock(block_in))
|
||||
down = nn.Module()
|
||||
down.block = block
|
||||
down.attn = attn
|
||||
if i_level != self.num_resolutions - 1:
|
||||
down.downsample = Downsample(block_in, resamp_with_conv)
|
||||
curr_res = curr_res // 2
|
||||
self.down.append(down)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
self.mid.attn_1 = AttnBlock(block_in)
|
||||
self.mid.block_2 = ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = torch.nn.Conv2d(
|
||||
block_in,
|
||||
2 * z_channels if double_z else z_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
# timestep embedding
|
||||
temb = None
|
||||
|
||||
# downsampling
|
||||
hs = [self.conv_in(x)]
|
||||
for i_level in range(self.num_resolutions):
|
||||
for i_block in range(self.num_res_blocks):
|
||||
h = self.down[i_level].block[i_block](hs[-1], temb)
|
||||
if len(self.down[i_level].attn) > 0:
|
||||
h = self.down[i_level].attn[i_block](h)
|
||||
hs.append(h)
|
||||
if i_level != self.num_resolutions - 1:
|
||||
hs.append(self.down[i_level].downsample(hs[-1]))
|
||||
|
||||
# middle
|
||||
h = hs[-1]
|
||||
h = self.mid.block_1(h, temb)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h, temb)
|
||||
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
ch,
|
||||
out_ch,
|
||||
ch_mult=(1, 2, 4, 8),
|
||||
num_res_blocks,
|
||||
attn_resolutions,
|
||||
dropout=0.0,
|
||||
resamp_with_conv=True,
|
||||
in_channels,
|
||||
resolution,
|
||||
z_channels,
|
||||
give_pre_end=False,
|
||||
tanh_out=False,
|
||||
use_linear_attn=False,
|
||||
attn_type='vanilla',
|
||||
**ignorekwargs):
|
||||
super().__init__()
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
self.give_pre_end = give_pre_end
|
||||
self.tanh_out = tanh_out
|
||||
|
||||
# compute block_in and curr_res at lowest res
|
||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||
curr_res = resolution // 2**(self.num_resolutions - 1)
|
||||
self.z_shape = (1, z_channels, curr_res, curr_res)
|
||||
logger.info('Working with z of shape {} = {} dimensions.'.format(
|
||||
self.z_shape, np.prod(self.z_shape)))
|
||||
|
||||
# z to block_in
|
||||
self.conv_in = torch.nn.Conv2d(
|
||||
z_channels, block_in, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
self.mid.attn_1 = AttnBlock(block_in)
|
||||
self.mid.block_2 = ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
|
||||
# upsampling
|
||||
self.up = nn.ModuleList()
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
block.append(
|
||||
ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout))
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(AttnBlock(block_in))
|
||||
up = nn.Module()
|
||||
up.block = block
|
||||
up.attn = attn
|
||||
if i_level != 0:
|
||||
up.upsample = Upsample(block_in, resamp_with_conv)
|
||||
curr_res = curr_res * 2
|
||||
self.up.insert(0, up)
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = torch.nn.Conv2d(
|
||||
block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, z):
|
||||
self.last_z_shape = z.shape
|
||||
|
||||
# timestep embedding
|
||||
temb = None
|
||||
|
||||
# z to block_in
|
||||
h = self.conv_in(z)
|
||||
|
||||
# middle
|
||||
h = self.mid.block_1(h, temb)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h, temb)
|
||||
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
h = self.up[i_level].block[i_block](h, temb)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
h = self.up[i_level].attn[i_block](h)
|
||||
if i_level != 0:
|
||||
h = self.up[i_level].upsample(h)
|
||||
|
||||
# end
|
||||
if self.give_pre_end:
|
||||
return h
|
||||
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
if self.tanh_out:
|
||||
h = torch.tanh(h)
|
||||
return h
|
||||
|
||||
|
||||
class AutoencoderKL(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
ddconfig,
|
||||
embed_dim,
|
||||
pretrained=None,
|
||||
ignore_keys=[],
|
||||
image_key='image',
|
||||
colorize_nlabels=None,
|
||||
monitor=None,
|
||||
ema_decay=None,
|
||||
learn_logvar=False,
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
self.learn_logvar = learn_logvar
|
||||
self.image_key = image_key
|
||||
self.encoder = Encoder(**ddconfig)
|
||||
self.decoder = Decoder(**ddconfig)
|
||||
assert ddconfig['double_z']
|
||||
self.quant_conv = torch.nn.Conv2d(2 * ddconfig['z_channels'],
|
||||
2 * embed_dim, 1)
|
||||
self.post_quant_conv = torch.nn.Conv2d(embed_dim,
|
||||
ddconfig['z_channels'], 1)
|
||||
self.embed_dim = embed_dim
|
||||
if colorize_nlabels is not None:
|
||||
assert type(colorize_nlabels) == int
|
||||
self.register_buffer('colorize',
|
||||
torch.randn(3, colorize_nlabels, 1, 1))
|
||||
if monitor is not None:
|
||||
self.monitor = monitor
|
||||
|
||||
self.use_ema = ema_decay is not None
|
||||
|
||||
if pretrained is not None:
|
||||
self.init_from_ckpt(pretrained, ignore_keys=ignore_keys)
|
||||
|
||||
def init_from_ckpt(self, path, ignore_keys=list()):
|
||||
sd = torch.load(path, map_location='cpu')['state_dict']
|
||||
keys = list(sd.keys())
|
||||
sd_new = collections.OrderedDict()
|
||||
for k in keys:
|
||||
if k.find('first_stage_model') >= 0:
|
||||
k_new = k.split('first_stage_model.')[-1]
|
||||
sd_new[k_new] = sd[k]
|
||||
self.load_state_dict(sd_new, strict=True)
|
||||
logger.info(f'Restored from {path}')
|
||||
|
||||
def on_train_batch_end(self, *args, **kwargs):
|
||||
if self.use_ema:
|
||||
self.model_ema(self)
|
||||
|
||||
def encode(self, x):
|
||||
h = self.encoder(x)
|
||||
moments = self.quant_conv(h)
|
||||
posterior = DiagonalGaussianDistribution(moments)
|
||||
return posterior
|
||||
|
||||
def decode(self, z):
|
||||
z = self.post_quant_conv(z)
|
||||
dec = self.decoder(z)
|
||||
return dec
|
||||
|
||||
def forward(self, input, sample_posterior=True):
|
||||
posterior = self.encode(input)
|
||||
if sample_posterior:
|
||||
z = posterior.sample()
|
||||
else:
|
||||
z = posterior.mode()
|
||||
dec = self.decode(z)
|
||||
return dec, posterior
|
||||
|
||||
def get_input(self, batch, k):
|
||||
x = batch[k]
|
||||
if len(x.shape) == 3:
|
||||
x = x[..., None]
|
||||
x = x.permute(0, 3, 1,
|
||||
2).to(memory_format=torch.contiguous_format).float()
|
||||
return x
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.decoder.conv_out.weight
|
||||
|
||||
@torch.no_grad()
|
||||
def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs):
|
||||
log = dict()
|
||||
x = self.get_input(batch, self.image_key)
|
||||
x = x.to(self.device)
|
||||
if not only_inputs:
|
||||
xrec, posterior = self(x)
|
||||
if x.shape[1] > 3:
|
||||
# colorize with random projection
|
||||
assert xrec.shape[1] > 3
|
||||
x = self.to_rgb(x)
|
||||
xrec = self.to_rgb(xrec)
|
||||
log['samples'] = self.decode(torch.randn_like(posterior.sample()))
|
||||
log['reconstructions'] = xrec
|
||||
if log_ema or self.use_ema:
|
||||
with self.ema_scope():
|
||||
xrec_ema, posterior_ema = self(x)
|
||||
if x.shape[1] > 3:
|
||||
# colorize with random projection
|
||||
assert xrec_ema.shape[1] > 3
|
||||
xrec_ema = self.to_rgb(xrec_ema)
|
||||
log['samples_ema'] = self.decode(
|
||||
torch.randn_like(posterior_ema.sample()))
|
||||
log['reconstructions_ema'] = xrec_ema
|
||||
log['inputs'] = x
|
||||
return log
|
||||
|
||||
def to_rgb(self, x):
|
||||
assert self.image_key == 'segmentation'
|
||||
if not hasattr(self, 'colorize'):
|
||||
self.register_buffer('colorize',
|
||||
torch.randn(3, x.shape[1], 1, 1).to(x))
|
||||
x = F.conv2d(x, weight=self.colorize)
|
||||
x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
|
||||
return x
|
||||
|
||||
|
||||
class IdentityFirstStage(torch.nn.Module):
|
||||
|
||||
def __init__(self, *args, vq_interface=False, **kwargs):
|
||||
self.vq_interface = vq_interface
|
||||
super().__init__()
|
||||
|
||||
def encode(self, x, *args, **kwargs):
|
||||
return x
|
||||
|
||||
def decode(self, x, *args, **kwargs):
|
||||
return x
|
||||
|
||||
def quantize(self, x, *args, **kwargs):
|
||||
if self.vq_interface:
|
||||
return x, None, [None, None, None]
|
||||
return x
|
||||
|
||||
def forward(self, x, *args, **kwargs):
|
||||
return x
|
||||
@@ -0,0 +1,76 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import open_clip
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.transforms as T
|
||||
|
||||
|
||||
class FrozenOpenCLIPEmbedder(nn.Module):
|
||||
"""
|
||||
Uses the OpenCLIP transformer encoder for text
|
||||
"""
|
||||
LAYERS = ['last', 'penultimate']
|
||||
|
||||
def __init__(self,
|
||||
pretrained,
|
||||
arch='ViT-H-14',
|
||||
device='cuda',
|
||||
max_length=77,
|
||||
freeze=True,
|
||||
layer='penultimate'):
|
||||
super().__init__()
|
||||
assert layer in self.LAYERS
|
||||
model, _, preprocess = open_clip.create_model_and_transforms(
|
||||
arch, device=torch.device('cpu'), pretrained=pretrained)
|
||||
|
||||
del model.visual
|
||||
self.model = model
|
||||
self.device = device
|
||||
self.max_length = max_length
|
||||
|
||||
if freeze:
|
||||
self.freeze()
|
||||
self.layer = layer
|
||||
if self.layer == 'last':
|
||||
self.layer_idx = 0
|
||||
elif self.layer == 'penultimate':
|
||||
self.layer_idx = 1
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
def freeze(self):
|
||||
self.model = self.model.eval()
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, text):
|
||||
tokens = open_clip.tokenize(text)
|
||||
z = self.encode_with_transformer(tokens.to(self.device))
|
||||
return z
|
||||
|
||||
def encode_with_transformer(self, text):
|
||||
x = self.model.token_embedding(text)
|
||||
x = x + self.model.positional_embedding
|
||||
x = x.permute(1, 0, 2)
|
||||
x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
|
||||
x = x.permute(1, 0, 2)
|
||||
x = self.model.ln_final(x)
|
||||
return x
|
||||
|
||||
def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
|
||||
for i, r in enumerate(self.model.transformer.resblocks):
|
||||
if i == len(self.model.transformer.resblocks) - self.layer_idx:
|
||||
break
|
||||
if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting(
|
||||
):
|
||||
x = checkpoint(r, x, attn_mask)
|
||||
else:
|
||||
x = r(x, attn_mask=attn_mask)
|
||||
return x
|
||||
|
||||
def encode(self, text):
|
||||
return self(text)
|
||||
1530
modelscope/models/multi_modal/video_to_video/modules/unet_v2v.py
Normal file
1530
modelscope/models/multi_modal/video_to_video/modules/unet_v2v.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,2 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
171
modelscope/models/multi_modal/video_to_video/utils/config.py
Normal file
171
modelscope/models/multi_modal/video_to_video/utils/config.py
Normal file
@@ -0,0 +1,171 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import os.path as osp
|
||||
from datetime import datetime
|
||||
|
||||
import torch
|
||||
from easydict import EasyDict
|
||||
|
||||
cfg = EasyDict(__name__='Config: VideoLDM Decoder')
|
||||
|
||||
# ---------------------------work dir--------------------------
|
||||
cfg.work_dir = 'workspace/'
|
||||
|
||||
# ---------------------------Global Variable-----------------------------------
|
||||
cfg.resolution = [448, 256]
|
||||
cfg.max_frames = 32
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
# ---------------------------Dataset Parameter---------------------------------
|
||||
cfg.mean = [0.5, 0.5, 0.5]
|
||||
cfg.std = [0.5, 0.5, 0.5]
|
||||
cfg.max_words = 1000
|
||||
|
||||
# PlaceHolder
|
||||
cfg.vit_out_dim = 1024
|
||||
cfg.vit_resolution = [224, 224]
|
||||
cfg.depth_clamp = 10.0
|
||||
cfg.misc_size = 384
|
||||
cfg.depth_std = 20.0
|
||||
|
||||
cfg.frame_lens = 32
|
||||
cfg.sample_fps = 8
|
||||
|
||||
cfg.batch_sizes = 1
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
# ---------------------------Mode Parameters-----------------------------------
|
||||
# Diffusion
|
||||
cfg.schedule = 'cosine'
|
||||
cfg.num_timesteps = 1000
|
||||
cfg.mean_type = 'v'
|
||||
cfg.var_type = 'fixed_small'
|
||||
cfg.loss_type = 'mse'
|
||||
cfg.ddim_timesteps = 50
|
||||
cfg.ddim_eta = 0.0
|
||||
cfg.clamp = 1.0
|
||||
cfg.share_noise = False
|
||||
cfg.use_div_loss = False
|
||||
cfg.noise_strength = 0.1
|
||||
|
||||
# classifier-free guidance
|
||||
cfg.p_zero = 0.1
|
||||
cfg.guide_scale = 3.0
|
||||
|
||||
# clip vision encoder
|
||||
cfg.vit_mean = [0.48145466, 0.4578275, 0.40821073]
|
||||
cfg.vit_std = [0.26862954, 0.26130258, 0.27577711]
|
||||
|
||||
# Model
|
||||
cfg.scale_factor = 0.18215
|
||||
cfg.use_fp16 = True
|
||||
cfg.temporal_attention = True
|
||||
cfg.decoder_bs = 8
|
||||
|
||||
cfg.UNet = {
|
||||
'type': 'Vid2VidSDUNet',
|
||||
'in_dim': 4,
|
||||
'dim': 320,
|
||||
'y_dim': cfg.vit_out_dim,
|
||||
'context_dim': 1024,
|
||||
'out_dim': 8 if cfg.var_type.startswith('learned') else 4,
|
||||
'dim_mult': [1, 2, 4, 4],
|
||||
'num_heads': 8,
|
||||
'head_dim': 64,
|
||||
'num_res_blocks': 2,
|
||||
'attn_scales': [1 / 1, 1 / 2, 1 / 4],
|
||||
'dropout': 0.1,
|
||||
'temporal_attention': cfg.temporal_attention,
|
||||
'temporal_attn_times': 1,
|
||||
'use_checkpoint': False,
|
||||
'use_fps_condition': False,
|
||||
'use_sim_mask': False,
|
||||
'num_tokens': 4,
|
||||
'default_fps': 8,
|
||||
'input_dim': 1024
|
||||
}
|
||||
|
||||
cfg.guidances = []
|
||||
|
||||
# auotoencoder from stabel diffusion
|
||||
cfg.auto_encoder = {
|
||||
'type': 'AutoencoderKL',
|
||||
'ddconfig': {
|
||||
'double_z': True,
|
||||
'z_channels': 4,
|
||||
'resolution': 256,
|
||||
'in_channels': 3,
|
||||
'out_ch': 3,
|
||||
'ch': 128,
|
||||
'ch_mult': [1, 2, 4, 4],
|
||||
'num_res_blocks': 2,
|
||||
'attn_resolutions': [],
|
||||
'dropout': 0.0
|
||||
},
|
||||
'embed_dim': 4,
|
||||
'pretrained': 'models/v2-1_512-ema-pruned.ckpt'
|
||||
}
|
||||
# clip embedder
|
||||
cfg.embedder = {
|
||||
'type': 'FrozenOpenCLIPEmbedder',
|
||||
'layer': 'penultimate',
|
||||
'vit_resolution': [224, 224],
|
||||
'pretrained': 'open_clip_pytorch_model.bin'
|
||||
}
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
# ---------------------------Training Settings---------------------------------
|
||||
# training and optimizer
|
||||
cfg.ema_decay = 0.9999
|
||||
cfg.num_steps = 600000
|
||||
cfg.lr = 5e-5
|
||||
cfg.weight_decay = 0.0
|
||||
cfg.betas = (0.9, 0.999)
|
||||
cfg.eps = 1.0e-8
|
||||
cfg.chunk_size = 16
|
||||
cfg.alpha = 0.7
|
||||
cfg.save_ckp_interval = 1000
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
# ----------------------------Pretrain Settings---------------------------------
|
||||
# Default: load 2d pretrain
|
||||
cfg.fix_weight = False
|
||||
cfg.load_match = False
|
||||
cfg.pretrained_checkpoint = 'v2-1_512-ema-pruned.ckpt'
|
||||
cfg.pretrained_image_keys = 'stable_diffusion_image_key_temporal_attention_x1.json'
|
||||
cfg.resume_checkpoint = 'img2video_ldm_0779000.pth'
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
# -----------------------------Visual-------------------------------------------
|
||||
# Visual videos
|
||||
cfg.viz_interval = 1000
|
||||
cfg.visual_train = {
|
||||
'type': 'VisualVideoTextDuringTrain',
|
||||
}
|
||||
cfg.visual_inference = {
|
||||
'type': 'VisualGeneratedVideos',
|
||||
}
|
||||
cfg.inference_list_path = ''
|
||||
|
||||
# logging
|
||||
cfg.log_interval = 100
|
||||
|
||||
# Default log_dir
|
||||
cfg.log_dir = 'workspace/output_data'
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
# ---------------------------Others--------------------------------------------
|
||||
# seed
|
||||
cfg.seed = 8888
|
||||
cfg.negative_prompt = 'worst quality, normal quality, low quality, low res, blurry, text, \
|
||||
watermark, logo, banner, extra digits, cropped, jpeg artifacts, signature, username, error, \
|
||||
sketch ,duplicate, ugly, monochrome, horror, geometry, mutation, disgusting'
|
||||
|
||||
cfg.positive_prompt = ', cinematic, High Contrast, highly detailed, unreal engine, \
|
||||
taken using a Canon EOS R camera, hyper detailed photo - realistic maximum detail, \
|
||||
32k, Color Grading, ultra HD, extreme meticulous detailing, skin pore detailing, \
|
||||
hyper sharpness, perfect without deformations, Unreal Engine 5, 4k render'
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -0,0 +1,247 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import random
|
||||
|
||||
import torch
|
||||
|
||||
from .schedules_sdedit import karras_schedule
|
||||
from .solvers_sdedit import sample_dpmpp_2m_sde, sample_heun
|
||||
|
||||
__all__ = ['GaussianDiffusion_SDEdit']
|
||||
|
||||
|
||||
def _i(tensor, t, x):
|
||||
shape = (x.size(0), ) + (1, ) * (x.ndim - 1)
|
||||
return tensor[t.to(tensor.device)].view(shape).to(x.device)
|
||||
|
||||
|
||||
class GaussianDiffusion_SDEdit(object):
|
||||
|
||||
def __init__(self, sigmas, prediction_type='eps'):
|
||||
assert prediction_type in {'x0', 'eps', 'v'}
|
||||
self.sigmas = sigmas
|
||||
self.alphas = torch.sqrt(1 - sigmas**2)
|
||||
self.num_timesteps = len(sigmas)
|
||||
self.prediction_type = prediction_type
|
||||
|
||||
def diffuse(self, x0, t, noise=None):
|
||||
noise = torch.randn_like(x0) if noise is None else noise
|
||||
xt = _i(self.alphas, t, x0) * x0 + _i(self.sigmas, t, x0) * noise
|
||||
return xt
|
||||
|
||||
def denoise(self,
|
||||
xt,
|
||||
t,
|
||||
s,
|
||||
model,
|
||||
model_kwargs={},
|
||||
guide_scale=None,
|
||||
guide_rescale=None,
|
||||
clamp=None,
|
||||
percentile=None):
|
||||
s = t - 1 if s is None else s
|
||||
|
||||
# hyperparams
|
||||
sigmas = _i(self.sigmas, t, xt)
|
||||
alphas = _i(self.alphas, t, xt)
|
||||
alphas_s = _i(self.alphas, s.clamp(0), xt)
|
||||
alphas_s[s < 0] = 1.
|
||||
sigmas_s = torch.sqrt(1 - alphas_s**2)
|
||||
|
||||
# precompute variables
|
||||
betas = 1 - (alphas / alphas_s)**2
|
||||
coef1 = betas * alphas_s / sigmas**2
|
||||
coef2 = (alphas * sigmas_s**2) / (alphas_s * sigmas**2)
|
||||
var = betas * (sigmas_s / sigmas)**2
|
||||
log_var = torch.log(var).clamp_(-20, 20)
|
||||
|
||||
# prediction
|
||||
if guide_scale is None:
|
||||
assert isinstance(model_kwargs, dict)
|
||||
out = model(xt, t=t, **model_kwargs)
|
||||
else:
|
||||
# classifier-free guidance
|
||||
assert isinstance(model_kwargs, list) and len(model_kwargs) == 2
|
||||
y_out = model(xt, t=t, **model_kwargs[0])
|
||||
if guide_scale == 1.:
|
||||
out = y_out
|
||||
else:
|
||||
u_out = model(xt, t=t, **model_kwargs[1])
|
||||
out = u_out + guide_scale * (y_out - u_out)
|
||||
|
||||
if guide_rescale is not None:
|
||||
assert guide_rescale >= 0 and guide_rescale <= 1
|
||||
ratio = (
|
||||
y_out.flatten(1).std(dim=1) / # noqa
|
||||
(out.flatten(1).std(dim=1) + 1e-12)
|
||||
).view((-1, ) + (1, ) * (y_out.ndim - 1))
|
||||
out *= guide_rescale * ratio + (1 - guide_rescale) * 1.0
|
||||
|
||||
# compute x0
|
||||
if self.prediction_type == 'x0':
|
||||
x0 = out
|
||||
elif self.prediction_type == 'eps':
|
||||
x0 = (xt - sigmas * out) / alphas
|
||||
elif self.prediction_type == 'v':
|
||||
x0 = alphas * xt - sigmas * out
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f'prediction_type {self.prediction_type} not implemented')
|
||||
|
||||
# restrict the range of x0
|
||||
if percentile is not None:
|
||||
assert percentile > 0 and percentile <= 1
|
||||
s = torch.quantile(x0.flatten(1).abs(), percentile, dim=1)
|
||||
s = s.clamp_(1.0).view((-1, ) + (1, ) * (xt.ndim - 1))
|
||||
x0 = torch.min(s, torch.max(-s, x0)) / s
|
||||
elif clamp is not None:
|
||||
x0 = x0.clamp(-clamp, clamp)
|
||||
|
||||
# recompute eps using the restricted x0
|
||||
eps = (xt - alphas * x0) / sigmas
|
||||
|
||||
# compute mu (mean of posterior distribution) using the restricted x0
|
||||
mu = coef1 * x0 + coef2 * xt
|
||||
return mu, var, log_var, x0, eps
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self,
|
||||
noise,
|
||||
model,
|
||||
model_kwargs={},
|
||||
condition_fn=None,
|
||||
guide_scale=None,
|
||||
guide_rescale=None,
|
||||
clamp=None,
|
||||
percentile=None,
|
||||
solver='euler_a',
|
||||
steps=20,
|
||||
t_max=None,
|
||||
t_min=None,
|
||||
discretization=None,
|
||||
discard_penultimate_step=None,
|
||||
return_intermediate=None,
|
||||
show_progress=False,
|
||||
seed=-1,
|
||||
**kwargs):
|
||||
# sanity check
|
||||
assert isinstance(steps, (int, torch.LongTensor))
|
||||
assert t_max is None or (t_max > 0 and t_max <= self.num_timesteps - 1)
|
||||
assert t_min is None or (t_min >= 0 and t_min < self.num_timesteps - 1)
|
||||
assert discretization in (None, 'leading', 'linspace', 'trailing')
|
||||
assert discard_penultimate_step in (None, True, False)
|
||||
assert return_intermediate in (None, 'x0', 'xt')
|
||||
|
||||
# function of diffusion solver
|
||||
solver_fn = {
|
||||
'heun': sample_heun,
|
||||
'dpmpp_2m_sde': sample_dpmpp_2m_sde
|
||||
}[solver]
|
||||
|
||||
# options
|
||||
schedule = 'karras' if 'karras' in solver else None
|
||||
discretization = discretization or 'linspace'
|
||||
seed = seed if seed >= 0 else random.randint(0, 2**31)
|
||||
if isinstance(steps, torch.LongTensor):
|
||||
discard_penultimate_step = False
|
||||
if discard_penultimate_step is None:
|
||||
discard_penultimate_step = True if solver in (
|
||||
'dpm2', 'dpm2_ancestral', 'dpmpp_2m_sde', 'dpm2_karras',
|
||||
'dpm2_ancestral_karras', 'dpmpp_2m_sde_karras') else False
|
||||
|
||||
# function for denoising xt to get x0
|
||||
intermediates = []
|
||||
|
||||
def model_fn(xt, sigma):
|
||||
# denoising
|
||||
t = self._sigma_to_t(sigma).repeat(len(xt)).round().long()
|
||||
x0 = self.denoise(xt, t, None, model, model_kwargs, guide_scale,
|
||||
guide_rescale, clamp, percentile)[-2]
|
||||
|
||||
# collect intermediate outputs
|
||||
if return_intermediate == 'xt':
|
||||
intermediates.append(xt)
|
||||
elif return_intermediate == 'x0':
|
||||
intermediates.append(x0)
|
||||
return x0
|
||||
|
||||
# get timesteps
|
||||
if isinstance(steps, int):
|
||||
steps += 1 if discard_penultimate_step else 0
|
||||
t_max = self.num_timesteps - 1 if t_max is None else t_max
|
||||
t_min = 0 if t_min is None else t_min
|
||||
|
||||
# discretize timesteps
|
||||
if discretization == 'leading':
|
||||
steps = torch.arange(t_min, t_max + 1,
|
||||
(t_max - t_min + 1) / steps).flip(0)
|
||||
elif discretization == 'linspace':
|
||||
steps = torch.linspace(t_max, t_min, steps)
|
||||
elif discretization == 'trailing':
|
||||
steps = torch.arange(t_max, t_min - 1,
|
||||
-((t_max - t_min + 1) / steps))
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f'{discretization} discretization not implemented')
|
||||
steps = steps.clamp_(t_min, t_max)
|
||||
steps = torch.as_tensor(
|
||||
steps, dtype=torch.float32, device=noise.device)
|
||||
|
||||
# get sigmas
|
||||
sigmas = self._t_to_sigma(steps)
|
||||
sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
|
||||
if schedule == 'karras':
|
||||
if sigmas[0] == float('inf'):
|
||||
sigmas = karras_schedule(
|
||||
n=len(steps) - 1,
|
||||
sigma_min=sigmas[sigmas > 0].min().item(),
|
||||
sigma_max=sigmas[sigmas < float('inf')].max().item(),
|
||||
rho=7.).to(sigmas)
|
||||
sigmas = torch.cat([
|
||||
sigmas.new_tensor([float('inf')]), sigmas,
|
||||
sigmas.new_zeros([1])
|
||||
])
|
||||
else:
|
||||
sigmas = karras_schedule(
|
||||
n=len(steps),
|
||||
sigma_min=sigmas[sigmas > 0].min().item(),
|
||||
sigma_max=sigmas.max().item(),
|
||||
rho=7.).to(sigmas)
|
||||
sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
|
||||
if discard_penultimate_step:
|
||||
sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
|
||||
|
||||
# sampling
|
||||
x0 = solver_fn(
|
||||
noise, model_fn, sigmas, show_progress=show_progress, **kwargs)
|
||||
return (x0, intermediates) if return_intermediate is not None else x0
|
||||
|
||||
def _sigma_to_t(self, sigma):
|
||||
if sigma == float('inf'):
|
||||
t = torch.full_like(sigma, len(self.sigmas) - 1)
|
||||
else:
|
||||
log_sigmas = torch.sqrt(self.sigmas**2 / # noqa
|
||||
(1 - self.sigmas**2)).log().to(sigma)
|
||||
log_sigma = sigma.log()
|
||||
dists = log_sigma - log_sigmas[:, None]
|
||||
low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(
|
||||
max=log_sigmas.shape[0] - 2)
|
||||
high_idx = low_idx + 1
|
||||
low, high = log_sigmas[low_idx], log_sigmas[high_idx]
|
||||
w = (low - log_sigma) / (low - high)
|
||||
w = w.clamp(0, 1)
|
||||
t = (1 - w) * low_idx + w * high_idx
|
||||
t = t.view(sigma.shape)
|
||||
if t.ndim == 0:
|
||||
t = t.unsqueeze(0)
|
||||
return t
|
||||
|
||||
def _t_to_sigma(self, t):
|
||||
t = t.float()
|
||||
low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac()
|
||||
log_sigmas = torch.sqrt(self.sigmas**2 / # noqa
|
||||
(1 - self.sigmas**2)).log().to(t)
|
||||
log_sigma = (1 - w) * log_sigmas[low_idx] + w * log_sigmas[high_idx]
|
||||
log_sigma[torch.isnan(log_sigma)
|
||||
| torch.isinf(log_sigma)] = float('inf')
|
||||
return log_sigma.exp()
|
||||
@@ -0,0 +1,85 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def betas_to_sigmas(betas):
|
||||
return torch.sqrt(1 - torch.cumprod(1 - betas, dim=0))
|
||||
|
||||
|
||||
def sigmas_to_betas(sigmas):
|
||||
square_alphas = 1 - sigmas**2
|
||||
betas = 1 - torch.cat(
|
||||
[square_alphas[:1], square_alphas[1:] / square_alphas[:-1]])
|
||||
return betas
|
||||
|
||||
|
||||
def logsnrs_to_sigmas(logsnrs):
|
||||
return torch.sqrt(torch.sigmoid(-logsnrs))
|
||||
|
||||
|
||||
def sigmas_to_logsnrs(sigmas):
|
||||
square_sigmas = sigmas**2
|
||||
return torch.log(square_sigmas / (1 - square_sigmas))
|
||||
|
||||
|
||||
def _logsnr_cosine(n, logsnr_min=-15, logsnr_max=15):
|
||||
t_min = math.atan(math.exp(-0.5 * logsnr_min))
|
||||
t_max = math.atan(math.exp(-0.5 * logsnr_max))
|
||||
t = torch.linspace(1, 0, n)
|
||||
logsnrs = -2 * torch.log(torch.tan(t_min + t * (t_max - t_min)))
|
||||
return logsnrs
|
||||
|
||||
|
||||
def _logsnr_cosine_shifted(n, logsnr_min=-15, logsnr_max=15, scale=2):
|
||||
logsnrs = _logsnr_cosine(n, logsnr_min, logsnr_max)
|
||||
logsnrs += 2 * math.log(1 / scale)
|
||||
return logsnrs
|
||||
|
||||
|
||||
def _logsnr_cosine_interp(n,
|
||||
logsnr_min=-15,
|
||||
logsnr_max=15,
|
||||
scale_min=2,
|
||||
scale_max=4):
|
||||
t = torch.linspace(1, 0, n)
|
||||
logsnrs_min = _logsnr_cosine_shifted(n, logsnr_min, logsnr_max, scale_min)
|
||||
logsnrs_max = _logsnr_cosine_shifted(n, logsnr_min, logsnr_max, scale_max)
|
||||
logsnrs = t * logsnrs_min + (1 - t) * logsnrs_max
|
||||
return logsnrs
|
||||
|
||||
|
||||
def karras_schedule(n, sigma_min=0.002, sigma_max=80.0, rho=7.0):
|
||||
ramp = torch.linspace(1, 0, n)
|
||||
min_inv_rho = sigma_min**(1 / rho)
|
||||
max_inv_rho = sigma_max**(1 / rho)
|
||||
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho))**rho
|
||||
sigmas = torch.sqrt(sigmas**2 / (1 + sigmas**2))
|
||||
return sigmas
|
||||
|
||||
|
||||
def logsnr_cosine_interp_schedule(n,
|
||||
logsnr_min=-15,
|
||||
logsnr_max=15,
|
||||
scale_min=2,
|
||||
scale_max=4):
|
||||
return logsnrs_to_sigmas(
|
||||
_logsnr_cosine_interp(n, logsnr_min, logsnr_max, scale_min, scale_max))
|
||||
|
||||
|
||||
def noise_schedule(schedule='logsnr_cosine_interp',
|
||||
n=1000,
|
||||
zero_terminal_snr=False,
|
||||
**kwargs):
|
||||
# compute sigmas
|
||||
sigmas = {
|
||||
'logsnr_cosine_interp': logsnr_cosine_interp_schedule
|
||||
}[schedule](n, **kwargs)
|
||||
|
||||
# post-processing
|
||||
if zero_terminal_snr and sigmas.max() != 1.0:
|
||||
scale = (1.0 - sigmas.min()) / (sigmas.max() - sigmas.min())
|
||||
sigmas = sigmas.min() + scale * (sigmas - sigmas.min())
|
||||
return sigmas
|
||||
14
modelscope/models/multi_modal/video_to_video/utils/seed.py
Normal file
14
modelscope/models/multi_modal/video_to_video/utils/seed.py
Normal file
@@ -0,0 +1,14 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def setup_seed(seed):
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
@@ -0,0 +1,194 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import torch
|
||||
import torchsde
|
||||
from tqdm.auto import trange
|
||||
|
||||
|
||||
def get_ancestral_step(sigma_from, sigma_to, eta=1.):
|
||||
"""
|
||||
Calculates the noise level (sigma_down) to step down to and the amount
|
||||
of noise to add (sigma_up) when doing an ancestral sampling step.
|
||||
"""
|
||||
if not eta:
|
||||
return sigma_to, 0.
|
||||
sigma_up = min(
|
||||
sigma_to,
|
||||
eta * (
|
||||
sigma_to**2 * # noqa
|
||||
(sigma_from**2 - sigma_to**2) / sigma_from**2)**0.5)
|
||||
sigma_down = (sigma_to**2 - sigma_up**2)**0.5
|
||||
return sigma_down, sigma_up
|
||||
|
||||
|
||||
def get_scalings(sigma):
|
||||
c_out = -sigma
|
||||
c_in = 1 / (sigma**2 + 1.**2)**0.5
|
||||
return c_out, c_in
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_heun(noise,
|
||||
model,
|
||||
sigmas,
|
||||
s_churn=0.,
|
||||
s_tmin=0.,
|
||||
s_tmax=float('inf'),
|
||||
s_noise=1.,
|
||||
show_progress=True):
|
||||
"""
|
||||
Implements Algorithm 2 (Heun steps) from Karras et al. (2022).
|
||||
"""
|
||||
x = noise * sigmas[0]
|
||||
for i in trange(len(sigmas) - 1, disable=not show_progress):
|
||||
gamma = 0.
|
||||
if s_tmin <= sigmas[i] <= s_tmax and sigmas[i] < float('inf'):
|
||||
gamma = min(s_churn / (len(sigmas) - 1), 2**0.5 - 1)
|
||||
eps = torch.randn_like(x) * s_noise
|
||||
sigma_hat = sigmas[i] * (gamma + 1)
|
||||
if gamma > 0:
|
||||
x = x + eps * (sigma_hat**2 - sigmas[i]**2)**0.5
|
||||
if sigmas[i] == float('inf'):
|
||||
# Euler method
|
||||
denoised = model(noise, sigma_hat)
|
||||
x = denoised + sigmas[i + 1] * (gamma + 1) * noise
|
||||
else:
|
||||
_, c_in = get_scalings(sigma_hat)
|
||||
denoised = model(x * c_in, sigma_hat)
|
||||
d = (x - denoised) / sigma_hat
|
||||
dt = sigmas[i + 1] - sigma_hat
|
||||
if sigmas[i + 1] == 0:
|
||||
# Euler method
|
||||
x = x + d * dt
|
||||
else:
|
||||
# Heun's method
|
||||
x_2 = x + d * dt
|
||||
_, c_in = get_scalings(sigmas[i + 1])
|
||||
denoised_2 = model(x_2 * c_in, sigmas[i + 1])
|
||||
d_2 = (x_2 - denoised_2) / sigmas[i + 1]
|
||||
d_prime = (d + d_2) / 2
|
||||
x = x + d_prime * dt
|
||||
return x
|
||||
|
||||
|
||||
class BatchedBrownianTree:
|
||||
"""
|
||||
A wrapper around torchsde.BrownianTree that enables batches of entropy.
|
||||
"""
|
||||
|
||||
def __init__(self, x, t0, t1, seed=None, **kwargs):
|
||||
t0, t1, self.sign = self.sort(t0, t1)
|
||||
w0 = kwargs.get('w0', torch.zeros_like(x))
|
||||
if seed is None:
|
||||
seed = torch.randint(0, 2**63 - 1, []).item()
|
||||
self.batched = True
|
||||
try:
|
||||
assert len(seed) == x.shape[0]
|
||||
w0 = w0[0]
|
||||
except TypeError:
|
||||
seed = [seed]
|
||||
self.batched = False
|
||||
self.trees = [
|
||||
torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs)
|
||||
for s in seed
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def sort(a, b):
|
||||
return (a, b, 1) if a < b else (b, a, -1)
|
||||
|
||||
def __call__(self, t0, t1):
|
||||
t0, t1, sign = self.sort(t0, t1)
|
||||
w = torch.stack([tree(t0, t1) for tree in self.trees]) * (
|
||||
self.sign * sign)
|
||||
return w if self.batched else w[0]
|
||||
|
||||
|
||||
class BrownianTreeNoiseSampler:
|
||||
"""
|
||||
A noise sampler backed by a torchsde.BrownianTree.
|
||||
|
||||
Args:
|
||||
x (Tensor): The tensor whose shape, device and dtype to use to generate
|
||||
random samples.
|
||||
sigma_min (float): The low end of the valid interval.
|
||||
sigma_max (float): The high end of the valid interval.
|
||||
seed (int or List[int]): The random seed. If a list of seeds is
|
||||
supplied instead of a single integer, then the noise sampler will
|
||||
use one BrownianTree per batch item, each with its own seed.
|
||||
transform (callable): A function that maps sigma to the sampler's
|
||||
internal timestep.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
x,
|
||||
sigma_min,
|
||||
sigma_max,
|
||||
seed=None,
|
||||
transform=lambda x: x):
|
||||
self.transform = transform
|
||||
t0 = self.transform(torch.as_tensor(sigma_min))
|
||||
t1 = self.transform(torch.as_tensor(sigma_max))
|
||||
self.tree = BatchedBrownianTree(x, t0, t1, seed)
|
||||
|
||||
def __call__(self, sigma, sigma_next):
|
||||
t0 = self.transform(torch.as_tensor(sigma))
|
||||
t1 = self.transform(torch.as_tensor(sigma_next))
|
||||
return self.tree(t0, t1) / (t1 - t0).abs().sqrt()
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_2m_sde(noise,
|
||||
model,
|
||||
sigmas,
|
||||
eta=1.,
|
||||
s_noise=1.,
|
||||
solver_type='midpoint',
|
||||
show_progress=True):
|
||||
"""
|
||||
DPM-Solver++ (2M) SDE.
|
||||
"""
|
||||
assert solver_type in {'heun', 'midpoint'}
|
||||
|
||||
x = noise * sigmas[0]
|
||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas[
|
||||
sigmas < float('inf')].max()
|
||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max)
|
||||
old_denoised = None
|
||||
h_last = None
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=not show_progress):
|
||||
if sigmas[i] == float('inf'):
|
||||
# Euler method
|
||||
denoised = model(noise, sigmas[i])
|
||||
x = denoised + sigmas[i + 1] * noise
|
||||
else:
|
||||
_, c_in = get_scalings(sigmas[i])
|
||||
denoised = model(x * c_in, sigmas[i])
|
||||
if sigmas[i + 1] == 0:
|
||||
# Denoising step
|
||||
x = denoised
|
||||
else:
|
||||
# DPM-Solver++(2M) SDE
|
||||
t, s = -sigmas[i].log(), -sigmas[i + 1].log()
|
||||
h = s - t
|
||||
eta_h = eta * h
|
||||
|
||||
x = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * x + \
|
||||
(-h - eta_h).expm1().neg() * denoised
|
||||
|
||||
if old_denoised is not None:
|
||||
r = h_last / h
|
||||
if solver_type == 'heun':
|
||||
x = x + ((-h - eta_h).expm1().neg() / (-h - eta_h) + 1) * \
|
||||
(1 / r) * (denoised - old_denoised)
|
||||
elif solver_type == 'midpoint':
|
||||
x = x + 0.5 * (-h - eta_h).expm1().neg() * \
|
||||
(1 / r) * (denoised - old_denoised)
|
||||
|
||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[
|
||||
i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise
|
||||
|
||||
old_denoised = denoised
|
||||
h_last = h
|
||||
return x
|
||||
404
modelscope/models/multi_modal/video_to_video/utils/transforms.py
Normal file
404
modelscope/models/multi_modal/video_to_video/utils/transforms.py
Normal file
@@ -0,0 +1,404 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import math
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms.functional as F
|
||||
from PIL import Image, ImageFilter
|
||||
|
||||
__all__ = [
|
||||
'Compose', 'Resize', 'Rescale', 'CenterCrop', 'CenterCropV2',
|
||||
'CenterCropWide', 'RandomCrop', 'RandomCropV2', 'RandomHFlip',
|
||||
'GaussianBlur', 'ColorJitter', 'RandomGray', 'ToTensor', 'Normalize',
|
||||
'ResizeRandomCrop', 'ExtractResizeRandomCrop', 'ExtractResizeAssignCrop'
|
||||
]
|
||||
|
||||
|
||||
class Compose(object):
|
||||
|
||||
def __init__(self, transforms):
|
||||
self.transforms = transforms
|
||||
|
||||
def __getitem__(self, index):
|
||||
if isinstance(index, slice):
|
||||
return Compose(self.transforms[index])
|
||||
else:
|
||||
return self.transforms[index]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.transforms)
|
||||
|
||||
def __call__(self, rgb):
|
||||
for t in self.transforms:
|
||||
rgb = t(rgb)
|
||||
return rgb
|
||||
|
||||
|
||||
class Resize(object):
|
||||
|
||||
def __init__(self, size=256):
|
||||
if isinstance(size, int):
|
||||
size = (size, size)
|
||||
self.size = size
|
||||
|
||||
def __call__(self, rgb):
|
||||
if isinstance(rgb, list):
|
||||
rgb = [u.resize(self.size, Image.BILINEAR) for u in rgb]
|
||||
else:
|
||||
rgb = rgb.resize(self.size, Image.BILINEAR)
|
||||
return rgb
|
||||
|
||||
|
||||
class Rescale(object):
|
||||
|
||||
def __init__(self, size=256, interpolation=Image.BILINEAR):
|
||||
self.size = size
|
||||
self.interpolation = interpolation
|
||||
|
||||
def __call__(self, rgb):
|
||||
w, h = rgb[0].size
|
||||
scale = self.size / min(w, h)
|
||||
out_w, out_h = int(round(w * scale)), int(round(h * scale))
|
||||
rgb = [u.resize((out_w, out_h), self.interpolation) for u in rgb]
|
||||
return rgb
|
||||
|
||||
|
||||
class CenterCrop(object):
|
||||
|
||||
def __init__(self, size=224):
|
||||
self.size = size
|
||||
|
||||
def __call__(self, rgb):
|
||||
w, h = rgb[0].size
|
||||
assert min(w, h) >= self.size
|
||||
x1 = (w - self.size) // 2
|
||||
y1 = (h - self.size) // 2
|
||||
rgb = [u.crop((x1, y1, x1 + self.size, y1 + self.size)) for u in rgb]
|
||||
return rgb
|
||||
|
||||
|
||||
class ResizeRandomCrop(object):
|
||||
|
||||
def __init__(self, size=256, size_short=292):
|
||||
self.size = size
|
||||
self.size_short = size_short
|
||||
|
||||
def __call__(self, rgb):
|
||||
|
||||
# consistent crop between rgb and m
|
||||
while min(rgb[0].size) >= 2 * self.size_short:
|
||||
rgb = [
|
||||
u.resize((u.width // 2, u.height // 2), resample=Image.BOX)
|
||||
for u in rgb
|
||||
]
|
||||
scale = self.size_short / min(rgb[0].size)
|
||||
rgb = [
|
||||
u.resize((round(scale * u.width), round(scale * u.height)),
|
||||
resample=Image.BICUBIC) for u in rgb
|
||||
]
|
||||
out_w = self.size
|
||||
out_h = self.size
|
||||
w, h = rgb[0].size
|
||||
x1 = random.randint(0, w - out_w)
|
||||
y1 = random.randint(0, h - out_h)
|
||||
|
||||
rgb = [u.crop((x1, y1, x1 + out_w, y1 + out_h)) for u in rgb]
|
||||
return rgb
|
||||
|
||||
|
||||
class ExtractResizeRandomCrop(object):
|
||||
|
||||
def __init__(self, size=256, size_short=292):
|
||||
self.size = size
|
||||
self.size_short = size_short
|
||||
|
||||
def __call__(self, rgb):
|
||||
|
||||
# consistent crop between rgb and m
|
||||
while min(rgb[0].size) >= 2 * self.size_short:
|
||||
rgb = [
|
||||
u.resize((u.width // 2, u.height // 2), resample=Image.BOX)
|
||||
for u in rgb
|
||||
]
|
||||
scale = self.size_short / min(rgb[0].size)
|
||||
rgb = [
|
||||
u.resize((round(scale * u.width), round(scale * u.height)),
|
||||
resample=Image.BICUBIC) for u in rgb
|
||||
]
|
||||
out_w = self.size
|
||||
out_h = self.size
|
||||
w, h = rgb[0].size
|
||||
x1 = random.randint(0, w - out_w)
|
||||
y1 = random.randint(0, h - out_h)
|
||||
|
||||
rgb = [u.crop((x1, y1, x1 + out_w, y1 + out_h)) for u in rgb]
|
||||
wh = [x1, y1, x1 + out_w, y1 + out_h]
|
||||
return rgb, wh
|
||||
|
||||
|
||||
class ExtractResizeAssignCrop(object):
|
||||
|
||||
def __init__(self, size=256, size_short=292):
|
||||
self.size = size
|
||||
self.size_short = size_short
|
||||
|
||||
def __call__(self, rgb, wh):
|
||||
|
||||
# consistent crop between rgb and m
|
||||
while min(rgb[0].size) >= 2 * self.size_short:
|
||||
rgb = [
|
||||
u.resize((u.width // 2, u.height // 2), resample=Image.BOX)
|
||||
for u in rgb
|
||||
]
|
||||
scale = self.size_short / min(rgb[0].size)
|
||||
rgb = [
|
||||
u.resize((round(scale * u.width), round(scale * u.height)),
|
||||
resample=Image.BICUBIC) for u in rgb
|
||||
]
|
||||
|
||||
rgb = [u.crop(wh) for u in rgb]
|
||||
rgb = [u.resize((self.size, self.size), Image.BILINEAR) for u in rgb]
|
||||
|
||||
return rgb
|
||||
|
||||
|
||||
class CenterCropV2(object):
|
||||
|
||||
def __init__(self, size):
|
||||
self.size = size
|
||||
|
||||
def __call__(self, img):
|
||||
# fast resize
|
||||
while min(img[0].size) >= 2 * self.size:
|
||||
img = [
|
||||
u.resize((u.width // 2, u.height // 2), resample=Image.BOX)
|
||||
for u in img
|
||||
]
|
||||
scale = self.size / min(img[0].size)
|
||||
img = [
|
||||
u.resize((round(scale * u.width), round(scale * u.height)),
|
||||
resample=Image.BICUBIC) for u in img
|
||||
]
|
||||
|
||||
# center crop
|
||||
x1 = (img[0].width - self.size) // 2
|
||||
y1 = (img[0].height - self.size) // 2
|
||||
img = [u.crop((x1, y1, x1 + self.size, y1 + self.size)) for u in img]
|
||||
return img
|
||||
|
||||
|
||||
class CenterCropWide(object):
|
||||
|
||||
def __init__(self, size):
|
||||
self.size = size
|
||||
|
||||
def __call__(self, img):
|
||||
if isinstance(img, list):
|
||||
scale = min(img[0].size[0] / self.size[0],
|
||||
img[0].size[1] / self.size[1])
|
||||
img = [
|
||||
u.resize((round(u.width // scale), round(u.height // scale)),
|
||||
resample=Image.BOX) for u in img
|
||||
]
|
||||
|
||||
# center crop
|
||||
x1 = (img[0].width - self.size[0]) // 2
|
||||
y1 = (img[0].height - self.size[1]) // 2
|
||||
img = [
|
||||
u.crop((x1, y1, x1 + self.size[0], y1 + self.size[1]))
|
||||
for u in img
|
||||
]
|
||||
return img
|
||||
else:
|
||||
scale = min(img.size[0] / self.size[0], img.size[1] / self.size[1])
|
||||
img = img.resize(
|
||||
(round(img.width // scale), round(img.height // scale)),
|
||||
resample=Image.BOX)
|
||||
x1 = (img.width - self.size[0]) // 2
|
||||
y1 = (img.height - self.size[1]) // 2
|
||||
img = img.crop((x1, y1, x1 + self.size[0], y1 + self.size[1]))
|
||||
return img
|
||||
|
||||
|
||||
class RandomCrop(object):
|
||||
|
||||
def __init__(self, size=224, min_area=0.4):
|
||||
self.size = size
|
||||
self.min_area = min_area
|
||||
|
||||
def __call__(self, rgb):
|
||||
|
||||
# consistent crop between rgb and m
|
||||
w, h = rgb[0].size
|
||||
area = w * h
|
||||
out_w, out_h = float('inf'), float('inf')
|
||||
while out_w > w or out_h > h:
|
||||
target_area = random.uniform(self.min_area, 1.0) * area
|
||||
aspect_ratio = random.uniform(3. / 4., 4. / 3.)
|
||||
out_w = int(round(math.sqrt(target_area * aspect_ratio)))
|
||||
out_h = int(round(math.sqrt(target_area / aspect_ratio)))
|
||||
x1 = random.randint(0, w - out_w)
|
||||
y1 = random.randint(0, h - out_h)
|
||||
|
||||
rgb = [u.crop((x1, y1, x1 + out_w, y1 + out_h)) for u in rgb]
|
||||
rgb = [u.resize((self.size, self.size), Image.BILINEAR) for u in rgb]
|
||||
|
||||
return rgb
|
||||
|
||||
|
||||
class RandomCropV2(object):
|
||||
|
||||
def __init__(self, size=224, min_area=0.4, ratio=(3. / 4., 4. / 3.)):
|
||||
if isinstance(size, (tuple, list)):
|
||||
self.size = size
|
||||
else:
|
||||
self.size = (size, size)
|
||||
self.min_area = min_area
|
||||
self.ratio = ratio
|
||||
|
||||
def _get_params(self, img):
|
||||
width, height = img.size
|
||||
area = height * width
|
||||
|
||||
for _ in range(10):
|
||||
target_area = random.uniform(self.min_area, 1.0) * area
|
||||
log_ratio = (math.log(self.ratio[0]), math.log(self.ratio[1]))
|
||||
aspect_ratio = math.exp(random.uniform(*log_ratio))
|
||||
|
||||
w = int(round(math.sqrt(target_area * aspect_ratio)))
|
||||
h = int(round(math.sqrt(target_area / aspect_ratio)))
|
||||
|
||||
if 0 < w <= width and 0 < h <= height:
|
||||
i = random.randint(0, height - h)
|
||||
j = random.randint(0, width - w)
|
||||
return i, j, h, w
|
||||
|
||||
# Fallback to central crop
|
||||
in_ratio = float(width) / float(height)
|
||||
if (in_ratio < min(self.ratio)):
|
||||
w = width
|
||||
h = int(round(w / min(self.ratio)))
|
||||
elif (in_ratio > max(self.ratio)):
|
||||
h = height
|
||||
w = int(round(h * max(self.ratio)))
|
||||
else:
|
||||
w = width
|
||||
h = height
|
||||
i = (height - h) // 2
|
||||
j = (width - w) // 2
|
||||
return i, j, h, w
|
||||
|
||||
def __call__(self, rgb):
|
||||
i, j, h, w = self._get_params(rgb[0])
|
||||
rgb = [F.resized_crop(u, i, j, h, w, self.size) for u in rgb]
|
||||
return rgb
|
||||
|
||||
|
||||
class RandomHFlip(object):
|
||||
|
||||
def __init__(self, p=0.5):
|
||||
self.p = p
|
||||
|
||||
def __call__(self, rgb):
|
||||
if random.random() < self.p:
|
||||
rgb = [u.transpose(Image.FLIP_LEFT_RIGHT) for u in rgb]
|
||||
return rgb
|
||||
|
||||
|
||||
class GaussianBlur(object):
|
||||
|
||||
def __init__(self, sigmas=[0.1, 2.0], p=0.5):
|
||||
self.sigmas = sigmas
|
||||
self.p = p
|
||||
|
||||
def __call__(self, rgb):
|
||||
if random.random() < self.p:
|
||||
sigma = random.uniform(*self.sigmas)
|
||||
rgb = [
|
||||
u.filter(ImageFilter.GaussianBlur(radius=sigma)) for u in rgb
|
||||
]
|
||||
return rgb
|
||||
|
||||
|
||||
class ColorJitter(object):
|
||||
|
||||
def __init__(self,
|
||||
brightness=0.4,
|
||||
contrast=0.4,
|
||||
saturation=0.4,
|
||||
hue=0.1,
|
||||
p=0.5):
|
||||
self.brightness = brightness
|
||||
self.contrast = contrast
|
||||
self.saturation = saturation
|
||||
self.hue = hue
|
||||
self.p = p
|
||||
|
||||
def __call__(self, rgb):
|
||||
if random.random() < self.p:
|
||||
brightness, contrast, saturation, hue = self._random_params()
|
||||
transforms = [
|
||||
lambda f: F.adjust_brightness(f, brightness),
|
||||
lambda f: F.adjust_contrast(f, contrast),
|
||||
lambda f: F.adjust_saturation(f, saturation),
|
||||
lambda f: F.adjust_hue(f, hue)
|
||||
]
|
||||
random.shuffle(transforms)
|
||||
for t in transforms:
|
||||
rgb = [t(u) for u in rgb]
|
||||
|
||||
return rgb
|
||||
|
||||
def _random_params(self):
|
||||
brightness = random.uniform(
|
||||
max(0, 1 - self.brightness), 1 + self.brightness)
|
||||
contrast = random.uniform(max(0, 1 - self.contrast), 1 + self.contrast)
|
||||
saturation = random.uniform(
|
||||
max(0, 1 - self.saturation), 1 + self.saturation)
|
||||
hue = random.uniform(-self.hue, self.hue)
|
||||
return brightness, contrast, saturation, hue
|
||||
|
||||
|
||||
class RandomGray(object):
|
||||
|
||||
def __init__(self, p=0.2):
|
||||
self.p = p
|
||||
|
||||
def __call__(self, rgb):
|
||||
if random.random() < self.p:
|
||||
rgb = [u.convert('L').convert('RGB') for u in rgb]
|
||||
return rgb
|
||||
|
||||
|
||||
class ToTensor(object):
|
||||
|
||||
def __call__(self, rgb):
|
||||
if isinstance(rgb, list):
|
||||
rgb = torch.stack([F.to_tensor(u) for u in rgb], dim=0)
|
||||
else:
|
||||
rgb = F.to_tensor(rgb)
|
||||
|
||||
return rgb
|
||||
|
||||
|
||||
class Normalize(object):
|
||||
|
||||
def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
|
||||
self.mean = mean
|
||||
self.std = std
|
||||
|
||||
def __call__(self, rgb):
|
||||
rgb = rgb.clone()
|
||||
rgb.clamp_(0, 1)
|
||||
if not isinstance(self.mean, torch.Tensor):
|
||||
self.mean = rgb.new_tensor(self.mean).view(-1)
|
||||
if not isinstance(self.std, torch.Tensor):
|
||||
self.std = rgb.new_tensor(self.std).view(-1)
|
||||
if rgb.dim() == 4:
|
||||
rgb.sub_(self.mean.view(1, -1, 1,
|
||||
1)).div_(self.std.view(1, -1, 1, 1))
|
||||
elif rgb.dim() == 3:
|
||||
rgb.sub_(self.mean.view(-1, 1, 1)).div_(self.std.view(-1, 1, 1))
|
||||
return rgb
|
||||
227
modelscope/models/multi_modal/video_to_video/video_to_video_model.py
Executable file
227
modelscope/models/multi_modal/video_to_video/video_to_video_model.py
Executable file
@@ -0,0 +1,227 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
import os.path as osp
|
||||
import random
|
||||
from copy import copy
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
import torch.cuda.amp as amp
|
||||
import torch.nn.functional as F
|
||||
|
||||
import modelscope.models.multi_modal.video_to_video.utils.transforms as data
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models.base import TorchModel
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.models.multi_modal.video_to_video.modules import *
|
||||
from modelscope.models.multi_modal.video_to_video.modules import (
|
||||
AutoencoderKL, FrozenOpenCLIPEmbedder, Vid2VidSDUNet,
|
||||
get_first_stage_encoding)
|
||||
from modelscope.models.multi_modal.video_to_video.utils.config import cfg
|
||||
from modelscope.models.multi_modal.video_to_video.utils.diffusion_sdedit import \
|
||||
GaussianDiffusion_SDEdit
|
||||
from modelscope.models.multi_modal.video_to_video.utils.schedules_sdedit import \
|
||||
noise_schedule
|
||||
from modelscope.models.multi_modal.video_to_video.utils.seed import setup_seed
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
__all__ = ['VideoToVideo']
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
@MODELS.register_module(
|
||||
Tasks.video_to_video, module_name=Models.video_to_video_model)
|
||||
class VideoToVideo(TorchModel):
|
||||
r"""
|
||||
Video2Video aims to solve the task of generating super-resolution videos based on input
|
||||
video and text, which is a video generation basic model developed by Alibaba Cloud.
|
||||
|
||||
Paper link: https://arxiv.org/abs/2306.02018
|
||||
|
||||
Attributes:
|
||||
diffusion: diffusion model for DDIM.
|
||||
autoencoder: decode the latent representation of input video into visual space.
|
||||
clip_encoder: encode the text into text embedding.
|
||||
"""
|
||||
|
||||
def __init__(self, model_dir, *args, **kwargs):
|
||||
r"""
|
||||
Args:
|
||||
model_dir (`str` or `os.PathLike`)
|
||||
Can be either:
|
||||
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co
|
||||
or modelscope.cn. Valid model ids can be located at the root-level, like `bert-base-uncased`,
|
||||
or namespaced under a user or organization name, like `dbmdz/bert-base-german-cased`.
|
||||
- A path to a *directory* containing model weights saved using
|
||||
[`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
|
||||
- A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In
|
||||
this case, `from_tf` should be set to `True` and a configuration object should be provided as
|
||||
`config` argument. This loading path is slower than converting the TensorFlow checkpoint in a
|
||||
PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
|
||||
- A path or url to a model folder containing a *flax checkpoint file* in *.msgpack* format (e.g,
|
||||
`./flax_model/` containing `flax_model.msgpack`). In this case, `from_flax` should be set to
|
||||
`True`.
|
||||
"""
|
||||
super().__init__(model_dir=model_dir, *args, **kwargs)
|
||||
|
||||
self.config = Config.from_file(
|
||||
osp.join(model_dir, ModelFile.CONFIGURATION))
|
||||
|
||||
cfg.solver_mode = self.config.model.model_args.solver_mode
|
||||
|
||||
# assign default value
|
||||
cfg.batch_size = self.config.model.model_cfg.batch_size
|
||||
cfg.target_fps = self.config.model.model_cfg.target_fps
|
||||
cfg.max_frames = self.config.model.model_cfg.max_frames
|
||||
cfg.latent_hei = self.config.model.model_cfg.latent_hei
|
||||
cfg.latent_wid = self.config.model.model_cfg.latent_wid
|
||||
cfg.model_path = osp.join(model_dir,
|
||||
self.config.model.model_args.ckpt_unet)
|
||||
|
||||
self.device = torch.device(
|
||||
'cuda') if torch.cuda.is_available() else torch.device('cpu')
|
||||
|
||||
if 'seed' in self.config.model.model_args.keys():
|
||||
cfg.seed = self.config.model.model_args.seed
|
||||
else:
|
||||
cfg.seed = random.randint(0, 99999)
|
||||
setup_seed(cfg.seed)
|
||||
|
||||
# transform
|
||||
vid_trans = data.Compose(
|
||||
[data.ToTensor(),
|
||||
data.Normalize(mean=cfg.mean, std=cfg.std)])
|
||||
self.vid_trans = vid_trans
|
||||
|
||||
cfg.embedder.pretrained = osp.join(
|
||||
model_dir, self.config.model.model_args.ckpt_clip)
|
||||
clip_encoder = FrozenOpenCLIPEmbedder(
|
||||
pretrained=cfg.embedder.pretrained)
|
||||
clip_encoder.model.to(self.device)
|
||||
self.clip_encoder = clip_encoder
|
||||
logger.info(f'Build encoder with {cfg.embedder.type}')
|
||||
|
||||
# [unet]
|
||||
generator = Vid2VidSDUNet()
|
||||
generator = generator.to(self.device)
|
||||
generator.eval()
|
||||
load_dict = torch.load(cfg.model_path, map_location='cpu')
|
||||
ret = generator.load_state_dict(load_dict['state_dict'], strict=True)
|
||||
self.generator = generator
|
||||
logger.info('Load model {} path {}, with local status {}'.format(
|
||||
cfg.UNet.type, cfg.model_path, ret))
|
||||
|
||||
# [diffusion]
|
||||
sigmas = noise_schedule(
|
||||
schedule='logsnr_cosine_interp',
|
||||
n=1000,
|
||||
zero_terminal_snr=True,
|
||||
scale_min=2.0,
|
||||
scale_max=4.0)
|
||||
diffusion = GaussianDiffusion_SDEdit(
|
||||
sigmas=sigmas, prediction_type='v')
|
||||
self.diffusion = diffusion
|
||||
logger.info('Build diffusion with type of GaussianDiffusion_SDEdit')
|
||||
|
||||
# [auotoencoder]
|
||||
cfg.auto_encoder.pretrained = osp.join(
|
||||
model_dir, self.config.model.model_args.ckpt_autoencoder)
|
||||
autoencoder = AutoencoderKL(**cfg.auto_encoder)
|
||||
autoencoder.eval()
|
||||
for param in autoencoder.parameters():
|
||||
param.requires_grad = False
|
||||
autoencoder.to(self.device)
|
||||
self.autoencoder = autoencoder
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
negative_prompt = cfg.negative_prompt
|
||||
negative_y = clip_encoder(negative_prompt).detach()
|
||||
self.negative_y = negative_y
|
||||
|
||||
positive_prompt = cfg.positive_prompt
|
||||
self.positive_prompt = positive_prompt
|
||||
|
||||
self.cfg = cfg
|
||||
|
||||
def forward(self, input: Dict[str, Any]):
|
||||
r"""
|
||||
The entry function of video to video task.
|
||||
1. Using CLIP to encode text into embeddings.
|
||||
2. Using diffusion model to generate the video's latent representation.
|
||||
3. Using autoencoder to decode the video's latent representation to visual space.
|
||||
|
||||
Args:
|
||||
input (`Dict[Str, Any]`):
|
||||
The input of the task
|
||||
Returns:
|
||||
A generated video (as pytorch tensor).
|
||||
"""
|
||||
|
||||
video_data = input['video_data']
|
||||
y = input['y']
|
||||
cfg = self.cfg
|
||||
|
||||
video_data = F.interpolate(
|
||||
video_data, size=(720, 1280), mode='bilinear')
|
||||
video_data = video_data.unsqueeze(0)
|
||||
video_data = video_data.to(self.device)
|
||||
|
||||
batch_size, frames_num, _, _, _ = video_data.shape
|
||||
video_data = rearrange(video_data, 'b f c h w -> (b f) c h w')
|
||||
|
||||
video_data_list = torch.chunk(
|
||||
video_data, video_data.shape[0] // 2, dim=0)
|
||||
with torch.no_grad():
|
||||
decode_data = []
|
||||
for vd_data in video_data_list:
|
||||
encoder_posterior = self.autoencoder.encode(vd_data)
|
||||
tmp = get_first_stage_encoding(encoder_posterior).detach()
|
||||
decode_data.append(tmp)
|
||||
video_data_feature = torch.cat(decode_data, dim=0)
|
||||
video_data_feature = rearrange(
|
||||
video_data_feature, '(b f) c h w -> b c f h w', b=batch_size)
|
||||
|
||||
with amp.autocast(enabled=True):
|
||||
total_noise_levels = 600
|
||||
t = torch.randint(
|
||||
total_noise_levels - 1,
|
||||
total_noise_levels, (1, ),
|
||||
dtype=torch.long).to(self.device)
|
||||
|
||||
noise = torch.randn_like(video_data_feature)
|
||||
noised_lr = self.diffusion.diffuse(video_data_feature, t, noise)
|
||||
model_kwargs = [{'y': y}, {'y': self.negative_y}]
|
||||
|
||||
gen_vid = self.diffusion.sample(
|
||||
noise=noised_lr,
|
||||
model=self.generator,
|
||||
model_kwargs=model_kwargs,
|
||||
guide_scale=7.5,
|
||||
guide_rescale=0.2,
|
||||
solver='dpmpp_2m_sde' if cfg.solver_mode == 'fast' else 'heun',
|
||||
steps=30 if cfg.solver_mode == 'fast' else 50,
|
||||
t_max=total_noise_levels - 1,
|
||||
t_min=0,
|
||||
discretization='trailing')
|
||||
|
||||
scale_factor = 0.18215
|
||||
vid_tensor_feature = 1. / scale_factor * gen_vid
|
||||
|
||||
vid_tensor_feature = rearrange(vid_tensor_feature,
|
||||
'b c f h w -> (b f) c h w')
|
||||
vid_tensor_feature_list = torch.chunk(
|
||||
vid_tensor_feature, vid_tensor_feature.shape[0] // 2, dim=0)
|
||||
decode_data = []
|
||||
for vd_data in vid_tensor_feature_list:
|
||||
tmp = self.autoencoder.decode(vd_data)
|
||||
decode_data.append(tmp)
|
||||
vid_tensor_gen = torch.cat(decode_data, dim=0)
|
||||
|
||||
gen_video = rearrange(
|
||||
vid_tensor_gen, '(b f) c h w -> b c f h w', b=cfg.batch_size)
|
||||
|
||||
return gen_video.type(torch.float32).cpu()
|
||||
@@ -65,6 +65,7 @@ if TYPE_CHECKING:
|
||||
ModelForTextRanking,
|
||||
ModelForTokenClassification,
|
||||
ModelForTokenClassificationWithCRF,
|
||||
ModelForMachineReadingComprehension,
|
||||
)
|
||||
from .unite import UniTEForTranslationEvaluation
|
||||
from .use import UserSatisfactionEstimation
|
||||
@@ -159,6 +160,7 @@ else:
|
||||
'ModelForTextRanking',
|
||||
'ModelForTokenClassification',
|
||||
'ModelForTokenClassificationWithCRF',
|
||||
'ModelForMachineReadingComprehension',
|
||||
],
|
||||
'sentence_embedding': ['SentenceEmbedding'],
|
||||
'T5': ['T5ForConditionalGeneration'],
|
||||
|
||||
@@ -13,6 +13,7 @@ if TYPE_CHECKING:
|
||||
ModelForTokenClassificationWithCRF)
|
||||
from .text_generation import ModelForTextGeneration
|
||||
from .text_ranking import ModelForTextRanking
|
||||
from .machine_reading_comprehension import ModelForMachineReadingComprehension
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
@@ -25,6 +26,8 @@ else:
|
||||
['ModelForTokenClassification', 'ModelForTokenClassificationWithCRF'],
|
||||
'text_generation': ['ModelForTextGeneration'],
|
||||
'text_ranking': ['ModelForTextRanking'],
|
||||
'machine_reading_comprehension':
|
||||
['ModelForMachineReadingComprehension'],
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
@@ -0,0 +1,139 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from transformers import AutoConfig
|
||||
from transformers.modeling_outputs import ModelOutput
|
||||
from transformers.models.roberta.modeling_roberta import (
|
||||
RobertaModel, RobertaPreTrainedModel)
|
||||
|
||||
from modelscope.metainfo import Heads, Models, TaskModels
|
||||
from modelscope.models.base import TorchModel
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.models.nlp.task_models.task_model import EncoderModel
|
||||
from modelscope.outputs import MachineReadingComprehensionOutput, OutputKeys
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from modelscope.utils.hub import parse_label_mapping
|
||||
|
||||
__all__ = ['ModelForMachineReadingComprehension']
|
||||
|
||||
|
||||
@MODELS.register_module(
|
||||
Tasks.machine_reading_comprehension,
|
||||
module_name=TaskModels.machine_reading_comprehension)
|
||||
class ModelForMachineReadingComprehension(TorchModel):
|
||||
'''
|
||||
Pretrained Machine Reader (PMR) model (https://arxiv.org/pdf/2212.04755.pdf)
|
||||
|
||||
'''
|
||||
|
||||
_keys_to_ignore_on_load_unexpected = [r'pooler']
|
||||
_keys_to_ignore_on_load_missing = [r'position_ids']
|
||||
|
||||
def __init__(self, model_dir: str, *args, **kwargs):
|
||||
super().__init__(model_dir, *args, **kwargs)
|
||||
self.config = AutoConfig.from_pretrained(model_dir)
|
||||
self.num_labels = self.config.num_labels
|
||||
self.roberta = RobertaModel(self.config, add_pooling_layer=False)
|
||||
self.span_transfer = MultiNonLinearProjection(
|
||||
self.config.hidden_size,
|
||||
self.config.hidden_size,
|
||||
self.config.hidden_dropout_prob,
|
||||
intermediate_hidden_size=self.config.
|
||||
projection_intermediate_hidden_size)
|
||||
self.load_state_dict(
|
||||
torch.load(
|
||||
os.path.join(model_dir, ModelFile.TORCH_MODEL_BIN_FILE)))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
label_mask=None,
|
||||
match_labels=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = self.roberta(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
|
||||
# adapted from https://github.com/ShannonAI/mrc-for-flat-nested-ner
|
||||
# for every position $i$ in sequence, should concate $j$ to
|
||||
# predict if $i$ and $j$ are start_pos and end_pos for an entity.
|
||||
# [batch, seq_len, hidden]
|
||||
span_intermediate = self.span_transfer(sequence_output)
|
||||
# [batch, seq_len, seq_len]
|
||||
span_logits = torch.matmul(span_intermediate,
|
||||
sequence_output.transpose(-1, -2))
|
||||
|
||||
total_loss = None
|
||||
if match_labels is not None:
|
||||
match_loss = self.compute_loss(span_logits, match_labels,
|
||||
label_mask)
|
||||
total_loss = match_loss
|
||||
if not return_dict:
|
||||
output = (span_logits) + outputs[2:]
|
||||
return ((total_loss, )
|
||||
+ output) if total_loss is not None else output
|
||||
|
||||
return MachineReadingComprehensionOutput(
|
||||
loss=total_loss,
|
||||
span_logits=span_logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
class MultiNonLinearProjection(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
hidden_size,
|
||||
num_label,
|
||||
dropout_rate,
|
||||
act_func='gelu',
|
||||
intermediate_hidden_size=None):
|
||||
super(MultiNonLinearProjection, self).__init__()
|
||||
self.num_label = num_label
|
||||
self.intermediate_hidden_size = hidden_size if intermediate_hidden_size is None else intermediate_hidden_size
|
||||
self.classifier1 = nn.Linear(hidden_size,
|
||||
self.intermediate_hidden_size)
|
||||
self.classifier2 = nn.Linear(self.intermediate_hidden_size,
|
||||
self.num_label)
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
self.act_func = act_func
|
||||
|
||||
def forward(self, input_features):
|
||||
features_output1 = self.classifier1(input_features)
|
||||
if self.act_func == 'gelu':
|
||||
features_output1 = F.gelu(features_output1)
|
||||
elif self.act_func == 'relu':
|
||||
features_output1 = F.relu(features_output1)
|
||||
elif self.act_func == 'tanh':
|
||||
features_output1 = F.tanh(features_output1)
|
||||
else:
|
||||
raise ValueError
|
||||
features_output1 = self.dropout(features_output1)
|
||||
features_output2 = self.classifier2(features_output1)
|
||||
return features_output2
|
||||
@@ -464,3 +464,25 @@ class TranslationEvaluationOutput(ModelOutputBase):
|
||||
score: Tensor = None
|
||||
loss: Tensor = None
|
||||
input_format: List[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class MachineReadingComprehensionOutput(ModelOutputBase):
|
||||
"""The output class for machine reading comprehension models.
|
||||
|
||||
Args:
|
||||
loss (`Tensor`, *optional*): The training loss of the current batch
|
||||
match_loss (`Tensor`, *optinal*): The match loss of the current batch
|
||||
span_logits (`Tensor`): The logits of the span matrix output by the model
|
||||
hidden_states (`Tuple[Tensor]`, *optinal*): The hidden states output by the model
|
||||
attentions (`Tuple[Tensor]`, *optinal*): The attention scores output by the model
|
||||
input_ids (`Tensor`): The token ids of the input sentence
|
||||
|
||||
"""
|
||||
|
||||
loss: Optional[Tensor] = None
|
||||
match_loss: Optional[Tensor] = None
|
||||
span_logits: Tensor = None
|
||||
hidden_states: Optional[Tuple[Tensor]] = None
|
||||
attentions: Optional[Tuple[Tensor]] = None
|
||||
input_ids: Tensor = None
|
||||
|
||||
@@ -323,6 +323,8 @@ TASK_INPUTS = {
|
||||
'positive': InputType.LIST,
|
||||
'negative': InputType.LIST
|
||||
},
|
||||
Tasks.machine_reading_comprehension:
|
||||
InputType.TEXT,
|
||||
|
||||
# ============ audio tasks ===================
|
||||
Tasks.auto_speech_recognition: # input can be audio, or audio and text.
|
||||
|
||||
@@ -81,28 +81,24 @@ class SpeakerVerificationPipeline(Pipeline):
|
||||
inputs: torch.Tensor,
|
||||
in_audios: Union[np.ndarray, list],
|
||||
save_dir=None):
|
||||
if isinstance(in_audios[0], str):
|
||||
if save_dir is not None:
|
||||
# save the embeddings
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
for i, p in enumerate(in_audios):
|
||||
save_path = os.path.join(
|
||||
save_dir, '%s.npy' %
|
||||
(os.path.basename(p).rsplit('.', 1)[0]))
|
||||
np.save(save_path, inputs[i].numpy())
|
||||
if isinstance(in_audios[0], str) and save_dir is not None:
|
||||
# save the embeddings
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
for i, p in enumerate(in_audios):
|
||||
save_path = os.path.join(
|
||||
save_dir, '%s.npy' %
|
||||
(os.path.basename(p).rsplit('.', 1)[0]))
|
||||
np.save(save_path, inputs[i].numpy())
|
||||
|
||||
if len(in_audios) == 2:
|
||||
# compute the score
|
||||
score = self.compute_cos_similarity(inputs[0], inputs[1])
|
||||
score = round(score, 5)
|
||||
if score >= self.thr:
|
||||
ans = 'yes'
|
||||
else:
|
||||
ans = 'no'
|
||||
output = {OutputKeys.SCORE: score, OutputKeys.TEXT: ans}
|
||||
if len(inputs) == 2:
|
||||
# compute the score
|
||||
score = self.compute_cos_similarity(inputs[0], inputs[1])
|
||||
score = round(score, 5)
|
||||
if score >= self.thr:
|
||||
ans = 'yes'
|
||||
else:
|
||||
output = {OutputKeys.TEXT: 'No similarity score output'}
|
||||
|
||||
ans = 'no'
|
||||
output = {OutputKeys.SCORE: score, OutputKeys.TEXT: ans}
|
||||
else:
|
||||
output = {OutputKeys.TEXT: 'No similarity score output'}
|
||||
|
||||
|
||||
@@ -40,6 +40,7 @@ class StableDiffusionPipeline(DiffusersPipeline):
|
||||
use_safetensors: load safetensors weights.
|
||||
"""
|
||||
use_safetensors = kwargs.pop('use_safetensors', False)
|
||||
torch_type = kwargs.pop('torch_type', torch.float32)
|
||||
# check custom diffusion input value
|
||||
if custom_dir is None and modifier_token is not None:
|
||||
raise ValueError(
|
||||
@@ -50,7 +51,6 @@ class StableDiffusionPipeline(DiffusersPipeline):
|
||||
|
||||
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
# load pipeline
|
||||
torch_type = torch.float16 if self.device == 'cuda' else torch.float32
|
||||
self.pipeline = DiffusionPipeline.from_pretrained(
|
||||
model, use_safetensors=use_safetensors, torch_dtype=torch_type)
|
||||
self.pipeline = self.pipeline.to(self.device)
|
||||
|
||||
104
modelscope/pipelines/multi_modal/image_to_video_pipeline.py
Normal file
104
modelscope/pipelines/multi_modal/image_to_video_pipeline.py
Normal file
@@ -0,0 +1,104 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
from einops import rearrange
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.base import Input, Model, Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.preprocessors.image import LoadImage
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.image_to_video, module_name=Pipelines.image_to_video_task_pipeline)
|
||||
class ImageToVideoPipeline(Pipeline):
|
||||
r""" Image To Video Pipeline.
|
||||
|
||||
Examples:
|
||||
>>> from modelscope.pipelines import pipeline
|
||||
>>> from modelscope.outputs import OutputKeys
|
||||
|
||||
>>> p = pipeline('image-to-video', 'damo/Image-to-Video')
|
||||
>>> input = 'path_to_image'
|
||||
>>> p(input,)
|
||||
|
||||
>>> {OutputKeys.OUTPUT_VIDEO: path-to-the-generated-video}
|
||||
>>>
|
||||
"""
|
||||
|
||||
def __init__(self, model: str, **kwargs):
|
||||
"""
|
||||
use `model` to create a kws pipeline for prediction
|
||||
Args:
|
||||
model: model id on modelscope hub.
|
||||
"""
|
||||
super().__init__(model=model, **kwargs)
|
||||
|
||||
def preprocess(self, input: Input, **preprocess_params) -> Dict[str, Any]:
|
||||
img_path = input
|
||||
|
||||
image = LoadImage.convert_to_img(img_path)
|
||||
if image.mode != 'RGB':
|
||||
image = image.convert('RGB')
|
||||
|
||||
vit_frame = self.model.vid_trans(image)
|
||||
vit_frame = vit_frame.unsqueeze(0)
|
||||
vit_frame = vit_frame.to(self.model.device)
|
||||
|
||||
return {'vit_frame': vit_frame}
|
||||
|
||||
def forward(self, input: Dict[str, Any],
|
||||
**forward_params) -> Dict[str, Any]:
|
||||
video = self.model(input)
|
||||
return {'video': video}
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Any],
|
||||
**post_params) -> Dict[str, Any]:
|
||||
video = tensor2vid(inputs['video'], self.model.cfg.mean,
|
||||
self.model.cfg.std)
|
||||
output_video_path = post_params.get('output_video', None)
|
||||
temp_video_file = False
|
||||
if output_video_path is None:
|
||||
output_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name
|
||||
temp_video_file = True
|
||||
|
||||
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
||||
h, w, c = video[0].shape
|
||||
video_writer = cv2.VideoWriter(
|
||||
output_video_path, fourcc, fps=8, frameSize=(w, h))
|
||||
for i in range(len(video)):
|
||||
img = cv2.cvtColor(video[i], cv2.COLOR_RGB2BGR)
|
||||
video_writer.write(img)
|
||||
video_writer.release()
|
||||
if temp_video_file:
|
||||
video_file_content = b''
|
||||
with open(output_video_path, 'rb') as f:
|
||||
video_file_content = f.read()
|
||||
os.remove(output_video_path)
|
||||
return {OutputKeys.OUTPUT_VIDEO: video_file_content}
|
||||
else:
|
||||
return {OutputKeys.OUTPUT_VIDEO: output_video_path}
|
||||
|
||||
|
||||
def tensor2vid(video, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]):
|
||||
mean = torch.tensor(mean, device=video.device).reshape(1, -1, 1, 1, 1)
|
||||
std = torch.tensor(std, device=video.device).reshape(1, -1, 1, 1, 1)
|
||||
|
||||
video = video.mul_(std).add_(mean)
|
||||
video.clamp_(0, 1)
|
||||
video = video * 255.0
|
||||
|
||||
images = rearrange(video, 'b c f h w -> b f h w c')[0]
|
||||
images = [(img.numpy()).astype('uint8') for img in images]
|
||||
|
||||
return images
|
||||
140
modelscope/pipelines/multi_modal/video_to_video_pipeline.py
Normal file
140
modelscope/pipelines/multi_modal/video_to_video_pipeline.py
Normal file
@@ -0,0 +1,140 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
from einops import rearrange
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.base import Input, Model, Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.preprocessors.image import LoadImage
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.video_to_video, module_name=Pipelines.video_to_video_pipeline)
|
||||
class VideoToVideoPipeline(Pipeline):
|
||||
r""" Video To Video Pipeline, generating super-resolution videos based on input
|
||||
video and text
|
||||
|
||||
Examples:
|
||||
>>> from modelscope.pipelines import pipeline
|
||||
>>> from modelscope.outputs import OutputKeys
|
||||
|
||||
>>> # YOUR_VIDEO_PATH: your video url or local position in low resolution
|
||||
>>> # INPUT_TEXT: when we do video super-resolution, we will add the text content
|
||||
>>> # into results
|
||||
>>> # output_video_path: path-to-the-generated-video
|
||||
|
||||
>>> p = pipeline('video-to-video', 'damo/Video-to-Video')
|
||||
>>> input = {"video_path":YOUR_VIDEO_PATH, "text": INPUT_TEXT}
|
||||
>>> output_video_path = p(input,output_video='./output.mp4')[OutputKeys.OUTPUT_VIDEO]
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, model: str, **kwargs):
|
||||
"""
|
||||
use `model` to create a kws pipeline for prediction
|
||||
Args:
|
||||
model: model id on modelscope hub.
|
||||
"""
|
||||
super().__init__(model=model, **kwargs)
|
||||
|
||||
def preprocess(self, input: Input, **preprocess_params) -> Dict[str, Any]:
|
||||
vid_path = input['video_path']
|
||||
if 'text' in input.keys():
|
||||
text = input['text']
|
||||
else:
|
||||
text = ''
|
||||
|
||||
caption = text + self.model.positive_prompt
|
||||
y = self.model.clip_encoder(caption).detach()
|
||||
|
||||
max_frames = self.model.cfg.max_frames
|
||||
|
||||
capture = cv2.VideoCapture(vid_path)
|
||||
_fps = capture.get(cv2.CAP_PROP_FPS)
|
||||
sample_fps = _fps
|
||||
_total_frame_num = capture.get(cv2.CAP_PROP_FRAME_COUNT)
|
||||
stride = round(_fps / sample_fps)
|
||||
start_frame = 0
|
||||
|
||||
pointer = 0
|
||||
frame_list = []
|
||||
while len(frame_list) < max_frames:
|
||||
ret, frame = capture.read()
|
||||
pointer += 1
|
||||
if (not ret) or (frame is None):
|
||||
break
|
||||
if pointer < start_frame:
|
||||
continue
|
||||
if pointer >= _total_frame_num + 1:
|
||||
break
|
||||
if (pointer - start_frame) % stride == 0:
|
||||
frame = LoadImage.convert_to_img(frame)
|
||||
frame_list.append(frame)
|
||||
capture.release()
|
||||
|
||||
video_data = self.model.vid_trans(frame_list)
|
||||
|
||||
return {'video_data': video_data, 'y': y}
|
||||
|
||||
def forward(self, input: Dict[str, Any],
|
||||
**forward_params) -> Dict[str, Any]:
|
||||
video = self.model(input)
|
||||
return {'video': video}
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Any],
|
||||
**post_params) -> Dict[str, Any]:
|
||||
video = tensor2vid(inputs['video'], self.model.cfg.mean,
|
||||
self.model.cfg.std)
|
||||
output_video_path = post_params.get('output_video', None)
|
||||
temp_video_file = False
|
||||
if output_video_path is None:
|
||||
output_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name
|
||||
temp_video_file = True
|
||||
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
for fid, frame in enumerate(video):
|
||||
tpth = os.path.join(temp_dir, '%06d.png' % (fid + 1))
|
||||
cv2.imwrite(tpth, frame[:, :, ::-1],
|
||||
[int(cv2.IMWRITE_JPEG_QUALITY), 100])
|
||||
|
||||
cmd = f'ffmpeg -y -f image2 -loglevel quiet -framerate 8.0 -i {temp_dir}/%06d.png \
|
||||
-vcodec libx264 -crf 17 -pix_fmt yuv420p {output_video_path}'
|
||||
|
||||
status = os.system(cmd)
|
||||
if status != 0:
|
||||
logger.info('Save Video Error with {}'.format(status))
|
||||
os.system(f'rm -rf {temp_dir}')
|
||||
|
||||
if temp_video_file:
|
||||
video_file_content = b''
|
||||
with open(output_video_path, 'rb') as f:
|
||||
video_file_content = f.read()
|
||||
os.remove(output_video_path)
|
||||
return {OutputKeys.OUTPUT_VIDEO: video_file_content}
|
||||
else:
|
||||
return {OutputKeys.OUTPUT_VIDEO: output_video_path}
|
||||
|
||||
|
||||
def tensor2vid(video, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]):
|
||||
mean = torch.tensor(mean, device=video.device).reshape(1, -1, 1, 1, 1)
|
||||
std = torch.tensor(std, device=video.device).reshape(1, -1, 1, 1, 1)
|
||||
|
||||
video = video.mul_(std).add_(mean)
|
||||
video.clamp_(0, 1)
|
||||
video = video * 255.0
|
||||
|
||||
images = rearrange(video, 'b c f h w -> b f h w c')[0]
|
||||
images = [(img.numpy()).astype('uint8') for img in images]
|
||||
|
||||
return images
|
||||
@@ -45,6 +45,7 @@ if TYPE_CHECKING:
|
||||
from .document_grounded_dialog_retrieval_pipeline import DocumentGroundedDialogRetrievalPipeline
|
||||
from .document_grounded_dialog_rerank_pipeline import DocumentGroundedDialogRerankPipeline
|
||||
from .language_identification_pipline import LanguageIdentificationPipeline
|
||||
from .machine_reading_comprehension_pipeline import MachineReadingComprehensionForNERPipeline
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
@@ -108,7 +109,10 @@ else:
|
||||
'document_grounded_dialog_retrieval_pipeline': [
|
||||
'DocumentGroundedDialogRetrievalPipeline'
|
||||
],
|
||||
'language_identification_pipline': ['LanguageIdentificationPipeline']
|
||||
'language_identification_pipline': ['LanguageIdentificationPipeline'],
|
||||
'machine_reading_comprehension_pipeline': [
|
||||
'MachineReadingComprehensionForNERPipeline'
|
||||
],
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
@@ -0,0 +1,83 @@
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
import torch
|
||||
|
||||
from modelscope.metainfo import Pipelines, Preprocessors
|
||||
from modelscope.models.base import Model
|
||||
from modelscope.outputs import MachineReadingComprehensionOutput, OutputKeys
|
||||
from modelscope.pipelines.base import Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.preprocessors import Preprocessor
|
||||
from modelscope.utils.constant import Fields, Tasks
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.machine_reading_comprehension,
|
||||
module_name=Pipelines.machine_reading_comprehension_for_ner)
|
||||
class MachineReadingComprehensionForNERPipeline(Pipeline):
|
||||
'''
|
||||
Pipeline for Pretrained Machine Reader (PMR) finetuned on Named Entity Recognition (NER)
|
||||
|
||||
Examples:
|
||||
|
||||
>>> from modelscope.pipelines import pipeline
|
||||
>>> pipeline_ins = pipeline(
|
||||
>>> task=Tasks.machine_reading_comprehension,
|
||||
>>> model='damo/nlp_roberta_machine-reading-comprehension_for-ner')
|
||||
>>> pipeline_ins('Soccer - Japan get lucky win , China in surprise defeat .')
|
||||
>>> {'ORG': [], 'PER': [], 'LOC': [' Japan', ' China'], 'MISC': []}
|
||||
'''
|
||||
|
||||
def __init__(self,
|
||||
model: Union[Model, str],
|
||||
preprocessor: Preprocessor = None,
|
||||
config_file: str = None,
|
||||
device: str = 'gpu',
|
||||
auto_collate=True,
|
||||
**kwargs):
|
||||
super().__init__(
|
||||
model=model,
|
||||
preprocessor=preprocessor,
|
||||
config_file=config_file,
|
||||
device=device,
|
||||
auto_collate=auto_collate)
|
||||
|
||||
assert isinstance(self.model, Model), \
|
||||
f'please check whether model config exists in {ModelFile.CONFIGURATION}'
|
||||
|
||||
if preprocessor is None:
|
||||
self.preprocessor = Preprocessor.from_pretrained(
|
||||
self.model.model_dir, **kwargs)
|
||||
self.labels = [label for label in self.preprocessor.label2query]
|
||||
|
||||
self.model.eval()
|
||||
|
||||
def forward(
|
||||
self, inputs, **forward_params
|
||||
) -> Union[Dict[str, Any], MachineReadingComprehensionOutput]:
|
||||
with torch.no_grad():
|
||||
outputs = self.model(**inputs)
|
||||
span_logits = outputs['span_logits']
|
||||
|
||||
return MachineReadingComprehensionOutput(
|
||||
span_logits=span_logits,
|
||||
input_ids=inputs['input_ids'],
|
||||
)
|
||||
|
||||
def postprocess(
|
||||
self, inputs: Union[Dict[str, Any], MachineReadingComprehensionOutput]
|
||||
) -> Dict[str, Any]:
|
||||
|
||||
span_preds = inputs['span_logits'] > 0
|
||||
extracted_indices = torch.nonzero(span_preds.long())
|
||||
|
||||
result = {label: [] for label in self.labels}
|
||||
for index in extracted_indices:
|
||||
label = self.labels[index[0]]
|
||||
start = index[1]
|
||||
end = index[2] + 1
|
||||
ids = inputs['input_ids'][index[0], start:end]
|
||||
entity = self.preprocessor.tokenizer.decode(ids)
|
||||
result[label].append(entity)
|
||||
|
||||
return result
|
||||
@@ -46,7 +46,8 @@ if TYPE_CHECKING:
|
||||
CanmtTranslationPreprocessor, DialogueClassificationUsePreprocessor,
|
||||
SiameseUiePreprocessor, DocumentGroundedDialogGeneratePreprocessor,
|
||||
DocumentGroundedDialogRetrievalPreprocessor,
|
||||
DocumentGroundedDialogRerankPreprocessor)
|
||||
DocumentGroundedDialogRerankPreprocessor,
|
||||
MachineReadingComprehensionForNERPreprocessor)
|
||||
from .video import ReadVideoData, MovieSceneSegmentationPreprocessor
|
||||
|
||||
else:
|
||||
@@ -77,7 +78,8 @@ else:
|
||||
'nlp': [
|
||||
'DocumentSegmentationTransformersPreprocessor',
|
||||
'FaqQuestionAnsweringTransformersPreprocessor',
|
||||
'FillMaskPoNetPreprocessor', 'FillMaskTransformersPreprocessor',
|
||||
'FillMaskPoNetPreprocessor',
|
||||
'FillMaskTransformersPreprocessor',
|
||||
'NLPTokenizerPreprocessorBase',
|
||||
'TextRankingTransformersPreprocessor',
|
||||
'RelationExtractionTransformersPreprocessor',
|
||||
@@ -85,26 +87,34 @@ else:
|
||||
'TextGenerationSentencePiecePreprocessor',
|
||||
'TextClassificationTransformersPreprocessor',
|
||||
'TokenClassificationTransformersPreprocessor',
|
||||
'TextErrorCorrectionPreprocessor', 'WordAlignmentPreprocessor',
|
||||
'TextGenerationTransformersPreprocessor', 'Tokenize',
|
||||
'TextErrorCorrectionPreprocessor',
|
||||
'WordAlignmentPreprocessor',
|
||||
'TextGenerationTransformersPreprocessor',
|
||||
'Tokenize',
|
||||
'TextGenerationT5Preprocessor',
|
||||
'WordSegmentationBlankSetToLabelPreprocessor',
|
||||
'MGLMSummarizationPreprocessor', 'CodeGeeXPreprocessor',
|
||||
'MGLMSummarizationPreprocessor',
|
||||
'CodeGeeXPreprocessor',
|
||||
'ZeroShotClassificationTransformersPreprocessor',
|
||||
'TextGenerationJiebaPreprocessor', 'SentencePiecePreprocessor',
|
||||
'NERPreprocessorViet', 'NERPreprocessorThai',
|
||||
'TextGenerationJiebaPreprocessor',
|
||||
'SentencePiecePreprocessor',
|
||||
'NERPreprocessorViet',
|
||||
'NERPreprocessorThai',
|
||||
'WordSegmentationPreprocessorThai',
|
||||
'DialogIntentPredictionPreprocessor', 'DialogModelingPreprocessor',
|
||||
'DialogIntentPredictionPreprocessor',
|
||||
'DialogModelingPreprocessor',
|
||||
'DialogStateTrackingPreprocessor',
|
||||
'ConversationalTextToSqlPreprocessor',
|
||||
'TableQuestionAnsweringPreprocessor',
|
||||
'TranslationEvaluationTransformersPreprocessor',
|
||||
'CanmtTranslationPreprocessor',
|
||||
'DialogueClassificationUsePreprocessor', 'SiameseUiePreprocessor',
|
||||
'DialogueClassificationUsePreprocessor',
|
||||
'SiameseUiePreprocessor',
|
||||
'DialogueClassificationUsePreprocessor',
|
||||
'DocumentGroundedDialogGeneratePreprocessor',
|
||||
'DocumentGroundedDialogRetrievalPreprocessor',
|
||||
'DocumentGroundedDialogRerankPreprocessor'
|
||||
'DocumentGroundedDialogRerankPreprocessor',
|
||||
'MachineReadingComprehensionForNERPreprocessor',
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@@ -36,6 +36,7 @@ if TYPE_CHECKING:
|
||||
from .document_grounded_dialog_generate_preprocessor import DocumentGroundedDialogGeneratePreprocessor
|
||||
from .document_grounded_dialog_retrieval_preprocessor import DocumentGroundedDialogRetrievalPreprocessor
|
||||
from .document_grounded_dialog_rerank_preprocessor import DocumentGroundedDialogRerankPreprocessor
|
||||
from .machine_reading_comprehension_preprocessor import MachineReadingComprehensionForNERPreprocessor
|
||||
else:
|
||||
_import_structure = {
|
||||
'bert_seq_cls_tokenizer': ['Tokenize'],
|
||||
@@ -102,7 +103,10 @@ else:
|
||||
'document_grounded_dialog_retrieval_preprocessor':
|
||||
['DocumentGroundedDialogRetrievalPreprocessor'],
|
||||
'document_grounded_dialog_rerank_preprocessor':
|
||||
['DocumentGroundedDialogRerankPreprocessor']
|
||||
['DocumentGroundedDialogRerankPreprocessor'],
|
||||
'machine_reading_comprehension_preprocessor': [
|
||||
'MachineReadingComprehensionForNERPreprocessor'
|
||||
],
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
@@ -0,0 +1,252 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, Dataset, SequentialSampler
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.tokenization_utils_base import TruncationStrategy
|
||||
from transformers.utils import logging
|
||||
|
||||
from modelscope.metainfo import Preprocessors
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.preprocessors import Preprocessor
|
||||
from modelscope.preprocessors.builder import PREPROCESSORS
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import ConfigFields, Fields, ModeKeys, ModelFile
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
MULTI_SEP_TOKENS_TOKENIZERS_SET = {'roberta', 'camembert', 'bart', 'mpnet'}
|
||||
|
||||
|
||||
@PREPROCESSORS.register_module(
|
||||
Fields.nlp,
|
||||
module_name=Preprocessors.machine_reading_comprehension_for_ner)
|
||||
class MachineReadingComprehensionForNERPreprocessor(Preprocessor):
|
||||
'''
|
||||
Preprocessor for Pretrained Machiner Reader (PMR) finetuned on Named Entity Recognition (NER)
|
||||
|
||||
'''
|
||||
|
||||
def __init__(self, model_dir, label2query=None):
|
||||
super().__init__()
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_dir, use_fast=False)
|
||||
|
||||
if label2query is None:
|
||||
config_path = os.path.join(model_dir, ModelFile.CONFIGURATION)
|
||||
config = Config.from_file(config_path)
|
||||
self.label2query = config[ConfigFields.preprocessor].label2query
|
||||
else:
|
||||
self.label2query = label2query
|
||||
|
||||
def __call__(self, data: str):
|
||||
all_data = []
|
||||
for label in self.label2query:
|
||||
all_data.append({
|
||||
'context': data,
|
||||
'end_position': [],
|
||||
'entity_label': label,
|
||||
'impossible': False,
|
||||
'qas_id': '',
|
||||
'query': self.label2query[label],
|
||||
'span_position': [],
|
||||
'start_position': []
|
||||
})
|
||||
|
||||
all_data = self.prompt(all_data)
|
||||
output = []
|
||||
for data in all_data:
|
||||
output.append(self.encode(data))
|
||||
output = collate_to_max_length_roberta(output)
|
||||
|
||||
output = {
|
||||
'input_ids': output[0],
|
||||
'attention_mask': output[1],
|
||||
'token_type_ids': output[2],
|
||||
}
|
||||
|
||||
return output
|
||||
|
||||
def prompt(self, all_data, var=0):
|
||||
new_datas = []
|
||||
for data in all_data:
|
||||
label = data['entity_label']
|
||||
details = data['query']
|
||||
context = data['context']
|
||||
start_positions = data['start_position']
|
||||
end_positions = data['end_position']
|
||||
words = context.split()
|
||||
assert len(words) == len(context.split(' '))
|
||||
if var == 0:
|
||||
query = '"{}". {}'.format(label, details) # ori
|
||||
elif var == 1:
|
||||
query = 'What are the "{}" entity, where {}'.format(
|
||||
label, details) # variant 1
|
||||
elif var == 2:
|
||||
query = 'Identify the spans (if any) related to "{}" entity. Details: {}'.format(
|
||||
label, details) # variant 2
|
||||
span_positions = {
|
||||
'{};{}'.format(start_positions[i], end_positions[i]):
|
||||
' '.join(words[start_positions[i]:end_positions[i] + 1])
|
||||
for i in range(len(start_positions))
|
||||
}
|
||||
new_data = {
|
||||
'context': words,
|
||||
'end_position': end_positions,
|
||||
'entity_label': label,
|
||||
'impossible': data['impossible'],
|
||||
'qas_id': data['qas_id'],
|
||||
'query': query,
|
||||
'span_position': span_positions,
|
||||
'start_position': start_positions,
|
||||
}
|
||||
new_datas.append(new_data)
|
||||
return new_datas
|
||||
|
||||
def encode(self, data, max_length=512, max_query_length=64):
|
||||
|
||||
tokenizer = self.tokenizer
|
||||
|
||||
query = data['query']
|
||||
context = data['context']
|
||||
start_positions = data['start_position']
|
||||
end_positions = data['end_position']
|
||||
|
||||
tokenizer_type = type(tokenizer).__name__.replace('Tokenizer',
|
||||
'').lower()
|
||||
sequence_added_tokens = (
|
||||
tokenizer.model_max_length - tokenizer.max_len_single_sentence
|
||||
+ 1 if tokenizer_type in MULTI_SEP_TOKENS_TOKENIZERS_SET else
|
||||
tokenizer.model_max_length - tokenizer.max_len_single_sentence)
|
||||
|
||||
tok_to_orig_index = []
|
||||
orig_to_tok_index = []
|
||||
all_doc_tokens = []
|
||||
for (i, token) in enumerate(context):
|
||||
orig_to_tok_index.append(len(all_doc_tokens))
|
||||
if tokenizer.__class__.__name__ in [
|
||||
'RobertaTokenizer',
|
||||
'LongformerTokenizer',
|
||||
'BartTokenizer',
|
||||
'RobertaTokenizerFast',
|
||||
'LongformerTokenizerFast',
|
||||
'BartTokenizerFast',
|
||||
]:
|
||||
sub_tokens = tokenizer.tokenize(token, add_prefix_space=True)
|
||||
elif tokenizer.__class__.__name__ in ['BertTokenizer']:
|
||||
sub_tokens = tokenizer.tokenize(token)
|
||||
elif tokenizer.__class__.__name__ in ['BertWordPieceTokenizer']:
|
||||
sub_tokens = tokenizer.encode(
|
||||
token, add_special_tokens=False).tokens
|
||||
else:
|
||||
sub_tokens = tokenizer.tokenize(token)
|
||||
for sub_token in sub_tokens:
|
||||
tok_to_orig_index.append(i)
|
||||
all_doc_tokens.append(sub_token)
|
||||
|
||||
tok_start_positions = [orig_to_tok_index[x] for x in start_positions]
|
||||
tok_end_positions = []
|
||||
for x in end_positions:
|
||||
if x < len(context) - 1:
|
||||
tok_end_positions.append(orig_to_tok_index[x + 1] - 1)
|
||||
else:
|
||||
tok_end_positions.append(len(all_doc_tokens) - 1)
|
||||
|
||||
truncation = TruncationStrategy.ONLY_SECOND.value
|
||||
padding_strategy = 'do_not_pad'
|
||||
|
||||
truncated_query = tokenizer.encode(
|
||||
query,
|
||||
add_special_tokens=False,
|
||||
truncation=True,
|
||||
max_length=max_query_length)
|
||||
encoded_dict = tokenizer.encode_plus( # TODO(thom) update this logic
|
||||
truncated_query,
|
||||
all_doc_tokens,
|
||||
truncation=truncation,
|
||||
padding=padding_strategy,
|
||||
max_length=max_length,
|
||||
return_overflowing_tokens=True,
|
||||
return_token_type_ids=True,
|
||||
)
|
||||
tokens = encoded_dict['input_ids']
|
||||
type_ids = encoded_dict['token_type_ids']
|
||||
attn_mask = encoded_dict['attention_mask']
|
||||
|
||||
# find new start_positions/end_positions, considering
|
||||
# 1. we add query tokens at the beginning
|
||||
# 2. special tokens
|
||||
doc_offset = len(truncated_query) + sequence_added_tokens
|
||||
new_start_positions = [
|
||||
x + doc_offset for x in tok_start_positions
|
||||
if (x + doc_offset) < max_length - 1
|
||||
]
|
||||
new_end_positions = [
|
||||
x + doc_offset if
|
||||
(x + doc_offset) < max_length - 1 else max_length - 2
|
||||
for x in tok_end_positions
|
||||
]
|
||||
new_end_positions = new_end_positions[:len(new_start_positions)]
|
||||
|
||||
label_mask = [0] * doc_offset + [1] * (len(tokens) - doc_offset
|
||||
- 1) + [0]
|
||||
|
||||
assert all(label_mask[p] != 0 for p in new_start_positions)
|
||||
assert all(label_mask[p] != 0 for p in new_end_positions)
|
||||
|
||||
assert len(label_mask) == len(tokens)
|
||||
|
||||
seq_len = len(tokens)
|
||||
match_labels = torch.zeros([seq_len, seq_len], dtype=torch.long)
|
||||
for start, end in zip(new_start_positions, new_end_positions):
|
||||
if start >= seq_len or end >= seq_len:
|
||||
continue
|
||||
match_labels[start, end] = 1
|
||||
|
||||
return [
|
||||
torch.LongTensor(tokens),
|
||||
torch.LongTensor(attn_mask),
|
||||
torch.LongTensor(type_ids),
|
||||
torch.LongTensor(label_mask),
|
||||
match_labels,
|
||||
]
|
||||
|
||||
|
||||
def collate_to_max_length_roberta(batch):
|
||||
"""
|
||||
adapted form https://github.com/ShannonAI/mrc-for-flat-nested-ner
|
||||
pad to maximum length of this batch
|
||||
Args:
|
||||
batch: a batch of samples, each contains a list of field data(Tensor):
|
||||
tokens, token_type_ids, start_labels, end_labels, start_label_mask,
|
||||
end_label_mask, match_labels, sample_idx, label_idx
|
||||
Returns:
|
||||
output: list of field batched data, which shape is [batch, max_length]
|
||||
"""
|
||||
batch_size = len(batch)
|
||||
max_length = max(x[0].shape[0] for x in batch)
|
||||
output = []
|
||||
|
||||
for field_idx in range(4):
|
||||
if field_idx == 0:
|
||||
pad_output = torch.full([batch_size, max_length],
|
||||
1,
|
||||
dtype=batch[0][field_idx].dtype)
|
||||
else:
|
||||
pad_output = torch.full([batch_size, max_length],
|
||||
0,
|
||||
dtype=batch[0][field_idx].dtype)
|
||||
for sample_idx in range(batch_size):
|
||||
data = batch[sample_idx][field_idx]
|
||||
pad_output[sample_idx][:data.shape[0]] = data
|
||||
output.append(pad_output)
|
||||
|
||||
pad_match_labels = torch.zeros([batch_size, max_length, max_length],
|
||||
dtype=torch.long)
|
||||
for sample_idx in range(batch_size):
|
||||
data = batch[sample_idx][4]
|
||||
pad_match_labels[sample_idx, :data.shape[1], :data.shape[1]] = data
|
||||
output.append(pad_match_labels)
|
||||
|
||||
return output
|
||||
@@ -37,24 +37,31 @@ from modelscope.utils.torch_utils import is_dist
|
||||
|
||||
class CustomCheckpointProcessor(CheckpointProcessor):
|
||||
|
||||
def __init__(self, modifier_token, modifier_token_id):
|
||||
def __init__(self,
|
||||
modifier_token,
|
||||
modifier_token_id,
|
||||
torch_type=torch.float32):
|
||||
"""Checkpoint processor for custom diffusion.
|
||||
|
||||
Args:
|
||||
modifier_token: The token to use as a modifier for the concept.
|
||||
modifier_token_id: The modifier token id for the concept.
|
||||
torch_type: The torch type, default is float32.
|
||||
|
||||
"""
|
||||
self.modifier_token = modifier_token
|
||||
self.modifier_token_id = modifier_token_id
|
||||
self.torch_type = torch_type
|
||||
|
||||
def save_checkpoints(self,
|
||||
trainer,
|
||||
checkpoint_path_prefix,
|
||||
output_dir,
|
||||
meta=None):
|
||||
meta=None,
|
||||
save_optimizers=True):
|
||||
"""Save the state dict for custom diffusion model.
|
||||
"""
|
||||
trainer.model.unet = trainer.model.unet.to(torch.float32)
|
||||
trainer.model.unet = trainer.model.unet.to(self.torch_type)
|
||||
trainer.model.unet.save_attn_procs(output_dir)
|
||||
|
||||
learned_embeds = trainer.model.text_encoder.get_input_embeddings(
|
||||
@@ -281,6 +288,7 @@ class CustomDiffusionTrainer(EpochBasedTrainer):
|
||||
instance_prompt = kwargs.pop('instance_prompt', 'a photo of sks dog')
|
||||
class_prompt = kwargs.pop('class_prompt', 'dog')
|
||||
class_data_dir = kwargs.pop('class_data_dir', '/tmp/class_data')
|
||||
self.torch_type = kwargs.pop('torch_type', torch.float32)
|
||||
self.real_prior = kwargs.pop('real_prior', False)
|
||||
self.num_class_images = kwargs.pop('num_class_images', 200)
|
||||
self.resolution = kwargs.pop('resolution', 512)
|
||||
@@ -387,7 +395,7 @@ class CustomDiffusionTrainer(EpochBasedTrainer):
|
||||
self.hooks))[0]
|
||||
ckpt_hook.set_processor(
|
||||
CustomCheckpointProcessor(self.modifier_token,
|
||||
self.modifier_token_id))
|
||||
self.modifier_token_id, self.torch_type))
|
||||
|
||||
# Add new Custom Diffusion weights to the attention layers
|
||||
attention_class = CustomDiffusionAttnProcessor
|
||||
@@ -477,7 +485,7 @@ class CustomDiffusionTrainer(EpochBasedTrainer):
|
||||
size=self.resolution,
|
||||
mask_size=self.model.vae.encode(
|
||||
torch.randn(1, 3, self.resolution,
|
||||
self.resolution).to(dtype=torch.float32).to(
|
||||
self.resolution).to(dtype=self.torch_type).to(
|
||||
self.device)).latent_dist.sample().size()[-1],
|
||||
center_crop=self.center_crop,
|
||||
num_class_images=self.num_class_images,
|
||||
@@ -534,8 +542,8 @@ class CustomDiffusionTrainer(EpochBasedTrainer):
|
||||
if cur_class_images < self.num_class_images:
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
self.model_dir,
|
||||
torch_dtype=torch.float32,
|
||||
safety_checker=None,
|
||||
torch_dtype=self.torch_type,
|
||||
revision=None,
|
||||
)
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
@@ -656,7 +664,7 @@ class CustomDiffusionTrainer(EpochBasedTrainer):
|
||||
batch = next(self.iter_train_dataloader)
|
||||
# Convert images to latent space
|
||||
latents = self.model.vae.encode(batch['pixel_values'].to(
|
||||
dtype=torch.float32).to(self.device)).latent_dist.sample()
|
||||
dtype=self.torch_type).to(self.device)).latent_dist.sample()
|
||||
latents = latents * self.model.vae.config.scaling_factor
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
|
||||
@@ -32,8 +32,16 @@ from modelscope.utils.torch_utils import is_dist
|
||||
|
||||
class DreamboothCheckpointProcessor(CheckpointProcessor):
|
||||
|
||||
def __init__(self, model_dir):
|
||||
def __init__(self, model_dir, torch_type=torch.float32):
|
||||
"""Checkpoint processor for dreambooth diffusion.
|
||||
|
||||
Args:
|
||||
model_dir: The model id or local model dir.
|
||||
torch_type: The torch type, default is float32.
|
||||
|
||||
"""
|
||||
self.model_dir = model_dir
|
||||
self.torch_type = torch_type
|
||||
|
||||
def save_checkpoints(self,
|
||||
trainer,
|
||||
@@ -49,6 +57,7 @@ class DreamboothCheckpointProcessor(CheckpointProcessor):
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
self.model_dir,
|
||||
unet=trainer.model.unet,
|
||||
torch_type=self.torch_type,
|
||||
**pipeline_args,
|
||||
)
|
||||
scheduler_args = {}
|
||||
@@ -174,6 +183,7 @@ class DreamboothDiffusionTrainer(EpochBasedTrainer):
|
||||
prior_loss_weight: the weight of the prior loss.
|
||||
|
||||
"""
|
||||
self.torch_type = kwargs.pop('torch_type', torch.float32)
|
||||
self.with_prior_preservation = kwargs.pop('with_prior_preservation',
|
||||
False)
|
||||
self.instance_prompt = kwargs.pop('instance_prompt',
|
||||
@@ -219,7 +229,7 @@ class DreamboothDiffusionTrainer(EpochBasedTrainer):
|
||||
warnings.warn('Multiple GPU inference not yet supported.')
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
self.model_dir,
|
||||
torch_dtype=torch.float32,
|
||||
torch_dtype=self.torch_type,
|
||||
safety_checker=None,
|
||||
revision=None,
|
||||
)
|
||||
@@ -309,7 +319,8 @@ class DreamboothDiffusionTrainer(EpochBasedTrainer):
|
||||
input_ids = batch['input_ids'].to(self.device)
|
||||
with torch.no_grad():
|
||||
latents = self.model.vae.encode(
|
||||
target_prior.to(dtype=torch.float32)).latent_dist.sample()
|
||||
target_prior.to(
|
||||
dtype=self.torch_type)).latent_dist.sample()
|
||||
latents = latents * self.model.vae.config.scaling_factor
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
|
||||
@@ -17,6 +17,15 @@ from modelscope.utils.config import ConfigDict
|
||||
|
||||
class LoraDiffusionCheckpointProcessor(CheckpointProcessor):
|
||||
|
||||
def __init__(self, torch_type=torch.float32):
|
||||
"""Checkpoint processor for lora diffusion.
|
||||
|
||||
Args:
|
||||
torch_type: The torch type, default is float32.
|
||||
|
||||
"""
|
||||
self.torch_type = torch_type
|
||||
|
||||
def save_checkpoints(self,
|
||||
trainer,
|
||||
checkpoint_path_prefix,
|
||||
@@ -25,7 +34,7 @@ class LoraDiffusionCheckpointProcessor(CheckpointProcessor):
|
||||
save_optimizers=True):
|
||||
"""Save the state dict for lora tune model.
|
||||
"""
|
||||
trainer.model.unet = trainer.model.unet.to(torch.float32)
|
||||
trainer.model.unet = trainer.model.unet.to(self.torch_type)
|
||||
trainer.model.unet.save_attn_procs(output_dir)
|
||||
|
||||
|
||||
@@ -38,15 +47,18 @@ class LoraDiffusionTrainer(EpochBasedTrainer):
|
||||
|
||||
Args:
|
||||
lora_rank: The rank size of lora intermediate linear.
|
||||
torch_type: The torch type, default is float32.
|
||||
|
||||
"""
|
||||
lora_rank = kwargs.pop('lora_rank', 4)
|
||||
torch_type = kwargs.pop('torch_type', torch.float32)
|
||||
|
||||
# set lora save checkpoint processor
|
||||
ckpt_hook = list(
|
||||
filter(lambda hook: isinstance(hook, CheckpointHook),
|
||||
self.hooks))[0]
|
||||
ckpt_hook.set_processor(LoraDiffusionCheckpointProcessor())
|
||||
ckpt_hook.set_processor(
|
||||
LoraDiffusionCheckpointProcessor(torch_type=torch_type))
|
||||
# Set correct lora layers
|
||||
lora_attn_procs = {}
|
||||
for name in self.model.unet.attn_processors.keys():
|
||||
|
||||
@@ -213,6 +213,7 @@ class NLPTasks(object):
|
||||
document_grounded_dialog_retrieval = 'document-grounded-dialog-retrieval'
|
||||
document_grounded_dialog_rerank = 'document-grounded-dialog-rerank'
|
||||
document_grounded_dialog_generate = 'document-grounded-dialog-generate'
|
||||
machine_reading_comprehension = 'machine-reading-comprehension'
|
||||
|
||||
|
||||
class AudioTasks(object):
|
||||
@@ -255,6 +256,8 @@ class MultiModalTasks(object):
|
||||
text_to_video_synthesis = 'text-to-video-synthesis'
|
||||
efficient_diffusion_tuning = 'efficient-diffusion-tuning'
|
||||
multimodal_dialogue = 'multimodal-dialogue'
|
||||
image_to_video = 'image-to-video'
|
||||
video_to_video = 'video-to-video'
|
||||
|
||||
|
||||
class ScienceTasks(object):
|
||||
|
||||
@@ -13,39 +13,45 @@ class EfficientDiffusionTuningTest(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.task = Tasks.efficient_diffusion_tuning
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_efficient_diffusion_tuning_lora_run_pipeline(self):
|
||||
model_id = 'damo/multi-modal_efficient-diffusion-tuning-lora'
|
||||
model_revision = 'v1.0.2'
|
||||
inputs = {'prompt': 'pale golden rod circle with old lace background'}
|
||||
edt_pipeline = pipeline(self.task, model_id)
|
||||
edt_pipeline = pipeline(
|
||||
self.task, model_id, model_revision=model_revision)
|
||||
result = edt_pipeline(inputs)
|
||||
print(f'Efficient-diffusion-tuning-lora output: {result}.')
|
||||
|
||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_efficient_diffusion_tuning_lora_load_model_from_pretrained(self):
|
||||
model_id = 'damo/multi-modal_efficient-diffusion-tuning-lora'
|
||||
model = Model.from_pretrained(model_id)
|
||||
model_revision = 'v1.0.2'
|
||||
model = Model.from_pretrained(model_id, model_revision=model_revision)
|
||||
self.assertTrue(model.__class__ == EfficientStableDiffusion)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_efficient_diffusion_tuning_control_lora_run_pipeline(self):
|
||||
# TODO: to be fixed in the future
|
||||
model_id = 'damo/multi-modal_efficient-diffusion-tuning-control-lora'
|
||||
model_revision = 'v1.0.2'
|
||||
inputs = {
|
||||
'prompt':
|
||||
'pale golden rod circle with old lace background',
|
||||
'cond':
|
||||
'data/test/images/efficient_diffusion_tuning_sd_control_lora_source.png'
|
||||
}
|
||||
edt_pipeline = pipeline(self.task, model_id)
|
||||
edt_pipeline = pipeline(
|
||||
self.task, model_id, model_revision=model_revision)
|
||||
result = edt_pipeline(inputs)
|
||||
print(f'Efficient-diffusion-tuning-control-lora output: {result}.')
|
||||
|
||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_efficient_diffusion_tuning_control_lora_load_model_from_pretrained(
|
||||
self):
|
||||
model_id = 'damo/multi-modal_efficient-diffusion-tuning-control-lora'
|
||||
model = Model.from_pretrained(model_id)
|
||||
model_revision = 'v1.0.2'
|
||||
model = Model.from_pretrained(model_id, model_revision=model_revision)
|
||||
self.assertTrue(model.__class__ == EfficientStableDiffusion)
|
||||
|
||||
|
||||
|
||||
@@ -16,14 +16,16 @@ class EfficientDiffusionTuningTestSwift(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.task = Tasks.efficient_diffusion_tuning
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_efficient_diffusion_tuning_swift_lora_run_pipeline(self):
|
||||
model_id = 'damo/multi-modal_efficient-diffusion-tuning-swift-lora'
|
||||
model_revision = 'v1.0.2'
|
||||
inputs = {
|
||||
'prompt':
|
||||
'a street scene with a cafe and a restaurant sign in anime style'
|
||||
}
|
||||
sd_tuner_pipeline = pipeline(self.task, model_id)
|
||||
sd_tuner_pipeline = pipeline(
|
||||
self.task, model_id, model_revision=model_revision)
|
||||
result = sd_tuner_pipeline(inputs, generator_seed=0)
|
||||
output_image_path = tempfile.NamedTemporaryFile(suffix='.png').name
|
||||
cv2.imwrite(output_image_path, result['output_imgs'][0])
|
||||
@@ -31,21 +33,24 @@ class EfficientDiffusionTuningTestSwift(unittest.TestCase):
|
||||
f'Efficient-diffusion-tuning-swift-lora output: {output_image_path}'
|
||||
)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_efficient_diffusion_tuning_swift_lora_load_model_from_pretrained(
|
||||
self):
|
||||
model_id = 'damo/multi-modal_efficient-diffusion-tuning-swift-lora'
|
||||
model = Model.from_pretrained(model_id)
|
||||
model_revision = 'v1.0.2'
|
||||
model = Model.from_pretrained(model_id, model_revision=model_revision)
|
||||
self.assertTrue(model.__class__ == EfficientStableDiffusion)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_efficient_diffusion_tuning_swift_adapter_run_pipeline(self):
|
||||
model_id = 'damo/multi-modal_efficient-diffusion-tuning-swift-adapter'
|
||||
model_revision = 'v1.0.2'
|
||||
inputs = {
|
||||
'prompt':
|
||||
'a street scene with a cafe and a restaurant sign in anime style'
|
||||
}
|
||||
sd_tuner_pipeline = pipeline(self.task, model_id)
|
||||
sd_tuner_pipeline = pipeline(
|
||||
self.task, model_id, model_revision=model_revision)
|
||||
result = sd_tuner_pipeline(inputs, generator_seed=0)
|
||||
output_image_path = tempfile.NamedTemporaryFile(suffix='.png').name
|
||||
cv2.imwrite(output_image_path, result['output_imgs'][0])
|
||||
@@ -53,21 +58,24 @@ class EfficientDiffusionTuningTestSwift(unittest.TestCase):
|
||||
f'Efficient-diffusion-tuning-swift-adapter output: {output_image_path}'
|
||||
)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_efficient_diffusion_tuning_swift_adapter_load_model_from_pretrained(
|
||||
self):
|
||||
model_id = 'damo/multi-modal_efficient-diffusion-tuning-swift-adapter'
|
||||
model = Model.from_pretrained(model_id)
|
||||
model_revision = 'v1.0.2'
|
||||
model = Model.from_pretrained(model_id, model_revision=model_revision)
|
||||
self.assertTrue(model.__class__ == EfficientStableDiffusion)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_efficient_diffusion_tuning_swift_prompt_run_pipeline(self):
|
||||
model_id = 'damo/multi-modal_efficient-diffusion-tuning-swift-prompt'
|
||||
model_revision = 'v1.0.2'
|
||||
inputs = {
|
||||
'prompt':
|
||||
'a street scene with a cafe and a restaurant sign in anime style'
|
||||
}
|
||||
sd_tuner_pipeline = pipeline(self.task, model_id)
|
||||
sd_tuner_pipeline = pipeline(
|
||||
self.task, model_id, model_revision=model_revision)
|
||||
result = sd_tuner_pipeline(inputs, generator_seed=0)
|
||||
output_image_path = tempfile.NamedTemporaryFile(suffix='.png').name
|
||||
cv2.imwrite(output_image_path, result['output_imgs'][0])
|
||||
@@ -75,11 +83,12 @@ class EfficientDiffusionTuningTestSwift(unittest.TestCase):
|
||||
f'Efficient-diffusion-tuning-swift-prompt output: {output_image_path}'
|
||||
)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_efficient_diffusion_tuning_swift_prompt_load_model_from_pretrained(
|
||||
self):
|
||||
model_id = 'damo/multi-modal_efficient-diffusion-tuning-swift-prompt'
|
||||
model = Model.from_pretrained(model_id)
|
||||
model_revision = 'v1.0.2'
|
||||
model = Model.from_pretrained(model_id, model_revision=model_revision)
|
||||
self.assertTrue(model.__class__ == EfficientStableDiffusion)
|
||||
|
||||
|
||||
|
||||
28
tests/pipelines/test_image2video.py
Normal file
28
tests/pipelines/test_image2video.py
Normal file
@@ -0,0 +1,28 @@
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
from modelscope.models import Model
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import DownloadMode, Tasks
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
|
||||
class Image2VideoTest(unittest.TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.task = Tasks.image_to_video
|
||||
self.model_id = 'damo/Image-to-Video'
|
||||
self.path = 'https://video-generation-wulanchabu.oss-cn-wulanchabu.aliyuncs.com/baishao/test.jpeg'
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_run_with_model_from_modelhub(self):
|
||||
pipe = pipeline(task=self.task, model=self.model_id)
|
||||
|
||||
output_video_path = pipe(
|
||||
self.path, output_video='./output.mp4')[OutputKeys.OUTPUT_VIDEO]
|
||||
print(output_video_path)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
59
tests/pipelines/test_machine_reading_comprehension.py
Normal file
59
tests/pipelines/test_machine_reading_comprehension.py
Normal file
@@ -0,0 +1,59 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import unittest
|
||||
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.models import Model
|
||||
from modelscope.models.nlp import ModelForMachineReadingComprehension
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.pipelines.nlp import MachineReadingComprehensionForNERPipeline
|
||||
from modelscope.preprocessors import \
|
||||
MachineReadingComprehensionForNERPreprocessor
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
|
||||
class MachineReadingComprehensionTest(unittest.TestCase):
|
||||
sentence = 'Soccer - Japan get lucky win , China in surprise defeat .'
|
||||
model_id = 'damo/nlp_roberta_machine-reading-comprehension_for-ner'
|
||||
|
||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||
def test_run_mrc_for_ner_by_direct_model_download(self):
|
||||
cache_path = snapshot_download(self.model_id)
|
||||
tokenizer = MachineReadingComprehensionForNERPreprocessor(cache_path)
|
||||
model = ModelForMachineReadingComprehension.from_pretrained(cache_path)
|
||||
pipeline1 = MachineReadingComprehensionForNERPipeline(
|
||||
model, preprocessor=tokenizer)
|
||||
|
||||
pipeline2 = pipeline(
|
||||
Tasks.machine_reading_comprehension,
|
||||
model=model,
|
||||
preprocessor=tokenizer)
|
||||
print(f'sentence: {self.sentence}\n'
|
||||
f'pipeline1:{pipeline1(input=self.sentence)}')
|
||||
print()
|
||||
print(f'pipeline2: {pipeline2(input=self.sentence)}')
|
||||
# {'ORG': [], 'PER': [], 'LOC': [' Japan', ' China'], 'MISC': []}
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_run_mrc_for_ner_with_model_from_modelhub(self):
|
||||
model = Model.from_pretrained(self.model_id)
|
||||
tokenizer = MachineReadingComprehensionForNERPreprocessor(
|
||||
model.model_dir)
|
||||
pipeline_ins = pipeline(
|
||||
task=Tasks.machine_reading_comprehension,
|
||||
model=model,
|
||||
preprocessor=tokenizer)
|
||||
print(f'sentence: {self.sentence}\n'
|
||||
f'pipeline:{pipeline_ins(input=self.sentence)}')
|
||||
# {'ORG': [], 'PER': [], 'LOC': [' Japan', ' China'], 'MISC': []}
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_mrc_for_ner_with_model_name(self):
|
||||
pipeline_ins = pipeline(
|
||||
task=Tasks.machine_reading_comprehension, model=self.model_id)
|
||||
print(pipeline_ins(input=self.sentence))
|
||||
# {'ORG': [], 'PER': [], 'LOC': [' Japan', ' China'], 'MISC': []}
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -52,7 +52,7 @@ class Text2360PanoramaImageTest(unittest.TestCase):
|
||||
print(
|
||||
'pipeline: the output image path is {}'.format(output_image_path))
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||
def test_run_with_model_from_modelhub(self):
|
||||
output_image_path = tempfile.NamedTemporaryFile(suffix='.png').name
|
||||
pipeline_ins = pipeline(
|
||||
|
||||
32
tests/pipelines/test_video2video.py
Normal file
32
tests/pipelines/test_video2video.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
from modelscope.models import Model
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
|
||||
class Video2VideoTest(unittest.TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.task = Tasks.video_to_video
|
||||
self.model_id = 'damo/Video-to-Video'
|
||||
self.path = 'https://video-generation-wulanchabu.oss-cn-wulanchabu.aliyuncs.com/baishao/test.mp4'
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_run_with_model_from_modelhub(self):
|
||||
pipe = pipeline(task=self.task, model=self.model_id)
|
||||
p_input = {
|
||||
'video_path': self.path,
|
||||
'text': 'A panda is surfing on the sea'
|
||||
}
|
||||
|
||||
output_video_path = pipe(
|
||||
p_input, output_video='./output.mp4')[OutputKeys.OUTPUT_VIDEO]
|
||||
print(output_video_path)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -125,16 +125,21 @@ def get_current_branch():
|
||||
|
||||
|
||||
def get_modified_files():
|
||||
cmd = ['git', 'diff', '--name-only', 'origin/master...']
|
||||
cmd_output = run_command_get_output(cmd)
|
||||
logger.info('Modified files: ')
|
||||
logger.info(cmd_output)
|
||||
if 'PR_CHANGED_FILES' in os.environ and os.environ[
|
||||
'PR_CHANGED_FILES'] != '':
|
||||
logger.info('Getting PR modified files.')
|
||||
# get modify file from environment
|
||||
diff_files = os.environ['PR_CHANGED_FILES'].replace('#', '\n')
|
||||
else:
|
||||
cmd = ['git', 'diff', '--name-only', 'origin/master...']
|
||||
diff_files = run_command_get_output(cmd)
|
||||
logger.info('Diff files: ')
|
||||
logger.info(diff_files)
|
||||
modified_files = []
|
||||
# remove the deleted file.
|
||||
for diff_file in cmd_output.splitlines():
|
||||
if os.path.exists(diff_file):
|
||||
modified_files.append(diff_file)
|
||||
|
||||
for diff_file in diff_files.splitlines():
|
||||
if os.path.exists(diff_file.strip()):
|
||||
modified_files.append(diff_file.strip())
|
||||
return modified_files
|
||||
|
||||
|
||||
|
||||
@@ -42,6 +42,7 @@ class TestEfficientDiffusionTuningTrainer(unittest.TestCase):
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_efficient_diffusion_tuning_lora_train(self):
|
||||
model_id = 'damo/multi-modal_efficient-diffusion-tuning-lora'
|
||||
model_revision = 'v1.0.2'
|
||||
|
||||
def cfg_modify_fn(cfg):
|
||||
cfg.train.max_epochs = self.max_epochs
|
||||
@@ -51,6 +52,7 @@ class TestEfficientDiffusionTuningTrainer(unittest.TestCase):
|
||||
|
||||
kwargs = dict(
|
||||
model=model_id,
|
||||
model_revision=model_revision,
|
||||
work_dir=self.tmp_dir,
|
||||
train_dataset=self.train_dataset,
|
||||
eval_dataset=self.eval_dataset,
|
||||
@@ -70,6 +72,7 @@ class TestEfficientDiffusionTuningTrainer(unittest.TestCase):
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_efficient_diffusion_tuning_lora_eval(self):
|
||||
model_id = 'damo/multi-modal_efficient-diffusion-tuning-lora'
|
||||
model_revision = 'v1.0.2'
|
||||
|
||||
def cfg_modify_fn(cfg):
|
||||
cfg.model.inference = False
|
||||
@@ -77,6 +80,7 @@ class TestEfficientDiffusionTuningTrainer(unittest.TestCase):
|
||||
|
||||
kwargs = dict(
|
||||
model=model_id,
|
||||
model_revision=model_revision,
|
||||
work_dir=self.tmp_dir,
|
||||
train_dataset=None,
|
||||
eval_dataset=self.eval_dataset,
|
||||
@@ -90,6 +94,7 @@ class TestEfficientDiffusionTuningTrainer(unittest.TestCase):
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_efficient_diffusion_tuning_control_lora_train(self):
|
||||
model_id = 'damo/multi-modal_efficient-diffusion-tuning-control-lora'
|
||||
model_revision = 'v1.0.2'
|
||||
|
||||
def cfg_modify_fn(cfg):
|
||||
cfg.train.max_epochs = self.max_epochs
|
||||
@@ -99,6 +104,7 @@ class TestEfficientDiffusionTuningTrainer(unittest.TestCase):
|
||||
|
||||
kwargs = dict(
|
||||
model=model_id,
|
||||
model_revision=model_revision,
|
||||
work_dir=self.tmp_dir,
|
||||
train_dataset=self.train_dataset,
|
||||
eval_dataset=self.eval_dataset,
|
||||
@@ -119,6 +125,7 @@ class TestEfficientDiffusionTuningTrainer(unittest.TestCase):
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_efficient_diffusion_tuning_control_lora_eval(self):
|
||||
model_id = 'damo/multi-modal_efficient-diffusion-tuning-control-lora'
|
||||
model_revision = 'v1.0.2'
|
||||
|
||||
def cfg_modify_fn(cfg):
|
||||
cfg.model.inference = False
|
||||
@@ -126,6 +133,7 @@ class TestEfficientDiffusionTuningTrainer(unittest.TestCase):
|
||||
|
||||
kwargs = dict(
|
||||
model=model_id,
|
||||
model_revision=model_revision,
|
||||
work_dir=self.tmp_dir,
|
||||
train_dataset=None,
|
||||
eval_dataset=self.eval_dataset,
|
||||
|
||||
@@ -33,9 +33,10 @@ class TestEfficientDiffusionTuningTrainerSwift(unittest.TestCase):
|
||||
shutil.rmtree(self.tmp_dir)
|
||||
super().tearDown()
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_efficient_diffusion_tuning_swift_lora_train(self):
|
||||
model_id = 'damo/multi-modal_efficient-diffusion-tuning-swift-lora'
|
||||
model_revision = 'v1.0.2'
|
||||
|
||||
def cfg_modify_fn(cfg):
|
||||
cfg.train.max_epochs = self.max_epochs
|
||||
@@ -47,6 +48,7 @@ class TestEfficientDiffusionTuningTrainerSwift(unittest.TestCase):
|
||||
|
||||
kwargs = dict(
|
||||
model=model_id,
|
||||
model_revision=model_revision,
|
||||
work_dir=self.tmp_dir,
|
||||
train_dataset=self.train_dataset,
|
||||
cfg_modify_fn=cfg_modify_fn)
|
||||
@@ -60,9 +62,10 @@ class TestEfficientDiffusionTuningTrainerSwift(unittest.TestCase):
|
||||
self.assertIn(f'{trainer.timestamp}.log.json', results_files)
|
||||
self.assertIn(f'epoch_{self.max_epochs}.pth', results_files)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_efficient_diffusion_tuning_swift_adapter_train(self):
|
||||
model_id = 'damo/multi-modal_efficient-diffusion-tuning-swift-adapter'
|
||||
model_revision = 'v1.0.2'
|
||||
|
||||
def cfg_modify_fn(cfg):
|
||||
cfg.train.max_epochs = self.max_epochs
|
||||
@@ -74,6 +77,7 @@ class TestEfficientDiffusionTuningTrainerSwift(unittest.TestCase):
|
||||
|
||||
kwargs = dict(
|
||||
model=model_id,
|
||||
model_revision=model_revision,
|
||||
work_dir=self.tmp_dir,
|
||||
train_dataset=self.train_dataset,
|
||||
cfg_modify_fn=cfg_modify_fn)
|
||||
@@ -87,9 +91,10 @@ class TestEfficientDiffusionTuningTrainerSwift(unittest.TestCase):
|
||||
self.assertIn(f'{trainer.timestamp}.log.json', results_files)
|
||||
self.assertIn(f'epoch_{self.max_epochs}.pth', results_files)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_efficient_diffusion_tuning_swift_prompt_train(self):
|
||||
model_id = 'damo/multi-modal_efficient-diffusion-tuning-swift-prompt'
|
||||
model_revision = 'v1.0.2'
|
||||
|
||||
def cfg_modify_fn(cfg):
|
||||
cfg.train.max_epochs = self.max_epochs
|
||||
@@ -101,6 +106,7 @@ class TestEfficientDiffusionTuningTrainerSwift(unittest.TestCase):
|
||||
|
||||
kwargs = dict(
|
||||
model=model_id,
|
||||
model_revision=model_revision,
|
||||
work_dir=self.tmp_dir,
|
||||
train_dataset=self.train_dataset,
|
||||
cfg_modify_fn=cfg_modify_fn)
|
||||
|
||||
Reference in New Issue
Block a user