mirror of
https://github.com/modelscope/modelscope.git
synced 2026-02-24 04:01:10 +01:00
Merge pull request #108 from modelscope/master-merge-internal20230215
Master merge internal20230215
This commit is contained in:
@@ -96,9 +96,9 @@ else
|
||||
fi
|
||||
if [[ $python_version == 3.7* ]]; then
|
||||
base_tag=$base_tag-py37
|
||||
elif [[ $python_version == z* ]]; then
|
||||
elif [[ $python_version == 3.8* ]]; then
|
||||
base_tag=$base_tag-py38
|
||||
elif [[ $python_version == z* ]]; then
|
||||
elif [[ $python_version == 3.9* ]]; then
|
||||
base_tag=$base_tag-py39
|
||||
else
|
||||
echo "Unsupport python version: $python_version"
|
||||
@@ -129,8 +129,15 @@ else
|
||||
echo "Building dsw image well need set ModelScope lib cache location."
|
||||
docker_file_content="${docker_file_content} \nENV MODELSCOPE_CACHE=/mnt/workspace/.cache/modelscope"
|
||||
fi
|
||||
if [ "$is_ci_test" == "True" ]; then
|
||||
echo "Building CI image, uninstall modelscope"
|
||||
docker_file_content="${docker_file_content} \nRUN pip uninstall modelscope -y"
|
||||
fi
|
||||
printf "$docker_file_content" > Dockerfile
|
||||
docker build -t $IMAGE_TO_BUILD \
|
||||
|
||||
while true
|
||||
do
|
||||
docker build -t $IMAGE_TO_BUILD \
|
||||
--build-arg USE_GPU \
|
||||
--build-arg BASE_IMAGE \
|
||||
--build-arg PYTHON_VERSION \
|
||||
@@ -138,11 +145,14 @@ docker build -t $IMAGE_TO_BUILD \
|
||||
--build-arg CUDATOOLKIT_VERSION \
|
||||
--build-arg TENSORFLOW_VERSION \
|
||||
-f Dockerfile .
|
||||
if [ $? -eq 0 ]; then
|
||||
echo "Image build done"
|
||||
break
|
||||
else
|
||||
echo "Running docker build command error, we will retry"
|
||||
fi
|
||||
done
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Running docker build command error, please check the log!"
|
||||
exit -1
|
||||
fi
|
||||
if [ "$run_ci_test" == "True" ]; then
|
||||
echo "Running ci case."
|
||||
export MODELSCOPE_CACHE=/home/mulin.lyh/model_scope_cache
|
||||
|
||||
@@ -20,15 +20,15 @@ if [ "$MODELSCOPE_SDK_DEBUG" == "True" ]; then
|
||||
fi
|
||||
fi
|
||||
|
||||
awk -F: '/^[^#]/ { print $1 }' requirements/framework.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
|
||||
awk -F: '/^[^#]/ { print $1 }' requirements/audio.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
|
||||
awk -F: '/^[^#]/ { print $1 }' requirements/cv.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
|
||||
awk -F: '/^[^#]/ { print $1 }' requirements/multi-modal.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
|
||||
awk -F: '/^[^#]/ { print $1 }' requirements/nlp.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
|
||||
awk -F: '/^[^#]/ { print $1 }' requirements/science.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
|
||||
pip install -r requirements/framework.txt -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
|
||||
pip install -r requirements/audio.txt -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
|
||||
pip install -r requirements/cv.txt -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
|
||||
pip install -r requirements/multi-modal.txt -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
|
||||
pip install -r requirements/nlp.txt -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
|
||||
pip install -r requirements/science.txt -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
|
||||
|
||||
# test with install
|
||||
python setup.py install
|
||||
pip install .
|
||||
else
|
||||
echo "Running case in release image, run case directly!"
|
||||
fi
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -122,6 +122,7 @@ tensorboard.sh
|
||||
.DS_Store
|
||||
replace.sh
|
||||
result.png
|
||||
result.jpg
|
||||
|
||||
# Pytorch
|
||||
*.pth
|
||||
|
||||
3
data/test/audios/speaker1_a_en_16k.wav
Normal file
3
data/test/audios/speaker1_a_en_16k.wav
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:cb35bff3dac9aec36e259461fecae1e1bc2ec029615f30713111cd598993676c
|
||||
size 249646
|
||||
3
data/test/audios/speaker1_b_en_16k.wav
Normal file
3
data/test/audios/speaker1_b_en_16k.wav
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:d7daff767e13d9a2187b676d958065121cd5e26da046d65cd9604e91a87525a2
|
||||
size 201006
|
||||
3
data/test/audios/speaker2_a_en_16k.wav
Normal file
3
data/test/audios/speaker2_a_en_16k.wav
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:a723c134978a17fe12ca2374d0281a8003a56fa44ff9d2249a08791714983362
|
||||
size 249646
|
||||
3
data/test/images/GOPR0384_11_00-000001.png
Normal file
3
data/test/images/GOPR0384_11_00-000001.png
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:1f516e38eea7a16fd48fddc34953cb227d86d22fbcd31de0c1334bb14b96dba8
|
||||
size 932252
|
||||
3
data/test/images/butterfly_lrx2_y.png
Normal file
3
data/test/images/butterfly_lrx2_y.png
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:430575a8cb668113d6b0e91e403be0c0e36a95bbb96c484603a625b52f71edd9
|
||||
size 11858
|
||||
3
data/test/images/content_check.jpg
Normal file
3
data/test/images/content_check.jpg
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:7d486900ecca027d70453322d0f22de4b36f9534a324b8b1cda3ea86bb72bac6
|
||||
size 353096
|
||||
3
data/test/images/face_liveness_xc.png
Normal file
3
data/test/images/face_liveness_xc.png
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:0abad2347748bf312ab0dbce48fdc643a703d94970e1b181cf19b9be6312db8c
|
||||
size 3145728
|
||||
3
data/test/images/face_reconstruction.jpg
Normal file
3
data/test/images/face_reconstruction.jpg
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:b3a4f864cee22265fdbb8008719e0e2e36235bd4bb2fdfbc9278b0b964e86eff
|
||||
size 1921140
|
||||
3
data/test/images/image_debanding.png
Normal file
3
data/test/images/image_debanding.png
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:7f4bc4dd40c69ecc54bc9517f52fbf3df9a5f682cd9f4d4f3f1376bf33ede22d
|
||||
size 2820304
|
||||
3
data/test/images/image_driving_perception.jpg
Normal file
3
data/test/images/image_driving_perception.jpg
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:1f6b6b4abfcc2fc9042c4e51c2e5f530ff84b345cd3176b11e8317143c5a7e0f
|
||||
size 91130
|
||||
3
data/test/images/image_ffhq34_00041527.png
Normal file
3
data/test/images/image_ffhq34_00041527.png
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:0e8a71df766b615e20a5e1cacd47796a5668747e039e7f6f6e1b029b40818cc2
|
||||
size 196993
|
||||
3
data/test/images/image_inpainting/image_inpainting_1.png
Normal file
3
data/test/images/image_inpainting/image_inpainting_1.png
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:6099183bbc513371c3bded04dbff688958a9c7ab569370c0fb4809fc64850e47
|
||||
size 704685
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:5bebb94d42fa4b8dd462fecfa7b248402a30cbc637344ce26143071ca2c470d7
|
||||
size 1636
|
||||
3
data/test/images/image_moire.jpg
Normal file
3
data/test/images/image_moire.jpg
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:680d76723fc28bc6ce729a1cd6f11a7d5fc26b5bfe3b486d885417935c20f493
|
||||
size 869811
|
||||
3
data/test/images/image_multiple_human_parsing.jpg
Normal file
3
data/test/images/image_multiple_human_parsing.jpg
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:176c824d99af119b36f743d3d90b44529167b0e4fc6db276da60fa140ee3f4a9
|
||||
size 87228
|
||||
3
data/test/images/image_open_vocabulary_detection.jpg
Normal file
3
data/test/images/image_open_vocabulary_detection.jpg
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:5b5861ca8955f8ff906abe78f2b32bc49deee2832f4518ffe4bb584653f3c9e9
|
||||
size 187443
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:40f535f4411fc9b3ea9d2d8c7a352f6f9a33465e797332bd1a4162b40aaffe5f
|
||||
size 338334
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:cd3415c9bf1cd099a379f0b3c8049d0f602ec900c9d335b75058355d8db2b077
|
||||
size 358916
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:63c6cd0f0f3b4201a9450dcf3db4b5b4a2b9ad2f48885854868d0c2b6406aac7
|
||||
size 471097
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:9c934ced1221d27153a15c14351c575a91f3ff5a6650c3dc9e0778a4245b2804
|
||||
size 1192
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:f2ab6add1c8a215ca6199baa68d56bca99dbdae7391937493067a6f363b059de
|
||||
size 1453
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:d87bd9fa4dca7c7dbb3253e733517303d9b85c9c6600a58c9e9b7150468036da
|
||||
size 1410
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:6b63bc83b6f5dfeb66f3c79db6fa28b0683690b5dad80b414a03ed723b351edc
|
||||
size 467695
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:9de64a9f9e1903f2a72bbddccfbffd16f6ea9e7a855e673792d66e7ad74c8ff4
|
||||
size 240669
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:5965f3f3293fb7616e439ef4821d586de1f129bcf08279bbd10a5f42463d542f
|
||||
size 240953
|
||||
3
data/test/images/image_phone.jpg
Normal file
3
data/test/images/image_phone.jpg
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:10b494cbc1a29b228745bcb26897e2524569b467b88cc9839be38504d268ca30
|
||||
size 55485
|
||||
3
data/test/images/image_single_human_parsing.jpg
Normal file
3
data/test/images/image_single_human_parsing.jpg
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:2a1976ea249b4ad5409cdae403dcd154fac3c628909b6b1874cc968960e2c62d
|
||||
size 8259
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:2f832af4703878076e42fb41544b82147fd31b6be06713975872f16294d1a613
|
||||
size 28297
|
||||
3
data/test/images/image_traffic_sign.jpg
Normal file
3
data/test/images/image_traffic_sign.jpg
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:b6ab556a1d69010cfe6dd136ff3fbd17ed122c6d0c3509667ef40a656bc18464
|
||||
size 87334
|
||||
BIN
data/test/images/images.zip
Normal file
BIN
data/test/images/images.zip
Normal file
Binary file not shown.
3
data/test/images/ir_face_recognition_1.png
Normal file
3
data/test/images/ir_face_recognition_1.png
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:602b46c6ba1d18fd3b91fd3b47112d37ca9d8e1ed72f0c0ea93ad8d493f5182e
|
||||
size 20299
|
||||
3
data/test/images/ir_face_recognition_2.png
Normal file
3
data/test/images/ir_face_recognition_2.png
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:c0791f043b905f2e77ccf2f8c5b29182e1fc99cee16d9069e8bbc1704e917268
|
||||
size 20631
|
||||
3
data/test/images/universal_matting.jpg
Normal file
3
data/test/images/universal_matting.jpg
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:78d7bf999d1a4186309693ff1b966edb3ccd40f7861a7589167cf9e33897a693
|
||||
size 369725
|
||||
3
data/test/images/vision_efficient_tuning_test_1.png
Normal file
3
data/test/images/vision_efficient_tuning_test_1.png
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:8b28d9c33eff034a706534f195f4443f8c053a74d5553787a5cb9b20873c072f
|
||||
size 1962
|
||||
3
data/test/images/vision_efficient_tuning_test_2.png
Normal file
3
data/test/images/vision_efficient_tuning_test_2.png
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:bbd99f0253d6e0d10ec500cf781cc83b93809db58da54bd914b0b80b7fe8d8a4
|
||||
size 2409
|
||||
3
data/test/videos/kitti-step_testing_image_02_0000.mp4
Normal file
3
data/test/videos/kitti-step_testing_image_02_0000.mp4
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:a834d1272253559cdf45a5f09642fb0b5209242dca854fce849efc15cebd4028
|
||||
size 4623264
|
||||
3
data/test/videos/video_deinterlace_test.mp4
Normal file
3
data/test/videos/video_deinterlace_test.mp4
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:9941ac4a5dd0d9eea5d33ce0009da34d0c93c64ed062479e6c8efb4788e8ef7c
|
||||
size 522972
|
||||
3
data/test/videos/video_nerf_recon_test.mp4
Normal file
3
data/test/videos/video_nerf_recon_test.mp4
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:824cc8beaaa8747a3ec32f4c79308e468838c448853f40e882a7cc090c71bf96
|
||||
size 2151630
|
||||
@@ -12,7 +12,7 @@ RUN apt-get update && apt-get install -y --reinstall ca-certificates && \
|
||||
apt-get clean && \
|
||||
cp /tmp/resources/ubuntu20.04_sources.tuna /etc/apt/sources.list && \
|
||||
apt-get update && \
|
||||
apt-get install -y locales wget git strace gdb vim ffmpeg libsm6 tzdata language-pack-zh-hans ttf-wqy-microhei ttf-wqy-zenhei xfonts-wqy libxext6 build-essential ninja-build && \
|
||||
apt-get install -y locales wget git strace gdb sox libopenmpi-dev curl strace vim ffmpeg libsm6 tzdata language-pack-zh-hans ttf-wqy-microhei ttf-wqy-zenhei xfonts-wqy libxext6 build-essential ninja-build && \
|
||||
wget https://packagecloud.io/github/git-lfs/packages/debian/bullseye/git-lfs_3.2.0_amd64.deb/download -O ./git-lfs_3.2.0_amd64.deb && \
|
||||
dpkg -i ./git-lfs_3.2.0_amd64.deb && \
|
||||
rm -f ./git-lfs_3.2.0_amd64.deb && \
|
||||
@@ -58,12 +58,46 @@ RUN if [ "$USE_GPU" = "True" ] ; then \
|
||||
pip install --no-cache-dir tensorflow==$TENSORFLOW_VERSION; \
|
||||
fi
|
||||
|
||||
# mmcv-full<=1.7.0 for mmdet3d compatible
|
||||
RUN if [ "$USE_GPU" = "True" ] ; then \
|
||||
CUDA_HOME=/usr/local/cuda TORCH_CUDA_ARCH_LIST="5.0 5.2 6.0 6.1 7.0 7.5 8.0 8.6" MMCV_WITH_OPS=1 MAX_JOBS=8 FORCE_CUDA=1 pip install --no-cache-dir mmcv-full && pip cache purge; \
|
||||
CUDA_HOME=/usr/local/cuda TORCH_CUDA_ARCH_LIST="5.0 5.2 6.0 6.1 7.0 7.5 8.0 8.6" MMCV_WITH_OPS=1 MAX_JOBS=8 FORCE_CUDA=1 pip install --no-cache-dir 'mmcv-full<=1.7.0' && pip cache purge; \
|
||||
else \
|
||||
MMCV_WITH_OPS=1 MAX_JOBS=8 pip install --no-cache-dir mmcv-full && pip cache purge; \
|
||||
MMCV_WITH_OPS=1 MAX_JOBS=8 pip install --no-cache-dir 'mmcv-full<=1.7.0' && pip cache purge; \
|
||||
fi
|
||||
|
||||
# default shell bash
|
||||
ENV SHELL=/bin/bash
|
||||
# install special package
|
||||
RUN if [ "$USE_GPU" = "True" ] ; then \
|
||||
pip install --no-cache-dir dgl-cu113 dglgo -f https://data.dgl.ai/wheels/repo.html; \
|
||||
else \
|
||||
pip install --no-cache-dir dgl dglgo -f https://data.dgl.ai/wheels/repo.html; \
|
||||
fi
|
||||
|
||||
# copy install scripts
|
||||
COPY docker/scripts/install_unifold.sh docker/scripts/install_colmap.sh docker/scripts/install_pytorch3d_nvdiffrast.sh docker/scripts/install_tiny_cuda_nn.sh docker/scripts/install_apex.sh /tmp/
|
||||
|
||||
# for uniford
|
||||
RUN if [ "$USE_GPU" = "True" ] ; then \
|
||||
bash /tmp/install_unifold.sh; \
|
||||
else \
|
||||
echo 'cpu unsupport uniford'; \
|
||||
fi
|
||||
|
||||
RUN if [ "$USE_GPU" = "True" ] ; then \
|
||||
pip install --no-cache-dir git+https://github.com/gxd1994/Pointnet2.PyTorch.git@master#subdirectory=pointnet2; \
|
||||
else \
|
||||
echo 'cpu unsupport Pointnet2'; \
|
||||
fi
|
||||
|
||||
RUN pip install --no-cache-dir detectron2==0.3 -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
|
||||
|
||||
# 3d supports
|
||||
RUN bash /tmp/install_colmap.sh
|
||||
RUN bash /tmp/install_tiny_cuda_nn.sh
|
||||
RUN bash /tmp/install_pytorch3d_nvdiffrast.sh
|
||||
# end of 3D
|
||||
|
||||
# install modelscope
|
||||
COPY requirements /var/modelscope
|
||||
RUN pip install --no-cache-dir --upgrade pip && \
|
||||
@@ -76,42 +110,17 @@ RUN pip install --no-cache-dir --upgrade pip && \
|
||||
pip install --no-cache-dir -r /var/modelscope/tests.txt -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html && \
|
||||
pip cache purge
|
||||
|
||||
# default shell bash
|
||||
ENV SHELL=/bin/bash
|
||||
|
||||
# install special package
|
||||
RUN if [ "$USE_GPU" = "True" ] ; then \
|
||||
pip install --no-cache-dir dgl-cu113 dglgo -f https://data.dgl.ai/wheels/repo.html; \
|
||||
else \
|
||||
pip install --no-cache-dir dgl dglgo -f https://data.dgl.ai/wheels/repo.html; \
|
||||
fi
|
||||
|
||||
# install jupyter plugin
|
||||
RUN mkdir -p /root/.local/share/jupyter/labextensions/ && \
|
||||
cp -r /tmp/resources/jupyter_plugins/* /root/.local/share/jupyter/labextensions/
|
||||
|
||||
COPY docker/scripts/modelscope_env_init.sh /usr/local/bin/ms_env_init.sh
|
||||
RUN pip install --no-cache-dir https://modelscope.oss-cn-beijing.aliyuncs.com/releases/dependencies/xtcocotools-1.12-cp37-cp37m-linux_x86_64.whl --force
|
||||
RUN pip install --no-cache-dir xtcocotools==1.12 detectron2==0.3 -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html --force
|
||||
|
||||
# for uniford
|
||||
COPY docker/scripts/install_unifold.sh /tmp/install_unifold.sh
|
||||
RUN if [ "$USE_GPU" = "True" ] ; then \
|
||||
bash /tmp/install_unifold.sh; \
|
||||
else \
|
||||
echo 'cpu unsupport uniford'; \
|
||||
fi
|
||||
|
||||
RUN pip install --no-cache-dir mmcls>=0.21.0 mmdet>=2.25.0 decord>=0.6.0 numpy==1.18.5 https://pypi.tuna.tsinghua.edu.cn/packages/70/ad/06f8a06cef819606cb1a521bcc144288daee5c7e73c5d722492866cb1b92/wenetruntime-1.11.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl ipykernel fairseq fasttext deepspeed
|
||||
COPY docker/scripts/install_apex.sh /tmp/install_apex.sh
|
||||
# speechbrain==0.5.7 for audio compatible
|
||||
RUN pip install --no-cache-dir speechbrain==0.5.7 adaseq>=0.5.0 mmcls>=0.21.0 mmdet>=2.25.0 decord>=0.6.0 numpy==1.18.5 wenetruntime==1.11.0 ipykernel fairseq fasttext deepspeed
|
||||
RUN if [ "$USE_GPU" = "True" ] ; then \
|
||||
bash /tmp/install_apex.sh; \
|
||||
else \
|
||||
echo 'cpu unsupport apex'; \
|
||||
fi
|
||||
RUN apt-get update && apt-get install -y sox && \
|
||||
apt-get clean
|
||||
RUN if [ "$USE_GPU" = "True" ] ; then \
|
||||
pip install --no-cache-dir git+https://github.com/gxd1994/Pointnet2.PyTorch.git@master#subdirectory=pointnet2; \
|
||||
else \
|
||||
echo 'cpu unsupport Pointnet2'; \
|
||||
fi
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
export MAX_JOBS=16
|
||||
git clone https://github.com/NVIDIA/apex
|
||||
cd apex
|
||||
TORCH_CUDA_ARCH_LIST="6.0;6.1;6.2;7.0;7.5;8.0;8.6" pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
|
||||
cd ..
|
||||
rm -rf apex
|
||||
export MAX_JOBS=16 \
|
||||
&& git clone https://github.com/NVIDIA/apex \
|
||||
&& cd apex \
|
||||
&& TORCH_CUDA_ARCH_LIST="6.0;6.1;6.2;7.0;7.5;8.0;8.6" pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./ \
|
||||
&& cd .. \
|
||||
&& rm -rf apex
|
||||
|
||||
24
docker/scripts/install_colmap.sh
Normal file
24
docker/scripts/install_colmap.sh
Normal file
@@ -0,0 +1,24 @@
|
||||
wget -q https://cmake.org/files/v3.25/cmake-3.25.2-linux-x86_64.sh \
|
||||
&& mkdir /opt/cmake \
|
||||
&& sh cmake-3.25.2-linux-x86_64.sh --prefix=/opt/cmake --skip-license \
|
||||
&& ln -s /opt/cmake/bin/cmake /usr/local/bin/cmake \
|
||||
&& rm -f cmake-3.25.2-linux-x86_64.sh \
|
||||
&& apt-get update \
|
||||
&& apt-get install libboost-program-options-dev libboost-filesystem-dev libboost-graph-dev libboost-system-dev libboost-test-dev libeigen3-dev libflann-dev libsuitesparse-dev libfreeimage-dev libmetis-dev libgoogle-glog-dev libgflags-dev libsqlite3-dev libglew-dev qtbase5-dev libqt5opengl5-dev libcgal-dev libceres-dev -y \
|
||||
&& export CMAKE_BUILD_PARALLEL_LEVEL=36 \
|
||||
&& export MAX_JOBS=16 \
|
||||
&& export COLMAP_VERSION=dev \
|
||||
&& export CUDA_ARCHITECTURES="all" \
|
||||
&& git clone https://github.com/colmap/colmap.git \
|
||||
&& cd colmap \
|
||||
&& git reset --hard ${COLMAP_VERSION} \
|
||||
&& mkdir build \
|
||||
&& cd build \
|
||||
&& cmake .. -GNinja -DCMAKE_CUDA_ARCHITECTURES=${CUDA_ARCHITECTURES} \
|
||||
&& ninja \
|
||||
&& ninja install \
|
||||
&& cd ../.. \
|
||||
&& rm -rf colmap \
|
||||
&& apt-get clean \
|
||||
&& strip --remove-section=.note.ABI-tag /usr/lib/x86_64-linux-gnu/libQt5Core.so.5 \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
@@ -1,12 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -eo pipefail
|
||||
|
||||
ModelScopeLib=/usr/local/modelscope/lib64
|
||||
|
||||
if [ ! -d /usr/local/modelscope ]; then
|
||||
mkdir -p $ModelScopeLib
|
||||
fi
|
||||
|
||||
# audio libs
|
||||
wget "http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/release/maas/libs/audio/libmitaec_pyio.so" -O ${ModelScopeLib}/libmitaec_pyio.so
|
||||
14
docker/scripts/install_pytorch3d_nvdiffrast.sh
Normal file
14
docker/scripts/install_pytorch3d_nvdiffrast.sh
Normal file
@@ -0,0 +1,14 @@
|
||||
export CMAKE_BUILD_PARALLEL_LEVEL=36 && export MAX_JOBS=36 && export CMAKE_CUDA_ARCHITECTURES="50;52;60;61;70;75;80;86" \
|
||||
&& pip install --no-cache-dir fvcore iopath \
|
||||
&& curl -LO https://github.com/NVIDIA/cub/archive/1.10.0.tar.gz \
|
||||
&& tar xzf 1.10.0.tar.gz \
|
||||
&& export CUB_HOME=$PWD/cub-1.10.0 \
|
||||
&& pip install "git+https://github.com/facebookresearch/pytorch3d.git@stable" \
|
||||
&& rm -fr 1.10.0.tar.gz cub-1.10.0 \
|
||||
&& apt-get update \
|
||||
&& apt-get install -y --no-install-recommends pkg-config libglvnd0 libgl1 libglx0 libegl1 libgles2 libglvnd-dev libgl1-mesa-dev libegl1-mesa-dev libgles2-mesa-dev -y \
|
||||
&& git clone https://github.com/NVlabs/nvdiffrast.git \
|
||||
&& cd nvdiffrast \
|
||||
&& pip install --no-cache-dir . \
|
||||
&& cd .. \
|
||||
&& rm -rf nvdiffrast
|
||||
8
docker/scripts/install_tiny_cuda_nn.sh
Normal file
8
docker/scripts/install_tiny_cuda_nn.sh
Normal file
@@ -0,0 +1,8 @@
|
||||
export CMAKE_BUILD_PARALLEL_LEVEL=36 && export MAX_JOBS=36 && export TCNN_CUDA_ARCHITECTURES="50;52;60;61;70;75;80;86" \
|
||||
&& git clone --recursive https://github.com/nvlabs/tiny-cuda-nn \
|
||||
&& cd tiny-cuda-nn \
|
||||
&& git checkout v1.6 \
|
||||
&& cd bindings/torch \
|
||||
&& python setup.py install \
|
||||
&& cd ../../.. \
|
||||
&& rm -rf tiny-cuda-nn
|
||||
@@ -27,9 +27,9 @@
|
||||
Currently supported formats include "json", "yaml/yml".
|
||||
|
||||
Examples:
|
||||
>>> load('/path/of/your/file') # file is storaged in disk
|
||||
>>> load('https://path/of/your/file') # file is storaged in Internet
|
||||
>>> load('oss://path/of/your/file') # file is storaged in petrel
|
||||
>>> load('/path/of/your/file') # file is stored in disk
|
||||
>>> load('https://path/of/your/file') # file is stored on internet
|
||||
>>> load('oss://path/of/your/file') # file is stored in petrel
|
||||
|
||||
Returns:
|
||||
The content from the file.
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
|
||||
.. autoclass:: {{ name }}
|
||||
:members:
|
||||
|
||||
:special-members: __init__, __call__
|
||||
|
||||
..
|
||||
autogenerated from source/_templates/classtemplate.rst
|
||||
|
||||
@@ -12,3 +12,16 @@ modelscope.models.cv
|
||||
:template: classtemplate.rst
|
||||
|
||||
easycv_base.EasyCVBaseModel
|
||||
image_colorization.ddcolor.ddcolor_for_image_colorization.DDColorForImageColorization
|
||||
image_deblur.nafnet_for_image_deblur.NAFNetForImageDeblur
|
||||
image_defrcn_fewshot.defrcn_for_fewshot.DeFRCNForFewShot
|
||||
image_denoise.nafnet_for_image_denoise.NAFNetForImageDenoise
|
||||
image_face_fusion.image_face_fusion.ImageFaceFusion
|
||||
image_matching.quadtree_attention_model.QuadTreeAttentionForImageMatching
|
||||
image_skychange.skychange_model.ImageSkychange
|
||||
language_guided_video_summarization.summarizer.ClipItVideoSummarization
|
||||
panorama_depth_estimation.unifuse_model.PanoramaDepthEstimation
|
||||
video_stabilization.DUTRAFTStabilizer.DUTRAFTStabilizer
|
||||
video_summarization.summarizer.PGLVideoSummarization
|
||||
video_super_resolution.real_basicvsr_for_video_super_resolution.RealBasicVSRNetForVideoSR
|
||||
vision_middleware.model.VisionMiddlewareModel
|
||||
|
||||
24
docs/source/api/modelscope.models.multi_modal.rst
Normal file
24
docs/source/api/modelscope.models.multi_modal.rst
Normal file
@@ -0,0 +1,24 @@
|
||||
modelscope.models.multi_modal
|
||||
====================
|
||||
|
||||
.. automodule:: modelscope.models.multi_modal
|
||||
|
||||
.. currentmodule:: modelscope.models.multi_modal
|
||||
|
||||
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
:template: classtemplate.rst
|
||||
|
||||
clip.CLIPForMultiModalEmbedding
|
||||
diffusion.DiffusionForTextToImageSynthesis
|
||||
gemm.GEMMForMultiModalEmbedding
|
||||
team.TEAMForMultiModalSimilarity
|
||||
mmr.VideoCLIPForMultiModalEmbedding
|
||||
mplug_for_all_tasks.MPlugForAllTasks
|
||||
mplug_for_all_tasks.HiTeAForAllTasks
|
||||
ofa_for_all_tasks.OfaForAllTasks
|
||||
ofa_for_text_to_image_synthesis_model.OfaForTextToImageSynthesis
|
||||
multi_stage_diffusion.MultiStageDiffusionForTextToImageSynthesis
|
||||
vldoc.VLDocForDocVLEmbedding
|
||||
60
docs/source/api/modelscope.models.nlp.rst
Normal file
60
docs/source/api/modelscope.models.nlp.rst
Normal file
@@ -0,0 +1,60 @@
|
||||
modelscope.models.nlp
|
||||
====================
|
||||
|
||||
.. automodule:: modelscope.models.nlp
|
||||
|
||||
.. currentmodule:: modelscope.models.nlp
|
||||
|
||||
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
:template: classtemplate.rst
|
||||
|
||||
bart.BartForTextErrorCorrection
|
||||
bert.BertConfig
|
||||
bert.BertModel
|
||||
bert.BertForMaskedLM
|
||||
bert.BertForTextRanking
|
||||
bert.BertForSentenceEmbedding
|
||||
bert.BertForSequenceClassification
|
||||
bert.BertForTokenClassification
|
||||
bert.BertForDocumentSegmentation
|
||||
csanmt.CsanmtForTranslation
|
||||
deberta_v2.DebertaV2Model
|
||||
deberta_v2.DebertaV2ForMaskedLM
|
||||
gpt_neo.GPTNeoModel
|
||||
gpt2.GPT2Model
|
||||
gpt3.GPT3ForTextGeneration
|
||||
gpt3.DistributedGPT3
|
||||
gpt_moe.GPTMoEForTextGeneration
|
||||
gpt_moe.DistributedGPTMoE
|
||||
megatron_bert.MegatronBertConfig
|
||||
megatron_bert.MegatronBertModel
|
||||
megatron_bert.MegatronBertForMaskedLM
|
||||
palm_v2.PalmForTextGeneration
|
||||
ponet.PoNetConfig
|
||||
ponet.PoNetModel
|
||||
ponet.PoNetForMaskedLM
|
||||
space.SpaceForDialogIntent
|
||||
space.SpaceForDialogModeling
|
||||
space.SpaceForDST
|
||||
space_T_cn.TableQuestionAnswering
|
||||
space_T_en.StarForTextToSql
|
||||
structbert.SbertModel
|
||||
structbert.SbertForMaskedLM
|
||||
structbert.SbertForSequenceClassification
|
||||
structbert.SbertForTokenClassification
|
||||
structbert.SbertForFaqQuestionAnswering
|
||||
T5.T5ForConditionalGeneration
|
||||
mglm.MGLMForTextSummarization
|
||||
codegeex.CodeGeeXForCodeTranslation
|
||||
codegeex.CodeGeeXForCodeGeneration
|
||||
veco.VecoConfig
|
||||
veco.VecoModel
|
||||
veco.VecoForMaskedLM
|
||||
veco.VecoForSequenceClassification
|
||||
veco.VecoForTokenClassification
|
||||
bloom.BloomModel
|
||||
unite.UniTEModel
|
||||
use.UserSatisfactionEstimation
|
||||
@@ -12,3 +12,5 @@ modelscope.models
|
||||
bases <modelscope.models.base>
|
||||
builders <modelscope.models.builder>
|
||||
cv <modelscope.models.cv>
|
||||
nlp <modelscope.models.nlp>
|
||||
multi-modal <modelscope.models.multi_modal>
|
||||
|
||||
20
docs/source/api/modelscope.pipelines.audio.rst
Normal file
20
docs/source/api/modelscope.pipelines.audio.rst
Normal file
@@ -0,0 +1,20 @@
|
||||
modelscope.pipelines.audio
|
||||
=======================
|
||||
|
||||
.. automodule:: modelscope.pipelines.audio
|
||||
|
||||
.. currentmodule:: modelscope.pipelines.audio
|
||||
|
||||
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
:template: classtemplate.rst
|
||||
|
||||
ANSPipeline
|
||||
AutomaticSpeechRecognitionPipeline
|
||||
InverseTextProcessingPipeline
|
||||
KWSFarfieldPipeline
|
||||
KeyWordSpottingKwsbpPipeline
|
||||
LinearAECPipeline
|
||||
TextToSpeechSambertHifiganPipeline
|
||||
@@ -11,4 +11,84 @@ modelscope.pipelines.cv
|
||||
:nosignatures:
|
||||
:template: classtemplate.rst
|
||||
|
||||
ActionDetectionPipeline
|
||||
ActionRecognitionPipeline
|
||||
AnimalRecognitionPipeline
|
||||
ArcFaceRecognitionPipeline
|
||||
Body2DKeypointsPipeline
|
||||
CardDetectionPipeline
|
||||
CMDSSLVideoEmbeddingPipeline
|
||||
CrowdCountingPipeline
|
||||
DDColorImageColorizationPipeline
|
||||
EasyCVDetectionPipeline
|
||||
EasyCVSegmentationPipeline
|
||||
Face2DKeypointsPipeline
|
||||
FaceAttributeRecognitionPipeline
|
||||
FaceDetectionPipeline
|
||||
FaceImageGenerationPipeline
|
||||
FaceLivenessIrPipeline
|
||||
FaceProcessingBasePipeline
|
||||
FaceRecognitionOnnxFmPipeline
|
||||
FaceRecognitionOodPipeline
|
||||
FaceRecognitionPipeline
|
||||
FacialExpressionRecognitionPipeline
|
||||
FacialLandmarkConfidencePipeline
|
||||
GeneralImageClassificationPipeline
|
||||
GeneralRecognitionPipeline
|
||||
HICOSSLVideoEmbeddingPipeline
|
||||
Hand2DKeypointsPipeline
|
||||
HandStaticPipeline
|
||||
HumanWholebodyKeypointsPipeline
|
||||
Image2ImageGenerationPipeline
|
||||
Image2ImageTranslationPipeline
|
||||
ImageCartoonPipeline
|
||||
ImageClassificationPipeline
|
||||
ImageColorEnhancePipeline
|
||||
ImageColorizationPipeline
|
||||
ImageDeblurPipeline
|
||||
ImageDefrcnDetectionPipeline
|
||||
ImageDenoisePipeline
|
||||
ImageDetectionPipeline
|
||||
ImageInpaintingPipeline
|
||||
ImageInstanceSegmentationPipeline
|
||||
ImageMatchingPipeline
|
||||
ImageMattingPipeline
|
||||
ImageMultiViewDepthEstimationPipeline
|
||||
ImagePanopticSegmentationEasyCVPipeline
|
||||
ImagePanopticSegmentationPipeline
|
||||
ImagePortraitEnhancementPipeline
|
||||
ImageReidPersonPipeline
|
||||
ImageSalientDetectionPipeline
|
||||
ImageSemanticSegmentationPipeline
|
||||
ImageSkychangePipeline
|
||||
ImageStyleTransferPipeline
|
||||
ImageSuperResolutionPipeline
|
||||
LanguageGuidedVideoSummarizationPipeline
|
||||
LicensePlateDetectionPipeline
|
||||
LiveCategoryPipeline
|
||||
MaskDINOInstanceSegmentationPipeline
|
||||
MaskFaceRecognitionPipeline
|
||||
MogFaceDetectionPipeline
|
||||
MovieSceneSegmentationPipeline
|
||||
MtcnnFaceDetectionPipeline
|
||||
OCRDetectionPipeline
|
||||
OCRRecognitionPipeline
|
||||
PointCloudSceneFlowEstimationPipeline
|
||||
ProductRetrievalEmbeddingPipeline
|
||||
RealtimeObjectDetectionPipeline
|
||||
ReferringVideoObjectSegmentationPipeline
|
||||
RetinaFaceDetectionPipeline
|
||||
ShopSegmentationPipeline
|
||||
SkinRetouchingPipeline
|
||||
TableRecognitionPipeline
|
||||
TextDrivenSegmentationPipeline
|
||||
TinynasClassificationPipeline
|
||||
UlfdFaceDetectionPipeline
|
||||
VideoCategoryPipeline
|
||||
VideoFrameInterpolationPipeline
|
||||
VideoObjectSegmentationPipeline
|
||||
VideoStabilizationPipeline
|
||||
VideoSuperResolutionPipeline
|
||||
VirtualTryonPipeline
|
||||
VisionMiddlewarePipeline
|
||||
VopRetrievalPipeline
|
||||
|
||||
28
docs/source/api/modelscope.pipelines.multi_modal.rst
Normal file
28
docs/source/api/modelscope.pipelines.multi_modal.rst
Normal file
@@ -0,0 +1,28 @@
|
||||
modelscope.pipelines.multi_modal
|
||||
=======================
|
||||
|
||||
.. automodule:: modelscope.pipelines.multi_modal
|
||||
|
||||
.. currentmodule:: modelscope.pipelines.multi_modal
|
||||
|
||||
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
:template: classtemplate.rst
|
||||
|
||||
AutomaticSpeechRecognitionPipeline
|
||||
ChineseStableDiffusionPipeline
|
||||
DocumentVLEmbeddingPipeline
|
||||
GEMMMultiModalEmbeddingPipeline
|
||||
ImageCaptioningPipeline
|
||||
MGeoRankingPipeline
|
||||
MultiModalEmbeddingPipeline
|
||||
StableDiffusionWrapperPipeline
|
||||
TextToImageSynthesisPipeline
|
||||
VideoCaptioningPipeline
|
||||
VideoMultiModalEmbeddingPipeline
|
||||
VideoQuestionAnsweringPipeline
|
||||
VisualEntailmentPipeline
|
||||
VisualGroundingPipeline
|
||||
VisualQuestionAnsweringPipeline
|
||||
45
docs/source/api/modelscope.pipelines.nlp.rst
Normal file
45
docs/source/api/modelscope.pipelines.nlp.rst
Normal file
@@ -0,0 +1,45 @@
|
||||
modelscope.pipelines.nlp
|
||||
=======================
|
||||
|
||||
.. automodule:: modelscope.pipelines.nlp
|
||||
|
||||
.. currentmodule:: modelscope.pipelines.nlp
|
||||
|
||||
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
:template: classtemplate.rst
|
||||
|
||||
AutomaticPostEditingPipeline
|
||||
CodeGeeXCodeGenerationPipeline
|
||||
CodeGeeXCodeTranslationPipeline
|
||||
ConversationalTextToSqlPipeline
|
||||
DialogIntentPredictionPipeline
|
||||
DialogModelingPipeline
|
||||
DialogStateTrackingPipeline
|
||||
DocumentSegmentationPipeline
|
||||
ExtractiveSummarizationPipeline
|
||||
FaqQuestionAnsweringPipeline
|
||||
FasttextSequenceClassificationPipeline
|
||||
FeatureExtractionPipeline
|
||||
FillMaskPipeline
|
||||
InformationExtractionPipeline
|
||||
MGLMTextSummarizationPipeline
|
||||
NamedEntityRecognitionPipeline
|
||||
SentenceEmbeddingPipeline
|
||||
SummarizationPipeline
|
||||
TableQuestionAnsweringPipeline
|
||||
TextClassificationPipeline
|
||||
TextErrorCorrectionPipeline
|
||||
TextGenerationPipeline
|
||||
TextGenerationT5Pipeline
|
||||
TextRankingPipeline
|
||||
TokenClassificationPipeline
|
||||
TranslationEvaluationPipeline
|
||||
TranslationPipeline
|
||||
TranslationQualityEstimationPipeline
|
||||
UserSatisfactionEstimationPipeline
|
||||
WordSegmentationPipeline
|
||||
WordSegmentationThaiPipeline
|
||||
ZeroShotClassificationPipeline
|
||||
@@ -12,3 +12,7 @@ modelscope.pipelines
|
||||
base <modelscope.pipelines.base>
|
||||
builder <modelscope.pipelines.builder>
|
||||
cv <modelscope.pipelines.cv>
|
||||
nlp <modelscope.pipelines.nlp>
|
||||
multi-modal <modelscope.pipelines.multi-modal>
|
||||
audio <modelscope.pipelines.audio>
|
||||
science <modelscope.pipelines.science>
|
||||
|
||||
14
docs/source/api/modelscope.pipelines.science.rst
Normal file
14
docs/source/api/modelscope.pipelines.science.rst
Normal file
@@ -0,0 +1,14 @@
|
||||
modelscope.pipelines.science
|
||||
=======================
|
||||
|
||||
.. automodule:: modelscope.pipelines.science
|
||||
|
||||
.. currentmodule:: modelscope.pipelines.science
|
||||
|
||||
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
:template: classtemplate.rst
|
||||
|
||||
ProteinStructurePipeline
|
||||
44
docs/source/api/modelscope.preprocessors.nlp.rst
Normal file
44
docs/source/api/modelscope.preprocessors.nlp.rst
Normal file
@@ -0,0 +1,44 @@
|
||||
modelscope.preprocessors.nlp
|
||||
====================
|
||||
|
||||
.. automodule:: modelscope.preprocessors.nlp
|
||||
|
||||
.. currentmodule:: modelscope.preprocessors.nlp
|
||||
|
||||
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
:template: classtemplate.rst
|
||||
|
||||
TextErrorCorrectionPreprocessor
|
||||
TextGenerationJiebaPreprocessor
|
||||
DocumentSegmentationTransformersPreprocessor
|
||||
FaqQuestionAnsweringTransformersPreprocessor
|
||||
FillMaskPoNetPreprocessor
|
||||
FillMaskTransformersPreprocessor
|
||||
TextRankingTransformersPreprocessor
|
||||
RelationExtractionTransformersPreprocessor
|
||||
TextClassificationTransformersPreprocessor
|
||||
SentenceEmbeddingTransformersPreprocessor
|
||||
TextGenerationTransformersPreprocessor
|
||||
TextGenerationT5Preprocessor
|
||||
TextGenerationSentencePiecePreprocessor
|
||||
SentencePiecePreprocessor
|
||||
TokenClassificationTransformersPreprocessor
|
||||
WordSegmentationBlankSetToLabelPreprocessor
|
||||
WordSegmentationPreprocessorThai
|
||||
NERPreprocessorThai
|
||||
NERPreprocessorViet
|
||||
ZeroShotClassificationTransformersPreprocessor
|
||||
DialogIntentPredictionPreprocessor
|
||||
DialogModelingPreprocessor
|
||||
DialogStateTrackingPreprocessor
|
||||
InputFeatures
|
||||
MultiWOZBPETextField
|
||||
IntentBPETextField
|
||||
ConversationalTextToSqlPreprocessor
|
||||
TableQuestionAnsweringPreprocessor
|
||||
MGLMSummarizationPreprocessor
|
||||
TranslationEvaluationPreprocessor
|
||||
DialogueClassificationUsePreprocessor
|
||||
@@ -12,3 +12,4 @@ modelscope.preprocessors
|
||||
base <modelscope.preprocessors.base>
|
||||
builders <modelscope.preprocessors.builder>
|
||||
video <modelscope.preprocessors.video>
|
||||
nlp <modelscope.preprocessors.nlp>
|
||||
|
||||
29
docs/source/api/modelscope.trainers.hooks.rst
Normal file
29
docs/source/api/modelscope.trainers.hooks.rst
Normal file
@@ -0,0 +1,29 @@
|
||||
modelscope.trainers.hooks
|
||||
=======================
|
||||
|
||||
.. automodule:: modelscope.trainers.
|
||||
|
||||
.. currentmodule:: modelscope.trainers.hooks
|
||||
|
||||
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
:template: classtemplate.rst
|
||||
|
||||
builder.build_hook
|
||||
hook.Hook
|
||||
priority.Priority
|
||||
checkpoint_hook.CheckpointHook
|
||||
checkpoint_hook.BestCkptSaverHook
|
||||
compression.SparsityHook
|
||||
evaluation_hook.EvaluationHook
|
||||
iter_timer_hook.IterTimerHook
|
||||
logger.TensorboardHook
|
||||
logger.TextLoggerHook
|
||||
lr_scheduler_hook.LrSchedulerHook
|
||||
lr_scheduler_hook.NoneLrSchedulerHook
|
||||
optimizer.OptimizerHook
|
||||
optimizer.NoneOptimizerHook
|
||||
optimizer.ApexAMPOptimizerHook
|
||||
optimizer.TorchAMPOptimizerHook
|
||||
18
docs/source/api/modelscope.trainers.multi_modal.rst
Normal file
18
docs/source/api/modelscope.trainers.multi_modal.rst
Normal file
@@ -0,0 +1,18 @@
|
||||
modelscope.trainers.multi_modal
|
||||
=======================
|
||||
|
||||
.. automodule:: modelscope.trainers.multi_modal
|
||||
|
||||
.. currentmodule:: modelscope.trainers.multi_modal
|
||||
|
||||
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
:template: classtemplate.rst
|
||||
|
||||
clip.CLIPTrainer
|
||||
team.TEAMImgClsTrainer
|
||||
ofa.OFATrainer
|
||||
mplug.MPlugTrainer
|
||||
mgeo_ranking_trainer.MGeoRankingTrainer
|
||||
17
docs/source/api/modelscope.trainers.nlp.rst
Normal file
17
docs/source/api/modelscope.trainers.nlp.rst
Normal file
@@ -0,0 +1,17 @@
|
||||
modelscope.trainers.nlp
|
||||
=======================
|
||||
|
||||
.. automodule:: modelscope.trainers.nlp
|
||||
|
||||
.. currentmodule:: modelscope.trainers.nlp
|
||||
|
||||
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
:template: classtemplate.rst
|
||||
|
||||
sequence_classification_trainer.SequenceClassificationTrainer
|
||||
csanmt_translation_trainer.CsanmtTranslationTrainer
|
||||
text_ranking_trainer.TextRankingTrainer
|
||||
text_generation_trainer.TextGenerationTrainer
|
||||
@@ -12,4 +12,7 @@ modelscope.trainers
|
||||
base <modelscope.trainers.base>
|
||||
builder <modelscope.trainers.builder>
|
||||
EpochBasedTrainer <modelscope.trainers.trainer>
|
||||
Hooks <modelscope.trainers.hooks>
|
||||
cv <modelscope.trainers.cv>
|
||||
nlp <modelscope.trainers.nlp>
|
||||
multi-modal <modelscope.trainers.multi_modal>
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
import sphinx_book_theme
|
||||
# import sphinx_book_theme
|
||||
|
||||
sys.path.insert(0, os.path.abspath('../../'))
|
||||
# -- Project information -----------------------------------------------------
|
||||
|
||||
@@ -22,11 +22,6 @@ ModelScope DOCUMENTATION
|
||||
Trainer <api/modelscope.trainers>
|
||||
MsDataset <api/modelscope.msdatasets>
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 2
|
||||
:caption: Changelog
|
||||
|
||||
change_log.md
|
||||
|
||||
Indices and tables
|
||||
==================
|
||||
|
||||
@@ -1,51 +1,43 @@
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from modelscope.metainfo import Trainers
|
||||
from modelscope.msdatasets.ms_dataset import MsDataset
|
||||
from modelscope.trainers.builder import build_trainer
|
||||
from modelscope.trainers.training_args import (ArgAttr, CliArgumentParser,
|
||||
training_args)
|
||||
from modelscope.trainers.training_args import TrainingArgs
|
||||
|
||||
|
||||
def define_parser():
|
||||
training_args.num_classes = ArgAttr(
|
||||
cfg_node_name=[
|
||||
'model.mm_model.head.num_classes',
|
||||
'model.mm_model.train_cfg.augments.0.num_classes',
|
||||
'model.mm_model.train_cfg.augments.1.num_classes'
|
||||
],
|
||||
type=int,
|
||||
help='number of classes')
|
||||
@dataclass
|
||||
class ImageClassificationTrainingArgs(TrainingArgs):
|
||||
num_classes: int = field(
|
||||
default=None,
|
||||
metadata={
|
||||
'cfg_node': [
|
||||
'model.mm_model.head.num_classes',
|
||||
'model.mm_model.train_cfg.augments.0.num_classes',
|
||||
'model.mm_model.train_cfg.augments.1.num_classes'
|
||||
],
|
||||
'help':
|
||||
'number of classes',
|
||||
})
|
||||
|
||||
training_args.train_batch_size.default = 16
|
||||
training_args.train_data_worker.default = 1
|
||||
training_args.max_epochs.default = 1
|
||||
training_args.optimizer.default = 'AdamW'
|
||||
training_args.lr.default = 1e-4
|
||||
training_args.warmup_iters = ArgAttr(
|
||||
'train.lr_config.warmup_iters',
|
||||
type=int,
|
||||
default=1,
|
||||
help='number of warmup epochs')
|
||||
training_args.topk = ArgAttr(
|
||||
cfg_node_name=[
|
||||
'train.evaluation.metric_options.topk',
|
||||
'evaluation.metric_options.topk'
|
||||
],
|
||||
default=(1, ),
|
||||
help='evaluation using topk, tuple format, eg (1,), (1,5)')
|
||||
topk: tuple = field(
|
||||
default=None,
|
||||
metadata={
|
||||
'cfg_node': [
|
||||
'train.evaluation.metric_options.topk',
|
||||
'evaluation.metric_options.topk'
|
||||
],
|
||||
'help':
|
||||
'evaluation using topk, tuple format, eg (1,), (1,5)',
|
||||
})
|
||||
|
||||
training_args.train_data = ArgAttr(
|
||||
type=str, default='tany0699/cats_and_dogs', help='train dataset')
|
||||
training_args.validation_data = ArgAttr(
|
||||
type=str, default='tany0699/cats_and_dogs', help='validation dataset')
|
||||
training_args.model_id = ArgAttr(
|
||||
type=str,
|
||||
default='damo/cv_vit-base_image-classification_ImageNet-labels',
|
||||
help='model name')
|
||||
|
||||
parser = CliArgumentParser(training_args)
|
||||
return parser
|
||||
warmup_iters: str = field(
|
||||
default=None,
|
||||
metadata={
|
||||
'cfg_node': 'train.lr_config.warmup_iters',
|
||||
'help': 'The warmup iters',
|
||||
})
|
||||
|
||||
|
||||
def create_dataset(name, split):
|
||||
@@ -54,21 +46,26 @@ def create_dataset(name, split):
|
||||
dataset_name, namespace=namespace, subset_name='default', split=split)
|
||||
|
||||
|
||||
def train(parser):
|
||||
cfg_dict = parser.get_cfg_dict()
|
||||
args = parser.args
|
||||
train_dataset = create_dataset(args.train_data, split='train')
|
||||
val_dataset = create_dataset(args.validation_data, split='validation')
|
||||
|
||||
def cfg_modify_fn(cfg):
|
||||
cfg.merge_from_dict(cfg_dict)
|
||||
return cfg
|
||||
def train():
|
||||
args = ImageClassificationTrainingArgs.from_cli(
|
||||
model='damo/cv_vit-base_image-classification_ImageNet-labels',
|
||||
max_epochs=1,
|
||||
lr=1e-4,
|
||||
optimizer='AdamW',
|
||||
warmup_iters=1,
|
||||
topk=(1, ))
|
||||
if args.dataset_name is not None:
|
||||
train_dataset = create_dataset(args.dataset_name, split='train')
|
||||
val_dataset = create_dataset(args.dataset_name, split='validation')
|
||||
else:
|
||||
train_dataset = create_dataset(args.train_dataset_name, split='train')
|
||||
val_dataset = create_dataset(args.val_dataset_name, split='validation')
|
||||
|
||||
kwargs = dict(
|
||||
model=args.model_id, # model id
|
||||
model=args.model, # model id
|
||||
train_dataset=train_dataset, # training dataset
|
||||
eval_dataset=val_dataset, # validation dataset
|
||||
cfg_modify_fn=cfg_modify_fn # callback to modify configuration
|
||||
cfg_modify_fn=args # callback to modify configuration
|
||||
)
|
||||
|
||||
# in distributed training, specify pytorch launcher
|
||||
@@ -82,5 +79,4 @@ def train(parser):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = define_parser()
|
||||
train(parser)
|
||||
train()
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
PYTHONPATH=. python -m torch.distributed.launch --nproc_per_node=2 \
|
||||
examples/pytorch/finetune_image_classification.py \
|
||||
--num_classes 2 \
|
||||
--train_data 'tany0699/cats_and_dogs' \
|
||||
--validation_data 'tany0699/cats_and_dogs'
|
||||
--train_dataset_name 'tany0699/cats_and_dogs' \
|
||||
--val_dataset_name 'tany0699/cats_and_dogs'
|
||||
|
||||
@@ -0,0 +1,90 @@
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from modelscope.msdatasets import MsDataset
|
||||
from modelscope.trainers import EpochBasedTrainer, build_trainer
|
||||
from modelscope.trainers.training_args import TrainingArgs
|
||||
|
||||
|
||||
def get_labels(cfg, metadata):
|
||||
label2id = cfg.safe_get(metadata['cfg_node'])
|
||||
if label2id is not None:
|
||||
return ','.join(label2id.keys())
|
||||
|
||||
|
||||
def set_labels(cfg, labels, metadata):
|
||||
if isinstance(labels, str):
|
||||
labels = labels.split(',')
|
||||
cfg.merge_from_dict(
|
||||
{metadata['cfg_node']: {label: id
|
||||
for id, label in enumerate(labels)}})
|
||||
|
||||
|
||||
@dataclass
|
||||
class TextClassificationArguments(TrainingArgs):
|
||||
|
||||
first_sequence: str = field(
|
||||
default=None,
|
||||
metadata={
|
||||
'help': 'The first sequence key of preprocessor',
|
||||
'cfg_node': 'preprocessor.first_sequence'
|
||||
})
|
||||
|
||||
second_sequence: str = field(
|
||||
default=None,
|
||||
metadata={
|
||||
'help': 'The second sequence key of preprocessor',
|
||||
'cfg_node': 'preprocessor.second_sequence'
|
||||
})
|
||||
|
||||
label: str = field(
|
||||
default=None,
|
||||
metadata={
|
||||
'help': 'The label key of preprocessor',
|
||||
'cfg_node': 'preprocessor.label'
|
||||
})
|
||||
|
||||
labels: str = field(
|
||||
default=None,
|
||||
metadata={
|
||||
'help': 'The labels of the dataset',
|
||||
'cfg_node': 'preprocessor.label2id',
|
||||
'cfg_getter': get_labels,
|
||||
'cfg_setter': set_labels,
|
||||
})
|
||||
|
||||
preprocessor: str = field(
|
||||
default=None,
|
||||
metadata={
|
||||
'help': 'The preprocessor type',
|
||||
'cfg_node': 'preprocessor.type'
|
||||
})
|
||||
|
||||
def __call__(self, config):
|
||||
config = super().__call__(config)
|
||||
config.model['num_labels'] = len(self.labels)
|
||||
if config.train.lr_scheduler.type == 'LinearLR':
|
||||
config.train.lr_scheduler['total_iters'] = \
|
||||
int(len(train_dataset) / self.per_device_train_batch_size) * self.max_epochs
|
||||
return config
|
||||
|
||||
|
||||
args = TextClassificationArguments.from_cli(
|
||||
task='text-classification', eval_metrics='seq-cls-metric')
|
||||
|
||||
print(args)
|
||||
|
||||
dataset = MsDataset.load(args.dataset_name, subset_name=args.subset_name)
|
||||
train_dataset = dataset['train']
|
||||
validation_dataset = dataset['validation']
|
||||
|
||||
kwargs = dict(
|
||||
model=args.model,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=validation_dataset,
|
||||
seed=args.seed,
|
||||
cfg_modify_fn=args)
|
||||
|
||||
os.environ['LOCAL_RANK'] = str(args.local_rank)
|
||||
trainer: EpochBasedTrainer = build_trainer(name='trainer', default_args=kwargs)
|
||||
trainer.train()
|
||||
12
examples/pytorch/text_classification/run_train.sh
Normal file
12
examples/pytorch/text_classification/run_train.sh
Normal file
@@ -0,0 +1,12 @@
|
||||
PYTHONPATH=. python examples/pytorch/text_classification/finetune_text_classification.py \
|
||||
--model 'damo/nlp_structbert_backbone_base_std' \
|
||||
--dataset_name 'clue' \
|
||||
--subset_name 'tnews' \
|
||||
--first_sequence 'sentence' \
|
||||
--preprocessor.label label \
|
||||
--model.num_labels 15 \
|
||||
--labels '0,1,2,3,4,5,6,7,8,9,10,11,12,13,14' \
|
||||
--preprocessor 'sen-cls-tokenizer' \
|
||||
--train.dataloader.workers_per_gpu 0 \
|
||||
--evaluation.dataloader.workers_per_gpu 0 \
|
||||
--train.optimizer.lr 1e-5 \
|
||||
1
examples/pytorch/transformers/configuration.json
Normal file
1
examples/pytorch/transformers/configuration.json
Normal file
@@ -0,0 +1 @@
|
||||
{"framework":"pytorch","train":{"work_dir":"/tmp","max_epochs":10,"dataloader":{"batch_size_per_gpu":16,"workers_per_gpu":0},"optimizer":{"type":"SGD","lr":0.001},"lr_scheduler":{"type":"StepLR","step_size":2},"hooks":[{"type":"CheckpointHook","interval":1}]},"evaluation":{"dataloader":{"batch_size_per_gpu":16,"workers_per_gpu":0,"shuffle":false}}}
|
||||
57
examples/pytorch/transformers/finetune_transformers_model.py
Normal file
57
examples/pytorch/transformers/finetune_transformers_model.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from datasets import load_dataset
|
||||
from transformers import (BertForSequenceClassification, BertTokenizerFast,
|
||||
default_data_collator)
|
||||
|
||||
from modelscope.trainers import EpochBasedTrainer, build_trainer
|
||||
from modelscope.trainers.default_config import DEFAULT_CONFIG, TrainingArgs
|
||||
|
||||
|
||||
@dataclass
|
||||
class TransformersArguments(TrainingArgs):
|
||||
|
||||
num_labels: int = field(
|
||||
default=None, metadata={
|
||||
'help': 'The number of labels',
|
||||
})
|
||||
|
||||
|
||||
args = TransformersArguments.from_cli(
|
||||
task='text-classification', eval_metrics='seq-cls-metric')
|
||||
|
||||
print(args)
|
||||
|
||||
dataset = load_dataset(args.dataset_name, args.subset_name)
|
||||
|
||||
model = BertForSequenceClassification.from_pretrained(
|
||||
args.model, num_labels=args.num_labels)
|
||||
tokenizer = BertTokenizerFast.from_pretrained(args.model)
|
||||
|
||||
|
||||
def tokenize_sentence(row):
|
||||
return tokenizer(row['sentence'], padding='max_length', max_length=128)
|
||||
|
||||
|
||||
# Extra columns, Rename columns
|
||||
dataset = dataset.map(tokenize_sentence).remove_columns(['sentence',
|
||||
'idx']).rename_column(
|
||||
'label', 'labels')
|
||||
|
||||
cfg_file = os.path.join(args.work_dir or './', 'configuration.json')
|
||||
DEFAULT_CONFIG.dump(cfg_file)
|
||||
|
||||
kwargs = dict(
|
||||
model=model,
|
||||
cfg_file=cfg_file,
|
||||
# data_collator
|
||||
data_collator=default_data_collator,
|
||||
train_dataset=dataset['train'],
|
||||
eval_dataset=dataset['validation'],
|
||||
seed=args.seed,
|
||||
cfg_modify_fn=args)
|
||||
|
||||
os.environ['LOCAL_RANK'] = str(args.local_rank)
|
||||
trainer: EpochBasedTrainer = build_trainer(name='trainer', default_args=kwargs)
|
||||
trainer.train()
|
||||
5
examples/pytorch/transformers/run_train.sh
Normal file
5
examples/pytorch/transformers/run_train.sh
Normal file
@@ -0,0 +1,5 @@
|
||||
PYTHONPATH=. python examples/pytorch/transformers/finetune_transformers_model.py \
|
||||
--model bert-base-uncased \
|
||||
--num_labels 15 \
|
||||
--dataset_name clue \
|
||||
--subset_name tnews
|
||||
@@ -1,5 +1,12 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from modelscope.utils.import_utils import is_tf_available, is_torch_available
|
||||
from .base import Exporter
|
||||
from .builder import build_exporter
|
||||
from .nlp import SbertForSequenceClassificationExporter
|
||||
from .tf_model_exporter import TfModelExporter
|
||||
from .torch_model_exporter import TorchModelExporter
|
||||
|
||||
if is_tf_available():
|
||||
from .nlp import CsanmtForTranslationExporter
|
||||
from .tf_model_exporter import TfModelExporter
|
||||
if is_torch_available():
|
||||
from .nlp import SbertForSequenceClassificationExporter, SbertForZeroShotClassificationExporter
|
||||
from .torch_model_exporter import TorchModelExporter
|
||||
|
||||
@@ -6,9 +6,11 @@ from typing import Dict, Union
|
||||
from modelscope.models import Model
|
||||
from modelscope.utils.config import Config, ConfigDict
|
||||
from modelscope.utils.constant import ModelFile
|
||||
from modelscope.utils.hub import snapshot_download
|
||||
from modelscope.utils.logger import get_logger
|
||||
from .builder import build_exporter
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class Exporter(ABC):
|
||||
"""Exporter base class to output model to onnx, torch_script, graphdef, etc.
|
||||
@@ -46,7 +48,12 @@ class Exporter(ABC):
|
||||
if hasattr(cfg, 'export'):
|
||||
export_cfg.update(cfg.export)
|
||||
export_cfg['model'] = model
|
||||
exporter = build_exporter(export_cfg, task_name, kwargs)
|
||||
try:
|
||||
exporter = build_exporter(export_cfg, task_name, kwargs)
|
||||
except KeyError as e:
|
||||
raise KeyError(
|
||||
f'The exporting of model \'{model_cfg.type}\' with task: \'{task_name}\' '
|
||||
f'is not supported currently.') from e
|
||||
return exporter
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@@ -1,2 +1,11 @@
|
||||
from .sbert_for_sequence_classification_exporter import \
|
||||
SbertForSequenceClassificationExporter
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from modelscope.utils.import_utils import is_tf_available, is_torch_available
|
||||
|
||||
if is_tf_available():
|
||||
from .csanmt_for_translation_exporter import CsanmtForTranslationExporter
|
||||
if is_torch_available():
|
||||
from .sbert_for_sequence_classification_exporter import \
|
||||
SbertForSequenceClassificationExporter
|
||||
from .sbert_for_zero_shot_classification_exporter import \
|
||||
SbertForZeroShotClassificationExporter
|
||||
|
||||
185
modelscope/exporters/nlp/csanmt_for_translation_exporter.py
Normal file
185
modelscope/exporters/nlp/csanmt_for_translation_exporter.py
Normal file
@@ -0,0 +1,185 @@
|
||||
import os
|
||||
from typing import Any, Dict
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.saved_model import tag_constants
|
||||
from tensorflow.python.tools import freeze_graph
|
||||
|
||||
from modelscope.exporters.builder import EXPORTERS
|
||||
from modelscope.exporters.tf_model_exporter import TfModelExporter
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.pipelines.nlp.translation_pipeline import TranslationPipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
from modelscope.utils.test_utils import compare_arguments_nested
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
if tf.__version__ >= '2.0':
|
||||
tf = tf.compat.v1
|
||||
tf.disable_eager_execution()
|
||||
|
||||
tf.logging.set_verbosity(tf.logging.INFO)
|
||||
|
||||
|
||||
@EXPORTERS.register_module(Tasks.translation, module_name=Models.translation)
|
||||
class CsanmtForTranslationExporter(TfModelExporter):
|
||||
|
||||
def __init__(self, model=None):
|
||||
super().__init__(model)
|
||||
self.pipeline = TranslationPipeline(self.model)
|
||||
|
||||
def generate_dummy_inputs(self, **kwargs) -> Dict[str, Any]:
|
||||
return_dict = self.pipeline.preprocess(
|
||||
"Alibaba Group's mission is to let the world have no difficult business"
|
||||
)
|
||||
return {'input_wids': return_dict['input_ids']}
|
||||
|
||||
def export_saved_model(self, output_dir, rtol=None, atol=None, **kwargs):
|
||||
|
||||
def _generate_signature():
|
||||
receiver_tensors = {
|
||||
'input_wids':
|
||||
tf.saved_model.utils.build_tensor_info(
|
||||
self.pipeline.input_wids)
|
||||
}
|
||||
export_outputs = {
|
||||
'output_seqs':
|
||||
tf.saved_model.utils.build_tensor_info(
|
||||
self.pipeline.output['output_seqs'])
|
||||
}
|
||||
|
||||
signature_def = tf.saved_model.signature_def_utils.build_signature_def(
|
||||
receiver_tensors, export_outputs,
|
||||
tf.saved_model.signature_constants.PREDICT_METHOD_NAME)
|
||||
|
||||
return {'translation_signature': signature_def}
|
||||
|
||||
with self.pipeline._session.as_default() as sess:
|
||||
builder = tf.saved_model.builder.SavedModelBuilder(output_dir)
|
||||
builder.add_meta_graph_and_variables(
|
||||
sess, [tag_constants.SERVING],
|
||||
signature_def_map=_generate_signature(),
|
||||
assets_collection=ops.get_collection(
|
||||
ops.GraphKeys.ASSET_FILEPATHS),
|
||||
clear_devices=True)
|
||||
builder.save()
|
||||
|
||||
dummy_inputs = self.generate_dummy_inputs()
|
||||
with tf.Session(graph=tf.Graph()) as sess:
|
||||
# Restore model from the saved_modle file, that is exported by TensorFlow estimator.
|
||||
MetaGraphDef = tf.saved_model.loader.load(sess, ['serve'],
|
||||
output_dir)
|
||||
|
||||
# SignatureDef protobuf
|
||||
SignatureDef_map = MetaGraphDef.signature_def
|
||||
SignatureDef = SignatureDef_map['translation_signature']
|
||||
# TensorInfo protobuf
|
||||
X_TensorInfo = SignatureDef.inputs['input_wids']
|
||||
y_TensorInfo = SignatureDef.outputs['output_seqs']
|
||||
X = tf.saved_model.utils.get_tensor_from_tensor_info(
|
||||
X_TensorInfo, sess.graph)
|
||||
y = tf.saved_model.utils.get_tensor_from_tensor_info(
|
||||
y_TensorInfo, sess.graph)
|
||||
outputs = sess.run(y, feed_dict={X: dummy_inputs['input_wids']})
|
||||
trans_result = self.pipeline.postprocess({'output_seqs': outputs})
|
||||
logger.info(trans_result)
|
||||
|
||||
outputs_origin = self.pipeline.forward(
|
||||
{'input_ids': dummy_inputs['input_wids']})
|
||||
|
||||
tols = {}
|
||||
if rtol is not None:
|
||||
tols['rtol'] = rtol
|
||||
if atol is not None:
|
||||
tols['atol'] = atol
|
||||
if not compare_arguments_nested('Output match failed', outputs,
|
||||
outputs_origin['output_seqs'], **tols):
|
||||
raise RuntimeError(
|
||||
'Export saved model failed because of validation error.')
|
||||
|
||||
return {'model': output_dir}
|
||||
|
||||
def export_frozen_graph_def(self,
|
||||
output_dir: str,
|
||||
rtol=None,
|
||||
atol=None,
|
||||
**kwargs):
|
||||
input_saver_def = self.pipeline.model_loader.as_saver_def()
|
||||
inference_graph_def = tf.get_default_graph().as_graph_def()
|
||||
for node in inference_graph_def.node:
|
||||
node.device = ''
|
||||
|
||||
frozen_dir = os.path.join(output_dir, 'frozen')
|
||||
tf.gfile.MkDir(frozen_dir)
|
||||
frozen_graph_path = os.path.join(frozen_dir,
|
||||
'frozen_inference_graph.pb')
|
||||
|
||||
outputs = {
|
||||
'output_trans_result':
|
||||
tf.identity(
|
||||
self.pipeline.output['output_seqs'],
|
||||
name='NmtModel/output_trans_result')
|
||||
}
|
||||
|
||||
for output_key in outputs:
|
||||
tf.add_to_collection('inference_op', outputs[output_key])
|
||||
|
||||
output_node_names = ','.join([
|
||||
'%s/%s' % ('NmtModel', output_key)
|
||||
for output_key in outputs.keys()
|
||||
])
|
||||
print(output_node_names)
|
||||
_ = freeze_graph.freeze_graph_with_def_protos(
|
||||
input_graph_def=tf.get_default_graph().as_graph_def(),
|
||||
input_saver_def=input_saver_def,
|
||||
input_checkpoint=self.pipeline.model_path,
|
||||
output_node_names=output_node_names,
|
||||
restore_op_name='save/restore_all',
|
||||
filename_tensor_name='save/Const:0',
|
||||
output_graph=frozen_graph_path,
|
||||
clear_devices=True,
|
||||
initializer_nodes='')
|
||||
|
||||
# 5. test frozen.pb
|
||||
dummy_inputs = self.generate_dummy_inputs()
|
||||
with self.pipeline._session.as_default() as sess:
|
||||
sess.run(tf.tables_initializer())
|
||||
|
||||
graph = tf.Graph()
|
||||
with tf.gfile.GFile(frozen_graph_path, 'rb') as f:
|
||||
graph_def = tf.GraphDef()
|
||||
graph_def.ParseFromString(f.read())
|
||||
|
||||
with graph.as_default():
|
||||
tf.import_graph_def(graph_def, name='')
|
||||
graph.finalize()
|
||||
|
||||
with tf.Session(graph=graph) as trans_sess:
|
||||
outputs = trans_sess.run(
|
||||
'NmtModel/strided_slice_9:0',
|
||||
feed_dict={'input_wids:0': dummy_inputs['input_wids']})
|
||||
trans_result = self.pipeline.postprocess(
|
||||
{'output_seqs': outputs})
|
||||
logger.info(trans_result)
|
||||
|
||||
outputs_origin = self.pipeline.forward(
|
||||
{'input_ids': dummy_inputs['input_wids']})
|
||||
|
||||
tols = {}
|
||||
if rtol is not None:
|
||||
tols['rtol'] = rtol
|
||||
if atol is not None:
|
||||
tols['atol'] = atol
|
||||
if not compare_arguments_nested('Output match failed', outputs,
|
||||
outputs_origin['output_seqs'], **tols):
|
||||
raise RuntimeError(
|
||||
'Export frozen graphdef failed because of validation error.')
|
||||
|
||||
return {'model': frozen_graph_path}
|
||||
|
||||
def export_onnx(self, output_dir: str, opset=13, **kwargs):
|
||||
raise NotImplementedError(
|
||||
'csanmt model does not support onnx format, consider using savedmodel instead.'
|
||||
)
|
||||
@@ -1,4 +1,3 @@
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Dict, Mapping, Tuple
|
||||
|
||||
@@ -7,9 +6,7 @@ from torch.utils.data.dataloader import default_collate
|
||||
from modelscope.exporters.builder import EXPORTERS
|
||||
from modelscope.exporters.torch_model_exporter import TorchModelExporter
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.preprocessors import (
|
||||
Preprocessor, TextClassificationTransformersPreprocessor,
|
||||
build_preprocessor)
|
||||
from modelscope.preprocessors import Preprocessor
|
||||
from modelscope.utils.constant import ModeKeys, Tasks
|
||||
|
||||
|
||||
@@ -17,8 +14,6 @@ from modelscope.utils.constant import ModeKeys, Tasks
|
||||
@EXPORTERS.register_module(
|
||||
Tasks.text_classification, module_name=Models.structbert)
|
||||
@EXPORTERS.register_module(Tasks.sentence_similarity, module_name=Models.bert)
|
||||
@EXPORTERS.register_module(
|
||||
Tasks.zero_shot_classification, module_name=Models.bert)
|
||||
@EXPORTERS.register_module(
|
||||
Tasks.sentiment_classification, module_name=Models.bert)
|
||||
@EXPORTERS.register_module(Tasks.nli, module_name=Models.bert)
|
||||
@@ -27,8 +22,6 @@ from modelscope.utils.constant import ModeKeys, Tasks
|
||||
@EXPORTERS.register_module(
|
||||
Tasks.sentiment_classification, module_name=Models.structbert)
|
||||
@EXPORTERS.register_module(Tasks.nli, module_name=Models.structbert)
|
||||
@EXPORTERS.register_module(
|
||||
Tasks.zero_shot_classification, module_name=Models.structbert)
|
||||
class SbertForSequenceClassificationExporter(TorchModelExporter):
|
||||
|
||||
def generate_dummy_inputs(self,
|
||||
|
||||
@@ -0,0 +1,58 @@
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Dict, Mapping
|
||||
|
||||
from modelscope.exporters.builder import EXPORTERS
|
||||
from modelscope.exporters.torch_model_exporter import TorchModelExporter
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.preprocessors import Preprocessor
|
||||
from modelscope.utils.constant import Tasks
|
||||
|
||||
|
||||
@EXPORTERS.register_module(
|
||||
Tasks.zero_shot_classification, module_name=Models.bert)
|
||||
@EXPORTERS.register_module(
|
||||
Tasks.zero_shot_classification, module_name=Models.structbert)
|
||||
class SbertForZeroShotClassificationExporter(TorchModelExporter):
|
||||
|
||||
def generate_dummy_inputs(self,
|
||||
candidate_labels,
|
||||
hypothesis_template,
|
||||
max_length=128,
|
||||
pair: bool = False,
|
||||
**kwargs) -> Dict[str, Any]:
|
||||
"""Generate dummy inputs for model exportation to onnx or other formats by tracing.
|
||||
|
||||
Args:
|
||||
|
||||
max_length(int): The max length of sentence, default 128.
|
||||
hypothesis_template(str): The template of prompt, like '这篇文章的标题是{}'
|
||||
candidate_labels(List): The labels of prompt,
|
||||
like ['文化', '体育', '娱乐', '财经', '家居', '汽车', '教育', '科技', '军事']
|
||||
pair(bool, `optional`): Whether to generate sentence pairs or single sentences.
|
||||
|
||||
Returns:
|
||||
Dummy inputs.
|
||||
"""
|
||||
|
||||
assert hasattr(
|
||||
self.model, 'model_dir'
|
||||
), 'model_dir attribute is required to build the preprocessor'
|
||||
preprocessor = Preprocessor.from_pretrained(
|
||||
self.model.model_dir, max_length=max_length)
|
||||
return preprocessor(
|
||||
preprocessor.nlp_tokenizer.tokenizer.unk_token,
|
||||
candidate_labels=candidate_labels,
|
||||
hypothesis_template=hypothesis_template)
|
||||
|
||||
@property
|
||||
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||
dynamic_axis = {0: 'batch', 1: 'sequence'}
|
||||
return OrderedDict([
|
||||
('input_ids', dynamic_axis),
|
||||
('attention_mask', dynamic_axis),
|
||||
('token_type_ids', dynamic_axis),
|
||||
])
|
||||
|
||||
@property
|
||||
def outputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||
return OrderedDict({'logits': {0: 'batch'}})
|
||||
@@ -1,5 +1,6 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
from abc import abstractmethod
|
||||
from typing import Any, Callable, Dict, Mapping
|
||||
|
||||
import tensorflow as tf
|
||||
@@ -7,7 +8,7 @@ import tensorflow as tf
|
||||
from modelscope.outputs import ModelOutputBase
|
||||
from modelscope.utils.constant import ModelFile
|
||||
from modelscope.utils.logger import get_logger
|
||||
from modelscope.utils.regress_test_utils import compare_arguments_nested
|
||||
from modelscope.utils.test_utils import compare_arguments_nested
|
||||
from .base import Exporter
|
||||
|
||||
logger = get_logger()
|
||||
@@ -29,6 +30,14 @@ class TfModelExporter(Exporter):
|
||||
self._tf2_export_onnx(model, onnx_file, opset=opset, **kwargs)
|
||||
return {'model': onnx_file}
|
||||
|
||||
@abstractmethod
|
||||
def export_saved_model(self, output_dir: str, **kwargs):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def export_frozen_graph_def(self, output_dir: str, **kwargs):
|
||||
pass
|
||||
|
||||
def _tf2_export_onnx(self,
|
||||
model,
|
||||
output: str,
|
||||
@@ -59,56 +68,67 @@ class TfModelExporter(Exporter):
|
||||
onnx.save(onnx_model, output)
|
||||
|
||||
if validation:
|
||||
try:
|
||||
import onnx
|
||||
import onnxruntime as ort
|
||||
except ImportError:
|
||||
logger.warn(
|
||||
'Cannot validate the exported onnx file, because '
|
||||
'the installation of onnx or onnxruntime cannot be found')
|
||||
return
|
||||
self._validate_model(dummy_inputs, model, output, rtol, atol,
|
||||
call_func)
|
||||
|
||||
def tensor_nested_numpify(tensors):
|
||||
if isinstance(tensors, (list, tuple)):
|
||||
return type(tensors)(
|
||||
tensor_nested_numpify(t) for t in tensors)
|
||||
if isinstance(tensors, Mapping):
|
||||
# return dict
|
||||
return {
|
||||
k: tensor_nested_numpify(t)
|
||||
for k, t in tensors.items()
|
||||
}
|
||||
if isinstance(tensors, tf.Tensor):
|
||||
t = tensors.cpu()
|
||||
return t.numpy()
|
||||
return tensors
|
||||
def _validate_model(
|
||||
self,
|
||||
dummy_inputs,
|
||||
model,
|
||||
output,
|
||||
rtol: float = None,
|
||||
atol: float = None,
|
||||
call_func: Callable = None,
|
||||
):
|
||||
try:
|
||||
import onnx
|
||||
import onnxruntime as ort
|
||||
except ImportError:
|
||||
logger.warn(
|
||||
'Cannot validate the exported onnx file, because '
|
||||
'the installation of onnx or onnxruntime cannot be found')
|
||||
return
|
||||
|
||||
onnx_model = onnx.load(output)
|
||||
onnx.checker.check_model(onnx_model)
|
||||
ort_session = ort.InferenceSession(output)
|
||||
outputs_origin = call_func(
|
||||
dummy_inputs) if call_func is not None else model(dummy_inputs)
|
||||
if isinstance(outputs_origin, (Mapping, ModelOutputBase)):
|
||||
outputs_origin = list(
|
||||
tensor_nested_numpify(outputs_origin).values())
|
||||
elif isinstance(outputs_origin, (tuple, list)):
|
||||
outputs_origin = list(tensor_nested_numpify(outputs_origin))
|
||||
outputs = ort_session.run(
|
||||
None,
|
||||
tensor_nested_numpify(dummy_inputs),
|
||||
)
|
||||
outputs = tensor_nested_numpify(outputs)
|
||||
if isinstance(outputs, dict):
|
||||
outputs = list(outputs.values())
|
||||
elif isinstance(outputs, tuple):
|
||||
outputs = list(outputs)
|
||||
def tensor_nested_numpify(tensors):
|
||||
if isinstance(tensors, (list, tuple)):
|
||||
return type(tensors)(tensor_nested_numpify(t) for t in tensors)
|
||||
if isinstance(tensors, Mapping):
|
||||
# return dict
|
||||
return {
|
||||
k: tensor_nested_numpify(t)
|
||||
for k, t in tensors.items()
|
||||
}
|
||||
if isinstance(tensors, tf.Tensor):
|
||||
t = tensors.cpu()
|
||||
return t.numpy()
|
||||
return tensors
|
||||
|
||||
tols = {}
|
||||
if rtol is not None:
|
||||
tols['rtol'] = rtol
|
||||
if atol is not None:
|
||||
tols['atol'] = atol
|
||||
if not compare_arguments_nested('Onnx model output match failed',
|
||||
outputs, outputs_origin, **tols):
|
||||
raise RuntimeError(
|
||||
'export onnx failed because of validation error.')
|
||||
onnx_model = onnx.load(output)
|
||||
onnx.checker.check_model(onnx_model, full_check=True)
|
||||
ort_session = ort.InferenceSession(output)
|
||||
outputs_origin = call_func(
|
||||
dummy_inputs) if call_func is not None else model(dummy_inputs)
|
||||
if isinstance(outputs_origin, (Mapping, ModelOutputBase)):
|
||||
outputs_origin = list(
|
||||
tensor_nested_numpify(outputs_origin).values())
|
||||
elif isinstance(outputs_origin, (tuple, list)):
|
||||
outputs_origin = list(tensor_nested_numpify(outputs_origin))
|
||||
outputs = ort_session.run(
|
||||
None,
|
||||
tensor_nested_numpify(dummy_inputs),
|
||||
)
|
||||
outputs = tensor_nested_numpify(outputs)
|
||||
if isinstance(outputs, dict):
|
||||
outputs = list(outputs.values())
|
||||
elif isinstance(outputs, tuple):
|
||||
outputs = list(outputs)
|
||||
|
||||
tols = {}
|
||||
if rtol is not None:
|
||||
tols['rtol'] = rtol
|
||||
if atol is not None:
|
||||
tols['atol'] = atol
|
||||
if not compare_arguments_nested('Onnx model output match failed',
|
||||
outputs, outputs_origin, **tols):
|
||||
raise RuntimeError(
|
||||
'export onnx failed because of validation error.')
|
||||
|
||||
@@ -27,9 +27,9 @@ def load(file, file_format=None, **kwargs):
|
||||
Currently supported formats include "json", "yaml/yml".
|
||||
|
||||
Examples:
|
||||
>>> load('/path/of/your/file') # file is storaged in disk
|
||||
>>> load('https://path/of/your/file') # file is storaged in Internet
|
||||
>>> load('oss://path/of/your/file') # file is storaged in petrel
|
||||
>>> load('/path/of/your/file') # file is stored in disk
|
||||
>>> load('https://path/of/your/file') # file is stored on internet
|
||||
>>> load('oss://path/of/your/file') # file is stored in petrel
|
||||
|
||||
Returns:
|
||||
The content from the file.
|
||||
|
||||
@@ -1,15 +1,13 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
import pickle
|
||||
from typing import Dict, Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from modelscope.hub.api import HubApi, ModelScopeConfig
|
||||
from modelscope.hub.constants import (FILE_HASH, MODEL_META_FILE_NAME,
|
||||
MODEL_META_MODEL_ID)
|
||||
from modelscope.hub.constants import FILE_HASH
|
||||
from modelscope.hub.git import GitCommandWrapper
|
||||
from modelscope.hub.utils.caching import FileSystemCache, ModelFileSystemCache
|
||||
from modelscope.hub.utils.caching import ModelFileSystemCache
|
||||
from modelscope.hub.utils.utils import compute_hash
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from modelscope.utils.constant import Fields, Tasks
|
||||
|
||||
|
||||
class Models(object):
|
||||
@@ -7,35 +8,44 @@ class Models(object):
|
||||
Holds the standard model name to use for identifying different model.
|
||||
This should be used to register models.
|
||||
|
||||
Model name should only contain model info but not task info.
|
||||
Model name should only contain model information but not task information.
|
||||
"""
|
||||
# tinynas models
|
||||
tinynas_detection = 'tinynas-detection'
|
||||
tinynas_damoyolo = 'tinynas-damoyolo'
|
||||
|
||||
# vision models
|
||||
detection = 'detection'
|
||||
mask_scoring = 'MaskScoring'
|
||||
image_restoration = 'image-restoration'
|
||||
realtime_object_detection = 'realtime-object-detection'
|
||||
realtime_video_object_detection = 'realtime-video-object-detection'
|
||||
scrfd = 'scrfd'
|
||||
depe = 'depe'
|
||||
classification_model = 'ClassificationModel'
|
||||
easyrobust_model = 'EasyRobustModel'
|
||||
bnext = 'bnext'
|
||||
yolopv2 = 'yolopv2'
|
||||
nafnet = 'nafnet'
|
||||
csrnet = 'csrnet'
|
||||
adaint = 'adaint'
|
||||
deeplpfnet = 'deeplpfnet'
|
||||
rrdb = 'rrdb'
|
||||
cascade_mask_rcnn_swin = 'cascade_mask_rcnn_swin'
|
||||
maskdino_swin = 'maskdino_swin'
|
||||
gpen = 'gpen'
|
||||
product_retrieval_embedding = 'product-retrieval-embedding'
|
||||
body_2d_keypoints = 'body-2d-keypoints'
|
||||
body_3d_keypoints = 'body-3d-keypoints'
|
||||
body_3d_keypoints_hdformer = 'hdformer'
|
||||
crowd_counting = 'HRNetCrowdCounting'
|
||||
face_2d_keypoints = 'face-2d-keypoints'
|
||||
panoptic_segmentation = 'swinL-panoptic-segmentation'
|
||||
r50_panoptic_segmentation = 'r50-panoptic-segmentation'
|
||||
image_reid_person = 'passvitb'
|
||||
image_inpainting = 'FFTInpainting'
|
||||
image_paintbyexample = 'Stablediffusion-Paintbyexample'
|
||||
video_summarization = 'pgl-video-summarization'
|
||||
video_panoptic_segmentation = 'swinb-video-panoptic-segmentation'
|
||||
language_guided_video_summarization = 'clip-it-language-guided-video-summarization'
|
||||
swinL_semantic_segmentation = 'swinL-semantic-segmentation'
|
||||
vitadapter_semantic_segmentation = 'vitadapter-semantic-segmentation'
|
||||
@@ -70,6 +80,7 @@ class Models(object):
|
||||
video_human_matting = 'video-human-matting'
|
||||
video_frame_interpolation = 'video-frame-interpolation'
|
||||
video_object_segmentation = 'video-object-segmentation'
|
||||
video_deinterlace = 'video-deinterlace'
|
||||
quadtree_attention_image_matching = 'quadtree-attention-image-matching'
|
||||
vision_middleware = 'vision-middleware'
|
||||
video_stabilization = 'video-stabilization'
|
||||
@@ -78,14 +89,31 @@ class Models(object):
|
||||
image_casmvs_depth_estimation = 'image-casmvs-depth-estimation'
|
||||
vop_retrieval_model = 'vop-retrieval-model'
|
||||
ddcolor = 'ddcolor'
|
||||
image_probing_model = 'image-probing-model'
|
||||
defrcn = 'defrcn'
|
||||
image_face_fusion = 'image-face-fusion'
|
||||
content_check = 'content-check'
|
||||
open_vocabulary_detection_vild = 'open-vocabulary-detection-vild'
|
||||
ecbsr = 'ecbsr'
|
||||
msrresnet_lite = 'msrresnet-lite'
|
||||
object_detection_3d = 'object_detection_3d'
|
||||
ddpm = 'ddpm'
|
||||
ocr_recognition = 'OCRRecognition'
|
||||
image_quality_assessment_mos = 'image-quality-assessment-mos'
|
||||
image_quality_assessment_degradation = 'image-quality-assessment-degradation'
|
||||
m2fp = 'm2fp'
|
||||
nerf_recon_acc = 'nerf-recon-acc'
|
||||
bts_depth_estimation = 'bts-depth-estimation'
|
||||
vision_efficient_tuning = 'vision-efficient-tuning'
|
||||
|
||||
bad_image_detecting = 'bad-image-detecting'
|
||||
|
||||
# EasyCV models
|
||||
yolox = 'YOLOX'
|
||||
segformer = 'Segformer'
|
||||
hand_2d_keypoints = 'HRNet-Hand2D-Keypoints'
|
||||
image_object_detection_auto = 'image-object-detection-auto'
|
||||
dino = 'DINO'
|
||||
|
||||
# nlp models
|
||||
bert = 'bert'
|
||||
@@ -122,6 +150,12 @@ class Models(object):
|
||||
unite = 'unite'
|
||||
megatron_bert = 'megatron-bert'
|
||||
use = 'user-satisfaction-estimation'
|
||||
fid_plug = 'fid-plug'
|
||||
lstm = 'lstm'
|
||||
xlm_roberta = 'xlm-roberta'
|
||||
transformers = 'transformers'
|
||||
plug_mental = 'plug-mental'
|
||||
doc2bot = 'doc2bot'
|
||||
|
||||
# audio models
|
||||
sambert_hifigan = 'sambert-hifigan'
|
||||
@@ -135,6 +169,8 @@ class Models(object):
|
||||
generic_itn = 'generic-itn'
|
||||
generic_punc = 'generic-punc'
|
||||
generic_sv = 'generic-sv'
|
||||
ecapa_tdnn_sv = 'ecapa-tdnn-sv'
|
||||
generic_lm = 'generic-lm'
|
||||
|
||||
# multi-modal models
|
||||
ofa = 'ofa'
|
||||
@@ -162,6 +198,7 @@ class TaskModels(object):
|
||||
fill_mask = 'fill-mask'
|
||||
feature_extraction = 'feature-extraction'
|
||||
text_generation = 'text-generation'
|
||||
text_ranking = 'text-ranking'
|
||||
|
||||
|
||||
class Heads(object):
|
||||
@@ -179,6 +216,11 @@ class Heads(object):
|
||||
information_extraction = 'information-extraction'
|
||||
# text gen
|
||||
text_generation = 'text-generation'
|
||||
# text ranking
|
||||
text_ranking = 'text-ranking'
|
||||
# crf
|
||||
lstm_crf = 'lstm-crf'
|
||||
transformer_crf = 'transformer-crf'
|
||||
|
||||
|
||||
class Pipelines(object):
|
||||
@@ -193,6 +235,7 @@ class Pipelines(object):
|
||||
"""
|
||||
# vision tasks
|
||||
portrait_matting = 'unet-image-matting'
|
||||
universal_matting = 'unet-universal-matting'
|
||||
image_denoise = 'nafnet-image-denoise'
|
||||
image_deblur = 'nafnet-image-deblur'
|
||||
person_image_cartoon = 'unet-person-image-cartoon'
|
||||
@@ -209,16 +252,19 @@ class Pipelines(object):
|
||||
hand_2d_keypoints = 'hrnetv2w18_hand-2d-keypoints_image'
|
||||
human_detection = 'resnet18-human-detection'
|
||||
object_detection = 'vit-object-detection'
|
||||
abnormal_object_detection = 'abnormal-object-detection'
|
||||
easycv_detection = 'easycv-detection'
|
||||
easycv_segmentation = 'easycv-segmentation'
|
||||
face_2d_keypoints = 'mobilenet_face-2d-keypoints_alignment'
|
||||
salient_detection = 'u2net-salient-detection'
|
||||
salient_boudary_detection = 'res2net-salient-detection'
|
||||
camouflaged_detection = 'res2net-camouflaged-detection'
|
||||
image_demoire = 'uhdm-image-demoireing'
|
||||
image_classification = 'image-classification'
|
||||
face_detection = 'resnet-face-detection-scrfd10gkps'
|
||||
face_liveness_ir = 'manual-face-liveness-flir'
|
||||
face_liveness_rgb = 'manual-face-liveness-flir'
|
||||
face_liveness_xc = 'manual-face-liveness-flxc'
|
||||
card_detection = 'resnet-card-detection-scrfd34gkps'
|
||||
ulfd_face_detection = 'manual-face-detection-ulfd'
|
||||
tinymog_face_detection = 'manual-face-detection-tinymog'
|
||||
@@ -234,20 +280,28 @@ class Pipelines(object):
|
||||
nextvit_small_daily_image_classification = 'nextvit-small_image-classification_Dailylife-labels'
|
||||
convnext_base_image_classification_garbage = 'convnext-base_image-classification_garbage'
|
||||
bnext_small_image_classification = 'bnext-small_image-classification_ImageNet-labels'
|
||||
yolopv2_image_driving_percetion_bdd100k = 'yolopv2_image-driving-percetion_bdd100k'
|
||||
common_image_classification = 'common-image-classification'
|
||||
image_color_enhance = 'csrnet-image-color-enhance'
|
||||
adaint_image_color_enhance = 'adaint-image-color-enhance'
|
||||
deeplpf_image_color_enhance = 'deeplpf-image-color-enhance'
|
||||
virtual_try_on = 'virtual-try-on'
|
||||
image_colorization = 'unet-image-colorization'
|
||||
image_style_transfer = 'AAMS-style-transfer'
|
||||
image_super_resolution = 'rrdb-image-super-resolution'
|
||||
image_debanding = 'rrdb-image-debanding'
|
||||
face_image_generation = 'gan-face-image-generation'
|
||||
product_retrieval_embedding = 'resnet50-product-retrieval-embedding'
|
||||
realtime_object_detection = 'cspnet_realtime-object-detection_yolox'
|
||||
realtime_video_object_detection = 'cspnet_realtime-video-object-detection_streamyolo'
|
||||
face_recognition = 'ir101-face-recognition-cfglint'
|
||||
face_recognition_ood = 'ir-face-recognition-ood-rts'
|
||||
face_quality_assessment = 'manual-face-quality-assessment-fqa'
|
||||
face_recognition_ood = 'ir-face-recognition-rts'
|
||||
face_recognition_onnx_ir = 'manual-face-recognition-frir'
|
||||
face_recognition_onnx_fm = 'manual-face-recognition-frfm'
|
||||
arc_face_recognition = 'ir50-face-recognition-arcface'
|
||||
mask_face_recognition = 'resnet-face-recognition-facemask'
|
||||
content_check = 'resnet50-image-classification-cc'
|
||||
image_instance_segmentation = 'cascade-mask-rcnn-swin-image-instance-segmentation'
|
||||
maskdino_instance_segmentation = 'maskdino-swin-image-instance-segmentation'
|
||||
image2image_translation = 'image-to-image-translation'
|
||||
@@ -259,6 +313,7 @@ class Pipelines(object):
|
||||
image_object_detection_auto = 'yolox_image-object-detection-auto'
|
||||
hand_detection = 'yolox-pai_hand-detection'
|
||||
skin_retouching = 'unet-skin-retouching'
|
||||
face_reconstruction = 'resnet50-face-reconstruction'
|
||||
tinynas_classification = 'tinynas-classification'
|
||||
easyrobust_classification = 'easyrobust-classification'
|
||||
tinynas_detection = 'tinynas-detection'
|
||||
@@ -277,6 +332,8 @@ class Pipelines(object):
|
||||
panorama_depth_estimation = 'panorama-depth-estimation'
|
||||
image_reid_person = 'passvitb-image-reid-person'
|
||||
image_inpainting = 'fft-inpainting'
|
||||
image_paintbyexample = 'stablediffusion-paintbyexample'
|
||||
image_inpainting_sdv2 = 'image-inpainting-sdv2'
|
||||
text_driven_segmentation = 'text-driven-segmentation'
|
||||
movie_scene_segmentation = 'resnet50-bert-movie-scene-segmentation'
|
||||
shop_segmentation = 'shop-segmentation'
|
||||
@@ -294,15 +351,31 @@ class Pipelines(object):
|
||||
vision_middleware_multi_task = 'vision-middleware-multi-task'
|
||||
video_frame_interpolation = 'video-frame-interpolation'
|
||||
video_object_segmentation = 'video-object-segmentation'
|
||||
video_deinterlace = 'video-deinterlace'
|
||||
image_matching = 'image-matching'
|
||||
video_stabilization = 'video-stabilization'
|
||||
video_super_resolution = 'realbasicvsr-video-super-resolution'
|
||||
pointcloud_sceneflow_estimation = 'pointcloud-sceneflow-estimation'
|
||||
image_multi_view_depth_estimation = 'image-multi-view-depth-estimation'
|
||||
video_panoptic_segmentation = 'video-panoptic-segmentation'
|
||||
vop_retrieval = 'vop-video-text-retrieval'
|
||||
ddcolor_image_colorization = 'ddcolor-image-colorization'
|
||||
image_structured_model_probing = 'image-structured-model-probing'
|
||||
image_fewshot_detection = 'image-fewshot-detection'
|
||||
image_face_fusion = 'image-face-fusion'
|
||||
open_vocabulary_detection_vild = 'open-vocabulary-detection-vild'
|
||||
ddpm_image_semantic_segmentation = 'ddpm-image-semantic-segmentation'
|
||||
video_colorization = 'video-colorization'
|
||||
motion_generattion = 'mdm-motion-generation'
|
||||
mobile_image_super_resolution = 'mobile-image-super-resolution'
|
||||
image_human_parsing = 'm2fp-image-human-parsing'
|
||||
object_detection_3d_depe = 'object-detection-3d-depe'
|
||||
nerf_recon_acc = 'nerf-recon-acc'
|
||||
bad_image_detecting = 'bad-image-detecting'
|
||||
|
||||
image_quality_assessment_mos = 'image-quality-assessment-mos'
|
||||
image_quality_assessment_degradation = 'image-quality-assessment-degradation'
|
||||
vision_efficient_tuning = 'vision-efficient-tuning'
|
||||
|
||||
# nlp tasks
|
||||
automatic_post_editing = 'automatic-post-editing'
|
||||
@@ -317,6 +390,7 @@ class Pipelines(object):
|
||||
named_entity_recognition_thai = 'named-entity-recognition-thai'
|
||||
named_entity_recognition_viet = 'named-entity-recognition-viet'
|
||||
text_generation = 'text-generation'
|
||||
fid_dialogue = 'fid-dialogue'
|
||||
text2text_generation = 'text2text-generation'
|
||||
sentiment_analysis = 'sentiment-analysis'
|
||||
sentiment_classification = 'sentiment-classification'
|
||||
@@ -324,6 +398,7 @@ class Pipelines(object):
|
||||
fill_mask = 'fill-mask'
|
||||
fill_mask_ponet = 'fill-mask-ponet'
|
||||
csanmt_translation = 'csanmt-translation'
|
||||
interactive_translation = 'interactive-translation'
|
||||
nli = 'nli'
|
||||
dialog_intent_prediction = 'dialog-intent-prediction'
|
||||
dialog_modeling = 'dialog-modeling'
|
||||
@@ -352,6 +427,10 @@ class Pipelines(object):
|
||||
token_classification = 'token-classification'
|
||||
translation_evaluation = 'translation-evaluation'
|
||||
user_satisfaction_estimation = 'user-satisfaction-estimation'
|
||||
siamese_uie = 'siamese-uie'
|
||||
document_grounded_dialog_retrieval = 'document-grounded-dialog-retrieval'
|
||||
document_grounded_dialog_rerank = 'document-grounded-dialog-rerank'
|
||||
document_grounded_dialog_generate = 'document-grounded-dialog-generate'
|
||||
|
||||
# audio tasks
|
||||
sambert_hifigan_tts = 'sambert-hifigan-tts'
|
||||
@@ -365,6 +444,9 @@ class Pipelines(object):
|
||||
itn_inference = 'itn-inference'
|
||||
punc_inference = 'punc-inference'
|
||||
sv_inference = 'sv-inference'
|
||||
vad_inference = 'vad-inference'
|
||||
speaker_verification = 'speaker-verification'
|
||||
lm_inference = 'language-model'
|
||||
|
||||
# multi-modal tasks
|
||||
image_captioning = 'image-captioning'
|
||||
@@ -386,31 +468,322 @@ class Pipelines(object):
|
||||
diffusers_stable_diffusion = 'diffusers-stable-diffusion'
|
||||
document_vl_embedding = 'document-vl-embedding'
|
||||
chinese_stable_diffusion = 'chinese-stable-diffusion'
|
||||
gridvlp_multi_modal_classification = 'gridvlp-multi-modal-classification'
|
||||
gridvlp_multi_modal_embedding = 'gridvlp-multi-modal-embedding'
|
||||
|
||||
# science tasks
|
||||
protein_structure = 'unifold-protein-structure'
|
||||
|
||||
|
||||
class Trainers(object):
|
||||
""" Names for different trainer.
|
||||
DEFAULT_MODEL_FOR_PIPELINE = {
|
||||
# TaskName: (pipeline_module_name, model_repo)
|
||||
Tasks.sentence_embedding:
|
||||
(Pipelines.sentence_embedding,
|
||||
'damo/nlp_corom_sentence-embedding_english-base'),
|
||||
Tasks.text_ranking: (Pipelines.mgeo_ranking,
|
||||
'damo/mgeo_address_ranking_chinese_base'),
|
||||
Tasks.text_ranking: (Pipelines.text_ranking,
|
||||
'damo/nlp_corom_passage-ranking_english-base'),
|
||||
Tasks.word_segmentation:
|
||||
(Pipelines.word_segmentation,
|
||||
'damo/nlp_structbert_word-segmentation_chinese-base'),
|
||||
Tasks.part_of_speech: (Pipelines.part_of_speech,
|
||||
'damo/nlp_structbert_part-of-speech_chinese-base'),
|
||||
Tasks.token_classification:
|
||||
(Pipelines.part_of_speech,
|
||||
'damo/nlp_structbert_part-of-speech_chinese-base'),
|
||||
Tasks.named_entity_recognition:
|
||||
(Pipelines.named_entity_recognition,
|
||||
'damo/nlp_raner_named-entity-recognition_chinese-base-news'),
|
||||
Tasks.relation_extraction:
|
||||
(Pipelines.relation_extraction,
|
||||
'damo/nlp_bert_relation-extraction_chinese-base'),
|
||||
Tasks.information_extraction:
|
||||
(Pipelines.relation_extraction,
|
||||
'damo/nlp_bert_relation-extraction_chinese-base'),
|
||||
Tasks.sentence_similarity:
|
||||
(Pipelines.sentence_similarity,
|
||||
'damo/nlp_structbert_sentence-similarity_chinese-base'),
|
||||
Tasks.translation: (Pipelines.csanmt_translation,
|
||||
'damo/nlp_csanmt_translation_zh2en'),
|
||||
Tasks.nli: (Pipelines.nli, 'damo/nlp_structbert_nli_chinese-base'),
|
||||
Tasks.sentiment_classification:
|
||||
(Pipelines.sentiment_classification,
|
||||
'damo/nlp_structbert_sentiment-classification_chinese-base'
|
||||
), # TODO: revise back after passing the pr
|
||||
Tasks.portrait_matting: (Pipelines.portrait_matting,
|
||||
'damo/cv_unet_image-matting'),
|
||||
Tasks.universal_matting: (Pipelines.universal_matting,
|
||||
'damo/cv_unet_universal-matting'),
|
||||
Tasks.human_detection: (Pipelines.human_detection,
|
||||
'damo/cv_resnet18_human-detection'),
|
||||
Tasks.image_object_detection: (Pipelines.object_detection,
|
||||
'damo/cv_vit_object-detection_coco'),
|
||||
Tasks.image_denoising: (Pipelines.image_denoise,
|
||||
'damo/cv_nafnet_image-denoise_sidd'),
|
||||
Tasks.image_deblurring: (Pipelines.image_deblur,
|
||||
'damo/cv_nafnet_image-deblur_gopro'),
|
||||
Tasks.video_stabilization: (Pipelines.video_stabilization,
|
||||
'damo/cv_dut-raft_video-stabilization_base'),
|
||||
Tasks.video_super_resolution:
|
||||
(Pipelines.video_super_resolution,
|
||||
'damo/cv_realbasicvsr_video-super-resolution_videolq'),
|
||||
Tasks.text_classification:
|
||||
(Pipelines.sentiment_classification,
|
||||
'damo/nlp_structbert_sentiment-classification_chinese-base'),
|
||||
Tasks.text_generation: (Pipelines.text_generation,
|
||||
'damo/nlp_palm2.0_text-generation_chinese-base'),
|
||||
Tasks.zero_shot_classification:
|
||||
(Pipelines.zero_shot_classification,
|
||||
'damo/nlp_structbert_zero-shot-classification_chinese-base'),
|
||||
Tasks.task_oriented_conversation: (Pipelines.dialog_modeling,
|
||||
'damo/nlp_space_dialog-modeling'),
|
||||
Tasks.dialog_state_tracking: (Pipelines.dialog_state_tracking,
|
||||
'damo/nlp_space_dialog-state-tracking'),
|
||||
Tasks.table_question_answering:
|
||||
(Pipelines.table_question_answering_pipeline,
|
||||
'damo/nlp-convai-text2sql-pretrain-cn'),
|
||||
Tasks.document_grounded_dialog_generate:
|
||||
(Pipelines.document_grounded_dialog_generate,
|
||||
'DAMO_ConvAI/nlp_convai_generation_pretrain'),
|
||||
Tasks.document_grounded_dialog_rerank:
|
||||
(Pipelines.document_grounded_dialog_rerank,
|
||||
'damo/nlp_convai_rerank_pretrain'),
|
||||
Tasks.document_grounded_dialog_retrieval:
|
||||
(Pipelines.document_grounded_dialog_retrieval,
|
||||
'DAMO_ConvAI/nlp_convai_retrieval_pretrain'),
|
||||
Tasks.text_error_correction:
|
||||
(Pipelines.text_error_correction,
|
||||
'damo/nlp_bart_text-error-correction_chinese'),
|
||||
Tasks.image_captioning: (Pipelines.image_captioning,
|
||||
'damo/ofa_image-caption_coco_large_en'),
|
||||
Tasks.video_captioning:
|
||||
(Pipelines.video_captioning,
|
||||
'damo/multi-modal_hitea_video-captioning_base_en'),
|
||||
Tasks.image_portrait_stylization:
|
||||
(Pipelines.person_image_cartoon,
|
||||
'damo/cv_unet_person-image-cartoon_compound-models'),
|
||||
Tasks.ocr_detection: (Pipelines.ocr_detection,
|
||||
'damo/cv_resnet18_ocr-detection-line-level_damo'),
|
||||
Tasks.table_recognition:
|
||||
(Pipelines.table_recognition,
|
||||
'damo/cv_dla34_table-structure-recognition_cycle-centernet'),
|
||||
Tasks.document_vl_embedding:
|
||||
(Pipelines.document_vl_embedding,
|
||||
'damo/multi-modal_convnext-roberta-base_vldoc-embedding'),
|
||||
Tasks.license_plate_detection:
|
||||
(Pipelines.license_plate_detection,
|
||||
'damo/cv_resnet18_license-plate-detection_damo'),
|
||||
Tasks.fill_mask: (Pipelines.fill_mask, 'damo/nlp_veco_fill-mask-large'),
|
||||
Tasks.feature_extraction: (Pipelines.feature_extraction,
|
||||
'damo/pert_feature-extraction_base-test'),
|
||||
Tasks.action_recognition: (Pipelines.action_recognition,
|
||||
'damo/cv_TAdaConv_action-recognition'),
|
||||
Tasks.action_detection: (Pipelines.action_detection,
|
||||
'damo/cv_ResNetC3D_action-detection_detection2d'),
|
||||
Tasks.live_category: (Pipelines.live_category,
|
||||
'damo/cv_resnet50_live-category'),
|
||||
Tasks.video_category: (Pipelines.video_category,
|
||||
'damo/cv_resnet50_video-category'),
|
||||
Tasks.multi_modal_embedding: (Pipelines.multi_modal_embedding,
|
||||
'damo/multi-modal_clip-vit-base-patch16_zh'),
|
||||
Tasks.generative_multi_modal_embedding:
|
||||
(Pipelines.generative_multi_modal_embedding,
|
||||
'damo/multi-modal_gemm-vit-large-patch14_generative-multi-modal-embedding'
|
||||
),
|
||||
Tasks.multi_modal_similarity:
|
||||
(Pipelines.multi_modal_similarity,
|
||||
'damo/multi-modal_team-vit-large-patch14_multi-modal-similarity'),
|
||||
Tasks.visual_question_answering:
|
||||
(Pipelines.visual_question_answering,
|
||||
'damo/mplug_visual-question-answering_coco_large_en'),
|
||||
Tasks.video_question_answering:
|
||||
(Pipelines.video_question_answering,
|
||||
'damo/multi-modal_hitea_video-question-answering_base_en'),
|
||||
Tasks.video_embedding: (Pipelines.cmdssl_video_embedding,
|
||||
'damo/cv_r2p1d_video_embedding'),
|
||||
Tasks.text_to_image_synthesis:
|
||||
(Pipelines.text_to_image_synthesis,
|
||||
'damo/cv_diffusion_text-to-image-synthesis_tiny'),
|
||||
Tasks.body_2d_keypoints: (Pipelines.body_2d_keypoints,
|
||||
'damo/cv_hrnetv2w32_body-2d-keypoints_image'),
|
||||
Tasks.body_3d_keypoints: (Pipelines.body_3d_keypoints,
|
||||
'damo/cv_canonical_body-3d-keypoints_video'),
|
||||
Tasks.hand_2d_keypoints:
|
||||
(Pipelines.hand_2d_keypoints,
|
||||
'damo/cv_hrnetw18_hand-pose-keypoints_coco-wholebody'),
|
||||
Tasks.card_detection: (Pipelines.card_detection,
|
||||
'damo/cv_resnet_carddetection_scrfd34gkps'),
|
||||
Tasks.content_check: (Pipelines.content_check,
|
||||
'damo/cv_resnet50_content-check_cc'),
|
||||
Tasks.face_detection:
|
||||
(Pipelines.mog_face_detection,
|
||||
'damo/cv_resnet101_face-detection_cvpr22papermogface'),
|
||||
Tasks.face_liveness: (Pipelines.face_liveness_ir,
|
||||
'damo/cv_manual_face-liveness_flir'),
|
||||
Tasks.face_recognition: (Pipelines.face_recognition,
|
||||
'damo/cv_ir101_facerecognition_cfglint'),
|
||||
Tasks.facial_expression_recognition:
|
||||
(Pipelines.facial_expression_recognition,
|
||||
'damo/cv_vgg19_facial-expression-recognition_fer'),
|
||||
Tasks.face_attribute_recognition:
|
||||
(Pipelines.face_attribute_recognition,
|
||||
'damo/cv_resnet34_face-attribute-recognition_fairface'),
|
||||
Tasks.face_2d_keypoints: (Pipelines.face_2d_keypoints,
|
||||
'damo/cv_mobilenet_face-2d-keypoints_alignment'),
|
||||
Tasks.face_quality_assessment:
|
||||
(Pipelines.face_quality_assessment,
|
||||
'damo/cv_manual_face-quality-assessment_fqa'),
|
||||
Tasks.video_multi_modal_embedding:
|
||||
(Pipelines.video_multi_modal_embedding,
|
||||
'damo/multi_modal_clip_vtretrival_msrvtt_53'),
|
||||
Tasks.image_color_enhancement:
|
||||
(Pipelines.image_color_enhance,
|
||||
'damo/cv_csrnet_image-color-enhance-models'),
|
||||
Tasks.virtual_try_on: (Pipelines.virtual_try_on,
|
||||
'damo/cv_daflow_virtual-try-on_base'),
|
||||
Tasks.image_colorization: (Pipelines.ddcolor_image_colorization,
|
||||
'damo/cv_ddcolor_image-colorization'),
|
||||
Tasks.video_colorization: (Pipelines.video_colorization,
|
||||
'damo/cv_unet_video-colorization'),
|
||||
Tasks.image_segmentation:
|
||||
(Pipelines.image_instance_segmentation,
|
||||
'damo/cv_swin-b_image-instance-segmentation_coco'),
|
||||
Tasks.image_driving_perception:
|
||||
(Pipelines.yolopv2_image_driving_percetion_bdd100k,
|
||||
'damo/cv_yolopv2_image-driving-perception_bdd100k'),
|
||||
Tasks.image_depth_estimation:
|
||||
(Pipelines.image_depth_estimation,
|
||||
'damo/cv_newcrfs_image-depth-estimation_indoor'),
|
||||
Tasks.indoor_layout_estimation:
|
||||
(Pipelines.indoor_layout_estimation,
|
||||
'damo/cv_panovit_indoor-layout-estimation'),
|
||||
Tasks.video_depth_estimation:
|
||||
(Pipelines.video_depth_estimation,
|
||||
'damo/cv_dro-resnet18_video-depth-estimation_indoor'),
|
||||
Tasks.panorama_depth_estimation:
|
||||
(Pipelines.panorama_depth_estimation,
|
||||
'damo/cv_unifuse_panorama-depth-estimation'),
|
||||
Tasks.image_style_transfer: (Pipelines.image_style_transfer,
|
||||
'damo/cv_aams_style-transfer_damo'),
|
||||
Tasks.face_image_generation: (Pipelines.face_image_generation,
|
||||
'damo/cv_gan_face-image-generation'),
|
||||
Tasks.image_super_resolution: (Pipelines.image_super_resolution,
|
||||
'damo/cv_rrdb_image-super-resolution'),
|
||||
Tasks.image_debanding: (Pipelines.image_debanding,
|
||||
'damo/cv_rrdb_image-debanding'),
|
||||
Tasks.image_portrait_enhancement:
|
||||
(Pipelines.image_portrait_enhancement,
|
||||
'damo/cv_gpen_image-portrait-enhancement'),
|
||||
Tasks.product_retrieval_embedding:
|
||||
(Pipelines.product_retrieval_embedding,
|
||||
'damo/cv_resnet50_product-bag-embedding-models'),
|
||||
Tasks.image_to_image_generation:
|
||||
(Pipelines.image_to_image_generation,
|
||||
'damo/cv_latent_diffusion_image2image_generate'),
|
||||
Tasks.image_classification:
|
||||
(Pipelines.daily_image_classification,
|
||||
'damo/cv_vit-base_image-classification_Dailylife-labels'),
|
||||
Tasks.image_object_detection:
|
||||
(Pipelines.image_object_detection_auto,
|
||||
'damo/cv_yolox_image-object-detection-auto'),
|
||||
Tasks.ocr_recognition:
|
||||
(Pipelines.ocr_recognition,
|
||||
'damo/cv_convnextTiny_ocr-recognition-general_damo'),
|
||||
Tasks.skin_retouching: (Pipelines.skin_retouching,
|
||||
'damo/cv_unet_skin-retouching'),
|
||||
Tasks.faq_question_answering:
|
||||
(Pipelines.faq_question_answering,
|
||||
'damo/nlp_structbert_faq-question-answering_chinese-base'),
|
||||
Tasks.crowd_counting: (Pipelines.crowd_counting,
|
||||
'damo/cv_hrnet_crowd-counting_dcanet'),
|
||||
Tasks.video_single_object_tracking:
|
||||
(Pipelines.video_single_object_tracking,
|
||||
'damo/cv_vitb_video-single-object-tracking_ostrack'),
|
||||
Tasks.image_reid_person: (Pipelines.image_reid_person,
|
||||
'damo/cv_passvitb_image-reid-person_market'),
|
||||
Tasks.text_driven_segmentation:
|
||||
(Pipelines.text_driven_segmentation,
|
||||
'damo/cv_vitl16_segmentation_text-driven-seg'),
|
||||
Tasks.movie_scene_segmentation: (
|
||||
Pipelines.movie_scene_segmentation,
|
||||
'damo/cv_resnet50-bert_video-scene-segmentation_movienet'),
|
||||
Tasks.shop_segmentation: (Pipelines.shop_segmentation,
|
||||
'damo/cv_vitb16_segmentation_shop-seg'),
|
||||
Tasks.image_inpainting: (Pipelines.image_inpainting,
|
||||
'damo/cv_fft_inpainting_lama'),
|
||||
Tasks.image_paintbyexample: (Pipelines.image_paintbyexample,
|
||||
'damo/cv_stable-diffusion_paint-by-example'),
|
||||
Tasks.video_inpainting: (Pipelines.video_inpainting,
|
||||
'damo/cv_video-inpainting'),
|
||||
Tasks.video_human_matting: (Pipelines.video_human_matting,
|
||||
'damo/cv_effnetv2_video-human-matting'),
|
||||
Tasks.video_frame_interpolation: (
|
||||
Pipelines.video_frame_interpolation,
|
||||
'damo/cv_raft_video-frame-interpolation'),
|
||||
Tasks.video_deinterlace: (Pipelines.video_deinterlace,
|
||||
'damo/cv_unet_video-deinterlace'),
|
||||
Tasks.human_wholebody_keypoint: (
|
||||
Pipelines.human_wholebody_keypoint,
|
||||
'damo/cv_hrnetw48_human-wholebody-keypoint_image'),
|
||||
Tasks.hand_static: (Pipelines.hand_static,
|
||||
'damo/cv_mobileface_hand-static'),
|
||||
Tasks.face_human_hand_detection: (
|
||||
Pipelines.face_human_hand_detection,
|
||||
'damo/cv_nanodet_face-human-hand-detection'),
|
||||
Tasks.face_emotion: (Pipelines.face_emotion, 'damo/cv_face-emotion'),
|
||||
Tasks.product_segmentation: (Pipelines.product_segmentation,
|
||||
'damo/cv_F3Net_product-segmentation'),
|
||||
Tasks.referring_video_object_segmentation: (
|
||||
Pipelines.referring_video_object_segmentation,
|
||||
'damo/cv_swin-t_referring_video-object-segmentation'),
|
||||
Tasks.video_summarization: (Pipelines.video_summarization,
|
||||
'damo/cv_googlenet_pgl-video-summarization'),
|
||||
Tasks.image_skychange: (Pipelines.image_skychange,
|
||||
'damo/cv_hrnetocr_skychange'),
|
||||
Tasks.translation_evaluation: (
|
||||
Pipelines.translation_evaluation,
|
||||
'damo/nlp_unite_mup_translation_evaluation_multilingual_large'),
|
||||
Tasks.video_object_segmentation: (
|
||||
Pipelines.video_object_segmentation,
|
||||
'damo/cv_rdevos_video-object-segmentation'),
|
||||
Tasks.video_multi_object_tracking: (
|
||||
Pipelines.video_multi_object_tracking,
|
||||
'damo/cv_yolov5_video-multi-object-tracking_fairmot'),
|
||||
Tasks.image_multi_view_depth_estimation: (
|
||||
Pipelines.image_multi_view_depth_estimation,
|
||||
'damo/cv_casmvs_multi-view-depth-estimation_general'),
|
||||
Tasks.image_fewshot_detection: (
|
||||
Pipelines.image_fewshot_detection,
|
||||
'damo/cv_resnet101_detection_fewshot-defrcn'),
|
||||
Tasks.image_body_reshaping: (Pipelines.image_body_reshaping,
|
||||
'damo/cv_flow-based-body-reshaping_damo'),
|
||||
Tasks.image_face_fusion: (Pipelines.image_face_fusion,
|
||||
'damo/cv_unet-image-face-fusion_damo'),
|
||||
Tasks.image_matching: (
|
||||
Pipelines.image_matching,
|
||||
'damo/cv_quadtree_attention_image-matching_outdoor'),
|
||||
Tasks.image_quality_assessment_mos: (
|
||||
Pipelines.image_quality_assessment_mos,
|
||||
'damo/cv_resnet_image-quality-assessment-mos_youtubeUGC'),
|
||||
Tasks.image_quality_assessment_degradation: (
|
||||
Pipelines.image_quality_assessment_degradation,
|
||||
'damo/cv_resnet50_image-quality-assessment_degradation'),
|
||||
Tasks.vision_efficient_tuning: (
|
||||
Pipelines.vision_efficient_tuning,
|
||||
'damo/cv_vitb16_classification_vision-efficient-tuning-adapter'),
|
||||
Tasks.object_detection_3d: (Pipelines.object_detection_3d_depe,
|
||||
'damo/cv_object-detection-3d_depe'),
|
||||
Tasks.bad_image_detecting: (Pipelines.bad_image_detecting,
|
||||
'damo/cv_mobilenet-v2_bad-image-detecting'),
|
||||
Tasks.nerf_recon_acc: (Pipelines.nerf_recon_acc,
|
||||
'damo/cv_nerf-3d-reconstruction-accelerate_damo'),
|
||||
Tasks.siamese_uie: (Pipelines.siamese_uie,
|
||||
'damo/nlp_structbert_siamese-uie_chinese-base'),
|
||||
}
|
||||
|
||||
Holds the standard trainer name to use for identifying different trainer.
|
||||
This should be used to register trainers.
|
||||
|
||||
For a general Trainer, you can use EpochBasedTrainer.
|
||||
For a model specific Trainer, you can use ${ModelName}-${Task}-trainer.
|
||||
"""
|
||||
|
||||
default = 'trainer'
|
||||
easycv = 'easycv'
|
||||
tinynas_damoyolo = 'tinynas-damoyolo'
|
||||
|
||||
# multi-modal trainers
|
||||
clip_multi_modal_embedding = 'clip-multi-modal-embedding'
|
||||
ofa = 'ofa'
|
||||
mplug = 'mplug'
|
||||
mgeo_ranking_trainer = 'mgeo-ranking-trainer'
|
||||
|
||||
class CVTrainers(object):
|
||||
# cv trainers
|
||||
image_instance_segmentation = 'image-instance-segmentation'
|
||||
image_portrait_enhancement = 'image-portrait-enhancement'
|
||||
@@ -424,6 +797,8 @@ class Trainers(object):
|
||||
image_classification = 'image-classification'
|
||||
image_fewshot_detection = 'image-fewshot-detection'
|
||||
|
||||
|
||||
class NLPTrainers(object):
|
||||
# nlp trainers
|
||||
bert_sentiment_analysis = 'bert-sentiment-analysis'
|
||||
dialog_modeling_trainer = 'dialog-modeling-trainer'
|
||||
@@ -431,14 +806,26 @@ class Trainers(object):
|
||||
nlp_base_trainer = 'nlp-base-trainer'
|
||||
nlp_veco_trainer = 'nlp-veco-trainer'
|
||||
nlp_text_ranking_trainer = 'nlp-text-ranking-trainer'
|
||||
nlp_sentence_embedding_trainer = 'nlp-sentence-embedding-trainer'
|
||||
text_generation_trainer = 'text-generation-trainer'
|
||||
nlp_plug_trainer = 'nlp-plug-trainer'
|
||||
gpt3_trainer = 'nlp-gpt3-trainer'
|
||||
faq_question_answering_trainer = 'faq-question-answering-trainer'
|
||||
gpt_moe_trainer = 'nlp-gpt-moe-trainer'
|
||||
table_question_answering_trainer = 'table-question-answering-trainer'
|
||||
document_grounded_dialog_generate_trainer = 'document-grounded-dialog-generate-trainer'
|
||||
document_grounded_dialog_rerank_trainer = 'document-grounded-dialog-rerank-trainer'
|
||||
document_grounded_dialog_retrieval_trainer = 'document-grounded-dialog-retrieval-trainer'
|
||||
|
||||
# audio trainers
|
||||
|
||||
class MultiModalTrainers(object):
|
||||
clip_multi_modal_embedding = 'clip-multi-modal-embedding'
|
||||
ofa = 'ofa'
|
||||
mplug = 'mplug'
|
||||
mgeo_ranking_trainer = 'mgeo-ranking-trainer'
|
||||
|
||||
|
||||
class AudioTrainers(object):
|
||||
speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k'
|
||||
speech_dfsmn_kws_char_farfield = 'speech_dfsmn_kws_char_farfield'
|
||||
speech_kws_fsmn_char_ctc_nearfield = 'speech_kws_fsmn_char_ctc_nearfield'
|
||||
@@ -447,6 +834,45 @@ class Trainers(object):
|
||||
speech_separation = 'speech-separation'
|
||||
|
||||
|
||||
class Trainers(CVTrainers, NLPTrainers, MultiModalTrainers, AudioTrainers):
|
||||
""" Names for different trainer.
|
||||
|
||||
Holds the standard trainer name to use for identifying different trainer.
|
||||
This should be used to register trainers.
|
||||
|
||||
For a general Trainer, you can use EpochBasedTrainer.
|
||||
For a model specific Trainer, you can use ${ModelName}-${Task}-trainer.
|
||||
"""
|
||||
|
||||
default = 'trainer'
|
||||
easycv = 'easycv'
|
||||
tinynas_damoyolo = 'tinynas-damoyolo'
|
||||
|
||||
@staticmethod
|
||||
def get_trainer_domain(attribute_or_value):
|
||||
if attribute_or_value in vars(
|
||||
CVTrainers) or attribute_or_value in vars(CVTrainers).values():
|
||||
return Fields.cv
|
||||
elif attribute_or_value in vars(
|
||||
NLPTrainers) or attribute_or_value in vars(
|
||||
NLPTrainers).values():
|
||||
return Fields.nlp
|
||||
elif attribute_or_value in vars(
|
||||
AudioTrainers) or attribute_or_value in vars(
|
||||
AudioTrainers).values():
|
||||
return Fields.audio
|
||||
elif attribute_or_value in vars(
|
||||
MultiModalTrainers) or attribute_or_value in vars(
|
||||
MultiModalTrainers).values():
|
||||
return Fields.multi_modal
|
||||
elif attribute_or_value == Trainers.default:
|
||||
return Trainers.default
|
||||
elif attribute_or_value == Trainers.easycv:
|
||||
return Trainers.easycv
|
||||
else:
|
||||
return 'unknown'
|
||||
|
||||
|
||||
class Preprocessors(object):
|
||||
""" Names for different preprocessor.
|
||||
|
||||
@@ -466,12 +892,18 @@ class Preprocessors(object):
|
||||
image_classification_mmcv_preprocessor = 'image-classification-mmcv-preprocessor'
|
||||
image_color_enhance_preprocessor = 'image-color-enhance-preprocessor'
|
||||
image_instance_segmentation_preprocessor = 'image-instance-segmentation-preprocessor'
|
||||
image_driving_perception_preprocessor = 'image-driving-perception-preprocessor'
|
||||
image_portrait_enhancement_preprocessor = 'image-portrait-enhancement-preprocessor'
|
||||
image_quality_assessment_mos_preprocessor = 'image-quality_assessment-mos-preprocessor'
|
||||
video_summarization_preprocessor = 'video-summarization-preprocessor'
|
||||
movie_scene_segmentation_preprocessor = 'movie-scene-segmentation-preprocessor'
|
||||
image_classification_bypass_preprocessor = 'image-classification-bypass-preprocessor'
|
||||
object_detection_scrfd = 'object-detection-scrfd'
|
||||
image_sky_change_preprocessor = 'image-sky-change-preprocessor'
|
||||
image_demoire_preprocessor = 'image-demoire-preprocessor'
|
||||
ocr_recognition = 'ocr-recognition'
|
||||
bad_image_detecting_preprocessor = 'bad-image-detecting-preprocessor'
|
||||
nerf_recon_acc_preprocessor = 'nerf-recon-acc-preprocessor'
|
||||
|
||||
# nlp preprocessor
|
||||
sen_sim_tokenizer = 'sen-sim-tokenizer'
|
||||
@@ -510,6 +942,10 @@ class Preprocessors(object):
|
||||
sentence_piece = 'sentence-piece'
|
||||
translation_evaluation = 'translation-evaluation-preprocessor'
|
||||
dialog_use_preprocessor = 'dialog-use-preprocessor'
|
||||
siamese_uie_preprocessor = 'siamese-uie-preprocessor'
|
||||
document_grounded_dialog_retrieval = 'document-grounded-dialog-retrieval'
|
||||
document_grounded_dialog_rerank = 'document-grounded-dialog-rerank'
|
||||
document_grounded_dialog_generate = 'document-grounded-dialog-generate'
|
||||
|
||||
# audio preprocessor
|
||||
linear_aec_fbank = 'linear-aec-fbank'
|
||||
@@ -555,10 +991,14 @@ class Metrics(object):
|
||||
image_ins_seg_coco_metric = 'image-ins-seg-coco-metric'
|
||||
# metrics for sequence classification task
|
||||
seq_cls_metric = 'seq-cls-metric'
|
||||
# loss metric
|
||||
loss_metric = 'loss-metric'
|
||||
# metrics for token-classification task
|
||||
token_cls_metric = 'token-cls-metric'
|
||||
# metrics for text-generation task
|
||||
text_gen_metric = 'text-gen-metric'
|
||||
# file saving wrapper
|
||||
prediction_saving_wrapper = 'prediction-saving-wrapper'
|
||||
# metrics for image-color-enhance task
|
||||
image_color_enhance_metric = 'image-color-enhance-metric'
|
||||
# metrics for image-portrait-enhancement task
|
||||
@@ -576,6 +1016,12 @@ class Metrics(object):
|
||||
referring_video_object_segmentation_metric = 'referring-video-object-segmentation-metric'
|
||||
# metric for video stabilization task
|
||||
video_stabilization_metric = 'video-stabilization-metric'
|
||||
# metirc for image-quality-assessment-mos task
|
||||
image_quality_assessment_mos_metric = 'image-quality-assessment-mos-metric'
|
||||
# metirc for image-quality-assessment-degradation task
|
||||
image_quality_assessment_degradation_metric = 'image-quality-assessment-degradation-metric'
|
||||
# metric for text-ranking task
|
||||
text_ranking_metric = 'text-ranking-metric'
|
||||
|
||||
|
||||
class Optimizers(object):
|
||||
@@ -609,6 +1055,7 @@ class Hooks(object):
|
||||
# checkpoint
|
||||
CheckpointHook = 'CheckpointHook'
|
||||
BestCkptSaverHook = 'BestCkptSaverHook'
|
||||
LoadCheckpointHook = 'LoadCheckpointHook'
|
||||
|
||||
# logger
|
||||
TextLoggerHook = 'TextLoggerHook'
|
||||
|
||||
@@ -25,7 +25,10 @@ if TYPE_CHECKING:
|
||||
from .video_stabilization_metric import VideoStabilizationMetric
|
||||
from .video_super_resolution_metric.video_super_resolution_metric import VideoSuperResolutionMetric
|
||||
from .ppl_metric import PplMetric
|
||||
|
||||
from .image_quality_assessment_degradation_metric import ImageQualityAssessmentDegradationMetric
|
||||
from .image_quality_assessment_mos_metric import ImageQualityAssessmentMosMetric
|
||||
from .text_ranking_metric import TextRankingMetric
|
||||
from .loss_metric import LossMetric
|
||||
else:
|
||||
_import_structure = {
|
||||
'audio_noise_metric': ['AudioNoiseMetric'],
|
||||
@@ -50,6 +53,12 @@ else:
|
||||
'video_frame_interpolation_metric': ['VideoFrameInterpolationMetric'],
|
||||
'video_stabilization_metric': ['VideoStabilizationMetric'],
|
||||
'ppl_metric': ['PplMetric'],
|
||||
'image_quality_assessment_degradation_metric':
|
||||
['ImageQualityAssessmentDegradationMetric'],
|
||||
'image_quality_assessment_mos_metric':
|
||||
['ImageQualityAssessmentMosMetric'],
|
||||
'text_ranking_metric': ['TextRankingMetric'],
|
||||
'loss_metric': ['LossMetric']
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
@@ -8,6 +8,7 @@ from modelscope.metainfo import Metrics
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.utils.chinese_utils import remove_space_between_chinese_chars
|
||||
from modelscope.utils.registry import default_group
|
||||
from modelscope.utils.tensor_utils import torch_nested_numpify
|
||||
from .base import Metric
|
||||
from .builder import METRICS, MetricKeys
|
||||
|
||||
@@ -36,8 +37,10 @@ class AccuracyMetric(Metric):
|
||||
eval_results = outputs[key]
|
||||
break
|
||||
assert type(ground_truths) == type(eval_results)
|
||||
ground_truths = torch_nested_numpify(ground_truths)
|
||||
for truth in ground_truths:
|
||||
self.labels.append(truth)
|
||||
eval_results = torch_nested_numpify(eval_results)
|
||||
for result in eval_results:
|
||||
if isinstance(truth, str):
|
||||
if isinstance(result, list):
|
||||
|
||||
@@ -12,7 +12,9 @@ METRICS = Registry('metrics')
|
||||
class MetricKeys(object):
|
||||
ACCURACY = 'accuracy'
|
||||
F1 = 'f1'
|
||||
Binary_F1 = 'binary-f1'
|
||||
Macro_F1 = 'macro-f1'
|
||||
Micro_F1 = 'micro-f1'
|
||||
PRECISION = 'precision'
|
||||
RECALL = 'recall'
|
||||
PSNR = 'psnr'
|
||||
@@ -33,6 +35,11 @@ class MetricKeys(object):
|
||||
DISTORTION_VALUE = 'distortion_value'
|
||||
STABILITY_SCORE = 'stability_score'
|
||||
PPL = 'ppl'
|
||||
PLCC = 'plcc'
|
||||
SRCC = 'srcc'
|
||||
RMSE = 'rmse'
|
||||
MRR = 'mrr'
|
||||
NDCG = 'ndcg'
|
||||
|
||||
|
||||
task_default_metrics = {
|
||||
@@ -59,6 +66,10 @@ task_default_metrics = {
|
||||
Tasks.video_frame_interpolation:
|
||||
[Metrics.video_frame_interpolation_metric],
|
||||
Tasks.video_stabilization: [Metrics.video_stabilization_metric],
|
||||
Tasks.image_quality_assessment_degradation:
|
||||
[Metrics.image_quality_assessment_degradation_metric],
|
||||
Tasks.image_quality_assessment_mos:
|
||||
[Metrics.image_quality_assessment_mos_metric],
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,75 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
from collections import defaultdict
|
||||
from typing import Dict
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from scipy.stats import pearsonr, spearmanr
|
||||
from tqdm import tqdm
|
||||
|
||||
from modelscope.metainfo import Metrics
|
||||
from modelscope.utils.registry import default_group
|
||||
from .base import Metric
|
||||
from .builder import METRICS, MetricKeys
|
||||
|
||||
|
||||
@METRICS.register_module(
|
||||
group_key=default_group,
|
||||
module_name=Metrics.image_quality_assessment_degradation_metric)
|
||||
class ImageQualityAssessmentDegradationMetric(Metric):
|
||||
"""The metric for image-quality-assessment-degradation task.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.inputs = defaultdict(list)
|
||||
self.outputs = defaultdict(list)
|
||||
|
||||
def add(self, outputs: Dict, inputs: Dict):
|
||||
item_degradation_id = outputs['item_id'][0] + outputs[
|
||||
'distortion_type'][0]
|
||||
if outputs['distortion_type'][0] in ['01', '02', '03']:
|
||||
pred = outputs['blur_degree']
|
||||
elif outputs['distortion_type'][0] in ['09', '10', '21']:
|
||||
pred = outputs['comp_degree']
|
||||
elif outputs['distortion_type'][0] in ['11', '12', '13', '14']:
|
||||
pred = outputs['noise_degree']
|
||||
else:
|
||||
return
|
||||
|
||||
self.outputs[item_degradation_id].append(pred[0].float())
|
||||
self.inputs[item_degradation_id].append(outputs['target'].float())
|
||||
|
||||
def evaluate(self):
|
||||
degree_plccs = []
|
||||
degree_sroccs = []
|
||||
|
||||
for item_degradation_id, degree_value in self.inputs.items():
|
||||
degree_label = torch.cat(degree_value).flatten().data.cpu().numpy()
|
||||
degree_pred = torch.cat(self.outputs[item_degradation_id]).flatten(
|
||||
).data.cpu().numpy()
|
||||
degree_plcc = pearsonr(degree_label, degree_pred)[0]
|
||||
degree_srocc = spearmanr(degree_label, degree_pred)[0]
|
||||
degree_plccs.append(degree_plcc)
|
||||
degree_sroccs.append(degree_srocc)
|
||||
degree_plcc_mean = np.array(degree_plccs).mean()
|
||||
degree_srocc_mean = np.array(degree_sroccs).mean()
|
||||
|
||||
return {
|
||||
MetricKeys.PLCC: degree_plcc_mean,
|
||||
MetricKeys.SRCC: degree_srocc_mean,
|
||||
}
|
||||
|
||||
def merge(self, other: 'ImageQualityAssessmentDegradationMetric'):
|
||||
self.inputs.extend(other.inputs)
|
||||
self.outputs.extend(other.outputs)
|
||||
|
||||
def __getstate__(self):
|
||||
return self.inputs, self.outputs
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.inputs, self.outputs = state
|
||||
57
modelscope/metrics/image_quality_assessment_mos_metric.py
Normal file
57
modelscope/metrics/image_quality_assessment_mos_metric.py
Normal file
@@ -0,0 +1,57 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
from typing import Dict
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from scipy.stats import pearsonr, spearmanr
|
||||
from tqdm import tqdm
|
||||
|
||||
from modelscope.metainfo import Metrics
|
||||
from modelscope.utils.registry import default_group
|
||||
from .base import Metric
|
||||
from .builder import METRICS, MetricKeys
|
||||
|
||||
|
||||
@METRICS.register_module(
|
||||
group_key=default_group,
|
||||
module_name=Metrics.image_quality_assessment_mos_metric)
|
||||
class ImageQualityAssessmentMosMetric(Metric):
|
||||
"""The metric for image-quality-assessment-mos task.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.inputs = []
|
||||
self.outputs = []
|
||||
|
||||
def add(self, outputs: Dict, inputs: Dict):
|
||||
self.outputs.append(outputs['pred'].float())
|
||||
self.inputs.append(outputs['target'].float())
|
||||
|
||||
def evaluate(self):
|
||||
|
||||
mos_labels = torch.cat(self.inputs).flatten().data.cpu().numpy()
|
||||
mos_preds = torch.cat(self.outputs).flatten().data.cpu().numpy()
|
||||
mos_plcc = pearsonr(mos_labels, mos_preds)[0]
|
||||
mos_srocc = spearmanr(mos_labels, mos_preds)[0]
|
||||
mos_rmse = np.sqrt(np.mean((mos_labels - mos_preds)**2))
|
||||
|
||||
return {
|
||||
MetricKeys.PLCC: mos_plcc,
|
||||
MetricKeys.SRCC: mos_srocc,
|
||||
MetricKeys.RMSE: mos_rmse,
|
||||
}
|
||||
|
||||
def merge(self, other: 'ImageQualityAssessmentMosMetric'):
|
||||
self.inputs.extend(other.inputs)
|
||||
self.outputs.extend(other.outputs)
|
||||
|
||||
def __getstate__(self):
|
||||
return self.inputs, self.outputs
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.inputs, self.outputs = state
|
||||
46
modelscope/metrics/loss_metric.py
Normal file
46
modelscope/metrics/loss_metric.py
Normal file
@@ -0,0 +1,46 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
from sklearn.metrics import accuracy_score, f1_score
|
||||
|
||||
from modelscope.metainfo import Metrics
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.utils.registry import default_group
|
||||
from modelscope.utils.tensor_utils import (torch_nested_detach,
|
||||
torch_nested_numpify)
|
||||
from .base import Metric
|
||||
from .builder import METRICS, MetricKeys
|
||||
|
||||
|
||||
@METRICS.register_module(
|
||||
group_key=default_group, module_name=Metrics.loss_metric)
|
||||
class LossMetric(Metric):
|
||||
"""The metric class to calculate average loss of batches.
|
||||
|
||||
Args:
|
||||
loss_key: The key of loss
|
||||
"""
|
||||
|
||||
def __init__(self, loss_key=OutputKeys.LOSS, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.loss_key = loss_key
|
||||
self.losses = []
|
||||
|
||||
def add(self, outputs: Dict, inputs: Dict):
|
||||
loss = outputs[self.loss_key]
|
||||
self.losses.append(torch_nested_numpify(torch_nested_detach(loss)))
|
||||
|
||||
def evaluate(self):
|
||||
return {OutputKeys.LOSS: float(np.average(self.losses))}
|
||||
|
||||
def merge(self, other: 'LossMetric'):
|
||||
self.losses.extend(other.losses)
|
||||
|
||||
def __getstate__(self):
|
||||
return self.losses
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.__init__()
|
||||
self.losses = state
|
||||
42
modelscope/metrics/prediction_saving_wrapper.py
Normal file
42
modelscope/metrics/prediction_saving_wrapper.py
Normal file
@@ -0,0 +1,42 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
from sklearn.metrics import accuracy_score, f1_score
|
||||
|
||||
from modelscope.metainfo import Metrics
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.utils.registry import default_group
|
||||
from modelscope.utils.tensor_utils import (torch_nested_detach,
|
||||
torch_nested_numpify)
|
||||
from .base import Metric
|
||||
from .builder import METRICS, MetricKeys
|
||||
|
||||
|
||||
@METRICS.register_module(
|
||||
group_key=default_group, module_name=Metrics.prediction_saving_wrapper)
|
||||
class PredictionSavingWrapper(Metric):
|
||||
"""The wrapper to save predictions to file.
|
||||
Args:
|
||||
saving_fn: The saving_fn used to save predictions to files.
|
||||
"""
|
||||
|
||||
def __init__(self, saving_fn, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.saving_fn = saving_fn
|
||||
|
||||
def add(self, outputs: Dict, inputs: Dict):
|
||||
self.saving_fn(inputs, outputs)
|
||||
|
||||
def evaluate(self):
|
||||
return {}
|
||||
|
||||
def merge(self, other: 'PredictionSavingWrapper'):
|
||||
pass
|
||||
|
||||
def __getstate__(self):
|
||||
pass
|
||||
|
||||
def __setstate__(self, state):
|
||||
pass
|
||||
@@ -48,19 +48,29 @@ class SequenceClassificationMetric(Metric):
|
||||
def evaluate(self):
|
||||
preds = np.concatenate(self.preds, axis=0)
|
||||
labels = np.concatenate(self.labels, axis=0)
|
||||
preds = np.argmax(preds, axis=1)
|
||||
return {
|
||||
MetricKeys.ACCURACY:
|
||||
accuracy_score(labels, preds),
|
||||
MetricKeys.F1:
|
||||
f1_score(
|
||||
labels,
|
||||
preds,
|
||||
average='micro' if any([label > 1
|
||||
for label in labels]) else None),
|
||||
MetricKeys.Macro_F1:
|
||||
f1_score(labels, preds, average='macro'),
|
||||
}
|
||||
assert len(preds.shape) == 2, 'Only support predictions with shape: (batch_size, num_labels),' \
|
||||
'multi-label classification is not supported in this metric class.'
|
||||
preds_max = np.argmax(preds, axis=1)
|
||||
if preds.shape[1] > 2:
|
||||
metrics = {
|
||||
MetricKeys.ACCURACY: accuracy_score(labels, preds_max),
|
||||
MetricKeys.Micro_F1:
|
||||
f1_score(labels, preds_max, average='micro'),
|
||||
MetricKeys.Macro_F1:
|
||||
f1_score(labels, preds_max, average='macro'),
|
||||
}
|
||||
|
||||
metrics[MetricKeys.F1] = metrics[MetricKeys.Micro_F1]
|
||||
return metrics
|
||||
else:
|
||||
metrics = {
|
||||
MetricKeys.ACCURACY:
|
||||
accuracy_score(labels, preds_max),
|
||||
MetricKeys.Binary_F1:
|
||||
f1_score(labels, preds_max, average='binary'),
|
||||
}
|
||||
metrics[MetricKeys.F1] = metrics[MetricKeys.Binary_F1]
|
||||
return metrics
|
||||
|
||||
def merge(self, other: 'SequenceClassificationMetric'):
|
||||
self.preds.extend(other.preds)
|
||||
|
||||
91
modelscope/metrics/text_ranking_metric.py
Normal file
91
modelscope/metrics/text_ranking_metric.py
Normal file
@@ -0,0 +1,91 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from typing import Dict, List
|
||||
|
||||
import numpy as np
|
||||
|
||||
from modelscope.metainfo import Metrics
|
||||
from modelscope.metrics.base import Metric
|
||||
from modelscope.metrics.builder import METRICS, MetricKeys
|
||||
from modelscope.utils.registry import default_group
|
||||
|
||||
|
||||
@METRICS.register_module(
|
||||
group_key=default_group, module_name=Metrics.text_ranking_metric)
|
||||
class TextRankingMetric(Metric):
|
||||
"""The metric computation class for text ranking classes.
|
||||
|
||||
This metric class calculates mrr and ndcg metric for the whole evaluation dataset.
|
||||
|
||||
Args:
|
||||
target_text: The key of the target text column in the `inputs` arg.
|
||||
pred_text: The key of the predicted text column in the `outputs` arg.
|
||||
"""
|
||||
|
||||
def __init__(self, mrr_k: int = 1, ndcg_k: int = 1):
|
||||
self.labels: List = []
|
||||
self.qids: List = []
|
||||
self.logits: List = []
|
||||
self.mrr_k: int = mrr_k
|
||||
self.ndcg_k: int = ndcg_k
|
||||
|
||||
def add(self, outputs: Dict[str, List], inputs: Dict[str, List]):
|
||||
self.labels.extend(inputs.pop('labels').detach().cpu().numpy())
|
||||
self.qids.extend(inputs.pop('qid').detach().cpu().numpy())
|
||||
|
||||
logits = outputs['logits'].squeeze(-1).detach().cpu().numpy()
|
||||
logits = self._sigmoid(logits).tolist()
|
||||
self.logits.extend(logits)
|
||||
|
||||
def evaluate(self):
|
||||
rank_result = {}
|
||||
for qid, score, label in zip(self.qids, self.logits, self.labels):
|
||||
if qid not in rank_result:
|
||||
rank_result[qid] = []
|
||||
rank_result[qid].append((score, label))
|
||||
|
||||
for qid in rank_result:
|
||||
rank_result[qid] = sorted(rank_result[qid], key=lambda x: x[0])
|
||||
|
||||
return {
|
||||
MetricKeys.MRR: self._compute_mrr(rank_result),
|
||||
MetricKeys.NDCG: self._compute_ndcg(rank_result)
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _sigmoid(logits):
|
||||
return np.exp(logits) / (1 + np.exp(logits))
|
||||
|
||||
def _compute_mrr(self, result):
|
||||
mrr = 0
|
||||
for res in result.values():
|
||||
sorted_res = sorted(res, key=lambda x: x[0], reverse=True)
|
||||
ar = 0
|
||||
for index, ele in enumerate(sorted_res[:self.mrr_k]):
|
||||
if str(ele[1]) == '1':
|
||||
ar = 1.0 / (index + 1)
|
||||
break
|
||||
mrr += ar
|
||||
return mrr / len(result)
|
||||
|
||||
def _compute_ndcg(self, result):
|
||||
ndcg = 0
|
||||
from sklearn.metrics import ndcg_score
|
||||
for res in result.values():
|
||||
sorted_res = sorted(res, key=lambda x: [0], reverse=True)
|
||||
labels = np.array([[ele[1] for ele in sorted_res]])
|
||||
scores = np.array([[ele[0] for ele in sorted_res]])
|
||||
ndcg += float(ndcg_score(labels, scores, k=self.ndcg_k))
|
||||
return ndcg / len(result)
|
||||
|
||||
def merge(self, other: 'TextRankingMetric'):
|
||||
self.labels.extend(other.labels)
|
||||
self.qids.extend(other.qids)
|
||||
self.logits.extend(other.logits)
|
||||
|
||||
def __getstate__(self):
|
||||
return self.labels, self.qids, self.logits, self.mrr_k, self.ndcg_k
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.__init__()
|
||||
self.labels, self.qids, self.logits, self.mrr_k, self.ndcg_k = state
|
||||
@@ -9,4 +9,5 @@ from .base import Head, Model
|
||||
from .builder import BACKBONES, HEADS, MODELS, build_model
|
||||
|
||||
if is_torch_available():
|
||||
from .base import TorchModel, TorchHead
|
||||
from .base.base_torch_model import TorchModel
|
||||
from .base.base_torch_head import TorchHead
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from . import ans, asr, itn, kws, tts
|
||||
from . import ans, asr, itn, kws, sv, tts
|
||||
|
||||
@@ -13,6 +13,9 @@ __all__ = ['GenericAutomaticSpeechRecognition']
|
||||
|
||||
@MODELS.register_module(
|
||||
Tasks.auto_speech_recognition, module_name=Models.generic_asr)
|
||||
@MODELS.register_module(
|
||||
Tasks.voice_activity_detection, module_name=Models.generic_asr)
|
||||
@MODELS.register_module(Tasks.language_model, module_name=Models.generic_asr)
|
||||
class GenericAutomaticSpeechRecognition(Model):
|
||||
|
||||
def __init__(self, model_dir: str, am_model_name: str,
|
||||
|
||||
@@ -120,13 +120,12 @@ class Encoder(nn.Module):
|
||||
in_channels: Number of input channels.
|
||||
out_channels: Number of output channels.
|
||||
|
||||
Example:
|
||||
-------
|
||||
Examples:
|
||||
|
||||
>>> x = torch.randn(2, 1000)
|
||||
>>> encoder = Encoder(kernel_size=4, out_channels=64)
|
||||
>>> h = encoder(x)
|
||||
>>> h.shape
|
||||
torch.Size([2, 64, 499])
|
||||
>>> h.shape # torch.Size([2, 64, 499])
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
||||
@@ -0,0 +1,18 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .ecapa_tdnn import SpeakerVerificationECAPATDNN
|
||||
|
||||
else:
|
||||
_import_structure = {'ecapa_tdnn': ['SpeakerVerificationECAPATDNN']}
|
||||
import sys
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
|
||||
504
modelscope/models/audio/sv/ecapa_tdnn.py
Normal file
504
modelscope/models/audio/sv/ecapa_tdnn.py
Normal file
@@ -0,0 +1,504 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
""" This ECAPA-TDNN implementation is adapted from https://github.com/speechbrain/speechbrain.
|
||||
"""
|
||||
import math
|
||||
import os
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchaudio.compliance.kaldi as Kaldi
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models import MODELS, TorchModel
|
||||
from modelscope.utils.constant import Tasks
|
||||
|
||||
|
||||
def length_to_mask(length, max_len=None, dtype=None, device=None):
|
||||
assert len(length.shape) == 1
|
||||
|
||||
if max_len is None:
|
||||
max_len = length.max().long().item()
|
||||
mask = torch.arange(
|
||||
max_len, device=length.device, dtype=length.dtype).expand(
|
||||
len(length), max_len) < length.unsqueeze(1)
|
||||
|
||||
if dtype is None:
|
||||
dtype = length.dtype
|
||||
|
||||
if device is None:
|
||||
device = length.device
|
||||
|
||||
mask = torch.as_tensor(mask, dtype=dtype, device=device)
|
||||
return mask
|
||||
|
||||
|
||||
def get_padding_elem(L_in: int, stride: int, kernel_size: int, dilation: int):
|
||||
if stride > 1:
|
||||
n_steps = math.ceil(((L_in - kernel_size * dilation) / stride) + 1)
|
||||
L_out = stride * (n_steps - 1) + kernel_size * dilation
|
||||
padding = [kernel_size // 2, kernel_size // 2]
|
||||
|
||||
else:
|
||||
L_out = (L_in - dilation * (kernel_size - 1) - 1) // stride + 1
|
||||
|
||||
padding = [(L_in - L_out) // 2, (L_in - L_out) // 2]
|
||||
return padding
|
||||
|
||||
|
||||
class Conv1d(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
in_channels,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
padding='same',
|
||||
groups=1,
|
||||
bias=True,
|
||||
padding_mode='reflect',
|
||||
):
|
||||
super().__init__()
|
||||
self.kernel_size = kernel_size
|
||||
self.stride = stride
|
||||
self.dilation = dilation
|
||||
self.padding = padding
|
||||
self.padding_mode = padding_mode
|
||||
|
||||
self.conv = nn.Conv1d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
self.kernel_size,
|
||||
stride=self.stride,
|
||||
dilation=self.dilation,
|
||||
padding=0,
|
||||
groups=groups,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
if self.padding == 'same':
|
||||
x = self._manage_padding(x, self.kernel_size, self.dilation,
|
||||
self.stride)
|
||||
|
||||
elif self.padding == 'causal':
|
||||
num_pad = (self.kernel_size - 1) * self.dilation
|
||||
x = F.pad(x, (num_pad, 0))
|
||||
|
||||
elif self.padding == 'valid':
|
||||
pass
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
"Padding must be 'same', 'valid' or 'causal'. Got "
|
||||
+ self.padding)
|
||||
|
||||
wx = self.conv(x)
|
||||
|
||||
return wx
|
||||
|
||||
def _manage_padding(
|
||||
self,
|
||||
x,
|
||||
kernel_size: int,
|
||||
dilation: int,
|
||||
stride: int,
|
||||
):
|
||||
L_in = x.shape[-1]
|
||||
padding = get_padding_elem(L_in, stride, kernel_size, dilation)
|
||||
x = F.pad(x, padding, mode=self.padding_mode)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class BatchNorm1d(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size,
|
||||
eps=1e-05,
|
||||
momentum=0.1,
|
||||
):
|
||||
super().__init__()
|
||||
self.norm = nn.BatchNorm1d(
|
||||
input_size,
|
||||
eps=eps,
|
||||
momentum=momentum,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.norm(x)
|
||||
|
||||
|
||||
class TDNNBlock(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
dilation,
|
||||
activation=nn.ReLU,
|
||||
groups=1,
|
||||
):
|
||||
super(TDNNBlock, self).__init__()
|
||||
self.conv = Conv1d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
)
|
||||
self.activation = activation()
|
||||
self.norm = BatchNorm1d(input_size=out_channels)
|
||||
|
||||
def forward(self, x):
|
||||
return self.norm(self.activation(self.conv(x)))
|
||||
|
||||
|
||||
class Res2NetBlock(torch.nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
scale=8,
|
||||
kernel_size=3,
|
||||
dilation=1):
|
||||
super(Res2NetBlock, self).__init__()
|
||||
assert in_channels % scale == 0
|
||||
assert out_channels % scale == 0
|
||||
|
||||
in_channel = in_channels // scale
|
||||
hidden_channel = out_channels // scale
|
||||
|
||||
self.blocks = nn.ModuleList([
|
||||
TDNNBlock(
|
||||
in_channel,
|
||||
hidden_channel,
|
||||
kernel_size=kernel_size,
|
||||
dilation=dilation,
|
||||
) for i in range(scale - 1)
|
||||
])
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, x):
|
||||
y = []
|
||||
for i, x_i in enumerate(torch.chunk(x, self.scale, dim=1)):
|
||||
if i == 0:
|
||||
y_i = x_i
|
||||
elif i == 1:
|
||||
y_i = self.blocks[i - 1](x_i)
|
||||
else:
|
||||
y_i = self.blocks[i - 1](x_i + y_i)
|
||||
y.append(y_i)
|
||||
y = torch.cat(y, dim=1)
|
||||
return y
|
||||
|
||||
|
||||
class SEBlock(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, se_channels, out_channels):
|
||||
super(SEBlock, self).__init__()
|
||||
|
||||
self.conv1 = Conv1d(
|
||||
in_channels=in_channels, out_channels=se_channels, kernel_size=1)
|
||||
self.relu = torch.nn.ReLU(inplace=True)
|
||||
self.conv2 = Conv1d(
|
||||
in_channels=se_channels, out_channels=out_channels, kernel_size=1)
|
||||
self.sigmoid = torch.nn.Sigmoid()
|
||||
|
||||
def forward(self, x, lengths=None):
|
||||
L = x.shape[-1]
|
||||
if lengths is not None:
|
||||
mask = length_to_mask(lengths * L, max_len=L, device=x.device)
|
||||
mask = mask.unsqueeze(1)
|
||||
total = mask.sum(dim=2, keepdim=True)
|
||||
s = (x * mask).sum(dim=2, keepdim=True) / total
|
||||
else:
|
||||
s = x.mean(dim=2, keepdim=True)
|
||||
|
||||
s = self.relu(self.conv1(s))
|
||||
s = self.sigmoid(self.conv2(s))
|
||||
|
||||
return s * x
|
||||
|
||||
|
||||
class AttentiveStatisticsPooling(nn.Module):
|
||||
|
||||
def __init__(self, channels, attention_channels=128, global_context=True):
|
||||
super().__init__()
|
||||
|
||||
self.eps = 1e-12
|
||||
self.global_context = global_context
|
||||
if global_context:
|
||||
self.tdnn = TDNNBlock(channels * 3, attention_channels, 1, 1)
|
||||
else:
|
||||
self.tdnn = TDNNBlock(channels, attention_channels, 1, 1)
|
||||
self.tanh = nn.Tanh()
|
||||
self.conv = Conv1d(
|
||||
in_channels=attention_channels,
|
||||
out_channels=channels,
|
||||
kernel_size=1)
|
||||
|
||||
def forward(self, x, lengths=None):
|
||||
L = x.shape[-1]
|
||||
|
||||
def _compute_statistics(x, m, dim=2, eps=self.eps):
|
||||
mean = (m * x).sum(dim)
|
||||
std = torch.sqrt(
|
||||
(m * (x - mean.unsqueeze(dim)).pow(2)).sum(dim).clamp(eps))
|
||||
return mean, std
|
||||
|
||||
if lengths is None:
|
||||
lengths = torch.ones(x.shape[0], device=x.device)
|
||||
|
||||
# Make binary mask of shape [N, 1, L]
|
||||
mask = length_to_mask(lengths * L, max_len=L, device=x.device)
|
||||
mask = mask.unsqueeze(1)
|
||||
|
||||
# Expand the temporal context of the pooling layer by allowing the
|
||||
# self-attention to look at global properties of the utterance.
|
||||
if self.global_context:
|
||||
# torch.std is unstable for backward computation
|
||||
# https://github.com/pytorch/pytorch/issues/4320
|
||||
total = mask.sum(dim=2, keepdim=True).float()
|
||||
mean, std = _compute_statistics(x, mask / total)
|
||||
mean = mean.unsqueeze(2).repeat(1, 1, L)
|
||||
std = std.unsqueeze(2).repeat(1, 1, L)
|
||||
attn = torch.cat([x, mean, std], dim=1)
|
||||
else:
|
||||
attn = x
|
||||
|
||||
# Apply layers
|
||||
attn = self.conv(self.tanh(self.tdnn(attn)))
|
||||
|
||||
# Filter out zero-paddings
|
||||
attn = attn.masked_fill(mask == 0, float('-inf'))
|
||||
|
||||
attn = F.softmax(attn, dim=2)
|
||||
mean, std = _compute_statistics(x, attn)
|
||||
# Append mean and std of the batch
|
||||
pooled_stats = torch.cat((mean, std), dim=1)
|
||||
pooled_stats = pooled_stats.unsqueeze(2)
|
||||
|
||||
return pooled_stats
|
||||
|
||||
|
||||
class SERes2NetBlock(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
res2net_scale=8,
|
||||
se_channels=128,
|
||||
kernel_size=1,
|
||||
dilation=1,
|
||||
activation=torch.nn.ReLU,
|
||||
groups=1,
|
||||
):
|
||||
super().__init__()
|
||||
self.out_channels = out_channels
|
||||
self.tdnn1 = TDNNBlock(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
dilation=1,
|
||||
activation=activation,
|
||||
groups=groups,
|
||||
)
|
||||
self.res2net_block = Res2NetBlock(out_channels, out_channels,
|
||||
res2net_scale, kernel_size, dilation)
|
||||
self.tdnn2 = TDNNBlock(
|
||||
out_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
dilation=1,
|
||||
activation=activation,
|
||||
groups=groups,
|
||||
)
|
||||
self.se_block = SEBlock(out_channels, se_channels, out_channels)
|
||||
|
||||
self.shortcut = None
|
||||
if in_channels != out_channels:
|
||||
self.shortcut = Conv1d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=1,
|
||||
)
|
||||
|
||||
def forward(self, x, lengths=None):
|
||||
residual = x
|
||||
if self.shortcut:
|
||||
residual = self.shortcut(x)
|
||||
|
||||
x = self.tdnn1(x)
|
||||
x = self.res2net_block(x)
|
||||
x = self.tdnn2(x)
|
||||
x = self.se_block(x, lengths)
|
||||
|
||||
return x + residual
|
||||
|
||||
|
||||
class ECAPA_TDNN(nn.Module):
|
||||
"""An implementation of the speaker embedding model in a paper.
|
||||
"ECAPA-TDNN: Emphasized Channel Attention, Propagation and Aggregation in
|
||||
TDNN Based Speaker Verification" (https://arxiv.org/abs/2005.07143).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size,
|
||||
device='cpu',
|
||||
lin_neurons=192,
|
||||
activation=torch.nn.ReLU,
|
||||
channels=[512, 512, 512, 512, 1536],
|
||||
kernel_sizes=[5, 3, 3, 3, 1],
|
||||
dilations=[1, 2, 3, 4, 1],
|
||||
attention_channels=128,
|
||||
res2net_scale=8,
|
||||
se_channels=128,
|
||||
global_context=True,
|
||||
groups=[1, 1, 1, 1, 1],
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
assert len(channels) == len(kernel_sizes)
|
||||
assert len(channels) == len(dilations)
|
||||
self.channels = channels
|
||||
self.blocks = nn.ModuleList()
|
||||
|
||||
# The initial TDNN layer
|
||||
self.blocks.append(
|
||||
TDNNBlock(
|
||||
input_size,
|
||||
channels[0],
|
||||
kernel_sizes[0],
|
||||
dilations[0],
|
||||
activation,
|
||||
groups[0],
|
||||
))
|
||||
|
||||
# SE-Res2Net layers
|
||||
for i in range(1, len(channels) - 1):
|
||||
self.blocks.append(
|
||||
SERes2NetBlock(
|
||||
channels[i - 1],
|
||||
channels[i],
|
||||
res2net_scale=res2net_scale,
|
||||
se_channels=se_channels,
|
||||
kernel_size=kernel_sizes[i],
|
||||
dilation=dilations[i],
|
||||
activation=activation,
|
||||
groups=groups[i],
|
||||
))
|
||||
|
||||
# Multi-layer feature aggregation
|
||||
self.mfa = TDNNBlock(
|
||||
channels[-1],
|
||||
channels[-1],
|
||||
kernel_sizes[-1],
|
||||
dilations[-1],
|
||||
activation,
|
||||
groups=groups[-1],
|
||||
)
|
||||
|
||||
# Attentive Statistical Pooling
|
||||
self.asp = AttentiveStatisticsPooling(
|
||||
channels[-1],
|
||||
attention_channels=attention_channels,
|
||||
global_context=global_context,
|
||||
)
|
||||
self.asp_bn = BatchNorm1d(input_size=channels[-1] * 2)
|
||||
|
||||
# Final linear transformation
|
||||
self.fc = Conv1d(
|
||||
in_channels=channels[-1] * 2,
|
||||
out_channels=lin_neurons,
|
||||
kernel_size=1,
|
||||
)
|
||||
|
||||
def forward(self, x, lengths=None):
|
||||
"""Returns the embedding vector.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
x : torch.Tensor
|
||||
Tensor of shape (batch, time, channel).
|
||||
"""
|
||||
x = x.transpose(1, 2)
|
||||
|
||||
xl = []
|
||||
for layer in self.blocks:
|
||||
try:
|
||||
x = layer(x, lengths=lengths)
|
||||
except TypeError:
|
||||
x = layer(x)
|
||||
xl.append(x)
|
||||
|
||||
# Multi-layer feature aggregation
|
||||
x = torch.cat(xl[1:], dim=1)
|
||||
x = self.mfa(x)
|
||||
|
||||
# Attentive Statistical Pooling
|
||||
x = self.asp(x, lengths=lengths)
|
||||
x = self.asp_bn(x)
|
||||
|
||||
# Final linear transformation
|
||||
x = self.fc(x)
|
||||
|
||||
x = x.transpose(1, 2).squeeze(1)
|
||||
return x
|
||||
|
||||
|
||||
@MODELS.register_module(
|
||||
Tasks.speaker_verification, module_name=Models.ecapa_tdnn_sv)
|
||||
class SpeakerVerificationECAPATDNN(TorchModel):
|
||||
|
||||
def __init__(self, model_dir, model_config: Dict[str, Any], *args,
|
||||
**kwargs):
|
||||
super().__init__(model_dir, model_config, *args, **kwargs)
|
||||
self.model_config = model_config
|
||||
self.other_config = kwargs
|
||||
if self.model_config['channel'] != 1024:
|
||||
raise ValueError(
|
||||
'modelscope error: Currently only 1024-channel ecapa tdnn is supported.'
|
||||
)
|
||||
|
||||
self.feature_dim = 80
|
||||
channels_config = [1024, 1024, 1024, 1024, 3072]
|
||||
|
||||
self.embedding_model = ECAPA_TDNN(
|
||||
self.feature_dim, channels=channels_config)
|
||||
|
||||
pretrained_model_name = kwargs['pretrained_model']
|
||||
self.__load_check_point(pretrained_model_name)
|
||||
|
||||
self.embedding_model.eval()
|
||||
|
||||
def forward(self, audio):
|
||||
assert len(audio.shape) == 2 and audio.shape[
|
||||
0] == 1, 'modelscope error: the shape of input audio to model needs to be [1, T]'
|
||||
# audio shape: [1, T]
|
||||
feature = self.__extract_feature(audio)
|
||||
embedding = self.embedding_model(feature)
|
||||
|
||||
return embedding
|
||||
|
||||
def __extract_feature(self, audio):
|
||||
feature = Kaldi.fbank(audio, num_mel_bins=self.feature_dim)
|
||||
feature = feature - feature.mean(dim=0, keepdim=True)
|
||||
feature = feature.unsqueeze(0)
|
||||
return feature
|
||||
|
||||
def __load_check_point(self, pretrained_model_name, device=None):
|
||||
if not device:
|
||||
device = torch.device('cpu')
|
||||
self.embedding_model.load_state_dict(
|
||||
torch.load(
|
||||
os.path.join(self.model_dir, pretrained_model_name),
|
||||
map_location=device),
|
||||
strict=True)
|
||||
@@ -1,6 +1,9 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from modelscope.utils.import_utils import is_torch_available
|
||||
from .base_head import * # noqa F403
|
||||
from .base_model import * # noqa F403
|
||||
from .base_torch_head import * # noqa F403
|
||||
from .base_torch_model import * # noqa F403
|
||||
|
||||
if is_torch_available():
|
||||
from .base_torch_model import TorchModel
|
||||
from .base_torch_head import TorchHead
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user