Merge pull request #108 from modelscope/master-merge-internal20230215

Master merge internal20230215
This commit is contained in:
wenmeng zhou
2023-02-16 16:13:51 +08:00
committed by GitHub
772 changed files with 59659 additions and 6291 deletions

View File

@@ -96,9 +96,9 @@ else
fi
if [[ $python_version == 3.7* ]]; then
base_tag=$base_tag-py37
elif [[ $python_version == z* ]]; then
elif [[ $python_version == 3.8* ]]; then
base_tag=$base_tag-py38
elif [[ $python_version == z* ]]; then
elif [[ $python_version == 3.9* ]]; then
base_tag=$base_tag-py39
else
echo "Unsupport python version: $python_version"
@@ -129,8 +129,15 @@ else
echo "Building dsw image well need set ModelScope lib cache location."
docker_file_content="${docker_file_content} \nENV MODELSCOPE_CACHE=/mnt/workspace/.cache/modelscope"
fi
if [ "$is_ci_test" == "True" ]; then
echo "Building CI image, uninstall modelscope"
docker_file_content="${docker_file_content} \nRUN pip uninstall modelscope -y"
fi
printf "$docker_file_content" > Dockerfile
docker build -t $IMAGE_TO_BUILD \
while true
do
docker build -t $IMAGE_TO_BUILD \
--build-arg USE_GPU \
--build-arg BASE_IMAGE \
--build-arg PYTHON_VERSION \
@@ -138,11 +145,14 @@ docker build -t $IMAGE_TO_BUILD \
--build-arg CUDATOOLKIT_VERSION \
--build-arg TENSORFLOW_VERSION \
-f Dockerfile .
if [ $? -eq 0 ]; then
echo "Image build done"
break
else
echo "Running docker build command error, we will retry"
fi
done
if [ $? -ne 0 ]; then
echo "Running docker build command error, please check the log!"
exit -1
fi
if [ "$run_ci_test" == "True" ]; then
echo "Running ci case."
export MODELSCOPE_CACHE=/home/mulin.lyh/model_scope_cache

View File

@@ -20,15 +20,15 @@ if [ "$MODELSCOPE_SDK_DEBUG" == "True" ]; then
fi
fi
awk -F: '/^[^#]/ { print $1 }' requirements/framework.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
awk -F: '/^[^#]/ { print $1 }' requirements/audio.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
awk -F: '/^[^#]/ { print $1 }' requirements/cv.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
awk -F: '/^[^#]/ { print $1 }' requirements/multi-modal.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
awk -F: '/^[^#]/ { print $1 }' requirements/nlp.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
awk -F: '/^[^#]/ { print $1 }' requirements/science.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
pip install -r requirements/framework.txt -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
pip install -r requirements/audio.txt -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
pip install -r requirements/cv.txt -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
pip install -r requirements/multi-modal.txt -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
pip install -r requirements/nlp.txt -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
pip install -r requirements/science.txt -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
# test with install
python setup.py install
pip install .
else
echo "Running case in release image, run case directly!"
fi

1
.gitignore vendored
View File

@@ -122,6 +122,7 @@ tensorboard.sh
.DS_Store
replace.sh
result.png
result.jpg
# Pytorch
*.pth

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:cb35bff3dac9aec36e259461fecae1e1bc2ec029615f30713111cd598993676c
size 249646

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:d7daff767e13d9a2187b676d958065121cd5e26da046d65cd9604e91a87525a2
size 201006

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:a723c134978a17fe12ca2374d0281a8003a56fa44ff9d2249a08791714983362
size 249646

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:1f516e38eea7a16fd48fddc34953cb227d86d22fbcd31de0c1334bb14b96dba8
size 932252

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:430575a8cb668113d6b0e91e403be0c0e36a95bbb96c484603a625b52f71edd9
size 11858

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:7d486900ecca027d70453322d0f22de4b36f9534a324b8b1cda3ea86bb72bac6
size 353096

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:0abad2347748bf312ab0dbce48fdc643a703d94970e1b181cf19b9be6312db8c
size 3145728

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:b3a4f864cee22265fdbb8008719e0e2e36235bd4bb2fdfbc9278b0b964e86eff
size 1921140

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:7f4bc4dd40c69ecc54bc9517f52fbf3df9a5f682cd9f4d4f3f1376bf33ede22d
size 2820304

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:1f6b6b4abfcc2fc9042c4e51c2e5f530ff84b345cd3176b11e8317143c5a7e0f
size 91130

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:0e8a71df766b615e20a5e1cacd47796a5668747e039e7f6f6e1b029b40818cc2
size 196993

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:6099183bbc513371c3bded04dbff688958a9c7ab569370c0fb4809fc64850e47
size 704685

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:5bebb94d42fa4b8dd462fecfa7b248402a30cbc637344ce26143071ca2c470d7
size 1636

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:680d76723fc28bc6ce729a1cd6f11a7d5fc26b5bfe3b486d885417935c20f493
size 869811

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:176c824d99af119b36f743d3d90b44529167b0e4fc6db276da60fa140ee3f4a9
size 87228

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:5b5861ca8955f8ff906abe78f2b32bc49deee2832f4518ffe4bb584653f3c9e9
size 187443

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:40f535f4411fc9b3ea9d2d8c7a352f6f9a33465e797332bd1a4162b40aaffe5f
size 338334

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:cd3415c9bf1cd099a379f0b3c8049d0f602ec900c9d335b75058355d8db2b077
size 358916

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:63c6cd0f0f3b4201a9450dcf3db4b5b4a2b9ad2f48885854868d0c2b6406aac7
size 471097

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:9c934ced1221d27153a15c14351c575a91f3ff5a6650c3dc9e0778a4245b2804
size 1192

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:f2ab6add1c8a215ca6199baa68d56bca99dbdae7391937493067a6f363b059de
size 1453

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:d87bd9fa4dca7c7dbb3253e733517303d9b85c9c6600a58c9e9b7150468036da
size 1410

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:6b63bc83b6f5dfeb66f3c79db6fa28b0683690b5dad80b414a03ed723b351edc
size 467695

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:9de64a9f9e1903f2a72bbddccfbffd16f6ea9e7a855e673792d66e7ad74c8ff4
size 240669

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:5965f3f3293fb7616e439ef4821d586de1f129bcf08279bbd10a5f42463d542f
size 240953

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:10b494cbc1a29b228745bcb26897e2524569b467b88cc9839be38504d268ca30
size 55485

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:2a1976ea249b4ad5409cdae403dcd154fac3c628909b6b1874cc968960e2c62d
size 8259

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:2f832af4703878076e42fb41544b82147fd31b6be06713975872f16294d1a613
size 28297

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:b6ab556a1d69010cfe6dd136ff3fbd17ed122c6d0c3509667ef40a656bc18464
size 87334

BIN
data/test/images/images.zip Normal file

Binary file not shown.

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:602b46c6ba1d18fd3b91fd3b47112d37ca9d8e1ed72f0c0ea93ad8d493f5182e
size 20299

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:c0791f043b905f2e77ccf2f8c5b29182e1fc99cee16d9069e8bbc1704e917268
size 20631

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:78d7bf999d1a4186309693ff1b966edb3ccd40f7861a7589167cf9e33897a693
size 369725

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:8b28d9c33eff034a706534f195f4443f8c053a74d5553787a5cb9b20873c072f
size 1962

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:bbd99f0253d6e0d10ec500cf781cc83b93809db58da54bd914b0b80b7fe8d8a4
size 2409

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:a834d1272253559cdf45a5f09642fb0b5209242dca854fce849efc15cebd4028
size 4623264

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:9941ac4a5dd0d9eea5d33ce0009da34d0c93c64ed062479e6c8efb4788e8ef7c
size 522972

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:824cc8beaaa8747a3ec32f4c79308e468838c448853f40e882a7cc090c71bf96
size 2151630

View File

@@ -12,7 +12,7 @@ RUN apt-get update && apt-get install -y --reinstall ca-certificates && \
apt-get clean && \
cp /tmp/resources/ubuntu20.04_sources.tuna /etc/apt/sources.list && \
apt-get update && \
apt-get install -y locales wget git strace gdb vim ffmpeg libsm6 tzdata language-pack-zh-hans ttf-wqy-microhei ttf-wqy-zenhei xfonts-wqy libxext6 build-essential ninja-build && \
apt-get install -y locales wget git strace gdb sox libopenmpi-dev curl strace vim ffmpeg libsm6 tzdata language-pack-zh-hans ttf-wqy-microhei ttf-wqy-zenhei xfonts-wqy libxext6 build-essential ninja-build && \
wget https://packagecloud.io/github/git-lfs/packages/debian/bullseye/git-lfs_3.2.0_amd64.deb/download -O ./git-lfs_3.2.0_amd64.deb && \
dpkg -i ./git-lfs_3.2.0_amd64.deb && \
rm -f ./git-lfs_3.2.0_amd64.deb && \
@@ -58,12 +58,46 @@ RUN if [ "$USE_GPU" = "True" ] ; then \
pip install --no-cache-dir tensorflow==$TENSORFLOW_VERSION; \
fi
# mmcv-full<=1.7.0 for mmdet3d compatible
RUN if [ "$USE_GPU" = "True" ] ; then \
CUDA_HOME=/usr/local/cuda TORCH_CUDA_ARCH_LIST="5.0 5.2 6.0 6.1 7.0 7.5 8.0 8.6" MMCV_WITH_OPS=1 MAX_JOBS=8 FORCE_CUDA=1 pip install --no-cache-dir mmcv-full && pip cache purge; \
CUDA_HOME=/usr/local/cuda TORCH_CUDA_ARCH_LIST="5.0 5.2 6.0 6.1 7.0 7.5 8.0 8.6" MMCV_WITH_OPS=1 MAX_JOBS=8 FORCE_CUDA=1 pip install --no-cache-dir 'mmcv-full<=1.7.0' && pip cache purge; \
else \
MMCV_WITH_OPS=1 MAX_JOBS=8 pip install --no-cache-dir mmcv-full && pip cache purge; \
MMCV_WITH_OPS=1 MAX_JOBS=8 pip install --no-cache-dir 'mmcv-full<=1.7.0' && pip cache purge; \
fi
# default shell bash
ENV SHELL=/bin/bash
# install special package
RUN if [ "$USE_GPU" = "True" ] ; then \
pip install --no-cache-dir dgl-cu113 dglgo -f https://data.dgl.ai/wheels/repo.html; \
else \
pip install --no-cache-dir dgl dglgo -f https://data.dgl.ai/wheels/repo.html; \
fi
# copy install scripts
COPY docker/scripts/install_unifold.sh docker/scripts/install_colmap.sh docker/scripts/install_pytorch3d_nvdiffrast.sh docker/scripts/install_tiny_cuda_nn.sh docker/scripts/install_apex.sh /tmp/
# for uniford
RUN if [ "$USE_GPU" = "True" ] ; then \
bash /tmp/install_unifold.sh; \
else \
echo 'cpu unsupport uniford'; \
fi
RUN if [ "$USE_GPU" = "True" ] ; then \
pip install --no-cache-dir git+https://github.com/gxd1994/Pointnet2.PyTorch.git@master#subdirectory=pointnet2; \
else \
echo 'cpu unsupport Pointnet2'; \
fi
RUN pip install --no-cache-dir detectron2==0.3 -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
# 3d supports
RUN bash /tmp/install_colmap.sh
RUN bash /tmp/install_tiny_cuda_nn.sh
RUN bash /tmp/install_pytorch3d_nvdiffrast.sh
# end of 3D
# install modelscope
COPY requirements /var/modelscope
RUN pip install --no-cache-dir --upgrade pip && \
@@ -76,42 +110,17 @@ RUN pip install --no-cache-dir --upgrade pip && \
pip install --no-cache-dir -r /var/modelscope/tests.txt -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html && \
pip cache purge
# default shell bash
ENV SHELL=/bin/bash
# install special package
RUN if [ "$USE_GPU" = "True" ] ; then \
pip install --no-cache-dir dgl-cu113 dglgo -f https://data.dgl.ai/wheels/repo.html; \
else \
pip install --no-cache-dir dgl dglgo -f https://data.dgl.ai/wheels/repo.html; \
fi
# install jupyter plugin
RUN mkdir -p /root/.local/share/jupyter/labextensions/ && \
cp -r /tmp/resources/jupyter_plugins/* /root/.local/share/jupyter/labextensions/
COPY docker/scripts/modelscope_env_init.sh /usr/local/bin/ms_env_init.sh
RUN pip install --no-cache-dir https://modelscope.oss-cn-beijing.aliyuncs.com/releases/dependencies/xtcocotools-1.12-cp37-cp37m-linux_x86_64.whl --force
RUN pip install --no-cache-dir xtcocotools==1.12 detectron2==0.3 -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html --force
# for uniford
COPY docker/scripts/install_unifold.sh /tmp/install_unifold.sh
RUN if [ "$USE_GPU" = "True" ] ; then \
bash /tmp/install_unifold.sh; \
else \
echo 'cpu unsupport uniford'; \
fi
RUN pip install --no-cache-dir mmcls>=0.21.0 mmdet>=2.25.0 decord>=0.6.0 numpy==1.18.5 https://pypi.tuna.tsinghua.edu.cn/packages/70/ad/06f8a06cef819606cb1a521bcc144288daee5c7e73c5d722492866cb1b92/wenetruntime-1.11.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl ipykernel fairseq fasttext deepspeed
COPY docker/scripts/install_apex.sh /tmp/install_apex.sh
# speechbrain==0.5.7 for audio compatible
RUN pip install --no-cache-dir speechbrain==0.5.7 adaseq>=0.5.0 mmcls>=0.21.0 mmdet>=2.25.0 decord>=0.6.0 numpy==1.18.5 wenetruntime==1.11.0 ipykernel fairseq fasttext deepspeed
RUN if [ "$USE_GPU" = "True" ] ; then \
bash /tmp/install_apex.sh; \
else \
echo 'cpu unsupport apex'; \
fi
RUN apt-get update && apt-get install -y sox && \
apt-get clean
RUN if [ "$USE_GPU" = "True" ] ; then \
pip install --no-cache-dir git+https://github.com/gxd1994/Pointnet2.PyTorch.git@master#subdirectory=pointnet2; \
else \
echo 'cpu unsupport Pointnet2'; \
fi

View File

@@ -1,6 +1,6 @@
export MAX_JOBS=16
git clone https://github.com/NVIDIA/apex
cd apex
TORCH_CUDA_ARCH_LIST="6.0;6.1;6.2;7.0;7.5;8.0;8.6" pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
cd ..
rm -rf apex
export MAX_JOBS=16 \
&& git clone https://github.com/NVIDIA/apex \
&& cd apex \
&& TORCH_CUDA_ARCH_LIST="6.0;6.1;6.2;7.0;7.5;8.0;8.6" pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./ \
&& cd .. \
&& rm -rf apex

View File

@@ -0,0 +1,24 @@
wget -q https://cmake.org/files/v3.25/cmake-3.25.2-linux-x86_64.sh \
&& mkdir /opt/cmake \
&& sh cmake-3.25.2-linux-x86_64.sh --prefix=/opt/cmake --skip-license \
&& ln -s /opt/cmake/bin/cmake /usr/local/bin/cmake \
&& rm -f cmake-3.25.2-linux-x86_64.sh \
&& apt-get update \
&& apt-get install libboost-program-options-dev libboost-filesystem-dev libboost-graph-dev libboost-system-dev libboost-test-dev libeigen3-dev libflann-dev libsuitesparse-dev libfreeimage-dev libmetis-dev libgoogle-glog-dev libgflags-dev libsqlite3-dev libglew-dev qtbase5-dev libqt5opengl5-dev libcgal-dev libceres-dev -y \
&& export CMAKE_BUILD_PARALLEL_LEVEL=36 \
&& export MAX_JOBS=16 \
&& export COLMAP_VERSION=dev \
&& export CUDA_ARCHITECTURES="all" \
&& git clone https://github.com/colmap/colmap.git \
&& cd colmap \
&& git reset --hard ${COLMAP_VERSION} \
&& mkdir build \
&& cd build \
&& cmake .. -GNinja -DCMAKE_CUDA_ARCHITECTURES=${CUDA_ARCHITECTURES} \
&& ninja \
&& ninja install \
&& cd ../.. \
&& rm -rf colmap \
&& apt-get clean \
&& strip --remove-section=.note.ABI-tag /usr/lib/x86_64-linux-gnu/libQt5Core.so.5 \
&& rm -rf /var/lib/apt/lists/*

View File

@@ -1,12 +0,0 @@
#!/bin/bash
set -eo pipefail
ModelScopeLib=/usr/local/modelscope/lib64
if [ ! -d /usr/local/modelscope ]; then
mkdir -p $ModelScopeLib
fi
# audio libs
wget "http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/release/maas/libs/audio/libmitaec_pyio.so" -O ${ModelScopeLib}/libmitaec_pyio.so

View File

@@ -0,0 +1,14 @@
export CMAKE_BUILD_PARALLEL_LEVEL=36 && export MAX_JOBS=36 && export CMAKE_CUDA_ARCHITECTURES="50;52;60;61;70;75;80;86" \
&& pip install --no-cache-dir fvcore iopath \
&& curl -LO https://github.com/NVIDIA/cub/archive/1.10.0.tar.gz \
&& tar xzf 1.10.0.tar.gz \
&& export CUB_HOME=$PWD/cub-1.10.0 \
&& pip install "git+https://github.com/facebookresearch/pytorch3d.git@stable" \
&& rm -fr 1.10.0.tar.gz cub-1.10.0 \
&& apt-get update \
&& apt-get install -y --no-install-recommends pkg-config libglvnd0 libgl1 libglx0 libegl1 libgles2 libglvnd-dev libgl1-mesa-dev libegl1-mesa-dev libgles2-mesa-dev -y \
&& git clone https://github.com/NVlabs/nvdiffrast.git \
&& cd nvdiffrast \
&& pip install --no-cache-dir . \
&& cd .. \
&& rm -rf nvdiffrast

View File

@@ -0,0 +1,8 @@
export CMAKE_BUILD_PARALLEL_LEVEL=36 && export MAX_JOBS=36 && export TCNN_CUDA_ARCHITECTURES="50;52;60;61;70;75;80;86" \
&& git clone --recursive https://github.com/nvlabs/tiny-cuda-nn \
&& cd tiny-cuda-nn \
&& git checkout v1.6 \
&& cd bindings/torch \
&& python setup.py install \
&& cd ../../.. \
&& rm -rf tiny-cuda-nn

View File

@@ -27,9 +27,9 @@
Currently supported formats include "json", "yaml/yml".
Examples:
>>> load('/path/of/your/file') # file is storaged in disk
>>> load('https://path/of/your/file') # file is storaged in Internet
>>> load('oss://path/of/your/file') # file is storaged in petrel
>>> load('/path/of/your/file') # file is stored in disk
>>> load('https://path/of/your/file') # file is stored on internet
>>> load('oss://path/of/your/file') # file is stored in petrel
Returns:
The content from the file.

View File

@@ -5,7 +5,7 @@
.. autoclass:: {{ name }}
:members:
:special-members: __init__, __call__
..
autogenerated from source/_templates/classtemplate.rst

View File

@@ -12,3 +12,16 @@ modelscope.models.cv
:template: classtemplate.rst
easycv_base.EasyCVBaseModel
image_colorization.ddcolor.ddcolor_for_image_colorization.DDColorForImageColorization
image_deblur.nafnet_for_image_deblur.NAFNetForImageDeblur
image_defrcn_fewshot.defrcn_for_fewshot.DeFRCNForFewShot
image_denoise.nafnet_for_image_denoise.NAFNetForImageDenoise
image_face_fusion.image_face_fusion.ImageFaceFusion
image_matching.quadtree_attention_model.QuadTreeAttentionForImageMatching
image_skychange.skychange_model.ImageSkychange
language_guided_video_summarization.summarizer.ClipItVideoSummarization
panorama_depth_estimation.unifuse_model.PanoramaDepthEstimation
video_stabilization.DUTRAFTStabilizer.DUTRAFTStabilizer
video_summarization.summarizer.PGLVideoSummarization
video_super_resolution.real_basicvsr_for_video_super_resolution.RealBasicVSRNetForVideoSR
vision_middleware.model.VisionMiddlewareModel

View File

@@ -0,0 +1,24 @@
modelscope.models.multi_modal
====================
.. automodule:: modelscope.models.multi_modal
.. currentmodule:: modelscope.models.multi_modal
.. autosummary::
:toctree: generated
:nosignatures:
:template: classtemplate.rst
clip.CLIPForMultiModalEmbedding
diffusion.DiffusionForTextToImageSynthesis
gemm.GEMMForMultiModalEmbedding
team.TEAMForMultiModalSimilarity
mmr.VideoCLIPForMultiModalEmbedding
mplug_for_all_tasks.MPlugForAllTasks
mplug_for_all_tasks.HiTeAForAllTasks
ofa_for_all_tasks.OfaForAllTasks
ofa_for_text_to_image_synthesis_model.OfaForTextToImageSynthesis
multi_stage_diffusion.MultiStageDiffusionForTextToImageSynthesis
vldoc.VLDocForDocVLEmbedding

View File

@@ -0,0 +1,60 @@
modelscope.models.nlp
====================
.. automodule:: modelscope.models.nlp
.. currentmodule:: modelscope.models.nlp
.. autosummary::
:toctree: generated
:nosignatures:
:template: classtemplate.rst
bart.BartForTextErrorCorrection
bert.BertConfig
bert.BertModel
bert.BertForMaskedLM
bert.BertForTextRanking
bert.BertForSentenceEmbedding
bert.BertForSequenceClassification
bert.BertForTokenClassification
bert.BertForDocumentSegmentation
csanmt.CsanmtForTranslation
deberta_v2.DebertaV2Model
deberta_v2.DebertaV2ForMaskedLM
gpt_neo.GPTNeoModel
gpt2.GPT2Model
gpt3.GPT3ForTextGeneration
gpt3.DistributedGPT3
gpt_moe.GPTMoEForTextGeneration
gpt_moe.DistributedGPTMoE
megatron_bert.MegatronBertConfig
megatron_bert.MegatronBertModel
megatron_bert.MegatronBertForMaskedLM
palm_v2.PalmForTextGeneration
ponet.PoNetConfig
ponet.PoNetModel
ponet.PoNetForMaskedLM
space.SpaceForDialogIntent
space.SpaceForDialogModeling
space.SpaceForDST
space_T_cn.TableQuestionAnswering
space_T_en.StarForTextToSql
structbert.SbertModel
structbert.SbertForMaskedLM
structbert.SbertForSequenceClassification
structbert.SbertForTokenClassification
structbert.SbertForFaqQuestionAnswering
T5.T5ForConditionalGeneration
mglm.MGLMForTextSummarization
codegeex.CodeGeeXForCodeTranslation
codegeex.CodeGeeXForCodeGeneration
veco.VecoConfig
veco.VecoModel
veco.VecoForMaskedLM
veco.VecoForSequenceClassification
veco.VecoForTokenClassification
bloom.BloomModel
unite.UniTEModel
use.UserSatisfactionEstimation

View File

@@ -12,3 +12,5 @@ modelscope.models
bases <modelscope.models.base>
builders <modelscope.models.builder>
cv <modelscope.models.cv>
nlp <modelscope.models.nlp>
multi-modal <modelscope.models.multi_modal>

View File

@@ -0,0 +1,20 @@
modelscope.pipelines.audio
=======================
.. automodule:: modelscope.pipelines.audio
.. currentmodule:: modelscope.pipelines.audio
.. autosummary::
:toctree: generated
:nosignatures:
:template: classtemplate.rst
ANSPipeline
AutomaticSpeechRecognitionPipeline
InverseTextProcessingPipeline
KWSFarfieldPipeline
KeyWordSpottingKwsbpPipeline
LinearAECPipeline
TextToSpeechSambertHifiganPipeline

View File

@@ -11,4 +11,84 @@ modelscope.pipelines.cv
:nosignatures:
:template: classtemplate.rst
ActionDetectionPipeline
ActionRecognitionPipeline
AnimalRecognitionPipeline
ArcFaceRecognitionPipeline
Body2DKeypointsPipeline
CardDetectionPipeline
CMDSSLVideoEmbeddingPipeline
CrowdCountingPipeline
DDColorImageColorizationPipeline
EasyCVDetectionPipeline
EasyCVSegmentationPipeline
Face2DKeypointsPipeline
FaceAttributeRecognitionPipeline
FaceDetectionPipeline
FaceImageGenerationPipeline
FaceLivenessIrPipeline
FaceProcessingBasePipeline
FaceRecognitionOnnxFmPipeline
FaceRecognitionOodPipeline
FaceRecognitionPipeline
FacialExpressionRecognitionPipeline
FacialLandmarkConfidencePipeline
GeneralImageClassificationPipeline
GeneralRecognitionPipeline
HICOSSLVideoEmbeddingPipeline
Hand2DKeypointsPipeline
HandStaticPipeline
HumanWholebodyKeypointsPipeline
Image2ImageGenerationPipeline
Image2ImageTranslationPipeline
ImageCartoonPipeline
ImageClassificationPipeline
ImageColorEnhancePipeline
ImageColorizationPipeline
ImageDeblurPipeline
ImageDefrcnDetectionPipeline
ImageDenoisePipeline
ImageDetectionPipeline
ImageInpaintingPipeline
ImageInstanceSegmentationPipeline
ImageMatchingPipeline
ImageMattingPipeline
ImageMultiViewDepthEstimationPipeline
ImagePanopticSegmentationEasyCVPipeline
ImagePanopticSegmentationPipeline
ImagePortraitEnhancementPipeline
ImageReidPersonPipeline
ImageSalientDetectionPipeline
ImageSemanticSegmentationPipeline
ImageSkychangePipeline
ImageStyleTransferPipeline
ImageSuperResolutionPipeline
LanguageGuidedVideoSummarizationPipeline
LicensePlateDetectionPipeline
LiveCategoryPipeline
MaskDINOInstanceSegmentationPipeline
MaskFaceRecognitionPipeline
MogFaceDetectionPipeline
MovieSceneSegmentationPipeline
MtcnnFaceDetectionPipeline
OCRDetectionPipeline
OCRRecognitionPipeline
PointCloudSceneFlowEstimationPipeline
ProductRetrievalEmbeddingPipeline
RealtimeObjectDetectionPipeline
ReferringVideoObjectSegmentationPipeline
RetinaFaceDetectionPipeline
ShopSegmentationPipeline
SkinRetouchingPipeline
TableRecognitionPipeline
TextDrivenSegmentationPipeline
TinynasClassificationPipeline
UlfdFaceDetectionPipeline
VideoCategoryPipeline
VideoFrameInterpolationPipeline
VideoObjectSegmentationPipeline
VideoStabilizationPipeline
VideoSuperResolutionPipeline
VirtualTryonPipeline
VisionMiddlewarePipeline
VopRetrievalPipeline

View File

@@ -0,0 +1,28 @@
modelscope.pipelines.multi_modal
=======================
.. automodule:: modelscope.pipelines.multi_modal
.. currentmodule:: modelscope.pipelines.multi_modal
.. autosummary::
:toctree: generated
:nosignatures:
:template: classtemplate.rst
AutomaticSpeechRecognitionPipeline
ChineseStableDiffusionPipeline
DocumentVLEmbeddingPipeline
GEMMMultiModalEmbeddingPipeline
ImageCaptioningPipeline
MGeoRankingPipeline
MultiModalEmbeddingPipeline
StableDiffusionWrapperPipeline
TextToImageSynthesisPipeline
VideoCaptioningPipeline
VideoMultiModalEmbeddingPipeline
VideoQuestionAnsweringPipeline
VisualEntailmentPipeline
VisualGroundingPipeline
VisualQuestionAnsweringPipeline

View File

@@ -0,0 +1,45 @@
modelscope.pipelines.nlp
=======================
.. automodule:: modelscope.pipelines.nlp
.. currentmodule:: modelscope.pipelines.nlp
.. autosummary::
:toctree: generated
:nosignatures:
:template: classtemplate.rst
AutomaticPostEditingPipeline
CodeGeeXCodeGenerationPipeline
CodeGeeXCodeTranslationPipeline
ConversationalTextToSqlPipeline
DialogIntentPredictionPipeline
DialogModelingPipeline
DialogStateTrackingPipeline
DocumentSegmentationPipeline
ExtractiveSummarizationPipeline
FaqQuestionAnsweringPipeline
FasttextSequenceClassificationPipeline
FeatureExtractionPipeline
FillMaskPipeline
InformationExtractionPipeline
MGLMTextSummarizationPipeline
NamedEntityRecognitionPipeline
SentenceEmbeddingPipeline
SummarizationPipeline
TableQuestionAnsweringPipeline
TextClassificationPipeline
TextErrorCorrectionPipeline
TextGenerationPipeline
TextGenerationT5Pipeline
TextRankingPipeline
TokenClassificationPipeline
TranslationEvaluationPipeline
TranslationPipeline
TranslationQualityEstimationPipeline
UserSatisfactionEstimationPipeline
WordSegmentationPipeline
WordSegmentationThaiPipeline
ZeroShotClassificationPipeline

View File

@@ -12,3 +12,7 @@ modelscope.pipelines
base <modelscope.pipelines.base>
builder <modelscope.pipelines.builder>
cv <modelscope.pipelines.cv>
nlp <modelscope.pipelines.nlp>
multi-modal <modelscope.pipelines.multi-modal>
audio <modelscope.pipelines.audio>
science <modelscope.pipelines.science>

View File

@@ -0,0 +1,14 @@
modelscope.pipelines.science
=======================
.. automodule:: modelscope.pipelines.science
.. currentmodule:: modelscope.pipelines.science
.. autosummary::
:toctree: generated
:nosignatures:
:template: classtemplate.rst
ProteinStructurePipeline

View File

@@ -0,0 +1,44 @@
modelscope.preprocessors.nlp
====================
.. automodule:: modelscope.preprocessors.nlp
.. currentmodule:: modelscope.preprocessors.nlp
.. autosummary::
:toctree: generated
:nosignatures:
:template: classtemplate.rst
TextErrorCorrectionPreprocessor
TextGenerationJiebaPreprocessor
DocumentSegmentationTransformersPreprocessor
FaqQuestionAnsweringTransformersPreprocessor
FillMaskPoNetPreprocessor
FillMaskTransformersPreprocessor
TextRankingTransformersPreprocessor
RelationExtractionTransformersPreprocessor
TextClassificationTransformersPreprocessor
SentenceEmbeddingTransformersPreprocessor
TextGenerationTransformersPreprocessor
TextGenerationT5Preprocessor
TextGenerationSentencePiecePreprocessor
SentencePiecePreprocessor
TokenClassificationTransformersPreprocessor
WordSegmentationBlankSetToLabelPreprocessor
WordSegmentationPreprocessorThai
NERPreprocessorThai
NERPreprocessorViet
ZeroShotClassificationTransformersPreprocessor
DialogIntentPredictionPreprocessor
DialogModelingPreprocessor
DialogStateTrackingPreprocessor
InputFeatures
MultiWOZBPETextField
IntentBPETextField
ConversationalTextToSqlPreprocessor
TableQuestionAnsweringPreprocessor
MGLMSummarizationPreprocessor
TranslationEvaluationPreprocessor
DialogueClassificationUsePreprocessor

View File

@@ -12,3 +12,4 @@ modelscope.preprocessors
base <modelscope.preprocessors.base>
builders <modelscope.preprocessors.builder>
video <modelscope.preprocessors.video>
nlp <modelscope.preprocessors.nlp>

View File

@@ -0,0 +1,29 @@
modelscope.trainers.hooks
=======================
.. automodule:: modelscope.trainers.
.. currentmodule:: modelscope.trainers.hooks
.. autosummary::
:toctree: generated
:nosignatures:
:template: classtemplate.rst
builder.build_hook
hook.Hook
priority.Priority
checkpoint_hook.CheckpointHook
checkpoint_hook.BestCkptSaverHook
compression.SparsityHook
evaluation_hook.EvaluationHook
iter_timer_hook.IterTimerHook
logger.TensorboardHook
logger.TextLoggerHook
lr_scheduler_hook.LrSchedulerHook
lr_scheduler_hook.NoneLrSchedulerHook
optimizer.OptimizerHook
optimizer.NoneOptimizerHook
optimizer.ApexAMPOptimizerHook
optimizer.TorchAMPOptimizerHook

View File

@@ -0,0 +1,18 @@
modelscope.trainers.multi_modal
=======================
.. automodule:: modelscope.trainers.multi_modal
.. currentmodule:: modelscope.trainers.multi_modal
.. autosummary::
:toctree: generated
:nosignatures:
:template: classtemplate.rst
clip.CLIPTrainer
team.TEAMImgClsTrainer
ofa.OFATrainer
mplug.MPlugTrainer
mgeo_ranking_trainer.MGeoRankingTrainer

View File

@@ -0,0 +1,17 @@
modelscope.trainers.nlp
=======================
.. automodule:: modelscope.trainers.nlp
.. currentmodule:: modelscope.trainers.nlp
.. autosummary::
:toctree: generated
:nosignatures:
:template: classtemplate.rst
sequence_classification_trainer.SequenceClassificationTrainer
csanmt_translation_trainer.CsanmtTranslationTrainer
text_ranking_trainer.TextRankingTrainer
text_generation_trainer.TextGenerationTrainer

View File

@@ -12,4 +12,7 @@ modelscope.trainers
base <modelscope.trainers.base>
builder <modelscope.trainers.builder>
EpochBasedTrainer <modelscope.trainers.trainer>
Hooks <modelscope.trainers.hooks>
cv <modelscope.trainers.cv>
nlp <modelscope.trainers.nlp>
multi-modal <modelscope.trainers.multi_modal>

View File

@@ -13,7 +13,7 @@
import os
import sys
import sphinx_book_theme
# import sphinx_book_theme
sys.path.insert(0, os.path.abspath('../../'))
# -- Project information -----------------------------------------------------

View File

@@ -22,11 +22,6 @@ ModelScope DOCUMENTATION
Trainer <api/modelscope.trainers>
MsDataset <api/modelscope.msdatasets>
.. toctree::
:maxdepth: 2
:caption: Changelog
change_log.md
Indices and tables
==================

View File

@@ -1,51 +1,43 @@
import os
from dataclasses import dataclass, field
from modelscope.metainfo import Trainers
from modelscope.msdatasets.ms_dataset import MsDataset
from modelscope.trainers.builder import build_trainer
from modelscope.trainers.training_args import (ArgAttr, CliArgumentParser,
training_args)
from modelscope.trainers.training_args import TrainingArgs
def define_parser():
training_args.num_classes = ArgAttr(
cfg_node_name=[
'model.mm_model.head.num_classes',
'model.mm_model.train_cfg.augments.0.num_classes',
'model.mm_model.train_cfg.augments.1.num_classes'
],
type=int,
help='number of classes')
@dataclass
class ImageClassificationTrainingArgs(TrainingArgs):
num_classes: int = field(
default=None,
metadata={
'cfg_node': [
'model.mm_model.head.num_classes',
'model.mm_model.train_cfg.augments.0.num_classes',
'model.mm_model.train_cfg.augments.1.num_classes'
],
'help':
'number of classes',
})
training_args.train_batch_size.default = 16
training_args.train_data_worker.default = 1
training_args.max_epochs.default = 1
training_args.optimizer.default = 'AdamW'
training_args.lr.default = 1e-4
training_args.warmup_iters = ArgAttr(
'train.lr_config.warmup_iters',
type=int,
default=1,
help='number of warmup epochs')
training_args.topk = ArgAttr(
cfg_node_name=[
'train.evaluation.metric_options.topk',
'evaluation.metric_options.topk'
],
default=(1, ),
help='evaluation using topk, tuple format, eg (1,), (1,5)')
topk: tuple = field(
default=None,
metadata={
'cfg_node': [
'train.evaluation.metric_options.topk',
'evaluation.metric_options.topk'
],
'help':
'evaluation using topk, tuple format, eg (1,), (1,5)',
})
training_args.train_data = ArgAttr(
type=str, default='tany0699/cats_and_dogs', help='train dataset')
training_args.validation_data = ArgAttr(
type=str, default='tany0699/cats_and_dogs', help='validation dataset')
training_args.model_id = ArgAttr(
type=str,
default='damo/cv_vit-base_image-classification_ImageNet-labels',
help='model name')
parser = CliArgumentParser(training_args)
return parser
warmup_iters: str = field(
default=None,
metadata={
'cfg_node': 'train.lr_config.warmup_iters',
'help': 'The warmup iters',
})
def create_dataset(name, split):
@@ -54,21 +46,26 @@ def create_dataset(name, split):
dataset_name, namespace=namespace, subset_name='default', split=split)
def train(parser):
cfg_dict = parser.get_cfg_dict()
args = parser.args
train_dataset = create_dataset(args.train_data, split='train')
val_dataset = create_dataset(args.validation_data, split='validation')
def cfg_modify_fn(cfg):
cfg.merge_from_dict(cfg_dict)
return cfg
def train():
args = ImageClassificationTrainingArgs.from_cli(
model='damo/cv_vit-base_image-classification_ImageNet-labels',
max_epochs=1,
lr=1e-4,
optimizer='AdamW',
warmup_iters=1,
topk=(1, ))
if args.dataset_name is not None:
train_dataset = create_dataset(args.dataset_name, split='train')
val_dataset = create_dataset(args.dataset_name, split='validation')
else:
train_dataset = create_dataset(args.train_dataset_name, split='train')
val_dataset = create_dataset(args.val_dataset_name, split='validation')
kwargs = dict(
model=args.model_id, # model id
model=args.model, # model id
train_dataset=train_dataset, # training dataset
eval_dataset=val_dataset, # validation dataset
cfg_modify_fn=cfg_modify_fn # callback to modify configuration
cfg_modify_fn=args # callback to modify configuration
)
# in distributed training, specify pytorch launcher
@@ -82,5 +79,4 @@ def train(parser):
if __name__ == '__main__':
parser = define_parser()
train(parser)
train()

View File

@@ -1,5 +1,5 @@
PYTHONPATH=. python -m torch.distributed.launch --nproc_per_node=2 \
examples/pytorch/finetune_image_classification.py \
--num_classes 2 \
--train_data 'tany0699/cats_and_dogs' \
--validation_data 'tany0699/cats_and_dogs'
--train_dataset_name 'tany0699/cats_and_dogs' \
--val_dataset_name 'tany0699/cats_and_dogs'

View File

@@ -0,0 +1,90 @@
import os
from dataclasses import dataclass, field
from modelscope.msdatasets import MsDataset
from modelscope.trainers import EpochBasedTrainer, build_trainer
from modelscope.trainers.training_args import TrainingArgs
def get_labels(cfg, metadata):
label2id = cfg.safe_get(metadata['cfg_node'])
if label2id is not None:
return ','.join(label2id.keys())
def set_labels(cfg, labels, metadata):
if isinstance(labels, str):
labels = labels.split(',')
cfg.merge_from_dict(
{metadata['cfg_node']: {label: id
for id, label in enumerate(labels)}})
@dataclass
class TextClassificationArguments(TrainingArgs):
first_sequence: str = field(
default=None,
metadata={
'help': 'The first sequence key of preprocessor',
'cfg_node': 'preprocessor.first_sequence'
})
second_sequence: str = field(
default=None,
metadata={
'help': 'The second sequence key of preprocessor',
'cfg_node': 'preprocessor.second_sequence'
})
label: str = field(
default=None,
metadata={
'help': 'The label key of preprocessor',
'cfg_node': 'preprocessor.label'
})
labels: str = field(
default=None,
metadata={
'help': 'The labels of the dataset',
'cfg_node': 'preprocessor.label2id',
'cfg_getter': get_labels,
'cfg_setter': set_labels,
})
preprocessor: str = field(
default=None,
metadata={
'help': 'The preprocessor type',
'cfg_node': 'preprocessor.type'
})
def __call__(self, config):
config = super().__call__(config)
config.model['num_labels'] = len(self.labels)
if config.train.lr_scheduler.type == 'LinearLR':
config.train.lr_scheduler['total_iters'] = \
int(len(train_dataset) / self.per_device_train_batch_size) * self.max_epochs
return config
args = TextClassificationArguments.from_cli(
task='text-classification', eval_metrics='seq-cls-metric')
print(args)
dataset = MsDataset.load(args.dataset_name, subset_name=args.subset_name)
train_dataset = dataset['train']
validation_dataset = dataset['validation']
kwargs = dict(
model=args.model,
train_dataset=train_dataset,
eval_dataset=validation_dataset,
seed=args.seed,
cfg_modify_fn=args)
os.environ['LOCAL_RANK'] = str(args.local_rank)
trainer: EpochBasedTrainer = build_trainer(name='trainer', default_args=kwargs)
trainer.train()

View File

@@ -0,0 +1,12 @@
PYTHONPATH=. python examples/pytorch/text_classification/finetune_text_classification.py \
--model 'damo/nlp_structbert_backbone_base_std' \
--dataset_name 'clue' \
--subset_name 'tnews' \
--first_sequence 'sentence' \
--preprocessor.label label \
--model.num_labels 15 \
--labels '0,1,2,3,4,5,6,7,8,9,10,11,12,13,14' \
--preprocessor 'sen-cls-tokenizer' \
--train.dataloader.workers_per_gpu 0 \
--evaluation.dataloader.workers_per_gpu 0 \
--train.optimizer.lr 1e-5 \

View File

@@ -0,0 +1 @@
{"framework":"pytorch","train":{"work_dir":"/tmp","max_epochs":10,"dataloader":{"batch_size_per_gpu":16,"workers_per_gpu":0},"optimizer":{"type":"SGD","lr":0.001},"lr_scheduler":{"type":"StepLR","step_size":2},"hooks":[{"type":"CheckpointHook","interval":1}]},"evaluation":{"dataloader":{"batch_size_per_gpu":16,"workers_per_gpu":0,"shuffle":false}}}

View File

@@ -0,0 +1,57 @@
import os
from dataclasses import dataclass, field
from datasets import load_dataset
from transformers import (BertForSequenceClassification, BertTokenizerFast,
default_data_collator)
from modelscope.trainers import EpochBasedTrainer, build_trainer
from modelscope.trainers.default_config import DEFAULT_CONFIG, TrainingArgs
@dataclass
class TransformersArguments(TrainingArgs):
num_labels: int = field(
default=None, metadata={
'help': 'The number of labels',
})
args = TransformersArguments.from_cli(
task='text-classification', eval_metrics='seq-cls-metric')
print(args)
dataset = load_dataset(args.dataset_name, args.subset_name)
model = BertForSequenceClassification.from_pretrained(
args.model, num_labels=args.num_labels)
tokenizer = BertTokenizerFast.from_pretrained(args.model)
def tokenize_sentence(row):
return tokenizer(row['sentence'], padding='max_length', max_length=128)
# Extra columns, Rename columns
dataset = dataset.map(tokenize_sentence).remove_columns(['sentence',
'idx']).rename_column(
'label', 'labels')
cfg_file = os.path.join(args.work_dir or './', 'configuration.json')
DEFAULT_CONFIG.dump(cfg_file)
kwargs = dict(
model=model,
cfg_file=cfg_file,
# data_collator
data_collator=default_data_collator,
train_dataset=dataset['train'],
eval_dataset=dataset['validation'],
seed=args.seed,
cfg_modify_fn=args)
os.environ['LOCAL_RANK'] = str(args.local_rank)
trainer: EpochBasedTrainer = build_trainer(name='trainer', default_args=kwargs)
trainer.train()

View File

@@ -0,0 +1,5 @@
PYTHONPATH=. python examples/pytorch/transformers/finetune_transformers_model.py \
--model bert-base-uncased \
--num_labels 15 \
--dataset_name clue \
--subset_name tnews

View File

@@ -1,5 +1,12 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from modelscope.utils.import_utils import is_tf_available, is_torch_available
from .base import Exporter
from .builder import build_exporter
from .nlp import SbertForSequenceClassificationExporter
from .tf_model_exporter import TfModelExporter
from .torch_model_exporter import TorchModelExporter
if is_tf_available():
from .nlp import CsanmtForTranslationExporter
from .tf_model_exporter import TfModelExporter
if is_torch_available():
from .nlp import SbertForSequenceClassificationExporter, SbertForZeroShotClassificationExporter
from .torch_model_exporter import TorchModelExporter

View File

@@ -6,9 +6,11 @@ from typing import Dict, Union
from modelscope.models import Model
from modelscope.utils.config import Config, ConfigDict
from modelscope.utils.constant import ModelFile
from modelscope.utils.hub import snapshot_download
from modelscope.utils.logger import get_logger
from .builder import build_exporter
logger = get_logger(__name__)
class Exporter(ABC):
"""Exporter base class to output model to onnx, torch_script, graphdef, etc.
@@ -46,7 +48,12 @@ class Exporter(ABC):
if hasattr(cfg, 'export'):
export_cfg.update(cfg.export)
export_cfg['model'] = model
exporter = build_exporter(export_cfg, task_name, kwargs)
try:
exporter = build_exporter(export_cfg, task_name, kwargs)
except KeyError as e:
raise KeyError(
f'The exporting of model \'{model_cfg.type}\' with task: \'{task_name}\' '
f'is not supported currently.') from e
return exporter
@abstractmethod

View File

@@ -1,2 +1,11 @@
from .sbert_for_sequence_classification_exporter import \
SbertForSequenceClassificationExporter
# Copyright (c) Alibaba, Inc. and its affiliates.
from modelscope.utils.import_utils import is_tf_available, is_torch_available
if is_tf_available():
from .csanmt_for_translation_exporter import CsanmtForTranslationExporter
if is_torch_available():
from .sbert_for_sequence_classification_exporter import \
SbertForSequenceClassificationExporter
from .sbert_for_zero_shot_classification_exporter import \
SbertForZeroShotClassificationExporter

View File

@@ -0,0 +1,185 @@
import os
from typing import Any, Dict
import tensorflow as tf
from tensorflow.python.framework import ops
from tensorflow.python.saved_model import tag_constants
from tensorflow.python.tools import freeze_graph
from modelscope.exporters.builder import EXPORTERS
from modelscope.exporters.tf_model_exporter import TfModelExporter
from modelscope.metainfo import Models
from modelscope.pipelines.nlp.translation_pipeline import TranslationPipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger
from modelscope.utils.test_utils import compare_arguments_nested
logger = get_logger(__name__)
if tf.__version__ >= '2.0':
tf = tf.compat.v1
tf.disable_eager_execution()
tf.logging.set_verbosity(tf.logging.INFO)
@EXPORTERS.register_module(Tasks.translation, module_name=Models.translation)
class CsanmtForTranslationExporter(TfModelExporter):
def __init__(self, model=None):
super().__init__(model)
self.pipeline = TranslationPipeline(self.model)
def generate_dummy_inputs(self, **kwargs) -> Dict[str, Any]:
return_dict = self.pipeline.preprocess(
"Alibaba Group's mission is to let the world have no difficult business"
)
return {'input_wids': return_dict['input_ids']}
def export_saved_model(self, output_dir, rtol=None, atol=None, **kwargs):
def _generate_signature():
receiver_tensors = {
'input_wids':
tf.saved_model.utils.build_tensor_info(
self.pipeline.input_wids)
}
export_outputs = {
'output_seqs':
tf.saved_model.utils.build_tensor_info(
self.pipeline.output['output_seqs'])
}
signature_def = tf.saved_model.signature_def_utils.build_signature_def(
receiver_tensors, export_outputs,
tf.saved_model.signature_constants.PREDICT_METHOD_NAME)
return {'translation_signature': signature_def}
with self.pipeline._session.as_default() as sess:
builder = tf.saved_model.builder.SavedModelBuilder(output_dir)
builder.add_meta_graph_and_variables(
sess, [tag_constants.SERVING],
signature_def_map=_generate_signature(),
assets_collection=ops.get_collection(
ops.GraphKeys.ASSET_FILEPATHS),
clear_devices=True)
builder.save()
dummy_inputs = self.generate_dummy_inputs()
with tf.Session(graph=tf.Graph()) as sess:
# Restore model from the saved_modle file, that is exported by TensorFlow estimator.
MetaGraphDef = tf.saved_model.loader.load(sess, ['serve'],
output_dir)
# SignatureDef protobuf
SignatureDef_map = MetaGraphDef.signature_def
SignatureDef = SignatureDef_map['translation_signature']
# TensorInfo protobuf
X_TensorInfo = SignatureDef.inputs['input_wids']
y_TensorInfo = SignatureDef.outputs['output_seqs']
X = tf.saved_model.utils.get_tensor_from_tensor_info(
X_TensorInfo, sess.graph)
y = tf.saved_model.utils.get_tensor_from_tensor_info(
y_TensorInfo, sess.graph)
outputs = sess.run(y, feed_dict={X: dummy_inputs['input_wids']})
trans_result = self.pipeline.postprocess({'output_seqs': outputs})
logger.info(trans_result)
outputs_origin = self.pipeline.forward(
{'input_ids': dummy_inputs['input_wids']})
tols = {}
if rtol is not None:
tols['rtol'] = rtol
if atol is not None:
tols['atol'] = atol
if not compare_arguments_nested('Output match failed', outputs,
outputs_origin['output_seqs'], **tols):
raise RuntimeError(
'Export saved model failed because of validation error.')
return {'model': output_dir}
def export_frozen_graph_def(self,
output_dir: str,
rtol=None,
atol=None,
**kwargs):
input_saver_def = self.pipeline.model_loader.as_saver_def()
inference_graph_def = tf.get_default_graph().as_graph_def()
for node in inference_graph_def.node:
node.device = ''
frozen_dir = os.path.join(output_dir, 'frozen')
tf.gfile.MkDir(frozen_dir)
frozen_graph_path = os.path.join(frozen_dir,
'frozen_inference_graph.pb')
outputs = {
'output_trans_result':
tf.identity(
self.pipeline.output['output_seqs'],
name='NmtModel/output_trans_result')
}
for output_key in outputs:
tf.add_to_collection('inference_op', outputs[output_key])
output_node_names = ','.join([
'%s/%s' % ('NmtModel', output_key)
for output_key in outputs.keys()
])
print(output_node_names)
_ = freeze_graph.freeze_graph_with_def_protos(
input_graph_def=tf.get_default_graph().as_graph_def(),
input_saver_def=input_saver_def,
input_checkpoint=self.pipeline.model_path,
output_node_names=output_node_names,
restore_op_name='save/restore_all',
filename_tensor_name='save/Const:0',
output_graph=frozen_graph_path,
clear_devices=True,
initializer_nodes='')
# 5. test frozen.pb
dummy_inputs = self.generate_dummy_inputs()
with self.pipeline._session.as_default() as sess:
sess.run(tf.tables_initializer())
graph = tf.Graph()
with tf.gfile.GFile(frozen_graph_path, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
with graph.as_default():
tf.import_graph_def(graph_def, name='')
graph.finalize()
with tf.Session(graph=graph) as trans_sess:
outputs = trans_sess.run(
'NmtModel/strided_slice_9:0',
feed_dict={'input_wids:0': dummy_inputs['input_wids']})
trans_result = self.pipeline.postprocess(
{'output_seqs': outputs})
logger.info(trans_result)
outputs_origin = self.pipeline.forward(
{'input_ids': dummy_inputs['input_wids']})
tols = {}
if rtol is not None:
tols['rtol'] = rtol
if atol is not None:
tols['atol'] = atol
if not compare_arguments_nested('Output match failed', outputs,
outputs_origin['output_seqs'], **tols):
raise RuntimeError(
'Export frozen graphdef failed because of validation error.')
return {'model': frozen_graph_path}
def export_onnx(self, output_dir: str, opset=13, **kwargs):
raise NotImplementedError(
'csanmt model does not support onnx format, consider using savedmodel instead.'
)

View File

@@ -1,4 +1,3 @@
import os
from collections import OrderedDict
from typing import Any, Dict, Mapping, Tuple
@@ -7,9 +6,7 @@ from torch.utils.data.dataloader import default_collate
from modelscope.exporters.builder import EXPORTERS
from modelscope.exporters.torch_model_exporter import TorchModelExporter
from modelscope.metainfo import Models
from modelscope.preprocessors import (
Preprocessor, TextClassificationTransformersPreprocessor,
build_preprocessor)
from modelscope.preprocessors import Preprocessor
from modelscope.utils.constant import ModeKeys, Tasks
@@ -17,8 +14,6 @@ from modelscope.utils.constant import ModeKeys, Tasks
@EXPORTERS.register_module(
Tasks.text_classification, module_name=Models.structbert)
@EXPORTERS.register_module(Tasks.sentence_similarity, module_name=Models.bert)
@EXPORTERS.register_module(
Tasks.zero_shot_classification, module_name=Models.bert)
@EXPORTERS.register_module(
Tasks.sentiment_classification, module_name=Models.bert)
@EXPORTERS.register_module(Tasks.nli, module_name=Models.bert)
@@ -27,8 +22,6 @@ from modelscope.utils.constant import ModeKeys, Tasks
@EXPORTERS.register_module(
Tasks.sentiment_classification, module_name=Models.structbert)
@EXPORTERS.register_module(Tasks.nli, module_name=Models.structbert)
@EXPORTERS.register_module(
Tasks.zero_shot_classification, module_name=Models.structbert)
class SbertForSequenceClassificationExporter(TorchModelExporter):
def generate_dummy_inputs(self,

View File

@@ -0,0 +1,58 @@
from collections import OrderedDict
from typing import Any, Dict, Mapping
from modelscope.exporters.builder import EXPORTERS
from modelscope.exporters.torch_model_exporter import TorchModelExporter
from modelscope.metainfo import Models
from modelscope.preprocessors import Preprocessor
from modelscope.utils.constant import Tasks
@EXPORTERS.register_module(
Tasks.zero_shot_classification, module_name=Models.bert)
@EXPORTERS.register_module(
Tasks.zero_shot_classification, module_name=Models.structbert)
class SbertForZeroShotClassificationExporter(TorchModelExporter):
def generate_dummy_inputs(self,
candidate_labels,
hypothesis_template,
max_length=128,
pair: bool = False,
**kwargs) -> Dict[str, Any]:
"""Generate dummy inputs for model exportation to onnx or other formats by tracing.
Args:
max_length(int): The max length of sentence, default 128.
hypothesis_template(str): The template of prompt, like '这篇文章的标题是{}'
candidate_labels(List): The labels of prompt,
like ['文化', '体育', '娱乐', '财经', '家居', '汽车', '教育', '科技', '军事']
pair(bool, `optional`): Whether to generate sentence pairs or single sentences.
Returns:
Dummy inputs.
"""
assert hasattr(
self.model, 'model_dir'
), 'model_dir attribute is required to build the preprocessor'
preprocessor = Preprocessor.from_pretrained(
self.model.model_dir, max_length=max_length)
return preprocessor(
preprocessor.nlp_tokenizer.tokenizer.unk_token,
candidate_labels=candidate_labels,
hypothesis_template=hypothesis_template)
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
dynamic_axis = {0: 'batch', 1: 'sequence'}
return OrderedDict([
('input_ids', dynamic_axis),
('attention_mask', dynamic_axis),
('token_type_ids', dynamic_axis),
])
@property
def outputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict({'logits': {0: 'batch'}})

View File

@@ -1,5 +1,6 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
from abc import abstractmethod
from typing import Any, Callable, Dict, Mapping
import tensorflow as tf
@@ -7,7 +8,7 @@ import tensorflow as tf
from modelscope.outputs import ModelOutputBase
from modelscope.utils.constant import ModelFile
from modelscope.utils.logger import get_logger
from modelscope.utils.regress_test_utils import compare_arguments_nested
from modelscope.utils.test_utils import compare_arguments_nested
from .base import Exporter
logger = get_logger()
@@ -29,6 +30,14 @@ class TfModelExporter(Exporter):
self._tf2_export_onnx(model, onnx_file, opset=opset, **kwargs)
return {'model': onnx_file}
@abstractmethod
def export_saved_model(self, output_dir: str, **kwargs):
pass
@abstractmethod
def export_frozen_graph_def(self, output_dir: str, **kwargs):
pass
def _tf2_export_onnx(self,
model,
output: str,
@@ -59,56 +68,67 @@ class TfModelExporter(Exporter):
onnx.save(onnx_model, output)
if validation:
try:
import onnx
import onnxruntime as ort
except ImportError:
logger.warn(
'Cannot validate the exported onnx file, because '
'the installation of onnx or onnxruntime cannot be found')
return
self._validate_model(dummy_inputs, model, output, rtol, atol,
call_func)
def tensor_nested_numpify(tensors):
if isinstance(tensors, (list, tuple)):
return type(tensors)(
tensor_nested_numpify(t) for t in tensors)
if isinstance(tensors, Mapping):
# return dict
return {
k: tensor_nested_numpify(t)
for k, t in tensors.items()
}
if isinstance(tensors, tf.Tensor):
t = tensors.cpu()
return t.numpy()
return tensors
def _validate_model(
self,
dummy_inputs,
model,
output,
rtol: float = None,
atol: float = None,
call_func: Callable = None,
):
try:
import onnx
import onnxruntime as ort
except ImportError:
logger.warn(
'Cannot validate the exported onnx file, because '
'the installation of onnx or onnxruntime cannot be found')
return
onnx_model = onnx.load(output)
onnx.checker.check_model(onnx_model)
ort_session = ort.InferenceSession(output)
outputs_origin = call_func(
dummy_inputs) if call_func is not None else model(dummy_inputs)
if isinstance(outputs_origin, (Mapping, ModelOutputBase)):
outputs_origin = list(
tensor_nested_numpify(outputs_origin).values())
elif isinstance(outputs_origin, (tuple, list)):
outputs_origin = list(tensor_nested_numpify(outputs_origin))
outputs = ort_session.run(
None,
tensor_nested_numpify(dummy_inputs),
)
outputs = tensor_nested_numpify(outputs)
if isinstance(outputs, dict):
outputs = list(outputs.values())
elif isinstance(outputs, tuple):
outputs = list(outputs)
def tensor_nested_numpify(tensors):
if isinstance(tensors, (list, tuple)):
return type(tensors)(tensor_nested_numpify(t) for t in tensors)
if isinstance(tensors, Mapping):
# return dict
return {
k: tensor_nested_numpify(t)
for k, t in tensors.items()
}
if isinstance(tensors, tf.Tensor):
t = tensors.cpu()
return t.numpy()
return tensors
tols = {}
if rtol is not None:
tols['rtol'] = rtol
if atol is not None:
tols['atol'] = atol
if not compare_arguments_nested('Onnx model output match failed',
outputs, outputs_origin, **tols):
raise RuntimeError(
'export onnx failed because of validation error.')
onnx_model = onnx.load(output)
onnx.checker.check_model(onnx_model, full_check=True)
ort_session = ort.InferenceSession(output)
outputs_origin = call_func(
dummy_inputs) if call_func is not None else model(dummy_inputs)
if isinstance(outputs_origin, (Mapping, ModelOutputBase)):
outputs_origin = list(
tensor_nested_numpify(outputs_origin).values())
elif isinstance(outputs_origin, (tuple, list)):
outputs_origin = list(tensor_nested_numpify(outputs_origin))
outputs = ort_session.run(
None,
tensor_nested_numpify(dummy_inputs),
)
outputs = tensor_nested_numpify(outputs)
if isinstance(outputs, dict):
outputs = list(outputs.values())
elif isinstance(outputs, tuple):
outputs = list(outputs)
tols = {}
if rtol is not None:
tols['rtol'] = rtol
if atol is not None:
tols['atol'] = atol
if not compare_arguments_nested('Onnx model output match failed',
outputs, outputs_origin, **tols):
raise RuntimeError(
'export onnx failed because of validation error.')

View File

@@ -27,9 +27,9 @@ def load(file, file_format=None, **kwargs):
Currently supported formats include "json", "yaml/yml".
Examples:
>>> load('/path/of/your/file') # file is storaged in disk
>>> load('https://path/of/your/file') # file is storaged in Internet
>>> load('oss://path/of/your/file') # file is storaged in petrel
>>> load('/path/of/your/file') # file is stored in disk
>>> load('https://path/of/your/file') # file is stored on internet
>>> load('oss://path/of/your/file') # file is stored in petrel
Returns:
The content from the file.

View File

@@ -1,15 +1,13 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import pickle
from typing import Dict, Optional, Union
from urllib.parse import urlparse
from modelscope.hub.api import HubApi, ModelScopeConfig
from modelscope.hub.constants import (FILE_HASH, MODEL_META_FILE_NAME,
MODEL_META_MODEL_ID)
from modelscope.hub.constants import FILE_HASH
from modelscope.hub.git import GitCommandWrapper
from modelscope.hub.utils.caching import FileSystemCache, ModelFileSystemCache
from modelscope.hub.utils.caching import ModelFileSystemCache
from modelscope.hub.utils.utils import compute_hash
from modelscope.utils.logger import get_logger

View File

@@ -1,4 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from modelscope.utils.constant import Fields, Tasks
class Models(object):
@@ -7,35 +8,44 @@ class Models(object):
Holds the standard model name to use for identifying different model.
This should be used to register models.
Model name should only contain model info but not task info.
Model name should only contain model information but not task information.
"""
# tinynas models
tinynas_detection = 'tinynas-detection'
tinynas_damoyolo = 'tinynas-damoyolo'
# vision models
detection = 'detection'
mask_scoring = 'MaskScoring'
image_restoration = 'image-restoration'
realtime_object_detection = 'realtime-object-detection'
realtime_video_object_detection = 'realtime-video-object-detection'
scrfd = 'scrfd'
depe = 'depe'
classification_model = 'ClassificationModel'
easyrobust_model = 'EasyRobustModel'
bnext = 'bnext'
yolopv2 = 'yolopv2'
nafnet = 'nafnet'
csrnet = 'csrnet'
adaint = 'adaint'
deeplpfnet = 'deeplpfnet'
rrdb = 'rrdb'
cascade_mask_rcnn_swin = 'cascade_mask_rcnn_swin'
maskdino_swin = 'maskdino_swin'
gpen = 'gpen'
product_retrieval_embedding = 'product-retrieval-embedding'
body_2d_keypoints = 'body-2d-keypoints'
body_3d_keypoints = 'body-3d-keypoints'
body_3d_keypoints_hdformer = 'hdformer'
crowd_counting = 'HRNetCrowdCounting'
face_2d_keypoints = 'face-2d-keypoints'
panoptic_segmentation = 'swinL-panoptic-segmentation'
r50_panoptic_segmentation = 'r50-panoptic-segmentation'
image_reid_person = 'passvitb'
image_inpainting = 'FFTInpainting'
image_paintbyexample = 'Stablediffusion-Paintbyexample'
video_summarization = 'pgl-video-summarization'
video_panoptic_segmentation = 'swinb-video-panoptic-segmentation'
language_guided_video_summarization = 'clip-it-language-guided-video-summarization'
swinL_semantic_segmentation = 'swinL-semantic-segmentation'
vitadapter_semantic_segmentation = 'vitadapter-semantic-segmentation'
@@ -70,6 +80,7 @@ class Models(object):
video_human_matting = 'video-human-matting'
video_frame_interpolation = 'video-frame-interpolation'
video_object_segmentation = 'video-object-segmentation'
video_deinterlace = 'video-deinterlace'
quadtree_attention_image_matching = 'quadtree-attention-image-matching'
vision_middleware = 'vision-middleware'
video_stabilization = 'video-stabilization'
@@ -78,14 +89,31 @@ class Models(object):
image_casmvs_depth_estimation = 'image-casmvs-depth-estimation'
vop_retrieval_model = 'vop-retrieval-model'
ddcolor = 'ddcolor'
image_probing_model = 'image-probing-model'
defrcn = 'defrcn'
image_face_fusion = 'image-face-fusion'
content_check = 'content-check'
open_vocabulary_detection_vild = 'open-vocabulary-detection-vild'
ecbsr = 'ecbsr'
msrresnet_lite = 'msrresnet-lite'
object_detection_3d = 'object_detection_3d'
ddpm = 'ddpm'
ocr_recognition = 'OCRRecognition'
image_quality_assessment_mos = 'image-quality-assessment-mos'
image_quality_assessment_degradation = 'image-quality-assessment-degradation'
m2fp = 'm2fp'
nerf_recon_acc = 'nerf-recon-acc'
bts_depth_estimation = 'bts-depth-estimation'
vision_efficient_tuning = 'vision-efficient-tuning'
bad_image_detecting = 'bad-image-detecting'
# EasyCV models
yolox = 'YOLOX'
segformer = 'Segformer'
hand_2d_keypoints = 'HRNet-Hand2D-Keypoints'
image_object_detection_auto = 'image-object-detection-auto'
dino = 'DINO'
# nlp models
bert = 'bert'
@@ -122,6 +150,12 @@ class Models(object):
unite = 'unite'
megatron_bert = 'megatron-bert'
use = 'user-satisfaction-estimation'
fid_plug = 'fid-plug'
lstm = 'lstm'
xlm_roberta = 'xlm-roberta'
transformers = 'transformers'
plug_mental = 'plug-mental'
doc2bot = 'doc2bot'
# audio models
sambert_hifigan = 'sambert-hifigan'
@@ -135,6 +169,8 @@ class Models(object):
generic_itn = 'generic-itn'
generic_punc = 'generic-punc'
generic_sv = 'generic-sv'
ecapa_tdnn_sv = 'ecapa-tdnn-sv'
generic_lm = 'generic-lm'
# multi-modal models
ofa = 'ofa'
@@ -162,6 +198,7 @@ class TaskModels(object):
fill_mask = 'fill-mask'
feature_extraction = 'feature-extraction'
text_generation = 'text-generation'
text_ranking = 'text-ranking'
class Heads(object):
@@ -179,6 +216,11 @@ class Heads(object):
information_extraction = 'information-extraction'
# text gen
text_generation = 'text-generation'
# text ranking
text_ranking = 'text-ranking'
# crf
lstm_crf = 'lstm-crf'
transformer_crf = 'transformer-crf'
class Pipelines(object):
@@ -193,6 +235,7 @@ class Pipelines(object):
"""
# vision tasks
portrait_matting = 'unet-image-matting'
universal_matting = 'unet-universal-matting'
image_denoise = 'nafnet-image-denoise'
image_deblur = 'nafnet-image-deblur'
person_image_cartoon = 'unet-person-image-cartoon'
@@ -209,16 +252,19 @@ class Pipelines(object):
hand_2d_keypoints = 'hrnetv2w18_hand-2d-keypoints_image'
human_detection = 'resnet18-human-detection'
object_detection = 'vit-object-detection'
abnormal_object_detection = 'abnormal-object-detection'
easycv_detection = 'easycv-detection'
easycv_segmentation = 'easycv-segmentation'
face_2d_keypoints = 'mobilenet_face-2d-keypoints_alignment'
salient_detection = 'u2net-salient-detection'
salient_boudary_detection = 'res2net-salient-detection'
camouflaged_detection = 'res2net-camouflaged-detection'
image_demoire = 'uhdm-image-demoireing'
image_classification = 'image-classification'
face_detection = 'resnet-face-detection-scrfd10gkps'
face_liveness_ir = 'manual-face-liveness-flir'
face_liveness_rgb = 'manual-face-liveness-flir'
face_liveness_xc = 'manual-face-liveness-flxc'
card_detection = 'resnet-card-detection-scrfd34gkps'
ulfd_face_detection = 'manual-face-detection-ulfd'
tinymog_face_detection = 'manual-face-detection-tinymog'
@@ -234,20 +280,28 @@ class Pipelines(object):
nextvit_small_daily_image_classification = 'nextvit-small_image-classification_Dailylife-labels'
convnext_base_image_classification_garbage = 'convnext-base_image-classification_garbage'
bnext_small_image_classification = 'bnext-small_image-classification_ImageNet-labels'
yolopv2_image_driving_percetion_bdd100k = 'yolopv2_image-driving-percetion_bdd100k'
common_image_classification = 'common-image-classification'
image_color_enhance = 'csrnet-image-color-enhance'
adaint_image_color_enhance = 'adaint-image-color-enhance'
deeplpf_image_color_enhance = 'deeplpf-image-color-enhance'
virtual_try_on = 'virtual-try-on'
image_colorization = 'unet-image-colorization'
image_style_transfer = 'AAMS-style-transfer'
image_super_resolution = 'rrdb-image-super-resolution'
image_debanding = 'rrdb-image-debanding'
face_image_generation = 'gan-face-image-generation'
product_retrieval_embedding = 'resnet50-product-retrieval-embedding'
realtime_object_detection = 'cspnet_realtime-object-detection_yolox'
realtime_video_object_detection = 'cspnet_realtime-video-object-detection_streamyolo'
face_recognition = 'ir101-face-recognition-cfglint'
face_recognition_ood = 'ir-face-recognition-ood-rts'
face_quality_assessment = 'manual-face-quality-assessment-fqa'
face_recognition_ood = 'ir-face-recognition-rts'
face_recognition_onnx_ir = 'manual-face-recognition-frir'
face_recognition_onnx_fm = 'manual-face-recognition-frfm'
arc_face_recognition = 'ir50-face-recognition-arcface'
mask_face_recognition = 'resnet-face-recognition-facemask'
content_check = 'resnet50-image-classification-cc'
image_instance_segmentation = 'cascade-mask-rcnn-swin-image-instance-segmentation'
maskdino_instance_segmentation = 'maskdino-swin-image-instance-segmentation'
image2image_translation = 'image-to-image-translation'
@@ -259,6 +313,7 @@ class Pipelines(object):
image_object_detection_auto = 'yolox_image-object-detection-auto'
hand_detection = 'yolox-pai_hand-detection'
skin_retouching = 'unet-skin-retouching'
face_reconstruction = 'resnet50-face-reconstruction'
tinynas_classification = 'tinynas-classification'
easyrobust_classification = 'easyrobust-classification'
tinynas_detection = 'tinynas-detection'
@@ -277,6 +332,8 @@ class Pipelines(object):
panorama_depth_estimation = 'panorama-depth-estimation'
image_reid_person = 'passvitb-image-reid-person'
image_inpainting = 'fft-inpainting'
image_paintbyexample = 'stablediffusion-paintbyexample'
image_inpainting_sdv2 = 'image-inpainting-sdv2'
text_driven_segmentation = 'text-driven-segmentation'
movie_scene_segmentation = 'resnet50-bert-movie-scene-segmentation'
shop_segmentation = 'shop-segmentation'
@@ -294,15 +351,31 @@ class Pipelines(object):
vision_middleware_multi_task = 'vision-middleware-multi-task'
video_frame_interpolation = 'video-frame-interpolation'
video_object_segmentation = 'video-object-segmentation'
video_deinterlace = 'video-deinterlace'
image_matching = 'image-matching'
video_stabilization = 'video-stabilization'
video_super_resolution = 'realbasicvsr-video-super-resolution'
pointcloud_sceneflow_estimation = 'pointcloud-sceneflow-estimation'
image_multi_view_depth_estimation = 'image-multi-view-depth-estimation'
video_panoptic_segmentation = 'video-panoptic-segmentation'
vop_retrieval = 'vop-video-text-retrieval'
ddcolor_image_colorization = 'ddcolor-image-colorization'
image_structured_model_probing = 'image-structured-model-probing'
image_fewshot_detection = 'image-fewshot-detection'
image_face_fusion = 'image-face-fusion'
open_vocabulary_detection_vild = 'open-vocabulary-detection-vild'
ddpm_image_semantic_segmentation = 'ddpm-image-semantic-segmentation'
video_colorization = 'video-colorization'
motion_generattion = 'mdm-motion-generation'
mobile_image_super_resolution = 'mobile-image-super-resolution'
image_human_parsing = 'm2fp-image-human-parsing'
object_detection_3d_depe = 'object-detection-3d-depe'
nerf_recon_acc = 'nerf-recon-acc'
bad_image_detecting = 'bad-image-detecting'
image_quality_assessment_mos = 'image-quality-assessment-mos'
image_quality_assessment_degradation = 'image-quality-assessment-degradation'
vision_efficient_tuning = 'vision-efficient-tuning'
# nlp tasks
automatic_post_editing = 'automatic-post-editing'
@@ -317,6 +390,7 @@ class Pipelines(object):
named_entity_recognition_thai = 'named-entity-recognition-thai'
named_entity_recognition_viet = 'named-entity-recognition-viet'
text_generation = 'text-generation'
fid_dialogue = 'fid-dialogue'
text2text_generation = 'text2text-generation'
sentiment_analysis = 'sentiment-analysis'
sentiment_classification = 'sentiment-classification'
@@ -324,6 +398,7 @@ class Pipelines(object):
fill_mask = 'fill-mask'
fill_mask_ponet = 'fill-mask-ponet'
csanmt_translation = 'csanmt-translation'
interactive_translation = 'interactive-translation'
nli = 'nli'
dialog_intent_prediction = 'dialog-intent-prediction'
dialog_modeling = 'dialog-modeling'
@@ -352,6 +427,10 @@ class Pipelines(object):
token_classification = 'token-classification'
translation_evaluation = 'translation-evaluation'
user_satisfaction_estimation = 'user-satisfaction-estimation'
siamese_uie = 'siamese-uie'
document_grounded_dialog_retrieval = 'document-grounded-dialog-retrieval'
document_grounded_dialog_rerank = 'document-grounded-dialog-rerank'
document_grounded_dialog_generate = 'document-grounded-dialog-generate'
# audio tasks
sambert_hifigan_tts = 'sambert-hifigan-tts'
@@ -365,6 +444,9 @@ class Pipelines(object):
itn_inference = 'itn-inference'
punc_inference = 'punc-inference'
sv_inference = 'sv-inference'
vad_inference = 'vad-inference'
speaker_verification = 'speaker-verification'
lm_inference = 'language-model'
# multi-modal tasks
image_captioning = 'image-captioning'
@@ -386,31 +468,322 @@ class Pipelines(object):
diffusers_stable_diffusion = 'diffusers-stable-diffusion'
document_vl_embedding = 'document-vl-embedding'
chinese_stable_diffusion = 'chinese-stable-diffusion'
gridvlp_multi_modal_classification = 'gridvlp-multi-modal-classification'
gridvlp_multi_modal_embedding = 'gridvlp-multi-modal-embedding'
# science tasks
protein_structure = 'unifold-protein-structure'
class Trainers(object):
""" Names for different trainer.
DEFAULT_MODEL_FOR_PIPELINE = {
# TaskName: (pipeline_module_name, model_repo)
Tasks.sentence_embedding:
(Pipelines.sentence_embedding,
'damo/nlp_corom_sentence-embedding_english-base'),
Tasks.text_ranking: (Pipelines.mgeo_ranking,
'damo/mgeo_address_ranking_chinese_base'),
Tasks.text_ranking: (Pipelines.text_ranking,
'damo/nlp_corom_passage-ranking_english-base'),
Tasks.word_segmentation:
(Pipelines.word_segmentation,
'damo/nlp_structbert_word-segmentation_chinese-base'),
Tasks.part_of_speech: (Pipelines.part_of_speech,
'damo/nlp_structbert_part-of-speech_chinese-base'),
Tasks.token_classification:
(Pipelines.part_of_speech,
'damo/nlp_structbert_part-of-speech_chinese-base'),
Tasks.named_entity_recognition:
(Pipelines.named_entity_recognition,
'damo/nlp_raner_named-entity-recognition_chinese-base-news'),
Tasks.relation_extraction:
(Pipelines.relation_extraction,
'damo/nlp_bert_relation-extraction_chinese-base'),
Tasks.information_extraction:
(Pipelines.relation_extraction,
'damo/nlp_bert_relation-extraction_chinese-base'),
Tasks.sentence_similarity:
(Pipelines.sentence_similarity,
'damo/nlp_structbert_sentence-similarity_chinese-base'),
Tasks.translation: (Pipelines.csanmt_translation,
'damo/nlp_csanmt_translation_zh2en'),
Tasks.nli: (Pipelines.nli, 'damo/nlp_structbert_nli_chinese-base'),
Tasks.sentiment_classification:
(Pipelines.sentiment_classification,
'damo/nlp_structbert_sentiment-classification_chinese-base'
), # TODO: revise back after passing the pr
Tasks.portrait_matting: (Pipelines.portrait_matting,
'damo/cv_unet_image-matting'),
Tasks.universal_matting: (Pipelines.universal_matting,
'damo/cv_unet_universal-matting'),
Tasks.human_detection: (Pipelines.human_detection,
'damo/cv_resnet18_human-detection'),
Tasks.image_object_detection: (Pipelines.object_detection,
'damo/cv_vit_object-detection_coco'),
Tasks.image_denoising: (Pipelines.image_denoise,
'damo/cv_nafnet_image-denoise_sidd'),
Tasks.image_deblurring: (Pipelines.image_deblur,
'damo/cv_nafnet_image-deblur_gopro'),
Tasks.video_stabilization: (Pipelines.video_stabilization,
'damo/cv_dut-raft_video-stabilization_base'),
Tasks.video_super_resolution:
(Pipelines.video_super_resolution,
'damo/cv_realbasicvsr_video-super-resolution_videolq'),
Tasks.text_classification:
(Pipelines.sentiment_classification,
'damo/nlp_structbert_sentiment-classification_chinese-base'),
Tasks.text_generation: (Pipelines.text_generation,
'damo/nlp_palm2.0_text-generation_chinese-base'),
Tasks.zero_shot_classification:
(Pipelines.zero_shot_classification,
'damo/nlp_structbert_zero-shot-classification_chinese-base'),
Tasks.task_oriented_conversation: (Pipelines.dialog_modeling,
'damo/nlp_space_dialog-modeling'),
Tasks.dialog_state_tracking: (Pipelines.dialog_state_tracking,
'damo/nlp_space_dialog-state-tracking'),
Tasks.table_question_answering:
(Pipelines.table_question_answering_pipeline,
'damo/nlp-convai-text2sql-pretrain-cn'),
Tasks.document_grounded_dialog_generate:
(Pipelines.document_grounded_dialog_generate,
'DAMO_ConvAI/nlp_convai_generation_pretrain'),
Tasks.document_grounded_dialog_rerank:
(Pipelines.document_grounded_dialog_rerank,
'damo/nlp_convai_rerank_pretrain'),
Tasks.document_grounded_dialog_retrieval:
(Pipelines.document_grounded_dialog_retrieval,
'DAMO_ConvAI/nlp_convai_retrieval_pretrain'),
Tasks.text_error_correction:
(Pipelines.text_error_correction,
'damo/nlp_bart_text-error-correction_chinese'),
Tasks.image_captioning: (Pipelines.image_captioning,
'damo/ofa_image-caption_coco_large_en'),
Tasks.video_captioning:
(Pipelines.video_captioning,
'damo/multi-modal_hitea_video-captioning_base_en'),
Tasks.image_portrait_stylization:
(Pipelines.person_image_cartoon,
'damo/cv_unet_person-image-cartoon_compound-models'),
Tasks.ocr_detection: (Pipelines.ocr_detection,
'damo/cv_resnet18_ocr-detection-line-level_damo'),
Tasks.table_recognition:
(Pipelines.table_recognition,
'damo/cv_dla34_table-structure-recognition_cycle-centernet'),
Tasks.document_vl_embedding:
(Pipelines.document_vl_embedding,
'damo/multi-modal_convnext-roberta-base_vldoc-embedding'),
Tasks.license_plate_detection:
(Pipelines.license_plate_detection,
'damo/cv_resnet18_license-plate-detection_damo'),
Tasks.fill_mask: (Pipelines.fill_mask, 'damo/nlp_veco_fill-mask-large'),
Tasks.feature_extraction: (Pipelines.feature_extraction,
'damo/pert_feature-extraction_base-test'),
Tasks.action_recognition: (Pipelines.action_recognition,
'damo/cv_TAdaConv_action-recognition'),
Tasks.action_detection: (Pipelines.action_detection,
'damo/cv_ResNetC3D_action-detection_detection2d'),
Tasks.live_category: (Pipelines.live_category,
'damo/cv_resnet50_live-category'),
Tasks.video_category: (Pipelines.video_category,
'damo/cv_resnet50_video-category'),
Tasks.multi_modal_embedding: (Pipelines.multi_modal_embedding,
'damo/multi-modal_clip-vit-base-patch16_zh'),
Tasks.generative_multi_modal_embedding:
(Pipelines.generative_multi_modal_embedding,
'damo/multi-modal_gemm-vit-large-patch14_generative-multi-modal-embedding'
),
Tasks.multi_modal_similarity:
(Pipelines.multi_modal_similarity,
'damo/multi-modal_team-vit-large-patch14_multi-modal-similarity'),
Tasks.visual_question_answering:
(Pipelines.visual_question_answering,
'damo/mplug_visual-question-answering_coco_large_en'),
Tasks.video_question_answering:
(Pipelines.video_question_answering,
'damo/multi-modal_hitea_video-question-answering_base_en'),
Tasks.video_embedding: (Pipelines.cmdssl_video_embedding,
'damo/cv_r2p1d_video_embedding'),
Tasks.text_to_image_synthesis:
(Pipelines.text_to_image_synthesis,
'damo/cv_diffusion_text-to-image-synthesis_tiny'),
Tasks.body_2d_keypoints: (Pipelines.body_2d_keypoints,
'damo/cv_hrnetv2w32_body-2d-keypoints_image'),
Tasks.body_3d_keypoints: (Pipelines.body_3d_keypoints,
'damo/cv_canonical_body-3d-keypoints_video'),
Tasks.hand_2d_keypoints:
(Pipelines.hand_2d_keypoints,
'damo/cv_hrnetw18_hand-pose-keypoints_coco-wholebody'),
Tasks.card_detection: (Pipelines.card_detection,
'damo/cv_resnet_carddetection_scrfd34gkps'),
Tasks.content_check: (Pipelines.content_check,
'damo/cv_resnet50_content-check_cc'),
Tasks.face_detection:
(Pipelines.mog_face_detection,
'damo/cv_resnet101_face-detection_cvpr22papermogface'),
Tasks.face_liveness: (Pipelines.face_liveness_ir,
'damo/cv_manual_face-liveness_flir'),
Tasks.face_recognition: (Pipelines.face_recognition,
'damo/cv_ir101_facerecognition_cfglint'),
Tasks.facial_expression_recognition:
(Pipelines.facial_expression_recognition,
'damo/cv_vgg19_facial-expression-recognition_fer'),
Tasks.face_attribute_recognition:
(Pipelines.face_attribute_recognition,
'damo/cv_resnet34_face-attribute-recognition_fairface'),
Tasks.face_2d_keypoints: (Pipelines.face_2d_keypoints,
'damo/cv_mobilenet_face-2d-keypoints_alignment'),
Tasks.face_quality_assessment:
(Pipelines.face_quality_assessment,
'damo/cv_manual_face-quality-assessment_fqa'),
Tasks.video_multi_modal_embedding:
(Pipelines.video_multi_modal_embedding,
'damo/multi_modal_clip_vtretrival_msrvtt_53'),
Tasks.image_color_enhancement:
(Pipelines.image_color_enhance,
'damo/cv_csrnet_image-color-enhance-models'),
Tasks.virtual_try_on: (Pipelines.virtual_try_on,
'damo/cv_daflow_virtual-try-on_base'),
Tasks.image_colorization: (Pipelines.ddcolor_image_colorization,
'damo/cv_ddcolor_image-colorization'),
Tasks.video_colorization: (Pipelines.video_colorization,
'damo/cv_unet_video-colorization'),
Tasks.image_segmentation:
(Pipelines.image_instance_segmentation,
'damo/cv_swin-b_image-instance-segmentation_coco'),
Tasks.image_driving_perception:
(Pipelines.yolopv2_image_driving_percetion_bdd100k,
'damo/cv_yolopv2_image-driving-perception_bdd100k'),
Tasks.image_depth_estimation:
(Pipelines.image_depth_estimation,
'damo/cv_newcrfs_image-depth-estimation_indoor'),
Tasks.indoor_layout_estimation:
(Pipelines.indoor_layout_estimation,
'damo/cv_panovit_indoor-layout-estimation'),
Tasks.video_depth_estimation:
(Pipelines.video_depth_estimation,
'damo/cv_dro-resnet18_video-depth-estimation_indoor'),
Tasks.panorama_depth_estimation:
(Pipelines.panorama_depth_estimation,
'damo/cv_unifuse_panorama-depth-estimation'),
Tasks.image_style_transfer: (Pipelines.image_style_transfer,
'damo/cv_aams_style-transfer_damo'),
Tasks.face_image_generation: (Pipelines.face_image_generation,
'damo/cv_gan_face-image-generation'),
Tasks.image_super_resolution: (Pipelines.image_super_resolution,
'damo/cv_rrdb_image-super-resolution'),
Tasks.image_debanding: (Pipelines.image_debanding,
'damo/cv_rrdb_image-debanding'),
Tasks.image_portrait_enhancement:
(Pipelines.image_portrait_enhancement,
'damo/cv_gpen_image-portrait-enhancement'),
Tasks.product_retrieval_embedding:
(Pipelines.product_retrieval_embedding,
'damo/cv_resnet50_product-bag-embedding-models'),
Tasks.image_to_image_generation:
(Pipelines.image_to_image_generation,
'damo/cv_latent_diffusion_image2image_generate'),
Tasks.image_classification:
(Pipelines.daily_image_classification,
'damo/cv_vit-base_image-classification_Dailylife-labels'),
Tasks.image_object_detection:
(Pipelines.image_object_detection_auto,
'damo/cv_yolox_image-object-detection-auto'),
Tasks.ocr_recognition:
(Pipelines.ocr_recognition,
'damo/cv_convnextTiny_ocr-recognition-general_damo'),
Tasks.skin_retouching: (Pipelines.skin_retouching,
'damo/cv_unet_skin-retouching'),
Tasks.faq_question_answering:
(Pipelines.faq_question_answering,
'damo/nlp_structbert_faq-question-answering_chinese-base'),
Tasks.crowd_counting: (Pipelines.crowd_counting,
'damo/cv_hrnet_crowd-counting_dcanet'),
Tasks.video_single_object_tracking:
(Pipelines.video_single_object_tracking,
'damo/cv_vitb_video-single-object-tracking_ostrack'),
Tasks.image_reid_person: (Pipelines.image_reid_person,
'damo/cv_passvitb_image-reid-person_market'),
Tasks.text_driven_segmentation:
(Pipelines.text_driven_segmentation,
'damo/cv_vitl16_segmentation_text-driven-seg'),
Tasks.movie_scene_segmentation: (
Pipelines.movie_scene_segmentation,
'damo/cv_resnet50-bert_video-scene-segmentation_movienet'),
Tasks.shop_segmentation: (Pipelines.shop_segmentation,
'damo/cv_vitb16_segmentation_shop-seg'),
Tasks.image_inpainting: (Pipelines.image_inpainting,
'damo/cv_fft_inpainting_lama'),
Tasks.image_paintbyexample: (Pipelines.image_paintbyexample,
'damo/cv_stable-diffusion_paint-by-example'),
Tasks.video_inpainting: (Pipelines.video_inpainting,
'damo/cv_video-inpainting'),
Tasks.video_human_matting: (Pipelines.video_human_matting,
'damo/cv_effnetv2_video-human-matting'),
Tasks.video_frame_interpolation: (
Pipelines.video_frame_interpolation,
'damo/cv_raft_video-frame-interpolation'),
Tasks.video_deinterlace: (Pipelines.video_deinterlace,
'damo/cv_unet_video-deinterlace'),
Tasks.human_wholebody_keypoint: (
Pipelines.human_wholebody_keypoint,
'damo/cv_hrnetw48_human-wholebody-keypoint_image'),
Tasks.hand_static: (Pipelines.hand_static,
'damo/cv_mobileface_hand-static'),
Tasks.face_human_hand_detection: (
Pipelines.face_human_hand_detection,
'damo/cv_nanodet_face-human-hand-detection'),
Tasks.face_emotion: (Pipelines.face_emotion, 'damo/cv_face-emotion'),
Tasks.product_segmentation: (Pipelines.product_segmentation,
'damo/cv_F3Net_product-segmentation'),
Tasks.referring_video_object_segmentation: (
Pipelines.referring_video_object_segmentation,
'damo/cv_swin-t_referring_video-object-segmentation'),
Tasks.video_summarization: (Pipelines.video_summarization,
'damo/cv_googlenet_pgl-video-summarization'),
Tasks.image_skychange: (Pipelines.image_skychange,
'damo/cv_hrnetocr_skychange'),
Tasks.translation_evaluation: (
Pipelines.translation_evaluation,
'damo/nlp_unite_mup_translation_evaluation_multilingual_large'),
Tasks.video_object_segmentation: (
Pipelines.video_object_segmentation,
'damo/cv_rdevos_video-object-segmentation'),
Tasks.video_multi_object_tracking: (
Pipelines.video_multi_object_tracking,
'damo/cv_yolov5_video-multi-object-tracking_fairmot'),
Tasks.image_multi_view_depth_estimation: (
Pipelines.image_multi_view_depth_estimation,
'damo/cv_casmvs_multi-view-depth-estimation_general'),
Tasks.image_fewshot_detection: (
Pipelines.image_fewshot_detection,
'damo/cv_resnet101_detection_fewshot-defrcn'),
Tasks.image_body_reshaping: (Pipelines.image_body_reshaping,
'damo/cv_flow-based-body-reshaping_damo'),
Tasks.image_face_fusion: (Pipelines.image_face_fusion,
'damo/cv_unet-image-face-fusion_damo'),
Tasks.image_matching: (
Pipelines.image_matching,
'damo/cv_quadtree_attention_image-matching_outdoor'),
Tasks.image_quality_assessment_mos: (
Pipelines.image_quality_assessment_mos,
'damo/cv_resnet_image-quality-assessment-mos_youtubeUGC'),
Tasks.image_quality_assessment_degradation: (
Pipelines.image_quality_assessment_degradation,
'damo/cv_resnet50_image-quality-assessment_degradation'),
Tasks.vision_efficient_tuning: (
Pipelines.vision_efficient_tuning,
'damo/cv_vitb16_classification_vision-efficient-tuning-adapter'),
Tasks.object_detection_3d: (Pipelines.object_detection_3d_depe,
'damo/cv_object-detection-3d_depe'),
Tasks.bad_image_detecting: (Pipelines.bad_image_detecting,
'damo/cv_mobilenet-v2_bad-image-detecting'),
Tasks.nerf_recon_acc: (Pipelines.nerf_recon_acc,
'damo/cv_nerf-3d-reconstruction-accelerate_damo'),
Tasks.siamese_uie: (Pipelines.siamese_uie,
'damo/nlp_structbert_siamese-uie_chinese-base'),
}
Holds the standard trainer name to use for identifying different trainer.
This should be used to register trainers.
For a general Trainer, you can use EpochBasedTrainer.
For a model specific Trainer, you can use ${ModelName}-${Task}-trainer.
"""
default = 'trainer'
easycv = 'easycv'
tinynas_damoyolo = 'tinynas-damoyolo'
# multi-modal trainers
clip_multi_modal_embedding = 'clip-multi-modal-embedding'
ofa = 'ofa'
mplug = 'mplug'
mgeo_ranking_trainer = 'mgeo-ranking-trainer'
class CVTrainers(object):
# cv trainers
image_instance_segmentation = 'image-instance-segmentation'
image_portrait_enhancement = 'image-portrait-enhancement'
@@ -424,6 +797,8 @@ class Trainers(object):
image_classification = 'image-classification'
image_fewshot_detection = 'image-fewshot-detection'
class NLPTrainers(object):
# nlp trainers
bert_sentiment_analysis = 'bert-sentiment-analysis'
dialog_modeling_trainer = 'dialog-modeling-trainer'
@@ -431,14 +806,26 @@ class Trainers(object):
nlp_base_trainer = 'nlp-base-trainer'
nlp_veco_trainer = 'nlp-veco-trainer'
nlp_text_ranking_trainer = 'nlp-text-ranking-trainer'
nlp_sentence_embedding_trainer = 'nlp-sentence-embedding-trainer'
text_generation_trainer = 'text-generation-trainer'
nlp_plug_trainer = 'nlp-plug-trainer'
gpt3_trainer = 'nlp-gpt3-trainer'
faq_question_answering_trainer = 'faq-question-answering-trainer'
gpt_moe_trainer = 'nlp-gpt-moe-trainer'
table_question_answering_trainer = 'table-question-answering-trainer'
document_grounded_dialog_generate_trainer = 'document-grounded-dialog-generate-trainer'
document_grounded_dialog_rerank_trainer = 'document-grounded-dialog-rerank-trainer'
document_grounded_dialog_retrieval_trainer = 'document-grounded-dialog-retrieval-trainer'
# audio trainers
class MultiModalTrainers(object):
clip_multi_modal_embedding = 'clip-multi-modal-embedding'
ofa = 'ofa'
mplug = 'mplug'
mgeo_ranking_trainer = 'mgeo-ranking-trainer'
class AudioTrainers(object):
speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k'
speech_dfsmn_kws_char_farfield = 'speech_dfsmn_kws_char_farfield'
speech_kws_fsmn_char_ctc_nearfield = 'speech_kws_fsmn_char_ctc_nearfield'
@@ -447,6 +834,45 @@ class Trainers(object):
speech_separation = 'speech-separation'
class Trainers(CVTrainers, NLPTrainers, MultiModalTrainers, AudioTrainers):
""" Names for different trainer.
Holds the standard trainer name to use for identifying different trainer.
This should be used to register trainers.
For a general Trainer, you can use EpochBasedTrainer.
For a model specific Trainer, you can use ${ModelName}-${Task}-trainer.
"""
default = 'trainer'
easycv = 'easycv'
tinynas_damoyolo = 'tinynas-damoyolo'
@staticmethod
def get_trainer_domain(attribute_or_value):
if attribute_or_value in vars(
CVTrainers) or attribute_or_value in vars(CVTrainers).values():
return Fields.cv
elif attribute_or_value in vars(
NLPTrainers) or attribute_or_value in vars(
NLPTrainers).values():
return Fields.nlp
elif attribute_or_value in vars(
AudioTrainers) or attribute_or_value in vars(
AudioTrainers).values():
return Fields.audio
elif attribute_or_value in vars(
MultiModalTrainers) or attribute_or_value in vars(
MultiModalTrainers).values():
return Fields.multi_modal
elif attribute_or_value == Trainers.default:
return Trainers.default
elif attribute_or_value == Trainers.easycv:
return Trainers.easycv
else:
return 'unknown'
class Preprocessors(object):
""" Names for different preprocessor.
@@ -466,12 +892,18 @@ class Preprocessors(object):
image_classification_mmcv_preprocessor = 'image-classification-mmcv-preprocessor'
image_color_enhance_preprocessor = 'image-color-enhance-preprocessor'
image_instance_segmentation_preprocessor = 'image-instance-segmentation-preprocessor'
image_driving_perception_preprocessor = 'image-driving-perception-preprocessor'
image_portrait_enhancement_preprocessor = 'image-portrait-enhancement-preprocessor'
image_quality_assessment_mos_preprocessor = 'image-quality_assessment-mos-preprocessor'
video_summarization_preprocessor = 'video-summarization-preprocessor'
movie_scene_segmentation_preprocessor = 'movie-scene-segmentation-preprocessor'
image_classification_bypass_preprocessor = 'image-classification-bypass-preprocessor'
object_detection_scrfd = 'object-detection-scrfd'
image_sky_change_preprocessor = 'image-sky-change-preprocessor'
image_demoire_preprocessor = 'image-demoire-preprocessor'
ocr_recognition = 'ocr-recognition'
bad_image_detecting_preprocessor = 'bad-image-detecting-preprocessor'
nerf_recon_acc_preprocessor = 'nerf-recon-acc-preprocessor'
# nlp preprocessor
sen_sim_tokenizer = 'sen-sim-tokenizer'
@@ -510,6 +942,10 @@ class Preprocessors(object):
sentence_piece = 'sentence-piece'
translation_evaluation = 'translation-evaluation-preprocessor'
dialog_use_preprocessor = 'dialog-use-preprocessor'
siamese_uie_preprocessor = 'siamese-uie-preprocessor'
document_grounded_dialog_retrieval = 'document-grounded-dialog-retrieval'
document_grounded_dialog_rerank = 'document-grounded-dialog-rerank'
document_grounded_dialog_generate = 'document-grounded-dialog-generate'
# audio preprocessor
linear_aec_fbank = 'linear-aec-fbank'
@@ -555,10 +991,14 @@ class Metrics(object):
image_ins_seg_coco_metric = 'image-ins-seg-coco-metric'
# metrics for sequence classification task
seq_cls_metric = 'seq-cls-metric'
# loss metric
loss_metric = 'loss-metric'
# metrics for token-classification task
token_cls_metric = 'token-cls-metric'
# metrics for text-generation task
text_gen_metric = 'text-gen-metric'
# file saving wrapper
prediction_saving_wrapper = 'prediction-saving-wrapper'
# metrics for image-color-enhance task
image_color_enhance_metric = 'image-color-enhance-metric'
# metrics for image-portrait-enhancement task
@@ -576,6 +1016,12 @@ class Metrics(object):
referring_video_object_segmentation_metric = 'referring-video-object-segmentation-metric'
# metric for video stabilization task
video_stabilization_metric = 'video-stabilization-metric'
# metirc for image-quality-assessment-mos task
image_quality_assessment_mos_metric = 'image-quality-assessment-mos-metric'
# metirc for image-quality-assessment-degradation task
image_quality_assessment_degradation_metric = 'image-quality-assessment-degradation-metric'
# metric for text-ranking task
text_ranking_metric = 'text-ranking-metric'
class Optimizers(object):
@@ -609,6 +1055,7 @@ class Hooks(object):
# checkpoint
CheckpointHook = 'CheckpointHook'
BestCkptSaverHook = 'BestCkptSaverHook'
LoadCheckpointHook = 'LoadCheckpointHook'
# logger
TextLoggerHook = 'TextLoggerHook'

View File

@@ -25,7 +25,10 @@ if TYPE_CHECKING:
from .video_stabilization_metric import VideoStabilizationMetric
from .video_super_resolution_metric.video_super_resolution_metric import VideoSuperResolutionMetric
from .ppl_metric import PplMetric
from .image_quality_assessment_degradation_metric import ImageQualityAssessmentDegradationMetric
from .image_quality_assessment_mos_metric import ImageQualityAssessmentMosMetric
from .text_ranking_metric import TextRankingMetric
from .loss_metric import LossMetric
else:
_import_structure = {
'audio_noise_metric': ['AudioNoiseMetric'],
@@ -50,6 +53,12 @@ else:
'video_frame_interpolation_metric': ['VideoFrameInterpolationMetric'],
'video_stabilization_metric': ['VideoStabilizationMetric'],
'ppl_metric': ['PplMetric'],
'image_quality_assessment_degradation_metric':
['ImageQualityAssessmentDegradationMetric'],
'image_quality_assessment_mos_metric':
['ImageQualityAssessmentMosMetric'],
'text_ranking_metric': ['TextRankingMetric'],
'loss_metric': ['LossMetric']
}
import sys

View File

@@ -8,6 +8,7 @@ from modelscope.metainfo import Metrics
from modelscope.outputs import OutputKeys
from modelscope.utils.chinese_utils import remove_space_between_chinese_chars
from modelscope.utils.registry import default_group
from modelscope.utils.tensor_utils import torch_nested_numpify
from .base import Metric
from .builder import METRICS, MetricKeys
@@ -36,8 +37,10 @@ class AccuracyMetric(Metric):
eval_results = outputs[key]
break
assert type(ground_truths) == type(eval_results)
ground_truths = torch_nested_numpify(ground_truths)
for truth in ground_truths:
self.labels.append(truth)
eval_results = torch_nested_numpify(eval_results)
for result in eval_results:
if isinstance(truth, str):
if isinstance(result, list):

View File

@@ -12,7 +12,9 @@ METRICS = Registry('metrics')
class MetricKeys(object):
ACCURACY = 'accuracy'
F1 = 'f1'
Binary_F1 = 'binary-f1'
Macro_F1 = 'macro-f1'
Micro_F1 = 'micro-f1'
PRECISION = 'precision'
RECALL = 'recall'
PSNR = 'psnr'
@@ -33,6 +35,11 @@ class MetricKeys(object):
DISTORTION_VALUE = 'distortion_value'
STABILITY_SCORE = 'stability_score'
PPL = 'ppl'
PLCC = 'plcc'
SRCC = 'srcc'
RMSE = 'rmse'
MRR = 'mrr'
NDCG = 'ndcg'
task_default_metrics = {
@@ -59,6 +66,10 @@ task_default_metrics = {
Tasks.video_frame_interpolation:
[Metrics.video_frame_interpolation_metric],
Tasks.video_stabilization: [Metrics.video_stabilization_metric],
Tasks.image_quality_assessment_degradation:
[Metrics.image_quality_assessment_degradation_metric],
Tasks.image_quality_assessment_mos:
[Metrics.image_quality_assessment_mos_metric],
}

View File

@@ -0,0 +1,75 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import sys
import tempfile
from collections import defaultdict
from typing import Dict
import cv2
import numpy as np
import torch
from scipy.stats import pearsonr, spearmanr
from tqdm import tqdm
from modelscope.metainfo import Metrics
from modelscope.utils.registry import default_group
from .base import Metric
from .builder import METRICS, MetricKeys
@METRICS.register_module(
group_key=default_group,
module_name=Metrics.image_quality_assessment_degradation_metric)
class ImageQualityAssessmentDegradationMetric(Metric):
"""The metric for image-quality-assessment-degradation task.
"""
def __init__(self):
self.inputs = defaultdict(list)
self.outputs = defaultdict(list)
def add(self, outputs: Dict, inputs: Dict):
item_degradation_id = outputs['item_id'][0] + outputs[
'distortion_type'][0]
if outputs['distortion_type'][0] in ['01', '02', '03']:
pred = outputs['blur_degree']
elif outputs['distortion_type'][0] in ['09', '10', '21']:
pred = outputs['comp_degree']
elif outputs['distortion_type'][0] in ['11', '12', '13', '14']:
pred = outputs['noise_degree']
else:
return
self.outputs[item_degradation_id].append(pred[0].float())
self.inputs[item_degradation_id].append(outputs['target'].float())
def evaluate(self):
degree_plccs = []
degree_sroccs = []
for item_degradation_id, degree_value in self.inputs.items():
degree_label = torch.cat(degree_value).flatten().data.cpu().numpy()
degree_pred = torch.cat(self.outputs[item_degradation_id]).flatten(
).data.cpu().numpy()
degree_plcc = pearsonr(degree_label, degree_pred)[0]
degree_srocc = spearmanr(degree_label, degree_pred)[0]
degree_plccs.append(degree_plcc)
degree_sroccs.append(degree_srocc)
degree_plcc_mean = np.array(degree_plccs).mean()
degree_srocc_mean = np.array(degree_sroccs).mean()
return {
MetricKeys.PLCC: degree_plcc_mean,
MetricKeys.SRCC: degree_srocc_mean,
}
def merge(self, other: 'ImageQualityAssessmentDegradationMetric'):
self.inputs.extend(other.inputs)
self.outputs.extend(other.outputs)
def __getstate__(self):
return self.inputs, self.outputs
def __setstate__(self, state):
self.inputs, self.outputs = state

View File

@@ -0,0 +1,57 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import sys
import tempfile
from typing import Dict
import cv2
import numpy as np
import torch
from scipy.stats import pearsonr, spearmanr
from tqdm import tqdm
from modelscope.metainfo import Metrics
from modelscope.utils.registry import default_group
from .base import Metric
from .builder import METRICS, MetricKeys
@METRICS.register_module(
group_key=default_group,
module_name=Metrics.image_quality_assessment_mos_metric)
class ImageQualityAssessmentMosMetric(Metric):
"""The metric for image-quality-assessment-mos task.
"""
def __init__(self):
self.inputs = []
self.outputs = []
def add(self, outputs: Dict, inputs: Dict):
self.outputs.append(outputs['pred'].float())
self.inputs.append(outputs['target'].float())
def evaluate(self):
mos_labels = torch.cat(self.inputs).flatten().data.cpu().numpy()
mos_preds = torch.cat(self.outputs).flatten().data.cpu().numpy()
mos_plcc = pearsonr(mos_labels, mos_preds)[0]
mos_srocc = spearmanr(mos_labels, mos_preds)[0]
mos_rmse = np.sqrt(np.mean((mos_labels - mos_preds)**2))
return {
MetricKeys.PLCC: mos_plcc,
MetricKeys.SRCC: mos_srocc,
MetricKeys.RMSE: mos_rmse,
}
def merge(self, other: 'ImageQualityAssessmentMosMetric'):
self.inputs.extend(other.inputs)
self.outputs.extend(other.outputs)
def __getstate__(self):
return self.inputs, self.outputs
def __setstate__(self, state):
self.inputs, self.outputs = state

View File

@@ -0,0 +1,46 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Dict
import numpy as np
from sklearn.metrics import accuracy_score, f1_score
from modelscope.metainfo import Metrics
from modelscope.outputs import OutputKeys
from modelscope.utils.registry import default_group
from modelscope.utils.tensor_utils import (torch_nested_detach,
torch_nested_numpify)
from .base import Metric
from .builder import METRICS, MetricKeys
@METRICS.register_module(
group_key=default_group, module_name=Metrics.loss_metric)
class LossMetric(Metric):
"""The metric class to calculate average loss of batches.
Args:
loss_key: The key of loss
"""
def __init__(self, loss_key=OutputKeys.LOSS, *args, **kwargs):
super().__init__(*args, **kwargs)
self.loss_key = loss_key
self.losses = []
def add(self, outputs: Dict, inputs: Dict):
loss = outputs[self.loss_key]
self.losses.append(torch_nested_numpify(torch_nested_detach(loss)))
def evaluate(self):
return {OutputKeys.LOSS: float(np.average(self.losses))}
def merge(self, other: 'LossMetric'):
self.losses.extend(other.losses)
def __getstate__(self):
return self.losses
def __setstate__(self, state):
self.__init__()
self.losses = state

View File

@@ -0,0 +1,42 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Dict
import numpy as np
from sklearn.metrics import accuracy_score, f1_score
from modelscope.metainfo import Metrics
from modelscope.outputs import OutputKeys
from modelscope.utils.registry import default_group
from modelscope.utils.tensor_utils import (torch_nested_detach,
torch_nested_numpify)
from .base import Metric
from .builder import METRICS, MetricKeys
@METRICS.register_module(
group_key=default_group, module_name=Metrics.prediction_saving_wrapper)
class PredictionSavingWrapper(Metric):
"""The wrapper to save predictions to file.
Args:
saving_fn: The saving_fn used to save predictions to files.
"""
def __init__(self, saving_fn, **kwargs):
super().__init__(**kwargs)
self.saving_fn = saving_fn
def add(self, outputs: Dict, inputs: Dict):
self.saving_fn(inputs, outputs)
def evaluate(self):
return {}
def merge(self, other: 'PredictionSavingWrapper'):
pass
def __getstate__(self):
pass
def __setstate__(self, state):
pass

View File

@@ -48,19 +48,29 @@ class SequenceClassificationMetric(Metric):
def evaluate(self):
preds = np.concatenate(self.preds, axis=0)
labels = np.concatenate(self.labels, axis=0)
preds = np.argmax(preds, axis=1)
return {
MetricKeys.ACCURACY:
accuracy_score(labels, preds),
MetricKeys.F1:
f1_score(
labels,
preds,
average='micro' if any([label > 1
for label in labels]) else None),
MetricKeys.Macro_F1:
f1_score(labels, preds, average='macro'),
}
assert len(preds.shape) == 2, 'Only support predictions with shape: (batch_size, num_labels),' \
'multi-label classification is not supported in this metric class.'
preds_max = np.argmax(preds, axis=1)
if preds.shape[1] > 2:
metrics = {
MetricKeys.ACCURACY: accuracy_score(labels, preds_max),
MetricKeys.Micro_F1:
f1_score(labels, preds_max, average='micro'),
MetricKeys.Macro_F1:
f1_score(labels, preds_max, average='macro'),
}
metrics[MetricKeys.F1] = metrics[MetricKeys.Micro_F1]
return metrics
else:
metrics = {
MetricKeys.ACCURACY:
accuracy_score(labels, preds_max),
MetricKeys.Binary_F1:
f1_score(labels, preds_max, average='binary'),
}
metrics[MetricKeys.F1] = metrics[MetricKeys.Binary_F1]
return metrics
def merge(self, other: 'SequenceClassificationMetric'):
self.preds.extend(other.preds)

View File

@@ -0,0 +1,91 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Dict, List
import numpy as np
from modelscope.metainfo import Metrics
from modelscope.metrics.base import Metric
from modelscope.metrics.builder import METRICS, MetricKeys
from modelscope.utils.registry import default_group
@METRICS.register_module(
group_key=default_group, module_name=Metrics.text_ranking_metric)
class TextRankingMetric(Metric):
"""The metric computation class for text ranking classes.
This metric class calculates mrr and ndcg metric for the whole evaluation dataset.
Args:
target_text: The key of the target text column in the `inputs` arg.
pred_text: The key of the predicted text column in the `outputs` arg.
"""
def __init__(self, mrr_k: int = 1, ndcg_k: int = 1):
self.labels: List = []
self.qids: List = []
self.logits: List = []
self.mrr_k: int = mrr_k
self.ndcg_k: int = ndcg_k
def add(self, outputs: Dict[str, List], inputs: Dict[str, List]):
self.labels.extend(inputs.pop('labels').detach().cpu().numpy())
self.qids.extend(inputs.pop('qid').detach().cpu().numpy())
logits = outputs['logits'].squeeze(-1).detach().cpu().numpy()
logits = self._sigmoid(logits).tolist()
self.logits.extend(logits)
def evaluate(self):
rank_result = {}
for qid, score, label in zip(self.qids, self.logits, self.labels):
if qid not in rank_result:
rank_result[qid] = []
rank_result[qid].append((score, label))
for qid in rank_result:
rank_result[qid] = sorted(rank_result[qid], key=lambda x: x[0])
return {
MetricKeys.MRR: self._compute_mrr(rank_result),
MetricKeys.NDCG: self._compute_ndcg(rank_result)
}
@staticmethod
def _sigmoid(logits):
return np.exp(logits) / (1 + np.exp(logits))
def _compute_mrr(self, result):
mrr = 0
for res in result.values():
sorted_res = sorted(res, key=lambda x: x[0], reverse=True)
ar = 0
for index, ele in enumerate(sorted_res[:self.mrr_k]):
if str(ele[1]) == '1':
ar = 1.0 / (index + 1)
break
mrr += ar
return mrr / len(result)
def _compute_ndcg(self, result):
ndcg = 0
from sklearn.metrics import ndcg_score
for res in result.values():
sorted_res = sorted(res, key=lambda x: [0], reverse=True)
labels = np.array([[ele[1] for ele in sorted_res]])
scores = np.array([[ele[0] for ele in sorted_res]])
ndcg += float(ndcg_score(labels, scores, k=self.ndcg_k))
return ndcg / len(result)
def merge(self, other: 'TextRankingMetric'):
self.labels.extend(other.labels)
self.qids.extend(other.qids)
self.logits.extend(other.logits)
def __getstate__(self):
return self.labels, self.qids, self.logits, self.mrr_k, self.ndcg_k
def __setstate__(self, state):
self.__init__()
self.labels, self.qids, self.logits, self.mrr_k, self.ndcg_k = state

View File

@@ -9,4 +9,5 @@ from .base import Head, Model
from .builder import BACKBONES, HEADS, MODELS, build_model
if is_torch_available():
from .base import TorchModel, TorchHead
from .base.base_torch_model import TorchModel
from .base.base_torch_head import TorchHead

View File

@@ -1,3 +1,3 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from . import ans, asr, itn, kws, tts
from . import ans, asr, itn, kws, sv, tts

View File

@@ -13,6 +13,9 @@ __all__ = ['GenericAutomaticSpeechRecognition']
@MODELS.register_module(
Tasks.auto_speech_recognition, module_name=Models.generic_asr)
@MODELS.register_module(
Tasks.voice_activity_detection, module_name=Models.generic_asr)
@MODELS.register_module(Tasks.language_model, module_name=Models.generic_asr)
class GenericAutomaticSpeechRecognition(Model):
def __init__(self, model_dir: str, am_model_name: str,

View File

@@ -120,13 +120,12 @@ class Encoder(nn.Module):
in_channels: Number of input channels.
out_channels: Number of output channels.
Example:
-------
Examples:
>>> x = torch.randn(2, 1000)
>>> encoder = Encoder(kernel_size=4, out_channels=64)
>>> h = encoder(x)
>>> h.shape
torch.Size([2, 64, 499])
>>> h.shape # torch.Size([2, 64, 499])
"""
def __init__(self,

View File

@@ -0,0 +1,18 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import TYPE_CHECKING
from modelscope.utils.import_utils import LazyImportModule
if TYPE_CHECKING:
from .ecapa_tdnn import SpeakerVerificationECAPATDNN
else:
_import_structure = {'ecapa_tdnn': ['SpeakerVerificationECAPATDNN']}
import sys
sys.modules[__name__] = LazyImportModule(
__name__,
globals()['__file__'],
_import_structure,
module_spec=__spec__,
extra_objects={},
)

View File

@@ -0,0 +1,504 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
""" This ECAPA-TDNN implementation is adapted from https://github.com/speechbrain/speechbrain.
"""
import math
import os
from typing import Any, Dict, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio.compliance.kaldi as Kaldi
from modelscope.metainfo import Models
from modelscope.models import MODELS, TorchModel
from modelscope.utils.constant import Tasks
def length_to_mask(length, max_len=None, dtype=None, device=None):
assert len(length.shape) == 1
if max_len is None:
max_len = length.max().long().item()
mask = torch.arange(
max_len, device=length.device, dtype=length.dtype).expand(
len(length), max_len) < length.unsqueeze(1)
if dtype is None:
dtype = length.dtype
if device is None:
device = length.device
mask = torch.as_tensor(mask, dtype=dtype, device=device)
return mask
def get_padding_elem(L_in: int, stride: int, kernel_size: int, dilation: int):
if stride > 1:
n_steps = math.ceil(((L_in - kernel_size * dilation) / stride) + 1)
L_out = stride * (n_steps - 1) + kernel_size * dilation
padding = [kernel_size // 2, kernel_size // 2]
else:
L_out = (L_in - dilation * (kernel_size - 1) - 1) // stride + 1
padding = [(L_in - L_out) // 2, (L_in - L_out) // 2]
return padding
class Conv1d(nn.Module):
def __init__(
self,
out_channels,
kernel_size,
in_channels,
stride=1,
dilation=1,
padding='same',
groups=1,
bias=True,
padding_mode='reflect',
):
super().__init__()
self.kernel_size = kernel_size
self.stride = stride
self.dilation = dilation
self.padding = padding
self.padding_mode = padding_mode
self.conv = nn.Conv1d(
in_channels,
out_channels,
self.kernel_size,
stride=self.stride,
dilation=self.dilation,
padding=0,
groups=groups,
bias=bias,
)
def forward(self, x):
if self.padding == 'same':
x = self._manage_padding(x, self.kernel_size, self.dilation,
self.stride)
elif self.padding == 'causal':
num_pad = (self.kernel_size - 1) * self.dilation
x = F.pad(x, (num_pad, 0))
elif self.padding == 'valid':
pass
else:
raise ValueError(
"Padding must be 'same', 'valid' or 'causal'. Got "
+ self.padding)
wx = self.conv(x)
return wx
def _manage_padding(
self,
x,
kernel_size: int,
dilation: int,
stride: int,
):
L_in = x.shape[-1]
padding = get_padding_elem(L_in, stride, kernel_size, dilation)
x = F.pad(x, padding, mode=self.padding_mode)
return x
class BatchNorm1d(nn.Module):
def __init__(
self,
input_size,
eps=1e-05,
momentum=0.1,
):
super().__init__()
self.norm = nn.BatchNorm1d(
input_size,
eps=eps,
momentum=momentum,
)
def forward(self, x):
return self.norm(x)
class TDNNBlock(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
dilation,
activation=nn.ReLU,
groups=1,
):
super(TDNNBlock, self).__init__()
self.conv = Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
dilation=dilation,
groups=groups,
)
self.activation = activation()
self.norm = BatchNorm1d(input_size=out_channels)
def forward(self, x):
return self.norm(self.activation(self.conv(x)))
class Res2NetBlock(torch.nn.Module):
def __init__(self,
in_channels,
out_channels,
scale=8,
kernel_size=3,
dilation=1):
super(Res2NetBlock, self).__init__()
assert in_channels % scale == 0
assert out_channels % scale == 0
in_channel = in_channels // scale
hidden_channel = out_channels // scale
self.blocks = nn.ModuleList([
TDNNBlock(
in_channel,
hidden_channel,
kernel_size=kernel_size,
dilation=dilation,
) for i in range(scale - 1)
])
self.scale = scale
def forward(self, x):
y = []
for i, x_i in enumerate(torch.chunk(x, self.scale, dim=1)):
if i == 0:
y_i = x_i
elif i == 1:
y_i = self.blocks[i - 1](x_i)
else:
y_i = self.blocks[i - 1](x_i + y_i)
y.append(y_i)
y = torch.cat(y, dim=1)
return y
class SEBlock(nn.Module):
def __init__(self, in_channels, se_channels, out_channels):
super(SEBlock, self).__init__()
self.conv1 = Conv1d(
in_channels=in_channels, out_channels=se_channels, kernel_size=1)
self.relu = torch.nn.ReLU(inplace=True)
self.conv2 = Conv1d(
in_channels=se_channels, out_channels=out_channels, kernel_size=1)
self.sigmoid = torch.nn.Sigmoid()
def forward(self, x, lengths=None):
L = x.shape[-1]
if lengths is not None:
mask = length_to_mask(lengths * L, max_len=L, device=x.device)
mask = mask.unsqueeze(1)
total = mask.sum(dim=2, keepdim=True)
s = (x * mask).sum(dim=2, keepdim=True) / total
else:
s = x.mean(dim=2, keepdim=True)
s = self.relu(self.conv1(s))
s = self.sigmoid(self.conv2(s))
return s * x
class AttentiveStatisticsPooling(nn.Module):
def __init__(self, channels, attention_channels=128, global_context=True):
super().__init__()
self.eps = 1e-12
self.global_context = global_context
if global_context:
self.tdnn = TDNNBlock(channels * 3, attention_channels, 1, 1)
else:
self.tdnn = TDNNBlock(channels, attention_channels, 1, 1)
self.tanh = nn.Tanh()
self.conv = Conv1d(
in_channels=attention_channels,
out_channels=channels,
kernel_size=1)
def forward(self, x, lengths=None):
L = x.shape[-1]
def _compute_statistics(x, m, dim=2, eps=self.eps):
mean = (m * x).sum(dim)
std = torch.sqrt(
(m * (x - mean.unsqueeze(dim)).pow(2)).sum(dim).clamp(eps))
return mean, std
if lengths is None:
lengths = torch.ones(x.shape[0], device=x.device)
# Make binary mask of shape [N, 1, L]
mask = length_to_mask(lengths * L, max_len=L, device=x.device)
mask = mask.unsqueeze(1)
# Expand the temporal context of the pooling layer by allowing the
# self-attention to look at global properties of the utterance.
if self.global_context:
# torch.std is unstable for backward computation
# https://github.com/pytorch/pytorch/issues/4320
total = mask.sum(dim=2, keepdim=True).float()
mean, std = _compute_statistics(x, mask / total)
mean = mean.unsqueeze(2).repeat(1, 1, L)
std = std.unsqueeze(2).repeat(1, 1, L)
attn = torch.cat([x, mean, std], dim=1)
else:
attn = x
# Apply layers
attn = self.conv(self.tanh(self.tdnn(attn)))
# Filter out zero-paddings
attn = attn.masked_fill(mask == 0, float('-inf'))
attn = F.softmax(attn, dim=2)
mean, std = _compute_statistics(x, attn)
# Append mean and std of the batch
pooled_stats = torch.cat((mean, std), dim=1)
pooled_stats = pooled_stats.unsqueeze(2)
return pooled_stats
class SERes2NetBlock(nn.Module):
def __init__(
self,
in_channels,
out_channels,
res2net_scale=8,
se_channels=128,
kernel_size=1,
dilation=1,
activation=torch.nn.ReLU,
groups=1,
):
super().__init__()
self.out_channels = out_channels
self.tdnn1 = TDNNBlock(
in_channels,
out_channels,
kernel_size=1,
dilation=1,
activation=activation,
groups=groups,
)
self.res2net_block = Res2NetBlock(out_channels, out_channels,
res2net_scale, kernel_size, dilation)
self.tdnn2 = TDNNBlock(
out_channels,
out_channels,
kernel_size=1,
dilation=1,
activation=activation,
groups=groups,
)
self.se_block = SEBlock(out_channels, se_channels, out_channels)
self.shortcut = None
if in_channels != out_channels:
self.shortcut = Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
)
def forward(self, x, lengths=None):
residual = x
if self.shortcut:
residual = self.shortcut(x)
x = self.tdnn1(x)
x = self.res2net_block(x)
x = self.tdnn2(x)
x = self.se_block(x, lengths)
return x + residual
class ECAPA_TDNN(nn.Module):
"""An implementation of the speaker embedding model in a paper.
"ECAPA-TDNN: Emphasized Channel Attention, Propagation and Aggregation in
TDNN Based Speaker Verification" (https://arxiv.org/abs/2005.07143).
"""
def __init__(
self,
input_size,
device='cpu',
lin_neurons=192,
activation=torch.nn.ReLU,
channels=[512, 512, 512, 512, 1536],
kernel_sizes=[5, 3, 3, 3, 1],
dilations=[1, 2, 3, 4, 1],
attention_channels=128,
res2net_scale=8,
se_channels=128,
global_context=True,
groups=[1, 1, 1, 1, 1],
):
super().__init__()
assert len(channels) == len(kernel_sizes)
assert len(channels) == len(dilations)
self.channels = channels
self.blocks = nn.ModuleList()
# The initial TDNN layer
self.blocks.append(
TDNNBlock(
input_size,
channels[0],
kernel_sizes[0],
dilations[0],
activation,
groups[0],
))
# SE-Res2Net layers
for i in range(1, len(channels) - 1):
self.blocks.append(
SERes2NetBlock(
channels[i - 1],
channels[i],
res2net_scale=res2net_scale,
se_channels=se_channels,
kernel_size=kernel_sizes[i],
dilation=dilations[i],
activation=activation,
groups=groups[i],
))
# Multi-layer feature aggregation
self.mfa = TDNNBlock(
channels[-1],
channels[-1],
kernel_sizes[-1],
dilations[-1],
activation,
groups=groups[-1],
)
# Attentive Statistical Pooling
self.asp = AttentiveStatisticsPooling(
channels[-1],
attention_channels=attention_channels,
global_context=global_context,
)
self.asp_bn = BatchNorm1d(input_size=channels[-1] * 2)
# Final linear transformation
self.fc = Conv1d(
in_channels=channels[-1] * 2,
out_channels=lin_neurons,
kernel_size=1,
)
def forward(self, x, lengths=None):
"""Returns the embedding vector.
Arguments
---------
x : torch.Tensor
Tensor of shape (batch, time, channel).
"""
x = x.transpose(1, 2)
xl = []
for layer in self.blocks:
try:
x = layer(x, lengths=lengths)
except TypeError:
x = layer(x)
xl.append(x)
# Multi-layer feature aggregation
x = torch.cat(xl[1:], dim=1)
x = self.mfa(x)
# Attentive Statistical Pooling
x = self.asp(x, lengths=lengths)
x = self.asp_bn(x)
# Final linear transformation
x = self.fc(x)
x = x.transpose(1, 2).squeeze(1)
return x
@MODELS.register_module(
Tasks.speaker_verification, module_name=Models.ecapa_tdnn_sv)
class SpeakerVerificationECAPATDNN(TorchModel):
def __init__(self, model_dir, model_config: Dict[str, Any], *args,
**kwargs):
super().__init__(model_dir, model_config, *args, **kwargs)
self.model_config = model_config
self.other_config = kwargs
if self.model_config['channel'] != 1024:
raise ValueError(
'modelscope error: Currently only 1024-channel ecapa tdnn is supported.'
)
self.feature_dim = 80
channels_config = [1024, 1024, 1024, 1024, 3072]
self.embedding_model = ECAPA_TDNN(
self.feature_dim, channels=channels_config)
pretrained_model_name = kwargs['pretrained_model']
self.__load_check_point(pretrained_model_name)
self.embedding_model.eval()
def forward(self, audio):
assert len(audio.shape) == 2 and audio.shape[
0] == 1, 'modelscope error: the shape of input audio to model needs to be [1, T]'
# audio shape: [1, T]
feature = self.__extract_feature(audio)
embedding = self.embedding_model(feature)
return embedding
def __extract_feature(self, audio):
feature = Kaldi.fbank(audio, num_mel_bins=self.feature_dim)
feature = feature - feature.mean(dim=0, keepdim=True)
feature = feature.unsqueeze(0)
return feature
def __load_check_point(self, pretrained_model_name, device=None):
if not device:
device = torch.device('cpu')
self.embedding_model.load_state_dict(
torch.load(
os.path.join(self.model_dir, pretrained_model_name),
map_location=device),
strict=True)

View File

@@ -1,6 +1,9 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from modelscope.utils.import_utils import is_torch_available
from .base_head import * # noqa F403
from .base_model import * # noqa F403
from .base_torch_head import * # noqa F403
from .base_torch_model import * # noqa F403
if is_torch_available():
from .base_torch_model import TorchModel
from .base_torch_head import TorchHead

Some files were not shown because too many files have changed in this diff Show More