Merge branch 'release/1.9' of gitlab.alibaba-inc.com:Ali-MaaS/MaaS-lib into master-github

This commit is contained in:
mulin.lyh
2023-09-05 17:11:09 +08:00
143 changed files with 12116 additions and 320 deletions

View File

@@ -3,6 +3,7 @@
BASE_CPU_IMAGE=reg.docker.alibaba-inc.com/modelscope/ubuntu:20.04
BASE_GPU_CUDA113_IMAGE=reg.docker.alibaba-inc.com/modelscope/ubuntu:20.04-cuda11.3.0-cudnn8-devel
BASE_GPU_CUDA117_IMAGE=reg.docker.alibaba-inc.com/modelscope/ubuntu:20.04-cuda11.7.1-cudnn8-devel
BASE_GPU_CUDA118_IMAGE=reg.docker.alibaba-inc.com/modelscope/ubuntu:20.04-cuda11.8.0-cudnn8-devel
MODELSCOPE_REPO_ADDRESS=reg.docker.alibaba-inc.com/modelscope/modelscope
python_version=3.7.13
torch_version=1.11.0
@@ -73,6 +74,10 @@ elif [ "$cuda_version" == 11.7.1 ]; then
echo "Building base image cuda11.7.1"
cudatoolkit_version=cu117
BASE_GPU_IMAGE=$BASE_GPU_CUDA117_IMAGE
elif [ "$cuda_version" == 11.8.0 ]; then
echo "Building base image cuda11.8.0"
cudatoolkit_version=cu118
BASE_GPU_IMAGE=$BASE_GPU_CUDA118_IMAGE
else
echo "Unsupport cuda version: $cuda_version"
exit 1

View File

@@ -42,6 +42,8 @@ for i in "$@"; do
cudatoolkit_version=11.3
elif [ "$cuda_version" == "11.7.1" ]; then
cudatoolkit_version=11.7
elif [ "$cuda_version" == "11.8.0" ]; then
cudatoolkit_version=11.8
else
echo "Unsupport cuda version $cuda_version"
exit 1

View File

@@ -9,7 +9,7 @@ cpu_sets_arr=($cpu_sets)
is_get_file_lock=false
CI_COMMAND=${CI_COMMAND:-bash .dev_scripts/ci_container_test.sh python tests/run.py --parallel 2 --run_config tests/run_config.yaml}
echo "ci command: $CI_COMMAND"
PR_CHANGED_FILES="${PR_CHANGED_FILES:-''}"
PR_CHANGED_FILES="${PR_CHANGED_FILES:-}"
echo "PR modified files: $PR_CHANGED_FILES"
PR_CHANGED_FILES=${PR_CHANGED_FILES//[ ]/#}
echo "PR_CHANGED_FILES: $PR_CHANGED_FILES"

View File

@@ -1,6 +1,9 @@
ARG BASE_IMAGE=reg.docker.alibaba-inc.com/modelscope/modelscope:ubuntu20.04-cuda11.3.0-py37-torch1.11.0-tf1.15.5-base
FROM $BASE_IMAGE
RUN apt-get update && apt-get install -y iputils-ping net-tools iproute2 && \
apt-get clean && \
rm -rf /var/lib/apt/lists/*
# install modelscope
COPY requirements /var/modelscope
RUN pip install --no-cache-dir --upgrade pip && \
@@ -31,9 +34,9 @@ RUN pip install --no-cache-dir mpi4py paint_ldm \
# for cpu install cpu version faiss, faiss depends on blas lib, we install libopenblas TODO rename gpu or cpu version faiss
RUN if [ "$USE_GPU" = "True" ] ; then \
pip install --no-cache-dir funtextprocessing kwsbp==0.0.6 faiss==1.7.2 safetensors typeguard==2.13.3 scikit-learn 'pandas<1.4.0' librosa==0.9.2 funasr -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html; \
pip install --no-cache-dir funtextprocessing kwsbp==0.0.6 faiss==1.7.2 safetensors typeguard==2.13.3 scikit-learn librosa==0.9.2 funasr -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html; \
else \
pip install --no-cache-dir funtextprocessing kwsbp==0.0.6 https://modelscope.oss-cn-beijing.aliyuncs.com/releases/dependencies/faiss-1.7.2-py37-none-linux_x86_64.whl safetensors typeguard==2.13.3 scikit-learn 'pandas<1.4.0' librosa==0.9.2 funasr -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html; \
pip install --no-cache-dir funtextprocessing kwsbp==0.0.6 https://modelscope.oss-cn-beijing.aliyuncs.com/releases/dependencies/faiss-1.7.2-py37-none-linux_x86_64.whl safetensors typeguard==2.13.3 scikit-learn librosa==0.9.2 funasr -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html; \
fi
RUN pip install --no-cache-dir wenetruntime==1.11.0 adaseq --no-deps
@@ -44,5 +47,11 @@ ENV SETUPTOOLS_USE_DISTUTILS=stdlib
RUN CUDA_HOME=/usr/local/cuda TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0 7.5 8.0 8.6" pip install --no-cache-dir 'git+https://github.com/facebookresearch/detectron2.git'
# add basicsr
RUN pip install --no-cache-dir basicsr
# torchmetrics==0.11.4 for ofa
RUN pip install --no-cache-dir tiktoken torchmetrics==0.11.4 'transformers<4.31.0' transformers_stream_generator 'protobuf<=3.20.0' bitsandbytes basicsr
COPY docker/scripts/install_flash_attension.sh /tmp/install_flash_attension.sh
RUN if [ "$USE_GPU" = "True" ] ; then \
bash /tmp/install_flash_attension.sh; \
else \
echo 'cpu unsupport flash attention'; \
fi

View File

@@ -69,14 +69,20 @@ RUN if [ "$USE_GPU" = "True" ] ; then \
# install tensorflow
ARG TENSORFLOW_VERSION=1.15.5
RUN if [ "$USE_GPU" = "True" ] ; then \
pip install --no-cache-dir tensorflow==$TENSORFLOW_VERSION -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html; \
if [ "$TENSORFLOW_VERSION" = "1.15.5" ] ; then \
pip install --no-cache-dir tensorflow==$TENSORFLOW_VERSION -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html; \
else \
pip install --no-cache-dir tensorflow==$TENSORFLOW_VERSION; \
fi \
else \
# only python 3.7 has tensorflow 1.15.5
if [ "$PYTHON_VERSION" = "3.7.13" ] ; then \
pip install --no-cache-dir tensorflow==$TENSORFLOW_VERSION; \
else \
elif [ "$TENSORFLOW_VERSION" = "1.15.5" ] ; then \
pip install --no-cache-dir numpy==1.18.5 https://modelscope.oss-cn-beijing.aliyuncs.com/releases/dependencies/tensorflow-1.15.5-cp38-cp38-linux_x86_64.whl; \
fi \
else \
pip install --no-cache-dir tensorflow==$TENSORFLOW_VERSION; \
fi \
fi
# mmcv-full<=1.7.0 for mmdet3d compatible

View File

@@ -0,0 +1,6 @@
git clone -b v1.0.8 https://github.com/Dao-AILab/flash-attention && \
cd flash-attention && pip install . && \
pip install csrc/layer_norm && \
pip install csrc/rotary && \
cd .. && \
rm -rf flash-attention

View File

@@ -1,14 +1,20 @@
export CMAKE_BUILD_PARALLEL_LEVEL=36 && export MAX_JOBS=36 && export CMAKE_CUDA_ARCHITECTURES="50;52;60;61;70;75;80;86" \
&& pip install --no-cache-dir fvcore iopath \
&& curl -LO https://github.com/NVIDIA/cub/archive/1.16.0.tar.gz \
&& tar xzf 1.16.0.tar.gz \
&& export CUB_HOME=$PWD/cub-1.16.0 \
export CMAKE_BUILD_PARALLEL_LEVEL=36 \
&& export MAX_JOBS=36 \
&& export CMAKE_CUDA_ARCHITECTURES="50;52;60;61;70;75;80;86" \
&& git clone --branch 2.1.0 --recursive https://github.com/NVIDIA/thrust.git \
&& cd thrust \
&& mkdir build \
&& cd build \
&& cmake -DCMAKE_INSTALL_PREFIX=/usr/local/cuda/ -DTHRUST_INCLUDE_CUB_CMAKE=ON .. \
&& make install \
&& cd ../.. \
&& rm -rf thrust \
&& pip install --no-cache-dir fvcore iopath \
&& pip install "git+https://github.com/facebookresearch/pytorch3d.git@stable" \
&& rm -fr 1.16.0.tar.gz cub-1.16.0 \
&& apt-get update \
&& apt-get install -y --no-install-recommends pkg-config libglvnd0 libgl1 libglx0 libegl1 libgles2 libglvnd-dev libgl1-mesa-dev libegl1-mesa-dev libgles2-mesa-dev -y \
&& apt-get install -y --no-install-recommends pkg-config libglvnd0 libgl1 libglx0 libegl1 libgles2 libglvnd-dev libgl1-mesa-dev libegl1-mesa-dev libgles2-mesa-dev -y \
&& git clone https://github.com/NVlabs/nvdiffrast.git \
&& cd nvdiffrast \
&& cd nvdiffrast \
&& pip install --no-cache-dir . \
&& cd .. \
&& rm -rf nvdiffrast

View File

@@ -10,10 +10,11 @@ import json
import torch
from swift import LoRAConfig, Swift
from modelscope import TrainingArgs
from modelscope import TrainingArgs, build_dataset_from_file
from modelscope.hub.snapshot_download import snapshot_download
from modelscope.metainfo import Trainers
from modelscope.models.nlp.llama import LlamaForTextGeneration, LlamaTokenizer
from modelscope.msdatasets import MsDataset
from modelscope.msdatasets.dataset_cls.custom_datasets.torch_custom_dataset import \
TorchCustomDataset
from modelscope.trainers import build_trainer
@@ -38,6 +39,23 @@ PROMPT_DICT = {
@dataclass(init=False)
class TextGenerationArguments(TrainingArgs):
instruction: str = field(
default='instruction',
metadata={
'help': 'The instruction text key of dataset',
})
input: str = field(
default='input', metadata={
'help': 'The input text key of dataset',
})
output: str = field(
default='output',
metadata={
'help': 'The output text key of dataset',
})
src_txt: str = field(
default=None,
metadata={
@@ -145,12 +163,7 @@ def smart_tokenizer_and_embedding_resize(special_tokens_dict, tokenizer,
class SupervisedDataset(TorchCustomDataset):
"""Dataset for supervised fine-tuning."""
def __init__(self, data_path: str, tokenizer):
logging.warning('Loading data...')
f = open(data_path, 'r')
list_data_dict = json.load(f)
f.close()
def __init__(self, list_data_dict, tokenizer):
logging.warning('Formatting inputs...')
prompt_input, prompt_no_input = PROMPT_DICT[
'prompt_input'], PROMPT_DICT['prompt_no_input']
@@ -173,6 +186,24 @@ class SupervisedDataset(TorchCustomDataset):
def __len__(self):
return len(self.input_ids)
def __getitem__(self, i):
if isinstance(i, int):
return dict(input_ids=self.input_ids[i], labels=self.labels[i])
elif isinstance(i, slice):
return SliceSupervisedDataset(self.input_ids, self.labels, i)
else:
raise TypeError(f'Unsupported input type: {type(i)}')
class SliceSupervisedDataset(TorchCustomDataset):
def __init__(self, input_ids, labels, slice_):
self.input_ids = input_ids[slice_]
self.labels = labels[slice_]
def __len__(self):
return len(self.input_ids)
def __getitem__(self, i):
return dict(input_ids=self.input_ids[i], labels=self.labels[i])
@@ -199,7 +230,9 @@ class DataCollatorForSupervisedDataset(object):
)
config, args = TextGenerationArguments().parse_cli().to_config()
training_args = TextGenerationArguments().parse_cli()
config, args = training_args.to_config()
print(args)
if __name__ == '__main__':
@@ -217,7 +250,7 @@ if __name__ == '__main__':
}
cfg.train.optimizer = {
'type': 'AdamW',
'lr': 2e-5,
'lr': training_args.lr,
'weight_decay': 0.0,
'options': {
'cumulative_iters': 8,
@@ -227,9 +260,15 @@ if __name__ == '__main__':
}
}
}
cfg.train.logging = {'interval': 8, 'by_epoch': False}
cfg.train.logging = {
'interval': training_args.logging_interval,
'by_epoch': False
}
cfg.train['bf16'] = True
cfg.train.dataloader = {'batch_size_per_gpu': 4, 'workers_per_gpu': 1}
cfg.train.dataloader = {
'batch_size_per_gpu': training_args.per_device_train_batch_size,
'workers_per_gpu': 1
}
if 'hooks' not in cfg.train:
cfg.train['hooks'] = []
if args.deepspeed is not None:
@@ -247,8 +286,49 @@ if __name__ == '__main__':
model_path = args.model if os.path.exists(
args.model) else snapshot_download(args.model)
data_path = args.src_txt if args.src_txt else os.path.join(
model_path, 'alpaca_data.json')
dataset_mapping_dict = {
args.instruction: 'instruction',
args.input: 'input',
args.output: 'output'
}
if args.dataset_json_file is None:
if args.train_dataset_name is not None and args.val_dataset_name is not None:
train_dataset = MsDataset.load(
args.train_dataset_name,
subset_name=args.train_subset_name,
split=args.train_split,
namespace=args.train_dataset_namespace).remap_columns(
dataset_mapping_dict)
validation_dataset = MsDataset.load(
args.val_dataset_name,
subset_name=args.val_subset_name,
split=args.val_split,
namespace=args.val_dataset_namespace).remap_columns(
dataset_mapping_dict)
elif args.train_dataset_name is not None and args.val_dataset_name is None:
ms_dataset = MsDataset.load(
args.train_dataset_name,
subset_name=args.train_subset_name,
split=args.train_split,
namespace=args.train_dataset_namespace).remap_columns(
dataset_mapping_dict).train_test_split(
test_size=0.02, seed=args.seed)
train_dataset = ms_dataset['train']
validation_dataset = ms_dataset['test']
else:
data_path = training_args.src_txt if training_args.src_txt else os.path.join(
model_path, 'alpaca_data.json')
ms_dataset = MsDataset.load(
'json', data_files=data_path).remap_columns(
dataset_mapping_dict).train_test_split(
test_size=0.02, seed=args.seed)
train_dataset = ms_dataset['train']
validation_dataset = ms_dataset['test']
else:
train_dataset, validation_dataset = build_dataset_from_file(
args.dataset_json_file)
model = LlamaForTextGeneration.from_pretrained(
model_path, device_map=args.device_map)
@@ -283,17 +363,19 @@ if __name__ == '__main__':
model=model,
)
train_dataset = SupervisedDataset(tokenizer=tokenizer, data_path=data_path)
train_dataset = SupervisedDataset(
tokenizer=tokenizer, list_data_dict=train_dataset)
validation_dataset = SupervisedDataset(
tokenizer=tokenizer, list_data_dict=validation_dataset)
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
kwargs = dict(
model=model,
cfg_file=os.path.join(model_path, 'configuration.json'),
train_dataset=train_dataset,
eval_dataset=validation_dataset,
data_collator=data_collator,
max_epochs=3,
cfg_modify_fn=cfg_modify_fn,
device='cpu')
cfg_modify_fn=cfg_modify_fn)
# Construct trainer and train
trainer = build_trainer(

View File

@@ -6,4 +6,5 @@ torchrun --nproc_per_node $DATA_PARALLEL_SIZE examples/pytorch/llama/finetune_ll
--work_dir './tmp' \
--model 'skyline2006/llama-7b' \
--deepspeed 'default_offload_opt_param.json' \
--eval_interval 100
--eval_interval 100 \
--max_epochs 3 \

View File

@@ -2,6 +2,22 @@ export PYTHONPATH=$PYTHONPATH:./
torchrun examples/pytorch/llama/finetune_llama.py \
--work_dir './tmp' \
--model 'skyline2006/llama-7b' \
--eval_interval 100 \
--train_dataset_name 'alpaca-gpt4-data-zh' \
--train_subset_name 'default' \
--train_split 'train' \
--train_dataset_namespace 'AI-ModelScope' \
--per_device_train_batch_size 4 \
--per_device_eval_batch_size 4 \
--eval_strategy 'by_epoch' \
--eval_interval 1 \
--eval_metrics 'ppl' \
--lr 2e-5 \
--save_strategy no \
--save_best true \
--metric_for_best_model ppl \
--metric_rule_for_best_model min \
--use_lora 1 \
--device_map 'auto' \
--task 'text-generation' \
--model.type 'llama' \
--max_epochs 3 \

View File

@@ -105,8 +105,8 @@ def llm_infer(args: InferArguments) -> None:
top_k=args.top_k,
top_p=args.top_p,
do_sample=True,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id)
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.eos_token_id)
logger.info(f'generation_config: {generation_config}')
if args.eval_human:

View File

@@ -1,5 +1,5 @@
CUDA_VISIBLE_DEVICES=0,1 \
python llm_infer.py \
--model_type qwen-7b \
--ckpt_path "runs/qwen-7b/vx_xxx/output_best/pytorch_model.bin" \
--model_type polylm-13b \
--ckpt_path "runs/polylm-13b/v0-20230802-172425/output_best/pytorch_model.bin" \
--eval_human true

View File

@@ -1,6 +1,6 @@
CUDA_VISIBLE_DEVICES=0,1 \
python llm_sft.py \
--model_type qwen-7b \
--model_type polylm-13b \
--output_dir runs \
--dataset alpaca-en,alpaca-zh \
--dataset alpaca-en,alpaca-zh,alpaca-multi \
--dataset_sample 20000

View File

@@ -141,6 +141,7 @@ class LoRATM(NamedTuple):
chatglm2 = ['query_key_value']
llama2 = ['q_proj', 'k_proj', 'v_proj']
qwen = ['c_attn']
polylm = ['c_attn']
# Reference: 'https://modelscope.cn/models/{model_id}/summary'

View File

@@ -0,0 +1,107 @@
import os
from dataclasses import dataclass, field
import cv2
from modelscope.metainfo import Trainers
from modelscope.msdatasets import MsDataset
from modelscope.pipelines import pipeline
from modelscope.trainers import EpochBasedTrainer, build_trainer
from modelscope.trainers.training_args import TrainingArgs
from modelscope.utils.constant import DownloadMode, Tasks
# Load configuration file and dataset
@dataclass(init=False)
class StableDiffusionCones2Arguments(TrainingArgs):
instance_prompt: str = field(
default='a photo of sks dog',
metadata={
'help': 'The instance prompt for cones.',
})
resolution: int = field(
default=768, metadata={
'help': 'The class images resolution.',
})
train_batch_size: int = field(
default=4,
metadata={
'help': 'Batch size (per device) for the training dataloader.',
})
sample_batch_size: int = field(
default=4,
metadata={
'help': 'Batch size (per device) for sampling images.',
})
prompt: str = field(
default='dog', metadata={
'help': 'The pipeline prompt.',
})
training_args = StableDiffusionCones2Arguments(
task='text-to-image-synthesis').parse_cli()
config, args = training_args.to_config()
if os.path.exists(args.train_dataset_name):
# Load local dataset
train_dataset = MsDataset.load(args.train_dataset_name)
validation_dataset = MsDataset.load(args.train_dataset_name)
else:
# Load online dataset
train_dataset = MsDataset.load(
args.train_dataset_name,
split='train',
download_mode=DownloadMode.FORCE_REDOWNLOAD)
validation_dataset = MsDataset.load(
args.train_dataset_name,
split='validation',
download_mode=DownloadMode.FORCE_REDOWNLOAD)
def cfg_modify_fn(cfg):
if args.use_model_config:
cfg.merge_from_dict(config)
else:
cfg = config
cfg.train.lr_scheduler = {
'type': 'LambdaLR',
'lr_lambda': lambda _: 1,
'last_epoch': -1
}
return cfg
kwargs = dict(
model=training_args.model,
model_revision=args.model_revision,
work_dir=training_args.work_dir,
train_dataset=train_dataset,
eval_dataset=validation_dataset,
cfg_modify_fn=cfg_modify_fn)
trainer = build_trainer(name=Trainers.cones2_inference, default_args=kwargs)
trainer.train()
# pipeline after training and save result
pipe = pipeline(
task=Tasks.text_to_image_synthesis,
model=training_args.work_dir + '/output',
model_revision=args.model_revision)
output = pipe({
'text': 'a mug and a dog on the beach',
'subject_list': [['mug', 2], ['dog', 5]],
'color_context': {
'255,192,0': ['mug', 2.5],
'255,0,0': ['dog', 2.5]
},
'layout': 'data/test/images/mask_example.png'
})
# visualize the result on ipynb and save it
output
cv2.imwrite('./cones2_result.png', output['output_imgs'][0])

View File

@@ -0,0 +1,13 @@
PYTHONPATH=. torchrun examples/pytorch/stable_diffusion/cones2/finetune_stable_diffusion_cones2.py \
--model 'damo/Cones2' \
--model_revision 'v1.0.1' \
--instance_prompt="dog" \
--work_dir './tmp/cones2_diffusion' \
--train_dataset_name 'buptwq/lora-stable-diffusion-finetune-dog' \
--max_epochs 250 \
--save_ckpt_strategy 'by_epoch' \
--logging_interval 1 \
--train.dataloader.workers_per_gpu 0 \
--evaluation.dataloader.workers_per_gpu 0 \
--train.optimizer.lr 1e-5 \
--use_model_config true

View File

@@ -4,30 +4,29 @@
# TODO: handle environments without threads
# (Python compiled without thread support)
import numpy as np
import simplejson as json
from operator import attrgetter
from sortedcontainers import SortedList
from datetime import datetime, timedelta, date, time
from dateutil.parser import parse as parse_datetime
from functools import wraps, partial
from operator import methodcaller
from decimal import Decimal
from fractions import Fraction
from collections import namedtuple
import threading
import uuid
import numpy as np
from collections import namedtuple
from datetime import date, datetime, time, timedelta
from dateutil.parser import parse as parse_datetime
from decimal import Decimal
from fractions import Fraction
from functools import partial, wraps
from operator import attrgetter, methodcaller
from sortedcontainers import SortedList
try:
from moneyed import Money, Currency
from moneyed import Currency, Money
except ImportError:
# defer failing to actual (de-)serialization
pass
__all__ = ["loads", "dumps", "pretty",
"json_loads", "json_dumps", "json_prettydump",
"encoder", "decoder"]
__all__ = [
"loads", "dumps", "pretty", "json_loads", "json_dumps", "json_prettydump",
"encoder", "decoder"
]
# Should we aim for the *exact* reproduction of Python types,
# or for maximum *compatibility* when (de-)serializing?
@@ -59,12 +58,15 @@ CODING_DEFAULT = EXACT
_local = threading.local()
def prefer(coding):
_local.coding = coding
def prefer_exact():
prefer(EXACT)
def prefer_compat():
prefer(COMPAT)
@@ -103,15 +105,18 @@ def kwargified(constructor):
>>> test({'b': 3})
4
"""
@wraps(constructor)
def kwargs_constructor(kwargs):
return constructor(**kwargs)
return kwargs_constructor
_PredicatedEncoder = namedtuple('_PredicatedEncoder',
'priority predicate encoder typename')
def encoder(classname, predicate=None, priority=None, exact=True):
"""A decorator for registering a new encoder for object type
defined either by a `classname`, or detected via `predicate`.
@@ -182,14 +187,18 @@ def _json_default_exact(obj):
# first try predicate-based encoders
for handler in _encode_handlers['exact']['predicate']:
if handler.predicate(obj):
return {"__class__": handler.typename,
"__value__": handler.encoder(obj)}
return {
"__class__": handler.typename,
"__value__": handler.encoder(obj)
}
# then classname-based
classname = type(obj).__name__
if classname in _encode_handlers['exact']['classname']:
return {"__class__": classname,
"__value__": _encode_handlers['exact']['classname'][classname](obj)}
return {
"__class__": classname,
"__value__": _encode_handlers['exact']['classname'][classname](obj)
}
raise TypeError(repr(obj) + " is not JSON serializable")
@@ -217,8 +226,10 @@ def decoder(classname):
def mytype_decoder(value):
return mytype(value, reconstruct=True)
"""
def _decorator(f):
_decode_handlers.setdefault(classname, f)
return _decorator
@@ -235,25 +246,25 @@ def _json_object_hook(dict):
return dict
def _encoder_default_args(kw):
"""Shape default arguments for encoding functions."""
# manual override of the preferred coding with `exact=False`
if kw.pop('exact', getattr(_local, 'coding', CODING_DEFAULT) == EXACT):
# settings necessary for the "exact coding"
kw.update({
'default': _json_default_exact,
'use_decimal': False, # don't encode `Decimal` as JSON's `Number`
'tuple_as_array': False, # don't encode `tuple` as `Array`
'namedtuple_as_object': False # don't call `_asdict` on `namedtuple`
'use_decimal': False, # don't encode `Decimal` as JSON's `Number`
'tuple_as_array': False, # don't encode `tuple` as `Array`
'namedtuple_as_object':
False # don't call `_asdict` on `namedtuple`
})
else:
# settings for the "compatibility coding"
kw.update({
'default': _json_default_compat,
'ignore_nan': True # be compliant with the ECMA-262 specification:
# serialize nan/inf as null
'ignore_nan': True # be compliant with the ECMA-262 specification:
# serialize nan/inf as null
})
# NOTE: if called from ``simplejson.dumps()`` with ``cls=JSONEncoder``,
@@ -276,8 +287,8 @@ def _decoder_default_args(kw):
kw.update({'object_hook': _json_object_hook})
class JSONEncoder(json.JSONEncoder):
def __init__(self, **kw):
"""Constructor for simplejson.JSONEncoder, with defaults overriden
for jsonplus.
@@ -287,6 +298,7 @@ class JSONEncoder(json.JSONEncoder):
class JSONDecoder(json.JSONDecoder):
def __init__(self, **kw):
"""Constructor for simplejson.JSONDecoder, with defaults overriden
for jsonplus.
@@ -295,7 +307,6 @@ class JSONDecoder(json.JSONDecoder):
super(JSONDecoder, self).__init__(**kw)
def dumps(*pa, **kw):
_encoder_default_args(kw)
return json.dumps(*pa, **kw)
@@ -306,14 +317,13 @@ def loads(*pa, **kw):
return json.loads(*pa, **kw)
def pretty(x, sort_keys=True, indent=4*' ', separators=(',', ': '), **kw):
def pretty(x, sort_keys=True, indent=4 * ' ', separators=(',', ': '), **kw):
kw.setdefault('sort_keys', sort_keys)
kw.setdefault('indent', indent)
kw.setdefault('separators', separators)
return dumps(x, **kw)
json_dumps = dumps
json_loads = loads
json_prettydump = pretty
@@ -330,21 +340,36 @@ def generic_to_item(value):
_encode_handlers = {
'exact': {
'classname': {
'datetime': methodcaller('isoformat'),
'date': methodcaller('isoformat'),
'time': methodcaller('isoformat'),
'timedelta': partial(getattrs, attrs=['days', 'seconds', 'microseconds']),
'tuple': list,
'set': list,
'ndarray': np_to_list,
'float16': generic_to_item,
'float32': generic_to_item,
'frozenset': list,
'complex': partial(getattrs, attrs=['real', 'imag']),
'Decimal': str,
'Fraction': partial(getattrs, attrs=['numerator', 'denominator']),
'UUID': partial(getattrs, attrs=['hex']),
'Money': partial(getattrs, attrs=['amount', 'currency'])
'datetime':
methodcaller('isoformat'),
'date':
methodcaller('isoformat'),
'time':
methodcaller('isoformat'),
'timedelta':
partial(getattrs, attrs=['days', 'seconds', 'microseconds']),
'tuple':
list,
'set':
list,
'ndarray':
np_to_list,
'float16':
generic_to_item,
'float32':
generic_to_item,
'frozenset':
list,
'complex':
partial(getattrs, attrs=['real', 'imag']),
'Decimal':
str,
'Fraction':
partial(getattrs, attrs=['numerator', 'denominator']),
'UUID':
partial(getattrs, attrs=['hex']),
'Money':
partial(getattrs, attrs=['amount', 'currency'])
},
'predicate': SortedList(key=attrgetter('priority'))
},
@@ -368,7 +393,6 @@ _encode_handlers = {
}
}
# all decode handlers are for EXACT decoding BY CLASSNAME
_decode_handlers = {
'datetime': parse_datetime,
@@ -388,11 +412,14 @@ _decode_handlers = {
}
@encoder('namedtuple', lambda obj: isinstance(obj, tuple) and hasattr(obj, '_fields'))
@encoder('namedtuple',
lambda obj: isinstance(obj, tuple) and hasattr(obj, '_fields'))
def _dump_namedtuple(obj):
return {"name": type(obj).__name__,
"fields": list(obj._fields),
"values": list(obj)}
return {
"name": type(obj).__name__,
"fields": list(obj._fields),
"values": list(obj)
}
@decoder('namedtuple')
@@ -404,7 +431,8 @@ def _load_namedtuple(val):
@encoder('timedelta', exact=False)
def _timedelta_total_seconds(td):
# timedelta.total_seconds() is only available since python 2.7
return (td.microseconds + (td.seconds + td.days * 24 * 3600.0) * 10**6) / 10**6
return (td.microseconds +
(td.seconds + td.days * 24 * 3600.0) * 10**6) / 10**6
@encoder('Currency')
@@ -412,7 +440,7 @@ def _dump_currency(obj):
"""Serialize standard (ISO-defined) currencies to currency code only,
and non-standard (user-added) currencies in full.
"""
from moneyed import get_currency, CurrencyDoesNotExist
from moneyed import CurrencyDoesNotExist, get_currency
try:
get_currency(obj.code)
return obj.code

View File

@@ -114,6 +114,7 @@ class Models(object):
nerf_recon_acc = 'nerf-recon-acc'
nerf_recon_4k = 'nerf-recon-4k'
nerf_recon_vq_compression = 'nerf-recon-vq-compression'
surface_recon_common = 'surface-recon-common'
bts_depth_estimation = 'bts-depth-estimation'
vision_efficient_tuning = 'vision-efficient-tuning'
bad_image_detecting = 'bad-image-detecting'
@@ -122,6 +123,7 @@ class Models(object):
fastinst = 'fastinst'
pedestrian_attribute_recognition = 'pedestrian-attribute-recognition'
image_try_on = 'image-try-on'
human_image_generation = 'human-image-generation'
# nlp models
bert = 'bert'
@@ -183,6 +185,7 @@ class Models(object):
speech_dfsmn_kws_char_farfield_iot = 'speech_dfsmn_kws_char_farfield_iot'
speech_kws_fsmn_char_ctc_nearfield = 'speech_kws_fsmn_char_ctc_nearfield'
speech_mossformer_separation_temporal_8k = 'speech_mossformer_separation_temporal_8k'
speech_mossformer2_separation_temporal_8k = 'speech_mossformer2_separation_temporal_8k'
kws_kwsbp = 'kws-kwsbp'
generic_asr = 'generic-asr'
wenet_asr = 'wenet-asr'
@@ -195,6 +198,7 @@ class Models(object):
eres2net_aug_sv = 'eres2net-aug-sv'
scl_sd = 'scl-sd'
campplus_lre = 'cam++-lre'
eres2net_lre = 'eres2net-lre'
cluster_backend = 'cluster-backend'
rdino_tdnn_sv = 'rdino_ecapa-tdnn-sv'
generic_lm = 'generic-lm'
@@ -210,12 +214,15 @@ class Models(object):
video_synthesis = 'latent-text-to-video-synthesis'
team = 'team-multi-modal-similarity'
video_clip = 'video-clip-multi-modal-embedding'
prost = 'prost-clip-text-video-retrieval'
mgeo = 'mgeo'
vldoc = 'vldoc'
hitea = 'hitea'
soonet = 'soonet'
efficient_diffusion_tuning = 'efficient-diffusion-tuning'
cones2_inference = 'cones2-inference'
mplug_owl = 'mplug-owl'
clip_interrogator = 'clip-interrogator'
stable_diffusion = 'stable-diffusion'
stable_diffusion_xl = 'stable-diffusion-xl'
@@ -280,6 +287,7 @@ class Pipelines(object):
universal_matting = 'unet-universal-matting'
image_denoise = 'nafnet-image-denoise'
image_deblur = 'nafnet-image-deblur'
image_editing = 'masactrl-image-editing'
person_image_cartoon = 'unet-person-image-cartoon'
ocr_detection = 'resnet18-ocr-detection'
table_recognition = 'dla34-table-recognition'
@@ -420,6 +428,7 @@ class Pipelines(object):
nerf_recon_acc = 'nerf-recon-acc'
nerf_recon_4k = 'nerf-recon-4k'
nerf_recon_vq_compression = 'nerf-recon-vq-compression'
surface_recon_common = 'surface-recon-common'
bad_image_detecting = 'bad-image-detecting'
controllable_image_generation = 'controllable-image-generation'
fast_instance_segmentation = 'fast-instance-segmentation'
@@ -431,6 +440,7 @@ class Pipelines(object):
pedestrian_attribute_recognition = 'resnet50_pedestrian-attribute-recognition_image'
text_to_360panorama_image = 'text-to-360panorama-image'
image_try_on = 'image-try-on'
human_image_generation = 'human-image-generation'
# nlp tasks
automatic_post_editing = 'automatic-post-editing'
@@ -508,10 +518,12 @@ class Pipelines(object):
sv_inference = 'sv-inference'
speaker_diarization_inference = 'speaker-diarization-inference'
vad_inference = 'vad-inference'
funasr_speech_separation = 'funasr-speech-separation'
speaker_verification = 'speaker-verification'
speaker_verification_rdino = 'speaker-verification-rdino'
speaker_verification_eres2net = 'speaker-verification-eres2net'
speech_language_recognition = 'speech-language-recognition'
speech_language_recognition_eres2net = 'speech-language-recognition-eres2net'
speaker_change_locating = 'speaker-change-locating'
speaker_diarization_dialogue_detection = 'speaker-diarization-dialogue-detection'
speaker_diarization_semantic_speaker_turn_detection = 'speaker-diarization-semantic-speaker-turn-detection'
@@ -529,6 +541,7 @@ class Pipelines(object):
multi_modal_similarity = 'multi-modal-similarity'
text_to_image_synthesis = 'text-to-image-synthesis'
video_multi_modal_embedding = 'video-multi-modal-embedding'
prost_text_video_retrieval = 'prost-text-video-retrieval'
videocomposer = 'videocomposer'
image_text_retrieval = 'image-text-retrieval'
ofa_ocr_recognition = 'ofa-ocr-recognition'
@@ -541,6 +554,7 @@ class Pipelines(object):
disco_guided_diffusion = 'disco_guided_diffusion'
document_vl_embedding = 'document-vl-embedding'
chinese_stable_diffusion = 'chinese-stable-diffusion'
cones2_inference = 'cones2-inference'
text_to_video_synthesis = 'latent-text-to-video-synthesis' # latent-text-to-video-synthesis
gridvlp_multi_modal_classification = 'gridvlp-multi-modal-classification'
gridvlp_multi_modal_embedding = 'gridvlp-multi-modal-embedding'
@@ -605,6 +619,8 @@ DEFAULT_MODEL_FOR_PIPELINE = {
'damo/cv_nafnet_image-denoise_sidd'),
Tasks.image_deblurring: (Pipelines.image_deblur,
'damo/cv_nafnet_image-deblur_gopro'),
Tasks.image_editing: (Pipelines.image_editing,
'damo/cv_masactrl_image-editing'),
Tasks.video_stabilization: (Pipelines.video_stabilization,
'damo/cv_dut-raft_video-stabilization_base'),
Tasks.video_super_resolution:
@@ -724,6 +740,8 @@ DEFAULT_MODEL_FOR_PIPELINE = {
Tasks.video_multi_modal_embedding:
(Pipelines.video_multi_modal_embedding,
'damo/multi_modal_clip_vtretrival_msrvtt_53'),
Tasks.text_video_retrieval: (Pipelines.prost_text_video_retrieval,
'damo/multi_modal_clip_vtretrieval_prost'),
Tasks.image_color_enhancement:
(Pipelines.image_color_enhance,
'damo/cv_csrnet_image-color-enhance-models'),
@@ -875,6 +893,8 @@ DEFAULT_MODEL_FOR_PIPELINE = {
Tasks.nerf_recon_vq_compression: (
Pipelines.nerf_recon_vq_compression,
'damo/cv_nerf-3d-reconstruction-vq-compression_damo'),
Tasks.surface_recon_common: (Pipelines.surface_recon_common,
'damo/cv_surface-reconstruction-common'),
Tasks.siamese_uie: (Pipelines.siamese_uie,
'damo/nlp_structbert_siamese-uie_chinese-base'),
Tasks.pedestrian_attribute_recognition: (
@@ -884,7 +904,9 @@ DEFAULT_MODEL_FOR_PIPELINE = {
Pipelines.text_to_360panorama_image,
'damo/cv_diffusion_text-to-360panorama-image_generation'),
Tasks.image_try_on: (Pipelines.image_try_on,
'damo/cv_SAL-VTON_virtual-try-on')
'damo/cv_SAL-VTON_virtual-try-on'),
Tasks.human_image_generation: (Pipelines.human_image_generation,
'damo/cv_FreqHPT_human-image-generation')
}
@@ -942,6 +964,7 @@ class MultiModalTrainers(object):
lora_diffusion_xl = 'lora-diffusion-xl'
dreambooth_diffusion = 'dreambooth-diffusion'
custom_diffusion = 'custom-diffusion'
cones2_inference = 'cones2-inference'
class AudioTrainers(object):

View File

@@ -1,3 +1,3 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from . import ans, asr, itn, kws, sv, tts
from . import ans, asr, itn, kws, separation, sv, tts

View File

@@ -15,6 +15,8 @@ __all__ = ['GenericAutomaticSpeechRecognition']
Tasks.auto_speech_recognition, module_name=Models.generic_asr)
@MODELS.register_module(
Tasks.voice_activity_detection, module_name=Models.generic_asr)
@MODELS.register_module(
Tasks.speech_separation, module_name=Models.generic_asr)
@MODELS.register_module(
Tasks.language_score_prediction, module_name=Models.generic_asr)
@MODELS.register_module(Tasks.speech_timestamp, module_name=Models.generic_asr)

View File

@@ -0,0 +1,24 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import TYPE_CHECKING
from modelscope.utils.import_utils import LazyImportModule
if TYPE_CHECKING:
from .mossformer import MossFormer
from .m2.mossformer import MossFormer2
else:
_import_structure = {
'mossformer': ['MossFormer'],
'm2.mossformer': ['MossFormer2'],
}
import sys
sys.modules[__name__] = LazyImportModule(
__name__,
globals()['__file__'],
_import_structure,
module_spec=__spec__,
extra_objects={},
)

View File

@@ -0,0 +1,278 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import torch
import torch.nn as nn
import torch.nn.init as init
from torch import Tensor
EPS = 1e-8
class GlobalLayerNorm(nn.Module):
"""Calculate Global Layer Normalization.
Args:
dim : (int or list or torch.Size)
Input shape from an expected input of size.
eps : float
A value added to the denominator for numerical stability.
elementwise_affine : bool
A boolean value that when set to True,
this module has learnable per-element affine parameters
initialized to ones (for weights) and zeros (for biases).
Example:
-------
>>> x = torch.randn(5, 10, 20)
>>> GLN = GlobalLayerNorm(10, 3)
>>> x_norm = GLN(x)
"""
def __init__(self, dim, shape, eps=1e-8, elementwise_affine=True):
super(GlobalLayerNorm, self).__init__()
self.dim = dim
self.eps = eps
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
if shape == 3:
self.weight = nn.Parameter(torch.ones(self.dim, 1))
self.bias = nn.Parameter(torch.zeros(self.dim, 1))
if shape == 4:
self.weight = nn.Parameter(torch.ones(self.dim, 1, 1))
self.bias = nn.Parameter(torch.zeros(self.dim, 1, 1))
else:
self.register_parameter('weight', None)
self.register_parameter('bias', None)
def forward(self, x):
"""Returns the normalized tensor.
Args:
x : torch.Tensor
Tensor of size [N, C, K, S] or [N, C, L].
"""
# x = N x C x K x S or N x C x L
# N x 1 x 1
# cln: mean,var N x 1 x K x S
# gln: mean,var N x 1 x 1
if x.dim() == 3:
mean = torch.mean(x, (1, 2), keepdim=True)
var = torch.mean((x - mean)**2, (1, 2), keepdim=True)
if self.elementwise_affine:
# yapf: disable
x = (self.weight * (x - mean) / torch.sqrt(var + self.eps)
+ self.bias)
# yapf: enable
else:
x = (x - mean) / torch.sqrt(var + self.eps)
if x.dim() == 4:
mean = torch.mean(x, (1, 2, 3), keepdim=True)
var = torch.mean((x - mean)**2, (1, 2, 3), keepdim=True)
if self.elementwise_affine:
# yapf: disable
x = (self.weight * (x - mean) / torch.sqrt(var + self.eps)
+ self.bias)
# yapf: enable
else:
x = (x - mean) / torch.sqrt(var + self.eps)
return x
class CumulativeLayerNorm(nn.LayerNorm):
"""Calculate Cumulative Layer Normalization.
Args:
dim : int
Dimension that you want to normalize.
elementwise_affine : True
Learnable per-element affine parameters.
Example:
-------
>>> x = torch.randn(5, 10, 20)
>>> CLN = CumulativeLayerNorm(10)
>>> x_norm = CLN(x)
"""
def __init__(self, dim, elementwise_affine=True):
super(CumulativeLayerNorm, self).__init__(
dim, elementwise_affine=elementwise_affine, eps=1e-8)
def forward(self, x):
"""Returns the normalized tensor.
Args:
x : torch.Tensor
Tensor size [N, C, K, S] or [N, C, L]
"""
# x: N x C x K x S or N x C x L
# N x K x S x C
if x.dim() == 4:
x = x.permute(0, 2, 3, 1).contiguous()
# N x K x S x C == only channel norm
x = super().forward(x)
# N x C x K x S
x = x.permute(0, 3, 1, 2).contiguous()
if x.dim() == 3:
x = torch.transpose(x, 1, 2)
# N x L x C == only channel norm
x = super().forward(x)
# N x C x L
x = torch.transpose(x, 1, 2)
return x
class Transpose(nn.Module):
""" Wrapper class of torch.transpose() for Sequential module. """
def __init__(self, shape: tuple):
super(Transpose, self).__init__()
self.shape = shape
def forward(self, x: Tensor) -> Tensor:
return x.transpose(*self.shape)
class DepthwiseConv1d(nn.Module):
"""When groups == in_channels and out_channels == K * in_channels, where K is a positive integer,
this operation is termed in literature as depthwise convolution.
Args:
in_channels (int): Number of channels in the input
out_channels (int): Number of channels produced by the convolution
kernel_size (int or tuple): Size of the convolving kernel
stride (int, optional): Stride of the convolution. Default: 1
padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
bias (bool, optional): If True, adds a learnable bias to the output. Default: True
Inputs: inputs
- **inputs** (batch, in_channels, time): Tensor containing input vector
Returns: outputs
- **outputs** (batch, out_channels, time): Tensor produces by depthwise 1-D convolution.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
padding: int = 0,
bias: bool = False,
) -> None:
super(DepthwiseConv1d, self).__init__()
assert out_channels % in_channels == 0, 'out_channels should be constant multiple of in_channels'
self.conv = nn.Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
groups=in_channels,
stride=stride,
padding=padding,
bias=bias,
)
def forward(self, inputs: Tensor) -> Tensor:
return self.conv(inputs)
class ConvModule(nn.Module):
"""
Conformer convolution module starts with a pointwise convolution and a gated linear unit (GLU).
This is followed by a single 1-D depthwise convolution layer. Batchnorm is deployed just after the convolution
to aid training deep models.
Args:
in_channels (int): Number of channels in the input
kernel_size (int or tuple, optional): Size of the convolving kernel Default: 17
dropout_p (float, optional): probability of dropout
Inputs: inputs
inputs (batch, time, dim): Tensor contains input sequences
Outputs: outputs
outputs (batch, time, dim): Tensor produces by conformer convolution module.
"""
def __init__(
self,
in_channels: int,
kernel_size: int = 17,
expansion_factor: int = 2,
dropout_p: float = 0.1,
) -> None:
super(ConvModule, self).__init__()
assert (
kernel_size - 1
) % 2 == 0, "kernel_size should be a odd number for 'SAME' padding"
assert expansion_factor == 2, 'Currently, Only Supports expansion_factor 2'
self.sequential = nn.Sequential(
Transpose(shape=(1, 2)),
DepthwiseConv1d(
in_channels,
in_channels,
kernel_size,
stride=1,
padding=(kernel_size - 1) // 2),
)
def forward(self, inputs: Tensor) -> Tensor:
return inputs + self.sequential(inputs).transpose(1, 2)
class DilatedDenseNet(nn.Module):
def __init__(self, depth=4, lorder=20, in_channels=64):
super(DilatedDenseNet, self).__init__()
self.depth = depth
self.in_channels = in_channels
self.pad = nn.ConstantPad2d((1, 1, 1, 0), value=0.)
self.twidth = lorder * 2 - 1
self.kernel_size = (self.twidth, 1)
for i in range(self.depth):
dil = 2**i
pad_length = lorder + (dil - 1) * (lorder - 1) - 1
setattr(self, 'pad{}'.format(i + 1),
nn.ConstantPad2d((0, 0, pad_length, pad_length), value=0.))
setattr(
self, 'conv{}'.format(i + 1),
nn.Conv2d(
self.in_channels * (i + 1),
self.in_channels,
kernel_size=self.kernel_size,
dilation=(dil, 1),
groups=self.in_channels,
bias=False))
setattr(self, 'norm{}'.format(i + 1),
nn.InstanceNorm2d(in_channels, affine=True))
setattr(self, 'prelu{}'.format(i + 1), nn.PReLU(self.in_channels))
def forward(self, x):
x = torch.unsqueeze(x, 1)
x_per = x.permute(0, 3, 2, 1)
skip = x_per
for i in range(self.depth):
out = getattr(self, 'pad{}'.format(i + 1))(skip)
out = getattr(self, 'conv{}'.format(i + 1))(out)
out = getattr(self, 'norm{}'.format(i + 1))(out)
out = getattr(self, 'prelu{}'.format(i + 1))(out)
skip = torch.cat([out, skip], dim=1)
out1 = out.permute(0, 3, 2, 1)
return out1.squeeze(1)
class FFConvMDilated(nn.Module):
def __init__(self, dim_in, dim_out, norm_klass=nn.LayerNorm, dropout=0.1):
super().__init__()
self.mdl = nn.Sequential(
norm_klass(dim_in), nn.Linear(dim_in, dim_out), nn.SiLU(),
DilatedDenseNet(depth=2, lorder=17, in_channels=dim_out),
nn.Dropout(dropout))
def forward(
self,
x,
):
output = self.mdl(x)
return output

View File

@@ -0,0 +1,144 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import torch as th
import torch.nn as nn
import torch.nn.functional as F
class UniDeepFsmn(nn.Module):
def __init__(self, input_dim, output_dim, lorder=None, hidden_size=None):
super(UniDeepFsmn, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
if lorder is None:
return
self.lorder = lorder
self.hidden_size = hidden_size
self.linear = nn.Linear(input_dim, hidden_size)
self.project = nn.Linear(hidden_size, output_dim, bias=False)
self.conv1 = nn.Conv2d(
output_dim,
output_dim, [lorder + lorder - 1, 1], [1, 1],
groups=output_dim,
bias=False)
def forward(self, input):
f1 = F.relu(self.linear(input))
p1 = self.project(f1)
x = th.unsqueeze(p1, 1)
x_per = x.permute(0, 3, 2, 1)
y = F.pad(x_per, [0, 0, self.lorder - 1, self.lorder - 1])
out = x_per + self.conv1(y)
out1 = out.permute(0, 3, 2, 1)
return input + out1.squeeze()
class UniDeepFsmnDual(nn.Module):
def __init__(self, input_dim, output_dim, lorder=None, hidden_size=None):
super(UniDeepFsmnDual, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
if lorder is None:
return
self.lorder = lorder
self.hidden_size = hidden_size
self.linear = nn.Linear(input_dim, hidden_size)
self.project = nn.Linear(hidden_size, output_dim, bias=False)
self.conv1 = nn.Conv2d(
output_dim,
output_dim, [lorder + lorder - 1, 1], [1, 1],
groups=output_dim,
bias=False)
self.conv2 = nn.Conv2d(
output_dim,
output_dim, [lorder + lorder - 1, 1], [1, 1],
groups=output_dim // 4,
bias=False)
def forward(self, input):
f1 = F.relu(self.linear(input))
p1 = self.project(f1)
x = th.unsqueeze(p1, 1)
x_per = x.permute(0, 3, 2, 1)
y = F.pad(x_per, [0, 0, self.lorder - 1, self.lorder - 1])
conv1_out = x_per + self.conv1(y)
z = F.pad(conv1_out, [0, 0, self.lorder - 1, self.lorder - 1])
out = conv1_out + self.conv2(z)
out1 = out.permute(0, 3, 2, 1)
return input + out1.squeeze()
class DilatedDenseNet(nn.Module):
def __init__(self, depth=4, lorder=20, in_channels=64):
super(DilatedDenseNet, self).__init__()
self.depth = depth
self.in_channels = in_channels
self.pad = nn.ConstantPad2d((1, 1, 1, 0), value=0.)
self.twidth = lorder * 2 - 1
self.kernel_size = (self.twidth, 1)
for i in range(self.depth):
dil = 2**i
pad_length = lorder + (dil - 1) * (lorder - 1) - 1
setattr(self, 'pad{}'.format(i + 1),
nn.ConstantPad2d((0, 0, pad_length, pad_length), value=0.))
setattr(
self, 'conv{}'.format(i + 1),
nn.Conv2d(
self.in_channels * (i + 1),
self.in_channels,
kernel_size=self.kernel_size,
dilation=(dil, 1),
groups=self.in_channels,
bias=False))
setattr(self, 'norm{}'.format(i + 1),
nn.InstanceNorm2d(in_channels, affine=True))
setattr(self, 'prelu{}'.format(i + 1), nn.PReLU(self.in_channels))
def forward(self, x):
skip = x
for i in range(self.depth):
out = getattr(self, 'pad{}'.format(i + 1))(skip)
out = getattr(self, 'conv{}'.format(i + 1))(out)
out = getattr(self, 'norm{}'.format(i + 1))(out)
out = getattr(self, 'prelu{}'.format(i + 1))(out)
skip = th.cat([out, skip], dim=1)
return out
class UniDeepFsmnDilated(nn.Module):
def __init__(self,
input_dim,
output_dim,
lorder=None,
hidden_size=None,
depth=2):
super(UniDeepFsmnDilated, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.depth = depth
if lorder is None:
return
self.lorder = lorder
self.hidden_size = hidden_size
self.linear = nn.Linear(input_dim, hidden_size)
self.project = nn.Linear(hidden_size, output_dim, bias=False)
self.conv = DilatedDenseNet(
depth=self.depth, lorder=lorder, in_channels=output_dim)
def forward(self, input):
f1 = F.relu(self.linear(input))
p1 = self.project(f1)
x = th.unsqueeze(p1, 1)
x_per = x.permute(0, 3, 2, 1)
out = self.conv(x_per)
out1 = out.permute(0, 3, 2, 1)
return input + out1.squeeze()

View File

@@ -0,0 +1,125 @@
# Copyright 2018 Northwestern Polytechnical University (author: Ke Wang)
from __future__ import absolute_import, division, print_function
import torch
import torch.nn as nn
class CLayerNorm(nn.LayerNorm):
"""Channel-wise layer normalization."""
def __init__(self, *args, **kwargs):
super(CLayerNorm, self).__init__(*args, **kwargs)
def forward(self, sample):
"""Forward function.
Args:
sample: [batch_size, channels, length]
"""
if sample.dim() != 3:
raise RuntimeError('{} only accept 3-D tensor as input'.format(
self.__name__))
# [N, C, T] -> [N, T, C]
sample = torch.transpose(sample, 1, 2)
# LayerNorm
sample = super().forward(sample)
# [N, T, C] -> [N, C, T]
sample = torch.transpose(sample, 1, 2)
return sample
class ILayerNorm(nn.InstanceNorm1d):
"""Channel-wise layer normalization."""
def __init__(self, *args, **kwargs):
super(ILayerNorm, self).__init__(*args, **kwargs)
def forward(self, sample):
"""Forward function.
Args:
sample: [batch_size, channels, length]
"""
if sample.dim() != 3:
raise RuntimeError('{} only accept 3-D tensor as input'.format(
self.__name__))
# [N, C, T] -> [N, T, C]
sample = torch.transpose(sample, 1, 2)
# LayerNorm
sample = super().forward(sample)
# [N, T, C] -> [N, C, T]
sample = torch.transpose(sample, 1, 2)
return sample
class GLayerNorm(nn.Module):
"""Global Layer Normalization for TasNet."""
def __init__(self, channels, eps=1e-5):
super(GLayerNorm, self).__init__()
self.eps = eps
self.norm_dim = channels
self.gamma = nn.Parameter(torch.Tensor(channels))
self.beta = nn.Parameter(torch.Tensor(channels))
# self.register_parameter('weight', self.gamma)
# self.register_parameter('bias', self.beta)
self.reset_parameters()
def reset_parameters(self):
nn.init.ones_(self.gamma)
nn.init.zeros_(self.beta)
def forward(self, sample):
"""Forward function.
Args:
sample: [batch_size, channels, length]
"""
if sample.dim() != 3:
raise RuntimeError('{} only accept 3-D tensor as input'.format(
self.__name__))
# [N, C, T] -> [N, T, C]
sample = torch.transpose(sample, 1, 2)
# Mean and variance [N, 1, 1]
mean = torch.mean(sample, (1, 2), keepdim=True)
var = torch.mean((sample - mean)**2, (1, 2), keepdim=True)
sample = (sample
- mean) / torch.sqrt(var + self.eps) * self.gamma + self.beta
# [N, T, C] -> [N, C, T]
sample = torch.transpose(sample, 1, 2)
return sample
class _LayerNorm(nn.Module):
"""Layer Normalization base class."""
def __init__(self, channel_size):
super(_LayerNorm, self).__init__()
self.channel_size = channel_size
self.gamma = nn.Parameter(torch.ones(channel_size), requires_grad=True)
self.beta = nn.Parameter(torch.zeros(channel_size), requires_grad=True)
def apply_gain_and_bias(self, normed_x):
""" Assumes input of size `[batch, chanel, *]`. """
return (self.gamma * normed_x.transpose(1, -1) + self.beta).transpose(
1, -1)
class GlobLayerNorm(_LayerNorm):
"""Global Layer Normalization (globLN)."""
def forward(self, x):
""" Applies forward pass.
Works for any input size > 2D.
Args:
x (:class:`torch.Tensor`): Shape `[batch, chan, *]`
Returns:
:class:`torch.Tensor`: gLN_x `[batch, chan, *]`
"""
dims = list(range(1, len(x.shape)))
mean = x.mean(dim=dims, keepdim=True)
var = torch.pow(x - mean, 2).mean(dim=dims, keepdim=True)
return self.apply_gain_and_bias((x - mean) / (var + 1e-8).sqrt())

View File

@@ -0,0 +1,599 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# Some code here is modified based on speechbrain and can be found on github
# https://github.com/speechbrain/speechbrain/blob/develop/speechbrain/lobes/models/dual_path.py
"""Library to support dual-path speech separation.
Authors
* Cem Subakan 2020
* Mirco Ravanelli 2020
* Samuele Cornell 2020
* Mirko Bronzi 2020
* Jianyuan Zhong 2020
"""
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from modelscope.metainfo import Models
from modelscope.models import MODELS, TorchModel
from modelscope.utils.constant import ModelFile, Tasks
from .mossformer_block import MossformerBlockGFSMN, ScaledSinuEmbedding
EPS = 1e-8
class GlobalLayerNorm(nn.Module):
"""Calculate Global Layer Normalization.
Args:
dim : (int or list or torch.Size)
Input shape from an expected input of size.
eps : float
A value added to the denominator for numerical stability.
elementwise_affine : bool
A boolean value that when set to True,
this module has learnable per-element affine parameters
initialized to ones (for weights) and zeros (for biases).
Example:
>>> x = torch.randn(5, 10, 20)
>>> GLN = GlobalLayerNorm(10, 3)
>>> x_norm = GLN(x)
"""
def __init__(self, dim, shape, eps=1e-8, elementwise_affine=True):
super(GlobalLayerNorm, self).__init__()
self.dim = dim
self.eps = eps
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
if shape == 3:
self.weight = nn.Parameter(torch.ones(self.dim, 1))
self.bias = nn.Parameter(torch.zeros(self.dim, 1))
if shape == 4:
self.weight = nn.Parameter(torch.ones(self.dim, 1, 1))
self.bias = nn.Parameter(torch.zeros(self.dim, 1, 1))
else:
self.register_parameter('weight', None)
self.register_parameter('bias', None)
def forward(self, x):
"""Returns the normalized tensor.
Args:
x : torch.Tensor
Tensor of size [N, C, K, S] or [N, C, L].
"""
# x = N x C x K x S or N x C x L
# N x 1 x 1
# cln: mean,var N x 1 x K x S
# gln: mean,var N x 1 x 1
if x.dim() == 3:
mean = torch.mean(x, (1, 2), keepdim=True)
var = torch.mean((x - mean)**2, (1, 2), keepdim=True)
if self.elementwise_affine:
# yapf: disable
x = (self.weight * (x - mean) / torch.sqrt(var + self.eps)
+ self.bias)
# yapf: enable
else:
x = (x - mean) / torch.sqrt(var + self.eps)
if x.dim() == 4:
mean = torch.mean(x, (1, 2, 3), keepdim=True)
var = torch.mean((x - mean)**2, (1, 2, 3), keepdim=True)
if self.elementwise_affine:
# yapf: disable
x = (self.weight * (x - mean) / torch.sqrt(var + self.eps)
+ self.bias)
# yapf: enable
else:
x = (x - mean) / torch.sqrt(var + self.eps)
return x
class CumulativeLayerNorm(nn.LayerNorm):
"""Calculate Cumulative Layer Normalization.
Args:
dim : int
Dimension that you want to normalize.
elementwise_affine : True
Learnable per-element affine parameters.
Example
-------
>>> x = torch.randn(5, 10, 20)
>>> CLN = CumulativeLayerNorm(10)
>>> x_norm = CLN(x)
"""
def __init__(self, dim, elementwise_affine=True):
super(CumulativeLayerNorm, self).__init__(
dim, elementwise_affine=elementwise_affine, eps=1e-8)
def forward(self, x):
"""Returns the normalized tensor.
Arguments
---------
x : torch.Tensor
Tensor size [N, C, K, S] or [N, C, L]
"""
# x: N x C x K x S or N x C x L
# N x K x S x C
if x.dim() == 4:
x = x.permute(0, 2, 3, 1).contiguous()
# N x K x S x C == only channel norm
x = super().forward(x)
# N x C x K x S
x = x.permute(0, 3, 1, 2).contiguous()
if x.dim() == 3:
x = torch.transpose(x, 1, 2)
# N x L x C == only channel norm
x = super().forward(x)
# N x C x L
x = torch.transpose(x, 1, 2)
return x
def select_norm(norm, dim, shape):
"""Just a wrapper to select the normalization type.
"""
if norm == 'gln':
return GlobalLayerNorm(dim, shape, elementwise_affine=True)
if norm == 'cln':
return CumulativeLayerNorm(dim, elementwise_affine=True)
if norm == 'ln':
return nn.GroupNorm(1, dim, eps=1e-8)
else:
return nn.BatchNorm1d(dim)
class Encoder(nn.Module):
"""Convolutional Encoder Layer.
Args:
kernel_size : int
Length of filters.
in_channels : int
Number of input channels.
out_channels : int
Number of output channels.
Example:
>>> x = torch.randn(2, 1000)
>>> encoder = Encoder(kernel_size=4, out_channels=64)
>>> h = encoder(x)
>>> h.shape
torch.Size([2, 64, 499])
"""
def __init__(self, kernel_size=2, out_channels=64, in_channels=1):
super(Encoder, self).__init__()
self.conv1d = nn.Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=kernel_size // 2,
groups=1,
bias=False,
)
self.in_channels = in_channels
def forward(self, x):
"""Return the encoded output.
Args:
x : torch.Tensor
Input tensor with dimensionality [B, L].
Returns:
x : torch.Tensor
Encoded tensor with dimensionality [B, N, T_out].
where B = Batchsize
L = Number of timepoints
N = Number of filters
T_out = Number of timepoints at the output of the encoder
"""
# B x L -> B x 1 x L
if self.in_channels == 1:
x = torch.unsqueeze(x, dim=1)
# B x 1 x L -> B x N x T_out
x = self.conv1d(x)
x = F.relu(x)
return x
class Decoder(nn.ConvTranspose1d):
"""A decoder layer that consists of ConvTranspose1d.
Args:
kernel_size : int
Length of filters.
in_channels : int
Number of input channels.
out_channels : int
Number of output channels.
Example:
---------
>>> x = torch.randn(2, 100, 1000)
>>> decoder = Decoder(kernel_size=4, in_channels=100, out_channels=1)
>>> h = decoder(x)
>>> h.shape
torch.Size([2, 1003])
"""
def __init__(self, *args, **kwargs):
super(Decoder, self).__init__(*args, **kwargs)
def forward(self, x):
"""Return the decoded output.
Args:
x : torch.Tensor
Input tensor with dimensionality [B, N, L].
where, B = Batchsize,
N = number of filters
L = time points
"""
if x.dim() not in [2, 3]:
raise RuntimeError('{} accept 3/4D tensor as input'.format(
self.__name__))
x = super().forward(x if x.dim() == 3 else torch.unsqueeze(x, 1))
if torch.squeeze(x).dim() == 1:
x = torch.squeeze(x, dim=1)
else:
x = torch.squeeze(x)
return x
class MossFormerM(nn.Module):
"""This class implements the transformer encoder.
Args:
num_blocks : int
Number of mossformer blocks to include.
d_model : int
The dimension of the input embedding.
attn_dropout : float
Dropout for the self-attention (Optional).
group_size: int
the chunk size
query_key_dim: int
the attention vector dimension
expansion_factor: int
the expansion factor for the linear projection in conv module
causal: bool
true for causal / false for non causal
Example:
-------
>>> import torch
>>> x = torch.rand((8, 60, 512))
>>> net = TransformerEncoder_MossFormerM(num_blocks=8, d_model=512)
>>> output, _ = net(x)
>>> output.shape
torch.Size([8, 60, 512])
"""
def __init__(self,
num_blocks,
d_model=None,
causal=False,
group_size=256,
query_key_dim=128,
expansion_factor=4.,
attn_dropout=0.1):
super().__init__()
self.mossformerM = MossformerBlockGFSMN(
dim=d_model,
depth=num_blocks,
group_size=group_size,
query_key_dim=query_key_dim,
expansion_factor=expansion_factor,
causal=causal,
attn_dropout=attn_dropout)
self.norm = nn.LayerNorm(d_model, eps=1e-6)
def forward(self, src):
"""
Args:
src : torch.Tensor
Tensor shape [B, L, N],
where, B = Batchsize,
L = time points
N = number of filters
The sequence to the encoder layer (required).
"""
output = self.mossformerM(src)
output = self.norm(output)
return output
class ComputationBlock(nn.Module):
"""Computation block for dual-path processing.
Args:
num_blocks : int
Number of mossformer blocks to include.
out_channels : int
Dimensionality of inter/intra model.
norm : str
Normalization type.
skip_around_intra : bool
Skip connection around the intra layer.
Example:
---------
>>> comp_block = ComputationBlock(64)
>>> x = torch.randn(10, 64, 100)
>>> x = comp_block(x)
>>> x.shape
torch.Size([10, 64, 100])
"""
def __init__(
self,
num_blocks,
out_channels,
norm='ln',
skip_around_intra=True,
):
super(ComputationBlock, self).__init__()
# MossFormer+: MossFormer with recurrence
self.intra_mdl = MossFormerM(
num_blocks=num_blocks, d_model=out_channels)
self.skip_around_intra = skip_around_intra
# Norm
self.norm = norm
if norm is not None:
self.intra_norm = select_norm(norm, out_channels, 3)
def forward(self, x):
"""Returns the output tensor.
Args:
x : torch.Tensor
Input tensor of dimension [B, N, S].
Returns:
out: torch.Tensor
Output tensor of dimension [B, N, S].
where, B = Batchsize,
N = number of filters
S = sequence time index
"""
B, N, S = x.shape
# intra RNN
# [B, S, N]
intra = x.permute(0, 2, 1).contiguous()
intra = self.intra_mdl(intra)
# [B, N, S]
intra = intra.permute(0, 2, 1).contiguous()
if self.norm is not None:
intra = self.intra_norm(intra)
# [B, N, S]
if self.skip_around_intra:
intra = intra + x
out = intra
return out
class MossFormerMaskNet(nn.Module):
"""The dual path model which is the basis for dualpathrnn, sepformer, dptnet.
Args:
in_channels : int
Number of channels at the output of the encoder.
out_channels : int
Number of channels that would be inputted to the intra and inter blocks.
norm : str
Normalization type.
num_spks : int
Number of sources (speakers).
skip_around_intra : bool
Skip connection around intra.
use_global_pos_enc : bool
Global positional encodings.
max_length : int
Maximum sequence length.
Example:
---------
>>> mossformer_block = MossFormerM(1, 64, 8)
>>> mossformer_masknet = MossFormerMaskNet(64, 64, intra_block, num_spks=2)
>>> x = torch.randn(10, 64, 2000)
>>> x = mossformer_masknet(x)
>>> x.shape
torch.Size([2, 10, 64, 2000])
"""
def __init__(
self,
in_channels,
out_channels,
num_blocks=24,
norm='ln',
num_spks=2,
skip_around_intra=True,
use_global_pos_enc=True,
max_length=20000,
):
super(MossFormerMaskNet, self).__init__()
self.num_spks = num_spks
self.num_blocks = num_blocks
self.norm = select_norm(norm, in_channels, 3)
self.conv1d_encoder = nn.Conv1d(
in_channels, out_channels, 1, bias=False)
self.use_global_pos_enc = use_global_pos_enc
if self.use_global_pos_enc:
self.pos_enc = ScaledSinuEmbedding(out_channels)
self.mdl = ComputationBlock(
num_blocks,
out_channels,
norm,
skip_around_intra=skip_around_intra,
)
self.conv1d_out = nn.Conv1d(
out_channels, out_channels * num_spks, kernel_size=1)
self.conv1_decoder = nn.Conv1d(
out_channels, in_channels, 1, bias=False)
self.prelu = nn.PReLU()
self.activation = nn.ReLU()
# gated output layer
self.output = nn.Sequential(
nn.Conv1d(out_channels, out_channels, 1), nn.Tanh())
self.output_gate = nn.Sequential(
nn.Conv1d(out_channels, out_channels, 1), nn.Sigmoid())
def forward(self, x):
"""Returns the output tensor.
Args:
x : torch.Tensor
Input tensor of dimension [B, N, S].
Returns:
out : torch.Tensor
Output tensor of dimension [spks, B, N, S]
where, spks = Number of speakers
B = Batchsize,
N = number of filters
S = the number of time frames
"""
# before each line we indicate the shape after executing the line
# [B, N, L]
x = self.norm(x)
# [B, N, L]
x = self.conv1d_encoder(x)
if self.use_global_pos_enc:
base = x
x = x.transpose(1, -1)
emb = self.pos_enc(x)
emb = emb.transpose(0, -1)
x = base + emb
# [B, N, S]
x = self.mdl(x)
x = self.prelu(x)
# [B, N*spks, S]
x = self.conv1d_out(x)
B, _, S = x.shape
# [B*spks, N, S]
x = x.view(B * self.num_spks, -1, S)
# [B*spks, N, S]
x = self.output(x) * self.output_gate(x)
# [B*spks, N, S]
x = self.conv1_decoder(x)
# [B, spks, N, S]
_, N, L = x.shape
x = x.view(B, self.num_spks, N, L)
x = self.activation(x)
# [spks, B, N, S]
x = x.transpose(0, 1)
return x
@MODELS.register_module(
Tasks.speech_separation,
module_name=Models.speech_mossformer2_separation_temporal_8k)
class MossFormer2(TorchModel):
"""Library to support MossFormer speech separation.
Args:
model_dir (str): the model path.
"""
def __init__(self,
model_dir: str,
in_channels=512,
out_channels=512,
num_blocks=24,
kernel_size=16,
norm='ln',
num_spks=2,
skip_around_intra=True,
use_global_pos_enc=True,
max_length=20000,
*args,
**kwargs):
super().__init__(model_dir, *args, **kwargs)
self.num_spks = num_spks
self.enc = Encoder(
kernel_size=kernel_size, out_channels=in_channels, in_channels=1)
self.mask_net = MossFormerMaskNet(
in_channels=in_channels,
out_channels=out_channels,
num_blocks=num_blocks,
norm=norm,
num_spks=num_spks,
skip_around_intra=skip_around_intra,
use_global_pos_enc=use_global_pos_enc,
max_length=max_length,
)
self.dec = Decoder(
in_channels=out_channels,
out_channels=1,
kernel_size=kernel_size,
stride=kernel_size // 2,
bias=False)
def forward(self, input):
x = self.enc(input)
mask = self.mask_net(x)
x = torch.stack([x] * self.num_spks)
sep_x = x * mask
# Decoding
est_source = torch.cat(
[self.dec(sep_x[i]).unsqueeze(-1) for i in range(self.num_spks)],
dim=-1,
)
T_origin = input.size(1)
T_est = est_source.size(1)
if T_origin > T_est:
est_source = F.pad(est_source, (0, 0, 0, T_origin - T_est))
else:
est_source = est_source[:, :T_origin, :]
return est_source
def load_check_point(self, load_path=None, device=None):
if not load_path:
load_path = self.model_dir
if not device:
device = torch.device('cpu')
self.load_state_dict(
torch.load(
os.path.join(load_path, ModelFile.TORCH_MODEL_FILE),
map_location=device),
strict=False)

View File

@@ -0,0 +1,548 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import torch
import torch.nn.functional as F
from einops import rearrange
from rotary_embedding_torch import RotaryEmbedding
from torch import einsum, nn
from .conv_module import ConvModule, FFConvMDilated
from .fsmn import UniDeepFsmn, UniDeepFsmnDilated
from .layer_norm import CLayerNorm
# functions
def identity(t, *args, **kwargs):
return t
def append_dims(x, num_dims):
if num_dims <= 0:
return x
return x.view(*x.shape, *((1, ) * num_dims))
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def padding_to_multiple_of(n, mult):
remainder = n % mult
if remainder == 0:
return 0
return mult - remainder
# scalenorm
class ScaleNorm(nn.Module):
def __init__(self, dim, eps=1e-5):
super().__init__()
self.scale = dim**-0.5
self.eps = eps
self.g = nn.Parameter(torch.ones(1))
def forward(self, x):
norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
return x / norm.clamp(min=self.eps) * self.g
# absolute positional encodings
class ScaledSinuEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
self.scale = nn.Parameter(torch.ones(1, ))
inv_freq = 1. / (10000**(torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)
def forward(self, x):
n, device = x.shape[1], x.device
t = torch.arange(n, device=device).type_as(self.inv_freq)
sinu = einsum('i , j -> i j', t, self.inv_freq)
emb = torch.cat((sinu.sin(), sinu.cos()), dim=-1)
return emb * self.scale
class OffsetScale(nn.Module):
def __init__(self, dim, heads=1):
super().__init__()
self.gamma = nn.Parameter(torch.ones(heads, dim))
self.beta = nn.Parameter(torch.zeros(heads, dim))
nn.init.normal_(self.gamma, std=0.02)
def forward(self, x):
out = einsum('... d, h d -> ... h d', x, self.gamma) + self.beta
return out.unbind(dim=-2)
class FFConvM(nn.Module):
def __init__(self, dim_in, dim_out, norm_klass=nn.LayerNorm, dropout=0.1):
super().__init__()
self.mdl = nn.Sequential(
norm_klass(dim_in), nn.Linear(dim_in, dim_out), nn.SiLU(),
ConvModule(dim_out), nn.Dropout(dropout))
def forward(
self,
x,
):
output = self.mdl(x)
return output
class GroupLinear(nn.Module):
def __init__(self, dim_in, dim_out, K=4):
super().__init__()
hidden = dim_in // 2
self.group_conv = nn.Conv1d(
dim_in, hidden, groups=dim_in // K, kernel_size=1)
self.norm = nn.LayerNorm(hidden)
self.linear = nn.Linear(hidden, dim_out)
def forward(
self,
x,
):
x1 = x.transpose(2, 1)
conv_out = self.group_conv(x1)
x2 = self.norm(conv_out.transpose(2, 1))
x3 = self.linear(x2)
return x3
class FFM(nn.Module):
def __init__(self, dim_in, dim_out, norm_klass=nn.LayerNorm, dropout=0.1):
super().__init__()
self.mdl = nn.Sequential(
norm_klass(dim_in), nn.Linear(dim_in, dim_out), nn.SiLU(),
nn.Dropout(dropout))
def forward(
self,
x,
):
output = self.mdl(x)
return output
# FLASH
class FLASH_ShareA_FFConvM(nn.Module):
def __init__(self,
*,
dim,
group_size=256,
query_key_dim=128,
expansion_factor=1.,
causal=False,
dropout=0.1,
rotary_pos_emb=None,
norm_klass=nn.LayerNorm,
shift_tokens=True):
super().__init__()
hidden_dim = int(dim * expansion_factor)
self.group_size = group_size
self.causal = causal
self.shift_tokens = shift_tokens
# positional embeddings
self.rotary_pos_emb = rotary_pos_emb
# norm
self.dropout = nn.Dropout(dropout)
# projections
self.to_hidden = FFConvM(
dim_in=dim,
dim_out=hidden_dim,
norm_klass=norm_klass,
dropout=dropout,
)
self.to_qk = FFConvM(
dim_in=dim,
dim_out=query_key_dim,
norm_klass=norm_klass,
dropout=dropout,
)
self.qk_offset_scale = OffsetScale(query_key_dim, heads=4)
self.to_out = FFConvM(
dim_in=dim * 2,
dim_out=dim,
norm_klass=norm_klass,
dropout=dropout,
)
self.gateActivate = nn.Sigmoid()
def forward(self, x, *, mask=None):
"""
b - batch
n - sequence length (within groups)
g - group dimension
d - feature dimension (keys)
e - feature dimension (values)
i - sequence dimension (source)
j - sequence dimension (target)
"""
# prenorm
normed_x = x
if self.shift_tokens:
x_shift, x_pass = normed_x.chunk(2, dim=-1)
x_shift = F.pad(x_shift, (0, 0, 1, -1), value=0.)
normed_x = torch.cat((x_shift, x_pass), dim=-1)
# initial projections
v, u = self.to_hidden(normed_x).chunk(2, dim=-1)
qk = self.to_qk(normed_x)
# offset and scale
quad_q, lin_q, quad_k, lin_k = self.qk_offset_scale(qk)
att_v, att_u = self.cal_attention(x, quad_q, lin_q, quad_k, lin_k, v,
u)
out = (att_u * v) * self.gateActivate(att_v * u)
x = x + self.to_out(out)
return x
def cal_attention(self, x, quad_q, lin_q, quad_k, lin_k, v, u, mask=None):
b, n, device, g = x.shape[0], x.shape[-2], x.device, self.group_size
if exists(mask):
lin_mask = rearrange(mask, '... -> ... 1')
lin_k = lin_k.masked_fill(~lin_mask, 0.)
# rotate queries and keys
if exists(self.rotary_pos_emb):
quad_q, lin_q, quad_k, lin_k = map(
self.rotary_pos_emb.rotate_queries_or_keys,
(quad_q, lin_q, quad_k, lin_k))
# padding for groups
padding = padding_to_multiple_of(n, g)
if padding > 0:
quad_q, quad_k, lin_q, lin_k, v, u = map(
lambda t: F.pad(t, (0, 0, 0, padding), value=0.),
(quad_q, quad_k, lin_q, lin_k, v, u))
mask = default(mask,
torch.ones((b, n), device=device, dtype=torch.bool))
mask = F.pad(mask, (0, padding), value=False)
# group along sequence
quad_q, quad_k, lin_q, lin_k, v, u = map(
lambda t: rearrange(t, 'b (g n) d -> b g n d', n=self.group_size),
(quad_q, quad_k, lin_q, lin_k, v, u))
if exists(mask):
mask = rearrange(mask, 'b (g j) -> b g 1 j', j=g)
# calculate quadratic attention output
sim = einsum('... i d, ... j d -> ... i j', quad_q, quad_k) / g
attn = F.relu(sim)**2
attn = self.dropout(attn)
if exists(mask):
attn = attn.masked_fill(~mask, 0.)
if self.causal:
causal_mask = torch.ones((g, g), dtype=torch.bool,
device=device).triu(1)
attn = attn.masked_fill(causal_mask, 0.)
quad_out_v = einsum('... i j, ... j d -> ... i d', attn, v)
quad_out_u = einsum('... i j, ... j d -> ... i d', attn, u)
# calculate linear attention output
if self.causal:
lin_kv = einsum('b g n d, b g n e -> b g d e', lin_k, v) / g
# exclusive cumulative sum along group dimension
lin_kv = lin_kv.cumsum(dim=1)
lin_kv = F.pad(lin_kv, (0, 0, 0, 0, 1, -1), value=0.)
lin_out_v = einsum('b g d e, b g n d -> b g n e', lin_kv, lin_q)
lin_ku = einsum('b g n d, b g n e -> b g d e', lin_k, u) / g
# exclusive cumulative sum along group dimension
lin_ku = lin_ku.cumsum(dim=1)
lin_ku = F.pad(lin_ku, (0, 0, 0, 0, 1, -1), value=0.)
lin_out_u = einsum('b g d e, b g n d -> b g n e', lin_ku, lin_q)
else:
lin_kv = einsum('b g n d, b g n e -> b d e', lin_k, v) / n
lin_out_v = einsum('b g n d, b d e -> b g n e', lin_q, lin_kv)
lin_ku = einsum('b g n d, b g n e -> b d e', lin_k, u) / n
lin_out_u = einsum('b g n d, b d e -> b g n e', lin_q, lin_ku)
# fold back groups into full sequence, and excise out padding
return map(lambda t: rearrange(t, 'b g n d -> b (g n) d')[:, :n],
(quad_out_v + lin_out_v, quad_out_u + lin_out_u))
class GatedFSMNDilated(nn.Module):
def __init__(self, in_channels, out_channels, lorder, hidden_size):
super().__init__()
self.to_u = FFConvM(
dim_in=in_channels,
dim_out=hidden_size,
norm_klass=nn.LayerNorm,
dropout=0.1,
)
self.to_v = FFConvM(
dim_in=in_channels,
dim_out=hidden_size,
norm_klass=nn.LayerNorm,
dropout=0.1,
)
self.fsmn = UniDeepFsmnDilated(in_channels, out_channels, lorder,
hidden_size)
def forward(
self,
x,
):
input = x
x_u = self.to_u(x)
x_v = self.to_v(x)
x_u = self.fsmn(x_u)
x = x_v * x_u + input
return x
class GatedFSMNDilatedDual(nn.Module):
def __init__(self, in_channels, out_channels, lorder, hidden_size):
super().__init__()
self.to_u = FFConvMDilated(
dim_in=in_channels,
dim_out=hidden_size,
norm_klass=nn.LayerNorm,
dropout=0.1,
)
self.to_v = FFConvMDilated(
dim_in=in_channels,
dim_out=hidden_size,
norm_klass=nn.LayerNorm,
dropout=0.1,
)
self.fsmn = UniDeepFsmnDilated(in_channels, out_channels, lorder,
hidden_size)
def forward(
self,
x,
):
input = x
x_u = self.to_u(x)
x_v = self.to_v(x)
x_u = self.fsmn(x_u)
x = x_v * x_u + input
return x
class GatedFSMNBlockDilatedDual(nn.Module):
"""1-D convolutional block."""
def __init__(
self,
dim,
inner_channels=256,
):
super(GatedFSMNBlockDilatedDual, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv1d(dim, inner_channels, kernel_size=1),
nn.PReLU(),
)
self.norm1 = CLayerNorm(inner_channels)
self.gated_fsmn = GatedFSMNDilatedDual(
inner_channels,
inner_channels,
lorder=20,
hidden_size=inner_channels)
self.norm2 = CLayerNorm(inner_channels)
self.conv2 = nn.Conv1d(inner_channels, dim, kernel_size=1)
def forward(self, input):
conv1 = self.conv1(input.transpose(2, 1))
norm1 = self.norm1(conv1)
seq_out = self.gated_fsmn(norm1.transpose(2, 1))
norm2 = self.norm2(seq_out.transpose(2, 1))
conv2 = self.conv2(norm2)
return conv2.transpose(2, 1) + input
class GatedFSMNBlockDilated(nn.Module):
"""1-D convolutional block."""
def __init__(
self,
dim,
inner_channels=256,
group_size=256,
norm_type='scalenorm',
):
super(GatedFSMNBlockDilated, self).__init__()
self.group_size = group_size
self.conv1 = nn.Sequential(
nn.Conv1d(dim, inner_channels, kernel_size=1),
nn.PReLU(),
)
self.norm1 = CLayerNorm(inner_channels)
# block dilated with gating
self.gated_fsmn = GatedFSMNDilated(
inner_channels,
inner_channels,
lorder=20,
hidden_size=inner_channels)
self.norm2 = CLayerNorm(inner_channels)
self.conv2 = nn.Conv1d(inner_channels, dim, kernel_size=1)
def forward(self, input):
conv1 = self.conv1(input.transpose(2, 1))
norm1 = self.norm1(conv1)
seq_out = self.gated_fsmn(norm1.transpose(2, 1))
norm2 = self.norm2(seq_out.transpose(2, 1))
conv2 = self.conv2(norm2)
return conv2.transpose(2, 1) + input
class MossformerBlockGFSMN(nn.Module):
def __init__(self,
*,
dim,
depth,
group_size=256,
query_key_dim=128,
expansion_factor=4.,
causal=False,
attn_dropout=0.1,
norm_type='scalenorm',
shift_tokens=True):
super().__init__()
assert norm_type in (
'scalenorm',
'layernorm'), 'norm_type must be one of scalenorm or layernorm'
if norm_type == 'scalenorm':
norm_klass = ScaleNorm
elif norm_type == 'layernorm':
norm_klass = nn.LayerNorm
self.group_size = group_size
rotary_pos_emb = RotaryEmbedding(dim=min(32, query_key_dim))
# max rotary embedding dimensions of 32, partial Rotary embeddings, from Wang et al - GPT-J
self.fsmn = nn.ModuleList(
[GatedFSMNBlockDilated(dim) for _ in range(depth)])
self.layers = nn.ModuleList([
FLASH_ShareA_FFConvM(
dim=dim,
group_size=group_size,
query_key_dim=query_key_dim,
expansion_factor=expansion_factor,
causal=causal,
dropout=attn_dropout,
rotary_pos_emb=rotary_pos_emb,
norm_klass=norm_klass,
shift_tokens=shift_tokens) for _ in range(depth)
])
def _build_repeats(self,
in_channels,
out_channels,
lorder,
hidden_size,
repeats=1):
repeats = [
UniDeepFsmn(in_channels, out_channels, lorder, hidden_size)
for i in range(repeats)
]
return nn.Sequential(*repeats)
def forward(self, x, *, mask=None):
ii = 0
for flash in self.layers:
x = flash(x, mask=mask)
x = self.fsmn[ii](x)
ii = ii + 1
return x
class MossformerBlock(nn.Module):
def __init__(self,
*,
dim,
depth,
group_size=256,
query_key_dim=128,
expansion_factor=4.,
causal=False,
attn_dropout=0.1,
norm_type='scalenorm',
shift_tokens=True):
super().__init__()
assert norm_type in (
'scalenorm',
'layernorm'), 'norm_type must be one of scalenorm or layernorm'
if norm_type == 'scalenorm':
norm_klass = ScaleNorm
elif norm_type == 'layernorm':
norm_klass = nn.LayerNorm
self.group_size = group_size
rotary_pos_emb = RotaryEmbedding(dim=min(32, query_key_dim))
# max rotary embedding dimensions of 32, partial Rotary embeddings, from Wang et al - GPT-J
self.layers = nn.ModuleList([
FLASH_ShareA_FFConvM(
dim=dim,
group_size=group_size,
query_key_dim=query_key_dim,
expansion_factor=expansion_factor,
causal=causal,
dropout=attn_dropout,
rotary_pos_emb=rotary_pos_emb,
norm_klass=norm_klass,
shift_tokens=shift_tokens) for _ in range(depth)
])
def _build_repeats(self,
in_channels,
out_channels,
lorder,
hidden_size,
repeats=1):
repeats = [
UniDeepFsmn(in_channels, out_channels, lorder, hidden_size)
for i in range(repeats)
]
return nn.Sequential(*repeats)
def forward(self, x, *, mask=None):
ii = 0
for flash in self.layers:
x = flash(x, mask=mask)
ii = ii + 1
return x

View File

@@ -0,0 +1,120 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
from typing import Any, Dict
import numpy as np
import torch
import torch.nn as nn
import torchaudio.compliance.kaldi as Kaldi
from modelscope.metainfo import Models
from modelscope.models import MODELS, TorchModel
from modelscope.models.audio.sv.DTDNN import CAMPPlus
from modelscope.models.audio.sv.DTDNN_layers import DenseLayer
from modelscope.models.audio.sv.ERes2Net import ERes2Net
from modelscope.utils.constant import Tasks
from modelscope.utils.device import create_device
class LinearClassifier(nn.Module):
def __init__(
self,
input_dim,
num_blocks=0,
inter_dim=512,
out_neurons=1000,
):
super().__init__()
self.blocks = nn.ModuleList()
self.nonlinear = nn.ReLU(inplace=True)
for _ in range(num_blocks):
self.blocks.append(DenseLayer(input_dim, inter_dim, bias=True))
input_dim = inter_dim
self.linear = nn.Linear(input_dim, out_neurons, bias=True)
def forward(self, x):
# x: [B, dim]
x = self.nonlinear(x)
for layer in self.blocks:
x = layer(x)
x = self.linear(x)
return x
@MODELS.register_module(
Tasks.speech_language_recognition, module_name=Models.eres2net_lre)
class LanguageRecognitionERes2Net(TorchModel):
r"""A speech language recognition model using the ERes2Net architecture as the backbone.
Args:
model_dir: A model dir.
model_config: The model config.
"""
def __init__(self, model_dir, model_config: Dict[str, Any], *args,
**kwargs):
super().__init__(model_dir, model_config, *args, **kwargs)
self.model_config = model_config
self.embed_dim = self.model_config['embed_dim']
self.m_channels = self.model_config['channels']
self.feature_dim = self.model_config['fbank_dim']
self.sample_rate = self.model_config['sample_rate']
self.device = create_device(kwargs['device'])
self.encoder = ERes2Net(
embed_dim=self.embed_dim, m_channels=self.m_channels)
self.backend = LinearClassifier(
input_dim=self.embed_dim,
out_neurons=len(self.model_config['languages']))
pretrained_encoder = kwargs['pretrained_encoder']
pretrained_backend = kwargs['pretrained_backend']
self._load_check_point(pretrained_encoder, pretrained_backend)
self.encoder.to(self.device)
self.backend.to(self.device)
self.encoder.eval()
self.backend.eval()
def forward(self, audio):
if isinstance(audio, np.ndarray):
audio = torch.from_numpy(audio)
if len(audio.shape) == 1:
audio = audio.unsqueeze(0)
assert len(audio.shape) == 2, \
'modelscope error: the shape of input audio to model needs to be [N, T]'
# audio shape: [N, T]
feature = self._extract_feature(audio)
embs = self.encoder(feature.to(self.device))
output = self.backend(embs)
output = output.detach().cpu().argmax(-1)
return output
def _extract_feature(self, audio):
features = []
for au in audio:
feature = Kaldi.fbank(
au.unsqueeze(0),
num_mel_bins=self.feature_dim,
sample_frequency=self.sample_rate)
feature = feature - feature.mean(dim=0, keepdim=True)
features.append(feature.unsqueeze(0))
features = torch.cat(features)
return features
def _load_check_point(self, pretrained_encoder, pretrained_backend):
self.encoder.load_state_dict(
torch.load(
os.path.join(self.model_dir, pretrained_encoder),
map_location=torch.device('cpu')))
self.backend.load_state_dict(
torch.load(
os.path.join(self.model_dir, pretrained_backend),
map_location=torch.device('cpu')))

View File

@@ -61,6 +61,7 @@ class LanguageRecognitionCAMPPlus(TorchModel):
self.emb_size = self.model_config['emb_size']
self.feature_dim = self.model_config['fbank_dim']
self.sample_rate = self.model_config['sample_rate']
self.device = create_device(kwargs['device'])
self.encoder = CAMPPlus(self.feature_dim, self.emb_size)
@@ -96,7 +97,9 @@ class LanguageRecognitionCAMPPlus(TorchModel):
features = []
for au in audio:
feature = Kaldi.fbank(
au.unsqueeze(0), num_mel_bins=self.feature_dim)
au.unsqueeze(0),
num_mel_bins=self.feature_dim,
sample_frequency=self.sample_rate)
feature = feature - feature.mean(dim=0, keepdim=True)
features.append(feature.unsqueeze(0))
features = torch.cat(features)

View File

@@ -7,10 +7,11 @@ from . import (action_recognition, animal_recognition, bad_image_detecting,
crowd_counting, face_detection, face_generation,
face_reconstruction, human_reconstruction, image_classification,
image_color_enhance, image_colorization, image_defrcn_fewshot,
image_denoise, image_inpainting, image_instance_segmentation,
image_matching, image_mvs_depth_estimation,
image_panoptic_segmentation, image_portrait_enhancement,
image_probing_model, image_quality_assessment_degradation,
image_denoise, image_editing, image_inpainting,
image_instance_segmentation, image_matching,
image_mvs_depth_estimation, image_panoptic_segmentation,
image_portrait_enhancement, image_probing_model,
image_quality_assessment_degradation,
image_quality_assessment_man, image_quality_assessment_mos,
image_reid_person, image_restoration,
image_semantic_segmentation, image_to_image_generation,
@@ -21,10 +22,11 @@ from . import (action_recognition, animal_recognition, bad_image_detecting,
referring_video_object_segmentation,
robust_image_classification, salient_detection,
shop_segmentation, stream_yolo, super_resolution,
table_recognition, video_deinterlace, video_frame_interpolation,
video_object_segmentation, video_panoptic_segmentation,
video_single_object_tracking, video_stabilization,
video_summarization, video_super_resolution, vidt, virual_tryon,
vision_middleware, vop_retrieval)
surface_recon_common, table_recognition, video_deinterlace,
video_frame_interpolation, video_object_segmentation,
video_panoptic_segmentation, video_single_object_tracking,
video_stabilization, video_summarization,
video_super_resolution, vidt, virual_tryon, vision_middleware,
vop_retrieval)
# yapf: enable

View File

@@ -0,0 +1,22 @@
# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.
from typing import TYPE_CHECKING
from modelscope.utils.import_utils import LazyImportModule
if TYPE_CHECKING:
from .human_image_generation_infer import FreqHPTForHumanImageGeneration
else:
_import_structure = {
'human_image_generation_infer': ['FreqHPTForHumanImageGeneration']
}
import sys
sys.modules[__name__] = LazyImportModule(
__name__,
globals()['__file__'],
_import_structure,
module_spec=__spec__,
extra_objects={},
)

View File

@@ -0,0 +1,717 @@
import collections
import math
import sys
import torch
from pytorch_wavelets import DWTForward, DWTInverse
from torch import kl_div, nn
from torch.nn import functional as F
from modelscope.ops.human_image_generation.fused_act import (FusedLeakyReLU,
fused_leaky_relu)
from modelscope.ops.human_image_generation.upfirdn2d import upfirdn2d
from .conv2d_gradfix import conv2d, conv_transpose2d
from .wavelet_module import *
# add flow
class ExtractionOperation_flow(nn.Module):
def __init__(self, in_channel, num_label, match_kernel):
super(ExtractionOperation_flow, self).__init__()
self.value_conv = EqualConv2d(
in_channel,
in_channel,
match_kernel,
1,
match_kernel // 2,
bias=True)
self.semantic_extraction_filter = EqualConv2d(
in_channel,
num_label,
match_kernel,
1,
match_kernel // 2,
bias=False)
self.softmax = nn.Softmax(dim=-1)
self.num_label = num_label
def forward(self, value, recoder):
key = value
b, c, h, w = value.shape
key = self.semantic_extraction_filter(self.feature_norm(key))
extraction_softmax = self.softmax(key.view(b, -1, h * w))
values_flatten = self.value_conv(value).view(b, -1, h * w)
neural_textures = torch.einsum('bkm,bvm->bvk', extraction_softmax,
values_flatten)
recoder['extraction_softmax'].insert(0, extraction_softmax)
recoder['neural_textures'].insert(0, neural_textures)
return neural_textures, extraction_softmax
def feature_norm(self, input_tensor):
input_tensor = input_tensor - input_tensor.mean(dim=1, keepdim=True)
norm = torch.norm(
input_tensor, 2, 1, keepdim=True) + sys.float_info.epsilon
out = torch.div(input_tensor, norm)
return out
class DistributionOperation_flow(nn.Module):
def __init__(self, num_label, input_dim, match_kernel=3):
super(DistributionOperation_flow, self).__init__()
self.semantic_distribution_filter = EqualConv2d(
input_dim,
num_label,
kernel_size=match_kernel,
stride=1,
padding=match_kernel // 2)
self.num_label = num_label
def forward(self, query, extracted_feature, recoder):
b, c, h, w = query.shape
query = self.semantic_distribution_filter(query)
query_flatten = query.view(b, self.num_label, -1)
query_softmax = F.softmax(query_flatten, 1)
values_q = torch.einsum('bkm,bkv->bvm', query_softmax,
extracted_feature.permute(0, 2, 1))
attn_out = values_q.view(b, -1, h, w)
recoder['semantic_distribution'].append(query)
return attn_out
class EncoderLayer_flow(nn.Sequential):
def __init__(self,
in_channel,
out_channel,
kernel_size,
downsample=False,
blur_kernel=[1, 3, 3, 1],
bias=True,
activate=True,
use_extraction=False,
num_label=None,
match_kernel=None,
num_extractions=2):
super().__init__()
if downsample:
factor = 2
p = (len(blur_kernel) - factor) + (kernel_size - 1)
pad0 = (p + 1) // 2
pad1 = p // 2
self.blur = Blur(blur_kernel, pad=(pad0, pad1))
stride = 2
padding = 0
else:
self.blur = None
stride = 1
padding = kernel_size // 2
self.conv = EqualConv2d(
in_channel,
out_channel,
kernel_size,
padding=padding,
stride=stride,
bias=bias and not activate,
)
self.activate = FusedLeakyReLU(
out_channel, bias=bias) if activate else None
self.use_extraction = use_extraction
if self.use_extraction:
self.extraction_operations = nn.ModuleList()
for _ in range(num_extractions):
self.extraction_operations.append(
ExtractionOperation_flow(out_channel, num_label,
match_kernel))
def forward(self, input, recoder=None):
out = self.blur(input) if self.blur is not None else input
out = self.conv(out)
out = self.activate(out) if self.activate is not None else out
if self.use_extraction:
for extraction_operation in self.extraction_operations:
extraction_operation(out, recoder)
return out
class DecoderLayer_flow_wavelet_fuse24(nn.Module):
# add fft refinement and tps
def __init__(
self,
in_channel,
out_channel,
kernel_size,
upsample=False,
blur_kernel=[1, 3, 3, 1],
bias=True,
activate=True,
use_distribution=True,
num_label=16,
match_kernel=3,
wavelet_down_level=False,
window_size=8,
):
super().__init__()
if upsample:
factor = 2
p = (len(blur_kernel) - factor) - (kernel_size - 1)
pad0 = (p + 1) // 2 + factor - 1
pad1 = p // 2 + 1
self.blur = Blur(
blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
self.conv = EqualTransposeConv2d(
in_channel,
out_channel,
kernel_size,
stride=2,
padding=0,
bias=bias and not activate,
)
else:
self.conv = EqualConv2d(
in_channel,
out_channel,
kernel_size,
stride=1,
padding=kernel_size // 2,
bias=bias and not activate,
)
self.blur = None
self.distribution_operation = DistributionOperation_flow(
num_label, out_channel,
match_kernel=match_kernel) if use_distribution else None
self.activate = FusedLeakyReLU(
out_channel, bias=bias) if activate else None
self.use_distribution = use_distribution
# mask prediction network
if use_distribution:
self.conv_mask_lf = nn.Sequential(*[
EqualConv2d(
out_channel, 1, 3, stride=1, padding=3 // 2, bias=False),
nn.Sigmoid()
])
self.conv_mask_dict = nn.ModuleDict()
for level in range(wavelet_down_level):
conv_mask = nn.Sequential(*[
EqualConv2d(
out_channel,
1,
3,
stride=1,
padding=3 // 2,
bias=False),
nn.Sigmoid()
])
self.conv_mask_dict[str(level)] = conv_mask
self.wavelet_down_level = wavelet_down_level
if wavelet_down_level:
self.dwt = DWTForward(
J=self.wavelet_down_level, mode='zero', wave='haar')
self.idwt = DWTInverse(mode='zero', wave='haar')
# for mask input channel squeeze and expand
self.conv_l_squeeze = EqualConv2d(
2 * out_channel, out_channel, 1, 1, 0, bias=False)
self.conv_h_squeeze = EqualConv2d(
6 * out_channel, out_channel, 1, 1, 0, bias=False)
self.conv_l = EqualConv2d(
out_channel, out_channel, 3, 1, 3 // 2, bias=False)
self.hf_modules = nn.ModuleDict()
for level in range(wavelet_down_level):
hf_module = nn.Module()
prev_channel = out_channel if level == self.wavelet_down_level - 1 else 3 * out_channel
hf_module.conv_prev = EqualConv2d(
prev_channel, 3 * out_channel, 3, 1, 3 // 2, bias=False)
hf_module.conv_hf = GatedConv2dWithActivation(
3 * out_channel, 3 * out_channel, 3, 1, 3 // 2, bias=False)
hf_module.conv_out = GatedConv2dWithActivation(
3 * out_channel, 3 * out_channel, 3, 1, 3 // 2, bias=False)
self.hf_modules[str(level)] = hf_module
self.amp_fuse = nn.Sequential(
EqualConv2d(2 * out_channel, out_channel, 1, 1, 0),
FusedLeakyReLU(out_channel, bias=False),
EqualConv2d(out_channel, out_channel, 1, 1, 0))
self.pha_fuse = nn.Sequential(
EqualConv2d(2 * out_channel, out_channel, 1, 1, 0),
FusedLeakyReLU(out_channel, bias=False),
EqualConv2d(out_channel, out_channel, 1, 1, 0))
self.post = EqualConv2d(out_channel, out_channel, 1, 1, 0)
self.eps = 1e-8
def forward(self,
input,
neural_texture=None,
recoder=None,
warped_texture=None,
style_net=None,
gstyle=None):
out = self.conv(input)
out = self.blur(out) if self.blur is not None else out
mask_l, mask_h = None, None
out_attn = None
if self.use_distribution and neural_texture is not None:
out_ori = out
out_attn = self.distribution_operation(out, neural_texture,
recoder)
# wavelet fusion
if self.wavelet_down_level:
assert out.shape[2] % 2 == 0, \
f'out shape {out.shape} is not appropriate for processing'
b, c, h, w = out.shape
# wavelet decomposition
LF_attn, HF_attn = self.dwt(out_attn)
LF_warp, HF_warp = self.dwt(warped_texture)
LF_out, HF_out = self.dwt(out)
# generate mask
hf_dict = {}
l_mask_input = torch.cat([LF_attn, LF_warp], dim=1)
l_mask_input = self.conv_l_squeeze(l_mask_input)
l_mask_input = style_net(l_mask_input, gstyle)
ml = self.conv_mask_lf(l_mask_input)
mask_l = ml
for level in range(self.wavelet_down_level):
# level up, feature size down
scale = 2**(level + 1)
hfa = HF_attn[level].view(b, c * 3, h // scale, w // scale)
hfw = HF_warp[level].view(b, c * 3, h // scale, w // scale)
hfg = HF_out[level].view(b, c * 3, h // scale, w // scale)
h_mask_input = torch.cat([hfa, hfw], dim=1)
h_mask_input = self.conv_h_squeeze(h_mask_input)
h_mask_input = style_net(h_mask_input, gstyle)
mh = self.conv_mask_dict[str(level)](h_mask_input)
if level == 0:
mask_h = mh
# fuse high frequency
xh = (mh * hfa + (1 - mh) * hfw + hfg) / math.sqrt(2)
hf_dict[str(level)] = xh
temp_result = (1 - ml) * LF_warp + LF_out
out_l = (ml * LF_attn + temp_result) / math.sqrt(2)
out_h_list = []
for level in range(self.wavelet_down_level - 1, -1, -1):
xh = hf_dict[str(level)]
b, c, h, w = xh.shape
out_h_list.append(xh.view(b, c // 3, 3, h, w))
out_h_list = (
out_h_list)[::-1] # the h list from large to small size
#
out = self.idwt((out_l, out_h_list))
else:
out = (out + out_attn) / math.sqrt(2)
# fourier refinement
_, _, H, W = out.shape
fuseF = torch.fft.rfft2(out + self.eps, norm='backward')
outF = torch.fft.rfft2(out_ori + self.eps, norm='backward')
amp = self.amp_fuse(
torch.cat([torch.abs(fuseF), torch.abs(outF)], 1))
pha = self.pha_fuse(
torch.cat(
[torch.angle(fuseF), torch.angle(outF)], 1))
out_fft = torch.fft.irfft2(
amp * torch.exp(1j * pha) + self.eps,
s=(H, W),
dim=(-2, -1),
norm='backward')
out = out + self.post(out_fft)
out = self.activate(
out.contiguous()) if self.activate is not None else out
return out, mask_h, mask_l
# base functions
class EqualConv2d(nn.Module):
def __init__(self,
in_channel,
out_channel,
kernel_size,
stride=1,
padding=0,
bias=True,
dilation=1):
super().__init__()
self.weight = nn.Parameter(
torch.randn(out_channel, in_channel, kernel_size, kernel_size))
self.scale = 1 / math.sqrt(in_channel * kernel_size**2)
self.stride = stride
self.padding = padding
self.dilation = dilation
if bias:
self.bias = nn.Parameter(torch.zeros(out_channel))
else:
self.bias = None
def forward(self, input):
out = conv2d(
input,
self.weight * self.scale,
bias=self.bias,
stride=self.stride,
padding=self.padding,
dilation=self.dilation)
return out
def __repr__(self):
return (
f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},'
f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'
)
class EqualTransposeConv2d(nn.Module):
def __init__(self,
in_channel,
out_channel,
kernel_size,
stride=1,
padding=0,
bias=True):
super().__init__()
self.weight = nn.Parameter(
torch.randn(out_channel, in_channel, kernel_size, kernel_size))
self.scale = 1 / math.sqrt(in_channel * kernel_size**2)
self.stride = stride
self.padding = padding
if bias:
self.bias = nn.Parameter(torch.zeros(out_channel))
else:
self.bias = None
def forward(self, input):
weight = self.weight.transpose(0, 1)
out = conv_transpose2d(
input,
weight * self.scale,
bias=self.bias,
stride=self.stride,
padding=self.padding,
)
return out
def __repr__(self):
return (
f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},'
f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'
)
class ToRGB(nn.Module):
def __init__(self, in_channel, upsample=True, blur_kernel=[1, 3, 3, 1]):
super().__init__()
if upsample:
self.upsample = Upsample(blur_kernel)
self.conv = EqualConv2d(in_channel, 3, 3, stride=1, padding=1)
def forward(self, input, skip=None):
out = self.conv(input)
if skip is not None:
skip = self.upsample(skip)
out = out + skip
return out
class EqualLinear(nn.Module):
def __init__(self,
in_dim,
out_dim,
bias=True,
bias_init=0,
lr_mul=1,
activation=None):
super().__init__()
self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
if bias:
self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
else:
self.bias = None
self.activation = activation
self.scale = (1 / math.sqrt(in_dim)) * lr_mul
self.lr_mul = lr_mul
def forward(self, input):
if self.activation:
out = F.linear(input, self.weight * self.scale)
out = fused_leaky_relu(out, self.bias * self.lr_mul)
else:
out = F.linear(
input, self.weight * self.scale, bias=self.bias * self.lr_mul)
return out
def __repr__(self):
return (
f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})'
)
class Upsample(nn.Module):
def __init__(self, kernel, factor=2):
super().__init__()
self.factor = factor
kernel = make_kernel(kernel) * (factor**2)
self.register_buffer('kernel', kernel)
p = kernel.shape[0] - factor
pad0 = (p + 1) // 2 + factor - 1
pad1 = p // 2
self.pad = (pad0, pad1)
def forward(self, input):
out = upfirdn2d(
input, self.kernel, up=self.factor, down=1, pad=self.pad)
return out
class ResBlock(nn.Module):
def __init__(self,
in_channel,
out_channel,
blur_kernel=[1, 3, 3, 1],
downsample=True):
super().__init__()
self.conv1 = ConvLayer(in_channel, in_channel, 3)
self.conv2 = ConvLayer(
in_channel, out_channel, 3, downsample=downsample)
self.skip = ConvLayer(
in_channel,
out_channel,
1,
downsample=downsample,
activate=False,
bias=False)
def forward(self, input):
out = self.conv1(input)
out = self.conv2(out)
skip = self.skip(input)
out = (out + skip) / math.sqrt(2)
return out
class ConvLayer(nn.Sequential):
def __init__(
self,
in_channel,
out_channel,
kernel_size,
downsample=False,
blur_kernel=[1, 3, 3, 1],
bias=True,
activate=True,
):
layers = []
if downsample:
factor = 2
p = (len(blur_kernel) - factor) + (kernel_size - 1)
pad0 = (p + 1) // 2
pad1 = p // 2
layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
stride = 2
self.padding = 0
else:
stride = 1
self.padding = kernel_size // 2
layers.append(
EqualConv2d(
in_channel,
out_channel,
kernel_size,
padding=self.padding,
stride=stride,
bias=bias and not activate,
))
if activate:
layers.append(FusedLeakyReLU(out_channel, bias=bias))
super().__init__(*layers)
class Blur(nn.Module):
def __init__(self, kernel, pad, upsample_factor=1):
super().__init__()
kernel = make_kernel(kernel)
if upsample_factor > 1:
kernel = kernel * (upsample_factor**2)
self.register_buffer('kernel', kernel)
self.pad = pad
def forward(self, input):
out = upfirdn2d(input, self.kernel, pad=self.pad)
return out
class GatedConv2dWithActivation(nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
bias=True,
activation=None):
super(GatedConv2dWithActivation, self).__init__()
self.activation = FusedLeakyReLU(out_channels, bias=False)
self.conv2d = EqualConv2d(in_channels, out_channels, kernel_size,
stride, padding, bias, dilation)
self.mask_conv2d = EqualConv2d(in_channels, out_channels, kernel_size,
stride, padding, bias, dilation)
self.sigmoid = nn.Sigmoid()
def gated(self, mask):
return self.sigmoid(mask)
def forward(self, input):
x = self.conv2d(input)
mask = self.mask_conv2d(input)
if self.activation is not None:
x = self.activation(x) * self.gated(mask)
else:
x = x * self.gated(mask)
return x
def make_kernel(k):
k = torch.tensor(k, dtype=torch.float32)
if k.ndim == 1:
k = k[None, :] * k[:, None]
k /= k.sum()
return k
class SPDNorm(nn.Module):
def __init__(self,
norm_channel,
label_nc,
norm_type='position',
use_equal=False):
super().__init__()
param_free_norm_type = norm_type
ks = 3
if param_free_norm_type == 'instance':
self.param_free_norm = nn.InstanceNorm2d(
norm_channel, affine=False)
elif param_free_norm_type == 'batch':
self.param_free_norm = nn.BatchNorm2d(norm_channel, affine=False)
elif param_free_norm_type == 'position':
self.param_free_norm = PositionalNorm2d
else:
raise ValueError(
'%s is not a recognized param-free norm type in SPADE'
% param_free_norm_type)
# The dimension of the intermediate embedding space. Yes, hardcoded.
pw = ks // 2
nhidden = 128
if not use_equal:
self.mlp_activate = nn.Sequential(
nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw),
nn.ReLU())
self.mlp_gamma = nn.Conv2d(
nhidden, norm_channel, kernel_size=ks, padding=pw)
self.mlp_beta = nn.Conv2d(
nhidden, norm_channel, kernel_size=ks, padding=pw)
else:
self.mlp_activate = nn.Sequential(*[
EqualConv2d(label_nc, nhidden, kernel_size=ks, padding=pw),
FusedLeakyReLU(nhidden, bias=False)
])
self.mlp_gamma = EqualConv2d(
nhidden, norm_channel, kernel_size=ks, padding=pw)
self.mlp_beta = EqualConv2d(
nhidden, norm_channel, kernel_size=ks, padding=pw)
def forward(self, x, prior_f, weight=1.0):
normalized = self.param_free_norm(x)
# Part 2. produce scaling and bias conditioned on condition feature
actv = self.mlp_activate(prior_f)
gamma = self.mlp_gamma(actv) * weight
beta = self.mlp_beta(actv) * weight
# apply scale and bias
out = normalized * (1 + gamma) + beta
return out
def PositionalNorm2d(x, epsilon=1e-5):
# x: B*C*W*H normalize in C dim
mean = x.mean(dim=1, keepdim=True)
std = x.var(dim=1, keepdim=True).add(epsilon).sqrt()
output = (x - mean) / std
return output

View File

@@ -0,0 +1,358 @@
import collections
import functools
import math
from tkinter.ttk import Style
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from .base_function import *
from .flow_module import MaskStyle, StyleFlow
from .tps import TPS
# adding flow version
class Encoder_wiflow(nn.Module):
def __init__(
self,
size,
input_dim,
channels,
num_labels=None,
match_kernels=None,
blur_kernel=[1, 3, 3, 1],
):
super().__init__()
self.first = EncoderLayer_flow(input_dim, channels[size], 1)
self.convs = nn.ModuleList()
self.num_labels = num_labels
self.match_kernels = match_kernels
log_size = int(math.log(size, 2))
self.log_size = log_size
in_channel = channels[size]
for i in range(log_size - 1, 3, -1):
out_channel = channels[2**i]
num_label = num_labels[2**i] if num_labels is not None else None
match_kernel = match_kernels[
2**i] if match_kernels is not None else None
use_extraction = num_label and match_kernel
conv = EncoderLayer_flow(
in_channel,
out_channel,
kernel_size=3,
downsample=True,
blur_kernel=blur_kernel,
use_extraction=use_extraction,
num_label=num_label,
match_kernel=match_kernel)
self.convs.append(conv)
in_channel = out_channel
def forward(self, input, recoder=None, out_list=None):
out = self.first(input)
for layer in self.convs:
out = layer(out, recoder)
if out_list is not None:
out_list.append(out)
return out
class Decoder_wiflow_wavelet_fuse25(nn.Module):
def __init__(
self,
size,
channels,
num_labels,
match_kernels,
blur_kernel=[1, 3, 3, 1],
wavelet_down_levels={'16': 3},
window_size=8,
):
super().__init__()
self.convs = nn.ModuleList()
# input at resolution 16*16
in_channel = channels[16]
self.log_size = int(math.log(size, 2))
self.conv_mask_dict = nn.ModuleDict()
self.conv_mask_fuse_dict = nn.ModuleDict()
flow_fusion = False
for i in range(4, self.log_size + 1):
out_channel = channels[2**i]
num_label, match_kernel = num_labels[2**i], match_kernels[2**i]
use_distribution = num_label and match_kernel
upsample = (i != 4)
wavelet_down_level = wavelet_down_levels[(2**i)]
base_layer = functools.partial(
DecoderLayer_flow_wavelet_fuse24,
out_channel=out_channel,
kernel_size=3,
blur_kernel=blur_kernel,
use_distribution=use_distribution,
num_label=num_label,
match_kernel=match_kernel,
wavelet_down_level=wavelet_down_level,
window_size=window_size)
# mask head for fusion
if use_distribution:
conv_mask = [
EqualConv2d(
2 * out_channel,
3,
3,
stride=1,
padding=3 // 2,
bias=False),
nn.Sigmoid()
]
conv_mask = nn.Sequential(*conv_mask)
self.conv_mask_dict[str(2**i)] = conv_mask
if not i == 4:
conv_mask_fuse = nn.Sequential(*[
EqualConv2d(
2, 1, 3, stride=1, padding=3 // 2, bias=False),
nn.Sigmoid()
])
self.conv_mask_fuse_dict[str(2**i)] = conv_mask_fuse
if not flow_fusion:
self.conv_flow_fusion = nn.Sequential(
EqualConv2d(
2 * out_channel,
1,
kernel_size=7,
stride=1,
padding=3,
bias=False), nn.Sigmoid())
flow_fusion = True
up = nn.Module()
up.conv0 = base_layer(in_channel=in_channel, upsample=upsample)
up.conv1 = base_layer(in_channel=out_channel, upsample=False)
up.to_rgb = ToRGB(out_channel, upsample=upsample)
self.convs.append(up)
in_channel = out_channel
style_in_channels = channels[16]
self.style_out_channel = 128
self.cond_style = nn.Sequential(
nn.Conv2d(
style_in_channels,
self.style_out_channel,
kernel_size=3,
stride=1,
padding=1), nn.LeakyReLU(inplace=False, negative_slope=0.1),
nn.AdaptiveAvgPool2d(1))
self.image_style = nn.Sequential(
nn.Conv2d(
style_in_channels,
self.style_out_channel,
kernel_size=3,
stride=1,
padding=1), nn.LeakyReLU(inplace=False, negative_slope=0.1),
nn.AdaptiveAvgPool2d(1))
self.flow_model = StyleFlow(
channels, self.log_size, style_in=2 * self.style_out_channel)
self.num_labels, self.match_kernels = num_labels, match_kernels
# for mask prediction
self.mask_style = MaskStyle(
channels,
self.log_size,
style_in=2 * self.style_out_channel,
channels_multiplier=1)
# tps transformation
self.tps = TPS()
def forward(self,
input,
neural_textures,
skeleton_features,
source_features,
kp_skeleton,
recoder,
add_nted=True):
source_features = source_features[::-1]
skeleton_features = skeleton_features[::-1]
counter = 0
out, skip = input, None
last_flow = None
mask_all_h, mask_all_l = [], []
delta_list = []
delta_x_all = []
delta_y_all = []
last_flow_all = []
filter_x = [[0, 0, 0], [1, -2, 1], [0, 0, 0]]
filter_y = [[0, 1, 0], [0, -2, 0], [0, 1, 0]]
filter_diag1 = [[1, 0, 0], [0, -2, 0], [0, 0, 1]]
filter_diag2 = [[0, 0, 1], [0, -2, 0], [1, 0, 0]]
weight_array = np.ones([3, 3, 1, 4])
weight_array[:, :, 0, 0] = filter_x
weight_array[:, :, 0, 1] = filter_y
weight_array[:, :, 0, 2] = filter_diag1
weight_array[:, :, 0, 3] = filter_diag2
weight_array = torch.FloatTensor(weight_array).permute(3, 2, 0, 1).to(
input.device)
self.weight = nn.Parameter(data=weight_array, requires_grad=False)
B = source_features[0].shape[0]
source_style = self.cond_style(source_features[0]).view(B, -1)
target_style = self.image_style(skeleton_features[0]).view(B, -1)
style = torch.cat([source_style, target_style], 1)
for i, up in enumerate(self.convs):
use_distribution = (
self.num_labels[2**(i + 4)] and self.match_kernels[2**(i + 4)])
if use_distribution:
# warp features with styleflow
source_feature = source_features[i]
skeleton_feature = skeleton_features[i]
if last_flow is not None:
last_flow = F.interpolate(
last_flow, scale_factor=2, mode='bilinear')
s_warp_after = F.grid_sample(
source_feature,
last_flow.detach().permute(0, 2, 3, 1),
mode='bilinear',
padding_mode='border')
else:
s_warp_after = source_feature
scale = str(2**(i + 4))
# use tps transformation to estimate flow at the very beginning
if last_flow is not None:
style_map = self.flow_model.netStyle[scale](s_warp_after,
style)
flow = self.flow_model.netF[scale](style_map, style)
flow = apply_offset(flow)
else:
style_map = self.flow_model.netStyle[scale](s_warp_after,
style)
flow = self.flow_model.netF[scale](style_map, style)
flow_dense = apply_offset(flow)
flow_tps = self.tps(source_feature, kp_skeleton)
warped_dense = F.grid_sample(
source_feature,
flow_dense,
mode='bilinear',
padding_mode='border')
warped_tps = F.grid_sample(
source_feature,
flow_tps,
mode='bilinear',
padding_mode='border')
contribution_map = self.conv_flow_fusion(
torch.cat([warped_dense, warped_tps], 1))
flow = contribution_map * flow_tps.permute(0, 3, 1, 2) + (
1 - contribution_map) * flow_dense.permute(0, 3, 1, 2)
flow = flow.permute(0, 2, 3, 1).contiguous()
if last_flow is not None:
# update flow according to the last scale flow
flow = F.grid_sample(
last_flow,
flow,
mode='bilinear',
padding_mode='border')
else:
flow = flow.permute(0, 3, 1, 2)
last_flow = flow
s_warp = F.grid_sample(
source_feature,
flow.permute(0, 2, 3, 1),
mode='bilinear',
padding_mode='border')
# refine flow according to the original flow
flow = self.flow_model.netRefine[scale](
torch.cat([s_warp, skeleton_feature], 1))
delta_list.append(flow)
flow = apply_offset(flow)
flow = F.grid_sample(
last_flow, flow, mode='bilinear', padding_mode='border')
last_flow_all.append(flow)
last_flow = flow
flow_x, flow_y = torch.split(last_flow, 1, dim=1)
delta_x = F.conv2d(flow_x, self.weight)
delta_y = F.conv2d(flow_y, self.weight)
delta_x_all.append(delta_x)
delta_y_all.append(delta_y)
s_warp = F.grid_sample(
source_feature,
last_flow.permute(0, 2, 3, 1),
mode='bilinear',
padding_mode='border')
# nted attention
neural_texture_conv0 = neural_textures[counter]
neural_texture_conv1 = neural_textures[counter + 1]
counter += 2
if not add_nted: # turn off the nted attention
neural_texture_conv0, neural_texture_conv1 = None, None
else:
neural_texture_conv0, neural_texture_conv1 = None, None
s_warp = None
mask_style_net = self.mask_style.netM[
scale] if use_distribution else None
out, mask_h, mask_l = up.conv0(
out,
neural_texture=neural_texture_conv0,
recoder=recoder,
warped_texture=s_warp,
style_net=mask_style_net,
gstyle=style)
out, mask_h, mask_l = up.conv1(
out,
neural_texture=neural_texture_conv1,
recoder=recoder,
warped_texture=s_warp,
style_net=mask_style_net,
gstyle=style)
if use_distribution:
if mask_h is not None:
mask_all_h.append(mask_h)
if mask_l is not None:
mask_all_l.append(mask_l)
skip = up.to_rgb(out, skip)
image = skip
return image, delta_x_all, delta_y_all, delta_list, last_flow_all, mask_all_h, mask_all_l
def apply_offset(offset):
sizes = list(offset.size()[2:])
grid_list = torch.meshgrid(
[torch.arange(size, device=offset.device) for size in sizes])
grid_list = reversed(grid_list)
# apply offset
grid_list = [
grid.float().unsqueeze(0) + offset[:, dim, ...]
for dim, grid in enumerate(grid_list)
]
# normalize
grid_list = [
grid / ((size - 1.0) / 2.0) - 1.0
for grid, size in zip(grid_list, reversed(sizes))
]
return torch.stack(grid_list, dim=-1)

View File

@@ -0,0 +1,227 @@
import contextlib
import warnings
import torch
from torch import autograd
from torch.nn import functional as F
enabled = True
weight_gradients_disabled = False
@contextlib.contextmanager
def no_weight_gradients():
global weight_gradients_disabled
old = weight_gradients_disabled
weight_gradients_disabled = True
yield
weight_gradients_disabled = old
def conv2d(input,
weight,
bias=None,
stride=1,
padding=0,
dilation=1,
groups=1):
if could_use_op(input):
return conv2d_gradfix(
transpose=False,
weight_shape=weight.shape,
stride=stride,
padding=padding,
output_padding=0,
dilation=dilation,
groups=groups,
).apply(input, weight, bias)
return F.conv2d(
input=input,
weight=weight,
bias=bias,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
)
def conv_transpose2d(
input,
weight,
bias=None,
stride=1,
padding=0,
output_padding=0,
groups=1,
dilation=1,
):
if could_use_op(input):
return conv2d_gradfix(
transpose=True,
weight_shape=weight.shape,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation,
).apply(input, weight, bias)
return F.conv_transpose2d(
input=input,
weight=weight,
bias=bias,
stride=stride,
padding=padding,
output_padding=output_padding,
dilation=dilation,
groups=groups,
)
def could_use_op(input):
if (not enabled) or (not torch.backends.cudnn.enabled):
return False
if input.device.type != 'cuda':
return False
if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.']):
return True
warnings.warn(
f'conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d().'
)
return False
def ensure_tuple(xs, ndim):
xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs, ) * ndim
return xs
conv2d_gradfix_cache = dict()
def conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding,
dilation, groups):
ndim = 2
weight_shape = tuple(weight_shape)
stride = ensure_tuple(stride, ndim)
padding = ensure_tuple(padding, ndim)
output_padding = ensure_tuple(output_padding, ndim)
dilation = ensure_tuple(dilation, ndim)
key = (transpose, weight_shape, stride, padding, output_padding, dilation,
groups)
if key in conv2d_gradfix_cache:
return conv2d_gradfix_cache[key]
common_kwargs = dict(
stride=stride, padding=padding, dilation=dilation, groups=groups)
def calc_output_padding(input_shape, output_shape):
if transpose:
return [0, 0]
shape1 = (output_shape[i + 2] - 1) * stride[i]
shape2 = (1 - 2 * padding[i]) - dilation[i] * (weight_shape[i + 2] - 1)
return [input_shape[i + 2] - shape1 - shape2 for i in range(ndim)]
class Conv2d(autograd.Function):
@staticmethod
def forward(ctx, input, weight, bias):
if not transpose:
out = F.conv2d(
input=input, weight=weight, bias=bias, **common_kwargs)
else:
out = F.conv_transpose2d(
input=input,
weight=weight,
bias=bias,
output_padding=output_padding,
**common_kwargs,
)
ctx.save_for_backward(input, weight)
return out
@staticmethod
def backward(ctx, grad_output):
input, weight = ctx.saved_tensors
grad_input, grad_weight, grad_bias = None, None, None
if ctx.needs_input_grad[0]:
p = calc_output_padding(
input_shape=input.shape, output_shape=grad_output.shape)
grad_input = conv2d_gradfix(
transpose=(not transpose),
weight_shape=weight_shape,
output_padding=p,
**common_kwargs,
).apply(grad_output, weight, None)
if ctx.needs_input_grad[1] and not weight_gradients_disabled:
grad_weight = Conv2dGradWeight.apply(grad_output, input)
if ctx.needs_input_grad[2]:
grad_bias = grad_output.sum((0, 2, 3))
return grad_input, grad_weight, grad_bias
class Conv2dGradWeight(autograd.Function):
@staticmethod
def forward(ctx, grad_output, input):
op = torch._C._jit_get_operation(
'aten::cudnn_convolution_backward_weight' if not transpose else
'aten::cudnn_convolution_transpose_backward_weight')
flags = [
torch.backends.cudnn.benchmark,
torch.backends.cudnn.deterministic,
torch.backends.cudnn.allow_tf32,
]
grad_weight = op(
weight_shape,
grad_output,
input,
padding,
stride,
dilation,
groups,
*flags,
)
ctx.save_for_backward(grad_output, input)
return grad_weight
@staticmethod
def backward(ctx, grad_grad_weight):
grad_output, input = ctx.saved_tensors
grad_grad_output, grad_grad_input = None, None
if ctx.needs_input_grad[0]:
grad_grad_output = Conv2d.apply(input, grad_grad_weight, None)
if ctx.needs_input_grad[1]:
p = calc_output_padding(
input_shape=input.shape, output_shape=grad_output.shape)
grad_grad_input = conv2d_gradfix(
transpose=(not transpose),
weight_shape=weight_shape,
output_padding=p,
**common_kwargs,
).apply(grad_output, grad_grad_weight, None)
return grad_grad_output, grad_grad_input
conv2d_gradfix_cache[key] = Conv2d
return Conv2d

View File

@@ -0,0 +1,64 @@
import collections
import os
import sys
import torch
from torch import nn
from .base_module import *
sys.path.insert(0, os.path.abspath(os.path.dirname(os.path.dirname(__file__))))
class Generator(nn.Module):
def __init__(
self,
size,
semantic_dim,
channels,
num_labels,
match_kernels,
blur_kernel=[1, 3, 3, 1],
wavelet_down_levels={'16': 3},
window_size=8,
):
super().__init__()
self.size = size
self.reference_encoder = Encoder_wiflow(size, 3, channels, num_labels,
match_kernels, blur_kernel)
self.skeleton_encoder = Encoder_wiflow(
size,
semantic_dim,
channels,
)
self.target_image_renderer = Decoder_wiflow_wavelet_fuse25(
size, channels, num_labels, match_kernels, blur_kernel,
wavelet_down_levels, window_size)
def _cal_temp(self, module):
return sum(p.numel() for p in module.parameters() if p.requires_grad)
def forward(self, source_image, skeleton, kp_skeleton):
output_dict = {}
recoder = collections.defaultdict(list)
skeleton_feature_list, source_feature_list = [], []
skeleton_feature = self.skeleton_encoder(
skeleton, out_list=skeleton_feature_list)
_ = self.reference_encoder(
source_image, recoder, out_list=source_feature_list)
neural_textures = recoder['neural_textures']
output_dict['fake_image'], delta_x_all, delta_y_all, delta_list, last_flow_all, mask_all_h, mask_all_l = \
self.target_image_renderer(skeleton_feature, neural_textures, skeleton_feature_list,
source_feature_list, kp_skeleton, recoder)
output_dict['info'] = recoder
output_dict['delta_x'] = delta_x_all
output_dict['delta_y'] = delta_y_all
output_dict['delta_list'] = delta_list
output_dict['last_flow_all'] = last_flow_all
output_dict['mask_all_h'] = mask_all_h
output_dict['mask_all_l'] = mask_all_l
return output_dict

View File

@@ -0,0 +1,346 @@
from math import sqrt
import torch
import torch.nn as nn
import torch.nn.functional as F
from .base_function import EqualConv2d, EqualLinear
def TVLoss(x):
tv_h = x[:, :, 1:, :] - x[:, :, :-1, :]
tv_w = x[:, :, :, 1:] - x[:, :, :, :-1]
return torch.mean(torch.abs(tv_h)) + torch.mean(torch.abs(tv_w))
class MaskStyle(nn.Module):
def __init__(self, channels, log_size, style_in, channels_multiplier=2):
super().__init__()
self.log_size = log_size
padding_type = 'zero'
actvn = 'lrelu'
normalize_mlp = False
modulated_conv = True
self.netM = nn.ModuleDict()
for i in range(4, self.log_size + 1):
out_channel = channels[2**i]
style_mask = StyledConvBlock(
channels_multiplier * out_channel,
channels_multiplier * out_channel,
latent_dim=style_in,
padding=padding_type,
actvn=actvn,
normalize_affine_output=normalize_mlp,
modulated_conv=modulated_conv)
scale = str(2**i)
self.netM[scale] = style_mask
class StyleFlow(nn.Module):
def __init__(self, channels, log_size, style_in):
super().__init__()
self.log_size = log_size
padding_type = 'zero'
actvn = 'lrelu'
normalize_mlp = False
modulated_conv = True
self.netRefine = nn.ModuleDict()
self.netStyle = nn.ModuleDict()
self.netF = nn.ModuleDict()
for i in range(4, self.log_size + 1):
out_channel = channels[2**i]
netRefine_layer = torch.nn.Sequential(
torch.nn.Conv2d(
2 * out_channel,
out_channels=128,
kernel_size=3,
stride=1,
padding=1),
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
torch.nn.Conv2d(
in_channels=128,
out_channels=64,
kernel_size=3,
stride=1,
padding=1),
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
torch.nn.Conv2d(
in_channels=64,
out_channels=32,
kernel_size=3,
stride=1,
padding=1),
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
torch.nn.Conv2d(
in_channels=32,
out_channels=2,
kernel_size=3,
stride=1,
padding=1))
style_block = StyledConvBlock(
out_channel,
49,
latent_dim=style_in,
padding=padding_type,
actvn=actvn,
normalize_affine_output=normalize_mlp,
modulated_conv=modulated_conv)
style_F_block = Styled_F_ConvBlock(
49,
2,
latent_dim=style_in,
padding=padding_type,
actvn=actvn,
normalize_affine_output=normalize_mlp,
modulated_conv=modulated_conv)
scale = str(2**i)
self.netRefine[scale] = (netRefine_layer)
self.netStyle[scale] = (style_block)
self.netF[scale] = (style_F_block)
class StyledConvBlock(nn.Module):
def __init__(self,
fin,
fout,
latent_dim=256,
padding='zero',
actvn='lrelu',
normalize_affine_output=False,
modulated_conv=False):
super(StyledConvBlock, self).__init__()
if not modulated_conv:
if padding == 'reflect':
padding_layer = nn.ReflectionPad2d
else:
padding_layer = nn.ZeroPad2d
if modulated_conv:
conv2d = ModulatedConv2d
else:
conv2d = EqualConv2d
if modulated_conv:
self.actvn_gain = sqrt(2)
else:
self.actvn_gain = 1.0
self.modulated_conv = modulated_conv
if actvn == 'relu':
activation = nn.ReLU(True)
else:
activation = nn.LeakyReLU(0.2, True)
if self.modulated_conv:
self.conv0 = conv2d(
fin,
fout,
kernel_size=3,
padding_type=padding,
upsample=False,
latent_dim=latent_dim,
normalize_mlp=normalize_affine_output)
else:
conv0 = conv2d(fin, fout, kernel_size=3)
seq0 = [padding_layer(1), conv0]
self.conv0 = nn.Sequential(*seq0)
self.actvn0 = activation
if self.modulated_conv:
self.conv1 = conv2d(
fout,
fout,
kernel_size=3,
padding_type=padding,
downsample=False,
latent_dim=latent_dim,
normalize_mlp=normalize_affine_output)
else:
conv1 = conv2d(fout, fout, kernel_size=3)
seq1 = [padding_layer(1), conv1]
self.conv1 = nn.Sequential(*seq1)
self.actvn1 = activation
def forward(self, input, latent=None):
if self.modulated_conv:
out = self.conv0(input, latent)
else:
out = self.conv0(input)
out = self.actvn0(out) * self.actvn_gain
if self.modulated_conv:
out = self.conv1(out, latent)
else:
out = self.conv1(out)
out = self.actvn1(out) * self.actvn_gain
return out
class Styled_F_ConvBlock(nn.Module):
def __init__(self,
fin,
fout,
latent_dim=256,
padding='zero',
actvn='lrelu',
normalize_affine_output=False,
modulated_conv=False):
super(Styled_F_ConvBlock, self).__init__()
if not modulated_conv:
if padding == 'reflect':
padding_layer = nn.ReflectionPad2d
else:
padding_layer = nn.ZeroPad2d
if modulated_conv:
conv2d = ModulatedConv2d
else:
conv2d = EqualConv2d
if modulated_conv:
self.actvn_gain = sqrt(2)
else:
self.actvn_gain = 1.0
self.modulated_conv = modulated_conv
if actvn == 'relu':
activation = nn.ReLU(True)
else:
activation = nn.LeakyReLU(0.2, True)
if self.modulated_conv:
self.conv0 = conv2d(
fin,
128,
kernel_size=3,
padding_type=padding,
upsample=False,
latent_dim=latent_dim,
normalize_mlp=normalize_affine_output)
else:
conv0 = conv2d(fin, 128, kernel_size=3)
seq0 = [padding_layer(1), conv0]
self.conv0 = nn.Sequential(*seq0)
self.actvn0 = activation
if self.modulated_conv:
self.conv1 = conv2d(
128,
fout,
kernel_size=3,
padding_type=padding,
downsample=False,
latent_dim=latent_dim,
normalize_mlp=normalize_affine_output)
else:
conv1 = conv2d(128, fout, kernel_size=3)
seq1 = [padding_layer(1), conv1]
self.conv1 = nn.Sequential(*seq1)
def forward(self, input, latent=None):
if self.modulated_conv:
out = self.conv0(input, latent)
else:
out = self.conv0(input)
out = self.actvn0(out) * self.actvn_gain
if self.modulated_conv:
out = self.conv1(out, latent)
else:
out = self.conv1(out)
return out
class ModulatedConv2d(nn.Module):
def __init__(self,
fin,
fout,
kernel_size,
padding_type='zero',
upsample=False,
downsample=False,
latent_dim=512,
normalize_mlp=False):
super(ModulatedConv2d, self).__init__()
self.in_channels = fin
self.out_channels = fout
self.kernel_size = kernel_size
padding_size = kernel_size // 2
if kernel_size == 1:
self.demudulate = False
else:
self.demudulate = True
self.weight = nn.Parameter(
torch.Tensor(fout, fin, kernel_size, kernel_size))
self.bias = nn.Parameter(torch.Tensor(1, fout, 1, 1))
if normalize_mlp:
self.mlp_class_std = nn.Sequential(
EqualLinear(latent_dim, fin), PixelNorm())
else:
self.mlp_class_std = EqualLinear(latent_dim, fin)
if padding_type == 'reflect':
self.padding = nn.ReflectionPad2d(padding_size)
else:
self.padding = nn.ZeroPad2d(padding_size)
self.weight.data.normal_()
self.bias.data.zero_()
def forward(self, input, latent):
fan_in = self.weight.data.size(1) * self.weight.data[0][0].numel()
weight = self.weight * sqrt(2 / fan_in)
weight = weight.view(1, self.out_channels, self.in_channels,
self.kernel_size, self.kernel_size)
s = self.mlp_class_std(latent).view(-1, 1, self.in_channels, 1, 1)
weight = s * weight
if self.demudulate:
d = torch.rsqrt((weight**2).sum(4).sum(3).sum(2) + 1e-5).view(
-1, self.out_channels, 1, 1, 1)
weight = (d * weight).view(-1, self.in_channels, self.kernel_size,
self.kernel_size)
else:
weight = weight.view(-1, self.in_channels, self.kernel_size,
self.kernel_size)
batch, _, height, width = input.shape
input = input.reshape(1, -1, height, width)
input = self.padding(input)
out = F.conv2d(
input, weight, groups=batch).view(batch, self.out_channels, height,
width) + self.bias
return out

View File

@@ -0,0 +1,121 @@
import torch
import torch.nn.functional as F
from torch import nn
class TPS(nn.Module):
def __init__(self, mode='kp'):
super().__init__()
self.mode = mode
def trans(self, kp_1):
if self.mode == 'kp':
device = kp_1.device
kp_type = kp_1.type()
self.gs = kp_1.shape[1]
n = kp_1.shape[2]
K = torch.norm(
kp_1[:, :, :, None] - kp_1[:, :, None, :], dim=4, p=2)
K = K**2
K = K * torch.log(K + 1e-9)
one1 = torch.ones(self.bs, kp_1.shape[1], kp_1.shape[2],
1).to(device).type(kp_type)
kp_1p = torch.cat([kp_1, one1], 3)
zero = torch.zeros(self.bs, kp_1.shape[1], 3,
3).to(device).type(kp_type)
P = torch.cat([kp_1p, zero], 2)
L = torch.cat([K, kp_1p.permute(0, 1, 3, 2)], 2)
L = torch.cat([L, P], 3)
zero = torch.zeros(self.bs, kp_1.shape[1], 3,
2).to(device).type(kp_type)
kp_substitute = torch.zeros(kp_1.shape).to(device).type(kp_type)
Y = torch.cat([kp_substitute, zero], 2)
one = torch.eye(L.shape[2]).expand(
L.shape).to(device).type(kp_type) * 0.01
L = L + one
param = torch.matmul(torch.inverse(L), Y)
self.theta = param[:, :, n:, :].permute(0, 1, 3, 2)
self.control_points = kp_1
self.control_params = param[:, :, :n, :]
else:
raise Exception('Error TPS mode')
def transform_frame(self, frame):
grid = make_coordinate_grid(
frame.shape[2:], type=frame.type()).unsqueeze(0).to(frame.device)
grid = grid.view(1, frame.shape[2] * frame.shape[3], 2)
shape = [self.bs, frame.shape[2], frame.shape[3], 2]
if self.mode == 'kp':
shape.insert(1, self.gs)
grid = self.warp_coordinates(grid).view(*shape)
return grid
def warp_coordinates(self, coordinates):
theta = self.theta.type(coordinates.type()).to(coordinates.device)
control_points = self.control_points.type(coordinates.type()).to(
coordinates.device)
control_params = self.control_params.type(coordinates.type()).to(
coordinates.device)
if self.mode == 'kp':
transformed = torch.matmul(theta[:, :, :, :2],
coordinates.permute(
0, 2, 1)) + theta[:, :, :, 2:]
distances = coordinates.view(
coordinates.shape[0], 1, 1, -1, 2) - control_points.view(
self.bs, control_points.shape[1], -1, 1, 2)
distances = distances**2
result = distances.sum(-1)
result = result * torch.log(result + 1e-9)
result = torch.matmul(result.permute(0, 1, 3, 2), control_params)
transformed = transformed.permute(0, 1, 3, 2) + result
else:
raise Exception('Error TPS mode')
return transformed
def preprocess_kp(self, kp_1):
'''
kp_1: (b, ntps*nkp, 2)
'''
kp_mask = (kp_1 == -1)
num_keypoints = kp_1.shape[1]
kp_1 = kp_1.masked_fill(kp_mask, -1.)
return kp_1, num_keypoints
def forward(self, source_image, kp_driving):
bs, _, h, w = source_image.shape
self.bs = bs
kp_driving, num_keypoints = self.preprocess_kp(kp_driving)
kp_1 = kp_driving.view(bs, -1, num_keypoints, 2)
self.trans(kp_1)
grid = self.transform_frame(source_image)
grid = grid.view(bs, h, w, 2)
return grid
def make_coordinate_grid(spatial_size, type):
"""
Create a meshgrid [-1,1] x [-1,1] of given spatial_size.
"""
h, w = spatial_size
x = torch.arange(w).type(type)
y = torch.arange(h).type(type)
x = (2 * (x / (w - 1)) - 1)
y = (2 * (y / (h - 1)) - 1)
yy = y.view(-1, 1).repeat(1, w)
xx = x.view(1, -1).repeat(h, 1)
meshed = torch.cat([xx.unsqueeze_(2), yy.unsqueeze_(2)], 2)
return meshed

View File

@@ -0,0 +1,182 @@
import numpy as np
import torch
import torch.nn as nn
def get_wav(in_channels, pool=True):
harr_wav_L = 1 / np.sqrt(2) * np.ones((1, 2))
harr_wav_H = 1 / np.sqrt(2) * np.ones((1, 2))
harr_wav_H[0, 0] = -1 * harr_wav_H[0, 0]
harr_wav_LL = np.transpose(harr_wav_L) * harr_wav_L
harr_wav_LH = np.transpose(harr_wav_L) * harr_wav_H
harr_wav_HL = np.transpose(harr_wav_H) * harr_wav_L
harr_wav_HH = np.transpose(harr_wav_H) * harr_wav_H
filter_LL = torch.from_numpy(harr_wav_LL).unsqueeze(0)
filter_LH = torch.from_numpy(harr_wav_LH).unsqueeze(0)
filter_HL = torch.from_numpy(harr_wav_HL).unsqueeze(0)
filter_HH = torch.from_numpy(harr_wav_HH).unsqueeze(0)
if pool:
net = nn.Conv2d
else:
net = nn.ConvTranspose2d
LL = net(
in_channels,
in_channels * 2,
kernel_size=2,
stride=2,
padding=0,
bias=False,
groups=in_channels)
LH = net(
in_channels,
in_channels * 2,
kernel_size=2,
stride=2,
padding=0,
bias=False,
groups=in_channels)
HL = net(
in_channels,
in_channels * 2,
kernel_size=2,
stride=2,
padding=0,
bias=False,
groups=in_channels)
HH = net(
in_channels,
in_channels * 2,
kernel_size=2,
stride=2,
padding=0,
bias=False,
groups=in_channels)
LL.weight.requires_grad = False
LH.weight.requires_grad = False
HL.weight.requires_grad = False
HH.weight.requires_grad = False
LL.weight.data = filter_LL.float().unsqueeze(0).expand(
in_channels * 2, -1, -1, -1)
LH.weight.data = filter_LH.float().unsqueeze(0).expand(
in_channels * 2, -1, -1, -1)
HL.weight.data = filter_HL.float().unsqueeze(0).expand(
in_channels * 2, -1, -1, -1)
HH.weight.data = filter_HH.float().unsqueeze(0).expand(
in_channels * 2, -1, -1, -1)
return LL, LH, HL, HH
class WavePool(nn.Module):
def __init__(self, in_channels):
super(WavePool, self).__init__()
self.LL, self.LH, self.HL, self.HH = get_wav(in_channels)
def forward(self, x):
return self.LL(x), self.LH(x), self.HL(x), self.HH(x)
def get_wav_two(in_channels, out_channels=None, pool=True):
"""wavelet decomposition using conv2d"""
harr_wav_L = 1 / np.sqrt(2) * np.ones((1, 2))
harr_wav_H = 1 / np.sqrt(2) * np.ones((1, 2))
harr_wav_H[0, 0] = -1 * harr_wav_H[0, 0]
harr_wav_LL = np.transpose(harr_wav_L) * harr_wav_L
harr_wav_LH = np.transpose(harr_wav_L) * harr_wav_H
harr_wav_HL = np.transpose(harr_wav_H) * harr_wav_L
harr_wav_HH = np.transpose(harr_wav_H) * harr_wav_H
filter_LL = torch.from_numpy(harr_wav_LL).unsqueeze(0)
filter_LH = torch.from_numpy(harr_wav_LH).unsqueeze(0)
filter_HL = torch.from_numpy(harr_wav_HL).unsqueeze(0)
filter_HH = torch.from_numpy(harr_wav_HH).unsqueeze(0)
if pool:
net = nn.Conv2d
else:
net = nn.ConvTranspose2d
if out_channels is None:
out_channels = in_channels
LL = net(
in_channels,
out_channels,
kernel_size=2,
stride=2,
padding=0,
bias=False,
groups=in_channels)
LH = net(
in_channels,
out_channels,
kernel_size=2,
stride=2,
padding=0,
bias=False,
groups=in_channels)
HL = net(
in_channels,
out_channels,
kernel_size=2,
stride=2,
padding=0,
bias=False,
groups=in_channels)
HH = net(
in_channels,
out_channels,
kernel_size=2,
stride=2,
padding=0,
bias=False,
groups=in_channels)
LL.weight.requires_grad = False
LH.weight.requires_grad = False
HL.weight.requires_grad = False
HH.weight.requires_grad = False
LL.weight.data = filter_LL.float().unsqueeze(0).expand(
in_channels, -1, -1, -1)
LH.weight.data = filter_LH.float().unsqueeze(0).expand(
in_channels, -1, -1, -1)
HL.weight.data = filter_HL.float().unsqueeze(0).expand(
in_channels, -1, -1, -1)
HH.weight.data = filter_HH.float().unsqueeze(0).expand(
in_channels, -1, -1, -1)
return LL, LH, HL, HH
class WavePool2(nn.Module):
def __init__(self, in_channels, out_channels=None):
super(WavePool2, self).__init__()
self.LL, self.LH, self.HL, self.HH = get_wav_two(
in_channels, out_channels)
def forward(self, x):
return self.LL(x), self.LH(x), self.HL(x), self.HH(x)
class WaveUnpool(nn.Module):
def __init__(self, in_channels, out_channels=None, option_unpool='cat5'):
super(WaveUnpool, self).__init__()
self.in_channels = in_channels
self.option_unpool = option_unpool
self.LL, self.LH, self.HL, self.HH = get_wav_two(
self.in_channels, out_channels, pool=False)
def forward(self, LL, LH, HL, HH, original=None):
if self.option_unpool == 'sum':
return self.LL(LL) + self.LH(LH) + self.HL(HL) + self.HH(HH)
elif self.option_unpool == 'cat5' and original is not None:
return torch.cat(
[self.LL(LL),
self.LH(LH),
self.HL(HL),
self.HH(HH), original],
dim=1)
else:
raise NotImplementedError

View File

@@ -0,0 +1,268 @@
# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.
import math
import random
from ast import Global
from pickle import GLOBAL
import cv2
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
import torchvision.transforms.functional as F
from PIL import Image
from modelscope.metainfo import Models
from modelscope.models.base import TorchModel
from modelscope.models.builder import MODELS
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger
from .generators.extraction_distribution_model_flow25 import \
Generator as Generator
tv_version = int(torchvision.__version__.split('.')[1])
if tv_version > 8:
from torchvision.transforms.functional import InterpolationMode
resize_method = InterpolationMode.BICUBIC
resize_nearest = InterpolationMode.NEAREST
else:
resize_method = Image.BICUBIC
resize_nearest = Image.NEAREST
logger = get_logger()
def get_random_params(size, scale_param, use_flip=False):
w, h = size
scale = random.random() * scale_param
if use_flip:
use_flip = random.random() > 0.9
new_w = int(w * (1.0 + scale))
new_h = int(h * (1.0 + scale))
x = random.randint(0, np.maximum(0, new_w - w))
y = random.randint(0, np.maximum(0, new_h - h))
return {
'crop_param': (x, y, w, h),
'scale_size': (new_h, new_w),
'use_flip': use_flip
}
def get_transform(param, method=resize_method, normalize=True, toTensor=True):
transform_list = []
if 'scale_size' in param and param['scale_size'] is not None:
osize = param['scale_size']
transform_list.append(transforms.Resize(osize, interpolation=method))
if 'crop_param' in param and param['crop_param'] is not None:
transform_list.append(
transforms.Lambda(lambda img: __crop(img, param['crop_param'])))
if param['use_flip']:
transform_list.append(transforms.Lambda(lambda img: __flip(img)))
if toTensor:
transform_list += [transforms.ToTensor()]
if normalize:
transform_list += [
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
]
return transforms.Compose(transform_list)
def __crop(img, pos):
x1, y1, tw, th = pos
return img.crop((x1, y1, x1 + tw, y1 + th))
def __flip(img):
return F.hflip(img)
def normalize():
return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
def load_checkpoint(model, checkpoint_path, device):
params = torch.load(checkpoint_path, map_location=device)
if 'target_image_renderer.weight' in params['net_G_ema'].keys():
params['net_G_ema'].pop('target_image_renderer.weight')
model.load_state_dict(params['net_G_ema'])
model.to(device)
model.eval()
return model
@MODELS.register_module(
Tasks.human_image_generation, module_name=Models.human_image_generation)
class FreqHPTForHumanImageGeneration(TorchModel):
"""initialize the human image generation model from the `model_dir` path.
Args:
model_dir (str): the model path.
"""
def __init__(self, model_dir, device_id=0, *args, **kwargs):
super().__init__(
model_dir=model_dir, device_id=device_id, *args, **kwargs)
if torch.cuda.is_available():
self.device = 'cuda'
logger.info('Use GPU')
else:
self.device = 'cpu'
logger.info('Use CPU')
size = 512
semantic_dim = 20
channels = {
16: 256,
32: 256,
64: 256,
128: 128,
256: 128,
512: 64,
1024: 32
}
num_labels = {16: 16, 32: 32, 64: 64, 128: 64, 256: 64, 512: False}
match_kernels = {16: False, 32: 3, 64: 3, 128: 3, 256: 3, 512: False}
wavelet_down_levels = {16: False, 32: 1, 64: 2, 128: 3, 256: 3, 512: 3}
self.model = Generator(
size,
semantic_dim,
channels,
num_labels,
match_kernels,
wavelet_down_levels=wavelet_down_levels)
self.model = load_checkpoint(
self.model, model_dir + '/' + ModelFile.TORCH_MODEL_BIN_FILE,
self.device)
def forward(self, x, y, z):
pred_result = self.model(x, y, z)
return pred_result
def trans_keypoins(keypoints, param, img_size, offset=None):
missing_keypoint_index = keypoints == -1
# crop the white line in the original dataset
if not offset == 40:
keypoints[:, 0] = (keypoints[:, 0] - 40)
# resize the dataset
img_h, img_w = img_size
scale_w = 1.0 / 176.0 * img_w
scale_h = 1.0 / 256.0 * img_h
if 'scale_size' in param and param['scale_size'] is not None:
new_h, new_w = param['scale_size']
scale_w = scale_w / img_w * new_w
scale_h = scale_h / img_h * new_h
if 'crop_param' in param and param['crop_param'] is not None:
w, h, _, _ = param['crop_param']
else:
w, h = 0, 0
keypoints[:, 0] = keypoints[:, 0] * scale_w - w
keypoints[:, 1] = keypoints[:, 1] * scale_h - h
normalized_kp = keypoints.copy()
normalized_kp[:, 0] = (normalized_kp[:, 0]) / img_w * 2 - 1
normalized_kp[:, 1] = (normalized_kp[:, 1]) / img_h * 2 - 1
normalized_kp[missing_keypoint_index] = -1
keypoints[missing_keypoint_index] = -1
return keypoints, normalized_kp
def get_label_tensor(path, img, param):
limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10],
[10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15],
[15, 17], [1, 16], [16, 18], [3, 17], [6, 18]]
colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0],
[170, 255, 0], [85, 255, 0], [0, 255, 0], [0, 255, 85],
[0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255],
[0, 0, 255], [85, 0, 255], [170, 0, 255], [255, 0, 255],
[255, 0, 170], [255, 0, 85]]
canvas = np.zeros((img.shape[1], img.shape[2], 3)).astype(np.uint8)
keypoint = np.loadtxt(path)
keypoint, normalized_kp = trans_keypoins(keypoint, param, img.shape[1:])
stickwidth = 4
for i in range(18):
x, y = keypoint[i, 0:2]
if x == -1 or y == -1:
continue
cv2.circle(canvas, (int(x), int(y)), 4, colors[i], thickness=-1)
joints = []
for i in range(17):
Y = keypoint[np.array(limbSeq[i]) - 1, 0]
X = keypoint[np.array(limbSeq[i]) - 1, 1]
cur_canvas = canvas.copy()
if -1 in Y or -1 in X:
joints.append(np.zeros_like(cur_canvas[:, :, 0]))
continue
mX = np.mean(X)
mY = np.mean(Y)
length = ((X[0] - X[1])**2 + (Y[0] - Y[1])**2)**0.5
angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
polygon = cv2.ellipse2Poly(
(int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0,
360, 1)
cv2.fillConvexPoly(cur_canvas, polygon, colors[i])
canvas = cv2.addWeighted(canvas, 0.4, cur_canvas, 0.6, 0)
joint = np.zeros_like(cur_canvas[:, :, 0])
cv2.fillConvexPoly(joint, polygon, 255)
joint = cv2.addWeighted(joint, 0.4, joint, 0.6, 0)
joints.append(joint)
pose = F.to_tensor(
Image.fromarray(cv2.cvtColor(canvas, cv2.COLOR_BGR2RGB)))
tensors_dist = 0
e = 1
for i in range(len(joints)):
im_dist = cv2.distanceTransform(255 - joints[i], cv2.DIST_L1, 3)
im_dist = np.clip((im_dist / 3), 0, 255).astype(np.uint8)
tensor_dist = F.to_tensor(Image.fromarray(im_dist))
tensors_dist = tensor_dist if e == 1 else torch.cat(
[tensors_dist, tensor_dist])
e += 1
label_tensor = torch.cat((pose, tensors_dist), dim=0)
return label_tensor, normalized_kp
def get_image_tensor(path):
img = Image.open(path)
param = get_random_params(img.size, 0)
trans = get_transform(param, normalize=True, toTensor=True)
img = trans(img)
return img, param
def infer(genmodel, image_path, target_label_path, device):
ref_tensor, param = get_image_tensor(image_path)
target_label_tensor, target_kp = get_label_tensor(target_label_path,
ref_tensor, param)
ref_tensor = ref_tensor.unsqueeze(0).to(device)
target_label_tensor = target_label_tensor.unsqueeze(0).to(device)
target_kp = torch.from_numpy(target_kp).unsqueeze(0).to(device)
output_dict = genmodel(ref_tensor, target_label_tensor, target_kp)
output_image = output_dict['fake_image'][0]
output_image = output_image.clamp_(-1, 1)
image = (output_image + 1) * 0.5
image = image.detach().cpu().squeeze().numpy()
image = np.transpose(image, (1, 2, 0)) * 255
image = np.uint8(image)
bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
return bgr

View File

@@ -158,7 +158,7 @@ class Encoder(nn.Module):
return hooks
def forward(self, img):
return self.arch(img)
return self.arch.forward_features(img)
class MultiScaleColorDecoder(nn.Module):

View File

@@ -119,8 +119,8 @@ class ConvNeXt(nn.Module):
self.head_cls = nn.Linear(dims[-1], 4)
self.apply(self._init_weights)
self.head_cls.weight.data.mul_(head_init_scale)
self.head_cls.bias.data.mul_(head_init_scale)
# self.head_cls.weight.data.mul_(head_init_scale)
# self.head_cls.bias.data.mul_(head_init_scale)
def _init_weights(self, m):
if isinstance(m, (nn.Conv2d, nn.Linear)):

View File

@@ -0,0 +1,23 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import TYPE_CHECKING
from modelscope.utils.import_utils import LazyImportModule
if TYPE_CHECKING:
from .masactrl import MutualSelfAttentionControl
from .masactrl_utils import regiter_attention_editor_diffusers
else:
_import_structure = {
'masactrl': ['MutualSelfAttentionControl'],
'masactrl_utils': ['regiter_attention_editor_diffusers']
}
import sys
sys.modules[__name__] = LazyImportModule(
__name__,
globals()['__file__'],
_import_structure,
module_spec=__spec__,
extra_objects={},
)

View File

@@ -0,0 +1,77 @@
# ------------------------------------------------------------------------
# Modified from https://github.com/TencentARC/MasaCtrl/blob/main/masactrl/masactrl.py
# Copyright (c) 2023 TencentARC. All Rights Reserved.
# ------------------------------------------------------------------------
import torch
from einops import rearrange
from .masactrl_utils import AttentionBase
class MutualSelfAttentionControl(AttentionBase):
def __init__(self,
start_step=4,
start_layer=10,
layer_idx=None,
step_idx=None,
total_steps=50):
"""
Mutual self-attention control for Stable-Diffusion model
Args:
start_step: the step to start mutual self-attention control
start_layer: the layer to start mutual self-attention control
layer_idx: list of the layers to apply mutual self-attention control
step_idx: list the steps to apply mutual self-attention control
total_steps: the total number of steps
"""
super().__init__()
self.total_steps = total_steps
self.start_step = start_step
self.start_layer = start_layer
self.layer_idx = layer_idx if layer_idx is not None else list(
range(start_layer, 16))
self.step_idx = step_idx if step_idx is not None else list(
range(start_step, total_steps)) # denoise index
print('step_idx: ', self.step_idx)
print('layer_idx: ', self.layer_idx)
def attn_batch(self, q, k, v, sim, attn, is_cross, place_in_unet,
num_heads, **kwargs):
b = q.shape[0] // num_heads
q = rearrange(q, '(b h) n d -> h (b n) d', h=num_heads)
k = rearrange(k, '(b h) n d -> h (b n) d', h=num_heads)
v = rearrange(v, '(b h) n d -> h (b n) d', h=num_heads)
sim = torch.einsum('h i d, h j d -> h i j', q, k) * kwargs.get('scale')
attn = sim.softmax(-1)
out = torch.einsum('h i j, h j d -> h i d', attn, v)
out = rearrange(out, 'h (b n) d -> b n (h d)', b=b)
return out
def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads,
**kwargs):
"""
Attention forward function
"""
if is_cross or self.cur_step not in self.step_idx or self.cur_att_layer // 2 not in self.layer_idx:
return super().forward(q, k, v, sim, attn, is_cross, place_in_unet,
num_heads, **kwargs)
qu, qc = q.chunk(2) # uncond, cond
ku, kc = k.chunk(2)
vu, vc = v.chunk(2)
attnu, attnc = attn.chunk(2)
# uncond
# ku[:num_heads], vu[:num_heads] -> source
# qu -> [source, target]
out_u = self.attn_batch(qu, ku[:num_heads], vu[:num_heads],
sim[:num_heads], attnu, is_cross,
place_in_unet, num_heads, **kwargs)
out_c = self.attn_batch(qc, kc[:num_heads], vc[:num_heads],
sim[:num_heads], attnc, is_cross,
place_in_unet, num_heads, **kwargs)
out = torch.cat([out_u, out_c], dim=0)
return out

View File

@@ -0,0 +1,124 @@
# ------------------------------------------------------------------------
# Modified from https://github.com/TencentARC/MasaCtrl/blob/main/masactrl/masactrl_utils.py
# Copyright (c) 2023 TencentARC. All Rights Reserved.
# ------------------------------------------------------------------------
import torch
import torch.nn as nn
from einops import rearrange, repeat
class AttentionBase:
def __init__(self):
self.cur_step = 0
self.num_att_layers = -1
self.cur_att_layer = 0
def after_step(self):
pass
def __call__(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads,
**kwargs):
out = self.forward(q, k, v, sim, attn, is_cross, place_in_unet,
num_heads, **kwargs)
self.cur_att_layer += 1
if self.cur_att_layer == self.num_att_layers:
self.cur_att_layer = 0
self.cur_step += 1
# after step
self.after_step()
return out
def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads,
**kwargs):
out = torch.einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h=num_heads)
return out
def reset(self):
self.cur_step = 0
self.cur_att_layer = 0
def regiter_attention_editor_diffusers(model, editor: AttentionBase):
"""
Register a attention editor to Diffuser Pipeline, refer from [Prompt-to-Prompt]
"""
def ca_forward(self, place_in_unet):
def forward(x,
encoder_hidden_states=None,
attention_mask=None,
context=None,
mask=None):
"""
The attention is similar to the original implementation of LDM CrossAttention class
except adding some modifications on the attention
"""
if encoder_hidden_states is not None:
context = encoder_hidden_states
if attention_mask is not None:
mask = attention_mask
to_out = self.to_out
if isinstance(to_out, nn.modules.container.ModuleList):
to_out = self.to_out[0]
else:
to_out = self.to_out
h = self.heads
q = self.to_q(x)
is_cross = context is not None
context = context if is_cross else x
k = self.to_k(context)
v = self.to_v(context)
q, k, v = map(
lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
(q, k, v))
sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale
if mask is not None:
mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h)
mask = mask[:, None, :].repeat(h, 1, 1)
sim.masked_fill_(~mask, max_neg_value)
attn = sim.softmax(dim=-1)
# the only difference
out = editor(
q,
k,
v,
sim,
attn,
is_cross,
place_in_unet,
self.heads,
scale=self.scale)
return to_out(out)
return forward
def register_editor(net, count, place_in_unet):
for name, subnet in net.named_children():
if net.__class__.__name__ == 'Attention': # spatial Transformer layer
net.forward = ca_forward(net, place_in_unet)
return count + 1
elif hasattr(net, 'children'):
count = register_editor(subnet, count, place_in_unet)
return count
cross_att_count = 0
for net_name, net in model.unet.named_children():
if 'down' in net_name:
cross_att_count += register_editor(net, 0, 'down')
elif 'mid' in net_name:
cross_att_count += register_editor(net, 0, 'mid')
elif 'up' in net_name:
cross_att_count += register_editor(net, 0, 'up')
editor.num_att_layers = cross_att_count

View File

@@ -91,7 +91,7 @@ def infer(ourgen_model, model_path, person_img, garment_img, mask_img, device):
cm_array = (cm_array >= 128).astype(np.float32)
cm = torch.from_numpy(cm_array)
cm = cm.unsqueeze(0).unsqueeze(0)
cm = torch.FloatTensor((cm.numpy() > 0.5).astype(np.float)).to(device)
cm = torch.FloatTensor((cm.numpy() > 0.5).astype(float)).to(device)
im = person_img
h_ori, w_ori = im.shape[0:2]

View File

@@ -65,6 +65,7 @@ class MTTR(nn.Module):
# keep only the valid frames (frames which are annotated):
# (for example, in a2d-sentences only the center frame in each window is annotated).
for layer_out in backbone_out:
valid_indices = valid_indices.to(layer_out.tensors.device)
layer_out.tensors = layer_out.tensors.index_select(
0, valid_indices)
layer_out.mask = layer_out.mask.index_select(0, valid_indices)

View File

@@ -0,0 +1,20 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import TYPE_CHECKING
from modelscope.utils.import_utils import LazyImportModule
if TYPE_CHECKING:
from .surface_recon_common import SurfaceReconCommon
else:
_import_structure = {'surface_recon_common': ['SurfaceReconCommon']}
import sys
sys.modules[__name__] = LazyImportModule(
__name__,
globals()['__file__'],
_import_structure,
module_spec=__spec__,
extra_objects={},
)

View File

@@ -0,0 +1,289 @@
# ------------------------------------------------------------------------
# Modified from https://github.com/Totoro97/NeuS/blob/main/models/dataset.py
# Copyright (c) 2021 Peng Wang. All Rights Reserved.
# ------------------------------------------------------------------------
import os
from glob import glob
import cv2 as cv
import numpy as np
import torch
from scipy.spatial.transform import Rotation as Rot
from scipy.spatial.transform import Slerp
def load_K_Rt_from_P(filename, P=None):
if P is None:
lines = open(filename).read().splitlines()
if len(lines) == 4:
lines = lines[1:]
lines = [[x[0], x[1], x[2], x[3]]
for x in (x.split(' ') for x in lines)]
P = np.asarray(lines).astype(np.float32).squeeze()
out = cv.decomposeProjectionMatrix(P)
K = out[0]
R = out[1]
t = out[2]
K = K / K[2, 2]
intrinsics = np.eye(4)
intrinsics[:3, :3] = K
pose = np.eye(4, dtype=np.float32)
pose[:3, :3] = R.transpose()
pose[:3, 3] = (t[:3] / t[3])[:, 0]
return intrinsics, pose
class Dataset:
def __init__(self, data_dir, device):
super(Dataset, self).__init__()
print('Load data: Begin')
self.device = device
self.data_dir = data_dir
print('data_dir: ', self.data_dir)
camera_dict = np.load(
os.path.join(self.data_dir, 'cameras_sphere.npz'))
self.camera_dict = camera_dict
self.images_lis = sorted(
glob(os.path.join(self.data_dir, 'image/*.png')))
self.n_images = len(self.images_lis)
print('found %d images' % self.n_images)
self.world_mats_np = [
camera_dict['world_mat_%d' % idx].astype(np.float32)
for idx in range(self.n_images)
]
self.scale_mats_np = [
camera_dict['scale_mat_%d' % idx].astype(np.float32)
for idx in range(self.n_images)
]
self.intrinsics_all = []
self.pose_all = []
for scale_mat, world_mat in zip(self.scale_mats_np,
self.world_mats_np):
P = world_mat @ scale_mat
P = P[:3, :4]
intrinsics, pose = load_K_Rt_from_P(None, P)
self.intrinsics_all.append(torch.from_numpy(intrinsics).float())
self.pose_all.append(torch.from_numpy(pose).float())
self.intrinsics_all = torch.stack(self.intrinsics_all).to(
self.device) # [n_images, 4, 4]
self.intrinsics_all_inv = torch.inverse(
self.intrinsics_all) # [n_images, 4, 4]
self.focal = self.intrinsics_all[0][0, 0]
self.pose_all = torch.stack(self.pose_all).to(
self.device) # [n_images, 4, 4]
object_bbox_min = np.array([-1.01, -1.01, -1.01, 1.0])
object_bbox_max = np.array([1.01, 1.01, 1.01, 1.0])
# Object scale mat: region of interest to **extract mesh**
object_scale_mat = np.load(
os.path.join(self.data_dir, 'cameras_sphere.npz'))['scale_mat_0']
object_bbox_min = np.linalg.inv(
self.scale_mats_np[0]) @ object_scale_mat @ object_bbox_min[:,
None]
object_bbox_max = np.linalg.inv(
self.scale_mats_np[0]) @ object_scale_mat @ object_bbox_max[:,
None]
self.object_bbox_min = object_bbox_min[:3, 0]
self.object_bbox_max = object_bbox_max[:3, 0]
print('Load data: End')
def gen_rays_at(self, img_idx, resolution_level=1):
"""
Generate rays at world space from one camera.
"""
level = resolution_level
tx = torch.linspace(0, self.W - 1, self.W // level)
ty = torch.linspace(0, self.H - 1, self.H // level)
pixels_x, pixels_y = torch.meshgrid(tx, ty)
p = torch.stack(
[pixels_x, pixels_y, torch.ones_like(pixels_y)], dim=-1) # W, H, 3
p = torch.matmul(self.intrinsics_all_inv[img_idx, None, None, :3, :3],
p[:, :, :, None]).squeeze() # W, H, 3
rays_v = p / torch.linalg.norm(
p, ord=2, dim=-1, keepdim=True) # W, H, 3
rays_v = torch.matmul(self.pose_all[img_idx, None, None, :3, :3],
rays_v[:, :, :, None]).squeeze() # W, H, 3
rays_o = self.pose_all[img_idx, None, None, :3,
3].expand(rays_v.shape) # W, H, 3
return rays_o.transpose(0, 1), rays_v.transpose(0, 1)
def gen_rays_o_at(self, img_idx):
"""
Generate rays_o at world space from one camera.
"""
rays_o = self.pose_all[img_idx, :3, 3]
return rays_o
# add
def gen_rays_at_camera(self, pose, resolution_level=1):
"""
Generate rays at world space from one camera.
"""
level = resolution_level
tx = torch.linspace(0, self.W - 1, self.W // level)
ty = torch.linspace(0, self.H - 1, self.H // level)
pixels_x, pixels_y = torch.meshgrid(tx, ty)
p = torch.stack(
[pixels_x, pixels_y, torch.ones_like(pixels_y)], dim=-1) # W, H, 3
p = torch.matmul(self.intrinsics_all_inv[0, None, None, :3, :3],
p[:, :, :, None]).squeeze() # W, H, 3
rays_v = p / torch.linalg.norm(
p, ord=2, dim=-1, keepdim=True) # W, H, 3
rays_v = torch.matmul(pose[:3, :3], rays_v[:, :, :,
None]).squeeze() # W, H, 3
rays_o = pose[:3, 3].expand(rays_v.shape) # W, H, 3
return rays_o.transpose(0, 1), rays_v.transpose(0, 1)
def gen_random_rays_at(self, img_idx, batch_size):
"""
Generate random rays at world space from one camera.
"""
pixels_x = torch.randint(low=0, high=self.W, size=[batch_size]) # bs
pixels_y = torch.randint(low=0, high=self.H, size=[batch_size]) # bs
color = self.images[img_idx][(pixels_y, pixels_x)] # batch_size, 3
mask = self.masks[img_idx][(pixels_y, pixels_x)] # batch_size, 3
depth = self.depths[img_idx][(pixels_y, pixels_x)] # batch_size, 1
p = torch.stack(
[pixels_x, pixels_y, torch.ones_like(pixels_y)],
dim=-1).float() # batch_size, 3
p = torch.matmul(self.intrinsics_all_inv[img_idx, None, :3, :3],
p[:, :, None]).squeeze() # batch_size, 3
rays_v = p / torch.linalg.norm(
p, ord=2, dim=-1, keepdim=True) # batch_size, 3
rays_v = torch.matmul(self.pose_all[img_idx, None, :3, :3],
rays_v[:, :, None]).squeeze() # batch_size, 3
rays_o = self.pose_all[img_idx, None, :3,
3].expand(rays_v.shape) # batch_size, 3
return torch.cat(
[rays_o.cpu(),
rays_v.cpu(), color, mask[:, :1], depth[:, None]],
dim=-1).cuda() # batch_size, 10
def gen_random_rays_at_mask(self, img_idx, batch_size):
"""
Generate random rays at world space from one camera.
"""
pixels_x = torch.randint(low=0, high=self.W, size=[batch_size]) # bs
pixels_y = torch.randint(low=0, high=self.H, size=[batch_size]) # bs
color = self.images[img_idx][(pixels_y, pixels_x)] # batch_size, 3
mask = self.masks[img_idx][(pixels_y, pixels_x)] # batch_size, 3
depth = self.depths[img_idx][(pixels_y, pixels_x)] # batch_size, 1
p = torch.stack(
[pixels_x, pixels_y, torch.ones_like(pixels_y)],
dim=-1).float() # batch_size, 3
p = torch.matmul(self.intrinsics_all_inv[img_idx, None, :3, :3],
p[:, :, None]).squeeze() # batch_size, 3
rays_v = p / torch.linalg.norm(
p, ord=2, dim=-1, keepdim=True) # batch_size, 3
rays_v = torch.matmul(self.pose_all[img_idx, None, :3, :3],
rays_v[:, :, None]).squeeze() # batch_size, 3
rays_o = self.pose_all[img_idx, None, :3,
3].expand(rays_v.shape) # batch_size, 3
return torch.cat(
[rays_o.cpu(),
rays_v.cpu(), color, mask[:, :1], depth[:, None]],
dim=-1).cuda() # batch_size, 10
def gen_rays_between(self, idx_0, idx_1, ratio, resolution_level=1):
"""
Interpolate pose between two cameras.
"""
level = resolution_level
tx = torch.linspace(0, self.W - 1, self.W // level)
ty = torch.linspace(0, self.H - 1, self.H // level)
pixels_x, pixels_y = torch.meshgrid(tx, ty)
p = torch.stack(
[pixels_x, pixels_y, torch.ones_like(pixels_y)], dim=-1) # W, H, 3
p = torch.matmul(self.intrinsics_all_inv[0, None, None, :3, :3],
p[:, :, :, None]).squeeze() # W, H, 3
rays_v = p / torch.linalg.norm(
p, ord=2, dim=-1, keepdim=True) # W, H, 3
trans = self.pose_all[idx_0, :3, 3] * (
1.0 - ratio) + self.pose_all[idx_1, :3, 3] * ratio
pose_0 = self.pose_all[idx_0].detach().cpu().numpy()
pose_1 = self.pose_all[idx_1].detach().cpu().numpy()
pose_0 = np.linalg.inv(pose_0)
pose_1 = np.linalg.inv(pose_1)
rot_0 = pose_0[:3, :3]
rot_1 = pose_1[:3, :3]
rots = Rot.from_matrix(np.stack([rot_0, rot_1]))
key_times = [0, 1]
slerp = Slerp(key_times, rots)
rot = slerp(ratio)
pose = np.diag([1.0, 1.0, 1.0, 1.0])
pose = pose.astype(np.float32)
pose[:3, :3] = rot.as_matrix()
pose[:3, 3] = ((1.0 - ratio) * pose_0 + ratio * pose_1)[:3, 3]
pose = np.linalg.inv(pose)
rot = torch.from_numpy(pose[:3, :3]).cuda()
trans = torch.from_numpy(pose[:3, 3]).cuda()
rays_v = torch.matmul(rot[None, None, :3, :3],
rays_v[:, :, :, None]).squeeze() # W, H, 3
rays_o = trans[None, None, :3].expand(rays_v.shape) # W, H, 3
return rays_o.transpose(0, 1), rays_v.transpose(0, 1), pose
def gen_rays_across(self, idx_0, idx_1, ratio, resolution_level=1):
"""
Interpolate pose between two cameras.
"""
level = resolution_level
tx = torch.linspace(0, self.W - 1, self.W // level)
ty = torch.linspace(0, self.H - 1, self.H // level)
pixels_x, pixels_y = torch.meshgrid(tx, ty)
p = torch.stack(
[pixels_x, pixels_y, torch.ones_like(pixels_y)], dim=-1) # W, H, 3
p = torch.matmul(self.intrinsics_all_inv[0, None, None, :3, :3],
p[:, :, :, None]).squeeze() # W, H, 3
rays_v = p / torch.linalg.norm(
p, ord=2, dim=-1, keepdim=True) # W, H, 3
trans = self.pose_all[idx_0, :3, 3] * (
1.0 - ratio) + self.pose_all[idx_1, :3, 3] * ratio
pose_0 = self.pose_all[idx_0].detach().cpu().numpy()
pose_1 = self.pose_all[idx_1].detach().cpu().numpy()
pose_0 = np.linalg.inv(pose_0)
pose_1 = np.linalg.inv(pose_1)
rot_0 = pose_0[:3, :3]
rot_1 = pose_1[:3, :3]
rots = Rot.from_matrix(np.stack([rot_0, rot_1]))
key_times = [0, 1]
slerp = Slerp(key_times, rots)
rot = slerp(ratio)
pose = np.diag([1.0, 1.0, 1.0, 1.0])
pose = pose.astype(np.float32)
pose[:3, :3] = rot.as_matrix()
pose[:3, 3] = ((1.0 - ratio) * pose_0 + ratio * pose_1)[:3, 3]
pose = np.linalg.inv(pose)
rot = torch.from_numpy(pose[:3, :3]).cuda()
trans = torch.from_numpy(pose[:3, 3]).cuda()
rays_v = torch.matmul(rot[None, None, :3, :3],
rays_v[:, :, :, None]).squeeze() # W, H, 3
rays_o = trans[None, None, :3].expand(rays_v.shape) # W, H, 3
return rays_o.transpose(0, 1), rays_v.transpose(0, 1), pose
def near_far_from_sphere(self, rays_o, rays_d):
a = torch.sum(rays_d**2, dim=-1, keepdim=True)
b = 2.0 * torch.sum(rays_o * rays_d, dim=-1, keepdim=True)
mid = 0.5 * (-b) / a
near = mid - 1.0
far = mid + 1.0
return near, far
def image_at(self, idx, resolution_level):
img = cv.imread(self.images_lis[idx])
return (cv.resize(img, (self.W // resolution_level,
self.H // resolution_level))).clip(0, 255)

View File

@@ -0,0 +1,390 @@
# ------------------------------------------------------------------------
# Modified from https://github.com/Totoro97/NeuS/blob/main/models/fields.py
# Copyright (c) 2021 Peng Wang. All Rights Reserved.
# ------------------------------------------------------------------------
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import spectral_norm
class SDFNetwork(nn.Module):
def __init__(self,
d_in,
d_out,
d_hidden,
n_layers,
skip_in=(4, ),
multires=0,
bias=0.5,
scale=1,
geometric_init=True,
weight_norm=True,
inside_outside=False):
super(SDFNetwork, self).__init__()
dims = [d_in] + [d_hidden for _ in range(n_layers)] + [d_out]
self.embed_fn_fine = None
if multires > 0:
embed_fn, input_ch = get_embedder(multires, input_dims=d_in)
self.embed_fn_fine = embed_fn
dims[0] = input_ch
self.num_layers = len(dims)
self.skip_in = skip_in
self.scale = scale
for layer in range(0, self.num_layers - 1):
if layer + 1 in self.skip_in:
out_dim = dims[layer + 1] - dims[0]
else:
out_dim = dims[layer + 1]
lin = nn.Linear(dims[layer], out_dim)
if geometric_init:
if layer == self.num_layers - 2:
if not inside_outside:
torch.nn.init.normal_(
lin.weight,
mean=np.sqrt(np.pi) / np.sqrt(dims[layer]),
std=0.0001)
torch.nn.init.constant_(lin.bias, -bias)
else:
torch.nn.init.normal_(
lin.weight,
mean=-np.sqrt(np.pi) / np.sqrt(dims[layer]),
std=0.0001)
torch.nn.init.constant_(lin.bias, bias)
elif multires > 0 and layer == 0:
torch.nn.init.constant_(lin.bias, 0.0)
torch.nn.init.constant_(lin.weight[:, 3:], 0.0)
torch.nn.init.normal_(lin.weight[:, :3], 0.0,
np.sqrt(2) / np.sqrt(out_dim))
elif multires > 0 and layer in self.skip_in:
torch.nn.init.constant_(lin.bias, 0.0)
torch.nn.init.normal_(lin.weight, 0.0,
np.sqrt(2) / np.sqrt(out_dim))
torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3):],
0.0)
else:
torch.nn.init.constant_(lin.bias, 0.0)
torch.nn.init.normal_(lin.weight, 0.0,
np.sqrt(2) / np.sqrt(out_dim))
if weight_norm:
lin = nn.utils.weight_norm(lin)
setattr(self, 'lin' + str(layer), lin)
self.activation = nn.Softplus(beta=100)
def forward(self, inputs):
inputs = inputs * self.scale
if self.embed_fn_fine is not None:
inputs = self.embed_fn_fine(inputs)
x = inputs
for layer in range(0, self.num_layers - 1):
lin = getattr(self, 'lin' + str(layer))
if layer in self.skip_in:
x = torch.cat([x, inputs], 1) / np.sqrt(2)
x = lin(x)
if layer < self.num_layers - 2:
x = self.activation(x)
return torch.cat([x[:, :1] / self.scale, x[:, 1:]], dim=-1)
def sdf(self, x):
return self.forward(x)[:, :1]
def sdf_hidden_appearance(self, x):
return self.forward(x)
def gradient(self, x):
x.requires_grad_(True)
with torch.enable_grad():
y = self.sdf(x)
d_output = torch.ones_like(y, requires_grad=False, device=y.device)
gradients = torch.autograd.grad(
outputs=y,
inputs=x,
grad_outputs=d_output,
create_graph=True,
retain_graph=True,
only_inputs=True)[0]
return gradients.unsqueeze(1)
class RenderingNetwork(nn.Module):
def __init__(self,
d_feature,
mode,
d_in,
d_out,
d_hidden,
n_layers,
weight_norm=True,
multires_view=0,
squeeze_out=True):
super().__init__()
self.mode = mode
self.squeeze_out = squeeze_out
dims = [d_in + d_feature] + [d_hidden
for _ in range(n_layers)] + [d_out]
self.embedview_fn = None
if multires_view > 0:
embedview_fn, input_ch = get_embedder(multires_view)
self.embedview_fn = embedview_fn
dims[0] += (input_ch - 3)
self.num_layers = len(dims)
for layer in range(0, self.num_layers - 1):
out_dim = dims[layer + 1]
lin = nn.Linear(dims[layer], out_dim)
if weight_norm:
lin = nn.utils.weight_norm(lin)
setattr(self, 'lin' + str(layer), lin)
self.relu = nn.ReLU()
def forward(self, points, normals, view_dirs, feature_vectors):
if self.embedview_fn is not None:
view_dirs = self.embedview_fn(view_dirs)
rendering_input = None
if self.mode == 'idr':
rendering_input = torch.cat(
[points, view_dirs, normals, feature_vectors], dim=-1)
elif self.mode == 'no_view_dir':
rendering_input = torch.cat([points, normals, feature_vectors],
dim=-1)
elif self.mode == 'no_normal':
rendering_input = torch.cat([points, view_dirs, feature_vectors],
dim=-1)
x = rendering_input
for layer in range(0, self.num_layers - 1):
lin = getattr(self, 'lin' + str(layer))
x = lin(x)
if layer < self.num_layers - 2:
x = self.relu(x)
if self.squeeze_out:
x = torch.sigmoid(x)
return x
class NeRF(nn.Module):
def __init__(self,
D=8,
W=256,
d_in=3,
d_in_view=3,
multires=0,
multires_view=0,
output_ch=4,
skips=[4],
use_viewdirs=False):
super(NeRF, self).__init__()
self.D = D
self.W = W
self.d_in = d_in
self.d_in_view = d_in_view
self.input_ch = 3
self.input_ch_view = 3
self.embed_fn = None
self.embed_fn_view = None
if multires > 0:
embed_fn, input_ch = get_embedder(multires, input_dims=d_in)
self.embed_fn = embed_fn
self.input_ch = input_ch
if multires_view > 0:
embed_fn_view, input_ch_view = get_embedder(
multires_view, input_dims=d_in_view)
self.embed_fn_view = embed_fn_view
self.input_ch_view = input_ch_view
self.skips = skips
self.use_viewdirs = use_viewdirs
self.pts_linears = nn.ModuleList([nn.Linear(self.input_ch, W)] + [
nn.Linear(W, W) if i not in
self.skips else nn.Linear(W + self.input_ch, W)
for i in range(D - 1)
])
self.views_linears = nn.ModuleList(
[nn.Linear(self.input_ch_view + W, W // 2)])
if use_viewdirs:
self.feature_linear = nn.Linear(W, W)
self.alpha_linear = nn.Linear(W, 1)
self.rgb_linear = nn.Linear(W // 2, 3)
else:
self.output_linear = nn.Linear(W, output_ch)
def forward(self, input_pts, input_views):
if self.embed_fn is not None:
input_pts = self.embed_fn(input_pts)
if self.embed_fn_view is not None:
input_views = self.embed_fn_view(input_views)
h = input_pts
for i, l in enumerate(self.pts_linears):
h = self.pts_linears[i](h)
h = F.relu(h)
if i in self.skips:
h = torch.cat([input_pts, h], -1)
if self.use_viewdirs:
alpha = self.alpha_linear(h)
feature = self.feature_linear(h)
h = torch.cat([feature, input_views], -1)
for i, l in enumerate(self.views_linears):
h = self.views_linears[i](h)
h = F.relu(h)
rgb = self.rgb_linear(h)
return alpha, rgb
else:
assert False
class SingleVarianceNetwork(nn.Module):
def __init__(self, init_val):
super(SingleVarianceNetwork, self).__init__()
self.register_parameter('variance',
nn.Parameter(torch.tensor(init_val)))
def forward(self, x):
return torch.ones([len(x), 1],
device=self.variance.device) * torch.exp(
self.variance * 10.0)
class Mean(nn.Module):
def __init__(self, dim: list, keepdim=False):
super().__init__()
self.dim = dim
self.keepdim = keepdim
def forward(self, x):
return torch.mean(x, self.dim, self.keepdim)
class Discriminator(nn.Module):
def __init__(self, channel=32, patch=True):
super().__init__()
self.imsize = 32
self.nc = 3
self.channel = channel
self.patch = patch
in_channel = 3
layer = []
for idx in range(3):
layer.extend([
spectral_norm(
nn.Conv2d(
in_channel, channel * (2**idx), 3, stride=2,
padding=1)),
nn.LeakyReLU(inplace=True),
spectral_norm(
nn.Conv2d(
channel * (2**idx),
channel * (2**idx),
3,
stride=1,
padding=1)),
nn.LeakyReLU(inplace=True),
])
in_channel = channel * (2**idx)
self.body = nn.Sequential(*layer)
if self.patch:
self.head = spectral_norm(nn.Conv2d(in_channel, 1, 1, padding=0))
else:
self.head = nn.Sequential(Mean([1, 2]), nn.Linear(in_channel, 1))
def forward(self, x):
x = x[:, :self.nc]
x = x.view(-1, self.imsize, self.imsize, self.nc).permute(0, 3, 1, 2)
x = self.body(x)
x = self.head(x)
return x
class Embedder:
def __init__(self, **kwargs):
self.kwargs = kwargs
self.create_embedding_fn()
def create_embedding_fn(self):
embed_fns = []
d = self.kwargs['input_dims']
out_dim = 0
if self.kwargs['include_input']:
embed_fns.append(lambda x: x)
out_dim += d
max_freq = self.kwargs['max_freq_log2']
N_freqs = self.kwargs['num_freqs']
if self.kwargs['log_sampling']:
freq_bands = 2.**torch.linspace(0., max_freq, N_freqs)
else:
freq_bands = torch.linspace(2.**0., 2.**max_freq, N_freqs)
for freq in freq_bands:
for p_fn in self.kwargs['periodic_fns']:
embed_fns.append(
lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq))
out_dim += d
self.embed_fns = embed_fns
self.out_dim = out_dim
def embed(self, inputs):
return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
def get_embedder(multires, input_dims=3):
embed_kwargs = {
'include_input': True,
'input_dims': input_dims,
'max_freq_log2': multires - 1,
'num_freqs': multires,
'log_sampling': True,
'periodic_fns': [torch.sin, torch.cos],
}
embedder_obj = Embedder(**embed_kwargs)
def embed(x, eo=embedder_obj):
return eo.embed(x)
return embed, embedder_obj.out_dim

View File

@@ -0,0 +1,388 @@
# ------------------------------------------------------------------------
# Modified from https://github.com/Totoro97/NeuS/blob/main/models/renderer.py
# Copyright (c) 2021 Peng Wang.
# Copyright (c) Alibaba, Inc. and its affiliates.
# ------------------------------------------------------------------------
import torch
import torch.nn as nn
import torch.nn.functional as F
from .fields import RenderingNetwork, SDFNetwork, SingleVarianceNetwork
from .utils import extract_geometry, sample_pdf
class SurfaceRenderer(nn.Module):
def __init__(self, conf, device):
super().__init__()
self.conf = conf
self.device = device
self.sdf_network = SDFNetwork(**self.conf['sdf_network']).to(
self.device)
self.variance_network = SingleVarianceNetwork(
**self.conf['variance_network']).to(self.device)
self.color_network = RenderingNetwork(
**self.conf['rendering_network']).to(self.device)
self.light_network = RenderingNetwork(**self.conf['light_network']).to(
self.device)
self.n_samples = self.conf['neus_renderer']['n_samples']
self.n_importance = self.conf['neus_renderer']['n_importance']
self.n_outside = self.conf['neus_renderer']['n_outside']
self.up_sample_steps = self.conf['neus_renderer']['up_sample_steps']
self.perturb = self.conf['neus_renderer']['perturb']
def extract_geometry(self,
bound_min,
bound_max,
resolution,
threshold=0.0,
device='cuda'):
return extract_geometry(
bound_min,
bound_max,
resolution=resolution,
threshold=threshold,
query_func=lambda pts: -self.sdf_network.sdf(pts),
device=device)
def render_core_outside(self,
rays_o,
rays_d,
z_vals,
sample_dist,
nerf,
background_rgb=None):
batch_size, n_samples = z_vals.shape
dists = z_vals[..., 1:] - z_vals[..., :-1]
dists = torch.cat(
[dists,
torch.Tensor([sample_dist]).expand(dists[..., :1].shape)], -1)
mid_z_vals = z_vals + dists * 0.5
pts = rays_o[:, None, :] + rays_d[:, None, :] * mid_z_vals[
..., :, None] # batch_size, n_samples, 3
dis_to_center = torch.linalg.norm(
pts, ord=2, dim=-1, keepdim=True).clip(1.0, 1e10)
pts = torch.cat([pts / dis_to_center, 1.0 / dis_to_center],
dim=-1) # batch_size, n_samples, 4
dirs = rays_d[:, None, :].expand(batch_size, n_samples, 3)
pts = pts.reshape(-1, 3 + int(self.n_outside > 0))
dirs = dirs.reshape(-1, 3)
density, sampled_color = nerf(pts, dirs)
alpha = 1.0 - torch.exp(
-F.softplus(density.reshape(batch_size, n_samples)) * dists)
alpha = alpha.reshape(batch_size, n_samples)
weights = alpha * torch.cumprod(
torch.cat([torch.ones([batch_size, 1]), 1. - alpha + 1e-7], -1),
-1)[:, :-1]
sampled_color = sampled_color.reshape(batch_size, n_samples, 3)
color = (weights[:, :, None] * sampled_color).sum(dim=1)
if background_rgb is not None:
color = color + background_rgb * (
1.0 - weights.sum(dim=-1, keepdim=True))
return {
'color': color,
'sampled_color': sampled_color,
'alpha': alpha,
'weights': weights,
}
def up_sample(self, rays_o, rays_d, z_vals, sdf, n_importance, inv_s):
batch_size, n_samples = z_vals.shape
pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None]
radius = torch.linalg.norm(pts, ord=2, dim=-1, keepdim=False)
inside_sphere = (radius[:, :-1] < 1.0) | (radius[:, 1:] < 1.0)
sdf = sdf.reshape(batch_size, n_samples)
prev_sdf, next_sdf = sdf[:, :-1], sdf[:, 1:]
prev_z_vals, next_z_vals = z_vals[:, :-1], z_vals[:, 1:]
mid_sdf = (prev_sdf + next_sdf) * 0.5
cos_val = (next_sdf - prev_sdf) / (next_z_vals - prev_z_vals + 1e-5)
prev_cos_val = torch.cat(
[torch.zeros([batch_size, 1]).to(self.device), cos_val[:, :-1]],
dim=-1)
cos_val = torch.stack([prev_cos_val, cos_val], dim=-1)
cos_val, _ = torch.min(cos_val, dim=-1, keepdim=False)
cos_val = cos_val.clip(-1e3, 0.0) * inside_sphere
dist = (next_z_vals - prev_z_vals)
prev_esti_sdf = mid_sdf - cos_val * dist * 0.5
next_esti_sdf = mid_sdf + cos_val * dist * 0.5
prev_cdf = torch.sigmoid(prev_esti_sdf * inv_s)
next_cdf = torch.sigmoid(next_esti_sdf * inv_s)
alpha = (prev_cdf - next_cdf + 1e-5) / (prev_cdf + 1e-5)
weights = alpha * torch.cumprod(
torch.cat([
torch.ones([batch_size, 1]).to(self.device), 1. - alpha + 1e-7
], -1), -1)[:, :-1]
z_samples = sample_pdf(
z_vals, weights, n_importance, det=True,
device=self.device).detach()
return z_samples
def cat_z_vals(self, rays_o, rays_d, z_vals, new_z_vals, sdf, last=False):
batch_size, n_samples = z_vals.shape
_, n_importance = new_z_vals.shape
pts = rays_o[:,
None, :] + rays_d[:, None, :] * new_z_vals[..., :, None]
z_vals = torch.cat([z_vals, new_z_vals], dim=-1)
z_vals, index = torch.sort(z_vals, dim=-1)
if not last:
new_sdf = self.sdf_network.sdf(pts.reshape(-1, 3)).reshape(
batch_size, n_importance)
sdf = torch.cat([sdf, new_sdf], dim=-1)
xx = torch.arange(batch_size)[:, None].expand(
batch_size, n_samples + n_importance).reshape(-1)
index = index.reshape(-1)
sdf = sdf[(xx, index)].reshape(batch_size,
n_samples + n_importance)
return z_vals, sdf
def render_core(self,
rays_o,
rays_d,
z_vals,
sample_dist,
sdf_network,
deviation_network,
color_network,
light_network,
depth_z=None,
background_alpha=None,
bg_sampled_color=None,
background_rgb=None,
cos_anneal_ratio=0.0):
batch_size, n_samples = z_vals.shape
dists = z_vals[..., 1:] - z_vals[..., :-1]
dists = torch.cat([
dists,
torch.Tensor([sample_dist]).expand(dists[..., :1].shape).to(
self.device)
], -1)
mid_z_vals = z_vals + dists * 0.5
pts = rays_o[:,
None, :] + rays_d[:, None, :] * mid_z_vals[..., :, None]
dirs = rays_d[:, None, :].expand(pts.shape)
pts = pts.reshape(-1, 3)
dirs = dirs.reshape(-1, 3)
sdf_nn_output = sdf_network(pts)
sdf = sdf_nn_output[:, :1]
feature_vector = sdf_nn_output[:, 1:]
gradients = sdf_network.gradient(pts).squeeze()
sampled_albedo = color_network(pts, gradients, dirs,
feature_vector).reshape(
batch_size, n_samples, 3)
sampled_light = light_network(pts, gradients, dirs,
feature_vector).reshape(
batch_size, n_samples, 3)
sampled_color = sampled_albedo * sampled_light
inv_s = deviation_network(torch.zeros([1, 3]))[:, :1].clip(1e-6, 1e6)
inv_s = inv_s.expand(batch_size * n_samples, 1)
true_cos = (dirs * gradients).sum(-1, keepdim=True)
iter_cos_p1 = F.relu(-true_cos * 0.5 + 0.5) * (1.0 - cos_anneal_ratio)
iter_cos = -(iter_cos_p1 + F.relu(-true_cos) * cos_anneal_ratio)
estimated_next_sdf = sdf + iter_cos * dists.reshape(-1, 1) * 0.5
estimated_prev_sdf = sdf - iter_cos * dists.reshape(-1, 1) * 0.5
prev_cdf = torch.sigmoid(estimated_prev_sdf * inv_s)
next_cdf = torch.sigmoid(estimated_next_sdf * inv_s)
p = prev_cdf - next_cdf
c = prev_cdf
alpha = ((p + 1e-5) / (c + 1e-5)).reshape(batch_size,
n_samples).clip(0.0, 1.0)
pts_norm = torch.linalg.norm(
pts, ord=2, dim=-1, keepdim=True).reshape(batch_size, n_samples)
inside_sphere = (pts_norm < 1.0).float().detach()
relax_inside_sphere = (pts_norm < 1.2).float().detach()
if background_alpha is not None:
alpha = alpha * inside_sphere + background_alpha[:, :n_samples] * (
1.0 - inside_sphere)
alpha = torch.cat([alpha, background_alpha[:, n_samples:]], dim=-1)
foreground_color = sampled_color * inside_sphere[:, :, None]
background_color = bg_sampled_color[:, :n_samples] * (
1.0 - inside_sphere)[:, :, None]
sampled_color = foreground_color + background_color
sampled_color = torch.cat(
[sampled_color, bg_sampled_color[:, n_samples:]], dim=1)
beta = torch.cat([
torch.ones([batch_size, 1], device=alpha.device), 1. - alpha + 1e-7
], -1)
weights = alpha * torch.cumprod(beta, -1)[:, :-1]
weights_sum = weights.sum(dim=-1, keepdim=True)
color = (sampled_color * weights[:, :, None]).sum(dim=1)
if background_rgb is not None:
color = color + background_rgb * (1.0 - weights_sum)
albedo = (sampled_albedo * weights[:, :, None]).sum(dim=1)
depth = (mid_z_vals * weights).sum(dim=1)
if depth_z is not None:
pts_depth = rays_o[:, None, :] + rays_d[:, None, :] * depth_z[
..., :, None] # n_rays, n_samples, 3
pts_depth = pts_depth.reshape(-1, 3)
sdf_depth = sdf_network(pts_depth)[:, :1]
else:
sdf_depth = None
gradients_norm = torch.linalg.norm(
gradients.reshape(batch_size, n_samples, 3), ord=2, dim=-1)
gradient_error = (gradients_norm - 1.0)**2
gradient_error = (relax_inside_sphere * gradient_error).sum()
gradient_error = gradient_error / (relax_inside_sphere.sum() + 1e-5)
return {
'color': color,
'albedo': albedo,
'depth': depth,
'sdf': sdf,
'sdf_depth': sdf_depth,
'dists': dists,
'gradients': gradients.reshape(batch_size, n_samples, 3),
's_val': 1.0 / inv_s,
'mid_z_vals': mid_z_vals,
'weights': weights,
'cdf': c.reshape(batch_size, n_samples),
'gradient_error': gradient_error,
'inside_sphere': inside_sphere
}
def render(self,
rays_o,
rays_d,
near,
far,
depth_z=None,
perturb_overwrite=-1,
background_rgb=None,
cos_anneal_ratio=0.0):
batch_size = len(rays_o)
sample_dist = 2.0 / self.n_samples # Assuming the region of interest is a unit sphere
z_vals = torch.linspace(0.0, 1.0, self.n_samples).to(self.device)
z_vals = near + (far - near) * z_vals[None, :]
z_vals_outside = None
if self.n_outside > 0:
z_vals_end = 1.0 - 1.0 / (self.n_outside + 1.0)
z_vals_outside = torch.linspace(1e-3, z_vals_end, self.n_outside)
n_samples = self.n_samples
perturb = self.perturb
if perturb_overwrite >= 0:
perturb = perturb_overwrite
if perturb > 0:
t_rand = (torch.rand([batch_size, 1]).to(self.device) - 0.5)
z_vals = z_vals + t_rand * 2.0 / self.n_samples
if self.n_outside > 0:
mids = .5 * (
z_vals_outside[..., 1:] + z_vals_outside[..., :-1])
upper = torch.cat([mids, z_vals_outside[..., -1:]], -1)
lower = torch.cat([z_vals_outside[..., :1], mids], -1)
t_rand = torch.rand([batch_size, z_vals_outside.shape[-1]])
z_vals_outside = lower[None, :] + (upper
- lower)[None, :] * t_rand
if self.n_outside > 0:
z_vals_outside = far / torch.flip(
z_vals_outside, dims=[-1]) + 1.0 / self.n_samples
background_alpha = None
background_sampled_color = None
# Up sample
if self.n_importance > 0:
with torch.no_grad():
pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :,
None]
sdf = self.sdf_network.sdf(pts.reshape(-1, 3)).reshape(
batch_size, self.n_samples)
for i in range(self.up_sample_steps):
new_z_vals = self.up_sample(
rays_o, rays_d, z_vals, sdf,
self.n_importance // self.up_sample_steps, 64 * 2**i)
z_vals, sdf = self.cat_z_vals(
rays_o,
rays_d,
z_vals,
new_z_vals,
sdf,
last=(i + 1 == self.up_sample_steps))
n_samples = self.n_samples + self.n_importance
if self.n_outside > 0:
z_vals_feed = torch.cat([z_vals, z_vals_outside], dim=-1)
z_vals_feed, _ = torch.sort(z_vals_feed, dim=-1)
ret_outside = self.render_core_outside(rays_o, rays_d, z_vals_feed,
sample_dist, self.nerf)
background_sampled_color = ret_outside['sampled_color']
background_alpha = ret_outside['alpha']
ret_fine = self.render_core(
rays_o,
rays_d,
z_vals,
sample_dist,
self.sdf_network,
self.variance_network,
self.color_network,
self.light_network,
depth_z=depth_z,
background_rgb=background_rgb,
background_alpha=background_alpha,
background_sampled_color=background_sampled_color,
cos_anneal_ratio=cos_anneal_ratio)
color_fine = ret_fine['color']
albedo_fine = ret_fine['albedo']
depth_fine = ret_fine['depth']
sdf_depth = ret_fine['sdf_depth']
weights = ret_fine['weights']
weights_sum = weights.sum(dim=-1, keepdim=True)
gradients = ret_fine['gradients']
s_val = ret_fine['s_val'].reshape(batch_size, n_samples).mean(
dim=-1, keepdim=True)
return {
'color_fine': color_fine,
'albedo_fine': albedo_fine,
'depth_fine': depth_fine,
'sdf_depth': sdf_depth,
's_val': s_val,
'cdf_fine': ret_fine['cdf'],
'weight_sum': weights_sum,
'weight_max': torch.max(weights, dim=-1, keepdim=True)[0],
'gradients': gradients,
'weights': weights,
'mid_z_vals': ret_fine['mid_z_vals'],
'gradient_error': ret_fine['gradient_error'],
'inside_sphere': ret_fine['inside_sphere']
}

View File

@@ -0,0 +1,160 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import numpy as np
import torch
import trimesh
from modelscope.metainfo import Models
from modelscope.models.base import Tensor, TorchModel
from modelscope.models.builder import MODELS
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger
from .dataset import Dataset
from .renderer import SurfaceRenderer
logger = get_logger()
__all__ = ['SurfaceReconCommon']
@MODELS.register_module(
Tasks.surface_recon_common, module_name=Models.surface_recon_common)
class SurfaceReconCommon(TorchModel):
def __init__(self, model_dir, network_cfg, **kwargs):
"""initialize the surface reconstruction model for common objects.
Args:
model_dir (str): the model path.
network_cfg (dict): args of network config
"""
super().__init__(model_dir, **kwargs)
logger.info('model params:{}'.format(kwargs))
if torch.cuda.is_available():
self.device = torch.device('cuda')
else:
raise Exception('GPU is required')
logger.info(network_cfg)
self.renderer = SurfaceRenderer(network_cfg, device=self.device)
self.ckpt_path = os.path.join(model_dir, 'model.pth')
if not os.path.exists(self.ckpt_path):
raise Exception('model path not found')
self.load_checkpoint(self.ckpt_path)
logger.info('load models from %s' % self.ckpt_path)
self.n_rays = network_cfg['n_rays']
def load_checkpoint(self, ckpt_path):
checkpoint = torch.load(ckpt_path, map_location=self.device)
for name, module in self.renderer.named_modules():
saved_name = name + '_fine'
if saved_name in checkpoint:
module.load_state_dict(checkpoint[saved_name])
def surface_reconstruction(self,
data_dir,
save_dir,
color=False,
n_directions=8):
self.dataset = Dataset(data_dir, self.device)
bound_min = torch.tensor(
self.dataset.object_bbox_min, dtype=torch.float32).to(self.device)
bound_max = torch.tensor(
self.dataset.object_bbox_max, dtype=torch.float32).to(self.device)
vertices, triangles = \
self.renderer.extract_geometry(bound_min, bound_max, resolution=512, threshold=0.0,
device=self.device)
if color:
pt_vertices = torch.from_numpy(vertices).cuda().reshape(-1, 1,
3).float()
idx_list = np.linspace(
0,
self.dataset.n_images,
n_directions,
endpoint=False,
dtype=int)
rays_o_list = []
for idx in idx_list:
rays_o = self.dataset.pose_all[idx, :3, 3]
rays_o_list.append(rays_o)
rgb_final = None
diff_final = None
for rays_o in rays_o_list:
rays_o = rays_o.reshape(1, 3).repeat(vertices.shape[0],
1).float()
rays_d = pt_vertices.reshape(-1, 3) - rays_o
rays_d = rays_d / torch.norm(rays_d, dim=-1).reshape(-1, 1)
dist = torch.norm(pt_vertices.reshape(-1, 3) - rays_o, dim=-1)
rays_o = rays_o.reshape(-1, 3).split(self.n_rays)
rays_d = rays_d.reshape(-1, 3).split(self.n_rays)
dist = dist.reshape(-1).split(self.n_rays)
out_rgb_fine = []
depth_diff = []
for i, (rays_o_batch,
rays_d_batch) in enumerate(zip(rays_o, rays_d)):
near, far = self.dataset.near_far_from_sphere(
rays_o_batch, rays_d_batch)
render_out = self.renderer.render(
rays_o_batch,
rays_d_batch,
near,
far,
cos_anneal_ratio=1.0,
background_rgb=None)
# out_rgb_fine.append(render_out['color_fine'].detach().cpu().numpy())
out_rgb_fine.append(
render_out['albedo_fine'].detach().cpu().numpy())
weights = render_out['weights']
mid_z_vals = render_out['mid_z_vals']
n_samples = self.renderer.n_samples + self.renderer.n_importance
depth_batch = (mid_z_vals[:, :n_samples]
* weights[:, :n_samples]).sum(
dim=1).detach().cpu().numpy()
dist_batch = dist[i].detach().cpu().numpy()
depth_diff.append(np.abs(depth_batch - dist_batch))
del render_out
out_rgb_fine = np.concatenate(
out_rgb_fine, axis=0).reshape(vertices.shape[0], 3)
depth_diff = np.concatenate(
depth_diff, axis=0).reshape(vertices.shape[0])
if rgb_final is None:
rgb_final = out_rgb_fine.copy()
diff_final = depth_diff.copy()
else:
ind = diff_final > depth_diff
ind = ind.reshape(-1)
rgb_final[ind] = out_rgb_fine[ind]
diff_final[ind] = depth_diff[ind]
vertices = vertices * self.dataset.scale_mats_np[0][
0, 0] + self.dataset.scale_mats_np[0][:3, 3][None]
if color:
logger.info('save mesh with color')
vert_colors = (255 * np.clip(rgb_final[..., ::-1], 0, 1)).astype(
np.uint8)
mesh = trimesh.Trimesh(
vertices, triangles, vertex_colors=vert_colors)
else:
mesh = trimesh.Trimesh(vertices, triangles)
outpath = os.path.join(save_dir, 'mesh.ply')
mesh.export(outpath)
logger.info('surface econstruction done, export mesh to %s' % outpath)

View File

@@ -0,0 +1,85 @@
# ------------------------------------------------------------------------
# Modified from https://github.com/Totoro97/NeuS/blob/main/models/renderer.py
# Copyright (c) 2021 Peng Wang.
# ------------------------------------------------------------------------
import mcubes
import numpy as np
import torch
def extract_fields(bound_min,
bound_max,
resolution,
query_func,
device='cuda'):
N = 64
X = torch.linspace(bound_min[0], bound_max[0], resolution).split(N)
Y = torch.linspace(bound_min[1], bound_max[1], resolution).split(N)
Z = torch.linspace(bound_min[2], bound_max[2], resolution).split(N)
u = np.zeros([resolution, resolution, resolution], dtype=np.float32)
with torch.no_grad():
for xi, xs in enumerate(X):
for yi, ys in enumerate(Y):
for zi, zs in enumerate(Z):
xx, yy, zz = torch.meshgrid(xs, ys, zs)
xx = xx.reshape(-1, 1)
yy = yy.reshape(-1, 1)
zz = zz.reshape(-1, 1)
pts = torch.cat([xx, yy, zz], dim=-1)
pts = pts.to(device)
val = query_func(pts).reshape(
len(xs), len(ys), len(zs)).detach().cpu().numpy()
u[xi * N:xi * N + len(xs), yi * N:yi * N + len(ys),
zi * N:zi * N + len(zs)] = val
return u
def extract_geometry(bound_min, bound_max, resolution, threshold, query_func,
device):
print('threshold: {}'.format(threshold))
u = extract_fields(bound_min, bound_max, resolution, query_func, device)
vertices, triangles = mcubes.marching_cubes(u, threshold)
b_max_np = bound_max.detach().cpu().numpy()
b_min_np = bound_min.detach().cpu().numpy()
vertices = vertices / (resolution - 1.0) * (
b_max_np - b_min_np)[None, :] + b_min_np[None, :]
return vertices, triangles
def sample_pdf(bins, weights, n_samples, det=False, device='cuda'):
# This implementation is from NeRF
# Get pdf
weights = weights + 1e-5 # prevent nans
pdf = weights / torch.sum(weights, -1, keepdim=True)
cdf = torch.cumsum(pdf, -1)
cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1)
# Take uniform samples
if det:
u = torch.linspace(
0. + 0.5 / n_samples, 1. - 0.5 / n_samples,
steps=n_samples).to(device)
u = u.expand(list(cdf.shape[:-1]) + [n_samples])
else:
u = torch.rand(list(cdf.shape[:-1]) + [n_samples]).to(device)
# Invert CDF
u = u.contiguous()
inds = torch.searchsorted(cdf, u, right=True)
below = torch.max(torch.zeros_like(inds - 1), inds - 1)
above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds)
inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2)
matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)
denom = (cdf_g[..., 1] - cdf_g[..., 0])
denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)
t = (u - cdf_g[..., 0]) / denom
samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])
return samples

View File

@@ -12,7 +12,7 @@ from modelscope.models.cv.video_depth_estimation.utils.misc import filter_dict
########################################################################################################################
def resize_image(image, shape, interpolation=Image.ANTIALIAS):
def resize_image(image, shape, interpolation=Image.Resampling.LANCZOS):
"""
Resizes input image.
@@ -57,7 +57,8 @@ def resize_depth(depth, shape):
def resize_sample_image_and_intrinsics(sample,
shape,
image_interpolation=Image.ANTIALIAS):
image_interpolation=Image.Resampling.
LANCZOS):
"""
Resizes the image and intrinsics of a sample
@@ -102,7 +103,7 @@ def resize_sample_image_and_intrinsics(sample,
return sample
def resize_sample(sample, shape, image_interpolation=Image.ANTIALIAS):
def resize_sample(sample, shape, image_interpolation=Image.Resampling.LANCZOS):
"""
Resizes a sample, including image, intrinsics and depth maps.

View File

@@ -6,22 +6,23 @@ from modelscope.utils.import_utils import LazyImportModule
if TYPE_CHECKING:
from .clip import CLIPForMultiModalEmbedding
from .gemm import GEMMForMultiModalEmbedding
from .rleg import RLEGForMultiModalEmbedding
from .team import TEAMForMultiModalSimilarity
from .clip_interrogator import CLIP_Interrogator
from .diffusion import DiffusionForTextToImageSynthesis
from .efficient_diffusion_tuning import EfficientStableDiffusion
from .gemm import GEMMForMultiModalEmbedding
from .mmr import VideoCLIPForMultiModalEmbedding
from .mplug_for_all_tasks import MPlugForAllTasks, HiTeAForAllTasks
from .mplug_for_all_tasks import HiTeAForAllTasks, MPlugForAllTasks
from .mplug_owl import MplugOwlForConditionalGeneration
from .multi_stage_diffusion import \
MultiStageDiffusionForTextToImageSynthesis
from .ofa_for_all_tasks import OfaForAllTasks
from .ofa_for_text_to_image_synthesis_model import \
OfaForTextToImageSynthesis
from .multi_stage_diffusion import \
MultiStageDiffusionForTextToImageSynthesis
from .vldoc import VLDocForDocVLEmbedding
from .prost import ProSTForTVRetrieval
from .rleg import RLEGForMultiModalEmbedding
from .team import TEAMForMultiModalSimilarity
from .video_synthesis import TextToVideoSynthesis
from .efficient_diffusion_tuning import EfficientStableDiffusion
from .mplug_owl import MplugOwlForConditionalGeneration
from .clip_interrogator import CLIP_Interrogator
from .vldoc import VLDocForDocVLEmbedding
from .videocomposer import VideoComposer
else:
@@ -32,6 +33,7 @@ else:
'rleg': ['RLEGForMultiModalEmbedding'],
'team': ['TEAMForMultiModalSimilarity'],
'mmr': ['VideoCLIPForMultiModalEmbedding'],
'prost': ['ProSTForTVRetrieval'],
'mplug_for_all_tasks': ['MPlugForAllTasks', 'HiTeAForAllTasks'],
'ofa_for_all_tasks': ['OfaForAllTasks'],
'ofa_for_text_to_image_synthesis_model':

View File

@@ -578,7 +578,7 @@ class CLIPForMultiModalEmbedding(TorchModel):
with torch.autograd.set_grad_enabled(mode == ModeKeys.TRAIN):
image_features = self.clip_model.encode_image(image_tensor)
image_features /= image_features.norm(
image_features = image_features / image_features.norm(
dim=-1, keepdim=True) # l2-normalize
output[OutputKeys.IMG_EMBEDDING] = image_features
@@ -590,7 +590,7 @@ class CLIPForMultiModalEmbedding(TorchModel):
with torch.autograd.set_grad_enabled(mode == ModeKeys.TRAIN):
text_features = self.clip_model.encode_text(text_tensor)
text_features /= text_features.norm(
text_features = text_features / text_features.norm(
dim=-1, keepdim=True) # l2-normalize
output[OutputKeys.TEXT_EMBEDDING] = text_features

View File

@@ -0,0 +1,3 @@
# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.
from .models import ProSTForTVRetrieval

View File

@@ -0,0 +1,117 @@
# The implementation is adopted from Huaishao Luo,
# made pubicly available under the MIT License at https://github.com/ArrowLuo/CLIP4Clip
import cv2
import numpy as np
import torch as th
from PIL import Image
from torchvision.transforms import (CenterCrop, Compose, InterpolationMode,
Normalize, Resize, ToTensor)
from modelscope.utils.logger import get_logger
logger = get_logger()
class RawVideoExtractorCV2():
def __init__(
self,
centercrop=False,
size=224,
frame_rate=-1,
):
self.centercrop = centercrop
self.size = size
self.framerate = frame_rate
self.transform = self._transform(self.size)
def _transform(self, n_px):
return Compose([
Resize(n_px, interpolation=InterpolationMode.BICUBIC),
CenterCrop(n_px),
lambda image: image.convert('RGB'),
ToTensor(),
Normalize((0.48145466, 0.4578275, 0.40821073),
(0.26862954, 0.26130258, 0.27577711)),
])
def video_to_tensor(self,
video_file,
preprocess,
sample_fp=0,
start_time=None,
end_time=None):
if start_time is not None or end_time is not None:
assert isinstance(start_time, int) and isinstance(end_time, int) \
and start_time > -1 and end_time > start_time
assert sample_fp > -1
# Samples a frame sample_fp X frames.
cap = cv2.VideoCapture(video_file)
frameCount = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
fps = int(cap.get(cv2.CAP_PROP_FPS))
if fps == 0:
logger.info(f'{video_file} with fps 0!!!')
total_duration = (frameCount + fps - 1) // fps
start_sec, end_sec = 0, total_duration
if start_time is not None:
start_sec, end_sec = start_time, end_time if end_time <= total_duration else total_duration
cap.set(cv2.CAP_PROP_POS_FRAMES, int(start_time * fps))
interval = 1
if sample_fp > 0:
interval = fps // sample_fp
else:
sample_fp = fps
if interval == 0:
interval = 1
inds = [ind for ind in np.arange(0, fps, interval)]
assert len(inds) >= sample_fp
inds = inds[:sample_fp]
ret = True
images = []
for sec in np.arange(start_sec, end_sec + 1):
if not ret:
break
sec_base = int(sec * fps)
for ind in inds:
cap.set(cv2.CAP_PROP_POS_FRAMES, sec_base + ind)
ret, frame = cap.read()
if not ret:
break
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
images.append(
preprocess(Image.fromarray(frame_rgb).convert('RGB')))
cap.release()
if len(images) > 0:
video_data = th.tensor(np.stack(images))
else:
video_data = th.zeros(1)
return {'video': video_data}
def get_video_data(self, video_path, start_time=None, end_time=None):
image_input = self.video_to_tensor(
video_path,
self.transform,
sample_fp=self.framerate,
start_time=start_time,
end_time=end_time)
return image_input
def process_raw_data(self, raw_video_data):
tensor_size = raw_video_data.size()
tensor = raw_video_data.view(-1, 1, tensor_size[-3], tensor_size[-2],
tensor_size[-1])
return tensor
# An ordinary video frame extractor based CV2
RawVideoExtractor = RawVideoExtractorCV2

View File

@@ -0,0 +1,3 @@
# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.
from .prost_model import ProSTForTVRetrieval

View File

@@ -0,0 +1,704 @@
# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.
import os
import platform
from collections import OrderedDict
from types import SimpleNamespace
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from modelscope.models.multi_modal.prost.models.module_clip import (
_PT_NAME, CLIP, QuickGELU, convert_weights)
from modelscope.models.multi_modal.prost.models.module_cross import (
CrossConfig, CrossModel)
from modelscope.models.multi_modal.prost.models.module_cross import \
Transformer as TransformerClip
from modelscope.models.multi_modal.prost.models.until_module import (
AllGather, CrossEn, Event_decoder, Frame_decoder, LayerNorm,
PreTrainedModel, make_patch_shift)
from modelscope.utils.logger import get_logger
allgather = AllGather.apply
logger = get_logger()
__all__ = ['CLIP4Clip']
class MyObject:
def __init__(self, **kwargs):
for key, value in kwargs.items():
setattr(self, key, value)
class CLIP4ClipPreTrainedModel(PreTrainedModel, nn.Module):
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
"""
def __init__(self, cross_config, *inputs, **kwargs):
super(CLIP4ClipPreTrainedModel, self).__init__(cross_config)
self.cross_config = cross_config
self.clip = None
self.cross = None
@classmethod
def from_pretrained(cls,
cross_config,
state_dict=None,
cache_dir=None,
type_vocab_size=2,
*inputs,
**kwargs):
task_config = None
if 'task_config' in kwargs.keys():
task_config = kwargs['task_config']
if not hasattr(task_config, 'local_rank'):
task_config['local_rank'] = 0
elif task_config['local_rank'] == -1:
task_config['local_rank'] = 0
if state_dict is None:
state_dict = {}
# pretrained_clip_name = task_config['pretrained_clip_name']
clip_state_dict = CLIP.get_config(model_dir=task_config['model_dir'])
for key, val in clip_state_dict.items():
new_key = 'clip.' + key
if new_key not in state_dict:
state_dict[new_key] = val.clone()
# cross_config, _ = CrossConfig.get_config(
# cross_model_name,
# cache_dir,
# type_vocab_size,
# state_dict=None,
# task_config=task_config)
cross_config = CrossConfig.from_dict(cross_config)
cross_config.type_vocab_size = type_vocab_size
task_config = MyObject(**kwargs['task_config'])
model = cls(cross_config, clip_state_dict, *inputs, task_config)
# ===> Initialization trick [HARD CODE]
if model.linear_patch == '3d':
contain_conv2 = False
for key in state_dict.keys():
if key.find('visual.conv2.weight') > -1:
contain_conv2 = True
break
if contain_conv2 is False and hasattr(model.clip.visual, 'conv2'):
cp_weight = state_dict['clip.visual.conv1.weight'].clone()
kernel_size = model.clip.visual.conv2.weight.size(2)
conv2_size = model.clip.visual.conv2.weight.size()
conv2_size = list(conv2_size)
left_conv2_size = conv2_size.copy()
right_conv2_size = conv2_size.copy()
left_conv2_size[2] = (kernel_size - 1) // 2
right_conv2_size[2] = kernel_size - 1 - left_conv2_size[2]
left_zeros, right_zeros = None, None
if left_conv2_size[2] > 0:
left_zeros = torch.zeros(
*tuple(left_conv2_size),
dtype=cp_weight.dtype,
device=cp_weight.device)
if right_conv2_size[2] > 0:
right_zeros = torch.zeros(
*tuple(right_conv2_size),
dtype=cp_weight.dtype,
device=cp_weight.device)
cat_list = []
if left_zeros is not None:
cat_list.append(left_zeros)
cat_list.append(cp_weight.unsqueeze(2))
if right_zeros is not None:
cat_list.append(right_zeros)
cp_weight = torch.cat(cat_list, dim=2)
state_dict['clip.visual.conv2.weight'] = cp_weight
# if model.sim_header == 'tightTransf':
# contain_cross = False
# for key in state_dict.keys():
# if key.find('cross.transformer') > -1:
# contain_cross = True
# break
# if contain_cross is False:
# for key, val in clip_state_dict.items():
# if key == 'positional_embedding':
# state_dict[
# 'cross.embeddings.position_embeddings.weight'] = val.clone(
# )
# continue
# if key.find('transformer.resblocks') == 0:
# num_layer = int(key.split('.')[2])
# # cut from beginning
# if num_layer < task_config.cross_num_hidden_layers:
# state_dict['cross.' + key] = val.clone()
# continue
if model.sim_header == 'seqLSTM' or model.sim_header == 'seqTransf':
# This step is to detect whether in train mode or test mode
contain_frame_position = False
for key in state_dict.keys():
if key.find('frame_position_embeddings') > -1:
contain_frame_position = True
break
# train mode
if contain_frame_position is False:
for key, val in clip_state_dict.items():
if key == 'positional_embedding':
state_dict[
'frame_position_embeddings.weight'] = val.clone()
# state_dict["text_prompt_encoder.pos_embedding"] = val[0:3].clone()
continue
if model.sim_header == 'seqTransf' and key.find(
'transformer.resblocks') == 0:
num_layer = int(key.split('.')[2])
# cut from beginning
if num_layer < task_config.cross_num_hidden_layers:
state_dict[key.replace(
'transformer.',
'transformerClip.')] = val.clone()
continue
else:
for key, val in state_dict.items():
# test mode
if key.find('clip.visual.transformer.resblocks') == 0:
num_layer = int(key.split('.')[4])
# shift layers 10-11
if num_layer >= 10 and num_layer < 12:
state_dict[key.replace('attn.net.',
'attn.')] = val.clone()
# <=== End of initialization trick
if state_dict is not None:
model = cls.init_preweight(
model, state_dict, task_config=task_config)
make_patch_shift(model, video_frame=task_config.max_frames, n_div=14)
return model
def show_log(task_config, info):
if task_config is None or task_config.local_rank == 0:
logger.warning(info)
def update_attr(target_name,
target_config,
target_attr_name,
source_config,
source_attr_name,
default_value=None):
if hasattr(source_config, source_attr_name):
if default_value is None or getattr(source_config,
source_attr_name) != default_value:
setattr(target_config, target_attr_name,
getattr(source_config, source_attr_name))
# show_log(
# source_config, "Set {}.{}: {}.".format(
# target_name, target_attr_name,
# getattr(target_config, target_attr_name)))
return target_config
def check_attr(target_name, task_config):
return hasattr(task_config,
target_name) and task_config.__dict__[target_name]
class CLIP4Clip(CLIP4ClipPreTrainedModel):
def __init__(self, cross_config, clip_state_dict, task_config):
super(CLIP4Clip, self).__init__(cross_config)
self.task_config = task_config
self.ignore_video_index = -1
assert self.task_config.max_words + self.task_config.max_frames <= cross_config.max_position_embeddings
self._stage_one = True
self._stage_two = False
# show_log(task_config, "Stage-One:{}, Stage-Two:{}".format(self._stage_one, self._stage_two))
self.loose_type = False
if self._stage_one and check_attr('loose_type', self.task_config):
self.loose_type = True
# show_log(task_config, "Test retrieval by loose type.")
# CLIP Encoders: From OpenAI: CLIP [https://github.com/openai/CLIP] ===>
vit = 'visual.proj' in clip_state_dict
assert vit
if vit:
vision_width = clip_state_dict['visual.conv1.weight'].shape[0]
vision_layers = len([
k for k in clip_state_dict.keys() if k.startswith('visual.')
and k.endswith('.attn.in_proj_weight')
])
vision_patch_size = clip_state_dict['visual.conv1.weight'].shape[
-1]
grid_size = round(
(clip_state_dict['visual.positional_embedding'].shape[0]
- 1)**0.5)
image_resolution = vision_patch_size * grid_size
else:
counts: list = [
len(
set(
k.split('.')[2] for k in clip_state_dict
if k.startswith(f'visual.layer{b}')))
for b in [1, 2, 3, 4]
]
vision_layers = tuple(counts)
vision_width = clip_state_dict[
'visual.layer1.0.conv1.weight'].shape[0]
output_width = round(
(clip_state_dict['visual.attnpool.positional_embedding'].
shape[0] - 1)**0.5)
vision_patch_size = None
assert output_width**2 + 1 == clip_state_dict[
'visual.attnpool.positional_embedding'].shape[0]
image_resolution = output_width * 32
embed_dim = clip_state_dict['text_projection'].shape[1]
context_length = clip_state_dict['positional_embedding'].shape[0]
vocab_size = clip_state_dict['token_embedding.weight'].shape[0]
transformer_width = clip_state_dict['ln_final.weight'].shape[0]
transformer_heads = transformer_width // 64
transformer_layers = len(
set(
k.split('.')[2] for k in clip_state_dict
if k.startswith('transformer.resblocks')))
# show_log(task_config, "\t embed_dim: {}".format(embed_dim))
# show_log(task_config, "\t image_resolution: {}".format(image_resolution))
# show_log(task_config, "\t vision_layers: {}".format(vision_layers))
# show_log(task_config, "\t vision_width: {}".format(vision_width))
# show_log(task_config, "\t vision_patch_size: {}".format(vision_patch_size))
# show_log(task_config, "\t context_length: {}".format(context_length))
# show_log(task_config, "\t vocab_size: {}".format(vocab_size))
# show_log(task_config, "\t transformer_width: {}".format(transformer_width))
# show_log(task_config, "\t transformer_heads: {}".format(transformer_heads))
# show_log(task_config, "\t transformer_layers: {}".format(transformer_layers))
self.linear_patch = '2d'
if hasattr(task_config, 'linear_patch'):
self.linear_patch = task_config.linear_patch
# show_log(task_config, "\t\t linear_patch: {}".format(self.linear_patch))
# use .float() to avoid overflow/underflow from fp16 weight. https://github.com/openai/CLIP/issues/40
cut_top_layer = 0
self.clip = CLIP(
embed_dim,
image_resolution,
vision_layers - cut_top_layer,
vision_width,
vision_patch_size,
context_length,
vocab_size,
transformer_width,
transformer_heads,
transformer_layers - cut_top_layer,
linear_patch=self.linear_patch).float()
for key in ['input_resolution', 'context_length', 'vocab_size']:
if key in clip_state_dict:
del clip_state_dict[key]
convert_weights(self.clip)
# <=== End of CLIP Encoders
self.sim_header = 'seqTransf'
if hasattr(task_config, 'sim_header'):
self.sim_header = task_config.sim_header
# show_log(task_config, "\t sim_header: {}".format(self.sim_header))
if self.sim_header == 'tightTransf':
assert self.loose_type is False
cross_config.max_position_embeddings = context_length
if self.loose_type is False:
# Cross Encoder ===>
cross_config = update_attr('cross_config', cross_config,
'num_hidden_layers', self.task_config,
'cross_num_hidden_layers')
self.cross = CrossModel(cross_config)
# <=== End of Cross Encoder
self.similarity_dense = nn.Linear(cross_config.hidden_size, 1)
if self.sim_header == 'seqLSTM' or self.sim_header == 'seqTransf':
self.frame_position_embeddings = nn.Embedding(
cross_config.max_position_embeddings, cross_config.hidden_size)
# self.frame_position_embeddings = nn.Embedding(600, cross_config.hidden_size)
if self.sim_header == 'seqTransf':
self.transformerClip = TransformerClip(
width=transformer_width,
layers=self.task_config.cross_num_hidden_layers,
heads=transformer_heads,
)
if self.sim_header == 'seqLSTM':
self.lstm_visual = nn.LSTM(
input_size=cross_config.hidden_size,
hidden_size=cross_config.hidden_size,
batch_first=True,
bidirectional=False,
num_layers=1)
self.loss_fct = CrossEn()
self.apply(self.init_weights)
self.set_dim = 512
self.patch_num = self.task_config.max_patch
if hasattr(self.task_config, 'max_word_pro'):
self.word_pro_num = self.task_config.max_word_pro
else:
self.word_pro_num = self.task_config.max_phrase
self.frame_num = self.task_config.max_frames
if hasattr(self.task_config, 'max_vfea'):
self.event_num = self.task_config.max_vfea
else:
self.event_num = self.task_config.max_event
self.patch_prototype_weight = nn.Sequential(
nn.Linear(self.set_dim, self.set_dim), nn.ReLU(inplace=True),
nn.Linear(self.set_dim, self.patch_num - 1), nn.ReLU(inplace=True))
self.word_prototype_weight = nn.Sequential(
nn.Linear(self.set_dim, self.set_dim), nn.ReLU(inplace=True),
nn.Linear(self.set_dim, self.word_pro_num), nn.ReLU(inplace=True))
self.frame_decoder = Frame_decoder(
num_attris=self.frame_num,
layers=2,
heads=1,
dim_ftr=512,
pos_emb=False,
length=1,
dim_feedforward=512,
without_init=False)
self.event_decoder = Event_decoder(
num_attris=self.event_num,
layers=2,
heads=1,
dim_ftr=512,
pos_emb=False,
length=1,
dim_feedforward=512,
without_init=False)
# -------------------------------------------------------------------------------------------------------
def forward(self,
input_ids,
token_type_ids,
attention_mask,
video,
video_mask=None):
input_ids = input_ids.view(-1, input_ids.shape[-1])
token_type_ids = token_type_ids.view(-1, token_type_ids.shape[-1])
attention_mask = attention_mask.view(-1, attention_mask.shape[-1])
video_mask = video_mask.view(-1, video_mask.shape[-1])
# T x 3 x H x W
video = torch.as_tensor(video).float()
bs, ts, channel, h, w = video.shape
video = video.view(bs * ts, channel, h, w)
video_frame = bs * ts
phr_feat, sen_feat, obj_feat, eve_feat = self.get_sequence_visual_output(
input_ids,
token_type_ids,
attention_mask,
video,
video_mask,
shaped=True,
video_frame=video_frame)
if self.training:
sim_matrix1, sim_matrix2, sim_matrix3, sim_matrix4 = self.get_max_similarity_logits(
phr_feat,
sen_feat,
obj_feat,
eve_feat,
attention_mask,
video_mask,
shaped=True)
sim_loss = (self.loss_fct(sim_matrix1) + self.loss_fct(sim_matrix2)
+ self.loss_fct(sim_matrix3)
+ self.loss_fct(sim_matrix4)) / 4.0
loss = sim_loss
return loss
else:
return None
def get_max_similarity_logits(self,
word_feat,
text_feat,
patch_feat,
video_feat,
text_mask,
video_mask,
shaped=False):
if shaped is False:
text_mask = text_mask.view(-1, text_mask.shape[-1])
video_mask = video_mask.view(-1, video_mask.shape[-1])
if self.training and torch.cuda.is_available(): # batch merge here
text_feat = allgather(text_feat, self.task_config)
video_feat = allgather(video_feat, self.task_config)
word_feat = allgather(word_feat, self.task_config)
patch_feat = allgather(patch_feat, self.task_config)
video_mask = allgather(video_mask, self.task_config)
torch.distributed.barrier() # force sync
# ESPM
text_feat = F.normalize(text_feat, p=2, dim=1)
video_feat = F.normalize(video_feat, p=2, dim=2)
retrieve_logits = torch.einsum('ad,bkd->abk', [text_feat, video_feat])
retrieve_logits = retrieve_logits.max(2)[0]
# OPPM
word_feat = F.normalize(word_feat, p=2, dim=2)
patch_feat = F.normalize(patch_feat, p=2, dim=3)
retrieve_logits_2 = torch.einsum('aid, bfjd->abfij',
[word_feat, patch_feat])
retrieve_logits_2 = retrieve_logits_2.max(3)[0]
retrieve_logits_2 = retrieve_logits_2.max(2)[0]
retrieve_logits_2 = retrieve_logits_2.sum(2) / self.patch_num
if self.training:
logit_scale = self.clip.logit_scale.exp()
retrieve_logits = logit_scale * retrieve_logits
retrieve_logits_2 = logit_scale * retrieve_logits_2
return retrieve_logits, retrieve_logits.t(
), retrieve_logits_2, retrieve_logits_2.t()
def get_sequence_output(self,
input_ids,
token_type_ids,
attention_mask,
shaped=False):
if shaped is False:
input_ids = input_ids.view(-1, input_ids.shape[-1])
token_type_ids = token_type_ids.view(-1, token_type_ids.shape[-1])
attention_mask = attention_mask.view(-1, attention_mask.shape[-1])
bs_pair = input_ids.size(0)
sequence_hidden = self.clip.encode_text(
input_ids, return_hidden=True)[1].float()
text_feat = sequence_hidden.view(bs_pair, -1, sequence_hidden.size(-1))
word_weights = self.word_prototype_weight(text_feat)
text_word_proto = torch.einsum('bmd,bmn->bnd', text_feat, word_weights)
cls_text_feat = text_feat.contiguous()
cls_text_feat = cls_text_feat[torch.arange(cls_text_feat.shape[0]),
torch.sum(attention_mask, dim=-1) - 1, :]
return text_word_proto, cls_text_feat
def get_visual_output(self,
video,
video_mask,
shaped=False,
video_frame=-1):
if shaped is False:
video_mask = video_mask.view(-1, video_mask.shape[-1])
video = torch.as_tensor(video).float()
bs, ts, channel, h, w = video.shape
video = video.view(bs * ts, channel, h, w)
# video_frame = bs * ts
bs_pair = video_mask.size(0)
cls_video_feat, video_patch_feat = self.clip.encode_image_tokens(
video, return_hidden=True)
cls_video_feat = cls_video_feat.float()
video_patch_feat = video_patch_feat.float()
# frame_num = video_patch_feat.shape[0]
patch_dim = video_patch_feat.shape[2]
patch_weights = self.patch_prototype_weight(video_patch_feat)
# cls_video_feat
video_patch_proto = torch.einsum('bmd,bmn->bnd', video_patch_feat,
patch_weights)
video_patch_proto = torch.cat(
(cls_video_feat.unsqueeze(1), video_patch_proto), 1)
video_patch_proto = video_patch_proto.reshape(
bs_pair, self.task_config.max_frames, self.patch_num, patch_dim)
video_frame_proto = video_patch_proto.reshape(
bs_pair, self.patch_num * self.task_config.max_frames, patch_dim)
video_frame_proto = self.frame_decoder(video_frame_proto)
video_frame_proto = 0.5 * video_frame_proto + 0.5 * cls_video_feat.reshape(
bs_pair, self.task_config.max_frames, patch_dim)
video_frame_proto = self.event_decoder(video_frame_proto)
video_frame_proto = 0.5 * video_frame_proto + 0.5 * cls_video_feat.reshape(
bs_pair, self.task_config.max_frames, patch_dim).mean(1).unsqueeze(
1).repeat(1, video_frame_proto.shape[1], 1)
return video_patch_proto, video_frame_proto
def get_sequence_visual_output(self,
input_ids,
token_type_ids,
attention_mask,
video,
video_mask,
shaped=False,
video_frame=-1):
if shaped is False:
input_ids = input_ids.view(-1, input_ids.shape[-1])
token_type_ids = token_type_ids.view(-1, token_type_ids.shape[-1])
attention_mask = attention_mask.view(-1, attention_mask.shape[-1])
video_mask = video_mask.view(-1, video_mask.shape[-1])
video = torch.as_tensor(video).float()
# import pdb;pdb.set_trace()
# b, pair,
bs, ts, channel, h, w = video.shape
video = video.view(bs * ts, channel, h, w)
video_frame = bs * ts
word_feat, text_feat = self.get_sequence_output(
input_ids, token_type_ids, attention_mask, shaped=True)
patch_feat, frame_feat = self.get_visual_output(
video, video_mask, shaped=True, video_frame=video_frame)
return word_feat, text_feat, patch_feat, frame_feat
def _get_cross_output(self, sequence_output, visual_output, attention_mask,
video_mask):
concat_features = torch.cat((sequence_output, visual_output),
dim=1) # concatnate tokens and frames
concat_mask = torch.cat((attention_mask, video_mask), dim=1)
text_type_ = torch.zeros_like(attention_mask)
video_type_ = torch.ones_like(video_mask)
concat_type = torch.cat((text_type_, video_type_), dim=1)
cross_layers, pooled_output = self.cross(
concat_features,
concat_type,
concat_mask,
output_all_encoded_layers=True)
cross_output = cross_layers[-1]
return cross_output, pooled_output, concat_mask
def _mean_pooling_for_similarity_sequence(self, sequence_output,
attention_mask):
attention_mask_un = attention_mask.to(dtype=torch.float).unsqueeze(-1)
attention_mask_un[:, 0, :] = 0.
sequence_output = sequence_output * attention_mask_un
text_out = torch.sum(
sequence_output, dim=1) / torch.sum(
attention_mask_un, dim=1, dtype=torch.float)
return text_out
def _mean_pooling_for_similarity_visual(
self,
visual_output,
video_mask,
):
video_mask_un = video_mask.to(dtype=torch.float).unsqueeze(-1)
visual_output = visual_output * video_mask_un
video_mask_un_sum = torch.sum(video_mask_un, dim=1, dtype=torch.float)
video_mask_un_sum[video_mask_un_sum == 0.] = 1.
video_out = torch.sum(visual_output, dim=1) / video_mask_un_sum
return video_out
def _mean_pooling_for_similarity(
self,
sequence_output,
visual_output,
attention_mask,
video_mask,
):
text_out = self._mean_pooling_for_similarity_sequence(
sequence_output, attention_mask)
video_out = self._mean_pooling_for_similarity_visual(
visual_output, video_mask)
return text_out, video_out
def get_global_similarity(self, sequence_output, visual_output,
attention_mask, video_mask):
visual_output = visual_output / visual_output.norm(
dim=-1, keepdim=True)
visual_output = self._mean_pooling_for_similarity_visual(
visual_output, video_mask)
visual_output = visual_output / visual_output.norm(
dim=-1, keepdim=True)
sequence_output = sequence_output.squeeze(1)
sequence_output = sequence_output / sequence_output.norm(
dim=-1, keepdim=True)
logit_scale = self.clip.logit_scale.exp()
# retrieve_logits = logit_scale * torch.matmul(sequence_output, visual_output.t())
sim_matrix_global = logit_scale * torch.matmul(sequence_output,
visual_output.t())
return sim_matrix_global
def _cross_similarity(self, sequence_output, visual_output, attention_mask,
video_mask):
sequence_output, visual_output = sequence_output.contiguous(
), visual_output.contiguous()
b_text, s_text, h_text = sequence_output.size()
b_visual, s_visual, h_visual = visual_output.size()
retrieve_logits_list = []
step_size = b_text # set smaller to reduce memory cost
split_size = [step_size] * (b_text // step_size)
release_size = b_text - sum(split_size)
if release_size > 0:
split_size += [release_size]
# due to clip text branch retrun the last hidden
attention_mask = torch.ones(sequence_output.size(0), 1)\
.to(device=attention_mask.device, dtype=attention_mask.dtype)
sequence_output_splits = torch.split(
sequence_output, split_size, dim=0)
attention_mask_splits = torch.split(attention_mask, split_size, dim=0)
for i in range(len(split_size)):
sequence_output_row = sequence_output_splits[i]
attention_mask_row = attention_mask_splits[i]
sequence_output_l = sequence_output_row.unsqueeze(1).repeat(
1, b_visual, 1, 1)
sequence_output_l = sequence_output_l.view(-1, s_text, h_text)
attention_mask_l = attention_mask_row.unsqueeze(1).repeat(
1, b_visual, 1)
attention_mask_l = attention_mask_l.view(-1, s_text)
step_truth = sequence_output_row.size(0)
visual_output_r = visual_output.unsqueeze(0).repeat(
step_truth, 1, 1, 1)
visual_output_r = visual_output_r.view(-1, s_visual, h_visual)
video_mask_r = video_mask.unsqueeze(0).repeat(step_truth, 1, 1)
video_mask_r = video_mask_r.view(-1, s_visual)
cross_output, pooled_output, concat_mask = \
self._get_cross_output(sequence_output_l, visual_output_r, attention_mask_l, video_mask_r)
retrieve_logits_row = self.similarity_dense(pooled_output).squeeze(
-1).view(step_truth, b_visual)
retrieve_logits_list.append(retrieve_logits_row)
retrieve_logits = torch.cat(retrieve_logits_list, dim=0)
return retrieve_logits

View File

@@ -0,0 +1,538 @@
# The implementation is adopated from the CLIP4Clip implementation,
# made pubicly available under Apache License, Version 2.0 at https://github.com/ArrowLuo/CLIP4Clip
import hashlib
import os
import urllib
import warnings
from collections import OrderedDict
from typing import Tuple, Union
import torch
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from torch import nn
from tqdm import tqdm
_MODELS = {}
_PT_NAME = {'ViT-B/16': 'ViT-B-16.pt'}
def available_models():
"""Returns the names of available CLIP models"""
return list(_MODELS.keys())
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1):
super(Bottleneck, self).__init__()
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = None
self.stride = stride
if stride > 1 or inplanes != planes * Bottleneck.expansion:
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
self.downsample = nn.Sequential(
OrderedDict([('-1', nn.AvgPool2d(stride)),
('0',
nn.Conv2d(
inplanes,
planes * self.expansion,
1,
stride=1,
bias=False)),
('1', nn.BatchNorm2d(planes * self.expansion))]))
def forward(self, x: torch.Tensor):
identity = x
out = self.relu(self.bn1(self.conv1(x)))
out = self.relu(self.bn2(self.conv2(out)))
out = self.avgpool(out)
out = self.bn3(self.conv3(out))
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class AttentionPool2d(nn.Module):
def __init__(self,
spacial_dim: int,
embed_dim: int,
num_heads: int,
output_dim: int = None):
super(AttentionPool2d, self).__init__()
self.positional_embedding = nn.Parameter(
torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
self.num_heads = num_heads
def forward(self, x):
x = x.reshape(x.shape[0], x.shape[1],
x.shape[2] * x.shape[3]).permute(2, 0,
1) # NCHW -> (HW)NC
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
x, _ = F.multi_head_attention_forward(
query=x,
key=x,
value=x,
embed_dim_to_check=x.shape[-1],
num_heads=self.num_heads,
q_proj_weight=self.q_proj.weight,
k_proj_weight=self.k_proj.weight,
v_proj_weight=self.v_proj.weight,
in_proj_weight=None,
in_proj_bias=torch.cat(
[self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
bias_k=None,
bias_v=None,
add_zero_attn=False,
dropout_p=0,
out_proj_weight=self.c_proj.weight,
out_proj_bias=self.c_proj.bias,
use_separate_proj_weight=True,
training=self.training,
need_weights=False)
return x[0]
class ModifiedResNet(nn.Module):
"""
A ResNet class that is similar to torchvision's but contains the following changes:
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
- The final pooling layer is a QKV attention instead of an average pool
"""
def __init__(self,
layers,
output_dim,
heads,
input_resolution=224,
width=64):
super(ModifiedResNet, self).__init__()
self.output_dim = output_dim
self.input_resolution = input_resolution
# the 3-layer stem
self.conv1 = nn.Conv2d(
3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(width // 2)
self.conv2 = nn.Conv2d(
width // 2, width // 2, kernel_size=3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(width // 2)
self.conv3 = nn.Conv2d(
width // 2, width, kernel_size=3, padding=1, bias=False)
self.bn3 = nn.BatchNorm2d(width)
self.avgpool = nn.AvgPool2d(2)
self.relu = nn.ReLU(inplace=True)
# residual layers
self._inplanes = width # this is a *mutable* variable used during construction
self.layer1 = self._make_layer(width, layers[0])
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
embed_dim = width * 32 # the ResNet feature dimension
self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim,
heads, output_dim)
def _make_layer(self, planes, blocks, stride=1):
layers = [Bottleneck(self._inplanes, planes, stride)]
self._inplanes = planes * Bottleneck.expansion
for _ in range(1, blocks):
layers.append(Bottleneck(self._inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x):
def stem(x):
for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2),
(self.conv3, self.bn3)]:
x = self.relu(bn(conv(x)))
x = self.avgpool(x)
return x
x = x.type(self.conv1.weight.dtype)
x = stem(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.attnpool(x)
return x
class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""
def forward(self, x: torch.Tensor):
orig_type = x.dtype
ret = super().forward(x.type(torch.float32))
return ret.type(orig_type)
class QuickGELU(nn.Module):
def forward(self, x: torch.Tensor):
return x * torch.sigmoid(1.702 * x)
class ResidualAttentionBlock(nn.Module):
def __init__(self, d_model: int, n_head: int, attn_mask=None):
super(ResidualAttentionBlock, self).__init__()
self.attn = nn.MultiheadAttention(d_model, n_head)
self.ln_1 = LayerNorm(d_model)
self.mlp = nn.Sequential(
OrderedDict([('c_fc', nn.Linear(d_model, d_model * 4)),
('gelu', QuickGELU()),
('c_proj', nn.Linear(d_model * 4, d_model))]))
self.ln_2 = LayerNorm(d_model)
self.attn_mask = attn_mask
def attention(self, x: torch.Tensor):
attn_mask_ = self.attn_mask
if self.attn_mask is not None and hasattr(self.attn_mask, '__call__'):
attn_mask_ = self.attn_mask(x.size(0)) # LND
attn_mask_ = attn_mask_.to(
dtype=x.dtype, device=x.device) if attn_mask_ is not None else None
return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask_)[0]
def forward(self, x):
x = x + self.attention(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
class Transformer(nn.Module):
def __init__(self,
width: int,
layers: int,
heads: int,
attn_mask=None,
use_gc=0):
super(Transformer, self).__init__()
self.width = width
self.layers = layers
self.resblocks = nn.Sequential(*[
ResidualAttentionBlock(width, heads, attn_mask)
for _ in range(layers)
])
self.use_gc = use_gc
def forward(self, x: torch.Tensor):
if self.use_gc > 0:
for blk in self.resblocks:
x = checkpoint.checkpoint(blk, x)
return x
else:
return self.resblocks(x)
class VisualTransformer(nn.Module):
def __init__(self,
input_resolution: int,
patch_size: int,
width: int,
layers: int,
heads: int,
output_dim: int,
linear_patch: str = '2d',
use_gc: int = 0):
super(VisualTransformer, self).__init__()
self.input_resolution = input_resolution
self.output_dim = output_dim
self.conv1 = nn.Conv2d(
in_channels=3,
out_channels=width,
kernel_size=patch_size,
stride=patch_size,
bias=False)
scale = width**-0.5
self.class_embedding = nn.Parameter(scale * torch.randn(width))
self.positional_embedding = nn.Parameter(scale * torch.randn(
(input_resolution // patch_size)**2 + 1, width))
self.ln_pre = LayerNorm(width)
self.transformer = Transformer(width, layers, heads, use_gc=use_gc)
self.ln_post = LayerNorm(width)
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
# For 3D
assert linear_patch in ['2d', '3d']
self.linear_patch = linear_patch
if self.linear_patch == '3d':
self.conv2 = nn.Conv3d(
in_channels=3,
out_channels=width,
kernel_size=(3, patch_size, patch_size),
stride=(1, patch_size, patch_size),
padding=(1, 0, 0),
bias=False)
def forward(self, x: torch.Tensor, video_frame=-1):
if self.linear_patch == '3d':
assert video_frame != -1
x_3d = x.reshape(-1, video_frame, x.shape[-3], x.shape[-2],
x.shape[-1])
x_3d = x_3d.permute(0, 2, 1, 3, 4)
x_3d = self.conv2(x_3d) # shape = [*, width, frame, grid, grid]
x_3d = x_3d.permute(0, 2, 1, 3,
4) # shape = [*, frame, width, grid, grid]
x = x_3d.reshape(
-1, x_3d.shape[-3], x_3d.shape[-2],
x_3d.shape[-1]).contiguous() # shape = [*, width, grid, grid]
else:
x = self.conv1(x) # shape = [*, width, grid, grid]
x = x.reshape(x.shape[0], x.shape[1],
-1) # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
_x = self.class_embedding.to(x.dtype) + torch.zeros(
x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device)
x = torch.cat([_x, x], dim=1)
x = x + self.positional_embedding.to(x.dtype)
x = self.ln_pre(x)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
return x
class CLIP(nn.Module):
def __init__(
self,
embed_dim: int,
# vision
image_resolution: int,
vision_layers: Union[Tuple[int, int, int, int], int],
vision_width: int,
vision_patch_size: int,
# text
context_length: int,
vocab_size: int,
transformer_width: int,
transformer_heads: int,
transformer_layers: int,
# vision linear of patch
linear_patch: str = '2d',
use_gc: int = 0):
super(CLIP, self).__init__()
self.context_length = context_length
if isinstance(vision_layers, (tuple, list)):
vision_heads = vision_width * 32 // 64
self.visual = ModifiedResNet(
layers=vision_layers,
output_dim=embed_dim,
heads=vision_heads,
input_resolution=image_resolution,
width=vision_width)
else:
vision_heads = vision_width // 64
self.visual = VisualTransformer(
input_resolution=image_resolution,
patch_size=vision_patch_size,
width=vision_width,
layers=vision_layers,
heads=vision_heads,
output_dim=embed_dim,
linear_patch=linear_patch,
use_gc=use_gc)
self.transformer = Transformer(
width=transformer_width,
layers=transformer_layers,
heads=transformer_heads,
attn_mask=self.build_attention_mask)
self.vocab_size = vocab_size
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
self.positional_embedding = nn.Parameter(
torch.empty(self.context_length, transformer_width))
self.ln_final = LayerNorm(transformer_width)
self.text_projection = nn.Parameter(
torch.empty(transformer_width, embed_dim))
self.logit_scale = nn.Parameter(torch.ones([]))
self.initialize_parameters()
def initialize_parameters(self):
nn.init.normal_(self.token_embedding.weight, std=0.02)
nn.init.normal_(self.positional_embedding, std=0.01)
if isinstance(self.visual, ModifiedResNet):
if self.visual.attnpool is not None:
std = self.visual.attnpool.c_proj.in_features**-0.5
nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
for resnet_block in [
self.visual.layer1, self.visual.layer2, self.visual.layer3,
self.visual.layer4
]:
for name, param in resnet_block.named_parameters():
if name.endswith('bn3.weight'):
nn.init.zeros_(param)
proj_std = (self.transformer.width**-0.5) * (
(2 * self.transformer.layers)**-0.5)
attn_std = self.transformer.width**-0.5
fc_std = (2 * self.transformer.width)**-0.5
for block in self.transformer.resblocks:
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
if self.text_projection is not None:
nn.init.normal_(
self.text_projection, std=self.transformer.width**-0.5)
def build_attention_mask(self, context_length):
# lazily create causal attention mask, with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.zeros(context_length, context_length)
mask.fill_(float('-inf'))
mask.triu_(1) # zero out the lower diagonal
return mask
@staticmethod
def get_config(model_dir):
model_path = '{}/ViT-B-16.pt'.format(model_dir)
try:
# loading JIT archive
model = torch.jit.load(model_path, map_location='cpu').eval()
state_dict = model.state_dict()
except RuntimeError:
state_dict = torch.load(model_path, map_location='cpu')
return state_dict
@property
def dtype(self):
return self.visual.conv1.weight.dtype
def encode_image_tokens(self, image, return_hidden=False):
hidden = self.visual(image.type(self.dtype))
hidden = self.visual.ln_post(hidden) @ self.visual.proj
x = hidden[:, 0, :]
if return_hidden:
return x, hidden
return x
def encode_text(self, text, return_hidden=False, prompt=None):
x = self.token_embedding(text).type(
self.dtype) # [batch_size, n_ctx, d_model]
if prompt:
x = prompt(x)
pos_emd = self.positional_embedding[:x.size(1), :].type(self.dtype)
x = x + pos_emd
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
hidden = self.ln_final(x).type(self.dtype) @ self.text_projection
# take features from the eot embedding (eot_token is the highest number in each sequence)
x = hidden[torch.arange(hidden.shape[0]), text.argmax(dim=-1)]
if return_hidden:
return x, hidden
return x
def forward(self, image, text):
image_features = self.encode_image(image)
text_features = self.encode_text(text)
# normalized features
image_features = image_features / image_features.norm(
dim=-1, keepdim=True)
text_features = text_features / text_features.norm(
dim=-1, keepdim=True)
# cosine similarity as logits
logit_scale = self.logit_scale.exp()
logits_per_image = logit_scale * image_features @ text_features.t()
logits_per_text = logit_scale * text_features @ image_features.t()
return logits_per_image, logits_per_text
def convert_weights(model: nn.Module):
"""Convert applicable model parameters to fp16"""
def _convert_weights_to_fp16(lay):
# l = lay
if isinstance(lay, (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.Linear)):
lay.weight.data = lay.weight.data.half()
if lay.bias is not None:
lay.bias.data = lay.bias.data.half()
if isinstance(lay, nn.MultiheadAttention):
for attr in [
*[f'{s}_proj_weight' for s in ['in', 'q', 'k', 'v']],
'in_proj_bias', 'bias_k', 'bias_v'
]:
tensor = getattr(lay, attr)
if tensor is not None:
tensor.data = tensor.data.half()
for name in ['text_projection', 'proj']:
if hasattr(lay, name):
attr = getattr(lay, name)
if attr is not None:
attr.data = attr.data.half()
model.apply(_convert_weights_to_fp16)

View File

@@ -0,0 +1,249 @@
from __future__ import absolute_import, division, print_function
import copy
import logging
import math
import os
import shutil
import tarfile
import tempfile
from collections import OrderedDict
import json
import torch
import torch.nn.functional as F
from torch import nn
from .until_config import PreCrossConfig
from .until_module import ACT2FN, LayerNorm, PreTrainedModel
# PRETRAINED_MODEL_ARCHIVE_MAP = {}
# CONFIG_NAME = 'cross_config.json'
# WEIGHTS_NAME = 'cross_pytorch_model.bin'
class CrossConfig(PreCrossConfig):
"""Configuration class to store the configuration of a `CrossModel`.
"""
# pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP
# config_name = CONFIG_NAME
# weights_name = WEIGHTS_NAME
def __init__(self,
vocab_size_or_config_json_file,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act='gelu',
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02):
"""Constructs CrossConfig.
Args:
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `CrossModel`.
hidden_size: Size of the encoder layers and the pooler layer.
num_hidden_layers: Number of hidden layers in the Transformer encoder.
num_attention_heads: Number of attention heads for each attention layer in
the Transformer encoder.
intermediate_size: The size of the "intermediate" (i.e., feed-forward)
layer in the Transformer encoder.
hidden_act: The non-linear activation function (function or string) in the
encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
hidden_dropout_prob: The dropout probabilitiy for all fully connected
layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob: The dropout ratio for the attention
probabilities.
max_position_embeddings: The maximum sequence length that this model might
ever be used with. Typically set this to something large just in case
(e.g., 512 or 1024 or 2048).
type_vocab_size: The vocabulary size of the `token_type_ids` passed into
`CrossModel`.
initializer_range: The sttdev of the truncated_normal_initializer for
initializing all weight matrices.
"""
if isinstance(vocab_size_or_config_json_file, str):
with open(
vocab_size_or_config_json_file, 'r',
encoding='utf-8') as reader:
json_config = json.loads(reader.read())
for key, value in json_config.items():
self.__dict__[key] = value
elif isinstance(vocab_size_or_config_json_file, int):
self.vocab_size = vocab_size_or_config_json_file
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range
else:
raise ValueError(
'First argument must be either a vocabulary size (int)'
'or the path to a pretrained model config file (str)')
class QuickGELU(nn.Module):
def forward(self, x: torch.Tensor):
return x * torch.sigmoid(1.702 * x)
class ResidualAttentionBlock(nn.Module):
def __init__(self, d_model: int, n_head: int):
super().__init__()
self.attn = nn.MultiheadAttention(d_model, n_head)
self.ln_1 = LayerNorm(d_model)
self.mlp = nn.Sequential(
OrderedDict([('c_fc', nn.Linear(d_model, d_model * 4)),
('gelu', QuickGELU()),
('c_proj', nn.Linear(d_model * 4, d_model))]))
self.ln_2 = LayerNorm(d_model)
self.n_head = n_head
def attention(self, x: torch.Tensor, attn_mask: torch.Tensor):
attn_mask_ = attn_mask.repeat_interleave(self.n_head, dim=0)
return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask_)[0]
def forward(self, para_tuple: tuple):
# x: torch.Tensor, attn_mask: torch.Tensor
# print(para_tuple)
x, attn_mask = para_tuple
x = x + self.attention(self.ln_1(x), attn_mask)
x = x + self.mlp(self.ln_2(x))
return (x, attn_mask)
class Transformer(nn.Module):
def __init__(self, width: int, layers: int, heads: int):
super().__init__()
self.width = width
self.layers = layers
self.resblocks = nn.Sequential(
*[ResidualAttentionBlock(width, heads) for _ in range(layers)])
def forward(self, x: torch.Tensor, attn_mask: torch.Tensor):
return self.resblocks((x, attn_mask))[0]
class CrossEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings.
"""
def __init__(self, config):
super(CrossEmbeddings, self).__init__()
self.position_embeddings = nn.Embedding(config.max_position_embeddings,
config.hidden_size)
# self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
# self.LayerNorm = LayerNorm(config.hidden_size, eps=1e-12)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, concat_embeddings, concat_type=None):
_, seq_length = concat_embeddings.size(0), concat_embeddings.size(1)
# if concat_type is None:
# concat_type = torch.zeros(batch_size, concat_type).to(concat_embeddings.device)
position_ids = torch.arange(
seq_length, dtype=torch.long, device=concat_embeddings.device)
position_ids = position_ids.unsqueeze(0).expand(
concat_embeddings.size(0), -1)
# token_type_embeddings = self.token_type_embeddings(concat_type)
position_embeddings = self.position_embeddings(position_ids)
embeddings = concat_embeddings + position_embeddings # + token_type_embeddings
# embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
class CrossPooler(nn.Module):
def __init__(self, config):
super(CrossPooler, self).__init__()
self.ln_pool = LayerNorm(config.hidden_size)
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = QuickGELU()
def forward(self, hidden_states, hidden_mask):
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
hidden_states = self.ln_pool(hidden_states)
pooled_output = hidden_states[:, 0]
pooled_output = self.dense(pooled_output)
pooled_output = self.activation(pooled_output)
return pooled_output
class CrossModel(PreTrainedModel):
def initialize_parameters(self):
proj_std = (self.transformer.width**-0.5) * (
(2 * self.transformer.layers)**-0.5)
attn_std = self.transformer.width**-0.5
fc_std = (2 * self.transformer.width)**-0.5
for block in self.transformer.resblocks:
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
def __init__(self, config):
super(CrossModel, self).__init__(config)
self.embeddings = CrossEmbeddings(config)
transformer_width = config.hidden_size
transformer_layers = config.num_hidden_layers
transformer_heads = config.num_attention_heads
self.transformer = Transformer(
width=transformer_width,
layers=transformer_layers,
heads=transformer_heads,
)
self.pooler = CrossPooler(config)
self.apply(self.init_weights)
def build_attention_mask(self, attention_mask):
extended_attention_mask = attention_mask.unsqueeze(1)
extended_attention_mask = extended_attention_mask.to(
dtype=self.dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -1000000.0
extended_attention_mask = extended_attention_mask.expand(
-1, attention_mask.size(1), -1)
return extended_attention_mask
def forward(self,
concat_input,
concat_type=None,
attention_mask=None,
output_all_encoded_layers=True):
if attention_mask is None:
attention_mask = torch.ones(
concat_input.size(0), concat_input.size(1))
if concat_type is None:
concat_type = torch.zeros_like(attention_mask)
extended_attention_mask = self.build_attention_mask(attention_mask)
embedding_output = self.embeddings(concat_input, concat_type)
embedding_output = embedding_output.permute(1, 0, 2) # NLD -> LND
embedding_output = self.transformer(embedding_output,
extended_attention_mask)
embedding_output = embedding_output.permute(1, 0, 2) # LND -> NLD
pooled_output = self.pooler(
embedding_output, hidden_mask=attention_mask)
return embedding_output, pooled_output

View File

@@ -0,0 +1,267 @@
# The implementation is adopted from the CLIP4Clip implementation,
# made pubicly available under Apache License, Version 2.0 at https://github.com/ArrowLuo/CLIP4Clip
import os
import random
import uuid
from os.path import exists
from tempfile import TemporaryDirectory
from typing import Any, Dict
from urllib.parse import urlparse
import json
import numpy as np
import torch
from decord import VideoReader, cpu
from PIL import Image
from modelscope.hub.file_download import http_get_file
from modelscope.metainfo import Models
from modelscope.models import TorchModel
from modelscope.models.builder import MODELS
from modelscope.models.multi_modal.prost.models.modeling import CLIP4Clip
from modelscope.models.multi_modal.prost.models.tokenization_clip import \
SimpleTokenizer as ClipTokenizer
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger
from ..dataloaders.rawvideo_util import RawVideoExtractor
logger = get_logger()
@MODELS.register_module(Tasks.text_video_retrieval, module_name=Models.prost)
class ProSTForTVRetrieval(TorchModel):
def __init__(self, model_dir, **kwargs):
super().__init__(model_dir=model_dir, **kwargs)
# model config parameters
with open(
f'{model_dir}/{ModelFile.CONFIGURATION}', 'r',
encoding='utf-8') as json_file:
all_model_config = json.load(json_file)
model_config = all_model_config['paras']
cross_model_config = all_model_config['crossbase']
# print(all_model_config)
# print(cross_model_config)
model_config['model_dir'] = model_dir
self.SPECIAL_TOKEN = {
'CLS_TOKEN': '<|startoftext|>',
'SEP_TOKEN': '<|endoftext|>',
'MASK_TOKEN': '[MASK]',
'UNK_TOKEN': '[UNK]',
'PAD_TOKEN': '[PAD]'
}
self.max_words = model_config['max_words']
self.max_frames = model_config['max_frames']
self.feature_framerate = model_config['feature_framerate']
self.image_resolution = 224
if torch.cuda.is_available():
self.device = model_config['device']
else:
self.device = 'cpu'
self.init_model = f'{model_dir}/{ModelFile.TORCH_MODEL_BIN_FILE}'
self.tokenizer = ClipTokenizer(model_dir)
self.rawVideoExtractor = RawVideoExtractor(
frame_rate=self.feature_framerate, size=self.image_resolution)
self.local_transform = self.rawVideoExtractor.transform
self.model = CLIP4Clip.from_pretrained(
cross_config=cross_model_config, task_config=model_config)
if hasattr(self.model, 'module'):
self.model = self.model.module.to(self.device)
else:
self.model = self.model.to(self.device)
if self.init_model:
assert exists(self.init_model)
model_state_dict = torch.load(self.init_model, map_location='cpu')
self.model.load_state_dict(model_state_dict, strict=False)
self.model.to(self.device)
def _get_text(self, caption, tokenizer, enable_zh=False):
if type(caption) is str:
_caption_text, s, e = caption, None, None
elif type(caption) is tuple:
if len(caption) == 3:
_caption_text, s, e = caption
elif len(caption) == 4:
_caption_text, s, e, pos = caption
else:
NotImplementedError
if isinstance(_caption_text, list):
caption_text = random.choice(_caption_text)
else:
caption_text = _caption_text
if enable_zh:
_token = tokenizer.encode(caption_text)
input_ids = _token.ids
input_mask = _token.attention_mask
segment_ids = _token.type_ids
else:
words = tokenizer.tokenize(caption_text)
words = [self.SPECIAL_TOKEN['CLS_TOKEN']] + words
total_length_with_CLS = self.max_words - 1
if len(words) > total_length_with_CLS:
words = words[:total_length_with_CLS]
words = words + [self.SPECIAL_TOKEN['SEP_TOKEN']]
input_ids = tokenizer.convert_tokens_to_ids(words)
input_mask = [1] * len(input_ids)
segment_ids = [0] * len(input_ids)
while len(input_ids) < self.max_words:
input_ids.append(0)
input_mask.append(0)
segment_ids.append(0)
assert len(input_ids) == self.max_words
assert len(input_mask) == self.max_words
assert len(segment_ids) == self.max_words
pairs_text = np.array(input_ids)
pairs_mask = np.array(input_mask)
pairs_segment = np.array(segment_ids)
return pairs_text, pairs_mask, pairs_segment, s, e
def _get_rawvideo_dec(self,
video_path,
rawVideoExtractor,
local_transform,
s=None,
e=None):
video_mask = np.zeros(self.max_frames, dtype=int)
max_video_length = 0
# T x 3 x H x W
video = np.zeros((self.max_frames, 3, rawVideoExtractor.size,
rawVideoExtractor.size),
dtype=float)
if s is None:
start_time, end_time = None, None
else:
start_time = int(s)
end_time = int(e)
start_time = start_time if start_time >= 0. else 0.
end_time = end_time if end_time >= 0. else 0.
if start_time > end_time:
start_time, end_time = end_time, start_time
elif start_time == end_time:
end_time = end_time + 1
url_parsed = urlparse(video_path)
if url_parsed.scheme in ('file', '') and exists(
url_parsed.path): # Possibly a local file
vreader = VideoReader(video_path, ctx=cpu(0))
else:
try:
with TemporaryDirectory() as temporary_cache_dir:
random_str = uuid.uuid4().hex
http_get_file(
url=video_path,
local_dir=temporary_cache_dir,
file_name=random_str,
cookies=None)
temp_file_path = os.path.join(temporary_cache_dir,
random_str)
vreader = VideoReader(temp_file_path, ctx=cpu(0))
except Exception as ex:
logger.error('non video input, output is {}!!!'.format(ex))
return video, video_mask
fps = vreader.get_avg_fps()
f_start = 0 if start_time is None else int(start_time * fps)
f_end = int(
min(1000000000 if end_time is None else end_time * fps,
len(vreader) - 1))
num_frames = f_end - f_start + 1
if num_frames > 0:
# L x T x 3 x H x W
sample_fps = int(self.feature_framerate)
t_stride = int(round(float(fps) / sample_fps))
all_pos = list(range(f_start, f_end + 1, t_stride))
if len(all_pos) > self.max_frames:
sample_pos = [
all_pos[_] for _ in np.linspace(
0, len(all_pos) - 1, num=self.max_frames, dtype=int)
]
else:
sample_pos = all_pos
patch_images = [
Image.fromarray(f)
for f in vreader.get_batch(sample_pos).asnumpy()
]
patch_images = torch.stack(
[local_transform(img) for img in patch_images])
slice_len = patch_images.shape[0]
max_video_length = max_video_length if max_video_length > slice_len else slice_len
if slice_len < 1:
pass
else:
video[:slice_len, ...] = patch_images
else:
logger.error('video path: {} error. video id: {}'.format(
video_path, video_id))
video_mask[:max_video_length] = [1] * max_video_length
return video, video_mask
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
from modelscope.outputs import OutputKeys
output = {}
if 'video' in input and input['video'] is not None:
video_path = input['video']
video, video_mask = self._get_rawvideo_dec(video_path,
self.rawVideoExtractor,
self.local_transform)
video = torch.unsqueeze(
torch.from_numpy(video), dim=0).to(self.device)
video_mask = torch.unsqueeze(
torch.from_numpy(video_mask), dim=0).to(self.device)
if 'text' in input and input['text'] is not None:
caption = input['text']
pairs_text, pairs_mask, pairs_segment, s, e = self._get_text(
caption, self.tokenizer, enable_zh=False)
input_ids = torch.unsqueeze(
torch.from_numpy(pairs_text), dim=0).to(self.device)
input_mask = torch.unsqueeze(
torch.from_numpy(pairs_mask), dim=0).to(self.device)
segment_ids = torch.unsqueeze(
torch.from_numpy(pairs_segment), dim=0).to(self.device)
phr_feat, sen_feat, obj_feat, eve_feat = self.model.get_sequence_visual_output(
input_ids, segment_ids, input_mask, video, video_mask)
sim_espm, _, sim_oppm, _ = self.model.get_max_similarity_logits(
phr_feat,
sen_feat,
obj_feat,
eve_feat,
input_mask,
video_mask,
shaped=True)
# logger.info('sim: {}'.format(sim_espm))
# logger.info('sim: {}'.format(sim_oppm))
sim_tv = sim_espm + 1.5 * sim_oppm
# logger.info('phrase prototype: {}'.format(phr_feat.shape))
# logger.info('sentence prototype: {}'.format(sen_feat.shape))
# logger.info('object prototype: {}'.format(obj_feat.shape))
# logger.info('event prototype: {}'.format(eve_feat.shape))
output[OutputKeys.TEXTVIDEO_SIM] = sim_tv.cpu().detach().numpy()
output[OutputKeys.PHRASE_PROTOTYPE] = phr_feat.cpu().detach().numpy()
output[OutputKeys.SENTENCE_PROTOTYPE] = sen_feat.cpu().detach().numpy()
output[OutputKeys.OBJECT_PROTOTYPE] = obj_feat.cpu().detach().numpy()
output[OutputKeys.EVENT_PROTOTYPE] = eve_feat.cpu().detach().numpy()
return output
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
return inputs

View File

@@ -0,0 +1,161 @@
# The implementation is adopted from the CLIP4Clip implementation,
# made pubicly available under Apache License, Version 2.0 at https://github.com/ArrowLuo/CLIP4Clip
import gzip
import html
import os
from functools import lru_cache
import ftfy
import regex as re
@lru_cache()
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a signficant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
bs = list(range(ord('!'),
ord('~') + 1)) + list(range(
ord('¡'),
ord('¬') + 1)) + list(range(ord('®'),
ord('ÿ') + 1))
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8 + n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
def get_pairs(word):
"""Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
def basic_clean(text):
text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text))
return text.strip()
def whitespace_clean(text):
text = re.sub(r'\s+', ' ', text)
text = text.strip()
return text
class SimpleTokenizer(object):
def __init__(self, model_dir):
bpe_path = '{}/bpe_simple_vocab_16e6.txt.gz'.format(model_dir)
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
merges = gzip.open(bpe_path).read().decode('utf-8').split('\n')
merges = merges[1:49152 - 256 - 2 + 1]
merges = [tuple(merge.split()) for merge in merges]
vocab = list(bytes_to_unicode().values())
vocab = vocab + [v + '</w>' for v in vocab]
for merge in merges:
vocab.append(''.join(merge))
vocab.extend(['<|startoftext|>', '<|endoftext|>'])
self.encoder = dict(zip(vocab, range(len(vocab))))
self.decoder = {v: k for k, v in self.encoder.items()}
self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {
'<|startoftext|>': '<|startoftext|>',
'<|endoftext|>': '<|endoftext|>'
}
self.pat = re.compile(
r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
re.IGNORECASE)
self.vocab = self.encoder
def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token[:-1]) + (token[-1] + '</w>', )
pairs = get_pairs(word)
if not pairs:
return token + '</w>'
while True:
bigram = min(
pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except Exception:
new_word.extend(word[i:])
break
if word[i] == first and i < len(word) - 1 and word[
i + 1] == second:
new_word.append(first + second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = ' '.join(word)
self.cache[token] = word
return word
def encode(self, text):
bpe_tokens = []
text = whitespace_clean(basic_clean(text)).lower()
for token in re.findall(self.pat, text):
token = ''.join(self.byte_encoder[b]
for b in token.encode('utf-8'))
bpe_tokens.extend(self.encoder[bpe_token]
for bpe_token in self.bpe(token).split(' '))
return bpe_tokens
def decode(self, tokens):
text = ''.join([self.decoder[token] for token in tokens])
text = bytearray([self.byte_decoder[c] for c in text]).decode(
'utf-8', errors='replace').replace('</w>', ' ')
return text
def tokenize(self, text):
tokens = []
text = whitespace_clean(basic_clean(text)).lower()
for token in re.findall(self.pat, text):
token = ''.join(self.byte_encoder[b]
for b in token.encode('utf-8'))
tokens.extend(
bpe_token for bpe_token in self.bpe(token).split(' '))
return tokens
def convert_tokens_to_ids(self, tokens):
return [self.encoder[bpe_token] for bpe_token in tokens]

View File

@@ -0,0 +1,59 @@
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch BERT model."""
from __future__ import absolute_import, division, print_function
import copy
import logging
import os
import shutil
import tarfile
import tempfile
import json
import torch
# from modelscope.utils.logger import get_logger
# logger = get_logger(__name__)
class PreCrossConfig(object):
@classmethod
def from_dict(cls, json_object):
"""Constructs a `BertConfig` from a Python dictionary of parameters."""
config = cls(vocab_size_or_config_json_file=-1)
for key, value in json_object.items():
config.__dict__[key] = value
return config
@classmethod
def from_json_file(cls, json_file):
"""Constructs a `BertConfig` from a json file of parameters."""
with open(json_file, 'r', encoding='utf-8') as reader:
text = reader.read()
return cls.from_dict(json.loads(text))
def __repr__(self):
return str(self.to_json_string())
def to_dict(self):
"""Serializes this instance to a Python dictionary."""
output = copy.deepcopy(self.__dict__)
return output
def to_json_string(self):
"""Serializes this instance to a JSON string."""
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + '\n'

View File

@@ -0,0 +1,574 @@
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch BERT model."""
import copy
import logging
import math
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from modelscope.models.multi_modal.prost.models.until_config import \
PreCrossConfig
def gelu(x):
"""Implementation of the gelu activation function.
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
"""
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
def swish(x):
return x * torch.sigmoid(x)
ACT2FN = {'gelu': gelu, 'relu': torch.nn.functional.relu, 'swish': swish}
class LayerNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-12):
"""Construct a layernorm module in the TF style (epsilon inside the square root).
"""
super(LayerNorm, self).__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.bias = nn.Parameter(torch.zeros(hidden_size))
self.variance_epsilon = eps
def forward(self, x):
u = x.mean(-1, keepdim=True)
s = (x - u).pow(2).mean(-1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
return self.weight * x + self.bias
class CrossEn(nn.Module):
def __init__(self, config=None):
super(CrossEn, self).__init__()
def forward(self, sim_matrix):
logpt = F.log_softmax(sim_matrix, dim=-1)
logpt = torch.diag(logpt)
nce_loss = -logpt
sim_loss = nce_loss.mean()
return sim_loss
class AllGather(torch.autograd.Function):
"""An autograd function that performs allgather on a tensor."""
@staticmethod
def forward(ctx, tensor, args):
if args.world_size == 1:
ctx.rank = args.local_rank
ctx.batch_size = tensor.shape[0]
return tensor
else:
output = [torch.empty_like(tensor) for _ in range(args.world_size)]
torch.distributed.all_gather(output, tensor)
ctx.rank = args.local_rank
ctx.batch_size = tensor.shape[0]
return torch.cat(output, dim=0)
@staticmethod
def backward(ctx, grad_output):
return (
grad_output[ctx.batch_size * ctx.rank:ctx.batch_size
* (ctx.rank + 1)],
None,
)
class AllGather2(torch.autograd.Function):
"""An autograd function that performs allgather on a tensor."""
# https://github.com/PyTorchLightning/lightning-bolts/blob/8d3fbf7782e3d3937ab8a1775a7092d7567f2933/pl_bolts/models/self_supervised/simclr/simclr_module.py#L20
@staticmethod
def forward(ctx, tensor, args):
if args.world_size == 1:
ctx.rank = args.local_rank
ctx.batch_size = tensor.shape[0]
return tensor
else:
output = [torch.empty_like(tensor) for _ in range(args.world_size)]
torch.distributed.all_gather(output, tensor)
ctx.rank = args.local_rank
ctx.batch_size = tensor.shape[0]
return torch.cat(output, dim=0)
@staticmethod
def backward(ctx, grad_output):
grad_input = grad_output.clone()
torch.distributed.all_reduce(
grad_input, op=torch.distributed.ReduceOp.SUM, async_op=False)
return (grad_input[ctx.rank * ctx.batch_size:(ctx.rank + 1)
* ctx.batch_size], None)
class PreTrainedModel(nn.Module):
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
"""
def __init__(self, config, *inputs, **kwargs):
super(PreTrainedModel, self).__init__()
if not isinstance(config, PreCrossConfig):
raise ValueError(
'Parameter config in `{}(config)` should be an instance of class `PreCrossConfig`. '
'To create a model from a Google pretrained model use '
'`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`'.format(
self.__class__.__name__, self.__class__.__name__))
self.config = config
def init_weights(self, module):
""" Initialize the weights.
"""
if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(
mean=0.0, std=self.config.initializer_range)
elif isinstance(module, LayerNorm):
if 'beta' in dir(module) and 'gamma' in dir(module):
module.beta.data.zero_()
module.gamma.data.fill_(1.0)
else:
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
def resize_token_embeddings(self, new_num_tokens=None):
raise NotImplementedError
@classmethod
def init_preweight(cls, model, state_dict, prefix=None, task_config=None):
old_keys = []
new_keys = []
for key in state_dict.keys():
new_key = None
if 'gamma' in key:
new_key = key.replace('gamma', 'weight')
if 'beta' in key:
new_key = key.replace('beta', 'bias')
if new_key:
old_keys.append(key)
new_keys.append(new_key)
for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key)
if prefix is not None:
old_keys = []
new_keys = []
for key in state_dict.keys():
old_keys.append(key)
new_keys.append(prefix + key)
for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key)
missing_keys = []
unexpected_keys = []
error_msgs = []
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, '_metadata', None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
def load(module, prefix=''):
local_metadata = {} if metadata is None else metadata.get(
prefix[:-1], {})
module._load_from_state_dict(state_dict, prefix, local_metadata,
True, missing_keys, unexpected_keys,
error_msgs)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + '.')
load(model, prefix='')
# if prefix is None and (task_config is None or task_config.local_rank == 0):
# logger.info("-" * 20)
# if len(missing_keys) > 0:
# logger.info("Weights of {} not initialized from pretrained model: {}"
# .format(model.__class__.__name__, "\n " + "\n ".join(missing_keys)))
# if len(unexpected_keys) > 0:
# logger.info("Weights from pretrained model not used in {}: {}"
# .format(model.__class__.__name__, "\n " + "\n ".join(unexpected_keys)))
# if len(error_msgs) > 0:
# logger.error("Weights from pretrained model cause errors in {}: {}"
# .format(model.__class__.__name__, "\n " + "\n ".join(error_msgs)))
return model
@property
def dtype(self):
"""
:obj:`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
"""
try:
return next(self.parameters()).dtype
except StopIteration:
# For nn.DataParallel compatibility in PyTorch 1.5
def find_tensor_attributes(module: nn.Module):
tuples = [(k, v) for k, v in module.__dict__.items()
if torch.is_tensor(v)]
return tuples
gen = self._named_members(get_members_fn=find_tensor_attributes)
first_tuple = next(gen)
return first_tuple[1].dtype
@classmethod
def from_pretrained(cls, config, state_dict=None, *inputs, **kwargs):
"""
Instantiate a PreTrainedModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed.
"""
# Instantiate model.
model = cls(config, *inputs, **kwargs)
if state_dict is None:
return model
model = cls.init_preweight(model, state_dict)
return model
class PatchShiftModule(nn.Module):
def __init__(self, net, video_frame, n_div):
super().__init__()
self.net = net
self.video_frame = video_frame
self.n_div = n_div
def forward(self,
query,
key,
value,
key_padding_mask=None,
need_weights=True,
attn_mask=None):
# here q == k == v, psm means patch shift output
x = query # shape here is LND, not NLD (50, 384, 768)
x = x.permute(1, 0, 2) # LND -> NLD
patch_len = x.shape[-2]
fold = patch_len // self.n_div
x = x.reshape(-1, self.video_frame, x.shape[-2],
x.shape[-1]) # shape = [bs, frame, grid ** 2, width]
psm = torch.zeros_like(x) # shape = [bs, frame, grid ** 2, width]
psm[:, :, :, :] = x[:, :, :, :]
lshift_indices = torch.arange(start=1, end=patch_len, step=fold)
psm[:, 1:, lshift_indices, :] = x[:, :-1,
lshift_indices, :] # f_t = f_t-1
rshift_indices = torch.arange(start=1 + 3, end=patch_len, step=fold)
psm[:, :-1, rshift_indices, :] = x[:, 1:,
rshift_indices, :] # f_t = f_t+1
x = psm.reshape(-1, patch_len, x.shape[-1])
x = x.permute(1, 0, 2) # NLD -> LND
return self.net(
x, x, x, need_weights=need_weights, attn_mask=attn_mask)
def make_patch_shift(net, video_frame=12, shift_layers=4, n_div=7):
'''
Args:
net: CLIP
video_frame: need predefine here
shift_layers: layers to be shift
'''
def make_trans_patch_shift(stage, shift_layers):
blocks = list(stage.children())
for i, b in enumerate(blocks):
if i >= 10 and i <= 11:
blocks[i].attn = PatchShiftModule(
b.attn, video_frame=video_frame, n_div=n_div)
return nn.Sequential(*blocks)
net.clip.visual.transformer.resblocks = make_trans_patch_shift(
net.clip.visual.transformer.resblocks, shift_layers=shift_layers)
def _get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
class Event_Layer(nn.Module):
def __init__(self,
d_model,
nhead,
dim_feedforward=2048,
dropout=0.1,
activation='relu',
normalize_before=False,
is_weights=False):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
self.self_attn_vis = nn.MultiheadAttention(
d_model, nhead, dropout=dropout)
self.multihead_attn = nn.MultiheadAttention(
d_model, nhead, dropout=dropout)
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.norm4 = nn.LayerNorm(d_model)
self.norm5 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
self.activation = nn.ReLU(inplace=True)
self.normalize_before = normalize_before
self.is_weights = is_weights
def forward(self, tgt, memory, pos=None, query_pos=None):
tgt = self.norm1(tgt)
memory = self.norm2(memory)
tgt = self.self_attn(tgt, tgt, tgt)[0]
tgt = self.norm3(tgt)
tgt2, atten_weights = self.multihead_attn(tgt, memory, memory)
tgt = tgt + self.dropout1(tgt2)
tgt = self.norm4(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
tgt = tgt + self.dropout2(tgt2)
tgt = self.norm5(tgt)
return tgt, atten_weights
def adaptive_mask(aa, bb, ada_para):
tensor = torch.zeros((aa, bb))
adaptive_num = int(bb * ada_para)
cc = int(bb / aa)
for i in range(aa):
start_col = i * cc
end_col = start_col + cc + adaptive_num
if end_col > bb - 1:
tmp = end_col - (bb - 1)
start_col = start_col - tmp
if start_col < 0:
start_col = 0
end_col = bb
tensor[i, start_col:end_col] = 1
tensor = ~tensor.bool()
return tensor
class Frame_Layer(nn.Module):
def __init__(self,
d_model,
nhead,
dim_feedforward=2048,
para=1.0,
dropout=0.1,
activation='relu',
normalize_before=False,
is_weights=False):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
self.self_attn_vis = nn.MultiheadAttention(
d_model, nhead, dropout=dropout)
self.multihead_attn = nn.MultiheadAttention(
d_model, nhead, dropout=dropout)
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.norm4 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.activation = nn.ReLU(inplace=True)
self.normalize_before = normalize_before
self.is_weights = is_weights
self.mask_para = para
def forward(self, tgt, memory, pos=None, query_pos=None):
tgt = self.norm1(tgt)
memory = self.norm2(memory)
mask_new = adaptive_mask(tgt.shape[0], memory.shape[0], ada_para=0.2)
tgt2, atten_weights = self.multihead_attn(
tgt, memory, memory, attn_mask=mask_new.cuda())
tgt = tgt + self.dropout1(tgt2)
tgt = self.norm3(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
tgt = tgt + self.dropout2(tgt2)
tgt = self.norm4(tgt)
return tgt, atten_weights
class TransDecoder(nn.Module):
def __init__(self,
decoder_layer,
num_layers,
norm=None,
return_intermediate=False):
super().__init__()
self.layers = _get_clones(decoder_layer, num_layers)
self.num_layers = num_layers
self.norm = norm
self.return_intermediate = return_intermediate
def forward(self, tgt, memory, pos=None, query_pos=None):
output = tgt
intermediate = []
all_weights = []
for layer in self.layers:
output, weights = layer(
output, memory, pos=pos, query_pos=query_pos)
if self.return_intermediate:
intermediate.append(self.norm(output))
all_weights.append(weights)
if self.norm is not None:
output = self.norm(output)
if self.return_intermediate:
intermediate.pop()
intermediate.append(output)
if self.return_intermediate:
return torch.stack(intermediate), torch.stack(all_weights)
return output.unsqueeze(0)
class Event_decoder(nn.Module):
def __init__(self,
num_attris=3,
layers=1,
heads=1,
dim_ftr=512,
pos_emb=False,
length=1,
dim_feedforward=512,
without_init=False):
super().__init__()
embedding_dim = dim_ftr
d_model = dim_ftr
dim_feedforward = dim_feedforward
self.V = nn.Parameter(
torch.Tensor(num_attris, dim_feedforward), requires_grad=True)
nn.init.xavier_uniform_(self.V)
decoder_layer = Event_Layer(
d_model=d_model, nhead=heads, dim_feedforward=dim_feedforward)
self.event_decoder = TransDecoder(
decoder_layer,
layers,
nn.LayerNorm(d_model),
return_intermediate=True)
self.use_pos_enc = pos_emb
if self.use_pos_enc:
self.position_encoding_pre = positionalencoding2d(
embedding_dim, 14, 14).unsqueeze(0)
def forward(self, features):
batch_size = features.shape[0]
if self.use_pos_enc: # False
pos_encoding = self.position_encoding_pre(
features,
torch.zeros(features.shape[0], 14, 14,
dtype=torch.bool).cuda())
features = features + pos_encoding
enco_others = features.permute(1, 0, 2)
h_attr = self.V
h_attr_batch = h_attr.unsqueeze(0).repeat(batch_size, 1, 1)
h_attr_batch = h_attr_batch.permute(1, 0, 2)
hs, _ = self.event_decoder(h_attr_batch, enco_others)
hs = hs[-1].permute(1, 0, 2)
return hs
class Frame_decoder(nn.Module):
def __init__(self,
num_attris=3,
layers=1,
heads=1,
dim_ftr=512,
pos_emb=False,
length=1,
dim_feedforward=512,
without_init=False):
super().__init__()
embedding_dim = dim_ftr
d_model = dim_ftr
dim_feedforward = dim_feedforward
self.V = nn.Parameter(
torch.Tensor(num_attris, dim_feedforward), requires_grad=True)
nn.init.xavier_uniform_(self.V)
decoder_layer = Frame_Layer(
d_model=d_model, nhead=heads, dim_feedforward=dim_feedforward)
self.event_decoder = TransDecoder(
decoder_layer,
layers,
nn.LayerNorm(d_model),
return_intermediate=True)
self.use_pos_enc = pos_emb
if self.use_pos_enc:
self.position_encoding_pre = positionalencoding2d(
embedding_dim, 14, 14).unsqueeze(0)
def forward(self, features):
batch_size = features.shape[0]
if self.use_pos_enc:
pos_encoding = self.position_encoding_pre(
features,
torch.zeros(features.shape[0], 14, 14,
dtype=torch.bool).cuda())
features = features + pos_encoding
enco_others = features.permute(1, 0, 2)
h_attr = self.V
h_attr_batch = h_attr.unsqueeze(0).repeat(batch_size, 1, 1)
h_attr_batch = h_attr_batch.permute(1, 0, 2)
hs, _ = self.event_decoder(h_attr_batch, enco_others)
hs = hs[-1].permute(1, 0, 2)
return hs

View File

@@ -33,10 +33,12 @@ from .backbone import MsModelMixin
def get_chat_prompt(system: str, text: str, history: List[Tuple[str, str]],
max_length: int, tokenizer):
system_prompt = f'<s>[INST] <<SYS>>\n{system}\n<</SYS>>\n\n'
system_ids = tokenizer(system_prompt, return_tensors='pt').input_ids
system_ids = tokenizer(
system_prompt, add_special_tokens=False, return_tensors='pt').input_ids
text_prompt = f'{text.strip()} [/INST]'
text_ids = tokenizer(text_prompt, return_tensors='pt').input_ids
text_ids = tokenizer(
text_prompt, add_special_tokens=False, return_tensors='pt').input_ids
prompt_length = system_ids.shape[-1] + text_ids.shape[-1]
if prompt_length > max_length:
@@ -51,7 +53,9 @@ def get_chat_prompt(system: str, text: str, history: List[Tuple[str, str]],
assert isinstance(user, str)
assert isinstance(bot, str)
round_prompt = f'{user.strip()} [/INST] {bot.strip()} </s><s>[INST] '
round_ids = tokenizer(round_prompt, return_tensors='pt').input_ids
round_ids = tokenizer(
round_prompt, add_special_tokens=False,
return_tensors='pt').input_ids
if prompt_length + round_ids.shape[-1] > max_length:
# excess history should not be appended to the prompt
break

View File

@@ -26,9 +26,9 @@ class PolyLMForTextGeneration(TorchModel, StreamingOutputMixin):
"""
super().__init__(model_dir, *args, **kwargs)
self.tokenizer = AutoTokenizer.from_pretrained(
model_dir, use_fast=False)
model_dir, legacy=False, use_fast=False)
self.model = AutoModelForCausalLM.from_pretrained(
model_dir, device_map='auto')
model_dir, device_map='auto', trust_remote_code=True)
self.model.eval()
def forward(self, input: Dict[str, Tensor], **kwargs) -> Dict[str, Tensor]:

View File

@@ -113,7 +113,7 @@ class ReferringVideoObjectSegmentationDataset(TorchCustomDataset):
instance_masks = instance_masks[np.newaxis, ...]
instance_masks = torch.tensor(instance_masks).transpose(1, 2)
mask_rles = [encode(mask) for mask in instance_masks.numpy()]
mask_areas = area(mask_rles).astype(np.float)
mask_areas = area(mask_rles).astype(float)
f.close()
# create the target dict for the center frame:

View File

@@ -0,0 +1,118 @@
import os
import torch
from torch import nn
from torch.autograd import Function
from torch.nn import functional as F
from torch.utils.cpp_extension import load
module_path = os.path.dirname(__file__)
fused = load(
'fused',
sources=[
os.path.join(module_path, 'fused_bias_act.cpp'),
os.path.join(module_path, 'fused_bias_act_kernel.cu'),
],
)
class FusedLeakyReLUFunctionBackward(Function):
@staticmethod
def forward(ctx, grad_output, out, bias, negative_slope, scale):
ctx.save_for_backward(out)
ctx.negative_slope = negative_slope
ctx.scale = scale
empty = grad_output.new_empty(0)
grad_input = fused.fused_bias_act(grad_output, empty, out, 3, 1,
negative_slope, scale)
dim = [0]
if grad_input.ndim > 2:
dim += list(range(2, grad_input.ndim))
if bias:
grad_bias = grad_input.sum(dim).detach()
else:
grad_bias = empty
return grad_input, grad_bias
@staticmethod
def backward(ctx, gradgrad_input, gradgrad_bias):
out, = ctx.saved_tensors
gradgrad_out = fused.fused_bias_act(gradgrad_input, gradgrad_bias, out,
3, 1, ctx.negative_slope,
ctx.scale)
return gradgrad_out, None, None, None, None
class FusedLeakyReLUFunction(Function):
@staticmethod
def forward(ctx, input, bias, negative_slope, scale):
empty = input.new_empty(0)
ctx.bias = bias is not None
if bias is None:
bias = empty
out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope,
scale)
ctx.save_for_backward(out)
ctx.negative_slope = negative_slope
ctx.scale = scale
return out
@staticmethod
def backward(ctx, grad_output):
out, = ctx.saved_tensors
grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(
grad_output, out, ctx.bias, ctx.negative_slope, ctx.scale)
if not ctx.bias:
grad_bias = None
return grad_input, grad_bias, None, None
class FusedLeakyReLU(nn.Module):
def __init__(self, channel, bias=True, negative_slope=0.2, scale=2**0.5):
super().__init__()
if bias:
self.bias = nn.Parameter(torch.zeros(channel))
else:
self.bias = None
self.negative_slope = negative_slope
self.scale = scale
def forward(self, input):
return fused_leaky_relu(input, self.bias, self.negative_slope,
self.scale)
def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2**0.5):
if input.device.type == 'cpu':
if bias is not None:
rest_dim = [1] * (input.ndim - bias.ndim - 1)
return (F.leaky_relu(
input + bias.view(1, bias.shape[0], *rest_dim),
negative_slope=0.2) * scale)
else:
return F.leaky_relu(input, negative_slope=0.2) * scale
else:
return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)

View File

@@ -0,0 +1,21 @@
#include <torch/extension.h>
torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
int act, int grad, float alpha, float scale);
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
int act, int grad, float alpha, float scale) {
CHECK_CUDA(input);
CHECK_CUDA(bias);
return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
}

View File

@@ -0,0 +1,99 @@
// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
//
// This work is made available under the Nvidia Source Code License-NC.
// To view a copy of this license, visit
// https://nvlabs.github.io/stylegan2/license.html
#include <torch/types.h>
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAApplyUtils.cuh>
#include <cuda.h>
#include <cuda_runtime.h>
template <typename scalar_t>
static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref,
int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) {
int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
scalar_t zero = 0.0;
for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) {
scalar_t x = p_x[xi];
if (use_bias) {
x += p_b[(xi / step_b) % size_b];
}
scalar_t ref = use_ref ? p_ref[xi] : zero;
scalar_t y;
switch (act * 10 + grad) {
default:
case 10: y = x; break;
case 11: y = x; break;
case 12: y = 0.0; break;
case 30: y = (x > 0.0) ? x : x * alpha; break;
case 31: y = (ref > 0.0) ? x : x * alpha; break;
case 32: y = 0.0; break;
}
out[xi] = y * scale;
}
}
torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
int act, int grad, float alpha, float scale) {
int curDevice = -1;
cudaGetDevice(&curDevice);
cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
auto x = input.contiguous();
auto b = bias.contiguous();
auto ref = refer.contiguous();
int use_bias = b.numel() ? 1 : 0;
int use_ref = ref.numel() ? 1 : 0;
int size_x = x.numel();
int size_b = b.numel();
int step_b = 1;
for (int i = 1 + 1; i < x.dim(); i++) {
step_b *= x.size(i);
}
int loop_x = 4;
int block_size = 4 * 32;
int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
auto y = torch::empty_like(x);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] {
fused_bias_act_kernel<scalar_t><<<grid_size, block_size, 0, stream>>>(
y.data_ptr<scalar_t>(),
x.data_ptr<scalar_t>(),
b.data_ptr<scalar_t>(),
ref.data_ptr<scalar_t>(),
act,
grad,
alpha,
scale,
loop_x,
size_x,
step_b,
size_b,
use_bias,
use_ref
);
});
return y;
}

View File

@@ -0,0 +1,23 @@
#include <torch/extension.h>
torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
int up_x, int up_y, int down_x, int down_y,
int pad_x0, int pad_x1, int pad_y0, int pad_y1);
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel,
int up_x, int up_y, int down_x, int down_y,
int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
CHECK_CUDA(input);
CHECK_CUDA(kernel);
return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
}

View File

@@ -0,0 +1,208 @@
import os
from collections import abc
import torch
from torch.autograd import Function
from torch.nn import functional as F
from torch.utils.cpp_extension import load
module_path = os.path.dirname(__file__)
upfirdn2d_op = load(
'upfirdn2d',
sources=[
os.path.join(module_path, 'upfirdn2d.cpp'),
os.path.join(module_path, 'upfirdn2d_kernel.cu'),
],
)
class UpFirDn2dBackward(Function):
@staticmethod
def forward(ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad,
in_size, out_size):
up_x, up_y = up
down_x, down_y = down
g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
grad_input = upfirdn2d_op.upfirdn2d(
grad_output,
grad_kernel,
down_x,
down_y,
up_x,
up_y,
g_pad_x0,
g_pad_x1,
g_pad_y0,
g_pad_y1,
)
grad_input = grad_input.view(in_size[0], in_size[1], in_size[2],
in_size[3])
ctx.save_for_backward(kernel)
pad_x0, pad_x1, pad_y0, pad_y1 = pad
ctx.up_x = up_x
ctx.up_y = up_y
ctx.down_x = down_x
ctx.down_y = down_y
ctx.pad_x0 = pad_x0
ctx.pad_x1 = pad_x1
ctx.pad_y0 = pad_y0
ctx.pad_y1 = pad_y1
ctx.in_size = in_size
ctx.out_size = out_size
return grad_input
@staticmethod
def backward(ctx, gradgrad_input):
kernel, = ctx.saved_tensors
gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2],
ctx.in_size[3], 1)
gradgrad_out = upfirdn2d_op.upfirdn2d(
gradgrad_input,
kernel,
ctx.up_x,
ctx.up_y,
ctx.down_x,
ctx.down_y,
ctx.pad_x0,
ctx.pad_x1,
ctx.pad_y0,
ctx.pad_y1,
)
# gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3])
gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.in_size[1],
ctx.out_size[0], ctx.out_size[1])
return gradgrad_out, None, None, None, None, None, None, None, None
class UpFirDn2d(Function):
@staticmethod
def forward(ctx, input, kernel, up, down, pad):
up_x, up_y = up
down_x, down_y = down
pad_x0, pad_x1, pad_y0, pad_y1 = pad
kernel_h, kernel_w = kernel.shape
batch, channel, in_h, in_w = input.shape
ctx.in_size = input.shape
input = input.reshape(-1, in_h, in_w, 1)
ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x
ctx.out_size = (out_h, out_w)
ctx.up = (up_x, up_y)
ctx.down = (down_x, down_y)
ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
g_pad_x0 = kernel_w - pad_x0 - 1
g_pad_y0 = kernel_h - pad_y0 - 1
g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
out = upfirdn2d_op.upfirdn2d(input, kernel, up_x, up_y, down_x, down_y,
pad_x0, pad_x1, pad_y0, pad_y1)
# out = out.view(major, out_h, out_w, minor)
out = out.view(-1, channel, out_h, out_w)
return out
@staticmethod
def backward(ctx, grad_output):
kernel, grad_kernel = ctx.saved_tensors
grad_input = None
if ctx.needs_input_grad[0]:
grad_input = UpFirDn2dBackward.apply(
grad_output,
kernel,
grad_kernel,
ctx.up,
ctx.down,
ctx.pad,
ctx.g_pad,
ctx.in_size,
ctx.out_size,
)
return grad_input, None, None, None, None
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
if not isinstance(up, abc.Iterable):
up = (up, up)
if not isinstance(down, abc.Iterable):
down = (down, down)
if len(pad) == 2:
pad = (pad[0], pad[1], pad[0], pad[1])
if input.device.type == 'cpu':
out = upfirdn2d_native(input, kernel, *up, *down, *pad)
else:
out = UpFirDn2d.apply(input, kernel, up, down, pad)
return out
def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1,
pad_y0, pad_y1):
_, channel, in_h, in_w = input.shape
input = input.reshape(-1, in_h, in_w, 1)
_, in_h, in_w, minor = input.shape
kernel_h, kernel_w = kernel.shape
out = input.view(-1, in_h, 1, in_w, 1, minor)
out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
out = out.view(-1, in_h * up_y, in_w * up_x, minor)
out = F.pad(
out,
[0, 0,
max(pad_x0, 0),
max(pad_x1, 0),
max(pad_y0, 0),
max(pad_y1, 0)])
out = out[:,
max(-pad_y0, 0):out.shape[1] - max(-pad_y1, 0),
max(-pad_x0, 0):out.shape[2] - max(-pad_x1, 0), :, ]
out = out.permute(0, 3, 1, 2)
out = out.reshape(
[-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
out = F.conv2d(out, w)
out = out.reshape(
-1,
minor,
in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
)
out = out.permute(0, 2, 3, 1)
out = out[:, ::down_y, ::down_x, :]
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x
return out.view(-1, channel, out_h, out_w)

View File

@@ -0,0 +1,369 @@
// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
//
// This work is made available under the Nvidia Source Code License-NC.
// To view a copy of this license, visit
// https://nvlabs.github.io/stylegan2/license.html
#include <torch/types.h>
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAApplyUtils.cuh>
#include <ATen/cuda/CUDAContext.h>
#include <cuda.h>
#include <cuda_runtime.h>
static __host__ __device__ __forceinline__ int floor_div(int a, int b) {
int c = a / b;
if (c * b > a) {
c--;
}
return c;
}
struct UpFirDn2DKernelParams {
int up_x;
int up_y;
int down_x;
int down_y;
int pad_x0;
int pad_x1;
int pad_y0;
int pad_y1;
int major_dim;
int in_h;
int in_w;
int minor_dim;
int kernel_h;
int kernel_w;
int out_h;
int out_w;
int loop_major;
int loop_x;
};
template <typename scalar_t>
__global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input,
const scalar_t *kernel,
const UpFirDn2DKernelParams p) {
int minor_idx = blockIdx.x * blockDim.x + threadIdx.x;
int out_y = minor_idx / p.minor_dim;
minor_idx -= out_y * p.minor_dim;
int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y;
int major_idx_base = blockIdx.z * p.loop_major;
if (out_x_base >= p.out_w || out_y >= p.out_h ||
major_idx_base >= p.major_dim) {
return;
}
int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0;
int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h);
int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y;
int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y;
for (int loop_major = 0, major_idx = major_idx_base;
loop_major < p.loop_major && major_idx < p.major_dim;
loop_major++, major_idx++) {
for (int loop_x = 0, out_x = out_x_base;
loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) {
int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0;
int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w);
int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x;
int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x;
const scalar_t *x_p =
&input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim +
minor_idx];
const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x];
int x_px = p.minor_dim;
int k_px = -p.up_x;
int x_py = p.in_w * p.minor_dim;
int k_py = -p.up_y * p.kernel_w;
scalar_t v = 0.0f;
for (int y = 0; y < h; y++) {
for (int x = 0; x < w; x++) {
v += static_cast<scalar_t>(*x_p) * static_cast<scalar_t>(*k_p);
x_p += x_px;
k_p += k_px;
}
x_p += x_py - w * x_px;
k_p += k_py - w * k_px;
}
out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
minor_idx] = v;
}
}
}
template <typename scalar_t, int up_x, int up_y, int down_x, int down_y,
int kernel_h, int kernel_w, int tile_out_h, int tile_out_w>
__global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input,
const scalar_t *kernel,
const UpFirDn2DKernelParams p) {
const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;
const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;
__shared__ volatile float sk[kernel_h][kernel_w];
__shared__ volatile float sx[tile_in_h][tile_in_w];
int minor_idx = blockIdx.x;
int tile_out_y = minor_idx / p.minor_dim;
minor_idx -= tile_out_y * p.minor_dim;
tile_out_y *= tile_out_h;
int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;
int major_idx_base = blockIdx.z * p.loop_major;
if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h |
major_idx_base >= p.major_dim) {
return;
}
for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w;
tap_idx += blockDim.x) {
int ky = tap_idx / kernel_w;
int kx = tap_idx - ky * kernel_w;
scalar_t v = 0.0;
if (kx < p.kernel_w & ky < p.kernel_h) {
v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];
}
sk[ky][kx] = v;
}
for (int loop_major = 0, major_idx = major_idx_base;
loop_major < p.loop_major & major_idx < p.major_dim;
loop_major++, major_idx++) {
for (int loop_x = 0, tile_out_x = tile_out_x_base;
loop_x < p.loop_x & tile_out_x < p.out_w;
loop_x++, tile_out_x += tile_out_w) {
int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;
int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;
int tile_in_x = floor_div(tile_mid_x, up_x);
int tile_in_y = floor_div(tile_mid_y, up_y);
__syncthreads();
for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w;
in_idx += blockDim.x) {
int rel_in_y = in_idx / tile_in_w;
int rel_in_x = in_idx - rel_in_y * tile_in_w;
int in_x = rel_in_x + tile_in_x;
int in_y = rel_in_y + tile_in_y;
scalar_t v = 0.0;
if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {
v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) *
p.minor_dim +
minor_idx];
}
sx[rel_in_y][rel_in_x] = v;
}
__syncthreads();
for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w;
out_idx += blockDim.x) {
int rel_out_y = out_idx / tile_out_w;
int rel_out_x = out_idx - rel_out_y * tile_out_w;
int out_x = rel_out_x + tile_out_x;
int out_y = rel_out_y + tile_out_y;
int mid_x = tile_mid_x + rel_out_x * down_x;
int mid_y = tile_mid_y + rel_out_y * down_y;
int in_x = floor_div(mid_x, up_x);
int in_y = floor_div(mid_y, up_y);
int rel_in_x = in_x - tile_in_x;
int rel_in_y = in_y - tile_in_y;
int kernel_x = (in_x + 1) * up_x - mid_x - 1;
int kernel_y = (in_y + 1) * up_y - mid_y - 1;
scalar_t v = 0.0;
#pragma unroll
for (int y = 0; y < kernel_h / up_y; y++)
#pragma unroll
for (int x = 0; x < kernel_w / up_x; x++)
v += sx[rel_in_y + y][rel_in_x + x] *
sk[kernel_y + y * up_y][kernel_x + x * up_x];
if (out_x < p.out_w & out_y < p.out_h) {
out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
minor_idx] = v;
}
}
}
}
}
torch::Tensor upfirdn2d_op(const torch::Tensor &input,
const torch::Tensor &kernel, int up_x, int up_y,
int down_x, int down_y, int pad_x0, int pad_x1,
int pad_y0, int pad_y1) {
int curDevice = -1;
cudaGetDevice(&curDevice);
cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
UpFirDn2DKernelParams p;
auto x = input.contiguous();
auto k = kernel.contiguous();
p.major_dim = x.size(0);
p.in_h = x.size(1);
p.in_w = x.size(2);
p.minor_dim = x.size(3);
p.kernel_h = k.size(0);
p.kernel_w = k.size(1);
p.up_x = up_x;
p.up_y = up_y;
p.down_x = down_x;
p.down_y = down_y;
p.pad_x0 = pad_x0;
p.pad_x1 = pad_x1;
p.pad_y0 = pad_y0;
p.pad_y1 = pad_y1;
p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) /
p.down_y;
p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) /
p.down_x;
auto out =
at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());
int mode = -1;
int tile_out_h = -1;
int tile_out_w = -1;
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
p.kernel_h <= 4 && p.kernel_w <= 4) {
mode = 1;
tile_out_h = 16;
tile_out_w = 64;
}
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
p.kernel_h <= 3 && p.kernel_w <= 3) {
mode = 2;
tile_out_h = 16;
tile_out_w = 64;
}
if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
p.kernel_h <= 4 && p.kernel_w <= 4) {
mode = 3;
tile_out_h = 16;
tile_out_w = 64;
}
if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
p.kernel_h <= 2 && p.kernel_w <= 2) {
mode = 4;
tile_out_h = 16;
tile_out_w = 64;
}
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
p.kernel_h <= 4 && p.kernel_w <= 4) {
mode = 5;
tile_out_h = 8;
tile_out_w = 32;
}
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
p.kernel_h <= 2 && p.kernel_w <= 2) {
mode = 6;
tile_out_h = 8;
tile_out_w = 32;
}
dim3 block_size;
dim3 grid_size;
if (tile_out_h > 0 && tile_out_w > 0) {
p.loop_major = (p.major_dim - 1) / 16384 + 1;
p.loop_x = 1;
block_size = dim3(32 * 8, 1, 1);
grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,
(p.out_w - 1) / (p.loop_x * tile_out_w) + 1,
(p.major_dim - 1) / p.loop_major + 1);
} else {
p.loop_major = (p.major_dim - 1) / 16384 + 1;
p.loop_x = 4;
block_size = dim3(4, 32, 1);
grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1,
(p.out_w - 1) / (p.loop_x * block_size.y) + 1,
(p.major_dim - 1) / p.loop_major + 1);
}
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] {
switch (mode) {
case 1:
upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 4, 4, 16, 64>
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
x.data_ptr<scalar_t>(),
k.data_ptr<scalar_t>(), p);
break;
case 2:
upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 3, 3, 16, 64>
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
x.data_ptr<scalar_t>(),
k.data_ptr<scalar_t>(), p);
break;
case 3:
upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 4, 4, 16, 64>
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
x.data_ptr<scalar_t>(),
k.data_ptr<scalar_t>(), p);
break;
case 4:
upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 2, 2, 16, 64>
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
x.data_ptr<scalar_t>(),
k.data_ptr<scalar_t>(), p);
break;
case 5:
upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
x.data_ptr<scalar_t>(),
k.data_ptr<scalar_t>(), p);
break;
case 6:
upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
x.data_ptr<scalar_t>(),
k.data_ptr<scalar_t>(), p);
break;
default:
upfirdn2d_kernel_large<scalar_t><<<grid_size, block_size, 0, stream>>>(
out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(),
k.data_ptr<scalar_t>(), p);
}
});
return out;
}

View File

@@ -48,6 +48,11 @@ class OutputKeys(object):
PROBABILITIES = 'probabilities'
DIALOG_STATES = 'dialog_states'
VIDEO_EMBEDDING = 'video_embedding'
PHRASE_PROTOTYPE = 'phrase_prototype'
OBJECT_PROTOTYPE = 'object_prototype'
SENTENCE_PROTOTYPE = 'sentence_prototype'
EVENT_PROTOTYPE = 'event_prototype'
TEXTVIDEO_SIM = 'textvideo_sim'
UUID = 'uuid'
WORD = 'word'
KWS_LIST = 'kws_list'
@@ -90,9 +95,9 @@ OutputTypes = {
OutputKeys.OUTPUT_IMG: 'image', # checked
OutputKeys.OUTPUT_IMGS: List[np.ndarray], # checked
OutputKeys.OUTPUT_VIDEO: 'bytes',
OutputKeys.OUTPUT_PCM: np.ndarray,
OutputKeys.OUTPUT_PCM: 'pcm',
OutputKeys.OUTPUT_PCM_LIST: List[np.ndarray],
OutputKeys.OUTPUT_WAV: np.ndarray,
OutputKeys.OUTPUT_WAV: 'pcm',
OutputKeys.OUTPUT_OBJ: Dict,
OutputKeys.OUTPUT_MESH: np.ndarray,
OutputKeys.IMG_EMBEDDING: np.ndarray,
@@ -106,6 +111,11 @@ OutputTypes = {
OutputKeys.PROBABILITIES: np.ndarray,
OutputKeys.DIALOG_STATES: object,
OutputKeys.VIDEO_EMBEDDING: np.ndarray,
OutputKeys.PHRASE_PROTOTYPE: np.ndarray,
OutputKeys.OBJECT_PROTOTYPE: np.ndarray,
OutputKeys.SENTENCE_PROTOTYPE: np.ndarray,
OutputKeys.EVENT_PROTOTYPE: np.ndarray,
OutputKeys.TEXTVIDEO_SIM: np.ndarray,
OutputKeys.UUID: str,
OutputKeys.WORD: str,
OutputKeys.KWS_LIST: List[str],
@@ -329,6 +339,24 @@ OutputTypeSchema = {
'type': 'number'
}
},
OutputKeys.PHRASE_PROTOTYPE: {
'type': 'array',
'items': {
'type': 'number'
}
},
OutputKeys.OBJECT_PROTOTYPE: {
'type': 'array',
'items': {
'type': 'number'
}
},
OutputKeys.TEXTVIDEO_SIM: {
'type': 'array',
'items': {
'type': 'number'
}
},
OutputKeys.UUID: {
'type': 'string'
},
@@ -688,6 +716,8 @@ TASK_OUTPUTS = {
# }
Tasks.portrait_matting: [OutputKeys.OUTPUT_IMG],
Tasks.universal_matting: [OutputKeys.OUTPUT_IMG],
Tasks.image_deblurring: [OutputKeys.OUTPUT_IMG],
Tasks.image_face_fusion: [OutputKeys.OUTPUT_IMG],
# image_quality_assessment_mos result for a single image is a score in range [0, 1]
# {0.5}
@@ -700,6 +730,7 @@ TASK_OUTPUTS = {
Tasks.image_colorization: [OutputKeys.OUTPUT_IMG],
Tasks.image_color_enhancement: [OutputKeys.OUTPUT_IMG],
Tasks.image_denoising: [OutputKeys.OUTPUT_IMG],
Tasks.image_editing: [OutputKeys.OUTPUT_IMG],
Tasks.image_portrait_enhancement: [OutputKeys.OUTPUT_IMG],
Tasks.crowd_counting: [OutputKeys.SCORES, OutputKeys.OUTPUT_IMG],
Tasks.image_inpainting: [OutputKeys.OUTPUT_IMG],
@@ -721,6 +752,7 @@ TASK_OUTPUTS = {
Tasks.video_deinterlace: [OutputKeys.OUTPUT_VIDEO],
Tasks.nerf_recon_acc: [OutputKeys.OUTPUT],
Tasks.nerf_recon_vq_compression: [OutputKeys.OUTPUT],
Tasks.surface_recon_common: [OutputKeys.OUTPUT],
Tasks.video_colorization: [OutputKeys.OUTPUT_VIDEO],
# image quality assessment degradation result for single image
@@ -914,6 +946,32 @@ TASK_OUTPUTS = {
# }
Tasks.video_embedding: [OutputKeys.VIDEO_EMBEDDING],
# phrase prototype result for single sentence
# {
# "phrase_prototype": np.array with shape [K*D],
# }
# sentence prototype result for single sentence
# {
# "sentence_prototype": np.array with shape [1*D],
# }
# object prototype result for single video
# {
# "object_prototype": np.array with shape [N*K*D],
# }
# event prototype result for single video
# {
# "event_prototype": np.array with shape [N*M*D],
# }
# text search video result for single sentence
# {
# "textvideo_sim": np.array with shape [N*N],
# }
Tasks.text_video_retrieval: [
OutputKeys.PHRASE_PROTOTYPE, OutputKeys.SENTENCE_PROTOTYPE,
OutputKeys.OBJECT_PROTOTYPE, OutputKeys.EVENT_PROTOTYPE,
OutputKeys.TEXTVIDEO_SIM
],
# video stabilization task result for a single video
# {"output_video": "path_to_rendered_video"}
Tasks.video_stabilization: [OutputKeys.OUTPUT_VIDEO],
@@ -1512,6 +1570,11 @@ TASK_OUTPUTS = {
# "output_img": np.ndarray with shape [height, width, 3]
# }
Tasks.image_try_on: [OutputKeys.OUTPUT_IMG],
# Tasks.human_image_generation result for a single sample
# {
# "output_img": np.ndarray with shape [height, width, 3]
# }
Tasks.human_image_generation: [OutputKeys.OUTPUT_IMG],
}

View File

@@ -102,6 +102,18 @@ TASK_INPUTS = {
InputType.IMAGE,
Tasks.face_2d_keypoints:
InputType.IMAGE,
Tasks.face_liveness:
InputType.IMAGE,
Tasks.face_quality_assessment:
InputType.IMAGE,
Tasks.card_detection:
InputType.IMAGE,
Tasks.license_plate_detection:
InputType.IMAGE,
Tasks.lineless_table_recognition:
InputType.IMAGE,
Tasks.table_recognition:
InputType.IMAGE,
Tasks.face_detection:
InputType.IMAGE,
Tasks.facial_expression_recognition:
@@ -118,14 +130,30 @@ TASK_INPUTS = {
InputType.NUMBER,
Tasks.image_classification:
InputType.IMAGE,
Tasks.image_quality_assessment_mos:
InputType.IMAGE,
Tasks.image_quality_assessment_degradation:
InputType.IMAGE,
Tasks.image_object_detection:
InputType.IMAGE,
Tasks.domain_specific_object_detection:
InputType.IMAGE,
Tasks.human_wholebody_keypoint:
InputType.IMAGE,
Tasks.image_segmentation:
InputType.IMAGE,
Tasks.portrait_matting:
InputType.IMAGE,
Tasks.universal_matting:
InputType.IMAGE,
Tasks.product_segmentation:
InputType.IMAGE,
Tasks.semantic_segmentation:
InputType.IMAGE,
Tasks.face_human_hand_detection:
InputType.IMAGE,
Tasks.hand_static:
InputType.IMAGE,
Tasks.image_fewshot_detection:
InputType.IMAGE,
Tasks.open_vocabulary_detection: {
@@ -148,6 +176,8 @@ TASK_INPUTS = {
InputType.IMAGE,
Tasks.image_denoising:
InputType.IMAGE,
Tasks.image_body_reshaping:
InputType.IMAGE,
Tasks.image_portrait_enhancement:
InputType.IMAGE,
Tasks.crowd_counting:
@@ -169,6 +199,12 @@ TASK_INPUTS = {
'image': InputType.IMAGE,
'prompt': InputType.TEXT,
},
Tasks.image_face_fusion: {
'template': InputType.IMAGE,
'user': InputType.IMAGE,
},
Tasks.image_deblurring:
InputType.IMAGE,
Tasks.video_colorization:
InputType.VIDEO,
@@ -227,6 +263,10 @@ TASK_INPUTS = {
InputKeys.IMAGE: InputType.IMAGE,
InputKeys.IMAGE: InputType.IMAGE
},
Tasks.human_image_generation: {
InputKeys.IMAGE: InputType.IMAGE,
'target_pose_path': InputType.TEXT
},
# ============ nlp tasks ===================
Tasks.chat: [
@@ -254,11 +294,15 @@ TASK_INPUTS = {
Tasks.nli: (InputType.TEXT, InputType.TEXT),
Tasks.sentiment_classification:
InputType.TEXT,
Tasks.zero_shot_classification: InputType.TEXT,
Tasks.zero_shot_classification:
InputType.TEXT,
Tasks.relation_extraction:
InputType.TEXT,
Tasks.translation:
InputType.TEXT,
Tasks.text_summarization: [InputType.TEXT, {
'text': InputType.TEXT,
}],
Tasks.competency_aware_translation:
InputType.TEXT,
Tasks.word_segmentation: [InputType.TEXT, {
@@ -348,12 +392,17 @@ TASK_INPUTS = {
InputType.AUDIO,
Tasks.speaker_diarization_dialogue_detection:
InputType.TEXT,
Tasks.language_score_prediction:
InputType.TEXT,
Tasks.punctuation:
InputType.TEXT,
Tasks.speech_language_recognition:
InputType.AUDIO,
Tasks.speaker_diarization_semantic_speaker_turn_detection:
InputType.TEXT,
Tasks.inverse_text_processing:
InputType.TEXT,
Tasks.speaker_verification: [InputType.AUDIO, InputType.AUDIO],
# ============ multi-modal tasks ===================
Tasks.image_captioning: [InputType.IMAGE, {
@@ -384,6 +433,10 @@ TASK_INPUTS = {
'img': InputType.IMAGE,
'text': InputType.TEXT
},
Tasks.text_video_retrieval: {
'video': InputType.VIDEO,
'text': InputType.TEXT
},
Tasks.visual_question_answering: {
'image': InputType.IMAGE,
'text': InputType.TEXT
@@ -415,4 +468,8 @@ TASK_INPUTS = {
Tasks.text_to_360panorama_image: {
'prompt': InputType.TEXT,
},
Tasks.image_editing: {
'img': InputType.IMAGE,
'prompts': InputType.LIST
}
}

View File

@@ -11,6 +11,7 @@ if TYPE_CHECKING:
from .linear_aec_pipeline import LinearAECPipeline
from .text_to_speech_pipeline import TextToSpeechSambertHifiganPipeline
from .inverse_text_processing_pipeline import InverseTextProcessingPipeline
from .separation_pipeline import SeparationPipeline
from .speaker_verification_pipeline import SpeakerVerificationPipeline
else:
_import_structure = {
@@ -23,6 +24,7 @@ else:
'text_to_speech_pipeline': ['TextToSpeechSambertHifiganPipeline'],
'itn_inference_pipeline': ['InverseTextProcessingPipeline'],
'inverse_text_processing_pipeline': ['InverseTextProcessingPipeline'],
'separation_pipeline': ['SeparationPipeline'],
'speaker_verification_pipeline': ['SpeakerVerificationPipeline']
}

View File

@@ -0,0 +1,144 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import io
import os
from typing import Union
import numpy as np
import soundfile as sf
import torch
import torchaudio
from modelscope.fileio import File
from modelscope.metainfo import Pipelines
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import InputModel, Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger
logger = get_logger()
__all__ = ['LanguageRecognitionPipeline']
@PIPELINES.register_module(
Tasks.speech_language_recognition,
module_name=Pipelines.speech_language_recognition_eres2net)
class LanguageRecognitionPipeline(Pipeline):
"""Language Recognition Inference Pipeline
use `model` to create a Language Recognition pipeline.
Args:
model (LanguageRecognitionPipeline): A model instance, or a model local dir, or a model id in the model hub.
kwargs (dict, `optional`):
Extra kwargs passed into the pipeline's constructor.
Example:
>>> from modelscope.pipelines import pipeline
>>> from modelscope.utils.constant import Tasks
>>> p = pipeline(
>>> task=Tasks.speech_language_recognition, model='damo/speech_eres2net_base_lre_en-cn_16k')
>>> print(p(audio_in))
"""
def __init__(self, model: InputModel, **kwargs):
"""use `model` to create a Language Recognition pipeline for prediction
Args:
model (str): a valid offical model id
"""
super().__init__(model=model, **kwargs)
self.model_config = self.model.model_config
self.languages = self.model_config['languages']
def __call__(self,
in_audios: Union[str, list, np.ndarray],
out_file: str = None):
wavs = self.preprocess(in_audios)
results = self.forward(wavs)
outputs = self.postprocess(results, in_audios, out_file)
return outputs
def forward(self, inputs: list):
results = []
for x in inputs:
results.append(self.model(x).item())
return results
def postprocess(self,
inputs: list,
in_audios: Union[str, list, np.ndarray],
out_file=None):
if isinstance(in_audios, str):
output = {OutputKeys.TEXT: self.languages[inputs[0]]}
else:
output = {OutputKeys.TEXT: [self.languages[i] for i in inputs]}
if out_file is not None:
out_lines = []
for i, audio in enumerate(in_audios):
if isinstance(audio, str):
audio_id = os.path.basename(audio).rsplit('.', 1)[0]
else:
audio_id = i
out_lines.append('%s %s\n' %
(audio_id, self.languages[inputs[i]]))
with open(out_file, 'w') as f:
for i in out_lines:
f.write(i)
return output
def preprocess(self, inputs: Union[str, list, np.ndarray]):
output = []
if isinstance(inputs, str):
file_bytes = File.read(inputs)
data, fs = sf.read(io.BytesIO(file_bytes), dtype='float32')
if len(data.shape) == 2:
data = data[:, 0]
data = torch.from_numpy(data).unsqueeze(0)
if fs != self.model_config['sample_rate']:
logger.warning(
'The sample rate of audio is not %d, resample it.'
% self.model_config['sample_rate'])
data, fs = torchaudio.sox_effects.apply_effects_tensor(
data,
fs,
effects=[['rate',
str(self.model_config['sample_rate'])]])
data = data.squeeze(0)
output.append(data)
else:
for i in range(len(inputs)):
if isinstance(inputs[i], str):
file_bytes = File.read(inputs[i])
data, fs = sf.read(io.BytesIO(file_bytes), dtype='float32')
if len(data.shape) == 2:
data = data[:, 0]
data = torch.from_numpy(data).unsqueeze(0)
if fs != self.model_config['sample_rate']:
logger.warning(
'The sample rate of audio is not %d, resample it.'
% self.model_config['sample_rate'])
data, fs = torchaudio.sox_effects.apply_effects_tensor(
data,
fs,
effects=[[
'rate',
str(self.model_config['sample_rate'])
]])
data = data.squeeze(0)
elif isinstance(inputs[i], np.ndarray):
assert len(
inputs[i].shape
) == 1, 'modelscope error: Input array should be [N, T]'
data = inputs[i]
if data.dtype in ['int16', 'int32', 'int64']:
data = (data / (1 << 15)).astype('float32')
else:
data = data.astype('float32')
data = torch.from_numpy(data)
else:
raise ValueError(
'modelscope error: The input type is restricted to audio address and nump array.'
)
output.append(data)
return output

View File

@@ -1,5 +1,6 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import ast
import io
from typing import Any, Dict, List, Union
@@ -180,7 +181,14 @@ class SegmentationClusteringPipeline(Pipeline):
model=self.config['vad_model'])
vad_time = self.vad_pipeline(audio, audio_fs=self.fs)
vad_segments = []
for t in vad_time['text']:
if isinstance(vad_time['text'], str):
vad_time_list = ast.literal_eval(vad_time['text'])
elif isinstance(vad_time['text'], list):
vad_time_list = vad_time['text']
else:
raise ValueError('Incorrect vad result. Get %s' %
(type(vad_time['text'])))
for t in vad_time_list:
st = int(t[0]) / 1000
ed = int(t[1]) / 1000
vad_segments.append(

View File

@@ -8,7 +8,7 @@ import soundfile as sf
import torch
from modelscope.fileio import File
from modelscope.metainfo import Pipelines
from modelscope.metainfo import Models, Pipelines
from modelscope.models.base import Input
from modelscope.outputs import OutputKeys
from modelscope.pipelines import Pipeline
@@ -20,7 +20,11 @@ logger = get_logger()
@PIPELINES.register_module(
Tasks.speech_separation, module_name=Pipelines.speech_separation)
Tasks.speech_separation,
module_name=Models.speech_mossformer_separation_temporal_8k)
@PIPELINES.register_module(
Tasks.speech_separation,
module_name=Models.speech_mossformer2_separation_temporal_8k)
class SeparationPipeline(Pipeline):
def __init__(self, model, **kwargs):

View File

@@ -0,0 +1,243 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
from typing import Any, Dict, List, Sequence, Tuple, Union
import json
import yaml
from funasr.utils import asr_utils
from modelscope.metainfo import Pipelines
from modelscope.models import Model
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.utils.audio.audio_utils import (generate_scp_from_url,
update_local_model)
from modelscope.utils.constant import Frameworks, ModelFile, Tasks
from modelscope.utils.logger import get_logger
logger = get_logger()
__all__ = ['SeparationPipeline']
@PIPELINES.register_module(
Tasks.speech_separation, module_name=Pipelines.funasr_speech_separation)
class SeparationPipeline(Pipeline):
"""Speech Separation Inference Pipeline
use `model` to create a speech separation pipeline for prediction.
Args:
model: A model instance, or a model local dir, or a model id in the model hub.
kwargs (dict, `optional`):
Extra kwargs passed into the preprocessor's constructor.
Example:
>>> from modelscope.pipelines import pipeline
>>> pipeline = pipeline(
>>> task=Tasks.speech_separation, model='damo/speech_separation_mossformer_8k_pytorch')
>>> audio_in = 'mix_speech.wav'
>>> print(pipeline(audio_in))
"""
def __init__(self,
model: Union[Model, str] = None,
ngpu: int = 1,
**kwargs):
"""use `model` to create an speech separation pipeline for prediction
"""
super().__init__(model=model, **kwargs)
config_path = os.path.join(model, ModelFile.CONFIGURATION)
self.cmd = self.get_cmd(config_path, kwargs, model)
from funasr.bin import ss_inference_launch
self.funasr_infer_modelscope = ss_inference_launch.inference_launch(
mode=self.cmd['mode'],
batch_size=self.cmd['batch_size'],
ngpu=ngpu,
log_level=self.cmd['log_level'],
ss_infer_config=self.cmd['ss_infer_config'],
ss_model_file=self.cmd['ss_model_file'],
output_dir=self.cmd['output_dir'],
dtype=self.cmd['dtype'],
seed=self.cmd['seed'],
num_workers=self.cmd['num_workers'],
num_spks=self.cmd['num_spks'],
param_dict=self.cmd['param_dict'],
**kwargs,
)
def __call__(self,
audio_in: Union[str, bytes],
audio_fs: int = None,
recog_type: str = None,
audio_format: str = None,
output_dir: str = None,
param_dict: dict = None,
**kwargs) -> Dict[str, Any]:
"""
Decoding the input audios
Args:
audio_in('str' or 'bytes'):
- A string containing a local path to a wav file
- A string containing a local path to a scp
- A string containing a wav url
- A bytes input
audio_fs('int'):
frequency of sample
recog_type('str'):
recog type for wav file or datasets file ('wav', 'test', 'dev', 'train')
audio_format('str'):
audio format ('pcm', 'scp', 'kaldi_ark', 'tfrecord')
output_dir('str'):
output dir
param_dict('dict'):
extra kwargs
Return:
A dictionary of result or a list of dictionary of result.
The dictionary contain the following keys:
- **text** ('str') --The vad result.
"""
self.audio_in = None
self.raw_inputs = None
self.recog_type = recog_type
self.audio_format = audio_format
self.audio_fs = None
checking_audio_fs = None
if output_dir is not None:
self.cmd['output_dir'] = output_dir
if param_dict is not None:
self.cmd['param_dict'] = param_dict
if isinstance(audio_in, str):
# for funasr code, generate wav.scp from url or local path
self.audio_in, self.raw_inputs = generate_scp_from_url(audio_in)
elif isinstance(audio_in, bytes):
self.audio_in = audio_in
self.raw_inputs = None
else:
import numpy
import torch
if isinstance(audio_in, torch.Tensor):
self.audio_in = None
self.raw_inputs = audio_in
elif isinstance(audio_in, numpy.ndarray):
self.audio_in = None
self.raw_inputs = audio_in
# set the sample_rate of audio_in if checking_audio_fs is valid
if checking_audio_fs is not None:
self.audio_fs = checking_audio_fs
if recog_type is None or audio_format is None:
self.recog_type, self.audio_format, self.audio_in = asr_utils.type_checking(
audio_in=self.audio_in,
recog_type=recog_type,
audio_format=audio_format)
if hasattr(asr_utils,
'sample_rate_checking') and self.audio_in is not None:
checking_audio_fs = asr_utils.sample_rate_checking(
self.audio_in, self.audio_format)
if checking_audio_fs is not None:
self.audio_fs = checking_audio_fs
if audio_fs is not None:
self.cmd['fs']['audio_fs'] = audio_fs
else:
self.cmd['fs']['audio_fs'] = self.audio_fs
output = self.forward(self.audio_in, **kwargs)
return output
def get_cmd(self, config_path, extra_args, model_path) -> Dict[str, Any]:
model_cfg = json.loads(open(config_path).read())
model_dir = os.path.dirname(config_path)
# generate inference command
ss_model_path = os.path.join(
model_dir, model_cfg['model']['model_config']['ss_model_name'])
ss_model_config = os.path.join(
model_dir, model_cfg['model']['model_config']['ss_model_config'])
mode = model_cfg['model']['model_config']['mode']
frontend_conf = None
if os.path.exists(ss_model_config):
config_file = open(ss_model_config, encoding='utf-8')
root = yaml.full_load(config_file)
config_file.close()
if 'frontend_conf' in root:
frontend_conf = root['frontend_conf']
update_local_model(model_cfg['model']['model_config'], model_path,
extra_args)
cmd = {
'mode': mode,
'batch_size': 1,
'ngpu': 1, # 0: only CPU, ngpu>=1: gpu number if cuda is available
'log_level': 'ERROR',
'ss_infer_config': ss_model_config,
'ss_model_file': ss_model_path,
'output_dir': None,
'dtype': 'float32',
'seed': 0,
'num_workers': 0,
'num_spks': 2,
'param_dict': None,
'fs': {
'model_fs': None,
'audio_fs': None
}
}
if frontend_conf is not None and 'fs' in frontend_conf:
cmd['fs']['model_fs'] = frontend_conf['fs']
user_args_dict = [
'output_dir', 'batch_size', 'mode', 'ngpu', 'param_dict',
'num_workers', 'fs'
]
for user_args in user_args_dict:
if user_args in extra_args:
if extra_args.get(user_args) is not None:
cmd[user_args] = extra_args[user_args]
del extra_args[user_args]
return cmd
def postprocess(self, inputs: Dict[str, Any],
**post_params) -> Dict[str, Any]:
return inputs
def forward(self, audio_in: Dict[str, Any], **kwargs) -> Dict[str, Any]:
"""Decoding
"""
logger.info('Speech Separation Processing ...')
# generate inputs
data_cmd: Sequence[Tuple[str, str, str]]
if isinstance(self.audio_in, bytes):
data_cmd = [self.audio_in, 'speech', 'bytes']
elif isinstance(self.audio_in, str):
data_cmd = [self.audio_in, 'speech', 'sound']
elif self.raw_inputs is not None:
data_cmd = None
self.cmd['name_and_type'] = data_cmd
self.cmd['raw_inputs'] = self.raw_inputs
self.cmd['audio_in'] = self.audio_in
ss_result = self.run_inference(self.cmd, **kwargs)
return ss_result
def run_inference(self, cmd, **kwargs):
ss_result = []
if self.framework == Frameworks.torch:
ss_result = self.funasr_infer_modelscope(
data_path_and_name_and_type=cmd['name_and_type'],
raw_inputs=cmd['raw_inputs'],
output_dir_v2=cmd['output_dir'],
fs=cmd['fs'],
param_dict=cmd['param_dict'],
**kwargs)
else:
raise ValueError('model type is mismatching')
return ss_result

View File

@@ -30,6 +30,7 @@ if TYPE_CHECKING:
from .image_colorization_pipeline import ImageColorizationPipeline
from .image_denoise_pipeline import ImageDenoisePipeline
from .image_deblur_pipeline import ImageDeblurPipeline
from .image_editing_pipeline import ImageEditingPipeline
from .image_instance_segmentation_pipeline import ImageInstanceSegmentationPipeline
from .image_matting_pipeline import ImageMattingPipeline
from .image_portrait_enhancement_pipeline import ImagePortraitEnhancementPipeline
@@ -104,6 +105,7 @@ if TYPE_CHECKING:
from .image_human_parsing_pipeline import ImageHumanParsingPipeline
from .nerf_recon_acc_pipeline import NeRFReconAccPipeline
from .nerf_recon_4k_pipeline import NeRFRecon4KPipeline
from .surface_recon_common_pipeline import SurfaceReconCommonPipeline
from .controllable_image_generation_pipeline import ControllableImageGenerationPipeline
from .image_bts_depth_estimation_pipeline import ImageBTSDepthEstimationPipeline
from .pedestrian_attribute_recognition_pipeline import PedestrainAttributeRecognitionPipeline
@@ -136,6 +138,7 @@ else:
'image_cartoon_pipeline': ['ImageCartoonPipeline'],
'image_denoise_pipeline': ['ImageDenoisePipeline'],
'image_deblur_pipeline': ['ImageDeblurPipeline'],
'image_editing_pipeline': ['ImageEditingPipeline'],
'image_color_enhance_pipeline': ['ImageColorEnhancePipeline'],
'image_colorization_pipeline': ['ImageColorizationPipeline'],
'image_instance_segmentation_pipeline':
@@ -256,6 +259,7 @@ else:
'image_human_parsing_pipeline': ['ImageHumanParsingPipeline'],
'nerf_recon_acc_pipeline': ['NeRFReconAccPipeline'],
'nerf_recon_4k_pipeline': ['NeRFRecon4KPipeline'],
'surface_recon_common_pipeline': ['SurfaceReconCommonPipeline'],
'controllable_image_generation_pipeline': [
'ControllableImageGenerationPipeline'
],

View File

@@ -163,7 +163,7 @@ class Body3DKeypointsPipeline(Pipeline):
box = kps_2d['boxes'][
0] # box: [[[x1, y1], [x2, y2]]], N human boxes per frame, [0] represent using first detected bbox
pose = kps_2d['keypoints'][0] # keypoints: [15, 2]
score = kps_2d['scores'][0] # keypoints: [15, 2]
score = np.array(kps_2d['scores'][0]).max()
all_2d_poses.append(pose)
all_boxes_with_socre.append(
list(np.array(box).reshape(

View File

@@ -31,7 +31,7 @@ class FaceEmotionPipeline(Pipeline):
logger.info('load model done')
def preprocess(self, input: Input) -> Dict[str, Any]:
img = LoadImage.convert_to_ndarray(input['img_path'])
img = LoadImage.convert_to_ndarray(input)
return img
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:

View File

@@ -32,14 +32,13 @@ class NanoDettForFaceHumanHandDetectionPipeline(Pipeline):
logger.info('load model done')
def preprocess(self, input: Input) -> Dict[str, Any]:
img = LoadImage.convert_to_ndarray(input['input_path'])
img = LoadImage.convert_to_ndarray(input)
return img
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
cls_list, bbox_list, score_list = det_infer.inference(
self.model, self.device, input)
logger.info(cls_list, bbox_list, score_list)
return {
OutputKeys.LABELS: cls_list,
OutputKeys.BOXES: bbox_list,

View File

@@ -30,7 +30,7 @@ class HandStaticPipeline(Pipeline):
logger.info('load model done')
def preprocess(self, input: Input) -> Dict[str, Any]:
img = LoadImage.convert_to_ndarray(input['img_path'])
img = LoadImage.convert_to_ndarray(input)
return img
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:

View File

@@ -0,0 +1,60 @@
# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.
from typing import Any, Dict
import numpy as np
import torch
from modelscope.metainfo import Pipelines
from modelscope.models.cv.human_image_generation import \
human_image_generation_infer
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Input, Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import LoadImage
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger
logger = get_logger()
@PIPELINES.register_module(
Tasks.human_image_generation, module_name=Pipelines.human_image_generation)
class FreqHPTForHumanImageGenerationPipeline(Pipeline):
""" Human Image Generation Pipeline.
Examples:
>>> human_image_generation = pipeline(Tasks.human_image_generation, model='damo/cv_FreqHPT_human-image-generation')
>>> input_images = {'source_img_path': '/your_path/source_img.jpg',
>>> 'target_pose_path': '/your_path/target_pose.txt'}
>>> result = human_image_generation(input_images)
>>> result[OutputKeys.OUTPUT_IMG]
"""
def __init__(self, model: str, **kwargs):
"""
use `model` to create human image generation pipeline for prediction
Args:
model: model id on modelscope hub.
"""
super().__init__(model=model, **kwargs)
self.model_path = model
logger.info('load model done')
if torch.cuda.is_available():
self.device = 'cuda'
logger.info('Use GPU')
else:
self.device = 'cpu'
logger.info('Use CPU')
def preprocess(self, input: Input) -> Dict[str, Any]:
return input
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
return inputs
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
human_image_generation = human_image_generation_infer.infer(
self.model, input['source_img_path'], input['target_pose_path'],
self.device)
return {OutputKeys.OUTPUT_IMG: human_image_generation}

View File

@@ -70,7 +70,7 @@ class ImageCartoonPipeline(Pipeline):
def preprocess(self, input: Input) -> Dict[str, Any]:
img = LoadImage.convert_to_ndarray(input)
img = img.astype(np.float)
img = img.astype(float)
result = {'img': img}
return result

View File

@@ -0,0 +1,365 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os.path
from typing import Any, Dict, Optional, Union
import numpy as np
import torch
import torch.nn.functional as F
from diffusers import DDIMScheduler, StableDiffusionPipeline
from PIL import Image
from torchvision import transforms
from tqdm import tqdm
from modelscope.metainfo import Pipelines
from modelscope.models.cv.image_editing import (
MutualSelfAttentionControl, regiter_attention_editor_diffusers)
from modelscope.outputs import OutputKeys
from modelscope.pipelines.builder import PIPELINES
from modelscope.pipelines.multi_modal.diffusers_wrapped.diffusers_pipeline import \
DiffusersPipeline
from modelscope.preprocessors import LoadImage
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger
logger = get_logger()
__all__ = ['ImageEditingPipeline']
@PIPELINES.register_module(
Tasks.image_editing, module_name=Pipelines.image_editing)
class ImageEditingPipeline(DiffusersPipeline):
def __init__(self, model=str, preprocessor=None, **kwargs):
""" MasaCtrl Image Editing Pipeline.
Examples:
>>> import cv2
>>> from modelscope.pipelines import pipeline
>>> from modelscope.utils.constant import Tasks
>>> prompts = [
>>> "", # source prompt
>>> "a photo of a running corgi" # target prompt
>>> ]
>>> output_image_path = './result.png'
>>> img = 'https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/public/ModelScope/test/images/corgi.jpg'
>>> input = {'img': img, 'prompts': prompts}
>>>
>>> pipe = pipeline(
>>> Tasks.image_editing,
>>> model='damo/cv_masactrl_image-editing')
>>>
>>> output = pipe(input)['output_img']
>>> cv2.imwrite(output_image_path, output)
>>> print('pipeline: the output image path is {}'.format(output_image_path))
"""
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
torch_dtype = kwargs.get('torch_dtype', torch.float32)
self._device = getattr(
kwargs, 'device',
torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
logger.info('load image editing pipeline done')
scheduler = DDIMScheduler.from_pretrained(
os.path.join(model, 'stable-diffusion-v1-4'),
subfolder='scheduler')
self.pipeline = _MasaCtrlPipeline.from_pretrained(
os.path.join(model, 'stable-diffusion-v1-4'),
scheduler=scheduler,
torch_dtype=torch_dtype,
use_safetensors=True).to(self._device)
def preprocess(self, input: Dict[str, Any]) -> Dict[str, Any]:
img = LoadImage.convert_to_img(input.get('img'))
test_transforms = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])]) # [-1, 1]
img = test_transforms(img).unsqueeze(0)
img = F.interpolate(img, (512, 512))
input['img'] = img.to(self._device)
return input
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
if not isinstance(input, dict):
raise ValueError(
f'Expected the input to be a dictionary, but got {type(input)}'
)
prompts = input.get('prompts')
start_code, latents_list = self.pipeline.invert(
input.get('img'),
prompts[0],
guidance_scale=7.5,
num_inference_steps=50,
return_intermediates=True)
start_code = start_code.expand(len(prompts), -1, -1, -1)
STEP, LAYER = 4, 10
editor = MutualSelfAttentionControl(STEP, LAYER)
regiter_attention_editor_diffusers(self.pipeline, editor)
# inference the synthesized image
output = self.pipeline(
prompts,
latents=start_code,
guidance_scale=input.get('guidance_scale', 7.5),
)[-1:]
return {'output_tensor': output}
def postprocess(self, input: Dict[str, Any]) -> Dict[str, Any]:
output_img = (input['output_tensor'].squeeze(0) * 255).cpu().permute(
1, 2, 0).numpy().astype('uint8')
return {OutputKeys.OUTPUT_IMG: output_img[:, :, ::-1]}
class _MasaCtrlPipeline(StableDiffusionPipeline):
def next_step(
self,
model_output: torch.FloatTensor,
timestep: int,
x: torch.FloatTensor,
eta=0,
verbose=False,
):
"""
Inverse sampling for DDIM Inversion
x_t -> x_(t+1)
"""
if verbose:
print('timestep: ', timestep)
next_step = timestep
timestep = min(
timestep - self.scheduler.config.num_train_timesteps
// self.scheduler.num_inference_steps, 999)
alpha_prod_t = self.scheduler.alphas_cumprod[
timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod
alpha_prod_t_next = self.scheduler.alphas_cumprod[next_step]
beta_prod_t = 1 - alpha_prod_t
pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5
pred_dir = (1 - alpha_prod_t_next)**0.5 * model_output
x_next = alpha_prod_t_next**0.5 * pred_x0 + pred_dir
return x_next, pred_x0
def step(
self,
model_output: torch.FloatTensor,
timestep: int,
x: torch.FloatTensor,
eta: float = 0.0,
verbose=False,
):
"""
predict the sample the next step in the denoise process.
x_t -> x_(t-1)
"""
prev_timestep = timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps
alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
alpha_prod_t_prev = self.scheduler.alphas_cumprod[
prev_timestep] if prev_timestep > 0 else self.scheduler.final_alpha_cumprod
beta_prod_t = 1 - alpha_prod_t
pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5
pred_dir = (1 - alpha_prod_t_prev)**0.5 * model_output
x_prev = alpha_prod_t_prev**0.5 * pred_x0 + pred_dir
return x_prev, pred_x0
@torch.no_grad()
def image2latent(self, image):
DEVICE = self._execution_device
if type(image) is Image:
image = np.array(image)
image = torch.from_numpy(image).float() / 127.5 - 1
image = image.permute(2, 0, 1).unsqueeze(0).to(DEVICE)
# input image density range [-1, 1]
latents = self.vae.encode(image)['latent_dist'].mean
latents = latents * 0.18215
return latents
@torch.no_grad()
def latent2image(self, latents, return_type='pt'):
latents = 1 / 0.18215 * latents.detach()
image = self.vae.decode(latents)['sample']
if return_type == 'np':
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()[0]
image = (image * 255).astype(np.uint8)
elif return_type == 'pt':
image = (image / 2 + 0.5).clamp(0, 1)
return image
@torch.no_grad()
def __call__(self,
prompt,
batch_size=1,
height=512,
width=512,
num_inference_steps=50,
guidance_scale=7.5,
eta=0.0,
latents=None,
unconditioning=None,
neg_prompt=None,
ref_intermediate_latents=None,
return_intermediates=False,
**kwds):
DEVICE = self._execution_device
if isinstance(prompt, list):
batch_size = len(prompt)
elif isinstance(prompt, str):
if batch_size > 1:
prompt = [prompt] * batch_size
# text embeddings
text_input = self.tokenizer(
prompt, padding='max_length', max_length=77, return_tensors='pt')
text_embeddings = self.text_encoder(text_input.input_ids.to(DEVICE))[0]
print('input text embeddings :', text_embeddings.shape)
# define initial latents
latents_shape = (batch_size, self.unet.in_channels, height // 8,
width // 8)
if latents is None:
latents = torch.randn(latents_shape, device=DEVICE)
else:
assert latents.shape == latents_shape, f'The shape of input latent tensor {latents.shape} should equal ' \
f'to predefined one.'
# unconditional embedding for classifier free guidance
if guidance_scale > 1.:
if neg_prompt:
uc_text = neg_prompt
else:
uc_text = ''
unconditional_input = self.tokenizer(
[uc_text] * batch_size,
padding='max_length',
max_length=77,
return_tensors='pt')
unconditional_embeddings = self.text_encoder(
unconditional_input.input_ids.to(DEVICE))[0]
text_embeddings = torch.cat(
[unconditional_embeddings, text_embeddings], dim=0)
print('latents shape: ', latents.shape)
# iterative sampling
self.scheduler.set_timesteps(num_inference_steps)
latents_list = [latents]
pred_x0_list = [latents]
for i, t in enumerate(
tqdm(self.scheduler.timesteps, desc='DDIM Sampler')):
if ref_intermediate_latents is not None:
# note that the batch_size >= 2
latents_ref = ref_intermediate_latents[-1 - i]
_, latents_cur = latents.chunk(2)
latents = torch.cat([latents_ref, latents_cur])
if guidance_scale > 1.:
model_inputs = torch.cat([latents] * 2)
else:
model_inputs = latents
if unconditioning is not None and isinstance(unconditioning, list):
_, text_embeddings = text_embeddings.chunk(2)
text_embeddings = torch.cat([
unconditioning[i].expand(*text_embeddings.shape),
text_embeddings
])
# predict the noise
noise_pred = self.unet(
model_inputs, t, encoder_hidden_states=text_embeddings).sample
if guidance_scale > 1.:
noise_pred_uncon, noise_pred_con = noise_pred.chunk(2, dim=0)
noise_pred = noise_pred_uncon + guidance_scale * (
noise_pred_con - noise_pred_uncon)
# compute the previous noise sample x_t -> x_t-1
latents, pred_x0 = self.step(noise_pred, t, latents)
latents_list.append(latents)
pred_x0_list.append(pred_x0)
image = self.latent2image(latents, return_type='pt')
if return_intermediates:
pred_x0_list = [
self.latent2image(img, return_type='pt')
for img in pred_x0_list
]
latents_list = [
self.latent2image(img, return_type='pt')
for img in latents_list
]
return image, pred_x0_list, latents_list
return image
@torch.no_grad()
def invert(self,
image: torch.Tensor,
prompt,
num_inference_steps=50,
guidance_scale=7.5,
eta=0.0,
return_intermediates=False,
**kwds):
"""
invert a real image into noise map with determinisc DDIM inversion
"""
DEVICE = self._execution_device
batch_size = image.shape[0]
if isinstance(prompt, list):
if batch_size == 1:
image = image.expand(len(prompt), -1, -1, -1)
elif isinstance(prompt, str):
if batch_size > 1:
prompt = [prompt] * batch_size
# text embeddings
text_input = self.tokenizer(
prompt, padding='max_length', max_length=77, return_tensors='pt')
text_embeddings = self.text_encoder(text_input.input_ids.to(DEVICE))[0]
print('input text embeddings :', text_embeddings.shape)
# define initial latents
latents = self.image2latent(image)
start_latents = latents
# unconditional embedding for classifier free guidance
if guidance_scale > 1.:
unconditional_input = self.tokenizer(
[''] * batch_size,
padding='max_length',
max_length=77,
return_tensors='pt')
unconditional_embeddings = self.text_encoder(
unconditional_input.input_ids.to(DEVICE))[0]
text_embeddings = torch.cat(
[unconditional_embeddings, text_embeddings], dim=0)
print('latents shape: ', latents.shape)
self.scheduler.set_timesteps(num_inference_steps)
print('Valid timesteps: ', reversed(self.scheduler.timesteps))
latents_list = [latents]
pred_x0_list = [latents]
for i, t in enumerate(
tqdm(
reversed(self.scheduler.timesteps),
desc='DDIM Inversion')):
if guidance_scale > 1.:
model_inputs = torch.cat([latents] * 2)
else:
model_inputs = latents
# predict the noise
noise_pred = self.unet(
model_inputs, t, encoder_hidden_states=text_embeddings).sample
if guidance_scale > 1.:
noise_pred_uncon, noise_pred_con = noise_pred.chunk(2, dim=0)
noise_pred = noise_pred_uncon + guidance_scale * (
noise_pred_con - noise_pred_uncon)
# compute the previous noise sample x_t-1 -> x_t
latents, pred_x0 = self.next_step(noise_pred, t, latents)
latents_list.append(latents)
pred_x0_list.append(pred_x0)
if return_intermediates:
return latents, latents_list
return latents, start_latents

View File

@@ -82,7 +82,7 @@ class ImagePanopticSegmentationPipeline(Pipeline):
ids = ids[legal_indices]
labels = np.array([id % INSTANCE_OFFSET for id in ids], dtype=np.int64)
segms = (pan_results[None] == ids[:, None, None])
masks = [it.astype(np.int) for it in segms]
masks = [it.astype(np.int32) for it in segms]
labels_txt = np.array(self.model.CLASSES)[labels].tolist()
outputs = {
OutputKeys.MASKS: masks,

View File

@@ -31,7 +31,8 @@ class F3NetForProductSegmentationPipeline(Pipeline):
logger.info('load model done')
def preprocess(self, input: Input) -> Dict[str, Any]:
img = LoadImage.convert_to_ndarray(input['input_path'])
img = LoadImage.convert_to_ndarray(input)
img = img.astype(np.float32)
return img

View File

@@ -0,0 +1,71 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Any, Dict
from modelscope.metainfo import Pipelines
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Input, Model, Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.pipelines.util import is_model, is_official_hub_path
from modelscope.utils.constant import Invoke, Tasks
from modelscope.utils.logger import get_logger
logger = get_logger()
@PIPELINES.register_module(
Tasks.surface_recon_common, module_name=Pipelines.surface_recon_common)
class SurfaceReconCommonPipeline(Pipeline):
""" Surface reconstruction common pipeline
Example:
```python
>>> from modelscope.pipelines import pipeline
>>> surface_recon_common = pipeline(Tasks.surface_recon_common,
'damo/cv_surface-reconstruction-common')
>>> surface_recon_common({
'data_dir': '/data/lego', # data dir path (str)
'save_dir': './output', # save dir path (str)
})
>>> #
```
"""
def __init__(self, model, device='gpu', **kwargs):
"""
use model to create a image sky change pipeline for image editing
Args:
model (str or Model): model_id on modelscope hub
device (str): only support gpu
"""
model = Model.from_pretrained(
model,
device=device,
model_prefetched=True,
invoked_by=Invoke.PIPELINE) if is_model(model) else model
super().__init__(model=model, **kwargs)
if not isinstance(self.model, Model):
logger.error('model object is not initialized.')
raise Exception('model object is not initialized.')
logger.info('load model done')
def preprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
return inputs
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
data_dir = input['data_dir']
save_dir = input['save_dir']
if 'color' in input:
color = input['color']
else:
color = False
if 'n_directions' in input:
n_directions = input['n_directions']
else:
n_directions = 8
self.model.surface_reconstruction(data_dir, save_dir, color,
n_directions)
return {OutputKeys.OUTPUT: 'Done'}
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
return inputs

View File

@@ -4,24 +4,27 @@ from typing import TYPE_CHECKING
from modelscope.utils.import_utils import LazyImportModule
if TYPE_CHECKING:
from .generative_multi_modal_embedding_pipeline import GEMMMultiModalEmbeddingPipeline
from .asr_pipeline import AutomaticSpeechRecognitionPipeline
from .diffusers_wrapped import (ChineseStableDiffusionPipeline,
StableDiffusionPipeline)
from .document_vl_embedding_pipeline import DocumentVLEmbeddingPipeline
from .generative_multi_modal_embedding_pipeline import \
GEMMMultiModalEmbeddingPipeline
from .image_captioning_pipeline import ImageCaptioningPipeline
from .visual_entailment_pipeline import VisualEntailmentPipeline
from .visual_grounding_pipeline import VisualGroundingPipeline
from .mgeo_ranking_pipeline import MGeoRankingPipeline
from .multi_modal_embedding_pipeline import MultiModalEmbeddingPipeline
from .multimodal_dialogue_pipeline import MultimodalDialoguePipeline
from .prost_text_video_retrieval_pipeline import \
ProSTTextVideoRetrievalPipeline
from .soonet_video_temporal_grounding_pipeline import \
SOONetVideoTemporalGroundingPipeline
from .text_to_image_synthesis_pipeline import TextToImageSynthesisPipeline
from .text_to_video_synthesis_pipeline import TextToVideoSynthesisPipeline
from .video_captioning_pipeline import VideoCaptioningPipeline
from .video_multi_modal_embedding_pipeline import \
VideoMultiModalEmbeddingPipeline
from .visual_question_answering_pipeline import VisualQuestionAnsweringPipeline
from .asr_pipeline import AutomaticSpeechRecognitionPipeline
from .mgeo_ranking_pipeline import MGeoRankingPipeline
from .document_vl_embedding_pipeline import DocumentVLEmbeddingPipeline
from .video_captioning_pipeline import VideoCaptioningPipeline
from .video_question_answering_pipeline import VideoQuestionAnsweringPipeline
from .diffusers_wrapped import StableDiffusionPipeline, ChineseStableDiffusionPipeline
from .soonet_video_temporal_grounding_pipeline import SOONetVideoTemporalGroundingPipeline
from .text_to_video_synthesis_pipeline import TextToVideoSynthesisPipeline
from .multimodal_dialogue_pipeline import MultimodalDialoguePipeline
from .videocomposer_pipeline import VideoComposerPipeline
else:
_import_structure = {
@@ -29,6 +32,8 @@ else:
'visual_entailment_pipeline': ['VisualEntailmentPipeline'],
'visual_grounding_pipeline': ['VisualGroundingPipeline'],
'multi_modal_embedding_pipeline': ['MultiModalEmbeddingPipeline'],
'prost_text_video_retrieval_pipeline':
['ProSTTextVideoRetrievalPipeline'],
'text_to_image_synthesis_pipeline': ['TextToImageSynthesisPipeline'],
'visual_question_answering_pipeline':
['VisualQuestionAnsweringPipeline'],

View File

@@ -0,0 +1,21 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import TYPE_CHECKING
from modelscope.utils.import_utils import LazyImportModule
if TYPE_CHECKING:
from .cones2_inference_pipeline import Cones2InferencePipeline
else:
_import_structure = {
'cones2_inference_pipeline': ['Cones2InferencePipeline'],
}
import sys
sys.modules[__name__] = LazyImportModule(
__name__,
globals()['__file__'],
_import_structure,
module_spec=__spec__,
extra_objects={},
)

View File

@@ -0,0 +1,494 @@
# Copyright 2023 The HuggingFace Team.
# Copyright 2023 The Alibaba Fundamental Vision Team Authors. All rights reserved.
# The implementation here is modified based on diffusers,
# originally Apache License, Copyright 2023 The HuggingFace Team
import math
from typing import Any, Dict
import cv2
import numpy as np
import torch
import torch.nn.functional as F
from diffusers import LMSDiscreteScheduler, StableDiffusionPipeline
from diffusers.models.cross_attention import CrossAttention
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import \
StableDiffusionPipelineOutput
from PIL import Image
from tqdm.auto import tqdm
from modelscope.metainfo import Pipelines
from modelscope.outputs import OutputKeys
from modelscope.pipelines.builder import PIPELINES
from modelscope.pipelines.multi_modal.diffusers_wrapped.diffusers_pipeline import \
DiffusersPipeline
from modelscope.utils.constant import Tasks
@PIPELINES.register_module(
Tasks.text_to_image_synthesis, module_name=Pipelines.cones2_inference)
class Cones2InferencePipeline(DiffusersPipeline):
r""" Cones2 Inference Pipeline.
Examples:
>>> from modelscope.pipelines import pipeline
>>> pipeline =pipeline(task=Tasks.text_to_image_synthesis, model= 'damo/Cones2', model_revision='v1.0.1')
>>> {
>>> "text": 'a mug and a dog on the beach',
>>> "subject_list": [["mug", 2], ["dog", 5]],
>>> "color_context": {"255,192,0": ["mug", 2.5], "255,0,0": ["dog", 2.5]},
>>> "layout": 'data/test/images/mask_example.png'
>>> }
>>>
"""
def __init__(self, model: str, device: str = 'gpu', **kwargs):
"""
use `model` to create a stable diffusion pipeline
Args:
model: model id on modelscope hub.
device: str = 'gpu'
"""
super().__init__(model, device, **kwargs)
self.pipeline = StableDiffusionPipeline.from_pretrained(model)
self.pipeline.text_encoder.pooler = None
self.pipeline.to(self.device)
def forward(self, inputs: Dict[str, Any],
**forward_params) -> Dict[str, Any]:
if not isinstance(inputs, dict):
raise ValueError(
f'Expected the input to be a dictionary, but got {type(input)}'
)
if 'text' not in inputs:
raise ValueError('input should contain "text", but not found')
return self.layout_guidance_sampling(
prompt=inputs.get('text'),
residual_dict=inputs.get('residual_dict', None),
subject_list=inputs.get('subject_list'),
color_context=inputs.get('color_context', None),
layout=inputs.get('layout', None),
)
@torch.no_grad()
def layout_guidance_sampling(
self,
prompt='',
residual_dict=None,
subject_list=None,
color_context=None,
layout=None,
cfg_scale=7.5,
inference_steps=50,
guidance_steps=50,
guidance_weight=0.05,
weight_negative=-1e8,
):
layout = Image.open(layout).resize((768, 768)).convert('RGB')
subject_color_dict = {
tuple(map(int, key.split(','))): value
for key, value in color_context.items()
}
vae = self.pipeline.vae
unet = self.pipeline.unet
text_encoder = self.pipeline.text_encoder
tokenizer = self.pipeline.tokenizer
unconditional_input_prompt = ''
scheduler = LMSDiscreteScheduler.from_config(
self.pipeline.scheduler.config)
scheduler.set_timesteps(inference_steps, device=self.device)
if guidance_steps > 0:
guidance_steps = min(guidance_steps, inference_steps)
scheduler_guidance = LMSDiscreteScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule='scaled_linear',
num_train_timesteps=1000,
)
scheduler_guidance.set_timesteps(
guidance_steps, device=self.device)
# Process input prompt text
text_input = tokenizer(
[prompt],
padding='max_length',
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors='pt',
)
# Edit text embedding conditions with residual token embeddings.
cond_embeddings = text_encoder(text_input.input_ids.to(self.device))[0]
if residual_dict is not None:
for name, token in subject_list:
residual_token_embedding = torch.load(residual_dict[name])
cond_embeddings[0][token] += residual_token_embedding.reshape(
1024)
# Process unconditional input "" for classifier-free guidance.
max_length = text_input.input_ids.shape[-1]
uncond_input = tokenizer([unconditional_input_prompt],
padding='max_length',
max_length=max_length,
return_tensors='pt')
uncond_embeddings = text_encoder(
uncond_input.input_ids.to(self.device))[0]
register_attention_control(unet)
# Calculate the hidden features for each cross attention layer.
hidden_states, uncond_hidden_states = _extract_cross_attention(
tokenizer, self.device, layout, subject_color_dict, text_input,
weight_negative)
hidden_states['CONDITION_TENSOR'] = cond_embeddings
uncond_hidden_states['CONDITION_TENSOR'] = uncond_embeddings
hidden_states['function'] = lambda w, sigma, qk: (
guidance_weight * w * math.log(1 + sigma**2)) * qk.std()
uncond_hidden_states['function'] = lambda w, sigma, qk: 0.0
# Sampling the initial latents.
latent_size = (1, unet.in_channels, 96, 96)
latents = torch.randn(latent_size).to(self.device)
latents = latents * scheduler.init_noise_sigma
for i, t in tqdm(
enumerate(scheduler.timesteps),
total=len(scheduler.timesteps)):
# Improve the harmony of generated images by self-recurrence.
if i < guidance_steps:
loop = 2
else:
loop = 1
for k in range(loop):
if i < guidance_steps:
sigma = scheduler_guidance.sigmas[i]
latent_model_input = scheduler.scale_model_input(
latents, t)
_t = t
hidden_states.update({'SIGMA': sigma})
noise_pred_text = unet(
latent_model_input,
_t,
encoder_hidden_states=hidden_states,
).sample
uncond_hidden_states.update({'SIGMA': sigma})
noise_pred_uncond = unet(
latent_model_input,
_t,
encoder_hidden_states=uncond_hidden_states,
).sample
noise_pred = noise_pred_uncond + cfg_scale * (
noise_pred_text - noise_pred_uncond)
latents = scheduler.step(noise_pred, t, latents,
1).prev_sample
# Self-recurrence.
if k < 1 and loop > 1:
noise_recurent = torch.randn(latents.shape).to(
self.device)
sigma_difference = scheduler.sigmas[
i]**2 - scheduler.sigmas[i + 1]**2
latents = latents + noise_recurent * (
sigma_difference**0.5)
else:
latent_model_input = scheduler.scale_model_input(
latents, t)
_t = t
noise_pred_text = unet(
latent_model_input,
_t,
encoder_hidden_states=cond_embeddings,
).sample
latent_model_input = scheduler.scale_model_input(
latents, t)
noise_pred_uncond = unet(
latent_model_input,
_t,
encoder_hidden_states=uncond_embeddings,
).sample
noise_pred = noise_pred_uncond + cfg_scale * (
noise_pred_text - noise_pred_uncond)
latents = scheduler.step(noise_pred, t, latents,
1).prev_sample
edited_images = _latents_to_images(vae, latents)
return StableDiffusionPipelineOutput(
images=edited_images, nsfw_content_detected=None)
def postprocess(self, inputs: Dict[str, Any], **kwargs) -> Dict[str, Any]:
images = []
for img in inputs.images:
if isinstance(img, Image.Image):
img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
images.append(img)
return {OutputKeys.OUTPUT_IMGS: images}
class Cones2AttnProcessor:
def __init__(self):
super().__init__()
def __call__(self,
attn: CrossAttention,
hidden_states,
encoder_hidden_states=None,
attention_mask=None):
batch_size, sequence_length, _ = hidden_states.shape
query = attn.to_q(hidden_states)
is_dict_format = True
if encoder_hidden_states is not None:
if 'CONDITION_TENSOR' in encoder_hidden_states:
encoder_hidden = encoder_hidden_states['CONDITION_TENSOR']
else:
encoder_hidden = encoder_hidden_states
is_dict_format = False
else:
encoder_hidden = hidden_states
key = attn.to_k(encoder_hidden)
value = attn.to_v(encoder_hidden)
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
attention_scores = torch.matmul(query, key.transpose(-1, -2))
attention_size_of_img = attention_scores.size()[-2]
if attention_scores.size()[2] == 77:
if is_dict_format:
f = encoder_hidden_states['function']
try:
w = encoder_hidden_states[
f'CA_WEIGHT_{attention_size_of_img}']
except KeyError:
w = encoder_hidden_states['CA_WEIGHT_ORIG']
if not isinstance(w, int):
img_h, img_w, nc = w.shape
ratio = math.sqrt(img_h * img_w
/ attention_size_of_img)
w = F.interpolate(
w.permute(2, 0, 1).unsqueeze(0),
scale_factor=1 / ratio,
mode='bilinear',
align_corners=True)
w = F.interpolate(
w.reshape(1, nc, -1),
size=(attention_size_of_img, ),
mode='nearest').permute(2, 1, 0).squeeze()
else:
w = 0
if type(w) is int and w == 0:
sigma = encoder_hidden_states['SIGMA']
cross_attention_weight = f(w, sigma, attention_scores)
else:
bias = torch.zeros_like(w)
bias[torch.where(w > 0)] = attention_scores.std() * 0
sigma = encoder_hidden_states['SIGMA']
cross_attention_weight = f(w, sigma, attention_scores)
cross_attention_weight = cross_attention_weight + bias
else:
cross_attention_weight = 0.0
else:
cross_attention_weight = 0.0
attention_scores = (attention_scores
+ cross_attention_weight) * attn.scale
attention_probs = attention_scores.softmax(dim=-1)
hidden_states = torch.matmul(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
def register_attention_control(unet):
attn_procs = {}
for name in unet.attn_processors.keys():
attn_procs[name] = Cones2AttnProcessor()
unet.set_attn_processor(attn_procs)
def _tokens_img_attention_weight(img_context_seperated,
tokenized_texts,
ratio: int = 8,
original_shape=False):
token_lis = tokenized_texts['input_ids'][0].tolist()
w, h = img_context_seperated[0][1].shape
w_r, h_r = round(w / ratio), round(h / ratio)
ret_tensor = torch.zeros((w_r * h_r, len(token_lis)), dtype=torch.float32)
for v_as_tokens, img_where_color in img_context_seperated:
is_in = 0
for idx, tok in enumerate(token_lis):
if token_lis[idx:idx + len(v_as_tokens)] == v_as_tokens:
is_in = 1
ret_tensor[:, idx:idx + len(v_as_tokens)] += (
_downsampling(img_where_color, w_r,
h_r).reshape(-1,
1).repeat(1, len(v_as_tokens)))
if not is_in == 1:
print(
f'Warning ratio {ratio} : tokens {v_as_tokens} not found in text'
)
if original_shape:
ret_tensor = ret_tensor.reshape((w_r, h_r, len(token_lis)))
return ret_tensor
def _image_context_seperator(img, color_context: dict, _tokenizer, neg: float):
ret_lists = []
if img is not None:
w, h = img.size
matrix = np.zeros((h, w))
for color, v in color_context.items():
color = tuple(color)
if len(color) > 3:
color = color[:3]
if isinstance(color, str):
r, g, b = color[1:3], color[3:5], color[5:7]
color = (int(r, 16), int(g, 16), int(b, 16))
img_where_color = (np.array(img) == color).all(axis=-1)
matrix[img_where_color] = 1
for color, (subject, weight_active) in color_context.items():
if len(color) > 3:
color = color[:3]
v_input = _tokenizer(
subject,
max_length=_tokenizer.model_max_length,
truncation=True,
)
v_as_tokens = v_input['input_ids'][1:-1]
if isinstance(color, str):
r, g, b = color[1:3], color[3:5], color[5:7]
color = (int(r, 16), int(g, 16), int(b, 16))
img_where_color = (np.array(img) == color).all(axis=-1)
matrix[img_where_color] = 1
if not img_where_color.sum() > 0:
print(
f'Warning : not a single color {color} not found in image')
img_where_color_init = torch.where(
torch.tensor(img_where_color, dtype=torch.bool), weight_active,
neg)
img_where_color = torch.where(
torch.from_numpy(matrix == 1) & (img_where_color_init == 0.0),
torch.tensor(neg), img_where_color_init)
ret_lists.append((v_as_tokens, img_where_color))
else:
w, h = 768, 768
if len(ret_lists) == 0:
ret_lists.append(([-1], torch.zeros((w, h), dtype=torch.float32)))
return ret_lists, w, h
def _extract_cross_attention(tokenizer, device, color_map_image, color_context,
text_input, neg):
# Process color map image and context
seperated_word_contexts, width, height = _image_context_seperator(
color_map_image, color_context, tokenizer, neg)
# Compute cross-attention weights
cross_attention_weight_1 = _tokens_img_attention_weight(
seperated_word_contexts, text_input, ratio=1,
original_shape=True).to(device)
cross_attention_weight_8 = _tokens_img_attention_weight(
seperated_word_contexts, text_input, ratio=8).to(device)
cross_attention_weight_16 = _tokens_img_attention_weight(
seperated_word_contexts, text_input, ratio=16).to(device)
cross_attention_weight_32 = _tokens_img_attention_weight(
seperated_word_contexts, text_input, ratio=32).to(device)
cross_attention_weight_64 = _tokens_img_attention_weight(
seperated_word_contexts, text_input, ratio=64).to(device)
hidden_states = {
'CA_WEIGHT_ORIG': cross_attention_weight_1, # 768 x 768
'CA_WEIGHT_9216': cross_attention_weight_8, # 96 x 96
'CA_WEIGHT_2304': cross_attention_weight_16, # 48 x 48
'CA_WEIGHT_576': cross_attention_weight_32, # 24 x 24
'CA_WEIGHT_144': cross_attention_weight_64, # 12 x 12
}
uncond_hidden_states = {
'CA_WEIGHT_ORIG': 0,
'CA_WEIGHT_9216': 0,
'CA_WEIGHT_2304': 0,
'CA_WEIGHT_576': 0,
'CA_WEIGHT_144': 0,
}
return hidden_states, uncond_hidden_states
def _downsampling(img: torch.tensor, w: int, h: int) -> torch.tensor:
return F.interpolate(
img.unsqueeze(0).unsqueeze(1),
size=(w, h),
mode='bilinear',
align_corners=True,
).squeeze()
def _latents_to_images(vae, latents, scale_factor=0.18215):
"""Decode latents to PIL images."""
scaled_latents = 1.0 / scale_factor * latents.clone()
images = vae.decode(scaled_latents).sample
images = (images / 2 + 0.5).clamp(0, 1)
images = images.detach().cpu().permute(0, 2, 3, 1).numpy()
if images.ndim == 3:
images = images[None, ...]
images = (images * 255).round().astype('uint8')
pil_images = [Image.fromarray(image) for image in images]
return pil_images
def _sanitize_parameters(self, **pipeline_parameters):
"""
this method should sanitize the keyword args to preprocessor params,
forward params and postprocess params on '__call__' or '_process_single' method
Returns:
Dict[str, str]: preprocess_params = {'image_resolution': self.model.get_resolution()}
Dict[str, str]: forward_params = pipeline_parameters
Dict[str, str]: postprocess_params = {}
"""
pipeline_parameters['image_resolution'] = self.model.get_resolution()
pipeline_parameters['modelsetting'] = self.model.get_config()
pipeline_parameters['model_dir'] = self.model.get_model_dir()
pipeline_parameters['control_type'] = self.init_control_type
pipeline_parameters['device'] = self.device

View File

@@ -0,0 +1,56 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Any, Dict
from modelscope.metainfo import Pipelines
from modelscope.pipelines.base import Input, Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.utils.constant import Tasks
from modelscope.utils.device import device_placement
from modelscope.utils.logger import get_logger
logger = get_logger()
@PIPELINES.register_module(
Tasks.text_video_retrieval,
module_name=Pipelines.prost_text_video_retrieval)
class ProSTTextVideoRetrievalPipeline(Pipeline):
'''
https://www.modelscope.cn/models/damo/multi_modal_clip_vtretrieval_prost/summary
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
text_video_retrieval= pipeline(
Tasks.text_video_retrieval,
model='damo/multi_modal_clip_vtretrieval_prost')
video_path = 'https://modelscope.oss-cn-beijing.aliyuncs.com/test/videos/multi_modal_test_video_9770.mp4'
caption = 'a person is connecting something to system'
_input = {'video': video_path, 'text': caption}
result = text_video_retrieval(_input)
'''
def __init__(self, model: str, **kwargs):
"""
use `model` to create a text_video_retrieval pipeline for prediction
Args:
model: model id on modelscope hub.
"""
super().__init__(model=model)
self.model.eval()
def preprocess(self, input: Input) -> Dict[str, Any]:
return input
def _process_single(self, input: Input, *args, **kwargs) -> Dict[str, Any]:
with device_placement(self.framework, self.device_name):
out = self.forward(input)
self._check_output(out)
return out
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
return self.model(input)
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
return inputs

View File

@@ -27,7 +27,9 @@ if TYPE_CHECKING:
from .translation_quality_estimation_pipeline import TranslationQualityEstimationPipeline
from .text_error_correction_pipeline import TextErrorCorrectionPipeline
from .word_alignment_pipeline import WordAlignmentPipeline
from .text_generation_pipeline import TextGenerationPipeline, TextGenerationT5Pipeline, SeqGPTPipeline
from .text_generation_pipeline import TextGenerationPipeline, TextGenerationT5Pipeline, \
SeqGPTPipeline, ChatGLM6bTextGenerationPipeline, ChatGLM6bV2TextGenerationPipeline, \
QWenChatPipeline, QWenTextGenerationPipeline, Llama2TaskPipeline
from .fid_dialogue_pipeline import FidDialoguePipeline
from .token_classification_pipeline import TokenClassificationPipeline
from .translation_pipeline import TranslationPipeline
@@ -80,7 +82,10 @@ else:
'word_alignment_pipeline': ['WordAlignmentPipeline'],
'text_generation_pipeline': [
'TextGenerationPipeline', 'TextGenerationT5Pipeline',
'SeqGPTPipeline'
'ChatGLM6bTextGenerationPipeline',
'ChatGLM6bV2TextGenerationPipeline', 'QWenChatPipeline',
'QWenTextGenerationPipeline', 'SeqGPTPipeline',
'Llama2TaskPipeline'
],
'fid_dialogue_pipeline': ['FidDialoguePipeline'],
'token_classification_pipeline': ['TokenClassificationPipeline'],

Some files were not shown because too many files have changed in this diff Show More