mirror of
https://github.com/modelscope/modelscope.git
synced 2026-02-24 20:19:51 +01:00
Merge branch 'master' into ocr/ocr_detection
This commit is contained in:
@@ -1,11 +1,9 @@
|
||||
pip install -r requirements.txt -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
|
||||
pip install -r requirements/audio.txt -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
|
||||
pip install -r requirements/cv.txt -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
|
||||
pip install -r requirements/multi-modal.txt -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
|
||||
pip install -r requirements/nlp.txt -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
|
||||
awk -F: '/^[^#]/ { print $1 }' requirements/framework.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
|
||||
awk -F: '/^[^#]/ { print $1 }' requirements/audio.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
|
||||
awk -F: '/^[^#]/ { print $1 }' requirements/cv.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
|
||||
awk -F: '/^[^#]/ { print $1 }' requirements/multi-modal.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
|
||||
awk -F: '/^[^#]/ { print $1 }' requirements/nlp.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
|
||||
pip install -r requirements/tests.txt
|
||||
# install numpy<=1.18 for tensorflow==1.15.x
|
||||
pip install "numpy<=1.18"
|
||||
|
||||
git config --global --add safe.directory /Maas-lib
|
||||
|
||||
@@ -19,4 +17,10 @@ fi
|
||||
# test with install
|
||||
python setup.py install
|
||||
|
||||
python tests/run.py
|
||||
if [ $# -eq 0 ]; then
|
||||
ci_command="python tests/run.py --subprocess"
|
||||
else
|
||||
ci_command="$@"
|
||||
fi
|
||||
echo "Running case with command: $ci_command"
|
||||
$ci_command
|
||||
|
||||
@@ -1,19 +0,0 @@
|
||||
pip install -r requirements.txt -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
|
||||
pip install -r requirements/audio.txt -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
|
||||
pip install -r requirements/cv.txt -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
|
||||
pip install -r requirements/multi-modal.txt -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
|
||||
pip install -r requirements/nlp.txt -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
|
||||
|
||||
pip install -r requirements/tests.txt
|
||||
# install numpy<=1.18 for tensorflow==1.15.x
|
||||
pip install "numpy<=1.18"
|
||||
|
||||
# linter test
|
||||
# use internal project for pre-commit due to the network problem
|
||||
pre-commit run --all-files
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "linter test failed, please run 'pre-commit run --all-files' to check"
|
||||
exit -1
|
||||
fi
|
||||
|
||||
PYTHONPATH=. python tests/run.py
|
||||
@@ -7,7 +7,9 @@ gpus='7 6 5 4 3 2 1 0'
|
||||
cpu_sets='0-7 8-15 16-23 24-30 31-37 38-44 45-51 52-58'
|
||||
cpu_sets_arr=($cpu_sets)
|
||||
is_get_file_lock=false
|
||||
CI_COMMAND=${CI_COMMAND:-'bash .dev_scripts/ci_container_test.sh'}
|
||||
# export RUN_CASE_COMMAND='python tests/run.py --run_config tests/run_config.yaml'
|
||||
CI_COMMAND=${CI_COMMAND:-bash .dev_scripts/ci_container_test.sh $RUN_CASE_BASE_COMMAND}
|
||||
echo "ci command: $CI_COMMAND"
|
||||
for gpu in $gpus
|
||||
do
|
||||
exec {lock_fd}>"/tmp/gpu$gpu" || exit 1
|
||||
@@ -15,6 +17,7 @@ do
|
||||
echo "get gpu lock $gpu"
|
||||
CONTAINER_NAME="modelscope-ci-$gpu"
|
||||
let is_get_file_lock=true
|
||||
|
||||
# pull image if there are update
|
||||
docker pull ${IMAGE_NAME}:${IMAGE_VERSION}
|
||||
docker run --rm --name $CONTAINER_NAME --shm-size=16gb \
|
||||
@@ -32,10 +35,13 @@ do
|
||||
-e TEST_ACCESS_TOKEN_CITEST=$TEST_ACCESS_TOKEN_CITEST \
|
||||
-e TEST_ACCESS_TOKEN_SDKDEV=$TEST_ACCESS_TOKEN_SDKDEV \
|
||||
-e TEST_LEVEL=$TEST_LEVEL \
|
||||
-e TEST_UPLOAD_MS_TOKEN=$TEST_UPLOAD_MS_TOKEN \
|
||||
-e MODEL_TAG_URL=$MODEL_TAG_URL \
|
||||
--workdir=$CODE_DIR_IN_CONTAINER \
|
||||
--net host \
|
||||
${IMAGE_NAME}:${IMAGE_VERSION} \
|
||||
$CI_COMMAND
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Running test case failed, please check the log!"
|
||||
exit -1
|
||||
|
||||
2
.gitattributes
vendored
2
.gitattributes
vendored
@@ -4,4 +4,6 @@
|
||||
*.wav filter=lfs diff=lfs merge=lfs -text
|
||||
*.JPEG filter=lfs diff=lfs merge=lfs -text
|
||||
*.jpeg filter=lfs diff=lfs merge=lfs -text
|
||||
*.pickle filter=lfs diff=lfs merge=lfs -text
|
||||
*.avi filter=lfs diff=lfs merge=lfs -text
|
||||
*.bin filter=lfs diff=lfs merge=lfs -text
|
||||
|
||||
@@ -25,4 +25,4 @@ python:
|
||||
install:
|
||||
- requirements: requirements/docs.txt
|
||||
- requirements: requirements/readthedocs.txt
|
||||
- requirements: requirements/runtime.txt
|
||||
- requirements: requirements/framework.txt
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
"framework": "pytorch",
|
||||
|
||||
"task": "image_classification",
|
||||
"work_dir": "./work_dir",
|
||||
|
||||
"model": {
|
||||
"type": "classification",
|
||||
@@ -119,6 +118,7 @@
|
||||
},
|
||||
|
||||
"train": {
|
||||
"work_dir": "./work_dir",
|
||||
"dataloader": {
|
||||
"batch_size_per_gpu": 2,
|
||||
"workers_per_gpu": 1
|
||||
|
||||
3
data/test/audios/1ch_nihaomiya.wav
Normal file
3
data/test/audios/1ch_nihaomiya.wav
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:4f7f5a0a4efca1e83463cb44460c66b56fb7cd673eb6da37924637bc05ef758d
|
||||
size 1440044
|
||||
3
data/test/images/face_emotion.jpg
Normal file
3
data/test/images/face_emotion.jpg
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:712b5525e37080d33f62d6657609dbef20e843ccc04ee5c788ea11aa7c08545e
|
||||
size 123341
|
||||
3
data/test/images/face_human_hand_detection.jpg
Normal file
3
data/test/images/face_human_hand_detection.jpg
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:8fddc7be8381eb244cd692601f1c1e6cf3484b44bb4e73df0bc7de29352eb487
|
||||
size 23889
|
||||
3
data/test/images/facial_expression_recognition.jpg
Normal file
3
data/test/images/facial_expression_recognition.jpg
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:bdb1cef5a5fd5f938a856311011c4820ddc45946a470b9929c61e59b6a065633
|
||||
size 161535
|
||||
3
data/test/images/hand_keypoints.jpg
Normal file
3
data/test/images/hand_keypoints.jpg
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:c05d58edee7398de37b8e479410676d6b97cfde69cc003e8356a348067e71988
|
||||
size 7750
|
||||
3
data/test/images/hand_static.jpg
Normal file
3
data/test/images/hand_static.jpg
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:94b8e281d77ee6d3ea2a8a0c9408ecdbd29fe75f33ea5399b6ea00070ba77bd6
|
||||
size 13090
|
||||
3
data/test/images/image-text-retrieval.jpg
Normal file
3
data/test/images/image-text-retrieval.jpg
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:b012c7e966f6550874ccb85ef9602d483aa89b8623dff9ffcdb0faab8f2ca9ab
|
||||
size 218143
|
||||
3
data/test/images/image_panoptic_segmentation.jpg
Normal file
3
data/test/images/image_panoptic_segmentation.jpg
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:59b1da30af12f76b691990363e0d221050a59cf53fc4a97e776bcb00228c6c2a
|
||||
size 245864
|
||||
3
data/test/images/image_segmentation.jpg
Normal file
3
data/test/images/image_segmentation.jpg
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:af6fa61274e497ecc170de5adc4b8e7ac89eba2bc22a6aa119b08ec7adbe9459
|
||||
size 146140
|
||||
3
data/test/images/image_semantic_segmentation.jpg
Normal file
3
data/test/images/image_semantic_segmentation.jpg
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:59b1da30af12f76b691990363e0d221050a59cf53fc4a97e776bcb00228c6c2a
|
||||
size 245864
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:331ead75033fa2f01f6be72a2f8e34d581fcb593308067815d4bb136bb13b766
|
||||
size 54390
|
||||
3
data/test/images/mog_face_detection.jpg
Normal file
3
data/test/images/mog_face_detection.jpg
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:176c824d99af119b36f743d3d90b44529167b0e4fc6db276da60fa140ee3f4a9
|
||||
size 87228
|
||||
3
data/test/images/mtcnn_face_detection.jpg
Normal file
3
data/test/images/mtcnn_face_detection.jpg
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:176c824d99af119b36f743d3d90b44529167b0e4fc6db276da60fa140ee3f4a9
|
||||
size 87228
|
||||
3
data/test/images/multimodal_similarity.jpg
Normal file
3
data/test/images/multimodal_similarity.jpg
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:1f24abbba43782d733dedbb0b4f416635af50263862e5632963ac9263e430555
|
||||
size 88542
|
||||
3
data/test/images/product_segmentation.jpg
Normal file
3
data/test/images/product_segmentation.jpg
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:a16038f7809127eb3e03cbae049592d193707e095309daca78f7d108d67fe4ec
|
||||
size 108357
|
||||
3
data/test/images/retina_face_detection.jpg
Normal file
3
data/test/images/retina_face_detection.jpg
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:176c824d99af119b36f743d3d90b44529167b0e4fc6db276da60fa140ee3f4a9
|
||||
size 87228
|
||||
3
data/test/images/shop_segmentation.jpg
Normal file
3
data/test/images/shop_segmentation.jpg
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:f5ecc371c8b0ca09d0e11df89bc549000937eafc451929586426fe657ade25a0
|
||||
size 238607
|
||||
3
data/test/images/text_driven_segmentation.jpg
Normal file
3
data/test/images/text_driven_segmentation.jpg
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:2c7d2f279e3b317f1d0de18410a0585e122166fa2464c17b88a0c813f6c58bd4
|
||||
size 67861
|
||||
3
data/test/images/ulfd_face_detection.jpg
Normal file
3
data/test/images/ulfd_face_detection.jpg
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:176c824d99af119b36f743d3d90b44529167b0e4fc6db276da60fa140ee3f4a9
|
||||
size 87228
|
||||
3
data/test/regression/fill_mask_bert_zh.bin
Normal file
3
data/test/regression/fill_mask_bert_zh.bin
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:541183383bb06aa3ca2c44a68cd51c1be5e3e984a1dee2c58092b9552660f3ce
|
||||
size 61883
|
||||
3
data/test/regression/fill_mask_sbert_en.bin
Normal file
3
data/test/regression/fill_mask_sbert_en.bin
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:8f0afcd9d2aa5ac9569114203bd9db4f1a520c903a88fd4854370cdde0e7eab7
|
||||
size 119940
|
||||
3
data/test/regression/fill_mask_sbert_zh.bin
Normal file
3
data/test/regression/fill_mask_sbert_zh.bin
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:4fd6fa6b23c2fdaf876606a767d9b64b1924e1acddfc06ac42db73ba86083280
|
||||
size 119940
|
||||
3
data/test/regression/fill_mask_veco_en.bin
Normal file
3
data/test/regression/fill_mask_veco_en.bin
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:4d37672a0e299a08d2daf5c7fc29bfce96bb15701fe5e5e68f068861ac2ee705
|
||||
size 119619
|
||||
3
data/test/regression/fill_mask_veco_zh.bin
Normal file
3
data/test/regression/fill_mask_veco_zh.bin
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:c692e0753cfe349e520511427727a8252f141fa10e85f9a61562845e8d731f9a
|
||||
size 119619
|
||||
3
data/test/regression/sbert_nli.bin
Normal file
3
data/test/regression/sbert_nli.bin
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:44e3925c15d86d8596baeb6bd1d153d86f57b7489798b2cf988a1248e110fd62
|
||||
size 62231
|
||||
3
data/test/regression/sbert_sen_sim.bin
Normal file
3
data/test/regression/sbert_sen_sim.bin
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:1ff17a0272752de4c88d4254b2e881f97f8ef022f03609d03ee1de0ae964368a
|
||||
size 62235
|
||||
3
data/test/regression/sbert_ws_en.bin
Normal file
3
data/test/regression/sbert_ws_en.bin
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:9103ce2bc89212f67fb49ce70783b7667e376900d0f70fb8f5c4432eb74bc572
|
||||
size 60801
|
||||
3
data/test/regression/sbert_ws_zh.bin
Normal file
3
data/test/regression/sbert_ws_zh.bin
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:2d4dee34c7e83b77db04fb2f0d1200bfd37c7c24954c58e185da5cb96445975c
|
||||
size 60801
|
||||
3
data/test/regression/sbert_zero_shot.bin
Normal file
3
data/test/regression/sbert_zero_shot.bin
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:9e3ecc2c30d382641d561f84849b199c12bb1a9418e8099a191153f6f5275a85
|
||||
size 61589
|
||||
3
data/test/videos/Walking.54138969.mp4
Normal file
3
data/test/videos/Walking.54138969.mp4
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:7663f9a32ea57086bf66c4b9e9ebe0fd418986c67716c7be02ca917e72ddc0ba
|
||||
size 8155895
|
||||
3
data/test/videos/action_detection_test_video.mp4
Normal file
3
data/test/videos/action_detection_test_video.mp4
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:0b7c3bc7c82ea5fee9d83130041df01046d89143ff77058b04577455ff6fdc92
|
||||
size 3191059
|
||||
3
data/test/videos/mask_dir/mask_00000_00320.png
Normal file
3
data/test/videos/mask_dir/mask_00000_00320.png
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:b158f6029d9763d7f84042f7c5835f398c688fdbb6b3f4fe6431101d4118c66c
|
||||
size 2766
|
||||
3
data/test/videos/mask_dir/mask_00321_00633.png
Normal file
3
data/test/videos/mask_dir/mask_00321_00633.png
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:0dcf46b93077e2229ab69cd6ddb80e2689546c575ee538bb2033fee1124ef3e3
|
||||
size 2761
|
||||
3
data/test/videos/movie_scene_segmentation_test_video.mp4
Normal file
3
data/test/videos/movie_scene_segmentation_test_video.mp4
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:03002807dc2aa180c3ae104e764c7a4d6c421d186a5d552f97d338467ae6c443
|
||||
size 12722029
|
||||
3
data/test/videos/video_inpainting_test.mp4
Normal file
3
data/test/videos/video_inpainting_test.mp4
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:9c9870df5a86acaaec67063183dace795479cd0f05296f13058995f475149c56
|
||||
size 2957783
|
||||
@@ -34,7 +34,8 @@ RUN wget --quiet https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-${a
|
||||
cp /tmp/resources/conda.tuna ~/.condarc && \
|
||||
source /root/.bashrc && \
|
||||
conda install --yes python==${PYTHON_VERSION} && \
|
||||
pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple && \
|
||||
pip config set install.trusted-host pypi.tuna.tsinghua.edu.cn
|
||||
|
||||
ARG USE_GPU=True
|
||||
|
||||
@@ -42,15 +43,15 @@ ARG USE_GPU=True
|
||||
ARG TORCH_VERSION=1.12.0
|
||||
ARG CUDATOOLKIT_VERSION=11.3
|
||||
RUN if [ "$USE_GPU" = "True" ] ; then \
|
||||
conda install --yes pytorch==$TORCH_VERSION torchvision torchaudio cudatoolkit=$CUDATOOLKIT_VERSION -c pytorch && conda clean --yes --all; \
|
||||
pip install --no-cache-dir torch==$TORCH_VERSION torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113; \
|
||||
else \
|
||||
conda install pytorch==$TORCH_VERSION torchvision torchaudio cpuonly -c pytorch; \
|
||||
pip install --no-cache-dir torch==$TORCH_VERSION torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu; \
|
||||
fi
|
||||
|
||||
# install tensorflow
|
||||
ARG TENSORFLOW_VERSION=1.15.5
|
||||
RUN if [ "$USE_GPU" = "True" ] ; then \
|
||||
pip install --no-cache-dir --use-deprecated=legacy-resolver tensorflow==$TENSORFLOW_VERSION -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html; \
|
||||
pip install --no-cache-dir tensorflow==$TENSORFLOW_VERSION -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html; \
|
||||
else \
|
||||
pip install --no-cache-dir tensorflow==$TENSORFLOW_VERSION; \
|
||||
fi
|
||||
@@ -64,7 +65,7 @@ RUN if [ "$USE_GPU" = "True" ] ; then \
|
||||
# install modelscope
|
||||
COPY requirements /var/modelscope
|
||||
RUN pip install --no-cache-dir --upgrade pip && \
|
||||
pip install --no-cache-dir -r /var/modelscope/runtime.txt -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html && \
|
||||
pip install --no-cache-dir -r /var/modelscope/framework.txt -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html && \
|
||||
pip install --no-cache-dir -r /var/modelscope/audio.txt -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html && \
|
||||
pip install --no-cache-dir -r /var/modelscope/cv.txt -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html && \
|
||||
pip install --no-cache-dir -r /var/modelscope/multi-modal.txt -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html && \
|
||||
@@ -75,7 +76,7 @@ RUN pip install --no-cache-dir --upgrade pip && \
|
||||
ENV SHELL=/bin/bash
|
||||
|
||||
# install special package
|
||||
RUN pip install --no-cache-dir mmcls>=0.21.0 mmdet>=2.25.0 decord>=0.6.0 numpy==1.18.5 datasets==2.1.0
|
||||
RUN pip install --no-cache-dir mmcls>=0.21.0 mmdet>=2.25.0 decord>=0.6.0 datasets==2.1.0 numpy==1.18.5 ipykernel fairseq
|
||||
|
||||
RUN if [ "$USE_GPU" = "True" ] ; then \
|
||||
pip install --no-cache-dir dgl-cu113 dglgo -f https://data.dgl.ai/wheels/repo.html; \
|
||||
|
||||
4
modelscope/exporters/__init__.py
Normal file
4
modelscope/exporters/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .base import Exporter
|
||||
from .builder import build_exporter
|
||||
from .nlp import SbertForSequenceClassificationExporter
|
||||
from .torch_model_exporter import TorchModelExporter
|
||||
53
modelscope/exporters/base.py
Normal file
53
modelscope/exporters/base.py
Normal file
@@ -0,0 +1,53 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from modelscope.models import Model
|
||||
from modelscope.utils.config import Config, ConfigDict
|
||||
from modelscope.utils.constant import ModelFile
|
||||
from .builder import build_exporter
|
||||
|
||||
|
||||
class Exporter(ABC):
|
||||
"""Exporter base class to output model to onnx, torch_script, graphdef, etc.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.model = None
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, model: Model, **kwargs):
|
||||
"""Build the Exporter instance.
|
||||
|
||||
@param model: A model instance. it will be used to output the generated file,
|
||||
and the configuration.json in its model_dir field will be used to create the exporter instance.
|
||||
@param kwargs: Extra kwargs used to create the Exporter instance.
|
||||
@return: The Exporter instance
|
||||
"""
|
||||
cfg = Config.from_file(
|
||||
os.path.join(model.model_dir, ModelFile.CONFIGURATION))
|
||||
task_name = cfg.task
|
||||
model_cfg = cfg.model
|
||||
if hasattr(model_cfg, 'model_type') and not hasattr(model_cfg, 'type'):
|
||||
model_cfg.type = model_cfg.model_type
|
||||
export_cfg = ConfigDict({'type': model_cfg.type})
|
||||
if hasattr(cfg, 'export'):
|
||||
export_cfg.update(cfg.export)
|
||||
exporter = build_exporter(export_cfg, task_name, kwargs)
|
||||
exporter.model = model
|
||||
return exporter
|
||||
|
||||
@abstractmethod
|
||||
def export_onnx(self, outputs: str, opset=11, **kwargs):
|
||||
"""Export the model as onnx format files.
|
||||
|
||||
In some cases, several files may be generated,
|
||||
So please return a dict which contains the generated name with the file path.
|
||||
|
||||
@param opset: The version of the ONNX operator set to use.
|
||||
@param outputs: The output dir.
|
||||
@param kwargs: In this default implementation,
|
||||
kwargs will be carried to generate_dummy_inputs as extra arguments (like input shape).
|
||||
@return: A dict contains the model name with the model file path.
|
||||
"""
|
||||
pass
|
||||
21
modelscope/exporters/builder.py
Normal file
21
modelscope/exporters/builder.py
Normal file
@@ -0,0 +1,21 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from modelscope.utils.config import ConfigDict
|
||||
from modelscope.utils.registry import Registry, build_from_cfg
|
||||
|
||||
EXPORTERS = Registry('exporters')
|
||||
|
||||
|
||||
def build_exporter(cfg: ConfigDict,
|
||||
task_name: str = None,
|
||||
default_args: dict = None):
|
||||
""" build exporter by the given model config dict
|
||||
|
||||
Args:
|
||||
cfg (:obj:`ConfigDict`): config dict for exporter object.
|
||||
task_name (str, optional): task name, refer to
|
||||
:obj:`Tasks` for more details
|
||||
default_args (dict, optional): Default initialization arguments.
|
||||
"""
|
||||
return build_from_cfg(
|
||||
cfg, EXPORTERS, group_key=task_name, default_args=default_args)
|
||||
2
modelscope/exporters/nlp/__init__.py
Normal file
2
modelscope/exporters/nlp/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .sbert_for_sequence_classification_exporter import \
|
||||
SbertForSequenceClassificationExporter
|
||||
@@ -0,0 +1,81 @@
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Dict, Mapping, Tuple
|
||||
|
||||
from torch.utils.data.dataloader import default_collate
|
||||
|
||||
from modelscope.exporters.builder import EXPORTERS
|
||||
from modelscope.exporters.torch_model_exporter import TorchModelExporter
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.preprocessors import Preprocessor, build_preprocessor
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import ModeKeys, Tasks
|
||||
|
||||
|
||||
@EXPORTERS.register_module(
|
||||
Tasks.sentence_similarity, module_name=Models.structbert)
|
||||
@EXPORTERS.register_module(
|
||||
Tasks.sentiment_classification, module_name=Models.structbert)
|
||||
@EXPORTERS.register_module(Tasks.nli, module_name=Models.structbert)
|
||||
@EXPORTERS.register_module(
|
||||
Tasks.zero_shot_classification, module_name=Models.structbert)
|
||||
class SbertForSequenceClassificationExporter(TorchModelExporter):
|
||||
|
||||
def generate_dummy_inputs(self,
|
||||
shape: Tuple = None,
|
||||
**kwargs) -> Dict[str, Any]:
|
||||
"""Generate dummy inputs for model exportation to onnx or other formats by tracing.
|
||||
|
||||
@param shape: A tuple of input shape which should have at most two dimensions.
|
||||
shape = (1, ) batch_size=1, sequence_length will be taken from the preprocessor.
|
||||
shape = (8, 128) batch_size=1, sequence_length=128, which will cover the config of the preprocessor.
|
||||
@return: Dummy inputs.
|
||||
"""
|
||||
|
||||
cfg = Config.from_file(
|
||||
os.path.join(self.model.model_dir, 'configuration.json'))
|
||||
field_name = Tasks.find_field_by_task(cfg.task)
|
||||
if 'type' not in cfg.preprocessor and 'val' in cfg.preprocessor:
|
||||
cfg = cfg.preprocessor.val
|
||||
else:
|
||||
cfg = cfg.preprocessor
|
||||
|
||||
batch_size = 1
|
||||
sequence_length = {}
|
||||
if shape is not None:
|
||||
if len(shape) == 1:
|
||||
batch_size = shape[0]
|
||||
elif len(shape) == 2:
|
||||
batch_size, max_length = shape
|
||||
sequence_length = {'sequence_length': max_length}
|
||||
|
||||
cfg.update({
|
||||
'model_dir': self.model.model_dir,
|
||||
'mode': ModeKeys.TRAIN,
|
||||
**sequence_length
|
||||
})
|
||||
preprocessor: Preprocessor = build_preprocessor(cfg, field_name)
|
||||
if preprocessor.pair:
|
||||
first_sequence = preprocessor.tokenizer.unk_token
|
||||
second_sequence = preprocessor.tokenizer.unk_token
|
||||
else:
|
||||
first_sequence = preprocessor.tokenizer.unk_token
|
||||
second_sequence = None
|
||||
|
||||
batched = []
|
||||
for _ in range(batch_size):
|
||||
batched.append(preprocessor((first_sequence, second_sequence)))
|
||||
return default_collate(batched)
|
||||
|
||||
@property
|
||||
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||
dynamic_axis = {0: 'batch', 1: 'sequence'}
|
||||
return OrderedDict([
|
||||
('input_ids', dynamic_axis),
|
||||
('attention_mask', dynamic_axis),
|
||||
('token_type_ids', dynamic_axis),
|
||||
])
|
||||
|
||||
@property
|
||||
def outputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||
return OrderedDict({'logits': {0: 'batch'}})
|
||||
247
modelscope/exporters/torch_model_exporter.py
Normal file
247
modelscope/exporters/torch_model_exporter.py
Normal file
@@ -0,0 +1,247 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from itertools import chain
|
||||
from typing import Any, Dict, Mapping
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.onnx import export as onnx_export
|
||||
from torch.onnx.utils import _decide_input_format
|
||||
|
||||
from modelscope.models import TorchModel
|
||||
from modelscope.pipelines.base import collate_fn
|
||||
from modelscope.utils.constant import ModelFile
|
||||
from modelscope.utils.logger import get_logger
|
||||
from modelscope.utils.regress_test_utils import compare_arguments_nested
|
||||
from modelscope.utils.tensor_utils import torch_nested_numpify
|
||||
from .base import Exporter
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class TorchModelExporter(Exporter):
|
||||
"""The torch base class of exporter.
|
||||
|
||||
This class provides the default implementations for exporting onnx and torch script.
|
||||
Each specific model may implement its own exporter by overriding the export_onnx/export_torch_script,
|
||||
and to provide implementations for generate_dummy_inputs/inputs/outputs methods.
|
||||
"""
|
||||
|
||||
def export_onnx(self, outputs: str, opset=11, **kwargs):
|
||||
"""Export the model as onnx format files.
|
||||
|
||||
In some cases, several files may be generated,
|
||||
So please return a dict which contains the generated name with the file path.
|
||||
|
||||
@param opset: The version of the ONNX operator set to use.
|
||||
@param outputs: The output dir.
|
||||
@param kwargs: In this default implementation,
|
||||
you can pass the arguments needed by _torch_export_onnx, other unrecognized args
|
||||
will be carried to generate_dummy_inputs as extra arguments (such as input shape).
|
||||
@return: A dict containing the model key - model file path pairs.
|
||||
"""
|
||||
model = self.model
|
||||
if not isinstance(model, nn.Module) and hasattr(model, 'model'):
|
||||
model = model.model
|
||||
onnx_file = os.path.join(outputs, ModelFile.ONNX_MODEL_FILE)
|
||||
self._torch_export_onnx(model, onnx_file, opset=opset, **kwargs)
|
||||
return {'model': onnx_file}
|
||||
|
||||
def export_torch_script(self, outputs: str, **kwargs):
|
||||
"""Export the model as torch script files.
|
||||
|
||||
In some cases, several files may be generated,
|
||||
So please return a dict which contains the generated name with the file path.
|
||||
|
||||
@param outputs: The output dir.
|
||||
@param kwargs: In this default implementation,
|
||||
you can pass the arguments needed by _torch_export_torch_script, other unrecognized args
|
||||
will be carried to generate_dummy_inputs as extra arguments (like input shape).
|
||||
@return: A dict contains the model name with the model file path.
|
||||
"""
|
||||
model = self.model
|
||||
if not isinstance(model, nn.Module) and hasattr(model, 'model'):
|
||||
model = model.model
|
||||
ts_file = os.path.join(outputs, ModelFile.TS_MODEL_FILE)
|
||||
# generate ts by tracing
|
||||
self._torch_export_torch_script(model, ts_file, **kwargs)
|
||||
return {'model': ts_file}
|
||||
|
||||
def generate_dummy_inputs(self, **kwargs) -> Dict[str, Any]:
|
||||
"""Generate dummy inputs for model exportation to onnx or other formats by tracing.
|
||||
@return: Dummy inputs.
|
||||
"""
|
||||
return None
|
||||
|
||||
@property
|
||||
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||
"""Return an ordered dict contains the model's input arguments name with their dynamic axis.
|
||||
|
||||
About the information of dynamic axis please check the dynamic_axes argument of torch.onnx.export function
|
||||
"""
|
||||
return None
|
||||
|
||||
@property
|
||||
def outputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||
"""Return an ordered dict contains the model's output arguments name with their dynamic axis.
|
||||
|
||||
About the information of dynamic axis please check the dynamic_axes argument of torch.onnx.export function
|
||||
"""
|
||||
return None
|
||||
|
||||
def _torch_export_onnx(self,
|
||||
model: nn.Module,
|
||||
output: str,
|
||||
opset: int = 11,
|
||||
device: str = 'cpu',
|
||||
validation: bool = True,
|
||||
rtol: float = None,
|
||||
atol: float = None,
|
||||
**kwargs):
|
||||
"""Export the model to an onnx format file.
|
||||
|
||||
@param model: A torch.nn.Module instance to export.
|
||||
@param output: The output file.
|
||||
@param opset: The version of the ONNX operator set to use.
|
||||
@param device: The device used to forward.
|
||||
@param validation: Whether validate the export file.
|
||||
@param rtol: The rtol used to regress the outputs.
|
||||
@param atol: The atol used to regress the outputs.
|
||||
"""
|
||||
|
||||
dummy_inputs = self.generate_dummy_inputs(**kwargs)
|
||||
inputs = self.inputs
|
||||
outputs = self.outputs
|
||||
if dummy_inputs is None or inputs is None or outputs is None:
|
||||
raise NotImplementedError(
|
||||
'Model property dummy_inputs,inputs,outputs must be set.')
|
||||
|
||||
with torch.no_grad():
|
||||
model.eval()
|
||||
device = torch.device(device)
|
||||
model.to(device)
|
||||
dummy_inputs = collate_fn(dummy_inputs, device)
|
||||
|
||||
if isinstance(dummy_inputs, Mapping):
|
||||
dummy_inputs = dict(dummy_inputs)
|
||||
onnx_outputs = list(self.outputs.keys())
|
||||
|
||||
with replace_call():
|
||||
onnx_export(
|
||||
model,
|
||||
(dummy_inputs, ),
|
||||
f=output,
|
||||
input_names=list(inputs.keys()),
|
||||
output_names=onnx_outputs,
|
||||
dynamic_axes={
|
||||
name: axes
|
||||
for name, axes in chain(inputs.items(),
|
||||
outputs.items())
|
||||
},
|
||||
do_constant_folding=True,
|
||||
opset_version=opset,
|
||||
)
|
||||
|
||||
if validation:
|
||||
try:
|
||||
import onnx
|
||||
import onnxruntime as ort
|
||||
except ImportError:
|
||||
logger.warn(
|
||||
'Cannot validate the exported onnx file, because '
|
||||
'the installation of onnx or onnxruntime cannot be found')
|
||||
return
|
||||
onnx_model = onnx.load(output)
|
||||
onnx.checker.check_model(onnx_model)
|
||||
ort_session = ort.InferenceSession(output)
|
||||
with torch.no_grad():
|
||||
model.eval()
|
||||
outputs_origin = model.forward(
|
||||
*_decide_input_format(model, dummy_inputs))
|
||||
if isinstance(outputs_origin, Mapping):
|
||||
outputs_origin = torch_nested_numpify(
|
||||
list(outputs_origin.values()))
|
||||
outputs = ort_session.run(
|
||||
onnx_outputs,
|
||||
torch_nested_numpify(dummy_inputs),
|
||||
)
|
||||
|
||||
tols = {}
|
||||
if rtol is not None:
|
||||
tols['rtol'] = rtol
|
||||
if atol is not None:
|
||||
tols['atol'] = atol
|
||||
if not compare_arguments_nested('Onnx model output match failed',
|
||||
outputs, outputs_origin, **tols):
|
||||
raise RuntimeError(
|
||||
'export onnx failed because of validation error.')
|
||||
|
||||
def _torch_export_torch_script(self,
|
||||
model: nn.Module,
|
||||
output: str,
|
||||
device: str = 'cpu',
|
||||
validation: bool = True,
|
||||
rtol: float = None,
|
||||
atol: float = None,
|
||||
**kwargs):
|
||||
"""Export the model to a torch script file.
|
||||
|
||||
@param model: A torch.nn.Module instance to export.
|
||||
@param output: The output file.
|
||||
@param device: The device used to forward.
|
||||
@param validation: Whether validate the export file.
|
||||
@param rtol: The rtol used to regress the outputs.
|
||||
@param atol: The atol used to regress the outputs.
|
||||
"""
|
||||
|
||||
model.eval()
|
||||
dummy_inputs = self.generate_dummy_inputs(**kwargs)
|
||||
if dummy_inputs is None:
|
||||
raise NotImplementedError(
|
||||
'Model property dummy_inputs must be set.')
|
||||
dummy_inputs = collate_fn(dummy_inputs, device)
|
||||
if isinstance(dummy_inputs, Mapping):
|
||||
dummy_inputs = tuple(dummy_inputs.values())
|
||||
with torch.no_grad():
|
||||
model.eval()
|
||||
with replace_call():
|
||||
traced_model = torch.jit.trace(
|
||||
model, dummy_inputs, strict=False)
|
||||
torch.jit.save(traced_model, output)
|
||||
|
||||
if validation:
|
||||
ts_model = torch.jit.load(output)
|
||||
with torch.no_grad():
|
||||
model.eval()
|
||||
ts_model.eval()
|
||||
outputs = ts_model.forward(*dummy_inputs)
|
||||
outputs = torch_nested_numpify(outputs)
|
||||
outputs_origin = model.forward(*dummy_inputs)
|
||||
outputs_origin = torch_nested_numpify(outputs_origin)
|
||||
tols = {}
|
||||
if rtol is not None:
|
||||
tols['rtol'] = rtol
|
||||
if atol is not None:
|
||||
tols['atol'] = atol
|
||||
if not compare_arguments_nested(
|
||||
'Torch script model output match failed', outputs,
|
||||
outputs_origin, **tols):
|
||||
raise RuntimeError(
|
||||
'export torch script failed because of validation error.')
|
||||
|
||||
|
||||
@contextmanager
|
||||
def replace_call():
|
||||
"""This function is used to recover the original call method.
|
||||
|
||||
The Model class of modelscope overrides the call method. When exporting to onnx or torchscript, torch will
|
||||
prepare the parameters as the prototype of forward method, and trace the call method, this causes
|
||||
problems. Here we recover the call method to the default implementation of torch.nn.Module, and change it
|
||||
back after the tracing was done.
|
||||
"""
|
||||
|
||||
TorchModel.call_origin, TorchModel.__call__ = TorchModel.__call__, TorchModel._call_impl
|
||||
yield
|
||||
TorchModel.__call__ = TorchModel.call_origin
|
||||
del TorchModel.call_origin
|
||||
@@ -1,2 +1,4 @@
|
||||
from .file import File
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from .file import File, LocalStorage
|
||||
from .io import dump, dumps, load
|
||||
|
||||
@@ -240,7 +240,7 @@ class File(object):
|
||||
@staticmethod
|
||||
def _get_storage(uri):
|
||||
assert isinstance(uri,
|
||||
str), f'uri should be str type, buf got {type(uri)}'
|
||||
str), f'uri should be str type, but got {type(uri)}'
|
||||
|
||||
if '://' not in uri:
|
||||
# local path
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from .base import FormatHandler
|
||||
from .json import JsonHandler
|
||||
from .yaml import YamlHandler
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
from .base import FormatHandler
|
||||
@@ -22,14 +21,16 @@ def set_default(obj):
|
||||
|
||||
|
||||
class JsonHandler(FormatHandler):
|
||||
"""Use jsonplus, serialization of Python types to JSON that "just works"."""
|
||||
|
||||
def load(self, file):
|
||||
return json.load(file)
|
||||
import jsonplus
|
||||
return jsonplus.loads(file.read())
|
||||
|
||||
def dump(self, obj, file, **kwargs):
|
||||
kwargs.setdefault('default', set_default)
|
||||
json.dump(obj, file, **kwargs)
|
||||
file.write(self.dumps(obj, **kwargs))
|
||||
|
||||
def dumps(self, obj, **kwargs):
|
||||
import jsonplus
|
||||
kwargs.setdefault('default', set_default)
|
||||
return json.dumps(obj, **kwargs)
|
||||
return jsonplus.dumps(obj, **kwargs)
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
import pickle
|
||||
import shutil
|
||||
import subprocess
|
||||
from collections import defaultdict
|
||||
from http import HTTPStatus
|
||||
from http.cookiejar import CookieJar
|
||||
@@ -16,8 +17,7 @@ from modelscope.hub.constants import (API_RESPONSE_FIELD_DATA,
|
||||
API_RESPONSE_FIELD_MESSAGE,
|
||||
API_RESPONSE_FIELD_USERNAME,
|
||||
DEFAULT_CREDENTIALS_PATH)
|
||||
from modelscope.msdatasets.config import (DOWNLOADED_DATASETS_PATH,
|
||||
HUB_DATASET_ENDPOINT)
|
||||
from modelscope.utils.config_ds import DOWNLOADED_DATASETS_PATH
|
||||
from modelscope.utils.constant import (DEFAULT_DATASET_REVISION,
|
||||
DEFAULT_MODEL_REVISION,
|
||||
DatasetFormations, DatasetMetaFormats,
|
||||
@@ -26,7 +26,8 @@ from modelscope.utils.logger import get_logger
|
||||
from .errors import (InvalidParameter, NotExistError, RequestError,
|
||||
datahub_raise_on_error, handle_http_response, is_ok,
|
||||
raise_on_error)
|
||||
from .utils.utils import get_endpoint, model_id_to_group_owner_name
|
||||
from .utils.utils import (get_dataset_hub_endpoint, get_endpoint,
|
||||
model_id_to_group_owner_name)
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
@@ -35,7 +36,8 @@ class HubApi:
|
||||
|
||||
def __init__(self, endpoint=None, dataset_endpoint=None):
|
||||
self.endpoint = endpoint if endpoint is not None else get_endpoint()
|
||||
self.dataset_endpoint = dataset_endpoint if dataset_endpoint is not None else HUB_DATASET_ENDPOINT
|
||||
self.dataset_endpoint = dataset_endpoint if dataset_endpoint is not None else get_dataset_hub_endpoint(
|
||||
)
|
||||
|
||||
def login(
|
||||
self,
|
||||
@@ -376,6 +378,27 @@ class HubApi:
|
||||
f'ststoken?Revision={revision}'
|
||||
return self.datahub_remote_call(datahub_url)
|
||||
|
||||
def get_dataset_access_config_session(
|
||||
self,
|
||||
cookies: CookieJar,
|
||||
dataset_name: str,
|
||||
namespace: str,
|
||||
revision: Optional[str] = DEFAULT_DATASET_REVISION):
|
||||
|
||||
datahub_url = f'{self.dataset_endpoint}/api/v1/datasets/{namespace}/{dataset_name}/' \
|
||||
f'ststoken?Revision={revision}'
|
||||
|
||||
cookies = requests.utils.dict_from_cookiejar(cookies)
|
||||
r = requests.get(url=datahub_url, cookies=cookies)
|
||||
resp = r.json()
|
||||
raise_on_error(resp)
|
||||
return resp['Data']
|
||||
|
||||
def on_dataset_download(self, dataset_name: str, namespace: str) -> None:
|
||||
url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/download/increase'
|
||||
r = requests.post(url)
|
||||
r.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
def datahub_remote_call(url):
|
||||
r = requests.get(url)
|
||||
@@ -383,6 +406,9 @@ class HubApi:
|
||||
datahub_raise_on_error(url, resp)
|
||||
return resp['Data']
|
||||
|
||||
def check_cookies_upload_data(self, use_cookies) -> CookieJar:
|
||||
return self._check_cookie(use_cookies=use_cookies)
|
||||
|
||||
|
||||
class ModelScopeConfig:
|
||||
path_credential = expanduser(DEFAULT_CREDENTIALS_PATH)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
MODELSCOPE_URL_SCHEME = 'http://'
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from http import HTTPStatus
|
||||
|
||||
from requests.exceptions import HTTPError
|
||||
@@ -49,8 +51,8 @@ def handle_http_response(response, logger, cookies, model_id):
|
||||
except HTTPError:
|
||||
if cookies is None: # code in [403] and
|
||||
logger.error(
|
||||
f'Authentication token does not exist, failed to access model {model_id} which may not exist or may be private. \
|
||||
Please login first.')
|
||||
f'Authentication token does not exist, failed to access model {model_id} which may not exist or may be \
|
||||
private. Please login first.')
|
||||
raise
|
||||
|
||||
|
||||
@@ -60,7 +62,7 @@ def raise_on_error(rsp):
|
||||
Args:
|
||||
rsp (_type_): The server response
|
||||
"""
|
||||
if rsp['Code'] == HTTPStatus.OK and rsp['Success']:
|
||||
if rsp['Code'] == HTTPStatus.OK:
|
||||
return True
|
||||
else:
|
||||
raise RequestError(rsp['Message'])
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import copy
|
||||
import os
|
||||
import sys
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
from typing import List
|
||||
@@ -39,17 +41,28 @@ class GitCommandWrapper(metaclass=Singleton):
|
||||
subprocess.CompletedProcess: the command response
|
||||
"""
|
||||
logger.debug(' '.join(args))
|
||||
git_env = os.environ.copy()
|
||||
git_env['GIT_TERMINAL_PROMPT'] = '0'
|
||||
response = subprocess.run(
|
||||
[self.git_path, *args],
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE) # compatible for python3.6
|
||||
stderr=subprocess.PIPE,
|
||||
env=git_env,
|
||||
) # compatible for python3.6
|
||||
try:
|
||||
response.check_returncode()
|
||||
return response
|
||||
except subprocess.CalledProcessError as error:
|
||||
raise GitError(
|
||||
'stdout: %s, stderr: %s' %
|
||||
(response.stdout.decode('utf8'), error.stderr.decode('utf8')))
|
||||
if response.returncode == 1:
|
||||
logger.info('Nothing to commit.')
|
||||
return response
|
||||
else:
|
||||
logger.error(
|
||||
'There are error run git command, you may need to login first.'
|
||||
)
|
||||
raise GitError('stdout: %s, stderr: %s' %
|
||||
(response.stdout.decode('utf8'),
|
||||
error.stderr.decode('utf8')))
|
||||
|
||||
def config_auth_token(self, repo_dir, auth_token):
|
||||
url = self.get_repo_remote_url(repo_dir)
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from modelscope.hub.errors import GitError, InvalidParameter, NotLoginException
|
||||
from modelscope.utils.constant import DEFAULT_MODEL_REVISION
|
||||
from modelscope.utils.constant import (DEFAULT_DATASET_REVISION,
|
||||
DEFAULT_MODEL_REVISION)
|
||||
from modelscope.utils.logger import get_logger
|
||||
from .api import ModelScopeConfig
|
||||
from .git import GitCommandWrapper
|
||||
@@ -15,14 +18,12 @@ class Repository:
|
||||
"""A local representation of the model git repository.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_dir: str,
|
||||
clone_from: str,
|
||||
revision: Optional[str] = DEFAULT_MODEL_REVISION,
|
||||
auth_token: Optional[str] = None,
|
||||
git_path: Optional[str] = None,
|
||||
):
|
||||
def __init__(self,
|
||||
model_dir: str,
|
||||
clone_from: str,
|
||||
revision: Optional[str] = DEFAULT_MODEL_REVISION,
|
||||
auth_token: Optional[str] = None,
|
||||
git_path: Optional[str] = None):
|
||||
"""
|
||||
Instantiate a Repository object by cloning the remote ModelScopeHub repo
|
||||
Args:
|
||||
@@ -41,6 +42,11 @@ class Repository:
|
||||
self.model_dir = model_dir
|
||||
self.model_base_dir = os.path.dirname(model_dir)
|
||||
self.model_repo_name = os.path.basename(model_dir)
|
||||
|
||||
if not revision:
|
||||
err_msg = 'a non-default value of revision cannot be empty.'
|
||||
raise InvalidParameter(err_msg)
|
||||
|
||||
if auth_token:
|
||||
self.auth_token = auth_token
|
||||
else:
|
||||
@@ -86,6 +92,7 @@ class Repository:
|
||||
branch: Optional[str] = DEFAULT_MODEL_REVISION,
|
||||
force: bool = False):
|
||||
"""Push local files to remote, this method will do.
|
||||
git pull
|
||||
git add
|
||||
git commit
|
||||
git push
|
||||
@@ -117,3 +124,118 @@ class Repository:
|
||||
url=url,
|
||||
local_branch=branch,
|
||||
remote_branch=branch)
|
||||
|
||||
|
||||
class DatasetRepository:
|
||||
"""A local representation of the dataset (metadata) git repository.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
repo_work_dir: str,
|
||||
dataset_id: str,
|
||||
revision: Optional[str] = DEFAULT_DATASET_REVISION,
|
||||
auth_token: Optional[str] = None,
|
||||
git_path: Optional[str] = None):
|
||||
"""
|
||||
Instantiate a Dataset Repository object by cloning the remote ModelScope dataset repo
|
||||
Args:
|
||||
repo_work_dir(`str`):
|
||||
The dataset repo root directory.
|
||||
dataset_id:
|
||||
dataset id in ModelScope from which git clone
|
||||
revision(`Optional[str]`):
|
||||
revision of the dataset you want to clone from. Can be any of a branch, tag or commit hash
|
||||
auth_token(`Optional[str]`):
|
||||
token obtained when calling `HubApi.login()`. Usually you can safely ignore the parameter
|
||||
as the token is already saved when you login the first time, if None, we will use saved token.
|
||||
git_path:(`Optional[str]`):
|
||||
The git command line path, if None, we use 'git'
|
||||
"""
|
||||
self.dataset_id = dataset_id
|
||||
if not repo_work_dir or not isinstance(repo_work_dir, str):
|
||||
err_msg = 'dataset_work_dir must be provided!'
|
||||
raise InvalidParameter(err_msg)
|
||||
self.repo_work_dir = repo_work_dir.rstrip('/')
|
||||
if not self.repo_work_dir:
|
||||
err_msg = 'dataset_work_dir can not be root dir!'
|
||||
raise InvalidParameter(err_msg)
|
||||
self.repo_base_dir = os.path.dirname(self.repo_work_dir)
|
||||
self.repo_name = os.path.basename(self.repo_work_dir)
|
||||
|
||||
if not revision:
|
||||
err_msg = 'a non-default value of revision cannot be empty.'
|
||||
raise InvalidParameter(err_msg)
|
||||
self.revision = revision
|
||||
|
||||
if auth_token:
|
||||
self.auth_token = auth_token
|
||||
else:
|
||||
self.auth_token = ModelScopeConfig.get_token()
|
||||
|
||||
self.git_wrapper = GitCommandWrapper(git_path)
|
||||
os.makedirs(self.repo_work_dir, exist_ok=True)
|
||||
self.repo_url = self._get_repo_url(dataset_id=dataset_id)
|
||||
|
||||
def clone(self) -> str:
|
||||
# check local repo dir, directory not empty.
|
||||
if os.listdir(self.repo_work_dir):
|
||||
remote_url = self._get_remote_url()
|
||||
remote_url = self.git_wrapper.remove_token_from_url(remote_url)
|
||||
# no need clone again
|
||||
if remote_url and remote_url == self.repo_url:
|
||||
return ''
|
||||
|
||||
logger.info('Cloning repo from {} '.format(self.repo_url))
|
||||
self.git_wrapper.clone(self.repo_base_dir, self.auth_token,
|
||||
self.repo_url, self.repo_name, self.revision)
|
||||
return self.repo_work_dir
|
||||
|
||||
def push(self,
|
||||
commit_message: str,
|
||||
branch: Optional[str] = DEFAULT_DATASET_REVISION,
|
||||
force: bool = False):
|
||||
"""Push local files to remote, this method will do.
|
||||
git pull
|
||||
git add
|
||||
git commit
|
||||
git push
|
||||
Args:
|
||||
commit_message (str): commit message
|
||||
branch (Optional[str], optional): which branch to push.
|
||||
force (Optional[bool]): whether to use forced-push.
|
||||
"""
|
||||
if commit_message is None or not isinstance(commit_message, str):
|
||||
msg = 'commit_message must be provided!'
|
||||
raise InvalidParameter(msg)
|
||||
|
||||
if not isinstance(force, bool):
|
||||
raise InvalidParameter('force must be bool')
|
||||
|
||||
if not self.auth_token:
|
||||
raise NotLoginException('Must login to push, please login first.')
|
||||
|
||||
self.git_wrapper.config_auth_token(self.repo_work_dir, self.auth_token)
|
||||
self.git_wrapper.add_user_info(self.repo_base_dir, self.repo_name)
|
||||
|
||||
remote_url = self._get_remote_url()
|
||||
remote_url = self.git_wrapper.remove_token_from_url(remote_url)
|
||||
|
||||
self.git_wrapper.pull(self.repo_work_dir)
|
||||
self.git_wrapper.add(self.repo_work_dir, all_files=True)
|
||||
self.git_wrapper.commit(self.repo_work_dir, commit_message)
|
||||
self.git_wrapper.push(
|
||||
repo_dir=self.repo_work_dir,
|
||||
token=self.auth_token,
|
||||
url=remote_url,
|
||||
local_branch=branch,
|
||||
remote_branch=branch)
|
||||
|
||||
def _get_repo_url(self, dataset_id):
|
||||
return f'{get_endpoint()}/datasets/{dataset_id}.git'
|
||||
|
||||
def _get_remote_url(self):
|
||||
try:
|
||||
remote = self.git_wrapper.get_repo_remote_url(self.repo_work_dir)
|
||||
except GitError:
|
||||
remote = None
|
||||
return remote
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
import pickle
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from modelscope.hub.constants import (DEFAULT_MODELSCOPE_DOMAIN,
|
||||
from modelscope.hub.constants import (DEFAULT_MODELSCOPE_DATA_ENDPOINT,
|
||||
DEFAULT_MODELSCOPE_DOMAIN,
|
||||
DEFAULT_MODELSCOPE_GROUP,
|
||||
MODEL_ID_SEPARATOR,
|
||||
MODELSCOPE_URL_SCHEME)
|
||||
@@ -22,14 +26,16 @@ def model_id_to_group_owner_name(model_id):
|
||||
return group_or_owner, name
|
||||
|
||||
|
||||
def get_cache_dir():
|
||||
def get_cache_dir(model_id: Optional[str] = None):
|
||||
"""
|
||||
cache dir precedence:
|
||||
function parameter > enviroment > ~/.cache/modelscope/hub
|
||||
"""
|
||||
default_cache_dir = get_default_cache_dir()
|
||||
return os.getenv('MODELSCOPE_CACHE', os.path.join(default_cache_dir,
|
||||
'hub'))
|
||||
base_path = os.getenv('MODELSCOPE_CACHE',
|
||||
os.path.join(default_cache_dir, 'hub'))
|
||||
return base_path if model_id is None else os.path.join(
|
||||
base_path, model_id + '/')
|
||||
|
||||
|
||||
def get_endpoint():
|
||||
@@ -38,6 +44,11 @@ def get_endpoint():
|
||||
return MODELSCOPE_URL_SCHEME + modelscope_domain
|
||||
|
||||
|
||||
def get_dataset_hub_endpoint():
|
||||
return os.environ.get('HUB_DATASET_ENDPOINT',
|
||||
DEFAULT_MODELSCOPE_DATA_ENDPOINT)
|
||||
|
||||
|
||||
def compute_hash(file_path):
|
||||
BUFFER_SIZE = 1024 * 64 # 64k buffer size
|
||||
sha256_hash = hashlib.sha256()
|
||||
|
||||
@@ -9,8 +9,11 @@ class Models(object):
|
||||
|
||||
Model name should only contain model info but not task info.
|
||||
"""
|
||||
tinynas_detection = 'tinynas-detection'
|
||||
|
||||
# vision models
|
||||
detection = 'detection'
|
||||
realtime_object_detection = 'realtime-object-detection'
|
||||
scrfd = 'scrfd'
|
||||
classification_model = 'ClassificationModel'
|
||||
nafnet = 'nafnet'
|
||||
@@ -19,23 +22,54 @@ class Models(object):
|
||||
gpen = 'gpen'
|
||||
product_retrieval_embedding = 'product-retrieval-embedding'
|
||||
body_2d_keypoints = 'body-2d-keypoints'
|
||||
body_3d_keypoints = 'body-3d-keypoints'
|
||||
crowd_counting = 'HRNetCrowdCounting'
|
||||
face_2d_keypoints = 'face-2d-keypoints'
|
||||
panoptic_segmentation = 'swinL-panoptic-segmentation'
|
||||
image_reid_person = 'passvitb'
|
||||
video_summarization = 'pgl-video-summarization'
|
||||
swinL_semantic_segmentation = 'swinL-semantic-segmentation'
|
||||
vitadapter_semantic_segmentation = 'vitadapter-semantic-segmentation'
|
||||
text_driven_segmentation = 'text-driven-segmentation'
|
||||
resnet50_bert = 'resnet50-bert'
|
||||
fer = 'fer'
|
||||
retinaface = 'retinaface'
|
||||
shop_segmentation = 'shop-segmentation'
|
||||
mogface = 'mogface'
|
||||
mtcnn = 'mtcnn'
|
||||
ulfd = 'ulfd'
|
||||
video_inpainting = 'video-inpainting'
|
||||
hand_static = 'hand-static'
|
||||
face_human_hand_detection = 'face-human-hand-detection'
|
||||
face_emotion = 'face-emotion'
|
||||
product_segmentation = 'product-segmentation'
|
||||
|
||||
# EasyCV models
|
||||
yolox = 'YOLOX'
|
||||
segformer = 'Segformer'
|
||||
|
||||
# nlp models
|
||||
bert = 'bert'
|
||||
palm = 'palm-v2'
|
||||
structbert = 'structbert'
|
||||
deberta_v2 = 'deberta_v2'
|
||||
veco = 'veco'
|
||||
translation = 'csanmt-translation'
|
||||
space_dst = 'space-dst'
|
||||
space_intent = 'space-intent'
|
||||
space_modeling = 'space-modeling'
|
||||
star = 'star'
|
||||
star3 = 'star3'
|
||||
tcrf = 'transformer-crf'
|
||||
transformer_softmax = 'transformer-softmax'
|
||||
lcrf = 'lstm-crf'
|
||||
gcnncrf = 'gcnn-crf'
|
||||
bart = 'bart'
|
||||
gpt3 = 'gpt3'
|
||||
plug = 'plug'
|
||||
bert_for_ds = 'bert-for-document-segmentation'
|
||||
ponet = 'ponet'
|
||||
T5 = 'T5'
|
||||
|
||||
# audio models
|
||||
sambert_hifigan = 'sambert-hifigan'
|
||||
@@ -50,21 +84,33 @@ class Models(object):
|
||||
gemm = 'gemm-generative-multi-modal'
|
||||
mplug = 'mplug'
|
||||
diffusion = 'diffusion-text-to-image-synthesis'
|
||||
multi_stage_diffusion = 'multi-stage-diffusion-text-to-image-synthesis'
|
||||
team = 'team-multi-modal-similarity'
|
||||
video_clip = 'video-clip-multi-modal-embedding'
|
||||
|
||||
|
||||
class TaskModels(object):
|
||||
# nlp task
|
||||
text_classification = 'text-classification'
|
||||
token_classification = 'token-classification'
|
||||
information_extraction = 'information-extraction'
|
||||
fill_mask = 'fill-mask'
|
||||
feature_extraction = 'feature-extraction'
|
||||
|
||||
|
||||
class Heads(object):
|
||||
# nlp heads
|
||||
|
||||
# text cls
|
||||
text_classification = 'text-classification'
|
||||
# mlm
|
||||
# fill mask
|
||||
fill_mask = 'fill-mask'
|
||||
bert_mlm = 'bert-mlm'
|
||||
# roberta mlm
|
||||
roberta_mlm = 'roberta-mlm'
|
||||
# token cls
|
||||
token_classification = 'token-classification'
|
||||
# extraction
|
||||
information_extraction = 'information-extraction'
|
||||
|
||||
|
||||
class Pipelines(object):
|
||||
@@ -86,12 +132,23 @@ class Pipelines(object):
|
||||
animal_recognition = 'resnet101-animal-recognition'
|
||||
general_recognition = 'resnet101-general-recognition'
|
||||
cmdssl_video_embedding = 'cmdssl-r2p1d_video_embedding'
|
||||
hicossl_video_embedding = 'hicossl-s3dg-video_embedding'
|
||||
body_2d_keypoints = 'hrnetv2w32_body-2d-keypoints_image'
|
||||
body_3d_keypoints = 'canonical_body-3d-keypoints_video'
|
||||
hand_2d_keypoints = 'hrnetv2w18_hand-2d-keypoints_image'
|
||||
human_detection = 'resnet18-human-detection'
|
||||
object_detection = 'vit-object-detection'
|
||||
easycv_detection = 'easycv-detection'
|
||||
easycv_segmentation = 'easycv-segmentation'
|
||||
face_2d_keypoints = 'mobilenet_face-2d-keypoints_alignment'
|
||||
salient_detection = 'u2net-salient-detection'
|
||||
image_classification = 'image-classification'
|
||||
face_detection = 'resnet-face-detection-scrfd10gkps'
|
||||
ulfd_face_detection = 'manual-face-detection-ulfd'
|
||||
facial_expression_recognition = 'vgg19-facial-expression-recognition-fer'
|
||||
retina_face_detection = 'resnet50-face-detection-retinaface'
|
||||
mog_face_detection = 'resnet101-face-detection-cvpr22papermogface'
|
||||
mtcnn_face_detection = 'manual-face-detection-mtcnn'
|
||||
live_category = 'live-category'
|
||||
general_image_classification = 'vit-base_image-classification_ImageNet-labels'
|
||||
daily_image_classification = 'vit-base_image-classification_Dailylife-labels'
|
||||
@@ -102,6 +159,7 @@ class Pipelines(object):
|
||||
image_super_resolution = 'rrdb-image-super-resolution'
|
||||
face_image_generation = 'gan-face-image-generation'
|
||||
product_retrieval_embedding = 'resnet50-product-retrieval-embedding'
|
||||
realtime_object_detection = 'cspnet_realtime-object-detection_yolox'
|
||||
face_recognition = 'ir101-face-recognition-cfglint'
|
||||
image_instance_segmentation = 'cascade-mask-rcnn-swin-image-instance-segmentation'
|
||||
image2image_translation = 'image-to-image-translation'
|
||||
@@ -112,20 +170,36 @@ class Pipelines(object):
|
||||
image_to_image_generation = 'image-to-image-generation'
|
||||
skin_retouching = 'unet-skin-retouching'
|
||||
tinynas_classification = 'tinynas-classification'
|
||||
tinynas_detection = 'tinynas-detection'
|
||||
crowd_counting = 'hrnet-crowd-counting'
|
||||
action_detection = 'ResNetC3D-action-detection'
|
||||
video_single_object_tracking = 'ostrack-vitb-video-single-object-tracking'
|
||||
image_panoptic_segmentation = 'image-panoptic-segmentation'
|
||||
video_summarization = 'googlenet_pgl_video_summarization'
|
||||
image_semantic_segmentation = 'image-semantic-segmentation'
|
||||
image_reid_person = 'passvitb-image-reid-person'
|
||||
text_driven_segmentation = 'text-driven-segmentation'
|
||||
movie_scene_segmentation = 'resnet50-bert-movie-scene-segmentation'
|
||||
shop_segmentation = 'shop-segmentation'
|
||||
video_inpainting = 'video-inpainting'
|
||||
pst_action_recognition = 'patchshift-action-recognition'
|
||||
hand_static = 'hand-static'
|
||||
face_human_hand_detection = 'face-human-hand-detection'
|
||||
face_emotion = 'face-emotion'
|
||||
product_segmentation = 'product-segmentation'
|
||||
|
||||
# nlp tasks
|
||||
sentence_similarity = 'sentence-similarity'
|
||||
word_segmentation = 'word-segmentation'
|
||||
part_of_speech = 'part-of-speech'
|
||||
named_entity_recognition = 'named-entity-recognition'
|
||||
text_generation = 'text-generation'
|
||||
text2text_generation = 'text2text-generation'
|
||||
sentiment_analysis = 'sentiment-analysis'
|
||||
sentiment_classification = 'sentiment-classification'
|
||||
text_classification = 'text-classification'
|
||||
fill_mask = 'fill-mask'
|
||||
fill_mask_ponet = 'fill-mask-ponet'
|
||||
csanmt_translation = 'csanmt-translation'
|
||||
nli = 'nli'
|
||||
dialog_intent_prediction = 'dialog-intent-prediction'
|
||||
@@ -133,7 +207,15 @@ class Pipelines(object):
|
||||
dialog_state_tracking = 'dialog-state-tracking'
|
||||
zero_shot_classification = 'zero-shot-classification'
|
||||
text_error_correction = 'text-error-correction'
|
||||
plug_generation = 'plug-generation'
|
||||
faq_question_answering = 'faq-question-answering'
|
||||
conversational_text_to_sql = 'conversational-text-to-sql'
|
||||
table_question_answering_pipeline = 'table-question-answering-pipeline'
|
||||
sentence_embedding = 'sentence-embedding'
|
||||
passage_ranking = 'passage-ranking'
|
||||
relation_extraction = 'relation-extraction'
|
||||
document_segmentation = 'document-segmentation'
|
||||
feature_extraction = 'feature-extraction'
|
||||
|
||||
# audio tasks
|
||||
sambert_hifigan_tts = 'sambert-hifigan-tts'
|
||||
@@ -150,8 +232,10 @@ class Pipelines(object):
|
||||
visual_question_answering = 'visual-question-answering'
|
||||
visual_grounding = 'visual-grounding'
|
||||
visual_entailment = 'visual-entailment'
|
||||
multi_modal_similarity = 'multi-modal-similarity'
|
||||
text_to_image_synthesis = 'text-to-image-synthesis'
|
||||
video_multi_modal_embedding = 'video-multi-modal-embedding'
|
||||
image_text_retrieval = 'image-text-retrieval'
|
||||
|
||||
|
||||
class Trainers(object):
|
||||
@@ -165,6 +249,7 @@ class Trainers(object):
|
||||
"""
|
||||
|
||||
default = 'trainer'
|
||||
easycv = 'easycv'
|
||||
|
||||
# multi-modal trainers
|
||||
clip_multi_modal_embedding = 'clip-multi-modal-embedding'
|
||||
@@ -173,11 +258,18 @@ class Trainers(object):
|
||||
image_instance_segmentation = 'image-instance-segmentation'
|
||||
image_portrait_enhancement = 'image-portrait-enhancement'
|
||||
video_summarization = 'video-summarization'
|
||||
movie_scene_segmentation = 'movie-scene-segmentation'
|
||||
|
||||
# nlp trainers
|
||||
bert_sentiment_analysis = 'bert-sentiment-analysis'
|
||||
dialog_modeling_trainer = 'dialog-modeling-trainer'
|
||||
dialog_intent_trainer = 'dialog-intent-trainer'
|
||||
nlp_base_trainer = 'nlp-base-trainer'
|
||||
nlp_veco_trainer = 'nlp-veco-trainer'
|
||||
nlp_passage_ranking_trainer = 'nlp-passage-ranking-trainer'
|
||||
|
||||
# audio trainers
|
||||
speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k'
|
||||
|
||||
|
||||
class Preprocessors(object):
|
||||
@@ -198,11 +290,14 @@ class Preprocessors(object):
|
||||
image_instance_segmentation_preprocessor = 'image-instance-segmentation-preprocessor'
|
||||
image_portrait_enhancement_preprocessor = 'image-portrait-enhancement-preprocessor'
|
||||
video_summarization_preprocessor = 'video-summarization-preprocessor'
|
||||
movie_scene_segmentation_preprocessor = 'movie-scene-segmentation-preprocessor'
|
||||
|
||||
# nlp preprocessor
|
||||
sen_sim_tokenizer = 'sen-sim-tokenizer'
|
||||
cross_encoder_tokenizer = 'cross-encoder-tokenizer'
|
||||
bert_seq_cls_tokenizer = 'bert-seq-cls-tokenizer'
|
||||
text_gen_tokenizer = 'text-gen-tokenizer'
|
||||
text2text_gen_preprocessor = 'text2text-gen-preprocessor'
|
||||
token_cls_tokenizer = 'token-cls-tokenizer'
|
||||
ner_tokenizer = 'ner-tokenizer'
|
||||
nli_tokenizer = 'nli-tokenizer'
|
||||
@@ -213,9 +308,18 @@ class Preprocessors(object):
|
||||
sbert_token_cls_tokenizer = 'sbert-token-cls-tokenizer'
|
||||
zero_shot_cls_tokenizer = 'zero-shot-cls-tokenizer'
|
||||
text_error_correction = 'text-error-correction'
|
||||
sentence_embedding = 'sentence-embedding'
|
||||
passage_ranking = 'passage-ranking'
|
||||
sequence_labeling_tokenizer = 'sequence-labeling-tokenizer'
|
||||
word_segment_text_to_label_preprocessor = 'word-segment-text-to-label-preprocessor'
|
||||
fill_mask = 'fill-mask'
|
||||
fill_mask_ponet = 'fill-mask-ponet'
|
||||
faq_question_answering_preprocessor = 'faq-question-answering-preprocessor'
|
||||
conversational_text_to_sql = 'conversational-text-to-sql'
|
||||
table_question_answering_preprocessor = 'table-question-answering-preprocessor'
|
||||
re_tokenizer = 're-tokenizer'
|
||||
document_segmentation = 'document-segmentation'
|
||||
feature_extraction = 'feature-extraction'
|
||||
|
||||
# audio preprocessor
|
||||
linear_aec_fbank = 'linear-aec-fbank'
|
||||
@@ -234,6 +338,7 @@ class Metrics(object):
|
||||
|
||||
# accuracy
|
||||
accuracy = 'accuracy'
|
||||
audio_noise_metric = 'audio-noise-metric'
|
||||
|
||||
# metrics for image denoise task
|
||||
image_denoise_metric = 'image-denoise-metric'
|
||||
@@ -251,6 +356,8 @@ class Metrics(object):
|
||||
# metrics for image-portrait-enhancement task
|
||||
image_portrait_enhancement_metric = 'image-portrait-enhancement-metric'
|
||||
video_summarization_metric = 'video-summarization-metric'
|
||||
# metric for movie-scene-segmentation task
|
||||
movie_scene_segmentation_metric = 'movie-scene-segmentation-metric'
|
||||
|
||||
|
||||
class Optimizers(object):
|
||||
@@ -300,3 +407,13 @@ class LR_Schedulers(object):
|
||||
LinearWarmup = 'LinearWarmup'
|
||||
ConstantWarmup = 'ConstantWarmup'
|
||||
ExponentialWarmup = 'ExponentialWarmup'
|
||||
|
||||
|
||||
class Datasets(object):
|
||||
""" Names for different datasets.
|
||||
"""
|
||||
ClsDataset = 'ClsDataset'
|
||||
Face2dKeypointsDataset = 'Face2dKeypointsDataset'
|
||||
SegDataset = 'SegDataset'
|
||||
DetDataset = 'DetDataset'
|
||||
DetImagesMixDataset = 'DetImagesMixDataset'
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import TYPE_CHECKING
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .audio_noise_metric import AudioNoiseMetric
|
||||
from .base import Metric
|
||||
from .builder import METRICS, build_metric, task_default_metrics
|
||||
from .image_color_enhance_metric import ImageColorEnhanceMetric
|
||||
@@ -15,9 +16,11 @@ if TYPE_CHECKING:
|
||||
from .text_generation_metric import TextGenerationMetric
|
||||
from .token_classification_metric import TokenClassificationMetric
|
||||
from .video_summarization_metric import VideoSummarizationMetric
|
||||
from .movie_scene_segmentation_metric import MovieSceneSegmentationMetric
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
'audio_noise_metric': ['AudioNoiseMetric'],
|
||||
'base': ['Metric'],
|
||||
'builder': ['METRICS', 'build_metric', 'task_default_metrics'],
|
||||
'image_color_enhance_metric': ['ImageColorEnhanceMetric'],
|
||||
@@ -30,6 +33,7 @@ else:
|
||||
'text_generation_metric': ['TextGenerationMetric'],
|
||||
'token_classification_metric': ['TokenClassificationMetric'],
|
||||
'video_summarization_metric': ['VideoSummarizationMetric'],
|
||||
'movie_scene_segmentation_metric': ['MovieSceneSegmentationMetric'],
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
40
modelscope/metrics/audio_noise_metric.py
Normal file
40
modelscope/metrics/audio_noise_metric.py
Normal file
@@ -0,0 +1,40 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from typing import Dict
|
||||
|
||||
from modelscope.metainfo import Metrics
|
||||
from modelscope.metrics.base import Metric
|
||||
from modelscope.metrics.builder import METRICS, MetricKeys
|
||||
from modelscope.utils.registry import default_group
|
||||
|
||||
|
||||
@METRICS.register_module(
|
||||
group_key=default_group, module_name=Metrics.audio_noise_metric)
|
||||
class AudioNoiseMetric(Metric):
|
||||
"""
|
||||
The metric computation class for acoustic noise suppression task.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.loss = []
|
||||
self.amp_loss = []
|
||||
self.phase_loss = []
|
||||
self.sisnr = []
|
||||
|
||||
def add(self, outputs: Dict, inputs: Dict):
|
||||
self.loss.append(outputs['loss'].data.cpu())
|
||||
self.amp_loss.append(outputs['amp_loss'].data.cpu())
|
||||
self.phase_loss.append(outputs['phase_loss'].data.cpu())
|
||||
self.sisnr.append(outputs['sisnr'].data.cpu())
|
||||
|
||||
def evaluate(self):
|
||||
avg_loss = sum(self.loss) / len(self.loss)
|
||||
avg_sisnr = sum(self.sisnr) / len(self.sisnr)
|
||||
avg_amp = sum(self.amp_loss) / len(self.amp_loss)
|
||||
avg_phase = sum(self.phase_loss) / len(self.phase_loss)
|
||||
total_loss = avg_loss + avg_amp + avg_phase + avg_sisnr
|
||||
return {
|
||||
'total_loss': total_loss.item(),
|
||||
'avg_sisnr': avg_sisnr.item(),
|
||||
MetricKeys.AVERAGE_LOSS: avg_loss.item()
|
||||
}
|
||||
@@ -1,4 +1,5 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import Dict, Mapping, Union
|
||||
|
||||
from modelscope.metainfo import Metrics
|
||||
from modelscope.utils.config import ConfigDict
|
||||
@@ -15,7 +16,12 @@ class MetricKeys(object):
|
||||
RECALL = 'recall'
|
||||
PSNR = 'psnr'
|
||||
SSIM = 'ssim'
|
||||
AVERAGE_LOSS = 'avg_loss'
|
||||
FScore = 'fscore'
|
||||
BLEU_1 = 'bleu-1'
|
||||
BLEU_4 = 'bleu-4'
|
||||
ROUGE_1 = 'rouge-1'
|
||||
ROUGE_L = 'rouge-l'
|
||||
|
||||
|
||||
task_default_metrics = {
|
||||
@@ -30,19 +36,25 @@ task_default_metrics = {
|
||||
Tasks.image_portrait_enhancement:
|
||||
[Metrics.image_portrait_enhancement_metric],
|
||||
Tasks.video_summarization: [Metrics.video_summarization_metric],
|
||||
Tasks.image_captioning: [Metrics.text_gen_metric],
|
||||
Tasks.visual_question_answering: [Metrics.text_gen_metric],
|
||||
Tasks.movie_scene_segmentation: [Metrics.movie_scene_segmentation_metric],
|
||||
}
|
||||
|
||||
|
||||
def build_metric(metric_name: str,
|
||||
def build_metric(metric_cfg: Union[str, Dict],
|
||||
field: str = default_group,
|
||||
default_args: dict = None):
|
||||
""" Build metric given metric_name and field.
|
||||
|
||||
Args:
|
||||
metric_name (:obj:`str`): The metric name.
|
||||
metric_name (str | dict): The metric name or metric config dict.
|
||||
field (str, optional): The field of this metric, default value: 'default' for all fields.
|
||||
default_args (dict, optional): Default initialization arguments.
|
||||
"""
|
||||
cfg = ConfigDict({'type': metric_name})
|
||||
if isinstance(metric_cfg, Mapping):
|
||||
assert 'type' in metric_cfg
|
||||
else:
|
||||
metric_cfg = ConfigDict({'type': metric_cfg})
|
||||
return build_from_cfg(
|
||||
cfg, METRICS, group_key=field, default_args=default_args)
|
||||
metric_cfg, METRICS, group_key=field, default_args=default_args)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
# Part of the implementation is borrowed and modified from MMDetection, publicly available at
|
||||
# https://github.com/open-mmlab/mmdetection/blob/master/mmdet/datasets/coco.py
|
||||
import os.path as osp
|
||||
import tempfile
|
||||
from collections import OrderedDict
|
||||
|
||||
54
modelscope/metrics/movie_scene_segmentation_metric.py
Normal file
54
modelscope/metrics/movie_scene_segmentation_metric.py
Normal file
@@ -0,0 +1,54 @@
|
||||
# The implementation here is modified based on BaSSL,
|
||||
# originally Apache 2.0 License and publicly available at https://github.com/kakaobrain/bassl
|
||||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
|
||||
from modelscope.metainfo import Metrics
|
||||
from modelscope.utils.registry import default_group
|
||||
from modelscope.utils.tensor_utils import (torch_nested_detach,
|
||||
torch_nested_numpify)
|
||||
from .base import Metric
|
||||
from .builder import METRICS, MetricKeys
|
||||
|
||||
|
||||
@METRICS.register_module(
|
||||
group_key=default_group,
|
||||
module_name=Metrics.movie_scene_segmentation_metric)
|
||||
class MovieSceneSegmentationMetric(Metric):
|
||||
"""The metric computation class for movie scene segmentation classes.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.preds = []
|
||||
self.labels = []
|
||||
self.eps = 1e-5
|
||||
|
||||
def add(self, outputs: Dict, inputs: Dict):
|
||||
preds = outputs['pred']
|
||||
labels = inputs['label']
|
||||
self.preds.extend(preds)
|
||||
self.labels.extend(labels)
|
||||
|
||||
def evaluate(self):
|
||||
gts = np.array(torch_nested_numpify(torch_nested_detach(self.labels)))
|
||||
prob = np.array(torch_nested_numpify(torch_nested_detach(self.preds)))
|
||||
|
||||
gt_one = gts == 1
|
||||
gt_zero = gts == 0
|
||||
pred_one = prob == 1
|
||||
pred_zero = prob == 0
|
||||
|
||||
tp = (gt_one * pred_one).sum()
|
||||
fp = (gt_zero * pred_one).sum()
|
||||
fn = (gt_one * pred_zero).sum()
|
||||
|
||||
precision = 100.0 * tp / (tp + fp + self.eps)
|
||||
recall = 100.0 * tp / (tp + fn + self.eps)
|
||||
f1 = 2 * precision * recall / (precision + recall)
|
||||
|
||||
return {
|
||||
MetricKeys.F1: f1,
|
||||
MetricKeys.RECALL: recall,
|
||||
MetricKeys.PRECISION: precision
|
||||
}
|
||||
@@ -1,3 +1,5 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
@@ -14,9 +16,9 @@ from .builder import METRICS, MetricKeys
|
||||
@METRICS.register_module(
|
||||
group_key=default_group, module_name=Metrics.seq_cls_metric)
|
||||
class SequenceClassificationMetric(Metric):
|
||||
"""The metric computation class for sequence classification classes.
|
||||
"""The metric computation class for sequence classification tasks.
|
||||
|
||||
This metric class calculates accuracy for the whole input batches.
|
||||
This metric class calculates accuracy of the whole input batches.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
||||
@@ -1,9 +1,14 @@
|
||||
from typing import Dict
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from typing import Dict, Iterable, List
|
||||
|
||||
from nltk.translate.bleu_score import sentence_bleu
|
||||
from rouge import Rouge
|
||||
|
||||
from modelscope.metainfo import Metrics
|
||||
from modelscope.metrics.base import Metric
|
||||
from modelscope.metrics.builder import METRICS, MetricKeys
|
||||
from modelscope.utils.registry import default_group
|
||||
from .base import Metric
|
||||
from .builder import METRICS, MetricKeys
|
||||
|
||||
|
||||
@METRICS.register_module(
|
||||
@@ -15,20 +20,49 @@ class TextGenerationMetric(Metric):
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.preds = []
|
||||
self.tgts = []
|
||||
from rouge_score import rouge_scorer
|
||||
self.scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)
|
||||
self.preds: List[str] = []
|
||||
self.tgts: List[str] = []
|
||||
self.rouge = Rouge()
|
||||
|
||||
def add(self, outputs: Dict, inputs: Dict):
|
||||
@staticmethod
|
||||
def is_chinese_char(char: str):
|
||||
# the length of char must be 1
|
||||
return '\u4e00' <= char <= '\u9fa5'
|
||||
|
||||
# add space for each chinese char
|
||||
def rebuild_str(self, string: str):
|
||||
return ' '.join(''.join([
|
||||
f' {char} ' if self.is_chinese_char(char) else char
|
||||
for char in string
|
||||
]).split())
|
||||
|
||||
def add(self, outputs: Dict[str, List[str]], inputs: Dict = None):
|
||||
ground_truths = outputs['tgts']
|
||||
eval_results = outputs['preds']
|
||||
self.preds.extend(eval_results)
|
||||
self.tgts.extend(ground_truths)
|
||||
for truth in ground_truths:
|
||||
self.tgts.append(self.rebuild_str(truth))
|
||||
for result in eval_results:
|
||||
self.preds.append(self.rebuild_str(result))
|
||||
|
||||
def evaluate(self):
|
||||
scores = [
|
||||
self.scorer.score(pred, tgt)['rougeL'].fmeasure
|
||||
for pred, tgt in zip(self.preds, self.tgts)
|
||||
]
|
||||
return {MetricKeys.F1: sum(scores) / len(scores)}
|
||||
|
||||
def mean(iter: Iterable) -> float:
|
||||
return sum(iter) / len(self.preds)
|
||||
|
||||
rouge_scores = self.rouge.get_scores(hyps=self.preds, refs=self.tgts)
|
||||
rouge_1 = mean(map(lambda score: score['rouge-1']['f'], rouge_scores))
|
||||
rouge_l = mean(map(lambda score: score['rouge-l']['f'], rouge_scores))
|
||||
pred_split = tuple(pred.split(' ') for pred in self.preds)
|
||||
tgt_split = tuple(tgt.split(' ') for tgt in self.tgts)
|
||||
bleu_1 = mean(
|
||||
sentence_bleu([tgt], pred, weights=(1, 0, 0, 0))
|
||||
for pred, tgt in zip(pred_split, tgt_split))
|
||||
bleu_4 = mean(
|
||||
sentence_bleu([tgt], pred)
|
||||
for pred, tgt in zip(pred_split, tgt_split))
|
||||
return {
|
||||
MetricKeys.ROUGE_1: rouge_1,
|
||||
MetricKeys.ROUGE_L: rouge_l,
|
||||
MetricKeys.BLEU_1: bleu_1,
|
||||
MetricKeys.BLEU_4: bleu_4
|
||||
}
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import importlib
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from .layer_base import LayerBase
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
import torch.nn as nn
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
import torch.nn as nn
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import abc
|
||||
import re
|
||||
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
import torch.nn as nn
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
@@ -4,11 +4,11 @@ from typing import TYPE_CHECKING
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .frcrn import FRCRNModel
|
||||
from .frcrn import FRCRNDecorator
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
'frcrn': ['FRCRNModel'],
|
||||
'frcrn': ['FRCRNDecorator'],
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
@@ -1,3 +1,10 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
#
|
||||
# The implementation of class ComplexConv2d, ComplexConvTranspose2d and
|
||||
# ComplexBatchNorm2d here is modified based on Jongho Choi(sweetcocoa@snu.ac.kr
|
||||
# / Seoul National Univ., ESTsoft ) and publicly available at
|
||||
# https://github.com/sweetcocoa/DeepComplexUNetPyTorch
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
from typing import Dict
|
||||
|
||||
@@ -14,54 +15,10 @@ from .conv_stft import ConviSTFT, ConvSTFT
|
||||
from .unet import UNet
|
||||
|
||||
|
||||
class FTB(nn.Module):
|
||||
|
||||
def __init__(self, input_dim=257, in_channel=9, r_channel=5):
|
||||
|
||||
super(FTB, self).__init__()
|
||||
self.in_channel = in_channel
|
||||
self.conv1 = nn.Sequential(
|
||||
nn.Conv2d(in_channel, r_channel, kernel_size=[1, 1]),
|
||||
nn.BatchNorm2d(r_channel), nn.ReLU())
|
||||
|
||||
self.conv1d = nn.Sequential(
|
||||
nn.Conv1d(
|
||||
r_channel * input_dim, in_channel, kernel_size=9, padding=4),
|
||||
nn.BatchNorm1d(in_channel), nn.ReLU())
|
||||
self.freq_fc = nn.Linear(input_dim, input_dim, bias=False)
|
||||
|
||||
self.conv2 = nn.Sequential(
|
||||
nn.Conv2d(in_channel * 2, in_channel, kernel_size=[1, 1]),
|
||||
nn.BatchNorm2d(in_channel), nn.ReLU())
|
||||
|
||||
def forward(self, inputs):
|
||||
'''
|
||||
inputs should be [Batch, Ca, Dim, Time]
|
||||
'''
|
||||
# T-F attention
|
||||
conv1_out = self.conv1(inputs)
|
||||
B, C, D, T = conv1_out.size()
|
||||
reshape1_out = torch.reshape(conv1_out, [B, C * D, T])
|
||||
conv1d_out = self.conv1d(reshape1_out)
|
||||
conv1d_out = torch.reshape(conv1d_out, [B, self.in_channel, 1, T])
|
||||
|
||||
# now is also [B,C,D,T]
|
||||
att_out = conv1d_out * inputs
|
||||
|
||||
# tranpose to [B,C,T,D]
|
||||
att_out = torch.transpose(att_out, 2, 3)
|
||||
freqfc_out = self.freq_fc(att_out)
|
||||
att_out = torch.transpose(freqfc_out, 2, 3)
|
||||
|
||||
cat_out = torch.cat([att_out, inputs], 1)
|
||||
outputs = self.conv2(cat_out)
|
||||
return outputs
|
||||
|
||||
|
||||
@MODELS.register_module(
|
||||
Tasks.acoustic_noise_suppression,
|
||||
module_name=Models.speech_frcrn_ans_cirm_16k)
|
||||
class FRCRNModel(TorchModel):
|
||||
class FRCRNDecorator(TorchModel):
|
||||
r""" A decorator of FRCRN for integrating into modelscope framework """
|
||||
|
||||
def __init__(self, model_dir: str, *args, **kwargs):
|
||||
@@ -71,32 +28,42 @@ class FRCRNModel(TorchModel):
|
||||
model_dir (str): the model path.
|
||||
"""
|
||||
super().__init__(model_dir, *args, **kwargs)
|
||||
kwargs.pop('device')
|
||||
self.model = FRCRN(*args, **kwargs)
|
||||
model_bin_file = os.path.join(model_dir,
|
||||
ModelFile.TORCH_MODEL_BIN_FILE)
|
||||
if os.path.exists(model_bin_file):
|
||||
checkpoint = torch.load(model_bin_file)
|
||||
self.model.load_state_dict(checkpoint, strict=False)
|
||||
checkpoint = torch.load(
|
||||
model_bin_file, map_location=torch.device('cpu'))
|
||||
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
|
||||
# the new trained model by user is based on FRCRNDecorator
|
||||
self.load_state_dict(checkpoint['state_dict'])
|
||||
else:
|
||||
# The released model on Modelscope is based on FRCRN
|
||||
self.model.load_state_dict(checkpoint, strict=False)
|
||||
|
||||
def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
||||
output = self.model.forward(input)
|
||||
return {
|
||||
'spec_l1': output[0],
|
||||
'wav_l1': output[1],
|
||||
'mask_l1': output[2],
|
||||
'spec_l2': output[3],
|
||||
'wav_l2': output[4],
|
||||
'mask_l2': output[5]
|
||||
def forward(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
||||
result_list = self.model.forward(inputs['noisy'])
|
||||
output = {
|
||||
'spec_l1': result_list[0],
|
||||
'wav_l1': result_list[1],
|
||||
'mask_l1': result_list[2],
|
||||
'spec_l2': result_list[3],
|
||||
'wav_l2': result_list[4],
|
||||
'mask_l2': result_list[5]
|
||||
}
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
self.model = self.model.to(*args, **kwargs)
|
||||
return self
|
||||
|
||||
def eval(self):
|
||||
self.model = self.model.train(False)
|
||||
return self
|
||||
if 'clean' in inputs:
|
||||
mix_result = self.model.loss(
|
||||
inputs['noisy'], inputs['clean'], result_list, mode='Mix')
|
||||
output.update(mix_result)
|
||||
sisnr_result = self.model.loss(
|
||||
inputs['noisy'], inputs['clean'], result_list, mode='SiSNR')
|
||||
output.update(sisnr_result)
|
||||
# logger hooker will use items under 'log_vars'
|
||||
output['log_vars'] = {k: mix_result[k].item() for k in mix_result}
|
||||
output['log_vars'].update(
|
||||
{k: sisnr_result[k].item()
|
||||
for k in sisnr_result})
|
||||
return output
|
||||
|
||||
|
||||
class FRCRN(nn.Module):
|
||||
@@ -111,7 +78,8 @@ class FRCRN(nn.Module):
|
||||
win_len=400,
|
||||
win_inc=100,
|
||||
fft_len=512,
|
||||
win_type='hanning'):
|
||||
win_type='hanning',
|
||||
**kwargs):
|
||||
r"""
|
||||
Args:
|
||||
complex: Whether to use complex networks.
|
||||
@@ -237,7 +205,7 @@ class FRCRN(nn.Module):
|
||||
if count != 3:
|
||||
loss = self.loss_1layer(noisy, est_spec, est_wav, labels,
|
||||
est_mask, mode)
|
||||
return loss
|
||||
return dict(sisnr=loss)
|
||||
|
||||
elif mode == 'Mix':
|
||||
count = 0
|
||||
@@ -252,7 +220,7 @@ class FRCRN(nn.Module):
|
||||
amp_loss, phase_loss, SiSNR_loss = self.loss_1layer(
|
||||
noisy, est_spec, est_wav, labels, est_mask, mode)
|
||||
loss = amp_loss + phase_loss + SiSNR_loss
|
||||
return loss, amp_loss, phase_loss
|
||||
return dict(loss=loss, amp_loss=amp_loss, phase_loss=phase_loss)
|
||||
|
||||
def loss_1layer(self, noisy, est, est_wav, labels, cmp_mask, mode='Mix'):
|
||||
r""" Compute the loss by mode
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
@@ -1,3 +1,10 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
#
|
||||
# The implementation here is modified based on
|
||||
# Jongho Choi(sweetcocoa@snu.ac.kr / Seoul National Univ., ESTsoft )
|
||||
# and publicly available at
|
||||
# https://github.com/sweetcocoa/DeepComplexUNetPyTorch
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
from typing import Dict
|
||||
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import math
|
||||
import struct
|
||||
from enum import Enum
|
||||
|
||||
9
modelscope/models/audio/tts/models/__init__.py
Executable file → Normal file
9
modelscope/models/audio/tts/models/__init__.py
Executable file → Normal file
@@ -1,9 +0,0 @@
|
||||
from .robutrans import RobuTrans
|
||||
from .vocoder_models import Generator
|
||||
|
||||
|
||||
def create_am_model(name, hparams):
|
||||
if name == 'robutrans':
|
||||
return RobuTrans(hparams)
|
||||
else:
|
||||
raise Exception('Unknown model: ' + name)
|
||||
|
||||
@@ -1,460 +0,0 @@
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
def encoder_prenet(inputs,
|
||||
n_conv_layers,
|
||||
filters,
|
||||
kernel_size,
|
||||
dense_units,
|
||||
is_training,
|
||||
mask=None,
|
||||
scope='encoder_prenet'):
|
||||
x = inputs
|
||||
with tf.variable_scope(scope):
|
||||
for i in range(n_conv_layers):
|
||||
x = conv1d(
|
||||
x,
|
||||
filters,
|
||||
kernel_size,
|
||||
is_training,
|
||||
activation=tf.nn.relu,
|
||||
dropout=True,
|
||||
mask=mask,
|
||||
scope='conv1d_{}'.format(i))
|
||||
x = tf.layers.dense(
|
||||
x, units=dense_units, activation=None, name='dense')
|
||||
return x
|
||||
|
||||
|
||||
def decoder_prenet(inputs,
|
||||
prenet_units,
|
||||
dense_units,
|
||||
is_training,
|
||||
scope='decoder_prenet'):
|
||||
x = inputs
|
||||
with tf.variable_scope(scope):
|
||||
for i, units in enumerate(prenet_units):
|
||||
x = tf.layers.dense(
|
||||
x,
|
||||
units=units,
|
||||
activation=tf.nn.relu,
|
||||
name='dense_{}'.format(i))
|
||||
x = tf.layers.dropout(
|
||||
x, rate=0.5, training=is_training, name='dropout_{}'.format(i))
|
||||
x = tf.layers.dense(
|
||||
x, units=dense_units, activation=None, name='dense')
|
||||
return x
|
||||
|
||||
|
||||
def encoder(inputs,
|
||||
input_lengths,
|
||||
n_conv_layers,
|
||||
filters,
|
||||
kernel_size,
|
||||
lstm_units,
|
||||
is_training,
|
||||
embedded_inputs_speaker,
|
||||
mask=None,
|
||||
scope='encoder'):
|
||||
with tf.variable_scope(scope):
|
||||
x = conv_and_lstm(
|
||||
inputs,
|
||||
input_lengths,
|
||||
n_conv_layers,
|
||||
filters,
|
||||
kernel_size,
|
||||
lstm_units,
|
||||
is_training,
|
||||
embedded_inputs_speaker,
|
||||
mask=mask)
|
||||
return x
|
||||
|
||||
|
||||
def prenet(inputs, prenet_units, is_training, scope='prenet'):
|
||||
x = inputs
|
||||
with tf.variable_scope(scope):
|
||||
for i, units in enumerate(prenet_units):
|
||||
x = tf.layers.dense(
|
||||
x,
|
||||
units=units,
|
||||
activation=tf.nn.relu,
|
||||
name='dense_{}'.format(i))
|
||||
x = tf.layers.dropout(
|
||||
x, rate=0.5, training=is_training, name='dropout_{}'.format(i))
|
||||
return x
|
||||
|
||||
|
||||
def postnet_residual_ulstm(inputs,
|
||||
n_conv_layers,
|
||||
filters,
|
||||
kernel_size,
|
||||
lstm_units,
|
||||
output_units,
|
||||
is_training,
|
||||
scope='postnet_residual_ulstm'):
|
||||
with tf.variable_scope(scope):
|
||||
x = conv_and_ulstm(inputs, None, n_conv_layers, filters, kernel_size,
|
||||
lstm_units, is_training)
|
||||
x = conv1d(
|
||||
x,
|
||||
output_units,
|
||||
kernel_size,
|
||||
is_training,
|
||||
activation=None,
|
||||
dropout=False,
|
||||
scope='conv1d_{}'.format(n_conv_layers - 1))
|
||||
return x
|
||||
|
||||
|
||||
def postnet_residual_lstm(inputs,
|
||||
n_conv_layers,
|
||||
filters,
|
||||
kernel_size,
|
||||
lstm_units,
|
||||
output_units,
|
||||
is_training,
|
||||
scope='postnet_residual_lstm'):
|
||||
with tf.variable_scope(scope):
|
||||
x = conv_and_lstm(inputs, None, n_conv_layers, filters, kernel_size,
|
||||
lstm_units, is_training)
|
||||
x = conv1d(
|
||||
x,
|
||||
output_units,
|
||||
kernel_size,
|
||||
is_training,
|
||||
activation=None,
|
||||
dropout=False,
|
||||
scope='conv1d_{}'.format(n_conv_layers - 1))
|
||||
return x
|
||||
|
||||
|
||||
def postnet_linear_ulstm(inputs,
|
||||
n_conv_layers,
|
||||
filters,
|
||||
kernel_size,
|
||||
lstm_units,
|
||||
output_units,
|
||||
is_training,
|
||||
scope='postnet_linear'):
|
||||
with tf.variable_scope(scope):
|
||||
x = conv_and_ulstm(inputs, None, n_conv_layers, filters, kernel_size,
|
||||
lstm_units, is_training)
|
||||
x = tf.layers.dense(x, units=output_units)
|
||||
return x
|
||||
|
||||
|
||||
def postnet_linear_lstm(inputs,
|
||||
n_conv_layers,
|
||||
filters,
|
||||
kernel_size,
|
||||
lstm_units,
|
||||
output_units,
|
||||
output_lengths,
|
||||
is_training,
|
||||
embedded_inputs_speaker2,
|
||||
mask=None,
|
||||
scope='postnet_linear'):
|
||||
with tf.variable_scope(scope):
|
||||
x = conv_and_lstm_dec(
|
||||
inputs,
|
||||
output_lengths,
|
||||
n_conv_layers,
|
||||
filters,
|
||||
kernel_size,
|
||||
lstm_units,
|
||||
is_training,
|
||||
embedded_inputs_speaker2,
|
||||
mask=mask)
|
||||
x = tf.layers.dense(x, units=output_units)
|
||||
return x
|
||||
|
||||
|
||||
def postnet_linear(inputs,
|
||||
n_conv_layers,
|
||||
filters,
|
||||
kernel_size,
|
||||
lstm_units,
|
||||
output_units,
|
||||
output_lengths,
|
||||
is_training,
|
||||
embedded_inputs_speaker2,
|
||||
mask=None,
|
||||
scope='postnet_linear'):
|
||||
with tf.variable_scope(scope):
|
||||
x = conv_dec(
|
||||
inputs,
|
||||
output_lengths,
|
||||
n_conv_layers,
|
||||
filters,
|
||||
kernel_size,
|
||||
lstm_units,
|
||||
is_training,
|
||||
embedded_inputs_speaker2,
|
||||
mask=mask)
|
||||
return x
|
||||
|
||||
|
||||
def conv_and_lstm(inputs,
|
||||
sequence_lengths,
|
||||
n_conv_layers,
|
||||
filters,
|
||||
kernel_size,
|
||||
lstm_units,
|
||||
is_training,
|
||||
embedded_inputs_speaker,
|
||||
mask=None,
|
||||
scope='conv_and_lstm'):
|
||||
from tensorflow.contrib.rnn import LSTMBlockCell
|
||||
x = inputs
|
||||
with tf.variable_scope(scope):
|
||||
for i in range(n_conv_layers):
|
||||
x = conv1d(
|
||||
x,
|
||||
filters,
|
||||
kernel_size,
|
||||
is_training,
|
||||
activation=tf.nn.relu,
|
||||
dropout=True,
|
||||
mask=mask,
|
||||
scope='conv1d_{}'.format(i))
|
||||
|
||||
x = tf.concat([x, embedded_inputs_speaker], axis=2)
|
||||
|
||||
outputs, states = tf.nn.bidirectional_dynamic_rnn(
|
||||
LSTMBlockCell(lstm_units),
|
||||
LSTMBlockCell(lstm_units),
|
||||
x,
|
||||
sequence_length=sequence_lengths,
|
||||
dtype=tf.float32)
|
||||
x = tf.concat(outputs, axis=-1)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def conv_and_lstm_dec(inputs,
|
||||
sequence_lengths,
|
||||
n_conv_layers,
|
||||
filters,
|
||||
kernel_size,
|
||||
lstm_units,
|
||||
is_training,
|
||||
embedded_inputs_speaker2,
|
||||
mask=None,
|
||||
scope='conv_and_lstm'):
|
||||
x = inputs
|
||||
from tensorflow.contrib.rnn import LSTMBlockCell
|
||||
with tf.variable_scope(scope):
|
||||
for i in range(n_conv_layers):
|
||||
x = conv1d(
|
||||
x,
|
||||
filters,
|
||||
kernel_size,
|
||||
is_training,
|
||||
activation=tf.nn.relu,
|
||||
dropout=True,
|
||||
mask=mask,
|
||||
scope='conv1d_{}'.format(i))
|
||||
|
||||
x = tf.concat([x, embedded_inputs_speaker2], axis=2)
|
||||
|
||||
outputs, states = tf.nn.bidirectional_dynamic_rnn(
|
||||
LSTMBlockCell(lstm_units),
|
||||
LSTMBlockCell(lstm_units),
|
||||
x,
|
||||
sequence_length=sequence_lengths,
|
||||
dtype=tf.float32)
|
||||
x = tf.concat(outputs, axis=-1)
|
||||
return x
|
||||
|
||||
|
||||
def conv_dec(inputs,
|
||||
sequence_lengths,
|
||||
n_conv_layers,
|
||||
filters,
|
||||
kernel_size,
|
||||
lstm_units,
|
||||
is_training,
|
||||
embedded_inputs_speaker2,
|
||||
mask=None,
|
||||
scope='conv_and_lstm'):
|
||||
x = inputs
|
||||
with tf.variable_scope(scope):
|
||||
for i in range(n_conv_layers):
|
||||
x = conv1d(
|
||||
x,
|
||||
filters,
|
||||
kernel_size,
|
||||
is_training,
|
||||
activation=tf.nn.relu,
|
||||
dropout=True,
|
||||
mask=mask,
|
||||
scope='conv1d_{}'.format(i))
|
||||
x = tf.concat([x, embedded_inputs_speaker2], axis=2)
|
||||
return x
|
||||
|
||||
|
||||
def conv_and_ulstm(inputs,
|
||||
sequence_lengths,
|
||||
n_conv_layers,
|
||||
filters,
|
||||
kernel_size,
|
||||
lstm_units,
|
||||
is_training,
|
||||
scope='conv_and_ulstm'):
|
||||
x = inputs
|
||||
with tf.variable_scope(scope):
|
||||
for i in range(n_conv_layers):
|
||||
x = conv1d(
|
||||
x,
|
||||
filters,
|
||||
kernel_size,
|
||||
is_training,
|
||||
activation=tf.nn.relu,
|
||||
dropout=True,
|
||||
scope='conv1d_{}'.format(i))
|
||||
|
||||
outputs, states = tf.nn.dynamic_rnn(
|
||||
LSTMBlockCell(lstm_units),
|
||||
x,
|
||||
sequence_length=sequence_lengths,
|
||||
dtype=tf.float32)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
def conv1d(inputs,
|
||||
filters,
|
||||
kernel_size,
|
||||
is_training,
|
||||
activation=None,
|
||||
dropout=False,
|
||||
mask=None,
|
||||
scope='conv1d'):
|
||||
with tf.variable_scope(scope):
|
||||
if mask is not None:
|
||||
inputs = inputs * tf.expand_dims(mask, -1)
|
||||
x = tf.layers.conv1d(
|
||||
inputs, filters=filters, kernel_size=kernel_size, padding='same')
|
||||
if mask is not None:
|
||||
x = x * tf.expand_dims(mask, -1)
|
||||
|
||||
x = tf.layers.batch_normalization(x, training=is_training)
|
||||
if activation is not None:
|
||||
x = activation(x)
|
||||
if dropout:
|
||||
x = tf.layers.dropout(x, rate=0.5, training=is_training)
|
||||
return x
|
||||
|
||||
|
||||
def conv1d_dp(inputs,
|
||||
filters,
|
||||
kernel_size,
|
||||
is_training,
|
||||
activation=None,
|
||||
dropout=False,
|
||||
dropoutrate=0.5,
|
||||
mask=None,
|
||||
scope='conv1d'):
|
||||
with tf.variable_scope(scope):
|
||||
if mask is not None:
|
||||
inputs = inputs * tf.expand_dims(mask, -1)
|
||||
x = tf.layers.conv1d(
|
||||
inputs, filters=filters, kernel_size=kernel_size, padding='same')
|
||||
if mask is not None:
|
||||
x = x * tf.expand_dims(mask, -1)
|
||||
|
||||
x = tf.contrib.layers.layer_norm(x)
|
||||
if activation is not None:
|
||||
x = activation(x)
|
||||
if dropout:
|
||||
x = tf.layers.dropout(x, rate=dropoutrate, training=is_training)
|
||||
return x
|
||||
|
||||
|
||||
def duration_predictor(inputs,
|
||||
n_conv_layers,
|
||||
filters,
|
||||
kernel_size,
|
||||
lstm_units,
|
||||
input_lengths,
|
||||
is_training,
|
||||
embedded_inputs_speaker,
|
||||
mask=None,
|
||||
scope='duration_predictor'):
|
||||
with tf.variable_scope(scope):
|
||||
x = inputs
|
||||
for i in range(n_conv_layers):
|
||||
x = conv1d_dp(
|
||||
x,
|
||||
filters,
|
||||
kernel_size,
|
||||
is_training,
|
||||
activation=tf.nn.relu,
|
||||
dropout=True,
|
||||
dropoutrate=0.1,
|
||||
mask=mask,
|
||||
scope='conv1d_{}'.format(i))
|
||||
|
||||
x = tf.concat([x, embedded_inputs_speaker], axis=2)
|
||||
|
||||
outputs, states = tf.nn.bidirectional_dynamic_rnn(
|
||||
LSTMBlockCell(lstm_units),
|
||||
LSTMBlockCell(lstm_units),
|
||||
x,
|
||||
sequence_length=input_lengths,
|
||||
dtype=tf.float32)
|
||||
x = tf.concat(outputs, axis=-1)
|
||||
|
||||
x = tf.layers.dense(x, units=1)
|
||||
x = tf.nn.relu(x)
|
||||
return x
|
||||
|
||||
|
||||
def duration_predictor2(inputs,
|
||||
n_conv_layers,
|
||||
filters,
|
||||
kernel_size,
|
||||
input_lengths,
|
||||
is_training,
|
||||
mask=None,
|
||||
scope='duration_predictor'):
|
||||
with tf.variable_scope(scope):
|
||||
x = inputs
|
||||
for i in range(n_conv_layers):
|
||||
x = conv1d_dp(
|
||||
x,
|
||||
filters,
|
||||
kernel_size,
|
||||
is_training,
|
||||
activation=tf.nn.relu,
|
||||
dropout=True,
|
||||
dropoutrate=0.1,
|
||||
mask=mask,
|
||||
scope='conv1d_{}'.format(i))
|
||||
|
||||
x = tf.layers.dense(x, units=1)
|
||||
x = tf.nn.relu(x)
|
||||
return x
|
||||
|
||||
|
||||
def conv_prenet(inputs,
|
||||
n_conv_layers,
|
||||
filters,
|
||||
kernel_size,
|
||||
is_training,
|
||||
mask=None,
|
||||
scope='conv_prenet'):
|
||||
x = inputs
|
||||
with tf.variable_scope(scope):
|
||||
for i in range(n_conv_layers):
|
||||
x = conv1d(
|
||||
x,
|
||||
filters,
|
||||
kernel_size,
|
||||
is_training,
|
||||
activation=tf.nn.relu,
|
||||
dropout=True,
|
||||
mask=mask,
|
||||
scope='conv1d_{}'.format(i))
|
||||
|
||||
return x
|
||||
@@ -1,82 +0,0 @@
|
||||
"""Functions for compatibility with different TensorFlow versions."""
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
def is_tf2():
|
||||
"""Returns ``True`` if running TensorFlow 2.0."""
|
||||
return tf.__version__.startswith('2')
|
||||
|
||||
|
||||
def tf_supports(symbol):
|
||||
"""Returns ``True`` if TensorFlow defines :obj:`symbol`."""
|
||||
return _string_to_tf_symbol(symbol) is not None
|
||||
|
||||
|
||||
def tf_any(*symbols):
|
||||
"""Returns the first supported symbol."""
|
||||
for symbol in symbols:
|
||||
module = _string_to_tf_symbol(symbol)
|
||||
if module is not None:
|
||||
return module
|
||||
return None
|
||||
|
||||
|
||||
def tf_compat(v2=None, v1=None): # pylint: disable=invalid-name
|
||||
"""Returns the compatible symbol based on the current TensorFlow version.
|
||||
|
||||
Args:
|
||||
v2: The candidate v2 symbol name.
|
||||
v1: The candidate v1 symbol name.
|
||||
|
||||
Returns:
|
||||
A TensorFlow symbol.
|
||||
|
||||
Raises:
|
||||
ValueError: if no symbol can be found.
|
||||
"""
|
||||
candidates = []
|
||||
if v2 is not None:
|
||||
candidates.append(v2)
|
||||
if v1 is not None:
|
||||
candidates.append(v1)
|
||||
candidates.append('compat.v1.%s' % v1)
|
||||
symbol = tf_any(*candidates)
|
||||
if symbol is None:
|
||||
raise ValueError('Failure to resolve the TensorFlow symbol')
|
||||
return symbol
|
||||
|
||||
|
||||
def name_from_variable_scope(name=''):
|
||||
"""Creates a name prefixed by the current variable scope."""
|
||||
var_scope = tf_compat(v1='get_variable_scope')().name
|
||||
compat_name = ''
|
||||
if name:
|
||||
compat_name = '%s/' % name
|
||||
if var_scope:
|
||||
compat_name = '%s/%s' % (var_scope, compat_name)
|
||||
return compat_name
|
||||
|
||||
|
||||
def reuse():
|
||||
"""Returns ``True`` if the current variable scope is marked for reuse."""
|
||||
return tf_compat(v1='get_variable_scope')().reuse
|
||||
|
||||
|
||||
def _string_to_tf_symbol(symbol):
|
||||
modules = symbol.split('.')
|
||||
namespace = tf
|
||||
for module in modules:
|
||||
namespace = getattr(namespace, module, None)
|
||||
if namespace is None:
|
||||
return None
|
||||
return namespace
|
||||
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
gfile_copy = tf_compat(v2='io.gfile.copy', v1='gfile.Copy')
|
||||
gfile_exists = tf_compat(v2='io.gfile.exists', v1='gfile.Exists')
|
||||
gfile_open = tf_compat(v2='io.gfile.GFile', v1='gfile.GFile')
|
||||
is_tensor = tf_compat(v2='is_tensor', v1='contrib.framework.is_tensor')
|
||||
logging = tf_compat(v1='logging')
|
||||
nest = tf_compat(v2='nest', v1='contrib.framework.nest')
|
||||
238
modelscope/models/audio/tts/models/datasets/kantts_data4fs.py
Normal file
238
modelscope/models/audio/tts/models/datasets/kantts_data4fs.py
Normal file
@@ -0,0 +1,238 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
|
||||
import json
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
from modelscope.utils.logger import get_logger
|
||||
from .units import KanTtsLinguisticUnit
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class KanTtsText2MelDataset(Dataset):
|
||||
|
||||
def __init__(self, metadata_filename, config_filename, cache=False):
|
||||
super(KanTtsText2MelDataset, self).__init__()
|
||||
|
||||
self.cache = cache
|
||||
|
||||
with open(config_filename) as f:
|
||||
self._config = json.loads(f.read())
|
||||
|
||||
# Load metadata:
|
||||
self._datadir = os.path.dirname(metadata_filename)
|
||||
with open(metadata_filename, encoding='utf-8') as f:
|
||||
self._metadata = [line.strip().split('|') for line in f]
|
||||
self._length_lst = [int(x[2]) for x in self._metadata]
|
||||
hours = sum(
|
||||
self._length_lst) * self._config['audio']['frame_shift_ms'] / (
|
||||
3600 * 1000)
|
||||
|
||||
logger.info('Loaded metadata for %d examples (%.2f hours)' %
|
||||
(len(self._metadata), hours))
|
||||
logger.info('Minimum length: %d, Maximum length: %d' %
|
||||
(min(self._length_lst), max(self._length_lst)))
|
||||
|
||||
self.ling_unit = KanTtsLinguisticUnit(config_filename)
|
||||
self.pad_executor = KanTtsText2MelPad()
|
||||
|
||||
self.r = self._config['am']['outputs_per_step']
|
||||
self.num_mels = self._config['am']['num_mels']
|
||||
|
||||
if 'adv' in self._config:
|
||||
self.feat_window = self._config['adv']['random_window']
|
||||
else:
|
||||
self.feat_window = None
|
||||
logger.info(self.feat_window)
|
||||
|
||||
self.data_cache = [
|
||||
self.cache_load(i) for i in tqdm(range(self.__len__()))
|
||||
] if self.cache else []
|
||||
|
||||
def get_frames_lst(self):
|
||||
return self._length_lst
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self.cache:
|
||||
sample = self.data_cache[index]
|
||||
return sample
|
||||
|
||||
return self.cache_load(index)
|
||||
|
||||
def cache_load(self, index):
|
||||
sample = {}
|
||||
|
||||
meta = self._metadata[index]
|
||||
|
||||
sample['utt_id'] = meta[0]
|
||||
|
||||
sample['mel_target'] = np.load(os.path.join(
|
||||
self._datadir, meta[1]))[:, :self.num_mels]
|
||||
sample['output_length'] = len(sample['mel_target'])
|
||||
|
||||
lfeat_symbol = meta[3]
|
||||
sample['ling'] = self.ling_unit.encode_symbol_sequence(lfeat_symbol)
|
||||
|
||||
sample['duration'] = np.load(os.path.join(self._datadir, meta[4]))
|
||||
|
||||
sample['pitch_contour'] = np.load(os.path.join(self._datadir, meta[5]))
|
||||
|
||||
sample['energy_contour'] = np.load(
|
||||
os.path.join(self._datadir, meta[6]))
|
||||
|
||||
return sample
|
||||
|
||||
def __len__(self):
|
||||
return len(self._metadata)
|
||||
|
||||
def collate_fn(self, batch):
|
||||
data_dict = {}
|
||||
|
||||
max_input_length = max((len(x['ling'][0]) for x in batch))
|
||||
|
||||
# pure linguistic info: sy|tone|syllable_flag|word_segment
|
||||
|
||||
# sy
|
||||
lfeat_type = self.ling_unit._lfeat_type_list[0]
|
||||
inputs_sy = self.pad_executor._prepare_scalar_inputs(
|
||||
[x['ling'][0] for x in batch], max_input_length,
|
||||
self.ling_unit._sub_unit_pad[lfeat_type]).long()
|
||||
# tone
|
||||
lfeat_type = self.ling_unit._lfeat_type_list[1]
|
||||
inputs_tone = self.pad_executor._prepare_scalar_inputs(
|
||||
[x['ling'][1] for x in batch], max_input_length,
|
||||
self.ling_unit._sub_unit_pad[lfeat_type]).long()
|
||||
|
||||
# syllable_flag
|
||||
lfeat_type = self.ling_unit._lfeat_type_list[2]
|
||||
inputs_syllable_flag = self.pad_executor._prepare_scalar_inputs(
|
||||
[x['ling'][2] for x in batch], max_input_length,
|
||||
self.ling_unit._sub_unit_pad[lfeat_type]).long()
|
||||
|
||||
# word_segment
|
||||
lfeat_type = self.ling_unit._lfeat_type_list[3]
|
||||
inputs_ws = self.pad_executor._prepare_scalar_inputs(
|
||||
[x['ling'][3] for x in batch], max_input_length,
|
||||
self.ling_unit._sub_unit_pad[lfeat_type]).long()
|
||||
|
||||
# emotion category
|
||||
lfeat_type = self.ling_unit._lfeat_type_list[4]
|
||||
data_dict['input_emotions'] = self.pad_executor._prepare_scalar_inputs(
|
||||
[x['ling'][4] for x in batch], max_input_length,
|
||||
self.ling_unit._sub_unit_pad[lfeat_type]).long()
|
||||
|
||||
# speaker category
|
||||
lfeat_type = self.ling_unit._lfeat_type_list[5]
|
||||
data_dict['input_speakers'] = self.pad_executor._prepare_scalar_inputs(
|
||||
[x['ling'][5] for x in batch], max_input_length,
|
||||
self.ling_unit._sub_unit_pad[lfeat_type]).long()
|
||||
|
||||
data_dict['input_lings'] = torch.stack(
|
||||
[inputs_sy, inputs_tone, inputs_syllable_flag, inputs_ws], dim=2)
|
||||
|
||||
data_dict['valid_input_lengths'] = torch.as_tensor(
|
||||
[len(x['ling'][0]) - 1 for x in batch], dtype=torch.long
|
||||
) # There is one '~' in the last of symbol sequence. We put length-1 for calculation.
|
||||
|
||||
data_dict['valid_output_lengths'] = torch.as_tensor(
|
||||
[x['output_length'] for x in batch], dtype=torch.long)
|
||||
max_output_length = torch.max(data_dict['valid_output_lengths']).item()
|
||||
max_output_round_length = self.pad_executor._round_up(
|
||||
max_output_length, self.r)
|
||||
|
||||
if self.feat_window is not None:
|
||||
active_feat_len = np.minimum(max_output_round_length,
|
||||
self.feat_window)
|
||||
if active_feat_len < self.feat_window:
|
||||
max_output_round_length = self.pad_executor._round_up(
|
||||
self.feat_window, self.r)
|
||||
active_feat_len = self.feat_window
|
||||
|
||||
max_offsets = [x['output_length'] - active_feat_len for x in batch]
|
||||
feat_offsets = [
|
||||
np.random.randint(0, np.maximum(1, offset))
|
||||
for offset in max_offsets
|
||||
]
|
||||
feat_offsets = torch.from_numpy(
|
||||
np.asarray(feat_offsets, dtype=np.int32)).long()
|
||||
data_dict['feat_offsets'] = feat_offsets
|
||||
|
||||
data_dict['mel_targets'] = self.pad_executor._prepare_targets(
|
||||
[x['mel_target'] for x in batch], max_output_round_length, 0.0)
|
||||
data_dict['durations'] = self.pad_executor._prepare_durations(
|
||||
[x['duration'] for x in batch], max_input_length,
|
||||
max_output_round_length)
|
||||
|
||||
data_dict['pitch_contours'] = self.pad_executor._prepare_scalar_inputs(
|
||||
[x['pitch_contour'] for x in batch], max_input_length,
|
||||
0.0).float()
|
||||
data_dict[
|
||||
'energy_contours'] = self.pad_executor._prepare_scalar_inputs(
|
||||
[x['energy_contour'] for x in batch], max_input_length,
|
||||
0.0).float()
|
||||
|
||||
data_dict['utt_ids'] = [x['utt_id'] for x in batch]
|
||||
|
||||
return data_dict
|
||||
|
||||
|
||||
class KanTtsText2MelPad(object):
|
||||
|
||||
def __init__(self):
|
||||
super(KanTtsText2MelPad, self).__init__()
|
||||
pass
|
||||
|
||||
def _pad1D(self, x, length, pad):
|
||||
return np.pad(
|
||||
x, (0, length - x.shape[0]), mode='constant', constant_values=pad)
|
||||
|
||||
def _pad2D(self, x, length, pad):
|
||||
return np.pad(
|
||||
x, [(0, length - x.shape[0]), (0, 0)],
|
||||
mode='constant',
|
||||
constant_values=pad)
|
||||
|
||||
def _pad_durations(self, duration, max_in_len, max_out_len):
|
||||
framenum = np.sum(duration)
|
||||
symbolnum = duration.shape[0]
|
||||
if framenum < max_out_len:
|
||||
padframenum = max_out_len - framenum
|
||||
duration = np.insert(
|
||||
duration, symbolnum, values=padframenum, axis=0)
|
||||
duration = np.insert(
|
||||
duration,
|
||||
symbolnum + 1,
|
||||
values=[0] * (max_in_len - symbolnum - 1),
|
||||
axis=0)
|
||||
else:
|
||||
if symbolnum < max_in_len:
|
||||
duration = np.insert(
|
||||
duration,
|
||||
symbolnum,
|
||||
values=[0] * (max_in_len - symbolnum),
|
||||
axis=0)
|
||||
return duration
|
||||
|
||||
def _round_up(self, x, multiple):
|
||||
remainder = x % multiple
|
||||
return x if remainder == 0 else x + multiple - remainder
|
||||
|
||||
def _prepare_scalar_inputs(self, inputs, max_len, pad):
|
||||
return torch.from_numpy(
|
||||
np.stack([self._pad1D(x, max_len, pad) for x in inputs]))
|
||||
|
||||
def _prepare_targets(self, targets, max_len, pad):
|
||||
return torch.from_numpy(
|
||||
np.stack([self._pad2D(t, max_len, pad) for t in targets])).float()
|
||||
|
||||
def _prepare_durations(self, durations, max_in_len, max_out_len):
|
||||
return torch.from_numpy(
|
||||
np.stack([
|
||||
self._pad_durations(t, max_in_len, max_out_len)
|
||||
for t in durations
|
||||
])).long()
|
||||
131
modelscope/models/audio/tts/models/datasets/samplers.py
Normal file
131
modelscope/models/audio/tts/models/datasets/samplers.py
Normal file
@@ -0,0 +1,131 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import math
|
||||
import random
|
||||
|
||||
import torch
|
||||
from torch import distributed as dist
|
||||
from torch.utils.data import Sampler
|
||||
|
||||
|
||||
class LenSortGroupPoolSampler(Sampler):
|
||||
|
||||
def __init__(self, data_source, length_lst, group_size):
|
||||
super(LenSortGroupPoolSampler, self).__init__(data_source)
|
||||
|
||||
self.data_source = data_source
|
||||
self.length_lst = length_lst
|
||||
self.group_size = group_size
|
||||
|
||||
self.num = len(self.length_lst)
|
||||
self.buckets = self.num // group_size
|
||||
|
||||
def __iter__(self):
|
||||
|
||||
def getkey(item):
|
||||
return item[1]
|
||||
|
||||
random_lst = torch.randperm(self.num).tolist()
|
||||
random_len_lst = [(i, self.length_lst[i]) for i in random_lst]
|
||||
|
||||
# Bucket examples based on similar output sequence length for efficiency:
|
||||
groups = [
|
||||
random_len_lst[i:i + self.group_size]
|
||||
for i in range(0, self.num, self.group_size)
|
||||
]
|
||||
if (self.num % self.group_size):
|
||||
groups.append(random_len_lst[self.buckets * self.group_size:-1])
|
||||
|
||||
indices = []
|
||||
|
||||
for group in groups:
|
||||
group.sort(key=getkey, reverse=True)
|
||||
for item in group:
|
||||
indices.append(item[0])
|
||||
|
||||
return iter(indices)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data_source)
|
||||
|
||||
|
||||
class DistributedLenSortGroupPoolSampler(Sampler):
|
||||
|
||||
def __init__(self,
|
||||
dataset,
|
||||
length_lst,
|
||||
group_size,
|
||||
num_replicas=None,
|
||||
rank=None,
|
||||
shuffle=True):
|
||||
super(DistributedLenSortGroupPoolSampler, self).__init__(dataset)
|
||||
|
||||
if num_replicas is None:
|
||||
if not dist.is_available():
|
||||
raise RuntimeError(
|
||||
'modelscope error: Requires distributed package to be available'
|
||||
)
|
||||
num_replicas = dist.get_world_size()
|
||||
if rank is None:
|
||||
if not dist.is_available():
|
||||
raise RuntimeError(
|
||||
'modelscope error: Requires distributed package to be available'
|
||||
)
|
||||
rank = dist.get_rank()
|
||||
self.dataset = dataset
|
||||
self.length_lst = length_lst
|
||||
self.group_size = group_size
|
||||
self.num_replicas = num_replicas
|
||||
self.rank = rank
|
||||
self.epoch = 0
|
||||
self.num_samples = int(
|
||||
math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
|
||||
self.total_size = self.num_samples * self.num_replicas
|
||||
self.buckets = self.num_samples // group_size
|
||||
self.shuffle = shuffle
|
||||
|
||||
def __iter__(self):
|
||||
|
||||
def getkey(item):
|
||||
return item[1]
|
||||
|
||||
# deterministically shuffle based on epoch
|
||||
g = torch.Generator()
|
||||
g.manual_seed(self.epoch)
|
||||
if self.shuffle:
|
||||
indices = torch.randperm(len(self.dataset), generator=g).tolist()
|
||||
else:
|
||||
indices = list(range(len(self.dataset)))
|
||||
|
||||
# add extra samples to make it evenly divisible
|
||||
indices += indices[:(self.total_size - len(indices))]
|
||||
assert len(indices) == self.total_size
|
||||
|
||||
# subsample
|
||||
indices = indices[self.rank:self.total_size:self.num_replicas]
|
||||
assert len(indices) == self.num_samples
|
||||
|
||||
random_len_lst = [(i, self.length_lst[i]) for i in indices]
|
||||
|
||||
# Bucket examples based on similar output sequence length for efficiency:
|
||||
groups = [
|
||||
random_len_lst[i:i + self.group_size]
|
||||
for i in range(0, self.num_samples, self.group_size)
|
||||
]
|
||||
if (self.num_samples % self.group_size):
|
||||
groups.append(random_len_lst[self.buckets * self.group_size:-1])
|
||||
|
||||
new_indices = []
|
||||
|
||||
for group in groups:
|
||||
group.sort(key=getkey, reverse=True)
|
||||
for item in group:
|
||||
new_indices.append(item[0])
|
||||
|
||||
return iter(new_indices)
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
self.epoch = epoch
|
||||
@@ -0,0 +1,3 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from .ling_unit import * # noqa F403
|
||||
@@ -0,0 +1,88 @@
|
||||
# from https://github.com/keithito/tacotron
|
||||
# Cleaners are transformations that run over the input text at both training and eval time.
|
||||
#
|
||||
# Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
|
||||
# hyperparameter. Some cleaners are English-specific. You'll typically want to use:
|
||||
# 1. "english_cleaners" for English text
|
||||
# 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
|
||||
# the Unidecode library (https://pypi.python.org/pypi/Unidecode)
|
||||
# 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
|
||||
# the symbols in symbols.py to match your data).
|
||||
|
||||
import re
|
||||
|
||||
from unidecode import unidecode
|
||||
|
||||
from .numbers import normalize_numbers
|
||||
|
||||
# Regular expression matching whitespace:
|
||||
_whitespace_re = re.compile(r'\s+')
|
||||
|
||||
# List of (regular expression, replacement) pairs for abbreviations:
|
||||
_abbreviations = [
|
||||
(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1])
|
||||
for x in [('mrs', 'misess'),
|
||||
('mr', 'mister'),
|
||||
('dr', 'doctor'),
|
||||
('st', 'saint'),
|
||||
('co', 'company'),
|
||||
('jr', 'junior'),
|
||||
('maj', 'major'),
|
||||
('gen', 'general'),
|
||||
('drs', 'doctors'),
|
||||
('rev', 'reverend'),
|
||||
('lt', 'lieutenant'),
|
||||
('hon', 'honorable'),
|
||||
('sgt', 'sergeant'),
|
||||
('capt', 'captain'),
|
||||
('esq', 'esquire'),
|
||||
('ltd', 'limited'),
|
||||
('col', 'colonel'),
|
||||
('ft', 'fort'), ]] # yapf:disable
|
||||
|
||||
|
||||
def expand_abbreviations(text):
|
||||
for regex, replacement in _abbreviations:
|
||||
text = re.sub(regex, replacement, text)
|
||||
return text
|
||||
|
||||
|
||||
def expand_numbers(text):
|
||||
return normalize_numbers(text)
|
||||
|
||||
|
||||
def lowercase(text):
|
||||
return text.lower()
|
||||
|
||||
|
||||
def collapse_whitespace(text):
|
||||
return re.sub(_whitespace_re, ' ', text)
|
||||
|
||||
|
||||
def convert_to_ascii(text):
|
||||
return unidecode(text)
|
||||
|
||||
|
||||
def basic_cleaners(text):
|
||||
'''Basic pipeline that lowercases and collapses whitespace without transliteration.'''
|
||||
text = lowercase(text)
|
||||
text = collapse_whitespace(text)
|
||||
return text
|
||||
|
||||
|
||||
def transliteration_cleaners(text):
|
||||
'''Pipeline for non-English text that transliterates to ASCII.'''
|
||||
text = convert_to_ascii(text)
|
||||
text = lowercase(text)
|
||||
text = collapse_whitespace(text)
|
||||
return text
|
||||
|
||||
|
||||
def english_cleaners(text):
|
||||
'''Pipeline for English text, including number and abbreviation expansion.'''
|
||||
text = convert_to_ascii(text)
|
||||
text = lowercase(text)
|
||||
text = expand_numbers(text)
|
||||
text = expand_abbreviations(text)
|
||||
text = collapse_whitespace(text)
|
||||
return text
|
||||
395
modelscope/models/audio/tts/models/datasets/units/ling_unit.py
Normal file
395
modelscope/models/audio/tts/models/datasets/units/ling_unit.py
Normal file
@@ -0,0 +1,395 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import abc
|
||||
import codecs
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
from . import cleaners as cleaners
|
||||
|
||||
# Regular expression matching text enclosed in curly braces:
|
||||
_curly_re = re.compile(r'(.*?)\{(.+?)\}(.*)')
|
||||
|
||||
|
||||
def _clean_text(text, cleaner_names):
|
||||
for name in cleaner_names:
|
||||
cleaner = getattr(cleaners, name)
|
||||
if not cleaner:
|
||||
raise Exception(
|
||||
'modelscope error: configuration cleaner unknown: %s' % name)
|
||||
text = cleaner(text)
|
||||
return text
|
||||
|
||||
|
||||
class LinguisticBaseUnit(abc.ABC):
|
||||
|
||||
def set_config_params(self, config_params):
|
||||
self.config_params = config_params
|
||||
|
||||
def save(self, config, config_name, path):
|
||||
t_path = os.path.join(path, config_name)
|
||||
if config != t_path:
|
||||
os.makedirs(path, exist_ok=True)
|
||||
shutil.copyfile(config, os.path.join(path, config_name))
|
||||
|
||||
|
||||
class KanTtsLinguisticUnit(LinguisticBaseUnit):
|
||||
|
||||
def __init__(self, config, path, has_mask=True):
|
||||
super(KanTtsLinguisticUnit, self).__init__()
|
||||
|
||||
# special symbol
|
||||
self._pad = '_'
|
||||
self._eos = '~'
|
||||
self._mask = '@[MASK]'
|
||||
self._has_mask = has_mask
|
||||
self._unit_config = config
|
||||
self._path = path
|
||||
|
||||
self._cleaner_names = [
|
||||
x.strip() for x in self._unit_config['cleaners'].split(',')
|
||||
]
|
||||
self._lfeat_type_list = self._unit_config['lfeat_type_list'].strip(
|
||||
).split(',')
|
||||
|
||||
self.build()
|
||||
|
||||
def get_unit_size(self):
|
||||
ling_unit_size = {}
|
||||
ling_unit_size['sy'] = len(self.sy)
|
||||
ling_unit_size['tone'] = len(self.tone)
|
||||
ling_unit_size['syllable_flag'] = len(self.syllable_flag)
|
||||
ling_unit_size['word_segment'] = len(self.word_segment)
|
||||
|
||||
if 'emo_category' in self._lfeat_type_list:
|
||||
ling_unit_size['emotion'] = len(self.emo_category)
|
||||
if 'speaker_category' in self._lfeat_type_list:
|
||||
ling_unit_size['speaker'] = len(self.speaker)
|
||||
|
||||
return ling_unit_size
|
||||
|
||||
def build(self):
|
||||
|
||||
self._sub_unit_dim = {}
|
||||
self._sub_unit_pad = {}
|
||||
# sy sub-unit
|
||||
_characters = ''
|
||||
|
||||
_ch_symbols = []
|
||||
|
||||
sy_path = os.path.join(self._path, self._unit_config['sy'])
|
||||
f = codecs.open(sy_path, 'r')
|
||||
for line in f:
|
||||
line = line.strip('\r\n')
|
||||
_ch_symbols.append(line)
|
||||
|
||||
_arpabet = ['@' + s for s in _ch_symbols]
|
||||
|
||||
# Export all symbols:
|
||||
self.sy = list(_characters) + _arpabet + [self._pad, self._eos]
|
||||
if self._has_mask:
|
||||
self.sy.append(self._mask)
|
||||
self._sy_to_id = {s: i for i, s in enumerate(self.sy)}
|
||||
self._id_to_sy = {i: s for i, s in enumerate(self.sy)}
|
||||
self._sub_unit_dim['sy'] = len(self.sy)
|
||||
self._sub_unit_pad['sy'] = self._sy_to_id['_']
|
||||
|
||||
# tone sub-unit
|
||||
_characters = ''
|
||||
|
||||
_ch_tones = []
|
||||
|
||||
tone_path = os.path.join(self._path, self._unit_config['tone'])
|
||||
f = codecs.open(tone_path, 'r')
|
||||
for line in f:
|
||||
line = line.strip('\r\n')
|
||||
_ch_tones.append(line)
|
||||
|
||||
# Export all tones:
|
||||
self.tone = list(_characters) + _ch_tones + [self._pad, self._eos]
|
||||
if self._has_mask:
|
||||
self.tone.append(self._mask)
|
||||
self._tone_to_id = {s: i for i, s in enumerate(self.tone)}
|
||||
self._id_to_tone = {i: s for i, s in enumerate(self.tone)}
|
||||
self._sub_unit_dim['tone'] = len(self.tone)
|
||||
self._sub_unit_pad['tone'] = self._tone_to_id['_']
|
||||
|
||||
# syllable flag sub-unit
|
||||
_characters = ''
|
||||
|
||||
_ch_syllable_flags = []
|
||||
|
||||
sy_flag_path = os.path.join(self._path,
|
||||
self._unit_config['syllable_flag'])
|
||||
f = codecs.open(sy_flag_path, 'r')
|
||||
for line in f:
|
||||
line = line.strip('\r\n')
|
||||
_ch_syllable_flags.append(line)
|
||||
|
||||
# Export all syllable_flags:
|
||||
self.syllable_flag = list(_characters) + _ch_syllable_flags + [
|
||||
self._pad, self._eos
|
||||
]
|
||||
if self._has_mask:
|
||||
self.syllable_flag.append(self._mask)
|
||||
self._syllable_flag_to_id = {
|
||||
s: i
|
||||
for i, s in enumerate(self.syllable_flag)
|
||||
}
|
||||
self._id_to_syllable_flag = {
|
||||
i: s
|
||||
for i, s in enumerate(self.syllable_flag)
|
||||
}
|
||||
self._sub_unit_dim['syllable_flag'] = len(self.syllable_flag)
|
||||
self._sub_unit_pad['syllable_flag'] = self._syllable_flag_to_id['_']
|
||||
|
||||
# word segment sub-unit
|
||||
_characters = ''
|
||||
|
||||
_ch_word_segments = []
|
||||
|
||||
ws_path = os.path.join(self._path, self._unit_config['word_segment'])
|
||||
f = codecs.open(ws_path, 'r')
|
||||
for line in f:
|
||||
line = line.strip('\r\n')
|
||||
_ch_word_segments.append(line)
|
||||
|
||||
# Export all syllable_flags:
|
||||
self.word_segment = list(_characters) + _ch_word_segments + [
|
||||
self._pad, self._eos
|
||||
]
|
||||
if self._has_mask:
|
||||
self.word_segment.append(self._mask)
|
||||
self._word_segment_to_id = {
|
||||
s: i
|
||||
for i, s in enumerate(self.word_segment)
|
||||
}
|
||||
self._id_to_word_segment = {
|
||||
i: s
|
||||
for i, s in enumerate(self.word_segment)
|
||||
}
|
||||
self._sub_unit_dim['word_segment'] = len(self.word_segment)
|
||||
self._sub_unit_pad['word_segment'] = self._word_segment_to_id['_']
|
||||
|
||||
if 'emo_category' in self._lfeat_type_list:
|
||||
# emotion category sub-unit
|
||||
_characters = ''
|
||||
|
||||
_ch_emo_types = []
|
||||
|
||||
emo_path = os.path.join(self._path,
|
||||
self._unit_config['emo_category'])
|
||||
f = codecs.open(emo_path, 'r')
|
||||
for line in f:
|
||||
line = line.strip('\r\n')
|
||||
_ch_emo_types.append(line)
|
||||
|
||||
self.emo_category = list(_characters) + _ch_emo_types + [
|
||||
self._pad, self._eos
|
||||
]
|
||||
if self._has_mask:
|
||||
self.emo_category.append(self._mask)
|
||||
self._emo_category_to_id = {
|
||||
s: i
|
||||
for i, s in enumerate(self.emo_category)
|
||||
}
|
||||
self._id_to_emo_category = {
|
||||
i: s
|
||||
for i, s in enumerate(self.emo_category)
|
||||
}
|
||||
self._sub_unit_dim['emo_category'] = len(self.emo_category)
|
||||
self._sub_unit_pad['emo_category'] = self._emo_category_to_id['_']
|
||||
|
||||
if 'speaker_category' in self._lfeat_type_list:
|
||||
# speaker category sub-unit
|
||||
_characters = ''
|
||||
|
||||
_ch_speakers = []
|
||||
|
||||
speaker_path = os.path.join(self._path,
|
||||
self._unit_config['speaker_category'])
|
||||
f = codecs.open(speaker_path, 'r')
|
||||
for line in f:
|
||||
line = line.strip('\r\n')
|
||||
_ch_speakers.append(line)
|
||||
|
||||
# Export all syllable_flags:
|
||||
self.speaker = list(_characters) + _ch_speakers + [
|
||||
self._pad, self._eos
|
||||
]
|
||||
if self._has_mask:
|
||||
self.speaker.append(self._mask)
|
||||
self._speaker_to_id = {s: i for i, s in enumerate(self.speaker)}
|
||||
self._id_to_speaker = {i: s for i, s in enumerate(self.speaker)}
|
||||
self._sub_unit_dim['speaker_category'] = len(self._speaker_to_id)
|
||||
self._sub_unit_pad['speaker_category'] = self._speaker_to_id['_']
|
||||
|
||||
def encode_symbol_sequence(self, lfeat_symbol):
|
||||
lfeat_symbol = lfeat_symbol.strip().split(' ')
|
||||
|
||||
lfeat_symbol_separate = [''] * int(len(self._lfeat_type_list))
|
||||
for this_lfeat_symbol in lfeat_symbol:
|
||||
this_lfeat_symbol = this_lfeat_symbol.strip('{').strip('}').split(
|
||||
'$')
|
||||
index = 0
|
||||
while index < len(lfeat_symbol_separate):
|
||||
lfeat_symbol_separate[index] = lfeat_symbol_separate[
|
||||
index] + this_lfeat_symbol[index] + ' '
|
||||
index = index + 1
|
||||
|
||||
input_and_label_data = []
|
||||
index = 0
|
||||
while index < len(self._lfeat_type_list):
|
||||
sequence = self.encode_sub_unit(
|
||||
lfeat_symbol_separate[index].strip(),
|
||||
self._lfeat_type_list[index])
|
||||
sequence_array = np.asarray(sequence, dtype=np.int32)
|
||||
input_and_label_data.append(sequence_array)
|
||||
index = index + 1
|
||||
|
||||
return input_and_label_data
|
||||
|
||||
def decode_symbol_sequence(self, sequence):
|
||||
result = []
|
||||
for i, lfeat_type in enumerate(self._lfeat_type_list):
|
||||
s = ''
|
||||
sequence_item = sequence[i].tolist()
|
||||
if lfeat_type == 'sy':
|
||||
s = self.decode_sy(sequence_item)
|
||||
elif lfeat_type == 'tone':
|
||||
s = self.decode_tone(sequence_item)
|
||||
elif lfeat_type == 'syllable_flag':
|
||||
s = self.decode_syllable_flag(sequence_item)
|
||||
elif lfeat_type == 'word_segment':
|
||||
s = self.decode_word_segment(sequence_item)
|
||||
elif lfeat_type == 'emo_category':
|
||||
s = self.decode_emo_category(sequence_item)
|
||||
elif lfeat_type == 'speaker_category':
|
||||
s = self.decode_speaker_category(sequence_item)
|
||||
else:
|
||||
raise Exception(
|
||||
'modelscope error: configuration lfeat type(%s) unknown.'
|
||||
% lfeat_type)
|
||||
result.append('%s:%s' % (lfeat_type, s))
|
||||
|
||||
return result
|
||||
|
||||
def encode_sub_unit(self, this_lfeat_symbol, lfeat_type):
|
||||
sequence = []
|
||||
if lfeat_type == 'sy':
|
||||
this_lfeat_symbol = this_lfeat_symbol.strip().split(' ')
|
||||
this_lfeat_symbol_format = ''
|
||||
index = 0
|
||||
while index < len(this_lfeat_symbol):
|
||||
this_lfeat_symbol_format = this_lfeat_symbol_format + '{' + this_lfeat_symbol[
|
||||
index] + '}' + ' '
|
||||
index = index + 1
|
||||
sequence = self.encode_text(this_lfeat_symbol_format,
|
||||
self._cleaner_names)
|
||||
elif lfeat_type == 'tone':
|
||||
sequence = self.encode_tone(this_lfeat_symbol)
|
||||
elif lfeat_type == 'syllable_flag':
|
||||
sequence = self.encode_syllable_flag(this_lfeat_symbol)
|
||||
elif lfeat_type == 'word_segment':
|
||||
sequence = self.encode_word_segment(this_lfeat_symbol)
|
||||
elif lfeat_type == 'emo_category':
|
||||
sequence = self.encode_emo_category(this_lfeat_symbol)
|
||||
elif lfeat_type == 'speaker_category':
|
||||
sequence = self.encode_speaker_category(this_lfeat_symbol)
|
||||
else:
|
||||
raise Exception(
|
||||
'modelscope error: configuration lfeat type(%s) unknown.'
|
||||
% lfeat_type)
|
||||
|
||||
return sequence
|
||||
|
||||
def encode_text(self, text, cleaner_names):
|
||||
sequence = []
|
||||
|
||||
# Check for curly braces and treat their contents as ARPAbet:
|
||||
while len(text):
|
||||
m = _curly_re.match(text)
|
||||
if not m:
|
||||
sequence += self.encode_sy(_clean_text(text, cleaner_names))
|
||||
break
|
||||
sequence += self.encode_sy(_clean_text(m.group(1), cleaner_names))
|
||||
sequence += self.encode_arpanet(m.group(2))
|
||||
text = m.group(3)
|
||||
|
||||
# Append EOS token
|
||||
sequence.append(self._sy_to_id['~'])
|
||||
return sequence
|
||||
|
||||
def encode_sy(self, sy):
|
||||
return [self._sy_to_id[s] for s in sy if self.should_keep_sy(s)]
|
||||
|
||||
def decode_sy(self, id):
|
||||
s = self._id_to_sy[id]
|
||||
if len(s) > 1 and s[0] == '@':
|
||||
s = s[1:]
|
||||
return s
|
||||
|
||||
def should_keep_sy(self, s):
|
||||
return s in self._sy_to_id and s != '_' and s != '~'
|
||||
|
||||
def encode_arpanet(self, text):
|
||||
return self.encode_sy(['@' + s for s in text.split()])
|
||||
|
||||
def encode_tone(self, tone):
|
||||
tones = tone.strip().split(' ')
|
||||
sequence = []
|
||||
for this_tone in tones:
|
||||
sequence.append(self._tone_to_id[this_tone])
|
||||
sequence.append(self._tone_to_id['~'])
|
||||
return sequence
|
||||
|
||||
def decode_tone(self, id):
|
||||
return self._id_to_tone[id]
|
||||
|
||||
def encode_syllable_flag(self, syllable_flag):
|
||||
syllable_flags = syllable_flag.strip().split(' ')
|
||||
sequence = []
|
||||
for this_syllable_flag in syllable_flags:
|
||||
sequence.append(self._syllable_flag_to_id[this_syllable_flag])
|
||||
sequence.append(self._syllable_flag_to_id['~'])
|
||||
return sequence
|
||||
|
||||
def decode_syllable_flag(self, id):
|
||||
return self._id_to_syllable_flag[id]
|
||||
|
||||
def encode_word_segment(self, word_segment):
|
||||
word_segments = word_segment.strip().split(' ')
|
||||
sequence = []
|
||||
for this_word_segment in word_segments:
|
||||
sequence.append(self._word_segment_to_id[this_word_segment])
|
||||
sequence.append(self._word_segment_to_id['~'])
|
||||
return sequence
|
||||
|
||||
def decode_word_segment(self, id):
|
||||
return self._id_to_word_segment[id]
|
||||
|
||||
def encode_emo_category(self, emo_type):
|
||||
emo_categories = emo_type.strip().split(' ')
|
||||
sequence = []
|
||||
for this_category in emo_categories:
|
||||
sequence.append(self._emo_category_to_id[this_category])
|
||||
sequence.append(self._emo_category_to_id['~'])
|
||||
return sequence
|
||||
|
||||
def decode_emo_category(self, id):
|
||||
return self._id_to_emo_category[id]
|
||||
|
||||
def encode_speaker_category(self, speaker):
|
||||
speakers = speaker.strip().split(' ')
|
||||
sequence = []
|
||||
for this_speaker in speakers:
|
||||
sequence.append(self._speaker_to_id[this_speaker])
|
||||
sequence.append(self._speaker_to_id['~'])
|
||||
return sequence
|
||||
|
||||
def decode_speaker_category(self, id):
|
||||
return self._id_to_speaker[id]
|
||||
3
modelscope/models/audio/tts/text/numbers.py → modelscope/models/audio/tts/models/datasets/units/numbers.py
Executable file → Normal file
3
modelscope/models/audio/tts/text/numbers.py → modelscope/models/audio/tts/models/datasets/units/numbers.py
Executable file → Normal file
@@ -1,3 +1,6 @@
|
||||
# The implementation is adopted from tacotron,
|
||||
# made publicly available under the MIT License at https://github.com/keithito/tacotron
|
||||
|
||||
import re
|
||||
|
||||
import inflect
|
||||
@@ -1,273 +0,0 @@
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
def build_sequence_mask(sequence_length,
|
||||
maximum_length=None,
|
||||
dtype=tf.float32):
|
||||
"""Builds the dot product mask.
|
||||
|
||||
Args:
|
||||
sequence_length: The sequence length.
|
||||
maximum_length: Optional size of the returned time dimension. Otherwise
|
||||
it is the maximum of :obj:`sequence_length`.
|
||||
dtype: The type of the mask tensor.
|
||||
|
||||
Returns:
|
||||
A broadcastable ``tf.Tensor`` of type :obj:`dtype` and shape
|
||||
``[batch_size, max_length]``.
|
||||
"""
|
||||
mask = tf.sequence_mask(
|
||||
sequence_length, maxlen=maximum_length, dtype=dtype)
|
||||
|
||||
return mask
|
||||
|
||||
|
||||
def norm(inputs):
|
||||
"""Layer normalizes :obj:`inputs`."""
|
||||
return tf.contrib.layers.layer_norm(inputs, begin_norm_axis=-1)
|
||||
|
||||
|
||||
def pad_in_time(x, padding_shape):
|
||||
"""Helper function to pad a tensor in the time dimension and retain the static depth dimension.
|
||||
|
||||
Agrs:
|
||||
x: [Batch, Time, Frequency]
|
||||
padding_length: padding size of constant value (0) before the time dimension
|
||||
|
||||
return:
|
||||
padded x
|
||||
"""
|
||||
|
||||
depth = x.get_shape().as_list()[-1]
|
||||
x = tf.pad(x, [[0, 0], padding_shape, [0, 0]])
|
||||
x.set_shape((None, None, depth))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def pad_in_time_right(x, padding_length):
|
||||
"""Helper function to pad a tensor in the time dimension and retain the static depth dimension.
|
||||
|
||||
Agrs:
|
||||
x: [Batch, Time, Frequency]
|
||||
padding_length: padding size of constant value (0) before the time dimension
|
||||
|
||||
return:
|
||||
padded x
|
||||
"""
|
||||
depth = x.get_shape().as_list()[-1]
|
||||
x = tf.pad(x, [[0, 0], [0, padding_length], [0, 0]])
|
||||
x.set_shape((None, None, depth))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def feed_forward(x, ffn_dim, memory_units, mode, dropout=0.0):
|
||||
"""Implements the Transformer's "Feed Forward" layer.
|
||||
|
||||
.. math::
|
||||
|
||||
ffn(x) = max(0, x*W_1 + b_1)*W_2
|
||||
|
||||
Args:
|
||||
x: The input.
|
||||
ffn_dim: The number of units of the nonlinear transformation.
|
||||
memory_units: the number of units of linear transformation
|
||||
mode: A ``tf.estimator.ModeKeys`` mode.
|
||||
dropout: The probability to drop units from the inner transformation.
|
||||
|
||||
Returns:
|
||||
The transformed input.
|
||||
"""
|
||||
inner = tf.layers.conv1d(x, ffn_dim, 1, activation=tf.nn.relu)
|
||||
inner = tf.layers.dropout(
|
||||
inner, rate=dropout, training=mode == tf.estimator.ModeKeys.TRAIN)
|
||||
outer = tf.layers.conv1d(inner, memory_units, 1, use_bias=False)
|
||||
|
||||
return outer
|
||||
|
||||
|
||||
def drop_and_add(inputs, outputs, mode, dropout=0.0):
|
||||
"""Drops units in the outputs and adds the previous values.
|
||||
|
||||
Args:
|
||||
inputs: The input of the previous layer.
|
||||
outputs: The output of the previous layer.
|
||||
mode: A ``tf.estimator.ModeKeys`` mode.
|
||||
dropout: The probability to drop units in :obj:`outputs`.
|
||||
|
||||
Returns:
|
||||
The residual and normalized output.
|
||||
"""
|
||||
outputs = tf.layers.dropout(outputs, rate=dropout, training=mode)
|
||||
|
||||
input_dim = inputs.get_shape().as_list()[-1]
|
||||
output_dim = outputs.get_shape().as_list()[-1]
|
||||
|
||||
if input_dim == output_dim:
|
||||
outputs += inputs
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
def MemoryBlock(
|
||||
inputs,
|
||||
filter_size,
|
||||
mode,
|
||||
mask=None,
|
||||
dropout=0.0,
|
||||
):
|
||||
"""
|
||||
Define the bidirectional memory block in FSMN
|
||||
|
||||
Agrs:
|
||||
inputs: The output of the previous layer. [Batch, Time, Frequency]
|
||||
filter_size: memory block filter size
|
||||
mode: Training or Evaluation
|
||||
mask: A ``tf.Tensor`` applied to the memory block output
|
||||
|
||||
return:
|
||||
output: 3-D tensor ([Batch, Time, Frequency])
|
||||
"""
|
||||
static_shape = inputs.get_shape().as_list()
|
||||
depth = static_shape[-1]
|
||||
inputs = tf.expand_dims(inputs, axis=1) # [Batch, 1, Time, Frequency]
|
||||
depthwise_filter = tf.get_variable(
|
||||
'depth_conv_w',
|
||||
shape=[1, filter_size, depth, 1],
|
||||
initializer=tf.glorot_uniform_initializer(),
|
||||
dtype=tf.float32)
|
||||
memory = tf.nn.depthwise_conv2d(
|
||||
input=inputs,
|
||||
filter=depthwise_filter,
|
||||
strides=[1, 1, 1, 1],
|
||||
padding='SAME',
|
||||
rate=[1, 1],
|
||||
data_format='NHWC')
|
||||
memory = memory + inputs
|
||||
output = tf.layers.dropout(memory, rate=dropout, training=mode)
|
||||
output = tf.reshape(
|
||||
output,
|
||||
[tf.shape(output)[0], tf.shape(output)[2], depth])
|
||||
if mask is not None:
|
||||
output = output * tf.expand_dims(mask, -1)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def MemoryBlockV2(
|
||||
inputs,
|
||||
filter_size,
|
||||
mode,
|
||||
shift=0,
|
||||
mask=None,
|
||||
dropout=0.0,
|
||||
):
|
||||
"""
|
||||
Define the bidirectional memory block in FSMN
|
||||
|
||||
Agrs:
|
||||
inputs: The output of the previous layer. [Batch, Time, Frequency]
|
||||
filter_size: memory block filter size
|
||||
mode: Training or Evaluation
|
||||
shift: left padding, to control delay
|
||||
mask: A ``tf.Tensor`` applied to the memory block output
|
||||
|
||||
return:
|
||||
output: 3-D tensor ([Batch, Time, Frequency])
|
||||
"""
|
||||
if mask is not None:
|
||||
inputs = inputs * tf.expand_dims(mask, -1)
|
||||
|
||||
static_shape = inputs.get_shape().as_list()
|
||||
depth = static_shape[-1]
|
||||
# padding
|
||||
left_padding = int(round((filter_size - 1) / 2))
|
||||
right_padding = int((filter_size - 1) / 2)
|
||||
if shift > 0:
|
||||
left_padding = left_padding + shift
|
||||
right_padding = right_padding - shift
|
||||
pad_inputs = pad_in_time(inputs, [left_padding, right_padding])
|
||||
pad_inputs = tf.expand_dims(
|
||||
pad_inputs, axis=1) # [Batch, 1, Time, Frequency]
|
||||
depthwise_filter = tf.get_variable(
|
||||
'depth_conv_w',
|
||||
shape=[1, filter_size, depth, 1],
|
||||
initializer=tf.glorot_uniform_initializer(),
|
||||
dtype=tf.float32)
|
||||
memory = tf.nn.depthwise_conv2d(
|
||||
input=pad_inputs,
|
||||
filter=depthwise_filter,
|
||||
strides=[1, 1, 1, 1],
|
||||
padding='VALID',
|
||||
rate=[1, 1],
|
||||
data_format='NHWC')
|
||||
memory = tf.reshape(
|
||||
memory,
|
||||
[tf.shape(memory)[0], tf.shape(memory)[2], depth])
|
||||
memory = memory + inputs
|
||||
output = tf.layers.dropout(memory, rate=dropout, training=mode)
|
||||
if mask is not None:
|
||||
output = output * tf.expand_dims(mask, -1)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def UniMemoryBlock(
|
||||
inputs,
|
||||
filter_size,
|
||||
mode,
|
||||
cache=None,
|
||||
mask=None,
|
||||
dropout=0.0,
|
||||
):
|
||||
"""
|
||||
Define the unidirectional memory block in FSMN
|
||||
|
||||
Agrs:
|
||||
inputs: The output of the previous layer. [Batch, Time, Frequency]
|
||||
filter_size: memory block filter size
|
||||
cache: for streaming inference
|
||||
mode: Training or Evaluation
|
||||
mask: A ``tf.Tensor`` applied to the memory block output
|
||||
dropout: dorpout factor
|
||||
return:
|
||||
output: 3-D tensor ([Batch, Time, Frequency])
|
||||
"""
|
||||
if cache is not None:
|
||||
static_shape = cache['queries'].get_shape().as_list()
|
||||
depth = static_shape[-1]
|
||||
queries = tf.slice(cache['queries'], [0, 1, 0], [
|
||||
tf.shape(cache['queries'])[0],
|
||||
tf.shape(cache['queries'])[1] - 1, depth
|
||||
])
|
||||
queries = tf.concat([queries, inputs], axis=1)
|
||||
cache['queries'] = queries
|
||||
else:
|
||||
padding_length = filter_size - 1
|
||||
queries = pad_in_time(inputs, [padding_length, 0])
|
||||
|
||||
queries = tf.expand_dims(queries, axis=1) # [Batch, 1, Time, Frequency]
|
||||
static_shape = queries.get_shape().as_list()
|
||||
depth = static_shape[-1]
|
||||
depthwise_filter = tf.get_variable(
|
||||
'depth_conv_w',
|
||||
shape=[1, filter_size, depth, 1],
|
||||
initializer=tf.glorot_uniform_initializer(),
|
||||
dtype=tf.float32)
|
||||
memory = tf.nn.depthwise_conv2d(
|
||||
input=queries,
|
||||
filter=depthwise_filter,
|
||||
strides=[1, 1, 1, 1],
|
||||
padding='VALID',
|
||||
rate=[1, 1],
|
||||
data_format='NHWC')
|
||||
memory = tf.reshape(
|
||||
memory,
|
||||
[tf.shape(memory)[0], tf.shape(memory)[2], depth])
|
||||
memory = memory + inputs
|
||||
output = tf.layers.dropout(memory, rate=dropout, training=mode)
|
||||
if mask is not None:
|
||||
output = output * tf.expand_dims(mask, -1)
|
||||
|
||||
return output
|
||||
@@ -1,178 +0,0 @@
|
||||
import tensorflow as tf
|
||||
|
||||
from . import fsmn
|
||||
|
||||
|
||||
class FsmnEncoder():
|
||||
"""Encoder using Fsmn
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
filter_size,
|
||||
fsmn_num_layers,
|
||||
dnn_num_layers,
|
||||
num_memory_units=512,
|
||||
ffn_inner_dim=2048,
|
||||
dropout=0.0,
|
||||
position_encoder=None):
|
||||
"""Initializes the parameters of the encoder.
|
||||
|
||||
Args:
|
||||
filter_size: the total order of memory block
|
||||
fsmn_num_layers: The number of fsmn layers.
|
||||
dnn_num_layers: The number of dnn layers
|
||||
num_units: The number of memory units.
|
||||
ffn_inner_dim: The number of units of the inner linear transformation
|
||||
in the feed forward layer.
|
||||
dropout: The probability to drop units from the outputs.
|
||||
position_encoder: The :class:`opennmt.layers.position.PositionEncoder` to
|
||||
apply on inputs or ``None``.
|
||||
"""
|
||||
super(FsmnEncoder, self).__init__()
|
||||
self.filter_size = filter_size
|
||||
self.fsmn_num_layers = fsmn_num_layers
|
||||
self.dnn_num_layers = dnn_num_layers
|
||||
self.num_memory_units = num_memory_units
|
||||
self.ffn_inner_dim = ffn_inner_dim
|
||||
self.dropout = dropout
|
||||
self.position_encoder = position_encoder
|
||||
|
||||
def encode(self, inputs, sequence_length=None, mode=True):
|
||||
if self.position_encoder is not None:
|
||||
inputs = self.position_encoder(inputs)
|
||||
|
||||
inputs = tf.layers.dropout(inputs, rate=self.dropout, training=mode)
|
||||
|
||||
mask = fsmn.build_sequence_mask(
|
||||
sequence_length, maximum_length=tf.shape(inputs)[1])
|
||||
|
||||
state = ()
|
||||
|
||||
for layer in range(self.fsmn_num_layers):
|
||||
with tf.variable_scope('fsmn_layer_{}'.format(layer)):
|
||||
with tf.variable_scope('ffn'):
|
||||
context = fsmn.feed_forward(
|
||||
inputs,
|
||||
self.ffn_inner_dim,
|
||||
self.num_memory_units,
|
||||
mode,
|
||||
dropout=self.dropout)
|
||||
|
||||
with tf.variable_scope('memory'):
|
||||
memory = fsmn.MemoryBlock(
|
||||
context,
|
||||
self.filter_size,
|
||||
mode,
|
||||
mask=mask,
|
||||
dropout=self.dropout)
|
||||
|
||||
memory = fsmn.drop_and_add(
|
||||
inputs, memory, mode, dropout=self.dropout)
|
||||
|
||||
inputs = memory
|
||||
state += (tf.reduce_mean(inputs, axis=1), )
|
||||
|
||||
for layer in range(self.dnn_num_layers):
|
||||
with tf.variable_scope('dnn_layer_{}'.format(layer)):
|
||||
transformed = fsmn.feed_forward(
|
||||
inputs,
|
||||
self.ffn_inner_dim,
|
||||
self.num_memory_units,
|
||||
mode,
|
||||
dropout=self.dropout)
|
||||
|
||||
inputs = transformed
|
||||
state += (tf.reduce_mean(inputs, axis=1), )
|
||||
|
||||
outputs = inputs
|
||||
return (outputs, state, sequence_length)
|
||||
|
||||
|
||||
class FsmnEncoderV2():
|
||||
"""Encoder using Fsmn
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
filter_size,
|
||||
fsmn_num_layers,
|
||||
dnn_num_layers,
|
||||
num_memory_units=512,
|
||||
ffn_inner_dim=2048,
|
||||
dropout=0.0,
|
||||
shift=0,
|
||||
position_encoder=None):
|
||||
"""Initializes the parameters of the encoder.
|
||||
|
||||
Args:
|
||||
filter_size: the total order of memory block
|
||||
fsmn_num_layers: The number of fsmn layers.
|
||||
dnn_num_layers: The number of dnn layers
|
||||
num_units: The number of memory units.
|
||||
ffn_inner_dim: The number of units of the inner linear transformation
|
||||
in the feed forward layer.
|
||||
dropout: The probability to drop units from the outputs.
|
||||
shift: left padding, to control delay
|
||||
position_encoder: The :class:`opennmt.layers.position.PositionEncoder` to
|
||||
apply on inputs or ``None``.
|
||||
"""
|
||||
super(FsmnEncoderV2, self).__init__()
|
||||
self.filter_size = filter_size
|
||||
self.fsmn_num_layers = fsmn_num_layers
|
||||
self.dnn_num_layers = dnn_num_layers
|
||||
self.num_memory_units = num_memory_units
|
||||
self.ffn_inner_dim = ffn_inner_dim
|
||||
self.dropout = dropout
|
||||
self.shift = shift
|
||||
if not isinstance(shift, list):
|
||||
self.shift = [shift for _ in range(self.fsmn_num_layers)]
|
||||
self.position_encoder = position_encoder
|
||||
|
||||
def encode(self, inputs, sequence_length=None, mode=True):
|
||||
if self.position_encoder is not None:
|
||||
inputs = self.position_encoder(inputs)
|
||||
|
||||
inputs = tf.layers.dropout(inputs, rate=self.dropout, training=mode)
|
||||
|
||||
mask = fsmn.build_sequence_mask(
|
||||
sequence_length, maximum_length=tf.shape(inputs)[1])
|
||||
|
||||
state = ()
|
||||
for layer in range(self.fsmn_num_layers):
|
||||
with tf.variable_scope('fsmn_layer_{}'.format(layer)):
|
||||
with tf.variable_scope('ffn'):
|
||||
context = fsmn.feed_forward(
|
||||
inputs,
|
||||
self.ffn_inner_dim,
|
||||
self.num_memory_units,
|
||||
mode,
|
||||
dropout=self.dropout)
|
||||
|
||||
with tf.variable_scope('memory'):
|
||||
memory = fsmn.MemoryBlockV2(
|
||||
context,
|
||||
self.filter_size,
|
||||
mode,
|
||||
shift=self.shift[layer],
|
||||
mask=mask,
|
||||
dropout=self.dropout)
|
||||
|
||||
memory = fsmn.drop_and_add(
|
||||
inputs, memory, mode, dropout=self.dropout)
|
||||
|
||||
inputs = memory
|
||||
state += (tf.reduce_mean(inputs, axis=1), )
|
||||
|
||||
for layer in range(self.dnn_num_layers):
|
||||
with tf.variable_scope('dnn_layer_{}'.format(layer)):
|
||||
transformed = fsmn.feed_forward(
|
||||
inputs,
|
||||
self.ffn_inner_dim,
|
||||
self.num_memory_units,
|
||||
mode,
|
||||
dropout=self.dropout)
|
||||
|
||||
inputs = transformed
|
||||
state += (tf.reduce_mean(inputs, axis=1), )
|
||||
|
||||
outputs = inputs
|
||||
return (outputs, state, sequence_length)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user