mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-25 04:29:22 +01:00
Merge pull request #199 from modelscope/master-merge-internal20230315
Master merge internal20230315
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -2,7 +2,7 @@
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
test.py
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
@@ -123,6 +123,7 @@ tensorboard.sh
|
||||
replace.sh
|
||||
result.png
|
||||
result.jpg
|
||||
result.mp4
|
||||
|
||||
# Pytorch
|
||||
*.pth
|
||||
|
||||
BIN
data/test/audios/speech_with_noise_48k.pcm
Normal file
BIN
data/test/audios/speech_with_noise_48k.pcm
Normal file
Binary file not shown.
3
data/test/audios/speech_with_noise_48k.wav
Normal file
3
data/test/audios/speech_with_noise_48k.wav
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:a9e76c8448e93934ed9c8827b76f702d07fccc3e586900903617971471235800
|
||||
size 475278
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:78094cc48fbcfd9b6d321fe13619ecc72b65e006fc1b4c4458409ade9979486d
|
||||
size 129862
|
||||
oid sha256:d53a77b0be82993ed44bbb9244cda42bf460f8dcdf87ff3cfdbfdc7191ff418d
|
||||
size 121984
|
||||
|
||||
3
data/test/images/human_reconstruction.jpg
Normal file
3
data/test/images/human_reconstruction.jpg
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:06ec486657dffbf244563a844c98c19d49b7a45b99da702403b52bb9e6bf3c0a
|
||||
size 226072
|
||||
3
data/test/images/image_camouflag_detection.jpg
Normal file
3
data/test/images/image_camouflag_detection.jpg
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:4c713215f7fb4da5382c9137347ee52956a7a44d5979c4cffd3c9b6d1d7e878f
|
||||
size 19445
|
||||
3
data/test/images/image_depth_estimation_kitti_007517.png
Normal file
3
data/test/images/image_depth_estimation_kitti_007517.png
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:f2a83dab7fd7fedff65979fd2496fd86f0a36f222a5a0e6c81fbb161043b9a45
|
||||
size 786657
|
||||
3
data/test/images/image_smokefire_detection.jpg
Normal file
3
data/test/images/image_smokefire_detection.jpg
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:713082e6967760d5a0d1ae07af62ecc58f9b8b0ab418394556dc5c6c31c27056
|
||||
size 63761
|
||||
3
data/test/images/lineless_table_recognition.jpg
Normal file
3
data/test/images/lineless_table_recognition.jpg
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:2053b3bcb7abfe5c22b4954b81899dcffbb99302af6f179c43d45265c732d804
|
||||
size 26493
|
||||
46
data/test/images/ocr_detection/test_gts/X51007339105.txt
Normal file
46
data/test/images/ocr_detection/test_gts/X51007339105.txt
Normal file
@@ -0,0 +1,46 @@
|
||||
44.0,131.0,513.0,131.0,513.0,166.0,44.0,166.0,SANYU STATIONERY SHOP
|
||||
45.0,179.0,502.0,179.0,502.0,204.0,45.0,204.0,NO. 31G&33G, JALAN SETIA INDAH X ,U13/X
|
||||
43.0,205.0,242.0,205.0,242.0,226.0,43.0,226.0,40170 SETIA ALAM
|
||||
44.0,231.0,431.0,231.0,431.0,255.0,44.0,255.0,MOBILE /WHATSAPPS : +6012-918 7937
|
||||
41.0,264.0,263.0,264.0,263.0,286.0,41.0,286.0,TEL: +603-3362 4137
|
||||
42.0,291.0,321.0,291.0,321.0,312.0,42.0,312.0,GST ID NO: 001531760640
|
||||
409.0,303.0,591.0,303.0,591.0,330.0,409.0,330.0,TAX INVOICE
|
||||
33.0,321.0,139.0,321.0,139.0,343.0,33.0,343.0,OWNED BY :
|
||||
34.0,343.0,376.0,343.0,376.0,369.0,34.0,369.0,SANYU SUPPLY SDN BHD (1135772-K)
|
||||
37.0,397.0,303.0,397.0,303.0,420.0,37.0,420.0,CASH SALES COUNTER
|
||||
53.0,459.0,188.0,459.0,188.0,483.0,53.0,483.0,1. 2012-0043
|
||||
79.0,518.0,193.0,518.0,193.0,545.0,79.0,545.0,1 X 3.3000
|
||||
270.0,460.0,585.0,460.0,585.0,484.0,270.0,484.0,JOURNAL BOOK 80PGS A4 70G
|
||||
271.0,487.0,527.0,487.0,527.0,513.0,271.0,513.0,CARD COVER (SJB-4013)
|
||||
479.0,522.0,527.0,522.0,527.0,542.0,479.0,542.0,3.30
|
||||
553.0,524.0,583.0,524.0,583.0,544.0,553.0,544.0,SR
|
||||
27.0,557.0,342.0,557.0,342.0,581.0,27.0,581.0,TOTAL SALES INCLUSIVE GST @6%
|
||||
240.0,587.0,331.0,587.0,331.0,612.0,240.0,612.0,DISCOUNT
|
||||
239.0,629.0,293.0,629.0,293.0,651.0,239.0,651.0,TOTAL
|
||||
240.0,659.0,345.0,659.0,345.0,687.0,240.0,687.0,ROUND ADJ
|
||||
239.0,701.0,347.0,701.0,347.0,723.0,239.0,723.0,FINAL TOTAL
|
||||
239.0,758.0,301.0,758.0,301.0,781.0,239.0,781.0,CASH
|
||||
240.0,789.0,329.0,789.0,329.0,811.0,240.0,811.0,CHANGE
|
||||
478.0,561.0,524.0,561.0,524.0,581.0,478.0,581.0,3.30
|
||||
477.0,590.0,525.0,590.0,525.0,614.0,477.0,614.0,0.00
|
||||
482.0,634.0,527.0,634.0,527.0,655.0,482.0,655.0,3.30
|
||||
480.0,664.0,528.0,664.0,528.0,684.0,480.0,684.0,0.00
|
||||
481.0,704.0,527.0,704.0,527.0,726.0,481.0,726.0,3.30
|
||||
481.0,760.0,526.0,760.0,526.0,782.0,481.0,782.0,5.00
|
||||
482.0,793.0,528.0,793.0,528.0,814.0,482.0,814.0,1.70
|
||||
28.0,834.0,172.0,834.0,172.0,859.0,28.0,859.0,GST SUMMARY
|
||||
253.0,834.0,384.0,834.0,384.0,859.0,253.0,859.0,AMOUNT(RM)
|
||||
475.0,834.0,566.0,834.0,566.0,860.0,475.0,860.0,TAX(RM)
|
||||
28.0,864.0,128.0,864.0,128.0,889.0,28.0,889.0,SR @ 6%
|
||||
337.0,864.0,385.0,864.0,385.0,886.0,337.0,886.0,3.11
|
||||
518.0,867.0,565.0,867.0,565.0,887.0,518.0,887.0,0.19
|
||||
25.0,943.0,290.0,943.0,290.0,967.0,25.0,967.0,INV NO: CS-SA-0076015
|
||||
316.0,942.0,516.0,942.0,516.0,967.0,316.0,967.0,DATE : 05/04/2017
|
||||
65.0,1084.0,569.0,1084.0,569.0,1110.0,65.0,1110.0,GOODS SOLD ARE NOT RETURNABLE & REFUNDABLE
|
||||
112.0,1135.0,524.0,1135.0,524.0,1163.0,112.0,1163.0,THANK YOU FOR YOUR PATRONAGE
|
||||
189.0,1169.0,441.0,1169.0,441.0,1192.0,189.0,1192.0,PLEASE COME AGAIN.
|
||||
115.0,1221.0,517.0,1221.0,517.0,1245.0,115.0,1245.0,TERIMA KASIH SILA DATANG LAGI
|
||||
65.0,1271.0,569.0,1271.0,569.0,1299.0,65.0,1299.0,** PLEASE KEEP THIS RECEIPT FOR PROVE OF
|
||||
48.0,1305.0,584.0,1305.0,584.0,1330.0,48.0,1330.0,PURCHASE DATE FOR I.T PRODUCT WARRANTY
|
||||
244.0,1339.0,393.0,1339.0,393.0,1359.0,244.0,1359.0,PURPOSE **
|
||||
85.0,1389.0,548.0,1389.0,548.0,1419.0,85.0,1419.0,FOLLOW US IN FACEBOOK : SANYU.STATIONERY
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:afe8e0d24bed53078472e6e4a00f81cc4e251e88d35bc49afb59cf3fab36fcf8
|
||||
size 348614
|
||||
1
data/test/images/ocr_detection/test_list.txt
Normal file
1
data/test/images/ocr_detection/test_list.txt
Normal file
@@ -0,0 +1 @@
|
||||
X51007339105.jpg
|
||||
46
data/test/images/ocr_detection/train_gts/X51007339133.txt
Normal file
46
data/test/images/ocr_detection/train_gts/X51007339133.txt
Normal file
@@ -0,0 +1,46 @@
|
||||
46.0,135.0,515.0,135.0,515.0,170.0,46.0,170.0,SANYU STATIONERY SHOP
|
||||
49.0,184.0,507.0,184.0,507.0,207.0,49.0,207.0,NO. 31G&33G, JALAN SETIA INDAH X ,U13/X
|
||||
47.0,209.0,245.0,209.0,245.0,230.0,47.0,230.0,40170 SETIA ALAM
|
||||
48.0,236.0,433.0,236.0,433.0,258.0,48.0,258.0,MOBILE /WHATSAPPS : +6012-918 7937
|
||||
46.0,267.0,266.0,267.0,266.0,287.0,46.0,287.0,TEL: +603-3362 4137
|
||||
47.0,292.0,325.0,292.0,325.0,315.0,47.0,315.0,GST ID NO: 001531760640
|
||||
410.0,303.0,594.0,303.0,594.0,331.0,410.0,331.0,TAX INVOICE
|
||||
38.0,322.0,141.0,322.0,141.0,344.0,38.0,344.0,OWNED BY :
|
||||
38.0,345.0,378.0,345.0,378.0,368.0,38.0,368.0,SANYU SUPPLY SDN BHD (1135772-K)
|
||||
43.0,398.0,303.0,398.0,303.0,422.0,43.0,422.0,CASH SALES COUNTER
|
||||
55.0,462.0,194.0,462.0,194.0,485.0,55.0,485.0,1. 2012-0029
|
||||
81.0,523.0,194.0,523.0,194.0,544.0,81.0,544.0,3 X 2.9000
|
||||
275.0,463.0,597.0,463.0,597.0,486.0,275.0,486.0,RESTAURANT ORDER CHIT NCR
|
||||
274.0,491.0,347.0,491.0,347.0,512.0,274.0,512.0,3.5"X6"
|
||||
482.0,524.0,529.0,524.0,529.0,545.0,482.0,545.0,8.70
|
||||
556.0,525.0,585.0,525.0,585.0,546.0,556.0,546.0,SR
|
||||
28.0,559.0,346.0,559.0,346.0,581.0,28.0,581.0,TOTAL SALES INCLUSIVE GST @6%
|
||||
243.0,590.0,335.0,590.0,335.0,612.0,243.0,612.0,DISCOUNT
|
||||
241.0,632.0,296.0,632.0,296.0,654.0,241.0,654.0,TOTAL
|
||||
242.0,661.0,349.0,661.0,349.0,685.0,242.0,685.0,ROUND ADJ
|
||||
243.0,703.0,348.0,703.0,348.0,724.0,243.0,724.0,FINAL TOTAL
|
||||
244.0,760.0,302.0,760.0,302.0,780.0,244.0,780.0,CASH
|
||||
241.0,792.0,332.0,792.0,332.0,812.0,241.0,812.0,CHANGE
|
||||
481.0,562.0,530.0,562.0,530.0,584.0,481.0,584.0,8.70
|
||||
482.0,594.0,528.0,594.0,528.0,613.0,482.0,613.0,0.00
|
||||
483.0,636.0,530.0,636.0,530.0,654.0,483.0,654.0,8.70
|
||||
483.0,666.0,533.0,666.0,533.0,684.0,483.0,684.0,0.00
|
||||
484.0,707.0,532.0,707.0,532.0,726.0,484.0,726.0,8.70
|
||||
473.0,764.0,535.0,764.0,535.0,783.0,473.0,783.0,10.00
|
||||
486.0,793.0,532.0,793.0,532.0,815.0,486.0,815.0,1.30
|
||||
31.0,836.0,176.0,836.0,176.0,859.0,31.0,859.0,GST SUMMARY
|
||||
257.0,836.0,391.0,836.0,391.0,858.0,257.0,858.0,AMOUNT(RM)
|
||||
479.0,837.0,569.0,837.0,569.0,859.0,479.0,859.0,TAX(RM)
|
||||
33.0,867.0,130.0,867.0,130.0,889.0,33.0,889.0,SR @ 6%
|
||||
341.0,867.0,389.0,867.0,389.0,889.0,341.0,889.0,8.21
|
||||
522.0,869.0,573.0,869.0,573.0,890.0,522.0,890.0,0.49
|
||||
30.0,945.0,292.0,945.0,292.0,967.0,30.0,967.0,INV NO: CS-SA-0120436
|
||||
323.0,945.0,520.0,945.0,520.0,967.0,323.0,967.0,DATE : 27/10/2017
|
||||
70.0,1089.0,572.0,1089.0,572.0,1111.0,70.0,1111.0,GOODS SOLD ARE NOT RETURNABLE & REFUNDABLE
|
||||
116.0,1142.0,526.0,1142.0,526.0,1162.0,116.0,1162.0,THANK YOU FOR YOUR PATRONAGE
|
||||
199.0,1173.0,445.0,1173.0,445.0,1193.0,199.0,1193.0,PLEASE COME AGAIN.
|
||||
121.0,1225.0,524.0,1225.0,524.0,1246.0,121.0,1246.0,TERIMA KASIH SILA DATANG LAGI
|
||||
72.0,1273.0,573.0,1273.0,573.0,1299.0,72.0,1299.0,** PLEASE KEEP THIS RECEIPT FOR PROVE OF
|
||||
55.0,1308.0,591.0,1308.0,591.0,1328.0,55.0,1328.0,PURCHASE DATE FOR I.T PRODUCT WARRANTY
|
||||
249.0,1338.0,396.0,1338.0,396.0,1361.0,249.0,1361.0,PURPOSE **
|
||||
93.0,1391.0,553.0,1391.0,553.0,1416.0,93.0,1416.0,FOLLOW US IN FACEBOOK : SANYU.STATIONERY
|
||||
46
data/test/images/ocr_detection/train_gts/X51007339135.txt
Normal file
46
data/test/images/ocr_detection/train_gts/X51007339135.txt
Normal file
@@ -0,0 +1,46 @@
|
||||
44.0,131.0,517.0,131.0,517.0,166.0,44.0,166.0,SANYU STATIONERY SHOP
|
||||
48.0,180.0,510.0,180.0,510.0,204.0,48.0,204.0,NO. 31G&33G, JALAN SETIA INDAH X ,U13/X
|
||||
47.0,206.0,247.0,206.0,247.0,229.0,47.0,229.0,40170 SETIA ALAM
|
||||
47.0,232.0,434.0,232.0,434.0,258.0,47.0,258.0,MOBILE /WHATSAPPS : +6012-918 7937
|
||||
47.0,265.0,268.0,265.0,268.0,285.0,47.0,285.0,TEL: +603-3362 4137
|
||||
48.0,287.0,325.0,287.0,325.0,313.0,48.0,313.0,GST ID NO: 001531760640
|
||||
411.0,301.0,599.0,301.0,599.0,333.0,411.0,333.0,TAX INVOICE
|
||||
38.0,321.0,143.0,321.0,143.0,342.0,38.0,342.0,OWNED BY :
|
||||
39.0,342.0,379.0,342.0,379.0,367.0,39.0,367.0,SANYU SUPPLY SDN BHD (1135772-K)
|
||||
42.0,397.0,305.0,397.0,305.0,420.0,42.0,420.0,CASH SALES COUNTER
|
||||
57.0,459.0,195.0,459.0,195.0,483.0,57.0,483.0,1. 2012-0029
|
||||
82.0,518.0,199.0,518.0,199.0,540.0,82.0,540.0,3 X 2.9000
|
||||
274.0,459.0,600.0,459.0,600.0,483.0,274.0,483.0,RESTAURANT ORDER CHIT NCR
|
||||
274.0,486.0,347.0,486.0,347.0,508.0,274.0,508.0,3.5"X6"
|
||||
483.0,521.0,530.0,521.0,530.0,541.0,483.0,541.0,8.70
|
||||
557.0,517.0,588.0,517.0,588.0,543.0,557.0,543.0,SR
|
||||
31.0,556.0,347.0,556.0,347.0,578.0,31.0,578.0,TOTAL SALES INCLUSIVE GST @6%
|
||||
244.0,585.0,335.0,585.0,335.0,608.0,244.0,608.0,DISCOUNT
|
||||
241.0,626.0,302.0,626.0,302.0,651.0,241.0,651.0,TOTAL
|
||||
245.0,659.0,354.0,659.0,354.0,683.0,245.0,683.0,ROUND ADJ
|
||||
244.0,698.0,351.0,698.0,351.0,722.0,244.0,722.0,FINAL TOTAL
|
||||
482.0,558.0,529.0,558.0,529.0,578.0,482.0,578.0,8.70
|
||||
484.0,591.0,531.0,591.0,531.0,608.0,484.0,608.0,0.00
|
||||
485.0,630.0,533.0,630.0,533.0,651.0,485.0,651.0,8.70
|
||||
485.0,661.0,532.0,661.0,532.0,681.0,485.0,681.0,0.00
|
||||
484.0,703.0,532.0,703.0,532.0,723.0,484.0,723.0,8.70
|
||||
474.0,760.0,534.0,760.0,534.0,777.0,474.0,777.0,10.00
|
||||
488.0,789.0,532.0,789.0,532.0,808.0,488.0,808.0,1.30
|
||||
33.0,829.0,179.0,829.0,179.0,855.0,33.0,855.0,GST SUMMARY
|
||||
261.0,828.0,390.0,828.0,390.0,855.0,261.0,855.0,AMOUNT(RM)
|
||||
482.0,830.0,572.0,830.0,572.0,856.0,482.0,856.0,TAX(RM)
|
||||
32.0,862.0,135.0,862.0,135.0,885.0,32.0,885.0,SR @ 6%
|
||||
344.0,860.0,389.0,860.0,389.0,884.0,344.0,884.0,8.21
|
||||
523.0,862.0,575.0,862.0,575.0,885.0,523.0,885.0,0.49
|
||||
32.0,941.0,295.0,941.0,295.0,961.0,32.0,961.0,INV NO: CS-SA-0122588
|
||||
72.0,1082.0,576.0,1082.0,576.0,1106.0,72.0,1106.0,GOODS SOLD ARE NOT RETURNABLE & REFUNDABLE
|
||||
115.0,1135.0,528.0,1135.0,528.0,1157.0,115.0,1157.0,THANK YOU FOR YOUR PATRONAGE
|
||||
195.0,1166.0,445.0,1166.0,445.0,1189.0,195.0,1189.0,PLEASE COME AGAIN.
|
||||
122.0,1217.0,528.0,1217.0,528.0,1246.0,122.0,1246.0,TERIMA KASIH SILA DATANG LAGI
|
||||
72.0,1270.0,576.0,1270.0,576.0,1294.0,72.0,1294.0,** PLEASE KEEP THIS RECEIPT FOR PROVE OF
|
||||
55.0,1301.0,592.0,1301.0,592.0,1325.0,55.0,1325.0,PURCHASE DATE FOR I.T PRODUCT WARRANTY
|
||||
251.0,1329.0,400.0,1329.0,400.0,1354.0,251.0,1354.0,PURPOSE **
|
||||
95.0,1386.0,558.0,1386.0,558.0,1413.0,95.0,1413.0,FOLLOW US IN FACEBOOK : SANYU.STATIONERY
|
||||
243.0,752.0,305.0,752.0,305.0,779.0,243.0,779.0,CASH
|
||||
244.0,784.0,336.0,784.0,336.0,807.0,244.0,807.0,CHANGE
|
||||
316.0,939.0,525.0,939.0,525.0,967.0,316.0,967.0,DATE: 06/11/2017
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:ffcc55042093629aaa54d26516de77b45c7b612c0516bad21517e1963e7b518c
|
||||
size 352297
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:56addcb7d36f9b3732e0c4efd04d7e31d291c0763a32dcb556fe262bcbb0520a
|
||||
size 353731
|
||||
2
data/test/images/ocr_detection/train_list.txt
Normal file
2
data/test/images/ocr_detection/train_list.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
X51007339133.jpg
|
||||
X51007339135.jpg
|
||||
3
data/test/images/vidt_test1.jpg
Normal file
3
data/test/images/vidt_test1.jpg
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:b7e87ea289bc59863ed81129d5991ede97bf5335c173ab9f36e4e4cfdc858e41
|
||||
size 120137
|
||||
3
data/test/images/vision_efficient_tuning_test_apple.jpg
Normal file
3
data/test/images/vision_efficient_tuning_test_apple.jpg
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:407d70db9f01bc7a6f34377e36c3f2f5eefdfca8bd3c578226bf5b31b73325dc
|
||||
size 127213
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:9c67733db75dc7fd773561a5091329fd5ee919b2268a3a65718261722607698f
|
||||
size 226882
|
||||
@@ -1,14 +0,0 @@
|
||||
modelscope.msdatasets.cv
|
||||
================================
|
||||
|
||||
.. automodule:: modelscope.msdatasets.cv
|
||||
|
||||
.. currentmodule:: modelscope.msdatasets.cv
|
||||
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
:template: classtemplate.rst
|
||||
|
||||
easycv_base.EasyCVBaseDataset
|
||||
image_classification.ClsDataset
|
||||
@@ -0,0 +1,41 @@
|
||||
modelscope.msdatasets.dataset_cls.custom_datasets
|
||||
====================
|
||||
|
||||
.. automodule:: modelscope.msdatasets.dataset_cls.custom_datasets
|
||||
|
||||
.. currentmodule:: modelscope.msdatasets.dataset_cls.custom_datasets
|
||||
|
||||
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
:template: classtemplate.rst
|
||||
|
||||
EasyCVBaseDataset
|
||||
TorchCustomDataset
|
||||
MovieSceneSegmentationDataset
|
||||
ImageInstanceSegmentationCocoDataset
|
||||
GoproImageDeblurringDataset
|
||||
LanguageGuidedVideoSummarizationDataset
|
||||
MGeoRankingDataset
|
||||
RedsImageDeblurringDataset
|
||||
TextRankingDataset
|
||||
VecoDataset
|
||||
VideoSummarizationDataset
|
||||
BadImageDetectingDataset
|
||||
ImageInpaintingDataset
|
||||
ImagePortraitEnhancementDataset
|
||||
ImageQualityAssessmentDegradationDataset
|
||||
ImageQualityAssessmentMosDataset
|
||||
ReferringVideoObjectSegmentationDataset
|
||||
SiddImageDenoisingDataset
|
||||
VideoFrameInterpolationDataset
|
||||
VideoStabilizationDataset
|
||||
VideoSuperResolutionDataset
|
||||
SegDataset
|
||||
FaceKeypointDataset
|
||||
HandCocoWholeBodyDataset
|
||||
WholeBodyCocoTopDownDataset
|
||||
ClsDataset
|
||||
DetImagesMixDataset
|
||||
DetDataset
|
||||
15
docs/source/api/modelscope.msdatasets.dataset_cls.rst
Normal file
15
docs/source/api/modelscope.msdatasets.dataset_cls.rst
Normal file
@@ -0,0 +1,15 @@
|
||||
modelscope.msdatasets.dataset_cls
|
||||
====================
|
||||
|
||||
.. automodule:: modelscope.msdatasets.dataset_cls
|
||||
|
||||
.. currentmodule:: modelscope.msdatasets.dataset_cls
|
||||
|
||||
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
:template: classtemplate.rst
|
||||
|
||||
ExternalDataset
|
||||
NativeIterableDataset
|
||||
@@ -10,5 +10,4 @@ modelscope.msdatasets.ms_dataset
|
||||
:nosignatures:
|
||||
:template: classtemplate.rst
|
||||
|
||||
MsMapDataset
|
||||
MsDataset
|
||||
|
||||
102
examples/pytorch/text_generation/finetune_text_generation.py
Normal file
102
examples/pytorch/text_generation/finetune_text_generation.py
Normal file
@@ -0,0 +1,102 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from modelscope.metainfo import Trainers
|
||||
from modelscope.msdatasets import MsDataset
|
||||
from modelscope.trainers import EpochBasedTrainer, build_trainer
|
||||
from modelscope.trainers.training_args import TrainingArgs
|
||||
|
||||
|
||||
@dataclass
|
||||
class TextGenerationArguments(TrainingArgs):
|
||||
|
||||
trainer: str = field(
|
||||
default=Trainers.default, metadata={
|
||||
'help': 'The trainer used',
|
||||
})
|
||||
|
||||
work_dir: str = field(
|
||||
default='./tmp',
|
||||
metadata={
|
||||
'help': 'The working path for saving checkpoint',
|
||||
})
|
||||
|
||||
src_txt: str = field(
|
||||
default=None,
|
||||
metadata={
|
||||
'help': 'The source text key of preprocessor',
|
||||
'cfg_node': 'preprocessor.src_txt'
|
||||
})
|
||||
|
||||
tgt_txt: str = field(
|
||||
default=None,
|
||||
metadata={
|
||||
'help': 'The target text key of preprocessor',
|
||||
'cfg_node': 'preprocessor.tgt_txt'
|
||||
})
|
||||
|
||||
preprocessor: str = field(
|
||||
default=None,
|
||||
metadata={
|
||||
'help': 'The preprocessor type',
|
||||
'cfg_node': 'preprocessor.type'
|
||||
})
|
||||
|
||||
lr_scheduler: str = field(
|
||||
default=None,
|
||||
metadata={
|
||||
'help': 'The lr scheduler type',
|
||||
'cfg_node': 'train.lr_scheduler.type'
|
||||
})
|
||||
|
||||
world_size: int = field(
|
||||
default=None,
|
||||
metadata={
|
||||
'help': 'The parallel world size',
|
||||
'cfg_node': 'megatron.world_size'
|
||||
})
|
||||
|
||||
tensor_model_parallel_size: int = field(
|
||||
default=None,
|
||||
metadata={
|
||||
'help': 'The tensor model parallel size',
|
||||
'cfg_node': 'megatron.tensor_model_parallel_size'
|
||||
})
|
||||
|
||||
def __call__(self, config):
|
||||
config = super().__call__(config)
|
||||
if config.train.lr_scheduler.type == 'noam':
|
||||
config.train.lr_scheduler = {
|
||||
'type': 'LambdaLR',
|
||||
'lr_lambda': noam_lambda,
|
||||
'options': {
|
||||
'by_epoch': False
|
||||
}
|
||||
}
|
||||
config.train.hooks.append({'type': 'MegatronHook'})
|
||||
return config
|
||||
|
||||
|
||||
def noam_lambda(current_step: int):
|
||||
current_step += 1
|
||||
return min(current_step**(-0.5), current_step * 100**(-1.5))
|
||||
|
||||
|
||||
args = TextGenerationArguments.from_cli(task='text-generation')
|
||||
|
||||
print(args)
|
||||
|
||||
dataset = MsDataset.load(args.dataset_name)
|
||||
train_dataset = dataset['train']
|
||||
eval_dataset = dataset['validation' if 'validation' in dataset else 'test']
|
||||
|
||||
kwargs = dict(
|
||||
model=args.model,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
seed=args.seed,
|
||||
work_dir=args.work_dir,
|
||||
cfg_modify_fn=args)
|
||||
|
||||
trainer: EpochBasedTrainer = build_trainer(
|
||||
name=args.trainer, default_args=kwargs)
|
||||
trainer.train()
|
||||
22
examples/pytorch/text_generation/run_train.sh
Normal file
22
examples/pytorch/text_generation/run_train.sh
Normal file
@@ -0,0 +1,22 @@
|
||||
DATA_PARALLEL_SIZE=2
|
||||
TENSOR_MODEL_PARALLEL_SIZE=2
|
||||
|
||||
WORLD_SIZE=$(($DATA_PARALLEL_SIZE * $TENSOR_MODEL_PARALLEL_SIZE))
|
||||
|
||||
|
||||
PYTHONPATH=. torchrun --nproc_per_node $WORLD_SIZE examples/pytorch/text_generation/finetune_text_generation.py \
|
||||
--trainer 'nlp-gpt3-trainer' \
|
||||
--work_dir './tmp' \
|
||||
--model 'damo/nlp_gpt3_text-generation_1.3B' \
|
||||
--dataset_name 'chinese-poetry-collection' \
|
||||
--preprocessor 'text-gen-jieba-tokenizer' \
|
||||
--src_txt 'text1' \
|
||||
--tgt_txt 'text2' \
|
||||
--max_epochs 3 \
|
||||
--per_device_train_batch_size 16 \
|
||||
--lr 3e-4 \
|
||||
--lr_scheduler 'noam' \
|
||||
--eval_metrics 'ppl' \
|
||||
--world_size $WORLD_SIZE \
|
||||
--tensor_model_parallel_size $TENSOR_MODEL_PARALLEL_SIZE \
|
||||
# --dataset_name 'DuReader_robust-QG' \ # input&output
|
||||
@@ -3,6 +3,9 @@
|
||||
import argparse
|
||||
|
||||
from modelscope.cli.download import DownloadCMD
|
||||
from modelscope.cli.modelcard import ModelCardCMD
|
||||
from modelscope.cli.pipeline import PipelineCMD
|
||||
from modelscope.cli.plugins import PluginsCMD
|
||||
|
||||
|
||||
def run_cmd():
|
||||
@@ -11,6 +14,9 @@ def run_cmd():
|
||||
subparsers = parser.add_subparsers(help='modelscope commands helpers')
|
||||
|
||||
DownloadCMD.define_args(subparsers)
|
||||
PluginsCMD.define_args(subparsers)
|
||||
PipelineCMD.define_args(subparsers)
|
||||
ModelCardCMD.define_args(subparsers)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
178
modelscope/cli/modelcard.py
Normal file
178
modelscope/cli/modelcard.py
Normal file
@@ -0,0 +1,178 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
from argparse import ArgumentParser
|
||||
from string import Template
|
||||
|
||||
from modelscope.cli.base import CLICommand
|
||||
from modelscope.hub.api import HubApi
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.hub.utils.utils import get_endpoint
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
curren_path = os.path.dirname(os.path.abspath(__file__))
|
||||
template_path = os.path.join(curren_path, 'template')
|
||||
|
||||
|
||||
def subparser_func(args):
|
||||
""" Fuction which will be called for a specific sub parser.
|
||||
"""
|
||||
return ModelCardCMD(args)
|
||||
|
||||
|
||||
class ModelCardCMD(CLICommand):
|
||||
name = 'modelcard'
|
||||
|
||||
def __init__(self, args):
|
||||
self.args = args
|
||||
self.api = HubApi()
|
||||
self.api.login(args.access_token)
|
||||
self.model_id = os.path.join(
|
||||
self.args.group_id, self.args.model_id
|
||||
) if '/' not in self.args.model_id else self.args.model_id
|
||||
self.url = os.path.join(get_endpoint(), self.model_id)
|
||||
|
||||
@staticmethod
|
||||
def define_args(parsers: ArgumentParser):
|
||||
""" define args for create or upload modelcard command.
|
||||
"""
|
||||
parser = parsers.add_parser(ModelCardCMD.name)
|
||||
parser.add_argument(
|
||||
'-tk',
|
||||
'--access_token',
|
||||
type=str,
|
||||
required=True,
|
||||
help='the certification of visit ModelScope')
|
||||
parser.add_argument(
|
||||
'-act',
|
||||
'--action',
|
||||
type=str,
|
||||
required=True,
|
||||
choices=['create', 'upload', 'download'],
|
||||
help='the action of api ModelScope[create, upload]')
|
||||
parser.add_argument(
|
||||
'-gid',
|
||||
'--group_id',
|
||||
type=str,
|
||||
default='damo',
|
||||
help='the group name of ModelScope, eg, damo')
|
||||
parser.add_argument(
|
||||
'-mid',
|
||||
'--model_id',
|
||||
type=str,
|
||||
required=True,
|
||||
help='the model name of ModelScope')
|
||||
parser.add_argument(
|
||||
'-vis',
|
||||
'--visibility',
|
||||
type=int,
|
||||
default=5,
|
||||
help='the visibility of ModelScope')
|
||||
parser.add_argument(
|
||||
'-lic',
|
||||
'--license',
|
||||
type=str,
|
||||
default='Apache License 2.0',
|
||||
help='the license of visit ModelScope')
|
||||
parser.add_argument(
|
||||
'-ch',
|
||||
'--chinese_name',
|
||||
type=str,
|
||||
default='这是我的第一个模型',
|
||||
help='the chinese name of ModelScope')
|
||||
parser.add_argument(
|
||||
'-md',
|
||||
'--model_dir',
|
||||
type=str,
|
||||
default='.',
|
||||
help='the model_dir of configuration.json')
|
||||
parser.add_argument(
|
||||
'-vt',
|
||||
'--version_tag',
|
||||
type=str,
|
||||
default=None,
|
||||
help='the tag of uploaded model')
|
||||
parser.add_argument(
|
||||
'-vi',
|
||||
'--version_info',
|
||||
type=str,
|
||||
default=None,
|
||||
help='the info of uploaded model')
|
||||
parser.set_defaults(func=subparser_func)
|
||||
|
||||
def create_model(self):
|
||||
from modelscope.hub.constants import Licenses, ModelVisibility
|
||||
visibilities = [
|
||||
getattr(ModelVisibility, attr) for attr in dir(ModelVisibility)
|
||||
if not attr.startswith('__')
|
||||
]
|
||||
if self.args.visibility not in visibilities:
|
||||
raise ValueError('The access_token must in %s!' % visibilities)
|
||||
licenses = [
|
||||
getattr(Licenses, attr) for attr in dir(Licenses)
|
||||
if not attr.startswith('__')
|
||||
]
|
||||
if self.args.license not in licenses:
|
||||
raise ValueError('The license must in %s!' % licenses)
|
||||
try:
|
||||
self.api.get_model(self.model_id)
|
||||
except Exception as e:
|
||||
logger.info('>>> %s' % type(e))
|
||||
self.api.create_model(
|
||||
model_id=self.model_id,
|
||||
visibility=self.args.visibility,
|
||||
license=self.args.license,
|
||||
chinese_name=self.args.chinese_name,
|
||||
)
|
||||
self.pprint()
|
||||
|
||||
def get_model_url(self):
|
||||
return self.api.get_model_url(self.model_id)
|
||||
|
||||
def push_model(self, tpl_dir='readme.tpl'):
|
||||
from modelscope.hub.repository import Repository
|
||||
if self.args.version_tag and self.args.version_info:
|
||||
clone_dir = tempfile.TemporaryDirectory().name
|
||||
repo = Repository(clone_dir, clone_from=self.model_id)
|
||||
repo.tag_and_push(self.args.version_tag, self.args.version_info)
|
||||
shutil.rmtree(clone_dir)
|
||||
else:
|
||||
cfg_file = os.path.join(self.args.model_dir, 'README.md')
|
||||
if not os.path.exists(cfg_file):
|
||||
with open(os.path.join(template_path,
|
||||
tpl_dir)) as tpl_file_path:
|
||||
tpl = Template(tpl_file_path.read())
|
||||
f = open(cfg_file, 'w')
|
||||
f.write(tpl.substitute(model_id=self.model_id))
|
||||
f.close()
|
||||
self.api.push_model(
|
||||
model_id=self.model_id,
|
||||
model_dir=self.args.model_dir,
|
||||
visibility=self.args.visibility,
|
||||
license=self.args.license,
|
||||
chinese_name=self.args.chinese_name)
|
||||
self.pprint()
|
||||
|
||||
def pprint(self):
|
||||
logger.info('>>> Clone the model_git < %s >, commit and push it.'
|
||||
% self.get_model_url())
|
||||
logger.info('>>> Open the url < %s >, check and read it.' % self.url)
|
||||
logger.info('>>> Visit the model_id < %s >, download and run it.'
|
||||
% self.model_id)
|
||||
|
||||
def execute(self):
|
||||
if self.args.action == 'create':
|
||||
self.create_model()
|
||||
elif self.args.action == 'upload':
|
||||
self.push_model()
|
||||
elif self.args.action == 'download':
|
||||
snapshot_download(
|
||||
self.model_id,
|
||||
cache_dir=self.args.model_dir,
|
||||
revision=self.args.version_tag)
|
||||
else:
|
||||
raise ValueError(
|
||||
'The parameter of action must be in [create, upload]')
|
||||
127
modelscope/cli/pipeline.py
Normal file
127
modelscope/cli/pipeline.py
Normal file
@@ -0,0 +1,127 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
from argparse import ArgumentParser
|
||||
from string import Template
|
||||
|
||||
from modelscope.cli.base import CLICommand
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
curren_path = os.path.dirname(os.path.abspath(__file__))
|
||||
template_path = os.path.join(curren_path, 'template')
|
||||
|
||||
|
||||
def subparser_func(args):
|
||||
""" Fuction which will be called for a specific sub parser.
|
||||
"""
|
||||
return PipelineCMD(args)
|
||||
|
||||
|
||||
class PipelineCMD(CLICommand):
|
||||
name = 'pipeline'
|
||||
|
||||
def __init__(self, args):
|
||||
self.args = args
|
||||
|
||||
@staticmethod
|
||||
def define_args(parsers: ArgumentParser):
|
||||
""" define args for create pipeline template command.
|
||||
"""
|
||||
parser = parsers.add_parser(PipelineCMD.name)
|
||||
parser.add_argument(
|
||||
'-act',
|
||||
'--action',
|
||||
type=str,
|
||||
required=True,
|
||||
choices=['create'],
|
||||
help='the action of command pipeline[create]')
|
||||
parser.add_argument(
|
||||
'-tpl',
|
||||
'--tpl_file_path',
|
||||
type=str,
|
||||
default='template.tpl',
|
||||
help='the template be selected for ModelScope[template.tpl]')
|
||||
parser.add_argument(
|
||||
'-s',
|
||||
'--save_file_path',
|
||||
type=str,
|
||||
default='./',
|
||||
help='the name of custom template be saved for ModelScope')
|
||||
parser.add_argument(
|
||||
'-f',
|
||||
'--filename',
|
||||
type=str,
|
||||
default='ms_wrapper.py',
|
||||
help='the init name of custom template be saved for ModelScope')
|
||||
parser.add_argument(
|
||||
'-t',
|
||||
'--task_name',
|
||||
type=str,
|
||||
required=True,
|
||||
help='the unique task_name for ModelScope')
|
||||
parser.add_argument(
|
||||
'-m',
|
||||
'--model_name',
|
||||
type=str,
|
||||
default='MyCustomModel',
|
||||
help='the class of model name for ModelScope')
|
||||
parser.add_argument(
|
||||
'-p',
|
||||
'--preprocessor_name',
|
||||
type=str,
|
||||
default='MyCustomPreprocessor',
|
||||
help='the class of preprocessor name for ModelScope')
|
||||
parser.add_argument(
|
||||
'-pp',
|
||||
'--pipeline_name',
|
||||
type=str,
|
||||
default='MyCustomPipeline',
|
||||
help='the class of pipeline name for ModelScope')
|
||||
parser.add_argument(
|
||||
'-config',
|
||||
'--configuration_path',
|
||||
type=str,
|
||||
default='./',
|
||||
help='the path of configuration.json for ModelScope')
|
||||
parser.set_defaults(func=subparser_func)
|
||||
|
||||
def create_template(self):
|
||||
if self.args.tpl_file_path not in os.listdir(template_path):
|
||||
tpl_file_path = self.args.tpl_file_path
|
||||
else:
|
||||
tpl_file_path = os.path.join(template_path,
|
||||
self.args.tpl_file_path)
|
||||
if not os.path.exists(tpl_file_path):
|
||||
raise ValueError('%s not exists!' % tpl_file_path)
|
||||
|
||||
save_file_path = self.args.save_file_path if self.args.save_file_path != './' else os.getcwd(
|
||||
)
|
||||
os.makedirs(save_file_path, exist_ok=True)
|
||||
if not self.args.filename.endswith('.py'):
|
||||
raise ValueError('the FILENAME must end with .py ')
|
||||
save_file_name = self.args.filename
|
||||
save_pkl_path = os.path.join(save_file_path, save_file_name)
|
||||
|
||||
if not self.args.configuration_path.endswith('/'):
|
||||
self.args.configuration_path = self.args.configuration_path + '/'
|
||||
|
||||
lines = []
|
||||
with open(tpl_file_path) as tpl_file:
|
||||
tpl = Template(tpl_file.read())
|
||||
lines.append(tpl.substitute(**vars(self.args)))
|
||||
|
||||
with open(save_pkl_path, 'w') as save_file:
|
||||
save_file.writelines(lines)
|
||||
|
||||
logger.info('>>> Configuration be saved in %s/%s' %
|
||||
(self.args.configuration_path, 'configuration.json'))
|
||||
logger.info('>>> Task_name: %s, Created in %s' %
|
||||
(self.args.task_name, save_pkl_path))
|
||||
logger.info('Open the file < %s >, update and run it.' % save_pkl_path)
|
||||
|
||||
def execute(self):
|
||||
if self.args.action == 'create':
|
||||
self.create_template()
|
||||
else:
|
||||
raise ValueError('The parameter of action must be in [create]')
|
||||
118
modelscope/cli/plugins.py
Normal file
118
modelscope/cli/plugins.py
Normal file
@@ -0,0 +1,118 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from argparse import ArgumentParser
|
||||
|
||||
from modelscope.cli.base import CLICommand
|
||||
from modelscope.utils.plugins import PluginsManager
|
||||
|
||||
plugins_manager = PluginsManager()
|
||||
|
||||
|
||||
def subparser_func(args):
|
||||
""" Fuction which will be called for a specific sub parser.
|
||||
"""
|
||||
return PluginsCMD(args)
|
||||
|
||||
|
||||
class PluginsCMD(CLICommand):
|
||||
name = 'plugin'
|
||||
|
||||
def __init__(self, args):
|
||||
self.args = args
|
||||
|
||||
@staticmethod
|
||||
def define_args(parsers: ArgumentParser):
|
||||
""" define args for install command.
|
||||
"""
|
||||
parser = parsers.add_parser(PluginsCMD.name)
|
||||
subparsers = parser.add_subparsers(dest='command')
|
||||
|
||||
PluginsInstallCMD.define_args(subparsers)
|
||||
PluginsUninstallCMD.define_args(subparsers)
|
||||
PluginsListCMD.define_args(subparsers)
|
||||
|
||||
parser.set_defaults(func=subparser_func)
|
||||
|
||||
def execute(self):
|
||||
print(self.args)
|
||||
if self.args.command == PluginsInstallCMD.name:
|
||||
PluginsInstallCMD.execute(self.args)
|
||||
if self.args.command == PluginsUninstallCMD.name:
|
||||
PluginsUninstallCMD.execute(self.args)
|
||||
if self.args.command == PluginsListCMD.name:
|
||||
PluginsListCMD.execute(self.args)
|
||||
|
||||
|
||||
class PluginsInstallCMD(PluginsCMD):
|
||||
name = 'install'
|
||||
|
||||
@staticmethod
|
||||
def define_args(parsers: ArgumentParser):
|
||||
install = parsers.add_parser(PluginsInstallCMD.name)
|
||||
install.add_argument(
|
||||
'package',
|
||||
type=str,
|
||||
nargs='+',
|
||||
default=None,
|
||||
help='Name of the package to be installed.')
|
||||
install.add_argument(
|
||||
'--index_url',
|
||||
'-i',
|
||||
type=str,
|
||||
default=None,
|
||||
help='Base URL of the Python Package Index.')
|
||||
install.add_argument(
|
||||
'--force_update',
|
||||
'-f',
|
||||
type=str,
|
||||
default=False,
|
||||
help='If force update the package')
|
||||
|
||||
@staticmethod
|
||||
def execute(args):
|
||||
plugins_manager.install_plugins(
|
||||
list(args.package),
|
||||
index_url=args.index_url,
|
||||
force_update=args.force_update)
|
||||
|
||||
|
||||
class PluginsUninstallCMD(PluginsCMD):
|
||||
name = 'uninstall'
|
||||
|
||||
@staticmethod
|
||||
def define_args(parsers: ArgumentParser):
|
||||
install = parsers.add_parser(PluginsUninstallCMD.name)
|
||||
install.add_argument(
|
||||
'package',
|
||||
type=str,
|
||||
nargs='+',
|
||||
default=None,
|
||||
help='Name of the package to be installed.')
|
||||
install.add_argument(
|
||||
'--yes',
|
||||
'-y',
|
||||
type=str,
|
||||
default=False,
|
||||
help='Base URL of the Python Package Index.')
|
||||
|
||||
@staticmethod
|
||||
def execute(args):
|
||||
plugins_manager.uninstall_plugins(list(args.package), is_yes=args.yes)
|
||||
|
||||
|
||||
class PluginsListCMD(PluginsCMD):
|
||||
name = 'list'
|
||||
|
||||
@staticmethod
|
||||
def define_args(parsers: ArgumentParser):
|
||||
install = parsers.add_parser(PluginsListCMD.name)
|
||||
install.add_argument(
|
||||
'--all',
|
||||
'-a',
|
||||
type=str,
|
||||
default=None,
|
||||
help='Show all of the plugins including those not installed.')
|
||||
|
||||
@staticmethod
|
||||
def execute(args):
|
||||
plugins_manager.list_plugins(show_all=all)
|
||||
10
modelscope/cli/template/readme.tpl
Normal file
10
modelscope/cli/template/readme.tpl
Normal file
@@ -0,0 +1,10 @@
|
||||
---
|
||||
license: Apache License 2.0
|
||||
---
|
||||
###### 该模型当前使用的是默认介绍模版,处于“预发布”阶段,页面仅限所有者可见。
|
||||
###### 请根据[模型贡献文档说明](https://www.modelscope.cn/docs/%E5%A6%82%E4%BD%95%E6%92%B0%E5%86%99%E5%A5%BD%E7%94%A8%E7%9A%84%E6%A8%A1%E5%9E%8B%E5%8D%A1%E7%89%87),及时完善模型卡片内容。ModelScope平台将在模型卡片完善后展示。谢谢您的理解。
|
||||
|
||||
#### Clone with HTTP
|
||||
```bash
|
||||
git clone https://www.modelscope.cn/${model_id}.git
|
||||
```
|
||||
139
modelscope/cli/template/template.tpl
Normal file
139
modelscope/cli/template/template.tpl
Normal file
@@ -0,0 +1,139 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import torch
|
||||
|
||||
from modelscope.models.base import TorchModel
|
||||
from modelscope.preprocessors.base import Preprocessor
|
||||
from modelscope.pipelines.base import Model, Pipeline
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.preprocessors.builder import PREPROCESSORS
|
||||
from modelscope.models.builder import MODELS
|
||||
|
||||
|
||||
@MODELS.register_module('${task_name}', module_name='my-custom-model')
|
||||
class ${model_name}(TorchModel):
|
||||
|
||||
def __init__(self, model_dir, *args, **kwargs):
|
||||
super().__init__(model_dir, *args, **kwargs)
|
||||
self.model = self.init_model(**kwargs)
|
||||
|
||||
def forward(self, input_tensor, **forward_params):
|
||||
return self.model(input_tensor, **forward_params)
|
||||
|
||||
def init_model(self, **kwargs):
|
||||
"""Provide default implementation based on TorchModel and user can reimplement it.
|
||||
include init model and load ckpt from the model_dir, maybe include preprocessor
|
||||
if nothing to do, then return lambdx x: x
|
||||
"""
|
||||
return lambda x: x
|
||||
|
||||
|
||||
@PREPROCESSORS.register_module('${task_name}', module_name='my-custom-preprocessor')
|
||||
class ${preprocessor_name}(Preprocessor):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.trainsforms = self.init_preprocessor(**kwargs)
|
||||
|
||||
def __call__(self, results):
|
||||
return self.trainsforms(results)
|
||||
|
||||
def init_preprocessor(self, **kwarg):
|
||||
""" Provide default implementation based on preprocess_cfg and user can reimplement it.
|
||||
if nothing to do, then return lambdx x: x
|
||||
"""
|
||||
return lambda x: x
|
||||
|
||||
|
||||
@PIPELINES.register_module('${task_name}', module_name='my-custom-pipeline')
|
||||
class ${pipeline_name}(Pipeline):
|
||||
""" Give simple introduction to this pipeline.
|
||||
|
||||
Examples:
|
||||
|
||||
>>> from modelscope.pipelines import pipeline
|
||||
>>> input = "Hello, ModelScope!"
|
||||
>>> my_pipeline = pipeline('my-task', 'my-model-id')
|
||||
>>> result = my_pipeline(input)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, model, preprocessor=None, **kwargs):
|
||||
"""
|
||||
use `model` and `preprocessor` to create a custom pipeline for prediction
|
||||
Args:
|
||||
model: model id on modelscope hub.
|
||||
preprocessor: the class of method be init_preprocessor
|
||||
"""
|
||||
super().__init__(model=model, auto_collate=False)
|
||||
assert isinstance(model, str) or isinstance(model, Model), \
|
||||
'model must be a single str or Model'
|
||||
if isinstance(model, str):
|
||||
pipe_model = Model.from_pretrained(model)
|
||||
elif isinstance(model, Model):
|
||||
pipe_model = model
|
||||
else:
|
||||
raise NotImplementedError
|
||||
pipe_model.eval()
|
||||
|
||||
if preprocessor is None:
|
||||
preprocessor = ${preprocessor_name}()
|
||||
super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs)
|
||||
|
||||
def _sanitize_parameters(self, **pipeline_parameters):
|
||||
"""
|
||||
this method should sanitize the keyword args to preprocessor params,
|
||||
forward params and postprocess params on '__call__' or '_process_single' method
|
||||
considered to be a normal classmethod with default implementation / output
|
||||
|
||||
Default Returns:
|
||||
Dict[str, str]: preprocess_params = {}
|
||||
Dict[str, str]: forward_params = {}
|
||||
Dict[str, str]: postprocess_params = pipeline_parameters
|
||||
"""
|
||||
return {}, pipeline_parameters, {}
|
||||
|
||||
def _check_input(self, inputs):
|
||||
pass
|
||||
|
||||
def _check_output(self, outputs):
|
||||
pass
|
||||
|
||||
def forward(self, inputs, **forward_params):
|
||||
""" Provide default implementation using self.model and user can reimplement it
|
||||
"""
|
||||
return super().forward(inputs, **forward_params)
|
||||
|
||||
def postprocess(self, inputs):
|
||||
""" If current pipeline support model reuse, common postprocess
|
||||
code should be write here.
|
||||
|
||||
Args:
|
||||
inputs: input data
|
||||
|
||||
Return:
|
||||
dict of results: a dict containing outputs of model, each
|
||||
output should have the standard output name.
|
||||
"""
|
||||
return inputs
|
||||
|
||||
|
||||
# Tips: usr_config_path is the temporary save configuration location, after upload modelscope hub, it is the model_id
|
||||
usr_config_path = '${configuration_path}'
|
||||
config = Config({
|
||||
'framework': 'pytorch',
|
||||
'task': '${task_name}',
|
||||
'model': {'type': 'my-custom-model'},
|
||||
"pipeline": {"type": "my-custom-pipeline"}
|
||||
})
|
||||
config.dump('${configuration_path}' + 'configuration.json')
|
||||
|
||||
if __name__ == "__main__":
|
||||
from modelscope.models import Model
|
||||
from modelscope.pipelines import pipeline
|
||||
# model = Model.from_pretrained(usr_config_path)
|
||||
input = "Hello, ModelScope!"
|
||||
inference = pipeline('${task_name}', model=usr_config_path)
|
||||
output = inference(input)
|
||||
print(output)
|
||||
@@ -1,14 +1,38 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from modelscope.utils.import_utils import is_tf_available, is_torch_available
|
||||
from .base import Exporter
|
||||
from .builder import build_exporter
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if is_tf_available():
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .base import Exporter
|
||||
from .builder import build_exporter
|
||||
from .cv import CartoonTranslationExporter
|
||||
from .nlp import CsanmtForTranslationExporter
|
||||
from .tf_model_exporter import TfModelExporter
|
||||
if is_torch_available():
|
||||
from .nlp import SbertForSequenceClassificationExporter, SbertForZeroShotClassificationExporter
|
||||
from .torch_model_exporter import TorchModelExporter
|
||||
from .cv import FaceDetectionSCRFDExporter
|
||||
else:
|
||||
_import_structure = {
|
||||
'base': ['Exporter'],
|
||||
'builder': ['build_exporter'],
|
||||
'cv': ['CartoonTranslationExporter', 'FaceDetectionSCRFDExporter'],
|
||||
'nlp': [
|
||||
'CsanmtForTranslationExporter',
|
||||
'SbertForSequenceClassificationExporter',
|
||||
'SbertForZeroShotClassificationExporter'
|
||||
],
|
||||
'tf_model_exporter': ['TfModelExporter'],
|
||||
'torch_model_exporter': ['TorchModelExporter'],
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
|
||||
@@ -1,7 +1,27 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from modelscope.utils.import_utils import is_tf_available, is_torch_available
|
||||
|
||||
if is_tf_available():
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .cartoon_translation_exporter import CartoonTranslationExporter
|
||||
if is_torch_available():
|
||||
from .object_detection_damoyolo_exporter import ObjectDetectionDamoyoloExporter
|
||||
from .face_detection_scrfd_exporter import FaceDetectionSCRFDExporter
|
||||
else:
|
||||
_import_structure = {
|
||||
'cartoon_translation_exporter': ['CartoonTranslationExporter'],
|
||||
'object_detection_damoyolo_exporter':
|
||||
['ObjectDetectionDamoyoloExporter'],
|
||||
'face_detection_scrfd_exporter': ['FaceDetectionSCRFDExporter'],
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
|
||||
@@ -0,0 +1,42 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
from functools import partial
|
||||
from typing import Mapping
|
||||
|
||||
import numpy as np
|
||||
import onnx
|
||||
import torch
|
||||
|
||||
from modelscope.exporters.builder import EXPORTERS
|
||||
from modelscope.exporters.torch_model_exporter import TorchModelExporter
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
|
||||
|
||||
@EXPORTERS.register_module(
|
||||
Tasks.image_object_detection, module_name=Models.tinynas_damoyolo)
|
||||
class ObjectDetectionDamoyoloExporter(TorchModelExporter):
|
||||
|
||||
def export_onnx(self,
|
||||
output_dir: str,
|
||||
opset=11,
|
||||
input_shape=(1, 3, 640, 640)):
|
||||
onnx_file = os.path.join(output_dir, ModelFile.ONNX_MODEL_FILE)
|
||||
dummy_input = torch.randn(*input_shape)
|
||||
self.model.head.nms = False
|
||||
self.model.onnx_export = True
|
||||
self.model.eval()
|
||||
_ = self.model(dummy_input)
|
||||
torch.onnx._export(
|
||||
self.model,
|
||||
dummy_input,
|
||||
onnx_file,
|
||||
input_names=[
|
||||
'images',
|
||||
],
|
||||
output_names=[
|
||||
'pred',
|
||||
],
|
||||
opset_version=opset)
|
||||
|
||||
return {'model', onnx_file}
|
||||
@@ -1,11 +1,31 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from modelscope.utils.import_utils import is_tf_available, is_torch_available
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if is_tf_available():
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .csanmt_for_translation_exporter import CsanmtForTranslationExporter
|
||||
if is_torch_available():
|
||||
from .model_for_token_classification_exporter import ModelForSequenceClassificationExporter
|
||||
from .sbert_for_sequence_classification_exporter import \
|
||||
SbertForSequenceClassificationExporter
|
||||
from .sbert_for_zero_shot_classification_exporter import \
|
||||
SbertForZeroShotClassificationExporter
|
||||
else:
|
||||
_import_structure = {
|
||||
'csanmt_for_translation_exporter': ['CsanmtForTranslationExporter'],
|
||||
'model_for_token_classification_exporter':
|
||||
['ModelForSequenceClassificationExporter'],
|
||||
'sbert_for_zero_shot_classification_exporter':
|
||||
['SbertForZeroShotClassificationExporter'],
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
|
||||
@@ -0,0 +1,112 @@
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Dict, Mapping
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from modelscope.exporters.builder import EXPORTERS
|
||||
from modelscope.exporters.torch_model_exporter import TorchModelExporter
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.outputs import ModelOutputBase
|
||||
from modelscope.preprocessors import Preprocessor
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.regress_test_utils import (compare_arguments_nested,
|
||||
numpify_tensor_nested)
|
||||
|
||||
|
||||
@EXPORTERS.register_module(Tasks.transformer_crf, module_name=Models.tcrf)
|
||||
@EXPORTERS.register_module(Tasks.token_classification, module_name=Models.tcrf)
|
||||
@EXPORTERS.register_module(
|
||||
Tasks.named_entity_recognition, module_name=Models.tcrf)
|
||||
@EXPORTERS.register_module(Tasks.part_of_speech, module_name=Models.tcrf)
|
||||
@EXPORTERS.register_module(Tasks.word_segmentation, module_name=Models.tcrf)
|
||||
class ModelForSequenceClassificationExporter(TorchModelExporter):
|
||||
|
||||
def generate_dummy_inputs(self, **kwargs) -> Dict[str, Any]:
|
||||
"""Generate dummy inputs for model exportation to onnx or other formats by tracing.
|
||||
|
||||
Args:
|
||||
shape: A tuple of input shape which should have at most two dimensions.
|
||||
shape = (1, ) batch_size=1, sequence_length will be taken from the preprocessor.
|
||||
shape = (8, 128) batch_size=1, sequence_length=128, which will cover the config of the preprocessor.
|
||||
pair(bool, `optional`): Whether to generate sentence pairs or single sentences.
|
||||
|
||||
Returns:
|
||||
Dummy inputs.
|
||||
"""
|
||||
|
||||
assert hasattr(
|
||||
self.model, 'model_dir'
|
||||
), 'model_dir attribute is required to build the preprocessor'
|
||||
|
||||
preprocessor = Preprocessor.from_pretrained(
|
||||
self.model.model_dir, return_text=False)
|
||||
return preprocessor('2023')
|
||||
|
||||
@property
|
||||
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||
dynamic_axis = {0: 'batch', 1: 'sequence'}
|
||||
return OrderedDict([
|
||||
('input_ids', dynamic_axis),
|
||||
('attention_mask', dynamic_axis),
|
||||
('offset_mapping', dynamic_axis),
|
||||
('label_mask', dynamic_axis),
|
||||
])
|
||||
|
||||
@property
|
||||
def outputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||
dynamic_axis = {0: 'batch', 1: 'sequence'}
|
||||
return OrderedDict([
|
||||
('predictions', dynamic_axis),
|
||||
])
|
||||
|
||||
def _validate_onnx_model(self,
|
||||
dummy_inputs,
|
||||
model,
|
||||
output,
|
||||
onnx_outputs,
|
||||
rtol: float = None,
|
||||
atol: float = None):
|
||||
try:
|
||||
import onnx
|
||||
import onnxruntime as ort
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
'Cannot validate the exported onnx file, because '
|
||||
'the installation of onnx or onnxruntime cannot be found')
|
||||
return
|
||||
onnx_model = onnx.load(output)
|
||||
onnx.checker.check_model(onnx_model)
|
||||
ort_session = ort.InferenceSession(output)
|
||||
with torch.no_grad():
|
||||
model.eval()
|
||||
outputs_origin = model.forward(
|
||||
*self._decide_input_format(model, dummy_inputs))
|
||||
if isinstance(outputs_origin, (Mapping, ModelOutputBase)):
|
||||
outputs_origin = list(
|
||||
numpify_tensor_nested(outputs_origin).values())
|
||||
elif isinstance(outputs_origin, (tuple, list)):
|
||||
outputs_origin = list(numpify_tensor_nested(outputs_origin))
|
||||
|
||||
outputs_origin = [outputs_origin[0]
|
||||
] # keeo `predictions`, drop other outputs
|
||||
|
||||
np_dummy_inputs = numpify_tensor_nested(dummy_inputs)
|
||||
np_dummy_inputs['label_mask'] = np_dummy_inputs['label_mask'].astype(
|
||||
bool)
|
||||
outputs = ort_session.run(onnx_outputs, np_dummy_inputs)
|
||||
outputs = numpify_tensor_nested(outputs)
|
||||
if isinstance(outputs, dict):
|
||||
outputs = list(outputs.values())
|
||||
elif isinstance(outputs, tuple):
|
||||
outputs = list(outputs)
|
||||
|
||||
tols = {}
|
||||
if rtol is not None:
|
||||
tols['rtol'] = rtol
|
||||
if atol is not None:
|
||||
tols['atol'] = atol
|
||||
if not compare_arguments_nested('Onnx model output match failed',
|
||||
outputs, outputs_origin, **tols):
|
||||
raise RuntimeError(
|
||||
'export onnx failed because of validation error.')
|
||||
@@ -213,45 +213,58 @@ class TorchModelExporter(Exporter):
|
||||
)
|
||||
|
||||
if validation:
|
||||
try:
|
||||
import onnx
|
||||
import onnxruntime as ort
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
'Cannot validate the exported onnx file, because '
|
||||
'the installation of onnx or onnxruntime cannot be found')
|
||||
return
|
||||
onnx_model = onnx.load(output)
|
||||
onnx.checker.check_model(onnx_model)
|
||||
ort_session = ort.InferenceSession(output)
|
||||
with torch.no_grad():
|
||||
model.eval()
|
||||
outputs_origin = model.forward(
|
||||
*self._decide_input_format(model, dummy_inputs))
|
||||
if isinstance(outputs_origin, (Mapping, ModelOutputBase)):
|
||||
outputs_origin = list(
|
||||
numpify_tensor_nested(outputs_origin).values())
|
||||
elif isinstance(outputs_origin, (tuple, list)):
|
||||
outputs_origin = list(numpify_tensor_nested(outputs_origin))
|
||||
outputs = ort_session.run(
|
||||
onnx_outputs,
|
||||
numpify_tensor_nested(dummy_inputs),
|
||||
)
|
||||
outputs = numpify_tensor_nested(outputs)
|
||||
if isinstance(outputs, dict):
|
||||
outputs = list(outputs.values())
|
||||
elif isinstance(outputs, tuple):
|
||||
outputs = list(outputs)
|
||||
self._validate_onnx_model(dummy_inputs, model, output,
|
||||
onnx_outputs, rtol, atol)
|
||||
|
||||
tols = {}
|
||||
if rtol is not None:
|
||||
tols['rtol'] = rtol
|
||||
if atol is not None:
|
||||
tols['atol'] = atol
|
||||
if not compare_arguments_nested('Onnx model output match failed',
|
||||
outputs, outputs_origin, **tols):
|
||||
raise RuntimeError(
|
||||
'export onnx failed because of validation error.')
|
||||
def _validate_onnx_model(self,
|
||||
dummy_inputs,
|
||||
model,
|
||||
output,
|
||||
onnx_outputs,
|
||||
rtol: float = None,
|
||||
atol: float = None):
|
||||
try:
|
||||
import onnx
|
||||
import onnxruntime as ort
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
'Cannot validate the exported onnx file, because '
|
||||
'the installation of onnx or onnxruntime cannot be found')
|
||||
return
|
||||
onnx_model = onnx.load(output)
|
||||
onnx.checker.check_model(onnx_model)
|
||||
ort_session = ort.InferenceSession(output)
|
||||
with torch.no_grad():
|
||||
model.eval()
|
||||
outputs_origin = model.forward(
|
||||
*self._decide_input_format(model, dummy_inputs))
|
||||
if isinstance(outputs_origin, (Mapping, ModelOutputBase)):
|
||||
outputs_origin = list(
|
||||
numpify_tensor_nested(outputs_origin).values())
|
||||
elif isinstance(outputs_origin, (tuple, list)):
|
||||
outputs_origin = list(numpify_tensor_nested(outputs_origin))
|
||||
|
||||
outputs = ort_session.run(
|
||||
onnx_outputs,
|
||||
numpify_tensor_nested(dummy_inputs),
|
||||
)
|
||||
outputs = numpify_tensor_nested(outputs)
|
||||
if isinstance(outputs, dict):
|
||||
outputs = list(outputs.values())
|
||||
elif isinstance(outputs, tuple):
|
||||
outputs = list(outputs)
|
||||
|
||||
tols = {}
|
||||
if rtol is not None:
|
||||
tols['rtol'] = rtol
|
||||
if atol is not None:
|
||||
tols['atol'] = atol
|
||||
print(outputs)
|
||||
print(outputs_origin)
|
||||
if not compare_arguments_nested('Onnx model output match failed',
|
||||
outputs, outputs_origin, **tols):
|
||||
raise RuntimeError(
|
||||
'export onnx failed because of validation error.')
|
||||
|
||||
def _torch_export_torch_script(self,
|
||||
model: nn.Module,
|
||||
@@ -307,28 +320,33 @@ class TorchModelExporter(Exporter):
|
||||
torch.jit.save(traced_model, output)
|
||||
|
||||
if validation:
|
||||
ts_model = torch.jit.load(output)
|
||||
with torch.no_grad():
|
||||
model.eval()
|
||||
ts_model.eval()
|
||||
outputs = ts_model.forward(*dummy_inputs)
|
||||
outputs = numpify_tensor_nested(outputs)
|
||||
outputs_origin = model.forward(*dummy_inputs)
|
||||
outputs_origin = numpify_tensor_nested(outputs_origin)
|
||||
if isinstance(outputs, dict):
|
||||
outputs = list(outputs.values())
|
||||
if isinstance(outputs_origin, dict):
|
||||
outputs_origin = list(outputs_origin.values())
|
||||
tols = {}
|
||||
if rtol is not None:
|
||||
tols['rtol'] = rtol
|
||||
if atol is not None:
|
||||
tols['atol'] = atol
|
||||
if not compare_arguments_nested(
|
||||
'Torch script model output match failed', outputs,
|
||||
outputs_origin, **tols):
|
||||
raise RuntimeError(
|
||||
'export torch script failed because of validation error.')
|
||||
self._validate_torch_script_model(dummy_inputs, model, output,
|
||||
rtol, atol)
|
||||
|
||||
def _validate_torch_script_model(self, dummy_inputs, model, output, rtol,
|
||||
atol):
|
||||
ts_model = torch.jit.load(output)
|
||||
with torch.no_grad():
|
||||
model.eval()
|
||||
ts_model.eval()
|
||||
outputs = ts_model.forward(*dummy_inputs)
|
||||
outputs = numpify_tensor_nested(outputs)
|
||||
outputs_origin = model.forward(*dummy_inputs)
|
||||
outputs_origin = numpify_tensor_nested(outputs_origin)
|
||||
if isinstance(outputs, dict):
|
||||
outputs = list(outputs.values())
|
||||
if isinstance(outputs_origin, dict):
|
||||
outputs_origin = list(outputs_origin.values())
|
||||
tols = {}
|
||||
if rtol is not None:
|
||||
tols['rtol'] = rtol
|
||||
if atol is not None:
|
||||
tols['atol'] = atol
|
||||
if not compare_arguments_nested(
|
||||
'Torch script model output match failed', outputs,
|
||||
outputs_origin, **tols):
|
||||
raise RuntimeError(
|
||||
'export torch script failed because of validation error.')
|
||||
|
||||
|
||||
@contextmanager
|
||||
|
||||
@@ -121,7 +121,7 @@ def model_file_download(
|
||||
|
||||
if model_file['Path'] == file_path:
|
||||
if cache.exists(model_file):
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f'File {model_file["Name"]} already in cache, skip downloading!'
|
||||
)
|
||||
return cache.get_file_by_info(model_file)
|
||||
@@ -209,7 +209,7 @@ def http_get_file(
|
||||
tempfile.NamedTemporaryFile, mode='wb', dir=local_dir, delete=False)
|
||||
get_headers = {} if headers is None else copy.deepcopy(headers)
|
||||
with temp_file_manager() as temp_file:
|
||||
logger.info('downloading %s to %s', url, temp_file.name)
|
||||
logger.debug('downloading %s to %s', url, temp_file.name)
|
||||
# retry sleep 0.5s, 1s, 2s, 4s
|
||||
retry = Retry(
|
||||
total=API_FILE_DOWNLOAD_RETRY_TIMES,
|
||||
@@ -248,7 +248,7 @@ def http_get_file(
|
||||
retry = retry.increment('GET', url, error=e)
|
||||
retry.sleep()
|
||||
|
||||
logger.info('storing %s in cache at %s', url, local_dir)
|
||||
logger.debug('storing %s in cache at %s', url, local_dir)
|
||||
downloaded_length = os.path.getsize(temp_file.name)
|
||||
if total != downloaded_length:
|
||||
os.remove(temp_file.name)
|
||||
|
||||
@@ -122,7 +122,7 @@ def snapshot_download(model_id: str,
|
||||
# check model_file is exist in cache, if existed, skip download, otherwise download
|
||||
if cache.exists(model_file):
|
||||
file_name = os.path.basename(model_file['Name'])
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f'File {file_name} already in cache, skip downloading!'
|
||||
)
|
||||
continue
|
||||
|
||||
@@ -46,6 +46,7 @@ class Models(object):
|
||||
image_paintbyexample = 'Stablediffusion-Paintbyexample'
|
||||
video_summarization = 'pgl-video-summarization'
|
||||
video_panoptic_segmentation = 'swinb-video-panoptic-segmentation'
|
||||
video_instance_segmentation = 'swinb-video-instance-segmentation'
|
||||
language_guided_video_summarization = 'clip-it-language-guided-video-summarization'
|
||||
swinL_semantic_segmentation = 'swinL-semantic-segmentation'
|
||||
vitadapter_semantic_segmentation = 'vitadapter-semantic-segmentation'
|
||||
@@ -78,16 +79,19 @@ class Models(object):
|
||||
image_body_reshaping = 'image-body-reshaping'
|
||||
image_skychange = 'image-skychange'
|
||||
video_human_matting = 'video-human-matting'
|
||||
human_reconstruction = 'human-reconstruction'
|
||||
video_frame_interpolation = 'video-frame-interpolation'
|
||||
video_object_segmentation = 'video-object-segmentation'
|
||||
video_deinterlace = 'video-deinterlace'
|
||||
quadtree_attention_image_matching = 'quadtree-attention-image-matching'
|
||||
vision_middleware = 'vision-middleware'
|
||||
vidt = 'vidt'
|
||||
video_stabilization = 'video-stabilization'
|
||||
real_basicvsr = 'real-basicvsr'
|
||||
rcp_sceneflow_estimation = 'rcp-sceneflow-estimation'
|
||||
image_casmvs_depth_estimation = 'image-casmvs-depth-estimation'
|
||||
vop_retrieval_model = 'vop-retrieval-model'
|
||||
vop_retrieval_model_se = 'vop-retrieval-model-se'
|
||||
ddcolor = 'ddcolor'
|
||||
image_probing_model = 'image-probing-model'
|
||||
defrcn = 'defrcn'
|
||||
@@ -100,7 +104,9 @@ class Models(object):
|
||||
ddpm = 'ddpm'
|
||||
ocr_recognition = 'OCRRecognition'
|
||||
ocr_detection = 'OCRDetection'
|
||||
lineless_table_recognition = 'LoreModel'
|
||||
image_quality_assessment_mos = 'image-quality-assessment-mos'
|
||||
image_quality_assessment_man = 'image-quality-assessment-man'
|
||||
image_quality_assessment_degradation = 'image-quality-assessment-degradation'
|
||||
m2fp = 'm2fp'
|
||||
nerf_recon_acc = 'nerf-recon-acc'
|
||||
@@ -108,6 +114,7 @@ class Models(object):
|
||||
vision_efficient_tuning = 'vision-efficient-tuning'
|
||||
bad_image_detecting = 'bad-image-detecting'
|
||||
controllable_image_generation = 'controllable-image-generation'
|
||||
longshortnet = 'longshortnet'
|
||||
|
||||
# EasyCV models
|
||||
yolox = 'YOLOX'
|
||||
@@ -157,10 +164,12 @@ class Models(object):
|
||||
transformers = 'transformers'
|
||||
plug_mental = 'plug-mental'
|
||||
doc2bot = 'doc2bot'
|
||||
peer = 'peer'
|
||||
|
||||
# audio models
|
||||
sambert_hifigan = 'sambert-hifigan'
|
||||
speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k'
|
||||
speech_dfsmn_ans = 'speech_dfsmn_ans'
|
||||
speech_dfsmn_kws_char_farfield = 'speech_dfsmn_kws_char_farfield'
|
||||
speech_kws_fsmn_char_ctc_nearfield = 'speech_kws_fsmn_char_ctc_nearfield'
|
||||
speech_mossformer_separation_temporal_8k = 'speech_mossformer_separation_temporal_8k'
|
||||
@@ -177,14 +186,17 @@ class Models(object):
|
||||
ofa = 'ofa'
|
||||
clip = 'clip-multi-modal-embedding'
|
||||
gemm = 'gemm-generative-multi-modal'
|
||||
rleg = 'rleg-generative-multi-modal'
|
||||
mplug = 'mplug'
|
||||
diffusion = 'diffusion-text-to-image-synthesis'
|
||||
multi_stage_diffusion = 'multi-stage-diffusion-text-to-image-synthesis'
|
||||
video_synthesis = 'latent-text-to-video-synthesis'
|
||||
team = 'team-multi-modal-similarity'
|
||||
video_clip = 'video-clip-multi-modal-embedding'
|
||||
mgeo = 'mgeo'
|
||||
vldoc = 'vldoc'
|
||||
hitea = 'hitea'
|
||||
soonet = 'soonet'
|
||||
|
||||
# science models
|
||||
unifold = 'unifold'
|
||||
@@ -242,6 +254,7 @@ class Pipelines(object):
|
||||
person_image_cartoon = 'unet-person-image-cartoon'
|
||||
ocr_detection = 'resnet18-ocr-detection'
|
||||
table_recognition = 'dla34-table-recognition'
|
||||
lineless_table_recognition = 'lore-lineless-table-recognition'
|
||||
license_plate_detection = 'resnet18-license-plate-detection'
|
||||
action_recognition = 'TAdaConv_action-recognition'
|
||||
animal_recognition = 'resnet101-animal-recognition'
|
||||
@@ -322,6 +335,7 @@ class Pipelines(object):
|
||||
crowd_counting = 'hrnet-crowd-counting'
|
||||
action_detection = 'ResNetC3D-action-detection'
|
||||
video_single_object_tracking = 'ostrack-vitb-video-single-object-tracking'
|
||||
video_single_object_tracking_procontext = 'procontext-vitb-video-single-object-tracking'
|
||||
video_multi_object_tracking = 'video-multi-object-tracking'
|
||||
image_panoptic_segmentation = 'image-panoptic-segmentation'
|
||||
image_panoptic_segmentation_easycv = 'image-panoptic-segmentation-easycv'
|
||||
@@ -350,7 +364,9 @@ class Pipelines(object):
|
||||
referring_video_object_segmentation = 'referring-video-object-segmentation'
|
||||
image_skychange = 'image-skychange'
|
||||
video_human_matting = 'video-human-matting'
|
||||
human_reconstruction = 'human-reconstruction'
|
||||
vision_middleware_multi_task = 'vision-middleware-multi-task'
|
||||
vidt = 'vidt'
|
||||
video_frame_interpolation = 'video-frame-interpolation'
|
||||
video_object_segmentation = 'video-object-segmentation'
|
||||
video_deinterlace = 'video-deinterlace'
|
||||
@@ -360,7 +376,9 @@ class Pipelines(object):
|
||||
pointcloud_sceneflow_estimation = 'pointcloud-sceneflow-estimation'
|
||||
image_multi_view_depth_estimation = 'image-multi-view-depth-estimation'
|
||||
video_panoptic_segmentation = 'video-panoptic-segmentation'
|
||||
video_instance_segmentation = 'video-instance-segmentation'
|
||||
vop_retrieval = 'vop-video-text-retrieval'
|
||||
vop_retrieval_se = 'vop-video-text-retrieval-se'
|
||||
ddcolor_image_colorization = 'ddcolor-image-colorization'
|
||||
image_structured_model_probing = 'image-structured-model-probing'
|
||||
image_fewshot_detection = 'image-fewshot-detection'
|
||||
@@ -377,8 +395,10 @@ class Pipelines(object):
|
||||
controllable_image_generation = 'controllable-image-generation'
|
||||
|
||||
image_quality_assessment_mos = 'image-quality-assessment-mos'
|
||||
image_quality_assessment_man = 'image-quality-assessment-man'
|
||||
image_quality_assessment_degradation = 'image-quality-assessment-degradation'
|
||||
vision_efficient_tuning = 'vision-efficient-tuning'
|
||||
image_bts_depth_estimation = 'image-bts-depth-estimation'
|
||||
|
||||
# nlp tasks
|
||||
automatic_post_editing = 'automatic-post-editing'
|
||||
@@ -441,6 +461,7 @@ class Pipelines(object):
|
||||
sambert_hifigan_tts = 'sambert-hifigan-tts'
|
||||
speech_dfsmn_aec_psm_16k = 'speech-dfsmn-aec-psm-16k'
|
||||
speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k'
|
||||
speech_dfsmn_ans_psm_48k_causal = 'speech_dfsmn_ans_psm_48k_causal'
|
||||
speech_dfsmn_kws_char_farfield = 'speech_dfsmn_kws_char_farfield'
|
||||
speech_separation = 'speech-separation'
|
||||
kws_kwsbp = 'kws-kwsbp'
|
||||
@@ -453,6 +474,7 @@ class Pipelines(object):
|
||||
vad_inference = 'vad-inference'
|
||||
speaker_verification = 'speaker-verification'
|
||||
lm_inference = 'language-score-prediction'
|
||||
speech_timestamp_inference = 'speech-timestamp-inference'
|
||||
|
||||
# multi-modal tasks
|
||||
image_captioning = 'image-captioning'
|
||||
@@ -472,10 +494,13 @@ class Pipelines(object):
|
||||
video_captioning = 'video-captioning'
|
||||
video_question_answering = 'video-question-answering'
|
||||
diffusers_stable_diffusion = 'diffusers-stable-diffusion'
|
||||
disco_guided_diffusion = 'disco_guided_diffusion'
|
||||
document_vl_embedding = 'document-vl-embedding'
|
||||
chinese_stable_diffusion = 'chinese-stable-diffusion'
|
||||
text_to_video_synthesis = 'latent-text-to-video-synthesis' # latent-text-to-video-synthesis
|
||||
gridvlp_multi_modal_classification = 'gridvlp-multi-modal-classification'
|
||||
gridvlp_multi_modal_embedding = 'gridvlp-multi-modal-embedding'
|
||||
soonet_video_temporal_grounding = 'soonet-video-temporal-grounding'
|
||||
|
||||
# science tasks
|
||||
protein_structure = 'unifold-protein-structure'
|
||||
@@ -574,6 +599,9 @@ DEFAULT_MODEL_FOR_PIPELINE = {
|
||||
Tasks.table_recognition:
|
||||
(Pipelines.table_recognition,
|
||||
'damo/cv_dla34_table-structure-recognition_cycle-centernet'),
|
||||
Tasks.lineless_table_recognition:
|
||||
(Pipelines.lineless_table_recognition,
|
||||
'damo/cv_resnet-transformer_table-structure-recognition_lore'),
|
||||
Tasks.document_vl_embedding:
|
||||
(Pipelines.document_vl_embedding,
|
||||
'damo/multi-modal_convnext-roberta-base_vldoc-embedding'),
|
||||
@@ -611,6 +639,8 @@ DEFAULT_MODEL_FOR_PIPELINE = {
|
||||
Tasks.text_to_image_synthesis:
|
||||
(Pipelines.text_to_image_synthesis,
|
||||
'damo/cv_diffusion_text-to-image-synthesis_tiny'),
|
||||
Tasks.text_to_video_synthesis: (Pipelines.text_to_video_synthesis,
|
||||
'damo/text-to-video-synthesis'),
|
||||
Tasks.body_2d_keypoints: (Pipelines.body_2d_keypoints,
|
||||
'damo/cv_hrnetv2w32_body-2d-keypoints_image'),
|
||||
Tasks.body_3d_keypoints: (Pipelines.body_3d_keypoints,
|
||||
@@ -708,9 +738,9 @@ DEFAULT_MODEL_FOR_PIPELINE = {
|
||||
'damo/cv_vitb_video-single-object-tracking_ostrack'),
|
||||
Tasks.image_reid_person: (Pipelines.image_reid_person,
|
||||
'damo/cv_passvitb_image-reid-person_market'),
|
||||
Tasks.text_driven_segmentation:
|
||||
(Pipelines.text_driven_segmentation,
|
||||
'damo/cv_vitl16_segmentation_text-driven-seg'),
|
||||
Tasks.text_driven_segmentation: (
|
||||
Pipelines.text_driven_segmentation,
|
||||
'damo/cv_vitl16_segmentation_text-driven-seg'),
|
||||
Tasks.movie_scene_segmentation: (
|
||||
Pipelines.movie_scene_segmentation,
|
||||
'damo/cv_resnet50-bert_video-scene-segmentation_movienet'),
|
||||
@@ -727,6 +757,8 @@ DEFAULT_MODEL_FOR_PIPELINE = {
|
||||
'damo/cv_video-inpainting'),
|
||||
Tasks.video_human_matting: (Pipelines.video_human_matting,
|
||||
'damo/cv_effnetv2_video-human-matting'),
|
||||
Tasks.human_reconstruction: (Pipelines.human_reconstruction,
|
||||
'damo/cv_hrnet_image-human-reconstruction'),
|
||||
Tasks.video_frame_interpolation: (
|
||||
Pipelines.video_frame_interpolation,
|
||||
'damo/cv_raft_video-frame-interpolation'),
|
||||
@@ -805,7 +837,11 @@ class CVTrainers(object):
|
||||
image_classification_team = 'image-classification-team'
|
||||
image_classification = 'image-classification'
|
||||
image_fewshot_detection = 'image-fewshot-detection'
|
||||
ocr_recognition = 'ocr-recognition'
|
||||
ocr_detection_db = 'ocr-detection-db'
|
||||
nerf_recon_acc = 'nerf-recon-acc'
|
||||
action_detection = 'action-detection'
|
||||
vision_efficient_tuning = 'vision-efficient-tuning'
|
||||
|
||||
|
||||
class NLPTrainers(object):
|
||||
@@ -826,6 +862,7 @@ class NLPTrainers(object):
|
||||
document_grounded_dialog_generate_trainer = 'document-grounded-dialog-generate-trainer'
|
||||
document_grounded_dialog_rerank_trainer = 'document-grounded-dialog-rerank-trainer'
|
||||
document_grounded_dialog_retrieval_trainer = 'document-grounded-dialog-retrieval-trainer'
|
||||
siamese_uie_trainer = 'siamese-uie-trainer'
|
||||
|
||||
|
||||
class MultiModalTrainers(object):
|
||||
@@ -904,6 +941,7 @@ class Preprocessors(object):
|
||||
image_instance_segmentation_preprocessor = 'image-instance-segmentation-preprocessor'
|
||||
image_driving_perception_preprocessor = 'image-driving-perception-preprocessor'
|
||||
image_portrait_enhancement_preprocessor = 'image-portrait-enhancement-preprocessor'
|
||||
image_quality_assessment_man_preprocessor = 'image-quality_assessment-man-preprocessor'
|
||||
image_quality_assessment_mos_preprocessor = 'image-quality_assessment-mos-preprocessor'
|
||||
video_summarization_preprocessor = 'video-summarization-preprocessor'
|
||||
movie_scene_segmentation_preprocessor = 'movie-scene-segmentation-preprocessor'
|
||||
@@ -916,6 +954,7 @@ class Preprocessors(object):
|
||||
bad_image_detecting_preprocessor = 'bad-image-detecting-preprocessor'
|
||||
nerf_recon_acc_preprocessor = 'nerf-recon-acc-preprocessor'
|
||||
controllable_image_generation_preprocessor = 'controllable-image-generation-preprocessor'
|
||||
image_classification_preprocessor = 'image-classification-preprocessor'
|
||||
|
||||
# nlp preprocessor
|
||||
sen_sim_tokenizer = 'sen-sim-tokenizer'
|
||||
@@ -1035,6 +1074,9 @@ class Metrics(object):
|
||||
image_quality_assessment_degradation_metric = 'image-quality-assessment-degradation-metric'
|
||||
# metric for text-ranking task
|
||||
text_ranking_metric = 'text-ranking-metric'
|
||||
# metric for image-colorization task
|
||||
image_colorization_metric = 'image-colorization-metric'
|
||||
ocr_recognition_metric = 'ocr-recognition-metric'
|
||||
|
||||
|
||||
class Optimizers(object):
|
||||
@@ -1087,6 +1129,7 @@ class Hooks(object):
|
||||
EarlyStopHook = 'EarlyStopHook'
|
||||
DeepspeedHook = 'DeepspeedHook'
|
||||
MegatronHook = 'MegatronHook'
|
||||
DDPHook = 'DDPHook'
|
||||
|
||||
|
||||
class LR_Schedulers(object):
|
||||
@@ -1098,7 +1141,7 @@ class LR_Schedulers(object):
|
||||
ExponentialWarmup = 'ExponentialWarmup'
|
||||
|
||||
|
||||
class Datasets(object):
|
||||
class CustomDatasets(object):
|
||||
""" Names for different datasets.
|
||||
"""
|
||||
ClsDataset = 'ClsDataset'
|
||||
@@ -1110,3 +1153,6 @@ class Datasets(object):
|
||||
DetImagesMixDataset = 'DetImagesMixDataset'
|
||||
PanopticDataset = 'PanopticDataset'
|
||||
PairedDataset = 'PairedDataset'
|
||||
SiddDataset = 'SiddDataset'
|
||||
GoproDataset = 'GoproDataset'
|
||||
RedsDataset = 'RedsDataset'
|
||||
|
||||
@@ -29,6 +29,8 @@ if TYPE_CHECKING:
|
||||
from .image_quality_assessment_mos_metric import ImageQualityAssessmentMosMetric
|
||||
from .text_ranking_metric import TextRankingMetric
|
||||
from .loss_metric import LossMetric
|
||||
from .image_colorization_metric import ImageColorizationMetric
|
||||
from .ocr_recognition_metric import OCRRecognitionMetric
|
||||
else:
|
||||
_import_structure = {
|
||||
'audio_noise_metric': ['AudioNoiseMetric'],
|
||||
@@ -58,7 +60,9 @@ else:
|
||||
'image_quality_assessment_mos_metric':
|
||||
['ImageQualityAssessmentMosMetric'],
|
||||
'text_ranking_metric': ['TextRankingMetric'],
|
||||
'loss_metric': ['LossMetric']
|
||||
'loss_metric': ['LossMetric'],
|
||||
'image_colorization_metric': ['ImageColorizationMetric'],
|
||||
'ocr_recognition_metric': ['OCRRecognitionMetric']
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
198
modelscope/metrics/action_detection_evaluator.py
Normal file
198
modelscope/metrics/action_detection_evaluator.py
Normal file
@@ -0,0 +1,198 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import copy
|
||||
import logging
|
||||
import os.path as osp
|
||||
from collections import OrderedDict
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from detectron2.evaluation import DatasetEvaluator
|
||||
from detectron2.evaluation.pascal_voc_evaluation import voc_ap
|
||||
from detectron2.structures.boxes import Boxes, pairwise_iou
|
||||
from detectron2.utils import comm
|
||||
from scipy import interpolate
|
||||
|
||||
|
||||
class DetEvaluator(DatasetEvaluator):
|
||||
|
||||
def __init__(self, class_names, output_dir, distributed=False):
|
||||
self.num_classes = len(class_names)
|
||||
self.class_names = class_names
|
||||
self.output_dir = output_dir
|
||||
self.distributed = distributed
|
||||
self.predictions = []
|
||||
self.gts = []
|
||||
|
||||
def reset(self):
|
||||
self.predictions.clear()
|
||||
self.gts.clear()
|
||||
|
||||
def process(self, input, output):
|
||||
"""
|
||||
|
||||
:param input: dataloader
|
||||
:param output: model(input)
|
||||
:return:
|
||||
"""
|
||||
gt_instances = [x['instances'].to('cpu') for x in input]
|
||||
pred_instances = [x['instances'].to('cpu') for x in output]
|
||||
self.gts.extend(gt_instances)
|
||||
self.predictions.extend(pred_instances)
|
||||
|
||||
def get_instance_by_class(self, instances, c):
|
||||
instances = copy.deepcopy(instances)
|
||||
name = 'gt_classes' if instances.has('gt_classes') else 'pred_classes'
|
||||
idxs = np.where(instances.get(name).numpy() == c)[0].tolist()
|
||||
data = {}
|
||||
for k, v in instances.get_fields().items():
|
||||
data[k] = [v[i] for i in idxs]
|
||||
return data
|
||||
|
||||
def evaluate(self):
|
||||
if self.distributed:
|
||||
comm.synchronize()
|
||||
self.predictions = sum(comm.gather(self.predictions, dst=0), [])
|
||||
self.gts = sum(comm.gather(self.gts, dst=0), [])
|
||||
if not comm.is_main_process():
|
||||
return
|
||||
logger = logging.getLogger('detectron2.human.' + __name__)
|
||||
logger.info(', '.join([f'{a}' for a in self.class_names]))
|
||||
maps = []
|
||||
precisions = []
|
||||
recalls = []
|
||||
for iou_th in [0.3, 0.5, 0.7]:
|
||||
aps, prs, ths = self.calc_map(iou_th)
|
||||
map = np.nanmean([x for x in aps if x > 0.01])
|
||||
maps.append(map)
|
||||
logger.info(f'iou_th:{iou_th},' + 'Aps:'
|
||||
+ ','.join([f'{ap:.2f}'
|
||||
for ap in aps]) + f', {map:.3f}')
|
||||
precision, recall = zip(*prs)
|
||||
logger.info('precision:'
|
||||
+ ', '.join([f'{p:.2f}' for p in precision]))
|
||||
logger.info('recall: ' + ', '.join([f'{p:.2f}' for p in recall]))
|
||||
logger.info('score th: ' + ', '.join([f'{p:.2f}' for p in ths]))
|
||||
logger.info(f'mean-precision:{np.nanmean(precision):.3f}')
|
||||
logger.info(f'mean-recall:{np.nanmean(recall):.3f}')
|
||||
precisions.append(np.nanmean(precision))
|
||||
recalls.append(np.nanmean(recall))
|
||||
|
||||
res = OrderedDict({
|
||||
'det': {
|
||||
'mAP': np.nanmean(maps),
|
||||
'precision': np.nanmean(precisions),
|
||||
'recall': np.nanmean(recalls)
|
||||
}
|
||||
})
|
||||
return res
|
||||
|
||||
def calc_map(self, iou_th):
|
||||
aps = []
|
||||
prs = []
|
||||
ths = []
|
||||
# 对每个类别
|
||||
interpolate_precs = []
|
||||
for c in range(self.num_classes):
|
||||
ap, recalls, precisions, scores = self.det_eval(iou_th, c)
|
||||
if iou_th == 0.3:
|
||||
p1 = interpolate_precision(recalls, precisions)
|
||||
interpolate_precs.append(p1)
|
||||
recalls = np.concatenate(([0.0], recalls, [1.0]))
|
||||
precisions = np.concatenate(([0.0], precisions, [0.0]))
|
||||
scores = np.concatenate(([1.0], scores, [0.0]))
|
||||
t = precisions + recalls
|
||||
t[t == 0] = 1e-5
|
||||
f_score = 2 * precisions * recalls / t
|
||||
f_score[np.isnan(f_score)] = 0
|
||||
idx = np.argmax(f_score)
|
||||
# print(iou_th,c,np.argmax(f_score),np.argmax(t))
|
||||
precision_recall = (precisions[idx], recalls[idx])
|
||||
prs.append(precision_recall)
|
||||
aps.append(ap)
|
||||
ths.append(scores[idx])
|
||||
if iou_th == 0.3:
|
||||
interpolate_precs = np.stack(interpolate_precs, axis=1)
|
||||
df = pd.DataFrame(data=interpolate_precs)
|
||||
df.to_csv(
|
||||
osp.join(self.output_dir, 'pr_data.csv'),
|
||||
index=False,
|
||||
columns=None)
|
||||
return aps, prs, ths
|
||||
|
||||
def det_eval(self, iou_th, class_id):
|
||||
c = class_id
|
||||
class_res_gt = {}
|
||||
npos = 0
|
||||
# 对每个样本
|
||||
for i, (gt, pred) in enumerate(zip(self.gts, self.predictions)):
|
||||
gt_classes = gt.gt_classes.tolist()
|
||||
pred_classes = pred.pred_classes.tolist()
|
||||
if c not in gt_classes + pred_classes:
|
||||
continue
|
||||
pred_data = self.get_instance_by_class(pred, c)
|
||||
gt_data = self.get_instance_by_class(gt, c)
|
||||
res = {}
|
||||
if c in gt_classes:
|
||||
res.update({
|
||||
'gt_bbox': Boxes.cat(gt_data['gt_boxes']),
|
||||
'det': [False] * len(gt_data['gt_classes'])
|
||||
})
|
||||
if c in pred_classes:
|
||||
res.update({'pred_bbox': Boxes.cat(pred_data['pred_boxes'])})
|
||||
res.update(
|
||||
{'pred_score': [s.item() for s in pred_data['scores']]})
|
||||
class_res_gt[i] = res
|
||||
npos += len(gt_data['gt_classes'])
|
||||
|
||||
all_preds = []
|
||||
for img_id, res in class_res_gt.items():
|
||||
if 'pred_bbox' in res:
|
||||
for i in range(len(res['pred_bbox'])):
|
||||
bbox = res['pred_bbox'][i]
|
||||
score = res['pred_score'][i]
|
||||
all_preds.append([img_id, bbox, score])
|
||||
sorted_preds = list(
|
||||
sorted(all_preds, key=lambda x: x[2], reverse=True))
|
||||
scores = [s[-1] for s in sorted_preds]
|
||||
nd = len(sorted_preds)
|
||||
tp = np.zeros(nd)
|
||||
fp = np.zeros(nd)
|
||||
for d in range(nd):
|
||||
img_id, pred_bbox, score = sorted_preds[d]
|
||||
R = class_res_gt[sorted_preds[d][0]]
|
||||
ovmax = -np.inf
|
||||
if 'gt_bbox' in R:
|
||||
gt_bbox = R['gt_bbox']
|
||||
IoUs = pairwise_iou(pred_bbox, gt_bbox).numpy()
|
||||
ovmax = IoUs[0].max()
|
||||
jmax = np.argmax(IoUs[0]) # hit该图像的第几个gt
|
||||
if ovmax > iou_th:
|
||||
if not R['det'][jmax]: # 该gt还没有预测过
|
||||
tp[d] = 1.0
|
||||
R['det'][jmax] = True
|
||||
else: # 重复预测
|
||||
fp[d] = 1.0
|
||||
else:
|
||||
fp[d] = 1.0
|
||||
fp = np.cumsum(fp)
|
||||
tp = np.cumsum(tp)
|
||||
rec = tp / float(npos)
|
||||
# avoid divide by zero in case the first detection matches a difficult
|
||||
# ground truth
|
||||
prec = tp / np.maximum(tp + fp, np.finfo(np.float64).eps)
|
||||
ap = voc_ap(rec, prec, False)
|
||||
return ap, rec, prec, scores
|
||||
|
||||
|
||||
def interpolate_precision(rec, prec):
|
||||
rec = np.concatenate(([0.0], rec, [1.0, 1.1]))
|
||||
prec = np.concatenate(([1.0], prec, [0.0]))
|
||||
for i in range(prec.size - 1, 0, -1):
|
||||
prec[i - 1] = np.maximum(prec[i - 1], prec[i])
|
||||
i = np.where(rec[1:] != rec[:-1])[0] # 从recall改变的地方取值
|
||||
rec, prec = rec[i], prec[i]
|
||||
f = interpolate.interp1d(rec, prec)
|
||||
r1 = np.linspace(0, 1, 101)
|
||||
p1 = f(r1)
|
||||
return p1
|
||||
@@ -40,6 +40,7 @@ class MetricKeys(object):
|
||||
RMSE = 'rmse'
|
||||
MRR = 'mrr'
|
||||
NDCG = 'ndcg'
|
||||
AR = 'AR'
|
||||
|
||||
|
||||
task_default_metrics = {
|
||||
@@ -71,6 +72,7 @@ task_default_metrics = {
|
||||
Tasks.image_quality_assessment_mos:
|
||||
[Metrics.image_quality_assessment_mos_metric],
|
||||
Tasks.bad_image_detecting: [Metrics.accuracy],
|
||||
Tasks.ocr_recognition: [Metrics.ocr_recognition_metric],
|
||||
}
|
||||
|
||||
|
||||
|
||||
56
modelscope/metrics/image_colorization_metric.py
Normal file
56
modelscope/metrics/image_colorization_metric.py
Normal file
@@ -0,0 +1,56 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from scipy import linalg
|
||||
|
||||
from modelscope.metainfo import Metrics
|
||||
from modelscope.models.cv.image_inpainting.modules.inception import InceptionV3
|
||||
from modelscope.utils.registry import default_group
|
||||
from modelscope.utils.tensor_utils import (torch_nested_detach,
|
||||
torch_nested_numpify)
|
||||
from .base import Metric
|
||||
from .builder import METRICS, MetricKeys
|
||||
from .image_denoise_metric import calculate_psnr
|
||||
from .image_inpainting_metric import FIDScore
|
||||
|
||||
|
||||
@METRICS.register_module(
|
||||
group_key=default_group, module_name=Metrics.image_colorization_metric)
|
||||
class ImageColorizationMetric(Metric):
|
||||
"""The metric computation class for image colorization.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.preds = []
|
||||
self.targets = []
|
||||
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
self.FID = FIDScore().to(device)
|
||||
|
||||
def add(self, outputs: Dict, inputs: Dict):
|
||||
ground_truths = outputs['preds']
|
||||
eval_results = outputs['targets']
|
||||
self.preds.append(eval_results)
|
||||
self.targets.append(ground_truths)
|
||||
|
||||
def evaluate(self):
|
||||
psnr_list = []
|
||||
for (pred, target) in zip(self.preds, self.targets):
|
||||
self.FID(pred, target)
|
||||
psnr_list.append(calculate_psnr(target[0], pred[0], crop_border=0))
|
||||
fid = self.FID.get_value()
|
||||
return {MetricKeys.PSNR: np.mean(psnr_list), MetricKeys.FID: fid}
|
||||
|
||||
def merge(self, other: 'ImageColorizationMetric'):
|
||||
self.preds.extend(other.preds)
|
||||
self.targets.extend(other.targets)
|
||||
|
||||
def __getstate__(self):
|
||||
return self.preds, self.targets
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.__init__()
|
||||
self.preds, self.targets = state
|
||||
79
modelscope/metrics/ocr_recognition_metric.py
Normal file
79
modelscope/metrics/ocr_recognition_metric.py
Normal file
@@ -0,0 +1,79 @@
|
||||
from typing import Dict
|
||||
|
||||
import edit_distance as ed
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from modelscope.metainfo import Metrics
|
||||
from modelscope.utils.registry import default_group
|
||||
from .base import Metric
|
||||
from .builder import METRICS, MetricKeys
|
||||
|
||||
|
||||
def cal_distance(label_list, pre_list):
|
||||
y = ed.SequenceMatcher(a=label_list, b=pre_list)
|
||||
yy = y.get_opcodes()
|
||||
insert = 0
|
||||
delete = 0
|
||||
replace = 0
|
||||
for item in yy:
|
||||
if item[0] == 'insert':
|
||||
insert += item[-1] - item[-2]
|
||||
if item[0] == 'delete':
|
||||
delete += item[2] - item[1]
|
||||
if item[0] == 'replace':
|
||||
replace += item[-1] - item[-2]
|
||||
distance = insert + delete + replace
|
||||
return distance, (delete, replace, insert)
|
||||
|
||||
|
||||
@METRICS.register_module(
|
||||
group_key=default_group, module_name=Metrics.ocr_recognition_metric)
|
||||
class OCRRecognitionMetric(Metric):
|
||||
"""The metric computation class for ocr recognition.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.preds = []
|
||||
self.targets = []
|
||||
self.loss_sum = 0.
|
||||
self.nsample = 0
|
||||
self.iter_sum = 0
|
||||
|
||||
def add(self, outputs: Dict, inputs: Dict):
|
||||
pred = outputs['preds']
|
||||
loss = outputs['loss']
|
||||
target = inputs['labels']
|
||||
self.preds.extend(pred)
|
||||
self.targets.extend(target)
|
||||
self.loss_sum += loss.data.cpu().numpy()
|
||||
self.nsample += len(pred)
|
||||
self.iter_sum += 1
|
||||
|
||||
def evaluate(self):
|
||||
total_chars = 0
|
||||
total_distance = 0
|
||||
total_fullmatch = 0
|
||||
for (pred, target) in zip(self.preds, self.targets):
|
||||
distance, _ = cal_distance(target, pred)
|
||||
total_chars += len(target)
|
||||
total_distance += distance
|
||||
total_fullmatch += (target == pred)
|
||||
accuracy = float(total_fullmatch) / self.nsample
|
||||
AR = 1 - float(total_distance) / total_chars
|
||||
average_loss = self.loss_sum / self.iter_sum if self.iter_sum > 0 else 0
|
||||
return {
|
||||
MetricKeys.ACCURACY: accuracy,
|
||||
MetricKeys.AR: AR,
|
||||
MetricKeys.AVERAGE_LOSS: average_loss
|
||||
}
|
||||
|
||||
def merge(self, other: 'OCRRecognitionMetric'):
|
||||
pass
|
||||
|
||||
def __getstate__(self):
|
||||
pass
|
||||
|
||||
def __setstate__(self, state):
|
||||
pass
|
||||
@@ -9,6 +9,7 @@ if TYPE_CHECKING:
|
||||
else:
|
||||
_import_structure = {
|
||||
'frcrn': ['FRCRNDecorator'],
|
||||
'dnoise_net': ['DenoiseNet'],
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
@@ -7,57 +7,8 @@
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class UniDeepFsmn(nn.Module):
|
||||
|
||||
def __init__(self, input_dim, output_dim, lorder=None, hidden_size=None):
|
||||
super(UniDeepFsmn, self).__init__()
|
||||
|
||||
self.input_dim = input_dim
|
||||
self.output_dim = output_dim
|
||||
|
||||
if lorder is None:
|
||||
return
|
||||
|
||||
self.lorder = lorder
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
self.linear = nn.Linear(input_dim, hidden_size)
|
||||
|
||||
self.project = nn.Linear(hidden_size, output_dim, bias=False)
|
||||
|
||||
self.conv1 = nn.Conv2d(
|
||||
output_dim,
|
||||
output_dim, [lorder, 1], [1, 1],
|
||||
groups=output_dim,
|
||||
bias=False)
|
||||
|
||||
def forward(self, input):
|
||||
r"""
|
||||
|
||||
Args:
|
||||
input: torch with shape: batch (b) x sequence(T) x feature (h)
|
||||
|
||||
Returns:
|
||||
batch (b) x channel (c) x sequence(T) x feature (h)
|
||||
"""
|
||||
f1 = F.relu(self.linear(input))
|
||||
|
||||
p1 = self.project(f1)
|
||||
|
||||
x = torch.unsqueeze(p1, 1)
|
||||
# x: batch (b) x channel (c) x sequence(T) x feature (h)
|
||||
x_per = x.permute(0, 3, 2, 1)
|
||||
# x_per: batch (b) x feature (h) x sequence(T) x channel (c)
|
||||
y = F.pad(x_per, [0, 0, self.lorder - 1, 0])
|
||||
|
||||
out = x_per + self.conv1(y)
|
||||
|
||||
out1 = out.permute(0, 3, 2, 1)
|
||||
# out1: batch (b) x channel (c) x sequence(T) x feature (h)
|
||||
return input + out1.squeeze()
|
||||
from modelscope.models.audio.ans.layers.uni_deep_fsmn import UniDeepFsmn
|
||||
|
||||
|
||||
class ComplexUniDeepFsmn(nn.Module):
|
||||
|
||||
73
modelscope/models/audio/ans/denoise_net.py
Normal file
73
modelscope/models/audio/ans/denoise_net.py
Normal file
@@ -0,0 +1,73 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
# Related papers:
|
||||
# Shengkui Zhao, Trung Hieu Nguyen, Bin Ma, “Monaural Speech Enhancement with Complex Convolutional
|
||||
# Block Attention Module and Joint Time Frequency Losses”, ICASSP 2021.
|
||||
# Shiliang Zhang, Ming Lei, Zhijie Yan, Lirong Dai, “Deep-FSMN for Large Vocabulary Continuous Speech
|
||||
# Recognition “, arXiv:1803.05030, 2018.
|
||||
|
||||
from torch import nn
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models import MODELS, TorchModel
|
||||
from modelscope.models.audio.ans.layers.activations import (RectifiedLinear,
|
||||
Sigmoid)
|
||||
from modelscope.models.audio.ans.layers.affine_transform import AffineTransform
|
||||
from modelscope.models.audio.ans.layers.uni_deep_fsmn import UniDeepFsmn
|
||||
from modelscope.utils.constant import Tasks
|
||||
|
||||
|
||||
@MODELS.register_module(
|
||||
Tasks.acoustic_noise_suppression, module_name=Models.speech_dfsmn_ans)
|
||||
class DfsmnAns(TorchModel):
|
||||
"""Denoise model with DFSMN.
|
||||
|
||||
Args:
|
||||
model_dir (str): the model path.
|
||||
fsmn_depth (int): the depth of deepfsmn
|
||||
lorder (int):
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model_dir: str,
|
||||
fsmn_depth=9,
|
||||
lorder=20,
|
||||
*args,
|
||||
**kwargs):
|
||||
super().__init__(model_dir, *args, **kwargs)
|
||||
self.lorder = lorder
|
||||
self.linear1 = AffineTransform(120, 256)
|
||||
self.relu = RectifiedLinear(256, 256)
|
||||
repeats = [
|
||||
UniDeepFsmn(256, 256, lorder, 256) for i in range(fsmn_depth)
|
||||
]
|
||||
self.deepfsmn = nn.Sequential(*repeats)
|
||||
self.linear2 = AffineTransform(256, 961)
|
||||
self.sig = Sigmoid(961, 961)
|
||||
|
||||
def forward(self, input):
|
||||
"""
|
||||
Args:
|
||||
input: fbank feature [batch_size,number_of_frame,feature_dimension]
|
||||
|
||||
Returns:
|
||||
mask value [batch_size, number_of_frame, FFT_size/2+1]
|
||||
"""
|
||||
x1 = self.linear1(input)
|
||||
x2 = self.relu(x1)
|
||||
x3 = self.deepfsmn(x2)
|
||||
x4 = self.linear2(x3)
|
||||
x5 = self.sig(x4)
|
||||
return x5
|
||||
|
||||
def to_kaldi_nnet(self):
|
||||
re_str = ''
|
||||
re_str += '<Nnet>\n'
|
||||
re_str += self.linear1.to_kaldi_nnet()
|
||||
re_str += self.relu.to_kaldi_nnet()
|
||||
for dfsmn in self.deepfsmn:
|
||||
re_str += dfsmn.to_kaldi_nnet()
|
||||
re_str += self.linear2.to_kaldi_nnet()
|
||||
re_str += self.sig.to_kaldi_nnet()
|
||||
re_str += '</Nnet>\n'
|
||||
|
||||
return re_str
|
||||
@@ -78,7 +78,7 @@ class FRCRN(nn.Module):
|
||||
win_len=400,
|
||||
win_inc=100,
|
||||
fft_len=512,
|
||||
win_type='hanning',
|
||||
win_type='hann',
|
||||
**kwargs):
|
||||
r"""
|
||||
Args:
|
||||
|
||||
62
modelscope/models/audio/ans/layers/activations.py
Normal file
62
modelscope/models/audio/ans/layers/activations.py
Normal file
@@ -0,0 +1,62 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from modelscope.models.audio.ans.layers.layer_base import LayerBase
|
||||
|
||||
|
||||
class RectifiedLinear(LayerBase):
|
||||
|
||||
def __init__(self, input_dim, output_dim):
|
||||
super(RectifiedLinear, self).__init__()
|
||||
self.dim = input_dim
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
def forward(self, input):
|
||||
return self.relu(input)
|
||||
|
||||
def to_kaldi_nnet(self):
|
||||
re_str = ''
|
||||
re_str += '<RectifiedLinear> %d %d\n' % (self.dim, self.dim)
|
||||
return re_str
|
||||
|
||||
def load_kaldi_nnet(self, instr):
|
||||
return instr
|
||||
|
||||
|
||||
class LogSoftmax(LayerBase):
|
||||
|
||||
def __init__(self, input_dim, output_dim):
|
||||
super(LogSoftmax, self).__init__()
|
||||
self.dim = input_dim
|
||||
self.ls = nn.LogSoftmax()
|
||||
|
||||
def forward(self, input):
|
||||
return self.ls(input)
|
||||
|
||||
def to_kaldi_nnet(self):
|
||||
re_str = ''
|
||||
re_str += '<Softmax> %d %d\n' % (self.dim, self.dim)
|
||||
return re_str
|
||||
|
||||
def load_kaldi_nnet(self, instr):
|
||||
return instr
|
||||
|
||||
|
||||
class Sigmoid(LayerBase):
|
||||
|
||||
def __init__(self, input_dim, output_dim):
|
||||
super(Sigmoid, self).__init__()
|
||||
self.dim = input_dim
|
||||
self.sig = nn.Sigmoid()
|
||||
|
||||
def forward(self, input):
|
||||
return self.sig(input)
|
||||
|
||||
def to_kaldi_nnet(self):
|
||||
re_str = ''
|
||||
re_str += '<Sigmoid> %d %d\n' % (self.dim, self.dim)
|
||||
return re_str
|
||||
|
||||
def load_kaldi_nnet(self, instr):
|
||||
return instr
|
||||
86
modelscope/models/audio/ans/layers/affine_transform.py
Normal file
86
modelscope/models/audio/ans/layers/affine_transform.py
Normal file
@@ -0,0 +1,86 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import torch as th
|
||||
import torch.nn as nn
|
||||
|
||||
from modelscope.models.audio.ans.layers.layer_base import (LayerBase,
|
||||
to_kaldi_matrix)
|
||||
from modelscope.utils.audio.audio_utils import (expect_kaldi_matrix,
|
||||
expect_token_number)
|
||||
|
||||
|
||||
class AffineTransform(LayerBase):
|
||||
|
||||
def __init__(self, input_dim, output_dim):
|
||||
super(AffineTransform, self).__init__()
|
||||
self.input_dim = input_dim
|
||||
self.output_dim = output_dim
|
||||
|
||||
self.linear = nn.Linear(input_dim, output_dim)
|
||||
|
||||
def forward(self, input):
|
||||
return self.linear(input)
|
||||
|
||||
def to_kaldi_nnet(self):
|
||||
re_str = ''
|
||||
|
||||
re_str += '<AffineTransform> %d %d\n' % (self.output_dim,
|
||||
self.input_dim)
|
||||
|
||||
re_str += '<LearnRateCoef> 1 <BiasLearnRateCoef> 1 <MaxNorm> 0\n'
|
||||
|
||||
linear_weights = self.state_dict()['linear.weight']
|
||||
|
||||
x = linear_weights.squeeze().numpy()
|
||||
|
||||
re_str += to_kaldi_matrix(x)
|
||||
|
||||
linear_bias = self.state_dict()['linear.bias']
|
||||
|
||||
x = linear_bias.squeeze().numpy()
|
||||
|
||||
re_str += to_kaldi_matrix(x)
|
||||
|
||||
return re_str
|
||||
|
||||
def load_kaldi_nnet(self, instr):
|
||||
|
||||
output = expect_token_number(
|
||||
instr,
|
||||
'<LearnRateCoef>',
|
||||
)
|
||||
if output is None:
|
||||
raise Exception('AffineTransform format error')
|
||||
|
||||
instr, lr = output
|
||||
|
||||
output = expect_token_number(instr, '<BiasLearnRateCoef>')
|
||||
if output is None:
|
||||
raise Exception('AffineTransform format error')
|
||||
|
||||
instr, lr = output
|
||||
|
||||
output = expect_token_number(instr, '<MaxNorm>')
|
||||
if output is None:
|
||||
raise Exception('AffineTransform format error')
|
||||
|
||||
instr, lr = output
|
||||
|
||||
output = expect_kaldi_matrix(instr)
|
||||
|
||||
if output is None:
|
||||
raise Exception('AffineTransform format error')
|
||||
|
||||
instr, mat = output
|
||||
|
||||
self.linear.weight = th.nn.Parameter(
|
||||
th.from_numpy(mat).type(th.FloatTensor))
|
||||
|
||||
output = expect_kaldi_matrix(instr)
|
||||
if output is None:
|
||||
raise Exception('AffineTransform format error')
|
||||
|
||||
instr, mat = output
|
||||
self.linear.bias = th.nn.Parameter(
|
||||
th.from_numpy(mat).type(th.FloatTensor))
|
||||
return instr
|
||||
31
modelscope/models/audio/ans/layers/layer_base.py
Normal file
31
modelscope/models/audio/ans/layers/layer_base.py
Normal file
@@ -0,0 +1,31 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import abc
|
||||
|
||||
import numpy as np
|
||||
import six
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def to_kaldi_matrix(np_mat):
|
||||
""" function that transform as str numpy mat to standard kaldi str matrix
|
||||
|
||||
Args:
|
||||
np_mat: numpy mat
|
||||
"""
|
||||
np.set_printoptions(threshold=np.inf, linewidth=np.nan)
|
||||
out_str = str(np_mat)
|
||||
out_str = out_str.replace('[', '')
|
||||
out_str = out_str.replace(']', '')
|
||||
return '[ %s ]\n' % out_str
|
||||
|
||||
|
||||
@six.add_metaclass(abc.ABCMeta)
|
||||
class LayerBase(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(LayerBase, self).__init__()
|
||||
|
||||
@abc.abstractmethod
|
||||
def to_kaldi_nnet(self):
|
||||
pass
|
||||
156
modelscope/models/audio/ans/layers/uni_deep_fsmn.py
Normal file
156
modelscope/models/audio/ans/layers/uni_deep_fsmn.py
Normal file
@@ -0,0 +1,156 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from modelscope.models.audio.ans.layers.layer_base import (LayerBase,
|
||||
to_kaldi_matrix)
|
||||
from modelscope.utils.audio.audio_utils import (expect_kaldi_matrix,
|
||||
expect_token_number)
|
||||
|
||||
|
||||
class UniDeepFsmn(LayerBase):
|
||||
|
||||
def __init__(self, input_dim, output_dim, lorder=1, hidden_size=None):
|
||||
super(UniDeepFsmn, self).__init__()
|
||||
|
||||
self.input_dim = input_dim
|
||||
self.output_dim = output_dim
|
||||
self.lorder = lorder
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
self.linear = nn.Linear(input_dim, hidden_size)
|
||||
self.project = nn.Linear(hidden_size, output_dim, bias=False)
|
||||
self.conv1 = nn.Conv2d(
|
||||
output_dim,
|
||||
output_dim, (lorder, 1), (1, 1),
|
||||
groups=output_dim,
|
||||
bias=False)
|
||||
|
||||
def forward(self, input):
|
||||
"""
|
||||
|
||||
Args:
|
||||
input: torch with shape: batch (b) x sequence(T) x feature (h)
|
||||
|
||||
Returns:
|
||||
batch (b) x channel (c) x sequence(T) x feature (h)
|
||||
"""
|
||||
f1 = F.relu(self.linear(input))
|
||||
p1 = self.project(f1)
|
||||
x = torch.unsqueeze(p1, 1)
|
||||
# x: batch (b) x channel (c) x sequence(T) x feature (h)
|
||||
x_per = x.permute(0, 3, 2, 1)
|
||||
# x_per: batch (b) x feature (h) x sequence(T) x channel (c)
|
||||
y = F.pad(x_per, [0, 0, self.lorder - 1, 0])
|
||||
|
||||
out = x_per + self.conv1(y)
|
||||
out1 = out.permute(0, 3, 2, 1)
|
||||
# out1: batch (b) x channel (c) x sequence(T) x feature (h)
|
||||
return input + out1.squeeze()
|
||||
|
||||
def to_kaldi_nnet(self):
|
||||
re_str = ''
|
||||
re_str += '<UniDeepFsmn> %d %d\n'\
|
||||
% (self.output_dim, self.input_dim)
|
||||
re_str += '<LearnRateCoef> %d <HidSize> %d <LOrder> %d <LStride> %d <MaxNorm> 0\n'\
|
||||
% (1, self.hidden_size, self.lorder, 1)
|
||||
|
||||
lfiters = self.state_dict()['conv1.weight']
|
||||
x = np.flipud(lfiters.squeeze().numpy().T)
|
||||
re_str += to_kaldi_matrix(x)
|
||||
proj_weights = self.state_dict()['project.weight']
|
||||
x = proj_weights.squeeze().numpy()
|
||||
re_str += to_kaldi_matrix(x)
|
||||
linear_weights = self.state_dict()['linear.weight']
|
||||
x = linear_weights.squeeze().numpy()
|
||||
re_str += to_kaldi_matrix(x)
|
||||
linear_bias = self.state_dict()['linear.bias']
|
||||
x = linear_bias.squeeze().numpy()
|
||||
re_str += to_kaldi_matrix(x)
|
||||
return re_str
|
||||
|
||||
def load_kaldi_nnet(self, instr):
|
||||
output = expect_token_number(
|
||||
instr,
|
||||
'<LearnRateCoef>',
|
||||
)
|
||||
if output is None:
|
||||
raise Exception('UniDeepFsmn format error')
|
||||
instr, lr = output
|
||||
|
||||
output = expect_token_number(
|
||||
instr,
|
||||
'<HidSize>',
|
||||
)
|
||||
if output is None:
|
||||
raise Exception('UniDeepFsmn format error')
|
||||
instr, hiddensize = output
|
||||
self.hidden_size = int(hiddensize)
|
||||
|
||||
output = expect_token_number(
|
||||
instr,
|
||||
'<LOrder>',
|
||||
)
|
||||
if output is None:
|
||||
raise Exception('UniDeepFsmn format error')
|
||||
instr, lorder = output
|
||||
self.lorder = int(lorder)
|
||||
|
||||
output = expect_token_number(
|
||||
instr,
|
||||
'<LStride>',
|
||||
)
|
||||
if output is None:
|
||||
raise Exception('UniDeepFsmn format error')
|
||||
instr, lstride = output
|
||||
self.lstride = lstride
|
||||
|
||||
output = expect_token_number(
|
||||
instr,
|
||||
'<MaxNorm>',
|
||||
)
|
||||
if output is None:
|
||||
raise Exception('UniDeepFsmn format error')
|
||||
|
||||
output = expect_kaldi_matrix(instr)
|
||||
if output is None:
|
||||
raise Exception('Fsmn format error')
|
||||
instr, mat = output
|
||||
mat1 = np.fliplr(mat.T).copy()
|
||||
self.conv1 = nn.Conv2d(
|
||||
self.output_dim,
|
||||
self.output_dim, (self.lorder, 1), (1, 1),
|
||||
groups=self.output_dim,
|
||||
bias=False)
|
||||
mat_th = torch.from_numpy(mat1).type(torch.FloatTensor)
|
||||
mat_th = mat_th.unsqueeze(1)
|
||||
mat_th = mat_th.unsqueeze(3)
|
||||
self.conv1.weight = torch.nn.Parameter(mat_th)
|
||||
|
||||
output = expect_kaldi_matrix(instr)
|
||||
if output is None:
|
||||
raise Exception('UniDeepFsmn format error')
|
||||
instr, mat = output
|
||||
|
||||
self.project = nn.Linear(self.hidden_size, self.output_dim, bias=False)
|
||||
self.linear = nn.Linear(self.input_dim, self.hidden_size)
|
||||
self.project.weight = torch.nn.Parameter(
|
||||
torch.from_numpy(mat).type(torch.FloatTensor))
|
||||
|
||||
output = expect_kaldi_matrix(instr)
|
||||
if output is None:
|
||||
raise Exception('UniDeepFsmn format error')
|
||||
instr, mat = output
|
||||
self.linear.weight = torch.nn.Parameter(
|
||||
torch.from_numpy(mat).type(torch.FloatTensor))
|
||||
|
||||
output = expect_kaldi_matrix(instr)
|
||||
if output is None:
|
||||
raise Exception('UniDeepFsmn format error')
|
||||
instr, mat = output
|
||||
self.linear.bias = torch.nn.Parameter(
|
||||
torch.from_numpy(mat).type(torch.FloatTensor))
|
||||
return instr
|
||||
@@ -17,6 +17,7 @@ __all__ = ['GenericAutomaticSpeechRecognition']
|
||||
Tasks.voice_activity_detection, module_name=Models.generic_asr)
|
||||
@MODELS.register_module(
|
||||
Tasks.language_score_prediction, module_name=Models.generic_asr)
|
||||
@MODELS.register_module(Tasks.speech_timestamp, module_name=Models.generic_asr)
|
||||
class GenericAutomaticSpeechRecognition(Model):
|
||||
|
||||
def __init__(self, model_dir: str, am_model_name: str,
|
||||
|
||||
@@ -68,6 +68,7 @@ class FSMNSeleNetV2Decorator(TorchModel):
|
||||
'keyword':
|
||||
self._sc.kwsKeyword(self._sc.kwsSpottedKeywordIndex()),
|
||||
'offset': self._sc.kwsKeywordOffset(),
|
||||
'channel': self._sc.kwsBestChannel(),
|
||||
'length': self._sc.kwsKeywordLength(),
|
||||
'confidence': self._sc.kwsConfidence()
|
||||
}
|
||||
|
||||
@@ -1,8 +0,0 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from .datasets.dataset import get_am_datasets, get_voc_datasets
|
||||
from .models import model_builder
|
||||
from .models.hifigan.hifigan import Generator
|
||||
from .train.loss import criterion_builder
|
||||
from .train.trainer import GAN_Trainer, Sambert_Trainer
|
||||
from .utils.ling_unit.ling_unit import KanTtsLinguisticUnit
|
||||
@@ -1,36 +0,0 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import numpy as np
|
||||
from scipy.io import wavfile
|
||||
|
||||
DATA_TYPE_DICT = {
|
||||
'txt': {
|
||||
'load_func': np.loadtxt,
|
||||
'desc': 'plain txt file or readable by np.loadtxt',
|
||||
},
|
||||
'wav': {
|
||||
'load_func': lambda x: wavfile.read(x)[1],
|
||||
'desc': 'wav file or readable by soundfile.read',
|
||||
},
|
||||
'npy': {
|
||||
'load_func': np.load,
|
||||
'desc': 'any .npy format file',
|
||||
},
|
||||
# PCM data type can be loaded by binary format
|
||||
'bin_f32': {
|
||||
'load_func': lambda x: np.fromfile(x, dtype=np.float32),
|
||||
'desc': 'binary file with float32 format',
|
||||
},
|
||||
'bin_f64': {
|
||||
'load_func': lambda x: np.fromfile(x, dtype=np.float64),
|
||||
'desc': 'binary file with float64 format',
|
||||
},
|
||||
'bin_i32': {
|
||||
'load_func': lambda x: np.fromfile(x, dtype=np.int32),
|
||||
'desc': 'binary file with int32 format',
|
||||
},
|
||||
'bin_i16': {
|
||||
'load_func': lambda x: np.fromfile(x, dtype=np.int16),
|
||||
'desc': 'binary file with int16 format',
|
||||
},
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,158 +0,0 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import torch
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
|
||||
import modelscope.models.audio.tts.kantts.train.scheduler as kantts_scheduler
|
||||
from modelscope.models.audio.tts.kantts.utils.ling_unit.ling_unit import \
|
||||
get_fpdict
|
||||
from .hifigan import (Generator, MultiPeriodDiscriminator,
|
||||
MultiScaleDiscriminator, MultiSpecDiscriminator)
|
||||
from .pqmf import PQMF
|
||||
from .sambert.kantts_sambert import KanTtsSAMBERT, KanTtsTextsyBERT
|
||||
|
||||
|
||||
def optimizer_builder(model_params, opt_name, opt_params):
|
||||
opt_cls = getattr(torch.optim, opt_name)
|
||||
optimizer = opt_cls(model_params, **opt_params)
|
||||
return optimizer
|
||||
|
||||
|
||||
def scheduler_builder(optimizer, sche_name, sche_params):
|
||||
scheduler_cls = getattr(kantts_scheduler, sche_name)
|
||||
scheduler = scheduler_cls(optimizer, **sche_params)
|
||||
return scheduler
|
||||
|
||||
|
||||
def hifigan_model_builder(config, device, rank, distributed):
|
||||
model = {}
|
||||
optimizer = {}
|
||||
scheduler = {}
|
||||
model['discriminator'] = {}
|
||||
optimizer['discriminator'] = {}
|
||||
scheduler['discriminator'] = {}
|
||||
for model_name in config['Model'].keys():
|
||||
if model_name == 'Generator':
|
||||
params = config['Model'][model_name]['params']
|
||||
model['generator'] = Generator(**params).to(device)
|
||||
optimizer['generator'] = optimizer_builder(
|
||||
model['generator'].parameters(),
|
||||
config['Model'][model_name]['optimizer'].get('type', 'Adam'),
|
||||
config['Model'][model_name]['optimizer'].get('params', {}),
|
||||
)
|
||||
scheduler['generator'] = scheduler_builder(
|
||||
optimizer['generator'],
|
||||
config['Model'][model_name]['scheduler'].get('type', 'StepLR'),
|
||||
config['Model'][model_name]['scheduler'].get('params', {}),
|
||||
)
|
||||
else:
|
||||
params = config['Model'][model_name]['params']
|
||||
model['discriminator'][model_name] = globals()[model_name](
|
||||
**params).to(device)
|
||||
optimizer['discriminator'][model_name] = optimizer_builder(
|
||||
model['discriminator'][model_name].parameters(),
|
||||
config['Model'][model_name]['optimizer'].get('type', 'Adam'),
|
||||
config['Model'][model_name]['optimizer'].get('params', {}),
|
||||
)
|
||||
scheduler['discriminator'][model_name] = scheduler_builder(
|
||||
optimizer['discriminator'][model_name],
|
||||
config['Model'][model_name]['scheduler'].get('type', 'StepLR'),
|
||||
config['Model'][model_name]['scheduler'].get('params', {}),
|
||||
)
|
||||
|
||||
out_channels = config['Model']['Generator']['params']['out_channels']
|
||||
if out_channels > 1:
|
||||
model['pqmf'] = PQMF(
|
||||
subbands=out_channels, **config.get('pqmf', {})).to(device)
|
||||
|
||||
# FIXME: pywavelets buffer leads to gradient error in DDP training
|
||||
# Solution: https://github.com/pytorch/pytorch/issues/22095
|
||||
if distributed:
|
||||
model['generator'] = DistributedDataParallel(
|
||||
model['generator'],
|
||||
device_ids=[rank],
|
||||
output_device=rank,
|
||||
broadcast_buffers=False,
|
||||
)
|
||||
for model_name in model['discriminator'].keys():
|
||||
model['discriminator'][model_name] = DistributedDataParallel(
|
||||
model['discriminator'][model_name],
|
||||
device_ids=[rank],
|
||||
output_device=rank,
|
||||
broadcast_buffers=False,
|
||||
)
|
||||
|
||||
return model, optimizer, scheduler
|
||||
|
||||
|
||||
def sambert_model_builder(config, device, rank, distributed):
|
||||
model = {}
|
||||
optimizer = {}
|
||||
scheduler = {}
|
||||
|
||||
model['KanTtsSAMBERT'] = KanTtsSAMBERT(
|
||||
config['Model']['KanTtsSAMBERT']['params']).to(device)
|
||||
|
||||
fp_enable = config['Model']['KanTtsSAMBERT']['params'].get('FP', False)
|
||||
if fp_enable:
|
||||
fp_dict = {
|
||||
k: torch.from_numpy(v).long().unsqueeze(0).to(device)
|
||||
for k, v in get_fpdict(config).items()
|
||||
}
|
||||
model['KanTtsSAMBERT'].fp_dict = fp_dict
|
||||
|
||||
optimizer['KanTtsSAMBERT'] = optimizer_builder(
|
||||
model['KanTtsSAMBERT'].parameters(),
|
||||
config['Model']['KanTtsSAMBERT']['optimizer'].get('type', 'Adam'),
|
||||
config['Model']['KanTtsSAMBERT']['optimizer'].get('params', {}),
|
||||
)
|
||||
scheduler['KanTtsSAMBERT'] = scheduler_builder(
|
||||
optimizer['KanTtsSAMBERT'],
|
||||
config['Model']['KanTtsSAMBERT']['scheduler'].get('type', 'StepLR'),
|
||||
config['Model']['KanTtsSAMBERT']['scheduler'].get('params', {}),
|
||||
)
|
||||
|
||||
if distributed:
|
||||
model['KanTtsSAMBERT'] = DistributedDataParallel(
|
||||
model['KanTtsSAMBERT'], device_ids=[rank], output_device=rank)
|
||||
|
||||
return model, optimizer, scheduler
|
||||
|
||||
|
||||
def sybert_model_builder(config, device, rank, distributed):
|
||||
model = {}
|
||||
optimizer = {}
|
||||
scheduler = {}
|
||||
|
||||
model['KanTtsTextsyBERT'] = KanTtsTextsyBERT(
|
||||
config['Model']['KanTtsTextsyBERT']['params']).to(device)
|
||||
optimizer['KanTtsTextsyBERT'] = optimizer_builder(
|
||||
model['KanTtsTextsyBERT'].parameters(),
|
||||
config['Model']['KanTtsTextsyBERT']['optimizer'].get('type', 'Adam'),
|
||||
config['Model']['KanTtsTextsyBERT']['optimizer'].get('params', {}),
|
||||
)
|
||||
scheduler['KanTtsTextsyBERT'] = scheduler_builder(
|
||||
optimizer['KanTtsTextsyBERT'],
|
||||
config['Model']['KanTtsTextsyBERT']['scheduler'].get('type', 'StepLR'),
|
||||
config['Model']['KanTtsTextsyBERT']['scheduler'].get('params', {}),
|
||||
)
|
||||
|
||||
if distributed:
|
||||
model['KanTtsTextsyBERT'] = DistributedDataParallel(
|
||||
model['KanTtsTextsyBERT'], device_ids=[rank], output_device=rank)
|
||||
|
||||
return model, optimizer, scheduler
|
||||
|
||||
|
||||
model_dict = {
|
||||
'hifigan': hifigan_model_builder,
|
||||
'sambert': sambert_model_builder,
|
||||
'sybert': sybert_model_builder,
|
||||
}
|
||||
|
||||
|
||||
def model_builder(config, device='cpu', rank=0, distributed=False):
|
||||
builder_func = model_dict[config['model_type']]
|
||||
model, optimizer, scheduler = builder_func(config, device, rank,
|
||||
distributed)
|
||||
return model, optimizer, scheduler
|
||||
@@ -1,4 +0,0 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from .hifigan import (Generator, MultiPeriodDiscriminator,
|
||||
MultiScaleDiscriminator, MultiSpecDiscriminator)
|
||||
@@ -1,613 +0,0 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import copy
|
||||
from distutils.version import LooseVersion
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from pytorch_wavelets import DWT1DForward
|
||||
from torch.nn.utils import spectral_norm, weight_norm
|
||||
|
||||
from modelscope.models.audio.tts.kantts.utils.audio_torch import stft
|
||||
from .layers import (CausalConv1d, CausalConvTranspose1d, Conv1d,
|
||||
ConvTranspose1d, ResidualBlock, SourceModule)
|
||||
|
||||
is_pytorch_17plus = LooseVersion(torch.__version__) >= LooseVersion('1.7')
|
||||
|
||||
|
||||
class Generator(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=80,
|
||||
out_channels=1,
|
||||
channels=512,
|
||||
kernel_size=7,
|
||||
upsample_scales=(8, 8, 2, 2),
|
||||
upsample_kernal_sizes=(16, 16, 4, 4),
|
||||
resblock_kernel_sizes=(3, 7, 11),
|
||||
resblock_dilations=[(1, 3, 5), (1, 3, 5), (1, 3, 5)],
|
||||
repeat_upsample=True,
|
||||
bias=True,
|
||||
causal=True,
|
||||
nonlinear_activation='LeakyReLU',
|
||||
nonlinear_activation_params={'negative_slope': 0.1},
|
||||
use_weight_norm=True,
|
||||
nsf_params=None,
|
||||
):
|
||||
super(Generator, self).__init__()
|
||||
|
||||
# check hyperparameters are valid
|
||||
assert kernel_size % 2 == 1, 'Kernal size must be odd number.'
|
||||
assert len(upsample_scales) == len(upsample_kernal_sizes)
|
||||
assert len(resblock_dilations) == len(resblock_kernel_sizes)
|
||||
|
||||
self.upsample_scales = upsample_scales
|
||||
self.repeat_upsample = repeat_upsample
|
||||
self.num_upsamples = len(upsample_kernal_sizes)
|
||||
self.num_kernels = len(resblock_kernel_sizes)
|
||||
self.out_channels = out_channels
|
||||
self.nsf_enable = nsf_params is not None
|
||||
|
||||
self.transpose_upsamples = torch.nn.ModuleList()
|
||||
self.repeat_upsamples = torch.nn.ModuleList() # for repeat upsampling
|
||||
self.conv_blocks = torch.nn.ModuleList()
|
||||
|
||||
conv_cls = CausalConv1d if causal else Conv1d
|
||||
conv_transposed_cls = CausalConvTranspose1d if causal else ConvTranspose1d
|
||||
|
||||
self.conv_pre = conv_cls(
|
||||
in_channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
padding=(kernel_size - 1) // 2)
|
||||
|
||||
for i in range(len(upsample_kernal_sizes)):
|
||||
self.transpose_upsamples.append(
|
||||
torch.nn.Sequential(
|
||||
getattr(
|
||||
torch.nn,
|
||||
nonlinear_activation)(**nonlinear_activation_params),
|
||||
conv_transposed_cls(
|
||||
channels // (2**i),
|
||||
channels // (2**(i + 1)),
|
||||
upsample_kernal_sizes[i],
|
||||
upsample_scales[i],
|
||||
padding=(upsample_kernal_sizes[i] - upsample_scales[i])
|
||||
// 2,
|
||||
),
|
||||
))
|
||||
|
||||
if repeat_upsample:
|
||||
self.repeat_upsamples.append(
|
||||
nn.Sequential(
|
||||
nn.Upsample(
|
||||
mode='nearest', scale_factor=upsample_scales[i]),
|
||||
getattr(torch.nn, nonlinear_activation)(
|
||||
**nonlinear_activation_params),
|
||||
conv_cls(
|
||||
channels // (2**i),
|
||||
channels // (2**(i + 1)),
|
||||
kernel_size=kernel_size,
|
||||
stride=1,
|
||||
padding=(kernel_size - 1) // 2,
|
||||
),
|
||||
))
|
||||
|
||||
for j in range(len(resblock_kernel_sizes)):
|
||||
self.conv_blocks.append(
|
||||
ResidualBlock(
|
||||
channels=channels // (2**(i + 1)),
|
||||
kernel_size=resblock_kernel_sizes[j],
|
||||
dilation=resblock_dilations[j],
|
||||
nonlinear_activation=nonlinear_activation,
|
||||
nonlinear_activation_params=nonlinear_activation_params,
|
||||
causal=causal,
|
||||
))
|
||||
|
||||
self.conv_post = conv_cls(
|
||||
channels // (2**(i + 1)),
|
||||
out_channels,
|
||||
kernel_size,
|
||||
1,
|
||||
padding=(kernel_size - 1) // 2,
|
||||
)
|
||||
|
||||
if self.nsf_enable:
|
||||
self.source_module = SourceModule(
|
||||
nb_harmonics=nsf_params['nb_harmonics'],
|
||||
upsample_ratio=np.cumprod(self.upsample_scales)[-1],
|
||||
sampling_rate=nsf_params['sampling_rate'],
|
||||
)
|
||||
self.source_downs = nn.ModuleList()
|
||||
self.downsample_rates = [1] + self.upsample_scales[::-1][:-1]
|
||||
self.downsample_cum_rates = np.cumprod(self.downsample_rates)
|
||||
|
||||
for i, u in enumerate(self.downsample_cum_rates[::-1]):
|
||||
if u == 1:
|
||||
self.source_downs.append(
|
||||
Conv1d(1, channels // (2**(i + 1)), 1, 1))
|
||||
else:
|
||||
self.source_downs.append(
|
||||
conv_cls(
|
||||
1,
|
||||
channels // (2**(i + 1)),
|
||||
u * 2,
|
||||
u,
|
||||
padding=u // 2,
|
||||
))
|
||||
|
||||
def forward(self, x):
|
||||
if self.nsf_enable:
|
||||
mel = x[:, :-2, :]
|
||||
pitch = x[:, -2:-1, :]
|
||||
uv = x[:, -1:, :]
|
||||
excitation = self.source_module(pitch, uv)
|
||||
else:
|
||||
mel = x
|
||||
|
||||
x = self.conv_pre(mel)
|
||||
for i in range(self.num_upsamples):
|
||||
# FIXME: sin function here seems to be causing issues
|
||||
x = torch.sin(x) + x
|
||||
rep = self.repeat_upsamples[i](x)
|
||||
|
||||
if self.nsf_enable:
|
||||
# Downsampling the excitation signal
|
||||
e = self.source_downs[i](excitation)
|
||||
# augment inputs with the excitation
|
||||
x = rep + e
|
||||
else:
|
||||
# transconv
|
||||
up = self.transpose_upsamples[i](x)
|
||||
x = rep + up[:, :, :rep.shape[-1]]
|
||||
|
||||
xs = None
|
||||
for j in range(self.num_kernels):
|
||||
if xs is None:
|
||||
xs = self.conv_blocks[i * self.num_kernels + j](x)
|
||||
else:
|
||||
xs += self.conv_blocks[i * self.num_kernels + j](x)
|
||||
x = xs / self.num_kernels
|
||||
|
||||
x = F.leaky_relu(x)
|
||||
x = self.conv_post(x)
|
||||
x = torch.tanh(x)
|
||||
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
print('Removing weight norm...')
|
||||
for layer in self.transpose_upsamples:
|
||||
layer[-1].remove_weight_norm()
|
||||
for layer in self.repeat_upsamples:
|
||||
layer[-1].remove_weight_norm()
|
||||
for layer in self.conv_blocks:
|
||||
layer.remove_weight_norm()
|
||||
self.conv_pre.remove_weight_norm()
|
||||
self.conv_post.remove_weight_norm()
|
||||
if self.nsf_enable:
|
||||
self.source_module.remove_weight_norm()
|
||||
for layer in self.source_downs:
|
||||
layer.remove_weight_norm()
|
||||
|
||||
|
||||
class PeriodDiscriminator(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=1,
|
||||
out_channels=1,
|
||||
period=3,
|
||||
kernel_sizes=[5, 3],
|
||||
channels=32,
|
||||
downsample_scales=[3, 3, 3, 3, 1],
|
||||
max_downsample_channels=1024,
|
||||
bias=True,
|
||||
nonlinear_activation='LeakyReLU',
|
||||
nonlinear_activation_params={'negative_slope': 0.1},
|
||||
use_spectral_norm=False,
|
||||
):
|
||||
super(PeriodDiscriminator, self).__init__()
|
||||
self.period = period
|
||||
norm_f = weight_norm if not use_spectral_norm else spectral_norm
|
||||
self.convs = nn.ModuleList()
|
||||
in_chs, out_chs = in_channels, channels
|
||||
|
||||
for downsample_scale in downsample_scales:
|
||||
self.convs.append(
|
||||
torch.nn.Sequential(
|
||||
norm_f(
|
||||
nn.Conv2d(
|
||||
in_chs,
|
||||
out_chs,
|
||||
(kernel_sizes[0], 1),
|
||||
(downsample_scale, 1),
|
||||
padding=((kernel_sizes[0] - 1) // 2, 0),
|
||||
)),
|
||||
getattr(
|
||||
torch.nn,
|
||||
nonlinear_activation)(**nonlinear_activation_params),
|
||||
))
|
||||
in_chs = out_chs
|
||||
out_chs = min(out_chs * 4, max_downsample_channels)
|
||||
|
||||
self.conv_post = nn.Conv2d(
|
||||
out_chs,
|
||||
out_channels,
|
||||
(kernel_sizes[1] - 1, 1),
|
||||
1,
|
||||
padding=((kernel_sizes[1] - 1) // 2, 0),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
fmap = []
|
||||
|
||||
# 1d to 2d
|
||||
b, c, t = x.shape
|
||||
if t % self.period != 0: # pad first
|
||||
n_pad = self.period - (t % self.period)
|
||||
x = F.pad(x, (0, n_pad), 'reflect')
|
||||
t = t + n_pad
|
||||
x = x.view(b, c, t // self.period, self.period)
|
||||
|
||||
for layer in self.convs:
|
||||
x = layer(x)
|
||||
fmap.append(x)
|
||||
x = self.conv_post(x)
|
||||
fmap.append(x)
|
||||
x = torch.flatten(x, 1, -1)
|
||||
|
||||
return x, fmap
|
||||
|
||||
|
||||
class MultiPeriodDiscriminator(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
periods=[2, 3, 5, 7, 11],
|
||||
discriminator_params={
|
||||
'in_channels': 1,
|
||||
'out_channels': 1,
|
||||
'kernel_sizes': [5, 3],
|
||||
'channels': 32,
|
||||
'downsample_scales': [3, 3, 3, 3, 1],
|
||||
'max_downsample_channels': 1024,
|
||||
'bias': True,
|
||||
'nonlinear_activation': 'LeakyReLU',
|
||||
'nonlinear_activation_params': {
|
||||
'negative_slope': 0.1
|
||||
},
|
||||
'use_spectral_norm': False,
|
||||
},
|
||||
):
|
||||
super(MultiPeriodDiscriminator, self).__init__()
|
||||
self.discriminators = nn.ModuleList()
|
||||
for period in periods:
|
||||
params = copy.deepcopy(discriminator_params)
|
||||
params['period'] = period
|
||||
self.discriminators += [PeriodDiscriminator(**params)]
|
||||
|
||||
def forward(self, y):
|
||||
y_d_rs = []
|
||||
fmap_rs = []
|
||||
for i, d in enumerate(self.discriminators):
|
||||
y_d_r, fmap_r = d(y)
|
||||
y_d_rs.append(y_d_r)
|
||||
fmap_rs.append(fmap_r)
|
||||
|
||||
return y_d_rs, fmap_rs
|
||||
|
||||
|
||||
class ScaleDiscriminator(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=1,
|
||||
out_channels=1,
|
||||
kernel_sizes=[15, 41, 5, 3],
|
||||
channels=128,
|
||||
max_downsample_channels=1024,
|
||||
max_groups=16,
|
||||
bias=True,
|
||||
downsample_scales=[2, 2, 4, 4, 1],
|
||||
nonlinear_activation='LeakyReLU',
|
||||
nonlinear_activation_params={'negative_slope': 0.1},
|
||||
use_spectral_norm=False,
|
||||
):
|
||||
super(ScaleDiscriminator, self).__init__()
|
||||
norm_f = weight_norm if not use_spectral_norm else spectral_norm
|
||||
|
||||
assert len(kernel_sizes) == 4
|
||||
for ks in kernel_sizes:
|
||||
assert ks % 2 == 1
|
||||
|
||||
self.convs = nn.ModuleList()
|
||||
|
||||
self.convs.append(
|
||||
torch.nn.Sequential(
|
||||
norm_f(
|
||||
nn.Conv1d(
|
||||
in_channels,
|
||||
channels,
|
||||
kernel_sizes[0],
|
||||
bias=bias,
|
||||
padding=(kernel_sizes[0] - 1) // 2,
|
||||
)),
|
||||
getattr(torch.nn,
|
||||
nonlinear_activation)(**nonlinear_activation_params),
|
||||
))
|
||||
in_chs = channels
|
||||
out_chs = channels
|
||||
groups = 4
|
||||
|
||||
for downsample_scale in downsample_scales:
|
||||
self.convs.append(
|
||||
torch.nn.Sequential(
|
||||
norm_f(
|
||||
nn.Conv1d(
|
||||
in_chs,
|
||||
out_chs,
|
||||
kernel_size=kernel_sizes[1],
|
||||
stride=downsample_scale,
|
||||
padding=(kernel_sizes[1] - 1) // 2,
|
||||
groups=groups,
|
||||
bias=bias,
|
||||
)),
|
||||
getattr(
|
||||
torch.nn,
|
||||
nonlinear_activation)(**nonlinear_activation_params),
|
||||
))
|
||||
in_chs = out_chs
|
||||
out_chs = min(in_chs * 2, max_downsample_channels)
|
||||
groups = min(groups * 4, max_groups)
|
||||
|
||||
out_chs = min(in_chs * 2, max_downsample_channels)
|
||||
self.convs.append(
|
||||
torch.nn.Sequential(
|
||||
norm_f(
|
||||
nn.Conv1d(
|
||||
in_chs,
|
||||
out_chs,
|
||||
kernel_size=kernel_sizes[2],
|
||||
stride=1,
|
||||
padding=(kernel_sizes[2] - 1) // 2,
|
||||
bias=bias,
|
||||
)),
|
||||
getattr(torch.nn,
|
||||
nonlinear_activation)(**nonlinear_activation_params),
|
||||
))
|
||||
|
||||
self.conv_post = norm_f(
|
||||
nn.Conv1d(
|
||||
out_chs,
|
||||
out_channels,
|
||||
kernel_size=kernel_sizes[3],
|
||||
stride=1,
|
||||
padding=(kernel_sizes[3] - 1) // 2,
|
||||
bias=bias,
|
||||
))
|
||||
|
||||
def forward(self, x):
|
||||
fmap = []
|
||||
for layer in self.convs:
|
||||
x = layer(x)
|
||||
fmap.append(x)
|
||||
x = self.conv_post(x)
|
||||
fmap.append(x)
|
||||
x = torch.flatten(x, 1, -1)
|
||||
|
||||
return x, fmap
|
||||
|
||||
|
||||
class MultiScaleDiscriminator(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scales=3,
|
||||
downsample_pooling='DWT',
|
||||
# follow the official implementation setting
|
||||
downsample_pooling_params={
|
||||
'kernel_size': 4,
|
||||
'stride': 2,
|
||||
'padding': 2,
|
||||
},
|
||||
discriminator_params={
|
||||
'in_channels': 1,
|
||||
'out_channels': 1,
|
||||
'kernel_sizes': [15, 41, 5, 3],
|
||||
'channels': 128,
|
||||
'max_downsample_channels': 1024,
|
||||
'max_groups': 16,
|
||||
'bias': True,
|
||||
'downsample_scales': [2, 2, 4, 4, 1],
|
||||
'nonlinear_activation': 'LeakyReLU',
|
||||
'nonlinear_activation_params': {
|
||||
'negative_slope': 0.1
|
||||
},
|
||||
},
|
||||
follow_official_norm=False,
|
||||
):
|
||||
super(MultiScaleDiscriminator, self).__init__()
|
||||
self.discriminators = torch.nn.ModuleList()
|
||||
|
||||
# add discriminators
|
||||
for i in range(scales):
|
||||
params = copy.deepcopy(discriminator_params)
|
||||
if follow_official_norm:
|
||||
params['use_spectral_norm'] = True if i == 0 else False
|
||||
self.discriminators += [ScaleDiscriminator(**params)]
|
||||
|
||||
if downsample_pooling == 'DWT':
|
||||
self.meanpools = nn.ModuleList(
|
||||
[DWT1DForward(wave='db3', J=1),
|
||||
DWT1DForward(wave='db3', J=1)])
|
||||
self.aux_convs = nn.ModuleList([
|
||||
weight_norm(nn.Conv1d(2, 1, 15, 1, padding=7)),
|
||||
weight_norm(nn.Conv1d(2, 1, 15, 1, padding=7)),
|
||||
])
|
||||
else:
|
||||
self.meanpools = nn.ModuleList(
|
||||
[nn.AvgPool1d(4, 2, padding=2),
|
||||
nn.AvgPool1d(4, 2, padding=2)])
|
||||
self.aux_convs = None
|
||||
|
||||
def forward(self, y):
|
||||
y_d_rs = []
|
||||
fmap_rs = []
|
||||
for i, d in enumerate(self.discriminators):
|
||||
if i != 0:
|
||||
if self.aux_convs is None:
|
||||
y = self.meanpools[i - 1](y)
|
||||
else:
|
||||
yl, yh = self.meanpools[i - 1](y)
|
||||
y = torch.cat([yl, yh[0]], dim=1)
|
||||
y = self.aux_convs[i - 1](y)
|
||||
y = F.leaky_relu(y, 0.1)
|
||||
|
||||
y_d_r, fmap_r = d(y)
|
||||
y_d_rs.append(y_d_r)
|
||||
fmap_rs.append(fmap_r)
|
||||
|
||||
return y_d_rs, fmap_rs
|
||||
|
||||
|
||||
class SpecDiscriminator(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels=32,
|
||||
init_kernel=15,
|
||||
kernel_size=11,
|
||||
stride=2,
|
||||
use_spectral_norm=False,
|
||||
fft_size=1024,
|
||||
shift_size=120,
|
||||
win_length=600,
|
||||
window='hann_window',
|
||||
nonlinear_activation='LeakyReLU',
|
||||
nonlinear_activation_params={'negative_slope': 0.1},
|
||||
):
|
||||
super(SpecDiscriminator, self).__init__()
|
||||
self.fft_size = fft_size
|
||||
self.shift_size = shift_size
|
||||
self.win_length = win_length
|
||||
# fft_size // 2 + 1
|
||||
norm_f = weight_norm if not use_spectral_norm else spectral_norm
|
||||
final_kernel = 5
|
||||
post_conv_kernel = 3
|
||||
blocks = 3
|
||||
self.convs = nn.ModuleList()
|
||||
self.convs.append(
|
||||
torch.nn.Sequential(
|
||||
norm_f(
|
||||
nn.Conv2d(
|
||||
fft_size // 2 + 1,
|
||||
channels,
|
||||
(init_kernel, 1),
|
||||
(1, 1),
|
||||
padding=(init_kernel - 1) // 2,
|
||||
)),
|
||||
getattr(torch.nn,
|
||||
nonlinear_activation)(**nonlinear_activation_params),
|
||||
))
|
||||
|
||||
for i in range(blocks):
|
||||
self.convs.append(
|
||||
torch.nn.Sequential(
|
||||
norm_f(
|
||||
nn.Conv2d(
|
||||
channels,
|
||||
channels,
|
||||
(kernel_size, 1),
|
||||
(stride, 1),
|
||||
padding=(kernel_size - 1) // 2,
|
||||
)),
|
||||
getattr(
|
||||
torch.nn,
|
||||
nonlinear_activation)(**nonlinear_activation_params),
|
||||
))
|
||||
|
||||
self.convs.append(
|
||||
torch.nn.Sequential(
|
||||
norm_f(
|
||||
nn.Conv2d(
|
||||
channels,
|
||||
channels,
|
||||
(final_kernel, 1),
|
||||
(1, 1),
|
||||
padding=(final_kernel - 1) // 2,
|
||||
)),
|
||||
getattr(torch.nn,
|
||||
nonlinear_activation)(**nonlinear_activation_params),
|
||||
))
|
||||
|
||||
self.conv_post = norm_f(
|
||||
nn.Conv2d(
|
||||
channels,
|
||||
1,
|
||||
(post_conv_kernel, 1),
|
||||
(1, 1),
|
||||
padding=((post_conv_kernel - 1) // 2, 0),
|
||||
))
|
||||
self.register_buffer('window', getattr(torch, window)(win_length))
|
||||
|
||||
def forward(self, wav):
|
||||
with torch.no_grad():
|
||||
wav = torch.squeeze(wav, 1)
|
||||
x_mag = stft(wav, self.fft_size, self.shift_size, self.win_length,
|
||||
self.window)
|
||||
x = torch.transpose(x_mag, 2, 1).unsqueeze(-1)
|
||||
fmap = []
|
||||
for layer in self.convs:
|
||||
x = layer(x)
|
||||
fmap.append(x)
|
||||
x = self.conv_post(x)
|
||||
fmap.append(x)
|
||||
x = x.squeeze(-1)
|
||||
|
||||
return x, fmap
|
||||
|
||||
|
||||
class MultiSpecDiscriminator(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fft_sizes=[1024, 2048, 512],
|
||||
hop_sizes=[120, 240, 50],
|
||||
win_lengths=[600, 1200, 240],
|
||||
discriminator_params={
|
||||
'channels': 15,
|
||||
'init_kernel': 1,
|
||||
'kernel_sizes': 11,
|
||||
'stride': 2,
|
||||
'use_spectral_norm': False,
|
||||
'window': 'hann_window',
|
||||
'nonlinear_activation': 'LeakyReLU',
|
||||
'nonlinear_activation_params': {
|
||||
'negative_slope': 0.1
|
||||
},
|
||||
},
|
||||
):
|
||||
super(MultiSpecDiscriminator, self).__init__()
|
||||
self.discriminators = nn.ModuleList()
|
||||
for fft_size, hop_size, win_length in zip(fft_sizes, hop_sizes,
|
||||
win_lengths):
|
||||
params = copy.deepcopy(discriminator_params)
|
||||
params['fft_size'] = fft_size
|
||||
params['shift_size'] = hop_size
|
||||
params['win_length'] = win_length
|
||||
self.discriminators += [SpecDiscriminator(**params)]
|
||||
|
||||
def forward(self, y):
|
||||
y_d = []
|
||||
fmap = []
|
||||
for i, d in enumerate(self.discriminators):
|
||||
x, x_map = d(y)
|
||||
y_d.append(x)
|
||||
fmap.append(x_map)
|
||||
|
||||
return y_d, fmap
|
||||
@@ -1,288 +0,0 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.distributions.normal import Normal
|
||||
from torch.distributions.uniform import Uniform
|
||||
from torch.nn.utils import remove_weight_norm, weight_norm
|
||||
|
||||
from modelscope.models.audio.tts.kantts.models.utils import init_weights
|
||||
|
||||
|
||||
def get_padding(kernel_size, dilation=1):
|
||||
return int((kernel_size * dilation - dilation) / 2)
|
||||
|
||||
|
||||
class Conv1d(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=True,
|
||||
padding_mode='zeros',
|
||||
):
|
||||
super(Conv1d, self).__init__()
|
||||
self.conv1d = weight_norm(
|
||||
nn.Conv1d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=bias,
|
||||
padding_mode=padding_mode,
|
||||
))
|
||||
self.conv1d.apply(init_weights)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1d(x)
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
remove_weight_norm(self.conv1d)
|
||||
|
||||
|
||||
class CausalConv1d(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=True,
|
||||
padding_mode='zeros',
|
||||
):
|
||||
super(CausalConv1d, self).__init__()
|
||||
self.pad = (kernel_size - 1) * dilation
|
||||
self.conv1d = weight_norm(
|
||||
nn.Conv1d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding=0,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=bias,
|
||||
padding_mode=padding_mode,
|
||||
))
|
||||
self.conv1d.apply(init_weights)
|
||||
|
||||
def forward(self, x): # bdt
|
||||
x = F.pad(
|
||||
x, (self.pad, 0, 0, 0, 0, 0), 'constant'
|
||||
) # described starting from the last dimension and moving forward.
|
||||
# x = F.pad(x, (self.pad, self.pad, 0, 0, 0, 0), "constant")
|
||||
x = self.conv1d(x)[:, :, :x.size(2)]
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
remove_weight_norm(self.conv1d)
|
||||
|
||||
|
||||
class ConvTranspose1d(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding=0,
|
||||
output_padding=0,
|
||||
):
|
||||
super(ConvTranspose1d, self).__init__()
|
||||
self.deconv = weight_norm(
|
||||
nn.ConvTranspose1d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding=padding,
|
||||
output_padding=0,
|
||||
))
|
||||
self.deconv.apply(init_weights)
|
||||
|
||||
def forward(self, x):
|
||||
return self.deconv(x)
|
||||
|
||||
def remove_weight_norm(self):
|
||||
remove_weight_norm(self.deconv)
|
||||
|
||||
|
||||
# FIXME: HACK to get shape right
|
||||
class CausalConvTranspose1d(torch.nn.Module):
|
||||
"""CausalConvTranspose1d module with customized initialization."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding=0,
|
||||
output_padding=0,
|
||||
):
|
||||
"""Initialize CausalConvTranspose1d module."""
|
||||
super(CausalConvTranspose1d, self).__init__()
|
||||
self.deconv = weight_norm(
|
||||
nn.ConvTranspose1d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding=0,
|
||||
output_padding=0,
|
||||
))
|
||||
self.stride = stride
|
||||
self.deconv.apply(init_weights)
|
||||
self.pad = kernel_size - stride
|
||||
|
||||
def forward(self, x):
|
||||
"""Calculate forward propagation.
|
||||
Args:
|
||||
x (Tensor): Input tensor (B, in_channels, T_in).
|
||||
Returns:
|
||||
Tensor: Output tensor (B, out_channels, T_out).
|
||||
"""
|
||||
# x = F.pad(x, (self.pad, 0, 0, 0, 0, 0), "constant")
|
||||
return self.deconv(x)[:, :, :-self.pad]
|
||||
# return self.deconv(x)
|
||||
|
||||
def remove_weight_norm(self):
|
||||
remove_weight_norm(self.deconv)
|
||||
|
||||
|
||||
class ResidualBlock(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels,
|
||||
kernel_size=3,
|
||||
dilation=(1, 3, 5),
|
||||
nonlinear_activation='LeakyReLU',
|
||||
nonlinear_activation_params={'negative_slope': 0.1},
|
||||
causal=False,
|
||||
):
|
||||
super(ResidualBlock, self).__init__()
|
||||
assert kernel_size % 2 == 1, 'Kernal size must be odd number.'
|
||||
conv_cls = CausalConv1d if causal else Conv1d
|
||||
self.convs1 = nn.ModuleList([
|
||||
conv_cls(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[i],
|
||||
padding=get_padding(kernel_size, dilation[i]),
|
||||
) for i in range(len(dilation))
|
||||
])
|
||||
|
||||
self.convs2 = nn.ModuleList([
|
||||
conv_cls(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1),
|
||||
) for i in range(len(dilation))
|
||||
])
|
||||
|
||||
self.activation = getattr(
|
||||
torch.nn, nonlinear_activation)(**nonlinear_activation_params)
|
||||
|
||||
def forward(self, x):
|
||||
for c1, c2 in zip(self.convs1, self.convs2):
|
||||
xt = self.activation(x)
|
||||
xt = c1(xt)
|
||||
xt = self.activation(xt)
|
||||
xt = c2(xt)
|
||||
x = xt + x
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for layer in self.convs1:
|
||||
layer.remove_weight_norm()
|
||||
for layer in self.convs2:
|
||||
layer.remove_weight_norm()
|
||||
|
||||
|
||||
class SourceModule(torch.nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
nb_harmonics,
|
||||
upsample_ratio,
|
||||
sampling_rate,
|
||||
alpha=0.1,
|
||||
sigma=0.003):
|
||||
super(SourceModule, self).__init__()
|
||||
|
||||
self.nb_harmonics = nb_harmonics
|
||||
self.upsample_ratio = upsample_ratio
|
||||
self.sampling_rate = sampling_rate
|
||||
self.alpha = alpha
|
||||
self.sigma = sigma
|
||||
|
||||
self.ffn = nn.Sequential(
|
||||
weight_norm(
|
||||
nn.Conv1d(self.nb_harmonics + 1, 1, kernel_size=1, stride=1)),
|
||||
nn.Tanh(),
|
||||
)
|
||||
|
||||
def forward(self, pitch, uv):
|
||||
"""
|
||||
:param pitch: [B, 1, frame_len], Hz
|
||||
:param uv: [B, 1, frame_len] vuv flag
|
||||
:return: [B, 1, sample_len]
|
||||
"""
|
||||
with torch.no_grad():
|
||||
pitch_samples = F.interpolate(
|
||||
pitch, scale_factor=(self.upsample_ratio), mode='nearest')
|
||||
uv_samples = F.interpolate(
|
||||
uv, scale_factor=(self.upsample_ratio), mode='nearest')
|
||||
|
||||
F_mat = torch.zeros(
|
||||
(pitch_samples.size(0), self.nb_harmonics + 1,
|
||||
pitch_samples.size(-1))).to(pitch_samples.device)
|
||||
for i in range(self.nb_harmonics + 1):
|
||||
F_mat[:, i:i
|
||||
+ 1, :] = pitch_samples * (i + 1) / self.sampling_rate
|
||||
|
||||
theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1)
|
||||
u_dist = Uniform(low=-np.pi, high=np.pi)
|
||||
phase_vec = u_dist.sample(
|
||||
sample_shape=(pitch.size(0), self.nb_harmonics + 1,
|
||||
1)).to(F_mat.device)
|
||||
phase_vec[:, 0, :] = 0
|
||||
|
||||
n_dist = Normal(loc=0.0, scale=self.sigma)
|
||||
noise = n_dist.sample(
|
||||
sample_shape=(
|
||||
pitch_samples.size(0),
|
||||
self.nb_harmonics + 1,
|
||||
pitch_samples.size(-1),
|
||||
)).to(F_mat.device)
|
||||
|
||||
e_voice = self.alpha * torch.sin(theta_mat + phase_vec) + noise
|
||||
e_unvoice = self.alpha / 3 / self.sigma * noise
|
||||
|
||||
e = e_voice * uv_samples + e_unvoice * (1 - uv_samples)
|
||||
|
||||
return self.ffn(e)
|
||||
|
||||
def remove_weight_norm(self):
|
||||
remove_weight_norm(self.ffn[0])
|
||||
@@ -1,133 +0,0 @@
|
||||
# The implementation is adopted from kan-bayashi's ParallelWaveGAN,
|
||||
# made publicly available under the MIT License at https://github.com/kan-bayashi/ParallelWaveGAN
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from scipy.signal import kaiser
|
||||
|
||||
|
||||
def design_prototype_filter(taps=62, cutoff_ratio=0.142, beta=9.0):
|
||||
"""Design prototype filter for PQMF.
|
||||
|
||||
This method is based on `A Kaiser window approach for the design of prototype
|
||||
filters of cosine modulated filterbanks`_.
|
||||
|
||||
Args:
|
||||
taps (int): The number of filter taps.
|
||||
cutoff_ratio (float): Cut-off frequency ratio.
|
||||
beta (float): Beta coefficient for kaiser window.
|
||||
|
||||
Returns:
|
||||
ndarray: Impluse response of prototype filter (taps + 1,).
|
||||
|
||||
.. _`A Kaiser window approach for the design of prototype filters of cosine modulated filterbanks`:
|
||||
https://ieeexplore.ieee.org/abstract/document/681427
|
||||
|
||||
"""
|
||||
# check the arguments are valid
|
||||
assert taps % 2 == 0, 'The number of taps mush be even number.'
|
||||
assert 0.0 < cutoff_ratio < 1.0, 'Cutoff ratio must be > 0.0 and < 1.0.'
|
||||
|
||||
# make initial filter
|
||||
omega_c = np.pi * cutoff_ratio
|
||||
with np.errstate(invalid='ignore'):
|
||||
h_i = np.sin(omega_c * (np.arange(taps + 1) - 0.5 * taps)) / (
|
||||
np.pi * (np.arange(taps + 1) - 0.5 * taps))
|
||||
h_i[taps
|
||||
// 2] = np.cos(0) * cutoff_ratio # fix nan due to indeterminate form
|
||||
|
||||
# apply kaiser window
|
||||
w = kaiser(taps + 1, beta)
|
||||
h = h_i * w
|
||||
|
||||
return h
|
||||
|
||||
|
||||
class PQMF(torch.nn.Module):
|
||||
"""PQMF module.
|
||||
|
||||
This module is based on `Near-perfect-reconstruction pseudo-QMF banks`_.
|
||||
|
||||
.. _`Near-perfect-reconstruction pseudo-QMF banks`:
|
||||
https://ieeexplore.ieee.org/document/258122
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, subbands=4, taps=62, cutoff_ratio=0.142, beta=9.0):
|
||||
"""Initilize PQMF module.
|
||||
|
||||
The cutoff_ratio and beta parameters are optimized for #subbands = 4.
|
||||
See dicussion in https://github.com/kan-bayashi/ParallelWaveGAN/issues/195.
|
||||
|
||||
Args:
|
||||
subbands (int): The number of subbands.
|
||||
taps (int): The number of filter taps.
|
||||
cutoff_ratio (float): Cut-off frequency ratio.
|
||||
beta (float): Beta coefficient for kaiser window.
|
||||
|
||||
"""
|
||||
super(PQMF, self).__init__()
|
||||
|
||||
# build analysis & synthesis filter coefficients
|
||||
h_proto = design_prototype_filter(taps, cutoff_ratio, beta)
|
||||
h_analysis = np.zeros((subbands, len(h_proto)))
|
||||
h_synthesis = np.zeros((subbands, len(h_proto)))
|
||||
for k in range(subbands):
|
||||
h_analysis[k] = (
|
||||
2 * h_proto * np.cos((2 * k + 1) * # noqa W504
|
||||
(np.pi / (2 * subbands)) * # noqa W504
|
||||
(np.arange(taps + 1) - (taps / 2))
|
||||
+ (-1)**k * np.pi / 4))
|
||||
h_synthesis[k] = (
|
||||
2 * h_proto * np.cos((2 * k + 1) * # noqa W504
|
||||
(np.pi / (2 * subbands)) * # noqa W504
|
||||
(np.arange(taps + 1) - (taps / 2))
|
||||
- (-1)**k * np.pi / 4))
|
||||
|
||||
# convert to tensor
|
||||
analysis_filter = torch.from_numpy(h_analysis).float().unsqueeze(1)
|
||||
synthesis_filter = torch.from_numpy(h_synthesis).float().unsqueeze(0)
|
||||
|
||||
# register coefficients as beffer
|
||||
self.register_buffer('analysis_filter', analysis_filter)
|
||||
self.register_buffer('synthesis_filter', synthesis_filter)
|
||||
|
||||
# filter for downsampling & upsampling
|
||||
updown_filter = torch.zeros((subbands, subbands, subbands)).float()
|
||||
for k in range(subbands):
|
||||
updown_filter[k, k, 0] = 1.0
|
||||
self.register_buffer('updown_filter', updown_filter)
|
||||
self.subbands = subbands
|
||||
|
||||
# keep padding info
|
||||
self.pad_fn = torch.nn.ConstantPad1d(taps // 2, 0.0)
|
||||
|
||||
def analysis(self, x):
|
||||
"""Analysis with PQMF.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input tensor (B, 1, T).
|
||||
|
||||
Returns:
|
||||
Tensor: Output tensor (B, subbands, T // subbands).
|
||||
|
||||
"""
|
||||
x = F.conv1d(self.pad_fn(x), self.analysis_filter)
|
||||
return F.conv1d(x, self.updown_filter, stride=self.subbands)
|
||||
|
||||
def synthesis(self, x):
|
||||
"""Synthesis with PQMF.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input tensor (B, subbands, T // subbands).
|
||||
|
||||
Returns:
|
||||
Tensor: Output tensor (B, 1, T).
|
||||
|
||||
"""
|
||||
# NOTE(kan-bayashi): Power will be dreased so here multiply by # subbands.
|
||||
# Not sure this is the correct way, it is better to check again.
|
||||
x = F.conv_transpose1d(
|
||||
x, self.updown_filter * self.subbands, stride=self.subbands)
|
||||
return F.conv1d(self.pad_fn(x), self.synthesis_filter)
|
||||
@@ -1,372 +0,0 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class ScaledDotProductAttention(nn.Module):
|
||||
""" Scaled Dot-Product Attention """
|
||||
|
||||
def __init__(self, temperature, dropatt=0.0):
|
||||
super().__init__()
|
||||
self.temperature = temperature
|
||||
self.softmax = nn.Softmax(dim=2)
|
||||
self.dropatt = nn.Dropout(dropatt)
|
||||
|
||||
def forward(self, q, k, v, mask=None):
|
||||
|
||||
attn = torch.bmm(q, k.transpose(1, 2))
|
||||
attn = attn / self.temperature
|
||||
|
||||
if mask is not None:
|
||||
attn = attn.masked_fill(mask, -np.inf)
|
||||
|
||||
attn = self.softmax(attn)
|
||||
attn = self.dropatt(attn)
|
||||
output = torch.bmm(attn, v)
|
||||
|
||||
return output, attn
|
||||
|
||||
|
||||
class Prenet(nn.Module):
|
||||
|
||||
def __init__(self, in_units, prenet_units, out_units=0):
|
||||
super(Prenet, self).__init__()
|
||||
|
||||
self.fcs = nn.ModuleList()
|
||||
for in_dim, out_dim in zip([in_units] + prenet_units[:-1],
|
||||
prenet_units):
|
||||
self.fcs.append(nn.Linear(in_dim, out_dim))
|
||||
self.fcs.append(nn.ReLU())
|
||||
self.fcs.append(nn.Dropout(0.5))
|
||||
|
||||
if out_units:
|
||||
self.fcs.append(nn.Linear(prenet_units[-1], out_units))
|
||||
|
||||
def forward(self, input):
|
||||
output = input
|
||||
for layer in self.fcs:
|
||||
output = layer(output)
|
||||
return output
|
||||
|
||||
|
||||
class MultiHeadSelfAttention(nn.Module):
|
||||
""" Multi-Head SelfAttention module """
|
||||
|
||||
def __init__(self, n_head, d_in, d_model, d_head, dropout, dropatt=0.0):
|
||||
super().__init__()
|
||||
|
||||
self.n_head = n_head
|
||||
self.d_head = d_head
|
||||
self.d_in = d_in
|
||||
self.d_model = d_model
|
||||
|
||||
self.layer_norm = nn.LayerNorm(d_in, eps=1e-6)
|
||||
self.w_qkv = nn.Linear(d_in, 3 * n_head * d_head)
|
||||
|
||||
self.attention = ScaledDotProductAttention(
|
||||
temperature=np.power(d_head, 0.5), dropatt=dropatt)
|
||||
|
||||
self.fc = nn.Linear(n_head * d_head, d_model)
|
||||
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, input, mask=None):
|
||||
d_head, n_head = self.d_head, self.n_head
|
||||
|
||||
sz_b, len_in, _ = input.size()
|
||||
|
||||
residual = input
|
||||
|
||||
x = self.layer_norm(input)
|
||||
qkv = self.w_qkv(x)
|
||||
q, k, v = qkv.chunk(3, -1)
|
||||
|
||||
q = q.view(sz_b, len_in, n_head, d_head)
|
||||
k = k.view(sz_b, len_in, n_head, d_head)
|
||||
v = v.view(sz_b, len_in, n_head, d_head)
|
||||
|
||||
q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_in,
|
||||
d_head) # (n*b) x l x d
|
||||
k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_in,
|
||||
d_head) # (n*b) x l x d
|
||||
v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_in,
|
||||
d_head) # (n*b) x l x d
|
||||
|
||||
if mask is not None:
|
||||
mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x ..
|
||||
output, attn = self.attention(q, k, v, mask=mask)
|
||||
|
||||
output = output.view(n_head, sz_b, len_in, d_head)
|
||||
output = (output.permute(1, 2, 0,
|
||||
3).contiguous().view(sz_b, len_in,
|
||||
-1)) # b x l x (n*d)
|
||||
|
||||
output = self.dropout(self.fc(output))
|
||||
if output.size(-1) == residual.size(-1):
|
||||
output = output + residual
|
||||
|
||||
return output, attn
|
||||
|
||||
|
||||
class PositionwiseConvFeedForward(nn.Module):
|
||||
""" A two-feed-forward-layer module """
|
||||
|
||||
def __init__(self,
|
||||
d_in,
|
||||
d_hid,
|
||||
kernel_size=(3, 1),
|
||||
dropout_inner=0.1,
|
||||
dropout=0.1):
|
||||
super().__init__()
|
||||
# Use Conv1D
|
||||
# position-wise
|
||||
self.w_1 = nn.Conv1d(
|
||||
d_in,
|
||||
d_hid,
|
||||
kernel_size=kernel_size[0],
|
||||
padding=(kernel_size[0] - 1) // 2,
|
||||
)
|
||||
# position-wise
|
||||
self.w_2 = nn.Conv1d(
|
||||
d_hid,
|
||||
d_in,
|
||||
kernel_size=kernel_size[1],
|
||||
padding=(kernel_size[1] - 1) // 2,
|
||||
)
|
||||
|
||||
self.layer_norm = nn.LayerNorm(d_in, eps=1e-6)
|
||||
self.dropout_inner = nn.Dropout(dropout_inner)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
residual = x
|
||||
x = self.layer_norm(x)
|
||||
|
||||
output = x.transpose(1, 2)
|
||||
output = F.relu(self.w_1(output))
|
||||
if mask is not None:
|
||||
output = output.masked_fill(mask.unsqueeze(1), 0)
|
||||
output = self.dropout_inner(output)
|
||||
output = self.w_2(output)
|
||||
output = output.transpose(1, 2)
|
||||
output = self.dropout(output)
|
||||
|
||||
output = output + residual
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class FFTBlock(nn.Module):
|
||||
"""FFT Block"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_in,
|
||||
d_model,
|
||||
n_head,
|
||||
d_head,
|
||||
d_inner,
|
||||
kernel_size,
|
||||
dropout,
|
||||
dropout_attn=0.0,
|
||||
dropout_relu=0.0,
|
||||
):
|
||||
super(FFTBlock, self).__init__()
|
||||
self.slf_attn = MultiHeadSelfAttention(
|
||||
n_head,
|
||||
d_in,
|
||||
d_model,
|
||||
d_head,
|
||||
dropout=dropout,
|
||||
dropatt=dropout_attn)
|
||||
self.pos_ffn = PositionwiseConvFeedForward(
|
||||
d_model,
|
||||
d_inner,
|
||||
kernel_size,
|
||||
dropout_inner=dropout_relu,
|
||||
dropout=dropout)
|
||||
|
||||
def forward(self, input, mask=None, slf_attn_mask=None):
|
||||
output, slf_attn = self.slf_attn(input, mask=slf_attn_mask)
|
||||
if mask is not None:
|
||||
output = output.masked_fill(mask.unsqueeze(-1), 0)
|
||||
|
||||
output = self.pos_ffn(output, mask=mask)
|
||||
if mask is not None:
|
||||
output = output.masked_fill(mask.unsqueeze(-1), 0)
|
||||
|
||||
return output, slf_attn
|
||||
|
||||
|
||||
class MultiHeadPNCAAttention(nn.Module):
|
||||
""" Multi-Head Attention PNCA module """
|
||||
|
||||
def __init__(self, n_head, d_model, d_mem, d_head, dropout, dropatt=0.0):
|
||||
super().__init__()
|
||||
|
||||
self.n_head = n_head
|
||||
self.d_head = d_head
|
||||
self.d_model = d_model
|
||||
self.d_mem = d_mem
|
||||
|
||||
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
|
||||
|
||||
self.w_x_qkv = nn.Linear(d_model, 3 * n_head * d_head)
|
||||
self.fc_x = nn.Linear(n_head * d_head, d_model)
|
||||
|
||||
self.w_h_kv = nn.Linear(d_mem, 2 * n_head * d_head)
|
||||
self.fc_h = nn.Linear(n_head * d_head, d_model)
|
||||
|
||||
self.attention = ScaledDotProductAttention(
|
||||
temperature=np.power(d_head, 0.5), dropatt=dropatt)
|
||||
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def update_x_state(self, x):
|
||||
d_head, n_head = self.d_head, self.n_head
|
||||
|
||||
sz_b, len_x, _ = x.size()
|
||||
|
||||
x_qkv = self.w_x_qkv(x)
|
||||
x_q, x_k, x_v = x_qkv.chunk(3, -1)
|
||||
|
||||
x_q = x_q.view(sz_b, len_x, n_head, d_head)
|
||||
x_k = x_k.view(sz_b, len_x, n_head, d_head)
|
||||
x_v = x_v.view(sz_b, len_x, n_head, d_head)
|
||||
|
||||
x_q = x_q.permute(2, 0, 1, 3).contiguous().view(-1, len_x, d_head)
|
||||
x_k = x_k.permute(2, 0, 1, 3).contiguous().view(-1, len_x, d_head)
|
||||
x_v = x_v.permute(2, 0, 1, 3).contiguous().view(-1, len_x, d_head)
|
||||
|
||||
if self.x_state_size:
|
||||
self.x_k = torch.cat([self.x_k, x_k], dim=1)
|
||||
self.x_v = torch.cat([self.x_v, x_v], dim=1)
|
||||
else:
|
||||
self.x_k = x_k
|
||||
self.x_v = x_v
|
||||
|
||||
self.x_state_size += len_x
|
||||
|
||||
return x_q, x_k, x_v
|
||||
|
||||
def update_h_state(self, h):
|
||||
if self.h_state_size == h.size(1):
|
||||
return None, None
|
||||
|
||||
d_head, n_head = self.d_head, self.n_head
|
||||
|
||||
# H
|
||||
sz_b, len_h, _ = h.size()
|
||||
|
||||
h_kv = self.w_h_kv(h)
|
||||
h_k, h_v = h_kv.chunk(2, -1)
|
||||
|
||||
h_k = h_k.view(sz_b, len_h, n_head, d_head)
|
||||
h_v = h_v.view(sz_b, len_h, n_head, d_head)
|
||||
|
||||
self.h_k = h_k.permute(2, 0, 1, 3).contiguous().view(-1, len_h, d_head)
|
||||
self.h_v = h_v.permute(2, 0, 1, 3).contiguous().view(-1, len_h, d_head)
|
||||
|
||||
self.h_state_size += len_h
|
||||
|
||||
return h_k, h_v
|
||||
|
||||
def reset_state(self):
|
||||
self.h_k = None
|
||||
self.h_v = None
|
||||
self.h_state_size = 0
|
||||
self.x_k = None
|
||||
self.x_v = None
|
||||
self.x_state_size = 0
|
||||
|
||||
def forward(self, x, h, mask_x=None, mask_h=None):
|
||||
residual = x
|
||||
self.update_h_state(h)
|
||||
x_q, x_k, x_v = self.update_x_state(self.layer_norm(x))
|
||||
|
||||
d_head, n_head = self.d_head, self.n_head
|
||||
|
||||
sz_b, len_in, _ = x.size()
|
||||
|
||||
# X
|
||||
if mask_x is not None:
|
||||
mask_x = mask_x.repeat(n_head, 1, 1) # (n*b) x .. x ..
|
||||
output_x, attn_x = self.attention(x_q, self.x_k, self.x_v, mask=mask_x)
|
||||
|
||||
output_x = output_x.view(n_head, sz_b, len_in, d_head)
|
||||
output_x = (output_x.permute(1, 2, 0,
|
||||
3).contiguous().view(sz_b, len_in,
|
||||
-1)) # b x l x (n*d)
|
||||
output_x = self.fc_x(output_x)
|
||||
|
||||
# H
|
||||
if mask_h is not None:
|
||||
mask_h = mask_h.repeat(n_head, 1, 1)
|
||||
output_h, attn_h = self.attention(x_q, self.h_k, self.h_v, mask=mask_h)
|
||||
|
||||
output_h = output_h.view(n_head, sz_b, len_in, d_head)
|
||||
output_h = (output_h.permute(1, 2, 0,
|
||||
3).contiguous().view(sz_b, len_in,
|
||||
-1)) # b x l x (n*d)
|
||||
output_h = self.fc_h(output_h)
|
||||
|
||||
output = output_x + output_h
|
||||
|
||||
output = self.dropout(output)
|
||||
|
||||
output = output + residual
|
||||
|
||||
return output, attn_x, attn_h
|
||||
|
||||
|
||||
class PNCABlock(nn.Module):
|
||||
"""PNCA Block"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_model,
|
||||
d_mem,
|
||||
n_head,
|
||||
d_head,
|
||||
d_inner,
|
||||
kernel_size,
|
||||
dropout,
|
||||
dropout_attn=0.0,
|
||||
dropout_relu=0.0,
|
||||
):
|
||||
super(PNCABlock, self).__init__()
|
||||
self.pnca_attn = MultiHeadPNCAAttention(
|
||||
n_head,
|
||||
d_model,
|
||||
d_mem,
|
||||
d_head,
|
||||
dropout=dropout,
|
||||
dropatt=dropout_attn)
|
||||
self.pos_ffn = PositionwiseConvFeedForward(
|
||||
d_model,
|
||||
d_inner,
|
||||
kernel_size,
|
||||
dropout_inner=dropout_relu,
|
||||
dropout=dropout)
|
||||
|
||||
def forward(self,
|
||||
input,
|
||||
memory,
|
||||
mask=None,
|
||||
pnca_x_attn_mask=None,
|
||||
pnca_h_attn_mask=None):
|
||||
output, pnca_attn_x, pnca_attn_h = self.pnca_attn(
|
||||
input, memory, pnca_x_attn_mask, pnca_h_attn_mask)
|
||||
if mask is not None:
|
||||
output = output.masked_fill(mask.unsqueeze(-1), 0)
|
||||
|
||||
output = self.pos_ffn(output, mask=mask)
|
||||
if mask is not None:
|
||||
output = output.masked_fill(mask.unsqueeze(-1), 0)
|
||||
|
||||
return output, pnca_attn_x, pnca_attn_h
|
||||
|
||||
def reset_state(self):
|
||||
self.pnca_attn.reset_state()
|
||||
@@ -1,147 +0,0 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from . import Prenet
|
||||
from .fsmn import FsmnEncoderV2
|
||||
|
||||
|
||||
class LengthRegulator(nn.Module):
|
||||
|
||||
def __init__(self, r=1):
|
||||
super(LengthRegulator, self).__init__()
|
||||
|
||||
self.r = r
|
||||
|
||||
def forward(self, inputs, durations, masks=None):
|
||||
reps = (durations + 0.5).long()
|
||||
output_lens = reps.sum(dim=1)
|
||||
max_len = output_lens.max()
|
||||
reps_cumsum = torch.cumsum(
|
||||
F.pad(reps.float(), (1, 0, 0, 0), value=0.0), dim=1)[:, None, :]
|
||||
range_ = torch.arange(max_len).to(inputs.device)[None, :, None]
|
||||
mult = (reps_cumsum[:, :, :-1] <= range_) & (
|
||||
reps_cumsum[:, :, 1:] > range_)
|
||||
mult = mult.float()
|
||||
out = torch.matmul(mult, inputs)
|
||||
|
||||
if masks is not None:
|
||||
out = out.masked_fill(masks.unsqueeze(-1), 0.0)
|
||||
|
||||
seq_len = out.size(1)
|
||||
padding = self.r - int(seq_len) % self.r
|
||||
if padding < self.r:
|
||||
out = F.pad(
|
||||
out.transpose(1, 2), (0, padding, 0, 0, 0, 0), value=0.0)
|
||||
out = out.transpose(1, 2)
|
||||
|
||||
return out, output_lens
|
||||
|
||||
|
||||
class VarRnnARPredictor(nn.Module):
|
||||
|
||||
def __init__(self, cond_units, prenet_units, rnn_units):
|
||||
super(VarRnnARPredictor, self).__init__()
|
||||
|
||||
self.prenet = Prenet(1, prenet_units)
|
||||
self.lstm = nn.LSTM(
|
||||
prenet_units[-1] + cond_units,
|
||||
rnn_units,
|
||||
num_layers=2,
|
||||
batch_first=True,
|
||||
bidirectional=False,
|
||||
)
|
||||
self.fc = nn.Linear(rnn_units, 1)
|
||||
|
||||
def forward(self, inputs, cond, h=None, masks=None):
|
||||
x = torch.cat([self.prenet(inputs), cond], dim=-1)
|
||||
# The input can also be a packed variable length sequence,
|
||||
# here we just omit it for simplicity due to the mask and uni-directional lstm.
|
||||
x, h_new = self.lstm(x, h)
|
||||
|
||||
x = self.fc(x).squeeze(-1)
|
||||
x = F.relu(x)
|
||||
|
||||
if masks is not None:
|
||||
x = x.masked_fill(masks, 0.0)
|
||||
|
||||
return x, h_new
|
||||
|
||||
def infer(self, cond, masks=None):
|
||||
batch_size, length = cond.size(0), cond.size(1)
|
||||
|
||||
output = []
|
||||
x = torch.zeros((batch_size, 1)).to(cond.device)
|
||||
h = None
|
||||
|
||||
for i in range(length):
|
||||
x, h = self.forward(x.unsqueeze(1), cond[:, i:i + 1, :], h=h)
|
||||
output.append(x)
|
||||
|
||||
output = torch.cat(output, dim=-1)
|
||||
|
||||
if masks is not None:
|
||||
output = output.masked_fill(masks, 0.0)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class VarFsmnRnnNARPredictor(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_dim,
|
||||
filter_size,
|
||||
fsmn_num_layers,
|
||||
num_memory_units,
|
||||
ffn_inner_dim,
|
||||
dropout,
|
||||
shift,
|
||||
lstm_units,
|
||||
):
|
||||
super(VarFsmnRnnNARPredictor, self).__init__()
|
||||
|
||||
self.fsmn = FsmnEncoderV2(
|
||||
filter_size,
|
||||
fsmn_num_layers,
|
||||
in_dim,
|
||||
num_memory_units,
|
||||
ffn_inner_dim,
|
||||
dropout,
|
||||
shift,
|
||||
)
|
||||
self.blstm = nn.LSTM(
|
||||
num_memory_units,
|
||||
lstm_units,
|
||||
num_layers=1,
|
||||
batch_first=True,
|
||||
bidirectional=True,
|
||||
)
|
||||
self.fc = nn.Linear(2 * lstm_units, 1)
|
||||
|
||||
def forward(self, inputs, masks=None):
|
||||
input_lengths = None
|
||||
if masks is not None:
|
||||
input_lengths = torch.sum((~masks).float(), dim=1).long()
|
||||
|
||||
x = self.fsmn(inputs, masks)
|
||||
|
||||
if input_lengths is not None:
|
||||
x = nn.utils.rnn.pack_padded_sequence(
|
||||
x,
|
||||
input_lengths.tolist(),
|
||||
batch_first=True,
|
||||
enforce_sorted=False)
|
||||
x, _ = self.blstm(x)
|
||||
x, _ = nn.utils.rnn.pad_packed_sequence(
|
||||
x, batch_first=True, total_length=inputs.size(1))
|
||||
else:
|
||||
x, _ = self.blstm(x)
|
||||
|
||||
x = self.fc(x).squeeze(-1)
|
||||
|
||||
if masks is not None:
|
||||
x = x.masked_fill(masks, 0.0)
|
||||
|
||||
return x
|
||||
@@ -1,73 +0,0 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import numba as nb
|
||||
import numpy as np
|
||||
|
||||
|
||||
@nb.jit(nopython=True)
|
||||
def mas(attn_map, width=1):
|
||||
# assumes mel x text
|
||||
opt = np.zeros_like(attn_map)
|
||||
attn_map = np.log(attn_map)
|
||||
attn_map[0, 1:] = -np.inf
|
||||
log_p = np.zeros_like(attn_map)
|
||||
log_p[0, :] = attn_map[0, :]
|
||||
prev_ind = np.zeros_like(attn_map, dtype=np.int64)
|
||||
for i in range(1, attn_map.shape[0]):
|
||||
for j in range(attn_map.shape[1]): # for each text dim
|
||||
prev_j = np.arange(max(0, j - width), j + 1)
|
||||
prev_log = np.array(
|
||||
[log_p[i - 1, prev_idx] for prev_idx in prev_j])
|
||||
|
||||
ind = np.argmax(prev_log)
|
||||
log_p[i, j] = attn_map[i, j] + prev_log[ind]
|
||||
prev_ind[i, j] = prev_j[ind]
|
||||
|
||||
# now backtrack
|
||||
curr_text_idx = attn_map.shape[1] - 1
|
||||
for i in range(attn_map.shape[0] - 1, -1, -1):
|
||||
opt[i, curr_text_idx] = 1
|
||||
curr_text_idx = prev_ind[i, curr_text_idx]
|
||||
opt[0, curr_text_idx] = 1
|
||||
return opt
|
||||
|
||||
|
||||
@nb.jit(nopython=True)
|
||||
def mas_width1(attn_map):
|
||||
"""mas with hardcoded width=1"""
|
||||
# assumes mel x text
|
||||
opt = np.zeros_like(attn_map)
|
||||
attn_map = np.log(attn_map)
|
||||
attn_map[0, 1:] = -np.inf
|
||||
log_p = np.zeros_like(attn_map)
|
||||
log_p[0, :] = attn_map[0, :]
|
||||
prev_ind = np.zeros_like(attn_map, dtype=np.int64)
|
||||
for i in range(1, attn_map.shape[0]):
|
||||
for j in range(attn_map.shape[1]): # for each text dim
|
||||
prev_log = log_p[i - 1, j]
|
||||
prev_j = j
|
||||
|
||||
if j - 1 >= 0 and log_p[i - 1, j - 1] >= log_p[i - 1, j]:
|
||||
prev_log = log_p[i - 1, j - 1]
|
||||
prev_j = j - 1
|
||||
|
||||
log_p[i, j] = attn_map[i, j] + prev_log
|
||||
prev_ind[i, j] = prev_j
|
||||
|
||||
# now backtrack
|
||||
curr_text_idx = attn_map.shape[1] - 1
|
||||
for i in range(attn_map.shape[0] - 1, -1, -1):
|
||||
opt[i, curr_text_idx] = 1
|
||||
curr_text_idx = prev_ind[i, curr_text_idx]
|
||||
opt[0, curr_text_idx] = 1
|
||||
return opt
|
||||
|
||||
|
||||
@nb.jit(nopython=True, parallel=True)
|
||||
def b_mas(b_attn_map, in_lens, out_lens, width=1):
|
||||
assert width == 1
|
||||
attn_out = np.zeros_like(b_attn_map)
|
||||
|
||||
for b in nb.prange(b_attn_map.shape[0]):
|
||||
out = mas_width1(b_attn_map[b, 0, :out_lens[b], :in_lens[b]])
|
||||
attn_out[b, 0, :out_lens[b], :in_lens[b]] = out
|
||||
return attn_out
|
||||
@@ -1,131 +0,0 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class ConvNorm(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=None,
|
||||
dilation=1,
|
||||
bias=True,
|
||||
w_init_gain='linear',
|
||||
):
|
||||
super(ConvNorm, self).__init__()
|
||||
if padding is None:
|
||||
assert kernel_size % 2 == 1
|
||||
padding = int(dilation * (kernel_size - 1) / 2)
|
||||
|
||||
self.conv = torch.nn.Conv1d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
torch.nn.init.xavier_uniform_(
|
||||
self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain))
|
||||
|
||||
def forward(self, signal):
|
||||
conv_signal = self.conv(signal)
|
||||
return conv_signal
|
||||
|
||||
|
||||
class ConvAttention(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_mel_channels=80,
|
||||
n_text_channels=512,
|
||||
n_att_channels=80,
|
||||
temperature=1.0,
|
||||
use_query_proj=True,
|
||||
):
|
||||
super(ConvAttention, self).__init__()
|
||||
self.temperature = temperature
|
||||
self.att_scaling_factor = np.sqrt(n_att_channels)
|
||||
self.softmax = torch.nn.Softmax(dim=3)
|
||||
self.log_softmax = torch.nn.LogSoftmax(dim=3)
|
||||
self.attn_proj = torch.nn.Conv2d(n_att_channels, 1, kernel_size=1)
|
||||
self.use_query_proj = bool(use_query_proj)
|
||||
|
||||
self.key_proj = nn.Sequential(
|
||||
ConvNorm(
|
||||
n_text_channels,
|
||||
n_text_channels * 2,
|
||||
kernel_size=3,
|
||||
bias=True,
|
||||
w_init_gain='relu',
|
||||
),
|
||||
torch.nn.ReLU(),
|
||||
ConvNorm(
|
||||
n_text_channels * 2, n_att_channels, kernel_size=1, bias=True),
|
||||
)
|
||||
|
||||
self.query_proj = nn.Sequential(
|
||||
ConvNorm(
|
||||
n_mel_channels,
|
||||
n_mel_channels * 2,
|
||||
kernel_size=3,
|
||||
bias=True,
|
||||
w_init_gain='relu',
|
||||
),
|
||||
torch.nn.ReLU(),
|
||||
ConvNorm(
|
||||
n_mel_channels * 2, n_mel_channels, kernel_size=1, bias=True),
|
||||
torch.nn.ReLU(),
|
||||
ConvNorm(n_mel_channels, n_att_channels, kernel_size=1, bias=True),
|
||||
)
|
||||
|
||||
def forward(self, queries, keys, mask=None, attn_prior=None):
|
||||
"""Attention mechanism for flowtron parallel
|
||||
Unlike in Flowtron, we have no restrictions such as causality etc,
|
||||
since we only need this during training.
|
||||
|
||||
Args:
|
||||
queries (torch.tensor): B x C x T1 tensor
|
||||
(probably going to be mel data)
|
||||
keys (torch.tensor): B x C2 x T2 tensor (text data)
|
||||
mask (torch.tensor): uint8 binary mask for variable length entries
|
||||
(should be in the T2 domain)
|
||||
Output:
|
||||
attn (torch.tensor): B x 1 x T1 x T2 attention mask.
|
||||
Final dim T2 should sum to 1
|
||||
"""
|
||||
keys_enc = self.key_proj(keys) # B x n_attn_dims x T2
|
||||
|
||||
# Beware can only do this since query_dim = attn_dim = n_mel_channels
|
||||
if self.use_query_proj:
|
||||
queries_enc = self.query_proj(queries)
|
||||
else:
|
||||
queries_enc = queries
|
||||
|
||||
# different ways of computing attn,
|
||||
# one is isotopic gaussians (per phoneme)
|
||||
# Simplistic Gaussian Isotopic Attention
|
||||
|
||||
# B x n_attn_dims x T1 x T2
|
||||
attn = (queries_enc[:, :, :, None] - keys_enc[:, :, None])**2
|
||||
# compute log likelihood from a gaussian
|
||||
attn = -0.0005 * attn.sum(1, keepdim=True)
|
||||
if attn_prior is not None:
|
||||
attn = self.log_softmax(attn) + torch.log(attn_prior[:, None]
|
||||
+ 1e-8)
|
||||
|
||||
attn_logprob = attn.clone()
|
||||
|
||||
if mask is not None:
|
||||
attn.data.masked_fill_(
|
||||
mask.unsqueeze(1).unsqueeze(1), -float('inf'))
|
||||
|
||||
attn = self.softmax(attn) # Softmax along T2
|
||||
return attn, attn_logprob
|
||||
@@ -1,127 +0,0 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class FeedForwardNet(nn.Module):
|
||||
""" A two-feed-forward-layer module """
|
||||
|
||||
def __init__(self, d_in, d_hid, d_out, kernel_size=[1, 1], dropout=0.1):
|
||||
super().__init__()
|
||||
|
||||
# Use Conv1D
|
||||
# position-wise
|
||||
self.w_1 = nn.Conv1d(
|
||||
d_in,
|
||||
d_hid,
|
||||
kernel_size=kernel_size[0],
|
||||
padding=(kernel_size[0] - 1) // 2,
|
||||
)
|
||||
# position-wise
|
||||
self.w_2 = nn.Conv1d(
|
||||
d_hid,
|
||||
d_out,
|
||||
kernel_size=kernel_size[1],
|
||||
padding=(kernel_size[1] - 1) // 2,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x):
|
||||
output = x.transpose(1, 2)
|
||||
output = F.relu(self.w_1(output))
|
||||
output = self.dropout(output)
|
||||
output = self.w_2(output)
|
||||
output = output.transpose(1, 2)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class MemoryBlockV2(nn.Module):
|
||||
|
||||
def __init__(self, d, filter_size, shift, dropout=0.0):
|
||||
super(MemoryBlockV2, self).__init__()
|
||||
|
||||
left_padding = int(round((filter_size - 1) / 2))
|
||||
right_padding = int((filter_size - 1) / 2)
|
||||
if shift > 0:
|
||||
left_padding += shift
|
||||
right_padding -= shift
|
||||
|
||||
self.lp, self.rp = left_padding, right_padding
|
||||
|
||||
self.conv_dw = nn.Conv1d(d, d, filter_size, 1, 0, groups=d, bias=False)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, input, mask=None):
|
||||
if mask is not None:
|
||||
input = input.masked_fill(mask.unsqueeze(-1), 0)
|
||||
|
||||
x = F.pad(
|
||||
input, (0, 0, self.lp, self.rp, 0, 0), mode='constant', value=0.0)
|
||||
output = (
|
||||
self.conv_dw(x.contiguous().transpose(1,
|
||||
2)).contiguous().transpose(
|
||||
1, 2))
|
||||
output += input
|
||||
output = self.dropout(output)
|
||||
|
||||
if mask is not None:
|
||||
output = output.masked_fill(mask.unsqueeze(-1), 0)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class FsmnEncoderV2(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
filter_size,
|
||||
fsmn_num_layers,
|
||||
input_dim,
|
||||
num_memory_units,
|
||||
ffn_inner_dim,
|
||||
dropout=0.0,
|
||||
shift=0,
|
||||
):
|
||||
super(FsmnEncoderV2, self).__init__()
|
||||
|
||||
self.filter_size = filter_size
|
||||
self.fsmn_num_layers = fsmn_num_layers
|
||||
self.num_memory_units = num_memory_units
|
||||
self.ffn_inner_dim = ffn_inner_dim
|
||||
self.dropout = dropout
|
||||
self.shift = shift
|
||||
if not isinstance(shift, list):
|
||||
self.shift = [shift for _ in range(self.fsmn_num_layers)]
|
||||
|
||||
self.ffn_lst = nn.ModuleList()
|
||||
self.ffn_lst.append(
|
||||
FeedForwardNet(
|
||||
input_dim, ffn_inner_dim, num_memory_units, dropout=dropout))
|
||||
for i in range(1, fsmn_num_layers):
|
||||
self.ffn_lst.append(
|
||||
FeedForwardNet(
|
||||
num_memory_units,
|
||||
ffn_inner_dim,
|
||||
num_memory_units,
|
||||
dropout=dropout))
|
||||
|
||||
self.memory_block_lst = nn.ModuleList()
|
||||
for i in range(fsmn_num_layers):
|
||||
self.memory_block_lst.append(
|
||||
MemoryBlockV2(num_memory_units, filter_size, self.shift[i],
|
||||
dropout))
|
||||
|
||||
def forward(self, input, mask=None):
|
||||
x = F.dropout(input, self.dropout, self.training)
|
||||
for (ffn, memory_block) in zip(self.ffn_lst, self.memory_block_lst):
|
||||
context = ffn(x)
|
||||
memory = memory_block(context, mask)
|
||||
memory = F.dropout(memory, self.dropout, self.training)
|
||||
if memory.size(-1) == x.size(-1):
|
||||
memory += x
|
||||
x = memory
|
||||
|
||||
return x
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,102 +0,0 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class SinusoidalPositionEncoder(nn.Module):
|
||||
|
||||
def __init__(self, max_len, depth):
|
||||
super(SinusoidalPositionEncoder, self).__init__()
|
||||
|
||||
self.max_len = max_len
|
||||
self.depth = depth
|
||||
self.position_enc = nn.Parameter(
|
||||
self.get_sinusoid_encoding_table(max_len, depth).unsqueeze(0),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
bz_in, len_in, _ = input.size()
|
||||
if len_in > self.max_len:
|
||||
self.max_len = len_in
|
||||
self.position_enc.data = (
|
||||
self.get_sinusoid_encoding_table(
|
||||
self.max_len, self.depth).unsqueeze(0).to(input.device))
|
||||
|
||||
output = input + self.position_enc[:, :len_in, :].expand(bz_in, -1, -1)
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None):
|
||||
""" Sinusoid position encoding table """
|
||||
|
||||
def cal_angle(position, hid_idx):
|
||||
return position / np.power(10000, hid_idx / float(d_hid / 2 - 1))
|
||||
|
||||
def get_posi_angle_vec(position):
|
||||
return [cal_angle(position, hid_j) for hid_j in range(d_hid // 2)]
|
||||
|
||||
scaled_time_table = np.array(
|
||||
[get_posi_angle_vec(pos_i + 1) for pos_i in range(n_position)])
|
||||
|
||||
sinusoid_table = np.zeros((n_position, d_hid))
|
||||
sinusoid_table[:, :d_hid // 2] = np.sin(scaled_time_table)
|
||||
sinusoid_table[:, d_hid // 2:] = np.cos(scaled_time_table)
|
||||
|
||||
if padding_idx is not None:
|
||||
# zero vector for padding dimension
|
||||
sinusoid_table[padding_idx] = 0.0
|
||||
|
||||
return torch.FloatTensor(sinusoid_table)
|
||||
|
||||
|
||||
class DurSinusoidalPositionEncoder(nn.Module):
|
||||
|
||||
def __init__(self, depth, outputs_per_step):
|
||||
super(DurSinusoidalPositionEncoder, self).__init__()
|
||||
|
||||
self.depth = depth
|
||||
self.outputs_per_step = outputs_per_step
|
||||
|
||||
inv_timescales = [
|
||||
np.power(10000, 2 * (hid_idx // 2) / depth)
|
||||
for hid_idx in range(depth)
|
||||
]
|
||||
self.inv_timescales = nn.Parameter(
|
||||
torch.FloatTensor(inv_timescales), requires_grad=False)
|
||||
|
||||
def forward(self, durations, masks=None):
|
||||
reps = (durations + 0.5).long()
|
||||
output_lens = reps.sum(dim=1)
|
||||
max_len = output_lens.max()
|
||||
reps_cumsum = torch.cumsum(
|
||||
F.pad(reps.float(), (1, 0, 0, 0), value=0.0), dim=1)[:, None, :]
|
||||
range_ = torch.arange(max_len).to(durations.device)[None, :, None]
|
||||
mult = (reps_cumsum[:, :, :-1] <= range_) & (
|
||||
reps_cumsum[:, :, 1:] > range_)
|
||||
mult = mult.float()
|
||||
offsets = torch.matmul(mult,
|
||||
reps_cumsum[:,
|
||||
0, :-1].unsqueeze(-1)).squeeze(-1)
|
||||
dur_pos = range_[:, :, 0] - offsets + 1
|
||||
|
||||
if masks is not None:
|
||||
assert masks.size(1) == dur_pos.size(1)
|
||||
dur_pos = dur_pos.masked_fill(masks, 0.0)
|
||||
|
||||
seq_len = dur_pos.size(1)
|
||||
padding = self.outputs_per_step - int(seq_len) % self.outputs_per_step
|
||||
if padding < self.outputs_per_step:
|
||||
dur_pos = F.pad(dur_pos, (0, padding, 0, 0), value=0.0)
|
||||
|
||||
position_embedding = dur_pos[:, :, None] / self.inv_timescales[None,
|
||||
None, :]
|
||||
position_embedding[:, :, 0::2] = torch.sin(position_embedding[:, :,
|
||||
0::2])
|
||||
position_embedding[:, :, 1::2] = torch.cos(position_embedding[:, :,
|
||||
1::2])
|
||||
|
||||
return position_embedding
|
||||
@@ -1,26 +0,0 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from distutils.version import LooseVersion
|
||||
|
||||
import torch
|
||||
|
||||
is_pytorch_17plus = LooseVersion(torch.__version__) >= LooseVersion('1.7')
|
||||
|
||||
|
||||
def init_weights(m, mean=0.0, std=0.01):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find('Conv') != -1:
|
||||
m.weight.data.normal_(mean, std)
|
||||
|
||||
|
||||
def get_mask_from_lengths(lengths, max_len=None):
|
||||
batch_size = lengths.shape[0]
|
||||
if max_len is None:
|
||||
max_len = torch.max(lengths).item()
|
||||
|
||||
ids = (
|
||||
torch.arange(0, max_len).unsqueeze(0).expand(batch_size,
|
||||
-1).to(lengths.device))
|
||||
mask = ids >= lengths.unsqueeze(1).expand(-1, max_len)
|
||||
|
||||
return mask
|
||||
@@ -1,774 +0,0 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from glob import glob
|
||||
|
||||
import numpy as np
|
||||
import yaml
|
||||
from tqdm import tqdm
|
||||
|
||||
from modelscope.utils.logger import get_logger
|
||||
from .core.dsp import (load_wav, melspectrogram, save_wav, trim_silence,
|
||||
trim_silence_with_interval)
|
||||
from .core.utils import (align_length, average_by_duration, compute_mean,
|
||||
compute_std, encode_16bits, f0_norm_mean_std,
|
||||
get_energy, get_pitch, norm_mean_std,
|
||||
parse_interval_file, volume_normalize)
|
||||
|
||||
logging = get_logger()
|
||||
|
||||
default_audio_config = {
|
||||
# Preprocess
|
||||
'wav_normalize': True,
|
||||
'trim_silence': True,
|
||||
'trim_silence_threshold_db': 60,
|
||||
'preemphasize': False,
|
||||
# Feature extraction
|
||||
'sampling_rate': 24000,
|
||||
'hop_length': 240,
|
||||
'win_length': 1024,
|
||||
'n_mels': 80,
|
||||
'n_fft': 1024,
|
||||
'fmin': 50.0,
|
||||
'fmax': 7600.0,
|
||||
'min_level_db': -100,
|
||||
'ref_level_db': 20,
|
||||
'phone_level_feature': True,
|
||||
'num_workers': 16,
|
||||
# Normalization
|
||||
'norm_type': 'mean_std', # 'mean_std', 'global norm'
|
||||
'max_norm': 1.0,
|
||||
'symmetric': False,
|
||||
}
|
||||
|
||||
|
||||
class AudioProcessor:
|
||||
|
||||
def __init__(self, config=None):
|
||||
if not isinstance(config, dict):
|
||||
logging.warning(
|
||||
'[AudioProcessor] config is not a dict, fall into default config.'
|
||||
)
|
||||
self.config = default_audio_config
|
||||
else:
|
||||
self.config = config
|
||||
|
||||
for key in self.config:
|
||||
setattr(self, key, self.config[key])
|
||||
|
||||
self.min_wav_length = int(self.config['sampling_rate'] * 0.5)
|
||||
|
||||
self.badcase_list = []
|
||||
self.pcm_dict = {}
|
||||
self.mel_dict = {}
|
||||
self.f0_dict = {}
|
||||
self.uv_dict = {}
|
||||
self.nccf_dict = {}
|
||||
self.f0uv_dict = {}
|
||||
self.energy_dict = {}
|
||||
self.dur_dict = {}
|
||||
logging.info('[AudioProcessor] Initialize AudioProcessor.')
|
||||
logging.info('[AudioProcessor] config params:')
|
||||
for key in self.config:
|
||||
logging.info('[AudioProcessor] %s: %s', key, self.config[key])
|
||||
|
||||
def calibrate_SyllableDuration(self, raw_dur_dir, raw_metafile,
|
||||
out_cali_duration_dir):
|
||||
with open(raw_metafile, 'r') as f:
|
||||
lines = f.readlines()
|
||||
|
||||
output_dur_dir = out_cali_duration_dir
|
||||
os.makedirs(output_dur_dir, exist_ok=True)
|
||||
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
index, symbols = line.split('\t')
|
||||
symbols = [
|
||||
symbol.strip('{').strip('}').split('$')[0]
|
||||
for symbol in symbols.strip().split(' ')
|
||||
]
|
||||
dur_file = os.path.join(raw_dur_dir, index + '.npy')
|
||||
phone_file = os.path.join(raw_dur_dir, index + '.phone')
|
||||
if not os.path.exists(dur_file) or not os.path.exists(phone_file):
|
||||
logging.warning(
|
||||
'[AudioProcessor] dur file or phone file not exists: %s',
|
||||
index)
|
||||
continue
|
||||
with open(phone_file, 'r') as f:
|
||||
phones = f.readlines()
|
||||
dur = np.load(dur_file)
|
||||
cali_duration = []
|
||||
|
||||
dur_idx = 0
|
||||
syll_idx = 0
|
||||
|
||||
while dur_idx < len(dur) and syll_idx < len(symbols):
|
||||
if phones[dur_idx].strip() == 'sil':
|
||||
dur_idx += 1
|
||||
continue
|
||||
|
||||
if phones[dur_idx].strip(
|
||||
) == 'sp' and symbols[syll_idx][0] != '#':
|
||||
dur_idx += 1
|
||||
continue
|
||||
|
||||
if symbols[syll_idx] in ['ga', 'go', 'ge']:
|
||||
cali_duration.append(0)
|
||||
syll_idx += 1
|
||||
# print("NONE", symbols[syll_idx], 0)
|
||||
continue
|
||||
|
||||
if symbols[syll_idx][0] == '#':
|
||||
if phones[dur_idx].strip() != 'sp':
|
||||
cali_duration.append(0)
|
||||
# print("NONE", symbols[syll_idx], 0)
|
||||
syll_idx += 1
|
||||
continue
|
||||
else:
|
||||
cali_duration.append(dur[dur_idx])
|
||||
# print(phones[dur_idx].strip(), symbols[syll_idx], dur[dur_idx])
|
||||
dur_idx += 1
|
||||
syll_idx += 1
|
||||
continue
|
||||
# A corresponding phone is found
|
||||
cali_duration.append(dur[dur_idx])
|
||||
# print(phones[dur_idx].strip(), symbols[syll_idx], dur[dur_idx])
|
||||
dur_idx += 1
|
||||
syll_idx += 1
|
||||
# Add #4 phone duration
|
||||
cali_duration.append(0)
|
||||
if len(cali_duration) != len(symbols):
|
||||
logging.error('[Duration Calibrating] Syllable duration {}\
|
||||
is not equal to the number of symbols {}, index: {}'.
|
||||
format(len(cali_duration), len(symbols), index))
|
||||
continue
|
||||
|
||||
# Align with mel frames
|
||||
durs = np.array(cali_duration)
|
||||
if len(self.mel_dict) > 0:
|
||||
pair_mel = self.mel_dict.get(index, None)
|
||||
if pair_mel is None:
|
||||
logging.warning(
|
||||
'[AudioProcessor] Interval file %s has no corresponding mel',
|
||||
index,
|
||||
)
|
||||
continue
|
||||
mel_frames = pair_mel.shape[0]
|
||||
dur_frames = np.sum(durs)
|
||||
if np.sum(durs) > mel_frames:
|
||||
durs[-2] -= dur_frames - mel_frames
|
||||
elif np.sum(durs) < mel_frames:
|
||||
durs[-2] += mel_frames - np.sum(durs)
|
||||
|
||||
if durs[-2] < 0:
|
||||
logging.error(
|
||||
'[AudioProcessor] Duration calibrating failed for %s, mismatch frames %s',
|
||||
index,
|
||||
durs[-2],
|
||||
)
|
||||
self.badcase_list.append(index)
|
||||
continue
|
||||
|
||||
self.dur_dict[index] = durs
|
||||
|
||||
np.save(
|
||||
os.path.join(output_dur_dir, index + '.npy'),
|
||||
self.dur_dict[index])
|
||||
|
||||
def amp_normalize(self, src_wav_dir, out_wav_dir):
|
||||
if self.wav_normalize:
|
||||
logging.info('[AudioProcessor] Amplitude normalization started')
|
||||
os.makedirs(out_wav_dir, exist_ok=True)
|
||||
res = volume_normalize(src_wav_dir, out_wav_dir)
|
||||
logging.info('[AudioProcessor] Amplitude normalization finished')
|
||||
return res
|
||||
else:
|
||||
logging.info('[AudioProcessor] No amplitude normalization')
|
||||
os.symlink(src_wav_dir, out_wav_dir, target_is_directory=True)
|
||||
return True
|
||||
|
||||
def get_pcm_dict(self, src_wav_dir):
|
||||
wav_list = glob(os.path.join(src_wav_dir, '*.wav'))
|
||||
if len(self.pcm_dict) > 0:
|
||||
return self.pcm_dict
|
||||
|
||||
logging.info('[AudioProcessor] Start to load pcm from %s', src_wav_dir)
|
||||
with ProcessPoolExecutor(
|
||||
max_workers=self.num_workers) as executor, tqdm(
|
||||
total=len(wav_list)) as progress:
|
||||
futures = []
|
||||
for wav_path in wav_list:
|
||||
future = executor.submit(load_wav, wav_path,
|
||||
self.sampling_rate)
|
||||
future.add_done_callback(lambda p: progress.update())
|
||||
wav_name = os.path.splitext(os.path.basename(wav_path))[0]
|
||||
futures.append((future, wav_name))
|
||||
for future, wav_name in futures:
|
||||
pcm = future.result()
|
||||
if len(pcm) < self.min_wav_length:
|
||||
logging.warning('[AudioProcessor] %s is too short, skip',
|
||||
wav_name)
|
||||
self.badcase_list.append(wav_name)
|
||||
continue
|
||||
self.pcm_dict[wav_name] = pcm
|
||||
|
||||
return self.pcm_dict
|
||||
|
||||
def trim_silence_wav(self, src_wav_dir, out_wav_dir=None):
|
||||
wav_list = glob(os.path.join(src_wav_dir, '*.wav'))
|
||||
logging.info('[AudioProcessor] Trim silence started')
|
||||
if out_wav_dir is None:
|
||||
out_wav_dir = src_wav_dir
|
||||
else:
|
||||
os.makedirs(out_wav_dir, exist_ok=True)
|
||||
pcm_dict = self.get_pcm_dict(src_wav_dir)
|
||||
with ProcessPoolExecutor(
|
||||
max_workers=self.num_workers) as executor, tqdm(
|
||||
total=len(wav_list)) as progress:
|
||||
futures = []
|
||||
for wav_basename, pcm_data in pcm_dict.items():
|
||||
future = executor.submit(
|
||||
trim_silence,
|
||||
pcm_data,
|
||||
self.trim_silence_threshold_db,
|
||||
self.hop_length,
|
||||
self.win_length,
|
||||
)
|
||||
future.add_done_callback(lambda p: progress.update())
|
||||
futures.append((future, wav_basename))
|
||||
for future, wav_basename in tqdm(futures):
|
||||
pcm = future.result()
|
||||
if len(pcm) < self.min_wav_length:
|
||||
logging.warning('[AudioProcessor] %s is too short, skip',
|
||||
wav_basename)
|
||||
self.badcase_list.append(wav_basename)
|
||||
self.pcm_dict.pop(wav_basename)
|
||||
continue
|
||||
self.pcm_dict[wav_basename] = pcm
|
||||
save_wav(
|
||||
self.pcm_dict[wav_basename],
|
||||
os.path.join(out_wav_dir, wav_basename + '.wav'),
|
||||
self.sampling_rate,
|
||||
)
|
||||
|
||||
logging.info('[AudioProcessor] Trim silence finished')
|
||||
return True
|
||||
|
||||
def trim_silence_wav_with_interval(self,
|
||||
src_wav_dir,
|
||||
dur_dir,
|
||||
out_wav_dir=None):
|
||||
wav_list = glob(os.path.join(src_wav_dir, '*.wav'))
|
||||
logging.info('[AudioProcessor] Trim silence with interval started')
|
||||
if out_wav_dir is None:
|
||||
out_wav_dir = src_wav_dir
|
||||
else:
|
||||
os.makedirs(out_wav_dir, exist_ok=True)
|
||||
pcm_dict = self.get_pcm_dict(src_wav_dir)
|
||||
with ProcessPoolExecutor(
|
||||
max_workers=self.num_workers) as executor, tqdm(
|
||||
total=len(wav_list)) as progress:
|
||||
futures = []
|
||||
for wav_basename, pcm_data in pcm_dict.items():
|
||||
future = executor.submit(
|
||||
trim_silence_with_interval,
|
||||
pcm_data,
|
||||
self.dur_dict.get(wav_basename, None),
|
||||
self.hop_length,
|
||||
)
|
||||
future.add_done_callback(lambda p: progress.update())
|
||||
futures.append((future, wav_basename))
|
||||
for future, wav_basename in tqdm(futures):
|
||||
trimed_pcm = future.result()
|
||||
if trimed_pcm is None:
|
||||
continue
|
||||
if len(trimed_pcm) < self.min_wav_length:
|
||||
logging.warning('[AudioProcessor] %s is too short, skip',
|
||||
wav_basename)
|
||||
self.badcase_list.append(wav_basename)
|
||||
self.pcm_dict.pop(wav_basename)
|
||||
continue
|
||||
self.pcm_dict[wav_basename] = trimed_pcm
|
||||
save_wav(
|
||||
self.pcm_dict[wav_basename],
|
||||
os.path.join(out_wav_dir, wav_basename + '.wav'),
|
||||
self.sampling_rate,
|
||||
)
|
||||
|
||||
logging.info('[AudioProcessor] Trim silence finished')
|
||||
return True
|
||||
|
||||
def mel_extract(self, src_wav_dir, out_feature_dir):
|
||||
os.makedirs(out_feature_dir, exist_ok=True)
|
||||
wav_list = glob(os.path.join(src_wav_dir, '*.wav'))
|
||||
pcm_dict = self.get_pcm_dict(src_wav_dir)
|
||||
|
||||
logging.info('[AudioProcessor] Melspec extraction started')
|
||||
|
||||
# Get global normed mel spec
|
||||
with ProcessPoolExecutor(
|
||||
max_workers=self.num_workers) as executor, tqdm(
|
||||
total=len(wav_list)) as progress:
|
||||
futures = []
|
||||
for wav_basename, pcm_data in pcm_dict.items():
|
||||
future = executor.submit(
|
||||
melspectrogram,
|
||||
pcm_data,
|
||||
self.sampling_rate,
|
||||
self.n_fft,
|
||||
self.hop_length,
|
||||
self.win_length,
|
||||
self.n_mels,
|
||||
self.max_norm,
|
||||
self.min_level_db,
|
||||
self.ref_level_db,
|
||||
self.fmin,
|
||||
self.fmax,
|
||||
self.symmetric,
|
||||
self.preemphasize,
|
||||
)
|
||||
future.add_done_callback(lambda p: progress.update())
|
||||
futures.append((future, wav_basename))
|
||||
|
||||
for future, wav_basename in futures:
|
||||
result = future.result()
|
||||
if result is None:
|
||||
logging.warning(
|
||||
'[AudioProcessor] Melspec extraction failed for %s',
|
||||
wav_basename,
|
||||
)
|
||||
self.badcase_list.append(wav_basename)
|
||||
else:
|
||||
melspec = result
|
||||
self.mel_dict[wav_basename] = melspec
|
||||
|
||||
logging.info('[AudioProcessor] Melspec extraction finished')
|
||||
|
||||
# FIXME: is this step necessary?
|
||||
# Do mean std norm on global-normed melspec
|
||||
logging.info('Melspec statistic proceeding...')
|
||||
mel_mean = compute_mean(list(self.mel_dict.values()), dims=self.n_mels)
|
||||
mel_std = compute_std(
|
||||
list(self.mel_dict.values()), mel_mean, dims=self.n_mels)
|
||||
logging.info('Melspec statistic done')
|
||||
np.savetxt(
|
||||
os.path.join(out_feature_dir, 'mel_mean.txt'),
|
||||
mel_mean,
|
||||
fmt='%.6f')
|
||||
np.savetxt(
|
||||
os.path.join(out_feature_dir, 'mel_std.txt'), mel_std, fmt='%.6f')
|
||||
logging.info(
|
||||
'[AudioProcessor] melspec mean and std saved to:\n{},\n{}'.format(
|
||||
os.path.join(out_feature_dir, 'mel_mean.txt'),
|
||||
os.path.join(out_feature_dir, 'mel_std.txt'),
|
||||
))
|
||||
|
||||
logging.info('[AudioProcessor] Melspec mean std norm is proceeding...')
|
||||
for wav_basename in self.mel_dict:
|
||||
melspec = self.mel_dict[wav_basename]
|
||||
norm_melspec = norm_mean_std(melspec, mel_mean, mel_std)
|
||||
np.save(
|
||||
os.path.join(out_feature_dir, wav_basename + '.npy'),
|
||||
norm_melspec)
|
||||
|
||||
logging.info('[AudioProcessor] Melspec normalization finished')
|
||||
logging.info('[AudioProcessor] Normed Melspec saved to %s',
|
||||
out_feature_dir)
|
||||
|
||||
return True
|
||||
|
||||
def duration_generate(self, src_interval_dir, out_feature_dir):
|
||||
os.makedirs(out_feature_dir, exist_ok=True)
|
||||
interval_list = glob(os.path.join(src_interval_dir, '*.interval'))
|
||||
|
||||
logging.info('[AudioProcessor] Duration generation started')
|
||||
with ProcessPoolExecutor(
|
||||
max_workers=self.num_workers) as executor, tqdm(
|
||||
total=len(interval_list)) as progress:
|
||||
futures = []
|
||||
for interval_file_path in interval_list:
|
||||
future = executor.submit(
|
||||
parse_interval_file,
|
||||
interval_file_path,
|
||||
self.sampling_rate,
|
||||
self.hop_length,
|
||||
)
|
||||
future.add_done_callback(lambda p: progress.update())
|
||||
futures.append((future,
|
||||
os.path.splitext(
|
||||
os.path.basename(interval_file_path))[0]))
|
||||
|
||||
logging.info(
|
||||
'[AudioProcessor] Duration align with mel is proceeding...')
|
||||
for future, wav_basename in futures:
|
||||
result = future.result()
|
||||
if result is None:
|
||||
logging.warning(
|
||||
'[AudioProcessor] Duration generate failed for %s',
|
||||
wav_basename)
|
||||
self.badcase_list.append(wav_basename)
|
||||
else:
|
||||
durs, phone_list = result
|
||||
# Align length with melspec
|
||||
if len(self.mel_dict) > 0:
|
||||
pair_mel = self.mel_dict.get(wav_basename, None)
|
||||
if pair_mel is None:
|
||||
logging.warning(
|
||||
'[AudioProcessor] Interval file %s has no corresponding mel',
|
||||
wav_basename,
|
||||
)
|
||||
continue
|
||||
mel_frames = pair_mel.shape[0]
|
||||
dur_frames = np.sum(durs)
|
||||
if np.sum(durs) > mel_frames:
|
||||
durs[-1] -= dur_frames - mel_frames
|
||||
elif np.sum(durs) < mel_frames:
|
||||
durs[-1] += mel_frames - np.sum(durs)
|
||||
|
||||
if durs[-1] < 0:
|
||||
logging.error(
|
||||
'[AudioProcessor] Duration align failed for %s, mismatch frames %s',
|
||||
wav_basename,
|
||||
durs[-1],
|
||||
)
|
||||
self.badcase_list.append(wav_basename)
|
||||
continue
|
||||
|
||||
self.dur_dict[wav_basename] = durs
|
||||
|
||||
np.save(
|
||||
os.path.join(out_feature_dir, wav_basename + '.npy'),
|
||||
durs)
|
||||
with open(
|
||||
os.path.join(out_feature_dir,
|
||||
wav_basename + '.phone'), 'w') as f:
|
||||
f.write('\n'.join(phone_list))
|
||||
logging.info('[AudioProcessor] Duration generate finished')
|
||||
|
||||
return True
|
||||
|
||||
def pitch_extract(self, src_wav_dir, out_f0_dir, out_frame_f0_dir,
|
||||
out_frame_uv_dir):
|
||||
os.makedirs(out_f0_dir, exist_ok=True)
|
||||
os.makedirs(out_frame_f0_dir, exist_ok=True)
|
||||
os.makedirs(out_frame_uv_dir, exist_ok=True)
|
||||
wav_list = glob(os.path.join(src_wav_dir, '*.wav'))
|
||||
pcm_dict = self.get_pcm_dict(src_wav_dir)
|
||||
mel_dict = self.mel_dict
|
||||
|
||||
logging.info('[AudioProcessor] Pitch extraction started')
|
||||
# Get raw pitch
|
||||
with ProcessPoolExecutor(
|
||||
max_workers=self.num_workers) as executor, tqdm(
|
||||
total=len(wav_list)) as progress:
|
||||
futures = []
|
||||
for wav_basename, pcm_data in pcm_dict.items():
|
||||
future = executor.submit(
|
||||
get_pitch,
|
||||
encode_16bits(pcm_data),
|
||||
self.sampling_rate,
|
||||
self.hop_length,
|
||||
)
|
||||
future.add_done_callback(lambda p: progress.update())
|
||||
futures.append((future, wav_basename))
|
||||
|
||||
logging.info(
|
||||
'[AudioProcessor] Pitch align with mel is proceeding...')
|
||||
for future, wav_basename in futures:
|
||||
result = future.result()
|
||||
if result is None:
|
||||
logging.warning(
|
||||
'[AudioProcessor] Pitch extraction failed for %s',
|
||||
wav_basename)
|
||||
self.badcase_list.append(wav_basename)
|
||||
else:
|
||||
f0, uv, f0uv = result
|
||||
if len(mel_dict) > 0:
|
||||
f0 = align_length(f0, mel_dict.get(wav_basename, None))
|
||||
uv = align_length(uv, mel_dict.get(wav_basename, None))
|
||||
f0uv = align_length(f0uv,
|
||||
mel_dict.get(wav_basename, None))
|
||||
|
||||
if f0 is None or uv is None or f0uv is None:
|
||||
logging.warning(
|
||||
'[AudioProcessor] Pitch length mismatch with mel in %s',
|
||||
wav_basename,
|
||||
)
|
||||
self.badcase_list.append(wav_basename)
|
||||
continue
|
||||
self.f0_dict[wav_basename] = f0
|
||||
self.uv_dict[wav_basename] = uv
|
||||
self.f0uv_dict[wav_basename] = f0uv
|
||||
|
||||
# Normalize f0
|
||||
logging.info('[AudioProcessor] Pitch normalization is proceeding...')
|
||||
f0_mean = compute_mean(list(self.f0uv_dict.values()), dims=1)
|
||||
f0_std = compute_std(list(self.f0uv_dict.values()), f0_mean, dims=1)
|
||||
np.savetxt(
|
||||
os.path.join(out_f0_dir, 'f0_mean.txt'), f0_mean, fmt='%.6f')
|
||||
np.savetxt(os.path.join(out_f0_dir, 'f0_std.txt'), f0_std, fmt='%.6f')
|
||||
logging.info(
|
||||
'[AudioProcessor] f0 mean and std saved to:\n{},\n{}'.format(
|
||||
os.path.join(out_f0_dir, 'f0_mean.txt'),
|
||||
os.path.join(out_f0_dir, 'f0_std.txt'),
|
||||
))
|
||||
logging.info('[AudioProcessor] Pitch mean std norm is proceeding...')
|
||||
for wav_basename in self.f0uv_dict:
|
||||
f0 = self.f0uv_dict[wav_basename]
|
||||
norm_f0 = f0_norm_mean_std(f0, f0_mean, f0_std)
|
||||
self.f0uv_dict[wav_basename] = norm_f0
|
||||
|
||||
for wav_basename in self.f0_dict:
|
||||
f0 = self.f0_dict[wav_basename]
|
||||
norm_f0 = f0_norm_mean_std(f0, f0_mean, f0_std)
|
||||
self.f0_dict[wav_basename] = norm_f0
|
||||
|
||||
# save frame f0 to a specific dir
|
||||
for wav_basename in self.f0_dict:
|
||||
np.save(
|
||||
os.path.join(out_frame_f0_dir, wav_basename + '.npy'),
|
||||
self.f0_dict[wav_basename].reshape(-1),
|
||||
)
|
||||
|
||||
for wav_basename in self.uv_dict:
|
||||
np.save(
|
||||
os.path.join(out_frame_uv_dir, wav_basename + '.npy'),
|
||||
self.uv_dict[wav_basename].reshape(-1),
|
||||
)
|
||||
|
||||
# phone level average
|
||||
# if there is no duration then save the frame-level f0
|
||||
if self.phone_level_feature and len(self.dur_dict) > 0:
|
||||
logging.info(
|
||||
'[AudioProcessor] Pitch turn to phone-level is proceeding...')
|
||||
with ProcessPoolExecutor(
|
||||
max_workers=self.num_workers) as executor, tqdm(
|
||||
total=len(self.f0uv_dict)) as progress:
|
||||
futures = []
|
||||
for wav_basename in self.f0uv_dict:
|
||||
future = executor.submit(
|
||||
average_by_duration,
|
||||
self.f0uv_dict.get(wav_basename, None),
|
||||
self.dur_dict.get(wav_basename, None),
|
||||
)
|
||||
future.add_done_callback(lambda p: progress.update())
|
||||
futures.append((future, wav_basename))
|
||||
|
||||
for future, wav_basename in futures:
|
||||
result = future.result()
|
||||
if result is None:
|
||||
logging.warning(
|
||||
'[AudioProcessor] Pitch extraction failed in phone level avg for: %s',
|
||||
wav_basename,
|
||||
)
|
||||
self.badcase_list.append(wav_basename)
|
||||
else:
|
||||
avg_f0 = result
|
||||
self.f0uv_dict[wav_basename] = avg_f0
|
||||
|
||||
for wav_basename in self.f0uv_dict:
|
||||
np.save(
|
||||
os.path.join(out_f0_dir, wav_basename + '.npy'),
|
||||
self.f0uv_dict[wav_basename].reshape(-1),
|
||||
)
|
||||
|
||||
logging.info('[AudioProcessor] Pitch normalization finished')
|
||||
logging.info('[AudioProcessor] Normed f0 saved to %s', out_f0_dir)
|
||||
logging.info('[AudioProcessor] Pitch extraction finished')
|
||||
|
||||
return True
|
||||
|
||||
def energy_extract(self, src_wav_dir, out_energy_dir,
|
||||
out_frame_energy_dir):
|
||||
os.makedirs(out_energy_dir, exist_ok=True)
|
||||
os.makedirs(out_frame_energy_dir, exist_ok=True)
|
||||
wav_list = glob(os.path.join(src_wav_dir, '*.wav'))
|
||||
pcm_dict = self.get_pcm_dict(src_wav_dir)
|
||||
mel_dict = self.mel_dict
|
||||
|
||||
logging.info('[AudioProcessor] Energy extraction started')
|
||||
# Get raw energy
|
||||
with ProcessPoolExecutor(
|
||||
max_workers=self.num_workers) as executor, tqdm(
|
||||
total=len(wav_list)) as progress:
|
||||
futures = []
|
||||
for wav_basename, pcm_data in pcm_dict.items():
|
||||
future = executor.submit(get_energy, pcm_data, self.hop_length,
|
||||
self.win_length, self.n_fft)
|
||||
future.add_done_callback(lambda p: progress.update())
|
||||
futures.append((future, wav_basename))
|
||||
|
||||
for future, wav_basename in futures:
|
||||
result = future.result()
|
||||
if result is None:
|
||||
logging.warning(
|
||||
'[AudioProcessor] Energy extraction failed for %s',
|
||||
wav_basename)
|
||||
self.badcase_list.append(wav_basename)
|
||||
else:
|
||||
energy = result
|
||||
if len(mel_dict) > 0:
|
||||
energy = align_length(energy,
|
||||
mel_dict.get(wav_basename, None))
|
||||
if energy is None:
|
||||
logging.warning(
|
||||
'[AudioProcessor] Energy length mismatch with mel in %s',
|
||||
wav_basename,
|
||||
)
|
||||
self.badcase_list.append(wav_basename)
|
||||
continue
|
||||
self.energy_dict[wav_basename] = energy
|
||||
|
||||
logging.info('Melspec statistic proceeding...')
|
||||
# Normalize energy
|
||||
energy_mean = compute_mean(list(self.energy_dict.values()), dims=1)
|
||||
energy_std = compute_std(
|
||||
list(self.energy_dict.values()), energy_mean, dims=1)
|
||||
np.savetxt(
|
||||
os.path.join(out_energy_dir, 'energy_mean.txt'),
|
||||
energy_mean,
|
||||
fmt='%.6f')
|
||||
np.savetxt(
|
||||
os.path.join(out_energy_dir, 'energy_std.txt'),
|
||||
energy_std,
|
||||
fmt='%.6f')
|
||||
logging.info(
|
||||
'[AudioProcessor] energy mean and std saved to:\n{},\n{}'.format(
|
||||
os.path.join(out_energy_dir, 'energy_mean.txt'),
|
||||
os.path.join(out_energy_dir, 'energy_std.txt'),
|
||||
))
|
||||
|
||||
logging.info('[AudioProcessor] Energy mean std norm is proceeding...')
|
||||
for wav_basename in self.energy_dict:
|
||||
energy = self.energy_dict[wav_basename]
|
||||
norm_energy = f0_norm_mean_std(energy, energy_mean, energy_std)
|
||||
self.energy_dict[wav_basename] = norm_energy
|
||||
|
||||
# save frame energy to a specific dir
|
||||
for wav_basename in self.energy_dict:
|
||||
np.save(
|
||||
os.path.join(out_frame_energy_dir, wav_basename + '.npy'),
|
||||
self.energy_dict[wav_basename].reshape(-1),
|
||||
)
|
||||
|
||||
# phone level average
|
||||
# if there is no duration then save the frame-level energy
|
||||
if self.phone_level_feature and len(self.dur_dict) > 0:
|
||||
with ProcessPoolExecutor(
|
||||
max_workers=self.num_workers) as executor, tqdm(
|
||||
total=len(self.energy_dict)) as progress:
|
||||
futures = []
|
||||
for wav_basename in self.energy_dict:
|
||||
future = executor.submit(
|
||||
average_by_duration,
|
||||
self.energy_dict.get(wav_basename, None),
|
||||
self.dur_dict.get(wav_basename, None),
|
||||
)
|
||||
future.add_done_callback(lambda p: progress.update())
|
||||
futures.append((future, wav_basename))
|
||||
|
||||
for future, wav_basename in futures:
|
||||
result = future.result()
|
||||
if result is None:
|
||||
logging.warning(
|
||||
'[AudioProcessor] Energy extraction failed in phone level avg for: %s',
|
||||
wav_basename,
|
||||
)
|
||||
self.badcase_list.append(wav_basename)
|
||||
else:
|
||||
avg_energy = result
|
||||
self.energy_dict[wav_basename] = avg_energy
|
||||
|
||||
for wav_basename in self.energy_dict:
|
||||
np.save(
|
||||
os.path.join(out_energy_dir, wav_basename + '.npy'),
|
||||
self.energy_dict[wav_basename].reshape(-1),
|
||||
)
|
||||
|
||||
logging.info('[AudioProcessor] Energy normalization finished')
|
||||
logging.info('[AudioProcessor] Normed Energy saved to %s',
|
||||
out_energy_dir)
|
||||
logging.info('[AudioProcessor] Energy extraction finished')
|
||||
|
||||
return True
|
||||
|
||||
def process(self, src_voice_dir, out_data_dir, aux_metafile=None):
|
||||
succeed = True
|
||||
|
||||
raw_wav_dir = os.path.join(src_voice_dir, 'wav')
|
||||
src_interval_dir = os.path.join(src_voice_dir, 'interval')
|
||||
|
||||
out_mel_dir = os.path.join(out_data_dir, 'mel')
|
||||
out_f0_dir = os.path.join(out_data_dir, 'f0')
|
||||
out_frame_f0_dir = os.path.join(out_data_dir, 'frame_f0')
|
||||
out_frame_uv_dir = os.path.join(out_data_dir, 'frame_uv')
|
||||
out_energy_dir = os.path.join(out_data_dir, 'energy')
|
||||
out_frame_energy_dir = os.path.join(out_data_dir, 'frame_energy')
|
||||
out_duration_dir = os.path.join(out_data_dir, 'raw_duration')
|
||||
out_cali_duration_dir = os.path.join(out_data_dir, 'duration')
|
||||
|
||||
os.makedirs(out_data_dir, exist_ok=True)
|
||||
|
||||
with_duration = os.path.exists(src_interval_dir)
|
||||
train_wav_dir = os.path.join(out_data_dir, 'wav')
|
||||
|
||||
succeed = self.amp_normalize(raw_wav_dir, train_wav_dir)
|
||||
if not succeed:
|
||||
logging.error('[AudioProcessor] amp_normalize failed, exit')
|
||||
return False
|
||||
|
||||
if with_duration:
|
||||
# Raw duration, non-trimmed
|
||||
succeed = self.duration_generate(src_interval_dir,
|
||||
out_duration_dir)
|
||||
if not succeed:
|
||||
logging.error(
|
||||
'[AudioProcessor] duration_generate failed, exit')
|
||||
return False
|
||||
|
||||
if self.trim_silence:
|
||||
if with_duration:
|
||||
succeed = self.trim_silence_wav_with_interval(
|
||||
train_wav_dir, out_duration_dir)
|
||||
if not succeed:
|
||||
logging.error(
|
||||
'[AudioProcessor] trim_silence_wav_with_interval failed, exit'
|
||||
)
|
||||
return False
|
||||
else:
|
||||
succeed = self.trim_silence_wav(train_wav_dir)
|
||||
if not succeed:
|
||||
logging.error(
|
||||
'[AudioProcessor] trim_silence_wav failed, exit')
|
||||
return False
|
||||
|
||||
succeed = self.mel_extract(train_wav_dir, out_mel_dir)
|
||||
if not succeed:
|
||||
logging.error('[AudioProcessor] mel_extract failed, exit')
|
||||
return False
|
||||
|
||||
if aux_metafile is not None and with_duration:
|
||||
self.calibrate_SyllableDuration(out_duration_dir, aux_metafile,
|
||||
out_cali_duration_dir)
|
||||
|
||||
succeed = self.pitch_extract(train_wav_dir, out_f0_dir,
|
||||
out_frame_f0_dir, out_frame_uv_dir)
|
||||
if not succeed:
|
||||
logging.error('[AudioProcessor] pitch_extract failed, exit')
|
||||
return False
|
||||
|
||||
succeed = self.energy_extract(train_wav_dir, out_energy_dir,
|
||||
out_frame_energy_dir)
|
||||
if not succeed:
|
||||
logging.error('[AudioProcessor] energy_extract failed, exit')
|
||||
return False
|
||||
|
||||
# recording badcase list
|
||||
with open(os.path.join(out_data_dir, 'badlist.txt'), 'w') as f:
|
||||
f.write('\n'.join(self.badcase_list))
|
||||
|
||||
logging.info('[AudioProcessor] All features extracted successfully!')
|
||||
|
||||
return succeed
|
||||
@@ -1,240 +0,0 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import librosa
|
||||
import librosa.filters
|
||||
import numpy as np
|
||||
from scipy import signal
|
||||
from scipy.io import wavfile
|
||||
|
||||
|
||||
def _stft(y, hop_length, win_length, n_fft):
|
||||
return librosa.stft(
|
||||
y=y, n_fft=n_fft, hop_length=hop_length, win_length=win_length)
|
||||
|
||||
|
||||
def _istft(y, hop_length, win_length):
|
||||
return librosa.istft(y, hop_length=hop_length, win_length=win_length)
|
||||
|
||||
|
||||
def _db_to_amp(x):
|
||||
return np.power(10.0, x * 0.05)
|
||||
|
||||
|
||||
def _amp_to_db(x):
|
||||
return 20 * np.log10(np.maximum(1e-5, x))
|
||||
|
||||
|
||||
def load_wav(path, sr):
|
||||
return librosa.load(path, sr=sr)[0]
|
||||
|
||||
|
||||
def save_wav(wav, path, sr):
|
||||
if wav.dtype == np.float32 or wav.dtype == np.float64:
|
||||
quant_wav = 32767 * wav
|
||||
else:
|
||||
quant_wav = wav
|
||||
# maximize the volume to avoid clipping
|
||||
# wav *= 32767 / max(0.01, np.max(np.abs(wav)))
|
||||
wavfile.write(path, sr, quant_wav.astype(np.int16))
|
||||
|
||||
|
||||
def trim_silence(wav, top_db, hop_length, win_length):
|
||||
trimed_wav, _ = librosa.effects.trim(
|
||||
wav, top_db=top_db, frame_length=win_length, hop_length=hop_length)
|
||||
return trimed_wav
|
||||
|
||||
|
||||
def trim_silence_with_interval(wav, interval, hop_length):
|
||||
if interval is None:
|
||||
return None
|
||||
leading_sil = interval[0]
|
||||
tailing_sil = interval[-1]
|
||||
trim_wav = wav[leading_sil * hop_length:-tailing_sil * hop_length]
|
||||
return trim_wav
|
||||
|
||||
|
||||
def preemphasis(wav, k=0.98, preemphasize=False):
|
||||
if preemphasize:
|
||||
return signal.lfilter([1, -k], [1], wav)
|
||||
return wav
|
||||
|
||||
|
||||
def inv_preemphasis(wav, k=0.98, inv_preemphasize=False):
|
||||
if inv_preemphasize:
|
||||
return signal.lfilter([1], [1, -k], wav)
|
||||
return wav
|
||||
|
||||
|
||||
def _normalize(S, max_norm=1.0, min_level_db=-100, symmetric=False):
|
||||
if symmetric:
|
||||
return np.clip(
|
||||
(2 * max_norm) * ((S - min_level_db) / (-min_level_db)) - max_norm,
|
||||
-max_norm,
|
||||
max_norm,
|
||||
)
|
||||
else:
|
||||
return np.clip(max_norm * ((S - min_level_db) / (-min_level_db)), 0,
|
||||
max_norm)
|
||||
|
||||
|
||||
def _denormalize(D, max_norm=1.0, min_level_db=-100, symmetric=False):
|
||||
if symmetric:
|
||||
return ((np.clip(D, -max_norm, max_norm) + max_norm) * -min_level_db
|
||||
/ # noqa W504
|
||||
(2 * max_norm)) + min_level_db
|
||||
else:
|
||||
return (np.clip(D, 0, max_norm) * -min_level_db
|
||||
/ max_norm) + min_level_db
|
||||
|
||||
|
||||
def _griffin_lim(S, n_fft, hop_length, win_length, griffin_lim_iters=60):
|
||||
angles = np.exp(2j * np.pi * np.random.rand(*S.shape))
|
||||
S_complex = np.abs(S).astype(np.complex)
|
||||
y = _istft(
|
||||
S_complex * angles, hop_length=hop_length, win_length=win_length)
|
||||
for i in range(griffin_lim_iters):
|
||||
angles = np.exp(1j * np.angle(
|
||||
_stft(
|
||||
y, n_fft=n_fft, hop_length=hop_length, win_length=win_length)))
|
||||
y = _istft(
|
||||
S_complex * angles, hop_length=hop_length, win_length=win_length)
|
||||
return y
|
||||
|
||||
|
||||
def spectrogram(
|
||||
y,
|
||||
n_fft=1024,
|
||||
hop_length=256,
|
||||
win_length=1024,
|
||||
max_norm=1.0,
|
||||
min_level_db=-100,
|
||||
ref_level_db=20,
|
||||
symmetric=False,
|
||||
):
|
||||
D = _stft(preemphasis(y), hop_length, win_length, n_fft)
|
||||
S = _amp_to_db(np.abs(D)) - ref_level_db
|
||||
return _normalize(S, max_norm, min_level_db, symmetric)
|
||||
|
||||
|
||||
def inv_spectrogram(
|
||||
spectrogram,
|
||||
n_fft=1024,
|
||||
hop_length=256,
|
||||
win_length=1024,
|
||||
max_norm=1.0,
|
||||
min_level_db=-100,
|
||||
ref_level_db=20,
|
||||
symmetric=False,
|
||||
power=1.5,
|
||||
):
|
||||
S = _db_to_amp(
|
||||
_denormalize(spectrogram, max_norm, min_level_db, symmetric)
|
||||
+ ref_level_db)
|
||||
return _griffin_lim(S**power, n_fft, hop_length, win_length)
|
||||
|
||||
|
||||
def _build_mel_basis(sample_rate, n_fft=1024, fmin=50, fmax=8000, n_mels=80):
|
||||
assert fmax <= sample_rate // 2
|
||||
return librosa.filters.mel(
|
||||
sr=sample_rate, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax)
|
||||
|
||||
|
||||
# mel linear Conversions
|
||||
_mel_basis = None
|
||||
_inv_mel_basis = None
|
||||
|
||||
|
||||
def _linear_to_mel(spectogram,
|
||||
sample_rate,
|
||||
n_fft=1024,
|
||||
fmin=50,
|
||||
fmax=8000,
|
||||
n_mels=80):
|
||||
global _mel_basis
|
||||
if _mel_basis is None:
|
||||
_mel_basis = _build_mel_basis(sample_rate, n_fft, fmin, fmax, n_mels)
|
||||
return np.dot(_mel_basis, spectogram)
|
||||
|
||||
|
||||
def _mel_to_linear(mel_spectrogram,
|
||||
sample_rate,
|
||||
n_fft=1024,
|
||||
fmin=50,
|
||||
fmax=8000,
|
||||
n_mels=80):
|
||||
global _inv_mel_basis
|
||||
if _inv_mel_basis is None:
|
||||
_inv_mel_basis = np.linalg.pinv(
|
||||
_build_mel_basis(sample_rate, n_fft, fmin, fmax, n_mels))
|
||||
return np.maximum(1e-10, np.dot(_inv_mel_basis, mel_spectrogram))
|
||||
|
||||
|
||||
def melspectrogram(
|
||||
y,
|
||||
sample_rate,
|
||||
n_fft=1024,
|
||||
hop_length=256,
|
||||
win_length=1024,
|
||||
n_mels=80,
|
||||
max_norm=1.0,
|
||||
min_level_db=-100,
|
||||
ref_level_db=20,
|
||||
fmin=50,
|
||||
fmax=8000,
|
||||
symmetric=False,
|
||||
preemphasize=False,
|
||||
):
|
||||
D = _stft(
|
||||
preemphasis(y, preemphasize=preemphasize),
|
||||
hop_length=hop_length,
|
||||
win_length=win_length,
|
||||
n_fft=n_fft,
|
||||
)
|
||||
S = (
|
||||
_amp_to_db(
|
||||
_linear_to_mel(
|
||||
np.abs(D),
|
||||
sample_rate=sample_rate,
|
||||
n_fft=n_fft,
|
||||
fmin=fmin,
|
||||
fmax=fmax,
|
||||
n_mels=n_mels,
|
||||
)) - ref_level_db)
|
||||
return _normalize(
|
||||
S, max_norm=max_norm, min_level_db=min_level_db, symmetric=symmetric).T
|
||||
|
||||
|
||||
def inv_mel_spectrogram(
|
||||
mel_spectrogram,
|
||||
sample_rate,
|
||||
n_fft=1024,
|
||||
hop_length=256,
|
||||
win_length=1024,
|
||||
n_mels=80,
|
||||
max_norm=1.0,
|
||||
min_level_db=-100,
|
||||
ref_level_db=20,
|
||||
fmin=50,
|
||||
fmax=8000,
|
||||
power=1.5,
|
||||
symmetric=False,
|
||||
preemphasize=False,
|
||||
):
|
||||
D = _denormalize(
|
||||
mel_spectrogram,
|
||||
max_norm=max_norm,
|
||||
min_level_db=min_level_db,
|
||||
symmetric=symmetric,
|
||||
)
|
||||
S = _mel_to_linear(
|
||||
_db_to_amp(D + ref_level_db),
|
||||
sample_rate=sample_rate,
|
||||
n_fft=n_fft,
|
||||
fmin=fmin,
|
||||
fmax=fmax,
|
||||
n_mels=n_mels,
|
||||
)
|
||||
return inv_preemphasis(
|
||||
_griffin_lim(S**power, n_fft, hop_length, win_length),
|
||||
preemphasize=preemphasize,
|
||||
)
|
||||
@@ -1,531 +0,0 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from glob import glob
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import pysptk
|
||||
import sox
|
||||
from scipy.io import wavfile
|
||||
from tqdm import tqdm
|
||||
|
||||
from modelscope.utils.logger import get_logger
|
||||
from .dsp import _stft
|
||||
|
||||
logging = get_logger()
|
||||
|
||||
anchor_hist = np.array([
|
||||
0.0,
|
||||
0.00215827,
|
||||
0.00354383,
|
||||
0.00442313,
|
||||
0.00490274,
|
||||
0.00532907,
|
||||
0.00602185,
|
||||
0.00690115,
|
||||
0.00810019,
|
||||
0.00948574,
|
||||
0.0120437,
|
||||
0.01489475,
|
||||
0.01873168,
|
||||
0.02302158,
|
||||
0.02872369,
|
||||
0.03669065,
|
||||
0.04636291,
|
||||
0.05843325,
|
||||
0.07700506,
|
||||
0.11052491,
|
||||
0.16802558,
|
||||
0.25997868,
|
||||
0.37942979,
|
||||
0.50730083,
|
||||
0.62006395,
|
||||
0.71092459,
|
||||
0.76877165,
|
||||
0.80762057,
|
||||
0.83458566,
|
||||
0.85672795,
|
||||
0.87660538,
|
||||
0.89251266,
|
||||
0.90578204,
|
||||
0.91569411,
|
||||
0.92541966,
|
||||
0.93383959,
|
||||
0.94162004,
|
||||
0.94940048,
|
||||
0.95539568,
|
||||
0.96136424,
|
||||
0.9670397,
|
||||
0.97290168,
|
||||
0.97705835,
|
||||
0.98116174,
|
||||
0.98465228,
|
||||
0.98814282,
|
||||
0.99152678,
|
||||
0.99421796,
|
||||
0.9965894,
|
||||
0.99840128,
|
||||
1.0,
|
||||
])
|
||||
|
||||
anchor_bins = np.array([
|
||||
0.033976,
|
||||
0.03529014,
|
||||
0.03660428,
|
||||
0.03791842,
|
||||
0.03923256,
|
||||
0.0405467,
|
||||
0.04186084,
|
||||
0.04317498,
|
||||
0.04448912,
|
||||
0.04580326,
|
||||
0.0471174,
|
||||
0.04843154,
|
||||
0.04974568,
|
||||
0.05105982,
|
||||
0.05237396,
|
||||
0.0536881,
|
||||
0.05500224,
|
||||
0.05631638,
|
||||
0.05763052,
|
||||
0.05894466,
|
||||
0.0602588,
|
||||
0.06157294,
|
||||
0.06288708,
|
||||
0.06420122,
|
||||
0.06551536,
|
||||
0.0668295,
|
||||
0.06814364,
|
||||
0.06945778,
|
||||
0.07077192,
|
||||
0.07208606,
|
||||
0.0734002,
|
||||
0.07471434,
|
||||
0.07602848,
|
||||
0.07734262,
|
||||
0.07865676,
|
||||
0.0799709,
|
||||
0.08128504,
|
||||
0.08259918,
|
||||
0.08391332,
|
||||
0.08522746,
|
||||
0.0865416,
|
||||
0.08785574,
|
||||
0.08916988,
|
||||
0.09048402,
|
||||
0.09179816,
|
||||
0.0931123,
|
||||
0.09442644,
|
||||
0.09574058,
|
||||
0.09705472,
|
||||
0.09836886,
|
||||
0.099683,
|
||||
])
|
||||
|
||||
hist_bins = 50
|
||||
|
||||
|
||||
def amp_info(wav_file_path):
|
||||
"""
|
||||
Returns the amplitude info of the wav file.
|
||||
"""
|
||||
stats = sox.file_info.stat(wav_file_path)
|
||||
amp_rms = stats['RMS amplitude']
|
||||
amp_max = stats['Maximum amplitude']
|
||||
amp_mean = stats['Mean amplitude']
|
||||
length = stats['Length (seconds)']
|
||||
|
||||
return {
|
||||
'amp_rms': amp_rms,
|
||||
'amp_max': amp_max,
|
||||
'amp_mean': amp_mean,
|
||||
'length': length,
|
||||
'basename': os.path.basename(wav_file_path),
|
||||
}
|
||||
|
||||
|
||||
def statistic_amplitude(src_wav_dir):
|
||||
"""
|
||||
Returns the amplitude info of the wav file.
|
||||
"""
|
||||
wav_lst = glob(os.path.join(src_wav_dir, '*.wav'))
|
||||
with ProcessPoolExecutor(max_workers=8) as executor, tqdm(
|
||||
total=len(wav_lst)) as progress:
|
||||
futures = []
|
||||
for wav_file_path in wav_lst:
|
||||
future = executor.submit(amp_info, wav_file_path)
|
||||
future.add_done_callback(lambda p: progress.update())
|
||||
futures.append(future)
|
||||
|
||||
amp_info_lst = [future.result() for future in futures]
|
||||
|
||||
amp_info_lst = sorted(amp_info_lst, key=lambda x: x['amp_rms'])
|
||||
|
||||
logging.info('Average amplitude RMS : {}'.format(
|
||||
np.mean([x['amp_rms'] for x in amp_info_lst])))
|
||||
|
||||
return amp_info_lst
|
||||
|
||||
|
||||
def volume_normalize(src_wav_dir, out_wav_dir):
|
||||
logging.info('Volume statistic proceeding...')
|
||||
amp_info_lst = statistic_amplitude(src_wav_dir)
|
||||
logging.info('Volume statistic done.')
|
||||
|
||||
rms_amp_lst = [x['amp_rms'] for x in amp_info_lst]
|
||||
src_hist, src_bins = np.histogram(
|
||||
rms_amp_lst, bins=hist_bins, density=True)
|
||||
src_hist = src_hist / np.sum(src_hist)
|
||||
src_hist = np.cumsum(src_hist)
|
||||
src_hist = np.insert(src_hist, 0, 0.0)
|
||||
|
||||
logging.info('Volume normalization proceeding...')
|
||||
for amp_info in tqdm(amp_info_lst):
|
||||
rms_amp = amp_info['amp_rms']
|
||||
rms_amp = np.clip(rms_amp, src_bins[0], src_bins[-1])
|
||||
|
||||
src_idx = np.where(rms_amp >= src_bins)[0][-1]
|
||||
src_pos = src_hist[src_idx]
|
||||
anchor_idx = np.where(src_pos >= anchor_hist)[0][-1]
|
||||
|
||||
if src_idx == hist_bins or anchor_idx == hist_bins:
|
||||
rms_amp = anchor_bins[-1]
|
||||
else:
|
||||
rms_amp = (rms_amp - src_bins[src_idx]) / (
|
||||
src_bins[src_idx + 1] - src_bins[src_idx]) * (
|
||||
anchor_bins[anchor_idx + 1]
|
||||
- anchor_bins[anchor_idx]) + anchor_bins[anchor_idx]
|
||||
|
||||
scale = rms_amp / amp_info['amp_rms']
|
||||
|
||||
# FIXME: This is a hack to avoid the sound cliping.
|
||||
sr, data = wavfile.read(
|
||||
os.path.join(src_wav_dir, amp_info['basename']))
|
||||
wavfile.write(
|
||||
os.path.join(out_wav_dir, amp_info['basename']),
|
||||
sr,
|
||||
(data * scale).astype(np.int16),
|
||||
)
|
||||
logging.info('Volume normalization done.')
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def interp_f0(f0_data):
|
||||
"""
|
||||
linear interpolation
|
||||
"""
|
||||
f0_data[f0_data < 1] = 0
|
||||
xp = np.nonzero(f0_data)
|
||||
yp = f0_data[xp]
|
||||
x = np.arange(f0_data.size)
|
||||
contour_f0 = np.interp(x, xp[0], yp).astype(np.float32)
|
||||
return contour_f0
|
||||
|
||||
|
||||
def frame_nccf(x, y):
|
||||
norm_coef = (np.sum(x**2.0) * np.sum(y**2.0) + 1e-30)**0.5
|
||||
return (np.sum(x * y) / norm_coef + 1.0) / 2.0
|
||||
|
||||
|
||||
def get_nccf(pcm_data, f0, min_f0=40, max_f0=800, fs=160, sr=16000):
|
||||
if pcm_data.dtype == np.int16:
|
||||
pcm_data = pcm_data.astype(np.float32) / 32768
|
||||
frame_len = int(sr / 200)
|
||||
frame_num = int(len(pcm_data) // fs)
|
||||
frame_num = min(frame_num, len(f0))
|
||||
|
||||
pad_len = int(sr / min_f0) + frame_len
|
||||
|
||||
pad_zeros = np.zeros([pad_len], dtype=np.float32)
|
||||
data = np.hstack((pad_zeros, pcm_data.astype(np.float32), pad_zeros))
|
||||
|
||||
nccf = np.zeros((frame_num), dtype=np.float32)
|
||||
|
||||
for i in range(frame_num):
|
||||
curr_f0 = np.clip(f0[i], min_f0, max_f0)
|
||||
lag = int(sr / curr_f0 + 0.5)
|
||||
j = i * fs + pad_len - frame_len // 2
|
||||
|
||||
l_data = data[j:j + frame_len]
|
||||
l_data -= l_data.mean()
|
||||
|
||||
r_data = data[j + lag:j + lag + frame_len]
|
||||
r_data -= r_data.mean()
|
||||
|
||||
nccf[i] = frame_nccf(l_data, r_data)
|
||||
|
||||
return nccf
|
||||
|
||||
|
||||
def smooth(data, win_len):
|
||||
if win_len % 2 == 0:
|
||||
win_len += 1
|
||||
hwin = win_len // 2
|
||||
win = np.hanning(win_len)
|
||||
win /= win.sum()
|
||||
data = data.reshape([-1])
|
||||
pad_data = np.pad(data, hwin, mode='edge')
|
||||
|
||||
for i in range(data.shape[0]):
|
||||
data[i] = np.dot(win, pad_data[i:i + win_len])
|
||||
|
||||
return data.reshape([-1, 1])
|
||||
|
||||
|
||||
# support: rapt, swipe
|
||||
# unsupport: reaper, world(DIO)
|
||||
def RAPT_FUNC(v1, v2, v3, v4, v5):
|
||||
return pysptk.sptk.rapt(
|
||||
v1.astype(np.float32), fs=v2, hopsize=v3, min=v4, max=v5)
|
||||
|
||||
|
||||
def SWIPE_FUNC(v1, v2, v3, v4, v5):
|
||||
return pysptk.sptk.swipe(
|
||||
v1.astype(np.float64), fs=v2, hopsize=v3, min=v4, max=v5)
|
||||
|
||||
|
||||
def PYIN_FUNC(v1, v2, v3, v4, v5):
|
||||
f0_mel = librosa.pyin(
|
||||
v1.astype(np.float32), sr=v2, frame_length=v3 * 4, fmin=v4, fmax=v5)[0]
|
||||
f0_mel = np.where(np.isnan(f0_mel), 0.0, f0_mel)
|
||||
return f0_mel
|
||||
|
||||
|
||||
def get_pitch(pcm_data, sampling_rate=16000, hop_length=160):
|
||||
log_f0_list = []
|
||||
uv_list = []
|
||||
low, high = 40, 800
|
||||
|
||||
cali_f0 = pysptk.sptk.rapt(
|
||||
pcm_data.astype(np.float32),
|
||||
fs=sampling_rate,
|
||||
hopsize=hop_length,
|
||||
min=low,
|
||||
max=high,
|
||||
)
|
||||
f0_range = np.sort(np.unique(cali_f0))
|
||||
|
||||
if len(f0_range) > 20:
|
||||
low = max(f0_range[10] - 50, low)
|
||||
high = min(f0_range[-10] + 50, high)
|
||||
|
||||
func_dict = {'rapt': RAPT_FUNC, 'swipe': SWIPE_FUNC}
|
||||
|
||||
for func_name in func_dict:
|
||||
f0 = func_dict[func_name](pcm_data, sampling_rate, hop_length, low,
|
||||
high)
|
||||
uv = f0 > 0
|
||||
|
||||
if len(f0) < 10 or f0.max() < low:
|
||||
logging.error('{} method: calc F0 is too low.'.format(func_name))
|
||||
continue
|
||||
else:
|
||||
f0 = np.clip(f0, 1e-30, high)
|
||||
log_f0 = np.log(f0)
|
||||
contour_log_f0 = interp_f0(log_f0)
|
||||
|
||||
log_f0_list.append(contour_log_f0)
|
||||
uv_list.append(uv)
|
||||
|
||||
if len(log_f0_list) == 0:
|
||||
logging.error('F0 estimation failed.')
|
||||
return None
|
||||
|
||||
min_len = float('inf')
|
||||
for log_f0 in log_f0_list:
|
||||
min_len = min(min_len, log_f0.shape[0])
|
||||
|
||||
multi_log_f0 = np.zeros([len(log_f0_list), min_len], dtype=np.float32)
|
||||
multi_uv = np.zeros([len(log_f0_list), min_len], dtype=np.float32)
|
||||
|
||||
for i in range(len(log_f0_list)):
|
||||
multi_log_f0[i, :] = log_f0_list[i][:min_len]
|
||||
multi_uv[i, :] = uv_list[i][:min_len]
|
||||
|
||||
log_f0 = smooth(np.median(multi_log_f0, axis=0), 5)
|
||||
uv = (smooth(np.median(multi_uv, axis=0), 5) > 0.5).astype(np.float32)
|
||||
|
||||
f0 = np.exp(log_f0)
|
||||
|
||||
min_len = min(f0.shape[0], uv.shape[0])
|
||||
|
||||
return f0[:min_len], uv[:min_len], f0[:min_len] * uv[:min_len]
|
||||
|
||||
|
||||
def get_energy(pcm_data, hop_length, win_length, n_fft):
|
||||
D = _stft(pcm_data, hop_length, win_length, n_fft)
|
||||
S, _ = librosa.magphase(D)
|
||||
energy = np.sqrt(np.sum(S**2, axis=0))
|
||||
|
||||
return energy.reshape((-1, 1))
|
||||
|
||||
|
||||
def align_length(in_data, tgt_data, basename=None):
|
||||
if in_data is None or tgt_data is None:
|
||||
logging.error('{}: Input data is None.'.format(basename))
|
||||
return None
|
||||
in_len = in_data.shape[0]
|
||||
tgt_len = tgt_data.shape[0]
|
||||
if abs(in_len - tgt_len) > 20:
|
||||
logging.error(
|
||||
'{}: Input data length mismatches with target data length too much.'
|
||||
.format(basename))
|
||||
return None
|
||||
|
||||
if in_len < tgt_len:
|
||||
out_data = np.pad(
|
||||
in_data, ((0, tgt_len - in_len), (0, 0)),
|
||||
'constant',
|
||||
constant_values=0.0)
|
||||
else:
|
||||
out_data = in_data[:tgt_len]
|
||||
|
||||
return out_data
|
||||
|
||||
|
||||
def compute_mean(data_list, dims=80):
|
||||
mean_vector = np.zeros((1, dims))
|
||||
all_frame_number = 0
|
||||
|
||||
for data in tqdm(data_list):
|
||||
if data is None:
|
||||
continue
|
||||
features = data.reshape((-1, dims))
|
||||
current_frame_number = np.shape(features)[0]
|
||||
mean_vector += np.sum(features[:, :], axis=0)
|
||||
all_frame_number += current_frame_number
|
||||
|
||||
mean_vector /= float(all_frame_number)
|
||||
return mean_vector
|
||||
|
||||
|
||||
def compute_std(data_list, mean_vector, dims=80):
|
||||
std_vector = np.zeros((1, dims))
|
||||
all_frame_number = 0
|
||||
|
||||
for data in tqdm(data_list):
|
||||
if data is None:
|
||||
continue
|
||||
features = data.reshape((-1, dims))
|
||||
current_frame_number = np.shape(features)[0]
|
||||
mean_matrix = np.tile(mean_vector, (current_frame_number, 1))
|
||||
std_vector += np.sum((features[:, :] - mean_matrix)**2, axis=0)
|
||||
all_frame_number += current_frame_number
|
||||
|
||||
std_vector /= float(all_frame_number)
|
||||
std_vector = std_vector**0.5
|
||||
return std_vector
|
||||
|
||||
|
||||
F0_MIN = 0.0
|
||||
F0_MAX = 800.0
|
||||
|
||||
ENERGY_MIN = 0.0
|
||||
ENERGY_MAX = 200.0
|
||||
|
||||
CLIP_FLOOR = 1e-3
|
||||
|
||||
|
||||
def f0_norm_min_max(f0):
|
||||
zero_idxs = np.where(f0 <= CLIP_FLOOR)[0]
|
||||
res = (2 * f0 - F0_MIN - F0_MAX) / (F0_MAX - F0_MIN)
|
||||
res[zero_idxs] = 0.0
|
||||
return res
|
||||
|
||||
|
||||
def f0_denorm_min_max(f0):
|
||||
zero_idxs = np.where(f0 == 0.0)[0]
|
||||
res = (f0 * (F0_MAX - F0_MIN) + F0_MIN + F0_MAX) / 2
|
||||
res[zero_idxs] = 0.0
|
||||
return res
|
||||
|
||||
|
||||
def energy_norm_min_max(energy):
|
||||
zero_idxs = np.where(energy == 0.0)[0]
|
||||
res = (2 * energy - ENERGY_MIN - ENERGY_MAX) / (ENERGY_MAX - ENERGY_MIN)
|
||||
res[zero_idxs] = 0.0
|
||||
return res
|
||||
|
||||
|
||||
def energy_denorm_min_max(energy):
|
||||
zero_idxs = np.where(energy == 0.0)[0]
|
||||
res = (energy * (ENERGY_MAX - ENERGY_MIN) + ENERGY_MIN + ENERGY_MAX) / 2
|
||||
res[zero_idxs] = 0.0
|
||||
return res
|
||||
|
||||
|
||||
def norm_log(x):
|
||||
zero_idxs = np.where(x <= CLIP_FLOOR)[0]
|
||||
x[zero_idxs] = 1.0
|
||||
res = np.log(x)
|
||||
return res
|
||||
|
||||
|
||||
def denorm_log(x):
|
||||
zero_idxs = np.where(x == 0.0)[0]
|
||||
res = np.exp(x)
|
||||
res[zero_idxs] = 0.0
|
||||
return res
|
||||
|
||||
|
||||
def f0_norm_mean_std(x, mean, std):
|
||||
zero_idxs = np.where(x == 0.0)[0]
|
||||
x = (x - mean) / std
|
||||
x[zero_idxs] = 0.0
|
||||
return x
|
||||
|
||||
|
||||
def norm_mean_std(x, mean, std):
|
||||
x = (x - mean) / std
|
||||
return x
|
||||
|
||||
|
||||
def parse_interval_file(file_path, sampling_rate, hop_length):
|
||||
with open(file_path, 'r') as f:
|
||||
lines = f.readlines()
|
||||
# second
|
||||
frame_intervals = 1.0 * hop_length / sampling_rate
|
||||
skip_lines = 12
|
||||
dur_list = []
|
||||
phone_list = []
|
||||
|
||||
line_index = skip_lines
|
||||
|
||||
while line_index < len(lines):
|
||||
phone_begin = float(lines[line_index])
|
||||
phone_end = float(lines[line_index + 1])
|
||||
phone = lines[line_index + 2].strip()[1:-1]
|
||||
dur_list.append(
|
||||
int(round((phone_end - phone_begin) / frame_intervals)))
|
||||
phone_list.append(phone)
|
||||
line_index += 3
|
||||
|
||||
if len(dur_list) == 0 or len(phone_list) == 0:
|
||||
return None
|
||||
|
||||
return np.array(dur_list), phone_list
|
||||
|
||||
|
||||
def average_by_duration(x, durs):
|
||||
if x is None or durs is None:
|
||||
return None
|
||||
durs_cum = np.cumsum(np.pad(durs, (1, 0), 'constant'))
|
||||
|
||||
# average over each symbol's duration
|
||||
x_symbol = np.zeros((durs.shape[0], ), dtype=np.float32)
|
||||
for idx, start, end in zip(
|
||||
range(durs.shape[0]), durs_cum[:-1], durs_cum[1:]):
|
||||
values = x[start:end][np.where(x[start:end] != 0.0)[0]]
|
||||
x_symbol[idx] = np.mean(values) if len(values) > 0 else 0.0
|
||||
|
||||
return x_symbol.astype(np.float32)
|
||||
|
||||
|
||||
def encode_16bits(x):
|
||||
if x.min() > -1.0 and x.max() < 1.0:
|
||||
return np.clip(x * 2**15, -(2**15), 2**15 - 1).astype(np.int16)
|
||||
else:
|
||||
return x
|
||||
@@ -1,186 +0,0 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import argparse
|
||||
import codecs
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
import yaml
|
||||
|
||||
from modelscope import __version__
|
||||
from modelscope.models.audio.tts.kantts.datasets.dataset import (AmDataset,
|
||||
VocDataset)
|
||||
from modelscope.utils.logger import get_logger
|
||||
from .audio_processor.audio_processor import AudioProcessor
|
||||
from .fp_processor import FpProcessor, is_fp_line
|
||||
from .languages import languages
|
||||
from .script_convertor.text_script_convertor import TextScriptConvertor
|
||||
|
||||
ROOT_PATH = os.path.dirname(os.path.dirname(
|
||||
os.path.abspath(__file__))) # NOQA: E402
|
||||
sys.path.insert(0, os.path.dirname(ROOT_PATH)) # NOQA: E402
|
||||
|
||||
logging = get_logger()
|
||||
|
||||
|
||||
def gen_metafile(
|
||||
voice_output_dir,
|
||||
fp_enable=False,
|
||||
badlist=None,
|
||||
split_ratio=0.98,
|
||||
):
|
||||
|
||||
voc_train_meta = os.path.join(voice_output_dir, 'train.lst')
|
||||
voc_valid_meta = os.path.join(voice_output_dir, 'valid.lst')
|
||||
if not os.path.exists(voc_train_meta) or not os.path.exists(
|
||||
voc_valid_meta):
|
||||
VocDataset.gen_metafile(
|
||||
os.path.join(voice_output_dir, 'wav'),
|
||||
voice_output_dir,
|
||||
split_ratio,
|
||||
)
|
||||
logging.info('Voc metafile generated.')
|
||||
|
||||
raw_metafile = os.path.join(voice_output_dir, 'raw_metafile.txt')
|
||||
am_train_meta = os.path.join(voice_output_dir, 'am_train.lst')
|
||||
am_valid_meta = os.path.join(voice_output_dir, 'am_valid.lst')
|
||||
if not os.path.exists(am_train_meta) or not os.path.exists(am_valid_meta):
|
||||
AmDataset.gen_metafile(
|
||||
raw_metafile,
|
||||
voice_output_dir,
|
||||
am_train_meta,
|
||||
am_valid_meta,
|
||||
badlist,
|
||||
split_ratio,
|
||||
)
|
||||
logging.info('AM metafile generated.')
|
||||
|
||||
if fp_enable:
|
||||
fpadd_metafile = os.path.join(voice_output_dir, 'fpadd_metafile.txt')
|
||||
am_train_meta = os.path.join(voice_output_dir, 'am_fpadd_train.lst')
|
||||
am_valid_meta = os.path.join(voice_output_dir, 'am_fpadd_valid.lst')
|
||||
if not os.path.exists(am_train_meta) or not os.path.exists(
|
||||
am_valid_meta):
|
||||
AmDataset.gen_metafile(
|
||||
fpadd_metafile,
|
||||
voice_output_dir,
|
||||
am_train_meta,
|
||||
am_valid_meta,
|
||||
badlist,
|
||||
split_ratio,
|
||||
)
|
||||
logging.info('AM fpaddmetafile generated.')
|
||||
|
||||
fprm_metafile = os.path.join(voice_output_dir, 'fprm_metafile.txt')
|
||||
am_train_meta = os.path.join(voice_output_dir, 'am_fprm_train.lst')
|
||||
am_valid_meta = os.path.join(voice_output_dir, 'am_fprm_valid.lst')
|
||||
if not os.path.exists(am_train_meta) or not os.path.exists(
|
||||
am_valid_meta):
|
||||
AmDataset.gen_metafile(
|
||||
fprm_metafile,
|
||||
voice_output_dir,
|
||||
am_train_meta,
|
||||
am_valid_meta,
|
||||
badlist,
|
||||
split_ratio,
|
||||
)
|
||||
logging.info('AM fprmmetafile generated.')
|
||||
|
||||
|
||||
def process_data(
|
||||
voice_input_dir,
|
||||
voice_output_dir,
|
||||
language_dir,
|
||||
audio_config,
|
||||
speaker_name=None,
|
||||
targetLang='PinYin',
|
||||
skip_script=False,
|
||||
):
|
||||
foreignLang = 'EnUS'
|
||||
emo_tag_path = None
|
||||
|
||||
phoneset_path = os.path.join(language_dir, targetLang,
|
||||
languages[targetLang]['phoneset_path'])
|
||||
posset_path = os.path.join(language_dir, targetLang,
|
||||
languages[targetLang]['posset_path'])
|
||||
f2t_map_path = os.path.join(language_dir, targetLang,
|
||||
languages[targetLang]['f2t_map_path'])
|
||||
s2p_map_path = os.path.join(language_dir, targetLang,
|
||||
languages[targetLang]['s2p_map_path'])
|
||||
|
||||
logging.info(f'phoneset_path={phoneset_path}')
|
||||
# dir of plain text/sentences for training byte based model
|
||||
plain_text_dir = os.path.join(voice_input_dir, 'text')
|
||||
|
||||
if speaker_name is None:
|
||||
speaker_name = os.path.basename(voice_input_dir)
|
||||
|
||||
if audio_config is not None:
|
||||
with open(audio_config, 'r') as f:
|
||||
config = yaml.load(f, Loader=yaml.Loader)
|
||||
|
||||
config['create_time'] = time.strftime('%Y-%m-%d %H:%M:%S',
|
||||
time.localtime())
|
||||
config['modelscope_version'] = __version__
|
||||
|
||||
with open(os.path.join(voice_output_dir, 'audio_config.yaml'), 'w') as f:
|
||||
yaml.dump(config, f, Dumper=yaml.Dumper, default_flow_style=None)
|
||||
|
||||
if skip_script:
|
||||
logging.info('Skip script conversion')
|
||||
raw_metafile = None
|
||||
# Script processor
|
||||
if not skip_script:
|
||||
if os.path.exists(plain_text_dir):
|
||||
TextScriptConvertor.turn_text_into_bytes(
|
||||
os.path.join(plain_text_dir, 'text.txt'),
|
||||
os.path.join(voice_output_dir, 'raw_metafile.txt'),
|
||||
speaker_name,
|
||||
)
|
||||
fp_enable = False
|
||||
else:
|
||||
tsc = TextScriptConvertor(
|
||||
phoneset_path,
|
||||
posset_path,
|
||||
targetLang,
|
||||
foreignLang,
|
||||
f2t_map_path,
|
||||
s2p_map_path,
|
||||
emo_tag_path,
|
||||
speaker_name,
|
||||
)
|
||||
tsc.process(
|
||||
os.path.join(voice_input_dir, 'prosody', 'prosody.txt'),
|
||||
os.path.join(voice_output_dir, 'Script.xml'),
|
||||
os.path.join(voice_output_dir, 'raw_metafile.txt'),
|
||||
)
|
||||
prosody = os.path.join(voice_input_dir, 'prosody', 'prosody.txt')
|
||||
# FP processor
|
||||
with codecs.open(prosody, 'r', 'utf-8') as f:
|
||||
lines = f.readlines()
|
||||
fp_enable = is_fp_line(lines[1])
|
||||
raw_metafile = os.path.join(voice_output_dir, 'raw_metafile.txt')
|
||||
|
||||
if fp_enable:
|
||||
FP = FpProcessor()
|
||||
|
||||
FP.process(
|
||||
voice_output_dir,
|
||||
prosody,
|
||||
raw_metafile,
|
||||
)
|
||||
logging.info('Processing fp done.')
|
||||
|
||||
# Audio processor
|
||||
ap = AudioProcessor(config['audio_config'])
|
||||
ap.process(
|
||||
voice_input_dir,
|
||||
voice_output_dir,
|
||||
raw_metafile,
|
||||
)
|
||||
|
||||
logging.info('Processing done.')
|
||||
|
||||
# Generate Voc&AM metafile
|
||||
gen_metafile(voice_output_dir, fp_enable, ap.badcase_list)
|
||||
@@ -1,156 +0,0 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
import random
|
||||
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logging = get_logger()
|
||||
|
||||
|
||||
def is_fp_line(line):
|
||||
fp_category_list = ['FP', 'I', 'N', 'Q']
|
||||
elements = line.strip().split(' ')
|
||||
res = True
|
||||
for ele in elements:
|
||||
if ele not in fp_category_list:
|
||||
res = False
|
||||
break
|
||||
return res
|
||||
|
||||
|
||||
class FpProcessor:
|
||||
|
||||
def __init__(self):
|
||||
# TODO: Add more audio processing methods.
|
||||
self.res = []
|
||||
|
||||
def is_fp_line(line):
|
||||
fp_category_list = ['FP', 'I', 'N', 'Q']
|
||||
elements = line.strip().split(' ')
|
||||
res = True
|
||||
for ele in elements:
|
||||
if ele not in fp_category_list:
|
||||
res = False
|
||||
break
|
||||
return res
|
||||
|
||||
# TODO: adjust idx judgment rule
|
||||
def addfp(self, voice_output_dir, prosody, raw_metafile_lines):
|
||||
|
||||
fp_category_list = ['FP', 'I', 'N']
|
||||
|
||||
f = open(prosody)
|
||||
prosody_lines = f.readlines()
|
||||
f.close()
|
||||
|
||||
idx = ''
|
||||
fp = ''
|
||||
fp_label_dict = {}
|
||||
i = 0
|
||||
while i < len(prosody_lines):
|
||||
if len(prosody_lines[i].strip().split('\t')) == 2:
|
||||
idx = prosody_lines[i].strip().split('\t')[0]
|
||||
i += 1
|
||||
else:
|
||||
fp_enable = is_fp_line(prosody_lines[i])
|
||||
if fp_enable:
|
||||
fp = prosody_lines[i].strip().split('\t')[0].split(' ')
|
||||
for label in fp:
|
||||
if label not in fp_category_list:
|
||||
logging.warning('fp label not in fp_category_list')
|
||||
break
|
||||
i += 4
|
||||
else:
|
||||
fp = [
|
||||
'N' for _ in range(
|
||||
len(prosody_lines[i].strip().split('\t')
|
||||
[0].replace('/ ', '').replace('. ', '').split(
|
||||
' ')))
|
||||
]
|
||||
i += 1
|
||||
fp_label_dict[idx] = fp
|
||||
|
||||
fpadd_metafile = os.path.join(voice_output_dir, 'fpadd_metafile.txt')
|
||||
f_out = open(fpadd_metafile, 'w')
|
||||
for line in raw_metafile_lines:
|
||||
tokens = line.strip().split('\t')
|
||||
if len(tokens) == 2:
|
||||
uttname = tokens[0]
|
||||
symbol_sequences = tokens[1].split(' ')
|
||||
|
||||
error_flag = False
|
||||
idx = 0
|
||||
out_str = uttname + '\t'
|
||||
|
||||
for this_symbol_sequence in symbol_sequences:
|
||||
emotion = this_symbol_sequence.split('$')[4]
|
||||
this_symbol_sequence = this_symbol_sequence.replace(
|
||||
emotion, 'emotion_neutral')
|
||||
|
||||
if idx < len(fp_label_dict[uttname]):
|
||||
if fp_label_dict[uttname][idx] == 'FP':
|
||||
if 'none' not in this_symbol_sequence:
|
||||
this_symbol_sequence = this_symbol_sequence.replace(
|
||||
'emotion_neutral', 'emotion_disgust')
|
||||
syllable_label = this_symbol_sequence.split('$')[2]
|
||||
if syllable_label == 's_both' or syllable_label == 's_end':
|
||||
idx += 1
|
||||
elif idx > len(fp_label_dict[uttname]):
|
||||
logging.warning(uttname + ' not match')
|
||||
error_flag = True
|
||||
out_str = out_str + this_symbol_sequence + ' '
|
||||
|
||||
# if idx != len(fp_label_dict[uttname]):
|
||||
# logging.warning(
|
||||
# "{} length mismatch, length: {} ".format(
|
||||
# idx, len(fp_label_dict[uttname])
|
||||
# )
|
||||
# )
|
||||
|
||||
if not error_flag:
|
||||
f_out.write(out_str.strip() + '\n')
|
||||
f_out.close()
|
||||
return fpadd_metafile
|
||||
|
||||
def removefp(self, voice_output_dir, fpadd_metafile, raw_metafile_lines):
|
||||
|
||||
f = open(fpadd_metafile)
|
||||
fpadd_metafile_lines = f.readlines()
|
||||
f.close()
|
||||
|
||||
fprm_metafile = os.path.join(voice_output_dir, 'fprm_metafile.txt')
|
||||
f_out = open(fprm_metafile, 'w')
|
||||
for i in range(len(raw_metafile_lines)):
|
||||
tokens = raw_metafile_lines[i].strip().split('\t')
|
||||
symbol_sequences = tokens[1].split(' ')
|
||||
fpadd_tokens = fpadd_metafile_lines[i].strip().split('\t')
|
||||
fpadd_symbol_sequences = fpadd_tokens[1].split(' ')
|
||||
|
||||
error_flag = False
|
||||
out_str = tokens[0] + '\t'
|
||||
idx = 0
|
||||
length = len(symbol_sequences)
|
||||
while idx < length:
|
||||
if '$emotion_disgust' in fpadd_symbol_sequences[idx]:
|
||||
if idx + 1 < length and 'none' in fpadd_symbol_sequences[
|
||||
idx + 1]:
|
||||
idx = idx + 2
|
||||
else:
|
||||
idx = idx + 1
|
||||
continue
|
||||
out_str = out_str + symbol_sequences[idx] + ' '
|
||||
idx = idx + 1
|
||||
|
||||
if not error_flag:
|
||||
f_out.write(out_str.strip() + '\n')
|
||||
f_out.close()
|
||||
|
||||
def process(self, voice_output_dir, prosody, raw_metafile):
|
||||
|
||||
with open(raw_metafile, 'r') as f:
|
||||
lines = f.readlines()
|
||||
random.shuffle(lines)
|
||||
|
||||
fpadd_metafile = self.addfp(voice_output_dir, prosody, lines)
|
||||
self.removefp(voice_output_dir, fpadd_metafile, lines)
|
||||
@@ -1,46 +0,0 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
languages = {
|
||||
'PinYin': {
|
||||
'phoneset_path': 'PhoneSet.xml',
|
||||
'posset_path': 'PosSet.xml',
|
||||
'f2t_map_path': 'En2ChPhoneMap.txt',
|
||||
's2p_map_path': 'py2phoneMap.txt',
|
||||
'tonelist_path': 'tonelist.txt',
|
||||
},
|
||||
'ZhHK': {
|
||||
'phoneset_path': 'PhoneSet.xml',
|
||||
'posset_path': 'PosSet.xml',
|
||||
'f2t_map_path': 'En2ChPhoneMap.txt',
|
||||
's2p_map_path': 'py2phoneMap.txt',
|
||||
'tonelist_path': 'tonelist.txt',
|
||||
},
|
||||
'WuuShanghai': {
|
||||
'phoneset_path': 'PhoneSet.xml',
|
||||
'posset_path': 'PosSet.xml',
|
||||
'f2t_map_path': 'En2ChPhoneMap.txt',
|
||||
's2p_map_path': 'py2phoneMap.txt',
|
||||
'tonelist_path': 'tonelist.txt',
|
||||
},
|
||||
'Sichuan': {
|
||||
'phoneset_path': 'PhoneSet.xml',
|
||||
'posset_path': 'PosSet.xml',
|
||||
'f2t_map_path': 'En2ChPhoneMap.txt',
|
||||
's2p_map_path': 'py2phoneMap.txt',
|
||||
'tonelist_path': 'tonelist.txt',
|
||||
},
|
||||
'EnGB': {
|
||||
'phoneset_path': 'PhoneSet.xml',
|
||||
'posset_path': 'PosSet.xml',
|
||||
'f2t_map_path': '',
|
||||
's2p_map_path': '',
|
||||
'tonelist_path': 'tonelist.txt',
|
||||
},
|
||||
'EnUS': {
|
||||
'phoneset_path': 'PhoneSet.xml',
|
||||
'posset_path': 'PosSet.xml',
|
||||
'f2t_map_path': '',
|
||||
's2p_map_path': '',
|
||||
'tonelist_path': 'tonelist.txt',
|
||||
}
|
||||
}
|
||||
@@ -1,242 +0,0 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class Tone(Enum):
|
||||
UnAssigned = -1
|
||||
NoneTone = 0
|
||||
YinPing = 1 # ZhHK: YinPingYinRu EnUS: primary stress
|
||||
YangPing = 2 # ZhHK: YinShang EnUS: secondary stress
|
||||
ShangSheng = 3 # ZhHK: YinQuZhongRu
|
||||
QuSheng = 4 # ZhHK: YangPing
|
||||
QingSheng = 5 # ZhHK: YangShang
|
||||
YangQuYangRu = 6 # ZhHK: YangQuYangRu
|
||||
|
||||
@classmethod
|
||||
def parse(cls, in_str):
|
||||
if not isinstance(in_str, str):
|
||||
return super(Tone, cls).__new__(cls, in_str)
|
||||
|
||||
if in_str in ['UnAssigned', '-1']:
|
||||
return Tone.UnAssigned
|
||||
elif in_str in ['NoneTone', '0']:
|
||||
return Tone.NoneTone
|
||||
elif in_str in ['YinPing', '1']:
|
||||
return Tone.YinPing
|
||||
elif in_str in ['YangPing', '2']:
|
||||
return Tone.YangPing
|
||||
elif in_str in ['ShangSheng', '3']:
|
||||
return Tone.ShangSheng
|
||||
elif in_str in ['QuSheng', '4']:
|
||||
return Tone.QuSheng
|
||||
elif in_str in ['QingSheng', '5']:
|
||||
return Tone.QingSheng
|
||||
elif in_str in ['YangQuYangRu', '6']:
|
||||
return Tone.YangQuYangRu
|
||||
else:
|
||||
return Tone.NoneTone
|
||||
|
||||
|
||||
class BreakLevel(Enum):
|
||||
UnAssigned = -1
|
||||
L0 = 0
|
||||
L1 = 1
|
||||
L2 = 2
|
||||
L3 = 3
|
||||
L4 = 4
|
||||
|
||||
@classmethod
|
||||
def parse(cls, in_str):
|
||||
if not isinstance(in_str, str):
|
||||
return super(BreakLevel, cls).__new__(cls, in_str)
|
||||
|
||||
if in_str in ['UnAssigned', '-1']:
|
||||
return BreakLevel.UnAssigned
|
||||
elif in_str in ['L0', '0']:
|
||||
return BreakLevel.L0
|
||||
elif in_str in ['L1', '1']:
|
||||
return BreakLevel.L1
|
||||
elif in_str in ['L2', '2']:
|
||||
return BreakLevel.L2
|
||||
elif in_str in ['L3', '3']:
|
||||
return BreakLevel.L3
|
||||
elif in_str in ['L4', '4']:
|
||||
return BreakLevel.L4
|
||||
else:
|
||||
return BreakLevel.UnAssigned
|
||||
|
||||
|
||||
class SentencePurpose(Enum):
|
||||
Declarative = 0
|
||||
Interrogative = 1
|
||||
Exclamatory = 2
|
||||
Imperative = 3
|
||||
|
||||
|
||||
class Language(Enum):
|
||||
Neutral = 0
|
||||
EnUS = 1033
|
||||
EnGB = 2057
|
||||
ZhCN = 2052
|
||||
PinYin = 2053
|
||||
WuuShanghai = 2054
|
||||
Sichuan = 2055
|
||||
ZhHK = 3076
|
||||
ZhEn = ZhCN | EnUS
|
||||
|
||||
@classmethod
|
||||
def parse(cls, in_str):
|
||||
if not isinstance(in_str, str):
|
||||
return super(Language, cls).__new__(cls, in_str)
|
||||
|
||||
if in_str in ['Neutral', '0']:
|
||||
return Language.Neutral
|
||||
elif in_str in ['EnUS', '1033']:
|
||||
return Language.EnUS
|
||||
elif in_str in ['EnGB', '2057']:
|
||||
return Language.EnGB
|
||||
elif in_str in ['ZhCN', '2052']:
|
||||
return Language.ZhCN
|
||||
elif in_str in ['PinYin', '2053']:
|
||||
return Language.PinYin
|
||||
elif in_str in ['WuuShanghai', '2054']:
|
||||
return Language.WuuShanghai
|
||||
elif in_str in ['Sichuan', '2055']:
|
||||
return Language.Sichuan
|
||||
elif in_str in ['ZhHK', '3076']:
|
||||
return Language.ZhHK
|
||||
elif in_str in ['ZhEn', '2052|1033']:
|
||||
return Language.ZhEn
|
||||
else:
|
||||
return Language.Neutral
|
||||
|
||||
|
||||
"""
|
||||
Phone Types
|
||||
"""
|
||||
|
||||
|
||||
class PhoneCVType(Enum):
|
||||
NULL = -1
|
||||
Consonant = 1
|
||||
Vowel = 2
|
||||
|
||||
@classmethod
|
||||
def parse(cls, in_str):
|
||||
if not isinstance(in_str, str):
|
||||
return super(PhoneCVType, cls).__new__(cls, in_str)
|
||||
|
||||
if in_str in ['consonant', 'Consonant']:
|
||||
return PhoneCVType.Consonant
|
||||
elif in_str in ['vowel', 'Vowel']:
|
||||
return PhoneCVType.Vowel
|
||||
else:
|
||||
return PhoneCVType.NULL
|
||||
|
||||
|
||||
class PhoneIFType(Enum):
|
||||
NULL = -1
|
||||
Initial = 1
|
||||
Final = 2
|
||||
|
||||
@classmethod
|
||||
def parse(cls, in_str):
|
||||
if not isinstance(in_str, str):
|
||||
return super(PhoneIFType, cls).__new__(cls, in_str)
|
||||
if in_str in ['initial', 'Initial']:
|
||||
return PhoneIFType.Initial
|
||||
elif in_str in ['final', 'Final']:
|
||||
return PhoneIFType.Final
|
||||
else:
|
||||
return PhoneIFType.NULL
|
||||
|
||||
|
||||
class PhoneUVType(Enum):
|
||||
NULL = -1
|
||||
Voiced = 1
|
||||
UnVoiced = 2
|
||||
|
||||
@classmethod
|
||||
def parse(cls, in_str):
|
||||
if not isinstance(in_str, str):
|
||||
return super(PhoneUVType, cls).__new__(cls, in_str)
|
||||
if in_str in ['voiced', 'Voiced']:
|
||||
return PhoneUVType.Voiced
|
||||
elif in_str in ['unvoiced', 'UnVoiced']:
|
||||
return PhoneUVType.UnVoiced
|
||||
else:
|
||||
return PhoneUVType.NULL
|
||||
|
||||
|
||||
class PhoneAPType(Enum):
|
||||
NULL = -1
|
||||
DoubleLips = 1
|
||||
LipTooth = 2
|
||||
FrontTongue = 3
|
||||
CentralTongue = 4
|
||||
BackTongue = 5
|
||||
Dorsal = 6
|
||||
Velar = 7
|
||||
Low = 8
|
||||
Middle = 9
|
||||
High = 10
|
||||
|
||||
@classmethod
|
||||
def parse(cls, in_str):
|
||||
if not isinstance(in_str, str):
|
||||
return super(PhoneAPType, cls).__new__(cls, in_str)
|
||||
if in_str in ['doublelips', 'DoubleLips']:
|
||||
return PhoneAPType.DoubleLips
|
||||
elif in_str in ['liptooth', 'LipTooth']:
|
||||
return PhoneAPType.LipTooth
|
||||
elif in_str in ['fronttongue', 'FrontTongue']:
|
||||
return PhoneAPType.FrontTongue
|
||||
elif in_str in ['centraltongue', 'CentralTongue']:
|
||||
return PhoneAPType.CentralTongue
|
||||
elif in_str in ['backtongue', 'BackTongue']:
|
||||
return PhoneAPType.BackTongue
|
||||
elif in_str in ['dorsal', 'Dorsal']:
|
||||
return PhoneAPType.Dorsal
|
||||
elif in_str in ['velar', 'Velar']:
|
||||
return PhoneAPType.Velar
|
||||
elif in_str in ['low', 'Low']:
|
||||
return PhoneAPType.Low
|
||||
elif in_str in ['middle', 'Middle']:
|
||||
return PhoneAPType.Middle
|
||||
elif in_str in ['high', 'High']:
|
||||
return PhoneAPType.High
|
||||
else:
|
||||
return PhoneAPType.NULL
|
||||
|
||||
|
||||
class PhoneAMType(Enum):
|
||||
NULL = -1
|
||||
Stop = 1
|
||||
Affricate = 2
|
||||
Fricative = 3
|
||||
Nasal = 4
|
||||
Lateral = 5
|
||||
Open = 6
|
||||
Close = 7
|
||||
|
||||
@classmethod
|
||||
def parse(cls, in_str):
|
||||
if not isinstance(in_str, str):
|
||||
return super(PhoneAMType, cls).__new__(cls, in_str)
|
||||
if in_str in ['stop', 'Stop']:
|
||||
return PhoneAMType.Stop
|
||||
elif in_str in ['affricate', 'Affricate']:
|
||||
return PhoneAMType.Affricate
|
||||
elif in_str in ['fricative', 'Fricative']:
|
||||
return PhoneAMType.Fricative
|
||||
elif in_str in ['nasal', 'Nasal']:
|
||||
return PhoneAMType.Nasal
|
||||
elif in_str in ['lateral', 'Lateral']:
|
||||
return PhoneAMType.Lateral
|
||||
elif in_str in ['open', 'Open']:
|
||||
return PhoneAMType.Open
|
||||
elif in_str in ['close', 'Close']:
|
||||
return PhoneAMType.Close
|
||||
else:
|
||||
return PhoneAMType.NULL
|
||||
@@ -1,48 +0,0 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from .core_types import (PhoneAMType, PhoneAPType, PhoneCVType, PhoneIFType,
|
||||
PhoneUVType)
|
||||
from .xml_obj import XmlObj
|
||||
|
||||
|
||||
class Phone(XmlObj):
|
||||
|
||||
def __init__(self):
|
||||
self.m_id = None
|
||||
self.m_name = None
|
||||
self.m_cv_type = PhoneCVType.NULL
|
||||
self.m_if_type = PhoneIFType.NULL
|
||||
self.m_uv_type = PhoneUVType.NULL
|
||||
self.m_ap_type = PhoneAPType.NULL
|
||||
self.m_am_type = PhoneAMType.NULL
|
||||
self.m_bnd = False
|
||||
|
||||
def __str__(self):
|
||||
return self.m_name
|
||||
|
||||
def save(self):
|
||||
pass
|
||||
|
||||
def load(self, phone_node):
|
||||
ns = '{http://schemas.alibaba-inc.com/tts}'
|
||||
|
||||
id_node = phone_node.find(ns + 'id')
|
||||
self.m_id = int(id_node.text)
|
||||
|
||||
name_node = phone_node.find(ns + 'name')
|
||||
self.m_name = name_node.text
|
||||
|
||||
cv_node = phone_node.find(ns + 'cv')
|
||||
self.m_cv_type = PhoneCVType.parse(cv_node.text)
|
||||
|
||||
if_node = phone_node.find(ns + 'if')
|
||||
self.m_if_type = PhoneIFType.parse(if_node.text)
|
||||
|
||||
uv_node = phone_node.find(ns + 'uv')
|
||||
self.m_uv_type = PhoneUVType.parse(uv_node.text)
|
||||
|
||||
ap_node = phone_node.find(ns + 'ap')
|
||||
self.m_ap_type = PhoneAPType.parse(ap_node.text)
|
||||
|
||||
am_node = phone_node.find(ns + 'am')
|
||||
self.m_am_type = PhoneAMType.parse(am_node.text)
|
||||
@@ -1,39 +0,0 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
from modelscope.utils.logger import get_logger
|
||||
from .phone import Phone
|
||||
from .xml_obj import XmlObj
|
||||
|
||||
logging = get_logger()
|
||||
|
||||
|
||||
class PhoneSet(XmlObj):
|
||||
|
||||
def __init__(self, phoneset_path):
|
||||
self.m_phone_list = []
|
||||
self.m_id_map = {}
|
||||
self.m_name_map = {}
|
||||
self.load(phoneset_path)
|
||||
|
||||
def load(self, file_path):
|
||||
# alibaba tts xml namespace
|
||||
ns = '{http://schemas.alibaba-inc.com/tts}'
|
||||
|
||||
phoneset_root = ET.parse(file_path).getroot()
|
||||
for phone_node in phoneset_root.findall(ns + 'phone'):
|
||||
phone = Phone()
|
||||
phone.load(phone_node)
|
||||
self.m_phone_list.append(phone)
|
||||
if phone.m_id in self.m_id_map:
|
||||
logging.error('PhoneSet.Load: duplicate id: %d', phone.m_id)
|
||||
self.m_id_map[phone.m_id] = phone
|
||||
|
||||
if phone.m_name in self.m_name_map:
|
||||
logging.error('PhoneSet.Load duplicate name name: %s',
|
||||
phone.m_name)
|
||||
self.m_name_map[phone.m_name] = phone
|
||||
|
||||
def save(self):
|
||||
pass
|
||||
@@ -1,43 +0,0 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from .xml_obj import XmlObj
|
||||
|
||||
|
||||
class Pos(XmlObj):
|
||||
|
||||
def __init__(self):
|
||||
self.m_id = None
|
||||
self.m_name = None
|
||||
self.m_desc = None
|
||||
self.m_level = 1
|
||||
self.m_parent = None
|
||||
self.m_sub_pos_list = []
|
||||
|
||||
def __str__(self):
|
||||
return self.m_name
|
||||
|
||||
def save(self):
|
||||
pass
|
||||
|
||||
def load(self, pos_node):
|
||||
ns = '{http://schemas.alibaba-inc.com/tts}'
|
||||
|
||||
id_node = pos_node.find(ns + 'id')
|
||||
self.m_id = int(id_node.text)
|
||||
|
||||
name_node = pos_node.find(ns + 'name')
|
||||
self.m_name = name_node.text
|
||||
|
||||
desc_node = pos_node.find(ns + 'desc')
|
||||
self.m_desc = desc_node.text
|
||||
|
||||
sub_node = pos_node.find(ns + 'sub')
|
||||
if sub_node is not None:
|
||||
for sub_pos_node in sub_node.findall(ns + 'pos'):
|
||||
sub_pos = Pos()
|
||||
sub_pos.load(sub_pos_node)
|
||||
sub_pos.m_parent = self
|
||||
sub_pos.m_level = self.m_level + 1
|
||||
self.m_sub_pos_list.append(sub_pos)
|
||||
|
||||
return
|
||||
@@ -1,50 +0,0 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import logging
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
from .pos import Pos
|
||||
from .xml_obj import XmlObj
|
||||
|
||||
|
||||
class PosSet(XmlObj):
|
||||
|
||||
def __init__(self, posset_path):
|
||||
self.m_pos_list = []
|
||||
self.m_id_map = {}
|
||||
self.m_name_map = {}
|
||||
self.load(posset_path)
|
||||
|
||||
def load(self, file_path):
|
||||
# alibaba tts xml namespace
|
||||
ns = '{http://schemas.alibaba-inc.com/tts}'
|
||||
|
||||
posset_root = ET.parse(file_path).getroot()
|
||||
for pos_node in posset_root.findall(ns + 'pos'):
|
||||
pos = Pos()
|
||||
pos.load(pos_node)
|
||||
self.m_pos_list.append(pos)
|
||||
if pos.m_id in self.m_id_map:
|
||||
logging.error('PosSet.Load: duplicate id: %d', pos.m_id)
|
||||
self.m_id_map[pos.m_id] = pos
|
||||
|
||||
if pos.m_name in self.m_name_map:
|
||||
logging.error('PosSet.Load duplicate name name: %s',
|
||||
pos.m_name)
|
||||
self.m_name_map[pos.m_name] = pos
|
||||
|
||||
if len(pos.m_sub_pos_list) > 0:
|
||||
for sub_pos in pos.m_sub_pos_list:
|
||||
self.m_pos_list.append(sub_pos)
|
||||
if sub_pos.m_id in self.m_id_map:
|
||||
logging.error('PosSet.Load: duplicate id: %d',
|
||||
sub_pos.m_id)
|
||||
self.m_id_map[sub_pos.m_id] = sub_pos
|
||||
|
||||
if sub_pos.m_name in self.m_name_map:
|
||||
logging.error('PosSet.Load duplicate name name: %s',
|
||||
sub_pos.m_name)
|
||||
self.m_name_map[sub_pos.m_name] = sub_pos
|
||||
|
||||
def save(self):
|
||||
pass
|
||||
@@ -1,35 +0,0 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import xml.etree.ElementTree as ET
|
||||
from xml.dom import minidom
|
||||
|
||||
from .xml_obj import XmlObj
|
||||
|
||||
|
||||
class Script(XmlObj):
|
||||
|
||||
def __init__(self, phoneset, posset):
|
||||
self.m_phoneset = phoneset
|
||||
self.m_posset = posset
|
||||
self.m_items = []
|
||||
|
||||
def save(self, outputXMLPath):
|
||||
root = ET.Element('script')
|
||||
|
||||
root.set('uttcount', str(len(self.m_items)))
|
||||
root.set('xmlns', 'http://schemas.alibaba-inc.com/tts')
|
||||
for item in self.m_items:
|
||||
item.save(root)
|
||||
|
||||
xmlstr = minidom.parseString(ET.tostring(root)).toprettyxml(
|
||||
indent=' ', encoding='utf-8')
|
||||
with open(outputXMLPath, 'wb') as f:
|
||||
f.write(xmlstr)
|
||||
|
||||
def save_meta_file(self):
|
||||
meta_lines = []
|
||||
|
||||
for item in self.m_items:
|
||||
meta_lines.append(item.save_metafile())
|
||||
|
||||
return meta_lines
|
||||
@@ -1,40 +0,0 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
from .xml_obj import XmlObj
|
||||
|
||||
|
||||
class ScriptItem(XmlObj):
|
||||
|
||||
def __init__(self, phoneset, posset):
|
||||
if phoneset is None or posset is None:
|
||||
raise Exception('ScriptItem.__init__: phoneset or posset is None')
|
||||
self.m_phoneset = phoneset
|
||||
self.m_posset = posset
|
||||
|
||||
self.m_id = None
|
||||
self.m_text = ''
|
||||
self.m_scriptSentence_list = []
|
||||
self.m_status = None
|
||||
|
||||
def load(self):
|
||||
pass
|
||||
|
||||
def save(self, parent_node):
|
||||
utterance_node = ET.SubElement(parent_node, 'utterance')
|
||||
utterance_node.set('id', self.m_id)
|
||||
|
||||
text_node = ET.SubElement(utterance_node, 'text')
|
||||
text_node.text = self.m_text
|
||||
|
||||
for sentence in self.m_scriptSentence_list:
|
||||
sentence.save(utterance_node)
|
||||
|
||||
def save_metafile(self):
|
||||
meta_line = self.m_id + '\t'
|
||||
|
||||
for sentence in self.m_scriptSentence_list:
|
||||
meta_line += sentence.save_metafile()
|
||||
|
||||
return meta_line
|
||||
@@ -1,185 +0,0 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
from .xml_obj import XmlObj
|
||||
|
||||
|
||||
class WrittenSentence(XmlObj):
|
||||
|
||||
def __init__(self, posset):
|
||||
self.m_written_word_list = []
|
||||
self.m_written_mark_list = []
|
||||
self.m_posset = posset
|
||||
self.m_align_list = []
|
||||
self.m_alignCursor = 0
|
||||
self.m_accompanyIndex = 0
|
||||
self.m_sequence = ''
|
||||
self.m_text = ''
|
||||
|
||||
def add_host(self, writtenWord):
|
||||
self.m_written_word_list.append(writtenWord)
|
||||
self.m_align_list.append(self.m_alignCursor)
|
||||
|
||||
def load_host(self):
|
||||
pass
|
||||
|
||||
def save_host(self):
|
||||
pass
|
||||
|
||||
def add_accompany(self, writtenMark):
|
||||
self.m_written_mark_list.append(writtenMark)
|
||||
self.m_alignCursor += 1
|
||||
self.m_accompanyIndex += 1
|
||||
|
||||
def save_accompany(self):
|
||||
pass
|
||||
|
||||
def load_accompany(self):
|
||||
pass
|
||||
|
||||
# Get the mark span corresponding to specific spoken word
|
||||
def get_accompany_span(self, host_index):
|
||||
if host_index == -1:
|
||||
return (0, self.m_align_list[0])
|
||||
|
||||
accompany_begin = self.m_align_list[host_index]
|
||||
accompany_end = (
|
||||
self.m_align_list[host_index + 1]
|
||||
if host_index + 1 < len(self.m_written_word_list) else len(
|
||||
self.m_written_mark_list))
|
||||
|
||||
return (accompany_begin, accompany_end)
|
||||
|
||||
def get_elements(self):
|
||||
accompany_begin, accompany_end = self.get_accompany_span(-1)
|
||||
res_lst = [
|
||||
self.m_written_mark_list[i]
|
||||
for i in range(accompany_begin, accompany_end)
|
||||
]
|
||||
|
||||
for j in range(len(self.m_written_word_list)):
|
||||
accompany_begin, accompany_end = self.get_accompany_span(j)
|
||||
res_lst.extend([self.m_written_word_list[j]])
|
||||
res_lst.extend([
|
||||
self.m_written_mark_list[i]
|
||||
for i in range(accompany_begin, accompany_end)
|
||||
])
|
||||
|
||||
return res_lst
|
||||
|
||||
def build_sequence(self):
|
||||
self.m_sequence = ' '.join([str(ele) for ele in self.get_elements()])
|
||||
|
||||
def build_text(self):
|
||||
self.m_text = ''.join([str(ele) for ele in self.get_elements()])
|
||||
|
||||
|
||||
class SpokenSentence(XmlObj):
|
||||
|
||||
def __init__(self, phoneset):
|
||||
self.m_spoken_word_list = []
|
||||
self.m_spoken_mark_list = []
|
||||
self.m_phoneset = phoneset
|
||||
self.m_align_list = []
|
||||
self.m_alignCursor = 0
|
||||
self.m_accompanyIndex = 0
|
||||
self.m_sequence = ''
|
||||
self.m_text = ''
|
||||
|
||||
def __len__(self):
|
||||
return len(self.m_spoken_word_list)
|
||||
|
||||
def add_host(self, spokenWord):
|
||||
self.m_spoken_word_list.append(spokenWord)
|
||||
self.m_align_list.append(self.m_alignCursor)
|
||||
|
||||
def save_host(self):
|
||||
pass
|
||||
|
||||
def load_host(self):
|
||||
pass
|
||||
|
||||
def add_accompany(self, spokenMark):
|
||||
self.m_spoken_mark_list.append(spokenMark)
|
||||
self.m_alignCursor += 1
|
||||
self.m_accompanyIndex += 1
|
||||
|
||||
def save_accompany(self):
|
||||
pass
|
||||
|
||||
# Get the mark span corresponding to specific spoken word
|
||||
def get_accompany_span(self, host_index):
|
||||
if host_index == -1:
|
||||
return (0, self.m_align_list[0])
|
||||
|
||||
accompany_begin = self.m_align_list[host_index]
|
||||
accompany_end = (
|
||||
self.m_align_list[host_index + 1]
|
||||
if host_index + 1 < len(self.m_spoken_word_list) else len(
|
||||
self.m_spoken_mark_list))
|
||||
|
||||
return (accompany_begin, accompany_end)
|
||||
|
||||
def get_elements(self):
|
||||
accompany_begin, accompany_end = self.get_accompany_span(-1)
|
||||
res_lst = [
|
||||
self.m_spoken_mark_list[i]
|
||||
for i in range(accompany_begin, accompany_end)
|
||||
]
|
||||
|
||||
for j in range(len(self.m_spoken_word_list)):
|
||||
accompany_begin, accompany_end = self.get_accompany_span(j)
|
||||
res_lst.extend([self.m_spoken_word_list[j]])
|
||||
res_lst.extend([
|
||||
self.m_spoken_mark_list[i]
|
||||
for i in range(accompany_begin, accompany_end)
|
||||
])
|
||||
|
||||
return res_lst
|
||||
|
||||
def load_accompany(self):
|
||||
pass
|
||||
|
||||
def build_sequence(self):
|
||||
self.m_sequence = ' '.join([str(ele) for ele in self.get_elements()])
|
||||
|
||||
def build_text(self):
|
||||
self.m_text = ''.join([str(ele) for ele in self.get_elements()])
|
||||
|
||||
def save(self, parent_node):
|
||||
spoken_node = ET.SubElement(parent_node, 'spoken')
|
||||
spoken_node.set('wordcount', str(len(self.m_spoken_word_list)))
|
||||
|
||||
text_node = ET.SubElement(spoken_node, 'text')
|
||||
text_node.text = self.m_sequence
|
||||
|
||||
for word in self.m_spoken_word_list:
|
||||
word.save(spoken_node)
|
||||
|
||||
def save_metafile(self):
|
||||
meta_line_list = [
|
||||
word.save_metafile() for word in self.m_spoken_word_list
|
||||
]
|
||||
|
||||
return ' '.join(meta_line_list)
|
||||
|
||||
|
||||
class ScriptSentence(XmlObj):
|
||||
|
||||
def __init__(self, phoneset, posset):
|
||||
self.m_phoneset = phoneset
|
||||
self.m_posset = posset
|
||||
self.m_writtenSentence = WrittenSentence(posset)
|
||||
self.m_spokenSentence = SpokenSentence(phoneset)
|
||||
self.m_text = ''
|
||||
|
||||
def save(self, parent_node):
|
||||
if len(self.m_spokenSentence) > 0:
|
||||
self.m_spokenSentence.save(parent_node)
|
||||
|
||||
def save_metafile(self):
|
||||
if len(self.m_spokenSentence) > 0:
|
||||
return self.m_spokenSentence.save_metafile()
|
||||
else:
|
||||
return ''
|
||||
@@ -1,120 +0,0 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
from .core_types import Language
|
||||
from .syllable import SyllableList
|
||||
from .xml_obj import XmlObj
|
||||
|
||||
|
||||
class WrittenWord(XmlObj):
|
||||
|
||||
def __init__(self):
|
||||
self.m_name = None
|
||||
self.m_POS = None
|
||||
|
||||
def __str__(self):
|
||||
return self.m_name
|
||||
|
||||
def load(self):
|
||||
pass
|
||||
|
||||
def save(self):
|
||||
pass
|
||||
|
||||
|
||||
class WrittenMark(XmlObj):
|
||||
|
||||
def __init__(self):
|
||||
self.m_punctuation = None
|
||||
|
||||
def __str__(self):
|
||||
return self.m_punctuation
|
||||
|
||||
def load(self):
|
||||
pass
|
||||
|
||||
def save(self):
|
||||
pass
|
||||
|
||||
|
||||
class SpokenWord(XmlObj):
|
||||
|
||||
def __init__(self):
|
||||
self.m_name = None
|
||||
self.m_language = None
|
||||
self.m_syllable_list = []
|
||||
self.m_breakText = '1'
|
||||
self.m_POS = '0'
|
||||
|
||||
def __str__(self):
|
||||
return self.m_name
|
||||
|
||||
def load(self):
|
||||
pass
|
||||
|
||||
def save(self, parent_node):
|
||||
|
||||
word_node = ET.SubElement(parent_node, 'word')
|
||||
|
||||
name_node = ET.SubElement(word_node, 'name')
|
||||
name_node.text = self.m_name
|
||||
|
||||
if (len(self.m_syllable_list) > 0
|
||||
and self.m_syllable_list[0].m_language != Language.Neutral):
|
||||
language_node = ET.SubElement(word_node, 'lang')
|
||||
language_node.text = self.m_syllable_list[0].m_language.name
|
||||
|
||||
SyllableList(self.m_syllable_list).save(word_node)
|
||||
|
||||
break_node = ET.SubElement(word_node, 'break')
|
||||
break_node.text = self.m_breakText
|
||||
|
||||
POS_node = ET.SubElement(word_node, 'POS')
|
||||
POS_node.text = self.m_POS
|
||||
|
||||
return
|
||||
|
||||
def save_metafile(self):
|
||||
word_phone_cnt = sum(
|
||||
[syllable.phone_count() for syllable in self.m_syllable_list])
|
||||
word_syllable_cnt = len(self.m_syllable_list)
|
||||
single_syllable_word = word_syllable_cnt == 1
|
||||
meta_line_list = []
|
||||
|
||||
for idx, syll in enumerate(self.m_syllable_list):
|
||||
if word_phone_cnt == 1:
|
||||
word_pos = 'word_both'
|
||||
elif idx == 0:
|
||||
word_pos = 'word_begin'
|
||||
elif idx == len(self.m_syllable_list) - 1:
|
||||
word_pos = 'word_end'
|
||||
else:
|
||||
word_pos = 'word_middle'
|
||||
meta_line_list.append(
|
||||
syll.save_metafile(
|
||||
word_pos, single_syllable_word=single_syllable_word))
|
||||
|
||||
if self.m_breakText != '0' and self.m_breakText is not None:
|
||||
meta_line_list.append('{{#{}$tone_none$s_none$word_none}}'.format(
|
||||
self.m_breakText))
|
||||
|
||||
return ' '.join(meta_line_list)
|
||||
|
||||
|
||||
class SpokenMark(XmlObj):
|
||||
|
||||
def __init__(self):
|
||||
self.m_breakLevel = None
|
||||
|
||||
def break_level2text(self):
|
||||
return '#' + str(self.m_breakLevel.value)
|
||||
|
||||
def __str__(self):
|
||||
return self.break_level2text()
|
||||
|
||||
def load(self):
|
||||
pass
|
||||
|
||||
def save(self):
|
||||
pass
|
||||
@@ -1,112 +0,0 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
from .xml_obj import XmlObj
|
||||
|
||||
|
||||
class Syllable(XmlObj):
|
||||
|
||||
def __init__(self):
|
||||
self.m_phone_list = []
|
||||
self.m_tone = None
|
||||
self.m_language = None
|
||||
self.m_breaklevel = None
|
||||
|
||||
def pronunciation_text(self):
|
||||
return ' '.join([str(phone) for phone in self.m_phone_list])
|
||||
|
||||
def phone_count(self):
|
||||
return len(self.m_phone_list)
|
||||
|
||||
def tone_text(self):
|
||||
return str(self.m_tone.value)
|
||||
|
||||
def save(self):
|
||||
pass
|
||||
|
||||
def load(self):
|
||||
pass
|
||||
|
||||
def get_phone_meta(self,
|
||||
phone_name,
|
||||
word_pos,
|
||||
syll_pos,
|
||||
tone_text,
|
||||
single_syllable_word=False):
|
||||
# Special case: word with single syllable, the last phone's word_pos should be "word_end"
|
||||
if word_pos == 'word_begin' and syll_pos == 's_end' and single_syllable_word:
|
||||
word_pos = 'word_end'
|
||||
elif word_pos == 'word_begin' and syll_pos not in [
|
||||
's_begin',
|
||||
's_both',
|
||||
]: # FIXME: keep accord with Engine logic
|
||||
word_pos = 'word_middle'
|
||||
elif word_pos == 'word_end' and syll_pos not in ['s_end', 's_both']:
|
||||
word_pos = 'word_middle'
|
||||
else:
|
||||
pass
|
||||
|
||||
return '{{{}$tone{}${}${}}}'.format(phone_name, tone_text, syll_pos,
|
||||
word_pos)
|
||||
|
||||
def save_metafile(self, word_pos, single_syllable_word=False):
|
||||
syllable_phone_cnt = len(self.m_phone_list)
|
||||
|
||||
meta_line_list = []
|
||||
|
||||
for idx, phone in enumerate(self.m_phone_list):
|
||||
if syllable_phone_cnt == 1:
|
||||
syll_pos = 's_both'
|
||||
elif idx == 0:
|
||||
syll_pos = 's_begin'
|
||||
elif idx == len(self.m_phone_list) - 1:
|
||||
syll_pos = 's_end'
|
||||
else:
|
||||
syll_pos = 's_middle'
|
||||
meta_line_list.append(
|
||||
self.get_phone_meta(
|
||||
phone,
|
||||
word_pos,
|
||||
syll_pos,
|
||||
self.tone_text(),
|
||||
single_syllable_word=single_syllable_word,
|
||||
))
|
||||
|
||||
return ' '.join(meta_line_list)
|
||||
|
||||
|
||||
class SyllableList(XmlObj):
|
||||
|
||||
def __init__(self, syllables):
|
||||
self.m_syllable_list = syllables
|
||||
|
||||
def __len__(self):
|
||||
return len(self.m_syllable_list)
|
||||
|
||||
def __index__(self, index):
|
||||
return self.m_syllable_list[index]
|
||||
|
||||
def pronunciation_text(self):
|
||||
return ' - '.join([
|
||||
syllable.pronunciation_text() for syllable in self.m_syllable_list
|
||||
])
|
||||
|
||||
def tone_text(self):
|
||||
return ''.join(
|
||||
[syllable.tone_text() for syllable in self.m_syllable_list])
|
||||
|
||||
def save(self, parent_node):
|
||||
syllable_node = ET.SubElement(parent_node, 'syllable')
|
||||
syllable_node.set('syllcount', str(len(self.m_syllable_list)))
|
||||
|
||||
phone_node = ET.SubElement(syllable_node, 'phone')
|
||||
phone_node.text = self.pronunciation_text()
|
||||
|
||||
tone_node = ET.SubElement(syllable_node, 'tone')
|
||||
tone_node.text = self.tone_text()
|
||||
|
||||
return
|
||||
|
||||
def load(self):
|
||||
pass
|
||||
@@ -1,322 +0,0 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import re
|
||||
|
||||
from modelscope.utils.logger import get_logger
|
||||
from .core_types import Language, PhoneCVType, Tone
|
||||
from .syllable import Syllable
|
||||
from .utils import NgBreakPattern
|
||||
|
||||
logging = get_logger()
|
||||
|
||||
|
||||
class DefaultSyllableFormatter:
|
||||
|
||||
def __init__(self):
|
||||
return
|
||||
|
||||
def format(self, phoneset, pronText, syllable_list):
|
||||
logging.warning('Using DefaultSyllableFormatter dry run: %s', pronText)
|
||||
return True
|
||||
|
||||
|
||||
RegexNg2en = re.compile(NgBreakPattern)
|
||||
RegexQingSheng = re.compile(r'([1-5]5)')
|
||||
RegexPron = re.compile(r'(?P<Pron>[a-z]+)(?P<Tone>[1-6])')
|
||||
|
||||
|
||||
class ZhCNSyllableFormatter:
|
||||
|
||||
def __init__(self, sy2ph_map):
|
||||
self.m_sy2ph_map = sy2ph_map
|
||||
|
||||
def normalize_pron(self, pronText):
|
||||
# Replace Qing Sheng
|
||||
newPron = pronText.replace('6', '2')
|
||||
newPron = re.sub(RegexQingSheng, '5', newPron)
|
||||
|
||||
# FIXME(Jin): ng case overrides newPron
|
||||
match = RegexNg2en.search(newPron)
|
||||
if match:
|
||||
newPron = 'en' + match.group('break')
|
||||
|
||||
return newPron
|
||||
|
||||
def format(self, phoneset, pronText, syllable_list):
|
||||
if phoneset is None or syllable_list is None or pronText is None:
|
||||
logging.error('ZhCNSyllableFormatter.Format: invalid input')
|
||||
return False
|
||||
pronText = self.normalize_pron(pronText)
|
||||
|
||||
if pronText in self.m_sy2ph_map:
|
||||
phone_list = self.m_sy2ph_map[pronText].split(' ')
|
||||
if len(phone_list) == 3:
|
||||
syll = Syllable()
|
||||
for phone in phone_list:
|
||||
syll.m_phone_list.append(phone)
|
||||
syll.m_tone = Tone.parse(
|
||||
pronText[-1]) # FIXME(Jin): assume tone is the last char
|
||||
syll.m_language = Language.ZhCN
|
||||
syllable_list.append(syll)
|
||||
return True
|
||||
else:
|
||||
logging.error(
|
||||
'ZhCNSyllableFormatter.Format: invalid pronText: %s',
|
||||
pronText)
|
||||
return False
|
||||
else:
|
||||
logging.error(
|
||||
'ZhCNSyllableFormatter.Format: syllable to phone map missing key: %s',
|
||||
pronText,
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
class PinYinSyllableFormatter:
|
||||
|
||||
def __init__(self, sy2ph_map):
|
||||
self.m_sy2ph_map = sy2ph_map
|
||||
|
||||
def normalize_pron(self, pronText):
|
||||
newPron = pronText.replace('6', '2')
|
||||
newPron = re.sub(RegexQingSheng, '5', newPron)
|
||||
|
||||
# FIXME(Jin): ng case overrides newPron
|
||||
match = RegexNg2en.search(newPron)
|
||||
if match:
|
||||
newPron = 'en' + match.group('break')
|
||||
|
||||
return newPron
|
||||
|
||||
def format(self, phoneset, pronText, syllable_list):
|
||||
if phoneset is None or syllable_list is None or pronText is None:
|
||||
logging.error('PinYinSyllableFormatter.Format: invalid input')
|
||||
return False
|
||||
pronText = self.normalize_pron(pronText)
|
||||
|
||||
match = RegexPron.search(pronText)
|
||||
|
||||
if match:
|
||||
pron = match.group('Pron')
|
||||
tone = match.group('Tone')
|
||||
else:
|
||||
logging.error(
|
||||
'PinYinSyllableFormatter.Format: pronunciation is not valid: %s',
|
||||
pronText,
|
||||
)
|
||||
return False
|
||||
|
||||
if pron in self.m_sy2ph_map:
|
||||
phone_list = self.m_sy2ph_map[pron].split(' ')
|
||||
if len(phone_list) in [1, 2]:
|
||||
syll = Syllable()
|
||||
for phone in phone_list:
|
||||
syll.m_phone_list.append(phone)
|
||||
syll.m_tone = Tone.parse(tone)
|
||||
syll.m_language = Language.PinYin
|
||||
syllable_list.append(syll)
|
||||
return True
|
||||
else:
|
||||
logging.error(
|
||||
'PinYinSyllableFormatter.Format: invalid phone: %s', pron)
|
||||
return False
|
||||
else:
|
||||
logging.error(
|
||||
'PinYinSyllableFormatter.Format: syllable to phone map missing key: %s',
|
||||
pron,
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
class ZhHKSyllableFormatter:
|
||||
|
||||
def __init__(self, sy2ph_map):
|
||||
self.m_sy2ph_map = sy2ph_map
|
||||
|
||||
def format(self, phoneset, pronText, syllable_list):
|
||||
if phoneset is None or syllable_list is None or pronText is None:
|
||||
logging.error('ZhHKSyllableFormatter.Format: invalid input')
|
||||
return False
|
||||
|
||||
match = RegexPron.search(pronText)
|
||||
if match:
|
||||
pron = match.group('Pron')
|
||||
tone = match.group('Tone')
|
||||
else:
|
||||
logging.error(
|
||||
'ZhHKSyllableFormatter.Format: pronunciation is not valid: %s',
|
||||
pronText)
|
||||
return False
|
||||
|
||||
if pron in self.m_sy2ph_map:
|
||||
phone_list = self.m_sy2ph_map[pron].split(' ')
|
||||
if len(phone_list) in [1, 2]:
|
||||
syll = Syllable()
|
||||
for phone in phone_list:
|
||||
syll.m_phone_list.append(phone)
|
||||
syll.m_tone = Tone.parse(tone)
|
||||
syll.m_language = Language.ZhHK
|
||||
syllable_list.append(syll)
|
||||
return True
|
||||
else:
|
||||
logging.error(
|
||||
'ZhHKSyllableFormatter.Format: invalid phone: %s', pron)
|
||||
return False
|
||||
else:
|
||||
logging.error(
|
||||
'ZhHKSyllableFormatter.Format: syllable to phone map missing key: %s',
|
||||
pron,
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
class WuuShanghaiSyllableFormatter:
|
||||
|
||||
def __init__(self, sy2ph_map):
|
||||
self.m_sy2ph_map = sy2ph_map
|
||||
|
||||
def format(self, phoneset, pronText, syllable_list):
|
||||
if phoneset is None or syllable_list is None or pronText is None:
|
||||
logging.error('WuuShanghaiSyllableFormatter.Format: invalid input')
|
||||
return False
|
||||
|
||||
match = RegexPron.search(pronText)
|
||||
if match:
|
||||
pron = match.group('Pron')
|
||||
tone = match.group('Tone')
|
||||
else:
|
||||
logging.error(
|
||||
'WuuShanghaiSyllableFormatter.Format: pronunciation is not valid: %s',
|
||||
pronText,
|
||||
)
|
||||
return False
|
||||
|
||||
if pron in self.m_sy2ph_map:
|
||||
phone_list = self.m_sy2ph_map[pron].split(' ')
|
||||
if len(phone_list) in [1, 2]:
|
||||
syll = Syllable()
|
||||
for phone in phone_list:
|
||||
syll.m_phone_list.append(phone)
|
||||
syll.m_tone = Tone.parse(tone)
|
||||
syll.m_language = Language.WuuShanghai
|
||||
syllable_list.append(syll)
|
||||
return True
|
||||
else:
|
||||
logging.error(
|
||||
'WuuShanghaiSyllableFormatter.Format: invalid phone: %s',
|
||||
pron)
|
||||
return False
|
||||
else:
|
||||
logging.error(
|
||||
'WuuShanghaiSyllableFormatter.Format: syllable to phone map missing key: %s',
|
||||
pron,
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
class SichuanSyllableFormatter:
|
||||
|
||||
def __init__(self, sy2ph_map):
|
||||
self.m_sy2ph_map = sy2ph_map
|
||||
|
||||
def format(self, phoneset, pronText, syllable_list):
|
||||
if phoneset is None or syllable_list is None or pronText is None:
|
||||
logging.error('SichuanSyllableFormatter.Format: invalid input')
|
||||
return False
|
||||
|
||||
match = RegexPron.search(pronText)
|
||||
if match:
|
||||
pron = match.group('Pron')
|
||||
tone = match.group('Tone')
|
||||
else:
|
||||
logging.error(
|
||||
'SichuanSyllableFormatter.Format: pronunciation is not valid: %s',
|
||||
pronText,
|
||||
)
|
||||
return False
|
||||
|
||||
if pron in self.m_sy2ph_map:
|
||||
phone_list = self.m_sy2ph_map[pron].split(' ')
|
||||
if len(phone_list) in [1, 2]:
|
||||
syll = Syllable()
|
||||
for phone in phone_list:
|
||||
syll.m_phone_list.append(phone)
|
||||
syll.m_tone = Tone.parse(tone)
|
||||
syll.m_language = Language.Sichuan
|
||||
syllable_list.append(syll)
|
||||
return True
|
||||
else:
|
||||
logging.error(
|
||||
'SichuanSyllableFormatter.Format: invalid phone: %s', pron)
|
||||
return False
|
||||
else:
|
||||
logging.error(
|
||||
'SichuanSyllableFormatter.Format: syllable to phone map missing key: %s',
|
||||
pron,
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
class EnXXSyllableFormatter:
|
||||
|
||||
def __init__(self, language):
|
||||
self.m_f2t_map = None
|
||||
self.m_language = language
|
||||
|
||||
def normalize_pron(self, pronText):
|
||||
newPron = pronText.replace('#', '.')
|
||||
newPron = (
|
||||
newPron.replace('03',
|
||||
'0').replace('13',
|
||||
'1').replace('23',
|
||||
'2').replace('3', ''))
|
||||
newPron = newPron.replace('2', '0')
|
||||
|
||||
return newPron
|
||||
|
||||
def format(self, phoneset, pronText, syllable_list):
|
||||
if phoneset is None or syllable_list is None or pronText is None:
|
||||
logging.error('EnXXSyllableFormatter.Format: invalid input')
|
||||
return False
|
||||
pronText = self.normalize_pron(pronText)
|
||||
|
||||
syllables = [ele.strip() for ele in pronText.split('.')]
|
||||
|
||||
for i in range(len(syllables)):
|
||||
syll = Syllable()
|
||||
syll.m_language = self.m_language
|
||||
syll.m_tone = Tone.parse('0')
|
||||
|
||||
phones = re.split(r'[\s]+', syllables[i])
|
||||
|
||||
for j in range(len(phones)):
|
||||
phoneName = phones[j].lower()
|
||||
toneName = '0'
|
||||
|
||||
if '0' in phoneName or '1' in phoneName or '2' in phoneName:
|
||||
toneName = phoneName[-1]
|
||||
phoneName = phoneName[:-1]
|
||||
|
||||
phoneName_lst = None
|
||||
if self.m_f2t_map is not None:
|
||||
phoneName_lst = self.m_f2t_map.get(phoneName, None)
|
||||
if phoneName_lst is None:
|
||||
phoneName_lst = [phoneName]
|
||||
|
||||
for new_phoneName in phoneName_lst:
|
||||
phone_obj = phoneset.m_name_map.get(new_phoneName, None)
|
||||
if phone_obj is None:
|
||||
logging.error(
|
||||
'EnXXSyllableFormatter.Format: phone %s not found',
|
||||
new_phoneName,
|
||||
)
|
||||
return False
|
||||
phone_obj.m_name = new_phoneName
|
||||
syll.m_phone_list.append(phone_obj)
|
||||
if phone_obj.m_cv_type == PhoneCVType.Vowel:
|
||||
syll.m_tone = Tone.parse(toneName)
|
||||
|
||||
if j == len(phones) - 1:
|
||||
phone_obj.m_bnd = True
|
||||
syllable_list.append(syll)
|
||||
return True
|
||||
@@ -1,116 +0,0 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import codecs
|
||||
import re
|
||||
import unicodedata
|
||||
|
||||
WordPattern = r'((?P<Word>\w+)(\(\w+\))?)'
|
||||
BreakPattern = r'(?P<Break>(\*?#(?P<BreakLevel>[0-4])))'
|
||||
MarkPattern = r'(?P<Mark>[、,。!?:“”《》·])'
|
||||
POSPattern = r'(?P<POS>(\*?\|(?P<POSClass>[1-9])))'
|
||||
PhraseTonePattern = r'(?P<PhraseTone>(\*?%([L|H])))'
|
||||
|
||||
NgBreakPattern = r'^ng(?P<break>\d)'
|
||||
|
||||
RegexWord = re.compile(WordPattern + r'\s*')
|
||||
RegexBreak = re.compile(BreakPattern + r'\s*')
|
||||
RegexID = re.compile(r'^(?P<ID>[a-zA-Z\-_0-9\.]+)\s*')
|
||||
RegexSentence = re.compile(r'({}|{}|{}|{}|{})\s*'.format(
|
||||
WordPattern, BreakPattern, MarkPattern, POSPattern, PhraseTonePattern))
|
||||
RegexForeignLang = re.compile(r'[A-Z@]')
|
||||
RegexSpace = re.compile(r'^\s*')
|
||||
RegexNeutralTone = re.compile(r'[1-5]5')
|
||||
|
||||
|
||||
def do_character_normalization(line):
|
||||
return unicodedata.normalize('NFKC', line)
|
||||
|
||||
|
||||
def do_prosody_text_normalization(line):
|
||||
tokens = line.split('\t')
|
||||
text = tokens[1]
|
||||
# Remove punctuations
|
||||
text = text.replace(u'。', ' ')
|
||||
text = text.replace(u'、', ' ')
|
||||
text = text.replace(u'“', ' ')
|
||||
text = text.replace(u'”', ' ')
|
||||
text = text.replace(u'‘', ' ')
|
||||
text = text.replace(u'’', ' ')
|
||||
text = text.replace(u'|', ' ')
|
||||
text = text.replace(u'《', ' ')
|
||||
text = text.replace(u'》', ' ')
|
||||
text = text.replace(u'【', ' ')
|
||||
text = text.replace(u'】', ' ')
|
||||
text = text.replace(u'—', ' ')
|
||||
text = text.replace(u'―', ' ')
|
||||
text = text.replace('.', ' ')
|
||||
text = text.replace('!', ' ')
|
||||
text = text.replace('?', ' ')
|
||||
text = text.replace('(', ' ')
|
||||
text = text.replace(')', ' ')
|
||||
text = text.replace('[', ' ')
|
||||
text = text.replace(']', ' ')
|
||||
text = text.replace('{', ' ')
|
||||
text = text.replace('}', ' ')
|
||||
text = text.replace('~', ' ')
|
||||
text = text.replace(':', ' ')
|
||||
text = text.replace(';', ' ')
|
||||
text = text.replace('+', ' ')
|
||||
text = text.replace(',', ' ')
|
||||
# text = text.replace('·', ' ')
|
||||
text = text.replace('"', ' ')
|
||||
text = text.replace(
|
||||
'-',
|
||||
'') # don't replace by space because compound word like two-year-old
|
||||
text = text.replace(
|
||||
"'", '') # don't replace by space because English word like that's
|
||||
|
||||
# Replace break
|
||||
text = text.replace('/', '#2')
|
||||
text = text.replace('%', '#3')
|
||||
# Remove useless spaces surround #2 #3 #4
|
||||
text = re.sub(r'(#\d)[ ]+', r'\1', text)
|
||||
text = re.sub(r'[ ]+(#\d)', r'\1', text)
|
||||
# Replace space by #1
|
||||
text = re.sub('[ ]+', '#1', text)
|
||||
|
||||
# Remove break at the end of the text
|
||||
text = re.sub(r'#\d$', '', text)
|
||||
|
||||
# Add #1 between target language and foreign language
|
||||
text = re.sub(r"([a-zA-Z])([^a-zA-Z\d\#\s\'\%\/\-])", r'\1#1\2', text)
|
||||
text = re.sub(r"([^a-zA-Z\d\#\s\'\%\/\-])([a-zA-Z])", r'\1#1\2', text)
|
||||
|
||||
return tokens[0] + '\t' + text
|
||||
|
||||
|
||||
def is_fp_line(line):
|
||||
fp_category_list = ['FP', 'I', 'N', 'Q']
|
||||
elements = line.strip().split(' ')
|
||||
res = True
|
||||
for ele in elements:
|
||||
if ele not in fp_category_list:
|
||||
res = False
|
||||
break
|
||||
return res
|
||||
|
||||
|
||||
def format_prosody(src_prosody):
|
||||
formatted_lines = []
|
||||
with codecs.open(src_prosody, 'r', 'utf-8') as f:
|
||||
lines = f.readlines()
|
||||
|
||||
idx = 0
|
||||
while idx < len(lines):
|
||||
line = do_character_normalization(lines[idx])
|
||||
|
||||
if len(line.strip().split('\t')) == 2:
|
||||
line = do_prosody_text_normalization(line)
|
||||
else:
|
||||
fp_enable = is_fp_line(line)
|
||||
if fp_enable:
|
||||
idx += 3
|
||||
continue
|
||||
formatted_lines.append(line)
|
||||
idx += 1
|
||||
return formatted_lines
|
||||
@@ -1,19 +0,0 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
|
||||
class XmlObj:
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def load(self):
|
||||
pass
|
||||
|
||||
def save(self):
|
||||
pass
|
||||
|
||||
def load_data(self):
|
||||
pass
|
||||
|
||||
def save_data(self):
|
||||
pass
|
||||
@@ -1,500 +0,0 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
|
||||
from bitstring import BitArray
|
||||
from tqdm import tqdm
|
||||
|
||||
from modelscope.utils.logger import get_logger
|
||||
from .core.core_types import BreakLevel, Language
|
||||
from .core.phone_set import PhoneSet
|
||||
from .core.pos_set import PosSet
|
||||
from .core.script import Script
|
||||
from .core.script_item import ScriptItem
|
||||
from .core.script_sentence import ScriptSentence
|
||||
from .core.script_word import SpokenMark, SpokenWord, WrittenMark, WrittenWord
|
||||
from .core.utils import (RegexForeignLang, RegexID, RegexSentence,
|
||||
format_prosody)
|
||||
|
||||
from .core.utils import RegexNeutralTone # isort:skip
|
||||
|
||||
from .core.syllable_formatter import ( # isort:skip
|
||||
EnXXSyllableFormatter, PinYinSyllableFormatter, # isort:skip
|
||||
SichuanSyllableFormatter, # isort:skip
|
||||
WuuShanghaiSyllableFormatter, ZhCNSyllableFormatter, # isort:skip
|
||||
ZhHKSyllableFormatter) # isort:skip
|
||||
|
||||
logging = get_logger()
|
||||
|
||||
|
||||
class TextScriptConvertor:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
phoneset_path,
|
||||
posset_path,
|
||||
target_lang,
|
||||
foreign_lang,
|
||||
f2t_map_path,
|
||||
s2p_map_path,
|
||||
m_emo_tag_path,
|
||||
m_speaker,
|
||||
):
|
||||
self.m_f2p_map = {}
|
||||
self.m_s2p_map = {}
|
||||
self.m_phoneset = PhoneSet(phoneset_path)
|
||||
self.m_posset = PosSet(posset_path)
|
||||
self.m_target_lang = Language.parse(target_lang)
|
||||
self.m_foreign_lang = Language.parse(foreign_lang)
|
||||
self.m_emo_tag_path = m_emo_tag_path
|
||||
self.m_speaker = m_speaker
|
||||
|
||||
self.load_f2tmap(f2t_map_path)
|
||||
self.load_s2pmap(s2p_map_path)
|
||||
|
||||
self.m_target_lang_syllable_formatter = self.init_syllable_formatter(
|
||||
self.m_target_lang)
|
||||
self.m_foreign_lang_syllable_formatter = self.init_syllable_formatter(
|
||||
self.m_foreign_lang)
|
||||
|
||||
def parse_sentence(self, sentence, line_num):
|
||||
script_item = ScriptItem(self.m_phoneset, self.m_posset)
|
||||
script_sentence = ScriptSentence(self.m_phoneset, self.m_posset)
|
||||
script_item.m_scriptSentence_list.append(script_sentence)
|
||||
|
||||
written_sentence = script_sentence.m_writtenSentence
|
||||
spoken_sentence = script_sentence.m_spokenSentence
|
||||
|
||||
position = 0
|
||||
|
||||
sentence = sentence.strip()
|
||||
|
||||
# Get ID
|
||||
match = re.search(RegexID, sentence)
|
||||
if match is None:
|
||||
logging.error(
|
||||
'TextScriptConvertor.parse_sentence:invalid line: %s,\
|
||||
line ID is needed',
|
||||
line_num,
|
||||
)
|
||||
return None
|
||||
else:
|
||||
sentence_id = match.group('ID')
|
||||
script_item.m_id = sentence_id
|
||||
position += match.end()
|
||||
|
||||
prevSpokenWord = SpokenWord()
|
||||
|
||||
prevWord = False
|
||||
lastBreak = False
|
||||
|
||||
for m in re.finditer(RegexSentence, sentence[position:]):
|
||||
if m is None:
|
||||
logging.error(
|
||||
'TextScriptConvertor.parse_sentence:\
|
||||
invalid line: %s, there is no matched pattern',
|
||||
line_num,
|
||||
)
|
||||
return None
|
||||
|
||||
if m.group('Word') is not None:
|
||||
wordName = m.group('Word')
|
||||
written_word = WrittenWord()
|
||||
written_word.m_name = wordName
|
||||
written_sentence.add_host(written_word)
|
||||
|
||||
spoken_word = SpokenWord()
|
||||
spoken_word.m_name = wordName
|
||||
prevSpokenWord = spoken_word
|
||||
prevWord = True
|
||||
lastBreak = False
|
||||
elif m.group('Break') is not None:
|
||||
breakText = m.group('BreakLevel')
|
||||
if len(breakText) == 0:
|
||||
breakLevel = BreakLevel.L1
|
||||
else:
|
||||
breakLevel = BreakLevel.parse(breakText)
|
||||
if prevWord:
|
||||
prevSpokenWord.m_breakText = breakText
|
||||
spoken_sentence.add_host(prevSpokenWord)
|
||||
|
||||
if breakLevel != BreakLevel.L1:
|
||||
spokenMark = SpokenMark()
|
||||
spokenMark.m_breakLevel = breakLevel
|
||||
spoken_sentence.add_accompany(spokenMark)
|
||||
|
||||
lastBreak = True
|
||||
|
||||
elif m.group('PhraseTone') is not None:
|
||||
pass
|
||||
elif m.group('POS') is not None:
|
||||
POSClass = m.group('POSClass')
|
||||
if prevWord:
|
||||
prevSpokenWord.m_pos = POSClass
|
||||
prevWord = False
|
||||
elif m.group('Mark') is not None:
|
||||
markText = m.group('Mark')
|
||||
|
||||
writtenMark = WrittenMark()
|
||||
writtenMark.m_punctuation = markText
|
||||
written_sentence.add_accompany(writtenMark)
|
||||
else:
|
||||
logging.error(
|
||||
'TextScriptConvertor.parse_sentence:\
|
||||
invalid line: %s, matched pattern is unrecognized',
|
||||
line_num,
|
||||
)
|
||||
return None
|
||||
|
||||
if not lastBreak:
|
||||
prevSpokenWord.m_breakText = '4'
|
||||
spoken_sentence.add_host(prevSpokenWord)
|
||||
|
||||
spoken_word_cnt = len(spoken_sentence.m_spoken_word_list)
|
||||
spoken_mark_cnt = len(spoken_sentence.m_spoken_mark_list)
|
||||
if (spoken_word_cnt > 0
|
||||
and spoken_sentence.m_align_list[spoken_word_cnt - 1]
|
||||
== spoken_mark_cnt):
|
||||
spokenMark = SpokenMark()
|
||||
spokenMark.m_breakLevel = BreakLevel.L4
|
||||
spoken_sentence.add_accompany(spokenMark)
|
||||
|
||||
written_sentence.build_sequence()
|
||||
spoken_sentence.build_sequence()
|
||||
written_sentence.build_text()
|
||||
spoken_sentence.build_text()
|
||||
|
||||
script_sentence.m_text = written_sentence.m_text
|
||||
script_item.m_text = written_sentence.m_text
|
||||
|
||||
return script_item
|
||||
|
||||
def format_syllable(self, pron, syllable_list):
|
||||
isForeign = RegexForeignLang.search(pron) is not None
|
||||
if self.m_foreign_lang_syllable_formatter is not None and isForeign:
|
||||
return self.m_foreign_lang_syllable_formatter.format(
|
||||
self.m_phoneset, pron, syllable_list)
|
||||
else:
|
||||
return self.m_target_lang_syllable_formatter.format(
|
||||
self.m_phoneset, pron, syllable_list)
|
||||
|
||||
def get_word_prons(self, pronText):
|
||||
prons = pronText.split('/')
|
||||
res = []
|
||||
|
||||
for pron in prons:
|
||||
if re.search(RegexForeignLang, pron):
|
||||
res.append(pron.strip())
|
||||
else:
|
||||
res.extend(pron.strip().split(' '))
|
||||
return res
|
||||
|
||||
def is_erhuayin(self, pron):
|
||||
pron = RegexNeutralTone.sub('5', pron)
|
||||
pron = pron[:-1]
|
||||
|
||||
return pron[-1] == 'r' and pron != 'er'
|
||||
|
||||
def parse_pronunciation(self, script_item, pronunciation, line_num):
|
||||
spoken_sentence = script_item.m_scriptSentence_list[0].m_spokenSentence
|
||||
|
||||
wordProns = self.get_word_prons(pronunciation)
|
||||
|
||||
wordIndex = 0
|
||||
pronIndex = 0
|
||||
succeed = True
|
||||
|
||||
while pronIndex < len(wordProns):
|
||||
language = Language.Neutral
|
||||
syllable_list = []
|
||||
|
||||
pron = wordProns[pronIndex].strip()
|
||||
|
||||
succeed = self.format_syllable(pron, syllable_list)
|
||||
if not succeed:
|
||||
logging.error(
|
||||
'TextScriptConvertor.parse_pronunciation:\
|
||||
invalid line: %s, error pronunciation: %s,\
|
||||
syllable format error',
|
||||
line_num,
|
||||
pron,
|
||||
)
|
||||
return False
|
||||
language = syllable_list[0].m_language
|
||||
|
||||
if wordIndex < len(spoken_sentence.m_spoken_word_list):
|
||||
if language in [Language.EnGB, Language.EnUS]:
|
||||
spoken_sentence.m_spoken_word_list[
|
||||
wordIndex].m_syllable_list.extend(syllable_list)
|
||||
wordIndex += 1
|
||||
pronIndex += 1
|
||||
elif language in [
|
||||
Language.ZhCN,
|
||||
Language.PinYin,
|
||||
Language.ZhHK,
|
||||
Language.WuuShanghai,
|
||||
Language.Sichuan,
|
||||
]:
|
||||
charCount = len(
|
||||
spoken_sentence.m_spoken_word_list[wordIndex].m_name)
|
||||
if (language in [
|
||||
Language.ZhCN, Language.PinYin, Language.Sichuan
|
||||
] and self.is_erhuayin(pron) and '儿' in spoken_sentence.
|
||||
m_spoken_word_list[wordIndex].m_name):
|
||||
spoken_sentence.m_spoken_word_list[
|
||||
wordIndex].m_name = spoken_sentence.m_spoken_word_list[
|
||||
wordIndex].m_name.replace('儿', '')
|
||||
charCount -= 1
|
||||
if charCount == 1:
|
||||
spoken_sentence.m_spoken_word_list[
|
||||
wordIndex].m_syllable_list.extend(syllable_list)
|
||||
wordIndex += 1
|
||||
pronIndex += 1
|
||||
else:
|
||||
# FIXME(Jin): Just skip the first char then match the rest char.
|
||||
i = 1
|
||||
while i >= 1 and i < charCount:
|
||||
pronIndex += 1
|
||||
if pronIndex < len(wordProns):
|
||||
pron = wordProns[pronIndex].strip()
|
||||
succeed = self.format_syllable(
|
||||
pron, syllable_list)
|
||||
if not succeed:
|
||||
logging.error(
|
||||
'TextScriptConvertor.parse_pronunciation: invalid line: %s, \
|
||||
error pronunciation: %s, syllable format error',
|
||||
line_num,
|
||||
pron,
|
||||
)
|
||||
return False
|
||||
if (language in [
|
||||
Language.ZhCN,
|
||||
Language.PinYin,
|
||||
Language.Sichuan,
|
||||
] and self.is_erhuayin(pron)
|
||||
and '儿' in spoken_sentence.
|
||||
m_spoken_word_list[wordIndex].m_name):
|
||||
spoken_sentence.m_spoken_word_list[
|
||||
wordIndex].m_name = spoken_sentence.m_spoken_word_list[
|
||||
wordIndex].m_name.replace('儿', '')
|
||||
charCount -= 1
|
||||
else:
|
||||
logging.error(
|
||||
'TextScriptConvertor.parse_pronunciation: invalid line: %s, \
|
||||
error pronunciation: %s, Word count mismatch with Pron count',
|
||||
line_num,
|
||||
pron,
|
||||
)
|
||||
return False
|
||||
i += 1
|
||||
spoken_sentence.m_spoken_word_list[
|
||||
wordIndex].m_syllable_list.extend(syllable_list)
|
||||
wordIndex += 1
|
||||
pronIndex += 1
|
||||
else:
|
||||
logging.error(
|
||||
'TextScriptConvertor.parse_pronunciation: invalid line: %s, \
|
||||
unsupported language: %s',
|
||||
line_num,
|
||||
language.name,
|
||||
)
|
||||
return False
|
||||
|
||||
else:
|
||||
logging.error(
|
||||
'TextScriptConvertor.parse_pronunciation: invalid line: %s, \
|
||||
error pronunciation: %s, word index is out of range',
|
||||
line_num,
|
||||
pron,
|
||||
)
|
||||
return False
|
||||
if pronIndex != len(wordProns):
|
||||
logging.error(
|
||||
'TextScriptConvertor.parse_pronunciation: invalid line: %s, \
|
||||
error pronunciation: %s, pron count mismatch with word count',
|
||||
line_num,
|
||||
pron,
|
||||
)
|
||||
return False
|
||||
|
||||
if wordIndex != len(spoken_sentence.m_spoken_word_list):
|
||||
logging.error(
|
||||
'TextScriptConvertor.parse_pronunciation: invalid line: %s, \
|
||||
error pronunciation: %s, word count mismatch with word index',
|
||||
line_num,
|
||||
pron,
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def load_f2tmap(self, file_path):
|
||||
with open(file_path, 'r') as f:
|
||||
for line in f.readlines():
|
||||
line = line.strip()
|
||||
elements = line.split('\t')
|
||||
if len(elements) != 2:
|
||||
logging.error(
|
||||
'TextScriptConvertor.LoadF2TMap: invalid line: %s',
|
||||
line)
|
||||
continue
|
||||
key = elements[0]
|
||||
value = elements[1]
|
||||
value_list = value.split(' ')
|
||||
if key in self.m_f2p_map:
|
||||
logging.error(
|
||||
'TextScriptConvertor.LoadF2TMap: duplicate key: %s',
|
||||
key)
|
||||
self.m_f2p_map[key] = value_list
|
||||
|
||||
def load_s2pmap(self, file_path):
|
||||
with open(file_path, 'r') as f:
|
||||
for line in f.readlines():
|
||||
line = line.strip()
|
||||
elements = line.split('\t')
|
||||
if len(elements) != 2:
|
||||
logging.error(
|
||||
'TextScriptConvertor.LoadS2PMap: invalid line: %s',
|
||||
line)
|
||||
continue
|
||||
key = elements[0]
|
||||
value = elements[1]
|
||||
if key in self.m_s2p_map:
|
||||
logging.error(
|
||||
'TextScriptConvertor.LoadS2PMap: duplicate key: %s',
|
||||
key)
|
||||
self.m_s2p_map[key] = value
|
||||
|
||||
def init_syllable_formatter(self, targetLang):
|
||||
if targetLang == Language.ZhCN:
|
||||
if len(self.m_s2p_map) == 0:
|
||||
logging.error(
|
||||
'TextScriptConvertor.InitSyllableFormatter: ZhCN syllable to phone map is empty'
|
||||
)
|
||||
return None
|
||||
return ZhCNSyllableFormatter(self.m_s2p_map)
|
||||
elif targetLang == Language.PinYin:
|
||||
if len(self.m_s2p_map) == 0:
|
||||
logging.error(
|
||||
'TextScriptConvertor.InitSyllableFormatter: PinYin syllable to phone map is empty'
|
||||
)
|
||||
return None
|
||||
return PinYinSyllableFormatter(self.m_s2p_map)
|
||||
elif targetLang == Language.ZhHK:
|
||||
if len(self.m_s2p_map) == 0:
|
||||
logging.error(
|
||||
'TextScriptConvertor.InitSyllableFormatter: ZhHK syllable to phone map is empty'
|
||||
)
|
||||
return None
|
||||
return ZhHKSyllableFormatter(self.m_s2p_map)
|
||||
elif targetLang == Language.WuuShanghai:
|
||||
if len(self.m_s2p_map) == 0:
|
||||
logging.error(
|
||||
'TextScriptConvertor.InitSyllableFormatter: WuuShanghai syllable to phone map is empty'
|
||||
)
|
||||
return None
|
||||
return WuuShanghaiSyllableFormatter(self.m_s2p_map)
|
||||
elif targetLang == Language.Sichuan:
|
||||
if len(self.m_s2p_map) == 0:
|
||||
logging.error(
|
||||
'TextScriptConvertor.InitSyllableFormatter: Sichuan syllable to phone map is empty'
|
||||
)
|
||||
return None
|
||||
return SichuanSyllableFormatter(self.m_s2p_map)
|
||||
elif targetLang == Language.EnGB:
|
||||
formatter = EnXXSyllableFormatter(Language.EnGB)
|
||||
if len(self.m_f2p_map) != 0:
|
||||
formatter.m_f2t_map = self.m_f2p_map
|
||||
return formatter
|
||||
elif targetLang == Language.EnUS:
|
||||
formatter = EnXXSyllableFormatter(Language.EnUS)
|
||||
if len(self.m_f2p_map) != 0:
|
||||
formatter.m_f2t_map = self.m_f2p_map
|
||||
return formatter
|
||||
else:
|
||||
logging.error(
|
||||
'TextScriptConvertor.InitSyllableFormatter: unsupported language: %s',
|
||||
targetLang,
|
||||
)
|
||||
return None
|
||||
|
||||
def process(self, textScriptPath, outputXMLPath, outputMetafile):
|
||||
script = Script(self.m_phoneset, self.m_posset)
|
||||
formatted_lines = format_prosody(textScriptPath)
|
||||
line_num = 0
|
||||
for line in tqdm(formatted_lines):
|
||||
if line_num % 2 == 0:
|
||||
sentence = line.strip()
|
||||
item = self.parse_sentence(sentence, line_num)
|
||||
else:
|
||||
if item is not None:
|
||||
pronunciation = line.strip()
|
||||
res = self.parse_pronunciation(item, pronunciation,
|
||||
line_num)
|
||||
if res:
|
||||
script.m_items.append(item)
|
||||
|
||||
line_num += 1
|
||||
|
||||
script.save(outputXMLPath)
|
||||
logging.info('TextScriptConvertor.process:\nSave script to: %s',
|
||||
outputXMLPath)
|
||||
|
||||
meta_lines = script.save_meta_file()
|
||||
emo = 'emotion_neutral'
|
||||
speaker = self.m_speaker
|
||||
|
||||
meta_lines_tagged = []
|
||||
for line in meta_lines:
|
||||
line_id, line_text = line.split('\t')
|
||||
syll_items = line_text.split(' ')
|
||||
syll_items_tagged = []
|
||||
for syll_item in syll_items:
|
||||
syll_item_tagged = syll_item[:-1] + '$' + emo + '$' + speaker + '}'
|
||||
syll_items_tagged.append(syll_item_tagged)
|
||||
meta_lines_tagged.append(line_id + '\t'
|
||||
+ ' '.join(syll_items_tagged))
|
||||
with open(outputMetafile, 'w') as f:
|
||||
for line in meta_lines_tagged:
|
||||
f.write(line + '\n')
|
||||
|
||||
logging.info('TextScriptConvertor.process:\nSave metafile to: %s',
|
||||
outputMetafile)
|
||||
|
||||
@staticmethod
|
||||
def turn_text_into_bytes(plain_text_path, output_meta_file_path, speaker):
|
||||
meta_lines = []
|
||||
with open(plain_text_path, 'r') as in_file:
|
||||
for text_line in in_file:
|
||||
[sentence_id, sentence] = text_line.strip().split('\t')
|
||||
sequence = []
|
||||
for character in sentence:
|
||||
hex_string = character.encode('utf-8').hex()
|
||||
i = 0
|
||||
while i < len(hex_string):
|
||||
byte_hex = hex_string[i:i + 2]
|
||||
bit_array = BitArray(hex=byte_hex)
|
||||
integer = bit_array.uint
|
||||
if integer > 255:
|
||||
logging.error(
|
||||
'TextScriptConverter.turn_text_into_bytes: invalid byte conversion in sentence {} \
|
||||
character {}: (uint) {} - (hex) {}'.
|
||||
format(
|
||||
sentence_id,
|
||||
character,
|
||||
integer,
|
||||
character.encode('utf-8').hex(),
|
||||
))
|
||||
continue
|
||||
sequence.append('{{{}$emotion_neutral${}}}'.format(
|
||||
integer, speaker))
|
||||
i += 2
|
||||
if sequence[-1][1:].split('$')[0] not in ['33', '46', '63']:
|
||||
sequence.append(
|
||||
'{{46$emotion_neutral${}}}'.format(speaker))
|
||||
meta_lines.append('{}\t{}\n'.format(sentence_id,
|
||||
' '.join(sequence)))
|
||||
with open(output_meta_file_path, 'w') as out_file:
|
||||
out_file.writelines(meta_lines)
|
||||
@@ -1,562 +0,0 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from modelscope.models.audio.tts.kantts.models.utils import \
|
||||
get_mask_from_lengths
|
||||
from modelscope.models.audio.tts.kantts.utils.audio_torch import (
|
||||
MelSpectrogram, stft)
|
||||
|
||||
|
||||
class MelReconLoss(torch.nn.Module):
|
||||
|
||||
def __init__(self, loss_type='mae'):
|
||||
super(MelReconLoss, self).__init__()
|
||||
self.loss_type = loss_type
|
||||
if loss_type == 'mae':
|
||||
self.criterion = torch.nn.L1Loss(reduction='none')
|
||||
elif loss_type == 'mse':
|
||||
self.criterion = torch.nn.MSELoss(reduction='none')
|
||||
else:
|
||||
raise ValueError('Unknown loss type: {}'.format(loss_type))
|
||||
|
||||
def forward(self,
|
||||
output_lengths,
|
||||
mel_targets,
|
||||
dec_outputs,
|
||||
postnet_outputs=None):
|
||||
output_masks = get_mask_from_lengths(
|
||||
output_lengths, max_len=mel_targets.size(1))
|
||||
output_masks = ~output_masks
|
||||
valid_outputs = output_masks.sum()
|
||||
|
||||
mel_loss_ = torch.sum(
|
||||
self.criterion(mel_targets, dec_outputs)
|
||||
* output_masks.unsqueeze(-1)) / (
|
||||
valid_outputs * mel_targets.size(-1))
|
||||
|
||||
if postnet_outputs is not None:
|
||||
mel_loss = torch.sum(
|
||||
self.criterion(mel_targets, postnet_outputs)
|
||||
* output_masks.unsqueeze(-1)) / (
|
||||
valid_outputs * mel_targets.size(-1))
|
||||
else:
|
||||
mel_loss = 0.0
|
||||
|
||||
return mel_loss_, mel_loss
|
||||
|
||||
|
||||
class ProsodyReconLoss(torch.nn.Module):
|
||||
|
||||
def __init__(self, loss_type='mae'):
|
||||
super(ProsodyReconLoss, self).__init__()
|
||||
self.loss_type = loss_type
|
||||
if loss_type == 'mae':
|
||||
self.criterion = torch.nn.L1Loss(reduction='none')
|
||||
elif loss_type == 'mse':
|
||||
self.criterion = torch.nn.MSELoss(reduction='none')
|
||||
else:
|
||||
raise ValueError('Unknown loss type: {}'.format(loss_type))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_lengths,
|
||||
duration_targets,
|
||||
pitch_targets,
|
||||
energy_targets,
|
||||
log_duration_predictions,
|
||||
pitch_predictions,
|
||||
energy_predictions,
|
||||
):
|
||||
input_masks = get_mask_from_lengths(
|
||||
input_lengths, max_len=duration_targets.size(1))
|
||||
input_masks = ~input_masks
|
||||
valid_inputs = input_masks.sum()
|
||||
|
||||
dur_loss = (
|
||||
torch.sum(
|
||||
self.criterion(
|
||||
torch.log(duration_targets.float() + 1),
|
||||
log_duration_predictions) * input_masks) / valid_inputs)
|
||||
pitch_loss = (
|
||||
torch.sum(
|
||||
self.criterion(pitch_targets, pitch_predictions) * input_masks)
|
||||
/ valid_inputs)
|
||||
energy_loss = (
|
||||
torch.sum(
|
||||
self.criterion(energy_targets, energy_predictions)
|
||||
* input_masks) / valid_inputs)
|
||||
|
||||
return dur_loss, pitch_loss, energy_loss
|
||||
|
||||
|
||||
class FpCELoss(torch.nn.Module):
|
||||
|
||||
def __init__(self, loss_type='ce', weight=[1, 4, 4, 8]):
|
||||
super(FpCELoss, self).__init__()
|
||||
self.loss_type = loss_type
|
||||
weight_ce = torch.FloatTensor(weight).cuda()
|
||||
self.criterion = torch.nn.CrossEntropyLoss(
|
||||
weight=weight_ce, reduction='none')
|
||||
|
||||
def forward(self, input_lengths, fp_pd, fp_label):
|
||||
input_masks = get_mask_from_lengths(
|
||||
input_lengths, max_len=fp_label.size(1))
|
||||
input_masks = ~input_masks
|
||||
valid_inputs = input_masks.sum()
|
||||
|
||||
fp_loss = (
|
||||
torch.sum(
|
||||
self.criterion(fp_pd.transpose(2, 1), fp_label) * input_masks)
|
||||
/ valid_inputs)
|
||||
|
||||
return fp_loss
|
||||
|
||||
|
||||
class GeneratorAdversarialLoss(torch.nn.Module):
|
||||
"""Generator adversarial loss module."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
average_by_discriminators=True,
|
||||
loss_type='mse',
|
||||
):
|
||||
"""Initialize GeneratorAversarialLoss module."""
|
||||
super().__init__()
|
||||
self.average_by_discriminators = average_by_discriminators
|
||||
assert loss_type in ['mse', 'hinge'], f'{loss_type} is not supported.'
|
||||
if loss_type == 'mse':
|
||||
self.criterion = self._mse_loss
|
||||
else:
|
||||
self.criterion = self._hinge_loss
|
||||
|
||||
def forward(self, outputs):
|
||||
"""Calcualate generator adversarial loss.
|
||||
|
||||
Args:
|
||||
outputs (Tensor or list): Discriminator outputs or list of
|
||||
discriminator outputs.
|
||||
|
||||
Returns:
|
||||
Tensor: Generator adversarial loss value.
|
||||
|
||||
"""
|
||||
if isinstance(outputs, (tuple, list)):
|
||||
adv_loss = 0.0
|
||||
for i, outputs_ in enumerate(outputs):
|
||||
adv_loss += self.criterion(outputs_)
|
||||
if self.average_by_discriminators:
|
||||
adv_loss /= i + 1
|
||||
else:
|
||||
adv_loss = self.criterion(outputs)
|
||||
|
||||
return adv_loss
|
||||
|
||||
def _mse_loss(self, x):
|
||||
return F.mse_loss(x, x.new_ones(x.size()))
|
||||
|
||||
def _hinge_loss(self, x):
|
||||
return -x.mean()
|
||||
|
||||
|
||||
class DiscriminatorAdversarialLoss(torch.nn.Module):
|
||||
"""Discriminator adversarial loss module."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
average_by_discriminators=True,
|
||||
loss_type='mse',
|
||||
):
|
||||
"""Initialize DiscriminatorAversarialLoss module."""
|
||||
super().__init__()
|
||||
self.average_by_discriminators = average_by_discriminators
|
||||
assert loss_type in ['mse', 'hinge'], f'{loss_type} is not supported.'
|
||||
if loss_type == 'mse':
|
||||
self.fake_criterion = self._mse_fake_loss
|
||||
self.real_criterion = self._mse_real_loss
|
||||
else:
|
||||
self.fake_criterion = self._hinge_fake_loss
|
||||
self.real_criterion = self._hinge_real_loss
|
||||
|
||||
def forward(self, outputs_hat, outputs):
|
||||
"""Calcualate discriminator adversarial loss.
|
||||
|
||||
Args:
|
||||
outputs_hat (Tensor or list): Discriminator outputs or list of
|
||||
discriminator outputs calculated from generator outputs.
|
||||
outputs (Tensor or list): Discriminator outputs or list of
|
||||
discriminator outputs calculated from groundtruth.
|
||||
|
||||
Returns:
|
||||
Tensor: Discriminator real loss value.
|
||||
Tensor: Discriminator fake loss value.
|
||||
|
||||
"""
|
||||
if isinstance(outputs, (tuple, list)):
|
||||
real_loss = 0.0
|
||||
fake_loss = 0.0
|
||||
for i, (outputs_hat_,
|
||||
outputs_) in enumerate(zip(outputs_hat, outputs)):
|
||||
if isinstance(outputs_hat_, (tuple, list)):
|
||||
# NOTE(kan-bayashi): case including feature maps
|
||||
outputs_hat_ = outputs_hat_[-1]
|
||||
outputs_ = outputs_[-1]
|
||||
real_loss += self.real_criterion(outputs_)
|
||||
fake_loss += self.fake_criterion(outputs_hat_)
|
||||
if self.average_by_discriminators:
|
||||
fake_loss /= i + 1
|
||||
real_loss /= i + 1
|
||||
else:
|
||||
real_loss = self.real_criterion(outputs)
|
||||
fake_loss = self.fake_criterion(outputs_hat)
|
||||
|
||||
return real_loss, fake_loss
|
||||
|
||||
def _mse_real_loss(self, x):
|
||||
return F.mse_loss(x, x.new_ones(x.size()))
|
||||
|
||||
def _mse_fake_loss(self, x):
|
||||
return F.mse_loss(x, x.new_zeros(x.size()))
|
||||
|
||||
def _hinge_real_loss(self, x):
|
||||
return -torch.mean(torch.min(x - 1, x.new_zeros(x.size())))
|
||||
|
||||
def _hinge_fake_loss(self, x):
|
||||
return -torch.mean(torch.min(-x - 1, x.new_zeros(x.size())))
|
||||
|
||||
|
||||
class FeatureMatchLoss(torch.nn.Module):
|
||||
"""Feature matching loss module."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
average_by_layers=True,
|
||||
average_by_discriminators=True,
|
||||
):
|
||||
"""Initialize FeatureMatchLoss module."""
|
||||
super().__init__()
|
||||
self.average_by_layers = average_by_layers
|
||||
self.average_by_discriminators = average_by_discriminators
|
||||
|
||||
def forward(self, feats_hat, feats):
|
||||
"""Calcualate feature matching loss.
|
||||
|
||||
Args:
|
||||
feats_hat (list): List of list of discriminator outputs
|
||||
calcuated from generater outputs.
|
||||
feats (list): List of list of discriminator outputs
|
||||
calcuated from groundtruth.
|
||||
|
||||
Returns:
|
||||
Tensor: Feature matching loss value.
|
||||
|
||||
"""
|
||||
feat_match_loss = 0.0
|
||||
for i, (feats_hat_, feats_) in enumerate(zip(feats_hat, feats)):
|
||||
feat_match_loss_ = 0.0
|
||||
for j, (feat_hat_, feat_) in enumerate(zip(feats_hat_, feats_)):
|
||||
feat_match_loss_ += F.l1_loss(feat_hat_, feat_.detach())
|
||||
if self.average_by_layers:
|
||||
feat_match_loss_ /= j + 1
|
||||
feat_match_loss += feat_match_loss_
|
||||
if self.average_by_discriminators:
|
||||
feat_match_loss /= i + 1
|
||||
|
||||
return feat_match_loss
|
||||
|
||||
|
||||
class MelSpectrogramLoss(torch.nn.Module):
|
||||
"""Mel-spectrogram loss."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fs=22050,
|
||||
fft_size=1024,
|
||||
hop_size=256,
|
||||
win_length=None,
|
||||
window='hann',
|
||||
num_mels=80,
|
||||
fmin=80,
|
||||
fmax=7600,
|
||||
center=True,
|
||||
normalized=False,
|
||||
onesided=True,
|
||||
eps=1e-10,
|
||||
log_base=10.0,
|
||||
):
|
||||
"""Initialize Mel-spectrogram loss."""
|
||||
super().__init__()
|
||||
self.mel_spectrogram = MelSpectrogram(
|
||||
fs=fs,
|
||||
fft_size=fft_size,
|
||||
hop_size=hop_size,
|
||||
win_length=win_length,
|
||||
window=window,
|
||||
num_mels=num_mels,
|
||||
fmin=fmin,
|
||||
fmax=fmax,
|
||||
center=center,
|
||||
normalized=normalized,
|
||||
onesided=onesided,
|
||||
eps=eps,
|
||||
log_base=log_base,
|
||||
)
|
||||
|
||||
def forward(self, y_hat, y):
|
||||
"""Calculate Mel-spectrogram loss.
|
||||
|
||||
Args:
|
||||
y_hat (Tensor): Generated single tensor (B, 1, T).
|
||||
y (Tensor): Groundtruth single tensor (B, 1, T).
|
||||
|
||||
Returns:
|
||||
Tensor: Mel-spectrogram loss value.
|
||||
|
||||
"""
|
||||
mel_hat = self.mel_spectrogram(y_hat)
|
||||
mel = self.mel_spectrogram(y)
|
||||
mel_loss = F.l1_loss(mel_hat, mel)
|
||||
|
||||
return mel_loss
|
||||
|
||||
|
||||
class SpectralConvergenceLoss(torch.nn.Module):
|
||||
"""Spectral convergence loss module."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initilize spectral convergence loss module."""
|
||||
super(SpectralConvergenceLoss, self).__init__()
|
||||
|
||||
def forward(self, x_mag, y_mag):
|
||||
"""Calculate forward propagation.
|
||||
|
||||
Args:
|
||||
x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
|
||||
y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
|
||||
|
||||
Returns:
|
||||
Tensor: Spectral convergence loss value.
|
||||
|
||||
"""
|
||||
return torch.norm(y_mag - x_mag, p='fro') / torch.norm(y_mag, p='fro')
|
||||
|
||||
|
||||
class LogSTFTMagnitudeLoss(torch.nn.Module):
|
||||
"""Log STFT magnitude loss module."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initilize los STFT magnitude loss module."""
|
||||
super(LogSTFTMagnitudeLoss, self).__init__()
|
||||
|
||||
def forward(self, x_mag, y_mag):
|
||||
"""Calculate forward propagation.
|
||||
|
||||
Args:
|
||||
x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
|
||||
y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
|
||||
|
||||
Returns:
|
||||
Tensor: Log STFT magnitude loss value.
|
||||
|
||||
"""
|
||||
return F.l1_loss(torch.log(y_mag), torch.log(x_mag))
|
||||
|
||||
|
||||
class STFTLoss(torch.nn.Module):
|
||||
"""STFT loss module."""
|
||||
|
||||
def __init__(self,
|
||||
fft_size=1024,
|
||||
shift_size=120,
|
||||
win_length=600,
|
||||
window='hann_window'):
|
||||
"""Initialize STFT loss module."""
|
||||
super(STFTLoss, self).__init__()
|
||||
self.fft_size = fft_size
|
||||
self.shift_size = shift_size
|
||||
self.win_length = win_length
|
||||
self.spectral_convergence_loss = SpectralConvergenceLoss()
|
||||
self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss()
|
||||
# NOTE(kan-bayashi): Use register_buffer to fix #223
|
||||
self.register_buffer('window', getattr(torch, window)(win_length))
|
||||
|
||||
def forward(self, x, y):
|
||||
"""Calculate forward propagation.
|
||||
|
||||
Args:
|
||||
x (Tensor): Predicted signal (B, T).
|
||||
y (Tensor): Groundtruth signal (B, T).
|
||||
|
||||
Returns:
|
||||
Tensor: Spectral convergence loss value.
|
||||
Tensor: Log STFT magnitude loss value.
|
||||
|
||||
"""
|
||||
x_mag = stft(x, self.fft_size, self.shift_size, self.win_length,
|
||||
self.window)
|
||||
y_mag = stft(y, self.fft_size, self.shift_size, self.win_length,
|
||||
self.window)
|
||||
sc_loss = self.spectral_convergence_loss(x_mag, y_mag)
|
||||
mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag)
|
||||
|
||||
return sc_loss, mag_loss
|
||||
|
||||
|
||||
class MultiResolutionSTFTLoss(torch.nn.Module):
|
||||
"""Multi resolution STFT loss module."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fft_sizes=[1024, 2048, 512],
|
||||
hop_sizes=[120, 240, 50],
|
||||
win_lengths=[600, 1200, 240],
|
||||
window='hann_window',
|
||||
):
|
||||
"""Initialize Multi resolution STFT loss module.
|
||||
|
||||
Args:
|
||||
fft_sizes (list): List of FFT sizes.
|
||||
hop_sizes (list): List of hop sizes.
|
||||
win_lengths (list): List of window lengths.
|
||||
window (str): Window function type.
|
||||
|
||||
"""
|
||||
super(MultiResolutionSTFTLoss, self).__init__()
|
||||
assert len(fft_sizes) == len(hop_sizes) == len(win_lengths)
|
||||
self.stft_losses = torch.nn.ModuleList()
|
||||
for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths):
|
||||
self.stft_losses += [STFTLoss(fs, ss, wl, window)]
|
||||
|
||||
def forward(self, x, y):
|
||||
"""Calculate forward propagation.
|
||||
|
||||
Args:
|
||||
x (Tensor): Predicted signal (B, T) or (B, #subband, T).
|
||||
y (Tensor): Groundtruth signal (B, T) or (B, #subband, T).
|
||||
|
||||
Returns:
|
||||
Tensor: Multi resolution spectral convergence loss value.
|
||||
Tensor: Multi resolution log STFT magnitude loss value.
|
||||
|
||||
"""
|
||||
if len(x.shape) == 3:
|
||||
x = x.view(-1, x.size(2)) # (B, C, T) -> (B x C, T)
|
||||
y = y.view(-1, y.size(2)) # (B, C, T) -> (B x C, T)
|
||||
sc_loss = 0.0
|
||||
mag_loss = 0.0
|
||||
for f in self.stft_losses:
|
||||
sc_l, mag_l = f(x, y)
|
||||
sc_loss += sc_l
|
||||
mag_loss += mag_l
|
||||
sc_loss /= len(self.stft_losses)
|
||||
mag_loss /= len(self.stft_losses)
|
||||
|
||||
return sc_loss, mag_loss
|
||||
|
||||
|
||||
class SeqCELoss(torch.nn.Module):
|
||||
|
||||
def __init__(self, loss_type='ce'):
|
||||
super(SeqCELoss, self).__init__()
|
||||
self.loss_type = loss_type
|
||||
self.criterion = torch.nn.CrossEntropyLoss(reduction='none')
|
||||
|
||||
def forward(self, logits, targets, masks):
|
||||
loss = self.criterion(logits.contiguous().view(-1, logits.size(-1)),
|
||||
targets.contiguous().view(-1))
|
||||
preds = torch.argmax(logits, dim=-1).contiguous().view(-1)
|
||||
masks = masks.contiguous().view(-1)
|
||||
|
||||
loss = (loss * masks).sum() / masks.sum()
|
||||
err = torch.sum((preds != targets.view(-1)) * masks) / masks.sum()
|
||||
|
||||
return loss, err
|
||||
|
||||
|
||||
class AttentionBinarizationLoss(torch.nn.Module):
|
||||
|
||||
def __init__(self, start_epoch=0, warmup_epoch=100):
|
||||
super(AttentionBinarizationLoss, self).__init__()
|
||||
self.start_epoch = start_epoch
|
||||
self.warmup_epoch = warmup_epoch
|
||||
|
||||
def forward(self, epoch, hard_attention, soft_attention, eps=1e-12):
|
||||
log_sum = torch.log(
|
||||
torch.clamp(soft_attention[hard_attention == 1], min=eps)).sum()
|
||||
kl_loss = -log_sum / hard_attention.sum()
|
||||
if epoch < self.start_epoch:
|
||||
warmup_ratio = 0
|
||||
else:
|
||||
warmup_ratio = min(1.0,
|
||||
(epoch - self.start_epoch) / self.warmup_epoch)
|
||||
return kl_loss * warmup_ratio
|
||||
|
||||
|
||||
class AttentionCTCLoss(torch.nn.Module):
|
||||
|
||||
def __init__(self, blank_logprob=-1):
|
||||
super(AttentionCTCLoss, self).__init__()
|
||||
self.log_softmax = torch.nn.LogSoftmax(dim=3)
|
||||
self.blank_logprob = blank_logprob
|
||||
self.CTCLoss = torch.nn.CTCLoss(zero_infinity=True)
|
||||
|
||||
def forward(self, attn_logprob, in_lens, out_lens):
|
||||
key_lens = in_lens
|
||||
query_lens = out_lens
|
||||
attn_logprob_padded = F.pad(
|
||||
input=attn_logprob,
|
||||
pad=(1, 0, 0, 0, 0, 0, 0, 0),
|
||||
value=self.blank_logprob)
|
||||
cost_total = 0.0
|
||||
for bid in range(attn_logprob.shape[0]):
|
||||
target_seq = torch.arange(1, key_lens[bid] + 1).unsqueeze(0)
|
||||
curr_logprob = attn_logprob_padded[bid].permute(1, 0, 2)
|
||||
curr_logprob = curr_logprob[:query_lens[bid], :, :key_lens[bid]
|
||||
+ 1]
|
||||
curr_logprob = self.log_softmax(curr_logprob[None])[0]
|
||||
ctc_cost = self.CTCLoss(
|
||||
curr_logprob,
|
||||
target_seq,
|
||||
input_lengths=query_lens[bid:bid + 1],
|
||||
target_lengths=key_lens[bid:bid + 1],
|
||||
)
|
||||
cost_total += ctc_cost
|
||||
cost = cost_total / attn_logprob.shape[0]
|
||||
return cost
|
||||
|
||||
|
||||
loss_dict = {
|
||||
'generator_adv_loss': GeneratorAdversarialLoss,
|
||||
'discriminator_adv_loss': DiscriminatorAdversarialLoss,
|
||||
'stft_loss': MultiResolutionSTFTLoss,
|
||||
'mel_loss': MelSpectrogramLoss,
|
||||
'subband_stft_loss': MultiResolutionSTFTLoss,
|
||||
'feat_match_loss': FeatureMatchLoss,
|
||||
'MelReconLoss': MelReconLoss,
|
||||
'ProsodyReconLoss': ProsodyReconLoss,
|
||||
'SeqCELoss': SeqCELoss,
|
||||
'AttentionBinarizationLoss': AttentionBinarizationLoss,
|
||||
'AttentionCTCLoss': AttentionCTCLoss,
|
||||
'FpCELoss': FpCELoss,
|
||||
}
|
||||
|
||||
|
||||
def criterion_builder(config, device='cpu'):
|
||||
"""Criterion builder.
|
||||
Args:
|
||||
config (dict): Config dictionary.
|
||||
Returns:
|
||||
criterion (dict): Loss dictionary
|
||||
"""
|
||||
criterion = {}
|
||||
for key, value in config['Loss'].items():
|
||||
if key in loss_dict:
|
||||
if value['enable']:
|
||||
criterion[key] = loss_dict[key](
|
||||
**value.get('params', {})).to(device)
|
||||
setattr(criterion[key], 'weights', value.get('weights', 1.0))
|
||||
else:
|
||||
raise NotImplementedError('{} is not implemented'.format(key))
|
||||
|
||||
return criterion
|
||||
@@ -1,44 +0,0 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from torch.optim.lr_scheduler import MultiStepLR, _LRScheduler
|
||||
|
||||
|
||||
class FindLR(_LRScheduler):
|
||||
"""
|
||||
inspired by fast.ai @https://sgugger.github.io/how-do-you-find-a-good-learning-rate.html
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer, max_steps, max_lr=10):
|
||||
self.max_steps = max_steps
|
||||
self.max_lr = max_lr
|
||||
super().__init__(optimizer)
|
||||
|
||||
def get_lr(self):
|
||||
return [
|
||||
base_lr * ((self.max_lr / base_lr)**(
|
||||
self.last_epoch / # noqa W504
|
||||
(self.max_steps - 1))) for base_lr in self.base_lrs
|
||||
]
|
||||
|
||||
|
||||
class NoamLR(_LRScheduler):
|
||||
"""
|
||||
Implements the Noam Learning rate schedule. This corresponds to increasing the learning rate
|
||||
linearly for the first ``warmup_steps`` training steps, and decreasing it thereafter proportionally
|
||||
to the inverse square root of the step number, scaled by the inverse square root of the
|
||||
dimensionality of the model. Time will tell if this is just madness or it's actually important.
|
||||
Parameters
|
||||
----------
|
||||
warmup_steps: ``int``, required.
|
||||
The number of steps to linearly increase the learning rate.
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer, warmup_steps):
|
||||
self.warmup_steps = warmup_steps
|
||||
super().__init__(optimizer)
|
||||
|
||||
def get_lr(self):
|
||||
last_epoch = max(1, self.last_epoch)
|
||||
scale = self.warmup_steps**0.5 * min(
|
||||
last_epoch**(-0.5), last_epoch * self.warmup_steps**(-1.5))
|
||||
return [base_lr * scale for base_lr in self.base_lrs]
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,188 +0,0 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from distutils.version import LooseVersion
|
||||
|
||||
import librosa
|
||||
import torch
|
||||
|
||||
is_pytorch_17plus = LooseVersion(torch.__version__) >= LooseVersion('1.7')
|
||||
|
||||
|
||||
def stft(x, fft_size, hop_size, win_length, window):
|
||||
"""Perform STFT and convert to magnitude spectrogram.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input signal tensor (B, T).
|
||||
fft_size (int): FFT size.
|
||||
hop_size (int): Hop size.
|
||||
win_length (int): Window length.
|
||||
window (str): Window function type.
|
||||
|
||||
Returns:
|
||||
Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
|
||||
|
||||
"""
|
||||
if is_pytorch_17plus:
|
||||
x_stft = torch.stft(
|
||||
x, fft_size, hop_size, win_length, window, return_complex=False)
|
||||
else:
|
||||
x_stft = torch.stft(x, fft_size, hop_size, win_length, window)
|
||||
real = x_stft[..., 0]
|
||||
imag = x_stft[..., 1]
|
||||
|
||||
return torch.sqrt(torch.clamp(real**2 + imag**2, min=1e-7)).transpose(2, 1)
|
||||
|
||||
|
||||
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
||||
return 20 * torch.log10(torch.clamp(x, min=clip_val) * C)
|
||||
|
||||
|
||||
def dynamic_range_decompression_torch(x, C=1):
|
||||
return torch.pow(10.0, x * 0.05) / C
|
||||
|
||||
|
||||
def spectral_normalize_torch(
|
||||
magnitudes,
|
||||
min_level_db=-100.0,
|
||||
ref_level_db=20.0,
|
||||
norm_abs_value=4.0,
|
||||
symmetric=True,
|
||||
):
|
||||
output = dynamic_range_compression_torch(magnitudes) - ref_level_db
|
||||
|
||||
if symmetric:
|
||||
return torch.clamp(
|
||||
2 * norm_abs_value * ((output - min_level_db) / # noqa W504
|
||||
(-min_level_db)) - norm_abs_value,
|
||||
min=-norm_abs_value,
|
||||
max=norm_abs_value)
|
||||
else:
|
||||
return torch.clamp(
|
||||
norm_abs_value * ((output - min_level_db) / (-min_level_db)),
|
||||
min=0.0,
|
||||
max=norm_abs_value)
|
||||
|
||||
|
||||
def spectral_de_normalize_torch(
|
||||
magnitudes,
|
||||
min_level_db=-100.0,
|
||||
ref_level_db=20.0,
|
||||
norm_abs_value=4.0,
|
||||
symmetric=True,
|
||||
):
|
||||
if symmetric:
|
||||
magnitudes = torch.clamp(
|
||||
magnitudes, min=-norm_abs_value, max=norm_abs_value)
|
||||
magnitudes = (magnitudes + norm_abs_value) * (-min_level_db) / (
|
||||
2 * norm_abs_value) + min_level_db
|
||||
else:
|
||||
magnitudes = torch.clamp(magnitudes, min=0.0, max=norm_abs_value)
|
||||
magnitudes = (magnitudes) * (-min_level_db) / (
|
||||
norm_abs_value) + min_level_db
|
||||
|
||||
output = dynamic_range_decompression_torch(magnitudes + ref_level_db)
|
||||
return output
|
||||
|
||||
|
||||
class MelSpectrogram(torch.nn.Module):
|
||||
"""Calculate Mel-spectrogram."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fs=22050,
|
||||
fft_size=1024,
|
||||
hop_size=256,
|
||||
win_length=None,
|
||||
window='hann',
|
||||
num_mels=80,
|
||||
fmin=80,
|
||||
fmax=7600,
|
||||
center=True,
|
||||
normalized=False,
|
||||
onesided=True,
|
||||
eps=1e-10,
|
||||
log_base=10.0,
|
||||
pad_mode='constant',
|
||||
):
|
||||
"""Initialize MelSpectrogram module."""
|
||||
super().__init__()
|
||||
self.fft_size = fft_size
|
||||
if win_length is None:
|
||||
self.win_length = fft_size
|
||||
else:
|
||||
self.win_length = win_length
|
||||
self.hop_size = hop_size
|
||||
self.center = center
|
||||
self.normalized = normalized
|
||||
self.onesided = onesided
|
||||
if window is not None and not hasattr(torch, f'{window}_window'):
|
||||
raise ValueError(f'{window} window is not implemented')
|
||||
self.window = window
|
||||
self.eps = eps
|
||||
self.pad_mode = pad_mode
|
||||
|
||||
fmin = 0 if fmin is None else fmin
|
||||
fmax = fs / 2 if fmax is None else fmax
|
||||
melmat = librosa.filters.mel(
|
||||
sr=fs,
|
||||
n_fft=fft_size,
|
||||
n_mels=num_mels,
|
||||
fmin=fmin,
|
||||
fmax=fmax,
|
||||
)
|
||||
self.register_buffer('melmat', torch.from_numpy(melmat.T).float())
|
||||
self.stft_params = {
|
||||
'n_fft': self.fft_size,
|
||||
'win_length': self.win_length,
|
||||
'hop_length': self.hop_size,
|
||||
'center': self.center,
|
||||
'normalized': self.normalized,
|
||||
'onesided': self.onesided,
|
||||
'pad_mode': self.pad_mode,
|
||||
}
|
||||
if is_pytorch_17plus:
|
||||
self.stft_params['return_complex'] = False
|
||||
|
||||
self.log_base = log_base
|
||||
if self.log_base is None:
|
||||
self.log = torch.log
|
||||
elif self.log_base == 2.0:
|
||||
self.log = torch.log2
|
||||
elif self.log_base == 10.0:
|
||||
self.log = torch.log10
|
||||
else:
|
||||
raise ValueError(f'log_base: {log_base} is not supported.')
|
||||
|
||||
def forward(self, x):
|
||||
"""Calculate Mel-spectrogram.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input waveform tensor (B, T) or (B, 1, T).
|
||||
|
||||
Returns:
|
||||
Tensor: Mel-spectrogram (B, #mels, #frames).
|
||||
|
||||
"""
|
||||
if x.dim() == 3:
|
||||
# (B, C, T) -> (B*C, T)
|
||||
x = x.reshape(-1, x.size(2))
|
||||
|
||||
if self.window is not None:
|
||||
window_func = getattr(torch, f'{self.window}_window')
|
||||
window = window_func(
|
||||
self.win_length, dtype=x.dtype, device=x.device)
|
||||
else:
|
||||
window = None
|
||||
|
||||
x_stft = torch.stft(x, window=window, **self.stft_params)
|
||||
# (B, #freqs, #frames, 2) -> (B, $frames, #freqs, 2)
|
||||
x_stft = x_stft.transpose(1, 2)
|
||||
x_power = x_stft[..., 0]**2 + x_stft[..., 1]**2
|
||||
x_amp = torch.sqrt(torch.clamp(x_power, min=self.eps))
|
||||
|
||||
x_mel = torch.matmul(x_amp, self.melmat)
|
||||
x_mel = torch.clamp(x_mel, min=self.eps)
|
||||
x_mel = spectral_normalize_torch(x_mel)
|
||||
|
||||
# return self.log(x_mel).transpose(1, 2)
|
||||
return x_mel.transpose(1, 2)
|
||||
@@ -1,26 +0,0 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import ttsfrd
|
||||
|
||||
|
||||
def text_to_mit_symbols(texts, resources_dir, speaker):
|
||||
fe = ttsfrd.TtsFrontendEngine()
|
||||
fe.initialize(resources_dir)
|
||||
fe.set_lang_type('Zh-CN')
|
||||
|
||||
symbols_lst = []
|
||||
for idx, text in enumerate(texts):
|
||||
text = text.strip()
|
||||
res = fe.gen_tacotron_symbols(text)
|
||||
res = res.replace('F7', speaker)
|
||||
sentences = res.split('\n')
|
||||
for sentence in sentences:
|
||||
arr = sentence.split('\t')
|
||||
# skip the empty line
|
||||
if len(arr) != 2:
|
||||
continue
|
||||
sub_index, symbols = sentence.split('\t')
|
||||
symbol_str = '{}_{}\t{}\n'.format(idx, sub_index, symbols)
|
||||
symbols_lst.append(symbol_str)
|
||||
|
||||
return symbols_lst
|
||||
@@ -1,85 +0,0 @@
|
||||
# from https://github.com/keithito/tacotron
|
||||
# Cleaners are transformations that run over the input text at both training and eval time.
|
||||
#
|
||||
# Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
|
||||
# hyperparameter. Some cleaners are English-specific. You'll typically want to use:
|
||||
# 1. "english_cleaners" for English text
|
||||
# 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
|
||||
# the Unidecode library (https://pypi.python.org/pypi/Unidecode)
|
||||
# 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
|
||||
# the symbols in symbols.py to match your data).
|
||||
|
||||
import re
|
||||
|
||||
from unidecode import unidecode
|
||||
|
||||
from .numbers import normalize_numbers
|
||||
|
||||
# Regular expression matching whitespace:
|
||||
_whitespace_re = re.compile(r'\s+')
|
||||
|
||||
# List of (regular expression, replacement) pairs for abbreviations:
|
||||
_abbreviations = [
|
||||
(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1])
|
||||
for x in [('mrs', 'misess'), ('mr', 'mister'), (
|
||||
'dr', 'doctor'), ('st', 'saint'), ('co', 'company'), (
|
||||
'jr',
|
||||
'junior'), ('maj', 'major'), ('gen', 'general'), (
|
||||
'drs', 'doctors'), ('rev', 'reverend'), (
|
||||
'lt',
|
||||
'lieutenant'), ('hon', 'honorable'), (
|
||||
'sgt',
|
||||
'sergeant'), ('capt', 'captain'), (
|
||||
'esq',
|
||||
'esquire'), ('ltd',
|
||||
'limited'), ('col',
|
||||
'colonel'), ('ft',
|
||||
'fort')]
|
||||
]
|
||||
|
||||
|
||||
def expand_abbreviations(text):
|
||||
for regex, replacement in _abbreviations:
|
||||
text = re.sub(regex, replacement, text)
|
||||
return text
|
||||
|
||||
|
||||
def expand_numbers(text):
|
||||
return normalize_numbers(text)
|
||||
|
||||
|
||||
def lowercase(text):
|
||||
return text.lower()
|
||||
|
||||
|
||||
def collapse_whitespace(text):
|
||||
return re.sub(_whitespace_re, ' ', text)
|
||||
|
||||
|
||||
def convert_to_ascii(text):
|
||||
return unidecode(text)
|
||||
|
||||
|
||||
def basic_cleaners(text):
|
||||
"""Basic pipeline that lowercases and collapses whitespace without transliteration."""
|
||||
text = lowercase(text)
|
||||
text = collapse_whitespace(text)
|
||||
return text
|
||||
|
||||
|
||||
def transliteration_cleaners(text):
|
||||
"""Pipeline for non-English text that transliterates to ASCII."""
|
||||
text = convert_to_ascii(text)
|
||||
text = lowercase(text)
|
||||
text = collapse_whitespace(text)
|
||||
return text
|
||||
|
||||
|
||||
def english_cleaners(text):
|
||||
"""Pipeline for English text, including number and abbreviation expansion."""
|
||||
text = convert_to_ascii(text)
|
||||
text = lowercase(text)
|
||||
text = expand_numbers(text)
|
||||
text = expand_abbreviations(text)
|
||||
text = collapse_whitespace(text)
|
||||
return text
|
||||
@@ -1,37 +0,0 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
emotion_types = [
|
||||
'emotion_none',
|
||||
'emotion_neutral',
|
||||
'emotion_angry',
|
||||
'emotion_disgust',
|
||||
'emotion_fear',
|
||||
'emotion_happy',
|
||||
'emotion_sad',
|
||||
'emotion_surprise',
|
||||
'emotion_calm',
|
||||
'emotion_gentle',
|
||||
'emotion_relax',
|
||||
'emotion_lyrical',
|
||||
'emotion_serious',
|
||||
'emotion_disgruntled',
|
||||
'emotion_satisfied',
|
||||
'emotion_disappointed',
|
||||
'emotion_excited',
|
||||
'emotion_anxiety',
|
||||
'emotion_jealousy',
|
||||
'emotion_hate',
|
||||
'emotion_pity',
|
||||
'emotion_pleasure',
|
||||
'emotion_arousal',
|
||||
'emotion_dominance',
|
||||
'emotion_placeholder1',
|
||||
'emotion_placeholder2',
|
||||
'emotion_placeholder3',
|
||||
'emotion_placeholder4',
|
||||
'emotion_placeholder5',
|
||||
'emotion_placeholder6',
|
||||
'emotion_placeholder7',
|
||||
'emotion_placeholder8',
|
||||
'emotion_placeholder9',
|
||||
]
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user