mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-25 20:49:37 +01:00
Merge branch 'release/1.9' of gitlab.alibaba-inc.com:Ali-MaaS/MaaS-lib into master-github
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
BASE_CPU_IMAGE=reg.docker.alibaba-inc.com/modelscope/ubuntu:20.04
|
||||
BASE_GPU_CUDA113_IMAGE=reg.docker.alibaba-inc.com/modelscope/ubuntu:20.04-cuda11.3.0-cudnn8-devel
|
||||
BASE_GPU_CUDA117_IMAGE=reg.docker.alibaba-inc.com/modelscope/ubuntu:20.04-cuda11.7.1-cudnn8-devel
|
||||
BASE_GPU_CUDA118_IMAGE=reg.docker.alibaba-inc.com/modelscope/ubuntu:20.04-cuda11.8.0-cudnn8-devel
|
||||
MODELSCOPE_REPO_ADDRESS=reg.docker.alibaba-inc.com/modelscope/modelscope
|
||||
python_version=3.7.13
|
||||
torch_version=1.11.0
|
||||
@@ -73,6 +74,10 @@ elif [ "$cuda_version" == 11.7.1 ]; then
|
||||
echo "Building base image cuda11.7.1"
|
||||
cudatoolkit_version=cu117
|
||||
BASE_GPU_IMAGE=$BASE_GPU_CUDA117_IMAGE
|
||||
elif [ "$cuda_version" == 11.8.0 ]; then
|
||||
echo "Building base image cuda11.8.0"
|
||||
cudatoolkit_version=cu118
|
||||
BASE_GPU_IMAGE=$BASE_GPU_CUDA118_IMAGE
|
||||
else
|
||||
echo "Unsupport cuda version: $cuda_version"
|
||||
exit 1
|
||||
|
||||
@@ -42,6 +42,8 @@ for i in "$@"; do
|
||||
cudatoolkit_version=11.3
|
||||
elif [ "$cuda_version" == "11.7.1" ]; then
|
||||
cudatoolkit_version=11.7
|
||||
elif [ "$cuda_version" == "11.8.0" ]; then
|
||||
cudatoolkit_version=11.8
|
||||
else
|
||||
echo "Unsupport cuda version $cuda_version"
|
||||
exit 1
|
||||
|
||||
@@ -9,7 +9,7 @@ cpu_sets_arr=($cpu_sets)
|
||||
is_get_file_lock=false
|
||||
CI_COMMAND=${CI_COMMAND:-bash .dev_scripts/ci_container_test.sh python tests/run.py --parallel 2 --run_config tests/run_config.yaml}
|
||||
echo "ci command: $CI_COMMAND"
|
||||
PR_CHANGED_FILES="${PR_CHANGED_FILES:-''}"
|
||||
PR_CHANGED_FILES="${PR_CHANGED_FILES:-}"
|
||||
echo "PR modified files: $PR_CHANGED_FILES"
|
||||
PR_CHANGED_FILES=${PR_CHANGED_FILES//[ ]/#}
|
||||
echo "PR_CHANGED_FILES: $PR_CHANGED_FILES"
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
ARG BASE_IMAGE=reg.docker.alibaba-inc.com/modelscope/modelscope:ubuntu20.04-cuda11.3.0-py37-torch1.11.0-tf1.15.5-base
|
||||
FROM $BASE_IMAGE
|
||||
|
||||
RUN apt-get update && apt-get install -y iputils-ping net-tools iproute2 && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
# install modelscope
|
||||
COPY requirements /var/modelscope
|
||||
RUN pip install --no-cache-dir --upgrade pip && \
|
||||
@@ -31,9 +34,9 @@ RUN pip install --no-cache-dir mpi4py paint_ldm \
|
||||
|
||||
# for cpu install cpu version faiss, faiss depends on blas lib, we install libopenblas TODO rename gpu or cpu version faiss
|
||||
RUN if [ "$USE_GPU" = "True" ] ; then \
|
||||
pip install --no-cache-dir funtextprocessing kwsbp==0.0.6 faiss==1.7.2 safetensors typeguard==2.13.3 scikit-learn 'pandas<1.4.0' librosa==0.9.2 funasr -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html; \
|
||||
pip install --no-cache-dir funtextprocessing kwsbp==0.0.6 faiss==1.7.2 safetensors typeguard==2.13.3 scikit-learn librosa==0.9.2 funasr -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html; \
|
||||
else \
|
||||
pip install --no-cache-dir funtextprocessing kwsbp==0.0.6 https://modelscope.oss-cn-beijing.aliyuncs.com/releases/dependencies/faiss-1.7.2-py37-none-linux_x86_64.whl safetensors typeguard==2.13.3 scikit-learn 'pandas<1.4.0' librosa==0.9.2 funasr -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html; \
|
||||
pip install --no-cache-dir funtextprocessing kwsbp==0.0.6 https://modelscope.oss-cn-beijing.aliyuncs.com/releases/dependencies/faiss-1.7.2-py37-none-linux_x86_64.whl safetensors typeguard==2.13.3 scikit-learn librosa==0.9.2 funasr -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html; \
|
||||
fi
|
||||
|
||||
RUN pip install --no-cache-dir wenetruntime==1.11.0 adaseq --no-deps
|
||||
@@ -44,5 +47,11 @@ ENV SETUPTOOLS_USE_DISTUTILS=stdlib
|
||||
|
||||
RUN CUDA_HOME=/usr/local/cuda TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0 7.5 8.0 8.6" pip install --no-cache-dir 'git+https://github.com/facebookresearch/detectron2.git'
|
||||
|
||||
# add basicsr
|
||||
RUN pip install --no-cache-dir basicsr
|
||||
# torchmetrics==0.11.4 for ofa
|
||||
RUN pip install --no-cache-dir tiktoken torchmetrics==0.11.4 'transformers<4.31.0' transformers_stream_generator 'protobuf<=3.20.0' bitsandbytes basicsr
|
||||
COPY docker/scripts/install_flash_attension.sh /tmp/install_flash_attension.sh
|
||||
RUN if [ "$USE_GPU" = "True" ] ; then \
|
||||
bash /tmp/install_flash_attension.sh; \
|
||||
else \
|
||||
echo 'cpu unsupport flash attention'; \
|
||||
fi
|
||||
|
||||
@@ -69,14 +69,20 @@ RUN if [ "$USE_GPU" = "True" ] ; then \
|
||||
# install tensorflow
|
||||
ARG TENSORFLOW_VERSION=1.15.5
|
||||
RUN if [ "$USE_GPU" = "True" ] ; then \
|
||||
pip install --no-cache-dir tensorflow==$TENSORFLOW_VERSION -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html; \
|
||||
if [ "$TENSORFLOW_VERSION" = "1.15.5" ] ; then \
|
||||
pip install --no-cache-dir tensorflow==$TENSORFLOW_VERSION -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html; \
|
||||
else \
|
||||
pip install --no-cache-dir tensorflow==$TENSORFLOW_VERSION; \
|
||||
fi \
|
||||
else \
|
||||
# only python 3.7 has tensorflow 1.15.5
|
||||
if [ "$PYTHON_VERSION" = "3.7.13" ] ; then \
|
||||
pip install --no-cache-dir tensorflow==$TENSORFLOW_VERSION; \
|
||||
else \
|
||||
elif [ "$TENSORFLOW_VERSION" = "1.15.5" ] ; then \
|
||||
pip install --no-cache-dir numpy==1.18.5 https://modelscope.oss-cn-beijing.aliyuncs.com/releases/dependencies/tensorflow-1.15.5-cp38-cp38-linux_x86_64.whl; \
|
||||
fi \
|
||||
else \
|
||||
pip install --no-cache-dir tensorflow==$TENSORFLOW_VERSION; \
|
||||
fi \
|
||||
fi
|
||||
|
||||
# mmcv-full<=1.7.0 for mmdet3d compatible
|
||||
|
||||
6
docker/scripts/install_flash_attension.sh
Normal file
6
docker/scripts/install_flash_attension.sh
Normal file
@@ -0,0 +1,6 @@
|
||||
git clone -b v1.0.8 https://github.com/Dao-AILab/flash-attention && \
|
||||
cd flash-attention && pip install . && \
|
||||
pip install csrc/layer_norm && \
|
||||
pip install csrc/rotary && \
|
||||
cd .. && \
|
||||
rm -rf flash-attention
|
||||
@@ -1,14 +1,20 @@
|
||||
export CMAKE_BUILD_PARALLEL_LEVEL=36 && export MAX_JOBS=36 && export CMAKE_CUDA_ARCHITECTURES="50;52;60;61;70;75;80;86" \
|
||||
&& pip install --no-cache-dir fvcore iopath \
|
||||
&& curl -LO https://github.com/NVIDIA/cub/archive/1.16.0.tar.gz \
|
||||
&& tar xzf 1.16.0.tar.gz \
|
||||
&& export CUB_HOME=$PWD/cub-1.16.0 \
|
||||
export CMAKE_BUILD_PARALLEL_LEVEL=36 \
|
||||
&& export MAX_JOBS=36 \
|
||||
&& export CMAKE_CUDA_ARCHITECTURES="50;52;60;61;70;75;80;86" \
|
||||
&& git clone --branch 2.1.0 --recursive https://github.com/NVIDIA/thrust.git \
|
||||
&& cd thrust \
|
||||
&& mkdir build \
|
||||
&& cd build \
|
||||
&& cmake -DCMAKE_INSTALL_PREFIX=/usr/local/cuda/ -DTHRUST_INCLUDE_CUB_CMAKE=ON .. \
|
||||
&& make install \
|
||||
&& cd ../.. \
|
||||
&& rm -rf thrust \
|
||||
&& pip install --no-cache-dir fvcore iopath \
|
||||
&& pip install "git+https://github.com/facebookresearch/pytorch3d.git@stable" \
|
||||
&& rm -fr 1.16.0.tar.gz cub-1.16.0 \
|
||||
&& apt-get update \
|
||||
&& apt-get install -y --no-install-recommends pkg-config libglvnd0 libgl1 libglx0 libegl1 libgles2 libglvnd-dev libgl1-mesa-dev libegl1-mesa-dev libgles2-mesa-dev -y \
|
||||
&& apt-get install -y --no-install-recommends pkg-config libglvnd0 libgl1 libglx0 libegl1 libgles2 libglvnd-dev libgl1-mesa-dev libegl1-mesa-dev libgles2-mesa-dev -y \
|
||||
&& git clone https://github.com/NVlabs/nvdiffrast.git \
|
||||
&& cd nvdiffrast \
|
||||
&& cd nvdiffrast \
|
||||
&& pip install --no-cache-dir . \
|
||||
&& cd .. \
|
||||
&& rm -rf nvdiffrast
|
||||
|
||||
@@ -10,10 +10,11 @@ import json
|
||||
import torch
|
||||
from swift import LoRAConfig, Swift
|
||||
|
||||
from modelscope import TrainingArgs
|
||||
from modelscope import TrainingArgs, build_dataset_from_file
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.metainfo import Trainers
|
||||
from modelscope.models.nlp.llama import LlamaForTextGeneration, LlamaTokenizer
|
||||
from modelscope.msdatasets import MsDataset
|
||||
from modelscope.msdatasets.dataset_cls.custom_datasets.torch_custom_dataset import \
|
||||
TorchCustomDataset
|
||||
from modelscope.trainers import build_trainer
|
||||
@@ -38,6 +39,23 @@ PROMPT_DICT = {
|
||||
|
||||
@dataclass(init=False)
|
||||
class TextGenerationArguments(TrainingArgs):
|
||||
instruction: str = field(
|
||||
default='instruction',
|
||||
metadata={
|
||||
'help': 'The instruction text key of dataset',
|
||||
})
|
||||
|
||||
input: str = field(
|
||||
default='input', metadata={
|
||||
'help': 'The input text key of dataset',
|
||||
})
|
||||
|
||||
output: str = field(
|
||||
default='output',
|
||||
metadata={
|
||||
'help': 'The output text key of dataset',
|
||||
})
|
||||
|
||||
src_txt: str = field(
|
||||
default=None,
|
||||
metadata={
|
||||
@@ -145,12 +163,7 @@ def smart_tokenizer_and_embedding_resize(special_tokens_dict, tokenizer,
|
||||
class SupervisedDataset(TorchCustomDataset):
|
||||
"""Dataset for supervised fine-tuning."""
|
||||
|
||||
def __init__(self, data_path: str, tokenizer):
|
||||
logging.warning('Loading data...')
|
||||
f = open(data_path, 'r')
|
||||
list_data_dict = json.load(f)
|
||||
f.close()
|
||||
|
||||
def __init__(self, list_data_dict, tokenizer):
|
||||
logging.warning('Formatting inputs...')
|
||||
prompt_input, prompt_no_input = PROMPT_DICT[
|
||||
'prompt_input'], PROMPT_DICT['prompt_no_input']
|
||||
@@ -173,6 +186,24 @@ class SupervisedDataset(TorchCustomDataset):
|
||||
def __len__(self):
|
||||
return len(self.input_ids)
|
||||
|
||||
def __getitem__(self, i):
|
||||
if isinstance(i, int):
|
||||
return dict(input_ids=self.input_ids[i], labels=self.labels[i])
|
||||
elif isinstance(i, slice):
|
||||
return SliceSupervisedDataset(self.input_ids, self.labels, i)
|
||||
else:
|
||||
raise TypeError(f'Unsupported input type: {type(i)}')
|
||||
|
||||
|
||||
class SliceSupervisedDataset(TorchCustomDataset):
|
||||
|
||||
def __init__(self, input_ids, labels, slice_):
|
||||
self.input_ids = input_ids[slice_]
|
||||
self.labels = labels[slice_]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.input_ids)
|
||||
|
||||
def __getitem__(self, i):
|
||||
return dict(input_ids=self.input_ids[i], labels=self.labels[i])
|
||||
|
||||
@@ -199,7 +230,9 @@ class DataCollatorForSupervisedDataset(object):
|
||||
)
|
||||
|
||||
|
||||
config, args = TextGenerationArguments().parse_cli().to_config()
|
||||
training_args = TextGenerationArguments().parse_cli()
|
||||
config, args = training_args.to_config()
|
||||
print(args)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -217,7 +250,7 @@ if __name__ == '__main__':
|
||||
}
|
||||
cfg.train.optimizer = {
|
||||
'type': 'AdamW',
|
||||
'lr': 2e-5,
|
||||
'lr': training_args.lr,
|
||||
'weight_decay': 0.0,
|
||||
'options': {
|
||||
'cumulative_iters': 8,
|
||||
@@ -227,9 +260,15 @@ if __name__ == '__main__':
|
||||
}
|
||||
}
|
||||
}
|
||||
cfg.train.logging = {'interval': 8, 'by_epoch': False}
|
||||
cfg.train.logging = {
|
||||
'interval': training_args.logging_interval,
|
||||
'by_epoch': False
|
||||
}
|
||||
cfg.train['bf16'] = True
|
||||
cfg.train.dataloader = {'batch_size_per_gpu': 4, 'workers_per_gpu': 1}
|
||||
cfg.train.dataloader = {
|
||||
'batch_size_per_gpu': training_args.per_device_train_batch_size,
|
||||
'workers_per_gpu': 1
|
||||
}
|
||||
if 'hooks' not in cfg.train:
|
||||
cfg.train['hooks'] = []
|
||||
if args.deepspeed is not None:
|
||||
@@ -247,8 +286,49 @@ if __name__ == '__main__':
|
||||
|
||||
model_path = args.model if os.path.exists(
|
||||
args.model) else snapshot_download(args.model)
|
||||
data_path = args.src_txt if args.src_txt else os.path.join(
|
||||
model_path, 'alpaca_data.json')
|
||||
|
||||
dataset_mapping_dict = {
|
||||
args.instruction: 'instruction',
|
||||
args.input: 'input',
|
||||
args.output: 'output'
|
||||
}
|
||||
if args.dataset_json_file is None:
|
||||
if args.train_dataset_name is not None and args.val_dataset_name is not None:
|
||||
train_dataset = MsDataset.load(
|
||||
args.train_dataset_name,
|
||||
subset_name=args.train_subset_name,
|
||||
split=args.train_split,
|
||||
namespace=args.train_dataset_namespace).remap_columns(
|
||||
dataset_mapping_dict)
|
||||
validation_dataset = MsDataset.load(
|
||||
args.val_dataset_name,
|
||||
subset_name=args.val_subset_name,
|
||||
split=args.val_split,
|
||||
namespace=args.val_dataset_namespace).remap_columns(
|
||||
dataset_mapping_dict)
|
||||
elif args.train_dataset_name is not None and args.val_dataset_name is None:
|
||||
ms_dataset = MsDataset.load(
|
||||
args.train_dataset_name,
|
||||
subset_name=args.train_subset_name,
|
||||
split=args.train_split,
|
||||
namespace=args.train_dataset_namespace).remap_columns(
|
||||
dataset_mapping_dict).train_test_split(
|
||||
test_size=0.02, seed=args.seed)
|
||||
train_dataset = ms_dataset['train']
|
||||
validation_dataset = ms_dataset['test']
|
||||
else:
|
||||
data_path = training_args.src_txt if training_args.src_txt else os.path.join(
|
||||
model_path, 'alpaca_data.json')
|
||||
ms_dataset = MsDataset.load(
|
||||
'json', data_files=data_path).remap_columns(
|
||||
dataset_mapping_dict).train_test_split(
|
||||
test_size=0.02, seed=args.seed)
|
||||
train_dataset = ms_dataset['train']
|
||||
validation_dataset = ms_dataset['test']
|
||||
else:
|
||||
train_dataset, validation_dataset = build_dataset_from_file(
|
||||
args.dataset_json_file)
|
||||
|
||||
model = LlamaForTextGeneration.from_pretrained(
|
||||
model_path, device_map=args.device_map)
|
||||
|
||||
@@ -283,17 +363,19 @@ if __name__ == '__main__':
|
||||
model=model,
|
||||
)
|
||||
|
||||
train_dataset = SupervisedDataset(tokenizer=tokenizer, data_path=data_path)
|
||||
train_dataset = SupervisedDataset(
|
||||
tokenizer=tokenizer, list_data_dict=train_dataset)
|
||||
validation_dataset = SupervisedDataset(
|
||||
tokenizer=tokenizer, list_data_dict=validation_dataset)
|
||||
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
|
||||
|
||||
kwargs = dict(
|
||||
model=model,
|
||||
cfg_file=os.path.join(model_path, 'configuration.json'),
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=validation_dataset,
|
||||
data_collator=data_collator,
|
||||
max_epochs=3,
|
||||
cfg_modify_fn=cfg_modify_fn,
|
||||
device='cpu')
|
||||
cfg_modify_fn=cfg_modify_fn)
|
||||
|
||||
# Construct trainer and train
|
||||
trainer = build_trainer(
|
||||
|
||||
@@ -6,4 +6,5 @@ torchrun --nproc_per_node $DATA_PARALLEL_SIZE examples/pytorch/llama/finetune_ll
|
||||
--work_dir './tmp' \
|
||||
--model 'skyline2006/llama-7b' \
|
||||
--deepspeed 'default_offload_opt_param.json' \
|
||||
--eval_interval 100
|
||||
--eval_interval 100 \
|
||||
--max_epochs 3 \
|
||||
|
||||
@@ -2,6 +2,22 @@ export PYTHONPATH=$PYTHONPATH:./
|
||||
torchrun examples/pytorch/llama/finetune_llama.py \
|
||||
--work_dir './tmp' \
|
||||
--model 'skyline2006/llama-7b' \
|
||||
--eval_interval 100 \
|
||||
--train_dataset_name 'alpaca-gpt4-data-zh' \
|
||||
--train_subset_name 'default' \
|
||||
--train_split 'train' \
|
||||
--train_dataset_namespace 'AI-ModelScope' \
|
||||
--per_device_train_batch_size 4 \
|
||||
--per_device_eval_batch_size 4 \
|
||||
--eval_strategy 'by_epoch' \
|
||||
--eval_interval 1 \
|
||||
--eval_metrics 'ppl' \
|
||||
--lr 2e-5 \
|
||||
--save_strategy no \
|
||||
--save_best true \
|
||||
--metric_for_best_model ppl \
|
||||
--metric_rule_for_best_model min \
|
||||
--use_lora 1 \
|
||||
--device_map 'auto' \
|
||||
--task 'text-generation' \
|
||||
--model.type 'llama' \
|
||||
--max_epochs 3 \
|
||||
|
||||
@@ -105,8 +105,8 @@ def llm_infer(args: InferArguments) -> None:
|
||||
top_k=args.top_k,
|
||||
top_p=args.top_p,
|
||||
do_sample=True,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id)
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
pad_token_id=tokenizer.eos_token_id)
|
||||
logger.info(f'generation_config: {generation_config}')
|
||||
|
||||
if args.eval_human:
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
CUDA_VISIBLE_DEVICES=0,1 \
|
||||
python llm_infer.py \
|
||||
--model_type qwen-7b \
|
||||
--ckpt_path "runs/qwen-7b/vx_xxx/output_best/pytorch_model.bin" \
|
||||
--model_type polylm-13b \
|
||||
--ckpt_path "runs/polylm-13b/v0-20230802-172425/output_best/pytorch_model.bin" \
|
||||
--eval_human true
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
CUDA_VISIBLE_DEVICES=0,1 \
|
||||
python llm_sft.py \
|
||||
--model_type qwen-7b \
|
||||
--model_type polylm-13b \
|
||||
--output_dir runs \
|
||||
--dataset alpaca-en,alpaca-zh \
|
||||
--dataset alpaca-en,alpaca-zh,alpaca-multi \
|
||||
--dataset_sample 20000
|
||||
|
||||
@@ -141,6 +141,7 @@ class LoRATM(NamedTuple):
|
||||
chatglm2 = ['query_key_value']
|
||||
llama2 = ['q_proj', 'k_proj', 'v_proj']
|
||||
qwen = ['c_attn']
|
||||
polylm = ['c_attn']
|
||||
|
||||
|
||||
# Reference: 'https://modelscope.cn/models/{model_id}/summary'
|
||||
|
||||
@@ -0,0 +1,107 @@
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import cv2
|
||||
|
||||
from modelscope.metainfo import Trainers
|
||||
from modelscope.msdatasets import MsDataset
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.trainers import EpochBasedTrainer, build_trainer
|
||||
from modelscope.trainers.training_args import TrainingArgs
|
||||
from modelscope.utils.constant import DownloadMode, Tasks
|
||||
|
||||
|
||||
# Load configuration file and dataset
|
||||
@dataclass(init=False)
|
||||
class StableDiffusionCones2Arguments(TrainingArgs):
|
||||
instance_prompt: str = field(
|
||||
default='a photo of sks dog',
|
||||
metadata={
|
||||
'help': 'The instance prompt for cones.',
|
||||
})
|
||||
|
||||
resolution: int = field(
|
||||
default=768, metadata={
|
||||
'help': 'The class images resolution.',
|
||||
})
|
||||
|
||||
train_batch_size: int = field(
|
||||
default=4,
|
||||
metadata={
|
||||
'help': 'Batch size (per device) for the training dataloader.',
|
||||
})
|
||||
|
||||
sample_batch_size: int = field(
|
||||
default=4,
|
||||
metadata={
|
||||
'help': 'Batch size (per device) for sampling images.',
|
||||
})
|
||||
|
||||
prompt: str = field(
|
||||
default='dog', metadata={
|
||||
'help': 'The pipeline prompt.',
|
||||
})
|
||||
|
||||
|
||||
training_args = StableDiffusionCones2Arguments(
|
||||
task='text-to-image-synthesis').parse_cli()
|
||||
config, args = training_args.to_config()
|
||||
|
||||
if os.path.exists(args.train_dataset_name):
|
||||
# Load local dataset
|
||||
train_dataset = MsDataset.load(args.train_dataset_name)
|
||||
validation_dataset = MsDataset.load(args.train_dataset_name)
|
||||
else:
|
||||
# Load online dataset
|
||||
train_dataset = MsDataset.load(
|
||||
args.train_dataset_name,
|
||||
split='train',
|
||||
download_mode=DownloadMode.FORCE_REDOWNLOAD)
|
||||
validation_dataset = MsDataset.load(
|
||||
args.train_dataset_name,
|
||||
split='validation',
|
||||
download_mode=DownloadMode.FORCE_REDOWNLOAD)
|
||||
|
||||
|
||||
def cfg_modify_fn(cfg):
|
||||
if args.use_model_config:
|
||||
cfg.merge_from_dict(config)
|
||||
else:
|
||||
cfg = config
|
||||
cfg.train.lr_scheduler = {
|
||||
'type': 'LambdaLR',
|
||||
'lr_lambda': lambda _: 1,
|
||||
'last_epoch': -1
|
||||
}
|
||||
return cfg
|
||||
|
||||
|
||||
kwargs = dict(
|
||||
model=training_args.model,
|
||||
model_revision=args.model_revision,
|
||||
work_dir=training_args.work_dir,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=validation_dataset,
|
||||
cfg_modify_fn=cfg_modify_fn)
|
||||
|
||||
trainer = build_trainer(name=Trainers.cones2_inference, default_args=kwargs)
|
||||
trainer.train()
|
||||
|
||||
# pipeline after training and save result
|
||||
pipe = pipeline(
|
||||
task=Tasks.text_to_image_synthesis,
|
||||
model=training_args.work_dir + '/output',
|
||||
model_revision=args.model_revision)
|
||||
|
||||
output = pipe({
|
||||
'text': 'a mug and a dog on the beach',
|
||||
'subject_list': [['mug', 2], ['dog', 5]],
|
||||
'color_context': {
|
||||
'255,192,0': ['mug', 2.5],
|
||||
'255,0,0': ['dog', 2.5]
|
||||
},
|
||||
'layout': 'data/test/images/mask_example.png'
|
||||
})
|
||||
# visualize the result on ipynb and save it
|
||||
output
|
||||
cv2.imwrite('./cones2_result.png', output['output_imgs'][0])
|
||||
13
examples/pytorch/stable_diffusion/cones2/run_train_cones2.sh
Normal file
13
examples/pytorch/stable_diffusion/cones2/run_train_cones2.sh
Normal file
@@ -0,0 +1,13 @@
|
||||
PYTHONPATH=. torchrun examples/pytorch/stable_diffusion/cones2/finetune_stable_diffusion_cones2.py \
|
||||
--model 'damo/Cones2' \
|
||||
--model_revision 'v1.0.1' \
|
||||
--instance_prompt="dog" \
|
||||
--work_dir './tmp/cones2_diffusion' \
|
||||
--train_dataset_name 'buptwq/lora-stable-diffusion-finetune-dog' \
|
||||
--max_epochs 250 \
|
||||
--save_ckpt_strategy 'by_epoch' \
|
||||
--logging_interval 1 \
|
||||
--train.dataloader.workers_per_gpu 0 \
|
||||
--evaluation.dataloader.workers_per_gpu 0 \
|
||||
--train.optimizer.lr 1e-5 \
|
||||
--use_model_config true
|
||||
@@ -4,30 +4,29 @@
|
||||
# TODO: handle environments without threads
|
||||
# (Python compiled without thread support)
|
||||
|
||||
import numpy as np
|
||||
import simplejson as json
|
||||
from operator import attrgetter
|
||||
from sortedcontainers import SortedList
|
||||
from datetime import datetime, timedelta, date, time
|
||||
from dateutil.parser import parse as parse_datetime
|
||||
from functools import wraps, partial
|
||||
from operator import methodcaller
|
||||
from decimal import Decimal
|
||||
from fractions import Fraction
|
||||
from collections import namedtuple
|
||||
import threading
|
||||
import uuid
|
||||
import numpy as np
|
||||
from collections import namedtuple
|
||||
from datetime import date, datetime, time, timedelta
|
||||
from dateutil.parser import parse as parse_datetime
|
||||
from decimal import Decimal
|
||||
from fractions import Fraction
|
||||
from functools import partial, wraps
|
||||
from operator import attrgetter, methodcaller
|
||||
from sortedcontainers import SortedList
|
||||
|
||||
try:
|
||||
from moneyed import Money, Currency
|
||||
from moneyed import Currency, Money
|
||||
except ImportError:
|
||||
# defer failing to actual (de-)serialization
|
||||
pass
|
||||
|
||||
__all__ = ["loads", "dumps", "pretty",
|
||||
"json_loads", "json_dumps", "json_prettydump",
|
||||
"encoder", "decoder"]
|
||||
|
||||
__all__ = [
|
||||
"loads", "dumps", "pretty", "json_loads", "json_dumps", "json_prettydump",
|
||||
"encoder", "decoder"
|
||||
]
|
||||
|
||||
# Should we aim for the *exact* reproduction of Python types,
|
||||
# or for maximum *compatibility* when (de-)serializing?
|
||||
@@ -59,12 +58,15 @@ CODING_DEFAULT = EXACT
|
||||
|
||||
_local = threading.local()
|
||||
|
||||
|
||||
def prefer(coding):
|
||||
_local.coding = coding
|
||||
|
||||
|
||||
def prefer_exact():
|
||||
prefer(EXACT)
|
||||
|
||||
|
||||
def prefer_compat():
|
||||
prefer(COMPAT)
|
||||
|
||||
@@ -103,15 +105,18 @@ def kwargified(constructor):
|
||||
>>> test({'b': 3})
|
||||
4
|
||||
"""
|
||||
|
||||
@wraps(constructor)
|
||||
def kwargs_constructor(kwargs):
|
||||
return constructor(**kwargs)
|
||||
|
||||
return kwargs_constructor
|
||||
|
||||
|
||||
_PredicatedEncoder = namedtuple('_PredicatedEncoder',
|
||||
'priority predicate encoder typename')
|
||||
|
||||
|
||||
def encoder(classname, predicate=None, priority=None, exact=True):
|
||||
"""A decorator for registering a new encoder for object type
|
||||
defined either by a `classname`, or detected via `predicate`.
|
||||
@@ -182,14 +187,18 @@ def _json_default_exact(obj):
|
||||
# first try predicate-based encoders
|
||||
for handler in _encode_handlers['exact']['predicate']:
|
||||
if handler.predicate(obj):
|
||||
return {"__class__": handler.typename,
|
||||
"__value__": handler.encoder(obj)}
|
||||
return {
|
||||
"__class__": handler.typename,
|
||||
"__value__": handler.encoder(obj)
|
||||
}
|
||||
|
||||
# then classname-based
|
||||
classname = type(obj).__name__
|
||||
if classname in _encode_handlers['exact']['classname']:
|
||||
return {"__class__": classname,
|
||||
"__value__": _encode_handlers['exact']['classname'][classname](obj)}
|
||||
return {
|
||||
"__class__": classname,
|
||||
"__value__": _encode_handlers['exact']['classname'][classname](obj)
|
||||
}
|
||||
|
||||
raise TypeError(repr(obj) + " is not JSON serializable")
|
||||
|
||||
@@ -217,8 +226,10 @@ def decoder(classname):
|
||||
def mytype_decoder(value):
|
||||
return mytype(value, reconstruct=True)
|
||||
"""
|
||||
|
||||
def _decorator(f):
|
||||
_decode_handlers.setdefault(classname, f)
|
||||
|
||||
return _decorator
|
||||
|
||||
|
||||
@@ -235,25 +246,25 @@ def _json_object_hook(dict):
|
||||
return dict
|
||||
|
||||
|
||||
|
||||
def _encoder_default_args(kw):
|
||||
"""Shape default arguments for encoding functions."""
|
||||
|
||||
|
||||
# manual override of the preferred coding with `exact=False`
|
||||
if kw.pop('exact', getattr(_local, 'coding', CODING_DEFAULT) == EXACT):
|
||||
# settings necessary for the "exact coding"
|
||||
kw.update({
|
||||
'default': _json_default_exact,
|
||||
'use_decimal': False, # don't encode `Decimal` as JSON's `Number`
|
||||
'tuple_as_array': False, # don't encode `tuple` as `Array`
|
||||
'namedtuple_as_object': False # don't call `_asdict` on `namedtuple`
|
||||
'use_decimal': False, # don't encode `Decimal` as JSON's `Number`
|
||||
'tuple_as_array': False, # don't encode `tuple` as `Array`
|
||||
'namedtuple_as_object':
|
||||
False # don't call `_asdict` on `namedtuple`
|
||||
})
|
||||
else:
|
||||
# settings for the "compatibility coding"
|
||||
kw.update({
|
||||
'default': _json_default_compat,
|
||||
'ignore_nan': True # be compliant with the ECMA-262 specification:
|
||||
# serialize nan/inf as null
|
||||
'ignore_nan': True # be compliant with the ECMA-262 specification:
|
||||
# serialize nan/inf as null
|
||||
})
|
||||
|
||||
# NOTE: if called from ``simplejson.dumps()`` with ``cls=JSONEncoder``,
|
||||
@@ -276,8 +287,8 @@ def _decoder_default_args(kw):
|
||||
kw.update({'object_hook': _json_object_hook})
|
||||
|
||||
|
||||
|
||||
class JSONEncoder(json.JSONEncoder):
|
||||
|
||||
def __init__(self, **kw):
|
||||
"""Constructor for simplejson.JSONEncoder, with defaults overriden
|
||||
for jsonplus.
|
||||
@@ -287,6 +298,7 @@ class JSONEncoder(json.JSONEncoder):
|
||||
|
||||
|
||||
class JSONDecoder(json.JSONDecoder):
|
||||
|
||||
def __init__(self, **kw):
|
||||
"""Constructor for simplejson.JSONDecoder, with defaults overriden
|
||||
for jsonplus.
|
||||
@@ -295,7 +307,6 @@ class JSONDecoder(json.JSONDecoder):
|
||||
super(JSONDecoder, self).__init__(**kw)
|
||||
|
||||
|
||||
|
||||
def dumps(*pa, **kw):
|
||||
_encoder_default_args(kw)
|
||||
return json.dumps(*pa, **kw)
|
||||
@@ -306,14 +317,13 @@ def loads(*pa, **kw):
|
||||
return json.loads(*pa, **kw)
|
||||
|
||||
|
||||
def pretty(x, sort_keys=True, indent=4*' ', separators=(',', ': '), **kw):
|
||||
def pretty(x, sort_keys=True, indent=4 * ' ', separators=(',', ': '), **kw):
|
||||
kw.setdefault('sort_keys', sort_keys)
|
||||
kw.setdefault('indent', indent)
|
||||
kw.setdefault('separators', separators)
|
||||
return dumps(x, **kw)
|
||||
|
||||
|
||||
|
||||
json_dumps = dumps
|
||||
json_loads = loads
|
||||
json_prettydump = pretty
|
||||
@@ -330,21 +340,36 @@ def generic_to_item(value):
|
||||
_encode_handlers = {
|
||||
'exact': {
|
||||
'classname': {
|
||||
'datetime': methodcaller('isoformat'),
|
||||
'date': methodcaller('isoformat'),
|
||||
'time': methodcaller('isoformat'),
|
||||
'timedelta': partial(getattrs, attrs=['days', 'seconds', 'microseconds']),
|
||||
'tuple': list,
|
||||
'set': list,
|
||||
'ndarray': np_to_list,
|
||||
'float16': generic_to_item,
|
||||
'float32': generic_to_item,
|
||||
'frozenset': list,
|
||||
'complex': partial(getattrs, attrs=['real', 'imag']),
|
||||
'Decimal': str,
|
||||
'Fraction': partial(getattrs, attrs=['numerator', 'denominator']),
|
||||
'UUID': partial(getattrs, attrs=['hex']),
|
||||
'Money': partial(getattrs, attrs=['amount', 'currency'])
|
||||
'datetime':
|
||||
methodcaller('isoformat'),
|
||||
'date':
|
||||
methodcaller('isoformat'),
|
||||
'time':
|
||||
methodcaller('isoformat'),
|
||||
'timedelta':
|
||||
partial(getattrs, attrs=['days', 'seconds', 'microseconds']),
|
||||
'tuple':
|
||||
list,
|
||||
'set':
|
||||
list,
|
||||
'ndarray':
|
||||
np_to_list,
|
||||
'float16':
|
||||
generic_to_item,
|
||||
'float32':
|
||||
generic_to_item,
|
||||
'frozenset':
|
||||
list,
|
||||
'complex':
|
||||
partial(getattrs, attrs=['real', 'imag']),
|
||||
'Decimal':
|
||||
str,
|
||||
'Fraction':
|
||||
partial(getattrs, attrs=['numerator', 'denominator']),
|
||||
'UUID':
|
||||
partial(getattrs, attrs=['hex']),
|
||||
'Money':
|
||||
partial(getattrs, attrs=['amount', 'currency'])
|
||||
},
|
||||
'predicate': SortedList(key=attrgetter('priority'))
|
||||
},
|
||||
@@ -368,7 +393,6 @@ _encode_handlers = {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# all decode handlers are for EXACT decoding BY CLASSNAME
|
||||
_decode_handlers = {
|
||||
'datetime': parse_datetime,
|
||||
@@ -388,11 +412,14 @@ _decode_handlers = {
|
||||
}
|
||||
|
||||
|
||||
@encoder('namedtuple', lambda obj: isinstance(obj, tuple) and hasattr(obj, '_fields'))
|
||||
@encoder('namedtuple',
|
||||
lambda obj: isinstance(obj, tuple) and hasattr(obj, '_fields'))
|
||||
def _dump_namedtuple(obj):
|
||||
return {"name": type(obj).__name__,
|
||||
"fields": list(obj._fields),
|
||||
"values": list(obj)}
|
||||
return {
|
||||
"name": type(obj).__name__,
|
||||
"fields": list(obj._fields),
|
||||
"values": list(obj)
|
||||
}
|
||||
|
||||
|
||||
@decoder('namedtuple')
|
||||
@@ -404,7 +431,8 @@ def _load_namedtuple(val):
|
||||
@encoder('timedelta', exact=False)
|
||||
def _timedelta_total_seconds(td):
|
||||
# timedelta.total_seconds() is only available since python 2.7
|
||||
return (td.microseconds + (td.seconds + td.days * 24 * 3600.0) * 10**6) / 10**6
|
||||
return (td.microseconds +
|
||||
(td.seconds + td.days * 24 * 3600.0) * 10**6) / 10**6
|
||||
|
||||
|
||||
@encoder('Currency')
|
||||
@@ -412,7 +440,7 @@ def _dump_currency(obj):
|
||||
"""Serialize standard (ISO-defined) currencies to currency code only,
|
||||
and non-standard (user-added) currencies in full.
|
||||
"""
|
||||
from moneyed import get_currency, CurrencyDoesNotExist
|
||||
from moneyed import CurrencyDoesNotExist, get_currency
|
||||
try:
|
||||
get_currency(obj.code)
|
||||
return obj.code
|
||||
|
||||
@@ -114,6 +114,7 @@ class Models(object):
|
||||
nerf_recon_acc = 'nerf-recon-acc'
|
||||
nerf_recon_4k = 'nerf-recon-4k'
|
||||
nerf_recon_vq_compression = 'nerf-recon-vq-compression'
|
||||
surface_recon_common = 'surface-recon-common'
|
||||
bts_depth_estimation = 'bts-depth-estimation'
|
||||
vision_efficient_tuning = 'vision-efficient-tuning'
|
||||
bad_image_detecting = 'bad-image-detecting'
|
||||
@@ -122,6 +123,7 @@ class Models(object):
|
||||
fastinst = 'fastinst'
|
||||
pedestrian_attribute_recognition = 'pedestrian-attribute-recognition'
|
||||
image_try_on = 'image-try-on'
|
||||
human_image_generation = 'human-image-generation'
|
||||
|
||||
# nlp models
|
||||
bert = 'bert'
|
||||
@@ -183,6 +185,7 @@ class Models(object):
|
||||
speech_dfsmn_kws_char_farfield_iot = 'speech_dfsmn_kws_char_farfield_iot'
|
||||
speech_kws_fsmn_char_ctc_nearfield = 'speech_kws_fsmn_char_ctc_nearfield'
|
||||
speech_mossformer_separation_temporal_8k = 'speech_mossformer_separation_temporal_8k'
|
||||
speech_mossformer2_separation_temporal_8k = 'speech_mossformer2_separation_temporal_8k'
|
||||
kws_kwsbp = 'kws-kwsbp'
|
||||
generic_asr = 'generic-asr'
|
||||
wenet_asr = 'wenet-asr'
|
||||
@@ -195,6 +198,7 @@ class Models(object):
|
||||
eres2net_aug_sv = 'eres2net-aug-sv'
|
||||
scl_sd = 'scl-sd'
|
||||
campplus_lre = 'cam++-lre'
|
||||
eres2net_lre = 'eres2net-lre'
|
||||
cluster_backend = 'cluster-backend'
|
||||
rdino_tdnn_sv = 'rdino_ecapa-tdnn-sv'
|
||||
generic_lm = 'generic-lm'
|
||||
@@ -210,12 +214,15 @@ class Models(object):
|
||||
video_synthesis = 'latent-text-to-video-synthesis'
|
||||
team = 'team-multi-modal-similarity'
|
||||
video_clip = 'video-clip-multi-modal-embedding'
|
||||
prost = 'prost-clip-text-video-retrieval'
|
||||
mgeo = 'mgeo'
|
||||
vldoc = 'vldoc'
|
||||
hitea = 'hitea'
|
||||
soonet = 'soonet'
|
||||
efficient_diffusion_tuning = 'efficient-diffusion-tuning'
|
||||
cones2_inference = 'cones2-inference'
|
||||
mplug_owl = 'mplug-owl'
|
||||
|
||||
clip_interrogator = 'clip-interrogator'
|
||||
stable_diffusion = 'stable-diffusion'
|
||||
stable_diffusion_xl = 'stable-diffusion-xl'
|
||||
@@ -280,6 +287,7 @@ class Pipelines(object):
|
||||
universal_matting = 'unet-universal-matting'
|
||||
image_denoise = 'nafnet-image-denoise'
|
||||
image_deblur = 'nafnet-image-deblur'
|
||||
image_editing = 'masactrl-image-editing'
|
||||
person_image_cartoon = 'unet-person-image-cartoon'
|
||||
ocr_detection = 'resnet18-ocr-detection'
|
||||
table_recognition = 'dla34-table-recognition'
|
||||
@@ -420,6 +428,7 @@ class Pipelines(object):
|
||||
nerf_recon_acc = 'nerf-recon-acc'
|
||||
nerf_recon_4k = 'nerf-recon-4k'
|
||||
nerf_recon_vq_compression = 'nerf-recon-vq-compression'
|
||||
surface_recon_common = 'surface-recon-common'
|
||||
bad_image_detecting = 'bad-image-detecting'
|
||||
controllable_image_generation = 'controllable-image-generation'
|
||||
fast_instance_segmentation = 'fast-instance-segmentation'
|
||||
@@ -431,6 +440,7 @@ class Pipelines(object):
|
||||
pedestrian_attribute_recognition = 'resnet50_pedestrian-attribute-recognition_image'
|
||||
text_to_360panorama_image = 'text-to-360panorama-image'
|
||||
image_try_on = 'image-try-on'
|
||||
human_image_generation = 'human-image-generation'
|
||||
|
||||
# nlp tasks
|
||||
automatic_post_editing = 'automatic-post-editing'
|
||||
@@ -508,10 +518,12 @@ class Pipelines(object):
|
||||
sv_inference = 'sv-inference'
|
||||
speaker_diarization_inference = 'speaker-diarization-inference'
|
||||
vad_inference = 'vad-inference'
|
||||
funasr_speech_separation = 'funasr-speech-separation'
|
||||
speaker_verification = 'speaker-verification'
|
||||
speaker_verification_rdino = 'speaker-verification-rdino'
|
||||
speaker_verification_eres2net = 'speaker-verification-eres2net'
|
||||
speech_language_recognition = 'speech-language-recognition'
|
||||
speech_language_recognition_eres2net = 'speech-language-recognition-eres2net'
|
||||
speaker_change_locating = 'speaker-change-locating'
|
||||
speaker_diarization_dialogue_detection = 'speaker-diarization-dialogue-detection'
|
||||
speaker_diarization_semantic_speaker_turn_detection = 'speaker-diarization-semantic-speaker-turn-detection'
|
||||
@@ -529,6 +541,7 @@ class Pipelines(object):
|
||||
multi_modal_similarity = 'multi-modal-similarity'
|
||||
text_to_image_synthesis = 'text-to-image-synthesis'
|
||||
video_multi_modal_embedding = 'video-multi-modal-embedding'
|
||||
prost_text_video_retrieval = 'prost-text-video-retrieval'
|
||||
videocomposer = 'videocomposer'
|
||||
image_text_retrieval = 'image-text-retrieval'
|
||||
ofa_ocr_recognition = 'ofa-ocr-recognition'
|
||||
@@ -541,6 +554,7 @@ class Pipelines(object):
|
||||
disco_guided_diffusion = 'disco_guided_diffusion'
|
||||
document_vl_embedding = 'document-vl-embedding'
|
||||
chinese_stable_diffusion = 'chinese-stable-diffusion'
|
||||
cones2_inference = 'cones2-inference'
|
||||
text_to_video_synthesis = 'latent-text-to-video-synthesis' # latent-text-to-video-synthesis
|
||||
gridvlp_multi_modal_classification = 'gridvlp-multi-modal-classification'
|
||||
gridvlp_multi_modal_embedding = 'gridvlp-multi-modal-embedding'
|
||||
@@ -605,6 +619,8 @@ DEFAULT_MODEL_FOR_PIPELINE = {
|
||||
'damo/cv_nafnet_image-denoise_sidd'),
|
||||
Tasks.image_deblurring: (Pipelines.image_deblur,
|
||||
'damo/cv_nafnet_image-deblur_gopro'),
|
||||
Tasks.image_editing: (Pipelines.image_editing,
|
||||
'damo/cv_masactrl_image-editing'),
|
||||
Tasks.video_stabilization: (Pipelines.video_stabilization,
|
||||
'damo/cv_dut-raft_video-stabilization_base'),
|
||||
Tasks.video_super_resolution:
|
||||
@@ -724,6 +740,8 @@ DEFAULT_MODEL_FOR_PIPELINE = {
|
||||
Tasks.video_multi_modal_embedding:
|
||||
(Pipelines.video_multi_modal_embedding,
|
||||
'damo/multi_modal_clip_vtretrival_msrvtt_53'),
|
||||
Tasks.text_video_retrieval: (Pipelines.prost_text_video_retrieval,
|
||||
'damo/multi_modal_clip_vtretrieval_prost'),
|
||||
Tasks.image_color_enhancement:
|
||||
(Pipelines.image_color_enhance,
|
||||
'damo/cv_csrnet_image-color-enhance-models'),
|
||||
@@ -875,6 +893,8 @@ DEFAULT_MODEL_FOR_PIPELINE = {
|
||||
Tasks.nerf_recon_vq_compression: (
|
||||
Pipelines.nerf_recon_vq_compression,
|
||||
'damo/cv_nerf-3d-reconstruction-vq-compression_damo'),
|
||||
Tasks.surface_recon_common: (Pipelines.surface_recon_common,
|
||||
'damo/cv_surface-reconstruction-common'),
|
||||
Tasks.siamese_uie: (Pipelines.siamese_uie,
|
||||
'damo/nlp_structbert_siamese-uie_chinese-base'),
|
||||
Tasks.pedestrian_attribute_recognition: (
|
||||
@@ -884,7 +904,9 @@ DEFAULT_MODEL_FOR_PIPELINE = {
|
||||
Pipelines.text_to_360panorama_image,
|
||||
'damo/cv_diffusion_text-to-360panorama-image_generation'),
|
||||
Tasks.image_try_on: (Pipelines.image_try_on,
|
||||
'damo/cv_SAL-VTON_virtual-try-on')
|
||||
'damo/cv_SAL-VTON_virtual-try-on'),
|
||||
Tasks.human_image_generation: (Pipelines.human_image_generation,
|
||||
'damo/cv_FreqHPT_human-image-generation')
|
||||
}
|
||||
|
||||
|
||||
@@ -942,6 +964,7 @@ class MultiModalTrainers(object):
|
||||
lora_diffusion_xl = 'lora-diffusion-xl'
|
||||
dreambooth_diffusion = 'dreambooth-diffusion'
|
||||
custom_diffusion = 'custom-diffusion'
|
||||
cones2_inference = 'cones2-inference'
|
||||
|
||||
|
||||
class AudioTrainers(object):
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from . import ans, asr, itn, kws, sv, tts
|
||||
from . import ans, asr, itn, kws, separation, sv, tts
|
||||
|
||||
@@ -15,6 +15,8 @@ __all__ = ['GenericAutomaticSpeechRecognition']
|
||||
Tasks.auto_speech_recognition, module_name=Models.generic_asr)
|
||||
@MODELS.register_module(
|
||||
Tasks.voice_activity_detection, module_name=Models.generic_asr)
|
||||
@MODELS.register_module(
|
||||
Tasks.speech_separation, module_name=Models.generic_asr)
|
||||
@MODELS.register_module(
|
||||
Tasks.language_score_prediction, module_name=Models.generic_asr)
|
||||
@MODELS.register_module(Tasks.speech_timestamp, module_name=Models.generic_asr)
|
||||
|
||||
@@ -0,0 +1,24 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .mossformer import MossFormer
|
||||
from .m2.mossformer import MossFormer2
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
'mossformer': ['MossFormer'],
|
||||
'm2.mossformer': ['MossFormer2'],
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
|
||||
0
modelscope/models/audio/separation/m2/__init__.py
Normal file
0
modelscope/models/audio/separation/m2/__init__.py
Normal file
278
modelscope/models/audio/separation/m2/conv_module.py
Normal file
278
modelscope/models/audio/separation/m2/conv_module.py
Normal file
@@ -0,0 +1,278 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.init as init
|
||||
from torch import Tensor
|
||||
|
||||
EPS = 1e-8
|
||||
|
||||
|
||||
class GlobalLayerNorm(nn.Module):
|
||||
"""Calculate Global Layer Normalization.
|
||||
|
||||
Args:
|
||||
dim : (int or list or torch.Size)
|
||||
Input shape from an expected input of size.
|
||||
eps : float
|
||||
A value added to the denominator for numerical stability.
|
||||
elementwise_affine : bool
|
||||
A boolean value that when set to True,
|
||||
this module has learnable per-element affine parameters
|
||||
initialized to ones (for weights) and zeros (for biases).
|
||||
|
||||
Example:
|
||||
-------
|
||||
>>> x = torch.randn(5, 10, 20)
|
||||
>>> GLN = GlobalLayerNorm(10, 3)
|
||||
>>> x_norm = GLN(x)
|
||||
"""
|
||||
|
||||
def __init__(self, dim, shape, eps=1e-8, elementwise_affine=True):
|
||||
super(GlobalLayerNorm, self).__init__()
|
||||
self.dim = dim
|
||||
self.eps = eps
|
||||
self.elementwise_affine = elementwise_affine
|
||||
|
||||
if self.elementwise_affine:
|
||||
if shape == 3:
|
||||
self.weight = nn.Parameter(torch.ones(self.dim, 1))
|
||||
self.bias = nn.Parameter(torch.zeros(self.dim, 1))
|
||||
if shape == 4:
|
||||
self.weight = nn.Parameter(torch.ones(self.dim, 1, 1))
|
||||
self.bias = nn.Parameter(torch.zeros(self.dim, 1, 1))
|
||||
else:
|
||||
self.register_parameter('weight', None)
|
||||
self.register_parameter('bias', None)
|
||||
|
||||
def forward(self, x):
|
||||
"""Returns the normalized tensor.
|
||||
|
||||
Args:
|
||||
x : torch.Tensor
|
||||
Tensor of size [N, C, K, S] or [N, C, L].
|
||||
"""
|
||||
# x = N x C x K x S or N x C x L
|
||||
# N x 1 x 1
|
||||
# cln: mean,var N x 1 x K x S
|
||||
# gln: mean,var N x 1 x 1
|
||||
if x.dim() == 3:
|
||||
mean = torch.mean(x, (1, 2), keepdim=True)
|
||||
var = torch.mean((x - mean)**2, (1, 2), keepdim=True)
|
||||
if self.elementwise_affine:
|
||||
# yapf: disable
|
||||
x = (self.weight * (x - mean) / torch.sqrt(var + self.eps)
|
||||
+ self.bias)
|
||||
# yapf: enable
|
||||
else:
|
||||
x = (x - mean) / torch.sqrt(var + self.eps)
|
||||
|
||||
if x.dim() == 4:
|
||||
mean = torch.mean(x, (1, 2, 3), keepdim=True)
|
||||
var = torch.mean((x - mean)**2, (1, 2, 3), keepdim=True)
|
||||
if self.elementwise_affine:
|
||||
# yapf: disable
|
||||
x = (self.weight * (x - mean) / torch.sqrt(var + self.eps)
|
||||
+ self.bias)
|
||||
# yapf: enable
|
||||
else:
|
||||
x = (x - mean) / torch.sqrt(var + self.eps)
|
||||
return x
|
||||
|
||||
|
||||
class CumulativeLayerNorm(nn.LayerNorm):
|
||||
"""Calculate Cumulative Layer Normalization.
|
||||
|
||||
Args:
|
||||
dim : int
|
||||
Dimension that you want to normalize.
|
||||
elementwise_affine : True
|
||||
Learnable per-element affine parameters.
|
||||
|
||||
Example:
|
||||
-------
|
||||
>>> x = torch.randn(5, 10, 20)
|
||||
>>> CLN = CumulativeLayerNorm(10)
|
||||
>>> x_norm = CLN(x)
|
||||
"""
|
||||
|
||||
def __init__(self, dim, elementwise_affine=True):
|
||||
super(CumulativeLayerNorm, self).__init__(
|
||||
dim, elementwise_affine=elementwise_affine, eps=1e-8)
|
||||
|
||||
def forward(self, x):
|
||||
"""Returns the normalized tensor.
|
||||
|
||||
Args:
|
||||
x : torch.Tensor
|
||||
Tensor size [N, C, K, S] or [N, C, L]
|
||||
"""
|
||||
# x: N x C x K x S or N x C x L
|
||||
# N x K x S x C
|
||||
if x.dim() == 4:
|
||||
x = x.permute(0, 2, 3, 1).contiguous()
|
||||
# N x K x S x C == only channel norm
|
||||
x = super().forward(x)
|
||||
# N x C x K x S
|
||||
x = x.permute(0, 3, 1, 2).contiguous()
|
||||
if x.dim() == 3:
|
||||
x = torch.transpose(x, 1, 2)
|
||||
# N x L x C == only channel norm
|
||||
x = super().forward(x)
|
||||
# N x C x L
|
||||
x = torch.transpose(x, 1, 2)
|
||||
return x
|
||||
|
||||
|
||||
class Transpose(nn.Module):
|
||||
""" Wrapper class of torch.transpose() for Sequential module. """
|
||||
|
||||
def __init__(self, shape: tuple):
|
||||
super(Transpose, self).__init__()
|
||||
self.shape = shape
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return x.transpose(*self.shape)
|
||||
|
||||
|
||||
class DepthwiseConv1d(nn.Module):
|
||||
"""When groups == in_channels and out_channels == K * in_channels, where K is a positive integer,
|
||||
this operation is termed in literature as depthwise convolution.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of channels in the input
|
||||
out_channels (int): Number of channels produced by the convolution
|
||||
kernel_size (int or tuple): Size of the convolving kernel
|
||||
stride (int, optional): Stride of the convolution. Default: 1
|
||||
padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
|
||||
bias (bool, optional): If True, adds a learnable bias to the output. Default: True
|
||||
Inputs: inputs
|
||||
- **inputs** (batch, in_channels, time): Tensor containing input vector
|
||||
Returns: outputs
|
||||
- **outputs** (batch, out_channels, time): Tensor produces by depthwise 1-D convolution.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int,
|
||||
stride: int = 1,
|
||||
padding: int = 0,
|
||||
bias: bool = False,
|
||||
) -> None:
|
||||
super(DepthwiseConv1d, self).__init__()
|
||||
assert out_channels % in_channels == 0, 'out_channels should be constant multiple of in_channels'
|
||||
self.conv = nn.Conv1d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
groups=in_channels,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
def forward(self, inputs: Tensor) -> Tensor:
|
||||
return self.conv(inputs)
|
||||
|
||||
|
||||
class ConvModule(nn.Module):
|
||||
"""
|
||||
Conformer convolution module starts with a pointwise convolution and a gated linear unit (GLU).
|
||||
This is followed by a single 1-D depthwise convolution layer. Batchnorm is deployed just after the convolution
|
||||
to aid training deep models.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of channels in the input
|
||||
kernel_size (int or tuple, optional): Size of the convolving kernel Default: 17
|
||||
dropout_p (float, optional): probability of dropout
|
||||
Inputs: inputs
|
||||
inputs (batch, time, dim): Tensor contains input sequences
|
||||
Outputs: outputs
|
||||
outputs (batch, time, dim): Tensor produces by conformer convolution module.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
kernel_size: int = 17,
|
||||
expansion_factor: int = 2,
|
||||
dropout_p: float = 0.1,
|
||||
) -> None:
|
||||
super(ConvModule, self).__init__()
|
||||
assert (
|
||||
kernel_size - 1
|
||||
) % 2 == 0, "kernel_size should be a odd number for 'SAME' padding"
|
||||
assert expansion_factor == 2, 'Currently, Only Supports expansion_factor 2'
|
||||
|
||||
self.sequential = nn.Sequential(
|
||||
Transpose(shape=(1, 2)),
|
||||
DepthwiseConv1d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=(kernel_size - 1) // 2),
|
||||
)
|
||||
|
||||
def forward(self, inputs: Tensor) -> Tensor:
|
||||
return inputs + self.sequential(inputs).transpose(1, 2)
|
||||
|
||||
|
||||
class DilatedDenseNet(nn.Module):
|
||||
|
||||
def __init__(self, depth=4, lorder=20, in_channels=64):
|
||||
super(DilatedDenseNet, self).__init__()
|
||||
self.depth = depth
|
||||
self.in_channels = in_channels
|
||||
self.pad = nn.ConstantPad2d((1, 1, 1, 0), value=0.)
|
||||
self.twidth = lorder * 2 - 1
|
||||
self.kernel_size = (self.twidth, 1)
|
||||
for i in range(self.depth):
|
||||
dil = 2**i
|
||||
pad_length = lorder + (dil - 1) * (lorder - 1) - 1
|
||||
setattr(self, 'pad{}'.format(i + 1),
|
||||
nn.ConstantPad2d((0, 0, pad_length, pad_length), value=0.))
|
||||
setattr(
|
||||
self, 'conv{}'.format(i + 1),
|
||||
nn.Conv2d(
|
||||
self.in_channels * (i + 1),
|
||||
self.in_channels,
|
||||
kernel_size=self.kernel_size,
|
||||
dilation=(dil, 1),
|
||||
groups=self.in_channels,
|
||||
bias=False))
|
||||
setattr(self, 'norm{}'.format(i + 1),
|
||||
nn.InstanceNorm2d(in_channels, affine=True))
|
||||
setattr(self, 'prelu{}'.format(i + 1), nn.PReLU(self.in_channels))
|
||||
|
||||
def forward(self, x):
|
||||
x = torch.unsqueeze(x, 1)
|
||||
x_per = x.permute(0, 3, 2, 1)
|
||||
skip = x_per
|
||||
for i in range(self.depth):
|
||||
out = getattr(self, 'pad{}'.format(i + 1))(skip)
|
||||
out = getattr(self, 'conv{}'.format(i + 1))(out)
|
||||
out = getattr(self, 'norm{}'.format(i + 1))(out)
|
||||
out = getattr(self, 'prelu{}'.format(i + 1))(out)
|
||||
skip = torch.cat([out, skip], dim=1)
|
||||
out1 = out.permute(0, 3, 2, 1)
|
||||
return out1.squeeze(1)
|
||||
|
||||
|
||||
class FFConvMDilated(nn.Module):
|
||||
|
||||
def __init__(self, dim_in, dim_out, norm_klass=nn.LayerNorm, dropout=0.1):
|
||||
super().__init__()
|
||||
self.mdl = nn.Sequential(
|
||||
norm_klass(dim_in), nn.Linear(dim_in, dim_out), nn.SiLU(),
|
||||
DilatedDenseNet(depth=2, lorder=17, in_channels=dim_out),
|
||||
nn.Dropout(dropout))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
):
|
||||
output = self.mdl(x)
|
||||
return output
|
||||
144
modelscope/models/audio/separation/m2/fsmn.py
Normal file
144
modelscope/models/audio/separation/m2/fsmn.py
Normal file
@@ -0,0 +1,144 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import torch as th
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class UniDeepFsmn(nn.Module):
|
||||
|
||||
def __init__(self, input_dim, output_dim, lorder=None, hidden_size=None):
|
||||
super(UniDeepFsmn, self).__init__()
|
||||
|
||||
self.input_dim = input_dim
|
||||
self.output_dim = output_dim
|
||||
if lorder is None:
|
||||
return
|
||||
self.lorder = lorder
|
||||
self.hidden_size = hidden_size
|
||||
self.linear = nn.Linear(input_dim, hidden_size)
|
||||
self.project = nn.Linear(hidden_size, output_dim, bias=False)
|
||||
self.conv1 = nn.Conv2d(
|
||||
output_dim,
|
||||
output_dim, [lorder + lorder - 1, 1], [1, 1],
|
||||
groups=output_dim,
|
||||
bias=False)
|
||||
|
||||
def forward(self, input):
|
||||
f1 = F.relu(self.linear(input))
|
||||
p1 = self.project(f1)
|
||||
x = th.unsqueeze(p1, 1)
|
||||
x_per = x.permute(0, 3, 2, 1)
|
||||
y = F.pad(x_per, [0, 0, self.lorder - 1, self.lorder - 1])
|
||||
out = x_per + self.conv1(y)
|
||||
out1 = out.permute(0, 3, 2, 1)
|
||||
return input + out1.squeeze()
|
||||
|
||||
|
||||
class UniDeepFsmnDual(nn.Module):
|
||||
|
||||
def __init__(self, input_dim, output_dim, lorder=None, hidden_size=None):
|
||||
super(UniDeepFsmnDual, self).__init__()
|
||||
|
||||
self.input_dim = input_dim
|
||||
self.output_dim = output_dim
|
||||
if lorder is None:
|
||||
return
|
||||
self.lorder = lorder
|
||||
self.hidden_size = hidden_size
|
||||
self.linear = nn.Linear(input_dim, hidden_size)
|
||||
self.project = nn.Linear(hidden_size, output_dim, bias=False)
|
||||
self.conv1 = nn.Conv2d(
|
||||
output_dim,
|
||||
output_dim, [lorder + lorder - 1, 1], [1, 1],
|
||||
groups=output_dim,
|
||||
bias=False)
|
||||
self.conv2 = nn.Conv2d(
|
||||
output_dim,
|
||||
output_dim, [lorder + lorder - 1, 1], [1, 1],
|
||||
groups=output_dim // 4,
|
||||
bias=False)
|
||||
|
||||
def forward(self, input):
|
||||
|
||||
f1 = F.relu(self.linear(input))
|
||||
p1 = self.project(f1)
|
||||
x = th.unsqueeze(p1, 1)
|
||||
x_per = x.permute(0, 3, 2, 1)
|
||||
y = F.pad(x_per, [0, 0, self.lorder - 1, self.lorder - 1])
|
||||
conv1_out = x_per + self.conv1(y)
|
||||
z = F.pad(conv1_out, [0, 0, self.lorder - 1, self.lorder - 1])
|
||||
out = conv1_out + self.conv2(z)
|
||||
out1 = out.permute(0, 3, 2, 1)
|
||||
return input + out1.squeeze()
|
||||
|
||||
|
||||
class DilatedDenseNet(nn.Module):
|
||||
|
||||
def __init__(self, depth=4, lorder=20, in_channels=64):
|
||||
super(DilatedDenseNet, self).__init__()
|
||||
self.depth = depth
|
||||
self.in_channels = in_channels
|
||||
self.pad = nn.ConstantPad2d((1, 1, 1, 0), value=0.)
|
||||
self.twidth = lorder * 2 - 1
|
||||
self.kernel_size = (self.twidth, 1)
|
||||
for i in range(self.depth):
|
||||
dil = 2**i
|
||||
pad_length = lorder + (dil - 1) * (lorder - 1) - 1
|
||||
setattr(self, 'pad{}'.format(i + 1),
|
||||
nn.ConstantPad2d((0, 0, pad_length, pad_length), value=0.))
|
||||
setattr(
|
||||
self, 'conv{}'.format(i + 1),
|
||||
nn.Conv2d(
|
||||
self.in_channels * (i + 1),
|
||||
self.in_channels,
|
||||
kernel_size=self.kernel_size,
|
||||
dilation=(dil, 1),
|
||||
groups=self.in_channels,
|
||||
bias=False))
|
||||
setattr(self, 'norm{}'.format(i + 1),
|
||||
nn.InstanceNorm2d(in_channels, affine=True))
|
||||
setattr(self, 'prelu{}'.format(i + 1), nn.PReLU(self.in_channels))
|
||||
|
||||
def forward(self, x):
|
||||
skip = x
|
||||
for i in range(self.depth):
|
||||
out = getattr(self, 'pad{}'.format(i + 1))(skip)
|
||||
out = getattr(self, 'conv{}'.format(i + 1))(out)
|
||||
out = getattr(self, 'norm{}'.format(i + 1))(out)
|
||||
out = getattr(self, 'prelu{}'.format(i + 1))(out)
|
||||
skip = th.cat([out, skip], dim=1)
|
||||
return out
|
||||
|
||||
|
||||
class UniDeepFsmnDilated(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
input_dim,
|
||||
output_dim,
|
||||
lorder=None,
|
||||
hidden_size=None,
|
||||
depth=2):
|
||||
super(UniDeepFsmnDilated, self).__init__()
|
||||
|
||||
self.input_dim = input_dim
|
||||
self.output_dim = output_dim
|
||||
self.depth = depth
|
||||
if lorder is None:
|
||||
return
|
||||
self.lorder = lorder
|
||||
self.hidden_size = hidden_size
|
||||
self.linear = nn.Linear(input_dim, hidden_size)
|
||||
self.project = nn.Linear(hidden_size, output_dim, bias=False)
|
||||
self.conv = DilatedDenseNet(
|
||||
depth=self.depth, lorder=lorder, in_channels=output_dim)
|
||||
|
||||
def forward(self, input):
|
||||
f1 = F.relu(self.linear(input))
|
||||
p1 = self.project(f1)
|
||||
x = th.unsqueeze(p1, 1)
|
||||
x_per = x.permute(0, 3, 2, 1)
|
||||
out = self.conv(x_per)
|
||||
out1 = out.permute(0, 3, 2, 1)
|
||||
|
||||
return input + out1.squeeze()
|
||||
125
modelscope/models/audio/separation/m2/layer_norm.py
Normal file
125
modelscope/models/audio/separation/m2/layer_norm.py
Normal file
@@ -0,0 +1,125 @@
|
||||
# Copyright 2018 Northwestern Polytechnical University (author: Ke Wang)
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class CLayerNorm(nn.LayerNorm):
|
||||
"""Channel-wise layer normalization."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(CLayerNorm, self).__init__(*args, **kwargs)
|
||||
|
||||
def forward(self, sample):
|
||||
"""Forward function.
|
||||
|
||||
Args:
|
||||
sample: [batch_size, channels, length]
|
||||
"""
|
||||
if sample.dim() != 3:
|
||||
raise RuntimeError('{} only accept 3-D tensor as input'.format(
|
||||
self.__name__))
|
||||
# [N, C, T] -> [N, T, C]
|
||||
sample = torch.transpose(sample, 1, 2)
|
||||
# LayerNorm
|
||||
sample = super().forward(sample)
|
||||
# [N, T, C] -> [N, C, T]
|
||||
sample = torch.transpose(sample, 1, 2)
|
||||
return sample
|
||||
|
||||
|
||||
class ILayerNorm(nn.InstanceNorm1d):
|
||||
"""Channel-wise layer normalization."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(ILayerNorm, self).__init__(*args, **kwargs)
|
||||
|
||||
def forward(self, sample):
|
||||
"""Forward function.
|
||||
|
||||
Args:
|
||||
sample: [batch_size, channels, length]
|
||||
"""
|
||||
if sample.dim() != 3:
|
||||
raise RuntimeError('{} only accept 3-D tensor as input'.format(
|
||||
self.__name__))
|
||||
# [N, C, T] -> [N, T, C]
|
||||
sample = torch.transpose(sample, 1, 2)
|
||||
# LayerNorm
|
||||
sample = super().forward(sample)
|
||||
# [N, T, C] -> [N, C, T]
|
||||
sample = torch.transpose(sample, 1, 2)
|
||||
return sample
|
||||
|
||||
|
||||
class GLayerNorm(nn.Module):
|
||||
"""Global Layer Normalization for TasNet."""
|
||||
|
||||
def __init__(self, channels, eps=1e-5):
|
||||
super(GLayerNorm, self).__init__()
|
||||
self.eps = eps
|
||||
self.norm_dim = channels
|
||||
self.gamma = nn.Parameter(torch.Tensor(channels))
|
||||
self.beta = nn.Parameter(torch.Tensor(channels))
|
||||
# self.register_parameter('weight', self.gamma)
|
||||
# self.register_parameter('bias', self.beta)
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
nn.init.ones_(self.gamma)
|
||||
nn.init.zeros_(self.beta)
|
||||
|
||||
def forward(self, sample):
|
||||
"""Forward function.
|
||||
|
||||
Args:
|
||||
sample: [batch_size, channels, length]
|
||||
"""
|
||||
if sample.dim() != 3:
|
||||
raise RuntimeError('{} only accept 3-D tensor as input'.format(
|
||||
self.__name__))
|
||||
# [N, C, T] -> [N, T, C]
|
||||
sample = torch.transpose(sample, 1, 2)
|
||||
# Mean and variance [N, 1, 1]
|
||||
mean = torch.mean(sample, (1, 2), keepdim=True)
|
||||
var = torch.mean((sample - mean)**2, (1, 2), keepdim=True)
|
||||
sample = (sample
|
||||
- mean) / torch.sqrt(var + self.eps) * self.gamma + self.beta
|
||||
# [N, T, C] -> [N, C, T]
|
||||
sample = torch.transpose(sample, 1, 2)
|
||||
return sample
|
||||
|
||||
|
||||
class _LayerNorm(nn.Module):
|
||||
"""Layer Normalization base class."""
|
||||
|
||||
def __init__(self, channel_size):
|
||||
super(_LayerNorm, self).__init__()
|
||||
self.channel_size = channel_size
|
||||
self.gamma = nn.Parameter(torch.ones(channel_size), requires_grad=True)
|
||||
self.beta = nn.Parameter(torch.zeros(channel_size), requires_grad=True)
|
||||
|
||||
def apply_gain_and_bias(self, normed_x):
|
||||
""" Assumes input of size `[batch, chanel, *]`. """
|
||||
return (self.gamma * normed_x.transpose(1, -1) + self.beta).transpose(
|
||||
1, -1)
|
||||
|
||||
|
||||
class GlobLayerNorm(_LayerNorm):
|
||||
"""Global Layer Normalization (globLN)."""
|
||||
|
||||
def forward(self, x):
|
||||
""" Applies forward pass.
|
||||
Works for any input size > 2D.
|
||||
|
||||
Args:
|
||||
x (:class:`torch.Tensor`): Shape `[batch, chan, *]`
|
||||
Returns:
|
||||
:class:`torch.Tensor`: gLN_x `[batch, chan, *]`
|
||||
"""
|
||||
dims = list(range(1, len(x.shape)))
|
||||
mean = x.mean(dim=dims, keepdim=True)
|
||||
var = torch.pow(x - mean, 2).mean(dim=dims, keepdim=True)
|
||||
return self.apply_gain_and_bias((x - mean) / (var + 1e-8).sqrt())
|
||||
599
modelscope/models/audio/separation/m2/mossformer.py
Normal file
599
modelscope/models/audio/separation/m2/mossformer.py
Normal file
@@ -0,0 +1,599 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
# Some code here is modified based on speechbrain and can be found on github
|
||||
# https://github.com/speechbrain/speechbrain/blob/develop/speechbrain/lobes/models/dual_path.py
|
||||
"""Library to support dual-path speech separation.
|
||||
|
||||
Authors
|
||||
* Cem Subakan 2020
|
||||
* Mirco Ravanelli 2020
|
||||
* Samuele Cornell 2020
|
||||
* Mirko Bronzi 2020
|
||||
* Jianyuan Zhong 2020
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models import MODELS, TorchModel
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from .mossformer_block import MossformerBlockGFSMN, ScaledSinuEmbedding
|
||||
|
||||
EPS = 1e-8
|
||||
|
||||
|
||||
class GlobalLayerNorm(nn.Module):
|
||||
"""Calculate Global Layer Normalization.
|
||||
|
||||
Args:
|
||||
dim : (int or list or torch.Size)
|
||||
Input shape from an expected input of size.
|
||||
eps : float
|
||||
A value added to the denominator for numerical stability.
|
||||
elementwise_affine : bool
|
||||
A boolean value that when set to True,
|
||||
this module has learnable per-element affine parameters
|
||||
initialized to ones (for weights) and zeros (for biases).
|
||||
|
||||
Example:
|
||||
>>> x = torch.randn(5, 10, 20)
|
||||
>>> GLN = GlobalLayerNorm(10, 3)
|
||||
>>> x_norm = GLN(x)
|
||||
"""
|
||||
|
||||
def __init__(self, dim, shape, eps=1e-8, elementwise_affine=True):
|
||||
super(GlobalLayerNorm, self).__init__()
|
||||
self.dim = dim
|
||||
self.eps = eps
|
||||
self.elementwise_affine = elementwise_affine
|
||||
|
||||
if self.elementwise_affine:
|
||||
if shape == 3:
|
||||
self.weight = nn.Parameter(torch.ones(self.dim, 1))
|
||||
self.bias = nn.Parameter(torch.zeros(self.dim, 1))
|
||||
if shape == 4:
|
||||
self.weight = nn.Parameter(torch.ones(self.dim, 1, 1))
|
||||
self.bias = nn.Parameter(torch.zeros(self.dim, 1, 1))
|
||||
else:
|
||||
self.register_parameter('weight', None)
|
||||
self.register_parameter('bias', None)
|
||||
|
||||
def forward(self, x):
|
||||
"""Returns the normalized tensor.
|
||||
|
||||
Args:
|
||||
x : torch.Tensor
|
||||
Tensor of size [N, C, K, S] or [N, C, L].
|
||||
"""
|
||||
# x = N x C x K x S or N x C x L
|
||||
# N x 1 x 1
|
||||
# cln: mean,var N x 1 x K x S
|
||||
# gln: mean,var N x 1 x 1
|
||||
if x.dim() == 3:
|
||||
mean = torch.mean(x, (1, 2), keepdim=True)
|
||||
var = torch.mean((x - mean)**2, (1, 2), keepdim=True)
|
||||
if self.elementwise_affine:
|
||||
# yapf: disable
|
||||
x = (self.weight * (x - mean) / torch.sqrt(var + self.eps)
|
||||
+ self.bias)
|
||||
# yapf: enable
|
||||
else:
|
||||
x = (x - mean) / torch.sqrt(var + self.eps)
|
||||
|
||||
if x.dim() == 4:
|
||||
mean = torch.mean(x, (1, 2, 3), keepdim=True)
|
||||
var = torch.mean((x - mean)**2, (1, 2, 3), keepdim=True)
|
||||
if self.elementwise_affine:
|
||||
# yapf: disable
|
||||
x = (self.weight * (x - mean) / torch.sqrt(var + self.eps)
|
||||
+ self.bias)
|
||||
# yapf: enable
|
||||
else:
|
||||
x = (x - mean) / torch.sqrt(var + self.eps)
|
||||
return x
|
||||
|
||||
|
||||
class CumulativeLayerNorm(nn.LayerNorm):
|
||||
"""Calculate Cumulative Layer Normalization.
|
||||
|
||||
Args:
|
||||
dim : int
|
||||
Dimension that you want to normalize.
|
||||
elementwise_affine : True
|
||||
Learnable per-element affine parameters.
|
||||
|
||||
Example
|
||||
-------
|
||||
>>> x = torch.randn(5, 10, 20)
|
||||
>>> CLN = CumulativeLayerNorm(10)
|
||||
>>> x_norm = CLN(x)
|
||||
"""
|
||||
|
||||
def __init__(self, dim, elementwise_affine=True):
|
||||
super(CumulativeLayerNorm, self).__init__(
|
||||
dim, elementwise_affine=elementwise_affine, eps=1e-8)
|
||||
|
||||
def forward(self, x):
|
||||
"""Returns the normalized tensor.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
x : torch.Tensor
|
||||
Tensor size [N, C, K, S] or [N, C, L]
|
||||
"""
|
||||
# x: N x C x K x S or N x C x L
|
||||
# N x K x S x C
|
||||
if x.dim() == 4:
|
||||
x = x.permute(0, 2, 3, 1).contiguous()
|
||||
# N x K x S x C == only channel norm
|
||||
x = super().forward(x)
|
||||
# N x C x K x S
|
||||
x = x.permute(0, 3, 1, 2).contiguous()
|
||||
if x.dim() == 3:
|
||||
x = torch.transpose(x, 1, 2)
|
||||
# N x L x C == only channel norm
|
||||
x = super().forward(x)
|
||||
# N x C x L
|
||||
x = torch.transpose(x, 1, 2)
|
||||
return x
|
||||
|
||||
|
||||
def select_norm(norm, dim, shape):
|
||||
"""Just a wrapper to select the normalization type.
|
||||
"""
|
||||
|
||||
if norm == 'gln':
|
||||
return GlobalLayerNorm(dim, shape, elementwise_affine=True)
|
||||
if norm == 'cln':
|
||||
return CumulativeLayerNorm(dim, elementwise_affine=True)
|
||||
if norm == 'ln':
|
||||
return nn.GroupNorm(1, dim, eps=1e-8)
|
||||
else:
|
||||
return nn.BatchNorm1d(dim)
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
"""Convolutional Encoder Layer.
|
||||
|
||||
Args:
|
||||
kernel_size : int
|
||||
Length of filters.
|
||||
in_channels : int
|
||||
Number of input channels.
|
||||
out_channels : int
|
||||
Number of output channels.
|
||||
|
||||
Example:
|
||||
>>> x = torch.randn(2, 1000)
|
||||
>>> encoder = Encoder(kernel_size=4, out_channels=64)
|
||||
>>> h = encoder(x)
|
||||
>>> h.shape
|
||||
torch.Size([2, 64, 499])
|
||||
"""
|
||||
|
||||
def __init__(self, kernel_size=2, out_channels=64, in_channels=1):
|
||||
super(Encoder, self).__init__()
|
||||
self.conv1d = nn.Conv1d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=kernel_size // 2,
|
||||
groups=1,
|
||||
bias=False,
|
||||
)
|
||||
self.in_channels = in_channels
|
||||
|
||||
def forward(self, x):
|
||||
"""Return the encoded output.
|
||||
|
||||
Args:
|
||||
x : torch.Tensor
|
||||
Input tensor with dimensionality [B, L].
|
||||
|
||||
Returns:
|
||||
x : torch.Tensor
|
||||
Encoded tensor with dimensionality [B, N, T_out].
|
||||
where B = Batchsize
|
||||
L = Number of timepoints
|
||||
N = Number of filters
|
||||
T_out = Number of timepoints at the output of the encoder
|
||||
"""
|
||||
# B x L -> B x 1 x L
|
||||
if self.in_channels == 1:
|
||||
x = torch.unsqueeze(x, dim=1)
|
||||
# B x 1 x L -> B x N x T_out
|
||||
x = self.conv1d(x)
|
||||
x = F.relu(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Decoder(nn.ConvTranspose1d):
|
||||
"""A decoder layer that consists of ConvTranspose1d.
|
||||
|
||||
Args:
|
||||
kernel_size : int
|
||||
Length of filters.
|
||||
in_channels : int
|
||||
Number of input channels.
|
||||
out_channels : int
|
||||
Number of output channels.
|
||||
|
||||
|
||||
Example:
|
||||
---------
|
||||
>>> x = torch.randn(2, 100, 1000)
|
||||
>>> decoder = Decoder(kernel_size=4, in_channels=100, out_channels=1)
|
||||
>>> h = decoder(x)
|
||||
>>> h.shape
|
||||
torch.Size([2, 1003])
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(Decoder, self).__init__(*args, **kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
"""Return the decoded output.
|
||||
|
||||
Args:
|
||||
x : torch.Tensor
|
||||
Input tensor with dimensionality [B, N, L].
|
||||
where, B = Batchsize,
|
||||
N = number of filters
|
||||
L = time points
|
||||
"""
|
||||
|
||||
if x.dim() not in [2, 3]:
|
||||
raise RuntimeError('{} accept 3/4D tensor as input'.format(
|
||||
self.__name__))
|
||||
x = super().forward(x if x.dim() == 3 else torch.unsqueeze(x, 1))
|
||||
|
||||
if torch.squeeze(x).dim() == 1:
|
||||
x = torch.squeeze(x, dim=1)
|
||||
else:
|
||||
x = torch.squeeze(x)
|
||||
return x
|
||||
|
||||
|
||||
class MossFormerM(nn.Module):
|
||||
"""This class implements the transformer encoder.
|
||||
|
||||
Args:
|
||||
num_blocks : int
|
||||
Number of mossformer blocks to include.
|
||||
d_model : int
|
||||
The dimension of the input embedding.
|
||||
attn_dropout : float
|
||||
Dropout for the self-attention (Optional).
|
||||
group_size: int
|
||||
the chunk size
|
||||
query_key_dim: int
|
||||
the attention vector dimension
|
||||
expansion_factor: int
|
||||
the expansion factor for the linear projection in conv module
|
||||
causal: bool
|
||||
true for causal / false for non causal
|
||||
|
||||
Example:
|
||||
-------
|
||||
>>> import torch
|
||||
>>> x = torch.rand((8, 60, 512))
|
||||
>>> net = TransformerEncoder_MossFormerM(num_blocks=8, d_model=512)
|
||||
>>> output, _ = net(x)
|
||||
>>> output.shape
|
||||
torch.Size([8, 60, 512])
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_blocks,
|
||||
d_model=None,
|
||||
causal=False,
|
||||
group_size=256,
|
||||
query_key_dim=128,
|
||||
expansion_factor=4.,
|
||||
attn_dropout=0.1):
|
||||
super().__init__()
|
||||
|
||||
self.mossformerM = MossformerBlockGFSMN(
|
||||
dim=d_model,
|
||||
depth=num_blocks,
|
||||
group_size=group_size,
|
||||
query_key_dim=query_key_dim,
|
||||
expansion_factor=expansion_factor,
|
||||
causal=causal,
|
||||
attn_dropout=attn_dropout)
|
||||
self.norm = nn.LayerNorm(d_model, eps=1e-6)
|
||||
|
||||
def forward(self, src):
|
||||
"""
|
||||
Args:
|
||||
src : torch.Tensor
|
||||
Tensor shape [B, L, N],
|
||||
where, B = Batchsize,
|
||||
L = time points
|
||||
N = number of filters
|
||||
The sequence to the encoder layer (required).
|
||||
"""
|
||||
output = self.mossformerM(src)
|
||||
output = self.norm(output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class ComputationBlock(nn.Module):
|
||||
"""Computation block for dual-path processing.
|
||||
|
||||
Args:
|
||||
num_blocks : int
|
||||
Number of mossformer blocks to include.
|
||||
out_channels : int
|
||||
Dimensionality of inter/intra model.
|
||||
norm : str
|
||||
Normalization type.
|
||||
skip_around_intra : bool
|
||||
Skip connection around the intra layer.
|
||||
|
||||
Example:
|
||||
---------
|
||||
>>> comp_block = ComputationBlock(64)
|
||||
>>> x = torch.randn(10, 64, 100)
|
||||
>>> x = comp_block(x)
|
||||
>>> x.shape
|
||||
torch.Size([10, 64, 100])
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_blocks,
|
||||
out_channels,
|
||||
norm='ln',
|
||||
skip_around_intra=True,
|
||||
):
|
||||
super(ComputationBlock, self).__init__()
|
||||
|
||||
# MossFormer+: MossFormer with recurrence
|
||||
self.intra_mdl = MossFormerM(
|
||||
num_blocks=num_blocks, d_model=out_channels)
|
||||
self.skip_around_intra = skip_around_intra
|
||||
|
||||
# Norm
|
||||
self.norm = norm
|
||||
if norm is not None:
|
||||
self.intra_norm = select_norm(norm, out_channels, 3)
|
||||
|
||||
def forward(self, x):
|
||||
"""Returns the output tensor.
|
||||
|
||||
Args:
|
||||
x : torch.Tensor
|
||||
Input tensor of dimension [B, N, S].
|
||||
|
||||
Returns:
|
||||
out: torch.Tensor
|
||||
Output tensor of dimension [B, N, S].
|
||||
where, B = Batchsize,
|
||||
N = number of filters
|
||||
S = sequence time index
|
||||
"""
|
||||
B, N, S = x.shape
|
||||
# intra RNN
|
||||
# [B, S, N]
|
||||
intra = x.permute(0, 2, 1).contiguous()
|
||||
|
||||
intra = self.intra_mdl(intra)
|
||||
|
||||
# [B, N, S]
|
||||
intra = intra.permute(0, 2, 1).contiguous()
|
||||
if self.norm is not None:
|
||||
intra = self.intra_norm(intra)
|
||||
|
||||
# [B, N, S]
|
||||
if self.skip_around_intra:
|
||||
intra = intra + x
|
||||
|
||||
out = intra
|
||||
return out
|
||||
|
||||
|
||||
class MossFormerMaskNet(nn.Module):
|
||||
"""The dual path model which is the basis for dualpathrnn, sepformer, dptnet.
|
||||
|
||||
Args:
|
||||
in_channels : int
|
||||
Number of channels at the output of the encoder.
|
||||
out_channels : int
|
||||
Number of channels that would be inputted to the intra and inter blocks.
|
||||
norm : str
|
||||
Normalization type.
|
||||
num_spks : int
|
||||
Number of sources (speakers).
|
||||
skip_around_intra : bool
|
||||
Skip connection around intra.
|
||||
use_global_pos_enc : bool
|
||||
Global positional encodings.
|
||||
max_length : int
|
||||
Maximum sequence length.
|
||||
|
||||
Example:
|
||||
---------
|
||||
>>> mossformer_block = MossFormerM(1, 64, 8)
|
||||
>>> mossformer_masknet = MossFormerMaskNet(64, 64, intra_block, num_spks=2)
|
||||
>>> x = torch.randn(10, 64, 2000)
|
||||
>>> x = mossformer_masknet(x)
|
||||
>>> x.shape
|
||||
torch.Size([2, 10, 64, 2000])
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
num_blocks=24,
|
||||
norm='ln',
|
||||
num_spks=2,
|
||||
skip_around_intra=True,
|
||||
use_global_pos_enc=True,
|
||||
max_length=20000,
|
||||
):
|
||||
super(MossFormerMaskNet, self).__init__()
|
||||
self.num_spks = num_spks
|
||||
self.num_blocks = num_blocks
|
||||
self.norm = select_norm(norm, in_channels, 3)
|
||||
self.conv1d_encoder = nn.Conv1d(
|
||||
in_channels, out_channels, 1, bias=False)
|
||||
self.use_global_pos_enc = use_global_pos_enc
|
||||
|
||||
if self.use_global_pos_enc:
|
||||
self.pos_enc = ScaledSinuEmbedding(out_channels)
|
||||
|
||||
self.mdl = ComputationBlock(
|
||||
num_blocks,
|
||||
out_channels,
|
||||
norm,
|
||||
skip_around_intra=skip_around_intra,
|
||||
)
|
||||
|
||||
self.conv1d_out = nn.Conv1d(
|
||||
out_channels, out_channels * num_spks, kernel_size=1)
|
||||
self.conv1_decoder = nn.Conv1d(
|
||||
out_channels, in_channels, 1, bias=False)
|
||||
self.prelu = nn.PReLU()
|
||||
self.activation = nn.ReLU()
|
||||
# gated output layer
|
||||
self.output = nn.Sequential(
|
||||
nn.Conv1d(out_channels, out_channels, 1), nn.Tanh())
|
||||
self.output_gate = nn.Sequential(
|
||||
nn.Conv1d(out_channels, out_channels, 1), nn.Sigmoid())
|
||||
|
||||
def forward(self, x):
|
||||
"""Returns the output tensor.
|
||||
|
||||
Args:
|
||||
x : torch.Tensor
|
||||
Input tensor of dimension [B, N, S].
|
||||
|
||||
Returns:
|
||||
out : torch.Tensor
|
||||
Output tensor of dimension [spks, B, N, S]
|
||||
where, spks = Number of speakers
|
||||
B = Batchsize,
|
||||
N = number of filters
|
||||
S = the number of time frames
|
||||
"""
|
||||
# before each line we indicate the shape after executing the line
|
||||
# [B, N, L]
|
||||
x = self.norm(x)
|
||||
|
||||
# [B, N, L]
|
||||
x = self.conv1d_encoder(x)
|
||||
if self.use_global_pos_enc:
|
||||
base = x
|
||||
x = x.transpose(1, -1)
|
||||
emb = self.pos_enc(x)
|
||||
emb = emb.transpose(0, -1)
|
||||
x = base + emb
|
||||
|
||||
# [B, N, S]
|
||||
x = self.mdl(x)
|
||||
x = self.prelu(x)
|
||||
|
||||
# [B, N*spks, S]
|
||||
x = self.conv1d_out(x)
|
||||
B, _, S = x.shape
|
||||
|
||||
# [B*spks, N, S]
|
||||
x = x.view(B * self.num_spks, -1, S)
|
||||
|
||||
# [B*spks, N, S]
|
||||
x = self.output(x) * self.output_gate(x)
|
||||
|
||||
# [B*spks, N, S]
|
||||
x = self.conv1_decoder(x)
|
||||
|
||||
# [B, spks, N, S]
|
||||
_, N, L = x.shape
|
||||
x = x.view(B, self.num_spks, N, L)
|
||||
x = self.activation(x)
|
||||
|
||||
# [spks, B, N, S]
|
||||
x = x.transpose(0, 1)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
@MODELS.register_module(
|
||||
Tasks.speech_separation,
|
||||
module_name=Models.speech_mossformer2_separation_temporal_8k)
|
||||
class MossFormer2(TorchModel):
|
||||
"""Library to support MossFormer speech separation.
|
||||
|
||||
Args:
|
||||
model_dir (str): the model path.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model_dir: str,
|
||||
in_channels=512,
|
||||
out_channels=512,
|
||||
num_blocks=24,
|
||||
kernel_size=16,
|
||||
norm='ln',
|
||||
num_spks=2,
|
||||
skip_around_intra=True,
|
||||
use_global_pos_enc=True,
|
||||
max_length=20000,
|
||||
*args,
|
||||
**kwargs):
|
||||
super().__init__(model_dir, *args, **kwargs)
|
||||
self.num_spks = num_spks
|
||||
self.enc = Encoder(
|
||||
kernel_size=kernel_size, out_channels=in_channels, in_channels=1)
|
||||
self.mask_net = MossFormerMaskNet(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
num_blocks=num_blocks,
|
||||
norm=norm,
|
||||
num_spks=num_spks,
|
||||
skip_around_intra=skip_around_intra,
|
||||
use_global_pos_enc=use_global_pos_enc,
|
||||
max_length=max_length,
|
||||
)
|
||||
self.dec = Decoder(
|
||||
in_channels=out_channels,
|
||||
out_channels=1,
|
||||
kernel_size=kernel_size,
|
||||
stride=kernel_size // 2,
|
||||
bias=False)
|
||||
|
||||
def forward(self, input):
|
||||
x = self.enc(input)
|
||||
mask = self.mask_net(x)
|
||||
x = torch.stack([x] * self.num_spks)
|
||||
sep_x = x * mask
|
||||
|
||||
# Decoding
|
||||
est_source = torch.cat(
|
||||
[self.dec(sep_x[i]).unsqueeze(-1) for i in range(self.num_spks)],
|
||||
dim=-1,
|
||||
)
|
||||
T_origin = input.size(1)
|
||||
T_est = est_source.size(1)
|
||||
if T_origin > T_est:
|
||||
est_source = F.pad(est_source, (0, 0, 0, T_origin - T_est))
|
||||
else:
|
||||
est_source = est_source[:, :T_origin, :]
|
||||
return est_source
|
||||
|
||||
def load_check_point(self, load_path=None, device=None):
|
||||
if not load_path:
|
||||
load_path = self.model_dir
|
||||
if not device:
|
||||
device = torch.device('cpu')
|
||||
self.load_state_dict(
|
||||
torch.load(
|
||||
os.path.join(load_path, ModelFile.TORCH_MODEL_FILE),
|
||||
map_location=device),
|
||||
strict=False)
|
||||
548
modelscope/models/audio/separation/m2/mossformer_block.py
Normal file
548
modelscope/models/audio/separation/m2/mossformer_block.py
Normal file
@@ -0,0 +1,548 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from rotary_embedding_torch import RotaryEmbedding
|
||||
from torch import einsum, nn
|
||||
|
||||
from .conv_module import ConvModule, FFConvMDilated
|
||||
from .fsmn import UniDeepFsmn, UniDeepFsmnDilated
|
||||
from .layer_norm import CLayerNorm
|
||||
|
||||
# functions
|
||||
|
||||
|
||||
def identity(t, *args, **kwargs):
|
||||
return t
|
||||
|
||||
|
||||
def append_dims(x, num_dims):
|
||||
if num_dims <= 0:
|
||||
return x
|
||||
return x.view(*x.shape, *((1, ) * num_dims))
|
||||
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
|
||||
def padding_to_multiple_of(n, mult):
|
||||
remainder = n % mult
|
||||
if remainder == 0:
|
||||
return 0
|
||||
return mult - remainder
|
||||
|
||||
|
||||
# scalenorm
|
||||
|
||||
|
||||
class ScaleNorm(nn.Module):
|
||||
|
||||
def __init__(self, dim, eps=1e-5):
|
||||
super().__init__()
|
||||
self.scale = dim**-0.5
|
||||
self.eps = eps
|
||||
self.g = nn.Parameter(torch.ones(1))
|
||||
|
||||
def forward(self, x):
|
||||
norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
|
||||
return x / norm.clamp(min=self.eps) * self.g
|
||||
|
||||
|
||||
# absolute positional encodings
|
||||
|
||||
|
||||
class ScaledSinuEmbedding(nn.Module):
|
||||
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.scale = nn.Parameter(torch.ones(1, ))
|
||||
inv_freq = 1. / (10000**(torch.arange(0, dim, 2).float() / dim))
|
||||
self.register_buffer('inv_freq', inv_freq)
|
||||
|
||||
def forward(self, x):
|
||||
n, device = x.shape[1], x.device
|
||||
t = torch.arange(n, device=device).type_as(self.inv_freq)
|
||||
sinu = einsum('i , j -> i j', t, self.inv_freq)
|
||||
emb = torch.cat((sinu.sin(), sinu.cos()), dim=-1)
|
||||
return emb * self.scale
|
||||
|
||||
|
||||
class OffsetScale(nn.Module):
|
||||
|
||||
def __init__(self, dim, heads=1):
|
||||
super().__init__()
|
||||
self.gamma = nn.Parameter(torch.ones(heads, dim))
|
||||
self.beta = nn.Parameter(torch.zeros(heads, dim))
|
||||
nn.init.normal_(self.gamma, std=0.02)
|
||||
|
||||
def forward(self, x):
|
||||
out = einsum('... d, h d -> ... h d', x, self.gamma) + self.beta
|
||||
return out.unbind(dim=-2)
|
||||
|
||||
|
||||
class FFConvM(nn.Module):
|
||||
|
||||
def __init__(self, dim_in, dim_out, norm_klass=nn.LayerNorm, dropout=0.1):
|
||||
super().__init__()
|
||||
self.mdl = nn.Sequential(
|
||||
norm_klass(dim_in), nn.Linear(dim_in, dim_out), nn.SiLU(),
|
||||
ConvModule(dim_out), nn.Dropout(dropout))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
):
|
||||
output = self.mdl(x)
|
||||
return output
|
||||
|
||||
|
||||
class GroupLinear(nn.Module):
|
||||
|
||||
def __init__(self, dim_in, dim_out, K=4):
|
||||
super().__init__()
|
||||
hidden = dim_in // 2
|
||||
self.group_conv = nn.Conv1d(
|
||||
dim_in, hidden, groups=dim_in // K, kernel_size=1)
|
||||
self.norm = nn.LayerNorm(hidden)
|
||||
self.linear = nn.Linear(hidden, dim_out)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
):
|
||||
x1 = x.transpose(2, 1)
|
||||
conv_out = self.group_conv(x1)
|
||||
x2 = self.norm(conv_out.transpose(2, 1))
|
||||
x3 = self.linear(x2)
|
||||
return x3
|
||||
|
||||
|
||||
class FFM(nn.Module):
|
||||
|
||||
def __init__(self, dim_in, dim_out, norm_klass=nn.LayerNorm, dropout=0.1):
|
||||
super().__init__()
|
||||
self.mdl = nn.Sequential(
|
||||
norm_klass(dim_in), nn.Linear(dim_in, dim_out), nn.SiLU(),
|
||||
nn.Dropout(dropout))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
):
|
||||
output = self.mdl(x)
|
||||
return output
|
||||
|
||||
|
||||
# FLASH
|
||||
class FLASH_ShareA_FFConvM(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
dim,
|
||||
group_size=256,
|
||||
query_key_dim=128,
|
||||
expansion_factor=1.,
|
||||
causal=False,
|
||||
dropout=0.1,
|
||||
rotary_pos_emb=None,
|
||||
norm_klass=nn.LayerNorm,
|
||||
shift_tokens=True):
|
||||
super().__init__()
|
||||
hidden_dim = int(dim * expansion_factor)
|
||||
self.group_size = group_size
|
||||
self.causal = causal
|
||||
self.shift_tokens = shift_tokens
|
||||
|
||||
# positional embeddings
|
||||
self.rotary_pos_emb = rotary_pos_emb
|
||||
# norm
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
# projections
|
||||
self.to_hidden = FFConvM(
|
||||
dim_in=dim,
|
||||
dim_out=hidden_dim,
|
||||
norm_klass=norm_klass,
|
||||
dropout=dropout,
|
||||
)
|
||||
self.to_qk = FFConvM(
|
||||
dim_in=dim,
|
||||
dim_out=query_key_dim,
|
||||
norm_klass=norm_klass,
|
||||
dropout=dropout,
|
||||
)
|
||||
|
||||
self.qk_offset_scale = OffsetScale(query_key_dim, heads=4)
|
||||
|
||||
self.to_out = FFConvM(
|
||||
dim_in=dim * 2,
|
||||
dim_out=dim,
|
||||
norm_klass=norm_klass,
|
||||
dropout=dropout,
|
||||
)
|
||||
|
||||
self.gateActivate = nn.Sigmoid()
|
||||
|
||||
def forward(self, x, *, mask=None):
|
||||
"""
|
||||
b - batch
|
||||
n - sequence length (within groups)
|
||||
g - group dimension
|
||||
d - feature dimension (keys)
|
||||
e - feature dimension (values)
|
||||
i - sequence dimension (source)
|
||||
j - sequence dimension (target)
|
||||
"""
|
||||
# prenorm
|
||||
normed_x = x
|
||||
|
||||
if self.shift_tokens:
|
||||
x_shift, x_pass = normed_x.chunk(2, dim=-1)
|
||||
x_shift = F.pad(x_shift, (0, 0, 1, -1), value=0.)
|
||||
normed_x = torch.cat((x_shift, x_pass), dim=-1)
|
||||
|
||||
# initial projections
|
||||
v, u = self.to_hidden(normed_x).chunk(2, dim=-1)
|
||||
qk = self.to_qk(normed_x)
|
||||
|
||||
# offset and scale
|
||||
quad_q, lin_q, quad_k, lin_k = self.qk_offset_scale(qk)
|
||||
att_v, att_u = self.cal_attention(x, quad_q, lin_q, quad_k, lin_k, v,
|
||||
u)
|
||||
out = (att_u * v) * self.gateActivate(att_v * u)
|
||||
|
||||
x = x + self.to_out(out)
|
||||
return x
|
||||
|
||||
def cal_attention(self, x, quad_q, lin_q, quad_k, lin_k, v, u, mask=None):
|
||||
b, n, device, g = x.shape[0], x.shape[-2], x.device, self.group_size
|
||||
|
||||
if exists(mask):
|
||||
lin_mask = rearrange(mask, '... -> ... 1')
|
||||
lin_k = lin_k.masked_fill(~lin_mask, 0.)
|
||||
|
||||
# rotate queries and keys
|
||||
if exists(self.rotary_pos_emb):
|
||||
quad_q, lin_q, quad_k, lin_k = map(
|
||||
self.rotary_pos_emb.rotate_queries_or_keys,
|
||||
(quad_q, lin_q, quad_k, lin_k))
|
||||
|
||||
# padding for groups
|
||||
padding = padding_to_multiple_of(n, g)
|
||||
|
||||
if padding > 0:
|
||||
quad_q, quad_k, lin_q, lin_k, v, u = map(
|
||||
lambda t: F.pad(t, (0, 0, 0, padding), value=0.),
|
||||
(quad_q, quad_k, lin_q, lin_k, v, u))
|
||||
|
||||
mask = default(mask,
|
||||
torch.ones((b, n), device=device, dtype=torch.bool))
|
||||
mask = F.pad(mask, (0, padding), value=False)
|
||||
|
||||
# group along sequence
|
||||
quad_q, quad_k, lin_q, lin_k, v, u = map(
|
||||
lambda t: rearrange(t, 'b (g n) d -> b g n d', n=self.group_size),
|
||||
(quad_q, quad_k, lin_q, lin_k, v, u))
|
||||
|
||||
if exists(mask):
|
||||
mask = rearrange(mask, 'b (g j) -> b g 1 j', j=g)
|
||||
|
||||
# calculate quadratic attention output
|
||||
sim = einsum('... i d, ... j d -> ... i j', quad_q, quad_k) / g
|
||||
|
||||
attn = F.relu(sim)**2
|
||||
attn = self.dropout(attn)
|
||||
|
||||
if exists(mask):
|
||||
attn = attn.masked_fill(~mask, 0.)
|
||||
|
||||
if self.causal:
|
||||
causal_mask = torch.ones((g, g), dtype=torch.bool,
|
||||
device=device).triu(1)
|
||||
attn = attn.masked_fill(causal_mask, 0.)
|
||||
|
||||
quad_out_v = einsum('... i j, ... j d -> ... i d', attn, v)
|
||||
quad_out_u = einsum('... i j, ... j d -> ... i d', attn, u)
|
||||
|
||||
# calculate linear attention output
|
||||
if self.causal:
|
||||
lin_kv = einsum('b g n d, b g n e -> b g d e', lin_k, v) / g
|
||||
# exclusive cumulative sum along group dimension
|
||||
lin_kv = lin_kv.cumsum(dim=1)
|
||||
lin_kv = F.pad(lin_kv, (0, 0, 0, 0, 1, -1), value=0.)
|
||||
lin_out_v = einsum('b g d e, b g n d -> b g n e', lin_kv, lin_q)
|
||||
|
||||
lin_ku = einsum('b g n d, b g n e -> b g d e', lin_k, u) / g
|
||||
# exclusive cumulative sum along group dimension
|
||||
lin_ku = lin_ku.cumsum(dim=1)
|
||||
lin_ku = F.pad(lin_ku, (0, 0, 0, 0, 1, -1), value=0.)
|
||||
lin_out_u = einsum('b g d e, b g n d -> b g n e', lin_ku, lin_q)
|
||||
else:
|
||||
lin_kv = einsum('b g n d, b g n e -> b d e', lin_k, v) / n
|
||||
lin_out_v = einsum('b g n d, b d e -> b g n e', lin_q, lin_kv)
|
||||
|
||||
lin_ku = einsum('b g n d, b g n e -> b d e', lin_k, u) / n
|
||||
lin_out_u = einsum('b g n d, b d e -> b g n e', lin_q, lin_ku)
|
||||
|
||||
# fold back groups into full sequence, and excise out padding
|
||||
return map(lambda t: rearrange(t, 'b g n d -> b (g n) d')[:, :n],
|
||||
(quad_out_v + lin_out_v, quad_out_u + lin_out_u))
|
||||
|
||||
|
||||
class GatedFSMNDilated(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, out_channels, lorder, hidden_size):
|
||||
super().__init__()
|
||||
self.to_u = FFConvM(
|
||||
dim_in=in_channels,
|
||||
dim_out=hidden_size,
|
||||
norm_klass=nn.LayerNorm,
|
||||
dropout=0.1,
|
||||
)
|
||||
self.to_v = FFConvM(
|
||||
dim_in=in_channels,
|
||||
dim_out=hidden_size,
|
||||
norm_klass=nn.LayerNorm,
|
||||
dropout=0.1,
|
||||
)
|
||||
self.fsmn = UniDeepFsmnDilated(in_channels, out_channels, lorder,
|
||||
hidden_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
):
|
||||
input = x
|
||||
x_u = self.to_u(x)
|
||||
x_v = self.to_v(x)
|
||||
x_u = self.fsmn(x_u)
|
||||
x = x_v * x_u + input
|
||||
return x
|
||||
|
||||
|
||||
class GatedFSMNDilatedDual(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, out_channels, lorder, hidden_size):
|
||||
super().__init__()
|
||||
self.to_u = FFConvMDilated(
|
||||
dim_in=in_channels,
|
||||
dim_out=hidden_size,
|
||||
norm_klass=nn.LayerNorm,
|
||||
dropout=0.1,
|
||||
)
|
||||
self.to_v = FFConvMDilated(
|
||||
dim_in=in_channels,
|
||||
dim_out=hidden_size,
|
||||
norm_klass=nn.LayerNorm,
|
||||
dropout=0.1,
|
||||
)
|
||||
self.fsmn = UniDeepFsmnDilated(in_channels, out_channels, lorder,
|
||||
hidden_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
):
|
||||
input = x
|
||||
x_u = self.to_u(x)
|
||||
x_v = self.to_v(x)
|
||||
x_u = self.fsmn(x_u)
|
||||
x = x_v * x_u + input
|
||||
return x
|
||||
|
||||
|
||||
class GatedFSMNBlockDilatedDual(nn.Module):
|
||||
"""1-D convolutional block."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
inner_channels=256,
|
||||
):
|
||||
super(GatedFSMNBlockDilatedDual, self).__init__()
|
||||
|
||||
self.conv1 = nn.Sequential(
|
||||
nn.Conv1d(dim, inner_channels, kernel_size=1),
|
||||
nn.PReLU(),
|
||||
)
|
||||
self.norm1 = CLayerNorm(inner_channels)
|
||||
self.gated_fsmn = GatedFSMNDilatedDual(
|
||||
inner_channels,
|
||||
inner_channels,
|
||||
lorder=20,
|
||||
hidden_size=inner_channels)
|
||||
self.norm2 = CLayerNorm(inner_channels)
|
||||
self.conv2 = nn.Conv1d(inner_channels, dim, kernel_size=1)
|
||||
|
||||
def forward(self, input):
|
||||
conv1 = self.conv1(input.transpose(2, 1))
|
||||
norm1 = self.norm1(conv1)
|
||||
seq_out = self.gated_fsmn(norm1.transpose(2, 1))
|
||||
norm2 = self.norm2(seq_out.transpose(2, 1))
|
||||
conv2 = self.conv2(norm2)
|
||||
return conv2.transpose(2, 1) + input
|
||||
|
||||
|
||||
class GatedFSMNBlockDilated(nn.Module):
|
||||
"""1-D convolutional block."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
inner_channels=256,
|
||||
group_size=256,
|
||||
norm_type='scalenorm',
|
||||
):
|
||||
super(GatedFSMNBlockDilated, self).__init__()
|
||||
|
||||
self.group_size = group_size
|
||||
|
||||
self.conv1 = nn.Sequential(
|
||||
nn.Conv1d(dim, inner_channels, kernel_size=1),
|
||||
nn.PReLU(),
|
||||
)
|
||||
self.norm1 = CLayerNorm(inner_channels)
|
||||
# block dilated with gating
|
||||
self.gated_fsmn = GatedFSMNDilated(
|
||||
inner_channels,
|
||||
inner_channels,
|
||||
lorder=20,
|
||||
hidden_size=inner_channels)
|
||||
self.norm2 = CLayerNorm(inner_channels)
|
||||
self.conv2 = nn.Conv1d(inner_channels, dim, kernel_size=1)
|
||||
|
||||
def forward(self, input):
|
||||
conv1 = self.conv1(input.transpose(2, 1))
|
||||
norm1 = self.norm1(conv1)
|
||||
seq_out = self.gated_fsmn(norm1.transpose(2, 1))
|
||||
norm2 = self.norm2(seq_out.transpose(2, 1))
|
||||
conv2 = self.conv2(norm2)
|
||||
return conv2.transpose(2, 1) + input
|
||||
|
||||
|
||||
class MossformerBlockGFSMN(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
dim,
|
||||
depth,
|
||||
group_size=256,
|
||||
query_key_dim=128,
|
||||
expansion_factor=4.,
|
||||
causal=False,
|
||||
attn_dropout=0.1,
|
||||
norm_type='scalenorm',
|
||||
shift_tokens=True):
|
||||
super().__init__()
|
||||
assert norm_type in (
|
||||
'scalenorm',
|
||||
'layernorm'), 'norm_type must be one of scalenorm or layernorm'
|
||||
|
||||
if norm_type == 'scalenorm':
|
||||
norm_klass = ScaleNorm
|
||||
elif norm_type == 'layernorm':
|
||||
norm_klass = nn.LayerNorm
|
||||
|
||||
self.group_size = group_size
|
||||
|
||||
rotary_pos_emb = RotaryEmbedding(dim=min(32, query_key_dim))
|
||||
# max rotary embedding dimensions of 32, partial Rotary embeddings, from Wang et al - GPT-J
|
||||
self.fsmn = nn.ModuleList(
|
||||
[GatedFSMNBlockDilated(dim) for _ in range(depth)])
|
||||
self.layers = nn.ModuleList([
|
||||
FLASH_ShareA_FFConvM(
|
||||
dim=dim,
|
||||
group_size=group_size,
|
||||
query_key_dim=query_key_dim,
|
||||
expansion_factor=expansion_factor,
|
||||
causal=causal,
|
||||
dropout=attn_dropout,
|
||||
rotary_pos_emb=rotary_pos_emb,
|
||||
norm_klass=norm_klass,
|
||||
shift_tokens=shift_tokens) for _ in range(depth)
|
||||
])
|
||||
|
||||
def _build_repeats(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
lorder,
|
||||
hidden_size,
|
||||
repeats=1):
|
||||
repeats = [
|
||||
UniDeepFsmn(in_channels, out_channels, lorder, hidden_size)
|
||||
for i in range(repeats)
|
||||
]
|
||||
return nn.Sequential(*repeats)
|
||||
|
||||
def forward(self, x, *, mask=None):
|
||||
ii = 0
|
||||
for flash in self.layers:
|
||||
x = flash(x, mask=mask)
|
||||
x = self.fsmn[ii](x)
|
||||
ii = ii + 1
|
||||
return x
|
||||
|
||||
|
||||
class MossformerBlock(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
dim,
|
||||
depth,
|
||||
group_size=256,
|
||||
query_key_dim=128,
|
||||
expansion_factor=4.,
|
||||
causal=False,
|
||||
attn_dropout=0.1,
|
||||
norm_type='scalenorm',
|
||||
shift_tokens=True):
|
||||
super().__init__()
|
||||
assert norm_type in (
|
||||
'scalenorm',
|
||||
'layernorm'), 'norm_type must be one of scalenorm or layernorm'
|
||||
|
||||
if norm_type == 'scalenorm':
|
||||
norm_klass = ScaleNorm
|
||||
elif norm_type == 'layernorm':
|
||||
norm_klass = nn.LayerNorm
|
||||
|
||||
self.group_size = group_size
|
||||
|
||||
rotary_pos_emb = RotaryEmbedding(dim=min(32, query_key_dim))
|
||||
# max rotary embedding dimensions of 32, partial Rotary embeddings, from Wang et al - GPT-J
|
||||
self.layers = nn.ModuleList([
|
||||
FLASH_ShareA_FFConvM(
|
||||
dim=dim,
|
||||
group_size=group_size,
|
||||
query_key_dim=query_key_dim,
|
||||
expansion_factor=expansion_factor,
|
||||
causal=causal,
|
||||
dropout=attn_dropout,
|
||||
rotary_pos_emb=rotary_pos_emb,
|
||||
norm_klass=norm_klass,
|
||||
shift_tokens=shift_tokens) for _ in range(depth)
|
||||
])
|
||||
|
||||
def _build_repeats(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
lorder,
|
||||
hidden_size,
|
||||
repeats=1):
|
||||
repeats = [
|
||||
UniDeepFsmn(in_channels, out_channels, lorder, hidden_size)
|
||||
for i in range(repeats)
|
||||
]
|
||||
return nn.Sequential(*repeats)
|
||||
|
||||
def forward(self, x, *, mask=None):
|
||||
ii = 0
|
||||
for flash in self.layers:
|
||||
x = flash(x, mask=mask)
|
||||
ii = ii + 1
|
||||
return x
|
||||
120
modelscope/models/audio/sv/lanuage_recognition_eres2net.py
Normal file
120
modelscope/models/audio/sv/lanuage_recognition_eres2net.py
Normal file
@@ -0,0 +1,120 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
from typing import Any, Dict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchaudio.compliance.kaldi as Kaldi
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models import MODELS, TorchModel
|
||||
from modelscope.models.audio.sv.DTDNN import CAMPPlus
|
||||
from modelscope.models.audio.sv.DTDNN_layers import DenseLayer
|
||||
from modelscope.models.audio.sv.ERes2Net import ERes2Net
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.device import create_device
|
||||
|
||||
|
||||
class LinearClassifier(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_dim,
|
||||
num_blocks=0,
|
||||
inter_dim=512,
|
||||
out_neurons=1000,
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
self.blocks = nn.ModuleList()
|
||||
|
||||
self.nonlinear = nn.ReLU(inplace=True)
|
||||
for _ in range(num_blocks):
|
||||
self.blocks.append(DenseLayer(input_dim, inter_dim, bias=True))
|
||||
input_dim = inter_dim
|
||||
|
||||
self.linear = nn.Linear(input_dim, out_neurons, bias=True)
|
||||
|
||||
def forward(self, x):
|
||||
# x: [B, dim]
|
||||
x = self.nonlinear(x)
|
||||
for layer in self.blocks:
|
||||
x = layer(x)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
@MODELS.register_module(
|
||||
Tasks.speech_language_recognition, module_name=Models.eres2net_lre)
|
||||
class LanguageRecognitionERes2Net(TorchModel):
|
||||
r"""A speech language recognition model using the ERes2Net architecture as the backbone.
|
||||
Args:
|
||||
model_dir: A model dir.
|
||||
model_config: The model config.
|
||||
"""
|
||||
|
||||
def __init__(self, model_dir, model_config: Dict[str, Any], *args,
|
||||
**kwargs):
|
||||
super().__init__(model_dir, model_config, *args, **kwargs)
|
||||
self.model_config = model_config
|
||||
|
||||
self.embed_dim = self.model_config['embed_dim']
|
||||
self.m_channels = self.model_config['channels']
|
||||
self.feature_dim = self.model_config['fbank_dim']
|
||||
self.sample_rate = self.model_config['sample_rate']
|
||||
self.device = create_device(kwargs['device'])
|
||||
|
||||
self.encoder = ERes2Net(
|
||||
embed_dim=self.embed_dim, m_channels=self.m_channels)
|
||||
self.backend = LinearClassifier(
|
||||
input_dim=self.embed_dim,
|
||||
out_neurons=len(self.model_config['languages']))
|
||||
|
||||
pretrained_encoder = kwargs['pretrained_encoder']
|
||||
pretrained_backend = kwargs['pretrained_backend']
|
||||
|
||||
self._load_check_point(pretrained_encoder, pretrained_backend)
|
||||
|
||||
self.encoder.to(self.device)
|
||||
self.backend.to(self.device)
|
||||
self.encoder.eval()
|
||||
self.backend.eval()
|
||||
|
||||
def forward(self, audio):
|
||||
if isinstance(audio, np.ndarray):
|
||||
audio = torch.from_numpy(audio)
|
||||
if len(audio.shape) == 1:
|
||||
audio = audio.unsqueeze(0)
|
||||
assert len(audio.shape) == 2, \
|
||||
'modelscope error: the shape of input audio to model needs to be [N, T]'
|
||||
# audio shape: [N, T]
|
||||
feature = self._extract_feature(audio)
|
||||
embs = self.encoder(feature.to(self.device))
|
||||
output = self.backend(embs)
|
||||
output = output.detach().cpu().argmax(-1)
|
||||
return output
|
||||
|
||||
def _extract_feature(self, audio):
|
||||
features = []
|
||||
for au in audio:
|
||||
feature = Kaldi.fbank(
|
||||
au.unsqueeze(0),
|
||||
num_mel_bins=self.feature_dim,
|
||||
sample_frequency=self.sample_rate)
|
||||
feature = feature - feature.mean(dim=0, keepdim=True)
|
||||
features.append(feature.unsqueeze(0))
|
||||
features = torch.cat(features)
|
||||
return features
|
||||
|
||||
def _load_check_point(self, pretrained_encoder, pretrained_backend):
|
||||
self.encoder.load_state_dict(
|
||||
torch.load(
|
||||
os.path.join(self.model_dir, pretrained_encoder),
|
||||
map_location=torch.device('cpu')))
|
||||
|
||||
self.backend.load_state_dict(
|
||||
torch.load(
|
||||
os.path.join(self.model_dir, pretrained_backend),
|
||||
map_location=torch.device('cpu')))
|
||||
@@ -61,6 +61,7 @@ class LanguageRecognitionCAMPPlus(TorchModel):
|
||||
|
||||
self.emb_size = self.model_config['emb_size']
|
||||
self.feature_dim = self.model_config['fbank_dim']
|
||||
self.sample_rate = self.model_config['sample_rate']
|
||||
self.device = create_device(kwargs['device'])
|
||||
|
||||
self.encoder = CAMPPlus(self.feature_dim, self.emb_size)
|
||||
@@ -96,7 +97,9 @@ class LanguageRecognitionCAMPPlus(TorchModel):
|
||||
features = []
|
||||
for au in audio:
|
||||
feature = Kaldi.fbank(
|
||||
au.unsqueeze(0), num_mel_bins=self.feature_dim)
|
||||
au.unsqueeze(0),
|
||||
num_mel_bins=self.feature_dim,
|
||||
sample_frequency=self.sample_rate)
|
||||
feature = feature - feature.mean(dim=0, keepdim=True)
|
||||
features.append(feature.unsqueeze(0))
|
||||
features = torch.cat(features)
|
||||
|
||||
@@ -7,10 +7,11 @@ from . import (action_recognition, animal_recognition, bad_image_detecting,
|
||||
crowd_counting, face_detection, face_generation,
|
||||
face_reconstruction, human_reconstruction, image_classification,
|
||||
image_color_enhance, image_colorization, image_defrcn_fewshot,
|
||||
image_denoise, image_inpainting, image_instance_segmentation,
|
||||
image_matching, image_mvs_depth_estimation,
|
||||
image_panoptic_segmentation, image_portrait_enhancement,
|
||||
image_probing_model, image_quality_assessment_degradation,
|
||||
image_denoise, image_editing, image_inpainting,
|
||||
image_instance_segmentation, image_matching,
|
||||
image_mvs_depth_estimation, image_panoptic_segmentation,
|
||||
image_portrait_enhancement, image_probing_model,
|
||||
image_quality_assessment_degradation,
|
||||
image_quality_assessment_man, image_quality_assessment_mos,
|
||||
image_reid_person, image_restoration,
|
||||
image_semantic_segmentation, image_to_image_generation,
|
||||
@@ -21,10 +22,11 @@ from . import (action_recognition, animal_recognition, bad_image_detecting,
|
||||
referring_video_object_segmentation,
|
||||
robust_image_classification, salient_detection,
|
||||
shop_segmentation, stream_yolo, super_resolution,
|
||||
table_recognition, video_deinterlace, video_frame_interpolation,
|
||||
video_object_segmentation, video_panoptic_segmentation,
|
||||
video_single_object_tracking, video_stabilization,
|
||||
video_summarization, video_super_resolution, vidt, virual_tryon,
|
||||
vision_middleware, vop_retrieval)
|
||||
surface_recon_common, table_recognition, video_deinterlace,
|
||||
video_frame_interpolation, video_object_segmentation,
|
||||
video_panoptic_segmentation, video_single_object_tracking,
|
||||
video_stabilization, video_summarization,
|
||||
video_super_resolution, vidt, virual_tryon, vision_middleware,
|
||||
vop_retrieval)
|
||||
|
||||
# yapf: enable
|
||||
|
||||
22
modelscope/models/cv/human_image_generation/__init__.py
Normal file
22
modelscope/models/cv/human_image_generation/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
||||
# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .human_image_generation_infer import FreqHPTForHumanImageGeneration
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
'human_image_generation_infer': ['FreqHPTForHumanImageGeneration']
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
@@ -0,0 +1,717 @@
|
||||
import collections
|
||||
import math
|
||||
import sys
|
||||
|
||||
import torch
|
||||
from pytorch_wavelets import DWTForward, DWTInverse
|
||||
from torch import kl_div, nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from modelscope.ops.human_image_generation.fused_act import (FusedLeakyReLU,
|
||||
fused_leaky_relu)
|
||||
from modelscope.ops.human_image_generation.upfirdn2d import upfirdn2d
|
||||
from .conv2d_gradfix import conv2d, conv_transpose2d
|
||||
from .wavelet_module import *
|
||||
|
||||
|
||||
# add flow
|
||||
class ExtractionOperation_flow(nn.Module):
|
||||
|
||||
def __init__(self, in_channel, num_label, match_kernel):
|
||||
super(ExtractionOperation_flow, self).__init__()
|
||||
self.value_conv = EqualConv2d(
|
||||
in_channel,
|
||||
in_channel,
|
||||
match_kernel,
|
||||
1,
|
||||
match_kernel // 2,
|
||||
bias=True)
|
||||
self.semantic_extraction_filter = EqualConv2d(
|
||||
in_channel,
|
||||
num_label,
|
||||
match_kernel,
|
||||
1,
|
||||
match_kernel // 2,
|
||||
bias=False)
|
||||
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
self.num_label = num_label
|
||||
|
||||
def forward(self, value, recoder):
|
||||
key = value
|
||||
b, c, h, w = value.shape
|
||||
key = self.semantic_extraction_filter(self.feature_norm(key))
|
||||
extraction_softmax = self.softmax(key.view(b, -1, h * w))
|
||||
values_flatten = self.value_conv(value).view(b, -1, h * w)
|
||||
neural_textures = torch.einsum('bkm,bvm->bvk', extraction_softmax,
|
||||
values_flatten)
|
||||
recoder['extraction_softmax'].insert(0, extraction_softmax)
|
||||
recoder['neural_textures'].insert(0, neural_textures)
|
||||
return neural_textures, extraction_softmax
|
||||
|
||||
def feature_norm(self, input_tensor):
|
||||
input_tensor = input_tensor - input_tensor.mean(dim=1, keepdim=True)
|
||||
norm = torch.norm(
|
||||
input_tensor, 2, 1, keepdim=True) + sys.float_info.epsilon
|
||||
out = torch.div(input_tensor, norm)
|
||||
return out
|
||||
|
||||
|
||||
class DistributionOperation_flow(nn.Module):
|
||||
|
||||
def __init__(self, num_label, input_dim, match_kernel=3):
|
||||
super(DistributionOperation_flow, self).__init__()
|
||||
self.semantic_distribution_filter = EqualConv2d(
|
||||
input_dim,
|
||||
num_label,
|
||||
kernel_size=match_kernel,
|
||||
stride=1,
|
||||
padding=match_kernel // 2)
|
||||
self.num_label = num_label
|
||||
|
||||
def forward(self, query, extracted_feature, recoder):
|
||||
b, c, h, w = query.shape
|
||||
|
||||
query = self.semantic_distribution_filter(query)
|
||||
query_flatten = query.view(b, self.num_label, -1)
|
||||
query_softmax = F.softmax(query_flatten, 1)
|
||||
values_q = torch.einsum('bkm,bkv->bvm', query_softmax,
|
||||
extracted_feature.permute(0, 2, 1))
|
||||
attn_out = values_q.view(b, -1, h, w)
|
||||
recoder['semantic_distribution'].append(query)
|
||||
return attn_out
|
||||
|
||||
|
||||
class EncoderLayer_flow(nn.Sequential):
|
||||
|
||||
def __init__(self,
|
||||
in_channel,
|
||||
out_channel,
|
||||
kernel_size,
|
||||
downsample=False,
|
||||
blur_kernel=[1, 3, 3, 1],
|
||||
bias=True,
|
||||
activate=True,
|
||||
use_extraction=False,
|
||||
num_label=None,
|
||||
match_kernel=None,
|
||||
num_extractions=2):
|
||||
super().__init__()
|
||||
|
||||
if downsample:
|
||||
factor = 2
|
||||
p = (len(blur_kernel) - factor) + (kernel_size - 1)
|
||||
pad0 = (p + 1) // 2
|
||||
pad1 = p // 2
|
||||
self.blur = Blur(blur_kernel, pad=(pad0, pad1))
|
||||
|
||||
stride = 2
|
||||
padding = 0
|
||||
|
||||
else:
|
||||
self.blur = None
|
||||
stride = 1
|
||||
padding = kernel_size // 2
|
||||
|
||||
self.conv = EqualConv2d(
|
||||
in_channel,
|
||||
out_channel,
|
||||
kernel_size,
|
||||
padding=padding,
|
||||
stride=stride,
|
||||
bias=bias and not activate,
|
||||
)
|
||||
|
||||
self.activate = FusedLeakyReLU(
|
||||
out_channel, bias=bias) if activate else None
|
||||
self.use_extraction = use_extraction
|
||||
if self.use_extraction:
|
||||
self.extraction_operations = nn.ModuleList()
|
||||
for _ in range(num_extractions):
|
||||
self.extraction_operations.append(
|
||||
ExtractionOperation_flow(out_channel, num_label,
|
||||
match_kernel))
|
||||
|
||||
def forward(self, input, recoder=None):
|
||||
out = self.blur(input) if self.blur is not None else input
|
||||
out = self.conv(out)
|
||||
out = self.activate(out) if self.activate is not None else out
|
||||
if self.use_extraction:
|
||||
for extraction_operation in self.extraction_operations:
|
||||
extraction_operation(out, recoder)
|
||||
return out
|
||||
|
||||
|
||||
class DecoderLayer_flow_wavelet_fuse24(nn.Module):
|
||||
|
||||
# add fft refinement and tps
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channel,
|
||||
out_channel,
|
||||
kernel_size,
|
||||
upsample=False,
|
||||
blur_kernel=[1, 3, 3, 1],
|
||||
bias=True,
|
||||
activate=True,
|
||||
use_distribution=True,
|
||||
num_label=16,
|
||||
match_kernel=3,
|
||||
wavelet_down_level=False,
|
||||
window_size=8,
|
||||
):
|
||||
super().__init__()
|
||||
if upsample:
|
||||
factor = 2
|
||||
p = (len(blur_kernel) - factor) - (kernel_size - 1)
|
||||
pad0 = (p + 1) // 2 + factor - 1
|
||||
pad1 = p // 2 + 1
|
||||
|
||||
self.blur = Blur(
|
||||
blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
|
||||
self.conv = EqualTransposeConv2d(
|
||||
in_channel,
|
||||
out_channel,
|
||||
kernel_size,
|
||||
stride=2,
|
||||
padding=0,
|
||||
bias=bias and not activate,
|
||||
)
|
||||
else:
|
||||
self.conv = EqualConv2d(
|
||||
in_channel,
|
||||
out_channel,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=kernel_size // 2,
|
||||
bias=bias and not activate,
|
||||
)
|
||||
self.blur = None
|
||||
|
||||
self.distribution_operation = DistributionOperation_flow(
|
||||
num_label, out_channel,
|
||||
match_kernel=match_kernel) if use_distribution else None
|
||||
self.activate = FusedLeakyReLU(
|
||||
out_channel, bias=bias) if activate else None
|
||||
self.use_distribution = use_distribution
|
||||
|
||||
# mask prediction network
|
||||
if use_distribution:
|
||||
self.conv_mask_lf = nn.Sequential(*[
|
||||
EqualConv2d(
|
||||
out_channel, 1, 3, stride=1, padding=3 // 2, bias=False),
|
||||
nn.Sigmoid()
|
||||
])
|
||||
self.conv_mask_dict = nn.ModuleDict()
|
||||
for level in range(wavelet_down_level):
|
||||
conv_mask = nn.Sequential(*[
|
||||
EqualConv2d(
|
||||
out_channel,
|
||||
1,
|
||||
3,
|
||||
stride=1,
|
||||
padding=3 // 2,
|
||||
bias=False),
|
||||
nn.Sigmoid()
|
||||
])
|
||||
self.conv_mask_dict[str(level)] = conv_mask
|
||||
|
||||
self.wavelet_down_level = wavelet_down_level
|
||||
if wavelet_down_level:
|
||||
self.dwt = DWTForward(
|
||||
J=self.wavelet_down_level, mode='zero', wave='haar')
|
||||
self.idwt = DWTInverse(mode='zero', wave='haar')
|
||||
|
||||
# for mask input channel squeeze and expand
|
||||
self.conv_l_squeeze = EqualConv2d(
|
||||
2 * out_channel, out_channel, 1, 1, 0, bias=False)
|
||||
self.conv_h_squeeze = EqualConv2d(
|
||||
6 * out_channel, out_channel, 1, 1, 0, bias=False)
|
||||
|
||||
self.conv_l = EqualConv2d(
|
||||
out_channel, out_channel, 3, 1, 3 // 2, bias=False)
|
||||
|
||||
self.hf_modules = nn.ModuleDict()
|
||||
for level in range(wavelet_down_level):
|
||||
hf_module = nn.Module()
|
||||
prev_channel = out_channel if level == self.wavelet_down_level - 1 else 3 * out_channel
|
||||
hf_module.conv_prev = EqualConv2d(
|
||||
prev_channel, 3 * out_channel, 3, 1, 3 // 2, bias=False)
|
||||
hf_module.conv_hf = GatedConv2dWithActivation(
|
||||
3 * out_channel, 3 * out_channel, 3, 1, 3 // 2, bias=False)
|
||||
hf_module.conv_out = GatedConv2dWithActivation(
|
||||
3 * out_channel, 3 * out_channel, 3, 1, 3 // 2, bias=False)
|
||||
self.hf_modules[str(level)] = hf_module
|
||||
|
||||
self.amp_fuse = nn.Sequential(
|
||||
EqualConv2d(2 * out_channel, out_channel, 1, 1, 0),
|
||||
FusedLeakyReLU(out_channel, bias=False),
|
||||
EqualConv2d(out_channel, out_channel, 1, 1, 0))
|
||||
self.pha_fuse = nn.Sequential(
|
||||
EqualConv2d(2 * out_channel, out_channel, 1, 1, 0),
|
||||
FusedLeakyReLU(out_channel, bias=False),
|
||||
EqualConv2d(out_channel, out_channel, 1, 1, 0))
|
||||
self.post = EqualConv2d(out_channel, out_channel, 1, 1, 0)
|
||||
self.eps = 1e-8
|
||||
|
||||
def forward(self,
|
||||
input,
|
||||
neural_texture=None,
|
||||
recoder=None,
|
||||
warped_texture=None,
|
||||
style_net=None,
|
||||
gstyle=None):
|
||||
out = self.conv(input)
|
||||
out = self.blur(out) if self.blur is not None else out
|
||||
|
||||
mask_l, mask_h = None, None
|
||||
out_attn = None
|
||||
if self.use_distribution and neural_texture is not None:
|
||||
out_ori = out
|
||||
out_attn = self.distribution_operation(out, neural_texture,
|
||||
recoder)
|
||||
# wavelet fusion
|
||||
if self.wavelet_down_level:
|
||||
assert out.shape[2] % 2 == 0, \
|
||||
f'out shape {out.shape} is not appropriate for processing'
|
||||
b, c, h, w = out.shape
|
||||
|
||||
# wavelet decomposition
|
||||
LF_attn, HF_attn = self.dwt(out_attn)
|
||||
LF_warp, HF_warp = self.dwt(warped_texture)
|
||||
LF_out, HF_out = self.dwt(out)
|
||||
|
||||
# generate mask
|
||||
hf_dict = {}
|
||||
l_mask_input = torch.cat([LF_attn, LF_warp], dim=1)
|
||||
l_mask_input = self.conv_l_squeeze(l_mask_input)
|
||||
l_mask_input = style_net(l_mask_input, gstyle)
|
||||
ml = self.conv_mask_lf(l_mask_input)
|
||||
mask_l = ml
|
||||
|
||||
for level in range(self.wavelet_down_level):
|
||||
# level up, feature size down
|
||||
scale = 2**(level + 1)
|
||||
hfa = HF_attn[level].view(b, c * 3, h // scale, w // scale)
|
||||
hfw = HF_warp[level].view(b, c * 3, h // scale, w // scale)
|
||||
hfg = HF_out[level].view(b, c * 3, h // scale, w // scale)
|
||||
|
||||
h_mask_input = torch.cat([hfa, hfw], dim=1)
|
||||
h_mask_input = self.conv_h_squeeze(h_mask_input)
|
||||
h_mask_input = style_net(h_mask_input, gstyle)
|
||||
mh = self.conv_mask_dict[str(level)](h_mask_input)
|
||||
if level == 0:
|
||||
mask_h = mh
|
||||
|
||||
# fuse high frequency
|
||||
xh = (mh * hfa + (1 - mh) * hfw + hfg) / math.sqrt(2)
|
||||
hf_dict[str(level)] = xh
|
||||
|
||||
temp_result = (1 - ml) * LF_warp + LF_out
|
||||
out_l = (ml * LF_attn + temp_result) / math.sqrt(2)
|
||||
out_h_list = []
|
||||
for level in range(self.wavelet_down_level - 1, -1, -1):
|
||||
xh = hf_dict[str(level)]
|
||||
b, c, h, w = xh.shape
|
||||
out_h_list.append(xh.view(b, c // 3, 3, h, w))
|
||||
out_h_list = (
|
||||
out_h_list)[::-1] # the h list from large to small size
|
||||
#
|
||||
out = self.idwt((out_l, out_h_list))
|
||||
else:
|
||||
out = (out + out_attn) / math.sqrt(2)
|
||||
|
||||
# fourier refinement
|
||||
_, _, H, W = out.shape
|
||||
fuseF = torch.fft.rfft2(out + self.eps, norm='backward')
|
||||
outF = torch.fft.rfft2(out_ori + self.eps, norm='backward')
|
||||
amp = self.amp_fuse(
|
||||
torch.cat([torch.abs(fuseF), torch.abs(outF)], 1))
|
||||
pha = self.pha_fuse(
|
||||
torch.cat(
|
||||
[torch.angle(fuseF), torch.angle(outF)], 1))
|
||||
out_fft = torch.fft.irfft2(
|
||||
amp * torch.exp(1j * pha) + self.eps,
|
||||
s=(H, W),
|
||||
dim=(-2, -1),
|
||||
norm='backward')
|
||||
|
||||
out = out + self.post(out_fft)
|
||||
|
||||
out = self.activate(
|
||||
out.contiguous()) if self.activate is not None else out
|
||||
return out, mask_h, mask_l
|
||||
|
||||
|
||||
# base functions
|
||||
|
||||
|
||||
class EqualConv2d(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
in_channel,
|
||||
out_channel,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=True,
|
||||
dilation=1):
|
||||
super().__init__()
|
||||
|
||||
self.weight = nn.Parameter(
|
||||
torch.randn(out_channel, in_channel, kernel_size, kernel_size))
|
||||
self.scale = 1 / math.sqrt(in_channel * kernel_size**2)
|
||||
|
||||
self.stride = stride
|
||||
self.padding = padding
|
||||
self.dilation = dilation
|
||||
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.zeros(out_channel))
|
||||
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
def forward(self, input):
|
||||
out = conv2d(
|
||||
input,
|
||||
self.weight * self.scale,
|
||||
bias=self.bias,
|
||||
stride=self.stride,
|
||||
padding=self.padding,
|
||||
dilation=self.dilation)
|
||||
|
||||
return out
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},'
|
||||
f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'
|
||||
)
|
||||
|
||||
|
||||
class EqualTransposeConv2d(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
in_channel,
|
||||
out_channel,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=True):
|
||||
super().__init__()
|
||||
|
||||
self.weight = nn.Parameter(
|
||||
torch.randn(out_channel, in_channel, kernel_size, kernel_size))
|
||||
self.scale = 1 / math.sqrt(in_channel * kernel_size**2)
|
||||
|
||||
self.stride = stride
|
||||
self.padding = padding
|
||||
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.zeros(out_channel))
|
||||
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
def forward(self, input):
|
||||
weight = self.weight.transpose(0, 1)
|
||||
out = conv_transpose2d(
|
||||
input,
|
||||
weight * self.scale,
|
||||
bias=self.bias,
|
||||
stride=self.stride,
|
||||
padding=self.padding,
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},'
|
||||
f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'
|
||||
)
|
||||
|
||||
|
||||
class ToRGB(nn.Module):
|
||||
|
||||
def __init__(self, in_channel, upsample=True, blur_kernel=[1, 3, 3, 1]):
|
||||
super().__init__()
|
||||
|
||||
if upsample:
|
||||
self.upsample = Upsample(blur_kernel)
|
||||
self.conv = EqualConv2d(in_channel, 3, 3, stride=1, padding=1)
|
||||
|
||||
def forward(self, input, skip=None):
|
||||
out = self.conv(input)
|
||||
if skip is not None:
|
||||
skip = self.upsample(skip)
|
||||
out = out + skip
|
||||
return out
|
||||
|
||||
|
||||
class EqualLinear(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
in_dim,
|
||||
out_dim,
|
||||
bias=True,
|
||||
bias_init=0,
|
||||
lr_mul=1,
|
||||
activation=None):
|
||||
super().__init__()
|
||||
|
||||
self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
|
||||
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
|
||||
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
self.activation = activation
|
||||
|
||||
self.scale = (1 / math.sqrt(in_dim)) * lr_mul
|
||||
self.lr_mul = lr_mul
|
||||
|
||||
def forward(self, input):
|
||||
if self.activation:
|
||||
out = F.linear(input, self.weight * self.scale)
|
||||
out = fused_leaky_relu(out, self.bias * self.lr_mul)
|
||||
|
||||
else:
|
||||
out = F.linear(
|
||||
input, self.weight * self.scale, bias=self.bias * self.lr_mul)
|
||||
|
||||
return out
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})'
|
||||
)
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
|
||||
def __init__(self, kernel, factor=2):
|
||||
super().__init__()
|
||||
|
||||
self.factor = factor
|
||||
kernel = make_kernel(kernel) * (factor**2)
|
||||
self.register_buffer('kernel', kernel)
|
||||
|
||||
p = kernel.shape[0] - factor
|
||||
|
||||
pad0 = (p + 1) // 2 + factor - 1
|
||||
pad1 = p // 2
|
||||
|
||||
self.pad = (pad0, pad1)
|
||||
|
||||
def forward(self, input):
|
||||
out = upfirdn2d(
|
||||
input, self.kernel, up=self.factor, down=1, pad=self.pad)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
in_channel,
|
||||
out_channel,
|
||||
blur_kernel=[1, 3, 3, 1],
|
||||
downsample=True):
|
||||
super().__init__()
|
||||
|
||||
self.conv1 = ConvLayer(in_channel, in_channel, 3)
|
||||
self.conv2 = ConvLayer(
|
||||
in_channel, out_channel, 3, downsample=downsample)
|
||||
|
||||
self.skip = ConvLayer(
|
||||
in_channel,
|
||||
out_channel,
|
||||
1,
|
||||
downsample=downsample,
|
||||
activate=False,
|
||||
bias=False)
|
||||
|
||||
def forward(self, input):
|
||||
out = self.conv1(input)
|
||||
out = self.conv2(out)
|
||||
|
||||
skip = self.skip(input)
|
||||
out = (out + skip) / math.sqrt(2)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ConvLayer(nn.Sequential):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channel,
|
||||
out_channel,
|
||||
kernel_size,
|
||||
downsample=False,
|
||||
blur_kernel=[1, 3, 3, 1],
|
||||
bias=True,
|
||||
activate=True,
|
||||
):
|
||||
layers = []
|
||||
|
||||
if downsample:
|
||||
factor = 2
|
||||
p = (len(blur_kernel) - factor) + (kernel_size - 1)
|
||||
pad0 = (p + 1) // 2
|
||||
pad1 = p // 2
|
||||
|
||||
layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
|
||||
|
||||
stride = 2
|
||||
self.padding = 0
|
||||
|
||||
else:
|
||||
stride = 1
|
||||
self.padding = kernel_size // 2
|
||||
|
||||
layers.append(
|
||||
EqualConv2d(
|
||||
in_channel,
|
||||
out_channel,
|
||||
kernel_size,
|
||||
padding=self.padding,
|
||||
stride=stride,
|
||||
bias=bias and not activate,
|
||||
))
|
||||
|
||||
if activate:
|
||||
layers.append(FusedLeakyReLU(out_channel, bias=bias))
|
||||
|
||||
super().__init__(*layers)
|
||||
|
||||
|
||||
class Blur(nn.Module):
|
||||
|
||||
def __init__(self, kernel, pad, upsample_factor=1):
|
||||
super().__init__()
|
||||
|
||||
kernel = make_kernel(kernel)
|
||||
|
||||
if upsample_factor > 1:
|
||||
kernel = kernel * (upsample_factor**2)
|
||||
|
||||
self.register_buffer('kernel', kernel)
|
||||
|
||||
self.pad = pad
|
||||
|
||||
def forward(self, input):
|
||||
out = upfirdn2d(input, self.kernel, pad=self.pad)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class GatedConv2dWithActivation(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
bias=True,
|
||||
activation=None):
|
||||
super(GatedConv2dWithActivation, self).__init__()
|
||||
self.activation = FusedLeakyReLU(out_channels, bias=False)
|
||||
self.conv2d = EqualConv2d(in_channels, out_channels, kernel_size,
|
||||
stride, padding, bias, dilation)
|
||||
self.mask_conv2d = EqualConv2d(in_channels, out_channels, kernel_size,
|
||||
stride, padding, bias, dilation)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
def gated(self, mask):
|
||||
return self.sigmoid(mask)
|
||||
|
||||
def forward(self, input):
|
||||
x = self.conv2d(input)
|
||||
mask = self.mask_conv2d(input)
|
||||
if self.activation is not None:
|
||||
x = self.activation(x) * self.gated(mask)
|
||||
else:
|
||||
x = x * self.gated(mask)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def make_kernel(k):
|
||||
k = torch.tensor(k, dtype=torch.float32)
|
||||
|
||||
if k.ndim == 1:
|
||||
k = k[None, :] * k[:, None]
|
||||
|
||||
k /= k.sum()
|
||||
|
||||
return k
|
||||
|
||||
|
||||
class SPDNorm(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
norm_channel,
|
||||
label_nc,
|
||||
norm_type='position',
|
||||
use_equal=False):
|
||||
super().__init__()
|
||||
param_free_norm_type = norm_type
|
||||
ks = 3
|
||||
if param_free_norm_type == 'instance':
|
||||
self.param_free_norm = nn.InstanceNorm2d(
|
||||
norm_channel, affine=False)
|
||||
elif param_free_norm_type == 'batch':
|
||||
self.param_free_norm = nn.BatchNorm2d(norm_channel, affine=False)
|
||||
elif param_free_norm_type == 'position':
|
||||
self.param_free_norm = PositionalNorm2d
|
||||
else:
|
||||
raise ValueError(
|
||||
'%s is not a recognized param-free norm type in SPADE'
|
||||
% param_free_norm_type)
|
||||
|
||||
# The dimension of the intermediate embedding space. Yes, hardcoded.
|
||||
pw = ks // 2
|
||||
nhidden = 128
|
||||
if not use_equal:
|
||||
self.mlp_activate = nn.Sequential(
|
||||
nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw),
|
||||
nn.ReLU())
|
||||
self.mlp_gamma = nn.Conv2d(
|
||||
nhidden, norm_channel, kernel_size=ks, padding=pw)
|
||||
self.mlp_beta = nn.Conv2d(
|
||||
nhidden, norm_channel, kernel_size=ks, padding=pw)
|
||||
else:
|
||||
self.mlp_activate = nn.Sequential(*[
|
||||
EqualConv2d(label_nc, nhidden, kernel_size=ks, padding=pw),
|
||||
FusedLeakyReLU(nhidden, bias=False)
|
||||
])
|
||||
self.mlp_gamma = EqualConv2d(
|
||||
nhidden, norm_channel, kernel_size=ks, padding=pw)
|
||||
self.mlp_beta = EqualConv2d(
|
||||
nhidden, norm_channel, kernel_size=ks, padding=pw)
|
||||
|
||||
def forward(self, x, prior_f, weight=1.0):
|
||||
normalized = self.param_free_norm(x)
|
||||
# Part 2. produce scaling and bias conditioned on condition feature
|
||||
actv = self.mlp_activate(prior_f)
|
||||
gamma = self.mlp_gamma(actv) * weight
|
||||
beta = self.mlp_beta(actv) * weight
|
||||
# apply scale and bias
|
||||
out = normalized * (1 + gamma) + beta
|
||||
return out
|
||||
|
||||
|
||||
def PositionalNorm2d(x, epsilon=1e-5):
|
||||
# x: B*C*W*H normalize in C dim
|
||||
mean = x.mean(dim=1, keepdim=True)
|
||||
std = x.var(dim=1, keepdim=True).add(epsilon).sqrt()
|
||||
output = (x - mean) / std
|
||||
return output
|
||||
@@ -0,0 +1,358 @@
|
||||
import collections
|
||||
import functools
|
||||
import math
|
||||
from tkinter.ttk import Style
|
||||
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .base_function import *
|
||||
from .flow_module import MaskStyle, StyleFlow
|
||||
from .tps import TPS
|
||||
|
||||
|
||||
# adding flow version
|
||||
class Encoder_wiflow(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size,
|
||||
input_dim,
|
||||
channels,
|
||||
num_labels=None,
|
||||
match_kernels=None,
|
||||
blur_kernel=[1, 3, 3, 1],
|
||||
):
|
||||
super().__init__()
|
||||
self.first = EncoderLayer_flow(input_dim, channels[size], 1)
|
||||
self.convs = nn.ModuleList()
|
||||
self.num_labels = num_labels
|
||||
self.match_kernels = match_kernels
|
||||
|
||||
log_size = int(math.log(size, 2))
|
||||
self.log_size = log_size
|
||||
|
||||
in_channel = channels[size]
|
||||
for i in range(log_size - 1, 3, -1):
|
||||
out_channel = channels[2**i]
|
||||
num_label = num_labels[2**i] if num_labels is not None else None
|
||||
match_kernel = match_kernels[
|
||||
2**i] if match_kernels is not None else None
|
||||
use_extraction = num_label and match_kernel
|
||||
conv = EncoderLayer_flow(
|
||||
in_channel,
|
||||
out_channel,
|
||||
kernel_size=3,
|
||||
downsample=True,
|
||||
blur_kernel=blur_kernel,
|
||||
use_extraction=use_extraction,
|
||||
num_label=num_label,
|
||||
match_kernel=match_kernel)
|
||||
|
||||
self.convs.append(conv)
|
||||
in_channel = out_channel
|
||||
|
||||
def forward(self, input, recoder=None, out_list=None):
|
||||
out = self.first(input)
|
||||
for layer in self.convs:
|
||||
out = layer(out, recoder)
|
||||
if out_list is not None:
|
||||
out_list.append(out)
|
||||
return out
|
||||
|
||||
|
||||
class Decoder_wiflow_wavelet_fuse25(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size,
|
||||
channels,
|
||||
num_labels,
|
||||
match_kernels,
|
||||
blur_kernel=[1, 3, 3, 1],
|
||||
wavelet_down_levels={'16': 3},
|
||||
window_size=8,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.convs = nn.ModuleList()
|
||||
# input at resolution 16*16
|
||||
in_channel = channels[16]
|
||||
self.log_size = int(math.log(size, 2))
|
||||
self.conv_mask_dict = nn.ModuleDict()
|
||||
self.conv_mask_fuse_dict = nn.ModuleDict()
|
||||
|
||||
flow_fusion = False
|
||||
|
||||
for i in range(4, self.log_size + 1):
|
||||
out_channel = channels[2**i]
|
||||
num_label, match_kernel = num_labels[2**i], match_kernels[2**i]
|
||||
use_distribution = num_label and match_kernel
|
||||
upsample = (i != 4)
|
||||
wavelet_down_level = wavelet_down_levels[(2**i)]
|
||||
base_layer = functools.partial(
|
||||
DecoderLayer_flow_wavelet_fuse24,
|
||||
out_channel=out_channel,
|
||||
kernel_size=3,
|
||||
blur_kernel=blur_kernel,
|
||||
use_distribution=use_distribution,
|
||||
num_label=num_label,
|
||||
match_kernel=match_kernel,
|
||||
wavelet_down_level=wavelet_down_level,
|
||||
window_size=window_size)
|
||||
# mask head for fusion
|
||||
if use_distribution:
|
||||
conv_mask = [
|
||||
EqualConv2d(
|
||||
2 * out_channel,
|
||||
3,
|
||||
3,
|
||||
stride=1,
|
||||
padding=3 // 2,
|
||||
bias=False),
|
||||
nn.Sigmoid()
|
||||
]
|
||||
conv_mask = nn.Sequential(*conv_mask)
|
||||
self.conv_mask_dict[str(2**i)] = conv_mask
|
||||
|
||||
if not i == 4:
|
||||
conv_mask_fuse = nn.Sequential(*[
|
||||
EqualConv2d(
|
||||
2, 1, 3, stride=1, padding=3 // 2, bias=False),
|
||||
nn.Sigmoid()
|
||||
])
|
||||
self.conv_mask_fuse_dict[str(2**i)] = conv_mask_fuse
|
||||
|
||||
if not flow_fusion:
|
||||
self.conv_flow_fusion = nn.Sequential(
|
||||
EqualConv2d(
|
||||
2 * out_channel,
|
||||
1,
|
||||
kernel_size=7,
|
||||
stride=1,
|
||||
padding=3,
|
||||
bias=False), nn.Sigmoid())
|
||||
flow_fusion = True
|
||||
|
||||
up = nn.Module()
|
||||
up.conv0 = base_layer(in_channel=in_channel, upsample=upsample)
|
||||
up.conv1 = base_layer(in_channel=out_channel, upsample=False)
|
||||
up.to_rgb = ToRGB(out_channel, upsample=upsample)
|
||||
self.convs.append(up)
|
||||
in_channel = out_channel
|
||||
|
||||
style_in_channels = channels[16]
|
||||
self.style_out_channel = 128
|
||||
self.cond_style = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
style_in_channels,
|
||||
self.style_out_channel,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1), nn.LeakyReLU(inplace=False, negative_slope=0.1),
|
||||
nn.AdaptiveAvgPool2d(1))
|
||||
self.image_style = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
style_in_channels,
|
||||
self.style_out_channel,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1), nn.LeakyReLU(inplace=False, negative_slope=0.1),
|
||||
nn.AdaptiveAvgPool2d(1))
|
||||
self.flow_model = StyleFlow(
|
||||
channels, self.log_size, style_in=2 * self.style_out_channel)
|
||||
|
||||
self.num_labels, self.match_kernels = num_labels, match_kernels
|
||||
|
||||
# for mask prediction
|
||||
self.mask_style = MaskStyle(
|
||||
channels,
|
||||
self.log_size,
|
||||
style_in=2 * self.style_out_channel,
|
||||
channels_multiplier=1)
|
||||
|
||||
# tps transformation
|
||||
self.tps = TPS()
|
||||
|
||||
def forward(self,
|
||||
input,
|
||||
neural_textures,
|
||||
skeleton_features,
|
||||
source_features,
|
||||
kp_skeleton,
|
||||
recoder,
|
||||
add_nted=True):
|
||||
source_features = source_features[::-1]
|
||||
skeleton_features = skeleton_features[::-1]
|
||||
|
||||
counter = 0
|
||||
out, skip = input, None
|
||||
|
||||
last_flow = None
|
||||
mask_all_h, mask_all_l = [], []
|
||||
delta_list = []
|
||||
delta_x_all = []
|
||||
delta_y_all = []
|
||||
last_flow_all = []
|
||||
filter_x = [[0, 0, 0], [1, -2, 1], [0, 0, 0]]
|
||||
filter_y = [[0, 1, 0], [0, -2, 0], [0, 1, 0]]
|
||||
filter_diag1 = [[1, 0, 0], [0, -2, 0], [0, 0, 1]]
|
||||
filter_diag2 = [[0, 0, 1], [0, -2, 0], [1, 0, 0]]
|
||||
weight_array = np.ones([3, 3, 1, 4])
|
||||
weight_array[:, :, 0, 0] = filter_x
|
||||
weight_array[:, :, 0, 1] = filter_y
|
||||
weight_array[:, :, 0, 2] = filter_diag1
|
||||
weight_array[:, :, 0, 3] = filter_diag2
|
||||
weight_array = torch.FloatTensor(weight_array).permute(3, 2, 0, 1).to(
|
||||
input.device)
|
||||
self.weight = nn.Parameter(data=weight_array, requires_grad=False)
|
||||
|
||||
B = source_features[0].shape[0]
|
||||
source_style = self.cond_style(source_features[0]).view(B, -1)
|
||||
target_style = self.image_style(skeleton_features[0]).view(B, -1)
|
||||
style = torch.cat([source_style, target_style], 1)
|
||||
|
||||
for i, up in enumerate(self.convs):
|
||||
use_distribution = (
|
||||
self.num_labels[2**(i + 4)] and self.match_kernels[2**(i + 4)])
|
||||
if use_distribution:
|
||||
# warp features with styleflow
|
||||
source_feature = source_features[i]
|
||||
skeleton_feature = skeleton_features[i]
|
||||
if last_flow is not None:
|
||||
last_flow = F.interpolate(
|
||||
last_flow, scale_factor=2, mode='bilinear')
|
||||
s_warp_after = F.grid_sample(
|
||||
source_feature,
|
||||
last_flow.detach().permute(0, 2, 3, 1),
|
||||
mode='bilinear',
|
||||
padding_mode='border')
|
||||
else:
|
||||
s_warp_after = source_feature
|
||||
scale = str(2**(i + 4))
|
||||
|
||||
# use tps transformation to estimate flow at the very beginning
|
||||
if last_flow is not None:
|
||||
style_map = self.flow_model.netStyle[scale](s_warp_after,
|
||||
style)
|
||||
flow = self.flow_model.netF[scale](style_map, style)
|
||||
flow = apply_offset(flow)
|
||||
|
||||
else:
|
||||
style_map = self.flow_model.netStyle[scale](s_warp_after,
|
||||
style)
|
||||
flow = self.flow_model.netF[scale](style_map, style)
|
||||
flow_dense = apply_offset(flow)
|
||||
flow_tps = self.tps(source_feature, kp_skeleton)
|
||||
warped_dense = F.grid_sample(
|
||||
source_feature,
|
||||
flow_dense,
|
||||
mode='bilinear',
|
||||
padding_mode='border')
|
||||
warped_tps = F.grid_sample(
|
||||
source_feature,
|
||||
flow_tps,
|
||||
mode='bilinear',
|
||||
padding_mode='border')
|
||||
contribution_map = self.conv_flow_fusion(
|
||||
torch.cat([warped_dense, warped_tps], 1))
|
||||
flow = contribution_map * flow_tps.permute(0, 3, 1, 2) + (
|
||||
1 - contribution_map) * flow_dense.permute(0, 3, 1, 2)
|
||||
flow = flow.permute(0, 2, 3, 1).contiguous()
|
||||
|
||||
if last_flow is not None:
|
||||
# update flow according to the last scale flow
|
||||
flow = F.grid_sample(
|
||||
last_flow,
|
||||
flow,
|
||||
mode='bilinear',
|
||||
padding_mode='border')
|
||||
else:
|
||||
flow = flow.permute(0, 3, 1, 2)
|
||||
|
||||
last_flow = flow
|
||||
s_warp = F.grid_sample(
|
||||
source_feature,
|
||||
flow.permute(0, 2, 3, 1),
|
||||
mode='bilinear',
|
||||
padding_mode='border')
|
||||
|
||||
# refine flow according to the original flow
|
||||
flow = self.flow_model.netRefine[scale](
|
||||
torch.cat([s_warp, skeleton_feature], 1))
|
||||
|
||||
delta_list.append(flow)
|
||||
flow = apply_offset(flow)
|
||||
flow = F.grid_sample(
|
||||
last_flow, flow, mode='bilinear', padding_mode='border')
|
||||
last_flow_all.append(flow)
|
||||
|
||||
last_flow = flow
|
||||
flow_x, flow_y = torch.split(last_flow, 1, dim=1)
|
||||
delta_x = F.conv2d(flow_x, self.weight)
|
||||
delta_y = F.conv2d(flow_y, self.weight)
|
||||
delta_x_all.append(delta_x)
|
||||
delta_y_all.append(delta_y)
|
||||
|
||||
s_warp = F.grid_sample(
|
||||
source_feature,
|
||||
last_flow.permute(0, 2, 3, 1),
|
||||
mode='bilinear',
|
||||
padding_mode='border')
|
||||
|
||||
# nted attention
|
||||
neural_texture_conv0 = neural_textures[counter]
|
||||
neural_texture_conv1 = neural_textures[counter + 1]
|
||||
counter += 2
|
||||
|
||||
if not add_nted: # turn off the nted attention
|
||||
neural_texture_conv0, neural_texture_conv1 = None, None
|
||||
else:
|
||||
neural_texture_conv0, neural_texture_conv1 = None, None
|
||||
s_warp = None
|
||||
|
||||
mask_style_net = self.mask_style.netM[
|
||||
scale] if use_distribution else None
|
||||
out, mask_h, mask_l = up.conv0(
|
||||
out,
|
||||
neural_texture=neural_texture_conv0,
|
||||
recoder=recoder,
|
||||
warped_texture=s_warp,
|
||||
style_net=mask_style_net,
|
||||
gstyle=style)
|
||||
out, mask_h, mask_l = up.conv1(
|
||||
out,
|
||||
neural_texture=neural_texture_conv1,
|
||||
recoder=recoder,
|
||||
warped_texture=s_warp,
|
||||
style_net=mask_style_net,
|
||||
gstyle=style)
|
||||
if use_distribution:
|
||||
if mask_h is not None:
|
||||
mask_all_h.append(mask_h)
|
||||
if mask_l is not None:
|
||||
mask_all_l.append(mask_l)
|
||||
skip = up.to_rgb(out, skip)
|
||||
|
||||
image = skip
|
||||
return image, delta_x_all, delta_y_all, delta_list, last_flow_all, mask_all_h, mask_all_l
|
||||
|
||||
|
||||
def apply_offset(offset):
|
||||
sizes = list(offset.size()[2:])
|
||||
grid_list = torch.meshgrid(
|
||||
[torch.arange(size, device=offset.device) for size in sizes])
|
||||
grid_list = reversed(grid_list)
|
||||
# apply offset
|
||||
grid_list = [
|
||||
grid.float().unsqueeze(0) + offset[:, dim, ...]
|
||||
for dim, grid in enumerate(grid_list)
|
||||
]
|
||||
# normalize
|
||||
grid_list = [
|
||||
grid / ((size - 1.0) / 2.0) - 1.0
|
||||
for grid, size in zip(grid_list, reversed(sizes))
|
||||
]
|
||||
|
||||
return torch.stack(grid_list, dim=-1)
|
||||
@@ -0,0 +1,227 @@
|
||||
import contextlib
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
from torch import autograd
|
||||
from torch.nn import functional as F
|
||||
|
||||
enabled = True
|
||||
weight_gradients_disabled = False
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def no_weight_gradients():
|
||||
global weight_gradients_disabled
|
||||
|
||||
old = weight_gradients_disabled
|
||||
weight_gradients_disabled = True
|
||||
yield
|
||||
weight_gradients_disabled = old
|
||||
|
||||
|
||||
def conv2d(input,
|
||||
weight,
|
||||
bias=None,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
groups=1):
|
||||
if could_use_op(input):
|
||||
return conv2d_gradfix(
|
||||
transpose=False,
|
||||
weight_shape=weight.shape,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
output_padding=0,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
).apply(input, weight, bias)
|
||||
|
||||
return F.conv2d(
|
||||
input=input,
|
||||
weight=weight,
|
||||
bias=bias,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
)
|
||||
|
||||
|
||||
def conv_transpose2d(
|
||||
input,
|
||||
weight,
|
||||
bias=None,
|
||||
stride=1,
|
||||
padding=0,
|
||||
output_padding=0,
|
||||
groups=1,
|
||||
dilation=1,
|
||||
):
|
||||
if could_use_op(input):
|
||||
return conv2d_gradfix(
|
||||
transpose=True,
|
||||
weight_shape=weight.shape,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
output_padding=output_padding,
|
||||
groups=groups,
|
||||
dilation=dilation,
|
||||
).apply(input, weight, bias)
|
||||
|
||||
return F.conv_transpose2d(
|
||||
input=input,
|
||||
weight=weight,
|
||||
bias=bias,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
output_padding=output_padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
)
|
||||
|
||||
|
||||
def could_use_op(input):
|
||||
if (not enabled) or (not torch.backends.cudnn.enabled):
|
||||
return False
|
||||
|
||||
if input.device.type != 'cuda':
|
||||
return False
|
||||
|
||||
if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.']):
|
||||
return True
|
||||
|
||||
warnings.warn(
|
||||
f'conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d().'
|
||||
)
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def ensure_tuple(xs, ndim):
|
||||
xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs, ) * ndim
|
||||
|
||||
return xs
|
||||
|
||||
|
||||
conv2d_gradfix_cache = dict()
|
||||
|
||||
|
||||
def conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding,
|
||||
dilation, groups):
|
||||
ndim = 2
|
||||
weight_shape = tuple(weight_shape)
|
||||
stride = ensure_tuple(stride, ndim)
|
||||
padding = ensure_tuple(padding, ndim)
|
||||
output_padding = ensure_tuple(output_padding, ndim)
|
||||
dilation = ensure_tuple(dilation, ndim)
|
||||
|
||||
key = (transpose, weight_shape, stride, padding, output_padding, dilation,
|
||||
groups)
|
||||
if key in conv2d_gradfix_cache:
|
||||
return conv2d_gradfix_cache[key]
|
||||
|
||||
common_kwargs = dict(
|
||||
stride=stride, padding=padding, dilation=dilation, groups=groups)
|
||||
|
||||
def calc_output_padding(input_shape, output_shape):
|
||||
if transpose:
|
||||
return [0, 0]
|
||||
|
||||
shape1 = (output_shape[i + 2] - 1) * stride[i]
|
||||
shape2 = (1 - 2 * padding[i]) - dilation[i] * (weight_shape[i + 2] - 1)
|
||||
return [input_shape[i + 2] - shape1 - shape2 for i in range(ndim)]
|
||||
|
||||
class Conv2d(autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input, weight, bias):
|
||||
if not transpose:
|
||||
out = F.conv2d(
|
||||
input=input, weight=weight, bias=bias, **common_kwargs)
|
||||
|
||||
else:
|
||||
out = F.conv_transpose2d(
|
||||
input=input,
|
||||
weight=weight,
|
||||
bias=bias,
|
||||
output_padding=output_padding,
|
||||
**common_kwargs,
|
||||
)
|
||||
|
||||
ctx.save_for_backward(input, weight)
|
||||
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, weight = ctx.saved_tensors
|
||||
grad_input, grad_weight, grad_bias = None, None, None
|
||||
|
||||
if ctx.needs_input_grad[0]:
|
||||
p = calc_output_padding(
|
||||
input_shape=input.shape, output_shape=grad_output.shape)
|
||||
grad_input = conv2d_gradfix(
|
||||
transpose=(not transpose),
|
||||
weight_shape=weight_shape,
|
||||
output_padding=p,
|
||||
**common_kwargs,
|
||||
).apply(grad_output, weight, None)
|
||||
|
||||
if ctx.needs_input_grad[1] and not weight_gradients_disabled:
|
||||
grad_weight = Conv2dGradWeight.apply(grad_output, input)
|
||||
|
||||
if ctx.needs_input_grad[2]:
|
||||
grad_bias = grad_output.sum((0, 2, 3))
|
||||
|
||||
return grad_input, grad_weight, grad_bias
|
||||
|
||||
class Conv2dGradWeight(autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, grad_output, input):
|
||||
op = torch._C._jit_get_operation(
|
||||
'aten::cudnn_convolution_backward_weight' if not transpose else
|
||||
'aten::cudnn_convolution_transpose_backward_weight')
|
||||
flags = [
|
||||
torch.backends.cudnn.benchmark,
|
||||
torch.backends.cudnn.deterministic,
|
||||
torch.backends.cudnn.allow_tf32,
|
||||
]
|
||||
grad_weight = op(
|
||||
weight_shape,
|
||||
grad_output,
|
||||
input,
|
||||
padding,
|
||||
stride,
|
||||
dilation,
|
||||
groups,
|
||||
*flags,
|
||||
)
|
||||
ctx.save_for_backward(grad_output, input)
|
||||
|
||||
return grad_weight
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_grad_weight):
|
||||
grad_output, input = ctx.saved_tensors
|
||||
grad_grad_output, grad_grad_input = None, None
|
||||
|
||||
if ctx.needs_input_grad[0]:
|
||||
grad_grad_output = Conv2d.apply(input, grad_grad_weight, None)
|
||||
|
||||
if ctx.needs_input_grad[1]:
|
||||
p = calc_output_padding(
|
||||
input_shape=input.shape, output_shape=grad_output.shape)
|
||||
grad_grad_input = conv2d_gradfix(
|
||||
transpose=(not transpose),
|
||||
weight_shape=weight_shape,
|
||||
output_padding=p,
|
||||
**common_kwargs,
|
||||
).apply(grad_output, grad_grad_weight, None)
|
||||
|
||||
return grad_grad_output, grad_grad_input
|
||||
|
||||
conv2d_gradfix_cache[key] = Conv2d
|
||||
|
||||
return Conv2d
|
||||
@@ -0,0 +1,64 @@
|
||||
import collections
|
||||
import os
|
||||
import sys
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from .base_module import *
|
||||
|
||||
sys.path.insert(0, os.path.abspath(os.path.dirname(os.path.dirname(__file__))))
|
||||
|
||||
|
||||
class Generator(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size,
|
||||
semantic_dim,
|
||||
channels,
|
||||
num_labels,
|
||||
match_kernels,
|
||||
blur_kernel=[1, 3, 3, 1],
|
||||
wavelet_down_levels={'16': 3},
|
||||
window_size=8,
|
||||
):
|
||||
super().__init__()
|
||||
self.size = size
|
||||
self.reference_encoder = Encoder_wiflow(size, 3, channels, num_labels,
|
||||
match_kernels, blur_kernel)
|
||||
|
||||
self.skeleton_encoder = Encoder_wiflow(
|
||||
size,
|
||||
semantic_dim,
|
||||
channels,
|
||||
)
|
||||
|
||||
self.target_image_renderer = Decoder_wiflow_wavelet_fuse25(
|
||||
size, channels, num_labels, match_kernels, blur_kernel,
|
||||
wavelet_down_levels, window_size)
|
||||
|
||||
def _cal_temp(self, module):
|
||||
return sum(p.numel() for p in module.parameters() if p.requires_grad)
|
||||
|
||||
def forward(self, source_image, skeleton, kp_skeleton):
|
||||
output_dict = {}
|
||||
recoder = collections.defaultdict(list)
|
||||
skeleton_feature_list, source_feature_list = [], []
|
||||
skeleton_feature = self.skeleton_encoder(
|
||||
skeleton, out_list=skeleton_feature_list)
|
||||
_ = self.reference_encoder(
|
||||
source_image, recoder, out_list=source_feature_list)
|
||||
neural_textures = recoder['neural_textures']
|
||||
|
||||
output_dict['fake_image'], delta_x_all, delta_y_all, delta_list, last_flow_all, mask_all_h, mask_all_l = \
|
||||
self.target_image_renderer(skeleton_feature, neural_textures, skeleton_feature_list,
|
||||
source_feature_list, kp_skeleton, recoder)
|
||||
output_dict['info'] = recoder
|
||||
output_dict['delta_x'] = delta_x_all
|
||||
output_dict['delta_y'] = delta_y_all
|
||||
output_dict['delta_list'] = delta_list
|
||||
output_dict['last_flow_all'] = last_flow_all
|
||||
output_dict['mask_all_h'] = mask_all_h
|
||||
output_dict['mask_all_l'] = mask_all_l
|
||||
return output_dict
|
||||
@@ -0,0 +1,346 @@
|
||||
from math import sqrt
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .base_function import EqualConv2d, EqualLinear
|
||||
|
||||
|
||||
def TVLoss(x):
|
||||
tv_h = x[:, :, 1:, :] - x[:, :, :-1, :]
|
||||
tv_w = x[:, :, :, 1:] - x[:, :, :, :-1]
|
||||
|
||||
return torch.mean(torch.abs(tv_h)) + torch.mean(torch.abs(tv_w))
|
||||
|
||||
|
||||
class MaskStyle(nn.Module):
|
||||
|
||||
def __init__(self, channels, log_size, style_in, channels_multiplier=2):
|
||||
super().__init__()
|
||||
self.log_size = log_size
|
||||
padding_type = 'zero'
|
||||
actvn = 'lrelu'
|
||||
normalize_mlp = False
|
||||
modulated_conv = True
|
||||
|
||||
self.netM = nn.ModuleDict()
|
||||
|
||||
for i in range(4, self.log_size + 1):
|
||||
out_channel = channels[2**i]
|
||||
|
||||
style_mask = StyledConvBlock(
|
||||
channels_multiplier * out_channel,
|
||||
channels_multiplier * out_channel,
|
||||
latent_dim=style_in,
|
||||
padding=padding_type,
|
||||
actvn=actvn,
|
||||
normalize_affine_output=normalize_mlp,
|
||||
modulated_conv=modulated_conv)
|
||||
|
||||
scale = str(2**i)
|
||||
self.netM[scale] = style_mask
|
||||
|
||||
|
||||
class StyleFlow(nn.Module):
|
||||
|
||||
def __init__(self, channels, log_size, style_in):
|
||||
super().__init__()
|
||||
self.log_size = log_size
|
||||
padding_type = 'zero'
|
||||
actvn = 'lrelu'
|
||||
normalize_mlp = False
|
||||
modulated_conv = True
|
||||
|
||||
self.netRefine = nn.ModuleDict()
|
||||
self.netStyle = nn.ModuleDict()
|
||||
self.netF = nn.ModuleDict()
|
||||
|
||||
for i in range(4, self.log_size + 1):
|
||||
out_channel = channels[2**i]
|
||||
|
||||
netRefine_layer = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(
|
||||
2 * out_channel,
|
||||
out_channels=128,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1),
|
||||
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
|
||||
torch.nn.Conv2d(
|
||||
in_channels=128,
|
||||
out_channels=64,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1),
|
||||
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
|
||||
torch.nn.Conv2d(
|
||||
in_channels=64,
|
||||
out_channels=32,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1),
|
||||
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
|
||||
torch.nn.Conv2d(
|
||||
in_channels=32,
|
||||
out_channels=2,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1))
|
||||
|
||||
style_block = StyledConvBlock(
|
||||
out_channel,
|
||||
49,
|
||||
latent_dim=style_in,
|
||||
padding=padding_type,
|
||||
actvn=actvn,
|
||||
normalize_affine_output=normalize_mlp,
|
||||
modulated_conv=modulated_conv)
|
||||
|
||||
style_F_block = Styled_F_ConvBlock(
|
||||
49,
|
||||
2,
|
||||
latent_dim=style_in,
|
||||
padding=padding_type,
|
||||
actvn=actvn,
|
||||
normalize_affine_output=normalize_mlp,
|
||||
modulated_conv=modulated_conv)
|
||||
|
||||
scale = str(2**i)
|
||||
self.netRefine[scale] = (netRefine_layer)
|
||||
self.netStyle[scale] = (style_block)
|
||||
self.netF[scale] = (style_F_block)
|
||||
|
||||
|
||||
class StyledConvBlock(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
fin,
|
||||
fout,
|
||||
latent_dim=256,
|
||||
padding='zero',
|
||||
actvn='lrelu',
|
||||
normalize_affine_output=False,
|
||||
modulated_conv=False):
|
||||
super(StyledConvBlock, self).__init__()
|
||||
if not modulated_conv:
|
||||
if padding == 'reflect':
|
||||
padding_layer = nn.ReflectionPad2d
|
||||
else:
|
||||
padding_layer = nn.ZeroPad2d
|
||||
|
||||
if modulated_conv:
|
||||
conv2d = ModulatedConv2d
|
||||
else:
|
||||
conv2d = EqualConv2d
|
||||
|
||||
if modulated_conv:
|
||||
self.actvn_gain = sqrt(2)
|
||||
else:
|
||||
self.actvn_gain = 1.0
|
||||
|
||||
self.modulated_conv = modulated_conv
|
||||
|
||||
if actvn == 'relu':
|
||||
activation = nn.ReLU(True)
|
||||
else:
|
||||
activation = nn.LeakyReLU(0.2, True)
|
||||
|
||||
if self.modulated_conv:
|
||||
self.conv0 = conv2d(
|
||||
fin,
|
||||
fout,
|
||||
kernel_size=3,
|
||||
padding_type=padding,
|
||||
upsample=False,
|
||||
latent_dim=latent_dim,
|
||||
normalize_mlp=normalize_affine_output)
|
||||
else:
|
||||
conv0 = conv2d(fin, fout, kernel_size=3)
|
||||
|
||||
seq0 = [padding_layer(1), conv0]
|
||||
self.conv0 = nn.Sequential(*seq0)
|
||||
|
||||
self.actvn0 = activation
|
||||
|
||||
if self.modulated_conv:
|
||||
self.conv1 = conv2d(
|
||||
fout,
|
||||
fout,
|
||||
kernel_size=3,
|
||||
padding_type=padding,
|
||||
downsample=False,
|
||||
latent_dim=latent_dim,
|
||||
normalize_mlp=normalize_affine_output)
|
||||
else:
|
||||
conv1 = conv2d(fout, fout, kernel_size=3)
|
||||
seq1 = [padding_layer(1), conv1]
|
||||
self.conv1 = nn.Sequential(*seq1)
|
||||
|
||||
self.actvn1 = activation
|
||||
|
||||
def forward(self, input, latent=None):
|
||||
if self.modulated_conv:
|
||||
out = self.conv0(input, latent)
|
||||
else:
|
||||
out = self.conv0(input)
|
||||
|
||||
out = self.actvn0(out) * self.actvn_gain
|
||||
|
||||
if self.modulated_conv:
|
||||
out = self.conv1(out, latent)
|
||||
else:
|
||||
out = self.conv1(out)
|
||||
|
||||
out = self.actvn1(out) * self.actvn_gain
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Styled_F_ConvBlock(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
fin,
|
||||
fout,
|
||||
latent_dim=256,
|
||||
padding='zero',
|
||||
actvn='lrelu',
|
||||
normalize_affine_output=False,
|
||||
modulated_conv=False):
|
||||
super(Styled_F_ConvBlock, self).__init__()
|
||||
if not modulated_conv:
|
||||
if padding == 'reflect':
|
||||
padding_layer = nn.ReflectionPad2d
|
||||
else:
|
||||
padding_layer = nn.ZeroPad2d
|
||||
|
||||
if modulated_conv:
|
||||
conv2d = ModulatedConv2d
|
||||
else:
|
||||
conv2d = EqualConv2d
|
||||
|
||||
if modulated_conv:
|
||||
self.actvn_gain = sqrt(2)
|
||||
else:
|
||||
self.actvn_gain = 1.0
|
||||
|
||||
self.modulated_conv = modulated_conv
|
||||
|
||||
if actvn == 'relu':
|
||||
activation = nn.ReLU(True)
|
||||
else:
|
||||
activation = nn.LeakyReLU(0.2, True)
|
||||
|
||||
if self.modulated_conv:
|
||||
self.conv0 = conv2d(
|
||||
fin,
|
||||
128,
|
||||
kernel_size=3,
|
||||
padding_type=padding,
|
||||
upsample=False,
|
||||
latent_dim=latent_dim,
|
||||
normalize_mlp=normalize_affine_output)
|
||||
else:
|
||||
conv0 = conv2d(fin, 128, kernel_size=3)
|
||||
|
||||
seq0 = [padding_layer(1), conv0]
|
||||
self.conv0 = nn.Sequential(*seq0)
|
||||
|
||||
self.actvn0 = activation
|
||||
|
||||
if self.modulated_conv:
|
||||
self.conv1 = conv2d(
|
||||
128,
|
||||
fout,
|
||||
kernel_size=3,
|
||||
padding_type=padding,
|
||||
downsample=False,
|
||||
latent_dim=latent_dim,
|
||||
normalize_mlp=normalize_affine_output)
|
||||
else:
|
||||
conv1 = conv2d(128, fout, kernel_size=3)
|
||||
seq1 = [padding_layer(1), conv1]
|
||||
self.conv1 = nn.Sequential(*seq1)
|
||||
|
||||
def forward(self, input, latent=None):
|
||||
if self.modulated_conv:
|
||||
out = self.conv0(input, latent)
|
||||
else:
|
||||
out = self.conv0(input)
|
||||
|
||||
out = self.actvn0(out) * self.actvn_gain
|
||||
|
||||
if self.modulated_conv:
|
||||
out = self.conv1(out, latent)
|
||||
else:
|
||||
out = self.conv1(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ModulatedConv2d(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
fin,
|
||||
fout,
|
||||
kernel_size,
|
||||
padding_type='zero',
|
||||
upsample=False,
|
||||
downsample=False,
|
||||
latent_dim=512,
|
||||
normalize_mlp=False):
|
||||
super(ModulatedConv2d, self).__init__()
|
||||
self.in_channels = fin
|
||||
self.out_channels = fout
|
||||
self.kernel_size = kernel_size
|
||||
padding_size = kernel_size // 2
|
||||
|
||||
if kernel_size == 1:
|
||||
self.demudulate = False
|
||||
else:
|
||||
self.demudulate = True
|
||||
|
||||
self.weight = nn.Parameter(
|
||||
torch.Tensor(fout, fin, kernel_size, kernel_size))
|
||||
self.bias = nn.Parameter(torch.Tensor(1, fout, 1, 1))
|
||||
|
||||
if normalize_mlp:
|
||||
self.mlp_class_std = nn.Sequential(
|
||||
EqualLinear(latent_dim, fin), PixelNorm())
|
||||
else:
|
||||
self.mlp_class_std = EqualLinear(latent_dim, fin)
|
||||
|
||||
if padding_type == 'reflect':
|
||||
self.padding = nn.ReflectionPad2d(padding_size)
|
||||
else:
|
||||
self.padding = nn.ZeroPad2d(padding_size)
|
||||
|
||||
self.weight.data.normal_()
|
||||
self.bias.data.zero_()
|
||||
|
||||
def forward(self, input, latent):
|
||||
fan_in = self.weight.data.size(1) * self.weight.data[0][0].numel()
|
||||
weight = self.weight * sqrt(2 / fan_in)
|
||||
weight = weight.view(1, self.out_channels, self.in_channels,
|
||||
self.kernel_size, self.kernel_size)
|
||||
|
||||
s = self.mlp_class_std(latent).view(-1, 1, self.in_channels, 1, 1)
|
||||
weight = s * weight
|
||||
if self.demudulate:
|
||||
d = torch.rsqrt((weight**2).sum(4).sum(3).sum(2) + 1e-5).view(
|
||||
-1, self.out_channels, 1, 1, 1)
|
||||
weight = (d * weight).view(-1, self.in_channels, self.kernel_size,
|
||||
self.kernel_size)
|
||||
else:
|
||||
weight = weight.view(-1, self.in_channels, self.kernel_size,
|
||||
self.kernel_size)
|
||||
|
||||
batch, _, height, width = input.shape
|
||||
|
||||
input = input.reshape(1, -1, height, width)
|
||||
input = self.padding(input)
|
||||
out = F.conv2d(
|
||||
input, weight, groups=batch).view(batch, self.out_channels, height,
|
||||
width) + self.bias
|
||||
|
||||
return out
|
||||
121
modelscope/models/cv/human_image_generation/generators/tps.py
Normal file
121
modelscope/models/cv/human_image_generation/generators/tps.py
Normal file
@@ -0,0 +1,121 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
|
||||
class TPS(nn.Module):
|
||||
|
||||
def __init__(self, mode='kp'):
|
||||
super().__init__()
|
||||
self.mode = mode
|
||||
|
||||
def trans(self, kp_1):
|
||||
if self.mode == 'kp':
|
||||
device = kp_1.device
|
||||
kp_type = kp_1.type()
|
||||
self.gs = kp_1.shape[1]
|
||||
n = kp_1.shape[2]
|
||||
K = torch.norm(
|
||||
kp_1[:, :, :, None] - kp_1[:, :, None, :], dim=4, p=2)
|
||||
K = K**2
|
||||
K = K * torch.log(K + 1e-9)
|
||||
|
||||
one1 = torch.ones(self.bs, kp_1.shape[1], kp_1.shape[2],
|
||||
1).to(device).type(kp_type)
|
||||
kp_1p = torch.cat([kp_1, one1], 3)
|
||||
|
||||
zero = torch.zeros(self.bs, kp_1.shape[1], 3,
|
||||
3).to(device).type(kp_type)
|
||||
P = torch.cat([kp_1p, zero], 2)
|
||||
L = torch.cat([K, kp_1p.permute(0, 1, 3, 2)], 2)
|
||||
L = torch.cat([L, P], 3)
|
||||
|
||||
zero = torch.zeros(self.bs, kp_1.shape[1], 3,
|
||||
2).to(device).type(kp_type)
|
||||
kp_substitute = torch.zeros(kp_1.shape).to(device).type(kp_type)
|
||||
Y = torch.cat([kp_substitute, zero], 2)
|
||||
one = torch.eye(L.shape[2]).expand(
|
||||
L.shape).to(device).type(kp_type) * 0.01
|
||||
L = L + one
|
||||
|
||||
param = torch.matmul(torch.inverse(L), Y)
|
||||
self.theta = param[:, :, n:, :].permute(0, 1, 3, 2)
|
||||
|
||||
self.control_points = kp_1
|
||||
self.control_params = param[:, :, :n, :]
|
||||
else:
|
||||
raise Exception('Error TPS mode')
|
||||
|
||||
def transform_frame(self, frame):
|
||||
grid = make_coordinate_grid(
|
||||
frame.shape[2:], type=frame.type()).unsqueeze(0).to(frame.device)
|
||||
grid = grid.view(1, frame.shape[2] * frame.shape[3], 2)
|
||||
shape = [self.bs, frame.shape[2], frame.shape[3], 2]
|
||||
if self.mode == 'kp':
|
||||
shape.insert(1, self.gs)
|
||||
grid = self.warp_coordinates(grid).view(*shape)
|
||||
return grid
|
||||
|
||||
def warp_coordinates(self, coordinates):
|
||||
theta = self.theta.type(coordinates.type()).to(coordinates.device)
|
||||
control_points = self.control_points.type(coordinates.type()).to(
|
||||
coordinates.device)
|
||||
control_params = self.control_params.type(coordinates.type()).to(
|
||||
coordinates.device)
|
||||
|
||||
if self.mode == 'kp':
|
||||
transformed = torch.matmul(theta[:, :, :, :2],
|
||||
coordinates.permute(
|
||||
0, 2, 1)) + theta[:, :, :, 2:]
|
||||
|
||||
distances = coordinates.view(
|
||||
coordinates.shape[0], 1, 1, -1, 2) - control_points.view(
|
||||
self.bs, control_points.shape[1], -1, 1, 2)
|
||||
distances = distances**2
|
||||
result = distances.sum(-1)
|
||||
result = result * torch.log(result + 1e-9)
|
||||
result = torch.matmul(result.permute(0, 1, 3, 2), control_params)
|
||||
transformed = transformed.permute(0, 1, 3, 2) + result
|
||||
|
||||
else:
|
||||
raise Exception('Error TPS mode')
|
||||
|
||||
return transformed
|
||||
|
||||
def preprocess_kp(self, kp_1):
|
||||
'''
|
||||
kp_1: (b, ntps*nkp, 2)
|
||||
'''
|
||||
kp_mask = (kp_1 == -1)
|
||||
num_keypoints = kp_1.shape[1]
|
||||
kp_1 = kp_1.masked_fill(kp_mask, -1.)
|
||||
return kp_1, num_keypoints
|
||||
|
||||
def forward(self, source_image, kp_driving):
|
||||
bs, _, h, w = source_image.shape
|
||||
self.bs = bs
|
||||
kp_driving, num_keypoints = self.preprocess_kp(kp_driving)
|
||||
kp_1 = kp_driving.view(bs, -1, num_keypoints, 2)
|
||||
self.trans(kp_1)
|
||||
grid = self.transform_frame(source_image)
|
||||
grid = grid.view(bs, h, w, 2)
|
||||
return grid
|
||||
|
||||
|
||||
def make_coordinate_grid(spatial_size, type):
|
||||
"""
|
||||
Create a meshgrid [-1,1] x [-1,1] of given spatial_size.
|
||||
"""
|
||||
h, w = spatial_size
|
||||
x = torch.arange(w).type(type)
|
||||
y = torch.arange(h).type(type)
|
||||
|
||||
x = (2 * (x / (w - 1)) - 1)
|
||||
y = (2 * (y / (h - 1)) - 1)
|
||||
|
||||
yy = y.view(-1, 1).repeat(1, w)
|
||||
xx = x.view(1, -1).repeat(h, 1)
|
||||
|
||||
meshed = torch.cat([xx.unsqueeze_(2), yy.unsqueeze_(2)], 2)
|
||||
|
||||
return meshed
|
||||
@@ -0,0 +1,182 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def get_wav(in_channels, pool=True):
|
||||
harr_wav_L = 1 / np.sqrt(2) * np.ones((1, 2))
|
||||
harr_wav_H = 1 / np.sqrt(2) * np.ones((1, 2))
|
||||
harr_wav_H[0, 0] = -1 * harr_wav_H[0, 0]
|
||||
harr_wav_LL = np.transpose(harr_wav_L) * harr_wav_L
|
||||
harr_wav_LH = np.transpose(harr_wav_L) * harr_wav_H
|
||||
harr_wav_HL = np.transpose(harr_wav_H) * harr_wav_L
|
||||
harr_wav_HH = np.transpose(harr_wav_H) * harr_wav_H
|
||||
filter_LL = torch.from_numpy(harr_wav_LL).unsqueeze(0)
|
||||
filter_LH = torch.from_numpy(harr_wav_LH).unsqueeze(0)
|
||||
filter_HL = torch.from_numpy(harr_wav_HL).unsqueeze(0)
|
||||
filter_HH = torch.from_numpy(harr_wav_HH).unsqueeze(0)
|
||||
if pool:
|
||||
net = nn.Conv2d
|
||||
else:
|
||||
net = nn.ConvTranspose2d
|
||||
LL = net(
|
||||
in_channels,
|
||||
in_channels * 2,
|
||||
kernel_size=2,
|
||||
stride=2,
|
||||
padding=0,
|
||||
bias=False,
|
||||
groups=in_channels)
|
||||
LH = net(
|
||||
in_channels,
|
||||
in_channels * 2,
|
||||
kernel_size=2,
|
||||
stride=2,
|
||||
padding=0,
|
||||
bias=False,
|
||||
groups=in_channels)
|
||||
HL = net(
|
||||
in_channels,
|
||||
in_channels * 2,
|
||||
kernel_size=2,
|
||||
stride=2,
|
||||
padding=0,
|
||||
bias=False,
|
||||
groups=in_channels)
|
||||
HH = net(
|
||||
in_channels,
|
||||
in_channels * 2,
|
||||
kernel_size=2,
|
||||
stride=2,
|
||||
padding=0,
|
||||
bias=False,
|
||||
groups=in_channels)
|
||||
LL.weight.requires_grad = False
|
||||
LH.weight.requires_grad = False
|
||||
HL.weight.requires_grad = False
|
||||
HH.weight.requires_grad = False
|
||||
LL.weight.data = filter_LL.float().unsqueeze(0).expand(
|
||||
in_channels * 2, -1, -1, -1)
|
||||
LH.weight.data = filter_LH.float().unsqueeze(0).expand(
|
||||
in_channels * 2, -1, -1, -1)
|
||||
HL.weight.data = filter_HL.float().unsqueeze(0).expand(
|
||||
in_channels * 2, -1, -1, -1)
|
||||
HH.weight.data = filter_HH.float().unsqueeze(0).expand(
|
||||
in_channels * 2, -1, -1, -1)
|
||||
return LL, LH, HL, HH
|
||||
|
||||
|
||||
class WavePool(nn.Module):
|
||||
|
||||
def __init__(self, in_channels):
|
||||
super(WavePool, self).__init__()
|
||||
self.LL, self.LH, self.HL, self.HH = get_wav(in_channels)
|
||||
|
||||
def forward(self, x):
|
||||
return self.LL(x), self.LH(x), self.HL(x), self.HH(x)
|
||||
|
||||
|
||||
def get_wav_two(in_channels, out_channels=None, pool=True):
|
||||
"""wavelet decomposition using conv2d"""
|
||||
harr_wav_L = 1 / np.sqrt(2) * np.ones((1, 2))
|
||||
harr_wav_H = 1 / np.sqrt(2) * np.ones((1, 2))
|
||||
harr_wav_H[0, 0] = -1 * harr_wav_H[0, 0]
|
||||
|
||||
harr_wav_LL = np.transpose(harr_wav_L) * harr_wav_L
|
||||
harr_wav_LH = np.transpose(harr_wav_L) * harr_wav_H
|
||||
harr_wav_HL = np.transpose(harr_wav_H) * harr_wav_L
|
||||
harr_wav_HH = np.transpose(harr_wav_H) * harr_wav_H
|
||||
|
||||
filter_LL = torch.from_numpy(harr_wav_LL).unsqueeze(0)
|
||||
filter_LH = torch.from_numpy(harr_wav_LH).unsqueeze(0)
|
||||
filter_HL = torch.from_numpy(harr_wav_HL).unsqueeze(0)
|
||||
filter_HH = torch.from_numpy(harr_wav_HH).unsqueeze(0)
|
||||
|
||||
if pool:
|
||||
net = nn.Conv2d
|
||||
else:
|
||||
net = nn.ConvTranspose2d
|
||||
if out_channels is None:
|
||||
out_channels = in_channels
|
||||
LL = net(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=2,
|
||||
stride=2,
|
||||
padding=0,
|
||||
bias=False,
|
||||
groups=in_channels)
|
||||
LH = net(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=2,
|
||||
stride=2,
|
||||
padding=0,
|
||||
bias=False,
|
||||
groups=in_channels)
|
||||
HL = net(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=2,
|
||||
stride=2,
|
||||
padding=0,
|
||||
bias=False,
|
||||
groups=in_channels)
|
||||
HH = net(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=2,
|
||||
stride=2,
|
||||
padding=0,
|
||||
bias=False,
|
||||
groups=in_channels)
|
||||
|
||||
LL.weight.requires_grad = False
|
||||
LH.weight.requires_grad = False
|
||||
HL.weight.requires_grad = False
|
||||
HH.weight.requires_grad = False
|
||||
|
||||
LL.weight.data = filter_LL.float().unsqueeze(0).expand(
|
||||
in_channels, -1, -1, -1)
|
||||
LH.weight.data = filter_LH.float().unsqueeze(0).expand(
|
||||
in_channels, -1, -1, -1)
|
||||
HL.weight.data = filter_HL.float().unsqueeze(0).expand(
|
||||
in_channels, -1, -1, -1)
|
||||
HH.weight.data = filter_HH.float().unsqueeze(0).expand(
|
||||
in_channels, -1, -1, -1)
|
||||
|
||||
return LL, LH, HL, HH
|
||||
|
||||
|
||||
class WavePool2(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, out_channels=None):
|
||||
super(WavePool2, self).__init__()
|
||||
self.LL, self.LH, self.HL, self.HH = get_wav_two(
|
||||
in_channels, out_channels)
|
||||
|
||||
def forward(self, x):
|
||||
return self.LL(x), self.LH(x), self.HL(x), self.HH(x)
|
||||
|
||||
|
||||
class WaveUnpool(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, out_channels=None, option_unpool='cat5'):
|
||||
super(WaveUnpool, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
self.option_unpool = option_unpool
|
||||
self.LL, self.LH, self.HL, self.HH = get_wav_two(
|
||||
self.in_channels, out_channels, pool=False)
|
||||
|
||||
def forward(self, LL, LH, HL, HH, original=None):
|
||||
if self.option_unpool == 'sum':
|
||||
return self.LL(LL) + self.LH(LH) + self.HL(HL) + self.HH(HH)
|
||||
elif self.option_unpool == 'cat5' and original is not None:
|
||||
return torch.cat(
|
||||
[self.LL(LL),
|
||||
self.LH(LH),
|
||||
self.HL(HL),
|
||||
self.HH(HH), original],
|
||||
dim=1)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
@@ -0,0 +1,268 @@
|
||||
# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.
|
||||
|
||||
import math
|
||||
import random
|
||||
from ast import Global
|
||||
from pickle import GLOBAL
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision
|
||||
import torchvision.transforms as transforms
|
||||
import torchvision.transforms.functional as F
|
||||
from PIL import Image
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models.base import TorchModel
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
from .generators.extraction_distribution_model_flow25 import \
|
||||
Generator as Generator
|
||||
|
||||
tv_version = int(torchvision.__version__.split('.')[1])
|
||||
if tv_version > 8:
|
||||
from torchvision.transforms.functional import InterpolationMode
|
||||
resize_method = InterpolationMode.BICUBIC
|
||||
resize_nearest = InterpolationMode.NEAREST
|
||||
else:
|
||||
resize_method = Image.BICUBIC
|
||||
resize_nearest = Image.NEAREST
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def get_random_params(size, scale_param, use_flip=False):
|
||||
w, h = size
|
||||
scale = random.random() * scale_param
|
||||
|
||||
if use_flip:
|
||||
use_flip = random.random() > 0.9
|
||||
|
||||
new_w = int(w * (1.0 + scale))
|
||||
new_h = int(h * (1.0 + scale))
|
||||
x = random.randint(0, np.maximum(0, new_w - w))
|
||||
y = random.randint(0, np.maximum(0, new_h - h))
|
||||
return {
|
||||
'crop_param': (x, y, w, h),
|
||||
'scale_size': (new_h, new_w),
|
||||
'use_flip': use_flip
|
||||
}
|
||||
|
||||
|
||||
def get_transform(param, method=resize_method, normalize=True, toTensor=True):
|
||||
transform_list = []
|
||||
if 'scale_size' in param and param['scale_size'] is not None:
|
||||
osize = param['scale_size']
|
||||
transform_list.append(transforms.Resize(osize, interpolation=method))
|
||||
|
||||
if 'crop_param' in param and param['crop_param'] is not None:
|
||||
transform_list.append(
|
||||
transforms.Lambda(lambda img: __crop(img, param['crop_param'])))
|
||||
|
||||
if param['use_flip']:
|
||||
transform_list.append(transforms.Lambda(lambda img: __flip(img)))
|
||||
|
||||
if toTensor:
|
||||
transform_list += [transforms.ToTensor()]
|
||||
|
||||
if normalize:
|
||||
transform_list += [
|
||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
||||
]
|
||||
return transforms.Compose(transform_list)
|
||||
|
||||
|
||||
def __crop(img, pos):
|
||||
x1, y1, tw, th = pos
|
||||
return img.crop((x1, y1, x1 + tw, y1 + th))
|
||||
|
||||
|
||||
def __flip(img):
|
||||
return F.hflip(img)
|
||||
|
||||
|
||||
def normalize():
|
||||
return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
||||
|
||||
|
||||
def load_checkpoint(model, checkpoint_path, device):
|
||||
params = torch.load(checkpoint_path, map_location=device)
|
||||
if 'target_image_renderer.weight' in params['net_G_ema'].keys():
|
||||
params['net_G_ema'].pop('target_image_renderer.weight')
|
||||
model.load_state_dict(params['net_G_ema'])
|
||||
model.to(device)
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
@MODELS.register_module(
|
||||
Tasks.human_image_generation, module_name=Models.human_image_generation)
|
||||
class FreqHPTForHumanImageGeneration(TorchModel):
|
||||
"""initialize the human image generation model from the `model_dir` path.
|
||||
|
||||
Args:
|
||||
model_dir (str): the model path.
|
||||
"""
|
||||
|
||||
def __init__(self, model_dir, device_id=0, *args, **kwargs):
|
||||
|
||||
super().__init__(
|
||||
model_dir=model_dir, device_id=device_id, *args, **kwargs)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
self.device = 'cuda'
|
||||
logger.info('Use GPU')
|
||||
else:
|
||||
self.device = 'cpu'
|
||||
logger.info('Use CPU')
|
||||
|
||||
size = 512
|
||||
semantic_dim = 20
|
||||
channels = {
|
||||
16: 256,
|
||||
32: 256,
|
||||
64: 256,
|
||||
128: 128,
|
||||
256: 128,
|
||||
512: 64,
|
||||
1024: 32
|
||||
}
|
||||
num_labels = {16: 16, 32: 32, 64: 64, 128: 64, 256: 64, 512: False}
|
||||
match_kernels = {16: False, 32: 3, 64: 3, 128: 3, 256: 3, 512: False}
|
||||
wavelet_down_levels = {16: False, 32: 1, 64: 2, 128: 3, 256: 3, 512: 3}
|
||||
self.model = Generator(
|
||||
size,
|
||||
semantic_dim,
|
||||
channels,
|
||||
num_labels,
|
||||
match_kernels,
|
||||
wavelet_down_levels=wavelet_down_levels)
|
||||
self.model = load_checkpoint(
|
||||
self.model, model_dir + '/' + ModelFile.TORCH_MODEL_BIN_FILE,
|
||||
self.device)
|
||||
|
||||
def forward(self, x, y, z):
|
||||
pred_result = self.model(x, y, z)
|
||||
return pred_result
|
||||
|
||||
|
||||
def trans_keypoins(keypoints, param, img_size, offset=None):
|
||||
missing_keypoint_index = keypoints == -1
|
||||
|
||||
# crop the white line in the original dataset
|
||||
if not offset == 40:
|
||||
keypoints[:, 0] = (keypoints[:, 0] - 40)
|
||||
|
||||
# resize the dataset
|
||||
img_h, img_w = img_size
|
||||
scale_w = 1.0 / 176.0 * img_w
|
||||
scale_h = 1.0 / 256.0 * img_h
|
||||
|
||||
if 'scale_size' in param and param['scale_size'] is not None:
|
||||
new_h, new_w = param['scale_size']
|
||||
scale_w = scale_w / img_w * new_w
|
||||
scale_h = scale_h / img_h * new_h
|
||||
|
||||
if 'crop_param' in param and param['crop_param'] is not None:
|
||||
w, h, _, _ = param['crop_param']
|
||||
else:
|
||||
w, h = 0, 0
|
||||
|
||||
keypoints[:, 0] = keypoints[:, 0] * scale_w - w
|
||||
keypoints[:, 1] = keypoints[:, 1] * scale_h - h
|
||||
|
||||
normalized_kp = keypoints.copy()
|
||||
normalized_kp[:, 0] = (normalized_kp[:, 0]) / img_w * 2 - 1
|
||||
normalized_kp[:, 1] = (normalized_kp[:, 1]) / img_h * 2 - 1
|
||||
normalized_kp[missing_keypoint_index] = -1
|
||||
|
||||
keypoints[missing_keypoint_index] = -1
|
||||
return keypoints, normalized_kp
|
||||
|
||||
|
||||
def get_label_tensor(path, img, param):
|
||||
limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10],
|
||||
[10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15],
|
||||
[15, 17], [1, 16], [16, 18], [3, 17], [6, 18]]
|
||||
|
||||
colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0],
|
||||
[170, 255, 0], [85, 255, 0], [0, 255, 0], [0, 255, 85],
|
||||
[0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255],
|
||||
[0, 0, 255], [85, 0, 255], [170, 0, 255], [255, 0, 255],
|
||||
[255, 0, 170], [255, 0, 85]]
|
||||
canvas = np.zeros((img.shape[1], img.shape[2], 3)).astype(np.uint8)
|
||||
keypoint = np.loadtxt(path)
|
||||
keypoint, normalized_kp = trans_keypoins(keypoint, param, img.shape[1:])
|
||||
stickwidth = 4
|
||||
for i in range(18):
|
||||
x, y = keypoint[i, 0:2]
|
||||
if x == -1 or y == -1:
|
||||
continue
|
||||
cv2.circle(canvas, (int(x), int(y)), 4, colors[i], thickness=-1)
|
||||
joints = []
|
||||
for i in range(17):
|
||||
Y = keypoint[np.array(limbSeq[i]) - 1, 0]
|
||||
X = keypoint[np.array(limbSeq[i]) - 1, 1]
|
||||
cur_canvas = canvas.copy()
|
||||
if -1 in Y or -1 in X:
|
||||
joints.append(np.zeros_like(cur_canvas[:, :, 0]))
|
||||
continue
|
||||
mX = np.mean(X)
|
||||
mY = np.mean(Y)
|
||||
length = ((X[0] - X[1])**2 + (Y[0] - Y[1])**2)**0.5
|
||||
angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
|
||||
polygon = cv2.ellipse2Poly(
|
||||
(int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0,
|
||||
360, 1)
|
||||
cv2.fillConvexPoly(cur_canvas, polygon, colors[i])
|
||||
canvas = cv2.addWeighted(canvas, 0.4, cur_canvas, 0.6, 0)
|
||||
|
||||
joint = np.zeros_like(cur_canvas[:, :, 0])
|
||||
cv2.fillConvexPoly(joint, polygon, 255)
|
||||
joint = cv2.addWeighted(joint, 0.4, joint, 0.6, 0)
|
||||
joints.append(joint)
|
||||
pose = F.to_tensor(
|
||||
Image.fromarray(cv2.cvtColor(canvas, cv2.COLOR_BGR2RGB)))
|
||||
|
||||
tensors_dist = 0
|
||||
e = 1
|
||||
for i in range(len(joints)):
|
||||
im_dist = cv2.distanceTransform(255 - joints[i], cv2.DIST_L1, 3)
|
||||
im_dist = np.clip((im_dist / 3), 0, 255).astype(np.uint8)
|
||||
tensor_dist = F.to_tensor(Image.fromarray(im_dist))
|
||||
tensors_dist = tensor_dist if e == 1 else torch.cat(
|
||||
[tensors_dist, tensor_dist])
|
||||
e += 1
|
||||
|
||||
label_tensor = torch.cat((pose, tensors_dist), dim=0)
|
||||
return label_tensor, normalized_kp
|
||||
|
||||
|
||||
def get_image_tensor(path):
|
||||
img = Image.open(path)
|
||||
param = get_random_params(img.size, 0)
|
||||
trans = get_transform(param, normalize=True, toTensor=True)
|
||||
img = trans(img)
|
||||
return img, param
|
||||
|
||||
|
||||
def infer(genmodel, image_path, target_label_path, device):
|
||||
ref_tensor, param = get_image_tensor(image_path)
|
||||
target_label_tensor, target_kp = get_label_tensor(target_label_path,
|
||||
ref_tensor, param)
|
||||
|
||||
ref_tensor = ref_tensor.unsqueeze(0).to(device)
|
||||
target_label_tensor = target_label_tensor.unsqueeze(0).to(device)
|
||||
target_kp = torch.from_numpy(target_kp).unsqueeze(0).to(device)
|
||||
output_dict = genmodel(ref_tensor, target_label_tensor, target_kp)
|
||||
output_image = output_dict['fake_image'][0]
|
||||
|
||||
output_image = output_image.clamp_(-1, 1)
|
||||
image = (output_image + 1) * 0.5
|
||||
image = image.detach().cpu().squeeze().numpy()
|
||||
image = np.transpose(image, (1, 2, 0)) * 255
|
||||
image = np.uint8(image)
|
||||
bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
||||
return bgr
|
||||
@@ -158,7 +158,7 @@ class Encoder(nn.Module):
|
||||
return hooks
|
||||
|
||||
def forward(self, img):
|
||||
return self.arch(img)
|
||||
return self.arch.forward_features(img)
|
||||
|
||||
|
||||
class MultiScaleColorDecoder(nn.Module):
|
||||
|
||||
@@ -119,8 +119,8 @@ class ConvNeXt(nn.Module):
|
||||
self.head_cls = nn.Linear(dims[-1], 4)
|
||||
|
||||
self.apply(self._init_weights)
|
||||
self.head_cls.weight.data.mul_(head_init_scale)
|
||||
self.head_cls.bias.data.mul_(head_init_scale)
|
||||
# self.head_cls.weight.data.mul_(head_init_scale)
|
||||
# self.head_cls.bias.data.mul_(head_init_scale)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, (nn.Conv2d, nn.Linear)):
|
||||
|
||||
23
modelscope/models/cv/image_editing/__init__.py
Normal file
23
modelscope/models/cv/image_editing/__init__.py
Normal file
@@ -0,0 +1,23 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .masactrl import MutualSelfAttentionControl
|
||||
from .masactrl_utils import regiter_attention_editor_diffusers
|
||||
else:
|
||||
_import_structure = {
|
||||
'masactrl': ['MutualSelfAttentionControl'],
|
||||
'masactrl_utils': ['regiter_attention_editor_diffusers']
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
77
modelscope/models/cv/image_editing/masactrl.py
Normal file
77
modelscope/models/cv/image_editing/masactrl.py
Normal file
@@ -0,0 +1,77 @@
|
||||
# ------------------------------------------------------------------------
|
||||
# Modified from https://github.com/TencentARC/MasaCtrl/blob/main/masactrl/masactrl.py
|
||||
# Copyright (c) 2023 TencentARC. All Rights Reserved.
|
||||
# ------------------------------------------------------------------------
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
|
||||
from .masactrl_utils import AttentionBase
|
||||
|
||||
|
||||
class MutualSelfAttentionControl(AttentionBase):
|
||||
|
||||
def __init__(self,
|
||||
start_step=4,
|
||||
start_layer=10,
|
||||
layer_idx=None,
|
||||
step_idx=None,
|
||||
total_steps=50):
|
||||
"""
|
||||
Mutual self-attention control for Stable-Diffusion model
|
||||
Args:
|
||||
start_step: the step to start mutual self-attention control
|
||||
start_layer: the layer to start mutual self-attention control
|
||||
layer_idx: list of the layers to apply mutual self-attention control
|
||||
step_idx: list the steps to apply mutual self-attention control
|
||||
total_steps: the total number of steps
|
||||
"""
|
||||
super().__init__()
|
||||
self.total_steps = total_steps
|
||||
self.start_step = start_step
|
||||
self.start_layer = start_layer
|
||||
self.layer_idx = layer_idx if layer_idx is not None else list(
|
||||
range(start_layer, 16))
|
||||
self.step_idx = step_idx if step_idx is not None else list(
|
||||
range(start_step, total_steps)) # denoise index
|
||||
print('step_idx: ', self.step_idx)
|
||||
print('layer_idx: ', self.layer_idx)
|
||||
|
||||
def attn_batch(self, q, k, v, sim, attn, is_cross, place_in_unet,
|
||||
num_heads, **kwargs):
|
||||
b = q.shape[0] // num_heads
|
||||
q = rearrange(q, '(b h) n d -> h (b n) d', h=num_heads)
|
||||
k = rearrange(k, '(b h) n d -> h (b n) d', h=num_heads)
|
||||
v = rearrange(v, '(b h) n d -> h (b n) d', h=num_heads)
|
||||
|
||||
sim = torch.einsum('h i d, h j d -> h i j', q, k) * kwargs.get('scale')
|
||||
attn = sim.softmax(-1)
|
||||
out = torch.einsum('h i j, h j d -> h i d', attn, v)
|
||||
out = rearrange(out, 'h (b n) d -> b n (h d)', b=b)
|
||||
return out
|
||||
|
||||
def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads,
|
||||
**kwargs):
|
||||
"""
|
||||
Attention forward function
|
||||
"""
|
||||
if is_cross or self.cur_step not in self.step_idx or self.cur_att_layer // 2 not in self.layer_idx:
|
||||
return super().forward(q, k, v, sim, attn, is_cross, place_in_unet,
|
||||
num_heads, **kwargs)
|
||||
|
||||
qu, qc = q.chunk(2) # uncond, cond
|
||||
ku, kc = k.chunk(2)
|
||||
vu, vc = v.chunk(2)
|
||||
attnu, attnc = attn.chunk(2)
|
||||
# uncond
|
||||
# ku[:num_heads], vu[:num_heads] -> source
|
||||
# qu -> [source, target]
|
||||
out_u = self.attn_batch(qu, ku[:num_heads], vu[:num_heads],
|
||||
sim[:num_heads], attnu, is_cross,
|
||||
place_in_unet, num_heads, **kwargs)
|
||||
out_c = self.attn_batch(qc, kc[:num_heads], vc[:num_heads],
|
||||
sim[:num_heads], attnc, is_cross,
|
||||
place_in_unet, num_heads, **kwargs)
|
||||
out = torch.cat([out_u, out_c], dim=0)
|
||||
|
||||
return out
|
||||
124
modelscope/models/cv/image_editing/masactrl_utils.py
Normal file
124
modelscope/models/cv/image_editing/masactrl_utils.py
Normal file
@@ -0,0 +1,124 @@
|
||||
# ------------------------------------------------------------------------
|
||||
# Modified from https://github.com/TencentARC/MasaCtrl/blob/main/masactrl/masactrl_utils.py
|
||||
# Copyright (c) 2023 TencentARC. All Rights Reserved.
|
||||
# ------------------------------------------------------------------------
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import rearrange, repeat
|
||||
|
||||
|
||||
class AttentionBase:
|
||||
|
||||
def __init__(self):
|
||||
self.cur_step = 0
|
||||
self.num_att_layers = -1
|
||||
self.cur_att_layer = 0
|
||||
|
||||
def after_step(self):
|
||||
pass
|
||||
|
||||
def __call__(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads,
|
||||
**kwargs):
|
||||
out = self.forward(q, k, v, sim, attn, is_cross, place_in_unet,
|
||||
num_heads, **kwargs)
|
||||
self.cur_att_layer += 1
|
||||
if self.cur_att_layer == self.num_att_layers:
|
||||
self.cur_att_layer = 0
|
||||
self.cur_step += 1
|
||||
# after step
|
||||
self.after_step()
|
||||
return out
|
||||
|
||||
def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads,
|
||||
**kwargs):
|
||||
out = torch.einsum('b i j, b j d -> b i d', attn, v)
|
||||
out = rearrange(out, '(b h) n d -> b n (h d)', h=num_heads)
|
||||
return out
|
||||
|
||||
def reset(self):
|
||||
self.cur_step = 0
|
||||
self.cur_att_layer = 0
|
||||
|
||||
|
||||
def regiter_attention_editor_diffusers(model, editor: AttentionBase):
|
||||
"""
|
||||
Register a attention editor to Diffuser Pipeline, refer from [Prompt-to-Prompt]
|
||||
"""
|
||||
|
||||
def ca_forward(self, place_in_unet):
|
||||
|
||||
def forward(x,
|
||||
encoder_hidden_states=None,
|
||||
attention_mask=None,
|
||||
context=None,
|
||||
mask=None):
|
||||
"""
|
||||
The attention is similar to the original implementation of LDM CrossAttention class
|
||||
except adding some modifications on the attention
|
||||
"""
|
||||
if encoder_hidden_states is not None:
|
||||
context = encoder_hidden_states
|
||||
if attention_mask is not None:
|
||||
mask = attention_mask
|
||||
|
||||
to_out = self.to_out
|
||||
if isinstance(to_out, nn.modules.container.ModuleList):
|
||||
to_out = self.to_out[0]
|
||||
else:
|
||||
to_out = self.to_out
|
||||
|
||||
h = self.heads
|
||||
q = self.to_q(x)
|
||||
is_cross = context is not None
|
||||
context = context if is_cross else x
|
||||
k = self.to_k(context)
|
||||
v = self.to_v(context)
|
||||
q, k, v = map(
|
||||
lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
|
||||
(q, k, v))
|
||||
|
||||
sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale
|
||||
|
||||
if mask is not None:
|
||||
mask = rearrange(mask, 'b ... -> b (...)')
|
||||
max_neg_value = -torch.finfo(sim.dtype).max
|
||||
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
||||
mask = mask[:, None, :].repeat(h, 1, 1)
|
||||
sim.masked_fill_(~mask, max_neg_value)
|
||||
|
||||
attn = sim.softmax(dim=-1)
|
||||
# the only difference
|
||||
out = editor(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
sim,
|
||||
attn,
|
||||
is_cross,
|
||||
place_in_unet,
|
||||
self.heads,
|
||||
scale=self.scale)
|
||||
|
||||
return to_out(out)
|
||||
|
||||
return forward
|
||||
|
||||
def register_editor(net, count, place_in_unet):
|
||||
for name, subnet in net.named_children():
|
||||
if net.__class__.__name__ == 'Attention': # spatial Transformer layer
|
||||
net.forward = ca_forward(net, place_in_unet)
|
||||
return count + 1
|
||||
elif hasattr(net, 'children'):
|
||||
count = register_editor(subnet, count, place_in_unet)
|
||||
return count
|
||||
|
||||
cross_att_count = 0
|
||||
for net_name, net in model.unet.named_children():
|
||||
if 'down' in net_name:
|
||||
cross_att_count += register_editor(net, 0, 'down')
|
||||
elif 'mid' in net_name:
|
||||
cross_att_count += register_editor(net, 0, 'mid')
|
||||
elif 'up' in net_name:
|
||||
cross_att_count += register_editor(net, 0, 'up')
|
||||
editor.num_att_layers = cross_att_count
|
||||
@@ -91,7 +91,7 @@ def infer(ourgen_model, model_path, person_img, garment_img, mask_img, device):
|
||||
cm_array = (cm_array >= 128).astype(np.float32)
|
||||
cm = torch.from_numpy(cm_array)
|
||||
cm = cm.unsqueeze(0).unsqueeze(0)
|
||||
cm = torch.FloatTensor((cm.numpy() > 0.5).astype(np.float)).to(device)
|
||||
cm = torch.FloatTensor((cm.numpy() > 0.5).astype(float)).to(device)
|
||||
|
||||
im = person_img
|
||||
h_ori, w_ori = im.shape[0:2]
|
||||
|
||||
@@ -65,6 +65,7 @@ class MTTR(nn.Module):
|
||||
# keep only the valid frames (frames which are annotated):
|
||||
# (for example, in a2d-sentences only the center frame in each window is annotated).
|
||||
for layer_out in backbone_out:
|
||||
valid_indices = valid_indices.to(layer_out.tensors.device)
|
||||
layer_out.tensors = layer_out.tensors.index_select(
|
||||
0, valid_indices)
|
||||
layer_out.mask = layer_out.mask.index_select(0, valid_indices)
|
||||
|
||||
20
modelscope/models/cv/surface_recon_common/__init__.py
Normal file
20
modelscope/models/cv/surface_recon_common/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .surface_recon_common import SurfaceReconCommon
|
||||
|
||||
else:
|
||||
_import_structure = {'surface_recon_common': ['SurfaceReconCommon']}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
289
modelscope/models/cv/surface_recon_common/dataset.py
Normal file
289
modelscope/models/cv/surface_recon_common/dataset.py
Normal file
@@ -0,0 +1,289 @@
|
||||
# ------------------------------------------------------------------------
|
||||
# Modified from https://github.com/Totoro97/NeuS/blob/main/models/dataset.py
|
||||
# Copyright (c) 2021 Peng Wang. All Rights Reserved.
|
||||
# ------------------------------------------------------------------------
|
||||
|
||||
import os
|
||||
from glob import glob
|
||||
|
||||
import cv2 as cv
|
||||
import numpy as np
|
||||
import torch
|
||||
from scipy.spatial.transform import Rotation as Rot
|
||||
from scipy.spatial.transform import Slerp
|
||||
|
||||
|
||||
def load_K_Rt_from_P(filename, P=None):
|
||||
if P is None:
|
||||
lines = open(filename).read().splitlines()
|
||||
if len(lines) == 4:
|
||||
lines = lines[1:]
|
||||
lines = [[x[0], x[1], x[2], x[3]]
|
||||
for x in (x.split(' ') for x in lines)]
|
||||
P = np.asarray(lines).astype(np.float32).squeeze()
|
||||
|
||||
out = cv.decomposeProjectionMatrix(P)
|
||||
K = out[0]
|
||||
R = out[1]
|
||||
t = out[2]
|
||||
|
||||
K = K / K[2, 2]
|
||||
intrinsics = np.eye(4)
|
||||
intrinsics[:3, :3] = K
|
||||
|
||||
pose = np.eye(4, dtype=np.float32)
|
||||
pose[:3, :3] = R.transpose()
|
||||
pose[:3, 3] = (t[:3] / t[3])[:, 0]
|
||||
|
||||
return intrinsics, pose
|
||||
|
||||
|
||||
class Dataset:
|
||||
|
||||
def __init__(self, data_dir, device):
|
||||
super(Dataset, self).__init__()
|
||||
print('Load data: Begin')
|
||||
self.device = device
|
||||
self.data_dir = data_dir
|
||||
print('data_dir: ', self.data_dir)
|
||||
|
||||
camera_dict = np.load(
|
||||
os.path.join(self.data_dir, 'cameras_sphere.npz'))
|
||||
self.camera_dict = camera_dict
|
||||
self.images_lis = sorted(
|
||||
glob(os.path.join(self.data_dir, 'image/*.png')))
|
||||
self.n_images = len(self.images_lis)
|
||||
print('found %d images' % self.n_images)
|
||||
|
||||
self.world_mats_np = [
|
||||
camera_dict['world_mat_%d' % idx].astype(np.float32)
|
||||
for idx in range(self.n_images)
|
||||
]
|
||||
self.scale_mats_np = [
|
||||
camera_dict['scale_mat_%d' % idx].astype(np.float32)
|
||||
for idx in range(self.n_images)
|
||||
]
|
||||
|
||||
self.intrinsics_all = []
|
||||
self.pose_all = []
|
||||
for scale_mat, world_mat in zip(self.scale_mats_np,
|
||||
self.world_mats_np):
|
||||
P = world_mat @ scale_mat
|
||||
P = P[:3, :4]
|
||||
intrinsics, pose = load_K_Rt_from_P(None, P)
|
||||
self.intrinsics_all.append(torch.from_numpy(intrinsics).float())
|
||||
self.pose_all.append(torch.from_numpy(pose).float())
|
||||
|
||||
self.intrinsics_all = torch.stack(self.intrinsics_all).to(
|
||||
self.device) # [n_images, 4, 4]
|
||||
self.intrinsics_all_inv = torch.inverse(
|
||||
self.intrinsics_all) # [n_images, 4, 4]
|
||||
self.focal = self.intrinsics_all[0][0, 0]
|
||||
self.pose_all = torch.stack(self.pose_all).to(
|
||||
self.device) # [n_images, 4, 4]
|
||||
|
||||
object_bbox_min = np.array([-1.01, -1.01, -1.01, 1.0])
|
||||
object_bbox_max = np.array([1.01, 1.01, 1.01, 1.0])
|
||||
# Object scale mat: region of interest to **extract mesh**
|
||||
object_scale_mat = np.load(
|
||||
os.path.join(self.data_dir, 'cameras_sphere.npz'))['scale_mat_0']
|
||||
object_bbox_min = np.linalg.inv(
|
||||
self.scale_mats_np[0]) @ object_scale_mat @ object_bbox_min[:,
|
||||
None]
|
||||
object_bbox_max = np.linalg.inv(
|
||||
self.scale_mats_np[0]) @ object_scale_mat @ object_bbox_max[:,
|
||||
None]
|
||||
self.object_bbox_min = object_bbox_min[:3, 0]
|
||||
self.object_bbox_max = object_bbox_max[:3, 0]
|
||||
|
||||
print('Load data: End')
|
||||
|
||||
def gen_rays_at(self, img_idx, resolution_level=1):
|
||||
"""
|
||||
Generate rays at world space from one camera.
|
||||
"""
|
||||
level = resolution_level
|
||||
tx = torch.linspace(0, self.W - 1, self.W // level)
|
||||
ty = torch.linspace(0, self.H - 1, self.H // level)
|
||||
pixels_x, pixels_y = torch.meshgrid(tx, ty)
|
||||
p = torch.stack(
|
||||
[pixels_x, pixels_y, torch.ones_like(pixels_y)], dim=-1) # W, H, 3
|
||||
p = torch.matmul(self.intrinsics_all_inv[img_idx, None, None, :3, :3],
|
||||
p[:, :, :, None]).squeeze() # W, H, 3
|
||||
rays_v = p / torch.linalg.norm(
|
||||
p, ord=2, dim=-1, keepdim=True) # W, H, 3
|
||||
rays_v = torch.matmul(self.pose_all[img_idx, None, None, :3, :3],
|
||||
rays_v[:, :, :, None]).squeeze() # W, H, 3
|
||||
rays_o = self.pose_all[img_idx, None, None, :3,
|
||||
3].expand(rays_v.shape) # W, H, 3
|
||||
return rays_o.transpose(0, 1), rays_v.transpose(0, 1)
|
||||
|
||||
def gen_rays_o_at(self, img_idx):
|
||||
"""
|
||||
Generate rays_o at world space from one camera.
|
||||
"""
|
||||
rays_o = self.pose_all[img_idx, :3, 3]
|
||||
return rays_o
|
||||
|
||||
# add
|
||||
def gen_rays_at_camera(self, pose, resolution_level=1):
|
||||
"""
|
||||
Generate rays at world space from one camera.
|
||||
"""
|
||||
level = resolution_level
|
||||
tx = torch.linspace(0, self.W - 1, self.W // level)
|
||||
ty = torch.linspace(0, self.H - 1, self.H // level)
|
||||
pixels_x, pixels_y = torch.meshgrid(tx, ty)
|
||||
p = torch.stack(
|
||||
[pixels_x, pixels_y, torch.ones_like(pixels_y)], dim=-1) # W, H, 3
|
||||
p = torch.matmul(self.intrinsics_all_inv[0, None, None, :3, :3],
|
||||
p[:, :, :, None]).squeeze() # W, H, 3
|
||||
rays_v = p / torch.linalg.norm(
|
||||
p, ord=2, dim=-1, keepdim=True) # W, H, 3
|
||||
rays_v = torch.matmul(pose[:3, :3], rays_v[:, :, :,
|
||||
None]).squeeze() # W, H, 3
|
||||
rays_o = pose[:3, 3].expand(rays_v.shape) # W, H, 3
|
||||
return rays_o.transpose(0, 1), rays_v.transpose(0, 1)
|
||||
|
||||
def gen_random_rays_at(self, img_idx, batch_size):
|
||||
"""
|
||||
Generate random rays at world space from one camera.
|
||||
"""
|
||||
pixels_x = torch.randint(low=0, high=self.W, size=[batch_size]) # bs
|
||||
pixels_y = torch.randint(low=0, high=self.H, size=[batch_size]) # bs
|
||||
color = self.images[img_idx][(pixels_y, pixels_x)] # batch_size, 3
|
||||
mask = self.masks[img_idx][(pixels_y, pixels_x)] # batch_size, 3
|
||||
|
||||
depth = self.depths[img_idx][(pixels_y, pixels_x)] # batch_size, 1
|
||||
|
||||
p = torch.stack(
|
||||
[pixels_x, pixels_y, torch.ones_like(pixels_y)],
|
||||
dim=-1).float() # batch_size, 3
|
||||
p = torch.matmul(self.intrinsics_all_inv[img_idx, None, :3, :3],
|
||||
p[:, :, None]).squeeze() # batch_size, 3
|
||||
rays_v = p / torch.linalg.norm(
|
||||
p, ord=2, dim=-1, keepdim=True) # batch_size, 3
|
||||
rays_v = torch.matmul(self.pose_all[img_idx, None, :3, :3],
|
||||
rays_v[:, :, None]).squeeze() # batch_size, 3
|
||||
rays_o = self.pose_all[img_idx, None, :3,
|
||||
3].expand(rays_v.shape) # batch_size, 3
|
||||
return torch.cat(
|
||||
[rays_o.cpu(),
|
||||
rays_v.cpu(), color, mask[:, :1], depth[:, None]],
|
||||
dim=-1).cuda() # batch_size, 10
|
||||
|
||||
def gen_random_rays_at_mask(self, img_idx, batch_size):
|
||||
"""
|
||||
Generate random rays at world space from one camera.
|
||||
"""
|
||||
pixels_x = torch.randint(low=0, high=self.W, size=[batch_size]) # bs
|
||||
pixels_y = torch.randint(low=0, high=self.H, size=[batch_size]) # bs
|
||||
color = self.images[img_idx][(pixels_y, pixels_x)] # batch_size, 3
|
||||
mask = self.masks[img_idx][(pixels_y, pixels_x)] # batch_size, 3
|
||||
|
||||
depth = self.depths[img_idx][(pixels_y, pixels_x)] # batch_size, 1
|
||||
|
||||
p = torch.stack(
|
||||
[pixels_x, pixels_y, torch.ones_like(pixels_y)],
|
||||
dim=-1).float() # batch_size, 3
|
||||
p = torch.matmul(self.intrinsics_all_inv[img_idx, None, :3, :3],
|
||||
p[:, :, None]).squeeze() # batch_size, 3
|
||||
rays_v = p / torch.linalg.norm(
|
||||
p, ord=2, dim=-1, keepdim=True) # batch_size, 3
|
||||
rays_v = torch.matmul(self.pose_all[img_idx, None, :3, :3],
|
||||
rays_v[:, :, None]).squeeze() # batch_size, 3
|
||||
rays_o = self.pose_all[img_idx, None, :3,
|
||||
3].expand(rays_v.shape) # batch_size, 3
|
||||
return torch.cat(
|
||||
[rays_o.cpu(),
|
||||
rays_v.cpu(), color, mask[:, :1], depth[:, None]],
|
||||
dim=-1).cuda() # batch_size, 10
|
||||
|
||||
def gen_rays_between(self, idx_0, idx_1, ratio, resolution_level=1):
|
||||
"""
|
||||
Interpolate pose between two cameras.
|
||||
"""
|
||||
level = resolution_level
|
||||
tx = torch.linspace(0, self.W - 1, self.W // level)
|
||||
ty = torch.linspace(0, self.H - 1, self.H // level)
|
||||
pixels_x, pixels_y = torch.meshgrid(tx, ty)
|
||||
p = torch.stack(
|
||||
[pixels_x, pixels_y, torch.ones_like(pixels_y)], dim=-1) # W, H, 3
|
||||
p = torch.matmul(self.intrinsics_all_inv[0, None, None, :3, :3],
|
||||
p[:, :, :, None]).squeeze() # W, H, 3
|
||||
rays_v = p / torch.linalg.norm(
|
||||
p, ord=2, dim=-1, keepdim=True) # W, H, 3
|
||||
trans = self.pose_all[idx_0, :3, 3] * (
|
||||
1.0 - ratio) + self.pose_all[idx_1, :3, 3] * ratio
|
||||
pose_0 = self.pose_all[idx_0].detach().cpu().numpy()
|
||||
pose_1 = self.pose_all[idx_1].detach().cpu().numpy()
|
||||
pose_0 = np.linalg.inv(pose_0)
|
||||
pose_1 = np.linalg.inv(pose_1)
|
||||
rot_0 = pose_0[:3, :3]
|
||||
rot_1 = pose_1[:3, :3]
|
||||
rots = Rot.from_matrix(np.stack([rot_0, rot_1]))
|
||||
key_times = [0, 1]
|
||||
slerp = Slerp(key_times, rots)
|
||||
rot = slerp(ratio)
|
||||
pose = np.diag([1.0, 1.0, 1.0, 1.0])
|
||||
pose = pose.astype(np.float32)
|
||||
pose[:3, :3] = rot.as_matrix()
|
||||
pose[:3, 3] = ((1.0 - ratio) * pose_0 + ratio * pose_1)[:3, 3]
|
||||
pose = np.linalg.inv(pose)
|
||||
rot = torch.from_numpy(pose[:3, :3]).cuda()
|
||||
trans = torch.from_numpy(pose[:3, 3]).cuda()
|
||||
rays_v = torch.matmul(rot[None, None, :3, :3],
|
||||
rays_v[:, :, :, None]).squeeze() # W, H, 3
|
||||
rays_o = trans[None, None, :3].expand(rays_v.shape) # W, H, 3
|
||||
return rays_o.transpose(0, 1), rays_v.transpose(0, 1), pose
|
||||
|
||||
def gen_rays_across(self, idx_0, idx_1, ratio, resolution_level=1):
|
||||
"""
|
||||
Interpolate pose between two cameras.
|
||||
"""
|
||||
level = resolution_level
|
||||
tx = torch.linspace(0, self.W - 1, self.W // level)
|
||||
ty = torch.linspace(0, self.H - 1, self.H // level)
|
||||
pixels_x, pixels_y = torch.meshgrid(tx, ty)
|
||||
p = torch.stack(
|
||||
[pixels_x, pixels_y, torch.ones_like(pixels_y)], dim=-1) # W, H, 3
|
||||
p = torch.matmul(self.intrinsics_all_inv[0, None, None, :3, :3],
|
||||
p[:, :, :, None]).squeeze() # W, H, 3
|
||||
rays_v = p / torch.linalg.norm(
|
||||
p, ord=2, dim=-1, keepdim=True) # W, H, 3
|
||||
trans = self.pose_all[idx_0, :3, 3] * (
|
||||
1.0 - ratio) + self.pose_all[idx_1, :3, 3] * ratio
|
||||
pose_0 = self.pose_all[idx_0].detach().cpu().numpy()
|
||||
pose_1 = self.pose_all[idx_1].detach().cpu().numpy()
|
||||
pose_0 = np.linalg.inv(pose_0)
|
||||
pose_1 = np.linalg.inv(pose_1)
|
||||
rot_0 = pose_0[:3, :3]
|
||||
rot_1 = pose_1[:3, :3]
|
||||
rots = Rot.from_matrix(np.stack([rot_0, rot_1]))
|
||||
key_times = [0, 1]
|
||||
slerp = Slerp(key_times, rots)
|
||||
rot = slerp(ratio)
|
||||
pose = np.diag([1.0, 1.0, 1.0, 1.0])
|
||||
pose = pose.astype(np.float32)
|
||||
pose[:3, :3] = rot.as_matrix()
|
||||
pose[:3, 3] = ((1.0 - ratio) * pose_0 + ratio * pose_1)[:3, 3]
|
||||
pose = np.linalg.inv(pose)
|
||||
rot = torch.from_numpy(pose[:3, :3]).cuda()
|
||||
trans = torch.from_numpy(pose[:3, 3]).cuda()
|
||||
rays_v = torch.matmul(rot[None, None, :3, :3],
|
||||
rays_v[:, :, :, None]).squeeze() # W, H, 3
|
||||
rays_o = trans[None, None, :3].expand(rays_v.shape) # W, H, 3
|
||||
return rays_o.transpose(0, 1), rays_v.transpose(0, 1), pose
|
||||
|
||||
def near_far_from_sphere(self, rays_o, rays_d):
|
||||
a = torch.sum(rays_d**2, dim=-1, keepdim=True)
|
||||
b = 2.0 * torch.sum(rays_o * rays_d, dim=-1, keepdim=True)
|
||||
mid = 0.5 * (-b) / a
|
||||
near = mid - 1.0
|
||||
far = mid + 1.0
|
||||
return near, far
|
||||
|
||||
def image_at(self, idx, resolution_level):
|
||||
img = cv.imread(self.images_lis[idx])
|
||||
return (cv.resize(img, (self.W // resolution_level,
|
||||
self.H // resolution_level))).clip(0, 255)
|
||||
390
modelscope/models/cv/surface_recon_common/fields.py
Normal file
390
modelscope/models/cv/surface_recon_common/fields.py
Normal file
@@ -0,0 +1,390 @@
|
||||
# ------------------------------------------------------------------------
|
||||
# Modified from https://github.com/Totoro97/NeuS/blob/main/models/fields.py
|
||||
# Copyright (c) 2021 Peng Wang. All Rights Reserved.
|
||||
# ------------------------------------------------------------------------
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils import spectral_norm
|
||||
|
||||
|
||||
class SDFNetwork(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
d_in,
|
||||
d_out,
|
||||
d_hidden,
|
||||
n_layers,
|
||||
skip_in=(4, ),
|
||||
multires=0,
|
||||
bias=0.5,
|
||||
scale=1,
|
||||
geometric_init=True,
|
||||
weight_norm=True,
|
||||
inside_outside=False):
|
||||
super(SDFNetwork, self).__init__()
|
||||
|
||||
dims = [d_in] + [d_hidden for _ in range(n_layers)] + [d_out]
|
||||
|
||||
self.embed_fn_fine = None
|
||||
|
||||
if multires > 0:
|
||||
embed_fn, input_ch = get_embedder(multires, input_dims=d_in)
|
||||
self.embed_fn_fine = embed_fn
|
||||
dims[0] = input_ch
|
||||
|
||||
self.num_layers = len(dims)
|
||||
self.skip_in = skip_in
|
||||
self.scale = scale
|
||||
|
||||
for layer in range(0, self.num_layers - 1):
|
||||
if layer + 1 in self.skip_in:
|
||||
out_dim = dims[layer + 1] - dims[0]
|
||||
else:
|
||||
out_dim = dims[layer + 1]
|
||||
|
||||
lin = nn.Linear(dims[layer], out_dim)
|
||||
|
||||
if geometric_init:
|
||||
if layer == self.num_layers - 2:
|
||||
if not inside_outside:
|
||||
torch.nn.init.normal_(
|
||||
lin.weight,
|
||||
mean=np.sqrt(np.pi) / np.sqrt(dims[layer]),
|
||||
std=0.0001)
|
||||
torch.nn.init.constant_(lin.bias, -bias)
|
||||
else:
|
||||
torch.nn.init.normal_(
|
||||
lin.weight,
|
||||
mean=-np.sqrt(np.pi) / np.sqrt(dims[layer]),
|
||||
std=0.0001)
|
||||
torch.nn.init.constant_(lin.bias, bias)
|
||||
elif multires > 0 and layer == 0:
|
||||
torch.nn.init.constant_(lin.bias, 0.0)
|
||||
torch.nn.init.constant_(lin.weight[:, 3:], 0.0)
|
||||
torch.nn.init.normal_(lin.weight[:, :3], 0.0,
|
||||
np.sqrt(2) / np.sqrt(out_dim))
|
||||
elif multires > 0 and layer in self.skip_in:
|
||||
torch.nn.init.constant_(lin.bias, 0.0)
|
||||
torch.nn.init.normal_(lin.weight, 0.0,
|
||||
np.sqrt(2) / np.sqrt(out_dim))
|
||||
torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3):],
|
||||
0.0)
|
||||
else:
|
||||
torch.nn.init.constant_(lin.bias, 0.0)
|
||||
torch.nn.init.normal_(lin.weight, 0.0,
|
||||
np.sqrt(2) / np.sqrt(out_dim))
|
||||
|
||||
if weight_norm:
|
||||
lin = nn.utils.weight_norm(lin)
|
||||
|
||||
setattr(self, 'lin' + str(layer), lin)
|
||||
|
||||
self.activation = nn.Softplus(beta=100)
|
||||
|
||||
def forward(self, inputs):
|
||||
inputs = inputs * self.scale
|
||||
|
||||
if self.embed_fn_fine is not None:
|
||||
inputs = self.embed_fn_fine(inputs)
|
||||
|
||||
x = inputs
|
||||
for layer in range(0, self.num_layers - 1):
|
||||
lin = getattr(self, 'lin' + str(layer))
|
||||
|
||||
if layer in self.skip_in:
|
||||
x = torch.cat([x, inputs], 1) / np.sqrt(2)
|
||||
|
||||
x = lin(x)
|
||||
|
||||
if layer < self.num_layers - 2:
|
||||
x = self.activation(x)
|
||||
return torch.cat([x[:, :1] / self.scale, x[:, 1:]], dim=-1)
|
||||
|
||||
def sdf(self, x):
|
||||
return self.forward(x)[:, :1]
|
||||
|
||||
def sdf_hidden_appearance(self, x):
|
||||
return self.forward(x)
|
||||
|
||||
def gradient(self, x):
|
||||
x.requires_grad_(True)
|
||||
with torch.enable_grad():
|
||||
y = self.sdf(x)
|
||||
|
||||
d_output = torch.ones_like(y, requires_grad=False, device=y.device)
|
||||
gradients = torch.autograd.grad(
|
||||
outputs=y,
|
||||
inputs=x,
|
||||
grad_outputs=d_output,
|
||||
create_graph=True,
|
||||
retain_graph=True,
|
||||
only_inputs=True)[0]
|
||||
return gradients.unsqueeze(1)
|
||||
|
||||
|
||||
class RenderingNetwork(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
d_feature,
|
||||
mode,
|
||||
d_in,
|
||||
d_out,
|
||||
d_hidden,
|
||||
n_layers,
|
||||
weight_norm=True,
|
||||
multires_view=0,
|
||||
squeeze_out=True):
|
||||
super().__init__()
|
||||
|
||||
self.mode = mode
|
||||
self.squeeze_out = squeeze_out
|
||||
dims = [d_in + d_feature] + [d_hidden
|
||||
for _ in range(n_layers)] + [d_out]
|
||||
|
||||
self.embedview_fn = None
|
||||
if multires_view > 0:
|
||||
embedview_fn, input_ch = get_embedder(multires_view)
|
||||
self.embedview_fn = embedview_fn
|
||||
dims[0] += (input_ch - 3)
|
||||
|
||||
self.num_layers = len(dims)
|
||||
|
||||
for layer in range(0, self.num_layers - 1):
|
||||
out_dim = dims[layer + 1]
|
||||
lin = nn.Linear(dims[layer], out_dim)
|
||||
|
||||
if weight_norm:
|
||||
lin = nn.utils.weight_norm(lin)
|
||||
|
||||
setattr(self, 'lin' + str(layer), lin)
|
||||
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
def forward(self, points, normals, view_dirs, feature_vectors):
|
||||
if self.embedview_fn is not None:
|
||||
view_dirs = self.embedview_fn(view_dirs)
|
||||
|
||||
rendering_input = None
|
||||
|
||||
if self.mode == 'idr':
|
||||
rendering_input = torch.cat(
|
||||
[points, view_dirs, normals, feature_vectors], dim=-1)
|
||||
elif self.mode == 'no_view_dir':
|
||||
rendering_input = torch.cat([points, normals, feature_vectors],
|
||||
dim=-1)
|
||||
elif self.mode == 'no_normal':
|
||||
rendering_input = torch.cat([points, view_dirs, feature_vectors],
|
||||
dim=-1)
|
||||
|
||||
x = rendering_input
|
||||
|
||||
for layer in range(0, self.num_layers - 1):
|
||||
lin = getattr(self, 'lin' + str(layer))
|
||||
x = lin(x)
|
||||
if layer < self.num_layers - 2:
|
||||
x = self.relu(x)
|
||||
if self.squeeze_out:
|
||||
x = torch.sigmoid(x)
|
||||
return x
|
||||
|
||||
|
||||
class NeRF(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
D=8,
|
||||
W=256,
|
||||
d_in=3,
|
||||
d_in_view=3,
|
||||
multires=0,
|
||||
multires_view=0,
|
||||
output_ch=4,
|
||||
skips=[4],
|
||||
use_viewdirs=False):
|
||||
super(NeRF, self).__init__()
|
||||
self.D = D
|
||||
self.W = W
|
||||
self.d_in = d_in
|
||||
self.d_in_view = d_in_view
|
||||
self.input_ch = 3
|
||||
self.input_ch_view = 3
|
||||
self.embed_fn = None
|
||||
self.embed_fn_view = None
|
||||
|
||||
if multires > 0:
|
||||
embed_fn, input_ch = get_embedder(multires, input_dims=d_in)
|
||||
self.embed_fn = embed_fn
|
||||
self.input_ch = input_ch
|
||||
|
||||
if multires_view > 0:
|
||||
embed_fn_view, input_ch_view = get_embedder(
|
||||
multires_view, input_dims=d_in_view)
|
||||
self.embed_fn_view = embed_fn_view
|
||||
self.input_ch_view = input_ch_view
|
||||
|
||||
self.skips = skips
|
||||
self.use_viewdirs = use_viewdirs
|
||||
|
||||
self.pts_linears = nn.ModuleList([nn.Linear(self.input_ch, W)] + [
|
||||
nn.Linear(W, W) if i not in
|
||||
self.skips else nn.Linear(W + self.input_ch, W)
|
||||
for i in range(D - 1)
|
||||
])
|
||||
|
||||
self.views_linears = nn.ModuleList(
|
||||
[nn.Linear(self.input_ch_view + W, W // 2)])
|
||||
|
||||
if use_viewdirs:
|
||||
self.feature_linear = nn.Linear(W, W)
|
||||
self.alpha_linear = nn.Linear(W, 1)
|
||||
self.rgb_linear = nn.Linear(W // 2, 3)
|
||||
else:
|
||||
self.output_linear = nn.Linear(W, output_ch)
|
||||
|
||||
def forward(self, input_pts, input_views):
|
||||
if self.embed_fn is not None:
|
||||
input_pts = self.embed_fn(input_pts)
|
||||
if self.embed_fn_view is not None:
|
||||
input_views = self.embed_fn_view(input_views)
|
||||
|
||||
h = input_pts
|
||||
for i, l in enumerate(self.pts_linears):
|
||||
h = self.pts_linears[i](h)
|
||||
h = F.relu(h)
|
||||
if i in self.skips:
|
||||
h = torch.cat([input_pts, h], -1)
|
||||
|
||||
if self.use_viewdirs:
|
||||
alpha = self.alpha_linear(h)
|
||||
feature = self.feature_linear(h)
|
||||
h = torch.cat([feature, input_views], -1)
|
||||
|
||||
for i, l in enumerate(self.views_linears):
|
||||
h = self.views_linears[i](h)
|
||||
h = F.relu(h)
|
||||
|
||||
rgb = self.rgb_linear(h)
|
||||
return alpha, rgb
|
||||
else:
|
||||
assert False
|
||||
|
||||
|
||||
class SingleVarianceNetwork(nn.Module):
|
||||
|
||||
def __init__(self, init_val):
|
||||
super(SingleVarianceNetwork, self).__init__()
|
||||
self.register_parameter('variance',
|
||||
nn.Parameter(torch.tensor(init_val)))
|
||||
|
||||
def forward(self, x):
|
||||
return torch.ones([len(x), 1],
|
||||
device=self.variance.device) * torch.exp(
|
||||
self.variance * 10.0)
|
||||
|
||||
|
||||
class Mean(nn.Module):
|
||||
|
||||
def __init__(self, dim: list, keepdim=False):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.keepdim = keepdim
|
||||
|
||||
def forward(self, x):
|
||||
return torch.mean(x, self.dim, self.keepdim)
|
||||
|
||||
|
||||
class Discriminator(nn.Module):
|
||||
|
||||
def __init__(self, channel=32, patch=True):
|
||||
super().__init__()
|
||||
self.imsize = 32
|
||||
self.nc = 3
|
||||
|
||||
self.channel = channel
|
||||
self.patch = patch
|
||||
in_channel = 3
|
||||
layer = []
|
||||
for idx in range(3):
|
||||
layer.extend([
|
||||
spectral_norm(
|
||||
nn.Conv2d(
|
||||
in_channel, channel * (2**idx), 3, stride=2,
|
||||
padding=1)),
|
||||
nn.LeakyReLU(inplace=True),
|
||||
spectral_norm(
|
||||
nn.Conv2d(
|
||||
channel * (2**idx),
|
||||
channel * (2**idx),
|
||||
3,
|
||||
stride=1,
|
||||
padding=1)),
|
||||
nn.LeakyReLU(inplace=True),
|
||||
])
|
||||
in_channel = channel * (2**idx)
|
||||
self.body = nn.Sequential(*layer)
|
||||
if self.patch:
|
||||
self.head = spectral_norm(nn.Conv2d(in_channel, 1, 1, padding=0))
|
||||
else:
|
||||
self.head = nn.Sequential(Mean([1, 2]), nn.Linear(in_channel, 1))
|
||||
|
||||
def forward(self, x):
|
||||
x = x[:, :self.nc]
|
||||
x = x.view(-1, self.imsize, self.imsize, self.nc).permute(0, 3, 1, 2)
|
||||
x = self.body(x)
|
||||
x = self.head(x)
|
||||
return x
|
||||
|
||||
|
||||
class Embedder:
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.kwargs = kwargs
|
||||
self.create_embedding_fn()
|
||||
|
||||
def create_embedding_fn(self):
|
||||
embed_fns = []
|
||||
d = self.kwargs['input_dims']
|
||||
out_dim = 0
|
||||
if self.kwargs['include_input']:
|
||||
embed_fns.append(lambda x: x)
|
||||
out_dim += d
|
||||
|
||||
max_freq = self.kwargs['max_freq_log2']
|
||||
N_freqs = self.kwargs['num_freqs']
|
||||
|
||||
if self.kwargs['log_sampling']:
|
||||
freq_bands = 2.**torch.linspace(0., max_freq, N_freqs)
|
||||
else:
|
||||
freq_bands = torch.linspace(2.**0., 2.**max_freq, N_freqs)
|
||||
|
||||
for freq in freq_bands:
|
||||
for p_fn in self.kwargs['periodic_fns']:
|
||||
embed_fns.append(
|
||||
lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq))
|
||||
out_dim += d
|
||||
|
||||
self.embed_fns = embed_fns
|
||||
self.out_dim = out_dim
|
||||
|
||||
def embed(self, inputs):
|
||||
return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
|
||||
|
||||
|
||||
def get_embedder(multires, input_dims=3):
|
||||
embed_kwargs = {
|
||||
'include_input': True,
|
||||
'input_dims': input_dims,
|
||||
'max_freq_log2': multires - 1,
|
||||
'num_freqs': multires,
|
||||
'log_sampling': True,
|
||||
'periodic_fns': [torch.sin, torch.cos],
|
||||
}
|
||||
|
||||
embedder_obj = Embedder(**embed_kwargs)
|
||||
|
||||
def embed(x, eo=embedder_obj):
|
||||
return eo.embed(x)
|
||||
|
||||
return embed, embedder_obj.out_dim
|
||||
388
modelscope/models/cv/surface_recon_common/renderer.py
Normal file
388
modelscope/models/cv/surface_recon_common/renderer.py
Normal file
@@ -0,0 +1,388 @@
|
||||
# ------------------------------------------------------------------------
|
||||
# Modified from https://github.com/Totoro97/NeuS/blob/main/models/renderer.py
|
||||
# Copyright (c) 2021 Peng Wang.
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
# ------------------------------------------------------------------------
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .fields import RenderingNetwork, SDFNetwork, SingleVarianceNetwork
|
||||
from .utils import extract_geometry, sample_pdf
|
||||
|
||||
|
||||
class SurfaceRenderer(nn.Module):
|
||||
|
||||
def __init__(self, conf, device):
|
||||
super().__init__()
|
||||
self.conf = conf
|
||||
self.device = device
|
||||
self.sdf_network = SDFNetwork(**self.conf['sdf_network']).to(
|
||||
self.device)
|
||||
self.variance_network = SingleVarianceNetwork(
|
||||
**self.conf['variance_network']).to(self.device)
|
||||
self.color_network = RenderingNetwork(
|
||||
**self.conf['rendering_network']).to(self.device)
|
||||
self.light_network = RenderingNetwork(**self.conf['light_network']).to(
|
||||
self.device)
|
||||
self.n_samples = self.conf['neus_renderer']['n_samples']
|
||||
self.n_importance = self.conf['neus_renderer']['n_importance']
|
||||
self.n_outside = self.conf['neus_renderer']['n_outside']
|
||||
self.up_sample_steps = self.conf['neus_renderer']['up_sample_steps']
|
||||
self.perturb = self.conf['neus_renderer']['perturb']
|
||||
|
||||
def extract_geometry(self,
|
||||
bound_min,
|
||||
bound_max,
|
||||
resolution,
|
||||
threshold=0.0,
|
||||
device='cuda'):
|
||||
return extract_geometry(
|
||||
bound_min,
|
||||
bound_max,
|
||||
resolution=resolution,
|
||||
threshold=threshold,
|
||||
query_func=lambda pts: -self.sdf_network.sdf(pts),
|
||||
device=device)
|
||||
|
||||
def render_core_outside(self,
|
||||
rays_o,
|
||||
rays_d,
|
||||
z_vals,
|
||||
sample_dist,
|
||||
nerf,
|
||||
background_rgb=None):
|
||||
batch_size, n_samples = z_vals.shape
|
||||
|
||||
dists = z_vals[..., 1:] - z_vals[..., :-1]
|
||||
dists = torch.cat(
|
||||
[dists,
|
||||
torch.Tensor([sample_dist]).expand(dists[..., :1].shape)], -1)
|
||||
mid_z_vals = z_vals + dists * 0.5
|
||||
|
||||
pts = rays_o[:, None, :] + rays_d[:, None, :] * mid_z_vals[
|
||||
..., :, None] # batch_size, n_samples, 3
|
||||
|
||||
dis_to_center = torch.linalg.norm(
|
||||
pts, ord=2, dim=-1, keepdim=True).clip(1.0, 1e10)
|
||||
pts = torch.cat([pts / dis_to_center, 1.0 / dis_to_center],
|
||||
dim=-1) # batch_size, n_samples, 4
|
||||
|
||||
dirs = rays_d[:, None, :].expand(batch_size, n_samples, 3)
|
||||
|
||||
pts = pts.reshape(-1, 3 + int(self.n_outside > 0))
|
||||
dirs = dirs.reshape(-1, 3)
|
||||
|
||||
density, sampled_color = nerf(pts, dirs)
|
||||
alpha = 1.0 - torch.exp(
|
||||
-F.softplus(density.reshape(batch_size, n_samples)) * dists)
|
||||
alpha = alpha.reshape(batch_size, n_samples)
|
||||
weights = alpha * torch.cumprod(
|
||||
torch.cat([torch.ones([batch_size, 1]), 1. - alpha + 1e-7], -1),
|
||||
-1)[:, :-1]
|
||||
sampled_color = sampled_color.reshape(batch_size, n_samples, 3)
|
||||
color = (weights[:, :, None] * sampled_color).sum(dim=1)
|
||||
if background_rgb is not None:
|
||||
color = color + background_rgb * (
|
||||
1.0 - weights.sum(dim=-1, keepdim=True))
|
||||
|
||||
return {
|
||||
'color': color,
|
||||
'sampled_color': sampled_color,
|
||||
'alpha': alpha,
|
||||
'weights': weights,
|
||||
}
|
||||
|
||||
def up_sample(self, rays_o, rays_d, z_vals, sdf, n_importance, inv_s):
|
||||
batch_size, n_samples = z_vals.shape
|
||||
pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None]
|
||||
radius = torch.linalg.norm(pts, ord=2, dim=-1, keepdim=False)
|
||||
inside_sphere = (radius[:, :-1] < 1.0) | (radius[:, 1:] < 1.0)
|
||||
sdf = sdf.reshape(batch_size, n_samples)
|
||||
prev_sdf, next_sdf = sdf[:, :-1], sdf[:, 1:]
|
||||
prev_z_vals, next_z_vals = z_vals[:, :-1], z_vals[:, 1:]
|
||||
mid_sdf = (prev_sdf + next_sdf) * 0.5
|
||||
cos_val = (next_sdf - prev_sdf) / (next_z_vals - prev_z_vals + 1e-5)
|
||||
|
||||
prev_cos_val = torch.cat(
|
||||
[torch.zeros([batch_size, 1]).to(self.device), cos_val[:, :-1]],
|
||||
dim=-1)
|
||||
cos_val = torch.stack([prev_cos_val, cos_val], dim=-1)
|
||||
cos_val, _ = torch.min(cos_val, dim=-1, keepdim=False)
|
||||
cos_val = cos_val.clip(-1e3, 0.0) * inside_sphere
|
||||
|
||||
dist = (next_z_vals - prev_z_vals)
|
||||
prev_esti_sdf = mid_sdf - cos_val * dist * 0.5
|
||||
next_esti_sdf = mid_sdf + cos_val * dist * 0.5
|
||||
prev_cdf = torch.sigmoid(prev_esti_sdf * inv_s)
|
||||
next_cdf = torch.sigmoid(next_esti_sdf * inv_s)
|
||||
alpha = (prev_cdf - next_cdf + 1e-5) / (prev_cdf + 1e-5)
|
||||
weights = alpha * torch.cumprod(
|
||||
torch.cat([
|
||||
torch.ones([batch_size, 1]).to(self.device), 1. - alpha + 1e-7
|
||||
], -1), -1)[:, :-1]
|
||||
|
||||
z_samples = sample_pdf(
|
||||
z_vals, weights, n_importance, det=True,
|
||||
device=self.device).detach()
|
||||
return z_samples
|
||||
|
||||
def cat_z_vals(self, rays_o, rays_d, z_vals, new_z_vals, sdf, last=False):
|
||||
batch_size, n_samples = z_vals.shape
|
||||
_, n_importance = new_z_vals.shape
|
||||
pts = rays_o[:,
|
||||
None, :] + rays_d[:, None, :] * new_z_vals[..., :, None]
|
||||
z_vals = torch.cat([z_vals, new_z_vals], dim=-1)
|
||||
z_vals, index = torch.sort(z_vals, dim=-1)
|
||||
|
||||
if not last:
|
||||
new_sdf = self.sdf_network.sdf(pts.reshape(-1, 3)).reshape(
|
||||
batch_size, n_importance)
|
||||
sdf = torch.cat([sdf, new_sdf], dim=-1)
|
||||
xx = torch.arange(batch_size)[:, None].expand(
|
||||
batch_size, n_samples + n_importance).reshape(-1)
|
||||
index = index.reshape(-1)
|
||||
sdf = sdf[(xx, index)].reshape(batch_size,
|
||||
n_samples + n_importance)
|
||||
|
||||
return z_vals, sdf
|
||||
|
||||
def render_core(self,
|
||||
rays_o,
|
||||
rays_d,
|
||||
z_vals,
|
||||
sample_dist,
|
||||
sdf_network,
|
||||
deviation_network,
|
||||
color_network,
|
||||
light_network,
|
||||
depth_z=None,
|
||||
background_alpha=None,
|
||||
bg_sampled_color=None,
|
||||
background_rgb=None,
|
||||
cos_anneal_ratio=0.0):
|
||||
batch_size, n_samples = z_vals.shape
|
||||
|
||||
dists = z_vals[..., 1:] - z_vals[..., :-1]
|
||||
dists = torch.cat([
|
||||
dists,
|
||||
torch.Tensor([sample_dist]).expand(dists[..., :1].shape).to(
|
||||
self.device)
|
||||
], -1)
|
||||
mid_z_vals = z_vals + dists * 0.5
|
||||
|
||||
pts = rays_o[:,
|
||||
None, :] + rays_d[:, None, :] * mid_z_vals[..., :, None]
|
||||
dirs = rays_d[:, None, :].expand(pts.shape)
|
||||
|
||||
pts = pts.reshape(-1, 3)
|
||||
dirs = dirs.reshape(-1, 3)
|
||||
|
||||
sdf_nn_output = sdf_network(pts)
|
||||
sdf = sdf_nn_output[:, :1]
|
||||
feature_vector = sdf_nn_output[:, 1:]
|
||||
|
||||
gradients = sdf_network.gradient(pts).squeeze()
|
||||
sampled_albedo = color_network(pts, gradients, dirs,
|
||||
feature_vector).reshape(
|
||||
batch_size, n_samples, 3)
|
||||
sampled_light = light_network(pts, gradients, dirs,
|
||||
feature_vector).reshape(
|
||||
batch_size, n_samples, 3)
|
||||
sampled_color = sampled_albedo * sampled_light
|
||||
|
||||
inv_s = deviation_network(torch.zeros([1, 3]))[:, :1].clip(1e-6, 1e6)
|
||||
inv_s = inv_s.expand(batch_size * n_samples, 1)
|
||||
|
||||
true_cos = (dirs * gradients).sum(-1, keepdim=True)
|
||||
iter_cos_p1 = F.relu(-true_cos * 0.5 + 0.5) * (1.0 - cos_anneal_ratio)
|
||||
iter_cos = -(iter_cos_p1 + F.relu(-true_cos) * cos_anneal_ratio)
|
||||
|
||||
estimated_next_sdf = sdf + iter_cos * dists.reshape(-1, 1) * 0.5
|
||||
estimated_prev_sdf = sdf - iter_cos * dists.reshape(-1, 1) * 0.5
|
||||
|
||||
prev_cdf = torch.sigmoid(estimated_prev_sdf * inv_s)
|
||||
next_cdf = torch.sigmoid(estimated_next_sdf * inv_s)
|
||||
|
||||
p = prev_cdf - next_cdf
|
||||
c = prev_cdf
|
||||
|
||||
alpha = ((p + 1e-5) / (c + 1e-5)).reshape(batch_size,
|
||||
n_samples).clip(0.0, 1.0)
|
||||
|
||||
pts_norm = torch.linalg.norm(
|
||||
pts, ord=2, dim=-1, keepdim=True).reshape(batch_size, n_samples)
|
||||
inside_sphere = (pts_norm < 1.0).float().detach()
|
||||
relax_inside_sphere = (pts_norm < 1.2).float().detach()
|
||||
|
||||
if background_alpha is not None:
|
||||
alpha = alpha * inside_sphere + background_alpha[:, :n_samples] * (
|
||||
1.0 - inside_sphere)
|
||||
alpha = torch.cat([alpha, background_alpha[:, n_samples:]], dim=-1)
|
||||
foreground_color = sampled_color * inside_sphere[:, :, None]
|
||||
background_color = bg_sampled_color[:, :n_samples] * (
|
||||
1.0 - inside_sphere)[:, :, None]
|
||||
sampled_color = foreground_color + background_color
|
||||
|
||||
sampled_color = torch.cat(
|
||||
[sampled_color, bg_sampled_color[:, n_samples:]], dim=1)
|
||||
|
||||
beta = torch.cat([
|
||||
torch.ones([batch_size, 1], device=alpha.device), 1. - alpha + 1e-7
|
||||
], -1)
|
||||
weights = alpha * torch.cumprod(beta, -1)[:, :-1]
|
||||
weights_sum = weights.sum(dim=-1, keepdim=True)
|
||||
|
||||
color = (sampled_color * weights[:, :, None]).sum(dim=1)
|
||||
if background_rgb is not None:
|
||||
color = color + background_rgb * (1.0 - weights_sum)
|
||||
|
||||
albedo = (sampled_albedo * weights[:, :, None]).sum(dim=1)
|
||||
|
||||
depth = (mid_z_vals * weights).sum(dim=1)
|
||||
if depth_z is not None:
|
||||
pts_depth = rays_o[:, None, :] + rays_d[:, None, :] * depth_z[
|
||||
..., :, None] # n_rays, n_samples, 3
|
||||
pts_depth = pts_depth.reshape(-1, 3)
|
||||
sdf_depth = sdf_network(pts_depth)[:, :1]
|
||||
else:
|
||||
sdf_depth = None
|
||||
|
||||
gradients_norm = torch.linalg.norm(
|
||||
gradients.reshape(batch_size, n_samples, 3), ord=2, dim=-1)
|
||||
gradient_error = (gradients_norm - 1.0)**2
|
||||
gradient_error = (relax_inside_sphere * gradient_error).sum()
|
||||
gradient_error = gradient_error / (relax_inside_sphere.sum() + 1e-5)
|
||||
|
||||
return {
|
||||
'color': color,
|
||||
'albedo': albedo,
|
||||
'depth': depth,
|
||||
'sdf': sdf,
|
||||
'sdf_depth': sdf_depth,
|
||||
'dists': dists,
|
||||
'gradients': gradients.reshape(batch_size, n_samples, 3),
|
||||
's_val': 1.0 / inv_s,
|
||||
'mid_z_vals': mid_z_vals,
|
||||
'weights': weights,
|
||||
'cdf': c.reshape(batch_size, n_samples),
|
||||
'gradient_error': gradient_error,
|
||||
'inside_sphere': inside_sphere
|
||||
}
|
||||
|
||||
def render(self,
|
||||
rays_o,
|
||||
rays_d,
|
||||
near,
|
||||
far,
|
||||
depth_z=None,
|
||||
perturb_overwrite=-1,
|
||||
background_rgb=None,
|
||||
cos_anneal_ratio=0.0):
|
||||
batch_size = len(rays_o)
|
||||
sample_dist = 2.0 / self.n_samples # Assuming the region of interest is a unit sphere
|
||||
z_vals = torch.linspace(0.0, 1.0, self.n_samples).to(self.device)
|
||||
z_vals = near + (far - near) * z_vals[None, :]
|
||||
|
||||
z_vals_outside = None
|
||||
if self.n_outside > 0:
|
||||
z_vals_end = 1.0 - 1.0 / (self.n_outside + 1.0)
|
||||
z_vals_outside = torch.linspace(1e-3, z_vals_end, self.n_outside)
|
||||
|
||||
n_samples = self.n_samples
|
||||
perturb = self.perturb
|
||||
|
||||
if perturb_overwrite >= 0:
|
||||
perturb = perturb_overwrite
|
||||
if perturb > 0:
|
||||
t_rand = (torch.rand([batch_size, 1]).to(self.device) - 0.5)
|
||||
z_vals = z_vals + t_rand * 2.0 / self.n_samples
|
||||
|
||||
if self.n_outside > 0:
|
||||
mids = .5 * (
|
||||
z_vals_outside[..., 1:] + z_vals_outside[..., :-1])
|
||||
upper = torch.cat([mids, z_vals_outside[..., -1:]], -1)
|
||||
lower = torch.cat([z_vals_outside[..., :1], mids], -1)
|
||||
t_rand = torch.rand([batch_size, z_vals_outside.shape[-1]])
|
||||
z_vals_outside = lower[None, :] + (upper
|
||||
- lower)[None, :] * t_rand
|
||||
|
||||
if self.n_outside > 0:
|
||||
z_vals_outside = far / torch.flip(
|
||||
z_vals_outside, dims=[-1]) + 1.0 / self.n_samples
|
||||
|
||||
background_alpha = None
|
||||
background_sampled_color = None
|
||||
|
||||
# Up sample
|
||||
if self.n_importance > 0:
|
||||
with torch.no_grad():
|
||||
pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :,
|
||||
None]
|
||||
sdf = self.sdf_network.sdf(pts.reshape(-1, 3)).reshape(
|
||||
batch_size, self.n_samples)
|
||||
|
||||
for i in range(self.up_sample_steps):
|
||||
new_z_vals = self.up_sample(
|
||||
rays_o, rays_d, z_vals, sdf,
|
||||
self.n_importance // self.up_sample_steps, 64 * 2**i)
|
||||
z_vals, sdf = self.cat_z_vals(
|
||||
rays_o,
|
||||
rays_d,
|
||||
z_vals,
|
||||
new_z_vals,
|
||||
sdf,
|
||||
last=(i + 1 == self.up_sample_steps))
|
||||
|
||||
n_samples = self.n_samples + self.n_importance
|
||||
|
||||
if self.n_outside > 0:
|
||||
z_vals_feed = torch.cat([z_vals, z_vals_outside], dim=-1)
|
||||
z_vals_feed, _ = torch.sort(z_vals_feed, dim=-1)
|
||||
ret_outside = self.render_core_outside(rays_o, rays_d, z_vals_feed,
|
||||
sample_dist, self.nerf)
|
||||
|
||||
background_sampled_color = ret_outside['sampled_color']
|
||||
background_alpha = ret_outside['alpha']
|
||||
|
||||
ret_fine = self.render_core(
|
||||
rays_o,
|
||||
rays_d,
|
||||
z_vals,
|
||||
sample_dist,
|
||||
self.sdf_network,
|
||||
self.variance_network,
|
||||
self.color_network,
|
||||
self.light_network,
|
||||
depth_z=depth_z,
|
||||
background_rgb=background_rgb,
|
||||
background_alpha=background_alpha,
|
||||
background_sampled_color=background_sampled_color,
|
||||
cos_anneal_ratio=cos_anneal_ratio)
|
||||
|
||||
color_fine = ret_fine['color']
|
||||
albedo_fine = ret_fine['albedo']
|
||||
depth_fine = ret_fine['depth']
|
||||
sdf_depth = ret_fine['sdf_depth']
|
||||
weights = ret_fine['weights']
|
||||
weights_sum = weights.sum(dim=-1, keepdim=True)
|
||||
gradients = ret_fine['gradients']
|
||||
s_val = ret_fine['s_val'].reshape(batch_size, n_samples).mean(
|
||||
dim=-1, keepdim=True)
|
||||
|
||||
return {
|
||||
'color_fine': color_fine,
|
||||
'albedo_fine': albedo_fine,
|
||||
'depth_fine': depth_fine,
|
||||
'sdf_depth': sdf_depth,
|
||||
's_val': s_val,
|
||||
'cdf_fine': ret_fine['cdf'],
|
||||
'weight_sum': weights_sum,
|
||||
'weight_max': torch.max(weights, dim=-1, keepdim=True)[0],
|
||||
'gradients': gradients,
|
||||
'weights': weights,
|
||||
'mid_z_vals': ret_fine['mid_z_vals'],
|
||||
'gradient_error': ret_fine['gradient_error'],
|
||||
'inside_sphere': ret_fine['inside_sphere']
|
||||
}
|
||||
@@ -0,0 +1,160 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import trimesh
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models.base import Tensor, TorchModel
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
from .dataset import Dataset
|
||||
from .renderer import SurfaceRenderer
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
__all__ = ['SurfaceReconCommon']
|
||||
|
||||
|
||||
@MODELS.register_module(
|
||||
Tasks.surface_recon_common, module_name=Models.surface_recon_common)
|
||||
class SurfaceReconCommon(TorchModel):
|
||||
|
||||
def __init__(self, model_dir, network_cfg, **kwargs):
|
||||
"""initialize the surface reconstruction model for common objects.
|
||||
|
||||
Args:
|
||||
model_dir (str): the model path.
|
||||
network_cfg (dict): args of network config
|
||||
"""
|
||||
super().__init__(model_dir, **kwargs)
|
||||
logger.info('model params:{}'.format(kwargs))
|
||||
|
||||
if torch.cuda.is_available():
|
||||
self.device = torch.device('cuda')
|
||||
else:
|
||||
raise Exception('GPU is required')
|
||||
|
||||
logger.info(network_cfg)
|
||||
|
||||
self.renderer = SurfaceRenderer(network_cfg, device=self.device)
|
||||
self.ckpt_path = os.path.join(model_dir, 'model.pth')
|
||||
if not os.path.exists(self.ckpt_path):
|
||||
raise Exception('model path not found')
|
||||
self.load_checkpoint(self.ckpt_path)
|
||||
logger.info('load models from %s' % self.ckpt_path)
|
||||
|
||||
self.n_rays = network_cfg['n_rays']
|
||||
|
||||
def load_checkpoint(self, ckpt_path):
|
||||
checkpoint = torch.load(ckpt_path, map_location=self.device)
|
||||
for name, module in self.renderer.named_modules():
|
||||
saved_name = name + '_fine'
|
||||
if saved_name in checkpoint:
|
||||
module.load_state_dict(checkpoint[saved_name])
|
||||
|
||||
def surface_reconstruction(self,
|
||||
data_dir,
|
||||
save_dir,
|
||||
color=False,
|
||||
n_directions=8):
|
||||
|
||||
self.dataset = Dataset(data_dir, self.device)
|
||||
|
||||
bound_min = torch.tensor(
|
||||
self.dataset.object_bbox_min, dtype=torch.float32).to(self.device)
|
||||
bound_max = torch.tensor(
|
||||
self.dataset.object_bbox_max, dtype=torch.float32).to(self.device)
|
||||
|
||||
vertices, triangles = \
|
||||
self.renderer.extract_geometry(bound_min, bound_max, resolution=512, threshold=0.0,
|
||||
device=self.device)
|
||||
if color:
|
||||
pt_vertices = torch.from_numpy(vertices).cuda().reshape(-1, 1,
|
||||
3).float()
|
||||
idx_list = np.linspace(
|
||||
0,
|
||||
self.dataset.n_images,
|
||||
n_directions,
|
||||
endpoint=False,
|
||||
dtype=int)
|
||||
rays_o_list = []
|
||||
for idx in idx_list:
|
||||
rays_o = self.dataset.pose_all[idx, :3, 3]
|
||||
rays_o_list.append(rays_o)
|
||||
|
||||
rgb_final = None
|
||||
diff_final = None
|
||||
for rays_o in rays_o_list:
|
||||
rays_o = rays_o.reshape(1, 3).repeat(vertices.shape[0],
|
||||
1).float()
|
||||
|
||||
rays_d = pt_vertices.reshape(-1, 3) - rays_o
|
||||
rays_d = rays_d / torch.norm(rays_d, dim=-1).reshape(-1, 1)
|
||||
dist = torch.norm(pt_vertices.reshape(-1, 3) - rays_o, dim=-1)
|
||||
|
||||
rays_o = rays_o.reshape(-1, 3).split(self.n_rays)
|
||||
rays_d = rays_d.reshape(-1, 3).split(self.n_rays)
|
||||
dist = dist.reshape(-1).split(self.n_rays)
|
||||
out_rgb_fine = []
|
||||
depth_diff = []
|
||||
for i, (rays_o_batch,
|
||||
rays_d_batch) in enumerate(zip(rays_o, rays_d)):
|
||||
near, far = self.dataset.near_far_from_sphere(
|
||||
rays_o_batch, rays_d_batch)
|
||||
render_out = self.renderer.render(
|
||||
rays_o_batch,
|
||||
rays_d_batch,
|
||||
near,
|
||||
far,
|
||||
cos_anneal_ratio=1.0,
|
||||
background_rgb=None)
|
||||
|
||||
# out_rgb_fine.append(render_out['color_fine'].detach().cpu().numpy())
|
||||
out_rgb_fine.append(
|
||||
render_out['albedo_fine'].detach().cpu().numpy())
|
||||
|
||||
weights = render_out['weights']
|
||||
mid_z_vals = render_out['mid_z_vals']
|
||||
n_samples = self.renderer.n_samples + self.renderer.n_importance
|
||||
depth_batch = (mid_z_vals[:, :n_samples]
|
||||
* weights[:, :n_samples]).sum(
|
||||
dim=1).detach().cpu().numpy()
|
||||
dist_batch = dist[i].detach().cpu().numpy()
|
||||
depth_diff.append(np.abs(depth_batch - dist_batch))
|
||||
|
||||
del render_out
|
||||
|
||||
out_rgb_fine = np.concatenate(
|
||||
out_rgb_fine, axis=0).reshape(vertices.shape[0], 3)
|
||||
depth_diff = np.concatenate(
|
||||
depth_diff, axis=0).reshape(vertices.shape[0])
|
||||
|
||||
if rgb_final is None:
|
||||
rgb_final = out_rgb_fine.copy()
|
||||
diff_final = depth_diff.copy()
|
||||
else:
|
||||
ind = diff_final > depth_diff
|
||||
ind = ind.reshape(-1)
|
||||
rgb_final[ind] = out_rgb_fine[ind]
|
||||
diff_final[ind] = depth_diff[ind]
|
||||
|
||||
vertices = vertices * self.dataset.scale_mats_np[0][
|
||||
0, 0] + self.dataset.scale_mats_np[0][:3, 3][None]
|
||||
|
||||
if color:
|
||||
logger.info('save mesh with color')
|
||||
vert_colors = (255 * np.clip(rgb_final[..., ::-1], 0, 1)).astype(
|
||||
np.uint8)
|
||||
mesh = trimesh.Trimesh(
|
||||
vertices, triangles, vertex_colors=vert_colors)
|
||||
else:
|
||||
mesh = trimesh.Trimesh(vertices, triangles)
|
||||
|
||||
outpath = os.path.join(save_dir, 'mesh.ply')
|
||||
mesh.export(outpath)
|
||||
|
||||
logger.info('surface econstruction done, export mesh to %s' % outpath)
|
||||
85
modelscope/models/cv/surface_recon_common/utils.py
Normal file
85
modelscope/models/cv/surface_recon_common/utils.py
Normal file
@@ -0,0 +1,85 @@
|
||||
# ------------------------------------------------------------------------
|
||||
# Modified from https://github.com/Totoro97/NeuS/blob/main/models/renderer.py
|
||||
# Copyright (c) 2021 Peng Wang.
|
||||
# ------------------------------------------------------------------------
|
||||
|
||||
import mcubes
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def extract_fields(bound_min,
|
||||
bound_max,
|
||||
resolution,
|
||||
query_func,
|
||||
device='cuda'):
|
||||
|
||||
N = 64
|
||||
X = torch.linspace(bound_min[0], bound_max[0], resolution).split(N)
|
||||
Y = torch.linspace(bound_min[1], bound_max[1], resolution).split(N)
|
||||
Z = torch.linspace(bound_min[2], bound_max[2], resolution).split(N)
|
||||
|
||||
u = np.zeros([resolution, resolution, resolution], dtype=np.float32)
|
||||
with torch.no_grad():
|
||||
for xi, xs in enumerate(X):
|
||||
for yi, ys in enumerate(Y):
|
||||
for zi, zs in enumerate(Z):
|
||||
xx, yy, zz = torch.meshgrid(xs, ys, zs)
|
||||
xx = xx.reshape(-1, 1)
|
||||
yy = yy.reshape(-1, 1)
|
||||
zz = zz.reshape(-1, 1)
|
||||
pts = torch.cat([xx, yy, zz], dim=-1)
|
||||
pts = pts.to(device)
|
||||
val = query_func(pts).reshape(
|
||||
len(xs), len(ys), len(zs)).detach().cpu().numpy()
|
||||
u[xi * N:xi * N + len(xs), yi * N:yi * N + len(ys),
|
||||
zi * N:zi * N + len(zs)] = val
|
||||
return u
|
||||
|
||||
|
||||
def extract_geometry(bound_min, bound_max, resolution, threshold, query_func,
|
||||
device):
|
||||
print('threshold: {}'.format(threshold))
|
||||
u = extract_fields(bound_min, bound_max, resolution, query_func, device)
|
||||
vertices, triangles = mcubes.marching_cubes(u, threshold)
|
||||
b_max_np = bound_max.detach().cpu().numpy()
|
||||
b_min_np = bound_min.detach().cpu().numpy()
|
||||
|
||||
vertices = vertices / (resolution - 1.0) * (
|
||||
b_max_np - b_min_np)[None, :] + b_min_np[None, :]
|
||||
return vertices, triangles
|
||||
|
||||
|
||||
def sample_pdf(bins, weights, n_samples, det=False, device='cuda'):
|
||||
# This implementation is from NeRF
|
||||
# Get pdf
|
||||
weights = weights + 1e-5 # prevent nans
|
||||
pdf = weights / torch.sum(weights, -1, keepdim=True)
|
||||
cdf = torch.cumsum(pdf, -1)
|
||||
cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1)
|
||||
# Take uniform samples
|
||||
if det:
|
||||
u = torch.linspace(
|
||||
0. + 0.5 / n_samples, 1. - 0.5 / n_samples,
|
||||
steps=n_samples).to(device)
|
||||
u = u.expand(list(cdf.shape[:-1]) + [n_samples])
|
||||
else:
|
||||
u = torch.rand(list(cdf.shape[:-1]) + [n_samples]).to(device)
|
||||
|
||||
# Invert CDF
|
||||
u = u.contiguous()
|
||||
inds = torch.searchsorted(cdf, u, right=True)
|
||||
below = torch.max(torch.zeros_like(inds - 1), inds - 1)
|
||||
above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds)
|
||||
inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2)
|
||||
|
||||
matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
|
||||
cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
|
||||
bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)
|
||||
|
||||
denom = (cdf_g[..., 1] - cdf_g[..., 0])
|
||||
denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)
|
||||
t = (u - cdf_g[..., 0]) / denom
|
||||
samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])
|
||||
|
||||
return samples
|
||||
@@ -12,7 +12,7 @@ from modelscope.models.cv.video_depth_estimation.utils.misc import filter_dict
|
||||
########################################################################################################################
|
||||
|
||||
|
||||
def resize_image(image, shape, interpolation=Image.ANTIALIAS):
|
||||
def resize_image(image, shape, interpolation=Image.Resampling.LANCZOS):
|
||||
"""
|
||||
Resizes input image.
|
||||
|
||||
@@ -57,7 +57,8 @@ def resize_depth(depth, shape):
|
||||
|
||||
def resize_sample_image_and_intrinsics(sample,
|
||||
shape,
|
||||
image_interpolation=Image.ANTIALIAS):
|
||||
image_interpolation=Image.Resampling.
|
||||
LANCZOS):
|
||||
"""
|
||||
Resizes the image and intrinsics of a sample
|
||||
|
||||
@@ -102,7 +103,7 @@ def resize_sample_image_and_intrinsics(sample,
|
||||
return sample
|
||||
|
||||
|
||||
def resize_sample(sample, shape, image_interpolation=Image.ANTIALIAS):
|
||||
def resize_sample(sample, shape, image_interpolation=Image.Resampling.LANCZOS):
|
||||
"""
|
||||
Resizes a sample, including image, intrinsics and depth maps.
|
||||
|
||||
|
||||
@@ -6,22 +6,23 @@ from modelscope.utils.import_utils import LazyImportModule
|
||||
if TYPE_CHECKING:
|
||||
|
||||
from .clip import CLIPForMultiModalEmbedding
|
||||
from .gemm import GEMMForMultiModalEmbedding
|
||||
from .rleg import RLEGForMultiModalEmbedding
|
||||
from .team import TEAMForMultiModalSimilarity
|
||||
from .clip_interrogator import CLIP_Interrogator
|
||||
from .diffusion import DiffusionForTextToImageSynthesis
|
||||
from .efficient_diffusion_tuning import EfficientStableDiffusion
|
||||
from .gemm import GEMMForMultiModalEmbedding
|
||||
from .mmr import VideoCLIPForMultiModalEmbedding
|
||||
from .mplug_for_all_tasks import MPlugForAllTasks, HiTeAForAllTasks
|
||||
from .mplug_for_all_tasks import HiTeAForAllTasks, MPlugForAllTasks
|
||||
from .mplug_owl import MplugOwlForConditionalGeneration
|
||||
from .multi_stage_diffusion import \
|
||||
MultiStageDiffusionForTextToImageSynthesis
|
||||
from .ofa_for_all_tasks import OfaForAllTasks
|
||||
from .ofa_for_text_to_image_synthesis_model import \
|
||||
OfaForTextToImageSynthesis
|
||||
from .multi_stage_diffusion import \
|
||||
MultiStageDiffusionForTextToImageSynthesis
|
||||
from .vldoc import VLDocForDocVLEmbedding
|
||||
from .prost import ProSTForTVRetrieval
|
||||
from .rleg import RLEGForMultiModalEmbedding
|
||||
from .team import TEAMForMultiModalSimilarity
|
||||
from .video_synthesis import TextToVideoSynthesis
|
||||
from .efficient_diffusion_tuning import EfficientStableDiffusion
|
||||
from .mplug_owl import MplugOwlForConditionalGeneration
|
||||
from .clip_interrogator import CLIP_Interrogator
|
||||
from .vldoc import VLDocForDocVLEmbedding
|
||||
from .videocomposer import VideoComposer
|
||||
|
||||
else:
|
||||
@@ -32,6 +33,7 @@ else:
|
||||
'rleg': ['RLEGForMultiModalEmbedding'],
|
||||
'team': ['TEAMForMultiModalSimilarity'],
|
||||
'mmr': ['VideoCLIPForMultiModalEmbedding'],
|
||||
'prost': ['ProSTForTVRetrieval'],
|
||||
'mplug_for_all_tasks': ['MPlugForAllTasks', 'HiTeAForAllTasks'],
|
||||
'ofa_for_all_tasks': ['OfaForAllTasks'],
|
||||
'ofa_for_text_to_image_synthesis_model':
|
||||
|
||||
@@ -578,7 +578,7 @@ class CLIPForMultiModalEmbedding(TorchModel):
|
||||
|
||||
with torch.autograd.set_grad_enabled(mode == ModeKeys.TRAIN):
|
||||
image_features = self.clip_model.encode_image(image_tensor)
|
||||
image_features /= image_features.norm(
|
||||
image_features = image_features / image_features.norm(
|
||||
dim=-1, keepdim=True) # l2-normalize
|
||||
|
||||
output[OutputKeys.IMG_EMBEDDING] = image_features
|
||||
@@ -590,7 +590,7 @@ class CLIPForMultiModalEmbedding(TorchModel):
|
||||
|
||||
with torch.autograd.set_grad_enabled(mode == ModeKeys.TRAIN):
|
||||
text_features = self.clip_model.encode_text(text_tensor)
|
||||
text_features /= text_features.norm(
|
||||
text_features = text_features / text_features.norm(
|
||||
dim=-1, keepdim=True) # l2-normalize
|
||||
output[OutputKeys.TEXT_EMBEDDING] = text_features
|
||||
|
||||
|
||||
3
modelscope/models/multi_modal/prost/__init__.py
Normal file
3
modelscope/models/multi_modal/prost/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.
|
||||
|
||||
from .models import ProSTForTVRetrieval
|
||||
117
modelscope/models/multi_modal/prost/dataloaders/rawvideo_util.py
Normal file
117
modelscope/models/multi_modal/prost/dataloaders/rawvideo_util.py
Normal file
@@ -0,0 +1,117 @@
|
||||
# The implementation is adopted from Huaishao Luo,
|
||||
# made pubicly available under the MIT License at https://github.com/ArrowLuo/CLIP4Clip
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch as th
|
||||
from PIL import Image
|
||||
from torchvision.transforms import (CenterCrop, Compose, InterpolationMode,
|
||||
Normalize, Resize, ToTensor)
|
||||
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class RawVideoExtractorCV2():
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
centercrop=False,
|
||||
size=224,
|
||||
frame_rate=-1,
|
||||
):
|
||||
self.centercrop = centercrop
|
||||
self.size = size
|
||||
self.framerate = frame_rate
|
||||
self.transform = self._transform(self.size)
|
||||
|
||||
def _transform(self, n_px):
|
||||
return Compose([
|
||||
Resize(n_px, interpolation=InterpolationMode.BICUBIC),
|
||||
CenterCrop(n_px),
|
||||
lambda image: image.convert('RGB'),
|
||||
ToTensor(),
|
||||
Normalize((0.48145466, 0.4578275, 0.40821073),
|
||||
(0.26862954, 0.26130258, 0.27577711)),
|
||||
])
|
||||
|
||||
def video_to_tensor(self,
|
||||
video_file,
|
||||
preprocess,
|
||||
sample_fp=0,
|
||||
start_time=None,
|
||||
end_time=None):
|
||||
if start_time is not None or end_time is not None:
|
||||
assert isinstance(start_time, int) and isinstance(end_time, int) \
|
||||
and start_time > -1 and end_time > start_time
|
||||
assert sample_fp > -1
|
||||
|
||||
# Samples a frame sample_fp X frames.
|
||||
cap = cv2.VideoCapture(video_file)
|
||||
frameCount = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
fps = int(cap.get(cv2.CAP_PROP_FPS))
|
||||
|
||||
if fps == 0:
|
||||
logger.info(f'{video_file} with fps 0!!!')
|
||||
total_duration = (frameCount + fps - 1) // fps
|
||||
start_sec, end_sec = 0, total_duration
|
||||
|
||||
if start_time is not None:
|
||||
start_sec, end_sec = start_time, end_time if end_time <= total_duration else total_duration
|
||||
cap.set(cv2.CAP_PROP_POS_FRAMES, int(start_time * fps))
|
||||
|
||||
interval = 1
|
||||
if sample_fp > 0:
|
||||
interval = fps // sample_fp
|
||||
else:
|
||||
sample_fp = fps
|
||||
if interval == 0:
|
||||
interval = 1
|
||||
|
||||
inds = [ind for ind in np.arange(0, fps, interval)]
|
||||
assert len(inds) >= sample_fp
|
||||
inds = inds[:sample_fp]
|
||||
|
||||
ret = True
|
||||
images = []
|
||||
|
||||
for sec in np.arange(start_sec, end_sec + 1):
|
||||
if not ret:
|
||||
break
|
||||
sec_base = int(sec * fps)
|
||||
for ind in inds:
|
||||
cap.set(cv2.CAP_PROP_POS_FRAMES, sec_base + ind)
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
break
|
||||
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
images.append(
|
||||
preprocess(Image.fromarray(frame_rgb).convert('RGB')))
|
||||
|
||||
cap.release()
|
||||
|
||||
if len(images) > 0:
|
||||
video_data = th.tensor(np.stack(images))
|
||||
else:
|
||||
video_data = th.zeros(1)
|
||||
return {'video': video_data}
|
||||
|
||||
def get_video_data(self, video_path, start_time=None, end_time=None):
|
||||
image_input = self.video_to_tensor(
|
||||
video_path,
|
||||
self.transform,
|
||||
sample_fp=self.framerate,
|
||||
start_time=start_time,
|
||||
end_time=end_time)
|
||||
return image_input
|
||||
|
||||
def process_raw_data(self, raw_video_data):
|
||||
tensor_size = raw_video_data.size()
|
||||
tensor = raw_video_data.view(-1, 1, tensor_size[-3], tensor_size[-2],
|
||||
tensor_size[-1])
|
||||
return tensor
|
||||
|
||||
|
||||
# An ordinary video frame extractor based CV2
|
||||
RawVideoExtractor = RawVideoExtractorCV2
|
||||
3
modelscope/models/multi_modal/prost/models/__init__.py
Normal file
3
modelscope/models/multi_modal/prost/models/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.
|
||||
|
||||
from .prost_model import ProSTForTVRetrieval
|
||||
704
modelscope/models/multi_modal/prost/models/modeling.py
Normal file
704
modelscope/models/multi_modal/prost/models/modeling.py
Normal file
@@ -0,0 +1,704 @@
|
||||
# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.
|
||||
|
||||
import os
|
||||
import platform
|
||||
from collections import OrderedDict
|
||||
from types import SimpleNamespace
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
|
||||
|
||||
from modelscope.models.multi_modal.prost.models.module_clip import (
|
||||
_PT_NAME, CLIP, QuickGELU, convert_weights)
|
||||
from modelscope.models.multi_modal.prost.models.module_cross import (
|
||||
CrossConfig, CrossModel)
|
||||
from modelscope.models.multi_modal.prost.models.module_cross import \
|
||||
Transformer as TransformerClip
|
||||
from modelscope.models.multi_modal.prost.models.until_module import (
|
||||
AllGather, CrossEn, Event_decoder, Frame_decoder, LayerNorm,
|
||||
PreTrainedModel, make_patch_shift)
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
allgather = AllGather.apply
|
||||
|
||||
logger = get_logger()
|
||||
__all__ = ['CLIP4Clip']
|
||||
|
||||
|
||||
class MyObject:
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
|
||||
class CLIP4ClipPreTrainedModel(PreTrainedModel, nn.Module):
|
||||
""" An abstract class to handle weights initialization and
|
||||
a simple interface for dowloading and loading pretrained models.
|
||||
"""
|
||||
|
||||
def __init__(self, cross_config, *inputs, **kwargs):
|
||||
super(CLIP4ClipPreTrainedModel, self).__init__(cross_config)
|
||||
self.cross_config = cross_config
|
||||
self.clip = None
|
||||
self.cross = None
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls,
|
||||
cross_config,
|
||||
state_dict=None,
|
||||
cache_dir=None,
|
||||
type_vocab_size=2,
|
||||
*inputs,
|
||||
**kwargs):
|
||||
|
||||
task_config = None
|
||||
if 'task_config' in kwargs.keys():
|
||||
task_config = kwargs['task_config']
|
||||
if not hasattr(task_config, 'local_rank'):
|
||||
task_config['local_rank'] = 0
|
||||
elif task_config['local_rank'] == -1:
|
||||
task_config['local_rank'] = 0
|
||||
|
||||
if state_dict is None:
|
||||
state_dict = {}
|
||||
# pretrained_clip_name = task_config['pretrained_clip_name']
|
||||
clip_state_dict = CLIP.get_config(model_dir=task_config['model_dir'])
|
||||
for key, val in clip_state_dict.items():
|
||||
new_key = 'clip.' + key
|
||||
if new_key not in state_dict:
|
||||
state_dict[new_key] = val.clone()
|
||||
|
||||
# cross_config, _ = CrossConfig.get_config(
|
||||
# cross_model_name,
|
||||
# cache_dir,
|
||||
# type_vocab_size,
|
||||
# state_dict=None,
|
||||
# task_config=task_config)
|
||||
cross_config = CrossConfig.from_dict(cross_config)
|
||||
cross_config.type_vocab_size = type_vocab_size
|
||||
task_config = MyObject(**kwargs['task_config'])
|
||||
model = cls(cross_config, clip_state_dict, *inputs, task_config)
|
||||
|
||||
# ===> Initialization trick [HARD CODE]
|
||||
if model.linear_patch == '3d':
|
||||
contain_conv2 = False
|
||||
for key in state_dict.keys():
|
||||
if key.find('visual.conv2.weight') > -1:
|
||||
contain_conv2 = True
|
||||
break
|
||||
if contain_conv2 is False and hasattr(model.clip.visual, 'conv2'):
|
||||
cp_weight = state_dict['clip.visual.conv1.weight'].clone()
|
||||
kernel_size = model.clip.visual.conv2.weight.size(2)
|
||||
conv2_size = model.clip.visual.conv2.weight.size()
|
||||
conv2_size = list(conv2_size)
|
||||
|
||||
left_conv2_size = conv2_size.copy()
|
||||
right_conv2_size = conv2_size.copy()
|
||||
left_conv2_size[2] = (kernel_size - 1) // 2
|
||||
right_conv2_size[2] = kernel_size - 1 - left_conv2_size[2]
|
||||
|
||||
left_zeros, right_zeros = None, None
|
||||
if left_conv2_size[2] > 0:
|
||||
left_zeros = torch.zeros(
|
||||
*tuple(left_conv2_size),
|
||||
dtype=cp_weight.dtype,
|
||||
device=cp_weight.device)
|
||||
if right_conv2_size[2] > 0:
|
||||
right_zeros = torch.zeros(
|
||||
*tuple(right_conv2_size),
|
||||
dtype=cp_weight.dtype,
|
||||
device=cp_weight.device)
|
||||
|
||||
cat_list = []
|
||||
if left_zeros is not None:
|
||||
cat_list.append(left_zeros)
|
||||
cat_list.append(cp_weight.unsqueeze(2))
|
||||
if right_zeros is not None:
|
||||
cat_list.append(right_zeros)
|
||||
cp_weight = torch.cat(cat_list, dim=2)
|
||||
|
||||
state_dict['clip.visual.conv2.weight'] = cp_weight
|
||||
|
||||
# if model.sim_header == 'tightTransf':
|
||||
# contain_cross = False
|
||||
# for key in state_dict.keys():
|
||||
# if key.find('cross.transformer') > -1:
|
||||
# contain_cross = True
|
||||
# break
|
||||
# if contain_cross is False:
|
||||
# for key, val in clip_state_dict.items():
|
||||
# if key == 'positional_embedding':
|
||||
# state_dict[
|
||||
# 'cross.embeddings.position_embeddings.weight'] = val.clone(
|
||||
# )
|
||||
# continue
|
||||
# if key.find('transformer.resblocks') == 0:
|
||||
# num_layer = int(key.split('.')[2])
|
||||
|
||||
# # cut from beginning
|
||||
# if num_layer < task_config.cross_num_hidden_layers:
|
||||
# state_dict['cross.' + key] = val.clone()
|
||||
# continue
|
||||
|
||||
if model.sim_header == 'seqLSTM' or model.sim_header == 'seqTransf':
|
||||
# This step is to detect whether in train mode or test mode
|
||||
contain_frame_position = False
|
||||
for key in state_dict.keys():
|
||||
if key.find('frame_position_embeddings') > -1:
|
||||
contain_frame_position = True
|
||||
break
|
||||
|
||||
# train mode
|
||||
if contain_frame_position is False:
|
||||
for key, val in clip_state_dict.items():
|
||||
if key == 'positional_embedding':
|
||||
state_dict[
|
||||
'frame_position_embeddings.weight'] = val.clone()
|
||||
# state_dict["text_prompt_encoder.pos_embedding"] = val[0:3].clone()
|
||||
continue
|
||||
if model.sim_header == 'seqTransf' and key.find(
|
||||
'transformer.resblocks') == 0:
|
||||
num_layer = int(key.split('.')[2])
|
||||
# cut from beginning
|
||||
if num_layer < task_config.cross_num_hidden_layers:
|
||||
state_dict[key.replace(
|
||||
'transformer.',
|
||||
'transformerClip.')] = val.clone()
|
||||
continue
|
||||
|
||||
else:
|
||||
for key, val in state_dict.items():
|
||||
# test mode
|
||||
if key.find('clip.visual.transformer.resblocks') == 0:
|
||||
num_layer = int(key.split('.')[4])
|
||||
# shift layers 10-11
|
||||
if num_layer >= 10 and num_layer < 12:
|
||||
state_dict[key.replace('attn.net.',
|
||||
'attn.')] = val.clone()
|
||||
# <=== End of initialization trick
|
||||
|
||||
if state_dict is not None:
|
||||
model = cls.init_preweight(
|
||||
model, state_dict, task_config=task_config)
|
||||
make_patch_shift(model, video_frame=task_config.max_frames, n_div=14)
|
||||
return model
|
||||
|
||||
|
||||
def show_log(task_config, info):
|
||||
if task_config is None or task_config.local_rank == 0:
|
||||
logger.warning(info)
|
||||
|
||||
|
||||
def update_attr(target_name,
|
||||
target_config,
|
||||
target_attr_name,
|
||||
source_config,
|
||||
source_attr_name,
|
||||
default_value=None):
|
||||
if hasattr(source_config, source_attr_name):
|
||||
if default_value is None or getattr(source_config,
|
||||
source_attr_name) != default_value:
|
||||
setattr(target_config, target_attr_name,
|
||||
getattr(source_config, source_attr_name))
|
||||
# show_log(
|
||||
# source_config, "Set {}.{}: {}.".format(
|
||||
# target_name, target_attr_name,
|
||||
# getattr(target_config, target_attr_name)))
|
||||
return target_config
|
||||
|
||||
|
||||
def check_attr(target_name, task_config):
|
||||
return hasattr(task_config,
|
||||
target_name) and task_config.__dict__[target_name]
|
||||
|
||||
|
||||
class CLIP4Clip(CLIP4ClipPreTrainedModel):
|
||||
|
||||
def __init__(self, cross_config, clip_state_dict, task_config):
|
||||
super(CLIP4Clip, self).__init__(cross_config)
|
||||
self.task_config = task_config
|
||||
self.ignore_video_index = -1
|
||||
|
||||
assert self.task_config.max_words + self.task_config.max_frames <= cross_config.max_position_embeddings
|
||||
|
||||
self._stage_one = True
|
||||
self._stage_two = False
|
||||
|
||||
# show_log(task_config, "Stage-One:{}, Stage-Two:{}".format(self._stage_one, self._stage_two))
|
||||
|
||||
self.loose_type = False
|
||||
if self._stage_one and check_attr('loose_type', self.task_config):
|
||||
self.loose_type = True
|
||||
# show_log(task_config, "Test retrieval by loose type.")
|
||||
|
||||
# CLIP Encoders: From OpenAI: CLIP [https://github.com/openai/CLIP] ===>
|
||||
vit = 'visual.proj' in clip_state_dict
|
||||
assert vit
|
||||
if vit:
|
||||
vision_width = clip_state_dict['visual.conv1.weight'].shape[0]
|
||||
vision_layers = len([
|
||||
k for k in clip_state_dict.keys() if k.startswith('visual.')
|
||||
and k.endswith('.attn.in_proj_weight')
|
||||
])
|
||||
vision_patch_size = clip_state_dict['visual.conv1.weight'].shape[
|
||||
-1]
|
||||
grid_size = round(
|
||||
(clip_state_dict['visual.positional_embedding'].shape[0]
|
||||
- 1)**0.5)
|
||||
image_resolution = vision_patch_size * grid_size
|
||||
else:
|
||||
counts: list = [
|
||||
len(
|
||||
set(
|
||||
k.split('.')[2] for k in clip_state_dict
|
||||
if k.startswith(f'visual.layer{b}')))
|
||||
for b in [1, 2, 3, 4]
|
||||
]
|
||||
vision_layers = tuple(counts)
|
||||
vision_width = clip_state_dict[
|
||||
'visual.layer1.0.conv1.weight'].shape[0]
|
||||
output_width = round(
|
||||
(clip_state_dict['visual.attnpool.positional_embedding'].
|
||||
shape[0] - 1)**0.5)
|
||||
vision_patch_size = None
|
||||
assert output_width**2 + 1 == clip_state_dict[
|
||||
'visual.attnpool.positional_embedding'].shape[0]
|
||||
image_resolution = output_width * 32
|
||||
|
||||
embed_dim = clip_state_dict['text_projection'].shape[1]
|
||||
context_length = clip_state_dict['positional_embedding'].shape[0]
|
||||
vocab_size = clip_state_dict['token_embedding.weight'].shape[0]
|
||||
transformer_width = clip_state_dict['ln_final.weight'].shape[0]
|
||||
transformer_heads = transformer_width // 64
|
||||
transformer_layers = len(
|
||||
set(
|
||||
k.split('.')[2] for k in clip_state_dict
|
||||
if k.startswith('transformer.resblocks')))
|
||||
|
||||
# show_log(task_config, "\t embed_dim: {}".format(embed_dim))
|
||||
# show_log(task_config, "\t image_resolution: {}".format(image_resolution))
|
||||
# show_log(task_config, "\t vision_layers: {}".format(vision_layers))
|
||||
# show_log(task_config, "\t vision_width: {}".format(vision_width))
|
||||
# show_log(task_config, "\t vision_patch_size: {}".format(vision_patch_size))
|
||||
# show_log(task_config, "\t context_length: {}".format(context_length))
|
||||
# show_log(task_config, "\t vocab_size: {}".format(vocab_size))
|
||||
# show_log(task_config, "\t transformer_width: {}".format(transformer_width))
|
||||
# show_log(task_config, "\t transformer_heads: {}".format(transformer_heads))
|
||||
# show_log(task_config, "\t transformer_layers: {}".format(transformer_layers))
|
||||
|
||||
self.linear_patch = '2d'
|
||||
if hasattr(task_config, 'linear_patch'):
|
||||
self.linear_patch = task_config.linear_patch
|
||||
# show_log(task_config, "\t\t linear_patch: {}".format(self.linear_patch))
|
||||
|
||||
# use .float() to avoid overflow/underflow from fp16 weight. https://github.com/openai/CLIP/issues/40
|
||||
cut_top_layer = 0
|
||||
|
||||
self.clip = CLIP(
|
||||
embed_dim,
|
||||
image_resolution,
|
||||
vision_layers - cut_top_layer,
|
||||
vision_width,
|
||||
vision_patch_size,
|
||||
context_length,
|
||||
vocab_size,
|
||||
transformer_width,
|
||||
transformer_heads,
|
||||
transformer_layers - cut_top_layer,
|
||||
linear_patch=self.linear_patch).float()
|
||||
|
||||
for key in ['input_resolution', 'context_length', 'vocab_size']:
|
||||
if key in clip_state_dict:
|
||||
del clip_state_dict[key]
|
||||
|
||||
convert_weights(self.clip)
|
||||
# <=== End of CLIP Encoders
|
||||
|
||||
self.sim_header = 'seqTransf'
|
||||
if hasattr(task_config, 'sim_header'):
|
||||
self.sim_header = task_config.sim_header
|
||||
# show_log(task_config, "\t sim_header: {}".format(self.sim_header))
|
||||
if self.sim_header == 'tightTransf':
|
||||
assert self.loose_type is False
|
||||
|
||||
cross_config.max_position_embeddings = context_length
|
||||
if self.loose_type is False:
|
||||
# Cross Encoder ===>
|
||||
cross_config = update_attr('cross_config', cross_config,
|
||||
'num_hidden_layers', self.task_config,
|
||||
'cross_num_hidden_layers')
|
||||
self.cross = CrossModel(cross_config)
|
||||
# <=== End of Cross Encoder
|
||||
self.similarity_dense = nn.Linear(cross_config.hidden_size, 1)
|
||||
|
||||
if self.sim_header == 'seqLSTM' or self.sim_header == 'seqTransf':
|
||||
self.frame_position_embeddings = nn.Embedding(
|
||||
cross_config.max_position_embeddings, cross_config.hidden_size)
|
||||
# self.frame_position_embeddings = nn.Embedding(600, cross_config.hidden_size)
|
||||
if self.sim_header == 'seqTransf':
|
||||
self.transformerClip = TransformerClip(
|
||||
width=transformer_width,
|
||||
layers=self.task_config.cross_num_hidden_layers,
|
||||
heads=transformer_heads,
|
||||
)
|
||||
if self.sim_header == 'seqLSTM':
|
||||
self.lstm_visual = nn.LSTM(
|
||||
input_size=cross_config.hidden_size,
|
||||
hidden_size=cross_config.hidden_size,
|
||||
batch_first=True,
|
||||
bidirectional=False,
|
||||
num_layers=1)
|
||||
|
||||
self.loss_fct = CrossEn()
|
||||
self.apply(self.init_weights)
|
||||
|
||||
self.set_dim = 512
|
||||
self.patch_num = self.task_config.max_patch
|
||||
if hasattr(self.task_config, 'max_word_pro'):
|
||||
self.word_pro_num = self.task_config.max_word_pro
|
||||
else:
|
||||
self.word_pro_num = self.task_config.max_phrase
|
||||
|
||||
self.frame_num = self.task_config.max_frames
|
||||
if hasattr(self.task_config, 'max_vfea'):
|
||||
self.event_num = self.task_config.max_vfea
|
||||
else:
|
||||
self.event_num = self.task_config.max_event
|
||||
|
||||
self.patch_prototype_weight = nn.Sequential(
|
||||
nn.Linear(self.set_dim, self.set_dim), nn.ReLU(inplace=True),
|
||||
nn.Linear(self.set_dim, self.patch_num - 1), nn.ReLU(inplace=True))
|
||||
|
||||
self.word_prototype_weight = nn.Sequential(
|
||||
nn.Linear(self.set_dim, self.set_dim), nn.ReLU(inplace=True),
|
||||
nn.Linear(self.set_dim, self.word_pro_num), nn.ReLU(inplace=True))
|
||||
|
||||
self.frame_decoder = Frame_decoder(
|
||||
num_attris=self.frame_num,
|
||||
layers=2,
|
||||
heads=1,
|
||||
dim_ftr=512,
|
||||
pos_emb=False,
|
||||
length=1,
|
||||
dim_feedforward=512,
|
||||
without_init=False)
|
||||
self.event_decoder = Event_decoder(
|
||||
num_attris=self.event_num,
|
||||
layers=2,
|
||||
heads=1,
|
||||
dim_ftr=512,
|
||||
pos_emb=False,
|
||||
length=1,
|
||||
dim_feedforward=512,
|
||||
without_init=False)
|
||||
# -------------------------------------------------------------------------------------------------------
|
||||
|
||||
def forward(self,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
attention_mask,
|
||||
video,
|
||||
video_mask=None):
|
||||
input_ids = input_ids.view(-1, input_ids.shape[-1])
|
||||
token_type_ids = token_type_ids.view(-1, token_type_ids.shape[-1])
|
||||
attention_mask = attention_mask.view(-1, attention_mask.shape[-1])
|
||||
video_mask = video_mask.view(-1, video_mask.shape[-1])
|
||||
|
||||
# T x 3 x H x W
|
||||
video = torch.as_tensor(video).float()
|
||||
bs, ts, channel, h, w = video.shape
|
||||
video = video.view(bs * ts, channel, h, w)
|
||||
video_frame = bs * ts
|
||||
phr_feat, sen_feat, obj_feat, eve_feat = self.get_sequence_visual_output(
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
attention_mask,
|
||||
video,
|
||||
video_mask,
|
||||
shaped=True,
|
||||
video_frame=video_frame)
|
||||
|
||||
if self.training:
|
||||
sim_matrix1, sim_matrix2, sim_matrix3, sim_matrix4 = self.get_max_similarity_logits(
|
||||
phr_feat,
|
||||
sen_feat,
|
||||
obj_feat,
|
||||
eve_feat,
|
||||
attention_mask,
|
||||
video_mask,
|
||||
shaped=True)
|
||||
sim_loss = (self.loss_fct(sim_matrix1) + self.loss_fct(sim_matrix2)
|
||||
+ self.loss_fct(sim_matrix3)
|
||||
+ self.loss_fct(sim_matrix4)) / 4.0
|
||||
|
||||
loss = sim_loss
|
||||
|
||||
return loss
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_max_similarity_logits(self,
|
||||
word_feat,
|
||||
text_feat,
|
||||
patch_feat,
|
||||
video_feat,
|
||||
text_mask,
|
||||
video_mask,
|
||||
shaped=False):
|
||||
if shaped is False:
|
||||
text_mask = text_mask.view(-1, text_mask.shape[-1])
|
||||
video_mask = video_mask.view(-1, video_mask.shape[-1])
|
||||
|
||||
if self.training and torch.cuda.is_available(): # batch merge here
|
||||
text_feat = allgather(text_feat, self.task_config)
|
||||
video_feat = allgather(video_feat, self.task_config)
|
||||
word_feat = allgather(word_feat, self.task_config)
|
||||
patch_feat = allgather(patch_feat, self.task_config)
|
||||
|
||||
video_mask = allgather(video_mask, self.task_config)
|
||||
torch.distributed.barrier() # force sync
|
||||
|
||||
# ESPM
|
||||
text_feat = F.normalize(text_feat, p=2, dim=1)
|
||||
video_feat = F.normalize(video_feat, p=2, dim=2)
|
||||
retrieve_logits = torch.einsum('ad,bkd->abk', [text_feat, video_feat])
|
||||
retrieve_logits = retrieve_logits.max(2)[0]
|
||||
|
||||
# OPPM
|
||||
word_feat = F.normalize(word_feat, p=2, dim=2)
|
||||
patch_feat = F.normalize(patch_feat, p=2, dim=3)
|
||||
retrieve_logits_2 = torch.einsum('aid, bfjd->abfij',
|
||||
[word_feat, patch_feat])
|
||||
|
||||
retrieve_logits_2 = retrieve_logits_2.max(3)[0]
|
||||
retrieve_logits_2 = retrieve_logits_2.max(2)[0]
|
||||
retrieve_logits_2 = retrieve_logits_2.sum(2) / self.patch_num
|
||||
|
||||
if self.training:
|
||||
logit_scale = self.clip.logit_scale.exp()
|
||||
retrieve_logits = logit_scale * retrieve_logits
|
||||
retrieve_logits_2 = logit_scale * retrieve_logits_2
|
||||
return retrieve_logits, retrieve_logits.t(
|
||||
), retrieve_logits_2, retrieve_logits_2.t()
|
||||
|
||||
def get_sequence_output(self,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
attention_mask,
|
||||
shaped=False):
|
||||
if shaped is False:
|
||||
input_ids = input_ids.view(-1, input_ids.shape[-1])
|
||||
token_type_ids = token_type_ids.view(-1, token_type_ids.shape[-1])
|
||||
attention_mask = attention_mask.view(-1, attention_mask.shape[-1])
|
||||
bs_pair = input_ids.size(0)
|
||||
sequence_hidden = self.clip.encode_text(
|
||||
input_ids, return_hidden=True)[1].float()
|
||||
text_feat = sequence_hidden.view(bs_pair, -1, sequence_hidden.size(-1))
|
||||
|
||||
word_weights = self.word_prototype_weight(text_feat)
|
||||
text_word_proto = torch.einsum('bmd,bmn->bnd', text_feat, word_weights)
|
||||
|
||||
cls_text_feat = text_feat.contiguous()
|
||||
cls_text_feat = cls_text_feat[torch.arange(cls_text_feat.shape[0]),
|
||||
torch.sum(attention_mask, dim=-1) - 1, :]
|
||||
|
||||
return text_word_proto, cls_text_feat
|
||||
|
||||
def get_visual_output(self,
|
||||
video,
|
||||
video_mask,
|
||||
shaped=False,
|
||||
video_frame=-1):
|
||||
if shaped is False:
|
||||
video_mask = video_mask.view(-1, video_mask.shape[-1])
|
||||
video = torch.as_tensor(video).float()
|
||||
bs, ts, channel, h, w = video.shape
|
||||
video = video.view(bs * ts, channel, h, w)
|
||||
# video_frame = bs * ts
|
||||
|
||||
bs_pair = video_mask.size(0)
|
||||
|
||||
cls_video_feat, video_patch_feat = self.clip.encode_image_tokens(
|
||||
video, return_hidden=True)
|
||||
cls_video_feat = cls_video_feat.float()
|
||||
video_patch_feat = video_patch_feat.float()
|
||||
# frame_num = video_patch_feat.shape[0]
|
||||
patch_dim = video_patch_feat.shape[2]
|
||||
|
||||
patch_weights = self.patch_prototype_weight(video_patch_feat)
|
||||
# cls_video_feat
|
||||
video_patch_proto = torch.einsum('bmd,bmn->bnd', video_patch_feat,
|
||||
patch_weights)
|
||||
video_patch_proto = torch.cat(
|
||||
(cls_video_feat.unsqueeze(1), video_patch_proto), 1)
|
||||
video_patch_proto = video_patch_proto.reshape(
|
||||
bs_pair, self.task_config.max_frames, self.patch_num, patch_dim)
|
||||
|
||||
video_frame_proto = video_patch_proto.reshape(
|
||||
bs_pair, self.patch_num * self.task_config.max_frames, patch_dim)
|
||||
video_frame_proto = self.frame_decoder(video_frame_proto)
|
||||
|
||||
video_frame_proto = 0.5 * video_frame_proto + 0.5 * cls_video_feat.reshape(
|
||||
bs_pair, self.task_config.max_frames, patch_dim)
|
||||
video_frame_proto = self.event_decoder(video_frame_proto)
|
||||
video_frame_proto = 0.5 * video_frame_proto + 0.5 * cls_video_feat.reshape(
|
||||
bs_pair, self.task_config.max_frames, patch_dim).mean(1).unsqueeze(
|
||||
1).repeat(1, video_frame_proto.shape[1], 1)
|
||||
return video_patch_proto, video_frame_proto
|
||||
|
||||
def get_sequence_visual_output(self,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
attention_mask,
|
||||
video,
|
||||
video_mask,
|
||||
shaped=False,
|
||||
video_frame=-1):
|
||||
if shaped is False:
|
||||
input_ids = input_ids.view(-1, input_ids.shape[-1])
|
||||
token_type_ids = token_type_ids.view(-1, token_type_ids.shape[-1])
|
||||
attention_mask = attention_mask.view(-1, attention_mask.shape[-1])
|
||||
video_mask = video_mask.view(-1, video_mask.shape[-1])
|
||||
|
||||
video = torch.as_tensor(video).float()
|
||||
|
||||
# import pdb;pdb.set_trace()
|
||||
# b, pair,
|
||||
bs, ts, channel, h, w = video.shape
|
||||
video = video.view(bs * ts, channel, h, w)
|
||||
video_frame = bs * ts
|
||||
|
||||
word_feat, text_feat = self.get_sequence_output(
|
||||
input_ids, token_type_ids, attention_mask, shaped=True)
|
||||
|
||||
patch_feat, frame_feat = self.get_visual_output(
|
||||
video, video_mask, shaped=True, video_frame=video_frame)
|
||||
|
||||
return word_feat, text_feat, patch_feat, frame_feat
|
||||
|
||||
def _get_cross_output(self, sequence_output, visual_output, attention_mask,
|
||||
video_mask):
|
||||
|
||||
concat_features = torch.cat((sequence_output, visual_output),
|
||||
dim=1) # concatnate tokens and frames
|
||||
concat_mask = torch.cat((attention_mask, video_mask), dim=1)
|
||||
text_type_ = torch.zeros_like(attention_mask)
|
||||
video_type_ = torch.ones_like(video_mask)
|
||||
concat_type = torch.cat((text_type_, video_type_), dim=1)
|
||||
|
||||
cross_layers, pooled_output = self.cross(
|
||||
concat_features,
|
||||
concat_type,
|
||||
concat_mask,
|
||||
output_all_encoded_layers=True)
|
||||
cross_output = cross_layers[-1]
|
||||
|
||||
return cross_output, pooled_output, concat_mask
|
||||
|
||||
def _mean_pooling_for_similarity_sequence(self, sequence_output,
|
||||
attention_mask):
|
||||
attention_mask_un = attention_mask.to(dtype=torch.float).unsqueeze(-1)
|
||||
attention_mask_un[:, 0, :] = 0.
|
||||
sequence_output = sequence_output * attention_mask_un
|
||||
text_out = torch.sum(
|
||||
sequence_output, dim=1) / torch.sum(
|
||||
attention_mask_un, dim=1, dtype=torch.float)
|
||||
return text_out
|
||||
|
||||
def _mean_pooling_for_similarity_visual(
|
||||
self,
|
||||
visual_output,
|
||||
video_mask,
|
||||
):
|
||||
video_mask_un = video_mask.to(dtype=torch.float).unsqueeze(-1)
|
||||
visual_output = visual_output * video_mask_un
|
||||
video_mask_un_sum = torch.sum(video_mask_un, dim=1, dtype=torch.float)
|
||||
video_mask_un_sum[video_mask_un_sum == 0.] = 1.
|
||||
video_out = torch.sum(visual_output, dim=1) / video_mask_un_sum
|
||||
return video_out
|
||||
|
||||
def _mean_pooling_for_similarity(
|
||||
self,
|
||||
sequence_output,
|
||||
visual_output,
|
||||
attention_mask,
|
||||
video_mask,
|
||||
):
|
||||
text_out = self._mean_pooling_for_similarity_sequence(
|
||||
sequence_output, attention_mask)
|
||||
video_out = self._mean_pooling_for_similarity_visual(
|
||||
visual_output, video_mask)
|
||||
|
||||
return text_out, video_out
|
||||
|
||||
def get_global_similarity(self, sequence_output, visual_output,
|
||||
attention_mask, video_mask):
|
||||
visual_output = visual_output / visual_output.norm(
|
||||
dim=-1, keepdim=True)
|
||||
visual_output = self._mean_pooling_for_similarity_visual(
|
||||
visual_output, video_mask)
|
||||
visual_output = visual_output / visual_output.norm(
|
||||
dim=-1, keepdim=True)
|
||||
|
||||
sequence_output = sequence_output.squeeze(1)
|
||||
sequence_output = sequence_output / sequence_output.norm(
|
||||
dim=-1, keepdim=True)
|
||||
|
||||
logit_scale = self.clip.logit_scale.exp()
|
||||
# retrieve_logits = logit_scale * torch.matmul(sequence_output, visual_output.t())
|
||||
sim_matrix_global = logit_scale * torch.matmul(sequence_output,
|
||||
visual_output.t())
|
||||
return sim_matrix_global
|
||||
|
||||
def _cross_similarity(self, sequence_output, visual_output, attention_mask,
|
||||
video_mask):
|
||||
sequence_output, visual_output = sequence_output.contiguous(
|
||||
), visual_output.contiguous()
|
||||
|
||||
b_text, s_text, h_text = sequence_output.size()
|
||||
b_visual, s_visual, h_visual = visual_output.size()
|
||||
|
||||
retrieve_logits_list = []
|
||||
|
||||
step_size = b_text # set smaller to reduce memory cost
|
||||
split_size = [step_size] * (b_text // step_size)
|
||||
release_size = b_text - sum(split_size)
|
||||
if release_size > 0:
|
||||
split_size += [release_size]
|
||||
|
||||
# due to clip text branch retrun the last hidden
|
||||
attention_mask = torch.ones(sequence_output.size(0), 1)\
|
||||
.to(device=attention_mask.device, dtype=attention_mask.dtype)
|
||||
|
||||
sequence_output_splits = torch.split(
|
||||
sequence_output, split_size, dim=0)
|
||||
attention_mask_splits = torch.split(attention_mask, split_size, dim=0)
|
||||
for i in range(len(split_size)):
|
||||
sequence_output_row = sequence_output_splits[i]
|
||||
attention_mask_row = attention_mask_splits[i]
|
||||
sequence_output_l = sequence_output_row.unsqueeze(1).repeat(
|
||||
1, b_visual, 1, 1)
|
||||
sequence_output_l = sequence_output_l.view(-1, s_text, h_text)
|
||||
attention_mask_l = attention_mask_row.unsqueeze(1).repeat(
|
||||
1, b_visual, 1)
|
||||
attention_mask_l = attention_mask_l.view(-1, s_text)
|
||||
|
||||
step_truth = sequence_output_row.size(0)
|
||||
visual_output_r = visual_output.unsqueeze(0).repeat(
|
||||
step_truth, 1, 1, 1)
|
||||
visual_output_r = visual_output_r.view(-1, s_visual, h_visual)
|
||||
video_mask_r = video_mask.unsqueeze(0).repeat(step_truth, 1, 1)
|
||||
video_mask_r = video_mask_r.view(-1, s_visual)
|
||||
|
||||
cross_output, pooled_output, concat_mask = \
|
||||
self._get_cross_output(sequence_output_l, visual_output_r, attention_mask_l, video_mask_r)
|
||||
retrieve_logits_row = self.similarity_dense(pooled_output).squeeze(
|
||||
-1).view(step_truth, b_visual)
|
||||
|
||||
retrieve_logits_list.append(retrieve_logits_row)
|
||||
|
||||
retrieve_logits = torch.cat(retrieve_logits_list, dim=0)
|
||||
return retrieve_logits
|
||||
538
modelscope/models/multi_modal/prost/models/module_clip.py
Normal file
538
modelscope/models/multi_modal/prost/models/module_clip.py
Normal file
@@ -0,0 +1,538 @@
|
||||
# The implementation is adopated from the CLIP4Clip implementation,
|
||||
# made pubicly available under Apache License, Version 2.0 at https://github.com/ArrowLuo/CLIP4Clip
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
import urllib
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
from typing import Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint as checkpoint
|
||||
from torch import nn
|
||||
from tqdm import tqdm
|
||||
|
||||
_MODELS = {}
|
||||
_PT_NAME = {'ViT-B/16': 'ViT-B-16.pt'}
|
||||
|
||||
|
||||
def available_models():
|
||||
"""Returns the names of available CLIP models"""
|
||||
return list(_MODELS.keys())
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1):
|
||||
super(Bottleneck, self).__init__()
|
||||
|
||||
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
|
||||
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
|
||||
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
|
||||
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
|
||||
|
||||
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
||||
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = None
|
||||
self.stride = stride
|
||||
|
||||
if stride > 1 or inplanes != planes * Bottleneck.expansion:
|
||||
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
|
||||
self.downsample = nn.Sequential(
|
||||
OrderedDict([('-1', nn.AvgPool2d(stride)),
|
||||
('0',
|
||||
nn.Conv2d(
|
||||
inplanes,
|
||||
planes * self.expansion,
|
||||
1,
|
||||
stride=1,
|
||||
bias=False)),
|
||||
('1', nn.BatchNorm2d(planes * self.expansion))]))
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
identity = x
|
||||
|
||||
out = self.relu(self.bn1(self.conv1(x)))
|
||||
out = self.relu(self.bn2(self.conv2(out)))
|
||||
out = self.avgpool(out)
|
||||
out = self.bn3(self.conv3(out))
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out += identity
|
||||
out = self.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
class AttentionPool2d(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
spacial_dim: int,
|
||||
embed_dim: int,
|
||||
num_heads: int,
|
||||
output_dim: int = None):
|
||||
super(AttentionPool2d, self).__init__()
|
||||
self.positional_embedding = nn.Parameter(
|
||||
torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5)
|
||||
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
||||
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
||||
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
||||
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
|
||||
self.num_heads = num_heads
|
||||
|
||||
def forward(self, x):
|
||||
x = x.reshape(x.shape[0], x.shape[1],
|
||||
x.shape[2] * x.shape[3]).permute(2, 0,
|
||||
1) # NCHW -> (HW)NC
|
||||
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
|
||||
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
|
||||
x, _ = F.multi_head_attention_forward(
|
||||
query=x,
|
||||
key=x,
|
||||
value=x,
|
||||
embed_dim_to_check=x.shape[-1],
|
||||
num_heads=self.num_heads,
|
||||
q_proj_weight=self.q_proj.weight,
|
||||
k_proj_weight=self.k_proj.weight,
|
||||
v_proj_weight=self.v_proj.weight,
|
||||
in_proj_weight=None,
|
||||
in_proj_bias=torch.cat(
|
||||
[self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
|
||||
bias_k=None,
|
||||
bias_v=None,
|
||||
add_zero_attn=False,
|
||||
dropout_p=0,
|
||||
out_proj_weight=self.c_proj.weight,
|
||||
out_proj_bias=self.c_proj.bias,
|
||||
use_separate_proj_weight=True,
|
||||
training=self.training,
|
||||
need_weights=False)
|
||||
|
||||
return x[0]
|
||||
|
||||
|
||||
class ModifiedResNet(nn.Module):
|
||||
"""
|
||||
A ResNet class that is similar to torchvision's but contains the following changes:
|
||||
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
|
||||
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
|
||||
- The final pooling layer is a QKV attention instead of an average pool
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
layers,
|
||||
output_dim,
|
||||
heads,
|
||||
input_resolution=224,
|
||||
width=64):
|
||||
super(ModifiedResNet, self).__init__()
|
||||
self.output_dim = output_dim
|
||||
self.input_resolution = input_resolution
|
||||
|
||||
# the 3-layer stem
|
||||
self.conv1 = nn.Conv2d(
|
||||
3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(width // 2)
|
||||
self.conv2 = nn.Conv2d(
|
||||
width // 2, width // 2, kernel_size=3, padding=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(width // 2)
|
||||
self.conv3 = nn.Conv2d(
|
||||
width // 2, width, kernel_size=3, padding=1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(width)
|
||||
self.avgpool = nn.AvgPool2d(2)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
# residual layers
|
||||
self._inplanes = width # this is a *mutable* variable used during construction
|
||||
self.layer1 = self._make_layer(width, layers[0])
|
||||
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
|
||||
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
|
||||
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
|
||||
|
||||
embed_dim = width * 32 # the ResNet feature dimension
|
||||
self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim,
|
||||
heads, output_dim)
|
||||
|
||||
def _make_layer(self, planes, blocks, stride=1):
|
||||
layers = [Bottleneck(self._inplanes, planes, stride)]
|
||||
|
||||
self._inplanes = planes * Bottleneck.expansion
|
||||
for _ in range(1, blocks):
|
||||
layers.append(Bottleneck(self._inplanes, planes))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
def stem(x):
|
||||
for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2),
|
||||
(self.conv3, self.bn3)]:
|
||||
x = self.relu(bn(conv(x)))
|
||||
x = self.avgpool(x)
|
||||
return x
|
||||
|
||||
x = x.type(self.conv1.weight.dtype)
|
||||
x = stem(x)
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
x = self.attnpool(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class LayerNorm(nn.LayerNorm):
|
||||
"""Subclass torch's LayerNorm to handle fp16."""
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
orig_type = x.dtype
|
||||
ret = super().forward(x.type(torch.float32))
|
||||
return ret.type(orig_type)
|
||||
|
||||
|
||||
class QuickGELU(nn.Module):
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
return x * torch.sigmoid(1.702 * x)
|
||||
|
||||
|
||||
class ResidualAttentionBlock(nn.Module):
|
||||
|
||||
def __init__(self, d_model: int, n_head: int, attn_mask=None):
|
||||
super(ResidualAttentionBlock, self).__init__()
|
||||
|
||||
self.attn = nn.MultiheadAttention(d_model, n_head)
|
||||
self.ln_1 = LayerNorm(d_model)
|
||||
self.mlp = nn.Sequential(
|
||||
OrderedDict([('c_fc', nn.Linear(d_model, d_model * 4)),
|
||||
('gelu', QuickGELU()),
|
||||
('c_proj', nn.Linear(d_model * 4, d_model))]))
|
||||
self.ln_2 = LayerNorm(d_model)
|
||||
self.attn_mask = attn_mask
|
||||
|
||||
def attention(self, x: torch.Tensor):
|
||||
attn_mask_ = self.attn_mask
|
||||
if self.attn_mask is not None and hasattr(self.attn_mask, '__call__'):
|
||||
attn_mask_ = self.attn_mask(x.size(0)) # LND
|
||||
|
||||
attn_mask_ = attn_mask_.to(
|
||||
dtype=x.dtype, device=x.device) if attn_mask_ is not None else None
|
||||
return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask_)[0]
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.attention(self.ln_1(x))
|
||||
x = x + self.mlp(self.ln_2(x))
|
||||
return x
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
width: int,
|
||||
layers: int,
|
||||
heads: int,
|
||||
attn_mask=None,
|
||||
use_gc=0):
|
||||
super(Transformer, self).__init__()
|
||||
self.width = width
|
||||
self.layers = layers
|
||||
self.resblocks = nn.Sequential(*[
|
||||
ResidualAttentionBlock(width, heads, attn_mask)
|
||||
for _ in range(layers)
|
||||
])
|
||||
|
||||
self.use_gc = use_gc
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
if self.use_gc > 0:
|
||||
for blk in self.resblocks:
|
||||
x = checkpoint.checkpoint(blk, x)
|
||||
return x
|
||||
else:
|
||||
return self.resblocks(x)
|
||||
|
||||
|
||||
class VisualTransformer(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
input_resolution: int,
|
||||
patch_size: int,
|
||||
width: int,
|
||||
layers: int,
|
||||
heads: int,
|
||||
output_dim: int,
|
||||
linear_patch: str = '2d',
|
||||
use_gc: int = 0):
|
||||
super(VisualTransformer, self).__init__()
|
||||
self.input_resolution = input_resolution
|
||||
self.output_dim = output_dim
|
||||
|
||||
self.conv1 = nn.Conv2d(
|
||||
in_channels=3,
|
||||
out_channels=width,
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size,
|
||||
bias=False)
|
||||
|
||||
scale = width**-0.5
|
||||
self.class_embedding = nn.Parameter(scale * torch.randn(width))
|
||||
self.positional_embedding = nn.Parameter(scale * torch.randn(
|
||||
(input_resolution // patch_size)**2 + 1, width))
|
||||
self.ln_pre = LayerNorm(width)
|
||||
|
||||
self.transformer = Transformer(width, layers, heads, use_gc=use_gc)
|
||||
|
||||
self.ln_post = LayerNorm(width)
|
||||
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
|
||||
|
||||
# For 3D
|
||||
assert linear_patch in ['2d', '3d']
|
||||
self.linear_patch = linear_patch
|
||||
if self.linear_patch == '3d':
|
||||
self.conv2 = nn.Conv3d(
|
||||
in_channels=3,
|
||||
out_channels=width,
|
||||
kernel_size=(3, patch_size, patch_size),
|
||||
stride=(1, patch_size, patch_size),
|
||||
padding=(1, 0, 0),
|
||||
bias=False)
|
||||
|
||||
def forward(self, x: torch.Tensor, video_frame=-1):
|
||||
|
||||
if self.linear_patch == '3d':
|
||||
assert video_frame != -1
|
||||
x_3d = x.reshape(-1, video_frame, x.shape[-3], x.shape[-2],
|
||||
x.shape[-1])
|
||||
x_3d = x_3d.permute(0, 2, 1, 3, 4)
|
||||
x_3d = self.conv2(x_3d) # shape = [*, width, frame, grid, grid]
|
||||
x_3d = x_3d.permute(0, 2, 1, 3,
|
||||
4) # shape = [*, frame, width, grid, grid]
|
||||
x = x_3d.reshape(
|
||||
-1, x_3d.shape[-3], x_3d.shape[-2],
|
||||
x_3d.shape[-1]).contiguous() # shape = [*, width, grid, grid]
|
||||
else:
|
||||
x = self.conv1(x) # shape = [*, width, grid, grid]
|
||||
|
||||
x = x.reshape(x.shape[0], x.shape[1],
|
||||
-1) # shape = [*, width, grid ** 2]
|
||||
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
||||
|
||||
_x = self.class_embedding.to(x.dtype) + torch.zeros(
|
||||
x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device)
|
||||
x = torch.cat([_x, x], dim=1)
|
||||
x = x + self.positional_embedding.to(x.dtype)
|
||||
x = self.ln_pre(x)
|
||||
|
||||
x = x.permute(1, 0, 2) # NLD -> LND
|
||||
x = self.transformer(x)
|
||||
x = x.permute(1, 0, 2) # LND -> NLD
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class CLIP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
# vision
|
||||
image_resolution: int,
|
||||
vision_layers: Union[Tuple[int, int, int, int], int],
|
||||
vision_width: int,
|
||||
vision_patch_size: int,
|
||||
# text
|
||||
context_length: int,
|
||||
vocab_size: int,
|
||||
transformer_width: int,
|
||||
transformer_heads: int,
|
||||
transformer_layers: int,
|
||||
# vision linear of patch
|
||||
linear_patch: str = '2d',
|
||||
use_gc: int = 0):
|
||||
super(CLIP, self).__init__()
|
||||
|
||||
self.context_length = context_length
|
||||
|
||||
if isinstance(vision_layers, (tuple, list)):
|
||||
vision_heads = vision_width * 32 // 64
|
||||
self.visual = ModifiedResNet(
|
||||
layers=vision_layers,
|
||||
output_dim=embed_dim,
|
||||
heads=vision_heads,
|
||||
input_resolution=image_resolution,
|
||||
width=vision_width)
|
||||
else:
|
||||
vision_heads = vision_width // 64
|
||||
self.visual = VisualTransformer(
|
||||
input_resolution=image_resolution,
|
||||
patch_size=vision_patch_size,
|
||||
width=vision_width,
|
||||
layers=vision_layers,
|
||||
heads=vision_heads,
|
||||
output_dim=embed_dim,
|
||||
linear_patch=linear_patch,
|
||||
use_gc=use_gc)
|
||||
|
||||
self.transformer = Transformer(
|
||||
width=transformer_width,
|
||||
layers=transformer_layers,
|
||||
heads=transformer_heads,
|
||||
attn_mask=self.build_attention_mask)
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
|
||||
self.positional_embedding = nn.Parameter(
|
||||
torch.empty(self.context_length, transformer_width))
|
||||
self.ln_final = LayerNorm(transformer_width)
|
||||
|
||||
self.text_projection = nn.Parameter(
|
||||
torch.empty(transformer_width, embed_dim))
|
||||
self.logit_scale = nn.Parameter(torch.ones([]))
|
||||
|
||||
self.initialize_parameters()
|
||||
|
||||
def initialize_parameters(self):
|
||||
nn.init.normal_(self.token_embedding.weight, std=0.02)
|
||||
nn.init.normal_(self.positional_embedding, std=0.01)
|
||||
|
||||
if isinstance(self.visual, ModifiedResNet):
|
||||
if self.visual.attnpool is not None:
|
||||
std = self.visual.attnpool.c_proj.in_features**-0.5
|
||||
nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
|
||||
nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
|
||||
nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
|
||||
nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
|
||||
|
||||
for resnet_block in [
|
||||
self.visual.layer1, self.visual.layer2, self.visual.layer3,
|
||||
self.visual.layer4
|
||||
]:
|
||||
for name, param in resnet_block.named_parameters():
|
||||
if name.endswith('bn3.weight'):
|
||||
nn.init.zeros_(param)
|
||||
|
||||
proj_std = (self.transformer.width**-0.5) * (
|
||||
(2 * self.transformer.layers)**-0.5)
|
||||
attn_std = self.transformer.width**-0.5
|
||||
fc_std = (2 * self.transformer.width)**-0.5
|
||||
for block in self.transformer.resblocks:
|
||||
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
||||
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
||||
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
||||
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
||||
|
||||
if self.text_projection is not None:
|
||||
nn.init.normal_(
|
||||
self.text_projection, std=self.transformer.width**-0.5)
|
||||
|
||||
def build_attention_mask(self, context_length):
|
||||
# lazily create causal attention mask, with full attention between the vision tokens
|
||||
# pytorch uses additive attention mask; fill with -inf
|
||||
mask = torch.zeros(context_length, context_length)
|
||||
mask.fill_(float('-inf'))
|
||||
mask.triu_(1) # zero out the lower diagonal
|
||||
return mask
|
||||
|
||||
@staticmethod
|
||||
def get_config(model_dir):
|
||||
model_path = '{}/ViT-B-16.pt'.format(model_dir)
|
||||
try:
|
||||
# loading JIT archive
|
||||
model = torch.jit.load(model_path, map_location='cpu').eval()
|
||||
state_dict = model.state_dict()
|
||||
except RuntimeError:
|
||||
state_dict = torch.load(model_path, map_location='cpu')
|
||||
return state_dict
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self.visual.conv1.weight.dtype
|
||||
|
||||
def encode_image_tokens(self, image, return_hidden=False):
|
||||
hidden = self.visual(image.type(self.dtype))
|
||||
hidden = self.visual.ln_post(hidden) @ self.visual.proj
|
||||
|
||||
x = hidden[:, 0, :]
|
||||
|
||||
if return_hidden:
|
||||
return x, hidden
|
||||
|
||||
return x
|
||||
|
||||
def encode_text(self, text, return_hidden=False, prompt=None):
|
||||
x = self.token_embedding(text).type(
|
||||
self.dtype) # [batch_size, n_ctx, d_model]
|
||||
if prompt:
|
||||
x = prompt(x)
|
||||
|
||||
pos_emd = self.positional_embedding[:x.size(1), :].type(self.dtype)
|
||||
x = x + pos_emd
|
||||
x = x.permute(1, 0, 2) # NLD -> LND
|
||||
x = self.transformer(x)
|
||||
x = x.permute(1, 0, 2) # LND -> NLD
|
||||
|
||||
hidden = self.ln_final(x).type(self.dtype) @ self.text_projection
|
||||
|
||||
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
||||
x = hidden[torch.arange(hidden.shape[0]), text.argmax(dim=-1)]
|
||||
|
||||
if return_hidden:
|
||||
return x, hidden
|
||||
|
||||
return x
|
||||
|
||||
def forward(self, image, text):
|
||||
image_features = self.encode_image(image)
|
||||
text_features = self.encode_text(text)
|
||||
|
||||
# normalized features
|
||||
image_features = image_features / image_features.norm(
|
||||
dim=-1, keepdim=True)
|
||||
text_features = text_features / text_features.norm(
|
||||
dim=-1, keepdim=True)
|
||||
|
||||
# cosine similarity as logits
|
||||
logit_scale = self.logit_scale.exp()
|
||||
logits_per_image = logit_scale * image_features @ text_features.t()
|
||||
logits_per_text = logit_scale * text_features @ image_features.t()
|
||||
|
||||
return logits_per_image, logits_per_text
|
||||
|
||||
|
||||
def convert_weights(model: nn.Module):
|
||||
"""Convert applicable model parameters to fp16"""
|
||||
|
||||
def _convert_weights_to_fp16(lay):
|
||||
# l = lay
|
||||
if isinstance(lay, (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.Linear)):
|
||||
lay.weight.data = lay.weight.data.half()
|
||||
if lay.bias is not None:
|
||||
lay.bias.data = lay.bias.data.half()
|
||||
|
||||
if isinstance(lay, nn.MultiheadAttention):
|
||||
for attr in [
|
||||
*[f'{s}_proj_weight' for s in ['in', 'q', 'k', 'v']],
|
||||
'in_proj_bias', 'bias_k', 'bias_v'
|
||||
]:
|
||||
tensor = getattr(lay, attr)
|
||||
if tensor is not None:
|
||||
tensor.data = tensor.data.half()
|
||||
|
||||
for name in ['text_projection', 'proj']:
|
||||
if hasattr(lay, name):
|
||||
attr = getattr(lay, name)
|
||||
if attr is not None:
|
||||
attr.data = attr.data.half()
|
||||
|
||||
model.apply(_convert_weights_to_fp16)
|
||||
249
modelscope/models/multi_modal/prost/models/module_cross.py
Normal file
249
modelscope/models/multi_modal/prost/models/module_cross.py
Normal file
@@ -0,0 +1,249 @@
|
||||
from __future__ import absolute_import, division, print_function
|
||||
import copy
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import shutil
|
||||
import tarfile
|
||||
import tempfile
|
||||
from collections import OrderedDict
|
||||
|
||||
import json
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from .until_config import PreCrossConfig
|
||||
from .until_module import ACT2FN, LayerNorm, PreTrainedModel
|
||||
|
||||
|
||||
# PRETRAINED_MODEL_ARCHIVE_MAP = {}
|
||||
# CONFIG_NAME = 'cross_config.json'
|
||||
# WEIGHTS_NAME = 'cross_pytorch_model.bin'
|
||||
class CrossConfig(PreCrossConfig):
|
||||
"""Configuration class to store the configuration of a `CrossModel`.
|
||||
"""
|
||||
|
||||
# pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
# config_name = CONFIG_NAME
|
||||
# weights_name = WEIGHTS_NAME
|
||||
def __init__(self,
|
||||
vocab_size_or_config_json_file,
|
||||
hidden_size=768,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=12,
|
||||
intermediate_size=3072,
|
||||
hidden_act='gelu',
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=2,
|
||||
initializer_range=0.02):
|
||||
"""Constructs CrossConfig.
|
||||
|
||||
Args:
|
||||
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `CrossModel`.
|
||||
hidden_size: Size of the encoder layers and the pooler layer.
|
||||
num_hidden_layers: Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads: Number of attention heads for each attention layer in
|
||||
the Transformer encoder.
|
||||
intermediate_size: The size of the "intermediate" (i.e., feed-forward)
|
||||
layer in the Transformer encoder.
|
||||
hidden_act: The non-linear activation function (function or string) in the
|
||||
encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
|
||||
hidden_dropout_prob: The dropout probabilitiy for all fully connected
|
||||
layers in the embeddings, encoder, and pooler.
|
||||
attention_probs_dropout_prob: The dropout ratio for the attention
|
||||
probabilities.
|
||||
max_position_embeddings: The maximum sequence length that this model might
|
||||
ever be used with. Typically set this to something large just in case
|
||||
(e.g., 512 or 1024 or 2048).
|
||||
type_vocab_size: The vocabulary size of the `token_type_ids` passed into
|
||||
`CrossModel`.
|
||||
initializer_range: The sttdev of the truncated_normal_initializer for
|
||||
initializing all weight matrices.
|
||||
"""
|
||||
if isinstance(vocab_size_or_config_json_file, str):
|
||||
with open(
|
||||
vocab_size_or_config_json_file, 'r',
|
||||
encoding='utf-8') as reader:
|
||||
json_config = json.loads(reader.read())
|
||||
for key, value in json_config.items():
|
||||
self.__dict__[key] = value
|
||||
elif isinstance(vocab_size_or_config_json_file, int):
|
||||
self.vocab_size = vocab_size_or_config_json_file
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.initializer_range = initializer_range
|
||||
else:
|
||||
raise ValueError(
|
||||
'First argument must be either a vocabulary size (int)'
|
||||
'or the path to a pretrained model config file (str)')
|
||||
|
||||
|
||||
class QuickGELU(nn.Module):
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
return x * torch.sigmoid(1.702 * x)
|
||||
|
||||
|
||||
class ResidualAttentionBlock(nn.Module):
|
||||
|
||||
def __init__(self, d_model: int, n_head: int):
|
||||
super().__init__()
|
||||
|
||||
self.attn = nn.MultiheadAttention(d_model, n_head)
|
||||
self.ln_1 = LayerNorm(d_model)
|
||||
self.mlp = nn.Sequential(
|
||||
OrderedDict([('c_fc', nn.Linear(d_model, d_model * 4)),
|
||||
('gelu', QuickGELU()),
|
||||
('c_proj', nn.Linear(d_model * 4, d_model))]))
|
||||
self.ln_2 = LayerNorm(d_model)
|
||||
self.n_head = n_head
|
||||
|
||||
def attention(self, x: torch.Tensor, attn_mask: torch.Tensor):
|
||||
attn_mask_ = attn_mask.repeat_interleave(self.n_head, dim=0)
|
||||
return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask_)[0]
|
||||
|
||||
def forward(self, para_tuple: tuple):
|
||||
# x: torch.Tensor, attn_mask: torch.Tensor
|
||||
# print(para_tuple)
|
||||
x, attn_mask = para_tuple
|
||||
x = x + self.attention(self.ln_1(x), attn_mask)
|
||||
x = x + self.mlp(self.ln_2(x))
|
||||
return (x, attn_mask)
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
|
||||
def __init__(self, width: int, layers: int, heads: int):
|
||||
super().__init__()
|
||||
self.width = width
|
||||
self.layers = layers
|
||||
self.resblocks = nn.Sequential(
|
||||
*[ResidualAttentionBlock(width, heads) for _ in range(layers)])
|
||||
|
||||
def forward(self, x: torch.Tensor, attn_mask: torch.Tensor):
|
||||
return self.resblocks((x, attn_mask))[0]
|
||||
|
||||
|
||||
class CrossEmbeddings(nn.Module):
|
||||
"""Construct the embeddings from word, position and token_type embeddings.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super(CrossEmbeddings, self).__init__()
|
||||
|
||||
self.position_embeddings = nn.Embedding(config.max_position_embeddings,
|
||||
config.hidden_size)
|
||||
# self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
|
||||
# self.LayerNorm = LayerNorm(config.hidden_size, eps=1e-12)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(self, concat_embeddings, concat_type=None):
|
||||
|
||||
_, seq_length = concat_embeddings.size(0), concat_embeddings.size(1)
|
||||
# if concat_type is None:
|
||||
# concat_type = torch.zeros(batch_size, concat_type).to(concat_embeddings.device)
|
||||
|
||||
position_ids = torch.arange(
|
||||
seq_length, dtype=torch.long, device=concat_embeddings.device)
|
||||
position_ids = position_ids.unsqueeze(0).expand(
|
||||
concat_embeddings.size(0), -1)
|
||||
|
||||
# token_type_embeddings = self.token_type_embeddings(concat_type)
|
||||
position_embeddings = self.position_embeddings(position_ids)
|
||||
|
||||
embeddings = concat_embeddings + position_embeddings # + token_type_embeddings
|
||||
# embeddings = self.LayerNorm(embeddings)
|
||||
embeddings = self.dropout(embeddings)
|
||||
return embeddings
|
||||
|
||||
|
||||
class CrossPooler(nn.Module):
|
||||
|
||||
def __init__(self, config):
|
||||
super(CrossPooler, self).__init__()
|
||||
self.ln_pool = LayerNorm(config.hidden_size)
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.activation = QuickGELU()
|
||||
|
||||
def forward(self, hidden_states, hidden_mask):
|
||||
# We "pool" the model by simply taking the hidden state corresponding
|
||||
# to the first token.
|
||||
hidden_states = self.ln_pool(hidden_states)
|
||||
pooled_output = hidden_states[:, 0]
|
||||
pooled_output = self.dense(pooled_output)
|
||||
pooled_output = self.activation(pooled_output)
|
||||
return pooled_output
|
||||
|
||||
|
||||
class CrossModel(PreTrainedModel):
|
||||
|
||||
def initialize_parameters(self):
|
||||
proj_std = (self.transformer.width**-0.5) * (
|
||||
(2 * self.transformer.layers)**-0.5)
|
||||
attn_std = self.transformer.width**-0.5
|
||||
fc_std = (2 * self.transformer.width)**-0.5
|
||||
for block in self.transformer.resblocks:
|
||||
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
||||
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
||||
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
||||
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
||||
|
||||
def __init__(self, config):
|
||||
super(CrossModel, self).__init__(config)
|
||||
|
||||
self.embeddings = CrossEmbeddings(config)
|
||||
|
||||
transformer_width = config.hidden_size
|
||||
transformer_layers = config.num_hidden_layers
|
||||
transformer_heads = config.num_attention_heads
|
||||
self.transformer = Transformer(
|
||||
width=transformer_width,
|
||||
layers=transformer_layers,
|
||||
heads=transformer_heads,
|
||||
)
|
||||
self.pooler = CrossPooler(config)
|
||||
self.apply(self.init_weights)
|
||||
|
||||
def build_attention_mask(self, attention_mask):
|
||||
extended_attention_mask = attention_mask.unsqueeze(1)
|
||||
extended_attention_mask = extended_attention_mask.to(
|
||||
dtype=self.dtype) # fp16 compatibility
|
||||
extended_attention_mask = (1.0 - extended_attention_mask) * -1000000.0
|
||||
extended_attention_mask = extended_attention_mask.expand(
|
||||
-1, attention_mask.size(1), -1)
|
||||
return extended_attention_mask
|
||||
|
||||
def forward(self,
|
||||
concat_input,
|
||||
concat_type=None,
|
||||
attention_mask=None,
|
||||
output_all_encoded_layers=True):
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(
|
||||
concat_input.size(0), concat_input.size(1))
|
||||
if concat_type is None:
|
||||
concat_type = torch.zeros_like(attention_mask)
|
||||
|
||||
extended_attention_mask = self.build_attention_mask(attention_mask)
|
||||
|
||||
embedding_output = self.embeddings(concat_input, concat_type)
|
||||
embedding_output = embedding_output.permute(1, 0, 2) # NLD -> LND
|
||||
embedding_output = self.transformer(embedding_output,
|
||||
extended_attention_mask)
|
||||
embedding_output = embedding_output.permute(1, 0, 2) # LND -> NLD
|
||||
|
||||
pooled_output = self.pooler(
|
||||
embedding_output, hidden_mask=attention_mask)
|
||||
|
||||
return embedding_output, pooled_output
|
||||
267
modelscope/models/multi_modal/prost/models/prost_model.py
Normal file
267
modelscope/models/multi_modal/prost/models/prost_model.py
Normal file
@@ -0,0 +1,267 @@
|
||||
# The implementation is adopted from the CLIP4Clip implementation,
|
||||
# made pubicly available under Apache License, Version 2.0 at https://github.com/ArrowLuo/CLIP4Clip
|
||||
|
||||
import os
|
||||
import random
|
||||
import uuid
|
||||
from os.path import exists
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Any, Dict
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import json
|
||||
import numpy as np
|
||||
import torch
|
||||
from decord import VideoReader, cpu
|
||||
from PIL import Image
|
||||
|
||||
from modelscope.hub.file_download import http_get_file
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models import TorchModel
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.models.multi_modal.prost.models.modeling import CLIP4Clip
|
||||
from modelscope.models.multi_modal.prost.models.tokenization_clip import \
|
||||
SimpleTokenizer as ClipTokenizer
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
from ..dataloaders.rawvideo_util import RawVideoExtractor
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
@MODELS.register_module(Tasks.text_video_retrieval, module_name=Models.prost)
|
||||
class ProSTForTVRetrieval(TorchModel):
|
||||
|
||||
def __init__(self, model_dir, **kwargs):
|
||||
super().__init__(model_dir=model_dir, **kwargs)
|
||||
# model config parameters
|
||||
with open(
|
||||
f'{model_dir}/{ModelFile.CONFIGURATION}', 'r',
|
||||
encoding='utf-8') as json_file:
|
||||
all_model_config = json.load(json_file)
|
||||
model_config = all_model_config['paras']
|
||||
|
||||
cross_model_config = all_model_config['crossbase']
|
||||
# print(all_model_config)
|
||||
# print(cross_model_config)
|
||||
model_config['model_dir'] = model_dir
|
||||
self.SPECIAL_TOKEN = {
|
||||
'CLS_TOKEN': '<|startoftext|>',
|
||||
'SEP_TOKEN': '<|endoftext|>',
|
||||
'MASK_TOKEN': '[MASK]',
|
||||
'UNK_TOKEN': '[UNK]',
|
||||
'PAD_TOKEN': '[PAD]'
|
||||
}
|
||||
self.max_words = model_config['max_words']
|
||||
self.max_frames = model_config['max_frames']
|
||||
self.feature_framerate = model_config['feature_framerate']
|
||||
self.image_resolution = 224
|
||||
if torch.cuda.is_available():
|
||||
self.device = model_config['device']
|
||||
else:
|
||||
self.device = 'cpu'
|
||||
self.init_model = f'{model_dir}/{ModelFile.TORCH_MODEL_BIN_FILE}'
|
||||
|
||||
self.tokenizer = ClipTokenizer(model_dir)
|
||||
self.rawVideoExtractor = RawVideoExtractor(
|
||||
frame_rate=self.feature_framerate, size=self.image_resolution)
|
||||
self.local_transform = self.rawVideoExtractor.transform
|
||||
self.model = CLIP4Clip.from_pretrained(
|
||||
cross_config=cross_model_config, task_config=model_config)
|
||||
if hasattr(self.model, 'module'):
|
||||
self.model = self.model.module.to(self.device)
|
||||
else:
|
||||
self.model = self.model.to(self.device)
|
||||
if self.init_model:
|
||||
assert exists(self.init_model)
|
||||
model_state_dict = torch.load(self.init_model, map_location='cpu')
|
||||
self.model.load_state_dict(model_state_dict, strict=False)
|
||||
self.model.to(self.device)
|
||||
|
||||
def _get_text(self, caption, tokenizer, enable_zh=False):
|
||||
|
||||
if type(caption) is str:
|
||||
_caption_text, s, e = caption, None, None
|
||||
elif type(caption) is tuple:
|
||||
if len(caption) == 3:
|
||||
_caption_text, s, e = caption
|
||||
elif len(caption) == 4:
|
||||
_caption_text, s, e, pos = caption
|
||||
else:
|
||||
NotImplementedError
|
||||
|
||||
if isinstance(_caption_text, list):
|
||||
caption_text = random.choice(_caption_text)
|
||||
else:
|
||||
caption_text = _caption_text
|
||||
if enable_zh:
|
||||
_token = tokenizer.encode(caption_text)
|
||||
input_ids = _token.ids
|
||||
input_mask = _token.attention_mask
|
||||
segment_ids = _token.type_ids
|
||||
else:
|
||||
words = tokenizer.tokenize(caption_text)
|
||||
|
||||
words = [self.SPECIAL_TOKEN['CLS_TOKEN']] + words
|
||||
total_length_with_CLS = self.max_words - 1
|
||||
if len(words) > total_length_with_CLS:
|
||||
words = words[:total_length_with_CLS]
|
||||
words = words + [self.SPECIAL_TOKEN['SEP_TOKEN']]
|
||||
|
||||
input_ids = tokenizer.convert_tokens_to_ids(words)
|
||||
input_mask = [1] * len(input_ids)
|
||||
segment_ids = [0] * len(input_ids)
|
||||
|
||||
while len(input_ids) < self.max_words:
|
||||
input_ids.append(0)
|
||||
input_mask.append(0)
|
||||
segment_ids.append(0)
|
||||
assert len(input_ids) == self.max_words
|
||||
assert len(input_mask) == self.max_words
|
||||
assert len(segment_ids) == self.max_words
|
||||
|
||||
pairs_text = np.array(input_ids)
|
||||
pairs_mask = np.array(input_mask)
|
||||
pairs_segment = np.array(segment_ids)
|
||||
|
||||
return pairs_text, pairs_mask, pairs_segment, s, e
|
||||
|
||||
def _get_rawvideo_dec(self,
|
||||
video_path,
|
||||
rawVideoExtractor,
|
||||
local_transform,
|
||||
s=None,
|
||||
e=None):
|
||||
video_mask = np.zeros(self.max_frames, dtype=int)
|
||||
max_video_length = 0
|
||||
|
||||
# T x 3 x H x W
|
||||
video = np.zeros((self.max_frames, 3, rawVideoExtractor.size,
|
||||
rawVideoExtractor.size),
|
||||
dtype=float)
|
||||
|
||||
if s is None:
|
||||
start_time, end_time = None, None
|
||||
else:
|
||||
start_time = int(s)
|
||||
end_time = int(e)
|
||||
start_time = start_time if start_time >= 0. else 0.
|
||||
end_time = end_time if end_time >= 0. else 0.
|
||||
if start_time > end_time:
|
||||
start_time, end_time = end_time, start_time
|
||||
elif start_time == end_time:
|
||||
end_time = end_time + 1
|
||||
|
||||
url_parsed = urlparse(video_path)
|
||||
if url_parsed.scheme in ('file', '') and exists(
|
||||
url_parsed.path): # Possibly a local file
|
||||
vreader = VideoReader(video_path, ctx=cpu(0))
|
||||
else:
|
||||
try:
|
||||
with TemporaryDirectory() as temporary_cache_dir:
|
||||
random_str = uuid.uuid4().hex
|
||||
http_get_file(
|
||||
url=video_path,
|
||||
local_dir=temporary_cache_dir,
|
||||
file_name=random_str,
|
||||
cookies=None)
|
||||
temp_file_path = os.path.join(temporary_cache_dir,
|
||||
random_str)
|
||||
vreader = VideoReader(temp_file_path, ctx=cpu(0))
|
||||
except Exception as ex:
|
||||
logger.error('non video input, output is {}!!!'.format(ex))
|
||||
return video, video_mask
|
||||
|
||||
fps = vreader.get_avg_fps()
|
||||
f_start = 0 if start_time is None else int(start_time * fps)
|
||||
f_end = int(
|
||||
min(1000000000 if end_time is None else end_time * fps,
|
||||
len(vreader) - 1))
|
||||
num_frames = f_end - f_start + 1
|
||||
if num_frames > 0:
|
||||
# L x T x 3 x H x W
|
||||
sample_fps = int(self.feature_framerate)
|
||||
t_stride = int(round(float(fps) / sample_fps))
|
||||
|
||||
all_pos = list(range(f_start, f_end + 1, t_stride))
|
||||
if len(all_pos) > self.max_frames:
|
||||
sample_pos = [
|
||||
all_pos[_] for _ in np.linspace(
|
||||
0, len(all_pos) - 1, num=self.max_frames, dtype=int)
|
||||
]
|
||||
else:
|
||||
sample_pos = all_pos
|
||||
patch_images = [
|
||||
Image.fromarray(f)
|
||||
for f in vreader.get_batch(sample_pos).asnumpy()
|
||||
]
|
||||
patch_images = torch.stack(
|
||||
[local_transform(img) for img in patch_images])
|
||||
slice_len = patch_images.shape[0]
|
||||
max_video_length = max_video_length if max_video_length > slice_len else slice_len
|
||||
if slice_len < 1:
|
||||
pass
|
||||
else:
|
||||
video[:slice_len, ...] = patch_images
|
||||
else:
|
||||
logger.error('video path: {} error. video id: {}'.format(
|
||||
video_path, video_id))
|
||||
|
||||
video_mask[:max_video_length] = [1] * max_video_length
|
||||
|
||||
return video, video_mask
|
||||
|
||||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
|
||||
|
||||
from modelscope.outputs import OutputKeys
|
||||
output = {}
|
||||
|
||||
if 'video' in input and input['video'] is not None:
|
||||
video_path = input['video']
|
||||
video, video_mask = self._get_rawvideo_dec(video_path,
|
||||
self.rawVideoExtractor,
|
||||
self.local_transform)
|
||||
video = torch.unsqueeze(
|
||||
torch.from_numpy(video), dim=0).to(self.device)
|
||||
video_mask = torch.unsqueeze(
|
||||
torch.from_numpy(video_mask), dim=0).to(self.device)
|
||||
|
||||
if 'text' in input and input['text'] is not None:
|
||||
caption = input['text']
|
||||
pairs_text, pairs_mask, pairs_segment, s, e = self._get_text(
|
||||
caption, self.tokenizer, enable_zh=False)
|
||||
input_ids = torch.unsqueeze(
|
||||
torch.from_numpy(pairs_text), dim=0).to(self.device)
|
||||
input_mask = torch.unsqueeze(
|
||||
torch.from_numpy(pairs_mask), dim=0).to(self.device)
|
||||
segment_ids = torch.unsqueeze(
|
||||
torch.from_numpy(pairs_segment), dim=0).to(self.device)
|
||||
|
||||
phr_feat, sen_feat, obj_feat, eve_feat = self.model.get_sequence_visual_output(
|
||||
input_ids, segment_ids, input_mask, video, video_mask)
|
||||
|
||||
sim_espm, _, sim_oppm, _ = self.model.get_max_similarity_logits(
|
||||
phr_feat,
|
||||
sen_feat,
|
||||
obj_feat,
|
||||
eve_feat,
|
||||
input_mask,
|
||||
video_mask,
|
||||
shaped=True)
|
||||
# logger.info('sim: {}'.format(sim_espm))
|
||||
# logger.info('sim: {}'.format(sim_oppm))
|
||||
sim_tv = sim_espm + 1.5 * sim_oppm
|
||||
|
||||
# logger.info('phrase prototype: {}'.format(phr_feat.shape))
|
||||
# logger.info('sentence prototype: {}'.format(sen_feat.shape))
|
||||
# logger.info('object prototype: {}'.format(obj_feat.shape))
|
||||
# logger.info('event prototype: {}'.format(eve_feat.shape))
|
||||
output[OutputKeys.TEXTVIDEO_SIM] = sim_tv.cpu().detach().numpy()
|
||||
output[OutputKeys.PHRASE_PROTOTYPE] = phr_feat.cpu().detach().numpy()
|
||||
output[OutputKeys.SENTENCE_PROTOTYPE] = sen_feat.cpu().detach().numpy()
|
||||
output[OutputKeys.OBJECT_PROTOTYPE] = obj_feat.cpu().detach().numpy()
|
||||
output[OutputKeys.EVENT_PROTOTYPE] = eve_feat.cpu().detach().numpy()
|
||||
return output
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return inputs
|
||||
161
modelscope/models/multi_modal/prost/models/tokenization_clip.py
Normal file
161
modelscope/models/multi_modal/prost/models/tokenization_clip.py
Normal file
@@ -0,0 +1,161 @@
|
||||
# The implementation is adopted from the CLIP4Clip implementation,
|
||||
# made pubicly available under Apache License, Version 2.0 at https://github.com/ArrowLuo/CLIP4Clip
|
||||
|
||||
import gzip
|
||||
import html
|
||||
import os
|
||||
from functools import lru_cache
|
||||
|
||||
import ftfy
|
||||
import regex as re
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def bytes_to_unicode():
|
||||
"""
|
||||
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
||||
The reversible bpe codes work on unicode strings.
|
||||
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
||||
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
||||
This is a signficant percentage of your normal, say, 32K bpe vocab.
|
||||
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
||||
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
||||
"""
|
||||
bs = list(range(ord('!'),
|
||||
ord('~') + 1)) + list(range(
|
||||
ord('¡'),
|
||||
ord('¬') + 1)) + list(range(ord('®'),
|
||||
ord('ÿ') + 1))
|
||||
cs = bs[:]
|
||||
n = 0
|
||||
for b in range(2**8):
|
||||
if b not in bs:
|
||||
bs.append(b)
|
||||
cs.append(2**8 + n)
|
||||
n += 1
|
||||
cs = [chr(n) for n in cs]
|
||||
return dict(zip(bs, cs))
|
||||
|
||||
|
||||
def get_pairs(word):
|
||||
"""Return set of symbol pairs in a word.
|
||||
Word is represented as tuple of symbols (symbols being variable-length strings).
|
||||
"""
|
||||
pairs = set()
|
||||
prev_char = word[0]
|
||||
for char in word[1:]:
|
||||
pairs.add((prev_char, char))
|
||||
prev_char = char
|
||||
return pairs
|
||||
|
||||
|
||||
def basic_clean(text):
|
||||
text = ftfy.fix_text(text)
|
||||
text = html.unescape(html.unescape(text))
|
||||
return text.strip()
|
||||
|
||||
|
||||
def whitespace_clean(text):
|
||||
text = re.sub(r'\s+', ' ', text)
|
||||
text = text.strip()
|
||||
return text
|
||||
|
||||
|
||||
class SimpleTokenizer(object):
|
||||
|
||||
def __init__(self, model_dir):
|
||||
bpe_path = '{}/bpe_simple_vocab_16e6.txt.gz'.format(model_dir)
|
||||
self.byte_encoder = bytes_to_unicode()
|
||||
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
||||
merges = gzip.open(bpe_path).read().decode('utf-8').split('\n')
|
||||
merges = merges[1:49152 - 256 - 2 + 1]
|
||||
merges = [tuple(merge.split()) for merge in merges]
|
||||
vocab = list(bytes_to_unicode().values())
|
||||
vocab = vocab + [v + '</w>' for v in vocab]
|
||||
for merge in merges:
|
||||
vocab.append(''.join(merge))
|
||||
vocab.extend(['<|startoftext|>', '<|endoftext|>'])
|
||||
self.encoder = dict(zip(vocab, range(len(vocab))))
|
||||
self.decoder = {v: k for k, v in self.encoder.items()}
|
||||
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
||||
self.cache = {
|
||||
'<|startoftext|>': '<|startoftext|>',
|
||||
'<|endoftext|>': '<|endoftext|>'
|
||||
}
|
||||
self.pat = re.compile(
|
||||
r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
|
||||
re.IGNORECASE)
|
||||
|
||||
self.vocab = self.encoder
|
||||
|
||||
def bpe(self, token):
|
||||
if token in self.cache:
|
||||
return self.cache[token]
|
||||
word = tuple(token[:-1]) + (token[-1] + '</w>', )
|
||||
pairs = get_pairs(word)
|
||||
|
||||
if not pairs:
|
||||
return token + '</w>'
|
||||
|
||||
while True:
|
||||
bigram = min(
|
||||
pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
||||
if bigram not in self.bpe_ranks:
|
||||
break
|
||||
first, second = bigram
|
||||
new_word = []
|
||||
i = 0
|
||||
while i < len(word):
|
||||
try:
|
||||
j = word.index(first, i)
|
||||
new_word.extend(word[i:j])
|
||||
i = j
|
||||
except Exception:
|
||||
new_word.extend(word[i:])
|
||||
break
|
||||
|
||||
if word[i] == first and i < len(word) - 1 and word[
|
||||
i + 1] == second:
|
||||
new_word.append(first + second)
|
||||
i += 2
|
||||
else:
|
||||
new_word.append(word[i])
|
||||
i += 1
|
||||
new_word = tuple(new_word)
|
||||
word = new_word
|
||||
if len(word) == 1:
|
||||
break
|
||||
else:
|
||||
pairs = get_pairs(word)
|
||||
word = ' '.join(word)
|
||||
self.cache[token] = word
|
||||
return word
|
||||
|
||||
def encode(self, text):
|
||||
bpe_tokens = []
|
||||
text = whitespace_clean(basic_clean(text)).lower()
|
||||
for token in re.findall(self.pat, text):
|
||||
token = ''.join(self.byte_encoder[b]
|
||||
for b in token.encode('utf-8'))
|
||||
bpe_tokens.extend(self.encoder[bpe_token]
|
||||
for bpe_token in self.bpe(token).split(' '))
|
||||
return bpe_tokens
|
||||
|
||||
def decode(self, tokens):
|
||||
text = ''.join([self.decoder[token] for token in tokens])
|
||||
text = bytearray([self.byte_decoder[c] for c in text]).decode(
|
||||
'utf-8', errors='replace').replace('</w>', ' ')
|
||||
return text
|
||||
|
||||
def tokenize(self, text):
|
||||
tokens = []
|
||||
text = whitespace_clean(basic_clean(text)).lower()
|
||||
for token in re.findall(self.pat, text):
|
||||
token = ''.join(self.byte_encoder[b]
|
||||
for b in token.encode('utf-8'))
|
||||
tokens.extend(
|
||||
bpe_token for bpe_token in self.bpe(token).split(' '))
|
||||
return tokens
|
||||
|
||||
def convert_tokens_to_ids(self, tokens):
|
||||
return [self.encoder[bpe_token] for bpe_token in tokens]
|
||||
59
modelscope/models/multi_modal/prost/models/until_config.py
Executable file
59
modelscope/models/multi_modal/prost/models/until_config.py
Executable file
@@ -0,0 +1,59 @@
|
||||
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
|
||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PyTorch BERT model."""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
import copy
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import tarfile
|
||||
import tempfile
|
||||
|
||||
import json
|
||||
import torch
|
||||
|
||||
# from modelscope.utils.logger import get_logger
|
||||
# logger = get_logger(__name__)
|
||||
|
||||
|
||||
class PreCrossConfig(object):
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, json_object):
|
||||
"""Constructs a `BertConfig` from a Python dictionary of parameters."""
|
||||
config = cls(vocab_size_or_config_json_file=-1)
|
||||
for key, value in json_object.items():
|
||||
config.__dict__[key] = value
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def from_json_file(cls, json_file):
|
||||
"""Constructs a `BertConfig` from a json file of parameters."""
|
||||
with open(json_file, 'r', encoding='utf-8') as reader:
|
||||
text = reader.read()
|
||||
return cls.from_dict(json.loads(text))
|
||||
|
||||
def __repr__(self):
|
||||
return str(self.to_json_string())
|
||||
|
||||
def to_dict(self):
|
||||
"""Serializes this instance to a Python dictionary."""
|
||||
output = copy.deepcopy(self.__dict__)
|
||||
return output
|
||||
|
||||
def to_json_string(self):
|
||||
"""Serializes this instance to a JSON string."""
|
||||
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + '\n'
|
||||
574
modelscope/models/multi_modal/prost/models/until_module.py
Normal file
574
modelscope/models/multi_modal/prost/models/until_module.py
Normal file
@@ -0,0 +1,574 @@
|
||||
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
|
||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PyTorch BERT model."""
|
||||
|
||||
import copy
|
||||
import logging
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from modelscope.models.multi_modal.prost.models.until_config import \
|
||||
PreCrossConfig
|
||||
|
||||
|
||||
def gelu(x):
|
||||
"""Implementation of the gelu activation function.
|
||||
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
|
||||
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
||||
"""
|
||||
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
|
||||
|
||||
|
||||
def swish(x):
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
ACT2FN = {'gelu': gelu, 'relu': torch.nn.functional.relu, 'swish': swish}
|
||||
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
|
||||
def __init__(self, hidden_size, eps=1e-12):
|
||||
"""Construct a layernorm module in the TF style (epsilon inside the square root).
|
||||
"""
|
||||
super(LayerNorm, self).__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.bias = nn.Parameter(torch.zeros(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, x):
|
||||
u = x.mean(-1, keepdim=True)
|
||||
s = (x - u).pow(2).mean(-1, keepdim=True)
|
||||
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
|
||||
return self.weight * x + self.bias
|
||||
|
||||
|
||||
class CrossEn(nn.Module):
|
||||
|
||||
def __init__(self, config=None):
|
||||
super(CrossEn, self).__init__()
|
||||
|
||||
def forward(self, sim_matrix):
|
||||
logpt = F.log_softmax(sim_matrix, dim=-1)
|
||||
logpt = torch.diag(logpt)
|
||||
nce_loss = -logpt
|
||||
sim_loss = nce_loss.mean()
|
||||
return sim_loss
|
||||
|
||||
|
||||
class AllGather(torch.autograd.Function):
|
||||
"""An autograd function that performs allgather on a tensor."""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, tensor, args):
|
||||
if args.world_size == 1:
|
||||
ctx.rank = args.local_rank
|
||||
ctx.batch_size = tensor.shape[0]
|
||||
return tensor
|
||||
else:
|
||||
output = [torch.empty_like(tensor) for _ in range(args.world_size)]
|
||||
torch.distributed.all_gather(output, tensor)
|
||||
ctx.rank = args.local_rank
|
||||
ctx.batch_size = tensor.shape[0]
|
||||
return torch.cat(output, dim=0)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return (
|
||||
grad_output[ctx.batch_size * ctx.rank:ctx.batch_size
|
||||
* (ctx.rank + 1)],
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
class AllGather2(torch.autograd.Function):
|
||||
"""An autograd function that performs allgather on a tensor."""
|
||||
# https://github.com/PyTorchLightning/lightning-bolts/blob/8d3fbf7782e3d3937ab8a1775a7092d7567f2933/pl_bolts/models/self_supervised/simclr/simclr_module.py#L20
|
||||
@staticmethod
|
||||
def forward(ctx, tensor, args):
|
||||
if args.world_size == 1:
|
||||
ctx.rank = args.local_rank
|
||||
ctx.batch_size = tensor.shape[0]
|
||||
return tensor
|
||||
else:
|
||||
output = [torch.empty_like(tensor) for _ in range(args.world_size)]
|
||||
torch.distributed.all_gather(output, tensor)
|
||||
ctx.rank = args.local_rank
|
||||
ctx.batch_size = tensor.shape[0]
|
||||
return torch.cat(output, dim=0)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
grad_input = grad_output.clone()
|
||||
torch.distributed.all_reduce(
|
||||
grad_input, op=torch.distributed.ReduceOp.SUM, async_op=False)
|
||||
return (grad_input[ctx.rank * ctx.batch_size:(ctx.rank + 1)
|
||||
* ctx.batch_size], None)
|
||||
|
||||
|
||||
class PreTrainedModel(nn.Module):
|
||||
""" An abstract class to handle weights initialization and
|
||||
a simple interface for dowloading and loading pretrained models.
|
||||
"""
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super(PreTrainedModel, self).__init__()
|
||||
if not isinstance(config, PreCrossConfig):
|
||||
raise ValueError(
|
||||
'Parameter config in `{}(config)` should be an instance of class `PreCrossConfig`. '
|
||||
'To create a model from a Google pretrained model use '
|
||||
'`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`'.format(
|
||||
self.__class__.__name__, self.__class__.__name__))
|
||||
self.config = config
|
||||
|
||||
def init_weights(self, module):
|
||||
""" Initialize the weights.
|
||||
"""
|
||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
module.weight.data.normal_(
|
||||
mean=0.0, std=self.config.initializer_range)
|
||||
elif isinstance(module, LayerNorm):
|
||||
if 'beta' in dir(module) and 'gamma' in dir(module):
|
||||
module.beta.data.zero_()
|
||||
module.gamma.data.fill_(1.0)
|
||||
else:
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
def resize_token_embeddings(self, new_num_tokens=None):
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def init_preweight(cls, model, state_dict, prefix=None, task_config=None):
|
||||
old_keys = []
|
||||
new_keys = []
|
||||
for key in state_dict.keys():
|
||||
new_key = None
|
||||
if 'gamma' in key:
|
||||
new_key = key.replace('gamma', 'weight')
|
||||
if 'beta' in key:
|
||||
new_key = key.replace('beta', 'bias')
|
||||
if new_key:
|
||||
old_keys.append(key)
|
||||
new_keys.append(new_key)
|
||||
for old_key, new_key in zip(old_keys, new_keys):
|
||||
state_dict[new_key] = state_dict.pop(old_key)
|
||||
|
||||
if prefix is not None:
|
||||
old_keys = []
|
||||
new_keys = []
|
||||
for key in state_dict.keys():
|
||||
old_keys.append(key)
|
||||
new_keys.append(prefix + key)
|
||||
for old_key, new_key in zip(old_keys, new_keys):
|
||||
state_dict[new_key] = state_dict.pop(old_key)
|
||||
|
||||
missing_keys = []
|
||||
unexpected_keys = []
|
||||
error_msgs = []
|
||||
# copy state_dict so _load_from_state_dict can modify it
|
||||
metadata = getattr(state_dict, '_metadata', None)
|
||||
state_dict = state_dict.copy()
|
||||
if metadata is not None:
|
||||
state_dict._metadata = metadata
|
||||
|
||||
def load(module, prefix=''):
|
||||
local_metadata = {} if metadata is None else metadata.get(
|
||||
prefix[:-1], {})
|
||||
module._load_from_state_dict(state_dict, prefix, local_metadata,
|
||||
True, missing_keys, unexpected_keys,
|
||||
error_msgs)
|
||||
for name, child in module._modules.items():
|
||||
if child is not None:
|
||||
load(child, prefix + name + '.')
|
||||
|
||||
load(model, prefix='')
|
||||
|
||||
# if prefix is None and (task_config is None or task_config.local_rank == 0):
|
||||
# logger.info("-" * 20)
|
||||
# if len(missing_keys) > 0:
|
||||
# logger.info("Weights of {} not initialized from pretrained model: {}"
|
||||
# .format(model.__class__.__name__, "\n " + "\n ".join(missing_keys)))
|
||||
# if len(unexpected_keys) > 0:
|
||||
# logger.info("Weights from pretrained model not used in {}: {}"
|
||||
# .format(model.__class__.__name__, "\n " + "\n ".join(unexpected_keys)))
|
||||
# if len(error_msgs) > 0:
|
||||
# logger.error("Weights from pretrained model cause errors in {}: {}"
|
||||
# .format(model.__class__.__name__, "\n " + "\n ".join(error_msgs)))
|
||||
|
||||
return model
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
"""
|
||||
:obj:`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
|
||||
"""
|
||||
try:
|
||||
return next(self.parameters()).dtype
|
||||
except StopIteration:
|
||||
# For nn.DataParallel compatibility in PyTorch 1.5
|
||||
def find_tensor_attributes(module: nn.Module):
|
||||
tuples = [(k, v) for k, v in module.__dict__.items()
|
||||
if torch.is_tensor(v)]
|
||||
return tuples
|
||||
|
||||
gen = self._named_members(get_members_fn=find_tensor_attributes)
|
||||
first_tuple = next(gen)
|
||||
return first_tuple[1].dtype
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, config, state_dict=None, *inputs, **kwargs):
|
||||
"""
|
||||
Instantiate a PreTrainedModel from a pre-trained model file or a pytorch state dict.
|
||||
Download and cache the pre-trained model file if needed.
|
||||
"""
|
||||
# Instantiate model.
|
||||
model = cls(config, *inputs, **kwargs)
|
||||
if state_dict is None:
|
||||
return model
|
||||
model = cls.init_preweight(model, state_dict)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
class PatchShiftModule(nn.Module):
|
||||
|
||||
def __init__(self, net, video_frame, n_div):
|
||||
super().__init__()
|
||||
self.net = net
|
||||
self.video_frame = video_frame
|
||||
self.n_div = n_div
|
||||
|
||||
def forward(self,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
key_padding_mask=None,
|
||||
need_weights=True,
|
||||
attn_mask=None):
|
||||
# here q == k == v, psm means patch shift output
|
||||
x = query # shape here is LND, not NLD (50, 384, 768)
|
||||
x = x.permute(1, 0, 2) # LND -> NLD
|
||||
patch_len = x.shape[-2]
|
||||
fold = patch_len // self.n_div
|
||||
x = x.reshape(-1, self.video_frame, x.shape[-2],
|
||||
x.shape[-1]) # shape = [bs, frame, grid ** 2, width]
|
||||
psm = torch.zeros_like(x) # shape = [bs, frame, grid ** 2, width]
|
||||
psm[:, :, :, :] = x[:, :, :, :]
|
||||
lshift_indices = torch.arange(start=1, end=patch_len, step=fold)
|
||||
psm[:, 1:, lshift_indices, :] = x[:, :-1,
|
||||
lshift_indices, :] # f_t = f_t-1
|
||||
rshift_indices = torch.arange(start=1 + 3, end=patch_len, step=fold)
|
||||
psm[:, :-1, rshift_indices, :] = x[:, 1:,
|
||||
rshift_indices, :] # f_t = f_t+1
|
||||
x = psm.reshape(-1, patch_len, x.shape[-1])
|
||||
x = x.permute(1, 0, 2) # NLD -> LND
|
||||
|
||||
return self.net(
|
||||
x, x, x, need_weights=need_weights, attn_mask=attn_mask)
|
||||
|
||||
|
||||
def make_patch_shift(net, video_frame=12, shift_layers=4, n_div=7):
|
||||
'''
|
||||
Args:
|
||||
net: CLIP
|
||||
video_frame: need predefine here
|
||||
shift_layers: layers to be shift
|
||||
'''
|
||||
|
||||
def make_trans_patch_shift(stage, shift_layers):
|
||||
blocks = list(stage.children())
|
||||
for i, b in enumerate(blocks):
|
||||
if i >= 10 and i <= 11:
|
||||
blocks[i].attn = PatchShiftModule(
|
||||
b.attn, video_frame=video_frame, n_div=n_div)
|
||||
return nn.Sequential(*blocks)
|
||||
|
||||
net.clip.visual.transformer.resblocks = make_trans_patch_shift(
|
||||
net.clip.visual.transformer.resblocks, shift_layers=shift_layers)
|
||||
|
||||
|
||||
def _get_clones(module, N):
|
||||
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
||||
|
||||
|
||||
class Event_Layer(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
d_model,
|
||||
nhead,
|
||||
dim_feedforward=2048,
|
||||
dropout=0.1,
|
||||
activation='relu',
|
||||
normalize_before=False,
|
||||
is_weights=False):
|
||||
super().__init__()
|
||||
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
||||
self.self_attn_vis = nn.MultiheadAttention(
|
||||
d_model, nhead, dropout=dropout)
|
||||
self.multihead_attn = nn.MultiheadAttention(
|
||||
d_model, nhead, dropout=dropout)
|
||||
# Implementation of Feedforward model
|
||||
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
||||
|
||||
self.norm1 = nn.LayerNorm(d_model)
|
||||
self.norm2 = nn.LayerNorm(d_model)
|
||||
self.norm3 = nn.LayerNorm(d_model)
|
||||
self.norm4 = nn.LayerNorm(d_model)
|
||||
self.norm5 = nn.LayerNorm(d_model)
|
||||
|
||||
self.dropout1 = nn.Dropout(dropout)
|
||||
self.dropout2 = nn.Dropout(dropout)
|
||||
self.dropout3 = nn.Dropout(dropout)
|
||||
|
||||
self.activation = nn.ReLU(inplace=True)
|
||||
self.normalize_before = normalize_before
|
||||
self.is_weights = is_weights
|
||||
|
||||
def forward(self, tgt, memory, pos=None, query_pos=None):
|
||||
|
||||
tgt = self.norm1(tgt)
|
||||
memory = self.norm2(memory)
|
||||
tgt = self.self_attn(tgt, tgt, tgt)[0]
|
||||
tgt = self.norm3(tgt)
|
||||
|
||||
tgt2, atten_weights = self.multihead_attn(tgt, memory, memory)
|
||||
tgt = tgt + self.dropout1(tgt2)
|
||||
tgt = self.norm4(tgt)
|
||||
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
|
||||
tgt = tgt + self.dropout2(tgt2)
|
||||
tgt = self.norm5(tgt)
|
||||
|
||||
return tgt, atten_weights
|
||||
|
||||
|
||||
def adaptive_mask(aa, bb, ada_para):
|
||||
tensor = torch.zeros((aa, bb))
|
||||
adaptive_num = int(bb * ada_para)
|
||||
cc = int(bb / aa)
|
||||
for i in range(aa):
|
||||
start_col = i * cc
|
||||
end_col = start_col + cc + adaptive_num
|
||||
if end_col > bb - 1:
|
||||
tmp = end_col - (bb - 1)
|
||||
start_col = start_col - tmp
|
||||
if start_col < 0:
|
||||
start_col = 0
|
||||
end_col = bb
|
||||
tensor[i, start_col:end_col] = 1
|
||||
tensor = ~tensor.bool()
|
||||
return tensor
|
||||
|
||||
|
||||
class Frame_Layer(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
d_model,
|
||||
nhead,
|
||||
dim_feedforward=2048,
|
||||
para=1.0,
|
||||
dropout=0.1,
|
||||
activation='relu',
|
||||
normalize_before=False,
|
||||
is_weights=False):
|
||||
super().__init__()
|
||||
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
||||
self.self_attn_vis = nn.MultiheadAttention(
|
||||
d_model, nhead, dropout=dropout)
|
||||
self.multihead_attn = nn.MultiheadAttention(
|
||||
d_model, nhead, dropout=dropout)
|
||||
# Implementation of Feedforward model
|
||||
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
||||
|
||||
self.norm1 = nn.LayerNorm(d_model)
|
||||
self.norm2 = nn.LayerNorm(d_model)
|
||||
self.norm3 = nn.LayerNorm(d_model)
|
||||
self.norm4 = nn.LayerNorm(d_model)
|
||||
|
||||
self.dropout1 = nn.Dropout(dropout)
|
||||
self.dropout2 = nn.Dropout(dropout)
|
||||
|
||||
self.activation = nn.ReLU(inplace=True)
|
||||
self.normalize_before = normalize_before
|
||||
self.is_weights = is_weights
|
||||
self.mask_para = para
|
||||
|
||||
def forward(self, tgt, memory, pos=None, query_pos=None):
|
||||
tgt = self.norm1(tgt)
|
||||
memory = self.norm2(memory)
|
||||
mask_new = adaptive_mask(tgt.shape[0], memory.shape[0], ada_para=0.2)
|
||||
tgt2, atten_weights = self.multihead_attn(
|
||||
tgt, memory, memory, attn_mask=mask_new.cuda())
|
||||
tgt = tgt + self.dropout1(tgt2)
|
||||
|
||||
tgt = self.norm3(tgt)
|
||||
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
|
||||
tgt = tgt + self.dropout2(tgt2)
|
||||
tgt = self.norm4(tgt)
|
||||
|
||||
return tgt, atten_weights
|
||||
|
||||
|
||||
class TransDecoder(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
decoder_layer,
|
||||
num_layers,
|
||||
norm=None,
|
||||
return_intermediate=False):
|
||||
super().__init__()
|
||||
self.layers = _get_clones(decoder_layer, num_layers)
|
||||
self.num_layers = num_layers
|
||||
self.norm = norm
|
||||
self.return_intermediate = return_intermediate
|
||||
|
||||
def forward(self, tgt, memory, pos=None, query_pos=None):
|
||||
output = tgt
|
||||
|
||||
intermediate = []
|
||||
all_weights = []
|
||||
|
||||
for layer in self.layers:
|
||||
output, weights = layer(
|
||||
output, memory, pos=pos, query_pos=query_pos)
|
||||
if self.return_intermediate:
|
||||
intermediate.append(self.norm(output))
|
||||
all_weights.append(weights)
|
||||
|
||||
if self.norm is not None:
|
||||
output = self.norm(output)
|
||||
if self.return_intermediate:
|
||||
intermediate.pop()
|
||||
intermediate.append(output)
|
||||
|
||||
if self.return_intermediate:
|
||||
return torch.stack(intermediate), torch.stack(all_weights)
|
||||
return output.unsqueeze(0)
|
||||
|
||||
|
||||
class Event_decoder(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
num_attris=3,
|
||||
layers=1,
|
||||
heads=1,
|
||||
dim_ftr=512,
|
||||
pos_emb=False,
|
||||
length=1,
|
||||
dim_feedforward=512,
|
||||
without_init=False):
|
||||
super().__init__()
|
||||
embedding_dim = dim_ftr
|
||||
|
||||
d_model = dim_ftr
|
||||
dim_feedforward = dim_feedforward
|
||||
|
||||
self.V = nn.Parameter(
|
||||
torch.Tensor(num_attris, dim_feedforward), requires_grad=True)
|
||||
nn.init.xavier_uniform_(self.V)
|
||||
decoder_layer = Event_Layer(
|
||||
d_model=d_model, nhead=heads, dim_feedforward=dim_feedforward)
|
||||
self.event_decoder = TransDecoder(
|
||||
decoder_layer,
|
||||
layers,
|
||||
nn.LayerNorm(d_model),
|
||||
return_intermediate=True)
|
||||
self.use_pos_enc = pos_emb
|
||||
|
||||
if self.use_pos_enc:
|
||||
self.position_encoding_pre = positionalencoding2d(
|
||||
embedding_dim, 14, 14).unsqueeze(0)
|
||||
|
||||
def forward(self, features):
|
||||
batch_size = features.shape[0]
|
||||
if self.use_pos_enc: # False
|
||||
pos_encoding = self.position_encoding_pre(
|
||||
features,
|
||||
torch.zeros(features.shape[0], 14, 14,
|
||||
dtype=torch.bool).cuda())
|
||||
features = features + pos_encoding
|
||||
|
||||
enco_others = features.permute(1, 0, 2)
|
||||
h_attr = self.V
|
||||
h_attr_batch = h_attr.unsqueeze(0).repeat(batch_size, 1, 1)
|
||||
h_attr_batch = h_attr_batch.permute(1, 0, 2)
|
||||
|
||||
hs, _ = self.event_decoder(h_attr_batch, enco_others)
|
||||
hs = hs[-1].permute(1, 0, 2)
|
||||
return hs
|
||||
|
||||
|
||||
class Frame_decoder(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
num_attris=3,
|
||||
layers=1,
|
||||
heads=1,
|
||||
dim_ftr=512,
|
||||
pos_emb=False,
|
||||
length=1,
|
||||
dim_feedforward=512,
|
||||
without_init=False):
|
||||
super().__init__()
|
||||
embedding_dim = dim_ftr
|
||||
d_model = dim_ftr
|
||||
dim_feedforward = dim_feedforward
|
||||
|
||||
self.V = nn.Parameter(
|
||||
torch.Tensor(num_attris, dim_feedforward), requires_grad=True)
|
||||
nn.init.xavier_uniform_(self.V)
|
||||
decoder_layer = Frame_Layer(
|
||||
d_model=d_model, nhead=heads, dim_feedforward=dim_feedforward)
|
||||
self.event_decoder = TransDecoder(
|
||||
decoder_layer,
|
||||
layers,
|
||||
nn.LayerNorm(d_model),
|
||||
return_intermediate=True)
|
||||
self.use_pos_enc = pos_emb
|
||||
|
||||
if self.use_pos_enc:
|
||||
self.position_encoding_pre = positionalencoding2d(
|
||||
embedding_dim, 14, 14).unsqueeze(0)
|
||||
|
||||
def forward(self, features):
|
||||
batch_size = features.shape[0]
|
||||
if self.use_pos_enc:
|
||||
pos_encoding = self.position_encoding_pre(
|
||||
features,
|
||||
torch.zeros(features.shape[0], 14, 14,
|
||||
dtype=torch.bool).cuda())
|
||||
features = features + pos_encoding
|
||||
|
||||
enco_others = features.permute(1, 0, 2)
|
||||
h_attr = self.V
|
||||
h_attr_batch = h_attr.unsqueeze(0).repeat(batch_size, 1, 1)
|
||||
h_attr_batch = h_attr_batch.permute(1, 0, 2)
|
||||
|
||||
hs, _ = self.event_decoder(h_attr_batch, enco_others)
|
||||
hs = hs[-1].permute(1, 0, 2)
|
||||
|
||||
return hs
|
||||
@@ -33,10 +33,12 @@ from .backbone import MsModelMixin
|
||||
def get_chat_prompt(system: str, text: str, history: List[Tuple[str, str]],
|
||||
max_length: int, tokenizer):
|
||||
system_prompt = f'<s>[INST] <<SYS>>\n{system}\n<</SYS>>\n\n'
|
||||
system_ids = tokenizer(system_prompt, return_tensors='pt').input_ids
|
||||
system_ids = tokenizer(
|
||||
system_prompt, add_special_tokens=False, return_tensors='pt').input_ids
|
||||
|
||||
text_prompt = f'{text.strip()} [/INST]'
|
||||
text_ids = tokenizer(text_prompt, return_tensors='pt').input_ids
|
||||
text_ids = tokenizer(
|
||||
text_prompt, add_special_tokens=False, return_tensors='pt').input_ids
|
||||
|
||||
prompt_length = system_ids.shape[-1] + text_ids.shape[-1]
|
||||
if prompt_length > max_length:
|
||||
@@ -51,7 +53,9 @@ def get_chat_prompt(system: str, text: str, history: List[Tuple[str, str]],
|
||||
assert isinstance(user, str)
|
||||
assert isinstance(bot, str)
|
||||
round_prompt = f'{user.strip()} [/INST] {bot.strip()} </s><s>[INST] '
|
||||
round_ids = tokenizer(round_prompt, return_tensors='pt').input_ids
|
||||
round_ids = tokenizer(
|
||||
round_prompt, add_special_tokens=False,
|
||||
return_tensors='pt').input_ids
|
||||
if prompt_length + round_ids.shape[-1] > max_length:
|
||||
# excess history should not be appended to the prompt
|
||||
break
|
||||
|
||||
@@ -26,9 +26,9 @@ class PolyLMForTextGeneration(TorchModel, StreamingOutputMixin):
|
||||
"""
|
||||
super().__init__(model_dir, *args, **kwargs)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_dir, use_fast=False)
|
||||
model_dir, legacy=False, use_fast=False)
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_dir, device_map='auto')
|
||||
model_dir, device_map='auto', trust_remote_code=True)
|
||||
self.model.eval()
|
||||
|
||||
def forward(self, input: Dict[str, Tensor], **kwargs) -> Dict[str, Tensor]:
|
||||
|
||||
@@ -113,7 +113,7 @@ class ReferringVideoObjectSegmentationDataset(TorchCustomDataset):
|
||||
instance_masks = instance_masks[np.newaxis, ...]
|
||||
instance_masks = torch.tensor(instance_masks).transpose(1, 2)
|
||||
mask_rles = [encode(mask) for mask in instance_masks.numpy()]
|
||||
mask_areas = area(mask_rles).astype(np.float)
|
||||
mask_areas = area(mask_rles).astype(float)
|
||||
f.close()
|
||||
|
||||
# create the target dict for the center frame:
|
||||
|
||||
0
modelscope/ops/human_image_generation/__init__.py
Normal file
0
modelscope/ops/human_image_generation/__init__.py
Normal file
118
modelscope/ops/human_image_generation/fused_act.py
Normal file
118
modelscope/ops/human_image_generation/fused_act.py
Normal file
@@ -0,0 +1,118 @@
|
||||
import os
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.autograd import Function
|
||||
from torch.nn import functional as F
|
||||
from torch.utils.cpp_extension import load
|
||||
|
||||
module_path = os.path.dirname(__file__)
|
||||
fused = load(
|
||||
'fused',
|
||||
sources=[
|
||||
os.path.join(module_path, 'fused_bias_act.cpp'),
|
||||
os.path.join(module_path, 'fused_bias_act_kernel.cu'),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class FusedLeakyReLUFunctionBackward(Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, grad_output, out, bias, negative_slope, scale):
|
||||
ctx.save_for_backward(out)
|
||||
ctx.negative_slope = negative_slope
|
||||
ctx.scale = scale
|
||||
|
||||
empty = grad_output.new_empty(0)
|
||||
|
||||
grad_input = fused.fused_bias_act(grad_output, empty, out, 3, 1,
|
||||
negative_slope, scale)
|
||||
|
||||
dim = [0]
|
||||
|
||||
if grad_input.ndim > 2:
|
||||
dim += list(range(2, grad_input.ndim))
|
||||
|
||||
if bias:
|
||||
grad_bias = grad_input.sum(dim).detach()
|
||||
|
||||
else:
|
||||
grad_bias = empty
|
||||
|
||||
return grad_input, grad_bias
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, gradgrad_input, gradgrad_bias):
|
||||
out, = ctx.saved_tensors
|
||||
gradgrad_out = fused.fused_bias_act(gradgrad_input, gradgrad_bias, out,
|
||||
3, 1, ctx.negative_slope,
|
||||
ctx.scale)
|
||||
|
||||
return gradgrad_out, None, None, None, None
|
||||
|
||||
|
||||
class FusedLeakyReLUFunction(Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input, bias, negative_slope, scale):
|
||||
empty = input.new_empty(0)
|
||||
|
||||
ctx.bias = bias is not None
|
||||
|
||||
if bias is None:
|
||||
bias = empty
|
||||
|
||||
out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope,
|
||||
scale)
|
||||
ctx.save_for_backward(out)
|
||||
ctx.negative_slope = negative_slope
|
||||
ctx.scale = scale
|
||||
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
out, = ctx.saved_tensors
|
||||
|
||||
grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(
|
||||
grad_output, out, ctx.bias, ctx.negative_slope, ctx.scale)
|
||||
|
||||
if not ctx.bias:
|
||||
grad_bias = None
|
||||
|
||||
return grad_input, grad_bias, None, None
|
||||
|
||||
|
||||
class FusedLeakyReLU(nn.Module):
|
||||
|
||||
def __init__(self, channel, bias=True, negative_slope=0.2, scale=2**0.5):
|
||||
super().__init__()
|
||||
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.zeros(channel))
|
||||
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
self.negative_slope = negative_slope
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, input):
|
||||
return fused_leaky_relu(input, self.bias, self.negative_slope,
|
||||
self.scale)
|
||||
|
||||
|
||||
def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2**0.5):
|
||||
if input.device.type == 'cpu':
|
||||
if bias is not None:
|
||||
rest_dim = [1] * (input.ndim - bias.ndim - 1)
|
||||
return (F.leaky_relu(
|
||||
input + bias.view(1, bias.shape[0], *rest_dim),
|
||||
negative_slope=0.2) * scale)
|
||||
|
||||
else:
|
||||
return F.leaky_relu(input, negative_slope=0.2) * scale
|
||||
|
||||
else:
|
||||
return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
|
||||
21
modelscope/ops/human_image_generation/fused_bias_act.cpp
Normal file
21
modelscope/ops/human_image_generation/fused_bias_act.cpp
Normal file
@@ -0,0 +1,21 @@
|
||||
#include <torch/extension.h>
|
||||
|
||||
|
||||
torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
|
||||
int act, int grad, float alpha, float scale);
|
||||
|
||||
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
||||
|
||||
torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
|
||||
int act, int grad, float alpha, float scale) {
|
||||
CHECK_CUDA(input);
|
||||
CHECK_CUDA(bias);
|
||||
|
||||
return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
|
||||
}
|
||||
@@ -0,0 +1,99 @@
|
||||
// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
|
||||
//
|
||||
// This work is made available under the Nvidia Source Code License-NC.
|
||||
// To view a copy of this license, visit
|
||||
// https://nvlabs.github.io/stylegan2/license.html
|
||||
|
||||
#include <torch/types.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/AccumulateType.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/cuda/CUDAApplyUtils.cuh>
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
|
||||
template <typename scalar_t>
|
||||
static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref,
|
||||
int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) {
|
||||
int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
|
||||
|
||||
scalar_t zero = 0.0;
|
||||
|
||||
for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) {
|
||||
scalar_t x = p_x[xi];
|
||||
|
||||
if (use_bias) {
|
||||
x += p_b[(xi / step_b) % size_b];
|
||||
}
|
||||
|
||||
scalar_t ref = use_ref ? p_ref[xi] : zero;
|
||||
|
||||
scalar_t y;
|
||||
|
||||
switch (act * 10 + grad) {
|
||||
default:
|
||||
case 10: y = x; break;
|
||||
case 11: y = x; break;
|
||||
case 12: y = 0.0; break;
|
||||
|
||||
case 30: y = (x > 0.0) ? x : x * alpha; break;
|
||||
case 31: y = (ref > 0.0) ? x : x * alpha; break;
|
||||
case 32: y = 0.0; break;
|
||||
}
|
||||
|
||||
out[xi] = y * scale;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
|
||||
int act, int grad, float alpha, float scale) {
|
||||
int curDevice = -1;
|
||||
cudaGetDevice(&curDevice);
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
|
||||
|
||||
auto x = input.contiguous();
|
||||
auto b = bias.contiguous();
|
||||
auto ref = refer.contiguous();
|
||||
|
||||
int use_bias = b.numel() ? 1 : 0;
|
||||
int use_ref = ref.numel() ? 1 : 0;
|
||||
|
||||
int size_x = x.numel();
|
||||
int size_b = b.numel();
|
||||
int step_b = 1;
|
||||
|
||||
for (int i = 1 + 1; i < x.dim(); i++) {
|
||||
step_b *= x.size(i);
|
||||
}
|
||||
|
||||
int loop_x = 4;
|
||||
int block_size = 4 * 32;
|
||||
int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
|
||||
|
||||
auto y = torch::empty_like(x);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] {
|
||||
fused_bias_act_kernel<scalar_t><<<grid_size, block_size, 0, stream>>>(
|
||||
y.data_ptr<scalar_t>(),
|
||||
x.data_ptr<scalar_t>(),
|
||||
b.data_ptr<scalar_t>(),
|
||||
ref.data_ptr<scalar_t>(),
|
||||
act,
|
||||
grad,
|
||||
alpha,
|
||||
scale,
|
||||
loop_x,
|
||||
size_x,
|
||||
step_b,
|
||||
size_b,
|
||||
use_bias,
|
||||
use_ref
|
||||
);
|
||||
});
|
||||
|
||||
return y;
|
||||
}
|
||||
23
modelscope/ops/human_image_generation/upfirdn2d.cpp
Normal file
23
modelscope/ops/human_image_generation/upfirdn2d.cpp
Normal file
@@ -0,0 +1,23 @@
|
||||
#include <torch/extension.h>
|
||||
|
||||
|
||||
torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
|
||||
int up_x, int up_y, int down_x, int down_y,
|
||||
int pad_x0, int pad_x1, int pad_y0, int pad_y1);
|
||||
|
||||
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
||||
|
||||
torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel,
|
||||
int up_x, int up_y, int down_x, int down_y,
|
||||
int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
|
||||
CHECK_CUDA(input);
|
||||
CHECK_CUDA(kernel);
|
||||
|
||||
return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1);
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
|
||||
}
|
||||
208
modelscope/ops/human_image_generation/upfirdn2d.py
Normal file
208
modelscope/ops/human_image_generation/upfirdn2d.py
Normal file
@@ -0,0 +1,208 @@
|
||||
import os
|
||||
from collections import abc
|
||||
|
||||
import torch
|
||||
from torch.autograd import Function
|
||||
from torch.nn import functional as F
|
||||
from torch.utils.cpp_extension import load
|
||||
|
||||
module_path = os.path.dirname(__file__)
|
||||
upfirdn2d_op = load(
|
||||
'upfirdn2d',
|
||||
sources=[
|
||||
os.path.join(module_path, 'upfirdn2d.cpp'),
|
||||
os.path.join(module_path, 'upfirdn2d_kernel.cu'),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class UpFirDn2dBackward(Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad,
|
||||
in_size, out_size):
|
||||
|
||||
up_x, up_y = up
|
||||
down_x, down_y = down
|
||||
g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
|
||||
|
||||
grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
|
||||
|
||||
grad_input = upfirdn2d_op.upfirdn2d(
|
||||
grad_output,
|
||||
grad_kernel,
|
||||
down_x,
|
||||
down_y,
|
||||
up_x,
|
||||
up_y,
|
||||
g_pad_x0,
|
||||
g_pad_x1,
|
||||
g_pad_y0,
|
||||
g_pad_y1,
|
||||
)
|
||||
grad_input = grad_input.view(in_size[0], in_size[1], in_size[2],
|
||||
in_size[3])
|
||||
|
||||
ctx.save_for_backward(kernel)
|
||||
|
||||
pad_x0, pad_x1, pad_y0, pad_y1 = pad
|
||||
|
||||
ctx.up_x = up_x
|
||||
ctx.up_y = up_y
|
||||
ctx.down_x = down_x
|
||||
ctx.down_y = down_y
|
||||
ctx.pad_x0 = pad_x0
|
||||
ctx.pad_x1 = pad_x1
|
||||
ctx.pad_y0 = pad_y0
|
||||
ctx.pad_y1 = pad_y1
|
||||
ctx.in_size = in_size
|
||||
ctx.out_size = out_size
|
||||
|
||||
return grad_input
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, gradgrad_input):
|
||||
kernel, = ctx.saved_tensors
|
||||
|
||||
gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2],
|
||||
ctx.in_size[3], 1)
|
||||
|
||||
gradgrad_out = upfirdn2d_op.upfirdn2d(
|
||||
gradgrad_input,
|
||||
kernel,
|
||||
ctx.up_x,
|
||||
ctx.up_y,
|
||||
ctx.down_x,
|
||||
ctx.down_y,
|
||||
ctx.pad_x0,
|
||||
ctx.pad_x1,
|
||||
ctx.pad_y0,
|
||||
ctx.pad_y1,
|
||||
)
|
||||
# gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3])
|
||||
gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.in_size[1],
|
||||
ctx.out_size[0], ctx.out_size[1])
|
||||
|
||||
return gradgrad_out, None, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
class UpFirDn2d(Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input, kernel, up, down, pad):
|
||||
up_x, up_y = up
|
||||
down_x, down_y = down
|
||||
pad_x0, pad_x1, pad_y0, pad_y1 = pad
|
||||
|
||||
kernel_h, kernel_w = kernel.shape
|
||||
batch, channel, in_h, in_w = input.shape
|
||||
ctx.in_size = input.shape
|
||||
|
||||
input = input.reshape(-1, in_h, in_w, 1)
|
||||
|
||||
ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
|
||||
|
||||
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y
|
||||
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x
|
||||
ctx.out_size = (out_h, out_w)
|
||||
|
||||
ctx.up = (up_x, up_y)
|
||||
ctx.down = (down_x, down_y)
|
||||
ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
|
||||
|
||||
g_pad_x0 = kernel_w - pad_x0 - 1
|
||||
g_pad_y0 = kernel_h - pad_y0 - 1
|
||||
g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
|
||||
g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
|
||||
|
||||
ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
|
||||
|
||||
out = upfirdn2d_op.upfirdn2d(input, kernel, up_x, up_y, down_x, down_y,
|
||||
pad_x0, pad_x1, pad_y0, pad_y1)
|
||||
# out = out.view(major, out_h, out_w, minor)
|
||||
out = out.view(-1, channel, out_h, out_w)
|
||||
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
kernel, grad_kernel = ctx.saved_tensors
|
||||
|
||||
grad_input = None
|
||||
|
||||
if ctx.needs_input_grad[0]:
|
||||
grad_input = UpFirDn2dBackward.apply(
|
||||
grad_output,
|
||||
kernel,
|
||||
grad_kernel,
|
||||
ctx.up,
|
||||
ctx.down,
|
||||
ctx.pad,
|
||||
ctx.g_pad,
|
||||
ctx.in_size,
|
||||
ctx.out_size,
|
||||
)
|
||||
|
||||
return grad_input, None, None, None, None
|
||||
|
||||
|
||||
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
|
||||
if not isinstance(up, abc.Iterable):
|
||||
up = (up, up)
|
||||
|
||||
if not isinstance(down, abc.Iterable):
|
||||
down = (down, down)
|
||||
|
||||
if len(pad) == 2:
|
||||
pad = (pad[0], pad[1], pad[0], pad[1])
|
||||
|
||||
if input.device.type == 'cpu':
|
||||
out = upfirdn2d_native(input, kernel, *up, *down, *pad)
|
||||
|
||||
else:
|
||||
out = UpFirDn2d.apply(input, kernel, up, down, pad)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1,
|
||||
pad_y0, pad_y1):
|
||||
_, channel, in_h, in_w = input.shape
|
||||
input = input.reshape(-1, in_h, in_w, 1)
|
||||
|
||||
_, in_h, in_w, minor = input.shape
|
||||
kernel_h, kernel_w = kernel.shape
|
||||
|
||||
out = input.view(-1, in_h, 1, in_w, 1, minor)
|
||||
out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
|
||||
out = out.view(-1, in_h * up_y, in_w * up_x, minor)
|
||||
|
||||
out = F.pad(
|
||||
out,
|
||||
[0, 0,
|
||||
max(pad_x0, 0),
|
||||
max(pad_x1, 0),
|
||||
max(pad_y0, 0),
|
||||
max(pad_y1, 0)])
|
||||
out = out[:,
|
||||
max(-pad_y0, 0):out.shape[1] - max(-pad_y1, 0),
|
||||
max(-pad_x0, 0):out.shape[2] - max(-pad_x1, 0), :, ]
|
||||
|
||||
out = out.permute(0, 3, 1, 2)
|
||||
out = out.reshape(
|
||||
[-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
|
||||
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
|
||||
out = F.conv2d(out, w)
|
||||
out = out.reshape(
|
||||
-1,
|
||||
minor,
|
||||
in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
|
||||
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
|
||||
)
|
||||
out = out.permute(0, 2, 3, 1)
|
||||
out = out[:, ::down_y, ::down_x, :]
|
||||
|
||||
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y
|
||||
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x
|
||||
|
||||
return out.view(-1, channel, out_h, out_w)
|
||||
369
modelscope/ops/human_image_generation/upfirdn2d_kernel.cu
Normal file
369
modelscope/ops/human_image_generation/upfirdn2d_kernel.cu
Normal file
@@ -0,0 +1,369 @@
|
||||
// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
|
||||
//
|
||||
// This work is made available under the Nvidia Source Code License-NC.
|
||||
// To view a copy of this license, visit
|
||||
// https://nvlabs.github.io/stylegan2/license.html
|
||||
|
||||
#include <torch/types.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/AccumulateType.h>
|
||||
#include <ATen/cuda/CUDAApplyUtils.cuh>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
static __host__ __device__ __forceinline__ int floor_div(int a, int b) {
|
||||
int c = a / b;
|
||||
|
||||
if (c * b > a) {
|
||||
c--;
|
||||
}
|
||||
|
||||
return c;
|
||||
}
|
||||
|
||||
struct UpFirDn2DKernelParams {
|
||||
int up_x;
|
||||
int up_y;
|
||||
int down_x;
|
||||
int down_y;
|
||||
int pad_x0;
|
||||
int pad_x1;
|
||||
int pad_y0;
|
||||
int pad_y1;
|
||||
|
||||
int major_dim;
|
||||
int in_h;
|
||||
int in_w;
|
||||
int minor_dim;
|
||||
int kernel_h;
|
||||
int kernel_w;
|
||||
int out_h;
|
||||
int out_w;
|
||||
int loop_major;
|
||||
int loop_x;
|
||||
};
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input,
|
||||
const scalar_t *kernel,
|
||||
const UpFirDn2DKernelParams p) {
|
||||
int minor_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int out_y = minor_idx / p.minor_dim;
|
||||
minor_idx -= out_y * p.minor_dim;
|
||||
int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y;
|
||||
int major_idx_base = blockIdx.z * p.loop_major;
|
||||
|
||||
if (out_x_base >= p.out_w || out_y >= p.out_h ||
|
||||
major_idx_base >= p.major_dim) {
|
||||
return;
|
||||
}
|
||||
|
||||
int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0;
|
||||
int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h);
|
||||
int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y;
|
||||
int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y;
|
||||
|
||||
for (int loop_major = 0, major_idx = major_idx_base;
|
||||
loop_major < p.loop_major && major_idx < p.major_dim;
|
||||
loop_major++, major_idx++) {
|
||||
for (int loop_x = 0, out_x = out_x_base;
|
||||
loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) {
|
||||
int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0;
|
||||
int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w);
|
||||
int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x;
|
||||
int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x;
|
||||
|
||||
const scalar_t *x_p =
|
||||
&input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim +
|
||||
minor_idx];
|
||||
const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x];
|
||||
int x_px = p.minor_dim;
|
||||
int k_px = -p.up_x;
|
||||
int x_py = p.in_w * p.minor_dim;
|
||||
int k_py = -p.up_y * p.kernel_w;
|
||||
|
||||
scalar_t v = 0.0f;
|
||||
|
||||
for (int y = 0; y < h; y++) {
|
||||
for (int x = 0; x < w; x++) {
|
||||
v += static_cast<scalar_t>(*x_p) * static_cast<scalar_t>(*k_p);
|
||||
x_p += x_px;
|
||||
k_p += k_px;
|
||||
}
|
||||
|
||||
x_p += x_py - w * x_px;
|
||||
k_p += k_py - w * k_px;
|
||||
}
|
||||
|
||||
out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
|
||||
minor_idx] = v;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, int up_x, int up_y, int down_x, int down_y,
|
||||
int kernel_h, int kernel_w, int tile_out_h, int tile_out_w>
|
||||
__global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input,
|
||||
const scalar_t *kernel,
|
||||
const UpFirDn2DKernelParams p) {
|
||||
const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;
|
||||
const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;
|
||||
|
||||
__shared__ volatile float sk[kernel_h][kernel_w];
|
||||
__shared__ volatile float sx[tile_in_h][tile_in_w];
|
||||
|
||||
int minor_idx = blockIdx.x;
|
||||
int tile_out_y = minor_idx / p.minor_dim;
|
||||
minor_idx -= tile_out_y * p.minor_dim;
|
||||
tile_out_y *= tile_out_h;
|
||||
int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;
|
||||
int major_idx_base = blockIdx.z * p.loop_major;
|
||||
|
||||
if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h |
|
||||
major_idx_base >= p.major_dim) {
|
||||
return;
|
||||
}
|
||||
|
||||
for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w;
|
||||
tap_idx += blockDim.x) {
|
||||
int ky = tap_idx / kernel_w;
|
||||
int kx = tap_idx - ky * kernel_w;
|
||||
scalar_t v = 0.0;
|
||||
|
||||
if (kx < p.kernel_w & ky < p.kernel_h) {
|
||||
v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];
|
||||
}
|
||||
|
||||
sk[ky][kx] = v;
|
||||
}
|
||||
|
||||
for (int loop_major = 0, major_idx = major_idx_base;
|
||||
loop_major < p.loop_major & major_idx < p.major_dim;
|
||||
loop_major++, major_idx++) {
|
||||
for (int loop_x = 0, tile_out_x = tile_out_x_base;
|
||||
loop_x < p.loop_x & tile_out_x < p.out_w;
|
||||
loop_x++, tile_out_x += tile_out_w) {
|
||||
int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;
|
||||
int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;
|
||||
int tile_in_x = floor_div(tile_mid_x, up_x);
|
||||
int tile_in_y = floor_div(tile_mid_y, up_y);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w;
|
||||
in_idx += blockDim.x) {
|
||||
int rel_in_y = in_idx / tile_in_w;
|
||||
int rel_in_x = in_idx - rel_in_y * tile_in_w;
|
||||
int in_x = rel_in_x + tile_in_x;
|
||||
int in_y = rel_in_y + tile_in_y;
|
||||
|
||||
scalar_t v = 0.0;
|
||||
|
||||
if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {
|
||||
v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) *
|
||||
p.minor_dim +
|
||||
minor_idx];
|
||||
}
|
||||
|
||||
sx[rel_in_y][rel_in_x] = v;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w;
|
||||
out_idx += blockDim.x) {
|
||||
int rel_out_y = out_idx / tile_out_w;
|
||||
int rel_out_x = out_idx - rel_out_y * tile_out_w;
|
||||
int out_x = rel_out_x + tile_out_x;
|
||||
int out_y = rel_out_y + tile_out_y;
|
||||
|
||||
int mid_x = tile_mid_x + rel_out_x * down_x;
|
||||
int mid_y = tile_mid_y + rel_out_y * down_y;
|
||||
int in_x = floor_div(mid_x, up_x);
|
||||
int in_y = floor_div(mid_y, up_y);
|
||||
int rel_in_x = in_x - tile_in_x;
|
||||
int rel_in_y = in_y - tile_in_y;
|
||||
int kernel_x = (in_x + 1) * up_x - mid_x - 1;
|
||||
int kernel_y = (in_y + 1) * up_y - mid_y - 1;
|
||||
|
||||
scalar_t v = 0.0;
|
||||
|
||||
#pragma unroll
|
||||
for (int y = 0; y < kernel_h / up_y; y++)
|
||||
#pragma unroll
|
||||
for (int x = 0; x < kernel_w / up_x; x++)
|
||||
v += sx[rel_in_y + y][rel_in_x + x] *
|
||||
sk[kernel_y + y * up_y][kernel_x + x * up_x];
|
||||
|
||||
if (out_x < p.out_w & out_y < p.out_h) {
|
||||
out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
|
||||
minor_idx] = v;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
torch::Tensor upfirdn2d_op(const torch::Tensor &input,
|
||||
const torch::Tensor &kernel, int up_x, int up_y,
|
||||
int down_x, int down_y, int pad_x0, int pad_x1,
|
||||
int pad_y0, int pad_y1) {
|
||||
int curDevice = -1;
|
||||
cudaGetDevice(&curDevice);
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
|
||||
|
||||
UpFirDn2DKernelParams p;
|
||||
|
||||
auto x = input.contiguous();
|
||||
auto k = kernel.contiguous();
|
||||
|
||||
p.major_dim = x.size(0);
|
||||
p.in_h = x.size(1);
|
||||
p.in_w = x.size(2);
|
||||
p.minor_dim = x.size(3);
|
||||
p.kernel_h = k.size(0);
|
||||
p.kernel_w = k.size(1);
|
||||
p.up_x = up_x;
|
||||
p.up_y = up_y;
|
||||
p.down_x = down_x;
|
||||
p.down_y = down_y;
|
||||
p.pad_x0 = pad_x0;
|
||||
p.pad_x1 = pad_x1;
|
||||
p.pad_y0 = pad_y0;
|
||||
p.pad_y1 = pad_y1;
|
||||
|
||||
p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) /
|
||||
p.down_y;
|
||||
p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) /
|
||||
p.down_x;
|
||||
|
||||
auto out =
|
||||
at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());
|
||||
|
||||
int mode = -1;
|
||||
|
||||
int tile_out_h = -1;
|
||||
int tile_out_w = -1;
|
||||
|
||||
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
|
||||
p.kernel_h <= 4 && p.kernel_w <= 4) {
|
||||
mode = 1;
|
||||
tile_out_h = 16;
|
||||
tile_out_w = 64;
|
||||
}
|
||||
|
||||
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
|
||||
p.kernel_h <= 3 && p.kernel_w <= 3) {
|
||||
mode = 2;
|
||||
tile_out_h = 16;
|
||||
tile_out_w = 64;
|
||||
}
|
||||
|
||||
if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
|
||||
p.kernel_h <= 4 && p.kernel_w <= 4) {
|
||||
mode = 3;
|
||||
tile_out_h = 16;
|
||||
tile_out_w = 64;
|
||||
}
|
||||
|
||||
if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
|
||||
p.kernel_h <= 2 && p.kernel_w <= 2) {
|
||||
mode = 4;
|
||||
tile_out_h = 16;
|
||||
tile_out_w = 64;
|
||||
}
|
||||
|
||||
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
|
||||
p.kernel_h <= 4 && p.kernel_w <= 4) {
|
||||
mode = 5;
|
||||
tile_out_h = 8;
|
||||
tile_out_w = 32;
|
||||
}
|
||||
|
||||
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
|
||||
p.kernel_h <= 2 && p.kernel_w <= 2) {
|
||||
mode = 6;
|
||||
tile_out_h = 8;
|
||||
tile_out_w = 32;
|
||||
}
|
||||
|
||||
dim3 block_size;
|
||||
dim3 grid_size;
|
||||
|
||||
if (tile_out_h > 0 && tile_out_w > 0) {
|
||||
p.loop_major = (p.major_dim - 1) / 16384 + 1;
|
||||
p.loop_x = 1;
|
||||
block_size = dim3(32 * 8, 1, 1);
|
||||
grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,
|
||||
(p.out_w - 1) / (p.loop_x * tile_out_w) + 1,
|
||||
(p.major_dim - 1) / p.loop_major + 1);
|
||||
} else {
|
||||
p.loop_major = (p.major_dim - 1) / 16384 + 1;
|
||||
p.loop_x = 4;
|
||||
block_size = dim3(4, 32, 1);
|
||||
grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1,
|
||||
(p.out_w - 1) / (p.loop_x * block_size.y) + 1,
|
||||
(p.major_dim - 1) / p.loop_major + 1);
|
||||
}
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] {
|
||||
switch (mode) {
|
||||
case 1:
|
||||
upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 4, 4, 16, 64>
|
||||
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
|
||||
x.data_ptr<scalar_t>(),
|
||||
k.data_ptr<scalar_t>(), p);
|
||||
|
||||
break;
|
||||
|
||||
case 2:
|
||||
upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 3, 3, 16, 64>
|
||||
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
|
||||
x.data_ptr<scalar_t>(),
|
||||
k.data_ptr<scalar_t>(), p);
|
||||
|
||||
break;
|
||||
|
||||
case 3:
|
||||
upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 4, 4, 16, 64>
|
||||
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
|
||||
x.data_ptr<scalar_t>(),
|
||||
k.data_ptr<scalar_t>(), p);
|
||||
|
||||
break;
|
||||
|
||||
case 4:
|
||||
upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 2, 2, 16, 64>
|
||||
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
|
||||
x.data_ptr<scalar_t>(),
|
||||
k.data_ptr<scalar_t>(), p);
|
||||
|
||||
break;
|
||||
|
||||
case 5:
|
||||
upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>
|
||||
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
|
||||
x.data_ptr<scalar_t>(),
|
||||
k.data_ptr<scalar_t>(), p);
|
||||
|
||||
break;
|
||||
|
||||
case 6:
|
||||
upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>
|
||||
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
|
||||
x.data_ptr<scalar_t>(),
|
||||
k.data_ptr<scalar_t>(), p);
|
||||
|
||||
break;
|
||||
|
||||
default:
|
||||
upfirdn2d_kernel_large<scalar_t><<<grid_size, block_size, 0, stream>>>(
|
||||
out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(),
|
||||
k.data_ptr<scalar_t>(), p);
|
||||
}
|
||||
});
|
||||
|
||||
return out;
|
||||
}
|
||||
@@ -48,6 +48,11 @@ class OutputKeys(object):
|
||||
PROBABILITIES = 'probabilities'
|
||||
DIALOG_STATES = 'dialog_states'
|
||||
VIDEO_EMBEDDING = 'video_embedding'
|
||||
PHRASE_PROTOTYPE = 'phrase_prototype'
|
||||
OBJECT_PROTOTYPE = 'object_prototype'
|
||||
SENTENCE_PROTOTYPE = 'sentence_prototype'
|
||||
EVENT_PROTOTYPE = 'event_prototype'
|
||||
TEXTVIDEO_SIM = 'textvideo_sim'
|
||||
UUID = 'uuid'
|
||||
WORD = 'word'
|
||||
KWS_LIST = 'kws_list'
|
||||
@@ -90,9 +95,9 @@ OutputTypes = {
|
||||
OutputKeys.OUTPUT_IMG: 'image', # checked
|
||||
OutputKeys.OUTPUT_IMGS: List[np.ndarray], # checked
|
||||
OutputKeys.OUTPUT_VIDEO: 'bytes',
|
||||
OutputKeys.OUTPUT_PCM: np.ndarray,
|
||||
OutputKeys.OUTPUT_PCM: 'pcm',
|
||||
OutputKeys.OUTPUT_PCM_LIST: List[np.ndarray],
|
||||
OutputKeys.OUTPUT_WAV: np.ndarray,
|
||||
OutputKeys.OUTPUT_WAV: 'pcm',
|
||||
OutputKeys.OUTPUT_OBJ: Dict,
|
||||
OutputKeys.OUTPUT_MESH: np.ndarray,
|
||||
OutputKeys.IMG_EMBEDDING: np.ndarray,
|
||||
@@ -106,6 +111,11 @@ OutputTypes = {
|
||||
OutputKeys.PROBABILITIES: np.ndarray,
|
||||
OutputKeys.DIALOG_STATES: object,
|
||||
OutputKeys.VIDEO_EMBEDDING: np.ndarray,
|
||||
OutputKeys.PHRASE_PROTOTYPE: np.ndarray,
|
||||
OutputKeys.OBJECT_PROTOTYPE: np.ndarray,
|
||||
OutputKeys.SENTENCE_PROTOTYPE: np.ndarray,
|
||||
OutputKeys.EVENT_PROTOTYPE: np.ndarray,
|
||||
OutputKeys.TEXTVIDEO_SIM: np.ndarray,
|
||||
OutputKeys.UUID: str,
|
||||
OutputKeys.WORD: str,
|
||||
OutputKeys.KWS_LIST: List[str],
|
||||
@@ -329,6 +339,24 @@ OutputTypeSchema = {
|
||||
'type': 'number'
|
||||
}
|
||||
},
|
||||
OutputKeys.PHRASE_PROTOTYPE: {
|
||||
'type': 'array',
|
||||
'items': {
|
||||
'type': 'number'
|
||||
}
|
||||
},
|
||||
OutputKeys.OBJECT_PROTOTYPE: {
|
||||
'type': 'array',
|
||||
'items': {
|
||||
'type': 'number'
|
||||
}
|
||||
},
|
||||
OutputKeys.TEXTVIDEO_SIM: {
|
||||
'type': 'array',
|
||||
'items': {
|
||||
'type': 'number'
|
||||
}
|
||||
},
|
||||
OutputKeys.UUID: {
|
||||
'type': 'string'
|
||||
},
|
||||
@@ -688,6 +716,8 @@ TASK_OUTPUTS = {
|
||||
# }
|
||||
Tasks.portrait_matting: [OutputKeys.OUTPUT_IMG],
|
||||
Tasks.universal_matting: [OutputKeys.OUTPUT_IMG],
|
||||
Tasks.image_deblurring: [OutputKeys.OUTPUT_IMG],
|
||||
Tasks.image_face_fusion: [OutputKeys.OUTPUT_IMG],
|
||||
|
||||
# image_quality_assessment_mos result for a single image is a score in range [0, 1]
|
||||
# {0.5}
|
||||
@@ -700,6 +730,7 @@ TASK_OUTPUTS = {
|
||||
Tasks.image_colorization: [OutputKeys.OUTPUT_IMG],
|
||||
Tasks.image_color_enhancement: [OutputKeys.OUTPUT_IMG],
|
||||
Tasks.image_denoising: [OutputKeys.OUTPUT_IMG],
|
||||
Tasks.image_editing: [OutputKeys.OUTPUT_IMG],
|
||||
Tasks.image_portrait_enhancement: [OutputKeys.OUTPUT_IMG],
|
||||
Tasks.crowd_counting: [OutputKeys.SCORES, OutputKeys.OUTPUT_IMG],
|
||||
Tasks.image_inpainting: [OutputKeys.OUTPUT_IMG],
|
||||
@@ -721,6 +752,7 @@ TASK_OUTPUTS = {
|
||||
Tasks.video_deinterlace: [OutputKeys.OUTPUT_VIDEO],
|
||||
Tasks.nerf_recon_acc: [OutputKeys.OUTPUT],
|
||||
Tasks.nerf_recon_vq_compression: [OutputKeys.OUTPUT],
|
||||
Tasks.surface_recon_common: [OutputKeys.OUTPUT],
|
||||
Tasks.video_colorization: [OutputKeys.OUTPUT_VIDEO],
|
||||
|
||||
# image quality assessment degradation result for single image
|
||||
@@ -914,6 +946,32 @@ TASK_OUTPUTS = {
|
||||
# }
|
||||
Tasks.video_embedding: [OutputKeys.VIDEO_EMBEDDING],
|
||||
|
||||
# phrase prototype result for single sentence
|
||||
# {
|
||||
# "phrase_prototype": np.array with shape [K*D],
|
||||
# }
|
||||
# sentence prototype result for single sentence
|
||||
# {
|
||||
# "sentence_prototype": np.array with shape [1*D],
|
||||
# }
|
||||
# object prototype result for single video
|
||||
# {
|
||||
# "object_prototype": np.array with shape [N*K*D],
|
||||
# }
|
||||
# event prototype result for single video
|
||||
# {
|
||||
# "event_prototype": np.array with shape [N*M*D],
|
||||
# }
|
||||
# text search video result for single sentence
|
||||
# {
|
||||
# "textvideo_sim": np.array with shape [N*N],
|
||||
# }
|
||||
Tasks.text_video_retrieval: [
|
||||
OutputKeys.PHRASE_PROTOTYPE, OutputKeys.SENTENCE_PROTOTYPE,
|
||||
OutputKeys.OBJECT_PROTOTYPE, OutputKeys.EVENT_PROTOTYPE,
|
||||
OutputKeys.TEXTVIDEO_SIM
|
||||
],
|
||||
|
||||
# video stabilization task result for a single video
|
||||
# {"output_video": "path_to_rendered_video"}
|
||||
Tasks.video_stabilization: [OutputKeys.OUTPUT_VIDEO],
|
||||
@@ -1512,6 +1570,11 @@ TASK_OUTPUTS = {
|
||||
# "output_img": np.ndarray with shape [height, width, 3]
|
||||
# }
|
||||
Tasks.image_try_on: [OutputKeys.OUTPUT_IMG],
|
||||
# Tasks.human_image_generation result for a single sample
|
||||
# {
|
||||
# "output_img": np.ndarray with shape [height, width, 3]
|
||||
# }
|
||||
Tasks.human_image_generation: [OutputKeys.OUTPUT_IMG],
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -102,6 +102,18 @@ TASK_INPUTS = {
|
||||
InputType.IMAGE,
|
||||
Tasks.face_2d_keypoints:
|
||||
InputType.IMAGE,
|
||||
Tasks.face_liveness:
|
||||
InputType.IMAGE,
|
||||
Tasks.face_quality_assessment:
|
||||
InputType.IMAGE,
|
||||
Tasks.card_detection:
|
||||
InputType.IMAGE,
|
||||
Tasks.license_plate_detection:
|
||||
InputType.IMAGE,
|
||||
Tasks.lineless_table_recognition:
|
||||
InputType.IMAGE,
|
||||
Tasks.table_recognition:
|
||||
InputType.IMAGE,
|
||||
Tasks.face_detection:
|
||||
InputType.IMAGE,
|
||||
Tasks.facial_expression_recognition:
|
||||
@@ -118,14 +130,30 @@ TASK_INPUTS = {
|
||||
InputType.NUMBER,
|
||||
Tasks.image_classification:
|
||||
InputType.IMAGE,
|
||||
Tasks.image_quality_assessment_mos:
|
||||
InputType.IMAGE,
|
||||
Tasks.image_quality_assessment_degradation:
|
||||
InputType.IMAGE,
|
||||
Tasks.image_object_detection:
|
||||
InputType.IMAGE,
|
||||
Tasks.domain_specific_object_detection:
|
||||
InputType.IMAGE,
|
||||
Tasks.human_wholebody_keypoint:
|
||||
InputType.IMAGE,
|
||||
Tasks.image_segmentation:
|
||||
InputType.IMAGE,
|
||||
Tasks.portrait_matting:
|
||||
InputType.IMAGE,
|
||||
Tasks.universal_matting:
|
||||
InputType.IMAGE,
|
||||
Tasks.product_segmentation:
|
||||
InputType.IMAGE,
|
||||
Tasks.semantic_segmentation:
|
||||
InputType.IMAGE,
|
||||
Tasks.face_human_hand_detection:
|
||||
InputType.IMAGE,
|
||||
Tasks.hand_static:
|
||||
InputType.IMAGE,
|
||||
Tasks.image_fewshot_detection:
|
||||
InputType.IMAGE,
|
||||
Tasks.open_vocabulary_detection: {
|
||||
@@ -148,6 +176,8 @@ TASK_INPUTS = {
|
||||
InputType.IMAGE,
|
||||
Tasks.image_denoising:
|
||||
InputType.IMAGE,
|
||||
Tasks.image_body_reshaping:
|
||||
InputType.IMAGE,
|
||||
Tasks.image_portrait_enhancement:
|
||||
InputType.IMAGE,
|
||||
Tasks.crowd_counting:
|
||||
@@ -169,6 +199,12 @@ TASK_INPUTS = {
|
||||
'image': InputType.IMAGE,
|
||||
'prompt': InputType.TEXT,
|
||||
},
|
||||
Tasks.image_face_fusion: {
|
||||
'template': InputType.IMAGE,
|
||||
'user': InputType.IMAGE,
|
||||
},
|
||||
Tasks.image_deblurring:
|
||||
InputType.IMAGE,
|
||||
Tasks.video_colorization:
|
||||
InputType.VIDEO,
|
||||
|
||||
@@ -227,6 +263,10 @@ TASK_INPUTS = {
|
||||
InputKeys.IMAGE: InputType.IMAGE,
|
||||
InputKeys.IMAGE: InputType.IMAGE
|
||||
},
|
||||
Tasks.human_image_generation: {
|
||||
InputKeys.IMAGE: InputType.IMAGE,
|
||||
'target_pose_path': InputType.TEXT
|
||||
},
|
||||
|
||||
# ============ nlp tasks ===================
|
||||
Tasks.chat: [
|
||||
@@ -254,11 +294,15 @@ TASK_INPUTS = {
|
||||
Tasks.nli: (InputType.TEXT, InputType.TEXT),
|
||||
Tasks.sentiment_classification:
|
||||
InputType.TEXT,
|
||||
Tasks.zero_shot_classification: InputType.TEXT,
|
||||
Tasks.zero_shot_classification:
|
||||
InputType.TEXT,
|
||||
Tasks.relation_extraction:
|
||||
InputType.TEXT,
|
||||
Tasks.translation:
|
||||
InputType.TEXT,
|
||||
Tasks.text_summarization: [InputType.TEXT, {
|
||||
'text': InputType.TEXT,
|
||||
}],
|
||||
Tasks.competency_aware_translation:
|
||||
InputType.TEXT,
|
||||
Tasks.word_segmentation: [InputType.TEXT, {
|
||||
@@ -348,12 +392,17 @@ TASK_INPUTS = {
|
||||
InputType.AUDIO,
|
||||
Tasks.speaker_diarization_dialogue_detection:
|
||||
InputType.TEXT,
|
||||
Tasks.language_score_prediction:
|
||||
InputType.TEXT,
|
||||
Tasks.punctuation:
|
||||
InputType.TEXT,
|
||||
Tasks.speech_language_recognition:
|
||||
InputType.AUDIO,
|
||||
Tasks.speaker_diarization_semantic_speaker_turn_detection:
|
||||
InputType.TEXT,
|
||||
Tasks.inverse_text_processing:
|
||||
InputType.TEXT,
|
||||
Tasks.speaker_verification: [InputType.AUDIO, InputType.AUDIO],
|
||||
|
||||
# ============ multi-modal tasks ===================
|
||||
Tasks.image_captioning: [InputType.IMAGE, {
|
||||
@@ -384,6 +433,10 @@ TASK_INPUTS = {
|
||||
'img': InputType.IMAGE,
|
||||
'text': InputType.TEXT
|
||||
},
|
||||
Tasks.text_video_retrieval: {
|
||||
'video': InputType.VIDEO,
|
||||
'text': InputType.TEXT
|
||||
},
|
||||
Tasks.visual_question_answering: {
|
||||
'image': InputType.IMAGE,
|
||||
'text': InputType.TEXT
|
||||
@@ -415,4 +468,8 @@ TASK_INPUTS = {
|
||||
Tasks.text_to_360panorama_image: {
|
||||
'prompt': InputType.TEXT,
|
||||
},
|
||||
Tasks.image_editing: {
|
||||
'img': InputType.IMAGE,
|
||||
'prompts': InputType.LIST
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ if TYPE_CHECKING:
|
||||
from .linear_aec_pipeline import LinearAECPipeline
|
||||
from .text_to_speech_pipeline import TextToSpeechSambertHifiganPipeline
|
||||
from .inverse_text_processing_pipeline import InverseTextProcessingPipeline
|
||||
from .separation_pipeline import SeparationPipeline
|
||||
from .speaker_verification_pipeline import SpeakerVerificationPipeline
|
||||
else:
|
||||
_import_structure = {
|
||||
@@ -23,6 +24,7 @@ else:
|
||||
'text_to_speech_pipeline': ['TextToSpeechSambertHifiganPipeline'],
|
||||
'itn_inference_pipeline': ['InverseTextProcessingPipeline'],
|
||||
'inverse_text_processing_pipeline': ['InverseTextProcessingPipeline'],
|
||||
'separation_pipeline': ['SeparationPipeline'],
|
||||
'speaker_verification_pipeline': ['SpeakerVerificationPipeline']
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,144 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import io
|
||||
import os
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
import torch
|
||||
import torchaudio
|
||||
|
||||
from modelscope.fileio import File
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.base import InputModel, Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
__all__ = ['LanguageRecognitionPipeline']
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.speech_language_recognition,
|
||||
module_name=Pipelines.speech_language_recognition_eres2net)
|
||||
class LanguageRecognitionPipeline(Pipeline):
|
||||
"""Language Recognition Inference Pipeline
|
||||
use `model` to create a Language Recognition pipeline.
|
||||
|
||||
Args:
|
||||
model (LanguageRecognitionPipeline): A model instance, or a model local dir, or a model id in the model hub.
|
||||
kwargs (dict, `optional`):
|
||||
Extra kwargs passed into the pipeline's constructor.
|
||||
Example:
|
||||
>>> from modelscope.pipelines import pipeline
|
||||
>>> from modelscope.utils.constant import Tasks
|
||||
>>> p = pipeline(
|
||||
>>> task=Tasks.speech_language_recognition, model='damo/speech_eres2net_base_lre_en-cn_16k')
|
||||
>>> print(p(audio_in))
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, model: InputModel, **kwargs):
|
||||
"""use `model` to create a Language Recognition pipeline for prediction
|
||||
Args:
|
||||
model (str): a valid offical model id
|
||||
"""
|
||||
super().__init__(model=model, **kwargs)
|
||||
self.model_config = self.model.model_config
|
||||
self.languages = self.model_config['languages']
|
||||
|
||||
def __call__(self,
|
||||
in_audios: Union[str, list, np.ndarray],
|
||||
out_file: str = None):
|
||||
wavs = self.preprocess(in_audios)
|
||||
results = self.forward(wavs)
|
||||
outputs = self.postprocess(results, in_audios, out_file)
|
||||
return outputs
|
||||
|
||||
def forward(self, inputs: list):
|
||||
results = []
|
||||
for x in inputs:
|
||||
results.append(self.model(x).item())
|
||||
return results
|
||||
|
||||
def postprocess(self,
|
||||
inputs: list,
|
||||
in_audios: Union[str, list, np.ndarray],
|
||||
out_file=None):
|
||||
if isinstance(in_audios, str):
|
||||
output = {OutputKeys.TEXT: self.languages[inputs[0]]}
|
||||
else:
|
||||
output = {OutputKeys.TEXT: [self.languages[i] for i in inputs]}
|
||||
if out_file is not None:
|
||||
out_lines = []
|
||||
for i, audio in enumerate(in_audios):
|
||||
if isinstance(audio, str):
|
||||
audio_id = os.path.basename(audio).rsplit('.', 1)[0]
|
||||
else:
|
||||
audio_id = i
|
||||
out_lines.append('%s %s\n' %
|
||||
(audio_id, self.languages[inputs[i]]))
|
||||
with open(out_file, 'w') as f:
|
||||
for i in out_lines:
|
||||
f.write(i)
|
||||
return output
|
||||
|
||||
def preprocess(self, inputs: Union[str, list, np.ndarray]):
|
||||
output = []
|
||||
if isinstance(inputs, str):
|
||||
file_bytes = File.read(inputs)
|
||||
data, fs = sf.read(io.BytesIO(file_bytes), dtype='float32')
|
||||
if len(data.shape) == 2:
|
||||
data = data[:, 0]
|
||||
data = torch.from_numpy(data).unsqueeze(0)
|
||||
if fs != self.model_config['sample_rate']:
|
||||
logger.warning(
|
||||
'The sample rate of audio is not %d, resample it.'
|
||||
% self.model_config['sample_rate'])
|
||||
data, fs = torchaudio.sox_effects.apply_effects_tensor(
|
||||
data,
|
||||
fs,
|
||||
effects=[['rate',
|
||||
str(self.model_config['sample_rate'])]])
|
||||
data = data.squeeze(0)
|
||||
output.append(data)
|
||||
else:
|
||||
for i in range(len(inputs)):
|
||||
if isinstance(inputs[i], str):
|
||||
file_bytes = File.read(inputs[i])
|
||||
data, fs = sf.read(io.BytesIO(file_bytes), dtype='float32')
|
||||
if len(data.shape) == 2:
|
||||
data = data[:, 0]
|
||||
data = torch.from_numpy(data).unsqueeze(0)
|
||||
if fs != self.model_config['sample_rate']:
|
||||
logger.warning(
|
||||
'The sample rate of audio is not %d, resample it.'
|
||||
% self.model_config['sample_rate'])
|
||||
data, fs = torchaudio.sox_effects.apply_effects_tensor(
|
||||
data,
|
||||
fs,
|
||||
effects=[[
|
||||
'rate',
|
||||
str(self.model_config['sample_rate'])
|
||||
]])
|
||||
data = data.squeeze(0)
|
||||
elif isinstance(inputs[i], np.ndarray):
|
||||
assert len(
|
||||
inputs[i].shape
|
||||
) == 1, 'modelscope error: Input array should be [N, T]'
|
||||
data = inputs[i]
|
||||
if data.dtype in ['int16', 'int32', 'int64']:
|
||||
data = (data / (1 << 15)).astype('float32')
|
||||
else:
|
||||
data = data.astype('float32')
|
||||
data = torch.from_numpy(data)
|
||||
else:
|
||||
raise ValueError(
|
||||
'modelscope error: The input type is restricted to audio address and nump array.'
|
||||
)
|
||||
output.append(data)
|
||||
return output
|
||||
@@ -1,5 +1,6 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import ast
|
||||
import io
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
@@ -180,7 +181,14 @@ class SegmentationClusteringPipeline(Pipeline):
|
||||
model=self.config['vad_model'])
|
||||
vad_time = self.vad_pipeline(audio, audio_fs=self.fs)
|
||||
vad_segments = []
|
||||
for t in vad_time['text']:
|
||||
if isinstance(vad_time['text'], str):
|
||||
vad_time_list = ast.literal_eval(vad_time['text'])
|
||||
elif isinstance(vad_time['text'], list):
|
||||
vad_time_list = vad_time['text']
|
||||
else:
|
||||
raise ValueError('Incorrect vad result. Get %s' %
|
||||
(type(vad_time['text'])))
|
||||
for t in vad_time_list:
|
||||
st = int(t[0]) / 1000
|
||||
ed = int(t[1]) / 1000
|
||||
vad_segments.append(
|
||||
|
||||
@@ -8,7 +8,7 @@ import soundfile as sf
|
||||
import torch
|
||||
|
||||
from modelscope.fileio import File
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.metainfo import Models, Pipelines
|
||||
from modelscope.models.base import Input
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines import Pipeline
|
||||
@@ -20,7 +20,11 @@ logger = get_logger()
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.speech_separation, module_name=Pipelines.speech_separation)
|
||||
Tasks.speech_separation,
|
||||
module_name=Models.speech_mossformer_separation_temporal_8k)
|
||||
@PIPELINES.register_module(
|
||||
Tasks.speech_separation,
|
||||
module_name=Models.speech_mossformer2_separation_temporal_8k)
|
||||
class SeparationPipeline(Pipeline):
|
||||
|
||||
def __init__(self, model, **kwargs):
|
||||
|
||||
243
modelscope/pipelines/audio/speech_separation_pipeline.py
Normal file
243
modelscope/pipelines/audio/speech_separation_pipeline.py
Normal file
@@ -0,0 +1,243 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
from typing import Any, Dict, List, Sequence, Tuple, Union
|
||||
|
||||
import json
|
||||
import yaml
|
||||
from funasr.utils import asr_utils
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.models import Model
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.base import Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.utils.audio.audio_utils import (generate_scp_from_url,
|
||||
update_local_model)
|
||||
from modelscope.utils.constant import Frameworks, ModelFile, Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
__all__ = ['SeparationPipeline']
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.speech_separation, module_name=Pipelines.funasr_speech_separation)
|
||||
class SeparationPipeline(Pipeline):
|
||||
"""Speech Separation Inference Pipeline
|
||||
use `model` to create a speech separation pipeline for prediction.
|
||||
|
||||
Args:
|
||||
model: A model instance, or a model local dir, or a model id in the model hub.
|
||||
kwargs (dict, `optional`):
|
||||
Extra kwargs passed into the preprocessor's constructor.
|
||||
|
||||
Example:
|
||||
>>> from modelscope.pipelines import pipeline
|
||||
>>> pipeline = pipeline(
|
||||
>>> task=Tasks.speech_separation, model='damo/speech_separation_mossformer_8k_pytorch')
|
||||
>>> audio_in = 'mix_speech.wav'
|
||||
>>> print(pipeline(audio_in))
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model: Union[Model, str] = None,
|
||||
ngpu: int = 1,
|
||||
**kwargs):
|
||||
"""use `model` to create an speech separation pipeline for prediction
|
||||
"""
|
||||
super().__init__(model=model, **kwargs)
|
||||
config_path = os.path.join(model, ModelFile.CONFIGURATION)
|
||||
self.cmd = self.get_cmd(config_path, kwargs, model)
|
||||
|
||||
from funasr.bin import ss_inference_launch
|
||||
self.funasr_infer_modelscope = ss_inference_launch.inference_launch(
|
||||
mode=self.cmd['mode'],
|
||||
batch_size=self.cmd['batch_size'],
|
||||
ngpu=ngpu,
|
||||
log_level=self.cmd['log_level'],
|
||||
ss_infer_config=self.cmd['ss_infer_config'],
|
||||
ss_model_file=self.cmd['ss_model_file'],
|
||||
output_dir=self.cmd['output_dir'],
|
||||
dtype=self.cmd['dtype'],
|
||||
seed=self.cmd['seed'],
|
||||
num_workers=self.cmd['num_workers'],
|
||||
num_spks=self.cmd['num_spks'],
|
||||
param_dict=self.cmd['param_dict'],
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def __call__(self,
|
||||
audio_in: Union[str, bytes],
|
||||
audio_fs: int = None,
|
||||
recog_type: str = None,
|
||||
audio_format: str = None,
|
||||
output_dir: str = None,
|
||||
param_dict: dict = None,
|
||||
**kwargs) -> Dict[str, Any]:
|
||||
"""
|
||||
Decoding the input audios
|
||||
Args:
|
||||
audio_in('str' or 'bytes'):
|
||||
- A string containing a local path to a wav file
|
||||
- A string containing a local path to a scp
|
||||
- A string containing a wav url
|
||||
- A bytes input
|
||||
audio_fs('int'):
|
||||
frequency of sample
|
||||
recog_type('str'):
|
||||
recog type for wav file or datasets file ('wav', 'test', 'dev', 'train')
|
||||
audio_format('str'):
|
||||
audio format ('pcm', 'scp', 'kaldi_ark', 'tfrecord')
|
||||
output_dir('str'):
|
||||
output dir
|
||||
param_dict('dict'):
|
||||
extra kwargs
|
||||
Return:
|
||||
A dictionary of result or a list of dictionary of result.
|
||||
|
||||
The dictionary contain the following keys:
|
||||
- **text** ('str') --The vad result.
|
||||
"""
|
||||
self.audio_in = None
|
||||
self.raw_inputs = None
|
||||
self.recog_type = recog_type
|
||||
self.audio_format = audio_format
|
||||
self.audio_fs = None
|
||||
checking_audio_fs = None
|
||||
if output_dir is not None:
|
||||
self.cmd['output_dir'] = output_dir
|
||||
if param_dict is not None:
|
||||
self.cmd['param_dict'] = param_dict
|
||||
if isinstance(audio_in, str):
|
||||
# for funasr code, generate wav.scp from url or local path
|
||||
self.audio_in, self.raw_inputs = generate_scp_from_url(audio_in)
|
||||
elif isinstance(audio_in, bytes):
|
||||
self.audio_in = audio_in
|
||||
self.raw_inputs = None
|
||||
else:
|
||||
import numpy
|
||||
import torch
|
||||
if isinstance(audio_in, torch.Tensor):
|
||||
self.audio_in = None
|
||||
self.raw_inputs = audio_in
|
||||
elif isinstance(audio_in, numpy.ndarray):
|
||||
self.audio_in = None
|
||||
self.raw_inputs = audio_in
|
||||
|
||||
# set the sample_rate of audio_in if checking_audio_fs is valid
|
||||
if checking_audio_fs is not None:
|
||||
self.audio_fs = checking_audio_fs
|
||||
|
||||
if recog_type is None or audio_format is None:
|
||||
self.recog_type, self.audio_format, self.audio_in = asr_utils.type_checking(
|
||||
audio_in=self.audio_in,
|
||||
recog_type=recog_type,
|
||||
audio_format=audio_format)
|
||||
|
||||
if hasattr(asr_utils,
|
||||
'sample_rate_checking') and self.audio_in is not None:
|
||||
checking_audio_fs = asr_utils.sample_rate_checking(
|
||||
self.audio_in, self.audio_format)
|
||||
if checking_audio_fs is not None:
|
||||
self.audio_fs = checking_audio_fs
|
||||
if audio_fs is not None:
|
||||
self.cmd['fs']['audio_fs'] = audio_fs
|
||||
else:
|
||||
self.cmd['fs']['audio_fs'] = self.audio_fs
|
||||
|
||||
output = self.forward(self.audio_in, **kwargs)
|
||||
return output
|
||||
|
||||
def get_cmd(self, config_path, extra_args, model_path) -> Dict[str, Any]:
|
||||
model_cfg = json.loads(open(config_path).read())
|
||||
model_dir = os.path.dirname(config_path)
|
||||
# generate inference command
|
||||
ss_model_path = os.path.join(
|
||||
model_dir, model_cfg['model']['model_config']['ss_model_name'])
|
||||
ss_model_config = os.path.join(
|
||||
model_dir, model_cfg['model']['model_config']['ss_model_config'])
|
||||
mode = model_cfg['model']['model_config']['mode']
|
||||
frontend_conf = None
|
||||
if os.path.exists(ss_model_config):
|
||||
config_file = open(ss_model_config, encoding='utf-8')
|
||||
root = yaml.full_load(config_file)
|
||||
config_file.close()
|
||||
if 'frontend_conf' in root:
|
||||
frontend_conf = root['frontend_conf']
|
||||
update_local_model(model_cfg['model']['model_config'], model_path,
|
||||
extra_args)
|
||||
|
||||
cmd = {
|
||||
'mode': mode,
|
||||
'batch_size': 1,
|
||||
'ngpu': 1, # 0: only CPU, ngpu>=1: gpu number if cuda is available
|
||||
'log_level': 'ERROR',
|
||||
'ss_infer_config': ss_model_config,
|
||||
'ss_model_file': ss_model_path,
|
||||
'output_dir': None,
|
||||
'dtype': 'float32',
|
||||
'seed': 0,
|
||||
'num_workers': 0,
|
||||
'num_spks': 2,
|
||||
'param_dict': None,
|
||||
'fs': {
|
||||
'model_fs': None,
|
||||
'audio_fs': None
|
||||
}
|
||||
}
|
||||
if frontend_conf is not None and 'fs' in frontend_conf:
|
||||
cmd['fs']['model_fs'] = frontend_conf['fs']
|
||||
|
||||
user_args_dict = [
|
||||
'output_dir', 'batch_size', 'mode', 'ngpu', 'param_dict',
|
||||
'num_workers', 'fs'
|
||||
]
|
||||
|
||||
for user_args in user_args_dict:
|
||||
if user_args in extra_args:
|
||||
if extra_args.get(user_args) is not None:
|
||||
cmd[user_args] = extra_args[user_args]
|
||||
del extra_args[user_args]
|
||||
|
||||
return cmd
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Any],
|
||||
**post_params) -> Dict[str, Any]:
|
||||
return inputs
|
||||
|
||||
def forward(self, audio_in: Dict[str, Any], **kwargs) -> Dict[str, Any]:
|
||||
"""Decoding
|
||||
"""
|
||||
logger.info('Speech Separation Processing ...')
|
||||
# generate inputs
|
||||
data_cmd: Sequence[Tuple[str, str, str]]
|
||||
if isinstance(self.audio_in, bytes):
|
||||
data_cmd = [self.audio_in, 'speech', 'bytes']
|
||||
elif isinstance(self.audio_in, str):
|
||||
data_cmd = [self.audio_in, 'speech', 'sound']
|
||||
elif self.raw_inputs is not None:
|
||||
data_cmd = None
|
||||
self.cmd['name_and_type'] = data_cmd
|
||||
self.cmd['raw_inputs'] = self.raw_inputs
|
||||
self.cmd['audio_in'] = self.audio_in
|
||||
|
||||
ss_result = self.run_inference(self.cmd, **kwargs)
|
||||
|
||||
return ss_result
|
||||
|
||||
def run_inference(self, cmd, **kwargs):
|
||||
ss_result = []
|
||||
if self.framework == Frameworks.torch:
|
||||
ss_result = self.funasr_infer_modelscope(
|
||||
data_path_and_name_and_type=cmd['name_and_type'],
|
||||
raw_inputs=cmd['raw_inputs'],
|
||||
output_dir_v2=cmd['output_dir'],
|
||||
fs=cmd['fs'],
|
||||
param_dict=cmd['param_dict'],
|
||||
**kwargs)
|
||||
else:
|
||||
raise ValueError('model type is mismatching')
|
||||
|
||||
return ss_result
|
||||
@@ -30,6 +30,7 @@ if TYPE_CHECKING:
|
||||
from .image_colorization_pipeline import ImageColorizationPipeline
|
||||
from .image_denoise_pipeline import ImageDenoisePipeline
|
||||
from .image_deblur_pipeline import ImageDeblurPipeline
|
||||
from .image_editing_pipeline import ImageEditingPipeline
|
||||
from .image_instance_segmentation_pipeline import ImageInstanceSegmentationPipeline
|
||||
from .image_matting_pipeline import ImageMattingPipeline
|
||||
from .image_portrait_enhancement_pipeline import ImagePortraitEnhancementPipeline
|
||||
@@ -104,6 +105,7 @@ if TYPE_CHECKING:
|
||||
from .image_human_parsing_pipeline import ImageHumanParsingPipeline
|
||||
from .nerf_recon_acc_pipeline import NeRFReconAccPipeline
|
||||
from .nerf_recon_4k_pipeline import NeRFRecon4KPipeline
|
||||
from .surface_recon_common_pipeline import SurfaceReconCommonPipeline
|
||||
from .controllable_image_generation_pipeline import ControllableImageGenerationPipeline
|
||||
from .image_bts_depth_estimation_pipeline import ImageBTSDepthEstimationPipeline
|
||||
from .pedestrian_attribute_recognition_pipeline import PedestrainAttributeRecognitionPipeline
|
||||
@@ -136,6 +138,7 @@ else:
|
||||
'image_cartoon_pipeline': ['ImageCartoonPipeline'],
|
||||
'image_denoise_pipeline': ['ImageDenoisePipeline'],
|
||||
'image_deblur_pipeline': ['ImageDeblurPipeline'],
|
||||
'image_editing_pipeline': ['ImageEditingPipeline'],
|
||||
'image_color_enhance_pipeline': ['ImageColorEnhancePipeline'],
|
||||
'image_colorization_pipeline': ['ImageColorizationPipeline'],
|
||||
'image_instance_segmentation_pipeline':
|
||||
@@ -256,6 +259,7 @@ else:
|
||||
'image_human_parsing_pipeline': ['ImageHumanParsingPipeline'],
|
||||
'nerf_recon_acc_pipeline': ['NeRFReconAccPipeline'],
|
||||
'nerf_recon_4k_pipeline': ['NeRFRecon4KPipeline'],
|
||||
'surface_recon_common_pipeline': ['SurfaceReconCommonPipeline'],
|
||||
'controllable_image_generation_pipeline': [
|
||||
'ControllableImageGenerationPipeline'
|
||||
],
|
||||
|
||||
@@ -163,7 +163,7 @@ class Body3DKeypointsPipeline(Pipeline):
|
||||
box = kps_2d['boxes'][
|
||||
0] # box: [[[x1, y1], [x2, y2]]], N human boxes per frame, [0] represent using first detected bbox
|
||||
pose = kps_2d['keypoints'][0] # keypoints: [15, 2]
|
||||
score = kps_2d['scores'][0] # keypoints: [15, 2]
|
||||
score = np.array(kps_2d['scores'][0]).max()
|
||||
all_2d_poses.append(pose)
|
||||
all_boxes_with_socre.append(
|
||||
list(np.array(box).reshape(
|
||||
|
||||
@@ -31,7 +31,7 @@ class FaceEmotionPipeline(Pipeline):
|
||||
logger.info('load model done')
|
||||
|
||||
def preprocess(self, input: Input) -> Dict[str, Any]:
|
||||
img = LoadImage.convert_to_ndarray(input['img_path'])
|
||||
img = LoadImage.convert_to_ndarray(input)
|
||||
return img
|
||||
|
||||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
|
||||
|
||||
@@ -32,14 +32,13 @@ class NanoDettForFaceHumanHandDetectionPipeline(Pipeline):
|
||||
logger.info('load model done')
|
||||
|
||||
def preprocess(self, input: Input) -> Dict[str, Any]:
|
||||
img = LoadImage.convert_to_ndarray(input['input_path'])
|
||||
img = LoadImage.convert_to_ndarray(input)
|
||||
return img
|
||||
|
||||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
|
||||
|
||||
cls_list, bbox_list, score_list = det_infer.inference(
|
||||
self.model, self.device, input)
|
||||
logger.info(cls_list, bbox_list, score_list)
|
||||
return {
|
||||
OutputKeys.LABELS: cls_list,
|
||||
OutputKeys.BOXES: bbox_list,
|
||||
|
||||
@@ -30,7 +30,7 @@ class HandStaticPipeline(Pipeline):
|
||||
logger.info('load model done')
|
||||
|
||||
def preprocess(self, input: Input) -> Dict[str, Any]:
|
||||
img = LoadImage.convert_to_ndarray(input['img_path'])
|
||||
img = LoadImage.convert_to_ndarray(input)
|
||||
return img
|
||||
|
||||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
|
||||
|
||||
60
modelscope/pipelines/cv/human_image_generation_pipeline.py
Normal file
60
modelscope/pipelines/cv/human_image_generation_pipeline.py
Normal file
@@ -0,0 +1,60 @@
|
||||
# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.models.cv.human_image_generation import \
|
||||
human_image_generation_infer
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.base import Input, Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.preprocessors import LoadImage
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.human_image_generation, module_name=Pipelines.human_image_generation)
|
||||
class FreqHPTForHumanImageGenerationPipeline(Pipeline):
|
||||
""" Human Image Generation Pipeline.
|
||||
Examples:
|
||||
>>> human_image_generation = pipeline(Tasks.human_image_generation, model='damo/cv_FreqHPT_human-image-generation')
|
||||
>>> input_images = {'source_img_path': '/your_path/source_img.jpg',
|
||||
>>> 'target_pose_path': '/your_path/target_pose.txt'}
|
||||
>>> result = human_image_generation(input_images)
|
||||
>>> result[OutputKeys.OUTPUT_IMG]
|
||||
"""
|
||||
|
||||
def __init__(self, model: str, **kwargs):
|
||||
"""
|
||||
use `model` to create human image generation pipeline for prediction
|
||||
Args:
|
||||
model: model id on modelscope hub.
|
||||
"""
|
||||
|
||||
super().__init__(model=model, **kwargs)
|
||||
self.model_path = model
|
||||
logger.info('load model done')
|
||||
if torch.cuda.is_available():
|
||||
self.device = 'cuda'
|
||||
logger.info('Use GPU')
|
||||
else:
|
||||
self.device = 'cpu'
|
||||
logger.info('Use CPU')
|
||||
|
||||
def preprocess(self, input: Input) -> Dict[str, Any]:
|
||||
return input
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return inputs
|
||||
|
||||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
|
||||
human_image_generation = human_image_generation_infer.infer(
|
||||
self.model, input['source_img_path'], input['target_pose_path'],
|
||||
self.device)
|
||||
return {OutputKeys.OUTPUT_IMG: human_image_generation}
|
||||
@@ -70,7 +70,7 @@ class ImageCartoonPipeline(Pipeline):
|
||||
|
||||
def preprocess(self, input: Input) -> Dict[str, Any]:
|
||||
img = LoadImage.convert_to_ndarray(input)
|
||||
img = img.astype(np.float)
|
||||
img = img.astype(float)
|
||||
result = {'img': img}
|
||||
return result
|
||||
|
||||
|
||||
365
modelscope/pipelines/cv/image_editing_pipeline.py
Normal file
365
modelscope/pipelines/cv/image_editing_pipeline.py
Normal file
@@ -0,0 +1,365 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os.path
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from diffusers import DDIMScheduler, StableDiffusionPipeline
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
from tqdm import tqdm
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.models.cv.image_editing import (
|
||||
MutualSelfAttentionControl, regiter_attention_editor_diffusers)
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.pipelines.multi_modal.diffusers_wrapped.diffusers_pipeline import \
|
||||
DiffusersPipeline
|
||||
from modelscope.preprocessors import LoadImage
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
__all__ = ['ImageEditingPipeline']
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.image_editing, module_name=Pipelines.image_editing)
|
||||
class ImageEditingPipeline(DiffusersPipeline):
|
||||
|
||||
def __init__(self, model=str, preprocessor=None, **kwargs):
|
||||
""" MasaCtrl Image Editing Pipeline.
|
||||
|
||||
Examples:
|
||||
|
||||
>>> import cv2
|
||||
>>> from modelscope.pipelines import pipeline
|
||||
>>> from modelscope.utils.constant import Tasks
|
||||
|
||||
>>> prompts = [
|
||||
>>> "", # source prompt
|
||||
>>> "a photo of a running corgi" # target prompt
|
||||
>>> ]
|
||||
>>> output_image_path = './result.png'
|
||||
>>> img = 'https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/public/ModelScope/test/images/corgi.jpg'
|
||||
>>> input = {'img': img, 'prompts': prompts}
|
||||
>>>
|
||||
>>> pipe = pipeline(
|
||||
>>> Tasks.image_editing,
|
||||
>>> model='damo/cv_masactrl_image-editing')
|
||||
>>>
|
||||
>>> output = pipe(input)['output_img']
|
||||
>>> cv2.imwrite(output_image_path, output)
|
||||
>>> print('pipeline: the output image path is {}'.format(output_image_path))
|
||||
"""
|
||||
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
|
||||
|
||||
torch_dtype = kwargs.get('torch_dtype', torch.float32)
|
||||
self._device = getattr(
|
||||
kwargs, 'device',
|
||||
torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
|
||||
logger.info('load image editing pipeline done')
|
||||
scheduler = DDIMScheduler.from_pretrained(
|
||||
os.path.join(model, 'stable-diffusion-v1-4'),
|
||||
subfolder='scheduler')
|
||||
self.pipeline = _MasaCtrlPipeline.from_pretrained(
|
||||
os.path.join(model, 'stable-diffusion-v1-4'),
|
||||
scheduler=scheduler,
|
||||
torch_dtype=torch_dtype,
|
||||
use_safetensors=True).to(self._device)
|
||||
|
||||
def preprocess(self, input: Dict[str, Any]) -> Dict[str, Any]:
|
||||
img = LoadImage.convert_to_img(input.get('img'))
|
||||
test_transforms = transforms.Compose(
|
||||
[transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5])]) # [-1, 1]
|
||||
img = test_transforms(img).unsqueeze(0)
|
||||
img = F.interpolate(img, (512, 512))
|
||||
input['img'] = img.to(self._device)
|
||||
return input
|
||||
|
||||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
|
||||
|
||||
if not isinstance(input, dict):
|
||||
raise ValueError(
|
||||
f'Expected the input to be a dictionary, but got {type(input)}'
|
||||
)
|
||||
prompts = input.get('prompts')
|
||||
start_code, latents_list = self.pipeline.invert(
|
||||
input.get('img'),
|
||||
prompts[0],
|
||||
guidance_scale=7.5,
|
||||
num_inference_steps=50,
|
||||
return_intermediates=True)
|
||||
start_code = start_code.expand(len(prompts), -1, -1, -1)
|
||||
STEP, LAYER = 4, 10
|
||||
editor = MutualSelfAttentionControl(STEP, LAYER)
|
||||
regiter_attention_editor_diffusers(self.pipeline, editor)
|
||||
|
||||
# inference the synthesized image
|
||||
output = self.pipeline(
|
||||
prompts,
|
||||
latents=start_code,
|
||||
guidance_scale=input.get('guidance_scale', 7.5),
|
||||
)[-1:]
|
||||
|
||||
return {'output_tensor': output}
|
||||
|
||||
def postprocess(self, input: Dict[str, Any]) -> Dict[str, Any]:
|
||||
output_img = (input['output_tensor'].squeeze(0) * 255).cpu().permute(
|
||||
1, 2, 0).numpy().astype('uint8')
|
||||
return {OutputKeys.OUTPUT_IMG: output_img[:, :, ::-1]}
|
||||
|
||||
|
||||
class _MasaCtrlPipeline(StableDiffusionPipeline):
|
||||
|
||||
def next_step(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
timestep: int,
|
||||
x: torch.FloatTensor,
|
||||
eta=0,
|
||||
verbose=False,
|
||||
):
|
||||
"""
|
||||
Inverse sampling for DDIM Inversion
|
||||
x_t -> x_(t+1)
|
||||
"""
|
||||
if verbose:
|
||||
print('timestep: ', timestep)
|
||||
next_step = timestep
|
||||
timestep = min(
|
||||
timestep - self.scheduler.config.num_train_timesteps
|
||||
// self.scheduler.num_inference_steps, 999)
|
||||
alpha_prod_t = self.scheduler.alphas_cumprod[
|
||||
timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod
|
||||
alpha_prod_t_next = self.scheduler.alphas_cumprod[next_step]
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5
|
||||
pred_dir = (1 - alpha_prod_t_next)**0.5 * model_output
|
||||
x_next = alpha_prod_t_next**0.5 * pred_x0 + pred_dir
|
||||
return x_next, pred_x0
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
timestep: int,
|
||||
x: torch.FloatTensor,
|
||||
eta: float = 0.0,
|
||||
verbose=False,
|
||||
):
|
||||
"""
|
||||
predict the sample the next step in the denoise process.
|
||||
x_t -> x_(t-1)
|
||||
"""
|
||||
prev_timestep = timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps
|
||||
alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
|
||||
alpha_prod_t_prev = self.scheduler.alphas_cumprod[
|
||||
prev_timestep] if prev_timestep > 0 else self.scheduler.final_alpha_cumprod
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5
|
||||
pred_dir = (1 - alpha_prod_t_prev)**0.5 * model_output
|
||||
x_prev = alpha_prod_t_prev**0.5 * pred_x0 + pred_dir
|
||||
return x_prev, pred_x0
|
||||
|
||||
@torch.no_grad()
|
||||
def image2latent(self, image):
|
||||
DEVICE = self._execution_device
|
||||
if type(image) is Image:
|
||||
image = np.array(image)
|
||||
image = torch.from_numpy(image).float() / 127.5 - 1
|
||||
image = image.permute(2, 0, 1).unsqueeze(0).to(DEVICE)
|
||||
# input image density range [-1, 1]
|
||||
latents = self.vae.encode(image)['latent_dist'].mean
|
||||
latents = latents * 0.18215
|
||||
return latents
|
||||
|
||||
@torch.no_grad()
|
||||
def latent2image(self, latents, return_type='pt'):
|
||||
latents = 1 / 0.18215 * latents.detach()
|
||||
image = self.vae.decode(latents)['sample']
|
||||
if return_type == 'np':
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).numpy()[0]
|
||||
image = (image * 255).astype(np.uint8)
|
||||
elif return_type == 'pt':
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
|
||||
return image
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self,
|
||||
prompt,
|
||||
batch_size=1,
|
||||
height=512,
|
||||
width=512,
|
||||
num_inference_steps=50,
|
||||
guidance_scale=7.5,
|
||||
eta=0.0,
|
||||
latents=None,
|
||||
unconditioning=None,
|
||||
neg_prompt=None,
|
||||
ref_intermediate_latents=None,
|
||||
return_intermediates=False,
|
||||
**kwds):
|
||||
DEVICE = self._execution_device
|
||||
if isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
elif isinstance(prompt, str):
|
||||
if batch_size > 1:
|
||||
prompt = [prompt] * batch_size
|
||||
|
||||
# text embeddings
|
||||
text_input = self.tokenizer(
|
||||
prompt, padding='max_length', max_length=77, return_tensors='pt')
|
||||
|
||||
text_embeddings = self.text_encoder(text_input.input_ids.to(DEVICE))[0]
|
||||
print('input text embeddings :', text_embeddings.shape)
|
||||
|
||||
# define initial latents
|
||||
latents_shape = (batch_size, self.unet.in_channels, height // 8,
|
||||
width // 8)
|
||||
if latents is None:
|
||||
latents = torch.randn(latents_shape, device=DEVICE)
|
||||
else:
|
||||
assert latents.shape == latents_shape, f'The shape of input latent tensor {latents.shape} should equal ' \
|
||||
f'to predefined one.'
|
||||
|
||||
# unconditional embedding for classifier free guidance
|
||||
if guidance_scale > 1.:
|
||||
if neg_prompt:
|
||||
uc_text = neg_prompt
|
||||
else:
|
||||
uc_text = ''
|
||||
unconditional_input = self.tokenizer(
|
||||
[uc_text] * batch_size,
|
||||
padding='max_length',
|
||||
max_length=77,
|
||||
return_tensors='pt')
|
||||
unconditional_embeddings = self.text_encoder(
|
||||
unconditional_input.input_ids.to(DEVICE))[0]
|
||||
text_embeddings = torch.cat(
|
||||
[unconditional_embeddings, text_embeddings], dim=0)
|
||||
|
||||
print('latents shape: ', latents.shape)
|
||||
# iterative sampling
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
latents_list = [latents]
|
||||
pred_x0_list = [latents]
|
||||
for i, t in enumerate(
|
||||
tqdm(self.scheduler.timesteps, desc='DDIM Sampler')):
|
||||
if ref_intermediate_latents is not None:
|
||||
# note that the batch_size >= 2
|
||||
latents_ref = ref_intermediate_latents[-1 - i]
|
||||
_, latents_cur = latents.chunk(2)
|
||||
latents = torch.cat([latents_ref, latents_cur])
|
||||
|
||||
if guidance_scale > 1.:
|
||||
model_inputs = torch.cat([latents] * 2)
|
||||
else:
|
||||
model_inputs = latents
|
||||
if unconditioning is not None and isinstance(unconditioning, list):
|
||||
_, text_embeddings = text_embeddings.chunk(2)
|
||||
text_embeddings = torch.cat([
|
||||
unconditioning[i].expand(*text_embeddings.shape),
|
||||
text_embeddings
|
||||
])
|
||||
# predict the noise
|
||||
noise_pred = self.unet(
|
||||
model_inputs, t, encoder_hidden_states=text_embeddings).sample
|
||||
if guidance_scale > 1.:
|
||||
noise_pred_uncon, noise_pred_con = noise_pred.chunk(2, dim=0)
|
||||
noise_pred = noise_pred_uncon + guidance_scale * (
|
||||
noise_pred_con - noise_pred_uncon)
|
||||
# compute the previous noise sample x_t -> x_t-1
|
||||
latents, pred_x0 = self.step(noise_pred, t, latents)
|
||||
latents_list.append(latents)
|
||||
pred_x0_list.append(pred_x0)
|
||||
|
||||
image = self.latent2image(latents, return_type='pt')
|
||||
if return_intermediates:
|
||||
pred_x0_list = [
|
||||
self.latent2image(img, return_type='pt')
|
||||
for img in pred_x0_list
|
||||
]
|
||||
latents_list = [
|
||||
self.latent2image(img, return_type='pt')
|
||||
for img in latents_list
|
||||
]
|
||||
return image, pred_x0_list, latents_list
|
||||
return image
|
||||
|
||||
@torch.no_grad()
|
||||
def invert(self,
|
||||
image: torch.Tensor,
|
||||
prompt,
|
||||
num_inference_steps=50,
|
||||
guidance_scale=7.5,
|
||||
eta=0.0,
|
||||
return_intermediates=False,
|
||||
**kwds):
|
||||
"""
|
||||
invert a real image into noise map with determinisc DDIM inversion
|
||||
"""
|
||||
DEVICE = self._execution_device
|
||||
batch_size = image.shape[0]
|
||||
if isinstance(prompt, list):
|
||||
if batch_size == 1:
|
||||
image = image.expand(len(prompt), -1, -1, -1)
|
||||
elif isinstance(prompt, str):
|
||||
if batch_size > 1:
|
||||
prompt = [prompt] * batch_size
|
||||
|
||||
# text embeddings
|
||||
text_input = self.tokenizer(
|
||||
prompt, padding='max_length', max_length=77, return_tensors='pt')
|
||||
text_embeddings = self.text_encoder(text_input.input_ids.to(DEVICE))[0]
|
||||
print('input text embeddings :', text_embeddings.shape)
|
||||
# define initial latents
|
||||
latents = self.image2latent(image)
|
||||
start_latents = latents
|
||||
|
||||
# unconditional embedding for classifier free guidance
|
||||
if guidance_scale > 1.:
|
||||
unconditional_input = self.tokenizer(
|
||||
[''] * batch_size,
|
||||
padding='max_length',
|
||||
max_length=77,
|
||||
return_tensors='pt')
|
||||
unconditional_embeddings = self.text_encoder(
|
||||
unconditional_input.input_ids.to(DEVICE))[0]
|
||||
text_embeddings = torch.cat(
|
||||
[unconditional_embeddings, text_embeddings], dim=0)
|
||||
|
||||
print('latents shape: ', latents.shape)
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
print('Valid timesteps: ', reversed(self.scheduler.timesteps))
|
||||
latents_list = [latents]
|
||||
pred_x0_list = [latents]
|
||||
for i, t in enumerate(
|
||||
tqdm(
|
||||
reversed(self.scheduler.timesteps),
|
||||
desc='DDIM Inversion')):
|
||||
if guidance_scale > 1.:
|
||||
model_inputs = torch.cat([latents] * 2)
|
||||
else:
|
||||
model_inputs = latents
|
||||
|
||||
# predict the noise
|
||||
noise_pred = self.unet(
|
||||
model_inputs, t, encoder_hidden_states=text_embeddings).sample
|
||||
if guidance_scale > 1.:
|
||||
noise_pred_uncon, noise_pred_con = noise_pred.chunk(2, dim=0)
|
||||
noise_pred = noise_pred_uncon + guidance_scale * (
|
||||
noise_pred_con - noise_pred_uncon)
|
||||
# compute the previous noise sample x_t-1 -> x_t
|
||||
latents, pred_x0 = self.next_step(noise_pred, t, latents)
|
||||
latents_list.append(latents)
|
||||
pred_x0_list.append(pred_x0)
|
||||
|
||||
if return_intermediates:
|
||||
return latents, latents_list
|
||||
return latents, start_latents
|
||||
@@ -82,7 +82,7 @@ class ImagePanopticSegmentationPipeline(Pipeline):
|
||||
ids = ids[legal_indices]
|
||||
labels = np.array([id % INSTANCE_OFFSET for id in ids], dtype=np.int64)
|
||||
segms = (pan_results[None] == ids[:, None, None])
|
||||
masks = [it.astype(np.int) for it in segms]
|
||||
masks = [it.astype(np.int32) for it in segms]
|
||||
labels_txt = np.array(self.model.CLASSES)[labels].tolist()
|
||||
outputs = {
|
||||
OutputKeys.MASKS: masks,
|
||||
|
||||
@@ -31,7 +31,8 @@ class F3NetForProductSegmentationPipeline(Pipeline):
|
||||
logger.info('load model done')
|
||||
|
||||
def preprocess(self, input: Input) -> Dict[str, Any]:
|
||||
img = LoadImage.convert_to_ndarray(input['input_path'])
|
||||
img = LoadImage.convert_to_ndarray(input)
|
||||
|
||||
img = img.astype(np.float32)
|
||||
return img
|
||||
|
||||
|
||||
71
modelscope/pipelines/cv/surface_recon_common_pipeline.py
Normal file
71
modelscope/pipelines/cv/surface_recon_common_pipeline.py
Normal file
@@ -0,0 +1,71 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import Any, Dict
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.base import Input, Model, Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.pipelines.util import is_model, is_official_hub_path
|
||||
from modelscope.utils.constant import Invoke, Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.surface_recon_common, module_name=Pipelines.surface_recon_common)
|
||||
class SurfaceReconCommonPipeline(Pipeline):
|
||||
""" Surface reconstruction common pipeline
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from modelscope.pipelines import pipeline
|
||||
>>> surface_recon_common = pipeline(Tasks.surface_recon_common,
|
||||
'damo/cv_surface-reconstruction-common')
|
||||
>>> surface_recon_common({
|
||||
'data_dir': '/data/lego', # data dir path (str)
|
||||
'save_dir': './output', # save dir path (str)
|
||||
})
|
||||
>>> #
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, model, device='gpu', **kwargs):
|
||||
"""
|
||||
use model to create a image sky change pipeline for image editing
|
||||
Args:
|
||||
model (str or Model): model_id on modelscope hub
|
||||
device (str): only support gpu
|
||||
"""
|
||||
model = Model.from_pretrained(
|
||||
model,
|
||||
device=device,
|
||||
model_prefetched=True,
|
||||
invoked_by=Invoke.PIPELINE) if is_model(model) else model
|
||||
|
||||
super().__init__(model=model, **kwargs)
|
||||
if not isinstance(self.model, Model):
|
||||
logger.error('model object is not initialized.')
|
||||
raise Exception('model object is not initialized.')
|
||||
logger.info('load model done')
|
||||
|
||||
def preprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return inputs
|
||||
|
||||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
|
||||
data_dir = input['data_dir']
|
||||
save_dir = input['save_dir']
|
||||
if 'color' in input:
|
||||
color = input['color']
|
||||
else:
|
||||
color = False
|
||||
if 'n_directions' in input:
|
||||
n_directions = input['n_directions']
|
||||
else:
|
||||
n_directions = 8
|
||||
self.model.surface_reconstruction(data_dir, save_dir, color,
|
||||
n_directions)
|
||||
return {OutputKeys.OUTPUT: 'Done'}
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return inputs
|
||||
@@ -4,24 +4,27 @@ from typing import TYPE_CHECKING
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .generative_multi_modal_embedding_pipeline import GEMMMultiModalEmbeddingPipeline
|
||||
from .asr_pipeline import AutomaticSpeechRecognitionPipeline
|
||||
from .diffusers_wrapped import (ChineseStableDiffusionPipeline,
|
||||
StableDiffusionPipeline)
|
||||
from .document_vl_embedding_pipeline import DocumentVLEmbeddingPipeline
|
||||
from .generative_multi_modal_embedding_pipeline import \
|
||||
GEMMMultiModalEmbeddingPipeline
|
||||
from .image_captioning_pipeline import ImageCaptioningPipeline
|
||||
from .visual_entailment_pipeline import VisualEntailmentPipeline
|
||||
from .visual_grounding_pipeline import VisualGroundingPipeline
|
||||
from .mgeo_ranking_pipeline import MGeoRankingPipeline
|
||||
from .multi_modal_embedding_pipeline import MultiModalEmbeddingPipeline
|
||||
from .multimodal_dialogue_pipeline import MultimodalDialoguePipeline
|
||||
from .prost_text_video_retrieval_pipeline import \
|
||||
ProSTTextVideoRetrievalPipeline
|
||||
from .soonet_video_temporal_grounding_pipeline import \
|
||||
SOONetVideoTemporalGroundingPipeline
|
||||
from .text_to_image_synthesis_pipeline import TextToImageSynthesisPipeline
|
||||
from .text_to_video_synthesis_pipeline import TextToVideoSynthesisPipeline
|
||||
from .video_captioning_pipeline import VideoCaptioningPipeline
|
||||
from .video_multi_modal_embedding_pipeline import \
|
||||
VideoMultiModalEmbeddingPipeline
|
||||
from .visual_question_answering_pipeline import VisualQuestionAnsweringPipeline
|
||||
from .asr_pipeline import AutomaticSpeechRecognitionPipeline
|
||||
from .mgeo_ranking_pipeline import MGeoRankingPipeline
|
||||
from .document_vl_embedding_pipeline import DocumentVLEmbeddingPipeline
|
||||
from .video_captioning_pipeline import VideoCaptioningPipeline
|
||||
from .video_question_answering_pipeline import VideoQuestionAnsweringPipeline
|
||||
from .diffusers_wrapped import StableDiffusionPipeline, ChineseStableDiffusionPipeline
|
||||
from .soonet_video_temporal_grounding_pipeline import SOONetVideoTemporalGroundingPipeline
|
||||
from .text_to_video_synthesis_pipeline import TextToVideoSynthesisPipeline
|
||||
from .multimodal_dialogue_pipeline import MultimodalDialoguePipeline
|
||||
from .videocomposer_pipeline import VideoComposerPipeline
|
||||
else:
|
||||
_import_structure = {
|
||||
@@ -29,6 +32,8 @@ else:
|
||||
'visual_entailment_pipeline': ['VisualEntailmentPipeline'],
|
||||
'visual_grounding_pipeline': ['VisualGroundingPipeline'],
|
||||
'multi_modal_embedding_pipeline': ['MultiModalEmbeddingPipeline'],
|
||||
'prost_text_video_retrieval_pipeline':
|
||||
['ProSTTextVideoRetrievalPipeline'],
|
||||
'text_to_image_synthesis_pipeline': ['TextToImageSynthesisPipeline'],
|
||||
'visual_question_answering_pipeline':
|
||||
['VisualQuestionAnsweringPipeline'],
|
||||
|
||||
21
modelscope/pipelines/multi_modal/cone2_pipeline/__init__.py
Normal file
21
modelscope/pipelines/multi_modal/cone2_pipeline/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .cones2_inference_pipeline import Cones2InferencePipeline
|
||||
else:
|
||||
_import_structure = {
|
||||
'cones2_inference_pipeline': ['Cones2InferencePipeline'],
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
@@ -0,0 +1,494 @@
|
||||
# Copyright 2023 The HuggingFace Team.
|
||||
# Copyright 2023 The Alibaba Fundamental Vision Team Authors. All rights reserved.
|
||||
|
||||
# The implementation here is modified based on diffusers,
|
||||
# originally Apache License, Copyright 2023 The HuggingFace Team
|
||||
|
||||
import math
|
||||
from typing import Any, Dict
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from diffusers import LMSDiscreteScheduler, StableDiffusionPipeline
|
||||
from diffusers.models.cross_attention import CrossAttention
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import \
|
||||
StableDiffusionPipelineOutput
|
||||
from PIL import Image
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.pipelines.multi_modal.diffusers_wrapped.diffusers_pipeline import \
|
||||
DiffusersPipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.text_to_image_synthesis, module_name=Pipelines.cones2_inference)
|
||||
class Cones2InferencePipeline(DiffusersPipeline):
|
||||
r""" Cones2 Inference Pipeline.
|
||||
|
||||
Examples:
|
||||
|
||||
>>> from modelscope.pipelines import pipeline
|
||||
|
||||
>>> pipeline =pipeline(task=Tasks.text_to_image_synthesis, model= 'damo/Cones2', model_revision='v1.0.1')
|
||||
>>> {
|
||||
>>> "text": 'a mug and a dog on the beach',
|
||||
>>> "subject_list": [["mug", 2], ["dog", 5]],
|
||||
>>> "color_context": {"255,192,0": ["mug", 2.5], "255,0,0": ["dog", 2.5]},
|
||||
>>> "layout": 'data/test/images/mask_example.png'
|
||||
>>> }
|
||||
>>>
|
||||
"""
|
||||
|
||||
def __init__(self, model: str, device: str = 'gpu', **kwargs):
|
||||
"""
|
||||
use `model` to create a stable diffusion pipeline
|
||||
Args:
|
||||
model: model id on modelscope hub.
|
||||
device: str = 'gpu'
|
||||
"""
|
||||
super().__init__(model, device, **kwargs)
|
||||
self.pipeline = StableDiffusionPipeline.from_pretrained(model)
|
||||
self.pipeline.text_encoder.pooler = None
|
||||
self.pipeline.to(self.device)
|
||||
|
||||
def forward(self, inputs: Dict[str, Any],
|
||||
**forward_params) -> Dict[str, Any]:
|
||||
if not isinstance(inputs, dict):
|
||||
raise ValueError(
|
||||
f'Expected the input to be a dictionary, but got {type(input)}'
|
||||
)
|
||||
if 'text' not in inputs:
|
||||
raise ValueError('input should contain "text", but not found')
|
||||
|
||||
return self.layout_guidance_sampling(
|
||||
prompt=inputs.get('text'),
|
||||
residual_dict=inputs.get('residual_dict', None),
|
||||
subject_list=inputs.get('subject_list'),
|
||||
color_context=inputs.get('color_context', None),
|
||||
layout=inputs.get('layout', None),
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def layout_guidance_sampling(
|
||||
self,
|
||||
prompt='',
|
||||
residual_dict=None,
|
||||
subject_list=None,
|
||||
color_context=None,
|
||||
layout=None,
|
||||
cfg_scale=7.5,
|
||||
inference_steps=50,
|
||||
guidance_steps=50,
|
||||
guidance_weight=0.05,
|
||||
weight_negative=-1e8,
|
||||
):
|
||||
|
||||
layout = Image.open(layout).resize((768, 768)).convert('RGB')
|
||||
subject_color_dict = {
|
||||
tuple(map(int, key.split(','))): value
|
||||
for key, value in color_context.items()
|
||||
}
|
||||
|
||||
vae = self.pipeline.vae
|
||||
unet = self.pipeline.unet
|
||||
text_encoder = self.pipeline.text_encoder
|
||||
tokenizer = self.pipeline.tokenizer
|
||||
unconditional_input_prompt = ''
|
||||
scheduler = LMSDiscreteScheduler.from_config(
|
||||
self.pipeline.scheduler.config)
|
||||
scheduler.set_timesteps(inference_steps, device=self.device)
|
||||
if guidance_steps > 0:
|
||||
guidance_steps = min(guidance_steps, inference_steps)
|
||||
scheduler_guidance = LMSDiscreteScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule='scaled_linear',
|
||||
num_train_timesteps=1000,
|
||||
)
|
||||
scheduler_guidance.set_timesteps(
|
||||
guidance_steps, device=self.device)
|
||||
|
||||
# Process input prompt text
|
||||
text_input = tokenizer(
|
||||
[prompt],
|
||||
padding='max_length',
|
||||
max_length=tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors='pt',
|
||||
)
|
||||
|
||||
# Edit text embedding conditions with residual token embeddings.
|
||||
cond_embeddings = text_encoder(text_input.input_ids.to(self.device))[0]
|
||||
if residual_dict is not None:
|
||||
for name, token in subject_list:
|
||||
residual_token_embedding = torch.load(residual_dict[name])
|
||||
cond_embeddings[0][token] += residual_token_embedding.reshape(
|
||||
1024)
|
||||
|
||||
# Process unconditional input "" for classifier-free guidance.
|
||||
max_length = text_input.input_ids.shape[-1]
|
||||
uncond_input = tokenizer([unconditional_input_prompt],
|
||||
padding='max_length',
|
||||
max_length=max_length,
|
||||
return_tensors='pt')
|
||||
uncond_embeddings = text_encoder(
|
||||
uncond_input.input_ids.to(self.device))[0]
|
||||
|
||||
register_attention_control(unet)
|
||||
|
||||
# Calculate the hidden features for each cross attention layer.
|
||||
hidden_states, uncond_hidden_states = _extract_cross_attention(
|
||||
tokenizer, self.device, layout, subject_color_dict, text_input,
|
||||
weight_negative)
|
||||
hidden_states['CONDITION_TENSOR'] = cond_embeddings
|
||||
uncond_hidden_states['CONDITION_TENSOR'] = uncond_embeddings
|
||||
hidden_states['function'] = lambda w, sigma, qk: (
|
||||
guidance_weight * w * math.log(1 + sigma**2)) * qk.std()
|
||||
uncond_hidden_states['function'] = lambda w, sigma, qk: 0.0
|
||||
|
||||
# Sampling the initial latents.
|
||||
latent_size = (1, unet.in_channels, 96, 96)
|
||||
latents = torch.randn(latent_size).to(self.device)
|
||||
latents = latents * scheduler.init_noise_sigma
|
||||
|
||||
for i, t in tqdm(
|
||||
enumerate(scheduler.timesteps),
|
||||
total=len(scheduler.timesteps)):
|
||||
# Improve the harmony of generated images by self-recurrence.
|
||||
if i < guidance_steps:
|
||||
loop = 2
|
||||
else:
|
||||
loop = 1
|
||||
for k in range(loop):
|
||||
if i < guidance_steps:
|
||||
sigma = scheduler_guidance.sigmas[i]
|
||||
latent_model_input = scheduler.scale_model_input(
|
||||
latents, t)
|
||||
_t = t
|
||||
|
||||
hidden_states.update({'SIGMA': sigma})
|
||||
|
||||
noise_pred_text = unet(
|
||||
latent_model_input,
|
||||
_t,
|
||||
encoder_hidden_states=hidden_states,
|
||||
).sample
|
||||
|
||||
uncond_hidden_states.update({'SIGMA': sigma})
|
||||
|
||||
noise_pred_uncond = unet(
|
||||
latent_model_input,
|
||||
_t,
|
||||
encoder_hidden_states=uncond_hidden_states,
|
||||
).sample
|
||||
|
||||
noise_pred = noise_pred_uncond + cfg_scale * (
|
||||
noise_pred_text - noise_pred_uncond)
|
||||
latents = scheduler.step(noise_pred, t, latents,
|
||||
1).prev_sample
|
||||
|
||||
# Self-recurrence.
|
||||
if k < 1 and loop > 1:
|
||||
noise_recurent = torch.randn(latents.shape).to(
|
||||
self.device)
|
||||
sigma_difference = scheduler.sigmas[
|
||||
i]**2 - scheduler.sigmas[i + 1]**2
|
||||
latents = latents + noise_recurent * (
|
||||
sigma_difference**0.5)
|
||||
else:
|
||||
latent_model_input = scheduler.scale_model_input(
|
||||
latents, t)
|
||||
_t = t
|
||||
noise_pred_text = unet(
|
||||
latent_model_input,
|
||||
_t,
|
||||
encoder_hidden_states=cond_embeddings,
|
||||
).sample
|
||||
|
||||
latent_model_input = scheduler.scale_model_input(
|
||||
latents, t)
|
||||
|
||||
noise_pred_uncond = unet(
|
||||
latent_model_input,
|
||||
_t,
|
||||
encoder_hidden_states=uncond_embeddings,
|
||||
).sample
|
||||
|
||||
noise_pred = noise_pred_uncond + cfg_scale * (
|
||||
noise_pred_text - noise_pred_uncond)
|
||||
latents = scheduler.step(noise_pred, t, latents,
|
||||
1).prev_sample
|
||||
|
||||
edited_images = _latents_to_images(vae, latents)
|
||||
|
||||
return StableDiffusionPipelineOutput(
|
||||
images=edited_images, nsfw_content_detected=None)
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Any], **kwargs) -> Dict[str, Any]:
|
||||
images = []
|
||||
for img in inputs.images:
|
||||
if isinstance(img, Image.Image):
|
||||
img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
|
||||
images.append(img)
|
||||
return {OutputKeys.OUTPUT_IMGS: images}
|
||||
|
||||
|
||||
class Cones2AttnProcessor:
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def __call__(self,
|
||||
attn: CrossAttention,
|
||||
hidden_states,
|
||||
encoder_hidden_states=None,
|
||||
attention_mask=None):
|
||||
batch_size, sequence_length, _ = hidden_states.shape
|
||||
query = attn.to_q(hidden_states)
|
||||
is_dict_format = True
|
||||
if encoder_hidden_states is not None:
|
||||
if 'CONDITION_TENSOR' in encoder_hidden_states:
|
||||
encoder_hidden = encoder_hidden_states['CONDITION_TENSOR']
|
||||
else:
|
||||
encoder_hidden = encoder_hidden_states
|
||||
is_dict_format = False
|
||||
else:
|
||||
encoder_hidden = hidden_states
|
||||
|
||||
key = attn.to_k(encoder_hidden)
|
||||
value = attn.to_v(encoder_hidden)
|
||||
|
||||
query = attn.head_to_batch_dim(query)
|
||||
key = attn.head_to_batch_dim(key)
|
||||
value = attn.head_to_batch_dim(value)
|
||||
|
||||
attention_scores = torch.matmul(query, key.transpose(-1, -2))
|
||||
attention_size_of_img = attention_scores.size()[-2]
|
||||
|
||||
if attention_scores.size()[2] == 77:
|
||||
if is_dict_format:
|
||||
f = encoder_hidden_states['function']
|
||||
try:
|
||||
w = encoder_hidden_states[
|
||||
f'CA_WEIGHT_{attention_size_of_img}']
|
||||
except KeyError:
|
||||
w = encoder_hidden_states['CA_WEIGHT_ORIG']
|
||||
if not isinstance(w, int):
|
||||
img_h, img_w, nc = w.shape
|
||||
ratio = math.sqrt(img_h * img_w
|
||||
/ attention_size_of_img)
|
||||
w = F.interpolate(
|
||||
w.permute(2, 0, 1).unsqueeze(0),
|
||||
scale_factor=1 / ratio,
|
||||
mode='bilinear',
|
||||
align_corners=True)
|
||||
w = F.interpolate(
|
||||
w.reshape(1, nc, -1),
|
||||
size=(attention_size_of_img, ),
|
||||
mode='nearest').permute(2, 1, 0).squeeze()
|
||||
else:
|
||||
w = 0
|
||||
if type(w) is int and w == 0:
|
||||
sigma = encoder_hidden_states['SIGMA']
|
||||
cross_attention_weight = f(w, sigma, attention_scores)
|
||||
else:
|
||||
bias = torch.zeros_like(w)
|
||||
bias[torch.where(w > 0)] = attention_scores.std() * 0
|
||||
sigma = encoder_hidden_states['SIGMA']
|
||||
cross_attention_weight = f(w, sigma, attention_scores)
|
||||
cross_attention_weight = cross_attention_weight + bias
|
||||
else:
|
||||
cross_attention_weight = 0.0
|
||||
else:
|
||||
cross_attention_weight = 0.0
|
||||
|
||||
attention_scores = (attention_scores
|
||||
+ cross_attention_weight) * attn.scale
|
||||
attention_probs = attention_scores.softmax(dim=-1)
|
||||
|
||||
hidden_states = torch.matmul(attention_probs, value)
|
||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
def register_attention_control(unet):
|
||||
attn_procs = {}
|
||||
for name in unet.attn_processors.keys():
|
||||
attn_procs[name] = Cones2AttnProcessor()
|
||||
|
||||
unet.set_attn_processor(attn_procs)
|
||||
|
||||
|
||||
def _tokens_img_attention_weight(img_context_seperated,
|
||||
tokenized_texts,
|
||||
ratio: int = 8,
|
||||
original_shape=False):
|
||||
token_lis = tokenized_texts['input_ids'][0].tolist()
|
||||
w, h = img_context_seperated[0][1].shape
|
||||
|
||||
w_r, h_r = round(w / ratio), round(h / ratio)
|
||||
ret_tensor = torch.zeros((w_r * h_r, len(token_lis)), dtype=torch.float32)
|
||||
for v_as_tokens, img_where_color in img_context_seperated:
|
||||
|
||||
is_in = 0
|
||||
|
||||
for idx, tok in enumerate(token_lis):
|
||||
if token_lis[idx:idx + len(v_as_tokens)] == v_as_tokens:
|
||||
is_in = 1
|
||||
|
||||
ret_tensor[:, idx:idx + len(v_as_tokens)] += (
|
||||
_downsampling(img_where_color, w_r,
|
||||
h_r).reshape(-1,
|
||||
1).repeat(1, len(v_as_tokens)))
|
||||
|
||||
if not is_in == 1:
|
||||
print(
|
||||
f'Warning ratio {ratio} : tokens {v_as_tokens} not found in text'
|
||||
)
|
||||
|
||||
if original_shape:
|
||||
ret_tensor = ret_tensor.reshape((w_r, h_r, len(token_lis)))
|
||||
|
||||
return ret_tensor
|
||||
|
||||
|
||||
def _image_context_seperator(img, color_context: dict, _tokenizer, neg: float):
|
||||
ret_lists = []
|
||||
if img is not None:
|
||||
w, h = img.size
|
||||
matrix = np.zeros((h, w))
|
||||
for color, v in color_context.items():
|
||||
color = tuple(color)
|
||||
if len(color) > 3:
|
||||
color = color[:3]
|
||||
if isinstance(color, str):
|
||||
r, g, b = color[1:3], color[3:5], color[5:7]
|
||||
color = (int(r, 16), int(g, 16), int(b, 16))
|
||||
img_where_color = (np.array(img) == color).all(axis=-1)
|
||||
matrix[img_where_color] = 1
|
||||
|
||||
for color, (subject, weight_active) in color_context.items():
|
||||
if len(color) > 3:
|
||||
color = color[:3]
|
||||
v_input = _tokenizer(
|
||||
subject,
|
||||
max_length=_tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
)
|
||||
|
||||
v_as_tokens = v_input['input_ids'][1:-1]
|
||||
if isinstance(color, str):
|
||||
r, g, b = color[1:3], color[3:5], color[5:7]
|
||||
color = (int(r, 16), int(g, 16), int(b, 16))
|
||||
img_where_color = (np.array(img) == color).all(axis=-1)
|
||||
matrix[img_where_color] = 1
|
||||
if not img_where_color.sum() > 0:
|
||||
print(
|
||||
f'Warning : not a single color {color} not found in image')
|
||||
|
||||
img_where_color_init = torch.where(
|
||||
torch.tensor(img_where_color, dtype=torch.bool), weight_active,
|
||||
neg)
|
||||
|
||||
img_where_color = torch.where(
|
||||
torch.from_numpy(matrix == 1) & (img_where_color_init == 0.0),
|
||||
torch.tensor(neg), img_where_color_init)
|
||||
|
||||
ret_lists.append((v_as_tokens, img_where_color))
|
||||
else:
|
||||
w, h = 768, 768
|
||||
|
||||
if len(ret_lists) == 0:
|
||||
ret_lists.append(([-1], torch.zeros((w, h), dtype=torch.float32)))
|
||||
return ret_lists, w, h
|
||||
|
||||
|
||||
def _extract_cross_attention(tokenizer, device, color_map_image, color_context,
|
||||
text_input, neg):
|
||||
# Process color map image and context
|
||||
seperated_word_contexts, width, height = _image_context_seperator(
|
||||
color_map_image, color_context, tokenizer, neg)
|
||||
|
||||
# Compute cross-attention weights
|
||||
cross_attention_weight_1 = _tokens_img_attention_weight(
|
||||
seperated_word_contexts, text_input, ratio=1,
|
||||
original_shape=True).to(device)
|
||||
cross_attention_weight_8 = _tokens_img_attention_weight(
|
||||
seperated_word_contexts, text_input, ratio=8).to(device)
|
||||
cross_attention_weight_16 = _tokens_img_attention_weight(
|
||||
seperated_word_contexts, text_input, ratio=16).to(device)
|
||||
cross_attention_weight_32 = _tokens_img_attention_weight(
|
||||
seperated_word_contexts, text_input, ratio=32).to(device)
|
||||
cross_attention_weight_64 = _tokens_img_attention_weight(
|
||||
seperated_word_contexts, text_input, ratio=64).to(device)
|
||||
|
||||
hidden_states = {
|
||||
'CA_WEIGHT_ORIG': cross_attention_weight_1, # 768 x 768
|
||||
'CA_WEIGHT_9216': cross_attention_weight_8, # 96 x 96
|
||||
'CA_WEIGHT_2304': cross_attention_weight_16, # 48 x 48
|
||||
'CA_WEIGHT_576': cross_attention_weight_32, # 24 x 24
|
||||
'CA_WEIGHT_144': cross_attention_weight_64, # 12 x 12
|
||||
}
|
||||
|
||||
uncond_hidden_states = {
|
||||
'CA_WEIGHT_ORIG': 0,
|
||||
'CA_WEIGHT_9216': 0,
|
||||
'CA_WEIGHT_2304': 0,
|
||||
'CA_WEIGHT_576': 0,
|
||||
'CA_WEIGHT_144': 0,
|
||||
}
|
||||
|
||||
return hidden_states, uncond_hidden_states
|
||||
|
||||
|
||||
def _downsampling(img: torch.tensor, w: int, h: int) -> torch.tensor:
|
||||
return F.interpolate(
|
||||
img.unsqueeze(0).unsqueeze(1),
|
||||
size=(w, h),
|
||||
mode='bilinear',
|
||||
align_corners=True,
|
||||
).squeeze()
|
||||
|
||||
|
||||
def _latents_to_images(vae, latents, scale_factor=0.18215):
|
||||
"""Decode latents to PIL images."""
|
||||
scaled_latents = 1.0 / scale_factor * latents.clone()
|
||||
images = vae.decode(scaled_latents).sample
|
||||
images = (images / 2 + 0.5).clamp(0, 1)
|
||||
images = images.detach().cpu().permute(0, 2, 3, 1).numpy()
|
||||
|
||||
if images.ndim == 3:
|
||||
images = images[None, ...]
|
||||
images = (images * 255).round().astype('uint8')
|
||||
pil_images = [Image.fromarray(image) for image in images]
|
||||
|
||||
return pil_images
|
||||
|
||||
|
||||
def _sanitize_parameters(self, **pipeline_parameters):
|
||||
"""
|
||||
this method should sanitize the keyword args to preprocessor params,
|
||||
forward params and postprocess params on '__call__' or '_process_single' method
|
||||
|
||||
Returns:
|
||||
Dict[str, str]: preprocess_params = {'image_resolution': self.model.get_resolution()}
|
||||
Dict[str, str]: forward_params = pipeline_parameters
|
||||
Dict[str, str]: postprocess_params = {}
|
||||
"""
|
||||
pipeline_parameters['image_resolution'] = self.model.get_resolution()
|
||||
pipeline_parameters['modelsetting'] = self.model.get_config()
|
||||
pipeline_parameters['model_dir'] = self.model.get_model_dir()
|
||||
pipeline_parameters['control_type'] = self.init_control_type
|
||||
pipeline_parameters['device'] = self.device
|
||||
@@ -0,0 +1,56 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.pipelines.base import Input, Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.device import device_placement
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.text_video_retrieval,
|
||||
module_name=Pipelines.prost_text_video_retrieval)
|
||||
class ProSTTextVideoRetrievalPipeline(Pipeline):
|
||||
'''
|
||||
https://www.modelscope.cn/models/damo/multi_modal_clip_vtretrieval_prost/summary
|
||||
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
text_video_retrieval= pipeline(
|
||||
Tasks.text_video_retrieval,
|
||||
model='damo/multi_modal_clip_vtretrieval_prost')
|
||||
video_path = 'https://modelscope.oss-cn-beijing.aliyuncs.com/test/videos/multi_modal_test_video_9770.mp4'
|
||||
caption = 'a person is connecting something to system'
|
||||
_input = {'video': video_path, 'text': caption}
|
||||
result = text_video_retrieval(_input)
|
||||
'''
|
||||
|
||||
def __init__(self, model: str, **kwargs):
|
||||
"""
|
||||
use `model` to create a text_video_retrieval pipeline for prediction
|
||||
Args:
|
||||
model: model id on modelscope hub.
|
||||
"""
|
||||
super().__init__(model=model)
|
||||
self.model.eval()
|
||||
|
||||
def preprocess(self, input: Input) -> Dict[str, Any]:
|
||||
return input
|
||||
|
||||
def _process_single(self, input: Input, *args, **kwargs) -> Dict[str, Any]:
|
||||
with device_placement(self.framework, self.device_name):
|
||||
out = self.forward(input)
|
||||
|
||||
self._check_output(out)
|
||||
return out
|
||||
|
||||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return self.model(input)
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return inputs
|
||||
@@ -27,7 +27,9 @@ if TYPE_CHECKING:
|
||||
from .translation_quality_estimation_pipeline import TranslationQualityEstimationPipeline
|
||||
from .text_error_correction_pipeline import TextErrorCorrectionPipeline
|
||||
from .word_alignment_pipeline import WordAlignmentPipeline
|
||||
from .text_generation_pipeline import TextGenerationPipeline, TextGenerationT5Pipeline, SeqGPTPipeline
|
||||
from .text_generation_pipeline import TextGenerationPipeline, TextGenerationT5Pipeline, \
|
||||
SeqGPTPipeline, ChatGLM6bTextGenerationPipeline, ChatGLM6bV2TextGenerationPipeline, \
|
||||
QWenChatPipeline, QWenTextGenerationPipeline, Llama2TaskPipeline
|
||||
from .fid_dialogue_pipeline import FidDialoguePipeline
|
||||
from .token_classification_pipeline import TokenClassificationPipeline
|
||||
from .translation_pipeline import TranslationPipeline
|
||||
@@ -80,7 +82,10 @@ else:
|
||||
'word_alignment_pipeline': ['WordAlignmentPipeline'],
|
||||
'text_generation_pipeline': [
|
||||
'TextGenerationPipeline', 'TextGenerationT5Pipeline',
|
||||
'SeqGPTPipeline'
|
||||
'ChatGLM6bTextGenerationPipeline',
|
||||
'ChatGLM6bV2TextGenerationPipeline', 'QWenChatPipeline',
|
||||
'QWenTextGenerationPipeline', 'SeqGPTPipeline',
|
||||
'Llama2TaskPipeline'
|
||||
],
|
||||
'fid_dialogue_pipeline': ['FidDialoguePipeline'],
|
||||
'token_classification_pipeline': ['TokenClassificationPipeline'],
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user