diff --git a/.gitignore b/.gitignore
index cf36a205..790daab3 100644
--- a/.gitignore
+++ b/.gitignore
@@ -2,7 +2,7 @@
__pycache__/
*.py[cod]
*$py.class
-
+test.py
# C extensions
*.so
@@ -123,6 +123,7 @@ tensorboard.sh
replace.sh
result.png
result.jpg
+result.mp4
# Pytorch
*.pth
diff --git a/data/test/audios/speech_with_noise_48k.pcm b/data/test/audios/speech_with_noise_48k.pcm
new file mode 100644
index 00000000..3e4af18f
Binary files /dev/null and b/data/test/audios/speech_with_noise_48k.pcm differ
diff --git a/data/test/audios/speech_with_noise_48k.wav b/data/test/audios/speech_with_noise_48k.wav
new file mode 100644
index 00000000..ccee3da3
--- /dev/null
+++ b/data/test/audios/speech_with_noise_48k.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a9e76c8448e93934ed9c8827b76f702d07fccc3e586900903617971471235800
+size 475278
diff --git a/data/test/images/dogs.jpg b/data/test/images/dogs.jpg
index 450a969d..1003c3fd 100644
--- a/data/test/images/dogs.jpg
+++ b/data/test/images/dogs.jpg
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
-oid sha256:78094cc48fbcfd9b6d321fe13619ecc72b65e006fc1b4c4458409ade9979486d
-size 129862
+oid sha256:d53a77b0be82993ed44bbb9244cda42bf460f8dcdf87ff3cfdbfdc7191ff418d
+size 121984
diff --git a/data/test/images/human_reconstruction.jpg b/data/test/images/human_reconstruction.jpg
new file mode 100644
index 00000000..4fe2753a
--- /dev/null
+++ b/data/test/images/human_reconstruction.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:06ec486657dffbf244563a844c98c19d49b7a45b99da702403b52bb9e6bf3c0a
+size 226072
diff --git a/data/test/images/image_camouflag_detection.jpg b/data/test/images/image_camouflag_detection.jpg
new file mode 100644
index 00000000..5029067d
--- /dev/null
+++ b/data/test/images/image_camouflag_detection.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4c713215f7fb4da5382c9137347ee52956a7a44d5979c4cffd3c9b6d1d7e878f
+size 19445
diff --git a/data/test/images/image_depth_estimation_kitti_007517.png b/data/test/images/image_depth_estimation_kitti_007517.png
new file mode 100644
index 00000000..785bd5db
--- /dev/null
+++ b/data/test/images/image_depth_estimation_kitti_007517.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f2a83dab7fd7fedff65979fd2496fd86f0a36f222a5a0e6c81fbb161043b9a45
+size 786657
diff --git a/data/test/images/image_smokefire_detection.jpg b/data/test/images/image_smokefire_detection.jpg
new file mode 100644
index 00000000..733e1429
--- /dev/null
+++ b/data/test/images/image_smokefire_detection.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:713082e6967760d5a0d1ae07af62ecc58f9b8b0ab418394556dc5c6c31c27056
+size 63761
diff --git a/data/test/images/lineless_table_recognition.jpg b/data/test/images/lineless_table_recognition.jpg
new file mode 100644
index 00000000..8db3a657
--- /dev/null
+++ b/data/test/images/lineless_table_recognition.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2053b3bcb7abfe5c22b4954b81899dcffbb99302af6f179c43d45265c732d804
+size 26493
diff --git a/data/test/images/ocr_detection/test_gts/X51007339105.txt b/data/test/images/ocr_detection/test_gts/X51007339105.txt
new file mode 100644
index 00000000..45d10a96
--- /dev/null
+++ b/data/test/images/ocr_detection/test_gts/X51007339105.txt
@@ -0,0 +1,46 @@
+44.0,131.0,513.0,131.0,513.0,166.0,44.0,166.0,SANYU STATIONERY SHOP
+45.0,179.0,502.0,179.0,502.0,204.0,45.0,204.0,NO. 31G&33G, JALAN SETIA INDAH X ,U13/X
+43.0,205.0,242.0,205.0,242.0,226.0,43.0,226.0,40170 SETIA ALAM
+44.0,231.0,431.0,231.0,431.0,255.0,44.0,255.0,MOBILE /WHATSAPPS : +6012-918 7937
+41.0,264.0,263.0,264.0,263.0,286.0,41.0,286.0,TEL: +603-3362 4137
+42.0,291.0,321.0,291.0,321.0,312.0,42.0,312.0,GST ID NO: 001531760640
+409.0,303.0,591.0,303.0,591.0,330.0,409.0,330.0,TAX INVOICE
+33.0,321.0,139.0,321.0,139.0,343.0,33.0,343.0,OWNED BY :
+34.0,343.0,376.0,343.0,376.0,369.0,34.0,369.0,SANYU SUPPLY SDN BHD (1135772-K)
+37.0,397.0,303.0,397.0,303.0,420.0,37.0,420.0,CASH SALES COUNTER
+53.0,459.0,188.0,459.0,188.0,483.0,53.0,483.0,1. 2012-0043
+79.0,518.0,193.0,518.0,193.0,545.0,79.0,545.0,1 X 3.3000
+270.0,460.0,585.0,460.0,585.0,484.0,270.0,484.0,JOURNAL BOOK 80PGS A4 70G
+271.0,487.0,527.0,487.0,527.0,513.0,271.0,513.0,CARD COVER (SJB-4013)
+479.0,522.0,527.0,522.0,527.0,542.0,479.0,542.0,3.30
+553.0,524.0,583.0,524.0,583.0,544.0,553.0,544.0,SR
+27.0,557.0,342.0,557.0,342.0,581.0,27.0,581.0,TOTAL SALES INCLUSIVE GST @6%
+240.0,587.0,331.0,587.0,331.0,612.0,240.0,612.0,DISCOUNT
+239.0,629.0,293.0,629.0,293.0,651.0,239.0,651.0,TOTAL
+240.0,659.0,345.0,659.0,345.0,687.0,240.0,687.0,ROUND ADJ
+239.0,701.0,347.0,701.0,347.0,723.0,239.0,723.0,FINAL TOTAL
+239.0,758.0,301.0,758.0,301.0,781.0,239.0,781.0,CASH
+240.0,789.0,329.0,789.0,329.0,811.0,240.0,811.0,CHANGE
+478.0,561.0,524.0,561.0,524.0,581.0,478.0,581.0,3.30
+477.0,590.0,525.0,590.0,525.0,614.0,477.0,614.0,0.00
+482.0,634.0,527.0,634.0,527.0,655.0,482.0,655.0,3.30
+480.0,664.0,528.0,664.0,528.0,684.0,480.0,684.0,0.00
+481.0,704.0,527.0,704.0,527.0,726.0,481.0,726.0,3.30
+481.0,760.0,526.0,760.0,526.0,782.0,481.0,782.0,5.00
+482.0,793.0,528.0,793.0,528.0,814.0,482.0,814.0,1.70
+28.0,834.0,172.0,834.0,172.0,859.0,28.0,859.0,GST SUMMARY
+253.0,834.0,384.0,834.0,384.0,859.0,253.0,859.0,AMOUNT(RM)
+475.0,834.0,566.0,834.0,566.0,860.0,475.0,860.0,TAX(RM)
+28.0,864.0,128.0,864.0,128.0,889.0,28.0,889.0,SR @ 6%
+337.0,864.0,385.0,864.0,385.0,886.0,337.0,886.0,3.11
+518.0,867.0,565.0,867.0,565.0,887.0,518.0,887.0,0.19
+25.0,943.0,290.0,943.0,290.0,967.0,25.0,967.0,INV NO: CS-SA-0076015
+316.0,942.0,516.0,942.0,516.0,967.0,316.0,967.0,DATE : 05/04/2017
+65.0,1084.0,569.0,1084.0,569.0,1110.0,65.0,1110.0,GOODS SOLD ARE NOT RETURNABLE & REFUNDABLE
+112.0,1135.0,524.0,1135.0,524.0,1163.0,112.0,1163.0,THANK YOU FOR YOUR PATRONAGE
+189.0,1169.0,441.0,1169.0,441.0,1192.0,189.0,1192.0,PLEASE COME AGAIN.
+115.0,1221.0,517.0,1221.0,517.0,1245.0,115.0,1245.0,TERIMA KASIH SILA DATANG LAGI
+65.0,1271.0,569.0,1271.0,569.0,1299.0,65.0,1299.0,** PLEASE KEEP THIS RECEIPT FOR PROVE OF
+48.0,1305.0,584.0,1305.0,584.0,1330.0,48.0,1330.0,PURCHASE DATE FOR I.T PRODUCT WARRANTY
+244.0,1339.0,393.0,1339.0,393.0,1359.0,244.0,1359.0,PURPOSE **
+85.0,1389.0,548.0,1389.0,548.0,1419.0,85.0,1419.0,FOLLOW US IN FACEBOOK : SANYU.STATIONERY
diff --git a/data/test/images/ocr_detection/test_images/X51007339105.jpg b/data/test/images/ocr_detection/test_images/X51007339105.jpg
new file mode 100644
index 00000000..ac166703
--- /dev/null
+++ b/data/test/images/ocr_detection/test_images/X51007339105.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:afe8e0d24bed53078472e6e4a00f81cc4e251e88d35bc49afb59cf3fab36fcf8
+size 348614
diff --git a/data/test/images/ocr_detection/test_list.txt b/data/test/images/ocr_detection/test_list.txt
new file mode 100644
index 00000000..2155af92
--- /dev/null
+++ b/data/test/images/ocr_detection/test_list.txt
@@ -0,0 +1 @@
+X51007339105.jpg
diff --git a/data/test/images/ocr_detection/train_gts/X51007339133.txt b/data/test/images/ocr_detection/train_gts/X51007339133.txt
new file mode 100644
index 00000000..87841e60
--- /dev/null
+++ b/data/test/images/ocr_detection/train_gts/X51007339133.txt
@@ -0,0 +1,46 @@
+46.0,135.0,515.0,135.0,515.0,170.0,46.0,170.0,SANYU STATIONERY SHOP
+49.0,184.0,507.0,184.0,507.0,207.0,49.0,207.0,NO. 31G&33G, JALAN SETIA INDAH X ,U13/X
+47.0,209.0,245.0,209.0,245.0,230.0,47.0,230.0,40170 SETIA ALAM
+48.0,236.0,433.0,236.0,433.0,258.0,48.0,258.0,MOBILE /WHATSAPPS : +6012-918 7937
+46.0,267.0,266.0,267.0,266.0,287.0,46.0,287.0,TEL: +603-3362 4137
+47.0,292.0,325.0,292.0,325.0,315.0,47.0,315.0,GST ID NO: 001531760640
+410.0,303.0,594.0,303.0,594.0,331.0,410.0,331.0,TAX INVOICE
+38.0,322.0,141.0,322.0,141.0,344.0,38.0,344.0,OWNED BY :
+38.0,345.0,378.0,345.0,378.0,368.0,38.0,368.0,SANYU SUPPLY SDN BHD (1135772-K)
+43.0,398.0,303.0,398.0,303.0,422.0,43.0,422.0,CASH SALES COUNTER
+55.0,462.0,194.0,462.0,194.0,485.0,55.0,485.0,1. 2012-0029
+81.0,523.0,194.0,523.0,194.0,544.0,81.0,544.0,3 X 2.9000
+275.0,463.0,597.0,463.0,597.0,486.0,275.0,486.0,RESTAURANT ORDER CHIT NCR
+274.0,491.0,347.0,491.0,347.0,512.0,274.0,512.0,3.5"X6"
+482.0,524.0,529.0,524.0,529.0,545.0,482.0,545.0,8.70
+556.0,525.0,585.0,525.0,585.0,546.0,556.0,546.0,SR
+28.0,559.0,346.0,559.0,346.0,581.0,28.0,581.0,TOTAL SALES INCLUSIVE GST @6%
+243.0,590.0,335.0,590.0,335.0,612.0,243.0,612.0,DISCOUNT
+241.0,632.0,296.0,632.0,296.0,654.0,241.0,654.0,TOTAL
+242.0,661.0,349.0,661.0,349.0,685.0,242.0,685.0,ROUND ADJ
+243.0,703.0,348.0,703.0,348.0,724.0,243.0,724.0,FINAL TOTAL
+244.0,760.0,302.0,760.0,302.0,780.0,244.0,780.0,CASH
+241.0,792.0,332.0,792.0,332.0,812.0,241.0,812.0,CHANGE
+481.0,562.0,530.0,562.0,530.0,584.0,481.0,584.0,8.70
+482.0,594.0,528.0,594.0,528.0,613.0,482.0,613.0,0.00
+483.0,636.0,530.0,636.0,530.0,654.0,483.0,654.0,8.70
+483.0,666.0,533.0,666.0,533.0,684.0,483.0,684.0,0.00
+484.0,707.0,532.0,707.0,532.0,726.0,484.0,726.0,8.70
+473.0,764.0,535.0,764.0,535.0,783.0,473.0,783.0,10.00
+486.0,793.0,532.0,793.0,532.0,815.0,486.0,815.0,1.30
+31.0,836.0,176.0,836.0,176.0,859.0,31.0,859.0,GST SUMMARY
+257.0,836.0,391.0,836.0,391.0,858.0,257.0,858.0,AMOUNT(RM)
+479.0,837.0,569.0,837.0,569.0,859.0,479.0,859.0,TAX(RM)
+33.0,867.0,130.0,867.0,130.0,889.0,33.0,889.0,SR @ 6%
+341.0,867.0,389.0,867.0,389.0,889.0,341.0,889.0,8.21
+522.0,869.0,573.0,869.0,573.0,890.0,522.0,890.0,0.49
+30.0,945.0,292.0,945.0,292.0,967.0,30.0,967.0,INV NO: CS-SA-0120436
+323.0,945.0,520.0,945.0,520.0,967.0,323.0,967.0,DATE : 27/10/2017
+70.0,1089.0,572.0,1089.0,572.0,1111.0,70.0,1111.0,GOODS SOLD ARE NOT RETURNABLE & REFUNDABLE
+116.0,1142.0,526.0,1142.0,526.0,1162.0,116.0,1162.0,THANK YOU FOR YOUR PATRONAGE
+199.0,1173.0,445.0,1173.0,445.0,1193.0,199.0,1193.0,PLEASE COME AGAIN.
+121.0,1225.0,524.0,1225.0,524.0,1246.0,121.0,1246.0,TERIMA KASIH SILA DATANG LAGI
+72.0,1273.0,573.0,1273.0,573.0,1299.0,72.0,1299.0,** PLEASE KEEP THIS RECEIPT FOR PROVE OF
+55.0,1308.0,591.0,1308.0,591.0,1328.0,55.0,1328.0,PURCHASE DATE FOR I.T PRODUCT WARRANTY
+249.0,1338.0,396.0,1338.0,396.0,1361.0,249.0,1361.0,PURPOSE **
+93.0,1391.0,553.0,1391.0,553.0,1416.0,93.0,1416.0,FOLLOW US IN FACEBOOK : SANYU.STATIONERY
diff --git a/data/test/images/ocr_detection/train_gts/X51007339135.txt b/data/test/images/ocr_detection/train_gts/X51007339135.txt
new file mode 100644
index 00000000..ed779b40
--- /dev/null
+++ b/data/test/images/ocr_detection/train_gts/X51007339135.txt
@@ -0,0 +1,46 @@
+44.0,131.0,517.0,131.0,517.0,166.0,44.0,166.0,SANYU STATIONERY SHOP
+48.0,180.0,510.0,180.0,510.0,204.0,48.0,204.0,NO. 31G&33G, JALAN SETIA INDAH X ,U13/X
+47.0,206.0,247.0,206.0,247.0,229.0,47.0,229.0,40170 SETIA ALAM
+47.0,232.0,434.0,232.0,434.0,258.0,47.0,258.0,MOBILE /WHATSAPPS : +6012-918 7937
+47.0,265.0,268.0,265.0,268.0,285.0,47.0,285.0,TEL: +603-3362 4137
+48.0,287.0,325.0,287.0,325.0,313.0,48.0,313.0,GST ID NO: 001531760640
+411.0,301.0,599.0,301.0,599.0,333.0,411.0,333.0,TAX INVOICE
+38.0,321.0,143.0,321.0,143.0,342.0,38.0,342.0,OWNED BY :
+39.0,342.0,379.0,342.0,379.0,367.0,39.0,367.0,SANYU SUPPLY SDN BHD (1135772-K)
+42.0,397.0,305.0,397.0,305.0,420.0,42.0,420.0,CASH SALES COUNTER
+57.0,459.0,195.0,459.0,195.0,483.0,57.0,483.0,1. 2012-0029
+82.0,518.0,199.0,518.0,199.0,540.0,82.0,540.0,3 X 2.9000
+274.0,459.0,600.0,459.0,600.0,483.0,274.0,483.0,RESTAURANT ORDER CHIT NCR
+274.0,486.0,347.0,486.0,347.0,508.0,274.0,508.0,3.5"X6"
+483.0,521.0,530.0,521.0,530.0,541.0,483.0,541.0,8.70
+557.0,517.0,588.0,517.0,588.0,543.0,557.0,543.0,SR
+31.0,556.0,347.0,556.0,347.0,578.0,31.0,578.0,TOTAL SALES INCLUSIVE GST @6%
+244.0,585.0,335.0,585.0,335.0,608.0,244.0,608.0,DISCOUNT
+241.0,626.0,302.0,626.0,302.0,651.0,241.0,651.0,TOTAL
+245.0,659.0,354.0,659.0,354.0,683.0,245.0,683.0,ROUND ADJ
+244.0,698.0,351.0,698.0,351.0,722.0,244.0,722.0,FINAL TOTAL
+482.0,558.0,529.0,558.0,529.0,578.0,482.0,578.0,8.70
+484.0,591.0,531.0,591.0,531.0,608.0,484.0,608.0,0.00
+485.0,630.0,533.0,630.0,533.0,651.0,485.0,651.0,8.70
+485.0,661.0,532.0,661.0,532.0,681.0,485.0,681.0,0.00
+484.0,703.0,532.0,703.0,532.0,723.0,484.0,723.0,8.70
+474.0,760.0,534.0,760.0,534.0,777.0,474.0,777.0,10.00
+488.0,789.0,532.0,789.0,532.0,808.0,488.0,808.0,1.30
+33.0,829.0,179.0,829.0,179.0,855.0,33.0,855.0,GST SUMMARY
+261.0,828.0,390.0,828.0,390.0,855.0,261.0,855.0,AMOUNT(RM)
+482.0,830.0,572.0,830.0,572.0,856.0,482.0,856.0,TAX(RM)
+32.0,862.0,135.0,862.0,135.0,885.0,32.0,885.0,SR @ 6%
+344.0,860.0,389.0,860.0,389.0,884.0,344.0,884.0,8.21
+523.0,862.0,575.0,862.0,575.0,885.0,523.0,885.0,0.49
+32.0,941.0,295.0,941.0,295.0,961.0,32.0,961.0,INV NO: CS-SA-0122588
+72.0,1082.0,576.0,1082.0,576.0,1106.0,72.0,1106.0,GOODS SOLD ARE NOT RETURNABLE & REFUNDABLE
+115.0,1135.0,528.0,1135.0,528.0,1157.0,115.0,1157.0,THANK YOU FOR YOUR PATRONAGE
+195.0,1166.0,445.0,1166.0,445.0,1189.0,195.0,1189.0,PLEASE COME AGAIN.
+122.0,1217.0,528.0,1217.0,528.0,1246.0,122.0,1246.0,TERIMA KASIH SILA DATANG LAGI
+72.0,1270.0,576.0,1270.0,576.0,1294.0,72.0,1294.0,** PLEASE KEEP THIS RECEIPT FOR PROVE OF
+55.0,1301.0,592.0,1301.0,592.0,1325.0,55.0,1325.0,PURCHASE DATE FOR I.T PRODUCT WARRANTY
+251.0,1329.0,400.0,1329.0,400.0,1354.0,251.0,1354.0,PURPOSE **
+95.0,1386.0,558.0,1386.0,558.0,1413.0,95.0,1413.0,FOLLOW US IN FACEBOOK : SANYU.STATIONERY
+243.0,752.0,305.0,752.0,305.0,779.0,243.0,779.0,CASH
+244.0,784.0,336.0,784.0,336.0,807.0,244.0,807.0,CHANGE
+316.0,939.0,525.0,939.0,525.0,967.0,316.0,967.0,DATE: 06/11/2017
diff --git a/data/test/images/ocr_detection/train_images/X51007339133.jpg b/data/test/images/ocr_detection/train_images/X51007339133.jpg
new file mode 100644
index 00000000..87ba004d
--- /dev/null
+++ b/data/test/images/ocr_detection/train_images/X51007339133.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ffcc55042093629aaa54d26516de77b45c7b612c0516bad21517e1963e7b518c
+size 352297
diff --git a/data/test/images/ocr_detection/train_images/X51007339135.jpg b/data/test/images/ocr_detection/train_images/X51007339135.jpg
new file mode 100644
index 00000000..d4ac4814
--- /dev/null
+++ b/data/test/images/ocr_detection/train_images/X51007339135.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:56addcb7d36f9b3732e0c4efd04d7e31d291c0763a32dcb556fe262bcbb0520a
+size 353731
diff --git a/data/test/images/ocr_detection/train_list.txt b/data/test/images/ocr_detection/train_list.txt
new file mode 100644
index 00000000..1bdf326f
--- /dev/null
+++ b/data/test/images/ocr_detection/train_list.txt
@@ -0,0 +1,2 @@
+X51007339133.jpg
+X51007339135.jpg
diff --git a/data/test/images/vidt_test1.jpg b/data/test/images/vidt_test1.jpg
new file mode 100644
index 00000000..6f4bc051
--- /dev/null
+++ b/data/test/images/vidt_test1.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b7e87ea289bc59863ed81129d5991ede97bf5335c173ab9f36e4e4cfdc858e41
+size 120137
diff --git a/data/test/images/vision_efficient_tuning_test_apple.jpg b/data/test/images/vision_efficient_tuning_test_apple.jpg
new file mode 100644
index 00000000..7da7fcab
--- /dev/null
+++ b/data/test/images/vision_efficient_tuning_test_apple.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:407d70db9f01bc7a6f34377e36c3f2f5eefdfca8bd3c578226bf5b31b73325dc
+size 127213
diff --git a/data/test/images/vision_efficient_tuning_test_sunflower.jpg b/data/test/images/vision_efficient_tuning_test_sunflower.jpg
new file mode 100644
index 00000000..7ebf088a
--- /dev/null
+++ b/data/test/images/vision_efficient_tuning_test_sunflower.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9c67733db75dc7fd773561a5091329fd5ee919b2268a3a65718261722607698f
+size 226882
diff --git a/docs/source/api/modelscope.msdatasets.cv.rst b/docs/source/api/modelscope.msdatasets.cv.rst
deleted file mode 100644
index ef0a8a3b..00000000
--- a/docs/source/api/modelscope.msdatasets.cv.rst
+++ /dev/null
@@ -1,14 +0,0 @@
-modelscope.msdatasets.cv
-================================
-
-.. automodule:: modelscope.msdatasets.cv
-
-.. currentmodule:: modelscope.msdatasets.cv
-
-.. autosummary::
- :toctree: generated
- :nosignatures:
- :template: classtemplate.rst
-
- easycv_base.EasyCVBaseDataset
- image_classification.ClsDataset
diff --git a/docs/source/api/modelscope.msdatasets.dataset_cls.custom_datasets.rst b/docs/source/api/modelscope.msdatasets.dataset_cls.custom_datasets.rst
new file mode 100644
index 00000000..b5a4b0f6
--- /dev/null
+++ b/docs/source/api/modelscope.msdatasets.dataset_cls.custom_datasets.rst
@@ -0,0 +1,41 @@
+modelscope.msdatasets.dataset_cls.custom_datasets
+====================
+
+.. automodule:: modelscope.msdatasets.dataset_cls.custom_datasets
+
+.. currentmodule:: modelscope.msdatasets.dataset_cls.custom_datasets
+
+
+.. autosummary::
+ :toctree: generated
+ :nosignatures:
+ :template: classtemplate.rst
+
+ EasyCVBaseDataset
+ TorchCustomDataset
+ MovieSceneSegmentationDataset
+ ImageInstanceSegmentationCocoDataset
+ GoproImageDeblurringDataset
+ LanguageGuidedVideoSummarizationDataset
+ MGeoRankingDataset
+ RedsImageDeblurringDataset
+ TextRankingDataset
+ VecoDataset
+ VideoSummarizationDataset
+ BadImageDetectingDataset
+ ImageInpaintingDataset
+ ImagePortraitEnhancementDataset
+ ImageQualityAssessmentDegradationDataset
+ ImageQualityAssessmentMosDataset
+ ReferringVideoObjectSegmentationDataset
+ SiddImageDenoisingDataset
+ VideoFrameInterpolationDataset
+ VideoStabilizationDataset
+ VideoSuperResolutionDataset
+ SegDataset
+ FaceKeypointDataset
+ HandCocoWholeBodyDataset
+ WholeBodyCocoTopDownDataset
+ ClsDataset
+ DetImagesMixDataset
+ DetDataset
diff --git a/docs/source/api/modelscope.msdatasets.dataset_cls.rst b/docs/source/api/modelscope.msdatasets.dataset_cls.rst
new file mode 100644
index 00000000..d415b800
--- /dev/null
+++ b/docs/source/api/modelscope.msdatasets.dataset_cls.rst
@@ -0,0 +1,15 @@
+modelscope.msdatasets.dataset_cls
+====================
+
+.. automodule:: modelscope.msdatasets.dataset_cls
+
+.. currentmodule:: modelscope.msdatasets.dataset_cls
+
+
+.. autosummary::
+ :toctree: generated
+ :nosignatures:
+ :template: classtemplate.rst
+
+ ExternalDataset
+ NativeIterableDataset
diff --git a/docs/source/api/modelscope.msdatasets.ms_dataset.rst b/docs/source/api/modelscope.msdatasets.ms_dataset.rst
index 03cc8d97..92df1e89 100644
--- a/docs/source/api/modelscope.msdatasets.ms_dataset.rst
+++ b/docs/source/api/modelscope.msdatasets.ms_dataset.rst
@@ -10,5 +10,4 @@ modelscope.msdatasets.ms_dataset
:nosignatures:
:template: classtemplate.rst
- MsMapDataset
MsDataset
diff --git a/examples/pytorch/text_generation/finetune_text_generation.py b/examples/pytorch/text_generation/finetune_text_generation.py
new file mode 100644
index 00000000..5168e00e
--- /dev/null
+++ b/examples/pytorch/text_generation/finetune_text_generation.py
@@ -0,0 +1,102 @@
+from dataclasses import dataclass, field
+
+from modelscope.metainfo import Trainers
+from modelscope.msdatasets import MsDataset
+from modelscope.trainers import EpochBasedTrainer, build_trainer
+from modelscope.trainers.training_args import TrainingArgs
+
+
+@dataclass
+class TextGenerationArguments(TrainingArgs):
+
+ trainer: str = field(
+ default=Trainers.default, metadata={
+ 'help': 'The trainer used',
+ })
+
+ work_dir: str = field(
+ default='./tmp',
+ metadata={
+ 'help': 'The working path for saving checkpoint',
+ })
+
+ src_txt: str = field(
+ default=None,
+ metadata={
+ 'help': 'The source text key of preprocessor',
+ 'cfg_node': 'preprocessor.src_txt'
+ })
+
+ tgt_txt: str = field(
+ default=None,
+ metadata={
+ 'help': 'The target text key of preprocessor',
+ 'cfg_node': 'preprocessor.tgt_txt'
+ })
+
+ preprocessor: str = field(
+ default=None,
+ metadata={
+ 'help': 'The preprocessor type',
+ 'cfg_node': 'preprocessor.type'
+ })
+
+ lr_scheduler: str = field(
+ default=None,
+ metadata={
+ 'help': 'The lr scheduler type',
+ 'cfg_node': 'train.lr_scheduler.type'
+ })
+
+ world_size: int = field(
+ default=None,
+ metadata={
+ 'help': 'The parallel world size',
+ 'cfg_node': 'megatron.world_size'
+ })
+
+ tensor_model_parallel_size: int = field(
+ default=None,
+ metadata={
+ 'help': 'The tensor model parallel size',
+ 'cfg_node': 'megatron.tensor_model_parallel_size'
+ })
+
+ def __call__(self, config):
+ config = super().__call__(config)
+ if config.train.lr_scheduler.type == 'noam':
+ config.train.lr_scheduler = {
+ 'type': 'LambdaLR',
+ 'lr_lambda': noam_lambda,
+ 'options': {
+ 'by_epoch': False
+ }
+ }
+ config.train.hooks.append({'type': 'MegatronHook'})
+ return config
+
+
+def noam_lambda(current_step: int):
+ current_step += 1
+ return min(current_step**(-0.5), current_step * 100**(-1.5))
+
+
+args = TextGenerationArguments.from_cli(task='text-generation')
+
+print(args)
+
+dataset = MsDataset.load(args.dataset_name)
+train_dataset = dataset['train']
+eval_dataset = dataset['validation' if 'validation' in dataset else 'test']
+
+kwargs = dict(
+ model=args.model,
+ train_dataset=train_dataset,
+ eval_dataset=eval_dataset,
+ seed=args.seed,
+ work_dir=args.work_dir,
+ cfg_modify_fn=args)
+
+trainer: EpochBasedTrainer = build_trainer(
+ name=args.trainer, default_args=kwargs)
+trainer.train()
diff --git a/examples/pytorch/text_generation/run_train.sh b/examples/pytorch/text_generation/run_train.sh
new file mode 100644
index 00000000..cbecd11a
--- /dev/null
+++ b/examples/pytorch/text_generation/run_train.sh
@@ -0,0 +1,22 @@
+DATA_PARALLEL_SIZE=2
+TENSOR_MODEL_PARALLEL_SIZE=2
+
+WORLD_SIZE=$(($DATA_PARALLEL_SIZE * $TENSOR_MODEL_PARALLEL_SIZE))
+
+
+PYTHONPATH=. torchrun --nproc_per_node $WORLD_SIZE examples/pytorch/text_generation/finetune_text_generation.py \
+ --trainer 'nlp-gpt3-trainer' \
+ --work_dir './tmp' \
+ --model 'damo/nlp_gpt3_text-generation_1.3B' \
+ --dataset_name 'chinese-poetry-collection' \
+ --preprocessor 'text-gen-jieba-tokenizer' \
+ --src_txt 'text1' \
+ --tgt_txt 'text2' \
+ --max_epochs 3 \
+ --per_device_train_batch_size 16 \
+ --lr 3e-4 \
+ --lr_scheduler 'noam' \
+ --eval_metrics 'ppl' \
+ --world_size $WORLD_SIZE \
+ --tensor_model_parallel_size $TENSOR_MODEL_PARALLEL_SIZE \
+ # --dataset_name 'DuReader_robust-QG' \ # input&output
diff --git a/modelscope/cli/cli.py b/modelscope/cli/cli.py
index 0f7a8139..a25502fd 100644
--- a/modelscope/cli/cli.py
+++ b/modelscope/cli/cli.py
@@ -3,6 +3,9 @@
import argparse
from modelscope.cli.download import DownloadCMD
+from modelscope.cli.modelcard import ModelCardCMD
+from modelscope.cli.pipeline import PipelineCMD
+from modelscope.cli.plugins import PluginsCMD
def run_cmd():
@@ -11,6 +14,9 @@ def run_cmd():
subparsers = parser.add_subparsers(help='modelscope commands helpers')
DownloadCMD.define_args(subparsers)
+ PluginsCMD.define_args(subparsers)
+ PipelineCMD.define_args(subparsers)
+ ModelCardCMD.define_args(subparsers)
args = parser.parse_args()
diff --git a/modelscope/cli/modelcard.py b/modelscope/cli/modelcard.py
new file mode 100644
index 00000000..72372894
--- /dev/null
+++ b/modelscope/cli/modelcard.py
@@ -0,0 +1,178 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import os
+import shutil
+import tempfile
+from argparse import ArgumentParser
+from string import Template
+
+from modelscope.cli.base import CLICommand
+from modelscope.hub.api import HubApi
+from modelscope.hub.snapshot_download import snapshot_download
+from modelscope.hub.utils.utils import get_endpoint
+from modelscope.utils.logger import get_logger
+
+logger = get_logger()
+
+curren_path = os.path.dirname(os.path.abspath(__file__))
+template_path = os.path.join(curren_path, 'template')
+
+
+def subparser_func(args):
+ """ Fuction which will be called for a specific sub parser.
+ """
+ return ModelCardCMD(args)
+
+
+class ModelCardCMD(CLICommand):
+ name = 'modelcard'
+
+ def __init__(self, args):
+ self.args = args
+ self.api = HubApi()
+ self.api.login(args.access_token)
+ self.model_id = os.path.join(
+ self.args.group_id, self.args.model_id
+ ) if '/' not in self.args.model_id else self.args.model_id
+ self.url = os.path.join(get_endpoint(), self.model_id)
+
+ @staticmethod
+ def define_args(parsers: ArgumentParser):
+ """ define args for create or upload modelcard command.
+ """
+ parser = parsers.add_parser(ModelCardCMD.name)
+ parser.add_argument(
+ '-tk',
+ '--access_token',
+ type=str,
+ required=True,
+ help='the certification of visit ModelScope')
+ parser.add_argument(
+ '-act',
+ '--action',
+ type=str,
+ required=True,
+ choices=['create', 'upload', 'download'],
+ help='the action of api ModelScope[create, upload]')
+ parser.add_argument(
+ '-gid',
+ '--group_id',
+ type=str,
+ default='damo',
+ help='the group name of ModelScope, eg, damo')
+ parser.add_argument(
+ '-mid',
+ '--model_id',
+ type=str,
+ required=True,
+ help='the model name of ModelScope')
+ parser.add_argument(
+ '-vis',
+ '--visibility',
+ type=int,
+ default=5,
+ help='the visibility of ModelScope')
+ parser.add_argument(
+ '-lic',
+ '--license',
+ type=str,
+ default='Apache License 2.0',
+ help='the license of visit ModelScope')
+ parser.add_argument(
+ '-ch',
+ '--chinese_name',
+ type=str,
+ default='这是我的第一个模型',
+ help='the chinese name of ModelScope')
+ parser.add_argument(
+ '-md',
+ '--model_dir',
+ type=str,
+ default='.',
+ help='the model_dir of configuration.json')
+ parser.add_argument(
+ '-vt',
+ '--version_tag',
+ type=str,
+ default=None,
+ help='the tag of uploaded model')
+ parser.add_argument(
+ '-vi',
+ '--version_info',
+ type=str,
+ default=None,
+ help='the info of uploaded model')
+ parser.set_defaults(func=subparser_func)
+
+ def create_model(self):
+ from modelscope.hub.constants import Licenses, ModelVisibility
+ visibilities = [
+ getattr(ModelVisibility, attr) for attr in dir(ModelVisibility)
+ if not attr.startswith('__')
+ ]
+ if self.args.visibility not in visibilities:
+ raise ValueError('The access_token must in %s!' % visibilities)
+ licenses = [
+ getattr(Licenses, attr) for attr in dir(Licenses)
+ if not attr.startswith('__')
+ ]
+ if self.args.license not in licenses:
+ raise ValueError('The license must in %s!' % licenses)
+ try:
+ self.api.get_model(self.model_id)
+ except Exception as e:
+ logger.info('>>> %s' % type(e))
+ self.api.create_model(
+ model_id=self.model_id,
+ visibility=self.args.visibility,
+ license=self.args.license,
+ chinese_name=self.args.chinese_name,
+ )
+ self.pprint()
+
+ def get_model_url(self):
+ return self.api.get_model_url(self.model_id)
+
+ def push_model(self, tpl_dir='readme.tpl'):
+ from modelscope.hub.repository import Repository
+ if self.args.version_tag and self.args.version_info:
+ clone_dir = tempfile.TemporaryDirectory().name
+ repo = Repository(clone_dir, clone_from=self.model_id)
+ repo.tag_and_push(self.args.version_tag, self.args.version_info)
+ shutil.rmtree(clone_dir)
+ else:
+ cfg_file = os.path.join(self.args.model_dir, 'README.md')
+ if not os.path.exists(cfg_file):
+ with open(os.path.join(template_path,
+ tpl_dir)) as tpl_file_path:
+ tpl = Template(tpl_file_path.read())
+ f = open(cfg_file, 'w')
+ f.write(tpl.substitute(model_id=self.model_id))
+ f.close()
+ self.api.push_model(
+ model_id=self.model_id,
+ model_dir=self.args.model_dir,
+ visibility=self.args.visibility,
+ license=self.args.license,
+ chinese_name=self.args.chinese_name)
+ self.pprint()
+
+ def pprint(self):
+ logger.info('>>> Clone the model_git < %s >, commit and push it.'
+ % self.get_model_url())
+ logger.info('>>> Open the url < %s >, check and read it.' % self.url)
+ logger.info('>>> Visit the model_id < %s >, download and run it.'
+ % self.model_id)
+
+ def execute(self):
+ if self.args.action == 'create':
+ self.create_model()
+ elif self.args.action == 'upload':
+ self.push_model()
+ elif self.args.action == 'download':
+ snapshot_download(
+ self.model_id,
+ cache_dir=self.args.model_dir,
+ revision=self.args.version_tag)
+ else:
+ raise ValueError(
+ 'The parameter of action must be in [create, upload]')
diff --git a/modelscope/cli/pipeline.py b/modelscope/cli/pipeline.py
new file mode 100644
index 00000000..59cabdf9
--- /dev/null
+++ b/modelscope/cli/pipeline.py
@@ -0,0 +1,127 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import os
+from argparse import ArgumentParser
+from string import Template
+
+from modelscope.cli.base import CLICommand
+from modelscope.utils.logger import get_logger
+
+logger = get_logger()
+
+curren_path = os.path.dirname(os.path.abspath(__file__))
+template_path = os.path.join(curren_path, 'template')
+
+
+def subparser_func(args):
+ """ Fuction which will be called for a specific sub parser.
+ """
+ return PipelineCMD(args)
+
+
+class PipelineCMD(CLICommand):
+ name = 'pipeline'
+
+ def __init__(self, args):
+ self.args = args
+
+ @staticmethod
+ def define_args(parsers: ArgumentParser):
+ """ define args for create pipeline template command.
+ """
+ parser = parsers.add_parser(PipelineCMD.name)
+ parser.add_argument(
+ '-act',
+ '--action',
+ type=str,
+ required=True,
+ choices=['create'],
+ help='the action of command pipeline[create]')
+ parser.add_argument(
+ '-tpl',
+ '--tpl_file_path',
+ type=str,
+ default='template.tpl',
+ help='the template be selected for ModelScope[template.tpl]')
+ parser.add_argument(
+ '-s',
+ '--save_file_path',
+ type=str,
+ default='./',
+ help='the name of custom template be saved for ModelScope')
+ parser.add_argument(
+ '-f',
+ '--filename',
+ type=str,
+ default='ms_wrapper.py',
+ help='the init name of custom template be saved for ModelScope')
+ parser.add_argument(
+ '-t',
+ '--task_name',
+ type=str,
+ required=True,
+ help='the unique task_name for ModelScope')
+ parser.add_argument(
+ '-m',
+ '--model_name',
+ type=str,
+ default='MyCustomModel',
+ help='the class of model name for ModelScope')
+ parser.add_argument(
+ '-p',
+ '--preprocessor_name',
+ type=str,
+ default='MyCustomPreprocessor',
+ help='the class of preprocessor name for ModelScope')
+ parser.add_argument(
+ '-pp',
+ '--pipeline_name',
+ type=str,
+ default='MyCustomPipeline',
+ help='the class of pipeline name for ModelScope')
+ parser.add_argument(
+ '-config',
+ '--configuration_path',
+ type=str,
+ default='./',
+ help='the path of configuration.json for ModelScope')
+ parser.set_defaults(func=subparser_func)
+
+ def create_template(self):
+ if self.args.tpl_file_path not in os.listdir(template_path):
+ tpl_file_path = self.args.tpl_file_path
+ else:
+ tpl_file_path = os.path.join(template_path,
+ self.args.tpl_file_path)
+ if not os.path.exists(tpl_file_path):
+ raise ValueError('%s not exists!' % tpl_file_path)
+
+ save_file_path = self.args.save_file_path if self.args.save_file_path != './' else os.getcwd(
+ )
+ os.makedirs(save_file_path, exist_ok=True)
+ if not self.args.filename.endswith('.py'):
+ raise ValueError('the FILENAME must end with .py ')
+ save_file_name = self.args.filename
+ save_pkl_path = os.path.join(save_file_path, save_file_name)
+
+ if not self.args.configuration_path.endswith('/'):
+ self.args.configuration_path = self.args.configuration_path + '/'
+
+ lines = []
+ with open(tpl_file_path) as tpl_file:
+ tpl = Template(tpl_file.read())
+ lines.append(tpl.substitute(**vars(self.args)))
+
+ with open(save_pkl_path, 'w') as save_file:
+ save_file.writelines(lines)
+
+ logger.info('>>> Configuration be saved in %s/%s' %
+ (self.args.configuration_path, 'configuration.json'))
+ logger.info('>>> Task_name: %s, Created in %s' %
+ (self.args.task_name, save_pkl_path))
+ logger.info('Open the file < %s >, update and run it.' % save_pkl_path)
+
+ def execute(self):
+ if self.args.action == 'create':
+ self.create_template()
+ else:
+ raise ValueError('The parameter of action must be in [create]')
diff --git a/modelscope/cli/plugins.py b/modelscope/cli/plugins.py
new file mode 100644
index 00000000..e40457df
--- /dev/null
+++ b/modelscope/cli/plugins.py
@@ -0,0 +1,118 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+
+from argparse import ArgumentParser
+
+from modelscope.cli.base import CLICommand
+from modelscope.utils.plugins import PluginsManager
+
+plugins_manager = PluginsManager()
+
+
+def subparser_func(args):
+ """ Fuction which will be called for a specific sub parser.
+ """
+ return PluginsCMD(args)
+
+
+class PluginsCMD(CLICommand):
+ name = 'plugin'
+
+ def __init__(self, args):
+ self.args = args
+
+ @staticmethod
+ def define_args(parsers: ArgumentParser):
+ """ define args for install command.
+ """
+ parser = parsers.add_parser(PluginsCMD.name)
+ subparsers = parser.add_subparsers(dest='command')
+
+ PluginsInstallCMD.define_args(subparsers)
+ PluginsUninstallCMD.define_args(subparsers)
+ PluginsListCMD.define_args(subparsers)
+
+ parser.set_defaults(func=subparser_func)
+
+ def execute(self):
+ print(self.args)
+ if self.args.command == PluginsInstallCMD.name:
+ PluginsInstallCMD.execute(self.args)
+ if self.args.command == PluginsUninstallCMD.name:
+ PluginsUninstallCMD.execute(self.args)
+ if self.args.command == PluginsListCMD.name:
+ PluginsListCMD.execute(self.args)
+
+
+class PluginsInstallCMD(PluginsCMD):
+ name = 'install'
+
+ @staticmethod
+ def define_args(parsers: ArgumentParser):
+ install = parsers.add_parser(PluginsInstallCMD.name)
+ install.add_argument(
+ 'package',
+ type=str,
+ nargs='+',
+ default=None,
+ help='Name of the package to be installed.')
+ install.add_argument(
+ '--index_url',
+ '-i',
+ type=str,
+ default=None,
+ help='Base URL of the Python Package Index.')
+ install.add_argument(
+ '--force_update',
+ '-f',
+ type=str,
+ default=False,
+ help='If force update the package')
+
+ @staticmethod
+ def execute(args):
+ plugins_manager.install_plugins(
+ list(args.package),
+ index_url=args.index_url,
+ force_update=args.force_update)
+
+
+class PluginsUninstallCMD(PluginsCMD):
+ name = 'uninstall'
+
+ @staticmethod
+ def define_args(parsers: ArgumentParser):
+ install = parsers.add_parser(PluginsUninstallCMD.name)
+ install.add_argument(
+ 'package',
+ type=str,
+ nargs='+',
+ default=None,
+ help='Name of the package to be installed.')
+ install.add_argument(
+ '--yes',
+ '-y',
+ type=str,
+ default=False,
+ help='Base URL of the Python Package Index.')
+
+ @staticmethod
+ def execute(args):
+ plugins_manager.uninstall_plugins(list(args.package), is_yes=args.yes)
+
+
+class PluginsListCMD(PluginsCMD):
+ name = 'list'
+
+ @staticmethod
+ def define_args(parsers: ArgumentParser):
+ install = parsers.add_parser(PluginsListCMD.name)
+ install.add_argument(
+ '--all',
+ '-a',
+ type=str,
+ default=None,
+ help='Show all of the plugins including those not installed.')
+
+ @staticmethod
+ def execute(args):
+ plugins_manager.list_plugins(show_all=all)
diff --git a/modelscope/cli/template/readme.tpl b/modelscope/cli/template/readme.tpl
new file mode 100644
index 00000000..15e10a7a
--- /dev/null
+++ b/modelscope/cli/template/readme.tpl
@@ -0,0 +1,10 @@
+---
+license: Apache License 2.0
+---
+###### 该模型当前使用的是默认介绍模版,处于“预发布”阶段,页面仅限所有者可见。
+###### 请根据[模型贡献文档说明](https://www.modelscope.cn/docs/%E5%A6%82%E4%BD%95%E6%92%B0%E5%86%99%E5%A5%BD%E7%94%A8%E7%9A%84%E6%A8%A1%E5%9E%8B%E5%8D%A1%E7%89%87),及时完善模型卡片内容。ModelScope平台将在模型卡片完善后展示。谢谢您的理解。
+
+#### Clone with HTTP
+```bash
+ git clone https://www.modelscope.cn/${model_id}.git
+```
diff --git a/modelscope/cli/template/template.tpl b/modelscope/cli/template/template.tpl
new file mode 100644
index 00000000..d24f1b71
--- /dev/null
+++ b/modelscope/cli/template/template.tpl
@@ -0,0 +1,139 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+
+import torch
+
+from modelscope.models.base import TorchModel
+from modelscope.preprocessors.base import Preprocessor
+from modelscope.pipelines.base import Model, Pipeline
+from modelscope.utils.config import Config
+from modelscope.pipelines.builder import PIPELINES
+from modelscope.preprocessors.builder import PREPROCESSORS
+from modelscope.models.builder import MODELS
+
+
+@MODELS.register_module('${task_name}', module_name='my-custom-model')
+class ${model_name}(TorchModel):
+
+ def __init__(self, model_dir, *args, **kwargs):
+ super().__init__(model_dir, *args, **kwargs)
+ self.model = self.init_model(**kwargs)
+
+ def forward(self, input_tensor, **forward_params):
+ return self.model(input_tensor, **forward_params)
+
+ def init_model(self, **kwargs):
+ """Provide default implementation based on TorchModel and user can reimplement it.
+ include init model and load ckpt from the model_dir, maybe include preprocessor
+ if nothing to do, then return lambdx x: x
+ """
+ return lambda x: x
+
+
+@PREPROCESSORS.register_module('${task_name}', module_name='my-custom-preprocessor')
+class ${preprocessor_name}(Preprocessor):
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.trainsforms = self.init_preprocessor(**kwargs)
+
+ def __call__(self, results):
+ return self.trainsforms(results)
+
+ def init_preprocessor(self, **kwarg):
+ """ Provide default implementation based on preprocess_cfg and user can reimplement it.
+ if nothing to do, then return lambdx x: x
+ """
+ return lambda x: x
+
+
+@PIPELINES.register_module('${task_name}', module_name='my-custom-pipeline')
+class ${pipeline_name}(Pipeline):
+ """ Give simple introduction to this pipeline.
+
+ Examples:
+
+ >>> from modelscope.pipelines import pipeline
+ >>> input = "Hello, ModelScope!"
+ >>> my_pipeline = pipeline('my-task', 'my-model-id')
+ >>> result = my_pipeline(input)
+
+ """
+
+ def __init__(self, model, preprocessor=None, **kwargs):
+ """
+ use `model` and `preprocessor` to create a custom pipeline for prediction
+ Args:
+ model: model id on modelscope hub.
+ preprocessor: the class of method be init_preprocessor
+ """
+ super().__init__(model=model, auto_collate=False)
+ assert isinstance(model, str) or isinstance(model, Model), \
+ 'model must be a single str or Model'
+ if isinstance(model, str):
+ pipe_model = Model.from_pretrained(model)
+ elif isinstance(model, Model):
+ pipe_model = model
+ else:
+ raise NotImplementedError
+ pipe_model.eval()
+
+ if preprocessor is None:
+ preprocessor = ${preprocessor_name}()
+ super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs)
+
+ def _sanitize_parameters(self, **pipeline_parameters):
+ """
+ this method should sanitize the keyword args to preprocessor params,
+ forward params and postprocess params on '__call__' or '_process_single' method
+ considered to be a normal classmethod with default implementation / output
+
+ Default Returns:
+ Dict[str, str]: preprocess_params = {}
+ Dict[str, str]: forward_params = {}
+ Dict[str, str]: postprocess_params = pipeline_parameters
+ """
+ return {}, pipeline_parameters, {}
+
+ def _check_input(self, inputs):
+ pass
+
+ def _check_output(self, outputs):
+ pass
+
+ def forward(self, inputs, **forward_params):
+ """ Provide default implementation using self.model and user can reimplement it
+ """
+ return super().forward(inputs, **forward_params)
+
+ def postprocess(self, inputs):
+ """ If current pipeline support model reuse, common postprocess
+ code should be write here.
+
+ Args:
+ inputs: input data
+
+ Return:
+ dict of results: a dict containing outputs of model, each
+ output should have the standard output name.
+ """
+ return inputs
+
+
+# Tips: usr_config_path is the temporary save configuration location, after upload modelscope hub, it is the model_id
+usr_config_path = '${configuration_path}'
+config = Config({
+ 'framework': 'pytorch',
+ 'task': '${task_name}',
+ 'model': {'type': 'my-custom-model'},
+ "pipeline": {"type": "my-custom-pipeline"}
+})
+config.dump('${configuration_path}' + 'configuration.json')
+
+if __name__ == "__main__":
+ from modelscope.models import Model
+ from modelscope.pipelines import pipeline
+ # model = Model.from_pretrained(usr_config_path)
+ input = "Hello, ModelScope!"
+ inference = pipeline('${task_name}', model=usr_config_path)
+ output = inference(input)
+ print(output)
diff --git a/modelscope/exporters/__init__.py b/modelscope/exporters/__init__.py
index c5f3ad50..8b627816 100644
--- a/modelscope/exporters/__init__.py
+++ b/modelscope/exporters/__init__.py
@@ -1,14 +1,38 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
-from modelscope.utils.import_utils import is_tf_available, is_torch_available
-from .base import Exporter
-from .builder import build_exporter
+from typing import TYPE_CHECKING
-if is_tf_available():
+from modelscope.utils.import_utils import LazyImportModule
+
+if TYPE_CHECKING:
+ from .base import Exporter
+ from .builder import build_exporter
from .cv import CartoonTranslationExporter
from .nlp import CsanmtForTranslationExporter
from .tf_model_exporter import TfModelExporter
-if is_torch_available():
from .nlp import SbertForSequenceClassificationExporter, SbertForZeroShotClassificationExporter
from .torch_model_exporter import TorchModelExporter
from .cv import FaceDetectionSCRFDExporter
+else:
+ _import_structure = {
+ 'base': ['Exporter'],
+ 'builder': ['build_exporter'],
+ 'cv': ['CartoonTranslationExporter', 'FaceDetectionSCRFDExporter'],
+ 'nlp': [
+ 'CsanmtForTranslationExporter',
+ 'SbertForSequenceClassificationExporter',
+ 'SbertForZeroShotClassificationExporter'
+ ],
+ 'tf_model_exporter': ['TfModelExporter'],
+ 'torch_model_exporter': ['TorchModelExporter'],
+ }
+
+ import sys
+
+ sys.modules[__name__] = LazyImportModule(
+ __name__,
+ globals()['__file__'],
+ _import_structure,
+ module_spec=__spec__,
+ extra_objects={},
+ )
diff --git a/modelscope/exporters/cv/__init__.py b/modelscope/exporters/cv/__init__.py
index ab80d049..67a406db 100644
--- a/modelscope/exporters/cv/__init__.py
+++ b/modelscope/exporters/cv/__init__.py
@@ -1,7 +1,27 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
-from modelscope.utils.import_utils import is_tf_available, is_torch_available
-if is_tf_available():
+from typing import TYPE_CHECKING
+
+from modelscope.utils.import_utils import LazyImportModule
+
+if TYPE_CHECKING:
from .cartoon_translation_exporter import CartoonTranslationExporter
-if is_torch_available():
+ from .object_detection_damoyolo_exporter import ObjectDetectionDamoyoloExporter
from .face_detection_scrfd_exporter import FaceDetectionSCRFDExporter
+else:
+ _import_structure = {
+ 'cartoon_translation_exporter': ['CartoonTranslationExporter'],
+ 'object_detection_damoyolo_exporter':
+ ['ObjectDetectionDamoyoloExporter'],
+ 'face_detection_scrfd_exporter': ['FaceDetectionSCRFDExporter'],
+ }
+
+ import sys
+
+ sys.modules[__name__] = LazyImportModule(
+ __name__,
+ globals()['__file__'],
+ _import_structure,
+ module_spec=__spec__,
+ extra_objects={},
+ )
diff --git a/modelscope/exporters/cv/object_detection_damoyolo_exporter.py b/modelscope/exporters/cv/object_detection_damoyolo_exporter.py
new file mode 100644
index 00000000..673811ad
--- /dev/null
+++ b/modelscope/exporters/cv/object_detection_damoyolo_exporter.py
@@ -0,0 +1,42 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import os
+from functools import partial
+from typing import Mapping
+
+import numpy as np
+import onnx
+import torch
+
+from modelscope.exporters.builder import EXPORTERS
+from modelscope.exporters.torch_model_exporter import TorchModelExporter
+from modelscope.metainfo import Models
+from modelscope.utils.constant import ModelFile, Tasks
+
+
+@EXPORTERS.register_module(
+ Tasks.image_object_detection, module_name=Models.tinynas_damoyolo)
+class ObjectDetectionDamoyoloExporter(TorchModelExporter):
+
+ def export_onnx(self,
+ output_dir: str,
+ opset=11,
+ input_shape=(1, 3, 640, 640)):
+ onnx_file = os.path.join(output_dir, ModelFile.ONNX_MODEL_FILE)
+ dummy_input = torch.randn(*input_shape)
+ self.model.head.nms = False
+ self.model.onnx_export = True
+ self.model.eval()
+ _ = self.model(dummy_input)
+ torch.onnx._export(
+ self.model,
+ dummy_input,
+ onnx_file,
+ input_names=[
+ 'images',
+ ],
+ output_names=[
+ 'pred',
+ ],
+ opset_version=opset)
+
+ return {'model', onnx_file}
diff --git a/modelscope/exporters/nlp/__init__.py b/modelscope/exporters/nlp/__init__.py
index 731e4bb7..26df5775 100644
--- a/modelscope/exporters/nlp/__init__.py
+++ b/modelscope/exporters/nlp/__init__.py
@@ -1,11 +1,31 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
-from modelscope.utils.import_utils import is_tf_available, is_torch_available
+from typing import TYPE_CHECKING
-if is_tf_available():
+from modelscope.utils.import_utils import LazyImportModule
+
+if TYPE_CHECKING:
from .csanmt_for_translation_exporter import CsanmtForTranslationExporter
-if is_torch_available():
+ from .model_for_token_classification_exporter import ModelForSequenceClassificationExporter
from .sbert_for_sequence_classification_exporter import \
SbertForSequenceClassificationExporter
from .sbert_for_zero_shot_classification_exporter import \
SbertForZeroShotClassificationExporter
+else:
+ _import_structure = {
+ 'csanmt_for_translation_exporter': ['CsanmtForTranslationExporter'],
+ 'model_for_token_classification_exporter':
+ ['ModelForSequenceClassificationExporter'],
+ 'sbert_for_zero_shot_classification_exporter':
+ ['SbertForZeroShotClassificationExporter'],
+ }
+
+ import sys
+
+ sys.modules[__name__] = LazyImportModule(
+ __name__,
+ globals()['__file__'],
+ _import_structure,
+ module_spec=__spec__,
+ extra_objects={},
+ )
diff --git a/modelscope/exporters/nlp/model_for_token_classification_exporter.py b/modelscope/exporters/nlp/model_for_token_classification_exporter.py
new file mode 100644
index 00000000..676615c0
--- /dev/null
+++ b/modelscope/exporters/nlp/model_for_token_classification_exporter.py
@@ -0,0 +1,112 @@
+from collections import OrderedDict
+from typing import Any, Dict, Mapping
+
+import torch
+from torch import nn
+
+from modelscope.exporters.builder import EXPORTERS
+from modelscope.exporters.torch_model_exporter import TorchModelExporter
+from modelscope.metainfo import Models
+from modelscope.outputs import ModelOutputBase
+from modelscope.preprocessors import Preprocessor
+from modelscope.utils.constant import Tasks
+from modelscope.utils.regress_test_utils import (compare_arguments_nested,
+ numpify_tensor_nested)
+
+
+@EXPORTERS.register_module(Tasks.transformer_crf, module_name=Models.tcrf)
+@EXPORTERS.register_module(Tasks.token_classification, module_name=Models.tcrf)
+@EXPORTERS.register_module(
+ Tasks.named_entity_recognition, module_name=Models.tcrf)
+@EXPORTERS.register_module(Tasks.part_of_speech, module_name=Models.tcrf)
+@EXPORTERS.register_module(Tasks.word_segmentation, module_name=Models.tcrf)
+class ModelForSequenceClassificationExporter(TorchModelExporter):
+
+ def generate_dummy_inputs(self, **kwargs) -> Dict[str, Any]:
+ """Generate dummy inputs for model exportation to onnx or other formats by tracing.
+
+ Args:
+ shape: A tuple of input shape which should have at most two dimensions.
+ shape = (1, ) batch_size=1, sequence_length will be taken from the preprocessor.
+ shape = (8, 128) batch_size=1, sequence_length=128, which will cover the config of the preprocessor.
+ pair(bool, `optional`): Whether to generate sentence pairs or single sentences.
+
+ Returns:
+ Dummy inputs.
+ """
+
+ assert hasattr(
+ self.model, 'model_dir'
+ ), 'model_dir attribute is required to build the preprocessor'
+
+ preprocessor = Preprocessor.from_pretrained(
+ self.model.model_dir, return_text=False)
+ return preprocessor('2023')
+
+ @property
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
+ dynamic_axis = {0: 'batch', 1: 'sequence'}
+ return OrderedDict([
+ ('input_ids', dynamic_axis),
+ ('attention_mask', dynamic_axis),
+ ('offset_mapping', dynamic_axis),
+ ('label_mask', dynamic_axis),
+ ])
+
+ @property
+ def outputs(self) -> Mapping[str, Mapping[int, str]]:
+ dynamic_axis = {0: 'batch', 1: 'sequence'}
+ return OrderedDict([
+ ('predictions', dynamic_axis),
+ ])
+
+ def _validate_onnx_model(self,
+ dummy_inputs,
+ model,
+ output,
+ onnx_outputs,
+ rtol: float = None,
+ atol: float = None):
+ try:
+ import onnx
+ import onnxruntime as ort
+ except ImportError:
+ logger.warning(
+ 'Cannot validate the exported onnx file, because '
+ 'the installation of onnx or onnxruntime cannot be found')
+ return
+ onnx_model = onnx.load(output)
+ onnx.checker.check_model(onnx_model)
+ ort_session = ort.InferenceSession(output)
+ with torch.no_grad():
+ model.eval()
+ outputs_origin = model.forward(
+ *self._decide_input_format(model, dummy_inputs))
+ if isinstance(outputs_origin, (Mapping, ModelOutputBase)):
+ outputs_origin = list(
+ numpify_tensor_nested(outputs_origin).values())
+ elif isinstance(outputs_origin, (tuple, list)):
+ outputs_origin = list(numpify_tensor_nested(outputs_origin))
+
+ outputs_origin = [outputs_origin[0]
+ ] # keeo `predictions`, drop other outputs
+
+ np_dummy_inputs = numpify_tensor_nested(dummy_inputs)
+ np_dummy_inputs['label_mask'] = np_dummy_inputs['label_mask'].astype(
+ bool)
+ outputs = ort_session.run(onnx_outputs, np_dummy_inputs)
+ outputs = numpify_tensor_nested(outputs)
+ if isinstance(outputs, dict):
+ outputs = list(outputs.values())
+ elif isinstance(outputs, tuple):
+ outputs = list(outputs)
+
+ tols = {}
+ if rtol is not None:
+ tols['rtol'] = rtol
+ if atol is not None:
+ tols['atol'] = atol
+ if not compare_arguments_nested('Onnx model output match failed',
+ outputs, outputs_origin, **tols):
+ raise RuntimeError(
+ 'export onnx failed because of validation error.')
diff --git a/modelscope/exporters/torch_model_exporter.py b/modelscope/exporters/torch_model_exporter.py
index d203a482..473b9705 100644
--- a/modelscope/exporters/torch_model_exporter.py
+++ b/modelscope/exporters/torch_model_exporter.py
@@ -213,45 +213,58 @@ class TorchModelExporter(Exporter):
)
if validation:
- try:
- import onnx
- import onnxruntime as ort
- except ImportError:
- logger.warning(
- 'Cannot validate the exported onnx file, because '
- 'the installation of onnx or onnxruntime cannot be found')
- return
- onnx_model = onnx.load(output)
- onnx.checker.check_model(onnx_model)
- ort_session = ort.InferenceSession(output)
- with torch.no_grad():
- model.eval()
- outputs_origin = model.forward(
- *self._decide_input_format(model, dummy_inputs))
- if isinstance(outputs_origin, (Mapping, ModelOutputBase)):
- outputs_origin = list(
- numpify_tensor_nested(outputs_origin).values())
- elif isinstance(outputs_origin, (tuple, list)):
- outputs_origin = list(numpify_tensor_nested(outputs_origin))
- outputs = ort_session.run(
- onnx_outputs,
- numpify_tensor_nested(dummy_inputs),
- )
- outputs = numpify_tensor_nested(outputs)
- if isinstance(outputs, dict):
- outputs = list(outputs.values())
- elif isinstance(outputs, tuple):
- outputs = list(outputs)
+ self._validate_onnx_model(dummy_inputs, model, output,
+ onnx_outputs, rtol, atol)
- tols = {}
- if rtol is not None:
- tols['rtol'] = rtol
- if atol is not None:
- tols['atol'] = atol
- if not compare_arguments_nested('Onnx model output match failed',
- outputs, outputs_origin, **tols):
- raise RuntimeError(
- 'export onnx failed because of validation error.')
+ def _validate_onnx_model(self,
+ dummy_inputs,
+ model,
+ output,
+ onnx_outputs,
+ rtol: float = None,
+ atol: float = None):
+ try:
+ import onnx
+ import onnxruntime as ort
+ except ImportError:
+ logger.warning(
+ 'Cannot validate the exported onnx file, because '
+ 'the installation of onnx or onnxruntime cannot be found')
+ return
+ onnx_model = onnx.load(output)
+ onnx.checker.check_model(onnx_model)
+ ort_session = ort.InferenceSession(output)
+ with torch.no_grad():
+ model.eval()
+ outputs_origin = model.forward(
+ *self._decide_input_format(model, dummy_inputs))
+ if isinstance(outputs_origin, (Mapping, ModelOutputBase)):
+ outputs_origin = list(
+ numpify_tensor_nested(outputs_origin).values())
+ elif isinstance(outputs_origin, (tuple, list)):
+ outputs_origin = list(numpify_tensor_nested(outputs_origin))
+
+ outputs = ort_session.run(
+ onnx_outputs,
+ numpify_tensor_nested(dummy_inputs),
+ )
+ outputs = numpify_tensor_nested(outputs)
+ if isinstance(outputs, dict):
+ outputs = list(outputs.values())
+ elif isinstance(outputs, tuple):
+ outputs = list(outputs)
+
+ tols = {}
+ if rtol is not None:
+ tols['rtol'] = rtol
+ if atol is not None:
+ tols['atol'] = atol
+ print(outputs)
+ print(outputs_origin)
+ if not compare_arguments_nested('Onnx model output match failed',
+ outputs, outputs_origin, **tols):
+ raise RuntimeError(
+ 'export onnx failed because of validation error.')
def _torch_export_torch_script(self,
model: nn.Module,
@@ -307,28 +320,33 @@ class TorchModelExporter(Exporter):
torch.jit.save(traced_model, output)
if validation:
- ts_model = torch.jit.load(output)
- with torch.no_grad():
- model.eval()
- ts_model.eval()
- outputs = ts_model.forward(*dummy_inputs)
- outputs = numpify_tensor_nested(outputs)
- outputs_origin = model.forward(*dummy_inputs)
- outputs_origin = numpify_tensor_nested(outputs_origin)
- if isinstance(outputs, dict):
- outputs = list(outputs.values())
- if isinstance(outputs_origin, dict):
- outputs_origin = list(outputs_origin.values())
- tols = {}
- if rtol is not None:
- tols['rtol'] = rtol
- if atol is not None:
- tols['atol'] = atol
- if not compare_arguments_nested(
- 'Torch script model output match failed', outputs,
- outputs_origin, **tols):
- raise RuntimeError(
- 'export torch script failed because of validation error.')
+ self._validate_torch_script_model(dummy_inputs, model, output,
+ rtol, atol)
+
+ def _validate_torch_script_model(self, dummy_inputs, model, output, rtol,
+ atol):
+ ts_model = torch.jit.load(output)
+ with torch.no_grad():
+ model.eval()
+ ts_model.eval()
+ outputs = ts_model.forward(*dummy_inputs)
+ outputs = numpify_tensor_nested(outputs)
+ outputs_origin = model.forward(*dummy_inputs)
+ outputs_origin = numpify_tensor_nested(outputs_origin)
+ if isinstance(outputs, dict):
+ outputs = list(outputs.values())
+ if isinstance(outputs_origin, dict):
+ outputs_origin = list(outputs_origin.values())
+ tols = {}
+ if rtol is not None:
+ tols['rtol'] = rtol
+ if atol is not None:
+ tols['atol'] = atol
+ if not compare_arguments_nested(
+ 'Torch script model output match failed', outputs,
+ outputs_origin, **tols):
+ raise RuntimeError(
+ 'export torch script failed because of validation error.')
@contextmanager
diff --git a/modelscope/hub/file_download.py b/modelscope/hub/file_download.py
index 23391073..7a731b79 100644
--- a/modelscope/hub/file_download.py
+++ b/modelscope/hub/file_download.py
@@ -121,7 +121,7 @@ def model_file_download(
if model_file['Path'] == file_path:
if cache.exists(model_file):
- logger.info(
+ logger.debug(
f'File {model_file["Name"]} already in cache, skip downloading!'
)
return cache.get_file_by_info(model_file)
@@ -209,7 +209,7 @@ def http_get_file(
tempfile.NamedTemporaryFile, mode='wb', dir=local_dir, delete=False)
get_headers = {} if headers is None else copy.deepcopy(headers)
with temp_file_manager() as temp_file:
- logger.info('downloading %s to %s', url, temp_file.name)
+ logger.debug('downloading %s to %s', url, temp_file.name)
# retry sleep 0.5s, 1s, 2s, 4s
retry = Retry(
total=API_FILE_DOWNLOAD_RETRY_TIMES,
@@ -248,7 +248,7 @@ def http_get_file(
retry = retry.increment('GET', url, error=e)
retry.sleep()
- logger.info('storing %s in cache at %s', url, local_dir)
+ logger.debug('storing %s in cache at %s', url, local_dir)
downloaded_length = os.path.getsize(temp_file.name)
if total != downloaded_length:
os.remove(temp_file.name)
diff --git a/modelscope/hub/snapshot_download.py b/modelscope/hub/snapshot_download.py
index 67492649..25a97975 100644
--- a/modelscope/hub/snapshot_download.py
+++ b/modelscope/hub/snapshot_download.py
@@ -122,7 +122,7 @@ def snapshot_download(model_id: str,
# check model_file is exist in cache, if existed, skip download, otherwise download
if cache.exists(model_file):
file_name = os.path.basename(model_file['Name'])
- logger.info(
+ logger.debug(
f'File {file_name} already in cache, skip downloading!'
)
continue
diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py
index 197999bb..360f3241 100644
--- a/modelscope/metainfo.py
+++ b/modelscope/metainfo.py
@@ -46,6 +46,7 @@ class Models(object):
image_paintbyexample = 'Stablediffusion-Paintbyexample'
video_summarization = 'pgl-video-summarization'
video_panoptic_segmentation = 'swinb-video-panoptic-segmentation'
+ video_instance_segmentation = 'swinb-video-instance-segmentation'
language_guided_video_summarization = 'clip-it-language-guided-video-summarization'
swinL_semantic_segmentation = 'swinL-semantic-segmentation'
vitadapter_semantic_segmentation = 'vitadapter-semantic-segmentation'
@@ -78,16 +79,19 @@ class Models(object):
image_body_reshaping = 'image-body-reshaping'
image_skychange = 'image-skychange'
video_human_matting = 'video-human-matting'
+ human_reconstruction = 'human-reconstruction'
video_frame_interpolation = 'video-frame-interpolation'
video_object_segmentation = 'video-object-segmentation'
video_deinterlace = 'video-deinterlace'
quadtree_attention_image_matching = 'quadtree-attention-image-matching'
vision_middleware = 'vision-middleware'
+ vidt = 'vidt'
video_stabilization = 'video-stabilization'
real_basicvsr = 'real-basicvsr'
rcp_sceneflow_estimation = 'rcp-sceneflow-estimation'
image_casmvs_depth_estimation = 'image-casmvs-depth-estimation'
vop_retrieval_model = 'vop-retrieval-model'
+ vop_retrieval_model_se = 'vop-retrieval-model-se'
ddcolor = 'ddcolor'
image_probing_model = 'image-probing-model'
defrcn = 'defrcn'
@@ -100,7 +104,9 @@ class Models(object):
ddpm = 'ddpm'
ocr_recognition = 'OCRRecognition'
ocr_detection = 'OCRDetection'
+ lineless_table_recognition = 'LoreModel'
image_quality_assessment_mos = 'image-quality-assessment-mos'
+ image_quality_assessment_man = 'image-quality-assessment-man'
image_quality_assessment_degradation = 'image-quality-assessment-degradation'
m2fp = 'm2fp'
nerf_recon_acc = 'nerf-recon-acc'
@@ -108,6 +114,7 @@ class Models(object):
vision_efficient_tuning = 'vision-efficient-tuning'
bad_image_detecting = 'bad-image-detecting'
controllable_image_generation = 'controllable-image-generation'
+ longshortnet = 'longshortnet'
# EasyCV models
yolox = 'YOLOX'
@@ -157,10 +164,12 @@ class Models(object):
transformers = 'transformers'
plug_mental = 'plug-mental'
doc2bot = 'doc2bot'
+ peer = 'peer'
# audio models
sambert_hifigan = 'sambert-hifigan'
speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k'
+ speech_dfsmn_ans = 'speech_dfsmn_ans'
speech_dfsmn_kws_char_farfield = 'speech_dfsmn_kws_char_farfield'
speech_kws_fsmn_char_ctc_nearfield = 'speech_kws_fsmn_char_ctc_nearfield'
speech_mossformer_separation_temporal_8k = 'speech_mossformer_separation_temporal_8k'
@@ -177,14 +186,17 @@ class Models(object):
ofa = 'ofa'
clip = 'clip-multi-modal-embedding'
gemm = 'gemm-generative-multi-modal'
+ rleg = 'rleg-generative-multi-modal'
mplug = 'mplug'
diffusion = 'diffusion-text-to-image-synthesis'
multi_stage_diffusion = 'multi-stage-diffusion-text-to-image-synthesis'
+ video_synthesis = 'latent-text-to-video-synthesis'
team = 'team-multi-modal-similarity'
video_clip = 'video-clip-multi-modal-embedding'
mgeo = 'mgeo'
vldoc = 'vldoc'
hitea = 'hitea'
+ soonet = 'soonet'
# science models
unifold = 'unifold'
@@ -242,6 +254,7 @@ class Pipelines(object):
person_image_cartoon = 'unet-person-image-cartoon'
ocr_detection = 'resnet18-ocr-detection'
table_recognition = 'dla34-table-recognition'
+ lineless_table_recognition = 'lore-lineless-table-recognition'
license_plate_detection = 'resnet18-license-plate-detection'
action_recognition = 'TAdaConv_action-recognition'
animal_recognition = 'resnet101-animal-recognition'
@@ -322,6 +335,7 @@ class Pipelines(object):
crowd_counting = 'hrnet-crowd-counting'
action_detection = 'ResNetC3D-action-detection'
video_single_object_tracking = 'ostrack-vitb-video-single-object-tracking'
+ video_single_object_tracking_procontext = 'procontext-vitb-video-single-object-tracking'
video_multi_object_tracking = 'video-multi-object-tracking'
image_panoptic_segmentation = 'image-panoptic-segmentation'
image_panoptic_segmentation_easycv = 'image-panoptic-segmentation-easycv'
@@ -350,7 +364,9 @@ class Pipelines(object):
referring_video_object_segmentation = 'referring-video-object-segmentation'
image_skychange = 'image-skychange'
video_human_matting = 'video-human-matting'
+ human_reconstruction = 'human-reconstruction'
vision_middleware_multi_task = 'vision-middleware-multi-task'
+ vidt = 'vidt'
video_frame_interpolation = 'video-frame-interpolation'
video_object_segmentation = 'video-object-segmentation'
video_deinterlace = 'video-deinterlace'
@@ -360,7 +376,9 @@ class Pipelines(object):
pointcloud_sceneflow_estimation = 'pointcloud-sceneflow-estimation'
image_multi_view_depth_estimation = 'image-multi-view-depth-estimation'
video_panoptic_segmentation = 'video-panoptic-segmentation'
+ video_instance_segmentation = 'video-instance-segmentation'
vop_retrieval = 'vop-video-text-retrieval'
+ vop_retrieval_se = 'vop-video-text-retrieval-se'
ddcolor_image_colorization = 'ddcolor-image-colorization'
image_structured_model_probing = 'image-structured-model-probing'
image_fewshot_detection = 'image-fewshot-detection'
@@ -377,8 +395,10 @@ class Pipelines(object):
controllable_image_generation = 'controllable-image-generation'
image_quality_assessment_mos = 'image-quality-assessment-mos'
+ image_quality_assessment_man = 'image-quality-assessment-man'
image_quality_assessment_degradation = 'image-quality-assessment-degradation'
vision_efficient_tuning = 'vision-efficient-tuning'
+ image_bts_depth_estimation = 'image-bts-depth-estimation'
# nlp tasks
automatic_post_editing = 'automatic-post-editing'
@@ -441,6 +461,7 @@ class Pipelines(object):
sambert_hifigan_tts = 'sambert-hifigan-tts'
speech_dfsmn_aec_psm_16k = 'speech-dfsmn-aec-psm-16k'
speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k'
+ speech_dfsmn_ans_psm_48k_causal = 'speech_dfsmn_ans_psm_48k_causal'
speech_dfsmn_kws_char_farfield = 'speech_dfsmn_kws_char_farfield'
speech_separation = 'speech-separation'
kws_kwsbp = 'kws-kwsbp'
@@ -453,6 +474,7 @@ class Pipelines(object):
vad_inference = 'vad-inference'
speaker_verification = 'speaker-verification'
lm_inference = 'language-score-prediction'
+ speech_timestamp_inference = 'speech-timestamp-inference'
# multi-modal tasks
image_captioning = 'image-captioning'
@@ -472,10 +494,13 @@ class Pipelines(object):
video_captioning = 'video-captioning'
video_question_answering = 'video-question-answering'
diffusers_stable_diffusion = 'diffusers-stable-diffusion'
+ disco_guided_diffusion = 'disco_guided_diffusion'
document_vl_embedding = 'document-vl-embedding'
chinese_stable_diffusion = 'chinese-stable-diffusion'
+ text_to_video_synthesis = 'latent-text-to-video-synthesis' # latent-text-to-video-synthesis
gridvlp_multi_modal_classification = 'gridvlp-multi-modal-classification'
gridvlp_multi_modal_embedding = 'gridvlp-multi-modal-embedding'
+ soonet_video_temporal_grounding = 'soonet-video-temporal-grounding'
# science tasks
protein_structure = 'unifold-protein-structure'
@@ -574,6 +599,9 @@ DEFAULT_MODEL_FOR_PIPELINE = {
Tasks.table_recognition:
(Pipelines.table_recognition,
'damo/cv_dla34_table-structure-recognition_cycle-centernet'),
+ Tasks.lineless_table_recognition:
+ (Pipelines.lineless_table_recognition,
+ 'damo/cv_resnet-transformer_table-structure-recognition_lore'),
Tasks.document_vl_embedding:
(Pipelines.document_vl_embedding,
'damo/multi-modal_convnext-roberta-base_vldoc-embedding'),
@@ -611,6 +639,8 @@ DEFAULT_MODEL_FOR_PIPELINE = {
Tasks.text_to_image_synthesis:
(Pipelines.text_to_image_synthesis,
'damo/cv_diffusion_text-to-image-synthesis_tiny'),
+ Tasks.text_to_video_synthesis: (Pipelines.text_to_video_synthesis,
+ 'damo/text-to-video-synthesis'),
Tasks.body_2d_keypoints: (Pipelines.body_2d_keypoints,
'damo/cv_hrnetv2w32_body-2d-keypoints_image'),
Tasks.body_3d_keypoints: (Pipelines.body_3d_keypoints,
@@ -708,9 +738,9 @@ DEFAULT_MODEL_FOR_PIPELINE = {
'damo/cv_vitb_video-single-object-tracking_ostrack'),
Tasks.image_reid_person: (Pipelines.image_reid_person,
'damo/cv_passvitb_image-reid-person_market'),
- Tasks.text_driven_segmentation:
- (Pipelines.text_driven_segmentation,
- 'damo/cv_vitl16_segmentation_text-driven-seg'),
+ Tasks.text_driven_segmentation: (
+ Pipelines.text_driven_segmentation,
+ 'damo/cv_vitl16_segmentation_text-driven-seg'),
Tasks.movie_scene_segmentation: (
Pipelines.movie_scene_segmentation,
'damo/cv_resnet50-bert_video-scene-segmentation_movienet'),
@@ -727,6 +757,8 @@ DEFAULT_MODEL_FOR_PIPELINE = {
'damo/cv_video-inpainting'),
Tasks.video_human_matting: (Pipelines.video_human_matting,
'damo/cv_effnetv2_video-human-matting'),
+ Tasks.human_reconstruction: (Pipelines.human_reconstruction,
+ 'damo/cv_hrnet_image-human-reconstruction'),
Tasks.video_frame_interpolation: (
Pipelines.video_frame_interpolation,
'damo/cv_raft_video-frame-interpolation'),
@@ -805,7 +837,11 @@ class CVTrainers(object):
image_classification_team = 'image-classification-team'
image_classification = 'image-classification'
image_fewshot_detection = 'image-fewshot-detection'
+ ocr_recognition = 'ocr-recognition'
+ ocr_detection_db = 'ocr-detection-db'
nerf_recon_acc = 'nerf-recon-acc'
+ action_detection = 'action-detection'
+ vision_efficient_tuning = 'vision-efficient-tuning'
class NLPTrainers(object):
@@ -826,6 +862,7 @@ class NLPTrainers(object):
document_grounded_dialog_generate_trainer = 'document-grounded-dialog-generate-trainer'
document_grounded_dialog_rerank_trainer = 'document-grounded-dialog-rerank-trainer'
document_grounded_dialog_retrieval_trainer = 'document-grounded-dialog-retrieval-trainer'
+ siamese_uie_trainer = 'siamese-uie-trainer'
class MultiModalTrainers(object):
@@ -904,6 +941,7 @@ class Preprocessors(object):
image_instance_segmentation_preprocessor = 'image-instance-segmentation-preprocessor'
image_driving_perception_preprocessor = 'image-driving-perception-preprocessor'
image_portrait_enhancement_preprocessor = 'image-portrait-enhancement-preprocessor'
+ image_quality_assessment_man_preprocessor = 'image-quality_assessment-man-preprocessor'
image_quality_assessment_mos_preprocessor = 'image-quality_assessment-mos-preprocessor'
video_summarization_preprocessor = 'video-summarization-preprocessor'
movie_scene_segmentation_preprocessor = 'movie-scene-segmentation-preprocessor'
@@ -916,6 +954,7 @@ class Preprocessors(object):
bad_image_detecting_preprocessor = 'bad-image-detecting-preprocessor'
nerf_recon_acc_preprocessor = 'nerf-recon-acc-preprocessor'
controllable_image_generation_preprocessor = 'controllable-image-generation-preprocessor'
+ image_classification_preprocessor = 'image-classification-preprocessor'
# nlp preprocessor
sen_sim_tokenizer = 'sen-sim-tokenizer'
@@ -1035,6 +1074,9 @@ class Metrics(object):
image_quality_assessment_degradation_metric = 'image-quality-assessment-degradation-metric'
# metric for text-ranking task
text_ranking_metric = 'text-ranking-metric'
+ # metric for image-colorization task
+ image_colorization_metric = 'image-colorization-metric'
+ ocr_recognition_metric = 'ocr-recognition-metric'
class Optimizers(object):
@@ -1087,6 +1129,7 @@ class Hooks(object):
EarlyStopHook = 'EarlyStopHook'
DeepspeedHook = 'DeepspeedHook'
MegatronHook = 'MegatronHook'
+ DDPHook = 'DDPHook'
class LR_Schedulers(object):
@@ -1098,7 +1141,7 @@ class LR_Schedulers(object):
ExponentialWarmup = 'ExponentialWarmup'
-class Datasets(object):
+class CustomDatasets(object):
""" Names for different datasets.
"""
ClsDataset = 'ClsDataset'
@@ -1110,3 +1153,6 @@ class Datasets(object):
DetImagesMixDataset = 'DetImagesMixDataset'
PanopticDataset = 'PanopticDataset'
PairedDataset = 'PairedDataset'
+ SiddDataset = 'SiddDataset'
+ GoproDataset = 'GoproDataset'
+ RedsDataset = 'RedsDataset'
diff --git a/modelscope/metrics/__init__.py b/modelscope/metrics/__init__.py
index e463ea63..17767001 100644
--- a/modelscope/metrics/__init__.py
+++ b/modelscope/metrics/__init__.py
@@ -29,6 +29,8 @@ if TYPE_CHECKING:
from .image_quality_assessment_mos_metric import ImageQualityAssessmentMosMetric
from .text_ranking_metric import TextRankingMetric
from .loss_metric import LossMetric
+ from .image_colorization_metric import ImageColorizationMetric
+ from .ocr_recognition_metric import OCRRecognitionMetric
else:
_import_structure = {
'audio_noise_metric': ['AudioNoiseMetric'],
@@ -58,7 +60,9 @@ else:
'image_quality_assessment_mos_metric':
['ImageQualityAssessmentMosMetric'],
'text_ranking_metric': ['TextRankingMetric'],
- 'loss_metric': ['LossMetric']
+ 'loss_metric': ['LossMetric'],
+ 'image_colorization_metric': ['ImageColorizationMetric'],
+ 'ocr_recognition_metric': ['OCRRecognitionMetric']
}
import sys
diff --git a/modelscope/metrics/action_detection_evaluator.py b/modelscope/metrics/action_detection_evaluator.py
new file mode 100644
index 00000000..24dd51ae
--- /dev/null
+++ b/modelscope/metrics/action_detection_evaluator.py
@@ -0,0 +1,198 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+
+import copy
+import logging
+import os.path as osp
+from collections import OrderedDict
+
+import numpy as np
+import pandas as pd
+from detectron2.evaluation import DatasetEvaluator
+from detectron2.evaluation.pascal_voc_evaluation import voc_ap
+from detectron2.structures.boxes import Boxes, pairwise_iou
+from detectron2.utils import comm
+from scipy import interpolate
+
+
+class DetEvaluator(DatasetEvaluator):
+
+ def __init__(self, class_names, output_dir, distributed=False):
+ self.num_classes = len(class_names)
+ self.class_names = class_names
+ self.output_dir = output_dir
+ self.distributed = distributed
+ self.predictions = []
+ self.gts = []
+
+ def reset(self):
+ self.predictions.clear()
+ self.gts.clear()
+
+ def process(self, input, output):
+ """
+
+ :param input: dataloader
+ :param output: model(input)
+ :return:
+ """
+ gt_instances = [x['instances'].to('cpu') for x in input]
+ pred_instances = [x['instances'].to('cpu') for x in output]
+ self.gts.extend(gt_instances)
+ self.predictions.extend(pred_instances)
+
+ def get_instance_by_class(self, instances, c):
+ instances = copy.deepcopy(instances)
+ name = 'gt_classes' if instances.has('gt_classes') else 'pred_classes'
+ idxs = np.where(instances.get(name).numpy() == c)[0].tolist()
+ data = {}
+ for k, v in instances.get_fields().items():
+ data[k] = [v[i] for i in idxs]
+ return data
+
+ def evaluate(self):
+ if self.distributed:
+ comm.synchronize()
+ self.predictions = sum(comm.gather(self.predictions, dst=0), [])
+ self.gts = sum(comm.gather(self.gts, dst=0), [])
+ if not comm.is_main_process():
+ return
+ logger = logging.getLogger('detectron2.human.' + __name__)
+ logger.info(', '.join([f'{a}' for a in self.class_names]))
+ maps = []
+ precisions = []
+ recalls = []
+ for iou_th in [0.3, 0.5, 0.7]:
+ aps, prs, ths = self.calc_map(iou_th)
+ map = np.nanmean([x for x in aps if x > 0.01])
+ maps.append(map)
+ logger.info(f'iou_th:{iou_th},' + 'Aps:'
+ + ','.join([f'{ap:.2f}'
+ for ap in aps]) + f', {map:.3f}')
+ precision, recall = zip(*prs)
+ logger.info('precision:'
+ + ', '.join([f'{p:.2f}' for p in precision]))
+ logger.info('recall: ' + ', '.join([f'{p:.2f}' for p in recall]))
+ logger.info('score th: ' + ', '.join([f'{p:.2f}' for p in ths]))
+ logger.info(f'mean-precision:{np.nanmean(precision):.3f}')
+ logger.info(f'mean-recall:{np.nanmean(recall):.3f}')
+ precisions.append(np.nanmean(precision))
+ recalls.append(np.nanmean(recall))
+
+ res = OrderedDict({
+ 'det': {
+ 'mAP': np.nanmean(maps),
+ 'precision': np.nanmean(precisions),
+ 'recall': np.nanmean(recalls)
+ }
+ })
+ return res
+
+ def calc_map(self, iou_th):
+ aps = []
+ prs = []
+ ths = []
+ # 对每个类别
+ interpolate_precs = []
+ for c in range(self.num_classes):
+ ap, recalls, precisions, scores = self.det_eval(iou_th, c)
+ if iou_th == 0.3:
+ p1 = interpolate_precision(recalls, precisions)
+ interpolate_precs.append(p1)
+ recalls = np.concatenate(([0.0], recalls, [1.0]))
+ precisions = np.concatenate(([0.0], precisions, [0.0]))
+ scores = np.concatenate(([1.0], scores, [0.0]))
+ t = precisions + recalls
+ t[t == 0] = 1e-5
+ f_score = 2 * precisions * recalls / t
+ f_score[np.isnan(f_score)] = 0
+ idx = np.argmax(f_score)
+ # print(iou_th,c,np.argmax(f_score),np.argmax(t))
+ precision_recall = (precisions[idx], recalls[idx])
+ prs.append(precision_recall)
+ aps.append(ap)
+ ths.append(scores[idx])
+ if iou_th == 0.3:
+ interpolate_precs = np.stack(interpolate_precs, axis=1)
+ df = pd.DataFrame(data=interpolate_precs)
+ df.to_csv(
+ osp.join(self.output_dir, 'pr_data.csv'),
+ index=False,
+ columns=None)
+ return aps, prs, ths
+
+ def det_eval(self, iou_th, class_id):
+ c = class_id
+ class_res_gt = {}
+ npos = 0
+ # 对每个样本
+ for i, (gt, pred) in enumerate(zip(self.gts, self.predictions)):
+ gt_classes = gt.gt_classes.tolist()
+ pred_classes = pred.pred_classes.tolist()
+ if c not in gt_classes + pred_classes:
+ continue
+ pred_data = self.get_instance_by_class(pred, c)
+ gt_data = self.get_instance_by_class(gt, c)
+ res = {}
+ if c in gt_classes:
+ res.update({
+ 'gt_bbox': Boxes.cat(gt_data['gt_boxes']),
+ 'det': [False] * len(gt_data['gt_classes'])
+ })
+ if c in pred_classes:
+ res.update({'pred_bbox': Boxes.cat(pred_data['pred_boxes'])})
+ res.update(
+ {'pred_score': [s.item() for s in pred_data['scores']]})
+ class_res_gt[i] = res
+ npos += len(gt_data['gt_classes'])
+
+ all_preds = []
+ for img_id, res in class_res_gt.items():
+ if 'pred_bbox' in res:
+ for i in range(len(res['pred_bbox'])):
+ bbox = res['pred_bbox'][i]
+ score = res['pred_score'][i]
+ all_preds.append([img_id, bbox, score])
+ sorted_preds = list(
+ sorted(all_preds, key=lambda x: x[2], reverse=True))
+ scores = [s[-1] for s in sorted_preds]
+ nd = len(sorted_preds)
+ tp = np.zeros(nd)
+ fp = np.zeros(nd)
+ for d in range(nd):
+ img_id, pred_bbox, score = sorted_preds[d]
+ R = class_res_gt[sorted_preds[d][0]]
+ ovmax = -np.inf
+ if 'gt_bbox' in R:
+ gt_bbox = R['gt_bbox']
+ IoUs = pairwise_iou(pred_bbox, gt_bbox).numpy()
+ ovmax = IoUs[0].max()
+ jmax = np.argmax(IoUs[0]) # hit该图像的第几个gt
+ if ovmax > iou_th:
+ if not R['det'][jmax]: # 该gt还没有预测过
+ tp[d] = 1.0
+ R['det'][jmax] = True
+ else: # 重复预测
+ fp[d] = 1.0
+ else:
+ fp[d] = 1.0
+ fp = np.cumsum(fp)
+ tp = np.cumsum(tp)
+ rec = tp / float(npos)
+ # avoid divide by zero in case the first detection matches a difficult
+ # ground truth
+ prec = tp / np.maximum(tp + fp, np.finfo(np.float64).eps)
+ ap = voc_ap(rec, prec, False)
+ return ap, rec, prec, scores
+
+
+def interpolate_precision(rec, prec):
+ rec = np.concatenate(([0.0], rec, [1.0, 1.1]))
+ prec = np.concatenate(([1.0], prec, [0.0]))
+ for i in range(prec.size - 1, 0, -1):
+ prec[i - 1] = np.maximum(prec[i - 1], prec[i])
+ i = np.where(rec[1:] != rec[:-1])[0] # 从recall改变的地方取值
+ rec, prec = rec[i], prec[i]
+ f = interpolate.interp1d(rec, prec)
+ r1 = np.linspace(0, 1, 101)
+ p1 = f(r1)
+ return p1
diff --git a/modelscope/metrics/builder.py b/modelscope/metrics/builder.py
index 0357fa25..882569dd 100644
--- a/modelscope/metrics/builder.py
+++ b/modelscope/metrics/builder.py
@@ -40,6 +40,7 @@ class MetricKeys(object):
RMSE = 'rmse'
MRR = 'mrr'
NDCG = 'ndcg'
+ AR = 'AR'
task_default_metrics = {
@@ -71,6 +72,7 @@ task_default_metrics = {
Tasks.image_quality_assessment_mos:
[Metrics.image_quality_assessment_mos_metric],
Tasks.bad_image_detecting: [Metrics.accuracy],
+ Tasks.ocr_recognition: [Metrics.ocr_recognition_metric],
}
diff --git a/modelscope/metrics/image_colorization_metric.py b/modelscope/metrics/image_colorization_metric.py
new file mode 100644
index 00000000..bbaf4127
--- /dev/null
+++ b/modelscope/metrics/image_colorization_metric.py
@@ -0,0 +1,56 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+from typing import Dict
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from scipy import linalg
+
+from modelscope.metainfo import Metrics
+from modelscope.models.cv.image_inpainting.modules.inception import InceptionV3
+from modelscope.utils.registry import default_group
+from modelscope.utils.tensor_utils import (torch_nested_detach,
+ torch_nested_numpify)
+from .base import Metric
+from .builder import METRICS, MetricKeys
+from .image_denoise_metric import calculate_psnr
+from .image_inpainting_metric import FIDScore
+
+
+@METRICS.register_module(
+ group_key=default_group, module_name=Metrics.image_colorization_metric)
+class ImageColorizationMetric(Metric):
+ """The metric computation class for image colorization.
+ """
+
+ def __init__(self):
+ self.preds = []
+ self.targets = []
+
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
+ self.FID = FIDScore().to(device)
+
+ def add(self, outputs: Dict, inputs: Dict):
+ ground_truths = outputs['preds']
+ eval_results = outputs['targets']
+ self.preds.append(eval_results)
+ self.targets.append(ground_truths)
+
+ def evaluate(self):
+ psnr_list = []
+ for (pred, target) in zip(self.preds, self.targets):
+ self.FID(pred, target)
+ psnr_list.append(calculate_psnr(target[0], pred[0], crop_border=0))
+ fid = self.FID.get_value()
+ return {MetricKeys.PSNR: np.mean(psnr_list), MetricKeys.FID: fid}
+
+ def merge(self, other: 'ImageColorizationMetric'):
+ self.preds.extend(other.preds)
+ self.targets.extend(other.targets)
+
+ def __getstate__(self):
+ return self.preds, self.targets
+
+ def __setstate__(self, state):
+ self.__init__()
+ self.preds, self.targets = state
diff --git a/modelscope/metrics/ocr_recognition_metric.py b/modelscope/metrics/ocr_recognition_metric.py
new file mode 100644
index 00000000..41fc28a5
--- /dev/null
+++ b/modelscope/metrics/ocr_recognition_metric.py
@@ -0,0 +1,79 @@
+from typing import Dict
+
+import edit_distance as ed
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+from modelscope.metainfo import Metrics
+from modelscope.utils.registry import default_group
+from .base import Metric
+from .builder import METRICS, MetricKeys
+
+
+def cal_distance(label_list, pre_list):
+ y = ed.SequenceMatcher(a=label_list, b=pre_list)
+ yy = y.get_opcodes()
+ insert = 0
+ delete = 0
+ replace = 0
+ for item in yy:
+ if item[0] == 'insert':
+ insert += item[-1] - item[-2]
+ if item[0] == 'delete':
+ delete += item[2] - item[1]
+ if item[0] == 'replace':
+ replace += item[-1] - item[-2]
+ distance = insert + delete + replace
+ return distance, (delete, replace, insert)
+
+
+@METRICS.register_module(
+ group_key=default_group, module_name=Metrics.ocr_recognition_metric)
+class OCRRecognitionMetric(Metric):
+ """The metric computation class for ocr recognition.
+ """
+
+ def __init__(self, *args, **kwargs):
+ self.preds = []
+ self.targets = []
+ self.loss_sum = 0.
+ self.nsample = 0
+ self.iter_sum = 0
+
+ def add(self, outputs: Dict, inputs: Dict):
+ pred = outputs['preds']
+ loss = outputs['loss']
+ target = inputs['labels']
+ self.preds.extend(pred)
+ self.targets.extend(target)
+ self.loss_sum += loss.data.cpu().numpy()
+ self.nsample += len(pred)
+ self.iter_sum += 1
+
+ def evaluate(self):
+ total_chars = 0
+ total_distance = 0
+ total_fullmatch = 0
+ for (pred, target) in zip(self.preds, self.targets):
+ distance, _ = cal_distance(target, pred)
+ total_chars += len(target)
+ total_distance += distance
+ total_fullmatch += (target == pred)
+ accuracy = float(total_fullmatch) / self.nsample
+ AR = 1 - float(total_distance) / total_chars
+ average_loss = self.loss_sum / self.iter_sum if self.iter_sum > 0 else 0
+ return {
+ MetricKeys.ACCURACY: accuracy,
+ MetricKeys.AR: AR,
+ MetricKeys.AVERAGE_LOSS: average_loss
+ }
+
+ def merge(self, other: 'OCRRecognitionMetric'):
+ pass
+
+ def __getstate__(self):
+ pass
+
+ def __setstate__(self, state):
+ pass
diff --git a/modelscope/models/audio/ans/__init__.py b/modelscope/models/audio/ans/__init__.py
index afcdf314..b88a787a 100644
--- a/modelscope/models/audio/ans/__init__.py
+++ b/modelscope/models/audio/ans/__init__.py
@@ -9,6 +9,7 @@ if TYPE_CHECKING:
else:
_import_structure = {
'frcrn': ['FRCRNDecorator'],
+ 'dnoise_net': ['DenoiseNet'],
}
import sys
diff --git a/modelscope/models/audio/ans/complex_nn.py b/modelscope/models/audio/ans/complex_nn.py
index beaa3187..98bfd8b5 100644
--- a/modelscope/models/audio/ans/complex_nn.py
+++ b/modelscope/models/audio/ans/complex_nn.py
@@ -7,57 +7,8 @@
import torch
import torch.nn as nn
-import torch.nn.functional as F
-
-class UniDeepFsmn(nn.Module):
-
- def __init__(self, input_dim, output_dim, lorder=None, hidden_size=None):
- super(UniDeepFsmn, self).__init__()
-
- self.input_dim = input_dim
- self.output_dim = output_dim
-
- if lorder is None:
- return
-
- self.lorder = lorder
- self.hidden_size = hidden_size
-
- self.linear = nn.Linear(input_dim, hidden_size)
-
- self.project = nn.Linear(hidden_size, output_dim, bias=False)
-
- self.conv1 = nn.Conv2d(
- output_dim,
- output_dim, [lorder, 1], [1, 1],
- groups=output_dim,
- bias=False)
-
- def forward(self, input):
- r"""
-
- Args:
- input: torch with shape: batch (b) x sequence(T) x feature (h)
-
- Returns:
- batch (b) x channel (c) x sequence(T) x feature (h)
- """
- f1 = F.relu(self.linear(input))
-
- p1 = self.project(f1)
-
- x = torch.unsqueeze(p1, 1)
- # x: batch (b) x channel (c) x sequence(T) x feature (h)
- x_per = x.permute(0, 3, 2, 1)
- # x_per: batch (b) x feature (h) x sequence(T) x channel (c)
- y = F.pad(x_per, [0, 0, self.lorder - 1, 0])
-
- out = x_per + self.conv1(y)
-
- out1 = out.permute(0, 3, 2, 1)
- # out1: batch (b) x channel (c) x sequence(T) x feature (h)
- return input + out1.squeeze()
+from modelscope.models.audio.ans.layers.uni_deep_fsmn import UniDeepFsmn
class ComplexUniDeepFsmn(nn.Module):
diff --git a/modelscope/models/audio/ans/denoise_net.py b/modelscope/models/audio/ans/denoise_net.py
new file mode 100644
index 00000000..9d20074b
--- /dev/null
+++ b/modelscope/models/audio/ans/denoise_net.py
@@ -0,0 +1,73 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+# Related papers:
+# Shengkui Zhao, Trung Hieu Nguyen, Bin Ma, “Monaural Speech Enhancement with Complex Convolutional
+# Block Attention Module and Joint Time Frequency Losses”, ICASSP 2021.
+# Shiliang Zhang, Ming Lei, Zhijie Yan, Lirong Dai, “Deep-FSMN for Large Vocabulary Continuous Speech
+# Recognition “, arXiv:1803.05030, 2018.
+
+from torch import nn
+
+from modelscope.metainfo import Models
+from modelscope.models import MODELS, TorchModel
+from modelscope.models.audio.ans.layers.activations import (RectifiedLinear,
+ Sigmoid)
+from modelscope.models.audio.ans.layers.affine_transform import AffineTransform
+from modelscope.models.audio.ans.layers.uni_deep_fsmn import UniDeepFsmn
+from modelscope.utils.constant import Tasks
+
+
+@MODELS.register_module(
+ Tasks.acoustic_noise_suppression, module_name=Models.speech_dfsmn_ans)
+class DfsmnAns(TorchModel):
+ """Denoise model with DFSMN.
+
+ Args:
+ model_dir (str): the model path.
+ fsmn_depth (int): the depth of deepfsmn
+ lorder (int):
+ """
+
+ def __init__(self,
+ model_dir: str,
+ fsmn_depth=9,
+ lorder=20,
+ *args,
+ **kwargs):
+ super().__init__(model_dir, *args, **kwargs)
+ self.lorder = lorder
+ self.linear1 = AffineTransform(120, 256)
+ self.relu = RectifiedLinear(256, 256)
+ repeats = [
+ UniDeepFsmn(256, 256, lorder, 256) for i in range(fsmn_depth)
+ ]
+ self.deepfsmn = nn.Sequential(*repeats)
+ self.linear2 = AffineTransform(256, 961)
+ self.sig = Sigmoid(961, 961)
+
+ def forward(self, input):
+ """
+ Args:
+ input: fbank feature [batch_size,number_of_frame,feature_dimension]
+
+ Returns:
+ mask value [batch_size, number_of_frame, FFT_size/2+1]
+ """
+ x1 = self.linear1(input)
+ x2 = self.relu(x1)
+ x3 = self.deepfsmn(x2)
+ x4 = self.linear2(x3)
+ x5 = self.sig(x4)
+ return x5
+
+ def to_kaldi_nnet(self):
+ re_str = ''
+ re_str += '\n'
+ re_str += self.linear1.to_kaldi_nnet()
+ re_str += self.relu.to_kaldi_nnet()
+ for dfsmn in self.deepfsmn:
+ re_str += dfsmn.to_kaldi_nnet()
+ re_str += self.linear2.to_kaldi_nnet()
+ re_str += self.sig.to_kaldi_nnet()
+ re_str += '\n'
+
+ return re_str
diff --git a/modelscope/models/audio/ans/frcrn.py b/modelscope/models/audio/ans/frcrn.py
index 220a14aa..0a83dfae 100644
--- a/modelscope/models/audio/ans/frcrn.py
+++ b/modelscope/models/audio/ans/frcrn.py
@@ -78,7 +78,7 @@ class FRCRN(nn.Module):
win_len=400,
win_inc=100,
fft_len=512,
- win_type='hanning',
+ win_type='hann',
**kwargs):
r"""
Args:
diff --git a/modelscope/models/audio/tts/kantts/datasets/__init__.py b/modelscope/models/audio/ans/layers/__init__.py
similarity index 100%
rename from modelscope/models/audio/tts/kantts/datasets/__init__.py
rename to modelscope/models/audio/ans/layers/__init__.py
diff --git a/modelscope/models/audio/ans/layers/activations.py b/modelscope/models/audio/ans/layers/activations.py
new file mode 100644
index 00000000..406de736
--- /dev/null
+++ b/modelscope/models/audio/ans/layers/activations.py
@@ -0,0 +1,62 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+
+import torch.nn as nn
+
+from modelscope.models.audio.ans.layers.layer_base import LayerBase
+
+
+class RectifiedLinear(LayerBase):
+
+ def __init__(self, input_dim, output_dim):
+ super(RectifiedLinear, self).__init__()
+ self.dim = input_dim
+ self.relu = nn.ReLU()
+
+ def forward(self, input):
+ return self.relu(input)
+
+ def to_kaldi_nnet(self):
+ re_str = ''
+ re_str += ' %d %d\n' % (self.dim, self.dim)
+ return re_str
+
+ def load_kaldi_nnet(self, instr):
+ return instr
+
+
+class LogSoftmax(LayerBase):
+
+ def __init__(self, input_dim, output_dim):
+ super(LogSoftmax, self).__init__()
+ self.dim = input_dim
+ self.ls = nn.LogSoftmax()
+
+ def forward(self, input):
+ return self.ls(input)
+
+ def to_kaldi_nnet(self):
+ re_str = ''
+ re_str += ' %d %d\n' % (self.dim, self.dim)
+ return re_str
+
+ def load_kaldi_nnet(self, instr):
+ return instr
+
+
+class Sigmoid(LayerBase):
+
+ def __init__(self, input_dim, output_dim):
+ super(Sigmoid, self).__init__()
+ self.dim = input_dim
+ self.sig = nn.Sigmoid()
+
+ def forward(self, input):
+ return self.sig(input)
+
+ def to_kaldi_nnet(self):
+ re_str = ''
+ re_str += ' %d %d\n' % (self.dim, self.dim)
+ return re_str
+
+ def load_kaldi_nnet(self, instr):
+ return instr
diff --git a/modelscope/models/audio/ans/layers/affine_transform.py b/modelscope/models/audio/ans/layers/affine_transform.py
new file mode 100644
index 00000000..d3cad181
--- /dev/null
+++ b/modelscope/models/audio/ans/layers/affine_transform.py
@@ -0,0 +1,86 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+
+import torch as th
+import torch.nn as nn
+
+from modelscope.models.audio.ans.layers.layer_base import (LayerBase,
+ to_kaldi_matrix)
+from modelscope.utils.audio.audio_utils import (expect_kaldi_matrix,
+ expect_token_number)
+
+
+class AffineTransform(LayerBase):
+
+ def __init__(self, input_dim, output_dim):
+ super(AffineTransform, self).__init__()
+ self.input_dim = input_dim
+ self.output_dim = output_dim
+
+ self.linear = nn.Linear(input_dim, output_dim)
+
+ def forward(self, input):
+ return self.linear(input)
+
+ def to_kaldi_nnet(self):
+ re_str = ''
+
+ re_str += ' %d %d\n' % (self.output_dim,
+ self.input_dim)
+
+ re_str += ' 1 1 0\n'
+
+ linear_weights = self.state_dict()['linear.weight']
+
+ x = linear_weights.squeeze().numpy()
+
+ re_str += to_kaldi_matrix(x)
+
+ linear_bias = self.state_dict()['linear.bias']
+
+ x = linear_bias.squeeze().numpy()
+
+ re_str += to_kaldi_matrix(x)
+
+ return re_str
+
+ def load_kaldi_nnet(self, instr):
+
+ output = expect_token_number(
+ instr,
+ '',
+ )
+ if output is None:
+ raise Exception('AffineTransform format error')
+
+ instr, lr = output
+
+ output = expect_token_number(instr, '')
+ if output is None:
+ raise Exception('AffineTransform format error')
+
+ instr, lr = output
+
+ output = expect_token_number(instr, '')
+ if output is None:
+ raise Exception('AffineTransform format error')
+
+ instr, lr = output
+
+ output = expect_kaldi_matrix(instr)
+
+ if output is None:
+ raise Exception('AffineTransform format error')
+
+ instr, mat = output
+
+ self.linear.weight = th.nn.Parameter(
+ th.from_numpy(mat).type(th.FloatTensor))
+
+ output = expect_kaldi_matrix(instr)
+ if output is None:
+ raise Exception('AffineTransform format error')
+
+ instr, mat = output
+ self.linear.bias = th.nn.Parameter(
+ th.from_numpy(mat).type(th.FloatTensor))
+ return instr
diff --git a/modelscope/models/audio/ans/layers/layer_base.py b/modelscope/models/audio/ans/layers/layer_base.py
new file mode 100644
index 00000000..ca713d2f
--- /dev/null
+++ b/modelscope/models/audio/ans/layers/layer_base.py
@@ -0,0 +1,31 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+
+import abc
+
+import numpy as np
+import six
+import torch.nn as nn
+
+
+def to_kaldi_matrix(np_mat):
+ """ function that transform as str numpy mat to standard kaldi str matrix
+
+ Args:
+ np_mat: numpy mat
+ """
+ np.set_printoptions(threshold=np.inf, linewidth=np.nan)
+ out_str = str(np_mat)
+ out_str = out_str.replace('[', '')
+ out_str = out_str.replace(']', '')
+ return '[ %s ]\n' % out_str
+
+
+@six.add_metaclass(abc.ABCMeta)
+class LayerBase(nn.Module):
+
+ def __init__(self):
+ super(LayerBase, self).__init__()
+
+ @abc.abstractmethod
+ def to_kaldi_nnet(self):
+ pass
diff --git a/modelscope/models/audio/ans/layers/uni_deep_fsmn.py b/modelscope/models/audio/ans/layers/uni_deep_fsmn.py
new file mode 100644
index 00000000..772e6048
--- /dev/null
+++ b/modelscope/models/audio/ans/layers/uni_deep_fsmn.py
@@ -0,0 +1,156 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from modelscope.models.audio.ans.layers.layer_base import (LayerBase,
+ to_kaldi_matrix)
+from modelscope.utils.audio.audio_utils import (expect_kaldi_matrix,
+ expect_token_number)
+
+
+class UniDeepFsmn(LayerBase):
+
+ def __init__(self, input_dim, output_dim, lorder=1, hidden_size=None):
+ super(UniDeepFsmn, self).__init__()
+
+ self.input_dim = input_dim
+ self.output_dim = output_dim
+ self.lorder = lorder
+ self.hidden_size = hidden_size
+
+ self.linear = nn.Linear(input_dim, hidden_size)
+ self.project = nn.Linear(hidden_size, output_dim, bias=False)
+ self.conv1 = nn.Conv2d(
+ output_dim,
+ output_dim, (lorder, 1), (1, 1),
+ groups=output_dim,
+ bias=False)
+
+ def forward(self, input):
+ """
+
+ Args:
+ input: torch with shape: batch (b) x sequence(T) x feature (h)
+
+ Returns:
+ batch (b) x channel (c) x sequence(T) x feature (h)
+ """
+ f1 = F.relu(self.linear(input))
+ p1 = self.project(f1)
+ x = torch.unsqueeze(p1, 1)
+ # x: batch (b) x channel (c) x sequence(T) x feature (h)
+ x_per = x.permute(0, 3, 2, 1)
+ # x_per: batch (b) x feature (h) x sequence(T) x channel (c)
+ y = F.pad(x_per, [0, 0, self.lorder - 1, 0])
+
+ out = x_per + self.conv1(y)
+ out1 = out.permute(0, 3, 2, 1)
+ # out1: batch (b) x channel (c) x sequence(T) x feature (h)
+ return input + out1.squeeze()
+
+ def to_kaldi_nnet(self):
+ re_str = ''
+ re_str += ' %d %d\n'\
+ % (self.output_dim, self.input_dim)
+ re_str += ' %d %d %d %d 0\n'\
+ % (1, self.hidden_size, self.lorder, 1)
+
+ lfiters = self.state_dict()['conv1.weight']
+ x = np.flipud(lfiters.squeeze().numpy().T)
+ re_str += to_kaldi_matrix(x)
+ proj_weights = self.state_dict()['project.weight']
+ x = proj_weights.squeeze().numpy()
+ re_str += to_kaldi_matrix(x)
+ linear_weights = self.state_dict()['linear.weight']
+ x = linear_weights.squeeze().numpy()
+ re_str += to_kaldi_matrix(x)
+ linear_bias = self.state_dict()['linear.bias']
+ x = linear_bias.squeeze().numpy()
+ re_str += to_kaldi_matrix(x)
+ return re_str
+
+ def load_kaldi_nnet(self, instr):
+ output = expect_token_number(
+ instr,
+ '',
+ )
+ if output is None:
+ raise Exception('UniDeepFsmn format error')
+ instr, lr = output
+
+ output = expect_token_number(
+ instr,
+ '',
+ )
+ if output is None:
+ raise Exception('UniDeepFsmn format error')
+ instr, hiddensize = output
+ self.hidden_size = int(hiddensize)
+
+ output = expect_token_number(
+ instr,
+ '',
+ )
+ if output is None:
+ raise Exception('UniDeepFsmn format error')
+ instr, lorder = output
+ self.lorder = int(lorder)
+
+ output = expect_token_number(
+ instr,
+ '',
+ )
+ if output is None:
+ raise Exception('UniDeepFsmn format error')
+ instr, lstride = output
+ self.lstride = lstride
+
+ output = expect_token_number(
+ instr,
+ '',
+ )
+ if output is None:
+ raise Exception('UniDeepFsmn format error')
+
+ output = expect_kaldi_matrix(instr)
+ if output is None:
+ raise Exception('Fsmn format error')
+ instr, mat = output
+ mat1 = np.fliplr(mat.T).copy()
+ self.conv1 = nn.Conv2d(
+ self.output_dim,
+ self.output_dim, (self.lorder, 1), (1, 1),
+ groups=self.output_dim,
+ bias=False)
+ mat_th = torch.from_numpy(mat1).type(torch.FloatTensor)
+ mat_th = mat_th.unsqueeze(1)
+ mat_th = mat_th.unsqueeze(3)
+ self.conv1.weight = torch.nn.Parameter(mat_th)
+
+ output = expect_kaldi_matrix(instr)
+ if output is None:
+ raise Exception('UniDeepFsmn format error')
+ instr, mat = output
+
+ self.project = nn.Linear(self.hidden_size, self.output_dim, bias=False)
+ self.linear = nn.Linear(self.input_dim, self.hidden_size)
+ self.project.weight = torch.nn.Parameter(
+ torch.from_numpy(mat).type(torch.FloatTensor))
+
+ output = expect_kaldi_matrix(instr)
+ if output is None:
+ raise Exception('UniDeepFsmn format error')
+ instr, mat = output
+ self.linear.weight = torch.nn.Parameter(
+ torch.from_numpy(mat).type(torch.FloatTensor))
+
+ output = expect_kaldi_matrix(instr)
+ if output is None:
+ raise Exception('UniDeepFsmn format error')
+ instr, mat = output
+ self.linear.bias = torch.nn.Parameter(
+ torch.from_numpy(mat).type(torch.FloatTensor))
+ return instr
diff --git a/modelscope/models/audio/asr/generic_automatic_speech_recognition.py b/modelscope/models/audio/asr/generic_automatic_speech_recognition.py
index b66351cc..25de839e 100644
--- a/modelscope/models/audio/asr/generic_automatic_speech_recognition.py
+++ b/modelscope/models/audio/asr/generic_automatic_speech_recognition.py
@@ -17,6 +17,7 @@ __all__ = ['GenericAutomaticSpeechRecognition']
Tasks.voice_activity_detection, module_name=Models.generic_asr)
@MODELS.register_module(
Tasks.language_score_prediction, module_name=Models.generic_asr)
+@MODELS.register_module(Tasks.speech_timestamp, module_name=Models.generic_asr)
class GenericAutomaticSpeechRecognition(Model):
def __init__(self, model_dir: str, am_model_name: str,
diff --git a/modelscope/models/audio/kws/farfield/model.py b/modelscope/models/audio/kws/farfield/model.py
index ee0301f9..fff88805 100644
--- a/modelscope/models/audio/kws/farfield/model.py
+++ b/modelscope/models/audio/kws/farfield/model.py
@@ -68,6 +68,7 @@ class FSMNSeleNetV2Decorator(TorchModel):
'keyword':
self._sc.kwsKeyword(self._sc.kwsSpottedKeywordIndex()),
'offset': self._sc.kwsKeywordOffset(),
+ 'channel': self._sc.kwsBestChannel(),
'length': self._sc.kwsKeywordLength(),
'confidence': self._sc.kwsConfidence()
}
diff --git a/modelscope/models/audio/tts/kantts/__init__.py b/modelscope/models/audio/tts/kantts/__init__.py
deleted file mode 100644
index 2b745d4a..00000000
--- a/modelscope/models/audio/tts/kantts/__init__.py
+++ /dev/null
@@ -1,8 +0,0 @@
-# Copyright (c) Alibaba, Inc. and its affiliates.
-
-from .datasets.dataset import get_am_datasets, get_voc_datasets
-from .models import model_builder
-from .models.hifigan.hifigan import Generator
-from .train.loss import criterion_builder
-from .train.trainer import GAN_Trainer, Sambert_Trainer
-from .utils.ling_unit.ling_unit import KanTtsLinguisticUnit
diff --git a/modelscope/models/audio/tts/kantts/datasets/data_types.py b/modelscope/models/audio/tts/kantts/datasets/data_types.py
deleted file mode 100644
index 3b41ffff..00000000
--- a/modelscope/models/audio/tts/kantts/datasets/data_types.py
+++ /dev/null
@@ -1,36 +0,0 @@
-# Copyright (c) Alibaba, Inc. and its affiliates.
-
-import numpy as np
-from scipy.io import wavfile
-
-DATA_TYPE_DICT = {
- 'txt': {
- 'load_func': np.loadtxt,
- 'desc': 'plain txt file or readable by np.loadtxt',
- },
- 'wav': {
- 'load_func': lambda x: wavfile.read(x)[1],
- 'desc': 'wav file or readable by soundfile.read',
- },
- 'npy': {
- 'load_func': np.load,
- 'desc': 'any .npy format file',
- },
- # PCM data type can be loaded by binary format
- 'bin_f32': {
- 'load_func': lambda x: np.fromfile(x, dtype=np.float32),
- 'desc': 'binary file with float32 format',
- },
- 'bin_f64': {
- 'load_func': lambda x: np.fromfile(x, dtype=np.float64),
- 'desc': 'binary file with float64 format',
- },
- 'bin_i32': {
- 'load_func': lambda x: np.fromfile(x, dtype=np.int32),
- 'desc': 'binary file with int32 format',
- },
- 'bin_i16': {
- 'load_func': lambda x: np.fromfile(x, dtype=np.int16),
- 'desc': 'binary file with int16 format',
- },
-}
diff --git a/modelscope/models/audio/tts/kantts/datasets/dataset.py b/modelscope/models/audio/tts/kantts/datasets/dataset.py
deleted file mode 100644
index d5dd4da7..00000000
--- a/modelscope/models/audio/tts/kantts/datasets/dataset.py
+++ /dev/null
@@ -1,1030 +0,0 @@
-# Copyright (c) Alibaba, Inc. and its affiliates.
-
-import functools
-import glob
-import math
-import os
-import random
-from multiprocessing import Manager
-
-import librosa
-import numpy as np
-import torch
-from scipy.stats import betabinom
-from tqdm import tqdm
-
-from modelscope.models.audio.tts.kantts.utils.ling_unit.ling_unit import (
- KanTtsLinguisticUnit, emotion_types)
-from modelscope.utils.logger import get_logger
-
-DATASET_RANDOM_SEED = 1234
-torch.multiprocessing.set_sharing_strategy('file_system')
-logging = get_logger()
-
-
-@functools.lru_cache(maxsize=256)
-def beta_binomial_prior_distribution(phoneme_count, mel_count, scaling=1.0):
- P = phoneme_count
- M = mel_count
- x = np.arange(0, P)
- mel_text_probs = []
- for i in range(1, M + 1):
- a, b = scaling * i, scaling * (M + 1 - i)
- rv = betabinom(P, a, b)
- mel_i_prob = rv.pmf(x)
- mel_text_probs.append(mel_i_prob)
- return torch.tensor(np.array(mel_text_probs))
-
-
-class Padder(object):
-
- def __init__(self):
- super(Padder, self).__init__()
- pass
-
- def _pad1D(self, x, length, pad):
- return np.pad(
- x, (0, length - x.shape[0]), mode='constant', constant_values=pad)
-
- def _pad2D(self, x, length, pad):
- return np.pad(
- x, [(0, length - x.shape[0]), (0, 0)],
- mode='constant',
- constant_values=pad)
-
- def _pad_durations(self, duration, max_in_len, max_out_len):
- framenum = np.sum(duration)
- symbolnum = duration.shape[0]
- if framenum < max_out_len:
- padframenum = max_out_len - framenum
- duration = np.insert(
- duration, symbolnum, values=padframenum, axis=0)
- duration = np.insert(
- duration,
- symbolnum + 1,
- values=[0] * (max_in_len - symbolnum - 1),
- axis=0,
- )
- else:
- if symbolnum < max_in_len:
- duration = np.insert(
- duration,
- symbolnum,
- values=[0] * (max_in_len - symbolnum),
- axis=0)
- return duration
-
- def _round_up(self, x, multiple):
- remainder = x % multiple
- return x if remainder == 0 else x + multiple - remainder
-
- def _prepare_scalar_inputs(self, inputs, max_len, pad):
- return torch.from_numpy(
- np.stack([self._pad1D(x, max_len, pad) for x in inputs]))
-
- def _prepare_targets(self, targets, max_len, pad):
- return torch.from_numpy(
- np.stack([self._pad2D(t, max_len, pad) for t in targets])).float()
-
- def _prepare_durations(self, durations, max_in_len, max_out_len):
- return torch.from_numpy(
- np.stack([
- self._pad_durations(t, max_in_len, max_out_len)
- for t in durations
- ])).long()
-
-
-class KanttsDataset(torch.utils.data.Dataset):
-
- def __init__(
- self,
- metafile,
- root_dir,
- ):
- self.meta = []
- if not isinstance(metafile, list):
- metafile = [metafile]
- if not isinstance(root_dir, list):
- root_dir = [root_dir]
-
- for meta_file, data_dir in zip(metafile, root_dir):
- if not os.path.exists(meta_file):
- logging.error('meta file not found: {}'.format(meta_file))
- raise ValueError(
- '[Dataset] meta file: {} not found'.format(meta_file))
- if not os.path.exists(data_dir):
- logging.error('data directory not found: {}'.format(data_dir))
- raise ValueError(
- '[Dataset] data dir: {} not found'.format(data_dir))
- self.meta.extend(self.load_meta(meta_file, data_dir))
-
- def load_meta(self, meta_file, data_dir):
- pass
-
-
-class VocDataset(KanttsDataset):
- """
- provide (mel, audio) data pair
- """
-
- def __init__(
- self,
- metafile,
- root_dir,
- config,
- ):
- self.config = config
- self.sampling_rate = config['audio_config']['sampling_rate']
- self.n_fft = config['audio_config']['n_fft']
- self.hop_length = config['audio_config']['hop_length']
- self.batch_max_steps = config['batch_max_steps']
- self.batch_max_frames = self.batch_max_steps // self.hop_length
- self.aux_context_window = 0
- self.start_offset = self.aux_context_window
- self.end_offset = -(self.batch_max_frames + self.aux_context_window)
- self.nsf_enable = (
- config['Model']['Generator']['params'].get('nsf_params', None)
- is not None)
-
- super().__init__(metafile, root_dir)
-
- # Load from training data directory
- if len(self.meta) == 0 and isinstance(root_dir, str):
- wav_dir = os.path.join(root_dir, 'wav')
- mel_dir = os.path.join(root_dir, 'mel')
- if not os.path.exists(wav_dir) or not os.path.exists(mel_dir):
- raise ValueError('wav or mel directory not found')
- self.meta.extend(self.load_meta_from_dir(wav_dir, mel_dir))
- elif len(self.meta) == 0 and isinstance(root_dir, list):
- for d in root_dir:
- wav_dir = os.path.join(d, 'wav')
- mel_dir = os.path.join(d, 'mel')
- if not os.path.exists(wav_dir) or not os.path.exists(mel_dir):
- raise ValueError('wav or mel directory not found')
- self.meta.extend(self.load_meta_from_dir(wav_dir, mel_dir))
-
- self.allow_cache = config['allow_cache']
- if self.allow_cache:
- self.manager = Manager()
- self.caches = self.manager.list()
- self.caches += [() for _ in range(len(self.meta))]
-
- @staticmethod
- def gen_metafile(wav_dir, out_dir, split_ratio=0.98):
- wav_files = glob.glob(os.path.join(wav_dir, '*.wav'))
- frame_f0_dir = os.path.join(out_dir, 'frame_f0')
- frame_uv_dir = os.path.join(out_dir, 'frame_uv')
- mel_dir = os.path.join(out_dir, 'mel')
- random.seed(DATASET_RANDOM_SEED)
- random.shuffle(wav_files)
- num_train = int(len(wav_files) * split_ratio) - 1
- with open(os.path.join(out_dir, 'train.lst'), 'w') as f:
- for wav_file in wav_files[:num_train]:
- index = os.path.splitext(os.path.basename(wav_file))[0]
- if (not os.path.exists(
- os.path.join(frame_f0_dir, index + '.npy'))
- or not os.path.exists(
- os.path.join(frame_uv_dir, index + '.npy'))
- or not os.path.exists(
- os.path.join(mel_dir, index + '.npy'))):
- continue
- f.write('{}\n'.format(index))
-
- with open(os.path.join(out_dir, 'valid.lst'), 'w') as f:
- for wav_file in wav_files[num_train:]:
- index = os.path.splitext(os.path.basename(wav_file))[0]
- if (not os.path.exists(
- os.path.join(frame_f0_dir, index + '.npy'))
- or not os.path.exists(
- os.path.join(frame_uv_dir, index + '.npy'))
- or not os.path.exists(
- os.path.join(mel_dir, index + '.npy'))):
- continue
- f.write('{}\n'.format(index))
-
- def load_meta(self, metafile, data_dir):
- with open(metafile, 'r') as f:
- lines = f.readlines()
- wav_dir = os.path.join(data_dir, 'wav')
- mel_dir = os.path.join(data_dir, 'mel')
- frame_f0_dir = os.path.join(data_dir, 'frame_f0')
- frame_uv_dir = os.path.join(data_dir, 'frame_uv')
- if not os.path.exists(wav_dir) or not os.path.exists(mel_dir):
- raise ValueError('wav or mel directory not found')
- items = []
- logging.info('Loading metafile...')
- for name in tqdm(lines):
- name = name.strip()
- mel_file = os.path.join(mel_dir, name + '.npy')
- wav_file = os.path.join(wav_dir, name + '.wav')
- frame_f0_file = os.path.join(frame_f0_dir, name + '.npy')
- frame_uv_file = os.path.join(frame_uv_dir, name + '.npy')
- items.append((wav_file, mel_file, frame_f0_file, frame_uv_file))
- return items
-
- def load_meta_from_dir(self, wav_dir, mel_dir):
- wav_files = glob.glob(os.path.join(wav_dir, '*.wav'))
- items = []
- for wav_file in wav_files:
- mel_file = os.path.join(mel_dir, os.path.basename(wav_file))
- if os.path.exists(mel_file):
- items.append((wav_file, mel_file))
- return items
-
- def __len__(self):
- return len(self.meta)
-
- def __getitem__(self, idx):
- if self.allow_cache and len(self.caches[idx]) != 0:
- return self.caches[idx]
-
- wav_file, mel_file, frame_f0_file, frame_uv_file = self.meta[idx]
-
- wav_data = librosa.core.load(wav_file, sr=self.sampling_rate)[0]
- mel_data = np.load(mel_file)
-
- if self.nsf_enable:
- frame_f0_data = np.load(frame_f0_file).reshape(-1, 1)
- frame_uv_data = np.load(frame_uv_file).reshape(-1, 1)
- mel_data = np.concatenate((mel_data, frame_f0_data, frame_uv_data),
- axis=1)
-
- # make sure mel_data length greater than batch_max_frames at least 1 frame
- if mel_data.shape[0] <= self.batch_max_frames:
- mel_data = np.concatenate(
- (
- mel_data,
- np.zeros((
- self.batch_max_frames - mel_data.shape[0] + 1,
- mel_data.shape[1],
- )),
- ),
- axis=0,
- )
- wav_cache = np.zeros(
- mel_data.shape[0] * self.hop_length, dtype=np.float32)
- wav_cache[:len(wav_data)] = wav_data
- wav_data = wav_cache
- else:
- # make sure the audio length and feature length are matched
- wav_data = np.pad(wav_data, (0, self.n_fft), mode='reflect')
- wav_data = wav_data[:len(mel_data) * self.hop_length]
-
- assert len(mel_data) * self.hop_length == len(wav_data)
-
- if self.allow_cache:
- self.caches[idx] = (wav_data, mel_data)
- return (wav_data, mel_data)
-
- def collate_fn(self, batch):
- wav_data, mel_data = [item[0]
- for item in batch], [item[1] for item in batch]
- mel_lengths = [len(mel) for mel in mel_data]
-
- start_frames = np.array([
- np.random.randint(self.start_offset, length + self.end_offset)
- for length in mel_lengths
- ])
-
- wav_start = start_frames * self.hop_length
- wav_end = wav_start + self.batch_max_steps
-
- # aux window works as padding
- mel_start = start_frames - self.aux_context_window
- mel_end = mel_start + self.batch_max_frames + self.aux_context_window
-
- wav_batch = [
- x[start:end] for x, start, end in zip(wav_data, wav_start, wav_end)
- ]
- mel_batch = [
- c[start:end] for c, start, end in zip(mel_data, mel_start, mel_end)
- ]
-
- # (B, 1, T)
- wav_batch = torch.tensor(
- np.asarray(wav_batch), dtype=torch.float32).unsqueeze(1)
- # (B, C, T)
- mel_batch = torch.tensor(
- np.asarray(mel_batch), dtype=torch.float32).transpose(2, 1)
- return wav_batch, mel_batch
-
-
-def get_voc_datasets(
- config,
- root_dir,
- split_ratio=0.98,
-):
- if isinstance(root_dir, str):
- root_dir = [root_dir]
- train_meta_lst = []
- valid_meta_lst = []
- for data_dir in root_dir:
- train_meta = os.path.join(data_dir, 'train.lst')
- valid_meta = os.path.join(data_dir, 'valid.lst')
- if not os.path.exists(train_meta) or not os.path.exists(valid_meta):
- VocDataset.gen_metafile(
- os.path.join(data_dir, 'wav'), data_dir, split_ratio)
- train_meta_lst.append(train_meta)
- valid_meta_lst.append(valid_meta)
- train_dataset = VocDataset(
- train_meta_lst,
- root_dir,
- config,
- )
-
- valid_dataset = VocDataset(
- valid_meta_lst,
- root_dir,
- config,
- )
-
- return train_dataset, valid_dataset
-
-
-def get_fp_label(aug_ling_txt):
- token_lst = aug_ling_txt.split(' ')
- emo_lst = [token.strip('{}').split('$')[4] for token in token_lst]
- syllable_lst = [token.strip('{}').split('$')[0] for token in token_lst]
-
- # EOS token append
- emo_lst.append(emotion_types[0])
- syllable_lst.append('EOS')
-
- # According to the original emotion tag, set each token's fp label.
- if emo_lst[0] != emotion_types[3]:
- emo_lst[0] = emotion_types[0]
- emo_lst[1] = emotion_types[0]
- for i in range(len(emo_lst) - 2, 1, -1):
- if emo_lst[i] != emotion_types[3] and emo_lst[i
- - 1] != emotion_types[3]:
- emo_lst[i] = emotion_types[0]
- elif emo_lst[i] != emotion_types[3] and emo_lst[
- i - 1] == emotion_types[3]:
- emo_lst[i] = emotion_types[3]
- if syllable_lst[i - 2] == 'ga':
- emo_lst[i + 1] = emotion_types[1]
- elif syllable_lst[i - 2] == 'ge' and syllable_lst[i - 1] == 'en_c':
- emo_lst[i + 1] = emotion_types[2]
- else:
- emo_lst[i + 1] = emotion_types[4]
-
- fp_label = []
- for i in range(len(emo_lst)):
- if emo_lst[i] == emotion_types[0]:
- fp_label.append(0)
- elif emo_lst[i] == emotion_types[1]:
- fp_label.append(1)
- elif emo_lst[i] == emotion_types[2]:
- fp_label.append(2)
- elif emo_lst[i] == emotion_types[3]:
- continue
- elif emo_lst[i] == emotion_types[4]:
- fp_label.append(3)
- else:
- pass
-
- return np.array(fp_label)
-
-
-class AmDataset(KanttsDataset):
- """
- provide (ling, emo, speaker, mel) pair
- """
-
- def __init__(
- self,
- metafile,
- root_dir,
- config,
- lang_dir=None,
- allow_cache=False,
- ):
- self.config = config
- self.with_duration = True
- self.nsf_enable = self.config['Model']['KanTtsSAMBERT']['params'].get(
- 'NSF', False)
- self.fp_enable = self.config['Model']['KanTtsSAMBERT']['params'].get(
- 'FP', False)
-
- super().__init__(metafile, root_dir)
- self.allow_cache = allow_cache
-
- self.ling_unit = KanTtsLinguisticUnit(config, lang_dir)
- self.padder = Padder()
-
- self.r = self.config['Model']['KanTtsSAMBERT']['params'][
- 'outputs_per_step']
-
- if allow_cache:
- self.manager = Manager()
- self.caches = self.manager.list()
- self.caches += [() for _ in range(len(self.meta))]
-
- def __len__(self):
- return len(self.meta)
-
- def __getitem__(self, idx):
- if self.allow_cache and len(self.caches[idx]) != 0:
- return self.caches[idx]
-
- (
- ling_txt,
- mel_file,
- dur_file,
- f0_file,
- energy_file,
- frame_f0_file,
- frame_uv_file,
- aug_ling_txt,
- ) = self.meta[idx]
-
- ling_data = self.ling_unit.encode_symbol_sequence(ling_txt)
- mel_data = np.load(mel_file)
- dur_data = np.load(dur_file) if dur_file is not None else None
- f0_data = np.load(f0_file)
- energy_data = np.load(energy_file)
-
- # generate fp position label according to fpadd_meta
- if self.fp_enable and aug_ling_txt is not None:
- fp_label = get_fp_label(aug_ling_txt)
- else:
- fp_label = None
-
- if self.with_duration:
- attn_prior = None
- else:
- attn_prior = beta_binomial_prior_distribution(
- len(ling_data[0]), mel_data.shape[0])
-
- # Concat frame-level f0 and uv to mel_data
- if self.nsf_enable:
- frame_f0_data = np.load(frame_f0_file).reshape(-1, 1)
- frame_uv_data = np.load(frame_uv_file).reshape(-1, 1)
- mel_data = np.concatenate([mel_data, frame_f0_data, frame_uv_data],
- axis=1)
-
- if self.allow_cache:
- self.caches[idx] = (
- ling_data,
- mel_data,
- dur_data,
- f0_data,
- energy_data,
- attn_prior,
- fp_label,
- )
-
- return (
- ling_data,
- mel_data,
- dur_data,
- f0_data,
- energy_data,
- attn_prior,
- fp_label,
- )
-
- def load_meta(self, metafile, data_dir):
- with open(metafile, 'r') as f:
- lines = f.readlines()
-
- aug_ling_dict = {}
- if self.fp_enable:
- add_fp_metafile = metafile.replace('fprm', 'fpadd')
- with open(add_fp_metafile, 'r') as f:
- fpadd_lines = f.readlines()
- for line in fpadd_lines:
- index, aug_ling_txt = line.split('\t')
- aug_ling_dict[index] = aug_ling_txt
-
- mel_dir = os.path.join(data_dir, 'mel')
- dur_dir = os.path.join(data_dir, 'duration')
- f0_dir = os.path.join(data_dir, 'f0')
- energy_dir = os.path.join(data_dir, 'energy')
- frame_f0_dir = os.path.join(data_dir, 'frame_f0')
- frame_uv_dir = os.path.join(data_dir, 'frame_uv')
-
- self.with_duration = os.path.exists(dur_dir)
-
- items = []
- logging.info('Loading metafile...')
- for line in tqdm(lines):
- line = line.strip()
- index, ling_txt = line.split('\t')
- mel_file = os.path.join(mel_dir, index + '.npy')
- if self.with_duration:
- dur_file = os.path.join(dur_dir, index + '.npy')
- else:
- dur_file = None
- f0_file = os.path.join(f0_dir, index + '.npy')
- energy_file = os.path.join(energy_dir, index + '.npy')
- frame_f0_file = os.path.join(frame_f0_dir, index + '.npy')
- frame_uv_file = os.path.join(frame_uv_dir, index + '.npy')
- aug_ling_txt = aug_ling_dict.get(index, None)
- if self.fp_enable and aug_ling_txt is None:
- logging.warning(f'Missing fpadd meta for {index}')
- continue
-
- items.append((
- ling_txt,
- mel_file,
- dur_file,
- f0_file,
- energy_file,
- frame_f0_file,
- frame_uv_file,
- aug_ling_txt,
- ))
-
- return items
-
- def load_fpadd_meta(self, metafile):
- with open(metafile, 'r') as f:
- lines = f.readlines()
-
- items = []
- logging.info('Loading fpadd metafile...')
- for line in tqdm(lines):
- line = line.strip()
- index, ling_txt = line.split('\t')
-
- items.append((ling_txt, ))
-
- return items
-
- @staticmethod
- def gen_metafile(
- raw_meta_file,
- out_dir,
- train_meta_file,
- valid_meta_file,
- badlist=None,
- split_ratio=0.98,
- ):
- with open(raw_meta_file, 'r') as f:
- lines = f.readlines()
- frame_f0_dir = os.path.join(out_dir, 'frame_f0')
- frame_uv_dir = os.path.join(out_dir, 'frame_uv')
- mel_dir = os.path.join(out_dir, 'mel')
- duration_dir = os.path.join(out_dir, 'duration')
- random.seed(DATASET_RANDOM_SEED)
- random.shuffle(lines)
- num_train = int(len(lines) * split_ratio) - 1
- with open(train_meta_file, 'w') as f:
- for line in lines[:num_train]:
- index = line.split('\t')[0]
- if badlist is not None and index in badlist:
- continue
- if (not os.path.exists(
- os.path.join(frame_f0_dir, index + '.npy'))
- or not os.path.exists(
- os.path.join(frame_uv_dir, index + '.npy'))
- or not os.path.exists(
- os.path.join(mel_dir, index + '.npy'))):
- continue
- if os.path.exists(duration_dir) and not os.path.exists(
- os.path.join(duration_dir, index + '.npy')):
- continue
- f.write(line)
-
- with open(valid_meta_file, 'w') as f:
- for line in lines[num_train:]:
- index = line.split('\t')[0]
- if badlist is not None and index in badlist:
- continue
- if (not os.path.exists(
- os.path.join(frame_f0_dir, index + '.npy'))
- or not os.path.exists(
- os.path.join(frame_uv_dir, index + '.npy'))
- or not os.path.exists(
- os.path.join(mel_dir, index + '.npy'))):
- continue
- if os.path.exists(duration_dir) and not os.path.exists(
- os.path.join(duration_dir, index + '.npy')):
- continue
- f.write(line)
-
- def collate_fn(self, batch):
- data_dict = {}
-
- max_input_length = max((len(x[0][0]) for x in batch))
- if self.with_duration:
- max_dur_length = max((x[2].shape[0] for x in batch)) + 1
-
- lfeat_type_index = 0
- lfeat_type = self.ling_unit._lfeat_type_list[lfeat_type_index]
- if self.ling_unit.using_byte():
- # for byte-based model only
- inputs_byte_index = self.padder._prepare_scalar_inputs(
- [x[0][lfeat_type_index] for x in batch],
- max_input_length,
- self.ling_unit._sub_unit_pad[lfeat_type],
- ).long()
-
- data_dict['input_lings'] = torch.stack([inputs_byte_index], dim=2)
- else:
- # pure linguistic info: sy|tone|syllable_flag|word_segment
- # sy
- inputs_sy = self.padder._prepare_scalar_inputs(
- [x[0][lfeat_type_index] for x in batch],
- max_input_length,
- self.ling_unit._sub_unit_pad[lfeat_type],
- ).long()
-
- # tone
- lfeat_type_index = lfeat_type_index + 1
- lfeat_type = self.ling_unit._lfeat_type_list[lfeat_type_index]
- inputs_tone = self.padder._prepare_scalar_inputs(
- [x[0][lfeat_type_index] for x in batch],
- max_input_length,
- self.ling_unit._sub_unit_pad[lfeat_type],
- ).long()
-
- # syllable_flag
- lfeat_type_index = lfeat_type_index + 1
- lfeat_type = self.ling_unit._lfeat_type_list[lfeat_type_index]
- inputs_syllable_flag = self.padder._prepare_scalar_inputs(
- [x[0][lfeat_type_index] for x in batch],
- max_input_length,
- self.ling_unit._sub_unit_pad[lfeat_type],
- ).long()
-
- # word_segment
- lfeat_type_index = lfeat_type_index + 1
- lfeat_type = self.ling_unit._lfeat_type_list[lfeat_type_index]
- inputs_ws = self.padder._prepare_scalar_inputs(
- [x[0][lfeat_type_index] for x in batch],
- max_input_length,
- self.ling_unit._sub_unit_pad[lfeat_type],
- ).long()
-
- data_dict['input_lings'] = torch.stack(
- [inputs_sy, inputs_tone, inputs_syllable_flag, inputs_ws],
- dim=2)
-
- # emotion category
- lfeat_type_index = lfeat_type_index + 1
- lfeat_type = self.ling_unit._lfeat_type_list[lfeat_type_index]
- data_dict['input_emotions'] = self.padder._prepare_scalar_inputs(
- [x[0][lfeat_type_index] for x in batch],
- max_input_length,
- self.ling_unit._sub_unit_pad[lfeat_type],
- ).long()
-
- # speaker category
- lfeat_type_index = lfeat_type_index + 1
- lfeat_type = self.ling_unit._lfeat_type_list[lfeat_type_index]
- data_dict['input_speakers'] = self.padder._prepare_scalar_inputs(
- [x[0][lfeat_type_index] for x in batch],
- max_input_length,
- self.ling_unit._sub_unit_pad[lfeat_type],
- ).long()
-
- # fp label category
- if self.fp_enable:
- data_dict['fp_label'] = self.padder._prepare_scalar_inputs(
- [x[6] for x in batch],
- max_input_length,
- 0,
- ).long()
-
- data_dict['valid_input_lengths'] = torch.as_tensor(
- [len(x[0][0]) - 1 for x in batch], dtype=torch.long
- ) # 输入的symbol sequence会在后面拼一个“~”,影响duration计算,所以把length-1
- data_dict['valid_output_lengths'] = torch.as_tensor(
- [len(x[1]) for x in batch], dtype=torch.long)
-
- max_output_length = torch.max(data_dict['valid_output_lengths']).item()
- max_output_round_length = self.padder._round_up(
- max_output_length, self.r)
-
- data_dict['mel_targets'] = self.padder._prepare_targets(
- [x[1] for x in batch], max_output_round_length, 0.0)
- if self.with_duration:
- data_dict['durations'] = self.padder._prepare_durations(
- [x[2] for x in batch], max_dur_length, max_output_round_length)
- else:
- data_dict['durations'] = None
-
- if self.with_duration:
- if self.fp_enable:
- feats_padding_length = max_dur_length
- else:
- feats_padding_length = max_input_length
- else:
- feats_padding_length = max_output_round_length
-
- data_dict['pitch_contours'] = self.padder._prepare_scalar_inputs(
- [x[3] for x in batch], feats_padding_length, 0.0).float()
- data_dict['energy_contours'] = self.padder._prepare_scalar_inputs(
- [x[4] for x in batch], feats_padding_length, 0.0).float()
-
- if self.with_duration:
- data_dict['attn_priors'] = None
- else:
- data_dict['attn_priors'] = torch.zeros(
- len(batch), max_output_round_length, max_input_length)
- for i in range(len(batch)):
- attn_prior = batch[i][5]
- data_dict['attn_priors'][
- i, :attn_prior.shape[0], :attn_prior.shape[1]] = attn_prior
- return data_dict
-
-
-def get_am_datasets(
- metafile,
- root_dir,
- lang_dir,
- config,
- allow_cache,
- split_ratio=0.98,
-):
- if not isinstance(root_dir, list):
- root_dir = [root_dir]
- if not isinstance(metafile, list):
- metafile = [metafile]
-
- train_meta_lst = []
- valid_meta_lst = []
-
- fp_enable = config['Model']['KanTtsSAMBERT']['params'].get('FP', False)
-
- if fp_enable:
- am_train_fn = 'am_fprm_train.lst'
- am_valid_fn = 'am_fprm_valid.lst'
- else:
- am_train_fn = 'am_train.lst'
- am_valid_fn = 'am_valid.lst'
-
- for raw_metafile, data_dir in zip(metafile, root_dir):
- train_meta = os.path.join(data_dir, am_train_fn)
- valid_meta = os.path.join(data_dir, am_valid_fn)
- if not os.path.exists(train_meta) or not os.path.exists(valid_meta):
- AmDataset.gen_metafile(raw_metafile, data_dir, train_meta,
- valid_meta, split_ratio)
- train_meta_lst.append(train_meta)
- valid_meta_lst.append(valid_meta)
-
- train_dataset = AmDataset(train_meta_lst, root_dir, config, lang_dir,
- allow_cache)
-
- valid_dataset = AmDataset(valid_meta_lst, root_dir, config, lang_dir,
- allow_cache)
-
- return train_dataset, valid_dataset
-
-
-class MaskingActor(object):
-
- def __init__(self, mask_ratio=0.15):
- super(MaskingActor, self).__init__()
- self.mask_ratio = mask_ratio
- pass
-
- def _get_random_mask(self, length, p1=0.15):
- mask = np.random.uniform(0, 1, length)
- index = 0
- while index < len(mask):
- if mask[index] < p1:
- mask[index] = 1
- else:
- mask[index] = 0
- index += 1
-
- return mask
-
- def _input_bert_masking(
- self,
- sequence_array,
- nb_symbol_category,
- mask_symbol_id,
- mask,
- p2=0.8,
- p3=0.1,
- p4=0.1,
- ):
- sequence_array_mask = sequence_array.copy()
- mask_id = np.where(mask == 1)[0]
- mask_len = len(mask_id)
- rand = np.arange(mask_len)
- np.random.shuffle(rand)
-
- # [MASK]
- mask_id_p2 = mask_id[rand[0:int(math.floor(mask_len * p2))]]
- if len(mask_id_p2) > 0:
- sequence_array_mask[mask_id_p2] = mask_symbol_id
-
- # rand
- mask_id_p3 = mask_id[
- rand[int(math.floor(mask_len * p2)):int(math.floor(mask_len * p2))
- + int(math.floor(mask_len * p3))]]
- if len(mask_id_p3) > 0:
- sequence_array_mask[mask_id_p3] = random.randint(
- 0, nb_symbol_category - 1)
-
- # ori
- # do nothing
-
- return sequence_array_mask
-
-
-class BERTTextDataset(torch.utils.data.Dataset):
- """
- provide (ling, ling_sy_masked, bert_mask) pair
- """
-
- def __init__(
- self,
- config,
- metafile,
- root_dir,
- lang_dir=None,
- allow_cache=False,
- ):
- self.meta = []
- self.config = config
-
- if not isinstance(metafile, list):
- metafile = [metafile]
- if not isinstance(root_dir, list):
- root_dir = [root_dir]
-
- for meta_file, data_dir in zip(metafile, root_dir):
- if not os.path.exists(meta_file):
- logging.error('meta file not found: {}'.format(meta_file))
- raise ValueError(
- '[BERT_Text_Dataset] meta file: {} not found'.format(
- meta_file))
- if not os.path.exists(data_dir):
- logging.error('data dir not found: {}'.format(data_dir))
- raise ValueError(
- '[BERT_Text_Dataset] data dir: {} not found'.format(
- data_dir))
- self.meta.extend(self.load_meta(meta_file, data_dir))
-
- self.allow_cache = allow_cache
-
- self.ling_unit = KanTtsLinguisticUnit(config, lang_dir)
- self.padder = Padder()
- self.masking_actor = MaskingActor(
- self.config['Model']['KanTtsTextsyBERT']['params']['mask_ratio'])
-
- if allow_cache:
- self.manager = Manager()
- self.caches = self.manager.list()
- self.caches += [() for _ in range(len(self.meta))]
-
- def __len__(self):
- return len(self.meta)
-
- def __getitem__(self, idx):
- if self.allow_cache and len(self.caches[idx]) != 0:
- ling_data = self.caches[idx][0]
- bert_mask, ling_sy_masked_data = self.bert_masking(ling_data)
- return (ling_data, ling_sy_masked_data, bert_mask)
-
- ling_txt = self.meta[idx]
-
- ling_data = self.ling_unit.encode_symbol_sequence(ling_txt)
- bert_mask, ling_sy_masked_data = self.bert_masking(ling_data)
-
- if self.allow_cache:
- self.caches[idx] = (ling_data, )
-
- return (ling_data, ling_sy_masked_data, bert_mask)
-
- def load_meta(self, metafile, data_dir):
- with open(metafile, 'r') as f:
- lines = f.readlines()
-
- items = []
- logging.info('Loading metafile...')
- for line in tqdm(lines):
- line = line.strip()
- index, ling_txt = line.split('\t')
-
- items.append((ling_txt))
-
- return items
-
- @staticmethod
- def gen_metafile(raw_meta_file, out_dir, split_ratio=0.98):
- with open(raw_meta_file, 'r') as f:
- lines = f.readlines()
- random.seed(DATASET_RANDOM_SEED)
- random.shuffle(lines)
- num_train = int(len(lines) * split_ratio) - 1
- with open(os.path.join(out_dir, 'bert_train.lst'), 'w') as f:
- for line in lines[:num_train]:
- f.write(line)
-
- with open(os.path.join(out_dir, 'bert_valid.lst'), 'w') as f:
- for line in lines[num_train:]:
- f.write(line)
-
- def bert_masking(self, ling_data):
- length = len(ling_data[0])
- mask = self.masking_actor._get_random_mask(
- length, p1=self.masking_actor.mask_ratio)
- mask[-1] = 0
-
- # sy_masked
- sy_mask_symbol_id = self.ling_unit.encode_sy([self.ling_unit._mask])[0]
- ling_sy_masked_data = self.masking_actor._input_bert_masking(
- ling_data[0],
- self.ling_unit.get_unit_size()['sy'],
- sy_mask_symbol_id,
- mask,
- p2=0.8,
- p3=0.1,
- p4=0.1,
- )
-
- return (mask, ling_sy_masked_data)
-
- def collate_fn(self, batch):
- data_dict = {}
-
- max_input_length = max((len(x[0][0]) for x in batch))
-
- # pure linguistic info: sy|tone|syllable_flag|word_segment
- # sy
- lfeat_type = self.ling_unit._lfeat_type_list[0]
- targets_sy = self.padder._prepare_scalar_inputs(
- [x[0][0] for x in batch],
- max_input_length,
- self.ling_unit._sub_unit_pad[lfeat_type],
- ).long()
- # sy masked
- inputs_sy = self.padder._prepare_scalar_inputs(
- [x[1] for x in batch],
- max_input_length,
- self.ling_unit._sub_unit_pad[lfeat_type],
- ).long()
- # tone
- lfeat_type = self.ling_unit._lfeat_type_list[1]
- inputs_tone = self.padder._prepare_scalar_inputs(
- [x[0][1] for x in batch],
- max_input_length,
- self.ling_unit._sub_unit_pad[lfeat_type],
- ).long()
-
- # syllable_flag
- lfeat_type = self.ling_unit._lfeat_type_list[2]
- inputs_syllable_flag = self.padder._prepare_scalar_inputs(
- [x[0][2] for x in batch],
- max_input_length,
- self.ling_unit._sub_unit_pad[lfeat_type],
- ).long()
-
- # word_segment
- lfeat_type = self.ling_unit._lfeat_type_list[3]
- inputs_ws = self.padder._prepare_scalar_inputs(
- [x[0][3] for x in batch],
- max_input_length,
- self.ling_unit._sub_unit_pad[lfeat_type],
- ).long()
-
- data_dict['input_lings'] = torch.stack(
- [inputs_sy, inputs_tone, inputs_syllable_flag, inputs_ws], dim=2)
- data_dict['valid_input_lengths'] = torch.as_tensor(
- [len(x[0][0]) - 1 for x in batch], dtype=torch.long
- ) # 输入的symbol sequence会在后面拼一个“~”,影响duration计算,所以把length-1
-
- data_dict['targets'] = targets_sy
- data_dict['bert_masks'] = self.padder._prepare_scalar_inputs(
- [x[2] for x in batch], max_input_length, 0.0)
-
- return data_dict
-
-
-def get_bert_text_datasets(
- metafile,
- root_dir,
- config,
- allow_cache,
- split_ratio=0.98,
-):
- if not isinstance(root_dir, list):
- root_dir = [root_dir]
- if not isinstance(metafile, list):
- metafile = [metafile]
-
- train_meta_lst = []
- valid_meta_lst = []
-
- for raw_metafile, data_dir in zip(metafile, root_dir):
- train_meta = os.path.join(data_dir, 'bert_train.lst')
- valid_meta = os.path.join(data_dir, 'bert_valid.lst')
- if not os.path.exists(train_meta) or not os.path.exists(valid_meta):
- BERTTextDataset.gen_metafile(raw_metafile, data_dir, split_ratio)
- train_meta_lst.append(train_meta)
- valid_meta_lst.append(valid_meta)
-
- train_dataset = BERTTextDataset(config, train_meta_lst, root_dir,
- allow_cache)
-
- valid_dataset = BERTTextDataset(config, valid_meta_lst, root_dir,
- allow_cache)
-
- return train_dataset, valid_dataset
diff --git a/modelscope/models/audio/tts/kantts/models/__init__.py b/modelscope/models/audio/tts/kantts/models/__init__.py
deleted file mode 100644
index 682f1865..00000000
--- a/modelscope/models/audio/tts/kantts/models/__init__.py
+++ /dev/null
@@ -1,158 +0,0 @@
-# Copyright (c) Alibaba, Inc. and its affiliates.
-
-import torch
-from torch.nn.parallel import DistributedDataParallel
-
-import modelscope.models.audio.tts.kantts.train.scheduler as kantts_scheduler
-from modelscope.models.audio.tts.kantts.utils.ling_unit.ling_unit import \
- get_fpdict
-from .hifigan import (Generator, MultiPeriodDiscriminator,
- MultiScaleDiscriminator, MultiSpecDiscriminator)
-from .pqmf import PQMF
-from .sambert.kantts_sambert import KanTtsSAMBERT, KanTtsTextsyBERT
-
-
-def optimizer_builder(model_params, opt_name, opt_params):
- opt_cls = getattr(torch.optim, opt_name)
- optimizer = opt_cls(model_params, **opt_params)
- return optimizer
-
-
-def scheduler_builder(optimizer, sche_name, sche_params):
- scheduler_cls = getattr(kantts_scheduler, sche_name)
- scheduler = scheduler_cls(optimizer, **sche_params)
- return scheduler
-
-
-def hifigan_model_builder(config, device, rank, distributed):
- model = {}
- optimizer = {}
- scheduler = {}
- model['discriminator'] = {}
- optimizer['discriminator'] = {}
- scheduler['discriminator'] = {}
- for model_name in config['Model'].keys():
- if model_name == 'Generator':
- params = config['Model'][model_name]['params']
- model['generator'] = Generator(**params).to(device)
- optimizer['generator'] = optimizer_builder(
- model['generator'].parameters(),
- config['Model'][model_name]['optimizer'].get('type', 'Adam'),
- config['Model'][model_name]['optimizer'].get('params', {}),
- )
- scheduler['generator'] = scheduler_builder(
- optimizer['generator'],
- config['Model'][model_name]['scheduler'].get('type', 'StepLR'),
- config['Model'][model_name]['scheduler'].get('params', {}),
- )
- else:
- params = config['Model'][model_name]['params']
- model['discriminator'][model_name] = globals()[model_name](
- **params).to(device)
- optimizer['discriminator'][model_name] = optimizer_builder(
- model['discriminator'][model_name].parameters(),
- config['Model'][model_name]['optimizer'].get('type', 'Adam'),
- config['Model'][model_name]['optimizer'].get('params', {}),
- )
- scheduler['discriminator'][model_name] = scheduler_builder(
- optimizer['discriminator'][model_name],
- config['Model'][model_name]['scheduler'].get('type', 'StepLR'),
- config['Model'][model_name]['scheduler'].get('params', {}),
- )
-
- out_channels = config['Model']['Generator']['params']['out_channels']
- if out_channels > 1:
- model['pqmf'] = PQMF(
- subbands=out_channels, **config.get('pqmf', {})).to(device)
-
- # FIXME: pywavelets buffer leads to gradient error in DDP training
- # Solution: https://github.com/pytorch/pytorch/issues/22095
- if distributed:
- model['generator'] = DistributedDataParallel(
- model['generator'],
- device_ids=[rank],
- output_device=rank,
- broadcast_buffers=False,
- )
- for model_name in model['discriminator'].keys():
- model['discriminator'][model_name] = DistributedDataParallel(
- model['discriminator'][model_name],
- device_ids=[rank],
- output_device=rank,
- broadcast_buffers=False,
- )
-
- return model, optimizer, scheduler
-
-
-def sambert_model_builder(config, device, rank, distributed):
- model = {}
- optimizer = {}
- scheduler = {}
-
- model['KanTtsSAMBERT'] = KanTtsSAMBERT(
- config['Model']['KanTtsSAMBERT']['params']).to(device)
-
- fp_enable = config['Model']['KanTtsSAMBERT']['params'].get('FP', False)
- if fp_enable:
- fp_dict = {
- k: torch.from_numpy(v).long().unsqueeze(0).to(device)
- for k, v in get_fpdict(config).items()
- }
- model['KanTtsSAMBERT'].fp_dict = fp_dict
-
- optimizer['KanTtsSAMBERT'] = optimizer_builder(
- model['KanTtsSAMBERT'].parameters(),
- config['Model']['KanTtsSAMBERT']['optimizer'].get('type', 'Adam'),
- config['Model']['KanTtsSAMBERT']['optimizer'].get('params', {}),
- )
- scheduler['KanTtsSAMBERT'] = scheduler_builder(
- optimizer['KanTtsSAMBERT'],
- config['Model']['KanTtsSAMBERT']['scheduler'].get('type', 'StepLR'),
- config['Model']['KanTtsSAMBERT']['scheduler'].get('params', {}),
- )
-
- if distributed:
- model['KanTtsSAMBERT'] = DistributedDataParallel(
- model['KanTtsSAMBERT'], device_ids=[rank], output_device=rank)
-
- return model, optimizer, scheduler
-
-
-def sybert_model_builder(config, device, rank, distributed):
- model = {}
- optimizer = {}
- scheduler = {}
-
- model['KanTtsTextsyBERT'] = KanTtsTextsyBERT(
- config['Model']['KanTtsTextsyBERT']['params']).to(device)
- optimizer['KanTtsTextsyBERT'] = optimizer_builder(
- model['KanTtsTextsyBERT'].parameters(),
- config['Model']['KanTtsTextsyBERT']['optimizer'].get('type', 'Adam'),
- config['Model']['KanTtsTextsyBERT']['optimizer'].get('params', {}),
- )
- scheduler['KanTtsTextsyBERT'] = scheduler_builder(
- optimizer['KanTtsTextsyBERT'],
- config['Model']['KanTtsTextsyBERT']['scheduler'].get('type', 'StepLR'),
- config['Model']['KanTtsTextsyBERT']['scheduler'].get('params', {}),
- )
-
- if distributed:
- model['KanTtsTextsyBERT'] = DistributedDataParallel(
- model['KanTtsTextsyBERT'], device_ids=[rank], output_device=rank)
-
- return model, optimizer, scheduler
-
-
-model_dict = {
- 'hifigan': hifigan_model_builder,
- 'sambert': sambert_model_builder,
- 'sybert': sybert_model_builder,
-}
-
-
-def model_builder(config, device='cpu', rank=0, distributed=False):
- builder_func = model_dict[config['model_type']]
- model, optimizer, scheduler = builder_func(config, device, rank,
- distributed)
- return model, optimizer, scheduler
diff --git a/modelscope/models/audio/tts/kantts/models/hifigan/__init__.py b/modelscope/models/audio/tts/kantts/models/hifigan/__init__.py
deleted file mode 100644
index 8c4f466e..00000000
--- a/modelscope/models/audio/tts/kantts/models/hifigan/__init__.py
+++ /dev/null
@@ -1,4 +0,0 @@
-# Copyright (c) Alibaba, Inc. and its affiliates.
-
-from .hifigan import (Generator, MultiPeriodDiscriminator,
- MultiScaleDiscriminator, MultiSpecDiscriminator)
diff --git a/modelscope/models/audio/tts/kantts/models/hifigan/hifigan.py b/modelscope/models/audio/tts/kantts/models/hifigan/hifigan.py
deleted file mode 100644
index c21e6714..00000000
--- a/modelscope/models/audio/tts/kantts/models/hifigan/hifigan.py
+++ /dev/null
@@ -1,613 +0,0 @@
-# Copyright (c) Alibaba, Inc. and its affiliates.
-
-import copy
-from distutils.version import LooseVersion
-
-import numpy as np
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from pytorch_wavelets import DWT1DForward
-from torch.nn.utils import spectral_norm, weight_norm
-
-from modelscope.models.audio.tts.kantts.utils.audio_torch import stft
-from .layers import (CausalConv1d, CausalConvTranspose1d, Conv1d,
- ConvTranspose1d, ResidualBlock, SourceModule)
-
-is_pytorch_17plus = LooseVersion(torch.__version__) >= LooseVersion('1.7')
-
-
-class Generator(torch.nn.Module):
-
- def __init__(
- self,
- in_channels=80,
- out_channels=1,
- channels=512,
- kernel_size=7,
- upsample_scales=(8, 8, 2, 2),
- upsample_kernal_sizes=(16, 16, 4, 4),
- resblock_kernel_sizes=(3, 7, 11),
- resblock_dilations=[(1, 3, 5), (1, 3, 5), (1, 3, 5)],
- repeat_upsample=True,
- bias=True,
- causal=True,
- nonlinear_activation='LeakyReLU',
- nonlinear_activation_params={'negative_slope': 0.1},
- use_weight_norm=True,
- nsf_params=None,
- ):
- super(Generator, self).__init__()
-
- # check hyperparameters are valid
- assert kernel_size % 2 == 1, 'Kernal size must be odd number.'
- assert len(upsample_scales) == len(upsample_kernal_sizes)
- assert len(resblock_dilations) == len(resblock_kernel_sizes)
-
- self.upsample_scales = upsample_scales
- self.repeat_upsample = repeat_upsample
- self.num_upsamples = len(upsample_kernal_sizes)
- self.num_kernels = len(resblock_kernel_sizes)
- self.out_channels = out_channels
- self.nsf_enable = nsf_params is not None
-
- self.transpose_upsamples = torch.nn.ModuleList()
- self.repeat_upsamples = torch.nn.ModuleList() # for repeat upsampling
- self.conv_blocks = torch.nn.ModuleList()
-
- conv_cls = CausalConv1d if causal else Conv1d
- conv_transposed_cls = CausalConvTranspose1d if causal else ConvTranspose1d
-
- self.conv_pre = conv_cls(
- in_channels,
- channels,
- kernel_size,
- 1,
- padding=(kernel_size - 1) // 2)
-
- for i in range(len(upsample_kernal_sizes)):
- self.transpose_upsamples.append(
- torch.nn.Sequential(
- getattr(
- torch.nn,
- nonlinear_activation)(**nonlinear_activation_params),
- conv_transposed_cls(
- channels // (2**i),
- channels // (2**(i + 1)),
- upsample_kernal_sizes[i],
- upsample_scales[i],
- padding=(upsample_kernal_sizes[i] - upsample_scales[i])
- // 2,
- ),
- ))
-
- if repeat_upsample:
- self.repeat_upsamples.append(
- nn.Sequential(
- nn.Upsample(
- mode='nearest', scale_factor=upsample_scales[i]),
- getattr(torch.nn, nonlinear_activation)(
- **nonlinear_activation_params),
- conv_cls(
- channels // (2**i),
- channels // (2**(i + 1)),
- kernel_size=kernel_size,
- stride=1,
- padding=(kernel_size - 1) // 2,
- ),
- ))
-
- for j in range(len(resblock_kernel_sizes)):
- self.conv_blocks.append(
- ResidualBlock(
- channels=channels // (2**(i + 1)),
- kernel_size=resblock_kernel_sizes[j],
- dilation=resblock_dilations[j],
- nonlinear_activation=nonlinear_activation,
- nonlinear_activation_params=nonlinear_activation_params,
- causal=causal,
- ))
-
- self.conv_post = conv_cls(
- channels // (2**(i + 1)),
- out_channels,
- kernel_size,
- 1,
- padding=(kernel_size - 1) // 2,
- )
-
- if self.nsf_enable:
- self.source_module = SourceModule(
- nb_harmonics=nsf_params['nb_harmonics'],
- upsample_ratio=np.cumprod(self.upsample_scales)[-1],
- sampling_rate=nsf_params['sampling_rate'],
- )
- self.source_downs = nn.ModuleList()
- self.downsample_rates = [1] + self.upsample_scales[::-1][:-1]
- self.downsample_cum_rates = np.cumprod(self.downsample_rates)
-
- for i, u in enumerate(self.downsample_cum_rates[::-1]):
- if u == 1:
- self.source_downs.append(
- Conv1d(1, channels // (2**(i + 1)), 1, 1))
- else:
- self.source_downs.append(
- conv_cls(
- 1,
- channels // (2**(i + 1)),
- u * 2,
- u,
- padding=u // 2,
- ))
-
- def forward(self, x):
- if self.nsf_enable:
- mel = x[:, :-2, :]
- pitch = x[:, -2:-1, :]
- uv = x[:, -1:, :]
- excitation = self.source_module(pitch, uv)
- else:
- mel = x
-
- x = self.conv_pre(mel)
- for i in range(self.num_upsamples):
- # FIXME: sin function here seems to be causing issues
- x = torch.sin(x) + x
- rep = self.repeat_upsamples[i](x)
-
- if self.nsf_enable:
- # Downsampling the excitation signal
- e = self.source_downs[i](excitation)
- # augment inputs with the excitation
- x = rep + e
- else:
- # transconv
- up = self.transpose_upsamples[i](x)
- x = rep + up[:, :, :rep.shape[-1]]
-
- xs = None
- for j in range(self.num_kernels):
- if xs is None:
- xs = self.conv_blocks[i * self.num_kernels + j](x)
- else:
- xs += self.conv_blocks[i * self.num_kernels + j](x)
- x = xs / self.num_kernels
-
- x = F.leaky_relu(x)
- x = self.conv_post(x)
- x = torch.tanh(x)
-
- return x
-
- def remove_weight_norm(self):
- print('Removing weight norm...')
- for layer in self.transpose_upsamples:
- layer[-1].remove_weight_norm()
- for layer in self.repeat_upsamples:
- layer[-1].remove_weight_norm()
- for layer in self.conv_blocks:
- layer.remove_weight_norm()
- self.conv_pre.remove_weight_norm()
- self.conv_post.remove_weight_norm()
- if self.nsf_enable:
- self.source_module.remove_weight_norm()
- for layer in self.source_downs:
- layer.remove_weight_norm()
-
-
-class PeriodDiscriminator(torch.nn.Module):
-
- def __init__(
- self,
- in_channels=1,
- out_channels=1,
- period=3,
- kernel_sizes=[5, 3],
- channels=32,
- downsample_scales=[3, 3, 3, 3, 1],
- max_downsample_channels=1024,
- bias=True,
- nonlinear_activation='LeakyReLU',
- nonlinear_activation_params={'negative_slope': 0.1},
- use_spectral_norm=False,
- ):
- super(PeriodDiscriminator, self).__init__()
- self.period = period
- norm_f = weight_norm if not use_spectral_norm else spectral_norm
- self.convs = nn.ModuleList()
- in_chs, out_chs = in_channels, channels
-
- for downsample_scale in downsample_scales:
- self.convs.append(
- torch.nn.Sequential(
- norm_f(
- nn.Conv2d(
- in_chs,
- out_chs,
- (kernel_sizes[0], 1),
- (downsample_scale, 1),
- padding=((kernel_sizes[0] - 1) // 2, 0),
- )),
- getattr(
- torch.nn,
- nonlinear_activation)(**nonlinear_activation_params),
- ))
- in_chs = out_chs
- out_chs = min(out_chs * 4, max_downsample_channels)
-
- self.conv_post = nn.Conv2d(
- out_chs,
- out_channels,
- (kernel_sizes[1] - 1, 1),
- 1,
- padding=((kernel_sizes[1] - 1) // 2, 0),
- )
-
- def forward(self, x):
- fmap = []
-
- # 1d to 2d
- b, c, t = x.shape
- if t % self.period != 0: # pad first
- n_pad = self.period - (t % self.period)
- x = F.pad(x, (0, n_pad), 'reflect')
- t = t + n_pad
- x = x.view(b, c, t // self.period, self.period)
-
- for layer in self.convs:
- x = layer(x)
- fmap.append(x)
- x = self.conv_post(x)
- fmap.append(x)
- x = torch.flatten(x, 1, -1)
-
- return x, fmap
-
-
-class MultiPeriodDiscriminator(torch.nn.Module):
-
- def __init__(
- self,
- periods=[2, 3, 5, 7, 11],
- discriminator_params={
- 'in_channels': 1,
- 'out_channels': 1,
- 'kernel_sizes': [5, 3],
- 'channels': 32,
- 'downsample_scales': [3, 3, 3, 3, 1],
- 'max_downsample_channels': 1024,
- 'bias': True,
- 'nonlinear_activation': 'LeakyReLU',
- 'nonlinear_activation_params': {
- 'negative_slope': 0.1
- },
- 'use_spectral_norm': False,
- },
- ):
- super(MultiPeriodDiscriminator, self).__init__()
- self.discriminators = nn.ModuleList()
- for period in periods:
- params = copy.deepcopy(discriminator_params)
- params['period'] = period
- self.discriminators += [PeriodDiscriminator(**params)]
-
- def forward(self, y):
- y_d_rs = []
- fmap_rs = []
- for i, d in enumerate(self.discriminators):
- y_d_r, fmap_r = d(y)
- y_d_rs.append(y_d_r)
- fmap_rs.append(fmap_r)
-
- return y_d_rs, fmap_rs
-
-
-class ScaleDiscriminator(torch.nn.Module):
-
- def __init__(
- self,
- in_channels=1,
- out_channels=1,
- kernel_sizes=[15, 41, 5, 3],
- channels=128,
- max_downsample_channels=1024,
- max_groups=16,
- bias=True,
- downsample_scales=[2, 2, 4, 4, 1],
- nonlinear_activation='LeakyReLU',
- nonlinear_activation_params={'negative_slope': 0.1},
- use_spectral_norm=False,
- ):
- super(ScaleDiscriminator, self).__init__()
- norm_f = weight_norm if not use_spectral_norm else spectral_norm
-
- assert len(kernel_sizes) == 4
- for ks in kernel_sizes:
- assert ks % 2 == 1
-
- self.convs = nn.ModuleList()
-
- self.convs.append(
- torch.nn.Sequential(
- norm_f(
- nn.Conv1d(
- in_channels,
- channels,
- kernel_sizes[0],
- bias=bias,
- padding=(kernel_sizes[0] - 1) // 2,
- )),
- getattr(torch.nn,
- nonlinear_activation)(**nonlinear_activation_params),
- ))
- in_chs = channels
- out_chs = channels
- groups = 4
-
- for downsample_scale in downsample_scales:
- self.convs.append(
- torch.nn.Sequential(
- norm_f(
- nn.Conv1d(
- in_chs,
- out_chs,
- kernel_size=kernel_sizes[1],
- stride=downsample_scale,
- padding=(kernel_sizes[1] - 1) // 2,
- groups=groups,
- bias=bias,
- )),
- getattr(
- torch.nn,
- nonlinear_activation)(**nonlinear_activation_params),
- ))
- in_chs = out_chs
- out_chs = min(in_chs * 2, max_downsample_channels)
- groups = min(groups * 4, max_groups)
-
- out_chs = min(in_chs * 2, max_downsample_channels)
- self.convs.append(
- torch.nn.Sequential(
- norm_f(
- nn.Conv1d(
- in_chs,
- out_chs,
- kernel_size=kernel_sizes[2],
- stride=1,
- padding=(kernel_sizes[2] - 1) // 2,
- bias=bias,
- )),
- getattr(torch.nn,
- nonlinear_activation)(**nonlinear_activation_params),
- ))
-
- self.conv_post = norm_f(
- nn.Conv1d(
- out_chs,
- out_channels,
- kernel_size=kernel_sizes[3],
- stride=1,
- padding=(kernel_sizes[3] - 1) // 2,
- bias=bias,
- ))
-
- def forward(self, x):
- fmap = []
- for layer in self.convs:
- x = layer(x)
- fmap.append(x)
- x = self.conv_post(x)
- fmap.append(x)
- x = torch.flatten(x, 1, -1)
-
- return x, fmap
-
-
-class MultiScaleDiscriminator(torch.nn.Module):
-
- def __init__(
- self,
- scales=3,
- downsample_pooling='DWT',
- # follow the official implementation setting
- downsample_pooling_params={
- 'kernel_size': 4,
- 'stride': 2,
- 'padding': 2,
- },
- discriminator_params={
- 'in_channels': 1,
- 'out_channels': 1,
- 'kernel_sizes': [15, 41, 5, 3],
- 'channels': 128,
- 'max_downsample_channels': 1024,
- 'max_groups': 16,
- 'bias': True,
- 'downsample_scales': [2, 2, 4, 4, 1],
- 'nonlinear_activation': 'LeakyReLU',
- 'nonlinear_activation_params': {
- 'negative_slope': 0.1
- },
- },
- follow_official_norm=False,
- ):
- super(MultiScaleDiscriminator, self).__init__()
- self.discriminators = torch.nn.ModuleList()
-
- # add discriminators
- for i in range(scales):
- params = copy.deepcopy(discriminator_params)
- if follow_official_norm:
- params['use_spectral_norm'] = True if i == 0 else False
- self.discriminators += [ScaleDiscriminator(**params)]
-
- if downsample_pooling == 'DWT':
- self.meanpools = nn.ModuleList(
- [DWT1DForward(wave='db3', J=1),
- DWT1DForward(wave='db3', J=1)])
- self.aux_convs = nn.ModuleList([
- weight_norm(nn.Conv1d(2, 1, 15, 1, padding=7)),
- weight_norm(nn.Conv1d(2, 1, 15, 1, padding=7)),
- ])
- else:
- self.meanpools = nn.ModuleList(
- [nn.AvgPool1d(4, 2, padding=2),
- nn.AvgPool1d(4, 2, padding=2)])
- self.aux_convs = None
-
- def forward(self, y):
- y_d_rs = []
- fmap_rs = []
- for i, d in enumerate(self.discriminators):
- if i != 0:
- if self.aux_convs is None:
- y = self.meanpools[i - 1](y)
- else:
- yl, yh = self.meanpools[i - 1](y)
- y = torch.cat([yl, yh[0]], dim=1)
- y = self.aux_convs[i - 1](y)
- y = F.leaky_relu(y, 0.1)
-
- y_d_r, fmap_r = d(y)
- y_d_rs.append(y_d_r)
- fmap_rs.append(fmap_r)
-
- return y_d_rs, fmap_rs
-
-
-class SpecDiscriminator(torch.nn.Module):
-
- def __init__(
- self,
- channels=32,
- init_kernel=15,
- kernel_size=11,
- stride=2,
- use_spectral_norm=False,
- fft_size=1024,
- shift_size=120,
- win_length=600,
- window='hann_window',
- nonlinear_activation='LeakyReLU',
- nonlinear_activation_params={'negative_slope': 0.1},
- ):
- super(SpecDiscriminator, self).__init__()
- self.fft_size = fft_size
- self.shift_size = shift_size
- self.win_length = win_length
- # fft_size // 2 + 1
- norm_f = weight_norm if not use_spectral_norm else spectral_norm
- final_kernel = 5
- post_conv_kernel = 3
- blocks = 3
- self.convs = nn.ModuleList()
- self.convs.append(
- torch.nn.Sequential(
- norm_f(
- nn.Conv2d(
- fft_size // 2 + 1,
- channels,
- (init_kernel, 1),
- (1, 1),
- padding=(init_kernel - 1) // 2,
- )),
- getattr(torch.nn,
- nonlinear_activation)(**nonlinear_activation_params),
- ))
-
- for i in range(blocks):
- self.convs.append(
- torch.nn.Sequential(
- norm_f(
- nn.Conv2d(
- channels,
- channels,
- (kernel_size, 1),
- (stride, 1),
- padding=(kernel_size - 1) // 2,
- )),
- getattr(
- torch.nn,
- nonlinear_activation)(**nonlinear_activation_params),
- ))
-
- self.convs.append(
- torch.nn.Sequential(
- norm_f(
- nn.Conv2d(
- channels,
- channels,
- (final_kernel, 1),
- (1, 1),
- padding=(final_kernel - 1) // 2,
- )),
- getattr(torch.nn,
- nonlinear_activation)(**nonlinear_activation_params),
- ))
-
- self.conv_post = norm_f(
- nn.Conv2d(
- channels,
- 1,
- (post_conv_kernel, 1),
- (1, 1),
- padding=((post_conv_kernel - 1) // 2, 0),
- ))
- self.register_buffer('window', getattr(torch, window)(win_length))
-
- def forward(self, wav):
- with torch.no_grad():
- wav = torch.squeeze(wav, 1)
- x_mag = stft(wav, self.fft_size, self.shift_size, self.win_length,
- self.window)
- x = torch.transpose(x_mag, 2, 1).unsqueeze(-1)
- fmap = []
- for layer in self.convs:
- x = layer(x)
- fmap.append(x)
- x = self.conv_post(x)
- fmap.append(x)
- x = x.squeeze(-1)
-
- return x, fmap
-
-
-class MultiSpecDiscriminator(torch.nn.Module):
-
- def __init__(
- self,
- fft_sizes=[1024, 2048, 512],
- hop_sizes=[120, 240, 50],
- win_lengths=[600, 1200, 240],
- discriminator_params={
- 'channels': 15,
- 'init_kernel': 1,
- 'kernel_sizes': 11,
- 'stride': 2,
- 'use_spectral_norm': False,
- 'window': 'hann_window',
- 'nonlinear_activation': 'LeakyReLU',
- 'nonlinear_activation_params': {
- 'negative_slope': 0.1
- },
- },
- ):
- super(MultiSpecDiscriminator, self).__init__()
- self.discriminators = nn.ModuleList()
- for fft_size, hop_size, win_length in zip(fft_sizes, hop_sizes,
- win_lengths):
- params = copy.deepcopy(discriminator_params)
- params['fft_size'] = fft_size
- params['shift_size'] = hop_size
- params['win_length'] = win_length
- self.discriminators += [SpecDiscriminator(**params)]
-
- def forward(self, y):
- y_d = []
- fmap = []
- for i, d in enumerate(self.discriminators):
- x, x_map = d(y)
- y_d.append(x)
- fmap.append(x_map)
-
- return y_d, fmap
diff --git a/modelscope/models/audio/tts/kantts/models/hifigan/layers.py b/modelscope/models/audio/tts/kantts/models/hifigan/layers.py
deleted file mode 100644
index 78887417..00000000
--- a/modelscope/models/audio/tts/kantts/models/hifigan/layers.py
+++ /dev/null
@@ -1,288 +0,0 @@
-# Copyright (c) Alibaba, Inc. and its affiliates.
-import numpy as np
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from torch.distributions.normal import Normal
-from torch.distributions.uniform import Uniform
-from torch.nn.utils import remove_weight_norm, weight_norm
-
-from modelscope.models.audio.tts.kantts.models.utils import init_weights
-
-
-def get_padding(kernel_size, dilation=1):
- return int((kernel_size * dilation - dilation) / 2)
-
-
-class Conv1d(torch.nn.Module):
-
- def __init__(
- self,
- in_channels,
- out_channels,
- kernel_size,
- stride=1,
- padding=0,
- dilation=1,
- groups=1,
- bias=True,
- padding_mode='zeros',
- ):
- super(Conv1d, self).__init__()
- self.conv1d = weight_norm(
- nn.Conv1d(
- in_channels,
- out_channels,
- kernel_size,
- stride,
- padding=padding,
- dilation=dilation,
- groups=groups,
- bias=bias,
- padding_mode=padding_mode,
- ))
- self.conv1d.apply(init_weights)
-
- def forward(self, x):
- x = self.conv1d(x)
- return x
-
- def remove_weight_norm(self):
- remove_weight_norm(self.conv1d)
-
-
-class CausalConv1d(torch.nn.Module):
-
- def __init__(
- self,
- in_channels,
- out_channels,
- kernel_size,
- stride=1,
- padding=0,
- dilation=1,
- groups=1,
- bias=True,
- padding_mode='zeros',
- ):
- super(CausalConv1d, self).__init__()
- self.pad = (kernel_size - 1) * dilation
- self.conv1d = weight_norm(
- nn.Conv1d(
- in_channels,
- out_channels,
- kernel_size,
- stride,
- padding=0,
- dilation=dilation,
- groups=groups,
- bias=bias,
- padding_mode=padding_mode,
- ))
- self.conv1d.apply(init_weights)
-
- def forward(self, x): # bdt
- x = F.pad(
- x, (self.pad, 0, 0, 0, 0, 0), 'constant'
- ) # described starting from the last dimension and moving forward.
- # x = F.pad(x, (self.pad, self.pad, 0, 0, 0, 0), "constant")
- x = self.conv1d(x)[:, :, :x.size(2)]
- return x
-
- def remove_weight_norm(self):
- remove_weight_norm(self.conv1d)
-
-
-class ConvTranspose1d(torch.nn.Module):
-
- def __init__(
- self,
- in_channels,
- out_channels,
- kernel_size,
- stride,
- padding=0,
- output_padding=0,
- ):
- super(ConvTranspose1d, self).__init__()
- self.deconv = weight_norm(
- nn.ConvTranspose1d(
- in_channels,
- out_channels,
- kernel_size,
- stride,
- padding=padding,
- output_padding=0,
- ))
- self.deconv.apply(init_weights)
-
- def forward(self, x):
- return self.deconv(x)
-
- def remove_weight_norm(self):
- remove_weight_norm(self.deconv)
-
-
-# FIXME: HACK to get shape right
-class CausalConvTranspose1d(torch.nn.Module):
- """CausalConvTranspose1d module with customized initialization."""
-
- def __init__(
- self,
- in_channels,
- out_channels,
- kernel_size,
- stride,
- padding=0,
- output_padding=0,
- ):
- """Initialize CausalConvTranspose1d module."""
- super(CausalConvTranspose1d, self).__init__()
- self.deconv = weight_norm(
- nn.ConvTranspose1d(
- in_channels,
- out_channels,
- kernel_size,
- stride,
- padding=0,
- output_padding=0,
- ))
- self.stride = stride
- self.deconv.apply(init_weights)
- self.pad = kernel_size - stride
-
- def forward(self, x):
- """Calculate forward propagation.
- Args:
- x (Tensor): Input tensor (B, in_channels, T_in).
- Returns:
- Tensor: Output tensor (B, out_channels, T_out).
- """
- # x = F.pad(x, (self.pad, 0, 0, 0, 0, 0), "constant")
- return self.deconv(x)[:, :, :-self.pad]
- # return self.deconv(x)
-
- def remove_weight_norm(self):
- remove_weight_norm(self.deconv)
-
-
-class ResidualBlock(torch.nn.Module):
-
- def __init__(
- self,
- channels,
- kernel_size=3,
- dilation=(1, 3, 5),
- nonlinear_activation='LeakyReLU',
- nonlinear_activation_params={'negative_slope': 0.1},
- causal=False,
- ):
- super(ResidualBlock, self).__init__()
- assert kernel_size % 2 == 1, 'Kernal size must be odd number.'
- conv_cls = CausalConv1d if causal else Conv1d
- self.convs1 = nn.ModuleList([
- conv_cls(
- channels,
- channels,
- kernel_size,
- 1,
- dilation=dilation[i],
- padding=get_padding(kernel_size, dilation[i]),
- ) for i in range(len(dilation))
- ])
-
- self.convs2 = nn.ModuleList([
- conv_cls(
- channels,
- channels,
- kernel_size,
- 1,
- dilation=1,
- padding=get_padding(kernel_size, 1),
- ) for i in range(len(dilation))
- ])
-
- self.activation = getattr(
- torch.nn, nonlinear_activation)(**nonlinear_activation_params)
-
- def forward(self, x):
- for c1, c2 in zip(self.convs1, self.convs2):
- xt = self.activation(x)
- xt = c1(xt)
- xt = self.activation(xt)
- xt = c2(xt)
- x = xt + x
- return x
-
- def remove_weight_norm(self):
- for layer in self.convs1:
- layer.remove_weight_norm()
- for layer in self.convs2:
- layer.remove_weight_norm()
-
-
-class SourceModule(torch.nn.Module):
-
- def __init__(self,
- nb_harmonics,
- upsample_ratio,
- sampling_rate,
- alpha=0.1,
- sigma=0.003):
- super(SourceModule, self).__init__()
-
- self.nb_harmonics = nb_harmonics
- self.upsample_ratio = upsample_ratio
- self.sampling_rate = sampling_rate
- self.alpha = alpha
- self.sigma = sigma
-
- self.ffn = nn.Sequential(
- weight_norm(
- nn.Conv1d(self.nb_harmonics + 1, 1, kernel_size=1, stride=1)),
- nn.Tanh(),
- )
-
- def forward(self, pitch, uv):
- """
- :param pitch: [B, 1, frame_len], Hz
- :param uv: [B, 1, frame_len] vuv flag
- :return: [B, 1, sample_len]
- """
- with torch.no_grad():
- pitch_samples = F.interpolate(
- pitch, scale_factor=(self.upsample_ratio), mode='nearest')
- uv_samples = F.interpolate(
- uv, scale_factor=(self.upsample_ratio), mode='nearest')
-
- F_mat = torch.zeros(
- (pitch_samples.size(0), self.nb_harmonics + 1,
- pitch_samples.size(-1))).to(pitch_samples.device)
- for i in range(self.nb_harmonics + 1):
- F_mat[:, i:i
- + 1, :] = pitch_samples * (i + 1) / self.sampling_rate
-
- theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1)
- u_dist = Uniform(low=-np.pi, high=np.pi)
- phase_vec = u_dist.sample(
- sample_shape=(pitch.size(0), self.nb_harmonics + 1,
- 1)).to(F_mat.device)
- phase_vec[:, 0, :] = 0
-
- n_dist = Normal(loc=0.0, scale=self.sigma)
- noise = n_dist.sample(
- sample_shape=(
- pitch_samples.size(0),
- self.nb_harmonics + 1,
- pitch_samples.size(-1),
- )).to(F_mat.device)
-
- e_voice = self.alpha * torch.sin(theta_mat + phase_vec) + noise
- e_unvoice = self.alpha / 3 / self.sigma * noise
-
- e = e_voice * uv_samples + e_unvoice * (1 - uv_samples)
-
- return self.ffn(e)
-
- def remove_weight_norm(self):
- remove_weight_norm(self.ffn[0])
diff --git a/modelscope/models/audio/tts/kantts/models/pqmf.py b/modelscope/models/audio/tts/kantts/models/pqmf.py
deleted file mode 100644
index d4679af2..00000000
--- a/modelscope/models/audio/tts/kantts/models/pqmf.py
+++ /dev/null
@@ -1,133 +0,0 @@
-# The implementation is adopted from kan-bayashi's ParallelWaveGAN,
-# made publicly available under the MIT License at https://github.com/kan-bayashi/ParallelWaveGAN
-
-import numpy as np
-import torch
-import torch.nn.functional as F
-from scipy.signal import kaiser
-
-
-def design_prototype_filter(taps=62, cutoff_ratio=0.142, beta=9.0):
- """Design prototype filter for PQMF.
-
- This method is based on `A Kaiser window approach for the design of prototype
- filters of cosine modulated filterbanks`_.
-
- Args:
- taps (int): The number of filter taps.
- cutoff_ratio (float): Cut-off frequency ratio.
- beta (float): Beta coefficient for kaiser window.
-
- Returns:
- ndarray: Impluse response of prototype filter (taps + 1,).
-
- .. _`A Kaiser window approach for the design of prototype filters of cosine modulated filterbanks`:
- https://ieeexplore.ieee.org/abstract/document/681427
-
- """
- # check the arguments are valid
- assert taps % 2 == 0, 'The number of taps mush be even number.'
- assert 0.0 < cutoff_ratio < 1.0, 'Cutoff ratio must be > 0.0 and < 1.0.'
-
- # make initial filter
- omega_c = np.pi * cutoff_ratio
- with np.errstate(invalid='ignore'):
- h_i = np.sin(omega_c * (np.arange(taps + 1) - 0.5 * taps)) / (
- np.pi * (np.arange(taps + 1) - 0.5 * taps))
- h_i[taps
- // 2] = np.cos(0) * cutoff_ratio # fix nan due to indeterminate form
-
- # apply kaiser window
- w = kaiser(taps + 1, beta)
- h = h_i * w
-
- return h
-
-
-class PQMF(torch.nn.Module):
- """PQMF module.
-
- This module is based on `Near-perfect-reconstruction pseudo-QMF banks`_.
-
- .. _`Near-perfect-reconstruction pseudo-QMF banks`:
- https://ieeexplore.ieee.org/document/258122
-
- """
-
- def __init__(self, subbands=4, taps=62, cutoff_ratio=0.142, beta=9.0):
- """Initilize PQMF module.
-
- The cutoff_ratio and beta parameters are optimized for #subbands = 4.
- See dicussion in https://github.com/kan-bayashi/ParallelWaveGAN/issues/195.
-
- Args:
- subbands (int): The number of subbands.
- taps (int): The number of filter taps.
- cutoff_ratio (float): Cut-off frequency ratio.
- beta (float): Beta coefficient for kaiser window.
-
- """
- super(PQMF, self).__init__()
-
- # build analysis & synthesis filter coefficients
- h_proto = design_prototype_filter(taps, cutoff_ratio, beta)
- h_analysis = np.zeros((subbands, len(h_proto)))
- h_synthesis = np.zeros((subbands, len(h_proto)))
- for k in range(subbands):
- h_analysis[k] = (
- 2 * h_proto * np.cos((2 * k + 1) * # noqa W504
- (np.pi / (2 * subbands)) * # noqa W504
- (np.arange(taps + 1) - (taps / 2))
- + (-1)**k * np.pi / 4))
- h_synthesis[k] = (
- 2 * h_proto * np.cos((2 * k + 1) * # noqa W504
- (np.pi / (2 * subbands)) * # noqa W504
- (np.arange(taps + 1) - (taps / 2))
- - (-1)**k * np.pi / 4))
-
- # convert to tensor
- analysis_filter = torch.from_numpy(h_analysis).float().unsqueeze(1)
- synthesis_filter = torch.from_numpy(h_synthesis).float().unsqueeze(0)
-
- # register coefficients as beffer
- self.register_buffer('analysis_filter', analysis_filter)
- self.register_buffer('synthesis_filter', synthesis_filter)
-
- # filter for downsampling & upsampling
- updown_filter = torch.zeros((subbands, subbands, subbands)).float()
- for k in range(subbands):
- updown_filter[k, k, 0] = 1.0
- self.register_buffer('updown_filter', updown_filter)
- self.subbands = subbands
-
- # keep padding info
- self.pad_fn = torch.nn.ConstantPad1d(taps // 2, 0.0)
-
- def analysis(self, x):
- """Analysis with PQMF.
-
- Args:
- x (Tensor): Input tensor (B, 1, T).
-
- Returns:
- Tensor: Output tensor (B, subbands, T // subbands).
-
- """
- x = F.conv1d(self.pad_fn(x), self.analysis_filter)
- return F.conv1d(x, self.updown_filter, stride=self.subbands)
-
- def synthesis(self, x):
- """Synthesis with PQMF.
-
- Args:
- x (Tensor): Input tensor (B, subbands, T // subbands).
-
- Returns:
- Tensor: Output tensor (B, 1, T).
-
- """
- # NOTE(kan-bayashi): Power will be dreased so here multiply by # subbands.
- # Not sure this is the correct way, it is better to check again.
- x = F.conv_transpose1d(
- x, self.updown_filter * self.subbands, stride=self.subbands)
- return F.conv1d(self.pad_fn(x), self.synthesis_filter)
diff --git a/modelscope/models/audio/tts/kantts/models/sambert/__init__.py b/modelscope/models/audio/tts/kantts/models/sambert/__init__.py
deleted file mode 100644
index bd2939e2..00000000
--- a/modelscope/models/audio/tts/kantts/models/sambert/__init__.py
+++ /dev/null
@@ -1,372 +0,0 @@
-# Copyright (c) Alibaba, Inc. and its affiliates.
-import numpy as np
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-
-class ScaledDotProductAttention(nn.Module):
- """ Scaled Dot-Product Attention """
-
- def __init__(self, temperature, dropatt=0.0):
- super().__init__()
- self.temperature = temperature
- self.softmax = nn.Softmax(dim=2)
- self.dropatt = nn.Dropout(dropatt)
-
- def forward(self, q, k, v, mask=None):
-
- attn = torch.bmm(q, k.transpose(1, 2))
- attn = attn / self.temperature
-
- if mask is not None:
- attn = attn.masked_fill(mask, -np.inf)
-
- attn = self.softmax(attn)
- attn = self.dropatt(attn)
- output = torch.bmm(attn, v)
-
- return output, attn
-
-
-class Prenet(nn.Module):
-
- def __init__(self, in_units, prenet_units, out_units=0):
- super(Prenet, self).__init__()
-
- self.fcs = nn.ModuleList()
- for in_dim, out_dim in zip([in_units] + prenet_units[:-1],
- prenet_units):
- self.fcs.append(nn.Linear(in_dim, out_dim))
- self.fcs.append(nn.ReLU())
- self.fcs.append(nn.Dropout(0.5))
-
- if out_units:
- self.fcs.append(nn.Linear(prenet_units[-1], out_units))
-
- def forward(self, input):
- output = input
- for layer in self.fcs:
- output = layer(output)
- return output
-
-
-class MultiHeadSelfAttention(nn.Module):
- """ Multi-Head SelfAttention module """
-
- def __init__(self, n_head, d_in, d_model, d_head, dropout, dropatt=0.0):
- super().__init__()
-
- self.n_head = n_head
- self.d_head = d_head
- self.d_in = d_in
- self.d_model = d_model
-
- self.layer_norm = nn.LayerNorm(d_in, eps=1e-6)
- self.w_qkv = nn.Linear(d_in, 3 * n_head * d_head)
-
- self.attention = ScaledDotProductAttention(
- temperature=np.power(d_head, 0.5), dropatt=dropatt)
-
- self.fc = nn.Linear(n_head * d_head, d_model)
-
- self.dropout = nn.Dropout(dropout)
-
- def forward(self, input, mask=None):
- d_head, n_head = self.d_head, self.n_head
-
- sz_b, len_in, _ = input.size()
-
- residual = input
-
- x = self.layer_norm(input)
- qkv = self.w_qkv(x)
- q, k, v = qkv.chunk(3, -1)
-
- q = q.view(sz_b, len_in, n_head, d_head)
- k = k.view(sz_b, len_in, n_head, d_head)
- v = v.view(sz_b, len_in, n_head, d_head)
-
- q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_in,
- d_head) # (n*b) x l x d
- k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_in,
- d_head) # (n*b) x l x d
- v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_in,
- d_head) # (n*b) x l x d
-
- if mask is not None:
- mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x ..
- output, attn = self.attention(q, k, v, mask=mask)
-
- output = output.view(n_head, sz_b, len_in, d_head)
- output = (output.permute(1, 2, 0,
- 3).contiguous().view(sz_b, len_in,
- -1)) # b x l x (n*d)
-
- output = self.dropout(self.fc(output))
- if output.size(-1) == residual.size(-1):
- output = output + residual
-
- return output, attn
-
-
-class PositionwiseConvFeedForward(nn.Module):
- """ A two-feed-forward-layer module """
-
- def __init__(self,
- d_in,
- d_hid,
- kernel_size=(3, 1),
- dropout_inner=0.1,
- dropout=0.1):
- super().__init__()
- # Use Conv1D
- # position-wise
- self.w_1 = nn.Conv1d(
- d_in,
- d_hid,
- kernel_size=kernel_size[0],
- padding=(kernel_size[0] - 1) // 2,
- )
- # position-wise
- self.w_2 = nn.Conv1d(
- d_hid,
- d_in,
- kernel_size=kernel_size[1],
- padding=(kernel_size[1] - 1) // 2,
- )
-
- self.layer_norm = nn.LayerNorm(d_in, eps=1e-6)
- self.dropout_inner = nn.Dropout(dropout_inner)
- self.dropout = nn.Dropout(dropout)
-
- def forward(self, x, mask=None):
- residual = x
- x = self.layer_norm(x)
-
- output = x.transpose(1, 2)
- output = F.relu(self.w_1(output))
- if mask is not None:
- output = output.masked_fill(mask.unsqueeze(1), 0)
- output = self.dropout_inner(output)
- output = self.w_2(output)
- output = output.transpose(1, 2)
- output = self.dropout(output)
-
- output = output + residual
-
- return output
-
-
-class FFTBlock(nn.Module):
- """FFT Block"""
-
- def __init__(
- self,
- d_in,
- d_model,
- n_head,
- d_head,
- d_inner,
- kernel_size,
- dropout,
- dropout_attn=0.0,
- dropout_relu=0.0,
- ):
- super(FFTBlock, self).__init__()
- self.slf_attn = MultiHeadSelfAttention(
- n_head,
- d_in,
- d_model,
- d_head,
- dropout=dropout,
- dropatt=dropout_attn)
- self.pos_ffn = PositionwiseConvFeedForward(
- d_model,
- d_inner,
- kernel_size,
- dropout_inner=dropout_relu,
- dropout=dropout)
-
- def forward(self, input, mask=None, slf_attn_mask=None):
- output, slf_attn = self.slf_attn(input, mask=slf_attn_mask)
- if mask is not None:
- output = output.masked_fill(mask.unsqueeze(-1), 0)
-
- output = self.pos_ffn(output, mask=mask)
- if mask is not None:
- output = output.masked_fill(mask.unsqueeze(-1), 0)
-
- return output, slf_attn
-
-
-class MultiHeadPNCAAttention(nn.Module):
- """ Multi-Head Attention PNCA module """
-
- def __init__(self, n_head, d_model, d_mem, d_head, dropout, dropatt=0.0):
- super().__init__()
-
- self.n_head = n_head
- self.d_head = d_head
- self.d_model = d_model
- self.d_mem = d_mem
-
- self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
-
- self.w_x_qkv = nn.Linear(d_model, 3 * n_head * d_head)
- self.fc_x = nn.Linear(n_head * d_head, d_model)
-
- self.w_h_kv = nn.Linear(d_mem, 2 * n_head * d_head)
- self.fc_h = nn.Linear(n_head * d_head, d_model)
-
- self.attention = ScaledDotProductAttention(
- temperature=np.power(d_head, 0.5), dropatt=dropatt)
-
- self.dropout = nn.Dropout(dropout)
-
- def update_x_state(self, x):
- d_head, n_head = self.d_head, self.n_head
-
- sz_b, len_x, _ = x.size()
-
- x_qkv = self.w_x_qkv(x)
- x_q, x_k, x_v = x_qkv.chunk(3, -1)
-
- x_q = x_q.view(sz_b, len_x, n_head, d_head)
- x_k = x_k.view(sz_b, len_x, n_head, d_head)
- x_v = x_v.view(sz_b, len_x, n_head, d_head)
-
- x_q = x_q.permute(2, 0, 1, 3).contiguous().view(-1, len_x, d_head)
- x_k = x_k.permute(2, 0, 1, 3).contiguous().view(-1, len_x, d_head)
- x_v = x_v.permute(2, 0, 1, 3).contiguous().view(-1, len_x, d_head)
-
- if self.x_state_size:
- self.x_k = torch.cat([self.x_k, x_k], dim=1)
- self.x_v = torch.cat([self.x_v, x_v], dim=1)
- else:
- self.x_k = x_k
- self.x_v = x_v
-
- self.x_state_size += len_x
-
- return x_q, x_k, x_v
-
- def update_h_state(self, h):
- if self.h_state_size == h.size(1):
- return None, None
-
- d_head, n_head = self.d_head, self.n_head
-
- # H
- sz_b, len_h, _ = h.size()
-
- h_kv = self.w_h_kv(h)
- h_k, h_v = h_kv.chunk(2, -1)
-
- h_k = h_k.view(sz_b, len_h, n_head, d_head)
- h_v = h_v.view(sz_b, len_h, n_head, d_head)
-
- self.h_k = h_k.permute(2, 0, 1, 3).contiguous().view(-1, len_h, d_head)
- self.h_v = h_v.permute(2, 0, 1, 3).contiguous().view(-1, len_h, d_head)
-
- self.h_state_size += len_h
-
- return h_k, h_v
-
- def reset_state(self):
- self.h_k = None
- self.h_v = None
- self.h_state_size = 0
- self.x_k = None
- self.x_v = None
- self.x_state_size = 0
-
- def forward(self, x, h, mask_x=None, mask_h=None):
- residual = x
- self.update_h_state(h)
- x_q, x_k, x_v = self.update_x_state(self.layer_norm(x))
-
- d_head, n_head = self.d_head, self.n_head
-
- sz_b, len_in, _ = x.size()
-
- # X
- if mask_x is not None:
- mask_x = mask_x.repeat(n_head, 1, 1) # (n*b) x .. x ..
- output_x, attn_x = self.attention(x_q, self.x_k, self.x_v, mask=mask_x)
-
- output_x = output_x.view(n_head, sz_b, len_in, d_head)
- output_x = (output_x.permute(1, 2, 0,
- 3).contiguous().view(sz_b, len_in,
- -1)) # b x l x (n*d)
- output_x = self.fc_x(output_x)
-
- # H
- if mask_h is not None:
- mask_h = mask_h.repeat(n_head, 1, 1)
- output_h, attn_h = self.attention(x_q, self.h_k, self.h_v, mask=mask_h)
-
- output_h = output_h.view(n_head, sz_b, len_in, d_head)
- output_h = (output_h.permute(1, 2, 0,
- 3).contiguous().view(sz_b, len_in,
- -1)) # b x l x (n*d)
- output_h = self.fc_h(output_h)
-
- output = output_x + output_h
-
- output = self.dropout(output)
-
- output = output + residual
-
- return output, attn_x, attn_h
-
-
-class PNCABlock(nn.Module):
- """PNCA Block"""
-
- def __init__(
- self,
- d_model,
- d_mem,
- n_head,
- d_head,
- d_inner,
- kernel_size,
- dropout,
- dropout_attn=0.0,
- dropout_relu=0.0,
- ):
- super(PNCABlock, self).__init__()
- self.pnca_attn = MultiHeadPNCAAttention(
- n_head,
- d_model,
- d_mem,
- d_head,
- dropout=dropout,
- dropatt=dropout_attn)
- self.pos_ffn = PositionwiseConvFeedForward(
- d_model,
- d_inner,
- kernel_size,
- dropout_inner=dropout_relu,
- dropout=dropout)
-
- def forward(self,
- input,
- memory,
- mask=None,
- pnca_x_attn_mask=None,
- pnca_h_attn_mask=None):
- output, pnca_attn_x, pnca_attn_h = self.pnca_attn(
- input, memory, pnca_x_attn_mask, pnca_h_attn_mask)
- if mask is not None:
- output = output.masked_fill(mask.unsqueeze(-1), 0)
-
- output = self.pos_ffn(output, mask=mask)
- if mask is not None:
- output = output.masked_fill(mask.unsqueeze(-1), 0)
-
- return output, pnca_attn_x, pnca_attn_h
-
- def reset_state(self):
- self.pnca_attn.reset_state()
diff --git a/modelscope/models/audio/tts/kantts/models/sambert/adaptors.py b/modelscope/models/audio/tts/kantts/models/sambert/adaptors.py
deleted file mode 100644
index bd7edd6e..00000000
--- a/modelscope/models/audio/tts/kantts/models/sambert/adaptors.py
+++ /dev/null
@@ -1,147 +0,0 @@
-# Copyright (c) Alibaba, Inc. and its affiliates.
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-from . import Prenet
-from .fsmn import FsmnEncoderV2
-
-
-class LengthRegulator(nn.Module):
-
- def __init__(self, r=1):
- super(LengthRegulator, self).__init__()
-
- self.r = r
-
- def forward(self, inputs, durations, masks=None):
- reps = (durations + 0.5).long()
- output_lens = reps.sum(dim=1)
- max_len = output_lens.max()
- reps_cumsum = torch.cumsum(
- F.pad(reps.float(), (1, 0, 0, 0), value=0.0), dim=1)[:, None, :]
- range_ = torch.arange(max_len).to(inputs.device)[None, :, None]
- mult = (reps_cumsum[:, :, :-1] <= range_) & (
- reps_cumsum[:, :, 1:] > range_)
- mult = mult.float()
- out = torch.matmul(mult, inputs)
-
- if masks is not None:
- out = out.masked_fill(masks.unsqueeze(-1), 0.0)
-
- seq_len = out.size(1)
- padding = self.r - int(seq_len) % self.r
- if padding < self.r:
- out = F.pad(
- out.transpose(1, 2), (0, padding, 0, 0, 0, 0), value=0.0)
- out = out.transpose(1, 2)
-
- return out, output_lens
-
-
-class VarRnnARPredictor(nn.Module):
-
- def __init__(self, cond_units, prenet_units, rnn_units):
- super(VarRnnARPredictor, self).__init__()
-
- self.prenet = Prenet(1, prenet_units)
- self.lstm = nn.LSTM(
- prenet_units[-1] + cond_units,
- rnn_units,
- num_layers=2,
- batch_first=True,
- bidirectional=False,
- )
- self.fc = nn.Linear(rnn_units, 1)
-
- def forward(self, inputs, cond, h=None, masks=None):
- x = torch.cat([self.prenet(inputs), cond], dim=-1)
- # The input can also be a packed variable length sequence,
- # here we just omit it for simplicity due to the mask and uni-directional lstm.
- x, h_new = self.lstm(x, h)
-
- x = self.fc(x).squeeze(-1)
- x = F.relu(x)
-
- if masks is not None:
- x = x.masked_fill(masks, 0.0)
-
- return x, h_new
-
- def infer(self, cond, masks=None):
- batch_size, length = cond.size(0), cond.size(1)
-
- output = []
- x = torch.zeros((batch_size, 1)).to(cond.device)
- h = None
-
- for i in range(length):
- x, h = self.forward(x.unsqueeze(1), cond[:, i:i + 1, :], h=h)
- output.append(x)
-
- output = torch.cat(output, dim=-1)
-
- if masks is not None:
- output = output.masked_fill(masks, 0.0)
-
- return output
-
-
-class VarFsmnRnnNARPredictor(nn.Module):
-
- def __init__(
- self,
- in_dim,
- filter_size,
- fsmn_num_layers,
- num_memory_units,
- ffn_inner_dim,
- dropout,
- shift,
- lstm_units,
- ):
- super(VarFsmnRnnNARPredictor, self).__init__()
-
- self.fsmn = FsmnEncoderV2(
- filter_size,
- fsmn_num_layers,
- in_dim,
- num_memory_units,
- ffn_inner_dim,
- dropout,
- shift,
- )
- self.blstm = nn.LSTM(
- num_memory_units,
- lstm_units,
- num_layers=1,
- batch_first=True,
- bidirectional=True,
- )
- self.fc = nn.Linear(2 * lstm_units, 1)
-
- def forward(self, inputs, masks=None):
- input_lengths = None
- if masks is not None:
- input_lengths = torch.sum((~masks).float(), dim=1).long()
-
- x = self.fsmn(inputs, masks)
-
- if input_lengths is not None:
- x = nn.utils.rnn.pack_padded_sequence(
- x,
- input_lengths.tolist(),
- batch_first=True,
- enforce_sorted=False)
- x, _ = self.blstm(x)
- x, _ = nn.utils.rnn.pad_packed_sequence(
- x, batch_first=True, total_length=inputs.size(1))
- else:
- x, _ = self.blstm(x)
-
- x = self.fc(x).squeeze(-1)
-
- if masks is not None:
- x = x.masked_fill(masks, 0.0)
-
- return x
diff --git a/modelscope/models/audio/tts/kantts/models/sambert/alignment.py b/modelscope/models/audio/tts/kantts/models/sambert/alignment.py
deleted file mode 100644
index 9bbec753..00000000
--- a/modelscope/models/audio/tts/kantts/models/sambert/alignment.py
+++ /dev/null
@@ -1,73 +0,0 @@
-# Copyright (c) Alibaba, Inc. and its affiliates.
-import numba as nb
-import numpy as np
-
-
-@nb.jit(nopython=True)
-def mas(attn_map, width=1):
- # assumes mel x text
- opt = np.zeros_like(attn_map)
- attn_map = np.log(attn_map)
- attn_map[0, 1:] = -np.inf
- log_p = np.zeros_like(attn_map)
- log_p[0, :] = attn_map[0, :]
- prev_ind = np.zeros_like(attn_map, dtype=np.int64)
- for i in range(1, attn_map.shape[0]):
- for j in range(attn_map.shape[1]): # for each text dim
- prev_j = np.arange(max(0, j - width), j + 1)
- prev_log = np.array(
- [log_p[i - 1, prev_idx] for prev_idx in prev_j])
-
- ind = np.argmax(prev_log)
- log_p[i, j] = attn_map[i, j] + prev_log[ind]
- prev_ind[i, j] = prev_j[ind]
-
- # now backtrack
- curr_text_idx = attn_map.shape[1] - 1
- for i in range(attn_map.shape[0] - 1, -1, -1):
- opt[i, curr_text_idx] = 1
- curr_text_idx = prev_ind[i, curr_text_idx]
- opt[0, curr_text_idx] = 1
- return opt
-
-
-@nb.jit(nopython=True)
-def mas_width1(attn_map):
- """mas with hardcoded width=1"""
- # assumes mel x text
- opt = np.zeros_like(attn_map)
- attn_map = np.log(attn_map)
- attn_map[0, 1:] = -np.inf
- log_p = np.zeros_like(attn_map)
- log_p[0, :] = attn_map[0, :]
- prev_ind = np.zeros_like(attn_map, dtype=np.int64)
- for i in range(1, attn_map.shape[0]):
- for j in range(attn_map.shape[1]): # for each text dim
- prev_log = log_p[i - 1, j]
- prev_j = j
-
- if j - 1 >= 0 and log_p[i - 1, j - 1] >= log_p[i - 1, j]:
- prev_log = log_p[i - 1, j - 1]
- prev_j = j - 1
-
- log_p[i, j] = attn_map[i, j] + prev_log
- prev_ind[i, j] = prev_j
-
- # now backtrack
- curr_text_idx = attn_map.shape[1] - 1
- for i in range(attn_map.shape[0] - 1, -1, -1):
- opt[i, curr_text_idx] = 1
- curr_text_idx = prev_ind[i, curr_text_idx]
- opt[0, curr_text_idx] = 1
- return opt
-
-
-@nb.jit(nopython=True, parallel=True)
-def b_mas(b_attn_map, in_lens, out_lens, width=1):
- assert width == 1
- attn_out = np.zeros_like(b_attn_map)
-
- for b in nb.prange(b_attn_map.shape[0]):
- out = mas_width1(b_attn_map[b, 0, :out_lens[b], :in_lens[b]])
- attn_out[b, 0, :out_lens[b], :in_lens[b]] = out
- return attn_out
diff --git a/modelscope/models/audio/tts/kantts/models/sambert/attention.py b/modelscope/models/audio/tts/kantts/models/sambert/attention.py
deleted file mode 100644
index 5ae32f7e..00000000
--- a/modelscope/models/audio/tts/kantts/models/sambert/attention.py
+++ /dev/null
@@ -1,131 +0,0 @@
-# Copyright (c) Alibaba, Inc. and its affiliates.
-import numpy as np
-import torch
-from torch import nn
-
-
-class ConvNorm(torch.nn.Module):
-
- def __init__(
- self,
- in_channels,
- out_channels,
- kernel_size=1,
- stride=1,
- padding=None,
- dilation=1,
- bias=True,
- w_init_gain='linear',
- ):
- super(ConvNorm, self).__init__()
- if padding is None:
- assert kernel_size % 2 == 1
- padding = int(dilation * (kernel_size - 1) / 2)
-
- self.conv = torch.nn.Conv1d(
- in_channels,
- out_channels,
- kernel_size=kernel_size,
- stride=stride,
- padding=padding,
- dilation=dilation,
- bias=bias,
- )
-
- torch.nn.init.xavier_uniform_(
- self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain))
-
- def forward(self, signal):
- conv_signal = self.conv(signal)
- return conv_signal
-
-
-class ConvAttention(torch.nn.Module):
-
- def __init__(
- self,
- n_mel_channels=80,
- n_text_channels=512,
- n_att_channels=80,
- temperature=1.0,
- use_query_proj=True,
- ):
- super(ConvAttention, self).__init__()
- self.temperature = temperature
- self.att_scaling_factor = np.sqrt(n_att_channels)
- self.softmax = torch.nn.Softmax(dim=3)
- self.log_softmax = torch.nn.LogSoftmax(dim=3)
- self.attn_proj = torch.nn.Conv2d(n_att_channels, 1, kernel_size=1)
- self.use_query_proj = bool(use_query_proj)
-
- self.key_proj = nn.Sequential(
- ConvNorm(
- n_text_channels,
- n_text_channels * 2,
- kernel_size=3,
- bias=True,
- w_init_gain='relu',
- ),
- torch.nn.ReLU(),
- ConvNorm(
- n_text_channels * 2, n_att_channels, kernel_size=1, bias=True),
- )
-
- self.query_proj = nn.Sequential(
- ConvNorm(
- n_mel_channels,
- n_mel_channels * 2,
- kernel_size=3,
- bias=True,
- w_init_gain='relu',
- ),
- torch.nn.ReLU(),
- ConvNorm(
- n_mel_channels * 2, n_mel_channels, kernel_size=1, bias=True),
- torch.nn.ReLU(),
- ConvNorm(n_mel_channels, n_att_channels, kernel_size=1, bias=True),
- )
-
- def forward(self, queries, keys, mask=None, attn_prior=None):
- """Attention mechanism for flowtron parallel
- Unlike in Flowtron, we have no restrictions such as causality etc,
- since we only need this during training.
-
- Args:
- queries (torch.tensor): B x C x T1 tensor
- (probably going to be mel data)
- keys (torch.tensor): B x C2 x T2 tensor (text data)
- mask (torch.tensor): uint8 binary mask for variable length entries
- (should be in the T2 domain)
- Output:
- attn (torch.tensor): B x 1 x T1 x T2 attention mask.
- Final dim T2 should sum to 1
- """
- keys_enc = self.key_proj(keys) # B x n_attn_dims x T2
-
- # Beware can only do this since query_dim = attn_dim = n_mel_channels
- if self.use_query_proj:
- queries_enc = self.query_proj(queries)
- else:
- queries_enc = queries
-
- # different ways of computing attn,
- # one is isotopic gaussians (per phoneme)
- # Simplistic Gaussian Isotopic Attention
-
- # B x n_attn_dims x T1 x T2
- attn = (queries_enc[:, :, :, None] - keys_enc[:, :, None])**2
- # compute log likelihood from a gaussian
- attn = -0.0005 * attn.sum(1, keepdim=True)
- if attn_prior is not None:
- attn = self.log_softmax(attn) + torch.log(attn_prior[:, None]
- + 1e-8)
-
- attn_logprob = attn.clone()
-
- if mask is not None:
- attn.data.masked_fill_(
- mask.unsqueeze(1).unsqueeze(1), -float('inf'))
-
- attn = self.softmax(attn) # Softmax along T2
- return attn, attn_logprob
diff --git a/modelscope/models/audio/tts/kantts/models/sambert/fsmn.py b/modelscope/models/audio/tts/kantts/models/sambert/fsmn.py
deleted file mode 100644
index d438537c..00000000
--- a/modelscope/models/audio/tts/kantts/models/sambert/fsmn.py
+++ /dev/null
@@ -1,127 +0,0 @@
-# Copyright (c) Alibaba, Inc. and its affiliates.
-import torch.nn as nn
-import torch.nn.functional as F
-
-
-class FeedForwardNet(nn.Module):
- """ A two-feed-forward-layer module """
-
- def __init__(self, d_in, d_hid, d_out, kernel_size=[1, 1], dropout=0.1):
- super().__init__()
-
- # Use Conv1D
- # position-wise
- self.w_1 = nn.Conv1d(
- d_in,
- d_hid,
- kernel_size=kernel_size[0],
- padding=(kernel_size[0] - 1) // 2,
- )
- # position-wise
- self.w_2 = nn.Conv1d(
- d_hid,
- d_out,
- kernel_size=kernel_size[1],
- padding=(kernel_size[1] - 1) // 2,
- bias=False,
- )
-
- self.dropout = nn.Dropout(dropout)
-
- def forward(self, x):
- output = x.transpose(1, 2)
- output = F.relu(self.w_1(output))
- output = self.dropout(output)
- output = self.w_2(output)
- output = output.transpose(1, 2)
-
- return output
-
-
-class MemoryBlockV2(nn.Module):
-
- def __init__(self, d, filter_size, shift, dropout=0.0):
- super(MemoryBlockV2, self).__init__()
-
- left_padding = int(round((filter_size - 1) / 2))
- right_padding = int((filter_size - 1) / 2)
- if shift > 0:
- left_padding += shift
- right_padding -= shift
-
- self.lp, self.rp = left_padding, right_padding
-
- self.conv_dw = nn.Conv1d(d, d, filter_size, 1, 0, groups=d, bias=False)
- self.dropout = nn.Dropout(dropout)
-
- def forward(self, input, mask=None):
- if mask is not None:
- input = input.masked_fill(mask.unsqueeze(-1), 0)
-
- x = F.pad(
- input, (0, 0, self.lp, self.rp, 0, 0), mode='constant', value=0.0)
- output = (
- self.conv_dw(x.contiguous().transpose(1,
- 2)).contiguous().transpose(
- 1, 2))
- output += input
- output = self.dropout(output)
-
- if mask is not None:
- output = output.masked_fill(mask.unsqueeze(-1), 0)
-
- return output
-
-
-class FsmnEncoderV2(nn.Module):
-
- def __init__(
- self,
- filter_size,
- fsmn_num_layers,
- input_dim,
- num_memory_units,
- ffn_inner_dim,
- dropout=0.0,
- shift=0,
- ):
- super(FsmnEncoderV2, self).__init__()
-
- self.filter_size = filter_size
- self.fsmn_num_layers = fsmn_num_layers
- self.num_memory_units = num_memory_units
- self.ffn_inner_dim = ffn_inner_dim
- self.dropout = dropout
- self.shift = shift
- if not isinstance(shift, list):
- self.shift = [shift for _ in range(self.fsmn_num_layers)]
-
- self.ffn_lst = nn.ModuleList()
- self.ffn_lst.append(
- FeedForwardNet(
- input_dim, ffn_inner_dim, num_memory_units, dropout=dropout))
- for i in range(1, fsmn_num_layers):
- self.ffn_lst.append(
- FeedForwardNet(
- num_memory_units,
- ffn_inner_dim,
- num_memory_units,
- dropout=dropout))
-
- self.memory_block_lst = nn.ModuleList()
- for i in range(fsmn_num_layers):
- self.memory_block_lst.append(
- MemoryBlockV2(num_memory_units, filter_size, self.shift[i],
- dropout))
-
- def forward(self, input, mask=None):
- x = F.dropout(input, self.dropout, self.training)
- for (ffn, memory_block) in zip(self.ffn_lst, self.memory_block_lst):
- context = ffn(x)
- memory = memory_block(context, mask)
- memory = F.dropout(memory, self.dropout, self.training)
- if memory.size(-1) == x.size(-1):
- memory += x
- x = memory
-
- return x
diff --git a/modelscope/models/audio/tts/kantts/models/sambert/kantts_sambert.py b/modelscope/models/audio/tts/kantts/models/sambert/kantts_sambert.py
deleted file mode 100644
index 46939cad..00000000
--- a/modelscope/models/audio/tts/kantts/models/sambert/kantts_sambert.py
+++ /dev/null
@@ -1,1043 +0,0 @@
-# Copyright (c) Alibaba, Inc. and its affiliates.
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-from modelscope.models.audio.tts.kantts.models.utils import \
- get_mask_from_lengths
-from . import FFTBlock, PNCABlock, Prenet
-from .adaptors import (LengthRegulator, VarFsmnRnnNARPredictor,
- VarRnnARPredictor)
-from .alignment import b_mas
-from .attention import ConvAttention
-from .fsmn import FsmnEncoderV2
-from .positions import DurSinusoidalPositionEncoder, SinusoidalPositionEncoder
-
-
-class SelfAttentionEncoder(nn.Module):
-
- def __init__(
- self,
- n_layer,
- d_in,
- d_model,
- n_head,
- d_head,
- d_inner,
- dropout,
- dropout_att,
- dropout_relu,
- position_encoder,
- ):
- super(SelfAttentionEncoder, self).__init__()
-
- self.d_in = d_in
- self.d_model = d_model
- self.dropout = dropout
- d_in_lst = [d_in] + [d_model] * (n_layer - 1)
- self.fft = nn.ModuleList([
- FFTBlock(
- d,
- d_model,
- n_head,
- d_head,
- d_inner,
- (3, 1),
- dropout,
- dropout_att,
- dropout_relu,
- ) for d in d_in_lst
- ])
- self.ln = nn.LayerNorm(d_model, eps=1e-6)
- self.position_enc = position_encoder
-
- def forward(self, input, mask=None, return_attns=False):
- input *= self.d_model**0.5
- if isinstance(self.position_enc, SinusoidalPositionEncoder):
- input = self.position_enc(input)
- else:
- raise NotImplementedError
-
- input = F.dropout(input, p=self.dropout, training=self.training)
-
- enc_slf_attn_list = []
- max_len = input.size(1)
- if mask is not None:
- slf_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1)
- else:
- slf_attn_mask = None
-
- enc_output = input
- for id, layer in enumerate(self.fft):
- enc_output, enc_slf_attn = layer(
- enc_output, mask=mask, slf_attn_mask=slf_attn_mask)
- if return_attns:
- enc_slf_attn_list += [enc_slf_attn]
-
- enc_output = self.ln(enc_output)
-
- return enc_output, enc_slf_attn_list
-
-
-class HybridAttentionDecoder(nn.Module):
-
- def __init__(
- self,
- d_in,
- prenet_units,
- n_layer,
- d_model,
- d_mem,
- n_head,
- d_head,
- d_inner,
- dropout,
- dropout_att,
- dropout_relu,
- d_out,
- ):
- super(HybridAttentionDecoder, self).__init__()
-
- self.d_model = d_model
- self.dropout = dropout
- self.prenet = Prenet(d_in, prenet_units, d_model)
- self.dec_in_proj = nn.Linear(d_model + d_mem, d_model)
- self.pnca = nn.ModuleList([
- PNCABlock(
- d_model,
- d_mem,
- n_head,
- d_head,
- d_inner,
- (1, 1),
- dropout,
- dropout_att,
- dropout_relu,
- ) for _ in range(n_layer)
- ])
- self.ln = nn.LayerNorm(d_model, eps=1e-6)
- self.dec_out_proj = nn.Linear(d_model, d_out)
-
- def reset_state(self):
- for layer in self.pnca:
- layer.reset_state()
-
- def get_pnca_attn_mask(self,
- device,
- max_len,
- x_band_width,
- h_band_width,
- mask=None):
- if mask is not None:
- pnca_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1)
- else:
- pnca_attn_mask = None
-
- range_ = torch.arange(max_len).to(device)
- x_start = torch.clamp_min(range_ - x_band_width, 0)[None, None, :]
- x_end = (range_ + 1)[None, None, :]
- h_start = range_[None, None, :]
- h_end = torch.clamp_max(range_ + h_band_width + 1,
- max_len + 1)[None, None, :]
-
- pnca_x_attn_mask = ~((x_start <= range_[None, :, None])
- & # noqa W504
- (x_end > range_[None, :, None])).transpose(1, 2)
- pnca_h_attn_mask = ~((h_start <= range_[None, :, None])
- & # noqa W504
- (h_end > range_[None, :, None])).transpose(1, 2)
-
- if pnca_attn_mask is not None:
- pnca_x_attn_mask = pnca_x_attn_mask | pnca_attn_mask
- pnca_h_attn_mask = pnca_h_attn_mask | pnca_attn_mask
- pnca_x_attn_mask = pnca_x_attn_mask.masked_fill(
- pnca_attn_mask.transpose(1, 2), False)
- pnca_h_attn_mask = pnca_h_attn_mask.masked_fill(
- pnca_attn_mask.transpose(1, 2), False)
-
- return pnca_attn_mask, pnca_x_attn_mask, pnca_h_attn_mask
-
- # must call reset_state before
- def forward(self,
- input,
- memory,
- x_band_width,
- h_band_width,
- mask=None,
- return_attns=False):
- input = self.prenet(input)
- input = torch.cat([memory, input], dim=-1)
- input = self.dec_in_proj(input)
-
- if mask is not None:
- input = input.masked_fill(mask.unsqueeze(-1), 0)
-
- input *= self.d_model**0.5
- input = F.dropout(input, p=self.dropout, training=self.training)
-
- max_len = input.size(1)
- pnca_attn_mask, pnca_x_attn_mask, pnca_h_attn_mask = self.get_pnca_attn_mask(
- input.device, max_len, x_band_width, h_band_width, mask)
-
- dec_pnca_attn_x_list = []
- dec_pnca_attn_h_list = []
- dec_output = input
- for id, layer in enumerate(self.pnca):
- dec_output, dec_pnca_attn_x, dec_pnca_attn_h = layer(
- dec_output,
- memory,
- mask=mask,
- pnca_x_attn_mask=pnca_x_attn_mask,
- pnca_h_attn_mask=pnca_h_attn_mask,
- )
- if return_attns:
- dec_pnca_attn_x_list += [dec_pnca_attn_x]
- dec_pnca_attn_h_list += [dec_pnca_attn_h]
-
- dec_output = self.ln(dec_output)
- dec_output = self.dec_out_proj(dec_output)
-
- return dec_output, dec_pnca_attn_x_list, dec_pnca_attn_h_list
-
- # must call reset_state before when step == 0
- def infer(
- self,
- step,
- input,
- memory,
- x_band_width,
- h_band_width,
- mask=None,
- return_attns=False,
- ):
- max_len = memory.size(1)
-
- input = self.prenet(input)
- input = torch.cat([memory[:, step:step + 1, :], input], dim=-1)
- input = self.dec_in_proj(input)
-
- input *= self.d_model**0.5
- input = F.dropout(input, p=self.dropout, training=self.training)
-
- pnca_attn_mask, pnca_x_attn_mask, pnca_h_attn_mask = self.get_pnca_attn_mask(
- input.device, max_len, x_band_width, h_band_width, mask)
-
- dec_pnca_attn_x_list = []
- dec_pnca_attn_h_list = []
- dec_output = input
- for id, layer in enumerate(self.pnca):
- if mask is not None:
- mask_step = mask[:, step:step + 1]
- else:
- mask_step = None
- dec_output, dec_pnca_attn_x, dec_pnca_attn_h = layer(
- dec_output,
- memory,
- mask=mask_step,
- pnca_x_attn_mask=pnca_x_attn_mask[:,
- step:step + 1, :(step + 1)],
- pnca_h_attn_mask=pnca_h_attn_mask[:, step:step + 1, :],
- )
- if return_attns:
- dec_pnca_attn_x_list += [dec_pnca_attn_x]
- dec_pnca_attn_h_list += [dec_pnca_attn_h]
-
- dec_output = self.ln(dec_output)
- dec_output = self.dec_out_proj(dec_output)
-
- return dec_output, dec_pnca_attn_x_list, dec_pnca_attn_h_list
-
-
-class TextFftEncoder(nn.Module):
-
- def __init__(self, config):
- super(TextFftEncoder, self).__init__()
-
- d_emb = config['embedding_dim']
- self.using_byte = False
- if config.get('using_byte', False):
- self.using_byte = True
- nb_ling_byte_index = config['byte_index']
- self.byte_index_emb = nn.Embedding(nb_ling_byte_index, d_emb)
- else:
- # linguistic unit lookup table
- nb_ling_sy = config['sy']
- nb_ling_tone = config['tone']
- nb_ling_syllable_flag = config['syllable_flag']
- nb_ling_ws = config['word_segment']
- self.sy_emb = nn.Embedding(nb_ling_sy, d_emb)
- self.tone_emb = nn.Embedding(nb_ling_tone, d_emb)
- self.syllable_flag_emb = nn.Embedding(nb_ling_syllable_flag, d_emb)
- self.ws_emb = nn.Embedding(nb_ling_ws, d_emb)
-
- max_len = config['max_len']
-
- nb_layers = config['encoder_num_layers']
- nb_heads = config['encoder_num_heads']
- d_model = config['encoder_num_units']
- d_head = d_model // nb_heads
- d_inner = config['encoder_ffn_inner_dim']
- dropout = config['encoder_dropout']
- dropout_attn = config['encoder_attention_dropout']
- dropout_relu = config['encoder_relu_dropout']
- d_proj = config['encoder_projection_units']
-
- self.d_model = d_model
-
- position_enc = SinusoidalPositionEncoder(max_len, d_emb)
-
- self.ling_enc = SelfAttentionEncoder(
- nb_layers,
- d_emb,
- d_model,
- nb_heads,
- d_head,
- d_inner,
- dropout,
- dropout_attn,
- dropout_relu,
- position_enc,
- )
-
- self.ling_proj = nn.Linear(d_model, d_proj, bias=False)
-
- def forward(self, inputs_ling, masks=None, return_attns=False):
- # Parse inputs_ling_seq
- if self.using_byte:
- inputs_byte_index = inputs_ling[:, :, 0]
- byte_index_embedding = self.byte_index_emb(inputs_byte_index)
- ling_embedding = byte_index_embedding
- else:
- inputs_sy = inputs_ling[:, :, 0]
- inputs_tone = inputs_ling[:, :, 1]
- inputs_syllable_flag = inputs_ling[:, :, 2]
- inputs_ws = inputs_ling[:, :, 3]
-
- # Lookup table
- sy_embedding = self.sy_emb(inputs_sy)
- tone_embedding = self.tone_emb(inputs_tone)
- syllable_flag_embedding = self.syllable_flag_emb(
- inputs_syllable_flag)
- ws_embedding = self.ws_emb(inputs_ws)
-
- ling_embedding = (
- sy_embedding + tone_embedding + syllable_flag_embedding
- + ws_embedding)
-
- enc_output, enc_slf_attn_list = self.ling_enc(ling_embedding, masks,
- return_attns)
-
- if hasattr(self, 'ling_proj'):
- enc_output = self.ling_proj(enc_output)
-
- return enc_output, enc_slf_attn_list, ling_embedding
-
-
-class VarianceAdaptor(nn.Module):
-
- def __init__(self, config):
- super(VarianceAdaptor, self).__init__()
-
- input_dim = (
- config['encoder_projection_units'] + config['emotion_units']
- + config['speaker_units'])
- filter_size = config['predictor_filter_size']
- fsmn_num_layers = config['predictor_fsmn_num_layers']
- num_memory_units = config['predictor_num_memory_units']
- ffn_inner_dim = config['predictor_ffn_inner_dim']
- dropout = config['predictor_dropout']
- shift = config['predictor_shift']
- lstm_units = config['predictor_lstm_units']
-
- dur_pred_prenet_units = config['dur_pred_prenet_units']
- dur_pred_lstm_units = config['dur_pred_lstm_units']
-
- self.pitch_predictor = VarFsmnRnnNARPredictor(
- input_dim,
- filter_size,
- fsmn_num_layers,
- num_memory_units,
- ffn_inner_dim,
- dropout,
- shift,
- lstm_units,
- )
- self.energy_predictor = VarFsmnRnnNARPredictor(
- input_dim,
- filter_size,
- fsmn_num_layers,
- num_memory_units,
- ffn_inner_dim,
- dropout,
- shift,
- lstm_units,
- )
- self.duration_predictor = VarRnnARPredictor(input_dim,
- dur_pred_prenet_units,
- dur_pred_lstm_units)
-
- self.length_regulator = LengthRegulator(config['outputs_per_step'])
- self.dur_position_encoder = DurSinusoidalPositionEncoder(
- config['encoder_projection_units'], config['outputs_per_step'])
-
- self.pitch_emb = nn.Conv1d(
- 1, config['encoder_projection_units'], kernel_size=9, padding=4)
- self.energy_emb = nn.Conv1d(
- 1, config['encoder_projection_units'], kernel_size=9, padding=4)
-
- def forward(
- self,
- inputs_text_embedding,
- inputs_emo_embedding,
- inputs_spk_embedding,
- masks=None,
- output_masks=None,
- duration_targets=None,
- pitch_targets=None,
- energy_targets=None,
- ):
-
- batch_size = inputs_text_embedding.size(0)
-
- variance_predictor_inputs = torch.cat([
- inputs_text_embedding, inputs_spk_embedding, inputs_emo_embedding
- ],
- dim=-1) # noqa
-
- pitch_predictions = self.pitch_predictor(variance_predictor_inputs,
- masks)
- energy_predictions = self.energy_predictor(variance_predictor_inputs,
- masks)
-
- if pitch_targets is not None:
- pitch_embeddings = self.pitch_emb(
- pitch_targets.unsqueeze(1)).transpose(1, 2)
- else:
- pitch_embeddings = self.pitch_emb(
- pitch_predictions.unsqueeze(1)).transpose(1, 2)
-
- if energy_targets is not None:
- energy_embeddings = self.energy_emb(
- energy_targets.unsqueeze(1)).transpose(1, 2)
- else:
- energy_embeddings = self.energy_emb(
- energy_predictions.unsqueeze(1)).transpose(1, 2)
-
- inputs_text_embedding_aug = (
- inputs_text_embedding + pitch_embeddings + energy_embeddings)
- duration_predictor_cond = torch.cat(
- [
- inputs_text_embedding_aug, inputs_spk_embedding,
- inputs_emo_embedding
- ],
- dim=-1,
- )
- if duration_targets is not None:
- duration_predictor_go_frame = torch.zeros(batch_size, 1).to(
- inputs_text_embedding.device)
- duration_predictor_input = torch.cat([
- duration_predictor_go_frame, duration_targets[:, :-1].float()
- ],
- dim=-1) # noqa
- duration_predictor_input = torch.log(duration_predictor_input + 1)
- log_duration_predictions, _ = self.duration_predictor(
- duration_predictor_input.unsqueeze(-1),
- duration_predictor_cond,
- masks=masks,
- )
- duration_predictions = torch.exp(log_duration_predictions) - 1
- else:
- log_duration_predictions = self.duration_predictor.infer(
- duration_predictor_cond, masks=masks)
- duration_predictions = torch.exp(log_duration_predictions) - 1
-
- if duration_targets is not None:
- LR_text_outputs, LR_length_rounded = self.length_regulator(
- inputs_text_embedding_aug,
- duration_targets,
- masks=output_masks)
- LR_position_embeddings = self.dur_position_encoder(
- duration_targets, masks=output_masks)
- LR_emo_outputs, _ = self.length_regulator(
- inputs_emo_embedding, duration_targets, masks=output_masks)
- LR_spk_outputs, _ = self.length_regulator(
- inputs_spk_embedding, duration_targets, masks=output_masks)
-
- else:
- LR_text_outputs, LR_length_rounded = self.length_regulator(
- inputs_text_embedding_aug,
- duration_predictions,
- masks=output_masks)
- LR_position_embeddings = self.dur_position_encoder(
- duration_predictions, masks=output_masks)
- LR_emo_outputs, _ = self.length_regulator(
- inputs_emo_embedding, duration_predictions, masks=output_masks)
- LR_spk_outputs, _ = self.length_regulator(
- inputs_spk_embedding, duration_predictions, masks=output_masks)
-
- LR_text_outputs = LR_text_outputs + LR_position_embeddings
-
- return (
- LR_text_outputs,
- LR_emo_outputs,
- LR_spk_outputs,
- LR_length_rounded,
- log_duration_predictions,
- pitch_predictions,
- energy_predictions,
- )
-
-
-class MelPNCADecoder(nn.Module):
-
- def __init__(self, config):
- super(MelPNCADecoder, self).__init__()
-
- prenet_units = config['decoder_prenet_units']
- nb_layers = config['decoder_num_layers']
- nb_heads = config['decoder_num_heads']
- d_model = config['decoder_num_units']
- d_head = d_model // nb_heads
- d_inner = config['decoder_ffn_inner_dim']
- dropout = config['decoder_dropout']
- dropout_attn = config['decoder_attention_dropout']
- dropout_relu = config['decoder_relu_dropout']
- outputs_per_step = config['outputs_per_step']
-
- d_mem = (
- config['encoder_projection_units'] * outputs_per_step
- + config['emotion_units'] + config['speaker_units'])
- d_mel = config['num_mels']
-
- self.d_mel = d_mel
- self.r = outputs_per_step
- self.nb_layers = nb_layers
-
- self.mel_dec = HybridAttentionDecoder(
- d_mel,
- prenet_units,
- nb_layers,
- d_model,
- d_mem,
- nb_heads,
- d_head,
- d_inner,
- dropout,
- dropout_attn,
- dropout_relu,
- d_mel * outputs_per_step,
- )
-
- def forward(
- self,
- memory,
- x_band_width,
- h_band_width,
- target=None,
- mask=None,
- return_attns=False,
- ):
- batch_size = memory.size(0)
- go_frame = torch.zeros((batch_size, 1, self.d_mel)).to(memory.device)
-
- if target is not None:
- self.mel_dec.reset_state()
- input = target[:, self.r - 1::self.r, :]
- input = torch.cat([go_frame, input], dim=1)[:, :-1, :]
- dec_output, dec_pnca_attn_x_list, dec_pnca_attn_h_list = self.mel_dec(
- input,
- memory,
- x_band_width,
- h_band_width,
- mask=mask,
- return_attns=return_attns,
- )
-
- else:
- dec_output = []
- dec_pnca_attn_x_list = [[] for _ in range(self.nb_layers)]
- dec_pnca_attn_h_list = [[] for _ in range(self.nb_layers)]
- self.mel_dec.reset_state()
- input = go_frame
- for step in range(memory.size(1)):
- (
- dec_output_step,
- dec_pnca_attn_x_step,
- dec_pnca_attn_h_step,
- ) = self.mel_dec.infer(
- step,
- input,
- memory,
- x_band_width,
- h_band_width,
- mask=mask,
- return_attns=return_attns,
- )
- input = dec_output_step[:, :, -self.d_mel:]
-
- dec_output.append(dec_output_step)
- for layer_id, (pnca_x_attn, pnca_h_attn) in enumerate(
- zip(dec_pnca_attn_x_step, dec_pnca_attn_h_step)):
- left = memory.size(1) - pnca_x_attn.size(-1)
- if left > 0:
- padding = torch.zeros(
- (pnca_x_attn.size(0), 1, left)).to(pnca_x_attn)
- pnca_x_attn = torch.cat([pnca_x_attn, padding], dim=-1)
- dec_pnca_attn_x_list[layer_id].append(pnca_x_attn)
- dec_pnca_attn_h_list[layer_id].append(pnca_h_attn)
- dec_output = torch.cat(dec_output, dim=1)
- for layer_id in range(self.nb_layers):
- dec_pnca_attn_x_list[layer_id] = torch.cat(
- dec_pnca_attn_x_list[layer_id], dim=1)
- dec_pnca_attn_h_list[layer_id] = torch.cat(
- dec_pnca_attn_h_list[layer_id], dim=1)
-
- return dec_output, dec_pnca_attn_x_list, dec_pnca_attn_h_list
-
-
-class PostNet(nn.Module):
-
- def __init__(self, config):
- super(PostNet, self).__init__()
-
- self.filter_size = config['postnet_filter_size']
- self.fsmn_num_layers = config['postnet_fsmn_num_layers']
- self.num_memory_units = config['postnet_num_memory_units']
- self.ffn_inner_dim = config['postnet_ffn_inner_dim']
- self.dropout = config['postnet_dropout']
- self.shift = config['postnet_shift']
- self.lstm_units = config['postnet_lstm_units']
- self.num_mels = config['num_mels']
-
- self.fsmn = FsmnEncoderV2(
- self.filter_size,
- self.fsmn_num_layers,
- self.num_mels,
- self.num_memory_units,
- self.ffn_inner_dim,
- self.dropout,
- self.shift,
- )
- self.lstm = nn.LSTM(
- self.num_memory_units,
- self.lstm_units,
- num_layers=1,
- batch_first=True)
- self.fc = nn.Linear(self.lstm_units, self.num_mels)
-
- def forward(self, x, mask=None):
- postnet_fsmn_output = self.fsmn(x, mask)
- # The input can also be a packed variable length sequence,
- # here we just omit it for simplicity due to the mask and uni-directional lstm.
- postnet_lstm_output, _ = self.lstm(postnet_fsmn_output)
- mel_residual_output = self.fc(postnet_lstm_output)
-
- return mel_residual_output
-
-
-def average_frame_feat(pitch, durs):
- durs_cums_ends = torch.cumsum(durs, dim=1).long()
- durs_cums_starts = F.pad(durs_cums_ends[:, :-1], (1, 0))
- pitch_nonzero_cums = F.pad(torch.cumsum(pitch != 0.0, dim=2), (1, 0))
- pitch_cums = F.pad(torch.cumsum(pitch, dim=2), (1, 0))
-
- bs, lengths = durs_cums_ends.size()
- n_formants = pitch.size(1)
- dcs = durs_cums_starts[:, None, :].expand(bs, n_formants, lengths)
- dce = durs_cums_ends[:, None, :].expand(bs, n_formants, lengths)
-
- pitch_sums = (torch.gather(pitch_cums, 2, dce)
- - torch.gather(pitch_cums, 2, dcs)).float()
- pitch_nelems = (torch.gather(pitch_nonzero_cums, 2, dce)
- - torch.gather(pitch_nonzero_cums, 2, dcs)).float()
-
- pitch_avg = torch.where(pitch_nelems == 0.0, pitch_nelems,
- pitch_sums / pitch_nelems)
- return pitch_avg
-
-
-class FP_Predictor(nn.Module):
-
- def __init__(self, config):
- super(FP_Predictor, self).__init__()
-
- self.w_1 = nn.Conv1d(
- config['encoder_projection_units'],
- config['embedding_dim'] // 2,
- kernel_size=3,
- padding=1,
- )
- self.w_2 = nn.Conv1d(
- config['embedding_dim'] // 2,
- config['encoder_projection_units'],
- kernel_size=1,
- padding=0,
- )
- self.layer_norm1 = nn.LayerNorm(config['embedding_dim'] // 2, eps=1e-6)
- self.layer_norm2 = nn.LayerNorm(
- config['encoder_projection_units'], eps=1e-6)
- self.dropout_inner = nn.Dropout(0.1)
- self.dropout = nn.Dropout(0.1)
- self.fc = nn.Linear(config['encoder_projection_units'], 4)
-
- def forward(self, x):
- x = x.transpose(1, 2)
- x = F.relu(self.w_1(x))
- x = x.transpose(1, 2)
- x = self.dropout_inner(self.layer_norm1(x))
- x = x.transpose(1, 2)
- x = F.relu(self.w_2(x))
- x = x.transpose(1, 2)
- x = self.dropout(self.layer_norm2(x))
- output = F.softmax(self.fc(x), dim=2)
- return output
-
-
-class KanTtsSAMBERT(nn.Module):
-
- def __init__(self, config):
- super(KanTtsSAMBERT, self).__init__()
-
- self.text_encoder = TextFftEncoder(config)
- self.spk_tokenizer = nn.Embedding(config['speaker'],
- config['speaker_units'])
- self.emo_tokenizer = nn.Embedding(config['emotion'],
- config['emotion_units'])
- self.variance_adaptor = VarianceAdaptor(config)
- self.mel_decoder = MelPNCADecoder(config)
- self.mel_postnet = PostNet(config)
- self.MAS = False
- if config.get('MAS', False):
- self.MAS = True
- self.align_attention = ConvAttention(
- n_mel_channels=config['num_mels'],
- n_text_channels=config['embedding_dim'],
- n_att_channels=config['num_mels'],
- )
- self.fp_enable = config.get('FP', False)
- if self.fp_enable:
- self.FP_predictor = FP_Predictor(config)
-
- def get_lfr_mask_from_lengths(self, lengths, max_len):
- batch_size = lengths.size(0)
- # padding according to the outputs_per_step
- padded_lr_lengths = torch.zeros_like(lengths)
- for i in range(batch_size):
- len_item = int(lengths[i].item())
- padding = self.mel_decoder.r - len_item % self.mel_decoder.r
- if padding < self.mel_decoder.r:
- padded_lr_lengths[i] = (len_item
- + padding) // self.mel_decoder.r
- else:
- padded_lr_lengths[i] = len_item // self.mel_decoder.r
-
- return get_mask_from_lengths(
- padded_lr_lengths, max_len=max_len // self.mel_decoder.r)
-
- def binarize_attention_parallel(self, attn, in_lens, out_lens):
- """For training purposes only. Binarizes attention with MAS.
- These will no longer receive a gradient.
-
- Args:
- attn: B x 1 x max_mel_len x max_text_len
- """
- with torch.no_grad():
- attn_cpu = attn.data.cpu().numpy()
- attn_out = b_mas(
- attn_cpu,
- in_lens.cpu().numpy(),
- out_lens.cpu().numpy(),
- width=1)
- return torch.from_numpy(attn_out).to(attn.get_device())
-
- def insert_fp(
- self,
- text_hid,
- FP_p,
- fp_label,
- fp_dict,
- inputs_emotion,
- inputs_speaker,
- input_lengths,
- input_masks,
- ):
-
- en, _, _ = self.text_encoder(fp_dict[1], return_attns=True)
- a, _, _ = self.text_encoder(fp_dict[2], return_attns=True)
- e, _, _ = self.text_encoder(fp_dict[3], return_attns=True)
-
- en = en.squeeze()
- a = a.squeeze()
- e = e.squeeze()
-
- max_len_ori = max(input_lengths)
- if fp_label is None:
- input_masks_r = ~input_masks
- fp_mask = (FP_p == FP_p.max(dim=2,
- keepdim=True)[0]).to(dtype=torch.int32)
- fp_mask = fp_mask[:, :, 1:] * input_masks_r.unsqueeze(2).expand(
- -1, -1, 3)
- fp_number = torch.sum(torch.sum(fp_mask, dim=2), dim=1)
- else:
- fp_number = torch.sum((fp_label > 0), dim=1)
- inter_lengths = input_lengths + 3 * fp_number
- max_len = max(inter_lengths)
-
- delta = max_len - max_len_ori
- if delta > 0:
- if delta > text_hid.shape[1]:
- nrepeat = delta // text_hid.shape[1]
- bias = delta % text_hid.shape[1]
- text_hid = torch.cat((text_hid, text_hid.repeat(
- 1, nrepeat, 1), text_hid[:, :bias, :]), 1)
- inputs_emotion = torch.cat(
- (
- inputs_emotion,
- inputs_emotion.repeat(1, nrepeat),
- inputs_emotion[:, :bias],
- ),
- 1,
- )
- inputs_speaker = torch.cat(
- (
- inputs_speaker,
- inputs_speaker.repeat(1, nrepeat),
- inputs_speaker[:, :bias],
- ),
- 1,
- )
- else:
- text_hid = torch.cat((text_hid, text_hid[:, :delta, :]), 1)
- inputs_emotion = torch.cat(
- (inputs_emotion, inputs_emotion[:, :delta]), 1)
- inputs_speaker = torch.cat(
- (inputs_speaker, inputs_speaker[:, :delta]), 1)
-
- if fp_label is None:
- for i in range(fp_mask.shape[0]):
- for j in range(fp_mask.shape[1] - 1, -1, -1):
- if fp_mask[i][j][0] == 1:
- text_hid[i] = torch.cat(
- (text_hid[i][:j], en, text_hid[i][j:-3]), 0)
- elif fp_mask[i][j][1] == 1:
- text_hid[i] = torch.cat(
- (text_hid[i][:j], a, text_hid[i][j:-3]), 0)
- elif fp_mask[i][j][2] == 1:
- text_hid[i] = torch.cat(
- (text_hid[i][:j], e, text_hid[i][j:-3]), 0)
- else:
- for i in range(fp_label.shape[0]):
- for j in range(fp_label.shape[1] - 1, -1, -1):
- if fp_label[i][j] == 1:
- text_hid[i] = torch.cat(
- (text_hid[i][:j], en, text_hid[i][j:-3]), 0)
- elif fp_label[i][j] == 2:
- text_hid[i] = torch.cat(
- (text_hid[i][:j], a, text_hid[i][j:-3]), 0)
- elif fp_label[i][j] == 3:
- text_hid[i] = torch.cat(
- (text_hid[i][:j], e, text_hid[i][j:-3]), 0)
- return text_hid, inputs_emotion, inputs_speaker, inter_lengths
-
- def forward(
- self,
- inputs_ling,
- inputs_emotion,
- inputs_speaker,
- input_lengths,
- output_lengths=None,
- mel_targets=None,
- duration_targets=None,
- pitch_targets=None,
- energy_targets=None,
- attn_priors=None,
- fp_label=None,
- ):
- batch_size = inputs_ling.size(0)
-
- is_training = mel_targets is not None
- input_masks = get_mask_from_lengths(
- input_lengths, max_len=inputs_ling.size(1))
-
- text_hid, enc_sla_attn_lst, ling_embedding = self.text_encoder(
- inputs_ling, input_masks, return_attns=True)
-
- inter_lengths = input_lengths
- FP_p = None
- if self.fp_enable:
- FP_p = self.FP_predictor(text_hid)
- fp_dict = self.fp_dict
- text_hid, inputs_emotion, inputs_speaker, inter_lengths = self.insert_fp(
- text_hid,
- FP_p,
- fp_label,
- fp_dict,
- inputs_emotion,
- inputs_speaker,
- input_lengths,
- input_masks,
- )
-
- # Monotonic-Alignment-Search
- if self.MAS and is_training:
- attn_soft, attn_logprob = self.align_attention(
- mel_targets.permute(0, 2, 1),
- ling_embedding.permute(0, 2, 1),
- input_masks,
- attn_priors,
- )
- attn_hard = self.binarize_attention_parallel(
- attn_soft, input_lengths, output_lengths)
- attn_hard_dur = attn_hard.sum(2)[:, 0, :]
- duration_targets = attn_hard_dur
- assert torch.all(
- torch.eq(duration_targets.sum(dim=1), output_lengths))
- pitch_targets = average_frame_feat(
- pitch_targets.unsqueeze(1), duration_targets).squeeze(1)
- energy_targets = average_frame_feat(
- energy_targets.unsqueeze(1), duration_targets).squeeze(1)
- # Padding the POS length to make it sum equal to max rounded output length
- for i in range(batch_size):
- len_item = int(output_lengths[i].item())
- padding = mel_targets.size(1) - len_item
- duration_targets[i, input_lengths[i]] = padding
-
- emo_hid = self.emo_tokenizer(inputs_emotion)
- spk_hid = self.spk_tokenizer(inputs_speaker)
-
- inter_masks = get_mask_from_lengths(
- inter_lengths, max_len=text_hid.size(1))
-
- if output_lengths is not None:
- output_masks = get_mask_from_lengths(
- output_lengths, max_len=mel_targets.size(1))
- else:
- output_masks = None
-
- (
- LR_text_outputs,
- LR_emo_outputs,
- LR_spk_outputs,
- LR_length_rounded,
- log_duration_predictions,
- pitch_predictions,
- energy_predictions,
- ) = self.variance_adaptor(
- text_hid,
- emo_hid,
- spk_hid,
- masks=inter_masks,
- output_masks=output_masks,
- duration_targets=duration_targets,
- pitch_targets=pitch_targets,
- energy_targets=energy_targets,
- )
-
- if output_lengths is not None:
- lfr_masks = self.get_lfr_mask_from_lengths(
- output_lengths, max_len=LR_text_outputs.size(1))
- else:
- output_masks = get_mask_from_lengths(
- LR_length_rounded, max_len=LR_text_outputs.size(1))
- lfr_masks = None
-
- # LFR with the factor of outputs_per_step
- LFR_text_inputs = LR_text_outputs.contiguous().view(
- batch_size, -1, self.mel_decoder.r * text_hid.shape[-1])
- LFR_emo_inputs = LR_emo_outputs.contiguous().view(
- batch_size, -1,
- self.mel_decoder.r * emo_hid.shape[-1])[:, :, :emo_hid.shape[-1]]
- LFR_spk_inputs = LR_spk_outputs.contiguous().view(
- batch_size, -1,
- self.mel_decoder.r * spk_hid.shape[-1])[:, :, :spk_hid.shape[-1]]
-
- memory = torch.cat([LFR_text_inputs, LFR_spk_inputs, LFR_emo_inputs],
- dim=-1)
-
- if duration_targets is not None:
- x_band_width = int(
- duration_targets.float().masked_fill(inter_masks, 0).max()
- / self.mel_decoder.r + 0.5)
- h_band_width = x_band_width
- else:
- x_band_width = int((torch.exp(log_duration_predictions) - 1).max()
- / self.mel_decoder.r + 0.5)
- h_band_width = x_band_width
-
- dec_outputs, pnca_x_attn_lst, pnca_h_attn_lst = self.mel_decoder(
- memory,
- x_band_width,
- h_band_width,
- target=mel_targets,
- mask=lfr_masks,
- return_attns=True,
- )
-
- # De-LFR with the factor of outputs_per_step
- dec_outputs = dec_outputs.contiguous().view(batch_size, -1,
- self.mel_decoder.d_mel)
-
- if output_masks is not None:
- dec_outputs = dec_outputs.masked_fill(
- output_masks.unsqueeze(-1), 0)
-
- postnet_outputs = self.mel_postnet(dec_outputs,
- output_masks) + dec_outputs
- if output_masks is not None:
- postnet_outputs = postnet_outputs.masked_fill(
- output_masks.unsqueeze(-1), 0)
-
- res = {
- 'x_band_width': x_band_width,
- 'h_band_width': h_band_width,
- 'enc_slf_attn_lst': enc_sla_attn_lst,
- 'pnca_x_attn_lst': pnca_x_attn_lst,
- 'pnca_h_attn_lst': pnca_h_attn_lst,
- 'dec_outputs': dec_outputs,
- 'postnet_outputs': postnet_outputs,
- 'LR_length_rounded': LR_length_rounded,
- 'log_duration_predictions': log_duration_predictions,
- 'pitch_predictions': pitch_predictions,
- 'energy_predictions': energy_predictions,
- 'duration_targets': duration_targets,
- 'pitch_targets': pitch_targets,
- 'energy_targets': energy_targets,
- 'fp_predictions': FP_p,
- 'valid_inter_lengths': inter_lengths,
- }
-
- res['LR_text_outputs'] = LR_text_outputs
- res['LR_emo_outputs'] = LR_emo_outputs
- res['LR_spk_outputs'] = LR_spk_outputs
-
- if self.MAS and is_training:
- res['attn_soft'] = attn_soft
- res['attn_hard'] = attn_hard
- res['attn_logprob'] = attn_logprob
-
- return res
-
-
-class KanTtsTextsyBERT(nn.Module):
-
- def __init__(self, config):
- super(KanTtsTextsyBERT, self).__init__()
-
- self.text_encoder = TextFftEncoder(config)
- delattr(self.text_encoder, 'ling_proj')
- self.fc = nn.Linear(self.text_encoder.d_model, config['sy'])
-
- def forward(self, inputs_ling, input_lengths):
- res = {}
-
- input_masks = get_mask_from_lengths(
- input_lengths, max_len=inputs_ling.size(1))
-
- text_hid, enc_sla_attn_lst = self.text_encoder(
- inputs_ling, input_masks, return_attns=True)
- logits = self.fc(text_hid)
-
- res['logits'] = logits
- res['enc_slf_attn_lst'] = enc_sla_attn_lst
-
- return res
diff --git a/modelscope/models/audio/tts/kantts/models/sambert/positions.py b/modelscope/models/audio/tts/kantts/models/sambert/positions.py
deleted file mode 100644
index a055f2c1..00000000
--- a/modelscope/models/audio/tts/kantts/models/sambert/positions.py
+++ /dev/null
@@ -1,102 +0,0 @@
-# Copyright (c) Alibaba, Inc. and its affiliates.
-import numpy as np
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-
-class SinusoidalPositionEncoder(nn.Module):
-
- def __init__(self, max_len, depth):
- super(SinusoidalPositionEncoder, self).__init__()
-
- self.max_len = max_len
- self.depth = depth
- self.position_enc = nn.Parameter(
- self.get_sinusoid_encoding_table(max_len, depth).unsqueeze(0),
- requires_grad=False,
- )
-
- def forward(self, input):
- bz_in, len_in, _ = input.size()
- if len_in > self.max_len:
- self.max_len = len_in
- self.position_enc.data = (
- self.get_sinusoid_encoding_table(
- self.max_len, self.depth).unsqueeze(0).to(input.device))
-
- output = input + self.position_enc[:, :len_in, :].expand(bz_in, -1, -1)
-
- return output
-
- @staticmethod
- def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None):
- """ Sinusoid position encoding table """
-
- def cal_angle(position, hid_idx):
- return position / np.power(10000, hid_idx / float(d_hid / 2 - 1))
-
- def get_posi_angle_vec(position):
- return [cal_angle(position, hid_j) for hid_j in range(d_hid // 2)]
-
- scaled_time_table = np.array(
- [get_posi_angle_vec(pos_i + 1) for pos_i in range(n_position)])
-
- sinusoid_table = np.zeros((n_position, d_hid))
- sinusoid_table[:, :d_hid // 2] = np.sin(scaled_time_table)
- sinusoid_table[:, d_hid // 2:] = np.cos(scaled_time_table)
-
- if padding_idx is not None:
- # zero vector for padding dimension
- sinusoid_table[padding_idx] = 0.0
-
- return torch.FloatTensor(sinusoid_table)
-
-
-class DurSinusoidalPositionEncoder(nn.Module):
-
- def __init__(self, depth, outputs_per_step):
- super(DurSinusoidalPositionEncoder, self).__init__()
-
- self.depth = depth
- self.outputs_per_step = outputs_per_step
-
- inv_timescales = [
- np.power(10000, 2 * (hid_idx // 2) / depth)
- for hid_idx in range(depth)
- ]
- self.inv_timescales = nn.Parameter(
- torch.FloatTensor(inv_timescales), requires_grad=False)
-
- def forward(self, durations, masks=None):
- reps = (durations + 0.5).long()
- output_lens = reps.sum(dim=1)
- max_len = output_lens.max()
- reps_cumsum = torch.cumsum(
- F.pad(reps.float(), (1, 0, 0, 0), value=0.0), dim=1)[:, None, :]
- range_ = torch.arange(max_len).to(durations.device)[None, :, None]
- mult = (reps_cumsum[:, :, :-1] <= range_) & (
- reps_cumsum[:, :, 1:] > range_)
- mult = mult.float()
- offsets = torch.matmul(mult,
- reps_cumsum[:,
- 0, :-1].unsqueeze(-1)).squeeze(-1)
- dur_pos = range_[:, :, 0] - offsets + 1
-
- if masks is not None:
- assert masks.size(1) == dur_pos.size(1)
- dur_pos = dur_pos.masked_fill(masks, 0.0)
-
- seq_len = dur_pos.size(1)
- padding = self.outputs_per_step - int(seq_len) % self.outputs_per_step
- if padding < self.outputs_per_step:
- dur_pos = F.pad(dur_pos, (0, padding, 0, 0), value=0.0)
-
- position_embedding = dur_pos[:, :, None] / self.inv_timescales[None,
- None, :]
- position_embedding[:, :, 0::2] = torch.sin(position_embedding[:, :,
- 0::2])
- position_embedding[:, :, 1::2] = torch.cos(position_embedding[:, :,
- 1::2])
-
- return position_embedding
diff --git a/modelscope/models/audio/tts/kantts/models/utils.py b/modelscope/models/audio/tts/kantts/models/utils.py
deleted file mode 100644
index e75e5f4f..00000000
--- a/modelscope/models/audio/tts/kantts/models/utils.py
+++ /dev/null
@@ -1,26 +0,0 @@
-# Copyright (c) Alibaba, Inc. and its affiliates.
-
-from distutils.version import LooseVersion
-
-import torch
-
-is_pytorch_17plus = LooseVersion(torch.__version__) >= LooseVersion('1.7')
-
-
-def init_weights(m, mean=0.0, std=0.01):
- classname = m.__class__.__name__
- if classname.find('Conv') != -1:
- m.weight.data.normal_(mean, std)
-
-
-def get_mask_from_lengths(lengths, max_len=None):
- batch_size = lengths.shape[0]
- if max_len is None:
- max_len = torch.max(lengths).item()
-
- ids = (
- torch.arange(0, max_len).unsqueeze(0).expand(batch_size,
- -1).to(lengths.device))
- mask = ids >= lengths.unsqueeze(1).expand(-1, max_len)
-
- return mask
diff --git a/modelscope/models/audio/tts/kantts/preprocess/audio_processor/audio_processor.py b/modelscope/models/audio/tts/kantts/preprocess/audio_processor/audio_processor.py
deleted file mode 100644
index 343cfd9c..00000000
--- a/modelscope/models/audio/tts/kantts/preprocess/audio_processor/audio_processor.py
+++ /dev/null
@@ -1,774 +0,0 @@
-# Copyright (c) Alibaba, Inc. and its affiliates.
-
-import argparse
-import os
-from concurrent.futures import ProcessPoolExecutor
-from glob import glob
-
-import numpy as np
-import yaml
-from tqdm import tqdm
-
-from modelscope.utils.logger import get_logger
-from .core.dsp import (load_wav, melspectrogram, save_wav, trim_silence,
- trim_silence_with_interval)
-from .core.utils import (align_length, average_by_duration, compute_mean,
- compute_std, encode_16bits, f0_norm_mean_std,
- get_energy, get_pitch, norm_mean_std,
- parse_interval_file, volume_normalize)
-
-logging = get_logger()
-
-default_audio_config = {
- # Preprocess
- 'wav_normalize': True,
- 'trim_silence': True,
- 'trim_silence_threshold_db': 60,
- 'preemphasize': False,
- # Feature extraction
- 'sampling_rate': 24000,
- 'hop_length': 240,
- 'win_length': 1024,
- 'n_mels': 80,
- 'n_fft': 1024,
- 'fmin': 50.0,
- 'fmax': 7600.0,
- 'min_level_db': -100,
- 'ref_level_db': 20,
- 'phone_level_feature': True,
- 'num_workers': 16,
- # Normalization
- 'norm_type': 'mean_std', # 'mean_std', 'global norm'
- 'max_norm': 1.0,
- 'symmetric': False,
-}
-
-
-class AudioProcessor:
-
- def __init__(self, config=None):
- if not isinstance(config, dict):
- logging.warning(
- '[AudioProcessor] config is not a dict, fall into default config.'
- )
- self.config = default_audio_config
- else:
- self.config = config
-
- for key in self.config:
- setattr(self, key, self.config[key])
-
- self.min_wav_length = int(self.config['sampling_rate'] * 0.5)
-
- self.badcase_list = []
- self.pcm_dict = {}
- self.mel_dict = {}
- self.f0_dict = {}
- self.uv_dict = {}
- self.nccf_dict = {}
- self.f0uv_dict = {}
- self.energy_dict = {}
- self.dur_dict = {}
- logging.info('[AudioProcessor] Initialize AudioProcessor.')
- logging.info('[AudioProcessor] config params:')
- for key in self.config:
- logging.info('[AudioProcessor] %s: %s', key, self.config[key])
-
- def calibrate_SyllableDuration(self, raw_dur_dir, raw_metafile,
- out_cali_duration_dir):
- with open(raw_metafile, 'r') as f:
- lines = f.readlines()
-
- output_dur_dir = out_cali_duration_dir
- os.makedirs(output_dur_dir, exist_ok=True)
-
- for line in lines:
- line = line.strip()
- index, symbols = line.split('\t')
- symbols = [
- symbol.strip('{').strip('}').split('$')[0]
- for symbol in symbols.strip().split(' ')
- ]
- dur_file = os.path.join(raw_dur_dir, index + '.npy')
- phone_file = os.path.join(raw_dur_dir, index + '.phone')
- if not os.path.exists(dur_file) or not os.path.exists(phone_file):
- logging.warning(
- '[AudioProcessor] dur file or phone file not exists: %s',
- index)
- continue
- with open(phone_file, 'r') as f:
- phones = f.readlines()
- dur = np.load(dur_file)
- cali_duration = []
-
- dur_idx = 0
- syll_idx = 0
-
- while dur_idx < len(dur) and syll_idx < len(symbols):
- if phones[dur_idx].strip() == 'sil':
- dur_idx += 1
- continue
-
- if phones[dur_idx].strip(
- ) == 'sp' and symbols[syll_idx][0] != '#':
- dur_idx += 1
- continue
-
- if symbols[syll_idx] in ['ga', 'go', 'ge']:
- cali_duration.append(0)
- syll_idx += 1
- # print("NONE", symbols[syll_idx], 0)
- continue
-
- if symbols[syll_idx][0] == '#':
- if phones[dur_idx].strip() != 'sp':
- cali_duration.append(0)
- # print("NONE", symbols[syll_idx], 0)
- syll_idx += 1
- continue
- else:
- cali_duration.append(dur[dur_idx])
- # print(phones[dur_idx].strip(), symbols[syll_idx], dur[dur_idx])
- dur_idx += 1
- syll_idx += 1
- continue
- # A corresponding phone is found
- cali_duration.append(dur[dur_idx])
- # print(phones[dur_idx].strip(), symbols[syll_idx], dur[dur_idx])
- dur_idx += 1
- syll_idx += 1
- # Add #4 phone duration
- cali_duration.append(0)
- if len(cali_duration) != len(symbols):
- logging.error('[Duration Calibrating] Syllable duration {}\
- is not equal to the number of symbols {}, index: {}'.
- format(len(cali_duration), len(symbols), index))
- continue
-
- # Align with mel frames
- durs = np.array(cali_duration)
- if len(self.mel_dict) > 0:
- pair_mel = self.mel_dict.get(index, None)
- if pair_mel is None:
- logging.warning(
- '[AudioProcessor] Interval file %s has no corresponding mel',
- index,
- )
- continue
- mel_frames = pair_mel.shape[0]
- dur_frames = np.sum(durs)
- if np.sum(durs) > mel_frames:
- durs[-2] -= dur_frames - mel_frames
- elif np.sum(durs) < mel_frames:
- durs[-2] += mel_frames - np.sum(durs)
-
- if durs[-2] < 0:
- logging.error(
- '[AudioProcessor] Duration calibrating failed for %s, mismatch frames %s',
- index,
- durs[-2],
- )
- self.badcase_list.append(index)
- continue
-
- self.dur_dict[index] = durs
-
- np.save(
- os.path.join(output_dur_dir, index + '.npy'),
- self.dur_dict[index])
-
- def amp_normalize(self, src_wav_dir, out_wav_dir):
- if self.wav_normalize:
- logging.info('[AudioProcessor] Amplitude normalization started')
- os.makedirs(out_wav_dir, exist_ok=True)
- res = volume_normalize(src_wav_dir, out_wav_dir)
- logging.info('[AudioProcessor] Amplitude normalization finished')
- return res
- else:
- logging.info('[AudioProcessor] No amplitude normalization')
- os.symlink(src_wav_dir, out_wav_dir, target_is_directory=True)
- return True
-
- def get_pcm_dict(self, src_wav_dir):
- wav_list = glob(os.path.join(src_wav_dir, '*.wav'))
- if len(self.pcm_dict) > 0:
- return self.pcm_dict
-
- logging.info('[AudioProcessor] Start to load pcm from %s', src_wav_dir)
- with ProcessPoolExecutor(
- max_workers=self.num_workers) as executor, tqdm(
- total=len(wav_list)) as progress:
- futures = []
- for wav_path in wav_list:
- future = executor.submit(load_wav, wav_path,
- self.sampling_rate)
- future.add_done_callback(lambda p: progress.update())
- wav_name = os.path.splitext(os.path.basename(wav_path))[0]
- futures.append((future, wav_name))
- for future, wav_name in futures:
- pcm = future.result()
- if len(pcm) < self.min_wav_length:
- logging.warning('[AudioProcessor] %s is too short, skip',
- wav_name)
- self.badcase_list.append(wav_name)
- continue
- self.pcm_dict[wav_name] = pcm
-
- return self.pcm_dict
-
- def trim_silence_wav(self, src_wav_dir, out_wav_dir=None):
- wav_list = glob(os.path.join(src_wav_dir, '*.wav'))
- logging.info('[AudioProcessor] Trim silence started')
- if out_wav_dir is None:
- out_wav_dir = src_wav_dir
- else:
- os.makedirs(out_wav_dir, exist_ok=True)
- pcm_dict = self.get_pcm_dict(src_wav_dir)
- with ProcessPoolExecutor(
- max_workers=self.num_workers) as executor, tqdm(
- total=len(wav_list)) as progress:
- futures = []
- for wav_basename, pcm_data in pcm_dict.items():
- future = executor.submit(
- trim_silence,
- pcm_data,
- self.trim_silence_threshold_db,
- self.hop_length,
- self.win_length,
- )
- future.add_done_callback(lambda p: progress.update())
- futures.append((future, wav_basename))
- for future, wav_basename in tqdm(futures):
- pcm = future.result()
- if len(pcm) < self.min_wav_length:
- logging.warning('[AudioProcessor] %s is too short, skip',
- wav_basename)
- self.badcase_list.append(wav_basename)
- self.pcm_dict.pop(wav_basename)
- continue
- self.pcm_dict[wav_basename] = pcm
- save_wav(
- self.pcm_dict[wav_basename],
- os.path.join(out_wav_dir, wav_basename + '.wav'),
- self.sampling_rate,
- )
-
- logging.info('[AudioProcessor] Trim silence finished')
- return True
-
- def trim_silence_wav_with_interval(self,
- src_wav_dir,
- dur_dir,
- out_wav_dir=None):
- wav_list = glob(os.path.join(src_wav_dir, '*.wav'))
- logging.info('[AudioProcessor] Trim silence with interval started')
- if out_wav_dir is None:
- out_wav_dir = src_wav_dir
- else:
- os.makedirs(out_wav_dir, exist_ok=True)
- pcm_dict = self.get_pcm_dict(src_wav_dir)
- with ProcessPoolExecutor(
- max_workers=self.num_workers) as executor, tqdm(
- total=len(wav_list)) as progress:
- futures = []
- for wav_basename, pcm_data in pcm_dict.items():
- future = executor.submit(
- trim_silence_with_interval,
- pcm_data,
- self.dur_dict.get(wav_basename, None),
- self.hop_length,
- )
- future.add_done_callback(lambda p: progress.update())
- futures.append((future, wav_basename))
- for future, wav_basename in tqdm(futures):
- trimed_pcm = future.result()
- if trimed_pcm is None:
- continue
- if len(trimed_pcm) < self.min_wav_length:
- logging.warning('[AudioProcessor] %s is too short, skip',
- wav_basename)
- self.badcase_list.append(wav_basename)
- self.pcm_dict.pop(wav_basename)
- continue
- self.pcm_dict[wav_basename] = trimed_pcm
- save_wav(
- self.pcm_dict[wav_basename],
- os.path.join(out_wav_dir, wav_basename + '.wav'),
- self.sampling_rate,
- )
-
- logging.info('[AudioProcessor] Trim silence finished')
- return True
-
- def mel_extract(self, src_wav_dir, out_feature_dir):
- os.makedirs(out_feature_dir, exist_ok=True)
- wav_list = glob(os.path.join(src_wav_dir, '*.wav'))
- pcm_dict = self.get_pcm_dict(src_wav_dir)
-
- logging.info('[AudioProcessor] Melspec extraction started')
-
- # Get global normed mel spec
- with ProcessPoolExecutor(
- max_workers=self.num_workers) as executor, tqdm(
- total=len(wav_list)) as progress:
- futures = []
- for wav_basename, pcm_data in pcm_dict.items():
- future = executor.submit(
- melspectrogram,
- pcm_data,
- self.sampling_rate,
- self.n_fft,
- self.hop_length,
- self.win_length,
- self.n_mels,
- self.max_norm,
- self.min_level_db,
- self.ref_level_db,
- self.fmin,
- self.fmax,
- self.symmetric,
- self.preemphasize,
- )
- future.add_done_callback(lambda p: progress.update())
- futures.append((future, wav_basename))
-
- for future, wav_basename in futures:
- result = future.result()
- if result is None:
- logging.warning(
- '[AudioProcessor] Melspec extraction failed for %s',
- wav_basename,
- )
- self.badcase_list.append(wav_basename)
- else:
- melspec = result
- self.mel_dict[wav_basename] = melspec
-
- logging.info('[AudioProcessor] Melspec extraction finished')
-
- # FIXME: is this step necessary?
- # Do mean std norm on global-normed melspec
- logging.info('Melspec statistic proceeding...')
- mel_mean = compute_mean(list(self.mel_dict.values()), dims=self.n_mels)
- mel_std = compute_std(
- list(self.mel_dict.values()), mel_mean, dims=self.n_mels)
- logging.info('Melspec statistic done')
- np.savetxt(
- os.path.join(out_feature_dir, 'mel_mean.txt'),
- mel_mean,
- fmt='%.6f')
- np.savetxt(
- os.path.join(out_feature_dir, 'mel_std.txt'), mel_std, fmt='%.6f')
- logging.info(
- '[AudioProcessor] melspec mean and std saved to:\n{},\n{}'.format(
- os.path.join(out_feature_dir, 'mel_mean.txt'),
- os.path.join(out_feature_dir, 'mel_std.txt'),
- ))
-
- logging.info('[AudioProcessor] Melspec mean std norm is proceeding...')
- for wav_basename in self.mel_dict:
- melspec = self.mel_dict[wav_basename]
- norm_melspec = norm_mean_std(melspec, mel_mean, mel_std)
- np.save(
- os.path.join(out_feature_dir, wav_basename + '.npy'),
- norm_melspec)
-
- logging.info('[AudioProcessor] Melspec normalization finished')
- logging.info('[AudioProcessor] Normed Melspec saved to %s',
- out_feature_dir)
-
- return True
-
- def duration_generate(self, src_interval_dir, out_feature_dir):
- os.makedirs(out_feature_dir, exist_ok=True)
- interval_list = glob(os.path.join(src_interval_dir, '*.interval'))
-
- logging.info('[AudioProcessor] Duration generation started')
- with ProcessPoolExecutor(
- max_workers=self.num_workers) as executor, tqdm(
- total=len(interval_list)) as progress:
- futures = []
- for interval_file_path in interval_list:
- future = executor.submit(
- parse_interval_file,
- interval_file_path,
- self.sampling_rate,
- self.hop_length,
- )
- future.add_done_callback(lambda p: progress.update())
- futures.append((future,
- os.path.splitext(
- os.path.basename(interval_file_path))[0]))
-
- logging.info(
- '[AudioProcessor] Duration align with mel is proceeding...')
- for future, wav_basename in futures:
- result = future.result()
- if result is None:
- logging.warning(
- '[AudioProcessor] Duration generate failed for %s',
- wav_basename)
- self.badcase_list.append(wav_basename)
- else:
- durs, phone_list = result
- # Align length with melspec
- if len(self.mel_dict) > 0:
- pair_mel = self.mel_dict.get(wav_basename, None)
- if pair_mel is None:
- logging.warning(
- '[AudioProcessor] Interval file %s has no corresponding mel',
- wav_basename,
- )
- continue
- mel_frames = pair_mel.shape[0]
- dur_frames = np.sum(durs)
- if np.sum(durs) > mel_frames:
- durs[-1] -= dur_frames - mel_frames
- elif np.sum(durs) < mel_frames:
- durs[-1] += mel_frames - np.sum(durs)
-
- if durs[-1] < 0:
- logging.error(
- '[AudioProcessor] Duration align failed for %s, mismatch frames %s',
- wav_basename,
- durs[-1],
- )
- self.badcase_list.append(wav_basename)
- continue
-
- self.dur_dict[wav_basename] = durs
-
- np.save(
- os.path.join(out_feature_dir, wav_basename + '.npy'),
- durs)
- with open(
- os.path.join(out_feature_dir,
- wav_basename + '.phone'), 'w') as f:
- f.write('\n'.join(phone_list))
- logging.info('[AudioProcessor] Duration generate finished')
-
- return True
-
- def pitch_extract(self, src_wav_dir, out_f0_dir, out_frame_f0_dir,
- out_frame_uv_dir):
- os.makedirs(out_f0_dir, exist_ok=True)
- os.makedirs(out_frame_f0_dir, exist_ok=True)
- os.makedirs(out_frame_uv_dir, exist_ok=True)
- wav_list = glob(os.path.join(src_wav_dir, '*.wav'))
- pcm_dict = self.get_pcm_dict(src_wav_dir)
- mel_dict = self.mel_dict
-
- logging.info('[AudioProcessor] Pitch extraction started')
- # Get raw pitch
- with ProcessPoolExecutor(
- max_workers=self.num_workers) as executor, tqdm(
- total=len(wav_list)) as progress:
- futures = []
- for wav_basename, pcm_data in pcm_dict.items():
- future = executor.submit(
- get_pitch,
- encode_16bits(pcm_data),
- self.sampling_rate,
- self.hop_length,
- )
- future.add_done_callback(lambda p: progress.update())
- futures.append((future, wav_basename))
-
- logging.info(
- '[AudioProcessor] Pitch align with mel is proceeding...')
- for future, wav_basename in futures:
- result = future.result()
- if result is None:
- logging.warning(
- '[AudioProcessor] Pitch extraction failed for %s',
- wav_basename)
- self.badcase_list.append(wav_basename)
- else:
- f0, uv, f0uv = result
- if len(mel_dict) > 0:
- f0 = align_length(f0, mel_dict.get(wav_basename, None))
- uv = align_length(uv, mel_dict.get(wav_basename, None))
- f0uv = align_length(f0uv,
- mel_dict.get(wav_basename, None))
-
- if f0 is None or uv is None or f0uv is None:
- logging.warning(
- '[AudioProcessor] Pitch length mismatch with mel in %s',
- wav_basename,
- )
- self.badcase_list.append(wav_basename)
- continue
- self.f0_dict[wav_basename] = f0
- self.uv_dict[wav_basename] = uv
- self.f0uv_dict[wav_basename] = f0uv
-
- # Normalize f0
- logging.info('[AudioProcessor] Pitch normalization is proceeding...')
- f0_mean = compute_mean(list(self.f0uv_dict.values()), dims=1)
- f0_std = compute_std(list(self.f0uv_dict.values()), f0_mean, dims=1)
- np.savetxt(
- os.path.join(out_f0_dir, 'f0_mean.txt'), f0_mean, fmt='%.6f')
- np.savetxt(os.path.join(out_f0_dir, 'f0_std.txt'), f0_std, fmt='%.6f')
- logging.info(
- '[AudioProcessor] f0 mean and std saved to:\n{},\n{}'.format(
- os.path.join(out_f0_dir, 'f0_mean.txt'),
- os.path.join(out_f0_dir, 'f0_std.txt'),
- ))
- logging.info('[AudioProcessor] Pitch mean std norm is proceeding...')
- for wav_basename in self.f0uv_dict:
- f0 = self.f0uv_dict[wav_basename]
- norm_f0 = f0_norm_mean_std(f0, f0_mean, f0_std)
- self.f0uv_dict[wav_basename] = norm_f0
-
- for wav_basename in self.f0_dict:
- f0 = self.f0_dict[wav_basename]
- norm_f0 = f0_norm_mean_std(f0, f0_mean, f0_std)
- self.f0_dict[wav_basename] = norm_f0
-
- # save frame f0 to a specific dir
- for wav_basename in self.f0_dict:
- np.save(
- os.path.join(out_frame_f0_dir, wav_basename + '.npy'),
- self.f0_dict[wav_basename].reshape(-1),
- )
-
- for wav_basename in self.uv_dict:
- np.save(
- os.path.join(out_frame_uv_dir, wav_basename + '.npy'),
- self.uv_dict[wav_basename].reshape(-1),
- )
-
- # phone level average
- # if there is no duration then save the frame-level f0
- if self.phone_level_feature and len(self.dur_dict) > 0:
- logging.info(
- '[AudioProcessor] Pitch turn to phone-level is proceeding...')
- with ProcessPoolExecutor(
- max_workers=self.num_workers) as executor, tqdm(
- total=len(self.f0uv_dict)) as progress:
- futures = []
- for wav_basename in self.f0uv_dict:
- future = executor.submit(
- average_by_duration,
- self.f0uv_dict.get(wav_basename, None),
- self.dur_dict.get(wav_basename, None),
- )
- future.add_done_callback(lambda p: progress.update())
- futures.append((future, wav_basename))
-
- for future, wav_basename in futures:
- result = future.result()
- if result is None:
- logging.warning(
- '[AudioProcessor] Pitch extraction failed in phone level avg for: %s',
- wav_basename,
- )
- self.badcase_list.append(wav_basename)
- else:
- avg_f0 = result
- self.f0uv_dict[wav_basename] = avg_f0
-
- for wav_basename in self.f0uv_dict:
- np.save(
- os.path.join(out_f0_dir, wav_basename + '.npy'),
- self.f0uv_dict[wav_basename].reshape(-1),
- )
-
- logging.info('[AudioProcessor] Pitch normalization finished')
- logging.info('[AudioProcessor] Normed f0 saved to %s', out_f0_dir)
- logging.info('[AudioProcessor] Pitch extraction finished')
-
- return True
-
- def energy_extract(self, src_wav_dir, out_energy_dir,
- out_frame_energy_dir):
- os.makedirs(out_energy_dir, exist_ok=True)
- os.makedirs(out_frame_energy_dir, exist_ok=True)
- wav_list = glob(os.path.join(src_wav_dir, '*.wav'))
- pcm_dict = self.get_pcm_dict(src_wav_dir)
- mel_dict = self.mel_dict
-
- logging.info('[AudioProcessor] Energy extraction started')
- # Get raw energy
- with ProcessPoolExecutor(
- max_workers=self.num_workers) as executor, tqdm(
- total=len(wav_list)) as progress:
- futures = []
- for wav_basename, pcm_data in pcm_dict.items():
- future = executor.submit(get_energy, pcm_data, self.hop_length,
- self.win_length, self.n_fft)
- future.add_done_callback(lambda p: progress.update())
- futures.append((future, wav_basename))
-
- for future, wav_basename in futures:
- result = future.result()
- if result is None:
- logging.warning(
- '[AudioProcessor] Energy extraction failed for %s',
- wav_basename)
- self.badcase_list.append(wav_basename)
- else:
- energy = result
- if len(mel_dict) > 0:
- energy = align_length(energy,
- mel_dict.get(wav_basename, None))
- if energy is None:
- logging.warning(
- '[AudioProcessor] Energy length mismatch with mel in %s',
- wav_basename,
- )
- self.badcase_list.append(wav_basename)
- continue
- self.energy_dict[wav_basename] = energy
-
- logging.info('Melspec statistic proceeding...')
- # Normalize energy
- energy_mean = compute_mean(list(self.energy_dict.values()), dims=1)
- energy_std = compute_std(
- list(self.energy_dict.values()), energy_mean, dims=1)
- np.savetxt(
- os.path.join(out_energy_dir, 'energy_mean.txt'),
- energy_mean,
- fmt='%.6f')
- np.savetxt(
- os.path.join(out_energy_dir, 'energy_std.txt'),
- energy_std,
- fmt='%.6f')
- logging.info(
- '[AudioProcessor] energy mean and std saved to:\n{},\n{}'.format(
- os.path.join(out_energy_dir, 'energy_mean.txt'),
- os.path.join(out_energy_dir, 'energy_std.txt'),
- ))
-
- logging.info('[AudioProcessor] Energy mean std norm is proceeding...')
- for wav_basename in self.energy_dict:
- energy = self.energy_dict[wav_basename]
- norm_energy = f0_norm_mean_std(energy, energy_mean, energy_std)
- self.energy_dict[wav_basename] = norm_energy
-
- # save frame energy to a specific dir
- for wav_basename in self.energy_dict:
- np.save(
- os.path.join(out_frame_energy_dir, wav_basename + '.npy'),
- self.energy_dict[wav_basename].reshape(-1),
- )
-
- # phone level average
- # if there is no duration then save the frame-level energy
- if self.phone_level_feature and len(self.dur_dict) > 0:
- with ProcessPoolExecutor(
- max_workers=self.num_workers) as executor, tqdm(
- total=len(self.energy_dict)) as progress:
- futures = []
- for wav_basename in self.energy_dict:
- future = executor.submit(
- average_by_duration,
- self.energy_dict.get(wav_basename, None),
- self.dur_dict.get(wav_basename, None),
- )
- future.add_done_callback(lambda p: progress.update())
- futures.append((future, wav_basename))
-
- for future, wav_basename in futures:
- result = future.result()
- if result is None:
- logging.warning(
- '[AudioProcessor] Energy extraction failed in phone level avg for: %s',
- wav_basename,
- )
- self.badcase_list.append(wav_basename)
- else:
- avg_energy = result
- self.energy_dict[wav_basename] = avg_energy
-
- for wav_basename in self.energy_dict:
- np.save(
- os.path.join(out_energy_dir, wav_basename + '.npy'),
- self.energy_dict[wav_basename].reshape(-1),
- )
-
- logging.info('[AudioProcessor] Energy normalization finished')
- logging.info('[AudioProcessor] Normed Energy saved to %s',
- out_energy_dir)
- logging.info('[AudioProcessor] Energy extraction finished')
-
- return True
-
- def process(self, src_voice_dir, out_data_dir, aux_metafile=None):
- succeed = True
-
- raw_wav_dir = os.path.join(src_voice_dir, 'wav')
- src_interval_dir = os.path.join(src_voice_dir, 'interval')
-
- out_mel_dir = os.path.join(out_data_dir, 'mel')
- out_f0_dir = os.path.join(out_data_dir, 'f0')
- out_frame_f0_dir = os.path.join(out_data_dir, 'frame_f0')
- out_frame_uv_dir = os.path.join(out_data_dir, 'frame_uv')
- out_energy_dir = os.path.join(out_data_dir, 'energy')
- out_frame_energy_dir = os.path.join(out_data_dir, 'frame_energy')
- out_duration_dir = os.path.join(out_data_dir, 'raw_duration')
- out_cali_duration_dir = os.path.join(out_data_dir, 'duration')
-
- os.makedirs(out_data_dir, exist_ok=True)
-
- with_duration = os.path.exists(src_interval_dir)
- train_wav_dir = os.path.join(out_data_dir, 'wav')
-
- succeed = self.amp_normalize(raw_wav_dir, train_wav_dir)
- if not succeed:
- logging.error('[AudioProcessor] amp_normalize failed, exit')
- return False
-
- if with_duration:
- # Raw duration, non-trimmed
- succeed = self.duration_generate(src_interval_dir,
- out_duration_dir)
- if not succeed:
- logging.error(
- '[AudioProcessor] duration_generate failed, exit')
- return False
-
- if self.trim_silence:
- if with_duration:
- succeed = self.trim_silence_wav_with_interval(
- train_wav_dir, out_duration_dir)
- if not succeed:
- logging.error(
- '[AudioProcessor] trim_silence_wav_with_interval failed, exit'
- )
- return False
- else:
- succeed = self.trim_silence_wav(train_wav_dir)
- if not succeed:
- logging.error(
- '[AudioProcessor] trim_silence_wav failed, exit')
- return False
-
- succeed = self.mel_extract(train_wav_dir, out_mel_dir)
- if not succeed:
- logging.error('[AudioProcessor] mel_extract failed, exit')
- return False
-
- if aux_metafile is not None and with_duration:
- self.calibrate_SyllableDuration(out_duration_dir, aux_metafile,
- out_cali_duration_dir)
-
- succeed = self.pitch_extract(train_wav_dir, out_f0_dir,
- out_frame_f0_dir, out_frame_uv_dir)
- if not succeed:
- logging.error('[AudioProcessor] pitch_extract failed, exit')
- return False
-
- succeed = self.energy_extract(train_wav_dir, out_energy_dir,
- out_frame_energy_dir)
- if not succeed:
- logging.error('[AudioProcessor] energy_extract failed, exit')
- return False
-
- # recording badcase list
- with open(os.path.join(out_data_dir, 'badlist.txt'), 'w') as f:
- f.write('\n'.join(self.badcase_list))
-
- logging.info('[AudioProcessor] All features extracted successfully!')
-
- return succeed
diff --git a/modelscope/models/audio/tts/kantts/preprocess/audio_processor/core/dsp.py b/modelscope/models/audio/tts/kantts/preprocess/audio_processor/core/dsp.py
deleted file mode 100644
index 04bacb28..00000000
--- a/modelscope/models/audio/tts/kantts/preprocess/audio_processor/core/dsp.py
+++ /dev/null
@@ -1,240 +0,0 @@
-# Copyright (c) Alibaba, Inc. and its affiliates.
-
-import librosa
-import librosa.filters
-import numpy as np
-from scipy import signal
-from scipy.io import wavfile
-
-
-def _stft(y, hop_length, win_length, n_fft):
- return librosa.stft(
- y=y, n_fft=n_fft, hop_length=hop_length, win_length=win_length)
-
-
-def _istft(y, hop_length, win_length):
- return librosa.istft(y, hop_length=hop_length, win_length=win_length)
-
-
-def _db_to_amp(x):
- return np.power(10.0, x * 0.05)
-
-
-def _amp_to_db(x):
- return 20 * np.log10(np.maximum(1e-5, x))
-
-
-def load_wav(path, sr):
- return librosa.load(path, sr=sr)[0]
-
-
-def save_wav(wav, path, sr):
- if wav.dtype == np.float32 or wav.dtype == np.float64:
- quant_wav = 32767 * wav
- else:
- quant_wav = wav
- # maximize the volume to avoid clipping
- # wav *= 32767 / max(0.01, np.max(np.abs(wav)))
- wavfile.write(path, sr, quant_wav.astype(np.int16))
-
-
-def trim_silence(wav, top_db, hop_length, win_length):
- trimed_wav, _ = librosa.effects.trim(
- wav, top_db=top_db, frame_length=win_length, hop_length=hop_length)
- return trimed_wav
-
-
-def trim_silence_with_interval(wav, interval, hop_length):
- if interval is None:
- return None
- leading_sil = interval[0]
- tailing_sil = interval[-1]
- trim_wav = wav[leading_sil * hop_length:-tailing_sil * hop_length]
- return trim_wav
-
-
-def preemphasis(wav, k=0.98, preemphasize=False):
- if preemphasize:
- return signal.lfilter([1, -k], [1], wav)
- return wav
-
-
-def inv_preemphasis(wav, k=0.98, inv_preemphasize=False):
- if inv_preemphasize:
- return signal.lfilter([1], [1, -k], wav)
- return wav
-
-
-def _normalize(S, max_norm=1.0, min_level_db=-100, symmetric=False):
- if symmetric:
- return np.clip(
- (2 * max_norm) * ((S - min_level_db) / (-min_level_db)) - max_norm,
- -max_norm,
- max_norm,
- )
- else:
- return np.clip(max_norm * ((S - min_level_db) / (-min_level_db)), 0,
- max_norm)
-
-
-def _denormalize(D, max_norm=1.0, min_level_db=-100, symmetric=False):
- if symmetric:
- return ((np.clip(D, -max_norm, max_norm) + max_norm) * -min_level_db
- / # noqa W504
- (2 * max_norm)) + min_level_db
- else:
- return (np.clip(D, 0, max_norm) * -min_level_db
- / max_norm) + min_level_db
-
-
-def _griffin_lim(S, n_fft, hop_length, win_length, griffin_lim_iters=60):
- angles = np.exp(2j * np.pi * np.random.rand(*S.shape))
- S_complex = np.abs(S).astype(np.complex)
- y = _istft(
- S_complex * angles, hop_length=hop_length, win_length=win_length)
- for i in range(griffin_lim_iters):
- angles = np.exp(1j * np.angle(
- _stft(
- y, n_fft=n_fft, hop_length=hop_length, win_length=win_length)))
- y = _istft(
- S_complex * angles, hop_length=hop_length, win_length=win_length)
- return y
-
-
-def spectrogram(
- y,
- n_fft=1024,
- hop_length=256,
- win_length=1024,
- max_norm=1.0,
- min_level_db=-100,
- ref_level_db=20,
- symmetric=False,
-):
- D = _stft(preemphasis(y), hop_length, win_length, n_fft)
- S = _amp_to_db(np.abs(D)) - ref_level_db
- return _normalize(S, max_norm, min_level_db, symmetric)
-
-
-def inv_spectrogram(
- spectrogram,
- n_fft=1024,
- hop_length=256,
- win_length=1024,
- max_norm=1.0,
- min_level_db=-100,
- ref_level_db=20,
- symmetric=False,
- power=1.5,
-):
- S = _db_to_amp(
- _denormalize(spectrogram, max_norm, min_level_db, symmetric)
- + ref_level_db)
- return _griffin_lim(S**power, n_fft, hop_length, win_length)
-
-
-def _build_mel_basis(sample_rate, n_fft=1024, fmin=50, fmax=8000, n_mels=80):
- assert fmax <= sample_rate // 2
- return librosa.filters.mel(
- sr=sample_rate, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax)
-
-
-# mel linear Conversions
-_mel_basis = None
-_inv_mel_basis = None
-
-
-def _linear_to_mel(spectogram,
- sample_rate,
- n_fft=1024,
- fmin=50,
- fmax=8000,
- n_mels=80):
- global _mel_basis
- if _mel_basis is None:
- _mel_basis = _build_mel_basis(sample_rate, n_fft, fmin, fmax, n_mels)
- return np.dot(_mel_basis, spectogram)
-
-
-def _mel_to_linear(mel_spectrogram,
- sample_rate,
- n_fft=1024,
- fmin=50,
- fmax=8000,
- n_mels=80):
- global _inv_mel_basis
- if _inv_mel_basis is None:
- _inv_mel_basis = np.linalg.pinv(
- _build_mel_basis(sample_rate, n_fft, fmin, fmax, n_mels))
- return np.maximum(1e-10, np.dot(_inv_mel_basis, mel_spectrogram))
-
-
-def melspectrogram(
- y,
- sample_rate,
- n_fft=1024,
- hop_length=256,
- win_length=1024,
- n_mels=80,
- max_norm=1.0,
- min_level_db=-100,
- ref_level_db=20,
- fmin=50,
- fmax=8000,
- symmetric=False,
- preemphasize=False,
-):
- D = _stft(
- preemphasis(y, preemphasize=preemphasize),
- hop_length=hop_length,
- win_length=win_length,
- n_fft=n_fft,
- )
- S = (
- _amp_to_db(
- _linear_to_mel(
- np.abs(D),
- sample_rate=sample_rate,
- n_fft=n_fft,
- fmin=fmin,
- fmax=fmax,
- n_mels=n_mels,
- )) - ref_level_db)
- return _normalize(
- S, max_norm=max_norm, min_level_db=min_level_db, symmetric=symmetric).T
-
-
-def inv_mel_spectrogram(
- mel_spectrogram,
- sample_rate,
- n_fft=1024,
- hop_length=256,
- win_length=1024,
- n_mels=80,
- max_norm=1.0,
- min_level_db=-100,
- ref_level_db=20,
- fmin=50,
- fmax=8000,
- power=1.5,
- symmetric=False,
- preemphasize=False,
-):
- D = _denormalize(
- mel_spectrogram,
- max_norm=max_norm,
- min_level_db=min_level_db,
- symmetric=symmetric,
- )
- S = _mel_to_linear(
- _db_to_amp(D + ref_level_db),
- sample_rate=sample_rate,
- n_fft=n_fft,
- fmin=fmin,
- fmax=fmax,
- n_mels=n_mels,
- )
- return inv_preemphasis(
- _griffin_lim(S**power, n_fft, hop_length, win_length),
- preemphasize=preemphasize,
- )
diff --git a/modelscope/models/audio/tts/kantts/preprocess/audio_processor/core/utils.py b/modelscope/models/audio/tts/kantts/preprocess/audio_processor/core/utils.py
deleted file mode 100644
index 0004458c..00000000
--- a/modelscope/models/audio/tts/kantts/preprocess/audio_processor/core/utils.py
+++ /dev/null
@@ -1,531 +0,0 @@
-# Copyright (c) Alibaba, Inc. and its affiliates.
-
-import os
-from concurrent.futures import ProcessPoolExecutor
-from glob import glob
-
-import librosa
-import numpy as np
-import pysptk
-import sox
-from scipy.io import wavfile
-from tqdm import tqdm
-
-from modelscope.utils.logger import get_logger
-from .dsp import _stft
-
-logging = get_logger()
-
-anchor_hist = np.array([
- 0.0,
- 0.00215827,
- 0.00354383,
- 0.00442313,
- 0.00490274,
- 0.00532907,
- 0.00602185,
- 0.00690115,
- 0.00810019,
- 0.00948574,
- 0.0120437,
- 0.01489475,
- 0.01873168,
- 0.02302158,
- 0.02872369,
- 0.03669065,
- 0.04636291,
- 0.05843325,
- 0.07700506,
- 0.11052491,
- 0.16802558,
- 0.25997868,
- 0.37942979,
- 0.50730083,
- 0.62006395,
- 0.71092459,
- 0.76877165,
- 0.80762057,
- 0.83458566,
- 0.85672795,
- 0.87660538,
- 0.89251266,
- 0.90578204,
- 0.91569411,
- 0.92541966,
- 0.93383959,
- 0.94162004,
- 0.94940048,
- 0.95539568,
- 0.96136424,
- 0.9670397,
- 0.97290168,
- 0.97705835,
- 0.98116174,
- 0.98465228,
- 0.98814282,
- 0.99152678,
- 0.99421796,
- 0.9965894,
- 0.99840128,
- 1.0,
-])
-
-anchor_bins = np.array([
- 0.033976,
- 0.03529014,
- 0.03660428,
- 0.03791842,
- 0.03923256,
- 0.0405467,
- 0.04186084,
- 0.04317498,
- 0.04448912,
- 0.04580326,
- 0.0471174,
- 0.04843154,
- 0.04974568,
- 0.05105982,
- 0.05237396,
- 0.0536881,
- 0.05500224,
- 0.05631638,
- 0.05763052,
- 0.05894466,
- 0.0602588,
- 0.06157294,
- 0.06288708,
- 0.06420122,
- 0.06551536,
- 0.0668295,
- 0.06814364,
- 0.06945778,
- 0.07077192,
- 0.07208606,
- 0.0734002,
- 0.07471434,
- 0.07602848,
- 0.07734262,
- 0.07865676,
- 0.0799709,
- 0.08128504,
- 0.08259918,
- 0.08391332,
- 0.08522746,
- 0.0865416,
- 0.08785574,
- 0.08916988,
- 0.09048402,
- 0.09179816,
- 0.0931123,
- 0.09442644,
- 0.09574058,
- 0.09705472,
- 0.09836886,
- 0.099683,
-])
-
-hist_bins = 50
-
-
-def amp_info(wav_file_path):
- """
- Returns the amplitude info of the wav file.
- """
- stats = sox.file_info.stat(wav_file_path)
- amp_rms = stats['RMS amplitude']
- amp_max = stats['Maximum amplitude']
- amp_mean = stats['Mean amplitude']
- length = stats['Length (seconds)']
-
- return {
- 'amp_rms': amp_rms,
- 'amp_max': amp_max,
- 'amp_mean': amp_mean,
- 'length': length,
- 'basename': os.path.basename(wav_file_path),
- }
-
-
-def statistic_amplitude(src_wav_dir):
- """
- Returns the amplitude info of the wav file.
- """
- wav_lst = glob(os.path.join(src_wav_dir, '*.wav'))
- with ProcessPoolExecutor(max_workers=8) as executor, tqdm(
- total=len(wav_lst)) as progress:
- futures = []
- for wav_file_path in wav_lst:
- future = executor.submit(amp_info, wav_file_path)
- future.add_done_callback(lambda p: progress.update())
- futures.append(future)
-
- amp_info_lst = [future.result() for future in futures]
-
- amp_info_lst = sorted(amp_info_lst, key=lambda x: x['amp_rms'])
-
- logging.info('Average amplitude RMS : {}'.format(
- np.mean([x['amp_rms'] for x in amp_info_lst])))
-
- return amp_info_lst
-
-
-def volume_normalize(src_wav_dir, out_wav_dir):
- logging.info('Volume statistic proceeding...')
- amp_info_lst = statistic_amplitude(src_wav_dir)
- logging.info('Volume statistic done.')
-
- rms_amp_lst = [x['amp_rms'] for x in amp_info_lst]
- src_hist, src_bins = np.histogram(
- rms_amp_lst, bins=hist_bins, density=True)
- src_hist = src_hist / np.sum(src_hist)
- src_hist = np.cumsum(src_hist)
- src_hist = np.insert(src_hist, 0, 0.0)
-
- logging.info('Volume normalization proceeding...')
- for amp_info in tqdm(amp_info_lst):
- rms_amp = amp_info['amp_rms']
- rms_amp = np.clip(rms_amp, src_bins[0], src_bins[-1])
-
- src_idx = np.where(rms_amp >= src_bins)[0][-1]
- src_pos = src_hist[src_idx]
- anchor_idx = np.where(src_pos >= anchor_hist)[0][-1]
-
- if src_idx == hist_bins or anchor_idx == hist_bins:
- rms_amp = anchor_bins[-1]
- else:
- rms_amp = (rms_amp - src_bins[src_idx]) / (
- src_bins[src_idx + 1] - src_bins[src_idx]) * (
- anchor_bins[anchor_idx + 1]
- - anchor_bins[anchor_idx]) + anchor_bins[anchor_idx]
-
- scale = rms_amp / amp_info['amp_rms']
-
- # FIXME: This is a hack to avoid the sound cliping.
- sr, data = wavfile.read(
- os.path.join(src_wav_dir, amp_info['basename']))
- wavfile.write(
- os.path.join(out_wav_dir, amp_info['basename']),
- sr,
- (data * scale).astype(np.int16),
- )
- logging.info('Volume normalization done.')
-
- return True
-
-
-def interp_f0(f0_data):
- """
- linear interpolation
- """
- f0_data[f0_data < 1] = 0
- xp = np.nonzero(f0_data)
- yp = f0_data[xp]
- x = np.arange(f0_data.size)
- contour_f0 = np.interp(x, xp[0], yp).astype(np.float32)
- return contour_f0
-
-
-def frame_nccf(x, y):
- norm_coef = (np.sum(x**2.0) * np.sum(y**2.0) + 1e-30)**0.5
- return (np.sum(x * y) / norm_coef + 1.0) / 2.0
-
-
-def get_nccf(pcm_data, f0, min_f0=40, max_f0=800, fs=160, sr=16000):
- if pcm_data.dtype == np.int16:
- pcm_data = pcm_data.astype(np.float32) / 32768
- frame_len = int(sr / 200)
- frame_num = int(len(pcm_data) // fs)
- frame_num = min(frame_num, len(f0))
-
- pad_len = int(sr / min_f0) + frame_len
-
- pad_zeros = np.zeros([pad_len], dtype=np.float32)
- data = np.hstack((pad_zeros, pcm_data.astype(np.float32), pad_zeros))
-
- nccf = np.zeros((frame_num), dtype=np.float32)
-
- for i in range(frame_num):
- curr_f0 = np.clip(f0[i], min_f0, max_f0)
- lag = int(sr / curr_f0 + 0.5)
- j = i * fs + pad_len - frame_len // 2
-
- l_data = data[j:j + frame_len]
- l_data -= l_data.mean()
-
- r_data = data[j + lag:j + lag + frame_len]
- r_data -= r_data.mean()
-
- nccf[i] = frame_nccf(l_data, r_data)
-
- return nccf
-
-
-def smooth(data, win_len):
- if win_len % 2 == 0:
- win_len += 1
- hwin = win_len // 2
- win = np.hanning(win_len)
- win /= win.sum()
- data = data.reshape([-1])
- pad_data = np.pad(data, hwin, mode='edge')
-
- for i in range(data.shape[0]):
- data[i] = np.dot(win, pad_data[i:i + win_len])
-
- return data.reshape([-1, 1])
-
-
-# support: rapt, swipe
-# unsupport: reaper, world(DIO)
-def RAPT_FUNC(v1, v2, v3, v4, v5):
- return pysptk.sptk.rapt(
- v1.astype(np.float32), fs=v2, hopsize=v3, min=v4, max=v5)
-
-
-def SWIPE_FUNC(v1, v2, v3, v4, v5):
- return pysptk.sptk.swipe(
- v1.astype(np.float64), fs=v2, hopsize=v3, min=v4, max=v5)
-
-
-def PYIN_FUNC(v1, v2, v3, v4, v5):
- f0_mel = librosa.pyin(
- v1.astype(np.float32), sr=v2, frame_length=v3 * 4, fmin=v4, fmax=v5)[0]
- f0_mel = np.where(np.isnan(f0_mel), 0.0, f0_mel)
- return f0_mel
-
-
-def get_pitch(pcm_data, sampling_rate=16000, hop_length=160):
- log_f0_list = []
- uv_list = []
- low, high = 40, 800
-
- cali_f0 = pysptk.sptk.rapt(
- pcm_data.astype(np.float32),
- fs=sampling_rate,
- hopsize=hop_length,
- min=low,
- max=high,
- )
- f0_range = np.sort(np.unique(cali_f0))
-
- if len(f0_range) > 20:
- low = max(f0_range[10] - 50, low)
- high = min(f0_range[-10] + 50, high)
-
- func_dict = {'rapt': RAPT_FUNC, 'swipe': SWIPE_FUNC}
-
- for func_name in func_dict:
- f0 = func_dict[func_name](pcm_data, sampling_rate, hop_length, low,
- high)
- uv = f0 > 0
-
- if len(f0) < 10 or f0.max() < low:
- logging.error('{} method: calc F0 is too low.'.format(func_name))
- continue
- else:
- f0 = np.clip(f0, 1e-30, high)
- log_f0 = np.log(f0)
- contour_log_f0 = interp_f0(log_f0)
-
- log_f0_list.append(contour_log_f0)
- uv_list.append(uv)
-
- if len(log_f0_list) == 0:
- logging.error('F0 estimation failed.')
- return None
-
- min_len = float('inf')
- for log_f0 in log_f0_list:
- min_len = min(min_len, log_f0.shape[0])
-
- multi_log_f0 = np.zeros([len(log_f0_list), min_len], dtype=np.float32)
- multi_uv = np.zeros([len(log_f0_list), min_len], dtype=np.float32)
-
- for i in range(len(log_f0_list)):
- multi_log_f0[i, :] = log_f0_list[i][:min_len]
- multi_uv[i, :] = uv_list[i][:min_len]
-
- log_f0 = smooth(np.median(multi_log_f0, axis=0), 5)
- uv = (smooth(np.median(multi_uv, axis=0), 5) > 0.5).astype(np.float32)
-
- f0 = np.exp(log_f0)
-
- min_len = min(f0.shape[0], uv.shape[0])
-
- return f0[:min_len], uv[:min_len], f0[:min_len] * uv[:min_len]
-
-
-def get_energy(pcm_data, hop_length, win_length, n_fft):
- D = _stft(pcm_data, hop_length, win_length, n_fft)
- S, _ = librosa.magphase(D)
- energy = np.sqrt(np.sum(S**2, axis=0))
-
- return energy.reshape((-1, 1))
-
-
-def align_length(in_data, tgt_data, basename=None):
- if in_data is None or tgt_data is None:
- logging.error('{}: Input data is None.'.format(basename))
- return None
- in_len = in_data.shape[0]
- tgt_len = tgt_data.shape[0]
- if abs(in_len - tgt_len) > 20:
- logging.error(
- '{}: Input data length mismatches with target data length too much.'
- .format(basename))
- return None
-
- if in_len < tgt_len:
- out_data = np.pad(
- in_data, ((0, tgt_len - in_len), (0, 0)),
- 'constant',
- constant_values=0.0)
- else:
- out_data = in_data[:tgt_len]
-
- return out_data
-
-
-def compute_mean(data_list, dims=80):
- mean_vector = np.zeros((1, dims))
- all_frame_number = 0
-
- for data in tqdm(data_list):
- if data is None:
- continue
- features = data.reshape((-1, dims))
- current_frame_number = np.shape(features)[0]
- mean_vector += np.sum(features[:, :], axis=0)
- all_frame_number += current_frame_number
-
- mean_vector /= float(all_frame_number)
- return mean_vector
-
-
-def compute_std(data_list, mean_vector, dims=80):
- std_vector = np.zeros((1, dims))
- all_frame_number = 0
-
- for data in tqdm(data_list):
- if data is None:
- continue
- features = data.reshape((-1, dims))
- current_frame_number = np.shape(features)[0]
- mean_matrix = np.tile(mean_vector, (current_frame_number, 1))
- std_vector += np.sum((features[:, :] - mean_matrix)**2, axis=0)
- all_frame_number += current_frame_number
-
- std_vector /= float(all_frame_number)
- std_vector = std_vector**0.5
- return std_vector
-
-
-F0_MIN = 0.0
-F0_MAX = 800.0
-
-ENERGY_MIN = 0.0
-ENERGY_MAX = 200.0
-
-CLIP_FLOOR = 1e-3
-
-
-def f0_norm_min_max(f0):
- zero_idxs = np.where(f0 <= CLIP_FLOOR)[0]
- res = (2 * f0 - F0_MIN - F0_MAX) / (F0_MAX - F0_MIN)
- res[zero_idxs] = 0.0
- return res
-
-
-def f0_denorm_min_max(f0):
- zero_idxs = np.where(f0 == 0.0)[0]
- res = (f0 * (F0_MAX - F0_MIN) + F0_MIN + F0_MAX) / 2
- res[zero_idxs] = 0.0
- return res
-
-
-def energy_norm_min_max(energy):
- zero_idxs = np.where(energy == 0.0)[0]
- res = (2 * energy - ENERGY_MIN - ENERGY_MAX) / (ENERGY_MAX - ENERGY_MIN)
- res[zero_idxs] = 0.0
- return res
-
-
-def energy_denorm_min_max(energy):
- zero_idxs = np.where(energy == 0.0)[0]
- res = (energy * (ENERGY_MAX - ENERGY_MIN) + ENERGY_MIN + ENERGY_MAX) / 2
- res[zero_idxs] = 0.0
- return res
-
-
-def norm_log(x):
- zero_idxs = np.where(x <= CLIP_FLOOR)[0]
- x[zero_idxs] = 1.0
- res = np.log(x)
- return res
-
-
-def denorm_log(x):
- zero_idxs = np.where(x == 0.0)[0]
- res = np.exp(x)
- res[zero_idxs] = 0.0
- return res
-
-
-def f0_norm_mean_std(x, mean, std):
- zero_idxs = np.where(x == 0.0)[0]
- x = (x - mean) / std
- x[zero_idxs] = 0.0
- return x
-
-
-def norm_mean_std(x, mean, std):
- x = (x - mean) / std
- return x
-
-
-def parse_interval_file(file_path, sampling_rate, hop_length):
- with open(file_path, 'r') as f:
- lines = f.readlines()
- # second
- frame_intervals = 1.0 * hop_length / sampling_rate
- skip_lines = 12
- dur_list = []
- phone_list = []
-
- line_index = skip_lines
-
- while line_index < len(lines):
- phone_begin = float(lines[line_index])
- phone_end = float(lines[line_index + 1])
- phone = lines[line_index + 2].strip()[1:-1]
- dur_list.append(
- int(round((phone_end - phone_begin) / frame_intervals)))
- phone_list.append(phone)
- line_index += 3
-
- if len(dur_list) == 0 or len(phone_list) == 0:
- return None
-
- return np.array(dur_list), phone_list
-
-
-def average_by_duration(x, durs):
- if x is None or durs is None:
- return None
- durs_cum = np.cumsum(np.pad(durs, (1, 0), 'constant'))
-
- # average over each symbol's duration
- x_symbol = np.zeros((durs.shape[0], ), dtype=np.float32)
- for idx, start, end in zip(
- range(durs.shape[0]), durs_cum[:-1], durs_cum[1:]):
- values = x[start:end][np.where(x[start:end] != 0.0)[0]]
- x_symbol[idx] = np.mean(values) if len(values) > 0 else 0.0
-
- return x_symbol.astype(np.float32)
-
-
-def encode_16bits(x):
- if x.min() > -1.0 and x.max() < 1.0:
- return np.clip(x * 2**15, -(2**15), 2**15 - 1).astype(np.int16)
- else:
- return x
diff --git a/modelscope/models/audio/tts/kantts/preprocess/data_process.py b/modelscope/models/audio/tts/kantts/preprocess/data_process.py
deleted file mode 100644
index 68025375..00000000
--- a/modelscope/models/audio/tts/kantts/preprocess/data_process.py
+++ /dev/null
@@ -1,186 +0,0 @@
-# Copyright (c) Alibaba, Inc. and its affiliates.
-
-import argparse
-import codecs
-import os
-import sys
-import time
-
-import yaml
-
-from modelscope import __version__
-from modelscope.models.audio.tts.kantts.datasets.dataset import (AmDataset,
- VocDataset)
-from modelscope.utils.logger import get_logger
-from .audio_processor.audio_processor import AudioProcessor
-from .fp_processor import FpProcessor, is_fp_line
-from .languages import languages
-from .script_convertor.text_script_convertor import TextScriptConvertor
-
-ROOT_PATH = os.path.dirname(os.path.dirname(
- os.path.abspath(__file__))) # NOQA: E402
-sys.path.insert(0, os.path.dirname(ROOT_PATH)) # NOQA: E402
-
-logging = get_logger()
-
-
-def gen_metafile(
- voice_output_dir,
- fp_enable=False,
- badlist=None,
- split_ratio=0.98,
-):
-
- voc_train_meta = os.path.join(voice_output_dir, 'train.lst')
- voc_valid_meta = os.path.join(voice_output_dir, 'valid.lst')
- if not os.path.exists(voc_train_meta) or not os.path.exists(
- voc_valid_meta):
- VocDataset.gen_metafile(
- os.path.join(voice_output_dir, 'wav'),
- voice_output_dir,
- split_ratio,
- )
- logging.info('Voc metafile generated.')
-
- raw_metafile = os.path.join(voice_output_dir, 'raw_metafile.txt')
- am_train_meta = os.path.join(voice_output_dir, 'am_train.lst')
- am_valid_meta = os.path.join(voice_output_dir, 'am_valid.lst')
- if not os.path.exists(am_train_meta) or not os.path.exists(am_valid_meta):
- AmDataset.gen_metafile(
- raw_metafile,
- voice_output_dir,
- am_train_meta,
- am_valid_meta,
- badlist,
- split_ratio,
- )
- logging.info('AM metafile generated.')
-
- if fp_enable:
- fpadd_metafile = os.path.join(voice_output_dir, 'fpadd_metafile.txt')
- am_train_meta = os.path.join(voice_output_dir, 'am_fpadd_train.lst')
- am_valid_meta = os.path.join(voice_output_dir, 'am_fpadd_valid.lst')
- if not os.path.exists(am_train_meta) or not os.path.exists(
- am_valid_meta):
- AmDataset.gen_metafile(
- fpadd_metafile,
- voice_output_dir,
- am_train_meta,
- am_valid_meta,
- badlist,
- split_ratio,
- )
- logging.info('AM fpaddmetafile generated.')
-
- fprm_metafile = os.path.join(voice_output_dir, 'fprm_metafile.txt')
- am_train_meta = os.path.join(voice_output_dir, 'am_fprm_train.lst')
- am_valid_meta = os.path.join(voice_output_dir, 'am_fprm_valid.lst')
- if not os.path.exists(am_train_meta) or not os.path.exists(
- am_valid_meta):
- AmDataset.gen_metafile(
- fprm_metafile,
- voice_output_dir,
- am_train_meta,
- am_valid_meta,
- badlist,
- split_ratio,
- )
- logging.info('AM fprmmetafile generated.')
-
-
-def process_data(
- voice_input_dir,
- voice_output_dir,
- language_dir,
- audio_config,
- speaker_name=None,
- targetLang='PinYin',
- skip_script=False,
-):
- foreignLang = 'EnUS'
- emo_tag_path = None
-
- phoneset_path = os.path.join(language_dir, targetLang,
- languages[targetLang]['phoneset_path'])
- posset_path = os.path.join(language_dir, targetLang,
- languages[targetLang]['posset_path'])
- f2t_map_path = os.path.join(language_dir, targetLang,
- languages[targetLang]['f2t_map_path'])
- s2p_map_path = os.path.join(language_dir, targetLang,
- languages[targetLang]['s2p_map_path'])
-
- logging.info(f'phoneset_path={phoneset_path}')
- # dir of plain text/sentences for training byte based model
- plain_text_dir = os.path.join(voice_input_dir, 'text')
-
- if speaker_name is None:
- speaker_name = os.path.basename(voice_input_dir)
-
- if audio_config is not None:
- with open(audio_config, 'r') as f:
- config = yaml.load(f, Loader=yaml.Loader)
-
- config['create_time'] = time.strftime('%Y-%m-%d %H:%M:%S',
- time.localtime())
- config['modelscope_version'] = __version__
-
- with open(os.path.join(voice_output_dir, 'audio_config.yaml'), 'w') as f:
- yaml.dump(config, f, Dumper=yaml.Dumper, default_flow_style=None)
-
- if skip_script:
- logging.info('Skip script conversion')
- raw_metafile = None
- # Script processor
- if not skip_script:
- if os.path.exists(plain_text_dir):
- TextScriptConvertor.turn_text_into_bytes(
- os.path.join(plain_text_dir, 'text.txt'),
- os.path.join(voice_output_dir, 'raw_metafile.txt'),
- speaker_name,
- )
- fp_enable = False
- else:
- tsc = TextScriptConvertor(
- phoneset_path,
- posset_path,
- targetLang,
- foreignLang,
- f2t_map_path,
- s2p_map_path,
- emo_tag_path,
- speaker_name,
- )
- tsc.process(
- os.path.join(voice_input_dir, 'prosody', 'prosody.txt'),
- os.path.join(voice_output_dir, 'Script.xml'),
- os.path.join(voice_output_dir, 'raw_metafile.txt'),
- )
- prosody = os.path.join(voice_input_dir, 'prosody', 'prosody.txt')
- # FP processor
- with codecs.open(prosody, 'r', 'utf-8') as f:
- lines = f.readlines()
- fp_enable = is_fp_line(lines[1])
- raw_metafile = os.path.join(voice_output_dir, 'raw_metafile.txt')
-
- if fp_enable:
- FP = FpProcessor()
-
- FP.process(
- voice_output_dir,
- prosody,
- raw_metafile,
- )
- logging.info('Processing fp done.')
-
- # Audio processor
- ap = AudioProcessor(config['audio_config'])
- ap.process(
- voice_input_dir,
- voice_output_dir,
- raw_metafile,
- )
-
- logging.info('Processing done.')
-
- # Generate Voc&AM metafile
- gen_metafile(voice_output_dir, fp_enable, ap.badcase_list)
diff --git a/modelscope/models/audio/tts/kantts/preprocess/fp_processor.py b/modelscope/models/audio/tts/kantts/preprocess/fp_processor.py
deleted file mode 100644
index 910a374c..00000000
--- a/modelscope/models/audio/tts/kantts/preprocess/fp_processor.py
+++ /dev/null
@@ -1,156 +0,0 @@
-# Copyright (c) Alibaba, Inc. and its affiliates.
-
-import os
-import random
-
-from modelscope.utils.logger import get_logger
-
-logging = get_logger()
-
-
-def is_fp_line(line):
- fp_category_list = ['FP', 'I', 'N', 'Q']
- elements = line.strip().split(' ')
- res = True
- for ele in elements:
- if ele not in fp_category_list:
- res = False
- break
- return res
-
-
-class FpProcessor:
-
- def __init__(self):
- # TODO: Add more audio processing methods.
- self.res = []
-
- def is_fp_line(line):
- fp_category_list = ['FP', 'I', 'N', 'Q']
- elements = line.strip().split(' ')
- res = True
- for ele in elements:
- if ele not in fp_category_list:
- res = False
- break
- return res
-
- # TODO: adjust idx judgment rule
- def addfp(self, voice_output_dir, prosody, raw_metafile_lines):
-
- fp_category_list = ['FP', 'I', 'N']
-
- f = open(prosody)
- prosody_lines = f.readlines()
- f.close()
-
- idx = ''
- fp = ''
- fp_label_dict = {}
- i = 0
- while i < len(prosody_lines):
- if len(prosody_lines[i].strip().split('\t')) == 2:
- idx = prosody_lines[i].strip().split('\t')[0]
- i += 1
- else:
- fp_enable = is_fp_line(prosody_lines[i])
- if fp_enable:
- fp = prosody_lines[i].strip().split('\t')[0].split(' ')
- for label in fp:
- if label not in fp_category_list:
- logging.warning('fp label not in fp_category_list')
- break
- i += 4
- else:
- fp = [
- 'N' for _ in range(
- len(prosody_lines[i].strip().split('\t')
- [0].replace('/ ', '').replace('. ', '').split(
- ' ')))
- ]
- i += 1
- fp_label_dict[idx] = fp
-
- fpadd_metafile = os.path.join(voice_output_dir, 'fpadd_metafile.txt')
- f_out = open(fpadd_metafile, 'w')
- for line in raw_metafile_lines:
- tokens = line.strip().split('\t')
- if len(tokens) == 2:
- uttname = tokens[0]
- symbol_sequences = tokens[1].split(' ')
-
- error_flag = False
- idx = 0
- out_str = uttname + '\t'
-
- for this_symbol_sequence in symbol_sequences:
- emotion = this_symbol_sequence.split('$')[4]
- this_symbol_sequence = this_symbol_sequence.replace(
- emotion, 'emotion_neutral')
-
- if idx < len(fp_label_dict[uttname]):
- if fp_label_dict[uttname][idx] == 'FP':
- if 'none' not in this_symbol_sequence:
- this_symbol_sequence = this_symbol_sequence.replace(
- 'emotion_neutral', 'emotion_disgust')
- syllable_label = this_symbol_sequence.split('$')[2]
- if syllable_label == 's_both' or syllable_label == 's_end':
- idx += 1
- elif idx > len(fp_label_dict[uttname]):
- logging.warning(uttname + ' not match')
- error_flag = True
- out_str = out_str + this_symbol_sequence + ' '
-
- # if idx != len(fp_label_dict[uttname]):
- # logging.warning(
- # "{} length mismatch, length: {} ".format(
- # idx, len(fp_label_dict[uttname])
- # )
- # )
-
- if not error_flag:
- f_out.write(out_str.strip() + '\n')
- f_out.close()
- return fpadd_metafile
-
- def removefp(self, voice_output_dir, fpadd_metafile, raw_metafile_lines):
-
- f = open(fpadd_metafile)
- fpadd_metafile_lines = f.readlines()
- f.close()
-
- fprm_metafile = os.path.join(voice_output_dir, 'fprm_metafile.txt')
- f_out = open(fprm_metafile, 'w')
- for i in range(len(raw_metafile_lines)):
- tokens = raw_metafile_lines[i].strip().split('\t')
- symbol_sequences = tokens[1].split(' ')
- fpadd_tokens = fpadd_metafile_lines[i].strip().split('\t')
- fpadd_symbol_sequences = fpadd_tokens[1].split(' ')
-
- error_flag = False
- out_str = tokens[0] + '\t'
- idx = 0
- length = len(symbol_sequences)
- while idx < length:
- if '$emotion_disgust' in fpadd_symbol_sequences[idx]:
- if idx + 1 < length and 'none' in fpadd_symbol_sequences[
- idx + 1]:
- idx = idx + 2
- else:
- idx = idx + 1
- continue
- out_str = out_str + symbol_sequences[idx] + ' '
- idx = idx + 1
-
- if not error_flag:
- f_out.write(out_str.strip() + '\n')
- f_out.close()
-
- def process(self, voice_output_dir, prosody, raw_metafile):
-
- with open(raw_metafile, 'r') as f:
- lines = f.readlines()
- random.shuffle(lines)
-
- fpadd_metafile = self.addfp(voice_output_dir, prosody, lines)
- self.removefp(voice_output_dir, fpadd_metafile, lines)
diff --git a/modelscope/models/audio/tts/kantts/preprocess/languages/__init__.py b/modelscope/models/audio/tts/kantts/preprocess/languages/__init__.py
deleted file mode 100644
index 3363e64a..00000000
--- a/modelscope/models/audio/tts/kantts/preprocess/languages/__init__.py
+++ /dev/null
@@ -1,46 +0,0 @@
-# Copyright (c) Alibaba, Inc. and its affiliates.
-
-languages = {
- 'PinYin': {
- 'phoneset_path': 'PhoneSet.xml',
- 'posset_path': 'PosSet.xml',
- 'f2t_map_path': 'En2ChPhoneMap.txt',
- 's2p_map_path': 'py2phoneMap.txt',
- 'tonelist_path': 'tonelist.txt',
- },
- 'ZhHK': {
- 'phoneset_path': 'PhoneSet.xml',
- 'posset_path': 'PosSet.xml',
- 'f2t_map_path': 'En2ChPhoneMap.txt',
- 's2p_map_path': 'py2phoneMap.txt',
- 'tonelist_path': 'tonelist.txt',
- },
- 'WuuShanghai': {
- 'phoneset_path': 'PhoneSet.xml',
- 'posset_path': 'PosSet.xml',
- 'f2t_map_path': 'En2ChPhoneMap.txt',
- 's2p_map_path': 'py2phoneMap.txt',
- 'tonelist_path': 'tonelist.txt',
- },
- 'Sichuan': {
- 'phoneset_path': 'PhoneSet.xml',
- 'posset_path': 'PosSet.xml',
- 'f2t_map_path': 'En2ChPhoneMap.txt',
- 's2p_map_path': 'py2phoneMap.txt',
- 'tonelist_path': 'tonelist.txt',
- },
- 'EnGB': {
- 'phoneset_path': 'PhoneSet.xml',
- 'posset_path': 'PosSet.xml',
- 'f2t_map_path': '',
- 's2p_map_path': '',
- 'tonelist_path': 'tonelist.txt',
- },
- 'EnUS': {
- 'phoneset_path': 'PhoneSet.xml',
- 'posset_path': 'PosSet.xml',
- 'f2t_map_path': '',
- 's2p_map_path': '',
- 'tonelist_path': 'tonelist.txt',
- }
-}
diff --git a/modelscope/models/audio/tts/kantts/preprocess/script_convertor/core/core_types.py b/modelscope/models/audio/tts/kantts/preprocess/script_convertor/core/core_types.py
deleted file mode 100644
index ce7f6080..00000000
--- a/modelscope/models/audio/tts/kantts/preprocess/script_convertor/core/core_types.py
+++ /dev/null
@@ -1,242 +0,0 @@
-# Copyright (c) Alibaba, Inc. and its affiliates.
-
-from enum import Enum
-
-
-class Tone(Enum):
- UnAssigned = -1
- NoneTone = 0
- YinPing = 1 # ZhHK: YinPingYinRu EnUS: primary stress
- YangPing = 2 # ZhHK: YinShang EnUS: secondary stress
- ShangSheng = 3 # ZhHK: YinQuZhongRu
- QuSheng = 4 # ZhHK: YangPing
- QingSheng = 5 # ZhHK: YangShang
- YangQuYangRu = 6 # ZhHK: YangQuYangRu
-
- @classmethod
- def parse(cls, in_str):
- if not isinstance(in_str, str):
- return super(Tone, cls).__new__(cls, in_str)
-
- if in_str in ['UnAssigned', '-1']:
- return Tone.UnAssigned
- elif in_str in ['NoneTone', '0']:
- return Tone.NoneTone
- elif in_str in ['YinPing', '1']:
- return Tone.YinPing
- elif in_str in ['YangPing', '2']:
- return Tone.YangPing
- elif in_str in ['ShangSheng', '3']:
- return Tone.ShangSheng
- elif in_str in ['QuSheng', '4']:
- return Tone.QuSheng
- elif in_str in ['QingSheng', '5']:
- return Tone.QingSheng
- elif in_str in ['YangQuYangRu', '6']:
- return Tone.YangQuYangRu
- else:
- return Tone.NoneTone
-
-
-class BreakLevel(Enum):
- UnAssigned = -1
- L0 = 0
- L1 = 1
- L2 = 2
- L3 = 3
- L4 = 4
-
- @classmethod
- def parse(cls, in_str):
- if not isinstance(in_str, str):
- return super(BreakLevel, cls).__new__(cls, in_str)
-
- if in_str in ['UnAssigned', '-1']:
- return BreakLevel.UnAssigned
- elif in_str in ['L0', '0']:
- return BreakLevel.L0
- elif in_str in ['L1', '1']:
- return BreakLevel.L1
- elif in_str in ['L2', '2']:
- return BreakLevel.L2
- elif in_str in ['L3', '3']:
- return BreakLevel.L3
- elif in_str in ['L4', '4']:
- return BreakLevel.L4
- else:
- return BreakLevel.UnAssigned
-
-
-class SentencePurpose(Enum):
- Declarative = 0
- Interrogative = 1
- Exclamatory = 2
- Imperative = 3
-
-
-class Language(Enum):
- Neutral = 0
- EnUS = 1033
- EnGB = 2057
- ZhCN = 2052
- PinYin = 2053
- WuuShanghai = 2054
- Sichuan = 2055
- ZhHK = 3076
- ZhEn = ZhCN | EnUS
-
- @classmethod
- def parse(cls, in_str):
- if not isinstance(in_str, str):
- return super(Language, cls).__new__(cls, in_str)
-
- if in_str in ['Neutral', '0']:
- return Language.Neutral
- elif in_str in ['EnUS', '1033']:
- return Language.EnUS
- elif in_str in ['EnGB', '2057']:
- return Language.EnGB
- elif in_str in ['ZhCN', '2052']:
- return Language.ZhCN
- elif in_str in ['PinYin', '2053']:
- return Language.PinYin
- elif in_str in ['WuuShanghai', '2054']:
- return Language.WuuShanghai
- elif in_str in ['Sichuan', '2055']:
- return Language.Sichuan
- elif in_str in ['ZhHK', '3076']:
- return Language.ZhHK
- elif in_str in ['ZhEn', '2052|1033']:
- return Language.ZhEn
- else:
- return Language.Neutral
-
-
-"""
-Phone Types
-"""
-
-
-class PhoneCVType(Enum):
- NULL = -1
- Consonant = 1
- Vowel = 2
-
- @classmethod
- def parse(cls, in_str):
- if not isinstance(in_str, str):
- return super(PhoneCVType, cls).__new__(cls, in_str)
-
- if in_str in ['consonant', 'Consonant']:
- return PhoneCVType.Consonant
- elif in_str in ['vowel', 'Vowel']:
- return PhoneCVType.Vowel
- else:
- return PhoneCVType.NULL
-
-
-class PhoneIFType(Enum):
- NULL = -1
- Initial = 1
- Final = 2
-
- @classmethod
- def parse(cls, in_str):
- if not isinstance(in_str, str):
- return super(PhoneIFType, cls).__new__(cls, in_str)
- if in_str in ['initial', 'Initial']:
- return PhoneIFType.Initial
- elif in_str in ['final', 'Final']:
- return PhoneIFType.Final
- else:
- return PhoneIFType.NULL
-
-
-class PhoneUVType(Enum):
- NULL = -1
- Voiced = 1
- UnVoiced = 2
-
- @classmethod
- def parse(cls, in_str):
- if not isinstance(in_str, str):
- return super(PhoneUVType, cls).__new__(cls, in_str)
- if in_str in ['voiced', 'Voiced']:
- return PhoneUVType.Voiced
- elif in_str in ['unvoiced', 'UnVoiced']:
- return PhoneUVType.UnVoiced
- else:
- return PhoneUVType.NULL
-
-
-class PhoneAPType(Enum):
- NULL = -1
- DoubleLips = 1
- LipTooth = 2
- FrontTongue = 3
- CentralTongue = 4
- BackTongue = 5
- Dorsal = 6
- Velar = 7
- Low = 8
- Middle = 9
- High = 10
-
- @classmethod
- def parse(cls, in_str):
- if not isinstance(in_str, str):
- return super(PhoneAPType, cls).__new__(cls, in_str)
- if in_str in ['doublelips', 'DoubleLips']:
- return PhoneAPType.DoubleLips
- elif in_str in ['liptooth', 'LipTooth']:
- return PhoneAPType.LipTooth
- elif in_str in ['fronttongue', 'FrontTongue']:
- return PhoneAPType.FrontTongue
- elif in_str in ['centraltongue', 'CentralTongue']:
- return PhoneAPType.CentralTongue
- elif in_str in ['backtongue', 'BackTongue']:
- return PhoneAPType.BackTongue
- elif in_str in ['dorsal', 'Dorsal']:
- return PhoneAPType.Dorsal
- elif in_str in ['velar', 'Velar']:
- return PhoneAPType.Velar
- elif in_str in ['low', 'Low']:
- return PhoneAPType.Low
- elif in_str in ['middle', 'Middle']:
- return PhoneAPType.Middle
- elif in_str in ['high', 'High']:
- return PhoneAPType.High
- else:
- return PhoneAPType.NULL
-
-
-class PhoneAMType(Enum):
- NULL = -1
- Stop = 1
- Affricate = 2
- Fricative = 3
- Nasal = 4
- Lateral = 5
- Open = 6
- Close = 7
-
- @classmethod
- def parse(cls, in_str):
- if not isinstance(in_str, str):
- return super(PhoneAMType, cls).__new__(cls, in_str)
- if in_str in ['stop', 'Stop']:
- return PhoneAMType.Stop
- elif in_str in ['affricate', 'Affricate']:
- return PhoneAMType.Affricate
- elif in_str in ['fricative', 'Fricative']:
- return PhoneAMType.Fricative
- elif in_str in ['nasal', 'Nasal']:
- return PhoneAMType.Nasal
- elif in_str in ['lateral', 'Lateral']:
- return PhoneAMType.Lateral
- elif in_str in ['open', 'Open']:
- return PhoneAMType.Open
- elif in_str in ['close', 'Close']:
- return PhoneAMType.Close
- else:
- return PhoneAMType.NULL
diff --git a/modelscope/models/audio/tts/kantts/preprocess/script_convertor/core/phone.py b/modelscope/models/audio/tts/kantts/preprocess/script_convertor/core/phone.py
deleted file mode 100644
index cdefe37f..00000000
--- a/modelscope/models/audio/tts/kantts/preprocess/script_convertor/core/phone.py
+++ /dev/null
@@ -1,48 +0,0 @@
-# Copyright (c) Alibaba, Inc. and its affiliates.
-
-from .core_types import (PhoneAMType, PhoneAPType, PhoneCVType, PhoneIFType,
- PhoneUVType)
-from .xml_obj import XmlObj
-
-
-class Phone(XmlObj):
-
- def __init__(self):
- self.m_id = None
- self.m_name = None
- self.m_cv_type = PhoneCVType.NULL
- self.m_if_type = PhoneIFType.NULL
- self.m_uv_type = PhoneUVType.NULL
- self.m_ap_type = PhoneAPType.NULL
- self.m_am_type = PhoneAMType.NULL
- self.m_bnd = False
-
- def __str__(self):
- return self.m_name
-
- def save(self):
- pass
-
- def load(self, phone_node):
- ns = '{http://schemas.alibaba-inc.com/tts}'
-
- id_node = phone_node.find(ns + 'id')
- self.m_id = int(id_node.text)
-
- name_node = phone_node.find(ns + 'name')
- self.m_name = name_node.text
-
- cv_node = phone_node.find(ns + 'cv')
- self.m_cv_type = PhoneCVType.parse(cv_node.text)
-
- if_node = phone_node.find(ns + 'if')
- self.m_if_type = PhoneIFType.parse(if_node.text)
-
- uv_node = phone_node.find(ns + 'uv')
- self.m_uv_type = PhoneUVType.parse(uv_node.text)
-
- ap_node = phone_node.find(ns + 'ap')
- self.m_ap_type = PhoneAPType.parse(ap_node.text)
-
- am_node = phone_node.find(ns + 'am')
- self.m_am_type = PhoneAMType.parse(am_node.text)
diff --git a/modelscope/models/audio/tts/kantts/preprocess/script_convertor/core/phone_set.py b/modelscope/models/audio/tts/kantts/preprocess/script_convertor/core/phone_set.py
deleted file mode 100644
index defe4e30..00000000
--- a/modelscope/models/audio/tts/kantts/preprocess/script_convertor/core/phone_set.py
+++ /dev/null
@@ -1,39 +0,0 @@
-# Copyright (c) Alibaba, Inc. and its affiliates.
-
-import xml.etree.ElementTree as ET
-
-from modelscope.utils.logger import get_logger
-from .phone import Phone
-from .xml_obj import XmlObj
-
-logging = get_logger()
-
-
-class PhoneSet(XmlObj):
-
- def __init__(self, phoneset_path):
- self.m_phone_list = []
- self.m_id_map = {}
- self.m_name_map = {}
- self.load(phoneset_path)
-
- def load(self, file_path):
- # alibaba tts xml namespace
- ns = '{http://schemas.alibaba-inc.com/tts}'
-
- phoneset_root = ET.parse(file_path).getroot()
- for phone_node in phoneset_root.findall(ns + 'phone'):
- phone = Phone()
- phone.load(phone_node)
- self.m_phone_list.append(phone)
- if phone.m_id in self.m_id_map:
- logging.error('PhoneSet.Load: duplicate id: %d', phone.m_id)
- self.m_id_map[phone.m_id] = phone
-
- if phone.m_name in self.m_name_map:
- logging.error('PhoneSet.Load duplicate name name: %s',
- phone.m_name)
- self.m_name_map[phone.m_name] = phone
-
- def save(self):
- pass
diff --git a/modelscope/models/audio/tts/kantts/preprocess/script_convertor/core/pos.py b/modelscope/models/audio/tts/kantts/preprocess/script_convertor/core/pos.py
deleted file mode 100644
index 2a4563dd..00000000
--- a/modelscope/models/audio/tts/kantts/preprocess/script_convertor/core/pos.py
+++ /dev/null
@@ -1,43 +0,0 @@
-# Copyright (c) Alibaba, Inc. and its affiliates.
-
-from .xml_obj import XmlObj
-
-
-class Pos(XmlObj):
-
- def __init__(self):
- self.m_id = None
- self.m_name = None
- self.m_desc = None
- self.m_level = 1
- self.m_parent = None
- self.m_sub_pos_list = []
-
- def __str__(self):
- return self.m_name
-
- def save(self):
- pass
-
- def load(self, pos_node):
- ns = '{http://schemas.alibaba-inc.com/tts}'
-
- id_node = pos_node.find(ns + 'id')
- self.m_id = int(id_node.text)
-
- name_node = pos_node.find(ns + 'name')
- self.m_name = name_node.text
-
- desc_node = pos_node.find(ns + 'desc')
- self.m_desc = desc_node.text
-
- sub_node = pos_node.find(ns + 'sub')
- if sub_node is not None:
- for sub_pos_node in sub_node.findall(ns + 'pos'):
- sub_pos = Pos()
- sub_pos.load(sub_pos_node)
- sub_pos.m_parent = self
- sub_pos.m_level = self.m_level + 1
- self.m_sub_pos_list.append(sub_pos)
-
- return
diff --git a/modelscope/models/audio/tts/kantts/preprocess/script_convertor/core/pos_set.py b/modelscope/models/audio/tts/kantts/preprocess/script_convertor/core/pos_set.py
deleted file mode 100644
index 26d170b5..00000000
--- a/modelscope/models/audio/tts/kantts/preprocess/script_convertor/core/pos_set.py
+++ /dev/null
@@ -1,50 +0,0 @@
-# Copyright (c) Alibaba, Inc. and its affiliates.
-
-import logging
-import xml.etree.ElementTree as ET
-
-from .pos import Pos
-from .xml_obj import XmlObj
-
-
-class PosSet(XmlObj):
-
- def __init__(self, posset_path):
- self.m_pos_list = []
- self.m_id_map = {}
- self.m_name_map = {}
- self.load(posset_path)
-
- def load(self, file_path):
- # alibaba tts xml namespace
- ns = '{http://schemas.alibaba-inc.com/tts}'
-
- posset_root = ET.parse(file_path).getroot()
- for pos_node in posset_root.findall(ns + 'pos'):
- pos = Pos()
- pos.load(pos_node)
- self.m_pos_list.append(pos)
- if pos.m_id in self.m_id_map:
- logging.error('PosSet.Load: duplicate id: %d', pos.m_id)
- self.m_id_map[pos.m_id] = pos
-
- if pos.m_name in self.m_name_map:
- logging.error('PosSet.Load duplicate name name: %s',
- pos.m_name)
- self.m_name_map[pos.m_name] = pos
-
- if len(pos.m_sub_pos_list) > 0:
- for sub_pos in pos.m_sub_pos_list:
- self.m_pos_list.append(sub_pos)
- if sub_pos.m_id in self.m_id_map:
- logging.error('PosSet.Load: duplicate id: %d',
- sub_pos.m_id)
- self.m_id_map[sub_pos.m_id] = sub_pos
-
- if sub_pos.m_name in self.m_name_map:
- logging.error('PosSet.Load duplicate name name: %s',
- sub_pos.m_name)
- self.m_name_map[sub_pos.m_name] = sub_pos
-
- def save(self):
- pass
diff --git a/modelscope/models/audio/tts/kantts/preprocess/script_convertor/core/script.py b/modelscope/models/audio/tts/kantts/preprocess/script_convertor/core/script.py
deleted file mode 100644
index 76ff1ffa..00000000
--- a/modelscope/models/audio/tts/kantts/preprocess/script_convertor/core/script.py
+++ /dev/null
@@ -1,35 +0,0 @@
-# Copyright (c) Alibaba, Inc. and its affiliates.
-
-import xml.etree.ElementTree as ET
-from xml.dom import minidom
-
-from .xml_obj import XmlObj
-
-
-class Script(XmlObj):
-
- def __init__(self, phoneset, posset):
- self.m_phoneset = phoneset
- self.m_posset = posset
- self.m_items = []
-
- def save(self, outputXMLPath):
- root = ET.Element('script')
-
- root.set('uttcount', str(len(self.m_items)))
- root.set('xmlns', 'http://schemas.alibaba-inc.com/tts')
- for item in self.m_items:
- item.save(root)
-
- xmlstr = minidom.parseString(ET.tostring(root)).toprettyxml(
- indent=' ', encoding='utf-8')
- with open(outputXMLPath, 'wb') as f:
- f.write(xmlstr)
-
- def save_meta_file(self):
- meta_lines = []
-
- for item in self.m_items:
- meta_lines.append(item.save_metafile())
-
- return meta_lines
diff --git a/modelscope/models/audio/tts/kantts/preprocess/script_convertor/core/script_item.py b/modelscope/models/audio/tts/kantts/preprocess/script_convertor/core/script_item.py
deleted file mode 100644
index a0e75c57..00000000
--- a/modelscope/models/audio/tts/kantts/preprocess/script_convertor/core/script_item.py
+++ /dev/null
@@ -1,40 +0,0 @@
-# Copyright (c) Alibaba, Inc. and its affiliates.
-
-import xml.etree.ElementTree as ET
-
-from .xml_obj import XmlObj
-
-
-class ScriptItem(XmlObj):
-
- def __init__(self, phoneset, posset):
- if phoneset is None or posset is None:
- raise Exception('ScriptItem.__init__: phoneset or posset is None')
- self.m_phoneset = phoneset
- self.m_posset = posset
-
- self.m_id = None
- self.m_text = ''
- self.m_scriptSentence_list = []
- self.m_status = None
-
- def load(self):
- pass
-
- def save(self, parent_node):
- utterance_node = ET.SubElement(parent_node, 'utterance')
- utterance_node.set('id', self.m_id)
-
- text_node = ET.SubElement(utterance_node, 'text')
- text_node.text = self.m_text
-
- for sentence in self.m_scriptSentence_list:
- sentence.save(utterance_node)
-
- def save_metafile(self):
- meta_line = self.m_id + '\t'
-
- for sentence in self.m_scriptSentence_list:
- meta_line += sentence.save_metafile()
-
- return meta_line
diff --git a/modelscope/models/audio/tts/kantts/preprocess/script_convertor/core/script_sentence.py b/modelscope/models/audio/tts/kantts/preprocess/script_convertor/core/script_sentence.py
deleted file mode 100644
index 473d34d2..00000000
--- a/modelscope/models/audio/tts/kantts/preprocess/script_convertor/core/script_sentence.py
+++ /dev/null
@@ -1,185 +0,0 @@
-# Copyright (c) Alibaba, Inc. and its affiliates.
-
-import xml.etree.ElementTree as ET
-
-from .xml_obj import XmlObj
-
-
-class WrittenSentence(XmlObj):
-
- def __init__(self, posset):
- self.m_written_word_list = []
- self.m_written_mark_list = []
- self.m_posset = posset
- self.m_align_list = []
- self.m_alignCursor = 0
- self.m_accompanyIndex = 0
- self.m_sequence = ''
- self.m_text = ''
-
- def add_host(self, writtenWord):
- self.m_written_word_list.append(writtenWord)
- self.m_align_list.append(self.m_alignCursor)
-
- def load_host(self):
- pass
-
- def save_host(self):
- pass
-
- def add_accompany(self, writtenMark):
- self.m_written_mark_list.append(writtenMark)
- self.m_alignCursor += 1
- self.m_accompanyIndex += 1
-
- def save_accompany(self):
- pass
-
- def load_accompany(self):
- pass
-
- # Get the mark span corresponding to specific spoken word
- def get_accompany_span(self, host_index):
- if host_index == -1:
- return (0, self.m_align_list[0])
-
- accompany_begin = self.m_align_list[host_index]
- accompany_end = (
- self.m_align_list[host_index + 1]
- if host_index + 1 < len(self.m_written_word_list) else len(
- self.m_written_mark_list))
-
- return (accompany_begin, accompany_end)
-
- def get_elements(self):
- accompany_begin, accompany_end = self.get_accompany_span(-1)
- res_lst = [
- self.m_written_mark_list[i]
- for i in range(accompany_begin, accompany_end)
- ]
-
- for j in range(len(self.m_written_word_list)):
- accompany_begin, accompany_end = self.get_accompany_span(j)
- res_lst.extend([self.m_written_word_list[j]])
- res_lst.extend([
- self.m_written_mark_list[i]
- for i in range(accompany_begin, accompany_end)
- ])
-
- return res_lst
-
- def build_sequence(self):
- self.m_sequence = ' '.join([str(ele) for ele in self.get_elements()])
-
- def build_text(self):
- self.m_text = ''.join([str(ele) for ele in self.get_elements()])
-
-
-class SpokenSentence(XmlObj):
-
- def __init__(self, phoneset):
- self.m_spoken_word_list = []
- self.m_spoken_mark_list = []
- self.m_phoneset = phoneset
- self.m_align_list = []
- self.m_alignCursor = 0
- self.m_accompanyIndex = 0
- self.m_sequence = ''
- self.m_text = ''
-
- def __len__(self):
- return len(self.m_spoken_word_list)
-
- def add_host(self, spokenWord):
- self.m_spoken_word_list.append(spokenWord)
- self.m_align_list.append(self.m_alignCursor)
-
- def save_host(self):
- pass
-
- def load_host(self):
- pass
-
- def add_accompany(self, spokenMark):
- self.m_spoken_mark_list.append(spokenMark)
- self.m_alignCursor += 1
- self.m_accompanyIndex += 1
-
- def save_accompany(self):
- pass
-
- # Get the mark span corresponding to specific spoken word
- def get_accompany_span(self, host_index):
- if host_index == -1:
- return (0, self.m_align_list[0])
-
- accompany_begin = self.m_align_list[host_index]
- accompany_end = (
- self.m_align_list[host_index + 1]
- if host_index + 1 < len(self.m_spoken_word_list) else len(
- self.m_spoken_mark_list))
-
- return (accompany_begin, accompany_end)
-
- def get_elements(self):
- accompany_begin, accompany_end = self.get_accompany_span(-1)
- res_lst = [
- self.m_spoken_mark_list[i]
- for i in range(accompany_begin, accompany_end)
- ]
-
- for j in range(len(self.m_spoken_word_list)):
- accompany_begin, accompany_end = self.get_accompany_span(j)
- res_lst.extend([self.m_spoken_word_list[j]])
- res_lst.extend([
- self.m_spoken_mark_list[i]
- for i in range(accompany_begin, accompany_end)
- ])
-
- return res_lst
-
- def load_accompany(self):
- pass
-
- def build_sequence(self):
- self.m_sequence = ' '.join([str(ele) for ele in self.get_elements()])
-
- def build_text(self):
- self.m_text = ''.join([str(ele) for ele in self.get_elements()])
-
- def save(self, parent_node):
- spoken_node = ET.SubElement(parent_node, 'spoken')
- spoken_node.set('wordcount', str(len(self.m_spoken_word_list)))
-
- text_node = ET.SubElement(spoken_node, 'text')
- text_node.text = self.m_sequence
-
- for word in self.m_spoken_word_list:
- word.save(spoken_node)
-
- def save_metafile(self):
- meta_line_list = [
- word.save_metafile() for word in self.m_spoken_word_list
- ]
-
- return ' '.join(meta_line_list)
-
-
-class ScriptSentence(XmlObj):
-
- def __init__(self, phoneset, posset):
- self.m_phoneset = phoneset
- self.m_posset = posset
- self.m_writtenSentence = WrittenSentence(posset)
- self.m_spokenSentence = SpokenSentence(phoneset)
- self.m_text = ''
-
- def save(self, parent_node):
- if len(self.m_spokenSentence) > 0:
- self.m_spokenSentence.save(parent_node)
-
- def save_metafile(self):
- if len(self.m_spokenSentence) > 0:
- return self.m_spokenSentence.save_metafile()
- else:
- return ''
diff --git a/modelscope/models/audio/tts/kantts/preprocess/script_convertor/core/script_word.py b/modelscope/models/audio/tts/kantts/preprocess/script_convertor/core/script_word.py
deleted file mode 100644
index 80d9c2fb..00000000
--- a/modelscope/models/audio/tts/kantts/preprocess/script_convertor/core/script_word.py
+++ /dev/null
@@ -1,120 +0,0 @@
-# Copyright (c) Alibaba, Inc. and its affiliates.
-
-import xml.etree.ElementTree as ET
-
-from .core_types import Language
-from .syllable import SyllableList
-from .xml_obj import XmlObj
-
-
-class WrittenWord(XmlObj):
-
- def __init__(self):
- self.m_name = None
- self.m_POS = None
-
- def __str__(self):
- return self.m_name
-
- def load(self):
- pass
-
- def save(self):
- pass
-
-
-class WrittenMark(XmlObj):
-
- def __init__(self):
- self.m_punctuation = None
-
- def __str__(self):
- return self.m_punctuation
-
- def load(self):
- pass
-
- def save(self):
- pass
-
-
-class SpokenWord(XmlObj):
-
- def __init__(self):
- self.m_name = None
- self.m_language = None
- self.m_syllable_list = []
- self.m_breakText = '1'
- self.m_POS = '0'
-
- def __str__(self):
- return self.m_name
-
- def load(self):
- pass
-
- def save(self, parent_node):
-
- word_node = ET.SubElement(parent_node, 'word')
-
- name_node = ET.SubElement(word_node, 'name')
- name_node.text = self.m_name
-
- if (len(self.m_syllable_list) > 0
- and self.m_syllable_list[0].m_language != Language.Neutral):
- language_node = ET.SubElement(word_node, 'lang')
- language_node.text = self.m_syllable_list[0].m_language.name
-
- SyllableList(self.m_syllable_list).save(word_node)
-
- break_node = ET.SubElement(word_node, 'break')
- break_node.text = self.m_breakText
-
- POS_node = ET.SubElement(word_node, 'POS')
- POS_node.text = self.m_POS
-
- return
-
- def save_metafile(self):
- word_phone_cnt = sum(
- [syllable.phone_count() for syllable in self.m_syllable_list])
- word_syllable_cnt = len(self.m_syllable_list)
- single_syllable_word = word_syllable_cnt == 1
- meta_line_list = []
-
- for idx, syll in enumerate(self.m_syllable_list):
- if word_phone_cnt == 1:
- word_pos = 'word_both'
- elif idx == 0:
- word_pos = 'word_begin'
- elif idx == len(self.m_syllable_list) - 1:
- word_pos = 'word_end'
- else:
- word_pos = 'word_middle'
- meta_line_list.append(
- syll.save_metafile(
- word_pos, single_syllable_word=single_syllable_word))
-
- if self.m_breakText != '0' and self.m_breakText is not None:
- meta_line_list.append('{{#{}$tone_none$s_none$word_none}}'.format(
- self.m_breakText))
-
- return ' '.join(meta_line_list)
-
-
-class SpokenMark(XmlObj):
-
- def __init__(self):
- self.m_breakLevel = None
-
- def break_level2text(self):
- return '#' + str(self.m_breakLevel.value)
-
- def __str__(self):
- return self.break_level2text()
-
- def load(self):
- pass
-
- def save(self):
- pass
diff --git a/modelscope/models/audio/tts/kantts/preprocess/script_convertor/core/syllable.py b/modelscope/models/audio/tts/kantts/preprocess/script_convertor/core/syllable.py
deleted file mode 100644
index 684976dd..00000000
--- a/modelscope/models/audio/tts/kantts/preprocess/script_convertor/core/syllable.py
+++ /dev/null
@@ -1,112 +0,0 @@
-# Copyright (c) Alibaba, Inc. and its affiliates.
-
-import xml.etree.ElementTree as ET
-
-from .xml_obj import XmlObj
-
-
-class Syllable(XmlObj):
-
- def __init__(self):
- self.m_phone_list = []
- self.m_tone = None
- self.m_language = None
- self.m_breaklevel = None
-
- def pronunciation_text(self):
- return ' '.join([str(phone) for phone in self.m_phone_list])
-
- def phone_count(self):
- return len(self.m_phone_list)
-
- def tone_text(self):
- return str(self.m_tone.value)
-
- def save(self):
- pass
-
- def load(self):
- pass
-
- def get_phone_meta(self,
- phone_name,
- word_pos,
- syll_pos,
- tone_text,
- single_syllable_word=False):
- # Special case: word with single syllable, the last phone's word_pos should be "word_end"
- if word_pos == 'word_begin' and syll_pos == 's_end' and single_syllable_word:
- word_pos = 'word_end'
- elif word_pos == 'word_begin' and syll_pos not in [
- 's_begin',
- 's_both',
- ]: # FIXME: keep accord with Engine logic
- word_pos = 'word_middle'
- elif word_pos == 'word_end' and syll_pos not in ['s_end', 's_both']:
- word_pos = 'word_middle'
- else:
- pass
-
- return '{{{}$tone{}${}${}}}'.format(phone_name, tone_text, syll_pos,
- word_pos)
-
- def save_metafile(self, word_pos, single_syllable_word=False):
- syllable_phone_cnt = len(self.m_phone_list)
-
- meta_line_list = []
-
- for idx, phone in enumerate(self.m_phone_list):
- if syllable_phone_cnt == 1:
- syll_pos = 's_both'
- elif idx == 0:
- syll_pos = 's_begin'
- elif idx == len(self.m_phone_list) - 1:
- syll_pos = 's_end'
- else:
- syll_pos = 's_middle'
- meta_line_list.append(
- self.get_phone_meta(
- phone,
- word_pos,
- syll_pos,
- self.tone_text(),
- single_syllable_word=single_syllable_word,
- ))
-
- return ' '.join(meta_line_list)
-
-
-class SyllableList(XmlObj):
-
- def __init__(self, syllables):
- self.m_syllable_list = syllables
-
- def __len__(self):
- return len(self.m_syllable_list)
-
- def __index__(self, index):
- return self.m_syllable_list[index]
-
- def pronunciation_text(self):
- return ' - '.join([
- syllable.pronunciation_text() for syllable in self.m_syllable_list
- ])
-
- def tone_text(self):
- return ''.join(
- [syllable.tone_text() for syllable in self.m_syllable_list])
-
- def save(self, parent_node):
- syllable_node = ET.SubElement(parent_node, 'syllable')
- syllable_node.set('syllcount', str(len(self.m_syllable_list)))
-
- phone_node = ET.SubElement(syllable_node, 'phone')
- phone_node.text = self.pronunciation_text()
-
- tone_node = ET.SubElement(syllable_node, 'tone')
- tone_node.text = self.tone_text()
-
- return
-
- def load(self):
- pass
diff --git a/modelscope/models/audio/tts/kantts/preprocess/script_convertor/core/syllable_formatter.py b/modelscope/models/audio/tts/kantts/preprocess/script_convertor/core/syllable_formatter.py
deleted file mode 100644
index dce2b65b..00000000
--- a/modelscope/models/audio/tts/kantts/preprocess/script_convertor/core/syllable_formatter.py
+++ /dev/null
@@ -1,322 +0,0 @@
-# Copyright (c) Alibaba, Inc. and its affiliates.
-
-import re
-
-from modelscope.utils.logger import get_logger
-from .core_types import Language, PhoneCVType, Tone
-from .syllable import Syllable
-from .utils import NgBreakPattern
-
-logging = get_logger()
-
-
-class DefaultSyllableFormatter:
-
- def __init__(self):
- return
-
- def format(self, phoneset, pronText, syllable_list):
- logging.warning('Using DefaultSyllableFormatter dry run: %s', pronText)
- return True
-
-
-RegexNg2en = re.compile(NgBreakPattern)
-RegexQingSheng = re.compile(r'([1-5]5)')
-RegexPron = re.compile(r'(?P[a-z]+)(?P[1-6])')
-
-
-class ZhCNSyllableFormatter:
-
- def __init__(self, sy2ph_map):
- self.m_sy2ph_map = sy2ph_map
-
- def normalize_pron(self, pronText):
- # Replace Qing Sheng
- newPron = pronText.replace('6', '2')
- newPron = re.sub(RegexQingSheng, '5', newPron)
-
- # FIXME(Jin): ng case overrides newPron
- match = RegexNg2en.search(newPron)
- if match:
- newPron = 'en' + match.group('break')
-
- return newPron
-
- def format(self, phoneset, pronText, syllable_list):
- if phoneset is None or syllable_list is None or pronText is None:
- logging.error('ZhCNSyllableFormatter.Format: invalid input')
- return False
- pronText = self.normalize_pron(pronText)
-
- if pronText in self.m_sy2ph_map:
- phone_list = self.m_sy2ph_map[pronText].split(' ')
- if len(phone_list) == 3:
- syll = Syllable()
- for phone in phone_list:
- syll.m_phone_list.append(phone)
- syll.m_tone = Tone.parse(
- pronText[-1]) # FIXME(Jin): assume tone is the last char
- syll.m_language = Language.ZhCN
- syllable_list.append(syll)
- return True
- else:
- logging.error(
- 'ZhCNSyllableFormatter.Format: invalid pronText: %s',
- pronText)
- return False
- else:
- logging.error(
- 'ZhCNSyllableFormatter.Format: syllable to phone map missing key: %s',
- pronText,
- )
- return False
-
-
-class PinYinSyllableFormatter:
-
- def __init__(self, sy2ph_map):
- self.m_sy2ph_map = sy2ph_map
-
- def normalize_pron(self, pronText):
- newPron = pronText.replace('6', '2')
- newPron = re.sub(RegexQingSheng, '5', newPron)
-
- # FIXME(Jin): ng case overrides newPron
- match = RegexNg2en.search(newPron)
- if match:
- newPron = 'en' + match.group('break')
-
- return newPron
-
- def format(self, phoneset, pronText, syllable_list):
- if phoneset is None or syllable_list is None or pronText is None:
- logging.error('PinYinSyllableFormatter.Format: invalid input')
- return False
- pronText = self.normalize_pron(pronText)
-
- match = RegexPron.search(pronText)
-
- if match:
- pron = match.group('Pron')
- tone = match.group('Tone')
- else:
- logging.error(
- 'PinYinSyllableFormatter.Format: pronunciation is not valid: %s',
- pronText,
- )
- return False
-
- if pron in self.m_sy2ph_map:
- phone_list = self.m_sy2ph_map[pron].split(' ')
- if len(phone_list) in [1, 2]:
- syll = Syllable()
- for phone in phone_list:
- syll.m_phone_list.append(phone)
- syll.m_tone = Tone.parse(tone)
- syll.m_language = Language.PinYin
- syllable_list.append(syll)
- return True
- else:
- logging.error(
- 'PinYinSyllableFormatter.Format: invalid phone: %s', pron)
- return False
- else:
- logging.error(
- 'PinYinSyllableFormatter.Format: syllable to phone map missing key: %s',
- pron,
- )
- return False
-
-
-class ZhHKSyllableFormatter:
-
- def __init__(self, sy2ph_map):
- self.m_sy2ph_map = sy2ph_map
-
- def format(self, phoneset, pronText, syllable_list):
- if phoneset is None or syllable_list is None or pronText is None:
- logging.error('ZhHKSyllableFormatter.Format: invalid input')
- return False
-
- match = RegexPron.search(pronText)
- if match:
- pron = match.group('Pron')
- tone = match.group('Tone')
- else:
- logging.error(
- 'ZhHKSyllableFormatter.Format: pronunciation is not valid: %s',
- pronText)
- return False
-
- if pron in self.m_sy2ph_map:
- phone_list = self.m_sy2ph_map[pron].split(' ')
- if len(phone_list) in [1, 2]:
- syll = Syllable()
- for phone in phone_list:
- syll.m_phone_list.append(phone)
- syll.m_tone = Tone.parse(tone)
- syll.m_language = Language.ZhHK
- syllable_list.append(syll)
- return True
- else:
- logging.error(
- 'ZhHKSyllableFormatter.Format: invalid phone: %s', pron)
- return False
- else:
- logging.error(
- 'ZhHKSyllableFormatter.Format: syllable to phone map missing key: %s',
- pron,
- )
- return False
-
-
-class WuuShanghaiSyllableFormatter:
-
- def __init__(self, sy2ph_map):
- self.m_sy2ph_map = sy2ph_map
-
- def format(self, phoneset, pronText, syllable_list):
- if phoneset is None or syllable_list is None or pronText is None:
- logging.error('WuuShanghaiSyllableFormatter.Format: invalid input')
- return False
-
- match = RegexPron.search(pronText)
- if match:
- pron = match.group('Pron')
- tone = match.group('Tone')
- else:
- logging.error(
- 'WuuShanghaiSyllableFormatter.Format: pronunciation is not valid: %s',
- pronText,
- )
- return False
-
- if pron in self.m_sy2ph_map:
- phone_list = self.m_sy2ph_map[pron].split(' ')
- if len(phone_list) in [1, 2]:
- syll = Syllable()
- for phone in phone_list:
- syll.m_phone_list.append(phone)
- syll.m_tone = Tone.parse(tone)
- syll.m_language = Language.WuuShanghai
- syllable_list.append(syll)
- return True
- else:
- logging.error(
- 'WuuShanghaiSyllableFormatter.Format: invalid phone: %s',
- pron)
- return False
- else:
- logging.error(
- 'WuuShanghaiSyllableFormatter.Format: syllable to phone map missing key: %s',
- pron,
- )
- return False
-
-
-class SichuanSyllableFormatter:
-
- def __init__(self, sy2ph_map):
- self.m_sy2ph_map = sy2ph_map
-
- def format(self, phoneset, pronText, syllable_list):
- if phoneset is None or syllable_list is None or pronText is None:
- logging.error('SichuanSyllableFormatter.Format: invalid input')
- return False
-
- match = RegexPron.search(pronText)
- if match:
- pron = match.group('Pron')
- tone = match.group('Tone')
- else:
- logging.error(
- 'SichuanSyllableFormatter.Format: pronunciation is not valid: %s',
- pronText,
- )
- return False
-
- if pron in self.m_sy2ph_map:
- phone_list = self.m_sy2ph_map[pron].split(' ')
- if len(phone_list) in [1, 2]:
- syll = Syllable()
- for phone in phone_list:
- syll.m_phone_list.append(phone)
- syll.m_tone = Tone.parse(tone)
- syll.m_language = Language.Sichuan
- syllable_list.append(syll)
- return True
- else:
- logging.error(
- 'SichuanSyllableFormatter.Format: invalid phone: %s', pron)
- return False
- else:
- logging.error(
- 'SichuanSyllableFormatter.Format: syllable to phone map missing key: %s',
- pron,
- )
- return False
-
-
-class EnXXSyllableFormatter:
-
- def __init__(self, language):
- self.m_f2t_map = None
- self.m_language = language
-
- def normalize_pron(self, pronText):
- newPron = pronText.replace('#', '.')
- newPron = (
- newPron.replace('03',
- '0').replace('13',
- '1').replace('23',
- '2').replace('3', ''))
- newPron = newPron.replace('2', '0')
-
- return newPron
-
- def format(self, phoneset, pronText, syllable_list):
- if phoneset is None or syllable_list is None or pronText is None:
- logging.error('EnXXSyllableFormatter.Format: invalid input')
- return False
- pronText = self.normalize_pron(pronText)
-
- syllables = [ele.strip() for ele in pronText.split('.')]
-
- for i in range(len(syllables)):
- syll = Syllable()
- syll.m_language = self.m_language
- syll.m_tone = Tone.parse('0')
-
- phones = re.split(r'[\s]+', syllables[i])
-
- for j in range(len(phones)):
- phoneName = phones[j].lower()
- toneName = '0'
-
- if '0' in phoneName or '1' in phoneName or '2' in phoneName:
- toneName = phoneName[-1]
- phoneName = phoneName[:-1]
-
- phoneName_lst = None
- if self.m_f2t_map is not None:
- phoneName_lst = self.m_f2t_map.get(phoneName, None)
- if phoneName_lst is None:
- phoneName_lst = [phoneName]
-
- for new_phoneName in phoneName_lst:
- phone_obj = phoneset.m_name_map.get(new_phoneName, None)
- if phone_obj is None:
- logging.error(
- 'EnXXSyllableFormatter.Format: phone %s not found',
- new_phoneName,
- )
- return False
- phone_obj.m_name = new_phoneName
- syll.m_phone_list.append(phone_obj)
- if phone_obj.m_cv_type == PhoneCVType.Vowel:
- syll.m_tone = Tone.parse(toneName)
-
- if j == len(phones) - 1:
- phone_obj.m_bnd = True
- syllable_list.append(syll)
- return True
diff --git a/modelscope/models/audio/tts/kantts/preprocess/script_convertor/core/utils.py b/modelscope/models/audio/tts/kantts/preprocess/script_convertor/core/utils.py
deleted file mode 100644
index 0b8bee0b..00000000
--- a/modelscope/models/audio/tts/kantts/preprocess/script_convertor/core/utils.py
+++ /dev/null
@@ -1,116 +0,0 @@
-# Copyright (c) Alibaba, Inc. and its affiliates.
-
-import codecs
-import re
-import unicodedata
-
-WordPattern = r'((?P\w+)(\(\w+\))?)'
-BreakPattern = r'(?P(\*?#(?P[0-4])))'
-MarkPattern = r'(?P[、,。!?:“”《》·])'
-POSPattern = r'(?P(\*?\|(?P[1-9])))'
-PhraseTonePattern = r'(?P(\*?%([L|H])))'
-
-NgBreakPattern = r'^ng(?P\d)'
-
-RegexWord = re.compile(WordPattern + r'\s*')
-RegexBreak = re.compile(BreakPattern + r'\s*')
-RegexID = re.compile(r'^(?P[a-zA-Z\-_0-9\.]+)\s*')
-RegexSentence = re.compile(r'({}|{}|{}|{}|{})\s*'.format(
- WordPattern, BreakPattern, MarkPattern, POSPattern, PhraseTonePattern))
-RegexForeignLang = re.compile(r'[A-Z@]')
-RegexSpace = re.compile(r'^\s*')
-RegexNeutralTone = re.compile(r'[1-5]5')
-
-
-def do_character_normalization(line):
- return unicodedata.normalize('NFKC', line)
-
-
-def do_prosody_text_normalization(line):
- tokens = line.split('\t')
- text = tokens[1]
- # Remove punctuations
- text = text.replace(u'。', ' ')
- text = text.replace(u'、', ' ')
- text = text.replace(u'“', ' ')
- text = text.replace(u'”', ' ')
- text = text.replace(u'‘', ' ')
- text = text.replace(u'’', ' ')
- text = text.replace(u'|', ' ')
- text = text.replace(u'《', ' ')
- text = text.replace(u'》', ' ')
- text = text.replace(u'【', ' ')
- text = text.replace(u'】', ' ')
- text = text.replace(u'—', ' ')
- text = text.replace(u'―', ' ')
- text = text.replace('.', ' ')
- text = text.replace('!', ' ')
- text = text.replace('?', ' ')
- text = text.replace('(', ' ')
- text = text.replace(')', ' ')
- text = text.replace('[', ' ')
- text = text.replace(']', ' ')
- text = text.replace('{', ' ')
- text = text.replace('}', ' ')
- text = text.replace('~', ' ')
- text = text.replace(':', ' ')
- text = text.replace(';', ' ')
- text = text.replace('+', ' ')
- text = text.replace(',', ' ')
- # text = text.replace('·', ' ')
- text = text.replace('"', ' ')
- text = text.replace(
- '-',
- '') # don't replace by space because compound word like two-year-old
- text = text.replace(
- "'", '') # don't replace by space because English word like that's
-
- # Replace break
- text = text.replace('/', '#2')
- text = text.replace('%', '#3')
- # Remove useless spaces surround #2 #3 #4
- text = re.sub(r'(#\d)[ ]+', r'\1', text)
- text = re.sub(r'[ ]+(#\d)', r'\1', text)
- # Replace space by #1
- text = re.sub('[ ]+', '#1', text)
-
- # Remove break at the end of the text
- text = re.sub(r'#\d$', '', text)
-
- # Add #1 between target language and foreign language
- text = re.sub(r"([a-zA-Z])([^a-zA-Z\d\#\s\'\%\/\-])", r'\1#1\2', text)
- text = re.sub(r"([^a-zA-Z\d\#\s\'\%\/\-])([a-zA-Z])", r'\1#1\2', text)
-
- return tokens[0] + '\t' + text
-
-
-def is_fp_line(line):
- fp_category_list = ['FP', 'I', 'N', 'Q']
- elements = line.strip().split(' ')
- res = True
- for ele in elements:
- if ele not in fp_category_list:
- res = False
- break
- return res
-
-
-def format_prosody(src_prosody):
- formatted_lines = []
- with codecs.open(src_prosody, 'r', 'utf-8') as f:
- lines = f.readlines()
-
- idx = 0
- while idx < len(lines):
- line = do_character_normalization(lines[idx])
-
- if len(line.strip().split('\t')) == 2:
- line = do_prosody_text_normalization(line)
- else:
- fp_enable = is_fp_line(line)
- if fp_enable:
- idx += 3
- continue
- formatted_lines.append(line)
- idx += 1
- return formatted_lines
diff --git a/modelscope/models/audio/tts/kantts/preprocess/script_convertor/core/xml_obj.py b/modelscope/models/audio/tts/kantts/preprocess/script_convertor/core/xml_obj.py
deleted file mode 100644
index 21f05e10..00000000
--- a/modelscope/models/audio/tts/kantts/preprocess/script_convertor/core/xml_obj.py
+++ /dev/null
@@ -1,19 +0,0 @@
-# Copyright (c) Alibaba, Inc. and its affiliates.
-
-
-class XmlObj:
-
- def __init__(self):
- pass
-
- def load(self):
- pass
-
- def save(self):
- pass
-
- def load_data(self):
- pass
-
- def save_data(self):
- pass
diff --git a/modelscope/models/audio/tts/kantts/preprocess/script_convertor/text_script_convertor.py b/modelscope/models/audio/tts/kantts/preprocess/script_convertor/text_script_convertor.py
deleted file mode 100644
index 8bb0f45a..00000000
--- a/modelscope/models/audio/tts/kantts/preprocess/script_convertor/text_script_convertor.py
+++ /dev/null
@@ -1,500 +0,0 @@
-# Copyright (c) Alibaba, Inc. and its affiliates.
-
-import argparse
-import os
-import re
-
-from bitstring import BitArray
-from tqdm import tqdm
-
-from modelscope.utils.logger import get_logger
-from .core.core_types import BreakLevel, Language
-from .core.phone_set import PhoneSet
-from .core.pos_set import PosSet
-from .core.script import Script
-from .core.script_item import ScriptItem
-from .core.script_sentence import ScriptSentence
-from .core.script_word import SpokenMark, SpokenWord, WrittenMark, WrittenWord
-from .core.utils import (RegexForeignLang, RegexID, RegexSentence,
- format_prosody)
-
-from .core.utils import RegexNeutralTone # isort:skip
-
-from .core.syllable_formatter import ( # isort:skip
- EnXXSyllableFormatter, PinYinSyllableFormatter, # isort:skip
- SichuanSyllableFormatter, # isort:skip
- WuuShanghaiSyllableFormatter, ZhCNSyllableFormatter, # isort:skip
- ZhHKSyllableFormatter) # isort:skip
-
-logging = get_logger()
-
-
-class TextScriptConvertor:
-
- def __init__(
- self,
- phoneset_path,
- posset_path,
- target_lang,
- foreign_lang,
- f2t_map_path,
- s2p_map_path,
- m_emo_tag_path,
- m_speaker,
- ):
- self.m_f2p_map = {}
- self.m_s2p_map = {}
- self.m_phoneset = PhoneSet(phoneset_path)
- self.m_posset = PosSet(posset_path)
- self.m_target_lang = Language.parse(target_lang)
- self.m_foreign_lang = Language.parse(foreign_lang)
- self.m_emo_tag_path = m_emo_tag_path
- self.m_speaker = m_speaker
-
- self.load_f2tmap(f2t_map_path)
- self.load_s2pmap(s2p_map_path)
-
- self.m_target_lang_syllable_formatter = self.init_syllable_formatter(
- self.m_target_lang)
- self.m_foreign_lang_syllable_formatter = self.init_syllable_formatter(
- self.m_foreign_lang)
-
- def parse_sentence(self, sentence, line_num):
- script_item = ScriptItem(self.m_phoneset, self.m_posset)
- script_sentence = ScriptSentence(self.m_phoneset, self.m_posset)
- script_item.m_scriptSentence_list.append(script_sentence)
-
- written_sentence = script_sentence.m_writtenSentence
- spoken_sentence = script_sentence.m_spokenSentence
-
- position = 0
-
- sentence = sentence.strip()
-
- # Get ID
- match = re.search(RegexID, sentence)
- if match is None:
- logging.error(
- 'TextScriptConvertor.parse_sentence:invalid line: %s,\
- line ID is needed',
- line_num,
- )
- return None
- else:
- sentence_id = match.group('ID')
- script_item.m_id = sentence_id
- position += match.end()
-
- prevSpokenWord = SpokenWord()
-
- prevWord = False
- lastBreak = False
-
- for m in re.finditer(RegexSentence, sentence[position:]):
- if m is None:
- logging.error(
- 'TextScriptConvertor.parse_sentence:\
- invalid line: %s, there is no matched pattern',
- line_num,
- )
- return None
-
- if m.group('Word') is not None:
- wordName = m.group('Word')
- written_word = WrittenWord()
- written_word.m_name = wordName
- written_sentence.add_host(written_word)
-
- spoken_word = SpokenWord()
- spoken_word.m_name = wordName
- prevSpokenWord = spoken_word
- prevWord = True
- lastBreak = False
- elif m.group('Break') is not None:
- breakText = m.group('BreakLevel')
- if len(breakText) == 0:
- breakLevel = BreakLevel.L1
- else:
- breakLevel = BreakLevel.parse(breakText)
- if prevWord:
- prevSpokenWord.m_breakText = breakText
- spoken_sentence.add_host(prevSpokenWord)
-
- if breakLevel != BreakLevel.L1:
- spokenMark = SpokenMark()
- spokenMark.m_breakLevel = breakLevel
- spoken_sentence.add_accompany(spokenMark)
-
- lastBreak = True
-
- elif m.group('PhraseTone') is not None:
- pass
- elif m.group('POS') is not None:
- POSClass = m.group('POSClass')
- if prevWord:
- prevSpokenWord.m_pos = POSClass
- prevWord = False
- elif m.group('Mark') is not None:
- markText = m.group('Mark')
-
- writtenMark = WrittenMark()
- writtenMark.m_punctuation = markText
- written_sentence.add_accompany(writtenMark)
- else:
- logging.error(
- 'TextScriptConvertor.parse_sentence:\
- invalid line: %s, matched pattern is unrecognized',
- line_num,
- )
- return None
-
- if not lastBreak:
- prevSpokenWord.m_breakText = '4'
- spoken_sentence.add_host(prevSpokenWord)
-
- spoken_word_cnt = len(spoken_sentence.m_spoken_word_list)
- spoken_mark_cnt = len(spoken_sentence.m_spoken_mark_list)
- if (spoken_word_cnt > 0
- and spoken_sentence.m_align_list[spoken_word_cnt - 1]
- == spoken_mark_cnt):
- spokenMark = SpokenMark()
- spokenMark.m_breakLevel = BreakLevel.L4
- spoken_sentence.add_accompany(spokenMark)
-
- written_sentence.build_sequence()
- spoken_sentence.build_sequence()
- written_sentence.build_text()
- spoken_sentence.build_text()
-
- script_sentence.m_text = written_sentence.m_text
- script_item.m_text = written_sentence.m_text
-
- return script_item
-
- def format_syllable(self, pron, syllable_list):
- isForeign = RegexForeignLang.search(pron) is not None
- if self.m_foreign_lang_syllable_formatter is not None and isForeign:
- return self.m_foreign_lang_syllable_formatter.format(
- self.m_phoneset, pron, syllable_list)
- else:
- return self.m_target_lang_syllable_formatter.format(
- self.m_phoneset, pron, syllable_list)
-
- def get_word_prons(self, pronText):
- prons = pronText.split('/')
- res = []
-
- for pron in prons:
- if re.search(RegexForeignLang, pron):
- res.append(pron.strip())
- else:
- res.extend(pron.strip().split(' '))
- return res
-
- def is_erhuayin(self, pron):
- pron = RegexNeutralTone.sub('5', pron)
- pron = pron[:-1]
-
- return pron[-1] == 'r' and pron != 'er'
-
- def parse_pronunciation(self, script_item, pronunciation, line_num):
- spoken_sentence = script_item.m_scriptSentence_list[0].m_spokenSentence
-
- wordProns = self.get_word_prons(pronunciation)
-
- wordIndex = 0
- pronIndex = 0
- succeed = True
-
- while pronIndex < len(wordProns):
- language = Language.Neutral
- syllable_list = []
-
- pron = wordProns[pronIndex].strip()
-
- succeed = self.format_syllable(pron, syllable_list)
- if not succeed:
- logging.error(
- 'TextScriptConvertor.parse_pronunciation:\
- invalid line: %s, error pronunciation: %s,\
- syllable format error',
- line_num,
- pron,
- )
- return False
- language = syllable_list[0].m_language
-
- if wordIndex < len(spoken_sentence.m_spoken_word_list):
- if language in [Language.EnGB, Language.EnUS]:
- spoken_sentence.m_spoken_word_list[
- wordIndex].m_syllable_list.extend(syllable_list)
- wordIndex += 1
- pronIndex += 1
- elif language in [
- Language.ZhCN,
- Language.PinYin,
- Language.ZhHK,
- Language.WuuShanghai,
- Language.Sichuan,
- ]:
- charCount = len(
- spoken_sentence.m_spoken_word_list[wordIndex].m_name)
- if (language in [
- Language.ZhCN, Language.PinYin, Language.Sichuan
- ] and self.is_erhuayin(pron) and '儿' in spoken_sentence.
- m_spoken_word_list[wordIndex].m_name):
- spoken_sentence.m_spoken_word_list[
- wordIndex].m_name = spoken_sentence.m_spoken_word_list[
- wordIndex].m_name.replace('儿', '')
- charCount -= 1
- if charCount == 1:
- spoken_sentence.m_spoken_word_list[
- wordIndex].m_syllable_list.extend(syllable_list)
- wordIndex += 1
- pronIndex += 1
- else:
- # FIXME(Jin): Just skip the first char then match the rest char.
- i = 1
- while i >= 1 and i < charCount:
- pronIndex += 1
- if pronIndex < len(wordProns):
- pron = wordProns[pronIndex].strip()
- succeed = self.format_syllable(
- pron, syllable_list)
- if not succeed:
- logging.error(
- 'TextScriptConvertor.parse_pronunciation: invalid line: %s, \
- error pronunciation: %s, syllable format error',
- line_num,
- pron,
- )
- return False
- if (language in [
- Language.ZhCN,
- Language.PinYin,
- Language.Sichuan,
- ] and self.is_erhuayin(pron)
- and '儿' in spoken_sentence.
- m_spoken_word_list[wordIndex].m_name):
- spoken_sentence.m_spoken_word_list[
- wordIndex].m_name = spoken_sentence.m_spoken_word_list[
- wordIndex].m_name.replace('儿', '')
- charCount -= 1
- else:
- logging.error(
- 'TextScriptConvertor.parse_pronunciation: invalid line: %s, \
- error pronunciation: %s, Word count mismatch with Pron count',
- line_num,
- pron,
- )
- return False
- i += 1
- spoken_sentence.m_spoken_word_list[
- wordIndex].m_syllable_list.extend(syllable_list)
- wordIndex += 1
- pronIndex += 1
- else:
- logging.error(
- 'TextScriptConvertor.parse_pronunciation: invalid line: %s, \
- unsupported language: %s',
- line_num,
- language.name,
- )
- return False
-
- else:
- logging.error(
- 'TextScriptConvertor.parse_pronunciation: invalid line: %s, \
- error pronunciation: %s, word index is out of range',
- line_num,
- pron,
- )
- return False
- if pronIndex != len(wordProns):
- logging.error(
- 'TextScriptConvertor.parse_pronunciation: invalid line: %s, \
- error pronunciation: %s, pron count mismatch with word count',
- line_num,
- pron,
- )
- return False
-
- if wordIndex != len(spoken_sentence.m_spoken_word_list):
- logging.error(
- 'TextScriptConvertor.parse_pronunciation: invalid line: %s, \
- error pronunciation: %s, word count mismatch with word index',
- line_num,
- pron,
- )
- return False
-
- return True
-
- def load_f2tmap(self, file_path):
- with open(file_path, 'r') as f:
- for line in f.readlines():
- line = line.strip()
- elements = line.split('\t')
- if len(elements) != 2:
- logging.error(
- 'TextScriptConvertor.LoadF2TMap: invalid line: %s',
- line)
- continue
- key = elements[0]
- value = elements[1]
- value_list = value.split(' ')
- if key in self.m_f2p_map:
- logging.error(
- 'TextScriptConvertor.LoadF2TMap: duplicate key: %s',
- key)
- self.m_f2p_map[key] = value_list
-
- def load_s2pmap(self, file_path):
- with open(file_path, 'r') as f:
- for line in f.readlines():
- line = line.strip()
- elements = line.split('\t')
- if len(elements) != 2:
- logging.error(
- 'TextScriptConvertor.LoadS2PMap: invalid line: %s',
- line)
- continue
- key = elements[0]
- value = elements[1]
- if key in self.m_s2p_map:
- logging.error(
- 'TextScriptConvertor.LoadS2PMap: duplicate key: %s',
- key)
- self.m_s2p_map[key] = value
-
- def init_syllable_formatter(self, targetLang):
- if targetLang == Language.ZhCN:
- if len(self.m_s2p_map) == 0:
- logging.error(
- 'TextScriptConvertor.InitSyllableFormatter: ZhCN syllable to phone map is empty'
- )
- return None
- return ZhCNSyllableFormatter(self.m_s2p_map)
- elif targetLang == Language.PinYin:
- if len(self.m_s2p_map) == 0:
- logging.error(
- 'TextScriptConvertor.InitSyllableFormatter: PinYin syllable to phone map is empty'
- )
- return None
- return PinYinSyllableFormatter(self.m_s2p_map)
- elif targetLang == Language.ZhHK:
- if len(self.m_s2p_map) == 0:
- logging.error(
- 'TextScriptConvertor.InitSyllableFormatter: ZhHK syllable to phone map is empty'
- )
- return None
- return ZhHKSyllableFormatter(self.m_s2p_map)
- elif targetLang == Language.WuuShanghai:
- if len(self.m_s2p_map) == 0:
- logging.error(
- 'TextScriptConvertor.InitSyllableFormatter: WuuShanghai syllable to phone map is empty'
- )
- return None
- return WuuShanghaiSyllableFormatter(self.m_s2p_map)
- elif targetLang == Language.Sichuan:
- if len(self.m_s2p_map) == 0:
- logging.error(
- 'TextScriptConvertor.InitSyllableFormatter: Sichuan syllable to phone map is empty'
- )
- return None
- return SichuanSyllableFormatter(self.m_s2p_map)
- elif targetLang == Language.EnGB:
- formatter = EnXXSyllableFormatter(Language.EnGB)
- if len(self.m_f2p_map) != 0:
- formatter.m_f2t_map = self.m_f2p_map
- return formatter
- elif targetLang == Language.EnUS:
- formatter = EnXXSyllableFormatter(Language.EnUS)
- if len(self.m_f2p_map) != 0:
- formatter.m_f2t_map = self.m_f2p_map
- return formatter
- else:
- logging.error(
- 'TextScriptConvertor.InitSyllableFormatter: unsupported language: %s',
- targetLang,
- )
- return None
-
- def process(self, textScriptPath, outputXMLPath, outputMetafile):
- script = Script(self.m_phoneset, self.m_posset)
- formatted_lines = format_prosody(textScriptPath)
- line_num = 0
- for line in tqdm(formatted_lines):
- if line_num % 2 == 0:
- sentence = line.strip()
- item = self.parse_sentence(sentence, line_num)
- else:
- if item is not None:
- pronunciation = line.strip()
- res = self.parse_pronunciation(item, pronunciation,
- line_num)
- if res:
- script.m_items.append(item)
-
- line_num += 1
-
- script.save(outputXMLPath)
- logging.info('TextScriptConvertor.process:\nSave script to: %s',
- outputXMLPath)
-
- meta_lines = script.save_meta_file()
- emo = 'emotion_neutral'
- speaker = self.m_speaker
-
- meta_lines_tagged = []
- for line in meta_lines:
- line_id, line_text = line.split('\t')
- syll_items = line_text.split(' ')
- syll_items_tagged = []
- for syll_item in syll_items:
- syll_item_tagged = syll_item[:-1] + '$' + emo + '$' + speaker + '}'
- syll_items_tagged.append(syll_item_tagged)
- meta_lines_tagged.append(line_id + '\t'
- + ' '.join(syll_items_tagged))
- with open(outputMetafile, 'w') as f:
- for line in meta_lines_tagged:
- f.write(line + '\n')
-
- logging.info('TextScriptConvertor.process:\nSave metafile to: %s',
- outputMetafile)
-
- @staticmethod
- def turn_text_into_bytes(plain_text_path, output_meta_file_path, speaker):
- meta_lines = []
- with open(plain_text_path, 'r') as in_file:
- for text_line in in_file:
- [sentence_id, sentence] = text_line.strip().split('\t')
- sequence = []
- for character in sentence:
- hex_string = character.encode('utf-8').hex()
- i = 0
- while i < len(hex_string):
- byte_hex = hex_string[i:i + 2]
- bit_array = BitArray(hex=byte_hex)
- integer = bit_array.uint
- if integer > 255:
- logging.error(
- 'TextScriptConverter.turn_text_into_bytes: invalid byte conversion in sentence {} \
- character {}: (uint) {} - (hex) {}'.
- format(
- sentence_id,
- character,
- integer,
- character.encode('utf-8').hex(),
- ))
- continue
- sequence.append('{{{}$emotion_neutral${}}}'.format(
- integer, speaker))
- i += 2
- if sequence[-1][1:].split('$')[0] not in ['33', '46', '63']:
- sequence.append(
- '{{46$emotion_neutral${}}}'.format(speaker))
- meta_lines.append('{}\t{}\n'.format(sentence_id,
- ' '.join(sequence)))
- with open(output_meta_file_path, 'w') as out_file:
- out_file.writelines(meta_lines)
diff --git a/modelscope/models/audio/tts/kantts/train/loss.py b/modelscope/models/audio/tts/kantts/train/loss.py
deleted file mode 100644
index f56c56b0..00000000
--- a/modelscope/models/audio/tts/kantts/train/loss.py
+++ /dev/null
@@ -1,562 +0,0 @@
-# Copyright (c) Alibaba, Inc. and its affiliates.
-
-import torch
-import torch.nn.functional as F
-
-from modelscope.models.audio.tts.kantts.models.utils import \
- get_mask_from_lengths
-from modelscope.models.audio.tts.kantts.utils.audio_torch import (
- MelSpectrogram, stft)
-
-
-class MelReconLoss(torch.nn.Module):
-
- def __init__(self, loss_type='mae'):
- super(MelReconLoss, self).__init__()
- self.loss_type = loss_type
- if loss_type == 'mae':
- self.criterion = torch.nn.L1Loss(reduction='none')
- elif loss_type == 'mse':
- self.criterion = torch.nn.MSELoss(reduction='none')
- else:
- raise ValueError('Unknown loss type: {}'.format(loss_type))
-
- def forward(self,
- output_lengths,
- mel_targets,
- dec_outputs,
- postnet_outputs=None):
- output_masks = get_mask_from_lengths(
- output_lengths, max_len=mel_targets.size(1))
- output_masks = ~output_masks
- valid_outputs = output_masks.sum()
-
- mel_loss_ = torch.sum(
- self.criterion(mel_targets, dec_outputs)
- * output_masks.unsqueeze(-1)) / (
- valid_outputs * mel_targets.size(-1))
-
- if postnet_outputs is not None:
- mel_loss = torch.sum(
- self.criterion(mel_targets, postnet_outputs)
- * output_masks.unsqueeze(-1)) / (
- valid_outputs * mel_targets.size(-1))
- else:
- mel_loss = 0.0
-
- return mel_loss_, mel_loss
-
-
-class ProsodyReconLoss(torch.nn.Module):
-
- def __init__(self, loss_type='mae'):
- super(ProsodyReconLoss, self).__init__()
- self.loss_type = loss_type
- if loss_type == 'mae':
- self.criterion = torch.nn.L1Loss(reduction='none')
- elif loss_type == 'mse':
- self.criterion = torch.nn.MSELoss(reduction='none')
- else:
- raise ValueError('Unknown loss type: {}'.format(loss_type))
-
- def forward(
- self,
- input_lengths,
- duration_targets,
- pitch_targets,
- energy_targets,
- log_duration_predictions,
- pitch_predictions,
- energy_predictions,
- ):
- input_masks = get_mask_from_lengths(
- input_lengths, max_len=duration_targets.size(1))
- input_masks = ~input_masks
- valid_inputs = input_masks.sum()
-
- dur_loss = (
- torch.sum(
- self.criterion(
- torch.log(duration_targets.float() + 1),
- log_duration_predictions) * input_masks) / valid_inputs)
- pitch_loss = (
- torch.sum(
- self.criterion(pitch_targets, pitch_predictions) * input_masks)
- / valid_inputs)
- energy_loss = (
- torch.sum(
- self.criterion(energy_targets, energy_predictions)
- * input_masks) / valid_inputs)
-
- return dur_loss, pitch_loss, energy_loss
-
-
-class FpCELoss(torch.nn.Module):
-
- def __init__(self, loss_type='ce', weight=[1, 4, 4, 8]):
- super(FpCELoss, self).__init__()
- self.loss_type = loss_type
- weight_ce = torch.FloatTensor(weight).cuda()
- self.criterion = torch.nn.CrossEntropyLoss(
- weight=weight_ce, reduction='none')
-
- def forward(self, input_lengths, fp_pd, fp_label):
- input_masks = get_mask_from_lengths(
- input_lengths, max_len=fp_label.size(1))
- input_masks = ~input_masks
- valid_inputs = input_masks.sum()
-
- fp_loss = (
- torch.sum(
- self.criterion(fp_pd.transpose(2, 1), fp_label) * input_masks)
- / valid_inputs)
-
- return fp_loss
-
-
-class GeneratorAdversarialLoss(torch.nn.Module):
- """Generator adversarial loss module."""
-
- def __init__(
- self,
- average_by_discriminators=True,
- loss_type='mse',
- ):
- """Initialize GeneratorAversarialLoss module."""
- super().__init__()
- self.average_by_discriminators = average_by_discriminators
- assert loss_type in ['mse', 'hinge'], f'{loss_type} is not supported.'
- if loss_type == 'mse':
- self.criterion = self._mse_loss
- else:
- self.criterion = self._hinge_loss
-
- def forward(self, outputs):
- """Calcualate generator adversarial loss.
-
- Args:
- outputs (Tensor or list): Discriminator outputs or list of
- discriminator outputs.
-
- Returns:
- Tensor: Generator adversarial loss value.
-
- """
- if isinstance(outputs, (tuple, list)):
- adv_loss = 0.0
- for i, outputs_ in enumerate(outputs):
- adv_loss += self.criterion(outputs_)
- if self.average_by_discriminators:
- adv_loss /= i + 1
- else:
- adv_loss = self.criterion(outputs)
-
- return adv_loss
-
- def _mse_loss(self, x):
- return F.mse_loss(x, x.new_ones(x.size()))
-
- def _hinge_loss(self, x):
- return -x.mean()
-
-
-class DiscriminatorAdversarialLoss(torch.nn.Module):
- """Discriminator adversarial loss module."""
-
- def __init__(
- self,
- average_by_discriminators=True,
- loss_type='mse',
- ):
- """Initialize DiscriminatorAversarialLoss module."""
- super().__init__()
- self.average_by_discriminators = average_by_discriminators
- assert loss_type in ['mse', 'hinge'], f'{loss_type} is not supported.'
- if loss_type == 'mse':
- self.fake_criterion = self._mse_fake_loss
- self.real_criterion = self._mse_real_loss
- else:
- self.fake_criterion = self._hinge_fake_loss
- self.real_criterion = self._hinge_real_loss
-
- def forward(self, outputs_hat, outputs):
- """Calcualate discriminator adversarial loss.
-
- Args:
- outputs_hat (Tensor or list): Discriminator outputs or list of
- discriminator outputs calculated from generator outputs.
- outputs (Tensor or list): Discriminator outputs or list of
- discriminator outputs calculated from groundtruth.
-
- Returns:
- Tensor: Discriminator real loss value.
- Tensor: Discriminator fake loss value.
-
- """
- if isinstance(outputs, (tuple, list)):
- real_loss = 0.0
- fake_loss = 0.0
- for i, (outputs_hat_,
- outputs_) in enumerate(zip(outputs_hat, outputs)):
- if isinstance(outputs_hat_, (tuple, list)):
- # NOTE(kan-bayashi): case including feature maps
- outputs_hat_ = outputs_hat_[-1]
- outputs_ = outputs_[-1]
- real_loss += self.real_criterion(outputs_)
- fake_loss += self.fake_criterion(outputs_hat_)
- if self.average_by_discriminators:
- fake_loss /= i + 1
- real_loss /= i + 1
- else:
- real_loss = self.real_criterion(outputs)
- fake_loss = self.fake_criterion(outputs_hat)
-
- return real_loss, fake_loss
-
- def _mse_real_loss(self, x):
- return F.mse_loss(x, x.new_ones(x.size()))
-
- def _mse_fake_loss(self, x):
- return F.mse_loss(x, x.new_zeros(x.size()))
-
- def _hinge_real_loss(self, x):
- return -torch.mean(torch.min(x - 1, x.new_zeros(x.size())))
-
- def _hinge_fake_loss(self, x):
- return -torch.mean(torch.min(-x - 1, x.new_zeros(x.size())))
-
-
-class FeatureMatchLoss(torch.nn.Module):
- """Feature matching loss module."""
-
- def __init__(
- self,
- average_by_layers=True,
- average_by_discriminators=True,
- ):
- """Initialize FeatureMatchLoss module."""
- super().__init__()
- self.average_by_layers = average_by_layers
- self.average_by_discriminators = average_by_discriminators
-
- def forward(self, feats_hat, feats):
- """Calcualate feature matching loss.
-
- Args:
- feats_hat (list): List of list of discriminator outputs
- calcuated from generater outputs.
- feats (list): List of list of discriminator outputs
- calcuated from groundtruth.
-
- Returns:
- Tensor: Feature matching loss value.
-
- """
- feat_match_loss = 0.0
- for i, (feats_hat_, feats_) in enumerate(zip(feats_hat, feats)):
- feat_match_loss_ = 0.0
- for j, (feat_hat_, feat_) in enumerate(zip(feats_hat_, feats_)):
- feat_match_loss_ += F.l1_loss(feat_hat_, feat_.detach())
- if self.average_by_layers:
- feat_match_loss_ /= j + 1
- feat_match_loss += feat_match_loss_
- if self.average_by_discriminators:
- feat_match_loss /= i + 1
-
- return feat_match_loss
-
-
-class MelSpectrogramLoss(torch.nn.Module):
- """Mel-spectrogram loss."""
-
- def __init__(
- self,
- fs=22050,
- fft_size=1024,
- hop_size=256,
- win_length=None,
- window='hann',
- num_mels=80,
- fmin=80,
- fmax=7600,
- center=True,
- normalized=False,
- onesided=True,
- eps=1e-10,
- log_base=10.0,
- ):
- """Initialize Mel-spectrogram loss."""
- super().__init__()
- self.mel_spectrogram = MelSpectrogram(
- fs=fs,
- fft_size=fft_size,
- hop_size=hop_size,
- win_length=win_length,
- window=window,
- num_mels=num_mels,
- fmin=fmin,
- fmax=fmax,
- center=center,
- normalized=normalized,
- onesided=onesided,
- eps=eps,
- log_base=log_base,
- )
-
- def forward(self, y_hat, y):
- """Calculate Mel-spectrogram loss.
-
- Args:
- y_hat (Tensor): Generated single tensor (B, 1, T).
- y (Tensor): Groundtruth single tensor (B, 1, T).
-
- Returns:
- Tensor: Mel-spectrogram loss value.
-
- """
- mel_hat = self.mel_spectrogram(y_hat)
- mel = self.mel_spectrogram(y)
- mel_loss = F.l1_loss(mel_hat, mel)
-
- return mel_loss
-
-
-class SpectralConvergenceLoss(torch.nn.Module):
- """Spectral convergence loss module."""
-
- def __init__(self):
- """Initilize spectral convergence loss module."""
- super(SpectralConvergenceLoss, self).__init__()
-
- def forward(self, x_mag, y_mag):
- """Calculate forward propagation.
-
- Args:
- x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
- y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
-
- Returns:
- Tensor: Spectral convergence loss value.
-
- """
- return torch.norm(y_mag - x_mag, p='fro') / torch.norm(y_mag, p='fro')
-
-
-class LogSTFTMagnitudeLoss(torch.nn.Module):
- """Log STFT magnitude loss module."""
-
- def __init__(self):
- """Initilize los STFT magnitude loss module."""
- super(LogSTFTMagnitudeLoss, self).__init__()
-
- def forward(self, x_mag, y_mag):
- """Calculate forward propagation.
-
- Args:
- x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
- y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
-
- Returns:
- Tensor: Log STFT magnitude loss value.
-
- """
- return F.l1_loss(torch.log(y_mag), torch.log(x_mag))
-
-
-class STFTLoss(torch.nn.Module):
- """STFT loss module."""
-
- def __init__(self,
- fft_size=1024,
- shift_size=120,
- win_length=600,
- window='hann_window'):
- """Initialize STFT loss module."""
- super(STFTLoss, self).__init__()
- self.fft_size = fft_size
- self.shift_size = shift_size
- self.win_length = win_length
- self.spectral_convergence_loss = SpectralConvergenceLoss()
- self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss()
- # NOTE(kan-bayashi): Use register_buffer to fix #223
- self.register_buffer('window', getattr(torch, window)(win_length))
-
- def forward(self, x, y):
- """Calculate forward propagation.
-
- Args:
- x (Tensor): Predicted signal (B, T).
- y (Tensor): Groundtruth signal (B, T).
-
- Returns:
- Tensor: Spectral convergence loss value.
- Tensor: Log STFT magnitude loss value.
-
- """
- x_mag = stft(x, self.fft_size, self.shift_size, self.win_length,
- self.window)
- y_mag = stft(y, self.fft_size, self.shift_size, self.win_length,
- self.window)
- sc_loss = self.spectral_convergence_loss(x_mag, y_mag)
- mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag)
-
- return sc_loss, mag_loss
-
-
-class MultiResolutionSTFTLoss(torch.nn.Module):
- """Multi resolution STFT loss module."""
-
- def __init__(
- self,
- fft_sizes=[1024, 2048, 512],
- hop_sizes=[120, 240, 50],
- win_lengths=[600, 1200, 240],
- window='hann_window',
- ):
- """Initialize Multi resolution STFT loss module.
-
- Args:
- fft_sizes (list): List of FFT sizes.
- hop_sizes (list): List of hop sizes.
- win_lengths (list): List of window lengths.
- window (str): Window function type.
-
- """
- super(MultiResolutionSTFTLoss, self).__init__()
- assert len(fft_sizes) == len(hop_sizes) == len(win_lengths)
- self.stft_losses = torch.nn.ModuleList()
- for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths):
- self.stft_losses += [STFTLoss(fs, ss, wl, window)]
-
- def forward(self, x, y):
- """Calculate forward propagation.
-
- Args:
- x (Tensor): Predicted signal (B, T) or (B, #subband, T).
- y (Tensor): Groundtruth signal (B, T) or (B, #subband, T).
-
- Returns:
- Tensor: Multi resolution spectral convergence loss value.
- Tensor: Multi resolution log STFT magnitude loss value.
-
- """
- if len(x.shape) == 3:
- x = x.view(-1, x.size(2)) # (B, C, T) -> (B x C, T)
- y = y.view(-1, y.size(2)) # (B, C, T) -> (B x C, T)
- sc_loss = 0.0
- mag_loss = 0.0
- for f in self.stft_losses:
- sc_l, mag_l = f(x, y)
- sc_loss += sc_l
- mag_loss += mag_l
- sc_loss /= len(self.stft_losses)
- mag_loss /= len(self.stft_losses)
-
- return sc_loss, mag_loss
-
-
-class SeqCELoss(torch.nn.Module):
-
- def __init__(self, loss_type='ce'):
- super(SeqCELoss, self).__init__()
- self.loss_type = loss_type
- self.criterion = torch.nn.CrossEntropyLoss(reduction='none')
-
- def forward(self, logits, targets, masks):
- loss = self.criterion(logits.contiguous().view(-1, logits.size(-1)),
- targets.contiguous().view(-1))
- preds = torch.argmax(logits, dim=-1).contiguous().view(-1)
- masks = masks.contiguous().view(-1)
-
- loss = (loss * masks).sum() / masks.sum()
- err = torch.sum((preds != targets.view(-1)) * masks) / masks.sum()
-
- return loss, err
-
-
-class AttentionBinarizationLoss(torch.nn.Module):
-
- def __init__(self, start_epoch=0, warmup_epoch=100):
- super(AttentionBinarizationLoss, self).__init__()
- self.start_epoch = start_epoch
- self.warmup_epoch = warmup_epoch
-
- def forward(self, epoch, hard_attention, soft_attention, eps=1e-12):
- log_sum = torch.log(
- torch.clamp(soft_attention[hard_attention == 1], min=eps)).sum()
- kl_loss = -log_sum / hard_attention.sum()
- if epoch < self.start_epoch:
- warmup_ratio = 0
- else:
- warmup_ratio = min(1.0,
- (epoch - self.start_epoch) / self.warmup_epoch)
- return kl_loss * warmup_ratio
-
-
-class AttentionCTCLoss(torch.nn.Module):
-
- def __init__(self, blank_logprob=-1):
- super(AttentionCTCLoss, self).__init__()
- self.log_softmax = torch.nn.LogSoftmax(dim=3)
- self.blank_logprob = blank_logprob
- self.CTCLoss = torch.nn.CTCLoss(zero_infinity=True)
-
- def forward(self, attn_logprob, in_lens, out_lens):
- key_lens = in_lens
- query_lens = out_lens
- attn_logprob_padded = F.pad(
- input=attn_logprob,
- pad=(1, 0, 0, 0, 0, 0, 0, 0),
- value=self.blank_logprob)
- cost_total = 0.0
- for bid in range(attn_logprob.shape[0]):
- target_seq = torch.arange(1, key_lens[bid] + 1).unsqueeze(0)
- curr_logprob = attn_logprob_padded[bid].permute(1, 0, 2)
- curr_logprob = curr_logprob[:query_lens[bid], :, :key_lens[bid]
- + 1]
- curr_logprob = self.log_softmax(curr_logprob[None])[0]
- ctc_cost = self.CTCLoss(
- curr_logprob,
- target_seq,
- input_lengths=query_lens[bid:bid + 1],
- target_lengths=key_lens[bid:bid + 1],
- )
- cost_total += ctc_cost
- cost = cost_total / attn_logprob.shape[0]
- return cost
-
-
-loss_dict = {
- 'generator_adv_loss': GeneratorAdversarialLoss,
- 'discriminator_adv_loss': DiscriminatorAdversarialLoss,
- 'stft_loss': MultiResolutionSTFTLoss,
- 'mel_loss': MelSpectrogramLoss,
- 'subband_stft_loss': MultiResolutionSTFTLoss,
- 'feat_match_loss': FeatureMatchLoss,
- 'MelReconLoss': MelReconLoss,
- 'ProsodyReconLoss': ProsodyReconLoss,
- 'SeqCELoss': SeqCELoss,
- 'AttentionBinarizationLoss': AttentionBinarizationLoss,
- 'AttentionCTCLoss': AttentionCTCLoss,
- 'FpCELoss': FpCELoss,
-}
-
-
-def criterion_builder(config, device='cpu'):
- """Criterion builder.
- Args:
- config (dict): Config dictionary.
- Returns:
- criterion (dict): Loss dictionary
- """
- criterion = {}
- for key, value in config['Loss'].items():
- if key in loss_dict:
- if value['enable']:
- criterion[key] = loss_dict[key](
- **value.get('params', {})).to(device)
- setattr(criterion[key], 'weights', value.get('weights', 1.0))
- else:
- raise NotImplementedError('{} is not implemented'.format(key))
-
- return criterion
diff --git a/modelscope/models/audio/tts/kantts/train/scheduler.py b/modelscope/models/audio/tts/kantts/train/scheduler.py
deleted file mode 100644
index 5fcfeb11..00000000
--- a/modelscope/models/audio/tts/kantts/train/scheduler.py
+++ /dev/null
@@ -1,44 +0,0 @@
-# Copyright (c) Alibaba, Inc. and its affiliates.
-
-from torch.optim.lr_scheduler import MultiStepLR, _LRScheduler
-
-
-class FindLR(_LRScheduler):
- """
- inspired by fast.ai @https://sgugger.github.io/how-do-you-find-a-good-learning-rate.html
- """
-
- def __init__(self, optimizer, max_steps, max_lr=10):
- self.max_steps = max_steps
- self.max_lr = max_lr
- super().__init__(optimizer)
-
- def get_lr(self):
- return [
- base_lr * ((self.max_lr / base_lr)**(
- self.last_epoch / # noqa W504
- (self.max_steps - 1))) for base_lr in self.base_lrs
- ]
-
-
-class NoamLR(_LRScheduler):
- """
- Implements the Noam Learning rate schedule. This corresponds to increasing the learning rate
- linearly for the first ``warmup_steps`` training steps, and decreasing it thereafter proportionally
- to the inverse square root of the step number, scaled by the inverse square root of the
- dimensionality of the model. Time will tell if this is just madness or it's actually important.
- Parameters
- ----------
- warmup_steps: ``int``, required.
- The number of steps to linearly increase the learning rate.
- """
-
- def __init__(self, optimizer, warmup_steps):
- self.warmup_steps = warmup_steps
- super().__init__(optimizer)
-
- def get_lr(self):
- last_epoch = max(1, self.last_epoch)
- scale = self.warmup_steps**0.5 * min(
- last_epoch**(-0.5), last_epoch * self.warmup_steps**(-1.5))
- return [base_lr * scale for base_lr in self.base_lrs]
diff --git a/modelscope/models/audio/tts/kantts/train/trainer.py b/modelscope/models/audio/tts/kantts/train/trainer.py
deleted file mode 100644
index 628b0503..00000000
--- a/modelscope/models/audio/tts/kantts/train/trainer.py
+++ /dev/null
@@ -1,1201 +0,0 @@
-# Copyright (c) Alibaba, Inc. and its affiliates.
-
-import os
-import sys
-from collections import defaultdict
-
-import numpy as np
-import soundfile as sf
-import torch
-from tensorboardX import SummaryWriter
-from tqdm import tqdm
-
-from modelscope.models.audio.tts.kantts.utils.plot import (plot_alignment,
- plot_spectrogram)
-from modelscope.utils.logger import get_logger
-
-logging = get_logger()
-
-
-def traversal_dict(d, func):
- if not isinstance(d, dict):
- logging.error('Not a dict: {}'.format(d))
- return
- for k, v in d.items():
- if isinstance(v, dict):
- traversal_dict(v, func)
- else:
- func(k, v)
-
-
-def distributed_init():
- world_size = int(os.environ.get('WORLD_SIZE', 1))
- local_rank = int(os.environ.get('RANK', 0))
- distributed = world_size > 1
- device = torch.device('cuda', local_rank)
- if distributed:
- torch.distributed.init_process_group(
- backend='nccl', init_method='env://')
- logging.info(
- 'Distributed training, global world size: {}, local world size: {}, global rank: {}, local rank: {}'
- .format(
- world_size,
- torch.cuda.device_count(),
- torch.distributed.get_rank(),
- local_rank,
- ))
- logging.info('nccl backend: {}'.format(
- torch.distributed.is_nccl_available()))
- logging.info('mpi backend: {}'.format(
- torch.distributed.is_mpi_available()))
- device_ids = list(range(torch.cuda.device_count()))
- logging.info(
- '[{}] rank = {}, world_size = {}, n_gpus = {}, device_ids = {}'.
- format(
- os.getpid(),
- torch.distributed.get_rank(),
- torch.distributed.get_world_size(),
- torch.cuda.device_count(),
- device_ids,
- ))
- return distributed, device, local_rank, world_size
-
-
-class Trainer(object):
-
- def __init__(
- self,
- config,
- model,
- optimizer,
- scheduler,
- criterion,
- device,
- sampler,
- train_loader,
- valid_loader,
- max_epochs=None,
- max_steps=None,
- save_dir=None,
- save_interval=1,
- valid_interval=1,
- log_interval=10,
- grad_clip=None,
- ):
- self.model = model
- self.optimizer = optimizer
- self.scheduler = scheduler
- self.criterion = criterion
- self.device = device
- self.sampler = sampler
- self.train_loader = train_loader
- self.valid_loader = valid_loader
- self.max_epochs = max_epochs
- self.steps = 1
- self.epoch = 0
- self.save_dir = save_dir
- self.save_interval = save_interval
- self.valid_interval = valid_interval
- self.log_interval = log_interval
- self.grad_clip = grad_clip
- self.total_train_loss = defaultdict(float)
- self.total_eval_loss = defaultdict(float)
- self.config = config
- self.distributed = self.config.get('distributed', False)
- self.rank = self.config.get('rank', 0)
-
- self.log_dir = os.path.join(save_dir, 'log')
- self.ckpt_dir = os.path.join(save_dir, 'ckpt')
- os.makedirs(self.log_dir, exist_ok=True)
- os.makedirs(self.ckpt_dir, exist_ok=True)
-
- self.writer = SummaryWriter(self.log_dir)
-
- if max_epochs is None:
- self.max_epochs = sys.maxsize
- else:
- self.max_epochs = int(max_epochs)
- if max_steps is None:
- self.max_steps = sys.maxsize
- else:
- self.max_steps = int(max_steps)
-
- self.finish_training = False
-
- def set_model_state(self, state='train'):
- if state == 'train':
- if isinstance(self.model, dict):
- for key in self.model.keys():
- self.model[key].train()
- else:
- self.model.train()
- elif state == 'eval':
- if isinstance(self.model, dict):
- for key in self.model.keys():
- self.model[key].eval()
- else:
- self.model.eval()
- else:
- raise ValueError("state must be either 'train' or 'eval'.")
-
- def write_to_tensorboard(self, loss):
- """Write to tensorboard."""
- for key, value in loss.items():
- self.writer.add_scalar(key, value, self.steps)
-
- def save_checkpoint(self, checkpoint_path):
- state_dict = {
- 'optimizer': self.optimizer.state_dict(),
- 'scheduler': self.scheduler.state_dict(),
- 'steps': self.steps,
- 'model': self.model.state_dict(),
- }
-
- if not os.path.exists(checkpoint_path):
- os.makedirs(os.path.dirname(checkpoint_path))
- torch.save(state_dict, checkpoint_path)
-
- def load_checkpoint(self,
- checkpoint_path,
- restore_training_state=False,
- strict=True):
- state_dict = torch.load(checkpoint_path)
- self.model.load_state_dict(state_dict['model'], strict=strict)
- if restore_training_state:
- self.optimizer.load_state_dict(state_dict['optimizer'])
- self.scheduler.load_state_dict(state_dict['scheduler'])
- self.steps = state_dict['steps']
-
- def check_save_interval(self):
- if self.ckpt_dir is not None and (
- self.steps) % self.save_interval == 0:
- self.save_checkpoint(
- os.path.join(self.ckpt_dir,
- 'checkpoint_{}.pth'.format(self.steps)))
- logging.info('Checkpoint saved at step {}'.format(self.steps))
-
- def check_log_interval(self):
- if self.writer is not None and (self.steps) % self.log_interval == 0:
- for key in self.total_train_loss.keys():
- self.total_train_loss[key] /= self.config['log_interval_steps']
- logging.info(
- f'(Steps: {self.steps}) {key} = {self.total_train_loss[key]:.4f}.'
- )
- self.write_to_tensorboard(self.total_train_loss)
- self.total_train_loss = defaultdict(float)
-
- def log_learning_rate(key, sche):
- logging.info('{} learning rate: {:.6f}'.format(
- key,
- sche.get_lr()[0]))
- self.write_to_tensorboard(
- {'{}_lr'.format(key): sche.get_lr()[0]})
-
- traversal_dict(self.scheduler, log_learning_rate)
-
- def check_eval_interval(self):
- if self.valid_interval > 0 and (self.steps) % self.valid_interval == 0:
- self.eval_epoch()
-
- def check_stop_training(self):
- if self.steps >= self.max_steps or self.epoch >= self.max_epochs:
- self.finish_training = True
-
- def train(self):
- self.set_model_state('train')
-
- while True:
- self.train_epoch()
- self.epoch += 1
- self.check_stop_training()
- if self.finish_training:
- break
-
- def train_epoch(self):
- for batch in tqdm(self.train_loader):
- self.train_step(batch)
-
- if self.rank == 0:
- self.check_eval_interval()
- self.check_save_interval()
- self.check_log_interval()
-
- self.steps += 1
- self.check_stop_training()
- if self.finish_training:
- break
-
- logging.info('Epoch {} finished'.format(self.epoch))
-
- if self.distributed:
- self.sampler['train'].set_epoch(self.epoch)
-
- def train_step(self, batch):
- data, target = batch
- data, target = data.to(self.device), target.to(self.device)
- self.optimizer.zero_grad()
- output = self.model(data)
- loss = self.criterion(output, target)
- loss.backward()
- if self.grad_clip is not None:
- torch.nn.utils.clip_grad_norm_(self.model.parameters(),
- self.grad_clip)
- self.optimizer.step()
-
- @torch.no_grad()
- def eval_step(self, batch):
- pass
-
- def eval_epoch(self):
- logging.info(f'(Epoch: {self.epoch}) Start evaluation.')
- # change mode
- self.set_model_state('eval')
-
- self.total_eval_loss = defaultdict(float)
- rand_idx = np.random.randint(0, len(self.valid_loader))
- idx = 0
- logging.info('Valid data size: {}'.format(len(self.valid_loader)))
- for batch in tqdm(self.valid_loader):
- self.eval_step(batch)
- if idx == rand_idx:
- logging.info(
- f'(Epoch: {self.epoch}) Random batch: {idx}, generating image.'
- )
- self.genearete_and_save_intermediate_result(batch)
- idx += 1
-
- for key in self.total_eval_loss.keys():
- self.total_eval_loss[key] /= idx + 1
- logging.info(
- f'(Steps: {self.steps}) {key} = {self.total_eval_loss[key]:.4f}.'
- )
- self.write_to_tensorboard(self.total_eval_loss)
-
- logging.info('Epoch {} evaluation finished'.format(self.epoch))
-
- self.set_model_state('train')
-
- @torch.no_grad()
- def genearete_and_save_intermediate_result(self, batch):
- pass
-
-
-class GAN_Trainer(Trainer):
-
- def __init__(
- self,
- config,
- model,
- optimizer,
- scheduler,
- criterion,
- device,
- sampler,
- train_loader,
- valid_loader,
- max_epochs=None,
- max_steps=None,
- save_dir=None,
- save_interval=1,
- valid_interval=1,
- log_interval=10,
- grad_clip=None,
- ):
- super().__init__(
- config,
- model,
- optimizer,
- scheduler,
- criterion,
- device,
- sampler,
- train_loader,
- valid_loader,
- max_epochs,
- max_steps,
- save_dir,
- save_interval,
- valid_interval,
- log_interval,
- grad_clip,
- )
-
- def set_model_state(self, state='train'):
- if state == 'train':
- if isinstance(self.model, dict):
- self.model['generator'].train()
- for key in self.model['discriminator'].keys():
- self.model['discriminator'][key].train()
- else:
- self.model.train()
- elif state == 'eval':
- if isinstance(self.model, dict):
- self.model['generator'].eval()
- for key in self.model['discriminator'].keys():
- self.model['discriminator'][key].eval()
- else:
- self.model.eval()
- else:
- raise ValueError("state must be either 'train' or 'eval'.")
-
- @torch.no_grad()
- def genearete_and_save_intermediate_result(self, batch):
- """Generate and save intermediate result."""
- # delayed import to avoid error related backend error
- import matplotlib.pyplot as plt
-
- # generate
- y_batch, x_batch = batch
- y_batch, x_batch = y_batch.to(self.device), x_batch.to(self.device)
- y_batch_ = self.model['generator'](x_batch)
- if self.model.get('pqmf', None):
- y_mb_ = y_batch_
- y_batch_ = self.model['pqmf'].synthesis(y_mb_)
-
- # check directory
- dirname = os.path.join(self.log_dir, f'predictions/{self.steps}steps')
- if not os.path.exists(dirname):
- os.makedirs(dirname)
-
- for idx, (y, y_) in enumerate(zip(y_batch, y_batch_), 1):
- # convert to ndarray
- y, y_ = y.view(-1).cpu().numpy(), y_.view(-1).cpu().numpy()
-
- # plot figure and save it
- figname = os.path.join(dirname, f'{idx}.png')
- plt.subplot(2, 1, 1)
- plt.plot(y)
- plt.title('groundtruth speech')
- plt.subplot(2, 1, 2)
- plt.plot(y_)
- plt.title(f'generated speech @ {self.steps} steps')
- plt.tight_layout()
- plt.savefig(figname)
- plt.close()
-
- # save as wavfile
- y = np.clip(y, -1, 1)
- y_ = np.clip(y_, -1, 1)
- sf.write(
- figname.replace('.png', '_ref.wav'),
- y,
- self.config['audio_config']['sampling_rate'],
- 'PCM_16',
- )
- sf.write(
- figname.replace('.png', '_gen.wav'),
- y_,
- self.config['audio_config']['sampling_rate'],
- 'PCM_16',
- )
-
- if idx >= self.config['num_save_intermediate_results']:
- break
-
- @torch.no_grad()
- def eval_step(self, batch):
- y, x = batch
- y, x = y.to(self.device), x.to(self.device)
-
- y_ = self.model['generator'](x)
- # reconstruct the signal from multi-band signal
- if self.model.get('pqmf', None):
- y_mb_ = y_
- y_ = self.model['pqmf'].synthesis(y_mb_)
-
- aux_loss = 0.0
-
- # multi-resolution sfft loss
- if self.criterion.get('stft_loss', None):
- sc_loss, mag_loss = self.criterion['stft_loss'](y_, y)
- aux_loss += (sc_loss
- + mag_loss) * self.criterion['stft_loss'].weights
- self.total_eval_loss[
- 'eval/spectral_convergence_loss'] += sc_loss.item()
-
- # subband multi-resolution stft loss
- if self.criterion.get('subband_stft_loss', None):
- aux_loss *= 0.5 # for balancing with subband stft loss
- y_mb = self.model['pqmf'].analysis(y)
- sub_sc_loss, sub_mag_loss = self.criterion['sub_stft'](y_mb_, y_mb)
- self.total_eval_loss[
- 'eval/sub_spectral_convergence_loss'] += sub_sc_loss.item()
- self.total_eval_loss[
- 'eval/sub_log_stft_magnitude_loss'] += sub_mag_loss.item()
- aux_loss += (0.5 * (sub_sc_loss + sub_mag_loss)
- * self.criterion['sub_stft'].weights)
-
- # mel spectrogram loss
- if self.criterion.get('mel_loss', None):
- mel_loss = self.criterion['mel_loss'](y_, y)
- aux_loss += mel_loss * self.criterion['mel_loss'].weights
- self.total_eval_loss['eval/mel_loss'] += mel_loss.item()
-
- fmap_lst_ = []
- adv_loss = 0.0
- # adversiral loss
- for discriminator in self.model['discriminator'].keys():
- p_, fmap_ = self.model['discriminator'][discriminator](y_)
- fmap_lst_.append(fmap_)
- adv_loss += (
- self.criterion['generator_adv_loss'](p_)
- * self.criterion['generator_adv_loss'].weights)
-
- gen_loss = aux_loss + adv_loss
-
- if self.criterion.get('feat_match_loss', None):
- fmap_lst = []
- # no need to track gradients
- for discriminator in self.model['discriminator'].keys():
- with torch.no_grad():
- p, fmap = self.model['discriminator'][discriminator](y)
- fmap_lst.append(fmap)
-
- fm_loss = 0.0
- for fmap_, fmap in zip(fmap_lst, fmap_lst_):
- fm_loss += self.criterion['feat_match_loss'](fmap_, fmap)
- self.total_eval_loss['eval/feature_matching_loss'] += fm_loss.item(
- )
-
- gen_loss += fm_loss * self.criterion['feat_match_loss'].weights
-
- dis_loss = 0.0
- for discriminator in self.model['discriminator'].keys():
- p, fmap = self.model['discriminator'][discriminator](y)
- p_, fmap_ = self.model['discriminator'][discriminator](y_.detach())
- real_loss, fake_loss = self.criterion['discriminator_adv_loss'](p_,
- p)
- dis_loss += real_loss + fake_loss
- self.total_eval_loss['eval/real_loss'] += real_loss.item()
- self.total_eval_loss['eval/fake_loss'] += fake_loss.item()
-
- self.total_eval_loss['eval/discriminator_loss'] += dis_loss.item()
- self.total_eval_loss['eval/adversarial_loss'] += adv_loss.item()
- self.total_eval_loss['eval/generator_loss'] += gen_loss.item()
-
- def train_step(self, batch):
- y, x = batch
- y, x = y.to(self.device), x.to(self.device)
-
- if self.steps >= self.config.get('generator_train_start_steps', 0):
- y_ = self.model['generator'](x)
- # reconstruct the signal from multi-band signal
- if self.model.get('pqmf', None):
- y_mb_ = y_
- y_ = self.model['pqmf'].synthesis(y_mb_)
-
- # initialize
- gen_loss = 0.0
-
- # multi-resolution sfft loss
- if self.criterion.get('stft_loss', None):
- sc_loss, mag_loss = self.criterion['stft_loss'](y_, y)
- gen_loss += (sc_loss
- + mag_loss) * self.criterion['stft_loss'].weights
- self.total_train_loss[
- 'train/spectral_convergence_loss'] += sc_loss.item()
- self.total_train_loss[
- 'train/log_stft_magnitude_loss'] += mag_loss.item()
-
- # subband multi-resolution stft loss
- if self.criterion.get('subband_stft_loss', None):
- gen_loss *= 0.5 # for balancing with subband stft loss
- y_mb = self.model['pqmf'].analysis(y)
- sub_sc_loss, sub_mag_loss = self.criterion['sub_stft'](y_mb_,
- y_mb)
- gen_loss += 0.5 * (sub_sc_loss + sub_mag_loss)
- self.total_train_loss[
- 'train/sub_spectral_convergence_loss'] += sub_sc_loss.item(
- ) # noqa E123
- self.total_train_loss[
- 'train/sub_log_stft_magnitude_loss'] += sub_mag_loss.item(
- ) # noqa E123
-
- # mel spectrogram loss
- if self.criterion.get('mel_loss', None):
- mel_loss = self.criterion['mel_loss'](y_, y)
- gen_loss += mel_loss * self.criterion['mel_loss'].weights
- self.total_train_loss['train/mel_loss'] += mel_loss.item()
-
- # adversarial loss
- if self.steps > self.config['discriminator_train_start_steps']:
- adv_loss = 0.0
- fmap_lst_ = []
- for discriminator in self.model['discriminator'].keys():
- p_, fmap_ = self.model['discriminator'][discriminator](y_)
- fmap_lst_.append(fmap_)
- adv_loss += self.criterion['generator_adv_loss'](p_)
- self.total_train_loss[
- 'train/adversarial_loss'] += adv_loss.item()
-
- gen_loss += adv_loss * self.criterion[
- 'generator_adv_loss'].weights
-
- # feature matching loss
- if self.criterion.get('feat_match_loss', None):
- fmap_lst = []
- # no need to track gradients
- for discriminator in self.model['discriminator'].keys():
- with torch.no_grad():
- p, fmap = self.model['discriminator'][
- discriminator](
- y)
- fmap_lst.append(fmap)
-
- fm_loss = 0.0
- for fmap_, fmap in zip(fmap_lst, fmap_lst_):
- fm_loss += self.criterion['feat_match_loss'](fmap_,
- fmap)
- self.total_train_loss[
- 'train/feature_matching_loss'] += fm_loss.item()
- gen_loss += fm_loss * self.criterion[
- 'feat_match_loss'].weights
-
- self.total_train_loss['train/generator_loss'] += gen_loss.item()
- # update generator
- self.optimizer['generator'].zero_grad()
- gen_loss.backward()
- if self.config['generator_grad_norm'] > 0:
- torch.nn.utils.clip_grad_norm_(
- self.model['generator'].parameters(),
- self.config['generator_grad_norm'],
- )
- self.optimizer['generator'].step()
- self.scheduler['generator'].step()
-
- # update discriminator
- if self.steps > self.config['discriminator_train_start_steps']:
- # re-compute y_ which leads better quality
- with torch.no_grad():
- y_ = self.model['generator'](x)
-
- if self.model.get('pqmf', None):
- y_ = self.model['pqmf'].synthesis(y_)
-
- # discriminator loss
- dis_loss = 0.0
- for discriminator in self.model['discriminator'].keys():
- p, fmap = self.model['discriminator'][discriminator](y)
- p_, fmap_ = self.model['discriminator'][discriminator](
- y_.detach())
- real_loss, fake_loss = self.criterion[
- 'discriminator_adv_loss'](p_, p)
- dis_loss += real_loss + fake_loss
- self.total_train_loss['train/real_loss'] += real_loss.item()
- self.total_train_loss['train/fake_loss'] += fake_loss.item()
-
- self.total_train_loss['train/discriminator_loss'] += dis_loss.item(
- )
-
- # update discriminator
- for key in self.optimizer['discriminator'].keys():
- self.optimizer['discriminator'][key].zero_grad()
-
- dis_loss.backward()
- if self.config['discriminator_grad_norm'] > 0:
- torch.nn.utils.clip_grad_norm_(
- self.model['discriminator'].parameters(),
- self.config['discriminator_grad_norm'],
- )
- for key in self.optimizer['discriminator'].keys():
- self.optimizer['discriminator'][key].step()
- for key in self.scheduler['discriminator'].keys():
- self.scheduler['discriminator'][key].step()
-
- def save_checkpoint(self, checkpoint_path):
- state_dict = {
- 'optimizer': {
- 'generator': self.optimizer['generator'].state_dict(),
- 'discriminator': {},
- },
- 'scheduler': {
- 'generator': self.scheduler['generator'].state_dict(),
- 'discriminator': {},
- },
- 'steps': self.steps,
- }
- for model_name in self.optimizer['discriminator'].keys():
- state_dict['optimizer']['discriminator'][
- model_name] = self.optimizer['discriminator'][
- model_name].state_dict()
-
- for model_name in self.scheduler['discriminator'].keys():
- state_dict['scheduler']['discriminator'][
- model_name] = self.scheduler['discriminator'][
- model_name].state_dict()
-
- if not self.distributed:
- model_state = self.model['generator'].state_dict()
- else:
- model_state = self.model['generator'].module.state_dict()
- state_dict['model'] = {
- 'generator': model_state,
- 'discriminator': {},
- }
- for model_name in self.model['discriminator'].keys():
- if not self.distributed:
- model_state = self.model['discriminator'][
- model_name].state_dict()
- else:
- model_state = self.model['discriminator'][
- model_name].module.state_dict()
- state_dict['model']['discriminator'][model_name] = model_state
-
- if not os.path.exists(os.path.dirname(checkpoint_path)):
- os.makedirs(os.path.dirname(checkpoint_path))
- torch.save(state_dict, checkpoint_path)
-
- def load_checkpoint(self,
- checkpoint_path,
- restore_training_state=False,
- strict=True):
- state_dict = torch.load(checkpoint_path, map_location='cpu')
- if not self.distributed:
- self.model['generator'].load_state_dict(
- state_dict['model']['generator'], strict=strict)
- else:
- self.model['generator'].module.load_state_dict(
- state_dict['model']['generator'], strict=strict)
- for model_name in state_dict['model']['discriminator']:
- if not self.distributed:
- self.model['discriminator'][model_name].load_state_dict(
- state_dict['model']['discriminator'][model_name],
- strict=strict)
- else:
- self.model['discriminator'][model_name].module.load_state_dict(
- state_dict['model']['discriminator'][model_name],
- strict=strict)
-
- if restore_training_state:
- self.steps = state_dict['steps']
- self.optimizer['generator'].load_state_dict(
- state_dict['optimizer']['generator'])
- self.scheduler['generator'].load_state_dict(
- state_dict['scheduler']['generator'])
- for model_name in state_dict['optimizer']['discriminator'].keys():
- self.optimizer['discriminator'][model_name].load_state_dict(
- state_dict['optimizer']['discriminator'][model_name])
- for model_name in state_dict['scheduler']['discriminator'].keys():
- self.scheduler['discriminator'][model_name].load_state_dict(
- state_dict['scheduler']['discriminator'][model_name])
-
-
-class Sambert_Trainer(Trainer):
-
- def __init__(
- self,
- config,
- model,
- optimizer,
- scheduler,
- criterion,
- device,
- sampler,
- train_loader,
- valid_loader,
- max_epochs=None,
- max_steps=None,
- save_dir=None,
- save_interval=1,
- valid_interval=1,
- log_interval=10,
- grad_clip=None,
- ):
- super().__init__(
- config,
- model,
- optimizer,
- scheduler,
- criterion,
- device,
- sampler,
- train_loader,
- valid_loader,
- max_epochs,
- max_steps,
- save_dir,
- save_interval,
- valid_interval,
- log_interval,
- grad_clip,
- )
- self.with_MAS = config['Model']['KanTtsSAMBERT']['params'].get(
- 'MAS', False)
- self.fp_enable = config['Model']['KanTtsSAMBERT']['params'].get(
- 'FP', False)
-
- @torch.no_grad()
- def genearete_and_save_intermediate_result(self, batch):
- inputs_ling = batch['input_lings'].to(self.device)
- inputs_emotion = batch['input_emotions'].to(self.device)
- inputs_speaker = batch['input_speakers'].to(self.device)
- valid_input_lengths = batch['valid_input_lengths'].to(self.device)
- mel_targets = batch['mel_targets'].to(self.device)
- # generate mel spectrograms
- res = self.model['KanTtsSAMBERT'](
- inputs_ling[0:1],
- inputs_emotion[0:1],
- inputs_speaker[0:1],
- valid_input_lengths[0:1],
- )
- x_band_width = res['x_band_width']
- h_band_width = res['h_band_width']
- enc_slf_attn_lst = res['enc_slf_attn_lst']
- pnca_x_attn_lst = res['pnca_x_attn_lst']
- pnca_h_attn_lst = res['pnca_h_attn_lst']
- dec_outputs = res['dec_outputs']
- postnet_outputs = res['postnet_outputs']
-
- dirname = os.path.join(self.log_dir, f'predictions/{self.steps}steps')
- if not os.path.exists(dirname):
- os.makedirs(dirname)
-
- for layer_id, slf_attn in enumerate(enc_slf_attn_lst):
- for head_id in range(self.config['Model']['KanTtsSAMBERT']
- ['params']['encoder_num_heads']):
- fig = plot_alignment(
- slf_attn[head_id, :valid_input_lengths[0], :
- valid_input_lengths[0]].cpu().numpy(),
- info='valid_len_{}'.format(valid_input_lengths[0].item()),
- )
- fig.savefig(
- os.path.join(
- dirname,
- 'enc_slf_attn_dev_layer{}_head{}'.format(
- layer_id, head_id),
- ))
- for layer_id, (pnca_x_attn, pnca_h_attn) in enumerate(
- zip(pnca_x_attn_lst, pnca_h_attn_lst)):
- for head_id in range(self.config['Model']['KanTtsSAMBERT']
- ['params']['decoder_num_heads']):
- fig = plot_alignment(
- pnca_x_attn[head_id, :, :].cpu().numpy(),
- info='x_band_width_{}'.format(x_band_width),
- )
- fig.savefig(
- os.path.join(
- dirname,
- 'pnca_x_attn_dev_layer{}_head{}'.format(
- layer_id, head_id),
- ))
- fig = plot_alignment(
- pnca_h_attn[head_id, :, :].cpu().numpy(),
- info='h_band_width_{}'.format(h_band_width),
- )
- fig.savefig(
- os.path.join(
- dirname,
- 'pnca_h_attn_dev_layer{}_head{}'.format(
- layer_id, head_id),
- ))
-
- target_mel = mel_targets[0].cpu().numpy()
- coarse_mel = dec_outputs.squeeze(0).cpu().numpy()
- output_mel = postnet_outputs.squeeze(0).cpu().numpy()
- np.save(os.path.join(dirname, 'coarse_mel.npy'), coarse_mel)
- np.save(os.path.join(dirname, 'output_mel.npy'), output_mel)
- np.save(os.path.join(dirname, 'target_mel.npy'), target_mel)
- fig = plot_spectrogram(coarse_mel.T)
- fig.savefig(os.path.join(dirname, 'mel_dec_outputs'))
- fig = plot_spectrogram(output_mel.T)
- fig.savefig(os.path.join(dirname, 'mel_postnet_outputs'))
-
- @torch.no_grad()
- def eval_step(self, batch):
- inputs_ling = batch['input_lings'].to(self.device)
- inputs_emotion = batch['input_emotions'].to(self.device)
- inputs_speaker = batch['input_speakers'].to(self.device)
- valid_input_lengths = batch['valid_input_lengths'].to(self.device)
- valid_output_lengths = batch['valid_output_lengths'].to(self.device)
- mel_targets = batch['mel_targets'].to(self.device)
- durations = (
- batch['durations'].to(self.device)
- if batch['durations'] is not None else None)
- pitch_contours = batch['pitch_contours'].to(self.device)
- energy_contours = batch['energy_contours'].to(self.device)
- attn_priors = (
- batch['attn_priors'].to(self.device)
- if batch['attn_priors'] is not None else None)
- fp_label = None
- if self.fp_enable:
- fp_label = batch['fp_label'].to(self.device)
- # generate mel spectrograms
- res = self.model['KanTtsSAMBERT'](
- inputs_ling,
- inputs_emotion,
- inputs_speaker,
- valid_input_lengths,
- output_lengths=valid_output_lengths,
- mel_targets=mel_targets,
- duration_targets=durations,
- pitch_targets=pitch_contours,
- energy_targets=energy_contours,
- attn_priors=attn_priors,
- fp_label=fp_label,
- )
-
- x_band_width = res['x_band_width']
- h_band_width = res['h_band_width']
- dec_outputs = res['dec_outputs']
- postnet_outputs = res['postnet_outputs']
- log_duration_predictions = res['log_duration_predictions']
- pitch_predictions = res['pitch_predictions']
- energy_predictions = res['energy_predictions']
- duration_targets = res['duration_targets']
- pitch_targets = res['pitch_targets']
- energy_targets = res['energy_targets']
- fp_predictions = res['fp_predictions']
- valid_inter_lengths = res['valid_inter_lengths']
-
- mel_loss_, mel_loss = self.criterion['MelReconLoss'](
- valid_output_lengths, mel_targets, dec_outputs, postnet_outputs)
-
- dur_loss, pitch_loss, energy_loss = self.criterion['ProsodyReconLoss'](
- valid_inter_lengths,
- duration_targets,
- pitch_targets,
- energy_targets,
- log_duration_predictions,
- pitch_predictions,
- energy_predictions,
- )
- loss_total = mel_loss_ + mel_loss + dur_loss + pitch_loss + energy_loss
- if self.fp_enable:
- fp_loss = self.criterion['FpCELoss'](valid_input_lengths,
- fp_predictions, fp_label)
- loss_total = loss_total + fp_loss
-
- if self.with_MAS:
- attn_soft = res['attn_soft']
- attn_hard = res['attn_hard']
- attn_logprob = res['attn_logprob']
- attn_ctc_loss = self.criterion['AttentionCTCLoss'](
- attn_logprob, valid_input_lengths, valid_output_lengths)
- attn_kl_loss = self.criterion['AttentionBinarizationLoss'](
- self.epoch, attn_hard, attn_soft)
-
- loss_total += attn_ctc_loss + attn_kl_loss
- self.total_eval_loss['eval/attn_ctc_loss'] += attn_ctc_loss.item()
- self.total_eval_loss['eval/attn_kl_loss'] += attn_kl_loss.item()
-
- self.total_eval_loss['eval/TotalLoss'] += loss_total.item()
- self.total_eval_loss['eval/mel_loss_'] += mel_loss_.item()
- self.total_eval_loss['eval/mel_loss'] += mel_loss.item()
- self.total_eval_loss['eval/dur_loss'] += dur_loss.item()
- self.total_eval_loss['eval/pitch_loss'] += pitch_loss.item()
- self.total_eval_loss['eval/energy_loss'] += energy_loss.item()
- if self.fp_enable:
- self.total_eval_loss['eval/fp_loss'] += fp_loss.item()
- self.total_eval_loss['eval/batch_size'] += mel_targets.size(0)
- self.total_eval_loss['eval/x_band_width'] += x_band_width
- self.total_eval_loss['eval/h_band_width'] += h_band_width
-
- def train_step(self, batch):
- inputs_ling = batch['input_lings'].to(self.device)
- inputs_emotion = batch['input_emotions'].to(self.device)
- inputs_speaker = batch['input_speakers'].to(self.device)
- valid_input_lengths = batch['valid_input_lengths'].to(self.device)
- valid_output_lengths = batch['valid_output_lengths'].to(self.device)
- mel_targets = batch['mel_targets'].to(self.device)
- durations = (
- batch['durations'].to(self.device)
- if batch['durations'] is not None else None)
- pitch_contours = batch['pitch_contours'].to(self.device)
- energy_contours = batch['energy_contours'].to(self.device)
- attn_priors = (
- batch['attn_priors'].to(self.device)
- if batch['attn_priors'] is not None else None)
- fp_label = None
- if self.fp_enable:
- fp_label = batch['fp_label'].to(self.device)
-
- # generate mel spectrograms
- res = self.model['KanTtsSAMBERT'](
- inputs_ling,
- inputs_emotion,
- inputs_speaker,
- valid_input_lengths,
- output_lengths=valid_output_lengths,
- mel_targets=mel_targets,
- duration_targets=durations,
- pitch_targets=pitch_contours,
- energy_targets=energy_contours,
- attn_priors=attn_priors,
- fp_label=fp_label,
- )
-
- x_band_width = res['x_band_width']
- h_band_width = res['h_band_width']
- dec_outputs = res['dec_outputs']
- postnet_outputs = res['postnet_outputs']
- log_duration_predictions = res['log_duration_predictions']
- pitch_predictions = res['pitch_predictions']
- energy_predictions = res['energy_predictions']
-
- duration_targets = res['duration_targets']
- pitch_targets = res['pitch_targets']
- energy_targets = res['energy_targets']
- fp_predictions = res['fp_predictions']
- valid_inter_lengths = res['valid_inter_lengths']
-
- mel_loss_, mel_loss = self.criterion['MelReconLoss'](
- valid_output_lengths, mel_targets, dec_outputs, postnet_outputs)
-
- dur_loss, pitch_loss, energy_loss = self.criterion['ProsodyReconLoss'](
- valid_inter_lengths,
- duration_targets,
- pitch_targets,
- energy_targets,
- log_duration_predictions,
- pitch_predictions,
- energy_predictions,
- )
- loss_total = mel_loss_ + mel_loss + dur_loss + pitch_loss + energy_loss
- if self.fp_enable:
- fp_loss = self.criterion['FpCELoss'](valid_input_lengths,
- fp_predictions, fp_label)
- loss_total = loss_total + fp_loss
-
- if self.with_MAS:
- attn_soft = res['attn_soft']
- attn_hard = res['attn_hard']
- attn_logprob = res['attn_logprob']
- attn_ctc_loss = self.criterion['AttentionCTCLoss'](
- attn_logprob, valid_input_lengths, valid_output_lengths)
- attn_kl_loss = self.criterion['AttentionBinarizationLoss'](
- self.epoch, attn_hard, attn_soft)
-
- loss_total += attn_ctc_loss + attn_kl_loss
- self.total_train_loss['train/attn_ctc_loss'] += attn_ctc_loss.item(
- )
- self.total_train_loss['train/attn_kl_loss'] += attn_kl_loss.item()
-
- self.total_train_loss['train/TotalLoss'] += loss_total.item()
- self.total_train_loss['train/mel_loss_'] += mel_loss_.item()
- self.total_train_loss['train/mel_loss'] += mel_loss.item()
- self.total_train_loss['train/dur_loss'] += dur_loss.item()
- self.total_train_loss['train/pitch_loss'] += pitch_loss.item()
- self.total_train_loss['train/energy_loss'] += energy_loss.item()
- if self.fp_enable:
- self.total_train_loss['train/fp_loss'] += fp_loss.item()
- self.total_train_loss['train/batch_size'] += mel_targets.size(0)
- self.total_train_loss['train/x_band_width'] += x_band_width
- self.total_train_loss['train/h_band_width'] += h_band_width
-
- self.optimizer['KanTtsSAMBERT'].zero_grad()
- loss_total.backward()
-
- if self.grad_clip is not None:
- torch.nn.utils.clip_grad_norm_(
- self.model['KanTtsSAMBERT'].parameters(), self.grad_clip)
- self.optimizer['KanTtsSAMBERT'].step()
- self.scheduler['KanTtsSAMBERT'].step()
-
- def save_checkpoint(self, checkpoint_path):
- if not self.distributed:
- model_state = self.model['KanTtsSAMBERT'].state_dict()
- else:
- model_state = self.model['KanTtsSAMBERT'].module.state_dict()
- state_dict = {
- 'optimizer': self.optimizer['KanTtsSAMBERT'].state_dict(),
- 'scheduler': self.scheduler['KanTtsSAMBERT'].state_dict(),
- 'steps': self.steps,
- 'model': model_state,
- }
-
- if not os.path.exists(checkpoint_path):
- os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
- torch.save(state_dict, checkpoint_path)
-
- def load_checkpoint(self,
- checkpoint_path,
- restore_training_state=False,
- strict=True):
- state_dict = torch.load(checkpoint_path)
- if not self.distributed:
- self.model['KanTtsSAMBERT'].load_state_dict(
- state_dict['model'], strict=strict)
- else:
- self.model['KanTtsSAMBERT'].module.load_state_dict(
- state_dict['model'], strict=strict)
-
- if restore_training_state:
- self.optimizer['KanTtsSAMBERT'].load_state_dict(
- state_dict['optimizer'])
- self.scheduler['KanTtsSAMBERT'].load_state_dict(
- state_dict['scheduler'])
- self.steps = state_dict['steps']
-
-
-class Textsy_BERT_Trainer(Trainer):
-
- def __init__(
- self,
- config,
- model,
- optimizer,
- scheduler,
- criterion,
- device,
- sampler,
- train_loader,
- valid_loader,
- max_epochs=None,
- max_steps=None,
- save_dir=None,
- save_interval=1,
- valid_interval=1,
- log_interval=10,
- grad_clip=None,
- ):
- super().__init__(
- config,
- model,
- optimizer,
- scheduler,
- criterion,
- device,
- sampler,
- train_loader,
- valid_loader,
- max_epochs,
- max_steps,
- save_dir,
- save_interval,
- valid_interval,
- log_interval,
- grad_clip,
- )
-
- @torch.no_grad()
- def genearete_and_save_intermediate_result(self, batch):
- inputs_ling = batch['input_lings'].to(self.device)
- valid_input_lengths = batch['valid_input_lengths'].to(self.device)
- bert_masks = batch['bert_masks'].to(self.device)
- targets = batch['targets'].to(self.device)
-
- res = self.model['KanTtsTextsyBERT'](
- inputs_ling[0:1],
- valid_input_lengths[0:1],
- )
-
- logits = res['logits']
- enc_slf_attn_lst = res['enc_slf_attn_lst']
- preds = torch.argmax(logits, dim=-1).contiguous().view(-1)
-
- dirname = os.path.join(self.log_dir, f'predictions/{self.steps}steps')
- if not os.path.exists(dirname):
- os.makedirs(dirname)
-
- for layer_id, slf_attn in enumerate(enc_slf_attn_lst):
- for head_id in range(self.config['Model']['KanTtsTextsyBERT']
- ['params']['encoder_num_heads']):
- fig = plot_alignment(
- slf_attn[head_id, :valid_input_lengths[0], :
- valid_input_lengths[0]].cpu().numpy(),
- info='valid_len_{}'.format(valid_input_lengths[0].item()),
- )
- fig.savefig(
- os.path.join(
- dirname,
- 'enc_slf_attn_dev_layer{}_head{}'.format(
- layer_id, head_id),
- ))
-
- target = targets[0].cpu().numpy()
- bert_mask = bert_masks[0].cpu().numpy()
- pred = preds.cpu().numpy()
- np.save(os.path.join(dirname, 'pred.npy'), pred)
- np.save(os.path.join(dirname, 'target.npy'), target)
- np.save(os.path.join(dirname, 'bert_mask.npy'), bert_mask)
-
- @torch.no_grad()
- def eval_step(self, batch):
- inputs_ling = batch['input_lings'].to(self.device)
- valid_input_lengths = batch['valid_input_lengths'].to(self.device)
- bert_masks = batch['bert_masks'].to(self.device)
- targets = batch['targets'].to(self.device)
-
- res = self.model['KanTtsTextsyBERT'](
- inputs_ling,
- valid_input_lengths,
- )
-
- logits = res['logits']
- loss_total, err = self.criterion['SeqCELoss'](
- logits,
- targets,
- bert_masks,
- )
- loss_total = loss_total / logits.size(-1)
-
- self.total_eval_loss['eval/TotalLoss'] += loss_total.item()
- self.total_eval_loss['eval/Error'] += err.item()
- self.total_eval_loss['eval/batch_size'] += targets.size(0)
-
- def train_step(self, batch):
- inputs_ling = batch['input_lings'].to(self.device)
- valid_input_lengths = batch['valid_input_lengths'].to(self.device)
- bert_masks = batch['bert_masks'].to(self.device)
- targets = batch['targets'].to(self.device)
-
- res = self.model['KanTtsTextsyBERT'](
- inputs_ling,
- valid_input_lengths,
- )
-
- logits = res['logits']
- loss_total, err = self.criterion['SeqCELoss'](
- logits,
- targets,
- bert_masks,
- )
- loss_total = loss_total / logits.size(-1)
-
- self.optimizer['KanTtsTextsyBERT'].zero_grad()
- loss_total.backward()
-
- if self.grad_clip is not None:
- torch.nn.utils.clip_grad_norm_(
- self.model['KanTtsTextsyBERT'].parameters(), self.grad_clip)
- self.optimizer['KanTtsTextsyBERT'].step()
- self.scheduler['KanTtsTextsyBERT'].step()
-
- self.total_train_loss['train/TotalLoss'] += loss_total.item()
- self.total_train_loss['train/Error'] += err.item()
- self.total_train_loss['train/batch_size'] += targets.size(0)
-
- def save_checkpoint(self, checkpoint_path):
- if not self.distributed:
- model_state = self.model['KanTtsTextsyBERT'].state_dict()
- else:
- model_state = self.model['KanTtsTextsyBERT'].module.state_dict()
- state_dict = {
- 'optimizer': self.optimizer['KanTtsTextsyBERT'].state_dict(),
- 'scheduler': self.scheduler['KanTtsTextsyBERT'].state_dict(),
- 'steps': self.steps,
- 'model': model_state,
- }
-
- if not os.path.exists(checkpoint_path):
- os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
- torch.save(state_dict, checkpoint_path)
-
- def load_checkpoint(self,
- checkpoint_path,
- restore_training_state=False,
- strict=True):
- state_dict = torch.load(checkpoint_path)
- if not self.distributed:
- self.model['KanTtsTextsyBERT'].load_state_dict(
- state_dict['model'], strict=strict)
- else:
- self.model['KanTtsTextsyBERT'].module.load_state_dict(
- state_dict['model'], strict=strict)
-
- if restore_training_state:
- self.optimizer['KanTtsTextsyBERT'].load_state_dict(
- state_dict['optimizer'])
- self.scheduler['KanTtsTextsyBERT'].load_state_dict(
- state_dict['scheduler'])
- self.steps = state_dict['steps']
diff --git a/modelscope/models/audio/tts/kantts/utils/audio_torch.py b/modelscope/models/audio/tts/kantts/utils/audio_torch.py
deleted file mode 100644
index e9f07ec3..00000000
--- a/modelscope/models/audio/tts/kantts/utils/audio_torch.py
+++ /dev/null
@@ -1,188 +0,0 @@
-# Copyright (c) Alibaba, Inc. and its affiliates.
-
-from distutils.version import LooseVersion
-
-import librosa
-import torch
-
-is_pytorch_17plus = LooseVersion(torch.__version__) >= LooseVersion('1.7')
-
-
-def stft(x, fft_size, hop_size, win_length, window):
- """Perform STFT and convert to magnitude spectrogram.
-
- Args:
- x (Tensor): Input signal tensor (B, T).
- fft_size (int): FFT size.
- hop_size (int): Hop size.
- win_length (int): Window length.
- window (str): Window function type.
-
- Returns:
- Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
-
- """
- if is_pytorch_17plus:
- x_stft = torch.stft(
- x, fft_size, hop_size, win_length, window, return_complex=False)
- else:
- x_stft = torch.stft(x, fft_size, hop_size, win_length, window)
- real = x_stft[..., 0]
- imag = x_stft[..., 1]
-
- return torch.sqrt(torch.clamp(real**2 + imag**2, min=1e-7)).transpose(2, 1)
-
-
-def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
- return 20 * torch.log10(torch.clamp(x, min=clip_val) * C)
-
-
-def dynamic_range_decompression_torch(x, C=1):
- return torch.pow(10.0, x * 0.05) / C
-
-
-def spectral_normalize_torch(
- magnitudes,
- min_level_db=-100.0,
- ref_level_db=20.0,
- norm_abs_value=4.0,
- symmetric=True,
-):
- output = dynamic_range_compression_torch(magnitudes) - ref_level_db
-
- if symmetric:
- return torch.clamp(
- 2 * norm_abs_value * ((output - min_level_db) / # noqa W504
- (-min_level_db)) - norm_abs_value,
- min=-norm_abs_value,
- max=norm_abs_value)
- else:
- return torch.clamp(
- norm_abs_value * ((output - min_level_db) / (-min_level_db)),
- min=0.0,
- max=norm_abs_value)
-
-
-def spectral_de_normalize_torch(
- magnitudes,
- min_level_db=-100.0,
- ref_level_db=20.0,
- norm_abs_value=4.0,
- symmetric=True,
-):
- if symmetric:
- magnitudes = torch.clamp(
- magnitudes, min=-norm_abs_value, max=norm_abs_value)
- magnitudes = (magnitudes + norm_abs_value) * (-min_level_db) / (
- 2 * norm_abs_value) + min_level_db
- else:
- magnitudes = torch.clamp(magnitudes, min=0.0, max=norm_abs_value)
- magnitudes = (magnitudes) * (-min_level_db) / (
- norm_abs_value) + min_level_db
-
- output = dynamic_range_decompression_torch(magnitudes + ref_level_db)
- return output
-
-
-class MelSpectrogram(torch.nn.Module):
- """Calculate Mel-spectrogram."""
-
- def __init__(
- self,
- fs=22050,
- fft_size=1024,
- hop_size=256,
- win_length=None,
- window='hann',
- num_mels=80,
- fmin=80,
- fmax=7600,
- center=True,
- normalized=False,
- onesided=True,
- eps=1e-10,
- log_base=10.0,
- pad_mode='constant',
- ):
- """Initialize MelSpectrogram module."""
- super().__init__()
- self.fft_size = fft_size
- if win_length is None:
- self.win_length = fft_size
- else:
- self.win_length = win_length
- self.hop_size = hop_size
- self.center = center
- self.normalized = normalized
- self.onesided = onesided
- if window is not None and not hasattr(torch, f'{window}_window'):
- raise ValueError(f'{window} window is not implemented')
- self.window = window
- self.eps = eps
- self.pad_mode = pad_mode
-
- fmin = 0 if fmin is None else fmin
- fmax = fs / 2 if fmax is None else fmax
- melmat = librosa.filters.mel(
- sr=fs,
- n_fft=fft_size,
- n_mels=num_mels,
- fmin=fmin,
- fmax=fmax,
- )
- self.register_buffer('melmat', torch.from_numpy(melmat.T).float())
- self.stft_params = {
- 'n_fft': self.fft_size,
- 'win_length': self.win_length,
- 'hop_length': self.hop_size,
- 'center': self.center,
- 'normalized': self.normalized,
- 'onesided': self.onesided,
- 'pad_mode': self.pad_mode,
- }
- if is_pytorch_17plus:
- self.stft_params['return_complex'] = False
-
- self.log_base = log_base
- if self.log_base is None:
- self.log = torch.log
- elif self.log_base == 2.0:
- self.log = torch.log2
- elif self.log_base == 10.0:
- self.log = torch.log10
- else:
- raise ValueError(f'log_base: {log_base} is not supported.')
-
- def forward(self, x):
- """Calculate Mel-spectrogram.
-
- Args:
- x (Tensor): Input waveform tensor (B, T) or (B, 1, T).
-
- Returns:
- Tensor: Mel-spectrogram (B, #mels, #frames).
-
- """
- if x.dim() == 3:
- # (B, C, T) -> (B*C, T)
- x = x.reshape(-1, x.size(2))
-
- if self.window is not None:
- window_func = getattr(torch, f'{self.window}_window')
- window = window_func(
- self.win_length, dtype=x.dtype, device=x.device)
- else:
- window = None
-
- x_stft = torch.stft(x, window=window, **self.stft_params)
- # (B, #freqs, #frames, 2) -> (B, $frames, #freqs, 2)
- x_stft = x_stft.transpose(1, 2)
- x_power = x_stft[..., 0]**2 + x_stft[..., 1]**2
- x_amp = torch.sqrt(torch.clamp(x_power, min=self.eps))
-
- x_mel = torch.matmul(x_amp, self.melmat)
- x_mel = torch.clamp(x_mel, min=self.eps)
- x_mel = spectral_normalize_torch(x_mel)
-
- # return self.log(x_mel).transpose(1, 2)
- return x_mel.transpose(1, 2)
diff --git a/modelscope/models/audio/tts/kantts/utils/ling_unit/__init__.py b/modelscope/models/audio/tts/kantts/utils/ling_unit/__init__.py
deleted file mode 100644
index b3a29992..00000000
--- a/modelscope/models/audio/tts/kantts/utils/ling_unit/__init__.py
+++ /dev/null
@@ -1,26 +0,0 @@
-# Copyright (c) Alibaba, Inc. and its affiliates.
-
-import ttsfrd
-
-
-def text_to_mit_symbols(texts, resources_dir, speaker):
- fe = ttsfrd.TtsFrontendEngine()
- fe.initialize(resources_dir)
- fe.set_lang_type('Zh-CN')
-
- symbols_lst = []
- for idx, text in enumerate(texts):
- text = text.strip()
- res = fe.gen_tacotron_symbols(text)
- res = res.replace('F7', speaker)
- sentences = res.split('\n')
- for sentence in sentences:
- arr = sentence.split('\t')
- # skip the empty line
- if len(arr) != 2:
- continue
- sub_index, symbols = sentence.split('\t')
- symbol_str = '{}_{}\t{}\n'.format(idx, sub_index, symbols)
- symbols_lst.append(symbol_str)
-
- return symbols_lst
diff --git a/modelscope/models/audio/tts/kantts/utils/ling_unit/cleaners.py b/modelscope/models/audio/tts/kantts/utils/ling_unit/cleaners.py
deleted file mode 100644
index 8697efd2..00000000
--- a/modelscope/models/audio/tts/kantts/utils/ling_unit/cleaners.py
+++ /dev/null
@@ -1,85 +0,0 @@
-# from https://github.com/keithito/tacotron
-# Cleaners are transformations that run over the input text at both training and eval time.
-#
-# Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
-# hyperparameter. Some cleaners are English-specific. You'll typically want to use:
-# 1. "english_cleaners" for English text
-# 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
-# the Unidecode library (https://pypi.python.org/pypi/Unidecode)
-# 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
-# the symbols in symbols.py to match your data).
-
-import re
-
-from unidecode import unidecode
-
-from .numbers import normalize_numbers
-
-# Regular expression matching whitespace:
-_whitespace_re = re.compile(r'\s+')
-
-# List of (regular expression, replacement) pairs for abbreviations:
-_abbreviations = [
- (re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1])
- for x in [('mrs', 'misess'), ('mr', 'mister'), (
- 'dr', 'doctor'), ('st', 'saint'), ('co', 'company'), (
- 'jr',
- 'junior'), ('maj', 'major'), ('gen', 'general'), (
- 'drs', 'doctors'), ('rev', 'reverend'), (
- 'lt',
- 'lieutenant'), ('hon', 'honorable'), (
- 'sgt',
- 'sergeant'), ('capt', 'captain'), (
- 'esq',
- 'esquire'), ('ltd',
- 'limited'), ('col',
- 'colonel'), ('ft',
- 'fort')]
-]
-
-
-def expand_abbreviations(text):
- for regex, replacement in _abbreviations:
- text = re.sub(regex, replacement, text)
- return text
-
-
-def expand_numbers(text):
- return normalize_numbers(text)
-
-
-def lowercase(text):
- return text.lower()
-
-
-def collapse_whitespace(text):
- return re.sub(_whitespace_re, ' ', text)
-
-
-def convert_to_ascii(text):
- return unidecode(text)
-
-
-def basic_cleaners(text):
- """Basic pipeline that lowercases and collapses whitespace without transliteration."""
- text = lowercase(text)
- text = collapse_whitespace(text)
- return text
-
-
-def transliteration_cleaners(text):
- """Pipeline for non-English text that transliterates to ASCII."""
- text = convert_to_ascii(text)
- text = lowercase(text)
- text = collapse_whitespace(text)
- return text
-
-
-def english_cleaners(text):
- """Pipeline for English text, including number and abbreviation expansion."""
- text = convert_to_ascii(text)
- text = lowercase(text)
- text = expand_numbers(text)
- text = expand_abbreviations(text)
- text = collapse_whitespace(text)
- return text
diff --git a/modelscope/models/audio/tts/kantts/utils/ling_unit/emotion_types.py b/modelscope/models/audio/tts/kantts/utils/ling_unit/emotion_types.py
deleted file mode 100644
index 3ae328de..00000000
--- a/modelscope/models/audio/tts/kantts/utils/ling_unit/emotion_types.py
+++ /dev/null
@@ -1,37 +0,0 @@
-# Copyright (c) Alibaba, Inc. and its affiliates.
-
-emotion_types = [
- 'emotion_none',
- 'emotion_neutral',
- 'emotion_angry',
- 'emotion_disgust',
- 'emotion_fear',
- 'emotion_happy',
- 'emotion_sad',
- 'emotion_surprise',
- 'emotion_calm',
- 'emotion_gentle',
- 'emotion_relax',
- 'emotion_lyrical',
- 'emotion_serious',
- 'emotion_disgruntled',
- 'emotion_satisfied',
- 'emotion_disappointed',
- 'emotion_excited',
- 'emotion_anxiety',
- 'emotion_jealousy',
- 'emotion_hate',
- 'emotion_pity',
- 'emotion_pleasure',
- 'emotion_arousal',
- 'emotion_dominance',
- 'emotion_placeholder1',
- 'emotion_placeholder2',
- 'emotion_placeholder3',
- 'emotion_placeholder4',
- 'emotion_placeholder5',
- 'emotion_placeholder6',
- 'emotion_placeholder7',
- 'emotion_placeholder8',
- 'emotion_placeholder9',
-]
diff --git a/modelscope/models/audio/tts/kantts/utils/ling_unit/lang_symbols.py b/modelscope/models/audio/tts/kantts/utils/ling_unit/lang_symbols.py
deleted file mode 100644
index e7b3399c..00000000
--- a/modelscope/models/audio/tts/kantts/utils/ling_unit/lang_symbols.py
+++ /dev/null
@@ -1,88 +0,0 @@
-# Copyright (c) Alibaba, Inc. and its affiliates.
-
-import os
-import xml.etree.ElementTree as ET
-
-from modelscope.models.audio.tts.kantts.preprocess.languages import languages
-from modelscope.utils.logger import get_logger
-
-logging = get_logger()
-
-syllable_flags = [
- 's_begin',
- 's_end',
- 's_none',
- 's_both',
- 's_middle',
-]
-
-word_segments = [
- 'word_begin',
- 'word_end',
- 'word_middle',
- 'word_both',
- 'word_none',
-]
-
-
-def parse_phoneset(phoneset_file):
- """Parse a phoneset file and return a list of symbols.
- Args:
- phoneset_file (str): Path to the phoneset file.
-
- Returns:
- list: A list of phones.
- """
- ns = '{http://schemas.alibaba-inc.com/tts}'
-
- phone_lst = []
- phoneset_root = ET.parse(phoneset_file).getroot()
- for phone_node in phoneset_root.findall(ns + 'phone'):
- phone_lst.append(phone_node.find(ns + 'name').text)
-
- for i in range(1, 5):
- phone_lst.append('#{}'.format(i))
-
- return phone_lst
-
-
-def parse_tonelist(tonelist_file):
- """Parse a tonelist file and return a list of tones.
- Args:
- tonelist_file (str): Path to the tonelist file.
-
- Returns:
- dict: A dictionary of tones.
- """
- tone_lst = []
- with open(tonelist_file, 'r') as f:
- lines = f.readlines()
- for line in lines:
- tone = line.strip()
- if tone != '':
- tone_lst.append('tone{}'.format(tone))
- else:
- tone_lst.append('tone_none')
-
- return tone_lst
-
-
-def get_language_symbols(language, language_dir):
- """Get symbols of a language.
- Args:
- language (str): Language name.
- """
- language_dict = languages.get(language, None)
- if language_dict is None:
- logging.error('Language %s not supported. Using PinYin as default',
- language)
- language_dict = languages['PinYin']
- language = 'PinYin'
-
- language_dir = os.path.join(language_dir, language)
- phoneset_file = os.path.join(language_dir, language_dict['phoneset_path'])
- tonelist_file = os.path.join(language_dir, language_dict['tonelist_path'])
- phones = parse_phoneset(phoneset_file)
- tones = parse_tonelist(tonelist_file)
-
- return phones, tones, syllable_flags, word_segments
diff --git a/modelscope/models/audio/tts/kantts/utils/ling_unit/ling_unit.py b/modelscope/models/audio/tts/kantts/utils/ling_unit/ling_unit.py
deleted file mode 100644
index a1a9ffdb..00000000
--- a/modelscope/models/audio/tts/kantts/utils/ling_unit/ling_unit.py
+++ /dev/null
@@ -1,422 +0,0 @@
-# Copyright (c) Alibaba, Inc. and its affiliates.
-
-import abc
-import os
-import re
-import shutil
-
-import numpy as np
-
-from . import cleaners as cleaners
-from .emotion_types import emotion_types
-from .lang_symbols import get_language_symbols
-
-# Regular expression matching text enclosed in curly braces:
-_curly_re = re.compile(r'(.*?)\{(.+?)\}(.*)')
-
-
-def _clean_text(text, cleaner_names):
- for name in cleaner_names:
- cleaner = getattr(cleaners, name)
- if not cleaner:
- raise Exception('Unknown cleaner: %s' % name)
- text = cleaner(text)
- return text
-
-
-def get_fpdict(config):
- # eomtion_neutral(F7) can be other emotion(speaker) types in the corresponding list in config file.
- default_sp = config['linguistic_unit']['speaker_list'].split(',')[0]
- en_sy = f'{{ge$tone5$s_begin$word_begin$emotion_neutral${default_sp}}} {{en_c$tone5$s_end$word_end$emotion_neutral${default_sp}}} {{#3$tone_none$s_none$word_none$emotion_neutral${default_sp}}}' # NOQA: E501
- a_sy = f'{{ga$tone5$s_begin$word_begin$emotion_neutral${default_sp}}} {{a_c$tone5$s_end$word_end$emotion_neutral${default_sp}}} {{#3$tone_none$s_none$word_none$emotion_neutral${default_sp}}}' # NOQA: E501
- e_sy = f'{{ge$tone5$s_begin$word_begin$emotion_neutral${default_sp}}} {{e_c$tone5$s_end$word_end$emotion_neutral${default_sp}}} {{#3$tone_none$s_none$word_none$emotion_neutral${default_sp}}}' # NOQA: E501
- ling_unit = KanTtsLinguisticUnit(config)
-
- en_lings = ling_unit.encode_symbol_sequence(en_sy)
- a_lings = ling_unit.encode_symbol_sequence(a_sy)
- e_lings = ling_unit.encode_symbol_sequence(e_sy)
-
- en_ling = np.stack(en_lings, axis=1)[:3, :4]
- a_ling = np.stack(a_lings, axis=1)[:3, :4]
- e_ling = np.stack(e_lings, axis=1)[:3, :4]
-
- fp_dict = {1: en_ling, 2: a_ling, 3: e_ling}
- return fp_dict
-
-
-class LinguisticBaseUnit(abc.ABC):
-
- def set_config_params(self, config_params):
- self.config_params = config_params
-
- def save(self, config, config_name, path):
- """Save config to file"""
- t_path = os.path.join(path, config_name)
- if config != t_path:
- os.makedirs(path, exist_ok=True)
- shutil.copyfile(config, os.path.join(path, config_name))
-
-
-class KanTtsLinguisticUnit(LinguisticBaseUnit):
-
- def __init__(self, config, lang_dir=None):
- super(KanTtsLinguisticUnit, self).__init__()
-
- # special symbol
- self._pad = '_'
- self._eos = '~'
- self._mask = '@[MASK]'
-
- self.unit_config = config['linguistic_unit']
- self.has_mask = self.unit_config.get('has_mask', True)
- self.lang_type = self.unit_config.get('language', 'PinYin')
- (
- self.lang_phones,
- self.lang_tones,
- self.lang_syllable_flags,
- self.lang_word_segments,
- ) = get_language_symbols(self.lang_type, lang_dir)
-
- self._cleaner_names = [
- x.strip() for x in self.unit_config['cleaners'].split(',')
- ]
- _lfeat_type_list = self.unit_config['lfeat_type_list'].strip().split(
- ',')
- self._lfeat_type_list = _lfeat_type_list
-
- self.fp_enable = config['Model']['KanTtsSAMBERT']['params'].get(
- 'FP', False)
- if self.fp_enable:
- self._fpadd_lfeat_type_list = [
- _lfeat_type_list[0], _lfeat_type_list[4]
- ]
-
- self.build()
-
- def using_byte(self):
- return 'byte_index' in self._lfeat_type_list
-
- def get_unit_size(self):
- ling_unit_size = {}
- if self.using_byte():
- ling_unit_size['byte_index'] = len(self.byte_index)
- else:
- ling_unit_size['sy'] = len(self.sy)
- ling_unit_size['tone'] = len(self.tone)
- ling_unit_size['syllable_flag'] = len(self.syllable_flag)
- ling_unit_size['word_segment'] = len(self.word_segment)
-
- if 'emo_category' in self._lfeat_type_list:
- ling_unit_size['emotion'] = len(self.emo_category)
- if 'speaker_category' in self._lfeat_type_list:
- ling_unit_size['speaker'] = len(self.speaker)
-
- return ling_unit_size
-
- def build(self):
- self._sub_unit_dim = {}
- self._sub_unit_pad = {}
- if self.using_byte():
- # Export all byte indices:
- self.byte_index = ['@' + str(idx) for idx in range(256)] + [
- self._pad,
- self._eos,
- ]
- if self.has_mask:
- self.byte_index.append(self._mask)
- self._byte_index_to_id = {
- s: i
- for i, s in enumerate(self.byte_index)
- }
- self._id_to_byte_index = {
- i: s
- for i, s in enumerate(self.byte_index)
- }
- self._sub_unit_dim['byte_index'] = len(self.byte_index)
- self._sub_unit_pad['byte_index'] = self._byte_index_to_id['_']
- else:
- # sy sub-unit
- _characters = ''
-
- # Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
- # _arpabet = ['@' + s for s in cmudict.valid_symbols]
- _arpabet = ['@' + s for s in self.lang_phones]
-
- # Export all symbols:
- self.sy = list(_characters) + _arpabet + [self._pad, self._eos]
- if self.has_mask:
- self.sy.append(self._mask)
- self._sy_to_id = {s: i for i, s in enumerate(self.sy)}
- self._id_to_sy = {i: s for i, s in enumerate(self.sy)}
- self._sub_unit_dim['sy'] = len(self.sy)
- self._sub_unit_pad['sy'] = self._sy_to_id['_']
-
- # tone sub-unit
- _characters = ''
-
- # Export all tones:
- self.tone = (
- list(_characters) + self.lang_tones + [self._pad, self._eos])
- if self.has_mask:
- self.tone.append(self._mask)
- self._tone_to_id = {s: i for i, s in enumerate(self.tone)}
- self._id_to_tone = {i: s for i, s in enumerate(self.tone)}
- self._sub_unit_dim['tone'] = len(self.tone)
- self._sub_unit_pad['tone'] = self._tone_to_id['_']
-
- # syllable flag sub-unit
- _characters = ''
-
- # Export all syllable_flags:
- self.syllable_flag = (
- list(_characters) + self.lang_syllable_flags
- + [self._pad, self._eos])
- if self.has_mask:
- self.syllable_flag.append(self._mask)
- self._syllable_flag_to_id = {
- s: i
- for i, s in enumerate(self.syllable_flag)
- }
- self._id_to_syllable_flag = {
- i: s
- for i, s in enumerate(self.syllable_flag)
- }
- self._sub_unit_dim['syllable_flag'] = len(self.syllable_flag)
- self._sub_unit_pad['syllable_flag'] = self._syllable_flag_to_id[
- '_']
-
- # word segment sub-unit
- _characters = ''
-
- # Export all syllable_flags:
- self.word_segment = (
- list(_characters) + self.lang_word_segments
- + [self._pad, self._eos])
- if self.has_mask:
- self.word_segment.append(self._mask)
- self._word_segment_to_id = {
- s: i
- for i, s in enumerate(self.word_segment)
- }
- self._id_to_word_segment = {
- i: s
- for i, s in enumerate(self.word_segment)
- }
- self._sub_unit_dim['word_segment'] = len(self.word_segment)
- self._sub_unit_pad['word_segment'] = self._word_segment_to_id['_']
-
- if 'emo_category' in self._lfeat_type_list:
- # emotion category sub-unit
- _characters = ''
-
- self.emo_category = (
- list(_characters) + emotion_types + [self._pad, self._eos])
- if self.has_mask:
- self.emo_category.append(self._mask)
- self._emo_category_to_id = {
- s: i
- for i, s in enumerate(self.emo_category)
- }
- self._id_to_emo_category = {
- i: s
- for i, s in enumerate(self.emo_category)
- }
- self._sub_unit_dim['emo_category'] = len(self.emo_category)
- self._sub_unit_pad['emo_category'] = self._emo_category_to_id['_']
-
- if 'speaker_category' in self._lfeat_type_list:
- # speaker category sub-unit
- _characters = ''
-
- _ch_speakers = self.unit_config['speaker_list'].strip().split(',')
-
- # Export all syllable_flags:
- self.speaker = (
- list(_characters) + _ch_speakers + [self._pad, self._eos])
- if self.has_mask:
- self.speaker.append(self._mask)
- self._speaker_to_id = {s: i for i, s in enumerate(self.speaker)}
- self._id_to_speaker = {i: s for i, s in enumerate(self.speaker)}
- self._sub_unit_dim['speaker_category'] = len(self._speaker_to_id)
- self._sub_unit_pad['speaker_category'] = self._speaker_to_id['_']
-
- def encode_symbol_sequence(self, lfeat_symbol):
- lfeat_symbol = lfeat_symbol.strip().split(' ')
-
- lfeat_symbol_separate = [''] * int(len(self._lfeat_type_list))
- for this_lfeat_symbol in lfeat_symbol:
- this_lfeat_symbol = this_lfeat_symbol.strip('{').strip('}').split(
- '$')
- index = 0
- while index < len(lfeat_symbol_separate):
- lfeat_symbol_separate[index] = (
- lfeat_symbol_separate[index] + this_lfeat_symbol[index]
- + ' ')
- index = index + 1
-
- input_and_label_data = []
- index = 0
- while index < len(self._lfeat_type_list):
- sequence = self.encode_sub_unit(
- lfeat_symbol_separate[index].strip(),
- self._lfeat_type_list[index])
- sequence_array = np.asarray(sequence, dtype=np.int32)
- input_and_label_data.append(sequence_array)
- index = index + 1
-
- return input_and_label_data
-
- def decode_symbol_sequence(self, sequence):
- result = []
- for i, lfeat_type in enumerate(self._lfeat_type_list):
- s = ''
- sequence_item = sequence[i].tolist()
- if lfeat_type == 'sy':
- s = self.decode_sy(sequence_item)
- elif lfeat_type == 'byte_index':
- s = self.decode_byte_index(sequence_item)
- elif lfeat_type == 'tone':
- s = self.decode_tone(sequence_item)
- elif lfeat_type == 'syllable_flag':
- s = self.decode_syllable_flag(sequence_item)
- elif lfeat_type == 'word_segment':
- s = self.decode_word_segment(sequence_item)
- elif lfeat_type == 'emo_category':
- s = self.decode_emo_category(sequence_item)
- elif lfeat_type == 'speaker_category':
- s = self.decode_speaker_category(sequence_item)
- else:
- raise Exception('Unknown lfeat type: %s' % lfeat_type)
- result.append('%s:%s' % (lfeat_type, s))
-
- return
-
- def encode_sub_unit(self, this_lfeat_symbol, lfeat_type):
- sequence = []
- if lfeat_type == 'sy':
- this_lfeat_symbol = this_lfeat_symbol.strip().split(' ')
- this_lfeat_symbol_format = ''
- index = 0
- while index < len(this_lfeat_symbol):
- this_lfeat_symbol_format = (
- this_lfeat_symbol_format + '{' + this_lfeat_symbol[index]
- + '}' + ' ')
- index = index + 1
- sequence = self.encode_text(this_lfeat_symbol_format,
- self._cleaner_names)
- elif lfeat_type == 'byte_index':
- sequence = self.encode_byte_index(this_lfeat_symbol)
- elif lfeat_type == 'tone':
- sequence = self.encode_tone(this_lfeat_symbol)
- elif lfeat_type == 'syllable_flag':
- sequence = self.encode_syllable_flag(this_lfeat_symbol)
- elif lfeat_type == 'word_segment':
- sequence = self.encode_word_segment(this_lfeat_symbol)
- elif lfeat_type == 'emo_category':
- sequence = self.encode_emo_category(this_lfeat_symbol)
- elif lfeat_type == 'speaker_category':
- sequence = self.encode_speaker_category(this_lfeat_symbol)
- else:
- raise Exception('Unknown lfeat type: %s' % lfeat_type)
- return sequence
-
- def encode_text(self, text, cleaner_names):
- sequence = []
-
- # Check for curly braces and treat their contents as ARPAbet:
- while len(text):
- m = _curly_re.match(text)
- if not m:
- sequence += self.encode_sy(_clean_text(text, cleaner_names))
- break
- sequence += self.encode_sy(_clean_text(m.group(1), cleaner_names))
- sequence += self.encode_arpanet(m.group(2))
- text = m.group(3)
-
- # Append EOS token
- sequence.append(self._sy_to_id['~'])
- return sequence
-
- def encode_sy(self, sy):
- return [self._sy_to_id[s] for s in sy if self.should_keep_sy(s)]
-
- def decode_sy(self, id):
- s = self._id_to_sy[id]
- if len(s) > 1 and s[0] == '@':
- s = s[1:]
- return s
-
- def should_keep_sy(self, s):
- return s in self._sy_to_id and s != '_' and s != '~'
-
- def encode_arpanet(self, text):
- return self.encode_sy(['@' + s for s in text.split()])
-
- def encode_byte_index(self, byte_index):
- byte_indices = ['@' + s for s in byte_index.strip().split(' ')]
- sequence = []
- for this_byte_index in byte_indices:
- sequence.append(self._byte_index_to_id[this_byte_index])
- sequence.append(self._byte_index_to_id['~'])
- return sequence
-
- def decode_byte_index(self, id):
- s = self._id_to_byte_index[id]
- if len(s) > 1 and s[0] == '@':
- s = s[1:]
- return s
-
- def encode_tone(self, tone):
- tones = tone.strip().split(' ')
- sequence = []
- for this_tone in tones:
- sequence.append(self._tone_to_id[this_tone])
- sequence.append(self._tone_to_id['~'])
- return sequence
-
- def decode_tone(self, id):
- return self._id_to_tone[id]
-
- def encode_syllable_flag(self, syllable_flag):
- syllable_flags = syllable_flag.strip().split(' ')
- sequence = []
- for this_syllable_flag in syllable_flags:
- sequence.append(self._syllable_flag_to_id[this_syllable_flag])
- sequence.append(self._syllable_flag_to_id['~'])
- return sequence
-
- def decode_syllable_flag(self, id):
- return self._id_to_syllable_flag[id]
-
- def encode_word_segment(self, word_segment):
- word_segments = word_segment.strip().split(' ')
- sequence = []
- for this_word_segment in word_segments:
- sequence.append(self._word_segment_to_id[this_word_segment])
- sequence.append(self._word_segment_to_id['~'])
- return sequence
-
- def decode_word_segment(self, id):
- return self._id_to_word_segment[id]
-
- def encode_emo_category(self, emo_type):
- emo_categories = emo_type.strip().split(' ')
- sequence = []
- for this_category in emo_categories:
- sequence.append(self._emo_category_to_id[this_category])
- sequence.append(self._emo_category_to_id['~'])
- return sequence
-
- def decode_emo_category(self, id):
- return self._id_to_emo_category[id]
-
- def encode_speaker_category(self, speaker):
- speakers = speaker.strip().split(' ')
- sequence = []
- for this_speaker in speakers:
- sequence.append(self._speaker_to_id[this_speaker])
- sequence.append(self._speaker_to_id['~'])
- return sequence
-
- def decode_speaker_category(self, id):
- return self._id_to_speaker[id]
diff --git a/modelscope/models/audio/tts/kantts/utils/ling_unit/numbers.py b/modelscope/models/audio/tts/kantts/utils/ling_unit/numbers.py
deleted file mode 100644
index 60814a2e..00000000
--- a/modelscope/models/audio/tts/kantts/utils/ling_unit/numbers.py
+++ /dev/null
@@ -1,72 +0,0 @@
-# Copyright (c) Alibaba, Inc. and its affiliates.
-
-import re
-
-import inflect
-
-_inflect = inflect.engine()
-_comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])')
-_decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)')
-_pounds_re = re.compile(r'£([0-9\,]*[0-9]+)')
-_dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)')
-_ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)')
-_number_re = re.compile(r'[0-9]+')
-
-
-def _remove_commas(m):
- return m.group(1).replace(',', '')
-
-
-def _expand_decimal_point(m):
- return m.group(1).replace('.', ' point ')
-
-
-def _expand_dollars(m):
- match = m.group(1)
- parts = match.split('.')
- if len(parts) > 2:
- return match + ' dollars' # Unexpected format
- dollars = int(parts[0]) if parts[0] else 0
- cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
- if dollars and cents:
- dollar_unit = 'dollar' if dollars == 1 else 'dollars'
- cent_unit = 'cent' if cents == 1 else 'cents'
- return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit)
- elif dollars:
- dollar_unit = 'dollar' if dollars == 1 else 'dollars'
- return '%s %s' % (dollars, dollar_unit)
- elif cents:
- cent_unit = 'cent' if cents == 1 else 'cents'
- return '%s %s' % (cents, cent_unit)
- else:
- return 'zero dollars'
-
-
-def _expand_ordinal(m):
- return _inflect.number_to_words(m.group(0))
-
-
-def _expand_number(m):
- num = int(m.group(0))
- if num > 1000 and num < 3000:
- if num == 2000:
- return 'two thousand'
- elif num > 2000 and num < 2010:
- return 'two thousand ' + _inflect.number_to_words(num % 100)
- elif num % 100 == 0:
- return _inflect.number_to_words(num // 100) + ' hundred'
- else:
- return _inflect.number_to_words(
- num, andword='', zero='oh', group=2).replace(', ', ' ')
- else:
- return _inflect.number_to_words(num, andword='')
-
-
-def normalize_numbers(text):
- text = re.sub(_comma_number_re, _remove_commas, text)
- text = re.sub(_pounds_re, r'\1 pounds', text)
- text = re.sub(_dollars_re, _expand_dollars, text)
- text = re.sub(_decimal_number_re, _expand_decimal_point, text)
- text = re.sub(_ordinal_re, _expand_ordinal, text)
- text = re.sub(_number_re, _expand_number, text)
- return text
diff --git a/modelscope/models/audio/tts/kantts/utils/log.py b/modelscope/models/audio/tts/kantts/utils/log.py
deleted file mode 100644
index 58d36124..00000000
--- a/modelscope/models/audio/tts/kantts/utils/log.py
+++ /dev/null
@@ -1,26 +0,0 @@
-# Copyright (c) Alibaba, Inc. and its affiliates.
-
-import logging
-import subprocess
-
-
-def logging_to_file(log_file):
- logger = logging.getLogger()
- handler = logging.FileHandler(log_file)
- formatter = logging.Formatter(
- '%(asctime)s %(levelname)-4s [%(filename)s:%(lineno)d] %(message)s',
- datefmt='%Y-%m-%d:%H:%M:%S',
- )
- handler.setFormatter(formatter)
- logger.addHandler(handler)
- logger.setLevel(logging.INFO)
-
-
-def get_git_revision_short_hash():
- return (subprocess.check_output(['git', 'rev-parse', '--short',
- 'HEAD']).decode('ascii').strip())
-
-
-def get_git_revision_hash():
- return subprocess.check_output(['git', 'rev-parse',
- 'HEAD']).decode('ascii').strip()
diff --git a/modelscope/models/audio/tts/kantts/utils/plot.py b/modelscope/models/audio/tts/kantts/utils/plot.py
deleted file mode 100644
index c1f2a601..00000000
--- a/modelscope/models/audio/tts/kantts/utils/plot.py
+++ /dev/null
@@ -1,39 +0,0 @@
-# Copyright (c) Alibaba, Inc. and its affiliates.
-
-import matplotlib
-
-matplotlib.use('Agg')
-try:
- import matplotlib.pyplot as plt
-except ImportError:
- raise ImportError('Please install matplotlib.')
-
-plt.set_loglevel('info')
-
-
-def plot_spectrogram(spectrogram):
- fig, ax = plt.subplots(figsize=(12, 8))
- im = ax.imshow(
- spectrogram, aspect='auto', origin='lower', interpolation='none')
- plt.colorbar(im, ax=ax)
-
- fig.canvas.draw()
- plt.close()
-
- return fig
-
-
-def plot_alignment(alignment, info=None):
- fig, ax = plt.subplots()
- im = ax.imshow(
- alignment, aspect='auto', origin='lower', interpolation='none')
- fig.colorbar(im, ax=ax)
- xlabel = 'Input timestep'
- if info is not None:
- xlabel += '\t' + info
- plt.xlabel(xlabel)
- plt.ylabel('Output timestep')
- fig.canvas.draw()
- plt.close()
-
- return fig
diff --git a/modelscope/models/audio/tts/sambert_hifi.py b/modelscope/models/audio/tts/sambert_hifi.py
index 0c5da33f..6df9ec97 100644
--- a/modelscope/models/audio/tts/sambert_hifi.py
+++ b/modelscope/models/audio/tts/sambert_hifi.py
@@ -9,13 +9,15 @@ import wave
import zipfile
import json
+import matplotlib.pyplot as plt
import numpy as np
import yaml
from modelscope.metainfo import Models
from modelscope.models.base import Model
from modelscope.models.builder import MODELS
-from modelscope.utils.audio.audio_utils import TtsTrainType, ndarray_pcm_to_wav
+from modelscope.utils.audio.audio_utils import (TtsCustomParams, TtsTrainType,
+ ndarray_pcm_to_wav)
from modelscope.utils.audio.tts_exceptions import (
TtsFrontendInitializeFailedException,
TtsFrontendLanguageTypeInvalidException, TtsModelConfigurationException,
@@ -35,74 +37,111 @@ class SambertHifigan(Model):
def __init__(self, model_dir, *args, **kwargs):
super().__init__(model_dir, *args, **kwargs)
- self.__model_dir = model_dir
- self.__sample_rate = kwargs.get('sample_rate', 16000)
- self.__is_train = False
+ self.model_dir = model_dir
+ self.sample_rate = kwargs.get('sample_rate', 16000)
+ self.is_train = False
if 'is_train' in kwargs:
is_train = kwargs['is_train']
if isinstance(is_train, bool):
- self.__is_train = is_train
+ self.is_train = is_train
+ # check legacy modelcard
+ self.ignore_mask = False
+ if 'am' in kwargs:
+ if 'linguistic_unit' in kwargs['am']:
+ self.ignore_mask = not kwargs['am']['linguistic_unit'].get(
+ 'has_mask', True)
+ self.voices, self.voice_cfg, self.lang_type = self.load_voice(
+ model_dir, kwargs.get('custom_ckpt', {}))
+ if len(self.voices) == 0 or len(self.voice_cfg.get('voices', [])) == 0:
+ raise TtsVoiceNotExistsException('modelscope error: voices empty')
+ if self.voice_cfg['voices']:
+ self.default_voice_name = self.voice_cfg['voices'][0]
+ else:
+ raise TtsVoiceNotExistsException(
+ 'modelscope error: voices is empty in voices.json')
# initialize frontend
import ttsfrd
frontend = ttsfrd.TtsFrontendEngine()
zip_file = os.path.join(model_dir, 'resource.zip')
- self.__res_path = os.path.join(model_dir, 'resource')
+ self.res_path = os.path.join(model_dir, 'resource')
with zipfile.ZipFile(zip_file, 'r') as zip_ref:
zip_ref.extractall(model_dir)
- if not frontend.initialize(self.__res_path):
+ if not frontend.initialize(self.res_path):
raise TtsFrontendInitializeFailedException(
- 'modelscope error: resource invalid: {}'.format(
- self.__res_path))
- if not frontend.set_lang_type(kwargs['lang_type']):
+ 'modelscope error: resource invalid: {}'.format(self.res_path))
+ if not frontend.set_lang_type(self.lang_type):
raise TtsFrontendLanguageTypeInvalidException(
'modelscope error: language type invalid: {}'.format(
- kwargs['lang_type']))
- self.__frontend = frontend
- self.__voices, self.__voice_cfg = self.load_voice(model_dir)
- if len(self.__voices) == 0 or len(self.__voice_cfg) == 0:
- raise TtsVoiceNotExistsException('modelscope error: voices empty')
- if self.__voice_cfg['voices']:
- self.__default_voice_name = self.__voice_cfg['voices'][0]
- else:
- raise TtsVoiceNotExistsException(
- 'modelscope error: voices is empty in voices.json')
+ self.lang_type))
+ self.frontend = frontend
- def load_voice(self, model_dir):
+ def build_voice_from_custom(self, model_dir, custom_ckpt):
+ necessary_files = (TtsCustomParams.VOICE_NAME, TtsCustomParams.AM_CKPT,
+ TtsCustomParams.VOC_CKPT, TtsCustomParams.AM_CONFIG,
+ TtsCustomParams.VOC_CONFIG)
+ voices = {}
+ voices_cfg = {}
+ lang_type = 'PinYin'
+ for k in necessary_files:
+ if k not in custom_ckpt:
+ raise TtsModelNotExistsException(
+ f'custom ckpt must have: {necessary_files}')
+ voice_name = custom_ckpt[TtsCustomParams.VOICE_NAME]
+ voice = Voice(
+ voice_name=voice_name,
+ voice_path=model_dir,
+ custom_ckpt=custom_ckpt,
+ ignore_mask=self.ignore_mask,
+ is_train=self.is_train)
+ voices[voice_name] = voice
+ voices_cfg['voices'] = [voice_name]
+ lang_type = voice.lang_type
+ return voices, voices_cfg, lang_type
+
+ def load_voice(self, model_dir, custom_ckpt):
voices = {}
voices_path = os.path.join(model_dir, 'voices')
voices_json_path = os.path.join(voices_path, 'voices.json')
+ lang_type = 'PinYin'
+ if len(custom_ckpt) != 0:
+ return self.build_voice_from_custom(model_dir, custom_ckpt)
if not os.path.exists(voices_path) or not os.path.exists(
voices_json_path):
- return voices, []
+ return voices, {}, lang_type
with open(voices_json_path, 'r', encoding='utf-8') as f:
voice_cfg = json.load(f)
if 'voices' not in voice_cfg:
- return voices, []
+ return voices, {}, lang_type
for name in voice_cfg['voices']:
voice_path = os.path.join(voices_path, name)
if not os.path.exists(voice_path):
continue
- voices[name] = Voice(name, voice_path)
- return voices, voice_cfg
+ voices[name] = Voice(
+ name,
+ voice_path,
+ ignore_mask=self.ignore_mask,
+ is_train=self.is_train)
+ lang_type = voices[name].lang_type
+ return voices, voice_cfg, lang_type
def save_voices(self):
- voices_json_path = os.path.join(self.__model_dir, 'voices',
+ voices_json_path = os.path.join(self.model_dir, 'voices',
'voices.json')
if os.path.exists(voices_json_path):
os.remove(voices_json_path)
save_voices = {}
save_voices['voices'] = []
- for k in self.__voices.keys():
+ for k in self.voices.keys():
save_voices['voices'].append(k)
with open(voices_json_path, 'w', encoding='utf-8') as f:
json.dump(save_voices, f)
def get_voices(self):
- return self.__voices, self.__voice_cfg
+ return self.voices, self.voice_cfg
def create_empty_voice(self, voice_name, audio_config, am_config_path,
voc_config_path):
- voice_name_path = os.path.join(self.__model_dir, 'voices', voice_name)
+ voice_name_path = os.path.join(self.model_dir, 'voices', voice_name)
if os.path.exists(voice_name_path):
shutil.rmtree(voice_name_path)
os.makedirs(voice_name_path, exist_ok=True)
@@ -123,63 +162,76 @@ class SambertHifigan(Model):
voc_ckpt_path = os.path.join(voice_voc_path, 'ckpt')
os.makedirs(am_ckpt_path, exist_ok=True)
os.makedirs(voc_ckpt_path, exist_ok=True)
- self.__voices[voice_name] = Voice(
+ self.voices[voice_name] = Voice(
voice_name=voice_name,
voice_path=voice_name_path,
allow_empty=True)
def get_voice_audio_config_path(self, voice):
- if voice not in self.__voices:
+ if voice not in self.voices:
+ return ''
+ return self.voices[voice].audio_config
+
+ def get_voice_se_model_path(self, voice):
+ if voice not in self.voices:
+ return ''
+ if self.voices[voice].se_enable:
+ return self.voices[voice].se_model_path
+ else:
return ''
- return self.__voices[voice].audio_config
def get_voice_lang_path(self, voice):
- if voice not in self.__voices:
+ if voice not in self.voices:
return ''
- return self.__voices[voice].lang_dir
+ return self.voices[voice].lang_dir
- def __synthesis_one_sentences(self, voice_name, text):
- if voice_name not in self.__voices:
+ def synthesis_one_sentences(self, voice_name, text):
+ if voice_name not in self.voices:
raise TtsVoiceNotExistsException(
f'modelscope error: Voice {voice_name} not exists')
- return self.__voices[voice_name].forward(text)
+ return self.voices[voice_name].forward(text)
def train(self,
voice,
dirs,
train_type,
- configs_path=None,
+ configs_path_dict=None,
ignore_pretrain=False,
create_if_not_exists=False,
hparam=None):
+ plt.set_loglevel('info')
work_dir = dirs['work_dir']
am_dir = dirs['am_tmp_dir']
voc_dir = dirs['voc_tmp_dir']
data_dir = dirs['data_dir']
-
- if voice not in self.__voices:
+ target_voice = None
+ if voice not in self.voices:
if not create_if_not_exists:
raise TtsVoiceNotExistsException(
f'modelscope error: Voice {voice_name} not exists')
- am_config = configs_path.get('am_config', None)
- voc_config = configs_path.get('voc_config', None)
+ am_config_path = configs_path_dict.get('am_config',
+ 'am_config.yaml')
+ voc_config_path = configs_path_dict.get('voc_config',
+ 'voc_config.yaml')
if TtsTrainType.TRAIN_TYPE_SAMBERT in train_type and not am_config:
raise TtsTrainingCfgNotExistsException(
'training new voice am with empty am_config')
if TtsTrainType.TRAIN_TYPE_VOC in train_type and not voc_config:
raise TtsTrainingCfgNotExistsException(
'training new voice voc with empty voc_config')
-
- target_voice = self.__voices[voice]
- am_config_path = target_voice.am_config
- voc_config_path = target_voice.voc_config
- if not configs_path:
- am_config = configs_path.get('am_config', None)
- if am_config:
- am_config_path = am_config
- voc_config = configs_path.get('voc_config', None)
- if voc_config:
- voc_config_path = voc_config
+ else:
+ target_voice = self.voices[voice]
+ am_config_path = target_voice.am_config_path
+ voc_config_path = target_voice.voc_config_path
+ if configs_path_dict:
+ if 'am_config' in configs_path_dict:
+ am_override = configs_path_dict['am_config']
+ if os.path.exists(am_override):
+ am_config_path = am_override
+ if 'voc_config' in configs_path_dict:
+ voc_override = configs_path_dict['voc_config']
+ if os.path.exists(voc_override):
+ voc_config_path = voc_override
logger.info('Start training....')
if TtsTrainType.TRAIN_TYPE_SAMBERT in train_type:
@@ -209,15 +261,15 @@ class SambertHifigan(Model):
logger.info('skip HIFIGAN training...')
def forward(self, text: str, voice_name: str = None):
- voice = self.__default_voice_name
+ voice = self.default_voice_name
if voice_name is not None:
voice = voice_name
- result = self.__frontend.gen_tacotron_symbols(text)
+ result = self.frontend.gen_tacotron_symbols(text)
texts = [s for s in result.splitlines() if s != '']
audio_total = np.empty((0), dtype='int16')
for line in texts:
line = line.strip().split('\t')
- audio = self.__synthesis_one_sentences(voice, line[1])
+ audio = self.synthesis_one_sentences(voice, line[1])
audio = 32768.0 * audio
audio_total = np.append(audio_total, audio.astype('int16'), axis=0)
- return ndarray_pcm_to_wav(self.__sample_rate, audio_total)
+ return ndarray_pcm_to_wav(self.sample_rate, audio_total)
diff --git a/modelscope/models/audio/tts/voice.py b/modelscope/models/audio/tts/voice.py
index b7b91a9e..645a528f 100644
--- a/modelscope/models/audio/tts/voice.py
+++ b/modelscope/models/audio/tts/voice.py
@@ -10,18 +10,20 @@ import json
import numpy as np
import torch
import yaml
+from kantts.datasets.dataset import get_am_datasets, get_voc_datasets
+from kantts.models import model_builder
+from kantts.train.loss import criterion_builder
+from kantts.train.trainer import GAN_Trainer, Sambert_Trainer, distributed_init
+from kantts.utils.ling_unit.ling_unit import KanTtsLinguisticUnit
from torch.utils.data import DataLoader
from modelscope import __version__
+from modelscope.utils.audio.audio_utils import TtsCustomParams
from modelscope.utils.audio.tts_exceptions import (
TtsModelConfigurationException, TtsModelNotExistsException)
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger
-from modelscope.models.audio.tts.kantts import ( # isort:skip; isort:skip
- GAN_Trainer, Generator, KanTtsLinguisticUnit, Sambert_Trainer,
- criterion_builder, get_am_datasets, get_voc_datasets, model_builder)
-
logger = get_logger()
@@ -29,59 +31,201 @@ def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
+def denorm_f0(mel,
+ f0_threshold=30,
+ uv_threshold=0.6,
+ norm_type='mean_std',
+ f0_feature=None):
+ if norm_type == 'mean_std':
+ f0_mvn = f0_feature
+
+ f0 = mel[:, -2]
+ uv = mel[:, -1]
+
+ uv[uv < uv_threshold] = 0.0
+ uv[uv >= uv_threshold] = 1.0
+
+ f0 = f0 * f0_mvn[1:, :] + f0_mvn[0:1, :]
+ f0[f0 < f0_threshold] = f0_threshold
+
+ mel[:, -2] = f0
+ mel[:, -1] = uv
+ else: # global
+ f0_global_max_min = f0_feature
+
+ f0 = mel[:, -2]
+ uv = mel[:, -1]
+
+ uv[uv < uv_threshold] = 0.0
+ uv[uv >= uv_threshold] = 1.0
+
+ f0 = f0 * (f0_global_max_min[0]
+ - f0_global_max_min[1]) + f0_global_max_min[1]
+ f0[f0 < f0_threshold] = f0_threshold
+
+ mel[:, -2] = f0
+ mel[:, -1] = uv
+
+ return mel
+
+
+def binarize(mel, threshold=0.6):
+ # vuv binarize
+ res_mel = mel.clone()
+ index = torch.where(mel[:, -1] < threshold)[0]
+ res_mel[:, -1] = 1.0
+ res_mel[:, -1][index] = 0.0
+ return res_mel
+
+
class Voice:
- def __init__(self, voice_name, voice_path, allow_empty=False):
- self.__voice_name = voice_name
- self.__voice_path = voice_path
- self.distributed = False
- self.local_rank = 0
- am_config_path = os.path.join(
- os.path.join(voice_path, 'am'), 'config.yaml')
- voc_config_path = os.path.join(
- os.path.join(voice_path, 'voc'), 'config.yaml')
+ def __init__(self,
+ voice_name,
+ voice_path=None,
+ custom_ckpt={},
+ ignore_mask=True,
+ is_train=False):
+ self.voice_name = voice_name
+ self.voice_path = voice_path
+ self.ignore_mask = ignore_mask
+ self.is_train = is_train
+ if not torch.cuda.is_available():
+ self.device = torch.device('cpu')
+ self.distributed = False
+ else:
+ torch.backends.cudnn.benchmark = True
+ self.distributed, self.device, self.local_rank, self.world_size = distributed_init(
+ )
- self.audio_config = os.path.join(voice_path, 'audio_config.yaml')
- self.lang_dir = os.path.join(voice_path, 'dict')
- self.am_config = am_config_path
- self.voc_config = voc_config_path
+ if len(custom_ckpt) != 0:
+ self.am_config_path = custom_ckpt[TtsCustomParams.AM_CONFIG]
+ self.voc_config_path = custom_ckpt[TtsCustomParams.VOC_CONFIG]
+ if not os.path.isabs(self.am_config_path):
+ self.am_config_path = os.path.join(voice_path,
+ self.am_config_path)
+ if not os.path.isabs(self.voc_config_path):
+ self.voc_config_path = os.path.join(voice_path,
+ self.voc_config_path)
+ am_ckpt = custom_ckpt[TtsCustomParams.AM_CKPT]
+ voc_ckpt = custom_ckpt[TtsCustomParams.VOC_CKPT]
+ if not os.path.isabs(am_ckpt):
+ am_ckpt = os.path.join(voice_path, am_ckpt)
+ if not os.path.isabs(voc_ckpt):
+ voc_ckpt = os.path.join(voice_path, voc_ckpt)
+ self.am_ckpts = self.scan_ckpt(am_ckpt)
+ self.voc_ckpts = self.scan_ckpt(voc_ckpt)
+ self.se_path = custom_ckpt.get(TtsCustomParams.SE_FILE, 'se.npy')
+ if not os.path.isabs(self.se_path):
+ self.se_path = os.path.join(voice_path, self.se_path)
+ self.se_model_path = custom_ckpt.get(TtsCustomParams.SE_MODEL,
+ 'se.onnx')
+ if not os.path.isabs(self.se_model_path):
+ self.se_model_path = os.path.join(voice_path,
+ self.se_model_path)
+ self.audio_config = custom_ckpt.get(TtsCustomParams.AUIDO_CONFIG,
+ 'audio_config.yaml')
+ if not os.path.isabs(self.audio_config):
+ self.audio_config = os.path.join(voice_path, self.audio_config)
+ self.mvn_path = custom_ckpt.get(TtsCustomParams.MVN_FILE,
+ 'mvn.npy')
+ if not os.path.isabs(self.mvn_path):
+ self.mvn_path = os.path.join(voice_path, self.mvn_path)
+ else:
+ self.audio_config = os.path.join(voice_path, 'audio_config.yaml')
+ self.am_config_path = os.path.join(voice_path, 'am', 'config.yaml')
+ self.voc_config_path = os.path.join(voice_path, 'voc',
+ 'config.yaml')
- am_ckpt = os.path.join(os.path.join(voice_path, 'am'), 'ckpt')
- voc_ckpt = os.path.join(os.path.join(voice_path, 'voc'), 'ckpt')
+ self.se_path = os.path.join(voice_path, 'am', 'se.npy')
+ self.am_ckpts = self.scan_ckpt(
+ os.path.join(voice_path, 'am', 'ckpt'))
+ self.voc_ckpts = self.scan_ckpt(
+ os.path.join(voice_path, 'voc', 'ckpt'))
+ self.mvn_path = os.path.join(voice_path, 'am', 'mvn.npy')
+ self.se_model_path = os.path.join(voice_path, 'se', 'ckpt',
+ 'se.onnx')
- self.__am_ckpts = self.scan_ckpt(am_ckpt)
- self.__voc_ckpts = self.scan_ckpt(voc_ckpt)
+ logger.info(
+ f'am_config={self.am_config_path} voc_config={self.voc_config_path}'
+ )
+ logger.info(f'audio_config={self.audio_config}')
+ logger.info(f'am_ckpts={self.am_ckpts}')
+ logger.info(f'voc_ckpts={self.voc_ckpts}')
+ logger.info(
+ f'se_path={self.se_path} se_model_path={self.se_model_path}')
+ logger.info(f'mvn_path={self.mvn_path}')
- if not os.path.exists(am_config_path):
+ if not os.path.exists(self.am_config_path):
raise TtsModelConfigurationException(
'modelscope error: am configuration not found')
- if not os.path.exists(voc_config_path):
+ if not os.path.exists(self.voc_config_path):
raise TtsModelConfigurationException(
'modelscope error: voc configuration not found')
- if not allow_empty:
- if len(self.__am_ckpts) == 0:
- raise TtsModelNotExistsException(
- 'modelscope error: am model file not found')
- if len(self.__voc_ckpts) == 0:
- raise TtsModelNotExistsException(
- 'modelscope error: voc model file not found')
- with open(am_config_path, 'r') as f:
- self.__am_config = yaml.load(f, Loader=yaml.Loader)
- with open(voc_config_path, 'r') as f:
- self.__voc_config = yaml.load(f, Loader=yaml.Loader)
- self.__model_loaded = False
- self.__lock = Lock()
- self.__ling_unit = KanTtsLinguisticUnit(self.__am_config,
- self.lang_dir)
- self.__ling_unit_size = self.__ling_unit.get_unit_size()
- self.__am_config['Model']['KanTtsSAMBERT']['params'].update(
- self.__ling_unit_size)
- if torch.cuda.is_available():
- self.__device = torch.device('cuda')
- else:
- self.__device = torch.device('cpu')
+ if len(self.am_ckpts) == 0:
+ raise TtsModelNotExistsException(
+ 'modelscope error: am model file not found')
+ if len(self.voc_ckpts) == 0:
+ raise TtsModelNotExistsException(
+ 'modelscope error: voc model file not found')
+ with open(self.am_config_path, 'r') as f:
+ self.am_config = yaml.load(f, Loader=yaml.Loader)
+ with open(self.voc_config_path, 'r') as f:
+ self.voc_config = yaml.load(f, Loader=yaml.Loader)
+ if 'linguistic_unit' not in self.am_config:
+ raise TtsModelConfigurationException(
+ 'no linguistic_unit in am config')
+ self.lang_type = self.am_config['linguistic_unit'].get(
+ 'language', 'PinYin')
+ self.model_loaded = False
+ self.lock = Lock()
+ self.ling_unit = KanTtsLinguisticUnit(self.am_config)
+ self.ling_unit_size = self.ling_unit.get_unit_size()
+ if self.ignore_mask:
+ target_set = set(('sy', 'tone', 'syllable_flag', 'word_segment',
+ 'emotion', 'speaker'))
+ for k, v in self.ling_unit_size.items():
+ if k in target_set:
+ self.ling_unit_size[k] = v - 1
+
+ self.am_config['Model']['KanTtsSAMBERT']['params'].update(
+ self.ling_unit_size)
+
+ self.se_enable = self.am_config['Model']['KanTtsSAMBERT'][
+ 'params'].get('SE', False)
+ if self.se_enable and not self.is_train:
+ if not os.path.exists(self.se_path):
+ raise TtsModelConfigurationException(
+ f'se enabled but se_file:{self.se_path} not exists')
+ self.se = np.load(self.se_path)
+
+ self.nsf_enable = self.am_config['Model']['KanTtsSAMBERT'][
+ 'params'].get('NSF', False)
+ if self.nsf_enable and not self.is_train:
+ self.nsf_norm_type = self.am_config['Model']['KanTtsSAMBERT'][
+ 'params'].get('nsf_norm_type', 'mean_std')
+ if self.nsf_norm_type == 'mean_std':
+ if not os.path.exists(self.mvn_path):
+ raise TtsModelNotExistsException(
+ f'f0_mvn_file: {self.mvn_path} not exists')
+ self.f0_feature = np.load(self.mvn_path)
+ else: # global
+ nsf_f0_global_minimum = self.am_config['Model'][
+ 'KanTtsSAMBERT']['params'].get('nsf_f0_global_minimum',
+ 30.0)
+ nsf_f0_global_maximum = self.am_config['Model'][
+ 'KanTtsSAMBERT']['params'].get('nsf_f0_global_maximum',
+ 730.0)
+ self.f0_feature = [
+ nsf_f0_global_maximum, nsf_f0_global_minimum
+ ]
def scan_ckpt(self, ckpt_path):
+ select_target = ckpt_path
+ input_not_dir = False
+ if not os.path.isdir(ckpt_path):
+ input_not_dir = True
+ ckpt_path = os.path.dirname(ckpt_path)
filelist = os.listdir(ckpt_path)
if len(filelist) == 0:
return {}
@@ -94,66 +238,68 @@ class Voice:
filename_prefix = filename.split('.')[0]
idx = int(filename_prefix.split('_')[-1])
path = os.path.join(ckpt_path, filename)
+ if input_not_dir and path != select_target:
+ continue
ckpts[idx] = path
od = OrderedDict(sorted(ckpts.items()))
return od
- def __load_am(self):
- self.__am_model, _, _ = model_builder(self.__am_config, self.__device)
- self.__am = self.__am_model['KanTtsSAMBERT']
+ def load_am(self):
+ self.am_model, _, _ = model_builder(self.am_config, self.device)
+ self.am = self.am_model['KanTtsSAMBERT']
state_dict = torch.load(
- self.__am_ckpts[next(reversed(self.__am_ckpts))],
- map_location=self.__device)
- self.__am.load_state_dict(state_dict['model'], strict=False)
- self.__am.eval()
+ self.am_ckpts[next(reversed(self.am_ckpts))],
+ map_location=self.device)
+ self.am.load_state_dict(state_dict['model'], strict=False)
+ self.am.eval()
- def __load_vocoder(self):
- self.__voc_model = Generator(
- **self.__voc_config['Model']['Generator']['params'])
+ def load_vocoder(self):
+ from kantts.models.hifigan.hifigan import Generator
+ self.voc_model = Generator(
+ **self.voc_config['Model']['Generator']['params'])
states = torch.load(
- self.__voc_ckpts[next(reversed(self.__voc_ckpts))],
- map_location=self.__device)
- self.__voc_model.load_state_dict(states['model']['generator'])
- if self.__voc_config['Model']['Generator']['params'][
- 'out_channels'] > 1:
- from .kantts.models.pqmf import PQMF
- self.__voc_model = PQMF()
- self.__voc_model.remove_weight_norm()
- self.__voc_model.eval().to(self.__device)
+ self.voc_ckpts[next(reversed(self.voc_ckpts))],
+ map_location=self.device)
+ self.voc_model.load_state_dict(states['model']['generator'])
+ if self.voc_config['Model']['Generator']['params']['out_channels'] > 1:
+ from kantts.models.pqmf import PQMF
+ self.voc_model = PQMF()
+ self.voc_model.remove_weight_norm()
+ self.voc_model.eval().to(self.device)
- def __am_forward(self, symbol_seq):
- with self.__lock:
+ def am_forward(self, symbol_seq):
+ with self.lock:
with torch.no_grad():
- inputs_feat_lst = self.__ling_unit.encode_symbol_sequence(
+ inputs_feat_lst = self.ling_unit.encode_symbol_sequence(
symbol_seq)
inputs_feat_index = 0
- if self.__ling_unit.using_byte():
+ if self.ling_unit.using_byte():
inputs_byte_index = (
torch.from_numpy(
inputs_feat_lst[inputs_feat_index]).long().to(
- self.__device))
+ self.device))
inputs_ling = torch.stack([inputs_byte_index],
dim=-1).unsqueeze(0)
else:
inputs_sy = (
torch.from_numpy(
inputs_feat_lst[inputs_feat_index]).long().to(
- self.__device))
+ self.device))
inputs_feat_index = inputs_feat_index + 1
inputs_tone = (
torch.from_numpy(
inputs_feat_lst[inputs_feat_index]).long().to(
- self.__device))
+ self.device))
inputs_feat_index = inputs_feat_index + 1
inputs_syllable = (
torch.from_numpy(
inputs_feat_lst[inputs_feat_index]).long().to(
- self.__device))
+ self.device))
inputs_feat_index = inputs_feat_index + 1
inputs_ws = (
torch.from_numpy(
inputs_feat_lst[inputs_feat_index]).long().to(
- self.__device))
+ self.device))
inputs_ling = torch.stack(
[inputs_sy, inputs_tone, inputs_syllable, inputs_ws],
dim=-1).unsqueeze(0)
@@ -161,39 +307,44 @@ class Voice:
inputs_emo = (
torch.from_numpy(
inputs_feat_lst[inputs_feat_index]).long().to(
- self.__device).unsqueeze(0))
+ self.device).unsqueeze(0))
inputs_feat_index = inputs_feat_index + 1
- inputs_spk = (
- torch.from_numpy(
- inputs_feat_lst[inputs_feat_index]).long().to(
- self.__device).unsqueeze(0))
- inputs_len = (torch.zeros(1).to(self.__device).long()
+ if self.se_enable:
+ inputs_spk = (
+ torch.from_numpy(
+ self.se.repeat(
+ len(inputs_feat_lst[inputs_feat_index]),
+ axis=0)).float().to(
+ self.device).unsqueeze(0)[:, :-1, :])
+ else:
+ inputs_spk = (
+ torch.from_numpy(
+ inputs_feat_lst[inputs_feat_index]).long().to(
+ self.device).unsqueeze(0)[:, :-1])
+ inputs_len = (torch.zeros(1).to(self.device).long()
+ inputs_emo.size(1) - 1) # minus 1 for "~"
- res = self.__am(inputs_ling[:, :-1, :], inputs_emo[:, :-1],
- inputs_spk[:, :-1], inputs_len)
+ res = self.am(inputs_ling[:, :-1, :], inputs_emo[:, :-1],
+ inputs_spk, inputs_len)
postnet_outputs = res['postnet_outputs']
LR_length_rounded = res['LR_length_rounded']
valid_length = int(LR_length_rounded[0].item())
- postnet_outputs = postnet_outputs[0, :valid_length, :].cpu()
- return postnet_outputs
+ mel_post = postnet_outputs[0, :valid_length, :].cpu()
+ if self.nsf_enable:
+ mel_post = denorm_f0(
+ mel_post,
+ norm_type=self.nsf_norm_type,
+ f0_feature=self.f0_feature)
+ return mel_post
- def __binarize(mel, threshold=0.6):
- # vuv binarize
- res_mel = mel.clone()
- index = torch.where(mel[:, -1] < threshold)[0]
- res_mel[:, -1] = 1.0
- res_mel[:, -1][index] = 0.0
- return res_mel
-
- def __vocoder_forward(self, melspec):
+ def vocoder_forward(self, melspec):
with torch.no_grad():
- x = melspec.to(self.__device)
- if self.__voc_model.nsf_enable:
- x = self.__binarize(x)
+ x = melspec.to(self.device)
+ if self.voc_model.nsf_enable:
+ x = binarize(x)
x = x.transpose(1, 0).unsqueeze(0)
- y = self.__voc_model(x)
- if hasattr(self.__voc_model, 'pqmf'):
- y = self.__voc_model.synthesis(y)
+ y = self.voc_model(x)
+ if hasattr(self.voc_model, 'pqmf'):
+ y = self.voc_model.synthesis(y)
y = y.view(-1).cpu().numpy()
return y
@@ -205,7 +356,7 @@ class Voice:
ignore_pretrain=False,
hparams=dict()):
logger.info('TRAIN SAMBERT....')
- if len(self.__am_ckpts) == 0:
+ if len(self.am_ckpts) == 0:
raise TtsTrainingInvalidModelException(
'resume pretrain but model is empty')
@@ -218,24 +369,23 @@ class Voice:
with open(self.audio_config, 'r') as f:
config = yaml.load(f, Loader=yaml.Loader)
-
with open(config_path, 'r') as f:
config.update(yaml.load(f, Loader=yaml.Loader))
config.update(hparams)
resume_from = None
if from_latest:
- from_steps = next(reversed(self.__am_ckpts))
- resume_from = self.__am_ckpts[from_steps]
+ from_steps = next(reversed(self.am_ckpts))
+ resume_from = self.am_ckpts[from_steps]
if not os.path.exists(resume_from):
raise TtsTrainingInvalidModelException(
f'latest model:{resume_from} not exists')
else:
- if from_steps not in self.__am_ckpts:
+ if from_steps not in self.am_ckpts:
raise TtsTrainingInvalidModelException(
f'no such model from steps:{from_steps}')
else:
- resume_from = self.__am_ckpts[from_steps]
+ resume_from = self.am_ckpts[from_steps]
if train_steps > 0:
train_max_steps = train_steps + from_steps
@@ -252,6 +402,17 @@ class Voice:
for key, value in config.items():
logger.info(f'{key} = {value}')
+ if self.distributed:
+ config['rank'] = torch.distributed.get_rank()
+ config['distributed'] = True
+
+ if self.se_enable:
+ valid_enable = False
+ valid_split_ratio = 0.00
+ else:
+ valid_enable = True
+ valid_split_ratio = 0.02
+
fp_enable = config['Model']['KanTtsSAMBERT']['params'].get('FP', False)
meta_file = [
os.path.join(
@@ -260,15 +421,33 @@ class Voice:
for d in data_dir
]
- train_dataset, valid_dataset = get_am_datasets(meta_file, data_dir,
- self.lang_dir, config,
- config['allow_cache'])
+ train_dataset, valid_dataset = get_am_datasets(
+ meta_file,
+ data_dir,
+ config,
+ config['allow_cache'],
+ split_ratio=1.0 - valid_split_ratio)
logger.info(f'The number of training files = {len(train_dataset)}.')
logger.info(f'The number of validation files = {len(valid_dataset)}.')
sampler = {'train': None, 'valid': None}
+ if self.distributed:
+ # setup sampler for distributed training
+ from torch.utils.data.distributed import DistributedSampler
+
+ sampler['train'] = DistributedSampler(
+ dataset=train_dataset,
+ num_replicas=self.world_size,
+ shuffle=True,
+ )
+ sampler['valid'] = DistributedSampler(
+ dataset=valid_dataset,
+ num_replicas=self.world_size,
+ shuffle=False,
+ ) if valid_enable else None
+
train_dataloader = DataLoader(
train_dataset,
shuffle=False if self.distributed else True,
@@ -287,16 +466,16 @@ class Voice:
num_workers=config['num_workers'],
sampler=sampler['valid'],
pin_memory=config['pin_memory'],
- )
+ ) if valid_enable else None
ling_unit_size = train_dataset.ling_unit.get_unit_size()
config['Model']['KanTtsSAMBERT']['params'].update(ling_unit_size)
- model, optimizer, scheduler = model_builder(config, self.__device,
+ model, optimizer, scheduler = model_builder(config, self.device,
self.local_rank,
self.distributed)
- criterion = criterion_builder(config, self.__device)
+ criterion = criterion_builder(config, self.device)
trainer = Sambert_Trainer(
config=config,
@@ -304,7 +483,7 @@ class Voice:
optimizer=optimizer,
scheduler=scheduler,
criterion=criterion,
- device=self.__device,
+ device=self.device,
sampler=sampler,
train_loader=train_dataloader,
valid_loader=valid_dataloader,
@@ -339,7 +518,7 @@ class Voice:
ignore_pretrain=False,
hparams=dict()):
logger.info('TRAIN HIFIGAN....')
- if len(self.__voc_ckpts) == 0:
+ if len(self.voc_ckpts) == 0:
raise TtsTrainingInvalidModelException(
'resume pretrain but model is empty')
@@ -359,17 +538,17 @@ class Voice:
resume_from = None
if from_latest:
- from_steps = next(reversed(self.__voc_ckpts))
- resume_from = self.__voc_ckpts[from_steps]
+ from_steps = next(reversed(self.voc_ckpts))
+ resume_from = self.voc_ckpts[from_steps]
if not os.path.exists(resume_from):
raise TtsTrainingInvalidModelException(
f'latest model:{resume_from} not exists')
else:
- if from_steps not in self.__voc_ckpts:
+ if from_steps not in self.voc_ckpts:
raise TtsTrainingInvalidModelException(
f'no such model from steps:{from_steps}')
else:
- resume_from = self.__voc_ckpts[from_steps]
+ resume_from = self.voc_ckpts[from_steps]
if train_steps > 0:
train_max_steps = train_steps
@@ -393,6 +572,20 @@ class Voice:
logger.info(f'The number of validation files = {len(valid_dataset)}.')
sampler = {'train': None, 'valid': None}
+ if self.distributed:
+ # setup sampler for distributed training
+ from torch.utils.data.distributed import DistributedSampler
+
+ sampler['train'] = DistributedSampler(
+ dataset=train_dataset,
+ num_replicas=self.world_size,
+ shuffle=True,
+ )
+ sampler['valid'] = DistributedSampler(
+ dataset=valid_dataset,
+ num_replicas=self.world_size,
+ shuffle=False,
+ )
train_dataloader = DataLoader(
train_dataset,
@@ -414,18 +607,18 @@ class Voice:
pin_memory=config['pin_memory'],
)
- model, optimizer, scheduler = model_builder(config, self.__device,
+ model, optimizer, scheduler = model_builder(config, self.device,
self.local_rank,
self.distributed)
- criterion = criterion_builder(config, self.__device)
+ criterion = criterion_builder(config, self.device)
trainer = GAN_Trainer(
config=config,
model=model,
optimizer=optimizer,
scheduler=scheduler,
criterion=criterion,
- device=self.__device,
+ device=self.device,
sampler=sampler,
train_loader=train_dataloader,
valid_loader=valid_dataloader,
@@ -452,9 +645,9 @@ class Voice:
f'Successfully saved checkpoint @ {trainer.steps}steps.')
def forward(self, symbol_seq):
- with self.__lock:
- if not self.__model_loaded:
- self.__load_am()
- self.__load_vocoder()
- self.__model_loaded = True
- return self.__vocoder_forward(self.__am_forward(symbol_seq))
+ with self.lock:
+ if not self.model_loaded:
+ self.load_am()
+ self.load_vocoder()
+ self.model_loaded = True
+ return self.vocoder_forward(self.am_forward(symbol_seq))
diff --git a/modelscope/models/base/base_model.py b/modelscope/models/base/base_model.py
index 18855829..0edb740e 100644
--- a/modelscope/models/base/base_model.py
+++ b/modelscope/models/base/base_model.py
@@ -12,6 +12,8 @@ from modelscope.utils.config import Config
from modelscope.utils.constant import DEFAULT_MODEL_REVISION, Invoke, ModelFile
from modelscope.utils.device import verify_device
from modelscope.utils.logger import get_logger
+from modelscope.utils.plugins import (register_modelhub_repo,
+ register_plugins_repo)
logger = get_logger()
@@ -126,6 +128,11 @@ class Model(ABC):
if hasattr(model_cfg, 'model_type') and not hasattr(model_cfg, 'type'):
model_cfg.type = model_cfg.model_type
model_cfg.model_dir = local_model_dir
+
+ # install and import remote repos before build
+ register_plugins_repo(cfg.safe_get('plugins'))
+ register_modelhub_repo(local_model_dir, cfg.get('allow_remote', False))
+
for k, v in kwargs.items():
model_cfg[k] = v
if device is not None:
diff --git a/modelscope/models/base/base_torch_model.py b/modelscope/models/base/base_torch_model.py
index b358c944..2caeb41b 100644
--- a/modelscope/models/base/base_torch_model.py
+++ b/modelscope/models/base/base_torch_model.py
@@ -6,6 +6,7 @@ from functools import partial
from typing import Any, Callable, Dict, List, Optional, Union
import torch
+from packaging import version
from torch import nn
from torch.nn.parallel import DataParallel, DistributedDataParallel
@@ -128,3 +129,19 @@ class TorchModel(Model, torch.nn.Module):
if config is not None:
save_config_function(target_folder, config)
+
+ def compile(self, **kwargs):
+ """Compile torch model with torch>=2.0
+
+ Args:
+ kwargs:
+ backend: The backend param of torch.compile
+ mode: The mode param of torch.compile
+ """
+ if version.parse(torch.__version__) >= version.parse('2.0.0.dev'):
+ return torch.compile(self, **kwargs)
+ else:
+ logger.warning(
+ f'Torch compiling needs torch version >= 2.0.0, your torch version is : {torch.__version__},'
+ f' returns original model')
+ return self
diff --git a/modelscope/models/cv/__init__.py b/modelscope/models/cv/__init__.py
index fdb8801a..3c9ea753 100644
--- a/modelscope/models/cv/__init__.py
+++ b/modelscope/models/cv/__init__.py
@@ -5,26 +5,27 @@ from . import (action_recognition, animal_recognition, bad_image_detecting,
body_2d_keypoints, body_3d_keypoints, cartoon,
cmdssl_video_embedding, controllable_image_generation,
crowd_counting, face_2d_keypoints, face_detection,
- face_generation, face_reconstruction, human_wholebody_keypoint,
- image_classification, image_color_enhance, image_colorization,
- image_defrcn_fewshot, image_denoise, image_inpainting,
- image_instance_segmentation, image_matching,
- image_mvs_depth_estimation, image_panoptic_segmentation,
- image_portrait_enhancement, image_probing_model,
- image_quality_assessment_degradation,
- image_quality_assessment_mos, image_reid_person,
- image_restoration, image_semantic_segmentation,
- image_to_image_generation, image_to_image_translation,
- language_guided_video_summarization, movie_scene_segmentation,
- object_detection, panorama_depth_estimation,
- pointcloud_sceneflow_estimation, product_retrieval_embedding,
+ face_generation, face_reconstruction, human_reconstruction,
+ human_wholebody_keypoint, image_classification,
+ image_color_enhance, image_colorization, image_defrcn_fewshot,
+ image_denoise, image_inpainting, image_instance_segmentation,
+ image_matching, image_mvs_depth_estimation,
+ image_panoptic_segmentation, image_portrait_enhancement,
+ image_probing_model, image_quality_assessment_degradation,
+ image_quality_assessment_man, image_quality_assessment_mos,
+ image_reid_person, image_restoration,
+ image_semantic_segmentation, image_to_image_generation,
+ image_to_image_translation, language_guided_video_summarization,
+ movie_scene_segmentation, object_detection,
+ panorama_depth_estimation, pointcloud_sceneflow_estimation,
+ product_retrieval_embedding,
referring_video_object_segmentation,
robust_image_classification, salient_detection,
shop_segmentation, stream_yolo, super_resolution,
- video_deinterlace, video_frame_interpolation,
+ table_recognition, video_deinterlace, video_frame_interpolation,
video_object_segmentation, video_panoptic_segmentation,
video_single_object_tracking, video_stabilization,
- video_summarization, video_super_resolution, virual_tryon,
+ video_summarization, video_super_resolution, vidt, virual_tryon,
vision_middleware, vop_retrieval)
# yapf: enable
diff --git a/modelscope/models/audio/tts/kantts/preprocess/__init__.py b/modelscope/models/cv/action_detection/modules/__init__.py
similarity index 100%
rename from modelscope/models/audio/tts/kantts/preprocess/__init__.py
rename to modelscope/models/cv/action_detection/modules/__init__.py
diff --git a/modelscope/models/cv/action_detection/modules/action_detection_pytorch.py b/modelscope/models/cv/action_detection/modules/action_detection_pytorch.py
new file mode 100644
index 00000000..a8600ae0
--- /dev/null
+++ b/modelscope/models/cv/action_detection/modules/action_detection_pytorch.py
@@ -0,0 +1,232 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+
+import logging
+from typing import Dict, List
+
+import torch
+import torch.nn as nn
+from detectron2.layers import ShapeSpec
+from detectron2.modeling import postprocessing
+from detectron2.modeling.backbone.fpn import FPN, LastLevelP6P7
+from detectron2.modeling.box_regression import _dense_box_regression_loss
+from detectron2.modeling.meta_arch.fcos import FCOS, FCOSHead
+from detectron2.structures import (Boxes, ImageList, Instances,
+ pairwise_point_box_distance)
+from fvcore.nn import sigmoid_focal_loss_jit
+from torch.nn import functional as F
+
+from modelscope.models.base import TorchModel
+from .resnet import Bottleneck3D, ResNet3D
+
+logger = logging.getLogger('detectron2.modelscope.' + __name__)
+
+
+class ActionDetector(FCOS, TorchModel):
+
+ def __init__(self, **kargs):
+ super().__init__(**kargs)
+
+ @torch.no_grad()
+ def load_init_backbone(self, path):
+ from fvcore.common import checkpoint
+ state = torch.load(path, map_location=torch.device('cpu'))
+ model_state = state.pop('model')
+ prefix = 'backbone.bottom_up.'
+ keys = sorted(model_state.keys())
+ for k in keys:
+ if not k.startswith(prefix):
+ model_state.pop(k)
+ checkpoint._strip_prefix_if_present(model_state, prefix)
+ t = self.backbone.bottom_up.load_state_dict(model_state, strict=False)
+ logger.info(str(t))
+ logger.info(f'Load pretrained backbone weights from {path}')
+
+ def preprocess_image(self, batched_inputs):
+ """
+ Normalize, pad and batch the input images.
+ """
+ images = [x['frames'].to(self.device) for x in batched_inputs]
+ images = [x.float() / 255.0 for x in images]
+ images = ImageList.from_tensors(images,
+ self.backbone.size_divisibility)
+ return images
+
+ @torch.no_grad()
+ def match_anchors(self, anchors: List[Boxes],
+ gt_instances: List[Instances]):
+ """
+ Match anchors with ground truth boxes.
+
+ Args:
+ anchors: #level boxes, from the highest resolution to lower resolution
+ gt_instances: ground truth instances per image
+
+ Returns:
+ List[Tensor]:
+ #image tensors, each is a vector of matched gt
+ indices (or -1 for unmatched anchors) for all anchors.
+ """
+ num_anchors_per_level = [len(x) for x in anchors]
+ anchors = Boxes.cat(anchors) # Rx4
+ anchor_centers = anchors.get_centers() # Rx2
+ anchor_sizes = anchors.tensor[:, 2] - anchors.tensor[:, 0] # R
+
+ lower_bound = anchor_sizes * 4
+ lower_bound[:num_anchors_per_level[0]] = 0
+ upper_bound = anchor_sizes * 8
+ upper_bound[-num_anchors_per_level[-1]:] = float('inf')
+
+ matched_indices = []
+ for gt_per_image in gt_instances:
+ if len(gt_per_image) == 0:
+ matched_indices.append(
+ torch.full((len(anchors), ),
+ -1,
+ dtype=torch.int64,
+ device=anchors.tensor.device))
+ continue
+ gt_centers = gt_per_image.gt_boxes.get_centers() # Nx2
+ # FCOS with center sampling: anchor point must be close enough to gt center.
+ center_dist = (anchor_centers[:, None, :]
+ - gt_centers[None, :, :]).abs_().max(dim=2).values
+ pairwise_match = center_dist < self.center_sampling_radius * anchor_sizes[:,
+ None]
+ pairwise_dist = pairwise_point_box_distance(
+ anchor_centers, gt_per_image.gt_boxes)
+
+ # The original FCOS anchor matching rule: anchor point must be inside gt
+ pairwise_match &= pairwise_dist.min(dim=2).values > 0
+
+ # Multilevel anchor matching in FCOS: each anchor is only responsible
+ # for certain scale range.
+ pairwise_dist = pairwise_dist.max(dim=2).values
+ pairwise_match &= (pairwise_dist > lower_bound[:, None]) & (
+ pairwise_dist < upper_bound[:, None])
+
+ # Match the GT box with minimum area, if there are multiple GT matches
+ gt_areas = gt_per_image.gt_boxes.area() # N
+ pairwise_match = pairwise_match.to(
+ torch.float32) * (1e8 - gt_areas[None, :])
+ min_values, matched_idx = pairwise_match.max(
+ dim=1) # R, per-anchor match
+ matched_idx[
+ min_values < 1e-5] = -1 # Unmatched anchors are assigned -1
+
+ matched_indices.append(matched_idx)
+ return matched_indices
+
+ def losses(self, anchors, pred_logits, gt_labels, pred_anchor_deltas,
+ gt_boxes, pred_centerness):
+ """
+ This method is almost identical to :meth:`RetinaNet.losses`, with an extra
+ "loss_centerness" in the returned dict.
+ """
+ gt_labels = torch.stack(gt_labels) # (N, R)
+
+ pos_mask = (gt_labels >= 0) & (gt_labels != self.num_classes)
+ num_pos_anchors = pos_mask.sum().item()
+ normalizer = self._ema_update('loss_normalizer',
+ max(num_pos_anchors, 1), 300)
+
+ # classification and regression loss
+ gt_labels_target = F.one_hot(
+ gt_labels, num_classes=self.num_classes
+ + 1)[:, :, :-1] # no loss for the last (background) class
+ loss_cls = sigmoid_focal_loss_jit(
+ torch.cat(pred_logits, dim=1),
+ gt_labels_target.to(pred_logits[0].dtype),
+ alpha=self.focal_loss_alpha,
+ gamma=self.focal_loss_gamma,
+ reduction='sum',
+ )
+
+ loss_box_reg = _dense_box_regression_loss(
+ anchors,
+ self.box2box_transform,
+ pred_anchor_deltas,
+ [x.tensor for x in gt_boxes],
+ pos_mask,
+ box_reg_loss_type='giou',
+ )
+
+ ctrness_targets = self.compute_ctrness_targets(anchors,
+ gt_boxes) # NxR
+ pred_centerness = torch.cat(
+ pred_centerness, dim=1).squeeze(dim=2) # NxR
+ ctrness_loss = F.binary_cross_entropy_with_logits(
+ pred_centerness[pos_mask],
+ ctrness_targets[pos_mask],
+ reduction='sum')
+ return {
+ 'loss_fcos_cls': loss_cls / normalizer,
+ 'loss_fcos_loc': loss_box_reg / normalizer,
+ 'loss_fcos_ctr': ctrness_loss / normalizer,
+ }
+
+ @torch.no_grad()
+ def label_anchors(self, anchors, gt_instances):
+ """
+ Same interface as :meth:`RetinaNet.label_anchors`, but implemented with FCOS
+ anchor matching rule.
+
+ Unlike RetinaNet, there are no ignored anchors.
+ """
+ matched_indices = self.match_anchors(anchors, gt_instances)
+
+ matched_labels, matched_boxes = [], []
+ for gt_index, gt_per_image in zip(matched_indices, gt_instances):
+ if len(gt_per_image) > 0:
+ label = gt_per_image.gt_classes[gt_index.clip(min=0)]
+ matched_gt_boxes = gt_per_image.gt_boxes[gt_index.clip(min=0)]
+ else:
+ label = gt_per_image.gt_classes.new_zeros((len(gt_index), ))
+ matched_gt_boxes = Boxes(
+ gt_per_image.gt_boxes.tensor.new_zeros((len(gt_index), 4)))
+ label[gt_index < 0] = self.num_classes # background
+ matched_labels.append(label)
+ matched_boxes.append(matched_gt_boxes)
+ return matched_labels, matched_boxes
+
+ def compute_ctrness_targets(self, anchors, gt_boxes): # NxR
+ anchors = Boxes.cat(anchors).tensor # Rx4
+ reg_targets = [
+ self.box2box_transform.get_deltas(anchors, m.tensor)
+ for m in gt_boxes
+ ]
+ reg_targets = torch.stack(reg_targets, dim=0) # NxRx4
+ if len(reg_targets) == 0:
+ # return reg_targets.new_zeros(len(reg_targets))
+ return reg_targets.new_zeros(reg_targets.size()[:-1])
+ left_right = reg_targets[:, :, [0, 2]]
+ top_bottom = reg_targets[:, :, [1, 3]]
+ ctrness = (left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0]) * (
+ top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0])
+ return torch.sqrt(ctrness)
+
+
+def build_action_detection_model(num_classes, device='cpu'):
+ backbone = ResNet3D(
+ Bottleneck3D, [3, 4, 6, 3],
+ ops=['c2d', 'p3d'] * 8,
+ t_stride=[1, 1, 1, 1, 1],
+ num_classes=None)
+ in_features = ['res3', 'res4', 'res5']
+ out_channels = 512
+ top_block = LastLevelP6P7(out_channels, out_channels, in_feature='p5')
+ fpnbackbone = FPN(
+ bottom_up=backbone,
+ in_features=in_features,
+ out_channels=out_channels,
+ top_block=top_block,
+ )
+ head = FCOSHead(
+ input_shape=[ShapeSpec(channels=out_channels)] * 5,
+ conv_dims=[out_channels] * 2,
+ num_classes=num_classes)
+ model = ActionDetector(
+ backbone=fpnbackbone,
+ head=head,
+ num_classes=num_classes,
+ pixel_mean=[0, 0, 0],
+ pixel_std=[0, 0, 0])
+ return model
diff --git a/modelscope/models/cv/action_detection/modules/resnet.py b/modelscope/models/cv/action_detection/modules/resnet.py
new file mode 100644
index 00000000..7f5529a4
--- /dev/null
+++ b/modelscope/models/cv/action_detection/modules/resnet.py
@@ -0,0 +1,382 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+
+import torch
+import torch.nn as nn
+from detectron2.modeling import Backbone
+
+
+def conv1x3x3(in_planes, out_planes, stride=(1, 1, 1)):
+ return nn.Conv3d(
+ in_planes,
+ out_planes,
+ kernel_size=(1, 3, 3),
+ stride=stride,
+ padding=(0, 1, 1),
+ bias=False)
+
+
+def conv3x3x3(in_planes, out_planes, stride=(1, 1, 1)):
+ return nn.Conv3d(
+ in_planes,
+ out_planes,
+ kernel_size=(3, 3, 3),
+ stride=stride,
+ padding=(1, 1, 1),
+ bias=False)
+
+
+def conv1x1x1(in_planes, out_planes, stride=(1, 1, 1)):
+ return nn.Conv3d(
+ in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
+
+
+def conv3x1x1(in_planes, out_planes, stride=(1, 1, 1)):
+ return nn.Conv3d(
+ in_planes,
+ out_planes,
+ kernel_size=(3, 1, 1),
+ stride=stride,
+ padding=(1, 0, 0),
+ bias=False)
+
+
+class BasicBlock3D(nn.Module):
+ expansion = 1
+
+ def __init__(self,
+ inplanes,
+ planes,
+ stride=1,
+ op='c2d',
+ downsample=None,
+ base_width=64,
+ norm_layer=None):
+ super(BasicBlock3D, self).__init__()
+ dilation = 1
+ groups = 1
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm3d
+ if groups != 1 or base_width != 64:
+ raise ValueError(
+ 'BasicBlock only supports groups=1 and base_width=64')
+ if dilation > 1:
+ raise NotImplementedError(
+ 'Dilation > 1 not supported in BasicBlock')
+ stride = [stride] * 3 if isinstance(stride, int) else stride
+ self.t_stride = stride[0]
+ self.stride = stride
+ stride = [1] + list(stride[1:])
+ # Both self.conv1 and self.downsample layers downsample the input when stride != 1
+ # self.conv1 = conv3x3(inplanes, planes, stride)
+ self.conv1 = conv1x3x3(inplanes, planes, stride)
+ self.bn1 = norm_layer(planes)
+ self.relu = nn.ReLU(inplace=True)
+ self.conv1d = conv3x1x1(planes, planes)
+ # self.conv2 = conv3x3(planes, planes)
+ self.conv2 = conv1x3x3(planes, planes)
+ self.bn2 = norm_layer(planes)
+ self.downsample = downsample
+
+ def forward(self, x):
+ identity = x
+
+ out = self.conv1(x)
+ if self.t_stride > 1:
+ out = torch.max_pool3d(
+ out, [self.t_stride, 1, 1], stride=[self.t_stride, 1, 1])
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv1d(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+ out = self.relu(out)
+
+ return out
+
+ @property
+ def out_channels(self):
+ return self.conv2.out_channels
+
+
+class Bottleneck3D(nn.Module):
+ expansion = 2
+
+ def __init__(self,
+ inplanes,
+ planes,
+ stride=1,
+ op='c2d',
+ downsample=None,
+ base_width=64,
+ norm_layer=None):
+ super(Bottleneck3D, self).__init__()
+ self.op = op
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm3d
+ width = int(planes * (base_width / 64.))
+ stride = [stride] * 3 if isinstance(stride, int) else stride
+ self.conv1 = conv3x1x1(inplanes, width) if op == 'p3d' else conv1x1x1(
+ inplanes, width)
+ self.bn1 = norm_layer(width)
+ self.conv2 = conv3x3x3(width, width,
+ stride) if op == 'c3d' else conv1x3x3(
+ width, width, stride)
+ self.bn2 = norm_layer(width)
+ self.conv3 = conv1x1x1(width, planes * self.expansion)
+ self.bn3 = norm_layer(planes * self.expansion)
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+
+ def forward(self, x):
+ identity = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ if self.op == 'tsm':
+ out = self.tsm(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+ out = self.relu(out)
+
+ return out
+
+ @property
+ def out_channels(self):
+ return self.conv3.out_channels
+
+
+class ResNet3D(Backbone):
+
+ def __init__(self,
+ block,
+ layers,
+ ops,
+ t_stride,
+ num_classes=1000,
+ zero_init_residual=False,
+ groups=1,
+ width_per_group=64,
+ replace_stride_with_dilation=None,
+ reduce_dim=0,
+ norm_layer=None):
+ self.reduce_dim = reduce_dim
+ self.num_classes = num_classes
+ super(ResNet3D, self).__init__()
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm3d
+ self._norm_layer = norm_layer
+ self._out_feature_strides = {'res3': 8, 'res4': 16, 'res5': 32}
+ self._out_features = ['res3', 'res4', 'res5']
+ self._out_feature_channels = {}
+ self.outputs = {}
+ self.inplanes = 64
+ if replace_stride_with_dilation is None:
+ # each element in the tuple indicates if we should replace
+ # the 2x2 stride with a dilated convolution instead
+ replace_stride_with_dilation = [False, False, False]
+ if len(replace_stride_with_dilation) != 3:
+ raise ValueError('replace_stride_with_dilation should be None '
+ 'or a 3-element tuple, got {}'.format(
+ replace_stride_with_dilation))
+ self.groups = groups
+ self.base_width = width_per_group
+ self.conv1 = nn.Conv3d(
+ 3,
+ self.inplanes, (1, 7, 7),
+ stride=(t_stride[0], 2, 2),
+ padding=(0, 3, 3),
+ bias=False)
+ self.bn1 = norm_layer(self.inplanes)
+ self.relu = nn.ReLU(inplace=True)
+ self.maxpool = nn.MaxPool3d(
+ kernel_size=(1, 3, 3),
+ stride=(t_stride[1], 2, 2),
+ padding=(0, 1, 1))
+ self.layer1 = self._make_layer(
+ block, 64, layers[0], stride=(1, 1, 1), ops=ops[:layers[0]])
+ self.layer2 = self._make_layer(
+ block,
+ 128,
+ layers[1],
+ stride=(t_stride[2], 2, 2),
+ ops=ops[layers[0]:][:layers[1]])
+ self.layer3 = self._make_layer(
+ block,
+ 256,
+ layers[2],
+ stride=(t_stride[3], 2, 2),
+ ops=ops[sum(layers[:2], 0):][:layers[2]])
+ self.layer4 = self._make_layer(
+ block,
+ 512,
+ layers[3],
+ stride=(t_stride[4], 2, 2),
+ ops=ops[sum(layers[:3], 0):][:layers[3]])
+ if num_classes is not None:
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
+ self.sptial_atten = nn.Conv2d(2, 1, kernel_size=7, padding=3)
+ self.drop = nn.Dropout(0.5)
+ if reduce_dim > 0:
+ self.rd_conv = nn.Conv2d(
+ 512 * block.expansion, reduce_dim, kernel_size=1)
+ self.clc = nn.Conv2d(reduce_dim, num_classes, kernel_size=1)
+ else:
+ self.clc = nn.Conv2d(
+ 512 * block.expansion, num_classes, kernel_size=1)
+
+ self._out_feature_channels['res3'] = self.layer2[-1].out_channels
+ self._out_feature_channels['res4'] = self.layer3[-1].out_channels
+ self._out_feature_channels['res5'] = self.layer4[-1].out_channels
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv3d):
+ nn.init.kaiming_normal_(
+ m.weight, mode='fan_out', nonlinearity='relu')
+ elif isinstance(m, (nn.BatchNorm3d, nn.GroupNorm)):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+ # Zero-initialize the last BN in each residual branch,
+ # so that the residual branch starts with zeros, and each residual block behaves like an identity.
+ # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
+ if zero_init_residual:
+ for m in self.modules():
+ if isinstance(m, Bottleneck3D):
+ nn.init.constant_(m.bn3.weight, 0)
+
+ def _make_layer(self, block, planes, blocks, ops, stride=(1, 1, 1)):
+ norm_layer = self._norm_layer
+ downsample = None
+ if stride != (1, 1, 1) or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ conv1x1x1(self.inplanes, planes * block.expansion, stride),
+ norm_layer(planes * block.expansion),
+ )
+
+ layers = []
+ layers.append(
+ block(self.inplanes, planes, stride, ops[0], downsample,
+ self.base_width, norm_layer))
+ self.inplanes = planes * block.expansion
+ for i in range(1, blocks):
+ layers.append(
+ block(
+ self.inplanes,
+ planes,
+ op=ops[i],
+ base_width=self.base_width,
+ norm_layer=norm_layer))
+
+ return nn.Sequential(*layers)
+
+ def features(self, x):
+ x = self.norm_x(x)
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+ x = self.maxpool(x)
+
+ x = self.layer1(x)
+ x = self.layer2(x)
+ self.outputs['res3'] = x.mean(dim=2)
+ x = self.layer3(x)
+ self.outputs['res4'] = x.mean(dim=2)
+ x = self.layer4(x) # N,C,T,H,W
+ self.outputs['res5'] = x.mean(dim=2)
+ if self.num_classes is not None:
+ x = torch.mean(x, dim=2) # 解决时间维度, N,C,H,W
+ # spatial attention
+ ftr = torch.cat(
+ (x.max(dim=1, keepdim=True)[0], x.mean(dim=1, keepdim=True)),
+ dim=1)
+ score = self.sptial_atten(ftr) # N,1,H,W
+ x = x * torch.sigmoid(score) # N,C,H,W
+ self.score = score
+
+ x = self.avgpool(x) # ,N,C,1,1
+ if self.reduce_dim > 0:
+ x = self.rd_conv(x)
+ self.outputs['ftr'] = x.mean(dim=(2, 3))
+ return x
+
+ def logits(self, x):
+ x = self.features(x)
+ x = self.clc(x)
+ return x
+
+ def forward(self, x):
+ ftr = self.features(x)
+ if self.num_classes is not None:
+ x = self.drop(ftr)
+ x = self.clc(x)
+ x = torch.mean(x, (2, 3))
+ return x
+
+ return self.outputs
+
+ @torch.no_grad()
+ def norm_x(self, x):
+ m = x.new_tensor([0.485, 0.456, 0.406]).reshape([1, 3, 1, 1, 1])
+ s = x.new_tensor([0.229, 0.224, 0.225]).reshape([1, 3, 1, 1, 1])
+ x -= m
+ x /= s
+ return x
+
+
+def resnet101_3d(ops, t_stride, num_class, reduce_dim=0):
+ net = ResNet3D(
+ Bottleneck3D, [3, 4, 23, 3],
+ ops=ops,
+ t_stride=t_stride,
+ num_classes=num_class,
+ reduce_dim=reduce_dim)
+ return net
+
+
+def resnet50_3d(ops, t_stride, num_class, reduce_dim=0):
+ net = ResNet3D(
+ Bottleneck3D, [3, 4, 6, 3],
+ ops=ops,
+ t_stride=t_stride,
+ num_classes=num_class,
+ reduce_dim=reduce_dim)
+ return net
+
+
+def resnet34_3d(ops, t_stride, num_class, reduce_dim=0):
+ net = ResNet3D(
+ BasicBlock3D, [3, 4, 6, 3],
+ ops=ops,
+ t_stride=t_stride,
+ num_classes=num_class,
+ reduce_dim=reduce_dim)
+ return net
+
+
+def resnet18_3d(ops, t_stride, num_class, reduce_dim=0):
+ net = ResNet3D(
+ BasicBlock3D, [2, 2, 2, 2],
+ ops=ops,
+ t_stride=t_stride,
+ num_classes=num_class,
+ reduce_dim=reduce_dim)
+ return net
diff --git a/modelscope/models/cv/human_reconstruction/Reconstruction.py b/modelscope/models/cv/human_reconstruction/Reconstruction.py
new file mode 100644
index 00000000..4140565e
--- /dev/null
+++ b/modelscope/models/cv/human_reconstruction/Reconstruction.py
@@ -0,0 +1,137 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import os.path as osp
+from typing import Optional
+
+import cv2
+import numpy as np
+import PIL.Image as Image
+import torch
+import torchvision.transforms as transforms
+from skimage.io import imread
+from skimage.transform import estimate_transform, warp
+
+from modelscope.metainfo import Models
+from modelscope.models.base import Tensor, TorchModel
+from modelscope.models.builder import MODELS
+from modelscope.models.cv.human_reconstruction.models.detectors import \
+ FasterRCNN
+from modelscope.models.cv.human_reconstruction.models.human_segmenter import \
+ human_segmenter
+from modelscope.models.cv.human_reconstruction.models.networks import define_G
+from modelscope.models.cv.human_reconstruction.models.PixToMesh import \
+ Pixto3DNet
+from modelscope.models.cv.human_reconstruction.utils import create_grid
+from modelscope.utils.constant import ModelFile, Tasks
+from modelscope.utils.logger import get_logger
+
+logger = get_logger()
+
+
+@MODELS.register_module(
+ Tasks.human_reconstruction, module_name=Models.human_reconstruction)
+class HumanReconstruction(TorchModel):
+
+ def __init__(self, model_dir, modelconfig, *args, **kwargs):
+ """The HumanReconstruction is modified based on PiFuHD and pix2pixhd, publicly available at
+ https://shunsukesaito.github.io/PIFuHD/ &
+ https://github.com/NVIDIA/pix2pixHD
+
+ Args:
+ model_dir: the root directory of the model files
+ modelconfig: the config param path of the model
+ """
+ super().__init__(model_dir=model_dir, *args, **kwargs)
+ if torch.cuda.is_available():
+ self.device = torch.device('cuda')
+ logger.info('Use GPU: {}'.format(self.device))
+ else:
+ self.device = torch.device('cpu')
+ logger.info('Use CPU: {}'.format(self.device))
+
+ model_path = '{}/{}'.format(model_dir, ModelFile.TORCH_MODEL_FILE)
+ normal_back_model = '{}/{}'.format(model_dir, 'Norm_B_GAN.pth')
+ normal_front_model = '{}/{}'.format(model_dir, 'Norm_F_GAN.pth')
+ human_seg_model = '{}/{}'.format(model_dir, ModelFile.TF_GRAPH_FILE)
+ fastrcnn_ckpt = '{}/{}'.format(model_dir, 'fasterrcnn_resnet50.pth')
+ self.meshmodel = Pixto3DNet(**modelconfig['model'])
+ self.detector = FasterRCNN(ckpt=fastrcnn_ckpt, device=self.device)
+ self.meshmodel.load_state_dict(
+ torch.load(model_path, map_location='cpu'))
+ self.netB = define_G(3, 3, 64, 'global', 4, 9, 1, 3, 'instance')
+ self.netF = define_G(3, 3, 64, 'global', 4, 9, 1, 3, 'instance')
+ self.netF.load_state_dict(torch.load(normal_front_model))
+ self.netB.load_state_dict(torch.load(normal_back_model))
+ self.netF = self.netF.to(self.device)
+ self.netB = self.netB.to(self.device)
+ self.netF.eval()
+ self.netB.eval()
+ self.meshmodel = self.meshmodel.to(self.device).eval()
+ self.portrait_matting = human_segmenter(model_path=human_seg_model)
+ b_min = np.array([-1, -1, -1])
+ b_max = np.array([1, 1, 1])
+ self.coords, self.mat = create_grid(modelconfig['resolution'], b_min,
+ b_max)
+ projection_matrix = np.identity(4)
+ projection_matrix[1, 1] = -1
+ self.calib = torch.Tensor(projection_matrix).float().to(self.device)
+ self.calib = self.calib[:3, :4].unsqueeze(0)
+ logger.info('model load over')
+
+ def get_mask(self, img):
+ result = self.portrait_matting.run(img)
+ result = result[..., None]
+ mask = result.repeat(3, axis=2)
+ return img, mask
+
+ @torch.no_grad()
+ def crop_img(self, img_url):
+ image = imread(img_url)[:, :, :3] / 255.
+ h, w, _ = image.shape
+ image_size = 512
+ image_tensor = torch.tensor(
+ image.transpose(2, 0, 1), dtype=torch.float32)[None, ...]
+ bbox = self.detector.run(image_tensor)
+ left = bbox[0]
+ right = bbox[2]
+ top = bbox[1]
+ bottom = bbox[3]
+
+ old_size = max(right - left, bottom - top)
+ center = np.array(
+ [right - (right - left) / 2.0, bottom - (bottom - top) / 2.0])
+ size = int(old_size * 1.1)
+ src_pts = np.array([[center[0] - size / 2, center[1] - size / 2],
+ [center[0] - size / 2, center[1] + size / 2],
+ [center[0] + size / 2, center[1] - size / 2]])
+ DST_PTS = np.array([[0, 0], [0, image_size - 1], [image_size - 1, 0]])
+ tform = estimate_transform('similarity', src_pts, DST_PTS)
+ dst_image = warp(
+ image, tform.inverse, output_shape=(image_size, image_size))
+ dst_image = (dst_image[:, :, ::-1] * 255).astype(np.uint8)
+ return dst_image
+
+ @torch.no_grad()
+ def generation_normal(self, img, mask):
+ to_tensor = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
+ ])
+ im_512 = cv2.resize(img, (512, 512))
+ image_512 = Image.fromarray(im_512).convert('RGB')
+ image_512 = to_tensor(image_512).unsqueeze(0)
+ img = image_512.to(self.device)
+ nml_f = self.netF.forward(img)
+ nml_b = self.netB.forward(img)
+ mask = cv2.resize(mask, (512, 512))
+ mask = transforms.ToTensor()(mask).unsqueeze(0)
+ nml_f = (nml_f.cpu() * mask).detach().cpu().numpy()[0]
+ nml_f = (np.transpose(nml_f,
+ (1, 2, 0)) * 0.5 + 0.5)[:, :, ::-1] * 255.0
+ nml_b = (nml_b.cpu() * mask).detach().cpu().numpy()[0]
+ nml_b = (np.transpose(nml_b,
+ (1, 2, 0)) * 0.5 + 0.5)[:, :, ::-1] * 255.0
+ nml_f = nml_f.astype(np.uint8)
+ nml_b = nml_b.astype(np.uint8)
+ return nml_f, nml_b
+
+ # def forward(self, img, mask, normal_f, normal_b):
diff --git a/modelscope/models/audio/tts/kantts/preprocess/audio_processor/__init__.py b/modelscope/models/cv/human_reconstruction/__init__.py
similarity index 100%
rename from modelscope/models/audio/tts/kantts/preprocess/audio_processor/__init__.py
rename to modelscope/models/cv/human_reconstruction/__init__.py
diff --git a/modelscope/models/cv/human_reconstruction/models/Embedding.py b/modelscope/models/cv/human_reconstruction/models/Embedding.py
new file mode 100644
index 00000000..a2ec1877
--- /dev/null
+++ b/modelscope/models/cv/human_reconstruction/models/Embedding.py
@@ -0,0 +1,32 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import torch
+from torch import nn
+
+
+class Embedding(nn.Module):
+
+ def __init__(self, in_channels, N_freqs, logscale=True):
+ """
+ Defines a function that embeds x to (x, sin(2^k x), cos(2^k x), ...)
+ in_channels: number of input channels (3 for both xyz and direction)
+ """
+ super(Embedding, self).__init__()
+ self.N_freqs = N_freqs
+ self.in_channels = in_channels
+ self.name = 'Embedding'
+ self.funcs = [torch.sin, torch.cos]
+ self.out_channels = in_channels * (len(self.funcs) * N_freqs + 1)
+ self.input_para = dict(in_channels=in_channels, N_freqs=N_freqs)
+
+ if logscale:
+ self.freq_bands = 2**torch.linspace(0, N_freqs - 1, N_freqs)
+ else:
+ self.freq_bands = torch.linspace(1, 2**(N_freqs - 1), N_freqs)
+
+ def forward(self, x):
+ out = [x]
+ for freq in self.freq_bands:
+ for func in self.funcs:
+ out += [func(freq * x)]
+
+ return torch.cat(out, 1)
diff --git a/modelscope/models/cv/human_reconstruction/models/PixToMesh.py b/modelscope/models/cv/human_reconstruction/models/PixToMesh.py
new file mode 100644
index 00000000..0299bf82
--- /dev/null
+++ b/modelscope/models/cv/human_reconstruction/models/PixToMesh.py
@@ -0,0 +1,142 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import torch
+import torch.nn as nn
+
+from .Embedding import Embedding
+from .geometry import index, orthogonal, perspective
+from .Res_backbone import Res_hournet
+from .Surface_head import Surface_Head
+
+
+class Pixto3DNet(nn.Module):
+
+ def __init__(self,
+ backbone,
+ head,
+ rgbhead,
+ embedding,
+ projection_mode: str = 'orthogonal',
+ error_term: str = 'mse',
+ num_views: int = 1):
+ """
+ Parameters:
+ backbone: parameter of networks to extract image features
+ head: parameter of networks to predict value in surface
+ rgbhead: parameter of networks to predict rgb of point
+ embedding: parameter of networks to normalize depth of camera coordinate
+ projection_mode: how to render your 3d model to images
+ error_term: train loss
+ num_view: how many images from which you want to reconstruct model
+ """
+ super(Pixto3DNet, self).__init__()
+
+ self.backbone = Res_hournet(**backbone)
+ self.head = Surface_Head(**head)
+ self.rgbhead = Surface_Head(**rgbhead)
+ self.depth = Embedding(**embedding)
+
+ if error_term == 'mse':
+ self.error_term = nn.MSELoss(reduction='none')
+ elif error_term == 'bce':
+ self.error_term = nn.BCELoss(reduction='none')
+ elif error_term == 'l1':
+ self.error_term = nn.L1Loss(reduction='none')
+ else:
+ raise NotImplementedError
+
+ self.index = index
+ self.projection = orthogonal if projection_mode == 'orthogonal' else perspective
+
+ self.num_views = num_views
+ self.im_feat_list = []
+ self.intermediate_preds_list = []
+
+ def extract_features(self, images: torch.Tensor):
+ self.im_feat_list = self.backbone(images)
+
+ def query(self, points, calibs, transforms=None, labels=None):
+ if labels is not None:
+ self.labels = labels
+
+ xyz = self.projection(points, calibs, transforms)
+
+ xy = xyz[:, :2, :]
+ xyz_feat = self.depth(xyz)
+
+ self.intermediate_preds_list = []
+
+ im_feat_256 = self.im_feat_list[0]
+ im_feat_512 = self.im_feat_list[1]
+
+ point_local_feat_list = [
+ self.index(im_feat_256, xy),
+ self.index(im_feat_512, xy), xyz_feat
+ ]
+ point_local_feat = torch.cat(point_local_feat_list, 1)
+
+ pred, phi = self.head(point_local_feat)
+ self.intermediate_preds_list.append(pred)
+ self.phi = phi
+
+ self.preds = self.intermediate_preds_list[-1]
+
+ def get_preds(self):
+ return self.preds
+
+ def query_rgb(self, points, calibs, transforms=None):
+ xyz = self.projection(points, calibs, transforms)
+
+ xy = xyz[:, :2, :]
+ xyz_feat = self.depth(xyz)
+
+ self.intermediate_preds_list = []
+
+ im_feat_256 = self.im_feat_list[0]
+ im_feat_512 = self.im_feat_list[1]
+
+ point_local_feat_list = [
+ self.index(im_feat_256, xy),
+ self.index(im_feat_512, xy), xyz_feat
+ ]
+ point_local_feat = torch.cat(point_local_feat_list, 1)
+
+ pred, phi = self.head(point_local_feat)
+ rgb_point_feat = torch.cat([point_local_feat, phi], 1)
+ rgb, phi = self.rgbhead(rgb_point_feat)
+ return rgb
+
+ def get_error(self):
+ error = 0
+ lc = torch.tensor(self.labels.shape[0] * self.labels.shape[1]
+ * self.labels.shape[2])
+ inw = torch.sum(self.labels)
+ weight_in = inw / lc
+ weight = torch.abs(self.labels - weight_in)
+ lamda = 1 / torch.mean(weight)
+ for preds in self.intermediate_preds_list:
+ error += lamda * torch.mean(
+ self.error_term(preds, self.labels) * weight)
+ error /= len(self.intermediate_preds_list)
+
+ return error
+
+ def forward(self,
+ images,
+ points,
+ calibs,
+ surpoint=None,
+ transforms=None,
+ labels=None):
+ self.extract_features(images)
+
+ self.query(
+ points=points, calibs=calibs, transforms=transforms, labels=labels)
+
+ if surpoint is not None:
+ rgb = self.query_rgb(
+ points=surpoint, calibs=calibs, transforms=transforms)
+ else:
+ rgb = None
+ res = self.preds
+
+ return res, rgb
diff --git a/modelscope/models/cv/human_reconstruction/models/Res_backbone.py b/modelscope/models/cv/human_reconstruction/models/Res_backbone.py
new file mode 100644
index 00000000..b14ae772
--- /dev/null
+++ b/modelscope/models/cv/human_reconstruction/models/Res_backbone.py
@@ -0,0 +1,330 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class BlurPool(nn.Module):
+
+ def __init__(self,
+ channels,
+ pad_type='reflect',
+ filt_size=4,
+ stride=2,
+ pad_off=0):
+ super(BlurPool, self).__init__()
+ self.filt_size = filt_size
+ self.pad_off = pad_off
+ self.pad_sizes = [
+ int(1. * (filt_size - 1) / 2),
+ int(np.ceil(1. * (filt_size - 1) / 2)),
+ int(1. * (filt_size - 1) / 2),
+ int(np.ceil(1. * (filt_size - 1) / 2))
+ ]
+ self.pad_sizes = [pad_size + pad_off for pad_size in self.pad_sizes]
+ self.stride = stride
+ self.off = int((self.stride - 1) / 2.)
+ self.channels = channels
+
+ if (self.filt_size == 1):
+ a = np.array([
+ 1.,
+ ])
+ elif (self.filt_size == 2):
+ a = np.array([1., 1.])
+ elif (self.filt_size == 3):
+ a = np.array([1., 2., 1.])
+ elif (self.filt_size == 4):
+ a = np.array([1., 3., 3., 1.])
+ elif (self.filt_size == 5):
+ a = np.array([1., 4., 6., 4., 1.])
+ elif (self.filt_size == 6):
+ a = np.array([1., 5., 10., 10., 5., 1.])
+ elif (self.filt_size == 7):
+ a = np.array([1., 6., 15., 20., 15., 6., 1.])
+
+ filt = torch.Tensor(a[:, None] * a[None, :])
+ filt = filt / torch.sum(filt)
+ self.register_buffer(
+ 'filt', filt[None, None, :, :].repeat((self.channels, 1, 1, 1)))
+
+ self.pad = get_pad_layer(pad_type)(self.pad_sizes)
+
+ def forward(self, inp):
+ if (self.filt_size == 1):
+ if (self.pad_off == 0):
+ return inp[:, :, ::self.stride, ::self.stride]
+ else:
+ return self.pad(inp)[:, :, ::self.stride, ::self.stride]
+ else:
+ return F.conv2d(
+ self.pad(inp),
+ self.filt,
+ stride=self.stride,
+ groups=inp.shape[1])
+
+
+def get_pad_layer(pad_type):
+ if (pad_type in ['refl', 'reflect']):
+ PadLayer = nn.ReflectionPad2d
+ elif (pad_type in ['repl', 'replicate']):
+ PadLayer = nn.ReplicationPad2d
+ elif (pad_type == 'zero'):
+ PadLayer = nn.ZeroPad2d
+ else:
+ print('Pad type [%s] not recognized' % pad_type)
+ return PadLayer
+
+
+class ConvBlockv1(nn.Module):
+
+ def __init__(self, in_planes, out_planes, norm='batch'):
+ super(ConvBlockv1, self).__init__()
+ self.conv1 = nn.Conv2d(
+ in_planes,
+ int(out_planes / 2),
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False)
+ self.conv2 = nn.Conv2d(
+ int(out_planes / 2),
+ int(out_planes / 4),
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False)
+ self.conv3 = nn.Conv2d(
+ int(out_planes / 4),
+ int(out_planes / 4),
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False)
+
+ if norm == 'batch':
+ self.bn2 = nn.BatchNorm2d(int(out_planes / 2))
+ self.bn3 = nn.BatchNorm2d(int(out_planes / 4))
+ self.bn4 = nn.BatchNorm2d(out_planes)
+ elif norm == 'group':
+ self.bn2 = nn.GroupNorm(32, int(out_planes / 2))
+ self.bn3 = nn.GroupNorm(32, int(out_planes / 4))
+ self.bn4 = nn.GroupNorm(32, out_planes)
+
+ if in_planes != out_planes:
+ self.downsample = nn.Sequential(
+ nn.Conv2d(
+ in_planes, out_planes, kernel_size=1, stride=1,
+ bias=False), )
+ else:
+ self.downsample = None
+
+ def forward(self, x):
+ residual = x
+ out1 = self.conv1(x)
+ out2 = self.bn2(out1)
+ out2 = F.relu(out2, True)
+ out2 = self.conv2(out2)
+
+ out3 = self.bn3(out2)
+ out3 = F.relu(out3, True)
+ out3 = self.conv3(out3)
+ out3 = torch.cat((out1, out2, out3), 1)
+
+ if self.downsample is not None:
+ residual = self.downsample(residual)
+ out3 += residual
+ out4 = self.bn4(out3)
+ out4 = F.relu(out4, True)
+ return out4
+
+
+class Conv2(nn.Module):
+
+ def __init__(self, in_planes, out_planes, norm='batch'):
+ super(Conv2, self).__init__()
+ self.conv1 = nn.Conv2d(
+ in_planes,
+ int(out_planes / 4),
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False)
+ self.conv2 = nn.Conv2d(
+ in_planes,
+ int(out_planes / 4),
+ kernel_size=5,
+ stride=1,
+ padding=2,
+ bias=False)
+ self.conv3 = nn.Conv2d(
+ in_planes,
+ int(out_planes / 2),
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=False)
+ self.conv4 = nn.Conv2d(
+ out_planes,
+ out_planes,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False)
+
+ if norm == 'batch':
+ self.bn1 = nn.BatchNorm2d(int(out_planes / 4))
+ self.bn2 = nn.BatchNorm2d(int(out_planes / 4))
+ self.bn3 = nn.BatchNorm2d(int(out_planes / 2))
+ self.bn4 = nn.BatchNorm2d(out_planes)
+ elif norm == 'group':
+ self.bn1 = nn.GroupNorm(32, int(out_planes / 4))
+ self.bn2 = nn.GroupNorm(32, int(out_planes / 4))
+ self.bn3 = nn.GroupNorm(32, int(out_planes / 2))
+ self.bn4 = nn.GroupNorm(32, out_planes)
+
+ if in_planes != out_planes:
+ self.downsample = nn.Sequential(
+ nn.Conv2d(
+ in_planes, out_planes, kernel_size=1, stride=1,
+ bias=False), )
+ else:
+ self.downsample = None
+
+ def forward(self, x):
+ residual = x
+ out1 = self.conv1(x)
+ out1 = self.bn1(out1)
+ out1 = F.relu(out1, True)
+
+ out2 = self.conv2(x)
+ out2 = self.bn2(out2)
+ out2 = F.relu(out2, True)
+
+ out3 = self.conv3(x)
+ out3 = self.bn3(out3)
+ out3 = F.relu(out3, True)
+ out3 = torch.cat((out1, out2, out3), 1)
+ if self.downsample is not None:
+ residual = self.downsample(residual)
+ out = out3 + residual
+ out = self.conv4(out)
+ out = self.bn4(out)
+ out = F.relu(out, True)
+ return out
+
+
+class Res_hournet(nn.Module):
+
+ def __init__(self, norm: str = 'group', use_front=False, use_back=False):
+ """
+ Defines a backbone of human reconstruction
+ use_front & use_back is the normal map of input image
+ """
+ super(Res_hournet, self).__init__()
+ self.name = 'Res Backbone'
+ self.norm = norm
+ inc = 3
+ self.use_front = use_front
+ self.use_back = use_back
+ if self.use_front:
+ inc += 3
+ if self.use_back:
+ inc += 3
+ self.conv1 = nn.Conv2d(inc, 64, kernel_size=7, stride=1, padding=3)
+ if self.norm == 'batch':
+ self.bn1 = nn.BatchNorm2d(64)
+ elif self.norm == 'group':
+ self.bn1 = nn.GroupNorm(32, 64)
+ self.down_conv1 = BlurPool(
+ 64, pad_type='reflect', filt_size=7, stride=2)
+ self.conv2 = ConvBlockv1(64, 128, self.norm)
+ self.down_conv2 = BlurPool(
+ 128, pad_type='reflect', filt_size=7, stride=2)
+ self.conv3 = ConvBlockv1(128, 128, self.norm)
+ self.conv5 = ConvBlockv1(128, 256, self.norm)
+ self.conv6 = ConvBlockv1(256, 256, self.norm)
+ self.down_conv3 = BlurPool(
+ 256, pad_type='reflect', filt_size=5, stride=2)
+ self.conv7 = ConvBlockv1(256, 256, self.norm)
+ self.conv8 = ConvBlockv1(256, 256, self.norm)
+ self.conv9 = ConvBlockv1(256, 256, self.norm)
+ self.conv10 = ConvBlockv1(256, 256, self.norm)
+ self.conv10_1 = ConvBlockv1(256, 512, self.norm)
+ self.conv10_2 = Conv2(512, 512, self.norm)
+ self.down_conv4 = BlurPool(
+ 512, pad_type='reflect', filt_size=5, stride=2)
+ self.conv11 = Conv2(512, 512, self.norm)
+ self.conv12 = ConvBlockv1(512, 512, self.norm)
+ self.conv13 = Conv2(512, 512, self.norm)
+ self.conv14 = ConvBlockv1(512, 512, self.norm)
+ self.conv15 = Conv2(512, 512, self.norm)
+ self.conv16 = ConvBlockv1(512, 512, self.norm)
+ self.conv17 = Conv2(512, 512, self.norm)
+ self.conv18 = ConvBlockv1(512, 512, self.norm)
+ self.conv19 = Conv2(512, 512, self.norm)
+ self.conv20 = ConvBlockv1(512, 512, self.norm)
+ self.conv21 = Conv2(512, 512, self.norm)
+ self.conv22 = ConvBlockv1(512, 512, self.norm)
+
+ self.up_down1 = nn.Conv2d(1024, 512, 3, 1, 1, bias=False)
+ self.upconv1 = ConvBlockv1(512, 512, self.norm)
+ self.upconv1_1 = ConvBlockv1(512, 512, self.norm)
+ self.up_down2 = nn.Conv2d(768, 512, 3, 1, 1, bias=False)
+ self.upconv2 = ConvBlockv1(512, 256, self.norm)
+ self.upconv2_1 = ConvBlockv1(256, 256, self.norm)
+ self.up_down3 = nn.Conv2d(384, 256, 3, 1, 1, bias=False)
+ self.upconv3 = ConvBlockv1(256, 256, self.norm)
+ self.upconv3_4 = nn.Conv2d(256, 128, 3, 1, 1, bias=False)
+ self.up_down4 = nn.Conv2d(192, 64, 3, 1, 1, bias=False)
+ self.upconv4 = ConvBlockv1(64, 64, 'batch')
+
+ def forward(self, x):
+ out0 = self.bn1(self.conv1(x))
+ out1 = self.down_conv1(out0)
+ out1 = self.conv2(out1)
+ out2 = self.down_conv2(out1)
+ out2 = self.conv3(out2)
+ out2 = self.conv5(out2)
+ out2 = self.conv6(out2)
+ out3 = self.down_conv3(out2)
+ out3 = self.conv7(out3)
+ out3 = self.conv9(self.conv8(out3))
+ out3 = self.conv10(out3)
+ out3 = self.conv10_2(self.conv10_1(out3))
+ out4 = self.down_conv4(out3)
+ out4 = self.conv12(self.conv11(out4))
+ out4 = self.conv14(self.conv13(out4))
+ out4 = self.conv16(self.conv15(out4))
+ out4 = self.conv18(self.conv17(out4))
+ out4 = self.conv20(self.conv19(out4))
+ out4 = self.conv22(self.conv21(out4))
+
+ up1 = F.interpolate(
+ out4, scale_factor=2, mode='bicubic', align_corners=True)
+ up1 = torch.cat((up1, out3), 1)
+ up1 = self.up_down1(up1)
+ up1 = self.upconv1(up1)
+ up1 = self.upconv1_1(up1)
+
+ up2 = F.interpolate(
+ up1, scale_factor=2, mode='bicubic', align_corners=True)
+ up2 = torch.cat((up2, out2), 1)
+ up2 = self.up_down2(up2)
+ up2 = self.upconv2(up2)
+ up2 = self.upconv2_1(up2)
+
+ up3 = F.interpolate(
+ up2, scale_factor=2, mode='bicubic', align_corners=True)
+ up3 = torch.cat((up3, out1), 1)
+ up3 = self.up_down3(up3)
+ up3 = self.upconv3(up3)
+
+ up34 = self.upconv3_4(up3)
+ up4 = F.interpolate(
+ up34, scale_factor=2, mode='bicubic', align_corners=True)
+ up4 = torch.cat((up4, out0), 1)
+ up4 = self.up_down4(up4)
+ up4 = self.upconv4(up4)
+ return up3, up4
diff --git a/modelscope/models/cv/human_reconstruction/models/Surface_head.py b/modelscope/models/cv/human_reconstruction/models/Surface_head.py
new file mode 100644
index 00000000..47c0cccb
--- /dev/null
+++ b/modelscope/models/cv/human_reconstruction/models/Surface_head.py
@@ -0,0 +1,73 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class Surface_Head(nn.Module):
+ """
+ MLP: aims at learn iso-surface function Implicit function
+ """
+
+ def __init__(self,
+ filter_channels,
+ merge_layer=0,
+ res_layers=[],
+ norm='group',
+ last_op=None):
+ super(Surface_Head, self).__init__()
+ if last_op == 'sigmoid':
+ self.last_op = nn.Sigmoid()
+ elif last_op == 'tanh':
+ self.last_op = nn.Tanh()
+ else:
+ raise NotImplementedError(
+ 'only sigmoid/tanh function could be used')
+
+ self.filters = nn.ModuleList()
+ self.norms = nn.ModuleList()
+ self.merge_layer = merge_layer if merge_layer > 0 else len(
+ filter_channels) // 2
+
+ self.res_layers = res_layers
+ self.norm = norm
+
+ for i in range(0, len(filter_channels) - 1):
+ if i in self.res_layers:
+ self.filters.append(
+ nn.Conv1d(filter_channels[i] + filter_channels[0],
+ filter_channels[i + 1], 1))
+ else:
+ self.filters.append(
+ nn.Conv1d(filter_channels[i], filter_channels[i + 1], 1))
+ if i != len(filter_channels) - 2:
+ if norm == 'group':
+ self.norms.append(nn.GroupNorm(32, filter_channels[i + 1]))
+ elif norm == 'batch':
+ self.norms.append(nn.BatchNorm1d(filter_channels[i + 1]))
+
+ def forward(self, feature):
+ """feature may include multiple view inputs
+ Parameters:
+ feature: [B, C_in, N]
+ return:
+ prediction: [B, C_out, N] and merge layer features
+ """
+
+ y = feature
+ tmpy = feature
+ phi = None
+
+ for i, f in enumerate(self.filters):
+ y = f(y if i not in self.res_layers else torch.cat([y, tmpy], 1))
+ if i != len(self.filters) - 1:
+ if self.norm not in ['batch', 'group']:
+ y = F.leaky_relu(y)
+ else:
+ y = F.leaky_relu(self.norms[i](y))
+ if i == self.merge_layer:
+ phi = y.clone()
+
+ if self.last_op is not None:
+ y = self.last_op(y)
+ return y, phi
diff --git a/modelscope/models/audio/tts/kantts/preprocess/audio_processor/core/__init__.py b/modelscope/models/cv/human_reconstruction/models/__init__.py
similarity index 100%
rename from modelscope/models/audio/tts/kantts/preprocess/audio_processor/core/__init__.py
rename to modelscope/models/cv/human_reconstruction/models/__init__.py
diff --git a/modelscope/models/cv/human_reconstruction/models/detectors.py b/modelscope/models/cv/human_reconstruction/models/detectors.py
new file mode 100644
index 00000000..4f63dd8c
--- /dev/null
+++ b/modelscope/models/cv/human_reconstruction/models/detectors.py
@@ -0,0 +1,66 @@
+# The implementation here is modified based on Pytorch, originally BSD License and publicly avaialbe at
+# https://github.com/pytorch/pytorch
+import numpy as np
+import torch
+
+
+class FasterRCNN(object):
+ ''' detect body
+ COCO_INSTANCE_CATEGORY_NAMES = [
+ '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
+ 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
+ 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
+ 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
+ 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
+ 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
+ 'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
+ 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
+ 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
+ 'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
+ 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
+ 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
+ ]
+ '''
+
+ def __init__(self, ckpt=None, device='cuda:0'):
+ """
+ https://pytorch.org/docs/stable/torchvision/models.html#faster-r-cnn
+ """
+ import torchvision
+ if ckpt is None:
+ self.model = torchvision.models.detection.fasterrcnn_resnet50_fpn(
+ pretrained=True)
+ else:
+ self.model = torchvision.models.detection.fasterrcnn_resnet50_fpn(
+ pretrained=False)
+ state_dict = torch.load(ckpt, map_location='cpu')
+ self.model.load_state_dict(state_dict)
+ self.model.to(device)
+ self.model.eval()
+ self.device = device
+
+ @torch.no_grad()
+ def run(self, input):
+ """
+ return: detected box, [x1, y1, x2, y2]
+ """
+ prediction = self.model(input.to(self.device))[0]
+ inds = (prediction['labels'] == 1) * (prediction['scores'] > 0.5)
+ if len(inds) < 1:
+ return None
+ else:
+ bbox = prediction['boxes'][inds][0].cpu().numpy()
+ return bbox
+
+ @torch.no_grad()
+ def run_multi(self, input):
+ """
+ return: detected box, [x1, y1, x2, y2]
+ """
+ prediction = self.model(input.to(self.device))[0]
+ inds = (prediction['labels'] == 1) * (prediction['scores'] > 0.9)
+ if len(inds) < 1:
+ return None
+ else:
+ bbox = prediction['boxes'][inds].cpu().numpy()
+ return bbox
diff --git a/modelscope/models/cv/human_reconstruction/models/geometry.py b/modelscope/models/cv/human_reconstruction/models/geometry.py
new file mode 100644
index 00000000..fa4a00a6
--- /dev/null
+++ b/modelscope/models/cv/human_reconstruction/models/geometry.py
@@ -0,0 +1,61 @@
+# The implementation here is modified based on PIFU, originally MIT License and publicly avaialbe at
+# https://github.com/shunsukesaito/PIFu/blob/master/lib/geometry.py
+import torch
+
+
+def index(feat, uv):
+ """
+ extract image features at floating coordinates with bilinear interpolation
+ args:
+ feat: [B, C, H, W] image features
+ uv: [B, 2, N] normalized image coordinates ranged in [-1, 1]
+ return:
+ [B, C, N] sampled pixel values
+ """
+ uv = uv.transpose(1, 2)
+ uv = uv.unsqueeze(2)
+ samples = torch.nn.functional.grid_sample(feat, uv, align_corners=True)
+ return samples[:, :, :, 0]
+
+
+def orthogonal(points, calib, transform=None):
+ """
+ project points onto screen space using orthogonal projection
+ args:
+ points: [B, 3, N] 3d points in world coordinates
+ calib: [B, 3, 4] projection matrix
+ transform: [B, 2, 3] screen space transformation
+ return:
+ [B, 3, N] 3d coordinates in screen space
+ """
+ rot = calib[:, :3, :3]
+ trans = calib[:, :3, 3:4]
+ pts = torch.baddbmm(trans, rot, points)
+ if transform is not None:
+ scale = transform[:2, :2]
+ shift = transform[:2, 2:3]
+ pts[:, :2, :] = torch.baddbmm(shift, scale, pts[:, :2, :])
+ return pts
+
+
+def perspective(points, calib, transform=None):
+ """
+ project points onto screen space using perspective projection
+ args:
+ points: [B, 3, N] 3d points in world coordinates
+ calib: [B, 3, 4] projection matrix
+ transform: [B, 2, 3] screen space trasnformation
+ return:
+ [B, 3, N] 3d coordinates in screen space
+ """
+ rot = calib[:, :3, :3]
+ trans = calib[:, :3, 3:4]
+ homo = torch.baddbmm(trans, rot, points)
+ xy = homo[:, :2, :] / homo[:, 2:3, :]
+ if transform is not None:
+ scale = transform[:2, :2]
+ shift = transform[:2, 2:3]
+ xy = torch.baddbmm(shift, scale, xy)
+
+ xyz = torch.cat([xy, homo[:, 2:3, :]], 1)
+ return xyz
diff --git a/modelscope/models/cv/human_reconstruction/models/human_segmenter.py b/modelscope/models/cv/human_reconstruction/models/human_segmenter.py
new file mode 100644
index 00000000..3f0261e7
--- /dev/null
+++ b/modelscope/models/cv/human_reconstruction/models/human_segmenter.py
@@ -0,0 +1,60 @@
+# The implementation is also open-sourced by the authors, and available at
+# https://www.modelscope.cn/models/damo/cv_unet_image-matting/summary
+import cv2
+import numpy as np
+import tensorflow as tf
+
+if tf.__version__ >= '2.0':
+ tf = tf.compat.v1
+
+
+class human_segmenter(object):
+
+ def __init__(self, model_path):
+ super(human_segmenter, self).__init__()
+ f = tf.gfile.FastGFile(model_path, 'rb')
+ graph_def = tf.GraphDef()
+ graph_def.ParseFromString(f.read())
+ persisted_graph = tf.import_graph_def(graph_def, name='')
+
+ config = tf.ConfigProto()
+ config.gpu_options.per_process_gpu_memory_fraction = 0.3 # 占用GPU 30%的显存
+ self.sess = tf.InteractiveSession(graph=persisted_graph, config=config)
+
+ self.image_node = self.sess.graph.get_tensor_by_name('input_image:0')
+ self.output_node = self.sess.graph.get_tensor_by_name('output_png:0')
+ self.logits_node = self.sess.graph.get_tensor_by_name('if_person:0')
+ print('human_segmenter init done')
+
+ def image_preprocess(self, img):
+ if len(img.shape) == 2:
+ img = np.dstack((img, img, img))
+ elif img.shape[2] == 4:
+ img = img[:, :, :3]
+ img = img.astype(np.float)
+ return img
+
+ def run(self, img):
+ image_feed = self.image_preprocess(img)
+ output_img_value, logits_value = self.sess.run(
+ [self.output_node, self.logits_node],
+ feed_dict={self.image_node: image_feed})
+ mask = output_img_value[:, :, -1]
+ return mask
+
+ def get_human_bbox(self, mask):
+ print('dtype:{}, max:{},shape:{}'.format(mask.dtype, np.max(mask),
+ mask.shape))
+ ret, thresh = cv2.threshold(mask, 127, 255, 0)
+ contours, hierarchy = cv2.findContours(thresh, cv2.RETR_TREE,
+ cv2.CHAIN_APPROX_SIMPLE)
+ if len(contours) == 0:
+ return None
+
+ contoursArea = [cv2.contourArea(c) for c in contours]
+ max_area_index = contoursArea.index(max(contoursArea))
+ bbox = cv2.boundingRect(contours[max_area_index])
+ return bbox
+
+ def release(self):
+ self.sess.close()
diff --git a/modelscope/models/cv/human_reconstruction/models/networks.py b/modelscope/models/cv/human_reconstruction/models/networks.py
new file mode 100644
index 00000000..266237b6
--- /dev/null
+++ b/modelscope/models/cv/human_reconstruction/models/networks.py
@@ -0,0 +1,366 @@
+# The implementation here is modified based on Pix2PixHD, originally BSD License and publicly avaialbe at
+# https://github.com/NVIDIA/pix2pixHD
+import functools
+
+import numpy as np
+import torch
+import torch.nn as nn
+
+
+def weights_init(m):
+ classname = m.__class__.__name__
+ if classname.find('Conv') != -1:
+ m.weight.data.normal_(0.0, 0.02)
+ elif classname.find('BatchNorm2d') != -1:
+ m.weight.data.normal_(1.0, 0.02)
+ m.bias.data.fill_(0)
+
+
+def get_norm_layer(norm_type='instance'):
+ if norm_type == 'batch':
+ norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
+ elif norm_type == 'instance':
+ norm_layer = functools.partial(nn.InstanceNorm2d, affine=False)
+ else:
+ raise NotImplementedError('normalization layer [%s] is not found'
+ % norm_type)
+ return norm_layer
+
+
+def define_G(input_nc,
+ output_nc,
+ ngf,
+ netG,
+ n_downsample_global=3,
+ n_blocks_global=9,
+ n_local_enhancers=1,
+ n_blocks_local=3,
+ norm='instance',
+ gpu_ids=[],
+ last_op=nn.Tanh()):
+ norm_layer = get_norm_layer(norm_type=norm)
+ if netG == 'global':
+ netG = GlobalGenerator(
+ input_nc,
+ output_nc,
+ ngf,
+ n_downsample_global,
+ n_blocks_global,
+ norm_layer,
+ last_op=last_op)
+ elif netG == 'local':
+ netG = LocalEnhancer(input_nc, output_nc, ngf, n_downsample_global,
+ n_blocks_global, n_local_enhancers,
+ n_blocks_local, norm_layer)
+ elif netG == 'encoder':
+ netG = Encoder(input_nc, output_nc, ngf, n_downsample_global,
+ norm_layer)
+ else:
+ raise ('generator not implemented!')
+ if len(gpu_ids) > 0:
+ assert (torch.cuda.is_available())
+ netG.cuda(gpu_ids[0])
+ netG.apply(weights_init)
+ return netG
+
+
+def print_network(net):
+ if isinstance(net, list):
+ net = net[0]
+ num_params = 0
+ for param in net.parameters():
+ num_params += param.numel()
+ print(net)
+ print('Total number of parameters: %d' % num_params)
+
+
+"""
+ Generator code
+"""
+
+
+class LocalEnhancer(nn.Module):
+
+ def __init__(self,
+ input_nc,
+ output_nc,
+ ngf=32,
+ n_downsample_global=3,
+ n_blocks_global=9,
+ n_local_enhancers=1,
+ n_blocks_local=3,
+ norm_layer=nn.BatchNorm2d,
+ padding_type='reflect'):
+ super(LocalEnhancer, self).__init__()
+ self.n_local_enhancers = n_local_enhancers
+
+ ngf_global = ngf * (2**n_local_enhancers)
+ model_global = GlobalGenerator(input_nc, output_nc, ngf_global,
+ n_downsample_global, n_blocks_global,
+ norm_layer).model
+ model_global = [model_global[i] for i in range(len(model_global) - 3)
+ ] # get rid of final convolution layers
+ self.model = nn.Sequential(*model_global)
+
+ for n in range(1, n_local_enhancers + 1):
+ ngf_global = ngf * (2**(n_local_enhancers - n))
+ model_downsample = [
+ nn.ReflectionPad2d(3),
+ nn.Conv2d(input_nc, ngf_global, kernel_size=7, padding=0),
+ norm_layer(ngf_global),
+ nn.ReLU(True),
+ nn.Conv2d(
+ ngf_global,
+ ngf_global * 2,
+ kernel_size=3,
+ stride=2,
+ padding=1),
+ norm_layer(ngf_global * 2),
+ nn.ReLU(True)
+ ]
+ model_upsample = []
+ for i in range(n_blocks_local):
+ model_upsample += [
+ ResnetBlock(
+ ngf_global * 2,
+ padding_type=padding_type,
+ norm_layer=norm_layer)
+ ]
+
+ model_upsample += [
+ nn.ConvTranspose2d(
+ ngf_global * 2,
+ ngf_global,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ output_padding=1),
+ norm_layer(ngf_global),
+ nn.ReLU(True)
+ ]
+
+ if n == n_local_enhancers:
+ model_upsample += [
+ nn.ReflectionPad2d(3),
+ nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0),
+ nn.Tanh()
+ ]
+
+ setattr(self, 'model' + str(n) + '_1',
+ nn.Sequential(*model_downsample))
+ setattr(self, 'model' + str(n) + '_2',
+ nn.Sequential(*model_upsample))
+
+ self.downsample = nn.AvgPool2d(
+ 3, stride=2, padding=[1, 1], count_include_pad=False)
+
+ def forward(self, input):
+ input_downsampled = [input]
+ for i in range(self.n_local_enhancers):
+ input_downsampled.append(self.downsample(input_downsampled[-1]))
+
+ output_prev = self.model(input_downsampled[-1])
+ for n_local_enhancers in range(1, self.n_local_enhancers + 1):
+ model_downsample = getattr(self,
+ 'model' + str(n_local_enhancers) + '_1')
+ model_upsample = getattr(self,
+ 'model' + str(n_local_enhancers) + '_2')
+ input_i = input_downsampled[self.n_local_enhancers
+ - n_local_enhancers]
+ output_prev = model_upsample(
+ model_downsample(input_i) + output_prev)
+ return output_prev
+
+
+class GlobalGenerator(nn.Module):
+
+ def __init__(self,
+ input_nc,
+ output_nc,
+ ngf=64,
+ n_downsampling=3,
+ n_blocks=9,
+ norm_layer=nn.BatchNorm2d,
+ padding_type='reflect',
+ last_op=nn.Tanh()):
+ assert (n_blocks >= 0)
+ super(GlobalGenerator, self).__init__()
+ activation = nn.ReLU(True)
+
+ model = [
+ nn.ReflectionPad2d(3),
+ nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0),
+ norm_layer(ngf), activation
+ ]
+ for i in range(n_downsampling):
+ mult = 2**i
+ model += [
+ nn.Conv2d(
+ ngf * mult,
+ ngf * mult * 2,
+ kernel_size=3,
+ stride=2,
+ padding=1),
+ norm_layer(ngf * mult * 2), activation
+ ]
+
+ mult = 2**n_downsampling
+ for i in range(n_blocks):
+ model += [
+ ResnetBlock(
+ ngf * mult,
+ padding_type=padding_type,
+ activation=activation,
+ norm_layer=norm_layer)
+ ]
+
+ for i in range(n_downsampling):
+ mult = 2**(n_downsampling - i)
+ model += [
+ nn.ConvTranspose2d(
+ ngf * mult,
+ int(ngf * mult / 2),
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ output_padding=1),
+ norm_layer(int(ngf * mult / 2)), activation
+ ]
+ model += [
+ nn.ReflectionPad2d(3),
+ nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)
+ ]
+ if last_op is not None:
+ model += [last_op]
+ self.model = nn.Sequential(*model)
+
+ def forward(self, input):
+ return self.model(input)
+
+
+"""
+ Define a resnet block
+"""
+
+
+class ResnetBlock(nn.Module):
+
+ def __init__(self,
+ dim,
+ padding_type,
+ norm_layer,
+ activation=nn.ReLU(True),
+ use_dropout=False):
+ super(ResnetBlock, self).__init__()
+ self.conv_block = self.build_conv_block(dim, padding_type, norm_layer,
+ activation, use_dropout)
+
+ def build_conv_block(self, dim, padding_type, norm_layer, activation,
+ use_dropout):
+ conv_block = []
+ p = 0
+ if padding_type == 'reflect':
+ conv_block += [nn.ReflectionPad2d(1)]
+ elif padding_type == 'replicate':
+ conv_block += [nn.ReplicationPad2d(1)]
+ elif padding_type == 'zero':
+ p = 1
+ else:
+ raise NotImplementedError('padding [%s] is not implemented'
+ % padding_type)
+
+ conv_block += [
+ nn.Conv2d(dim, dim, kernel_size=3, padding=p),
+ norm_layer(dim), activation
+ ]
+ if use_dropout:
+ conv_block += [nn.Dropout(0.5)]
+
+ p = 0
+ if padding_type == 'reflect':
+ conv_block += [nn.ReflectionPad2d(1)]
+ elif padding_type == 'replicate':
+ conv_block += [nn.ReplicationPad2d(1)]
+ elif padding_type == 'zero':
+ p = 1
+ else:
+ raise NotImplementedError('padding [%s] is not implemented'
+ % padding_type)
+ conv_block += [
+ nn.Conv2d(dim, dim, kernel_size=3, padding=p),
+ norm_layer(dim)
+ ]
+
+ return nn.Sequential(*conv_block)
+
+ def forward(self, x):
+ out = x + self.conv_block(x)
+ return out
+
+
+class Encoder(nn.Module):
+
+ def __init__(self,
+ input_nc,
+ output_nc,
+ ngf=32,
+ n_downsampling=4,
+ norm_layer=nn.BatchNorm2d):
+ super(Encoder, self).__init__()
+ self.output_nc = output_nc
+
+ model = [
+ nn.ReflectionPad2d(3),
+ nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0),
+ norm_layer(ngf),
+ nn.ReLU(True)
+ ]
+ for i in range(n_downsampling):
+ mult = 2**i
+ model += [
+ nn.Conv2d(
+ ngf * mult,
+ ngf * mult * 2,
+ kernel_size=3,
+ stride=2,
+ padding=1),
+ norm_layer(ngf * mult * 2),
+ nn.ReLU(True)
+ ]
+
+ for i in range(n_downsampling):
+ mult = 2**(n_downsampling - i)
+ model += [
+ nn.ConvTranspose2d(
+ ngf * mult,
+ int(ngf * mult / 2),
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ output_padding=1),
+ norm_layer(int(ngf * mult / 2)),
+ nn.ReLU(True)
+ ]
+
+ model += [
+ nn.ReflectionPad2d(3),
+ nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0),
+ nn.Tanh()
+ ]
+ self.model = nn.Sequential(*model)
+
+ def forward(self, input, inst):
+ outputs = self.model(input)
+
+ outputs_mean = outputs.clone()
+ inst_list = np.unique(inst.cpu().numpy().astype(int))
+ for i in inst_list:
+ for b in range(input.size()[0]):
+ indices = (inst[b:b + 1] == int(i)).nonzero()
+ for j in range(self.output_nc):
+ output_ins = outputs[indices[:, 0] + b, indices[:, 1] + j,
+ indices[:, 2], indices[:, 3]]
+ mean_feat = torch.mean(output_ins).expand_as(output_ins)
+ outputs_mean[indices[:, 0] + b, indices[:, 1] + j,
+ indices[:, 2], indices[:, 3]] = mean_feat
+ return outputs_mean
diff --git a/modelscope/models/cv/human_reconstruction/utils.py b/modelscope/models/cv/human_reconstruction/utils.py
new file mode 100644
index 00000000..45653dc6
--- /dev/null
+++ b/modelscope/models/cv/human_reconstruction/utils.py
@@ -0,0 +1,178 @@
+import os
+
+import mcubes
+import numpy as np
+import torch
+
+
+def save_obj_mesh_with_color(mesh_path, verts, faces, colors):
+ file = open(mesh_path, 'w')
+ for idx, v in enumerate(verts):
+ c = colors[idx]
+ file.write('v %.4f %.4f %.4f %.4f %.4f %.4f\n' %
+ (v[0], v[1], v[2], c[0], c[1], c[2]))
+ for f in faces:
+ f_plus = f + 1
+ file.write('f %d %d %d\n' % (f_plus[0], f_plus[2], f_plus[1]))
+ file.close()
+
+
+def save_obj_mesh(mesh_path, verts, faces):
+ file = open(mesh_path, 'w')
+ for idx, v in enumerate(verts):
+ file.write('v %.4f %.4f %.4f\n' % (v[0], v[1], v[2]))
+ for f in faces:
+ f_plus = f + 1
+ file.write('f %d %d %d\n' % (f_plus[0], f_plus[2], f_plus[1]))
+ file.close()
+
+
+def to_tensor(img):
+ if len(img.shape) == 2:
+ img = img[:, :, np.newaxis]
+ img = torch.from_numpy(img.transpose(2, 0, 1)).float()
+ img = img / 255.
+ return img
+
+
+def reconstruction(net, calib_tensor, coords, mat, num_samples=50000):
+
+ def eval_func(points):
+ points = np.expand_dims(points, axis=0)
+ points = np.repeat(points, 1, axis=0)
+ samples = torch.from_numpy(points).cuda().float()
+ net.query(samples, calib_tensor)
+ pred = net.get_preds()
+ pred = pred[0]
+ return pred.detach().cpu().numpy()
+
+ sdf = eval_grid(coords, eval_func, num_samples=num_samples)
+ vertices, faces = mcubes.marching_cubes(sdf, 0.5)
+ verts = np.matmul(mat[:3, :3], vertices.T) + mat[:3, 3:4]
+ verts = verts.T
+ return verts, faces
+
+
+def keep_largest(mesh_big):
+ mesh_lst = mesh_big.split(only_watertight=False)
+ keep_mesh = mesh_lst[0]
+ for mesh in mesh_lst:
+ if mesh.vertices.shape[0] > keep_mesh.vertices.shape[0]:
+ keep_mesh = mesh
+ return keep_mesh
+
+
+def eval_grid(coords,
+ eval_func,
+ init_resolution=64,
+ threshold=0.01,
+ num_samples=512 * 512 * 512):
+ resolution = coords.shape[1:4]
+ sdf = np.zeros(resolution)
+ dirty = np.ones(resolution, dtype=np.bool)
+ grid_mask = np.zeros(resolution, dtype=np.bool)
+ reso = resolution[0] // init_resolution
+
+ while reso > 0:
+ grid_mask[0:resolution[0]:reso, 0:resolution[1]:reso,
+ 0:resolution[2]:reso] = True
+ test_mask = np.logical_and(grid_mask, dirty)
+ points = coords[:, test_mask]
+
+ sdf[test_mask] = batch_eval(points, eval_func, num_samples=num_samples)
+ dirty[test_mask] = False
+
+ if reso <= 1:
+ break
+ for x in range(0, resolution[0] - reso, reso):
+ for y in range(0, resolution[1] - reso, reso):
+ for z in range(0, resolution[2] - reso, reso):
+ if not dirty[x + reso // 2, y + reso // 2, z + reso // 2]:
+ continue
+ v0 = sdf[x, y, z]
+ v1 = sdf[x, y, z + reso]
+ v2 = sdf[x, y + reso, z]
+ v3 = sdf[x, y + reso, z + reso]
+ v4 = sdf[x + reso, y, z]
+ v5 = sdf[x + reso, y, z + reso]
+ v6 = sdf[x + reso, y + reso, z]
+ v7 = sdf[x + reso, y + reso, z + reso]
+ v = np.array([v0, v1, v2, v3, v4, v5, v6, v7])
+ v_min = v.min()
+ v_max = v.max()
+ if (v_max - v_min) < threshold:
+ sdf[x:x + reso, y:y + reso,
+ z:z + reso] = (v_max + v_min) / 2
+ dirty[x:x + reso, y:y + reso, z:z + reso] = False
+ reso //= 2
+
+ return sdf.reshape(resolution)
+
+
+def batch_eval(points, eval_func, num_samples=512 * 512 * 512):
+ num_pts = points.shape[1]
+ sdf = np.zeros(num_pts)
+
+ num_batches = num_pts // num_samples
+ for i in range(num_batches):
+ sdf[i * num_samples:i * num_samples + num_samples] = eval_func(
+ points[:, i * num_samples:i * num_samples + num_samples])
+ if num_pts % num_samples:
+ sdf[num_batches * num_samples:] = eval_func(points[:, num_batches
+ * num_samples:])
+ return sdf
+
+
+def create_grid(res,
+ b_min=np.array([0, 0, 0]),
+ b_max=np.array([1, 1, 1]),
+ transform=None):
+ coords = np.mgrid[:res, :res, :res]
+
+ coords = coords.reshape(3, -1)
+ coords_matrix = np.eye(4)
+ length = b_max - b_min
+
+ coords_matrix[0, 0] = length[0] / res
+ coords_matrix[1, 1] = length[1] / res
+ coords_matrix[2, 2] = length[2] / res
+ coords_matrix[0:3, 3] = b_min
+
+ coords = np.matmul(coords_matrix[:3, :3], coords) + coords_matrix[:3, 3:4]
+ if transform is not None:
+ coords = np.matmul(transform[:3, :3], coords) + transform[:3, 3:4]
+ coords_matrix = np.matmul(transform, coords_matrix)
+ coords = coords.reshape(3, res, res, res)
+ return coords, coords_matrix
+
+
+def get_submesh(verts,
+ faces,
+ color,
+ verts_retained=None,
+ faces_retained=None,
+ min_vert_in_face=2):
+ verts = verts
+ faces = faces
+ colors = color
+ if verts_retained is not None:
+ if verts_retained.dtype != 'bool':
+ vert_mask = np.zeros(len(verts), dtype=bool)
+ vert_mask[verts_retained] = True
+ else:
+ vert_mask = verts_retained
+ bool_faces = np.sum(
+ vert_mask[faces.ravel()].reshape(-1, 3), axis=1) > min_vert_in_face
+ elif faces_retained is not None:
+ if faces_retained.dtype != 'bool':
+ bool_faces = np.zeros(len(faces_retained), dtype=bool)
+ else:
+ bool_faces = faces_retained
+ new_faces = faces[bool_faces]
+ vertex_ids = list(set(new_faces.ravel()))
+ oldtonew = -1 * np.ones([len(verts)])
+ oldtonew[vertex_ids] = range(0, len(vertex_ids))
+ new_verts = verts[vertex_ids]
+ new_colors = colors[vertex_ids]
+ new_faces = oldtonew[new_faces].astype('int32')
+ return (new_verts, new_faces, new_colors, bool_faces, vertex_ids)
diff --git a/modelscope/models/cv/image_colorization/ddcolor/ddcolor_for_image_colorization.py b/modelscope/models/cv/image_colorization/ddcolor/ddcolor_for_image_colorization.py
index 0d2acbd2..e4479f13 100644
--- a/modelscope/models/cv/image_colorization/ddcolor/ddcolor_for_image_colorization.py
+++ b/modelscope/models/cv/image_colorization/ddcolor/ddcolor_for_image_colorization.py
@@ -1,21 +1,87 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os.path as osp
+from copy import deepcopy
from typing import Dict, Union
+import numpy as np
+import torch
+
from modelscope.metainfo import Models
from modelscope.models.base import Tensor, TorchModel
from modelscope.models.builder import MODELS
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger
from .ddcolor import DDColor
+from .loss import L1Loss
logger = get_logger()
__all__ = ['DDColorForImageColorization']
+def tensor_lab2rgb(labs, illuminant='D65', observer='2'):
+ """
+ Args:
+ lab : (B, C, H, W)
+ Returns:
+ tuple : (C, H, W)
+ """
+ illuminants = \
+ {'A': {'2': (1.098466069456375, 1, 0.3558228003436005),
+ '10': (1.111420406956693, 1, 0.3519978321919493)},
+ 'D50': {'2': (0.9642119944211994, 1, 0.8251882845188288),
+ '10': (0.9672062750333777, 1, 0.8142801513128616)},
+ 'D55': {'2': (0.956797052643698, 1, 0.9214805860173273),
+ '10': (0.9579665682254781, 1, 0.9092525159847462)},
+ 'D65': {'2': (0.95047, 1., 1.08883), # This was: `lab_ref_white`
+ '10': (0.94809667673716, 1, 1.0730513595166162)},
+ 'D75': {'2': (0.9497220898840717, 1, 1.226393520724154),
+ '10': (0.9441713925645873, 1, 1.2064272211720228)},
+ 'E': {'2': (1.0, 1.0, 1.0),
+ '10': (1.0, 1.0, 1.0)}}
+ rgb_from_xyz = np.array([[3.240481340, -0.96925495, 0.055646640],
+ [-1.53715152, 1.875990000, -0.20404134],
+ [-0.49853633, 0.041555930, 1.057311070]])
+ B, C, H, W = labs.shape
+ arrs = labs.permute(
+ (0, 2, 3, 1)).contiguous() # (B, 3, H, W) -> (B, H, W, 3)
+ L, a, b = arrs[:, :, :, 0:1], arrs[:, :, :, 1:2], arrs[:, :, :, 2:]
+ y = (L + 16.) / 116.
+ x = (a / 500.) + y
+ z = y - (b / 200.)
+ invalid = z.data < 0
+ z[invalid] = 0
+ xyz = torch.cat([x, y, z], dim=3)
+ mask = xyz.data > 0.2068966
+ mask_xyz = xyz.clone()
+ mask_xyz[mask] = torch.pow(xyz[mask], 3.0)
+ mask_xyz[~mask] = (xyz[~mask] - 16.0 / 116.) / 7.787
+ xyz_ref_white = illuminants[illuminant][observer]
+ for i in range(C):
+ mask_xyz[:, :, :, i] = mask_xyz[:, :, :, i] * xyz_ref_white[i]
+
+ rgb_trans = torch.mm(
+ mask_xyz.view(-1, 3),
+ torch.from_numpy(rgb_from_xyz).type_as(xyz)).view(B, H, W, C)
+ rgb = rgb_trans.permute((0, 3, 1, 2)).contiguous()
+ mask = rgb.data > 0.0031308
+ mask_rgb = rgb.clone()
+ mask_rgb[mask] = 1.055 * torch.pow(rgb[mask], 1 / 2.4) - 0.055
+ mask_rgb[~mask] = rgb[~mask] * 12.92
+ neg_mask = mask_rgb.data < 0
+ large_mask = mask_rgb.data > 1
+ mask_rgb[neg_mask] = 0
+ mask_rgb[large_mask] = 1
+ return mask_rgb
+
+
@MODELS.register_module(Tasks.image_colorization, module_name=Models.ddcolor)
class DDColorForImageColorization(TorchModel):
+ """DDColor model for Image Colorization:
+ Colorize an image using unet with dual decoders,
+ while the image decoder restores the spatial resolution,
+ and the color decoder learn adaptive color queries.
+ """
def __init__(self,
model_dir,
@@ -38,6 +104,56 @@ class DDColorForImageColorization(TorchModel):
model_path = osp.join(model_dir, ModelFile.TORCH_MODEL_FILE)
self.model = self._load_pretrained(self.model, model_path)
+ self.loss = L1Loss(loss_weight=0.1)
+
+ def _load_pretrained(self,
+ net,
+ load_path,
+ strict=True,
+ param_key='params'):
+ load_net = torch.load(
+ load_path, map_location=lambda storage, loc: storage)
+ if param_key is not None:
+ if param_key not in load_net and 'params' in load_net:
+ param_key = 'params'
+ logger.info(
+ f'Loading: {param_key} does not exist, use params.')
+ if param_key in load_net:
+ load_net = load_net[param_key]
+ logger.info(
+ f'Loading {net.__class__.__name__} model from {load_path}, with param key: [{param_key}].'
+ )
+ # remove unnecessary 'module.' or 'model.'
+ for k, v in deepcopy(load_net).items():
+ if k.startswith('module.'):
+ load_net[k[7:]] = v
+ load_net.pop(k)
+ elif k.startswith('model.'):
+ load_net[k[6:]] = v
+ load_net.pop(k)
+ net.load_state_dict(load_net, strict=strict)
+ logger.info('load model done.')
+ return net
+
+ def _train_forward(self, input: Tensor,
+ target: Tensor) -> Dict[str, Tensor]:
+ preds = self.model(input)
+ return {'loss': self.loss(preds, target)}
+
+ def _evaluate_postprocess(self, input: Tensor, target: Tensor,
+ img_l: Tensor,
+ gt_rgb: Tensor) -> Dict[str, list]:
+ preds = self.model(input) # (n, 2, h, w)
+
+ preds_lab = torch.cat((img_l, preds), 1) # (n, 3, h, w)
+ preds_rgb = tensor_lab2rgb(preds_lab)
+
+ # preds = list(torch.split(preds_rgb, 1, 0))
+ # targets = list(torch.split(gt_rgb, 1, 0))
+ preds = preds_rgb
+ targets = gt_rgb
+
+ return {'preds': preds, 'targets': targets}
def forward(self, input: Dict[str,
Tensor]) -> Dict[str, Union[list, Tensor]]:
@@ -49,4 +165,9 @@ class DDColorForImageColorization(TorchModel):
Returns:
Dict[str, Tensor]: results
"""
- return self.model(**input)
+ if self.training:
+ return self._train_forward(**input)
+ elif 'target' in input:
+ return self._evaluate_postprocess(**input)
+ else:
+ return self.model(**input)
diff --git a/modelscope/models/cv/image_colorization/ddcolor/loss.py b/modelscope/models/cv/image_colorization/ddcolor/loss.py
new file mode 100644
index 00000000..db2be54e
--- /dev/null
+++ b/modelscope/models/cv/image_colorization/ddcolor/loss.py
@@ -0,0 +1,270 @@
+# The implementation here is modified based on BasicSR, originally Apache 2.0 license and publicly available at
+# https://github.com/XPixelGroup/BasicSR/blob/master/basicsr/losses/basic_loss.py
+
+import torch
+from torch import nn as nn
+from torch.nn import functional as F
+
+from .utils.vgg import VGGFeatureExtractor
+
+
+def l1_loss(pred, target, reduction):
+ return F.l1_loss(pred, target, reduction=reduction)
+
+
+class L1Loss(nn.Module):
+ """L1 (mean absolute error, MAE) loss.
+
+ Args:
+ loss_weight (float): Loss weight for L1 loss. Default: 1.0.
+ reduction (str): Specifies the reduction to apply to the output.
+ Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
+ """
+
+ def __init__(self, loss_weight=1.0, reduction='mean'):
+ super(L1Loss, self).__init__()
+ if reduction not in ['none', 'mean', 'sum']:
+ raise ValueError(
+ f'Unsupported reduction mode: {reduction}. Supported ones are: {_reduction_modes}'
+ )
+
+ self.loss_weight = loss_weight
+ self.reduction = reduction
+
+ def forward(self, pred, target, weight=None, **kwargs):
+ """
+ Args:
+ pred (Tensor): of shape (N, C, H, W). Predicted tensor.
+ target (Tensor): of shape (N, C, H, W). Ground truth tensor.
+ weight (Tensor, optional): of shape (N, C, H, W). Element-wise
+ weights. Default: None.
+ """
+ return self.loss_weight * l1_loss(
+ pred, target, reduction=self.reduction)
+
+
+class PerceptualLoss(nn.Module):
+ """Perceptual loss with commonly used style loss.
+
+ Args:
+ layer_weights (dict): The weight for each layer of vgg feature.
+ Here is an example: {'conv5_4': 1.}, which means the conv5_4
+ feature layer (before relu5_4) will be extracted with weight
+ 1.0 in calculating losses.
+ vgg_type (str): The type of vgg network used as feature extractor.
+ Default: 'vgg19'.
+ use_input_norm (bool): If True, normalize the input image in vgg.
+ Default: True.
+ range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].
+ Default: False.
+ perceptual_weight (float): If `perceptual_weight > 0`, the perceptual
+ loss will be calculated and the loss will multiplied by the
+ weight. Default: 1.0.
+ style_weight (float): If `style_weight > 0`, the style loss will be
+ calculated and the loss will multiplied by the weight.
+ Default: 0.
+ criterion (str): Criterion used for perceptual loss. Default: 'l1'.
+ """
+
+ def __init__(self,
+ layer_weights,
+ vgg_type='vgg19',
+ use_input_norm=True,
+ range_norm=False,
+ perceptual_weight=1.0,
+ style_weight=0.,
+ criterion='l1'):
+ super(PerceptualLoss, self).__init__()
+ self.perceptual_weight = perceptual_weight
+ self.style_weight = style_weight
+ self.layer_weights = layer_weights
+ self.vgg = VGGFeatureExtractor(
+ layer_name_list=list(layer_weights.keys()),
+ vgg_type=vgg_type,
+ use_input_norm=use_input_norm,
+ range_norm=range_norm)
+
+ self.criterion_type = criterion
+ if self.criterion_type == 'l1':
+ self.criterion = torch.nn.L1Loss()
+ elif self.criterion_type == 'l2':
+ self.criterion = torch.nn.L2loss()
+ elif self.criterion_type == 'fro':
+ self.criterion = None
+ else:
+ raise NotImplementedError(
+ f'{criterion} criterion has not been supported.')
+
+ def forward(self, x, gt):
+ """Forward function.
+
+ Args:
+ x (Tensor): Input tensor with shape (n, c, h, w).
+ gt (Tensor): Ground-truth tensor with shape (n, c, h, w).
+
+ Returns:
+ Tensor: Forward results.
+ """
+ # extract vgg features
+ x_features = self.vgg(x)
+ gt_features = self.vgg(gt.detach())
+
+ # calculate perceptual loss
+ if self.perceptual_weight > 0:
+ percep_loss = 0
+ for k in x_features.keys():
+ if self.criterion_type == 'fro':
+ percep_loss += torch.norm(
+ x_features[k] - gt_features[k],
+ p='fro') * self.layer_weights[k]
+ else:
+ percep_loss += self.criterion(
+ x_features[k], gt_features[k]) * self.layer_weights[k]
+ percep_loss *= self.perceptual_weight
+ else:
+ percep_loss = None
+
+ # calculate style loss
+ if self.style_weight > 0:
+ style_loss = 0
+ for k in x_features.keys():
+ if self.criterion_type == 'fro':
+ style_loss += torch.norm(
+ self._gram_mat(x_features[k])
+ - self._gram_mat(gt_features[k]),
+ p='fro') * self.layer_weights[k]
+ else:
+ style_loss += self.criterion(
+ self._gram_mat(x_features[k]),
+ self._gram_mat(gt_features[k])) * self.layer_weights[k]
+ style_loss *= self.style_weight
+ else:
+ style_loss = None
+
+ return percep_loss, style_loss
+
+ def _gram_mat(self, x):
+ """Calculate Gram matrix.
+
+ Args:
+ x (torch.Tensor): Tensor with shape of (n, c, h, w).
+
+ Returns:
+ torch.Tensor: Gram matrix.
+ """
+ n, c, h, w = x.size()
+ features = x.view(n, c, w * h)
+ features_t = features.transpose(1, 2)
+ gram = features.bmm(features_t) / (c * h * w)
+ return gram
+
+
+class GANLoss(nn.Module):
+ """Define GAN loss.
+
+ Args:
+ gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge'.
+ real_label_val (float): The value for real label. Default: 1.0.
+ fake_label_val (float): The value for fake label. Default: 0.0.
+ loss_weight (float): Loss weight. Default: 1.0.
+ Note that loss_weight is only for generators; and it is always 1.0
+ for discriminators.
+ """
+
+ def __init__(self,
+ gan_type,
+ real_label_val=1.0,
+ fake_label_val=0.0,
+ loss_weight=1.0):
+ super(GANLoss, self).__init__()
+ self.gan_type = gan_type
+ self.loss_weight = loss_weight
+ self.real_label_val = real_label_val
+ self.fake_label_val = fake_label_val
+
+ if self.gan_type == 'vanilla':
+ self.loss = nn.BCEWithLogitsLoss()
+ elif self.gan_type == 'lsgan':
+ self.loss = nn.MSELoss()
+ elif self.gan_type == 'wgan':
+ self.loss = self._wgan_loss
+ elif self.gan_type == 'wgan_softplus':
+ self.loss = self._wgan_softplus_loss
+ elif self.gan_type == 'hinge':
+ self.loss = nn.ReLU()
+ else:
+ raise NotImplementedError(
+ f'GAN type {self.gan_type} is not implemented.')
+
+ def _wgan_loss(self, input, target):
+ """wgan loss.
+
+ Args:
+ input (Tensor): Input tensor.
+ target (bool): Target label.
+
+ Returns:
+ Tensor: wgan loss.
+ """
+ return -input.mean() if target else input.mean()
+
+ def _wgan_softplus_loss(self, input, target):
+ """wgan loss with soft plus. softplus is a smooth approximation to the
+ ReLU function.
+
+ In StyleGAN2, it is called:
+ Logistic loss for discriminator;
+ Non-saturating loss for generator.
+
+ Args:
+ input (Tensor): Input tensor.
+ target (bool): Target label.
+
+ Returns:
+ Tensor: wgan loss.
+ """
+ return F.softplus(-input).mean() if target else F.softplus(
+ input).mean()
+
+ def get_target_label(self, input, target_is_real):
+ """Get target label.
+
+ Args:
+ input (Tensor): Input tensor.
+ target_is_real (bool): Whether the target is real or fake.
+
+ Returns:
+ (bool | Tensor): Target tensor. Return bool for wgan, otherwise,
+ return Tensor.
+ """
+
+ if self.gan_type in ['wgan', 'wgan_softplus']:
+ return target_is_real
+ target_val = (
+ self.real_label_val if target_is_real else self.fake_label_val)
+ return input.new_ones(input.size()) * target_val
+
+ def forward(self, input, target_is_real, is_disc=False):
+ """
+ Args:
+ input (Tensor): The input for the loss module, i.e., the network
+ prediction.
+ target_is_real (bool): Whether the targe is real or fake.
+ is_disc (bool): Whether the loss for discriminators or not.
+ Default: False.
+
+ Returns:
+ Tensor: GAN loss value.
+ """
+ target_label = self.get_target_label(input, target_is_real)
+ if self.gan_type == 'hinge':
+ if is_disc: # for discriminators in hinge-gan
+ input = -input if target_is_real else input
+ loss = self.loss(1 + input).mean()
+ else: # for generators in hinge-gan
+ loss = -input.mean()
+ else: # other gan types
+ loss = self.loss(input, target_label)
+
+ # loss_weight is always 1.0 for discriminators
+ return loss if is_disc else loss * self.loss_weight
diff --git a/modelscope/models/cv/image_colorization/ddcolor/utils/vgg.py b/modelscope/models/cv/image_colorization/ddcolor/utils/vgg.py
new file mode 100644
index 00000000..dc0125c2
--- /dev/null
+++ b/modelscope/models/cv/image_colorization/ddcolor/utils/vgg.py
@@ -0,0 +1,180 @@
+# The implementation here is modified based on BasicSR, originally Apache 2.0 license and publicly available at
+# https://github.com/XPixelGroup/BasicSR/blob/master/basicsr/archs/vgg_arch.py
+
+import os
+from collections import OrderedDict
+
+import torch
+from torch import nn as nn
+from torchvision.models import vgg as vgg
+
+VGG_PRETRAIN_PATH = {
+ 'vgg19': './pretrain/vgg19-dcbb9e9d.pth',
+ 'vgg16_bn': './pretrain/vgg16_bn-6c64b313.pth'
+}
+
+NAMES = {
+ 'vgg11': [
+ 'conv1_1', 'relu1_1', 'pool1', 'conv2_1', 'relu2_1', 'pool2',
+ 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'pool3', 'conv4_1',
+ 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', 'conv5_1', 'relu5_1',
+ 'conv5_2', 'relu5_2', 'pool5'
+ ],
+ 'vgg13': [
+ 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1',
+ 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', 'conv3_1', 'relu3_1',
+ 'conv3_2', 'relu3_2', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2',
+ 'relu4_2', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'pool5'
+ ],
+ 'vgg16': [
+ 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1',
+ 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', 'conv3_1', 'relu3_1',
+ 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'pool3', 'conv4_1',
+ 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'pool4',
+ 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3',
+ 'pool5'
+ ],
+ 'vgg19': [
+ 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1',
+ 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', 'conv3_1', 'relu3_1',
+ 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4',
+ 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3',
+ 'relu4_3', 'conv4_4', 'relu4_4', 'pool4', 'conv5_1', 'relu5_1',
+ 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4',
+ 'pool5'
+ ]
+}
+
+
+def insert_bn(names):
+ """Insert bn layer after each conv.
+
+ Args:
+ names (list): The list of layer names.
+
+ Returns:
+ list: The list of layer names with bn layers.
+ """
+ names_bn = []
+ for name in names:
+ names_bn.append(name)
+ if 'conv' in name:
+ position = name.replace('conv', '')
+ names_bn.append('bn' + position)
+ return names_bn
+
+
+class VGGFeatureExtractor(nn.Module):
+ """VGG network for feature extraction.
+
+ In this implementation, we allow users to choose whether use normalization
+ in the input feature and the type of vgg network. Note that the pretrained
+ path must fit the vgg type.
+
+ Args:
+ layer_name_list (list[str]): Forward function returns the corresponding
+ features according to the layer_name_list.
+ Example: {'relu1_1', 'relu2_1', 'relu3_1'}.
+ vgg_type (str): Set the type of vgg network. Default: 'vgg19'.
+ use_input_norm (bool): If True, normalize the input image. Importantly,
+ the input feature must in the range [0, 1]. Default: True.
+ range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].
+ Default: False.
+ requires_grad (bool): If true, the parameters of VGG network will be
+ optimized. Default: False.
+ remove_pooling (bool): If true, the max pooling operations in VGG net
+ will be removed. Default: False.
+ pooling_stride (int): The stride of max pooling operation. Default: 2.
+ """
+
+ def __init__(self,
+ layer_name_list,
+ vgg_type='vgg19',
+ use_input_norm=True,
+ range_norm=False,
+ requires_grad=False,
+ remove_pooling=False,
+ pooling_stride=2):
+ super(VGGFeatureExtractor, self).__init__()
+
+ self.layer_name_list = layer_name_list
+ self.use_input_norm = use_input_norm
+ self.range_norm = range_norm
+
+ self.names = NAMES[vgg_type.replace('_bn', '')]
+ if 'bn' in vgg_type:
+ self.names = insert_bn(self.names)
+
+ # only borrow layers that will be used to avoid unused params
+ max_idx = 0
+ for v in layer_name_list:
+ idx = self.names.index(v)
+ if idx > max_idx:
+ max_idx = idx
+
+ if os.path.exists(VGG_PRETRAIN_PATH[vgg_type]):
+ vgg_net = getattr(vgg, vgg_type)(pretrained=False)
+ state_dict = torch.load(
+ VGG_PRETRAIN_PATH[vgg_type],
+ map_location=lambda storage, loc: storage)
+ vgg_net.load_state_dict(state_dict)
+ else:
+ vgg_net = getattr(vgg, vgg_type)(pretrained=True)
+
+ features = vgg_net.features[:max_idx + 1]
+
+ modified_net = OrderedDict()
+ for k, v in zip(self.names, features):
+ if 'pool' in k:
+ # if remove_pooling is true, pooling operation will be removed
+ if remove_pooling:
+ continue
+ else:
+ # in some cases, we may want to change the default stride
+ modified_net[k] = nn.MaxPool2d(
+ kernel_size=2, stride=pooling_stride)
+ else:
+ modified_net[k] = v
+
+ self.vgg_net = nn.Sequential(modified_net)
+
+ if not requires_grad:
+ self.vgg_net.eval()
+ for param in self.parameters():
+ param.requires_grad = False
+ else:
+ self.vgg_net.train()
+ for param in self.parameters():
+ param.requires_grad = True
+
+ if self.use_input_norm:
+ # the mean is for image with range [0, 1]
+ self.register_buffer(
+ 'mean',
+ torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
+ # the std is for image with range [0, 1]
+ self.register_buffer(
+ 'std',
+ torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
+
+ def forward(self, x):
+ """Forward function.
+
+ Args:
+ x (Tensor): Input tensor with shape (n, c, h, w).
+
+ Returns:
+ Tensor: Forward results.
+ """
+ if self.range_norm:
+ x = (x + 1) / 2
+ if self.use_input_norm:
+ x = (x - self.mean) / self.std
+
+ output = {}
+ for key, layer in self.vgg_net._modules.items():
+ x = layer(x)
+ if key in self.layer_name_list:
+ output[key] = x.clone()
+
+ return output
diff --git a/modelscope/models/cv/image_depth_estimation_bts/__init__.py b/modelscope/models/cv/image_depth_estimation_bts/__init__.py
new file mode 100644
index 00000000..29b18261
--- /dev/null
+++ b/modelscope/models/cv/image_depth_estimation_bts/__init__.py
@@ -0,0 +1,22 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+from typing import TYPE_CHECKING
+
+from modelscope.utils.import_utils import LazyImportModule
+
+if TYPE_CHECKING:
+ from .depth_estimation_bts_model import DepthEstimationBtsModel
+
+else:
+ _import_structure = {
+ 'depth_estimation_bts_model': ['DepthEstimationBtsModel']
+ }
+
+ import sys
+
+ sys.modules[__name__] = LazyImportModule(
+ __name__,
+ globals()['__file__'],
+ _import_structure,
+ module_spec=__spec__,
+ extra_objects={},
+ )
diff --git a/modelscope/models/cv/image_depth_estimation_bts/depth_estimation_bts_model.py b/modelscope/models/cv/image_depth_estimation_bts/depth_estimation_bts_model.py
new file mode 100644
index 00000000..08e04220
--- /dev/null
+++ b/modelscope/models/cv/image_depth_estimation_bts/depth_estimation_bts_model.py
@@ -0,0 +1,69 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import os.path as osp
+
+import torch
+
+from modelscope.metainfo import Models
+from modelscope.models.base.base_torch_model import TorchModel
+from modelscope.models.builder import MODELS
+from modelscope.outputs import OutputKeys
+from modelscope.utils.constant import ModelFile, Tasks
+from modelscope.utils.logger import get_logger
+from .networks.bts_model import BtsModel
+
+logger = get_logger()
+__all__ = ['DepthEstimationBtsModel']
+
+
+@MODELS.register_module(
+ Tasks.image_depth_estimation, module_name=Models.bts_depth_estimation)
+class DepthEstimationBtsModel(TorchModel):
+ """ Depth estimation model bts, implemented from paper https://arxiv.org/pdf/1907.10326.pdf.
+ The network utilizes novel local planar guidance layers located at multiple stage in the decoding phase.
+ The bts model is composed with encoder and decoder, an encoder for dense feature extraction and a decoder
+ for predicting the desired depth.
+ """
+
+ def __init__(self, model_dir: str, **kwargs):
+ """initialize the bts model from the `model_dir` path.
+
+ Args:
+ model_dir (str): the model path.
+ focal: focal length, pictures that do not work are input according to
+ the camera setting value at the time of shooting
+ dataset: used to set focal value according dataset type, only support 'kitti'
+ """
+ super().__init__(model_dir, **kwargs)
+ self.focal = 715.0873 # focal length, different dataset has different value
+ if 'focal' in kwargs:
+ self.focal = kwargs['focal']
+ elif 'dataset' in kwargs:
+ if kwargs['dataset'] == 'nyu':
+ self.focal = 518.8579
+ elif kwargs['dataset'] == 'kitti':
+ self.focal = 715.0873
+
+ self.model = BtsModel(focal=self.focal)
+
+ model_path = osp.join(model_dir, ModelFile.TORCH_MODEL_FILE)
+ checkpoint = torch.load(model_path)
+
+ state_dict = {}
+ for k in checkpoint['model_state_dict'].keys():
+ if k.startswith('module.'):
+ state_dict[k[7:]] = checkpoint['model_state_dict'][k]
+ else:
+ state_dict[k] = checkpoint['model_state_dict'][k]
+ self.model.load_state_dict(state_dict)
+ self.model.eval()
+
+ def forward(self, inputs):
+ return self.model(inputs['imgs'])
+
+ def postprocess(self, inputs):
+ results = {OutputKeys.DEPTHS: inputs}
+ return results
+
+ def inference(self, data):
+ results = self.forward(data)
+ return results
diff --git a/modelscope/models/audio/tts/kantts/preprocess/script_convertor/__init__.py b/modelscope/models/cv/image_depth_estimation_bts/networks/__init__.py
similarity index 100%
rename from modelscope/models/audio/tts/kantts/preprocess/script_convertor/__init__.py
rename to modelscope/models/cv/image_depth_estimation_bts/networks/__init__.py
diff --git a/modelscope/models/cv/image_depth_estimation_bts/networks/bts_model.py b/modelscope/models/cv/image_depth_estimation_bts/networks/bts_model.py
new file mode 100644
index 00000000..776f3074
--- /dev/null
+++ b/modelscope/models/cv/image_depth_estimation_bts/networks/bts_model.py
@@ -0,0 +1,42 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import torch
+import torch.nn as nn
+
+from .decoder import Decoder
+from .encoder import Encoder
+
+
+class BtsModel(nn.Module):
+ """Depth estimation model bts, implemented from paper https://arxiv.org/pdf/1907.10326.pdf.
+ The network utilizes novel local planar guidance layers located at multiple stage in the decoding phase.
+ The bts model is composed with encoder and decoder, an encoder for dense feature extraction and a decoder
+ for predicting the desired depth.
+ """
+
+ def __init__(self, focal=715.0873):
+ """initial bts model
+
+ Args:
+ focal (float): focal length, pictures that do not work are input according to
+ the camera setting value at the time of shooting
+ """
+ super(BtsModel, self).__init__()
+ self.focal = focal
+ self.encoder = Encoder()
+ self.decoder = Decoder()
+
+ def forward(self, x, focal=None):
+ """forward to estimation depth
+
+ Args:
+ x (Tensor): input image data
+ focal (float): The focal length when the picture is taken. By default, the focal length
+ of the data set when the model is created is used
+
+ Returns:
+ Tensor: Depth estimation image
+ """
+ focal_run = focal if focal else self.focal
+ skip_feat = self.encoder(x)
+ depth = self.decoder(skip_feat, torch.tensor(focal_run).cuda())
+ return depth
diff --git a/modelscope/models/cv/image_depth_estimation_bts/networks/decoder.py b/modelscope/models/cv/image_depth_estimation_bts/networks/decoder.py
new file mode 100644
index 00000000..e9cf9fa7
--- /dev/null
+++ b/modelscope/models/cv/image_depth_estimation_bts/networks/decoder.py
@@ -0,0 +1,72 @@
+# The implementation is modified from ErenBalatkan/Bts-PyTorch
+# made publicly available under the MIT license
+# https://github.com/ErenBalatkan/Bts-PyTorch/blob/master/BTS.py
+
+import torch
+import torch.nn as nn
+
+from .utils import (MAX_DEPTH, ASSPBlock, LPGBlock, Reduction, UpscaleBlock,
+ UpscaleLayer, UpscaleNetwork, activation_fn)
+
+
+class Decoder(nn.Module):
+
+ def __init__(self, dataset='kitti'):
+ super(Decoder, self).__init__()
+ self.UpscaleNet = UpscaleNetwork()
+ self.DenseASSPNet = ASSPBlock()
+
+ self.upscale_block3 = UpscaleBlock(64, 96, 128) # H4
+ self.upscale_block4 = UpscaleBlock(128, 96, 128) # H2
+
+ self.LPGBlock8 = LPGBlock(8, 128)
+ self.LPGBlock4 = LPGBlock(4, 64) # 64 Filter
+ self.LPGBlock2 = LPGBlock(2, 64) # 64 Filter
+
+ self.upconv_h4 = UpscaleLayer(128, 64)
+ self.upconv_h2 = UpscaleLayer(64, 32) # 64 Filter
+ self.upconv_h = UpscaleLayer(64, 32) # 32 filter
+
+ self.conv_h4 = nn.Conv2d(161, 64, 3, 1, 1, bias=True) # 64 Filter
+ self.conv_h2 = nn.Conv2d(129, 64, 3, 1, 1, bias=True) # 64 Filter
+ self.conv_h1 = nn.Conv2d(36, 32, 3, 1, 1, bias=True)
+
+ self.reduction1x1 = Reduction(1, 32, True)
+
+ self.final_conv = nn.Conv2d(32, 1, 3, 1, 1, bias=True)
+
+ self.dataset = dataset
+
+ def forward(self, joint_input, focal):
+ (dense_features, dense_op_h2, dense_op_h4, dense_op_h8,
+ dense_op_h16) = joint_input
+ upscaled_out = self.UpscaleNet(joint_input)
+
+ dense_assp_out = self.DenseASSPNet(upscaled_out)
+
+ upconv_h4 = self.upconv_h4(dense_assp_out)
+ depth_8x8 = self.LPGBlock8(dense_assp_out) / MAX_DEPTH
+ depth_8x8_ds = nn.functional.interpolate(
+ depth_8x8, scale_factor=1 / 4, mode='nearest')
+ depth_concat_4x4 = torch.cat((depth_8x8_ds, dense_op_h4, upconv_h4), 1)
+
+ conv_h4 = activation_fn(self.conv_h4(depth_concat_4x4))
+ upconv_h2 = self.upconv_h2(conv_h4)
+ depth_4x4 = self.LPGBlock4(conv_h4) / MAX_DEPTH
+
+ depth_4x4_ds = nn.functional.interpolate(
+ depth_4x4, scale_factor=1 / 2, mode='nearest')
+ depth_concat_2x2 = torch.cat((depth_4x4_ds, dense_op_h2, upconv_h2), 1)
+
+ conv_h2 = activation_fn(self.conv_h2(depth_concat_2x2))
+ upconv_h = self.upconv_h(conv_h2)
+ depth_1x1 = self.reduction1x1(upconv_h)
+ depth_2x2 = self.LPGBlock2(conv_h2) / MAX_DEPTH
+
+ depth_concat = torch.cat(
+ (upconv_h, depth_1x1, depth_2x2, depth_4x4, depth_8x8), 1)
+ depth = activation_fn(self.conv_h1(depth_concat))
+ depth = self.final_conv(depth).sigmoid() * MAX_DEPTH + 0.1
+ if self.dataset == 'kitti':
+ depth *= focal.view(-1, 1, 1, 1) / 715.0873
+ return depth
diff --git a/modelscope/models/cv/image_depth_estimation_bts/networks/encoder.py b/modelscope/models/cv/image_depth_estimation_bts/networks/encoder.py
new file mode 100644
index 00000000..784064c2
--- /dev/null
+++ b/modelscope/models/cv/image_depth_estimation_bts/networks/encoder.py
@@ -0,0 +1,70 @@
+# The implementation is modified from ErenBalatkan/Bts-PyTorch
+# made publicly available under the MIT license
+# https://github.com/ErenBalatkan/Bts-PyTorch/blob/master/BTS.py
+
+import torch.nn as nn
+import torchvision.models as models
+
+
+class Encoder(nn.Module):
+
+ def __init__(self, pretrained=False):
+ super(Encoder, self).__init__()
+ self.dense_op_h2 = None
+ self.dense_op_h4 = None
+ self.dense_op_h8 = None
+ self.dense_op_h16 = None
+ self.dense_features = None
+
+ self.dense_feature_extractor = self.initial_feature_extractor(
+ pretrained)
+ self.freeze_batch_norm()
+ self.initialize_hooks()
+
+ def freeze_batch_norm(self):
+ for module in self.dense_feature_extractor.modules():
+ if isinstance(module, nn.modules.BatchNorm2d):
+ module.track_running_stats = True
+ module.eval()
+ module.affine = True
+ module.requires_grad = True
+
+ def initial_feature_extractor(self, pretrained=False):
+ dfe = models.densenet161(pretrained=pretrained)
+ dfe.features.denseblock1.requires_grad = False
+ dfe.features.denseblock2.requires_grad = False
+ dfe.features.conv0.requires_grad = False
+ return dfe
+
+ def set_h2(self, module, input_, output):
+ self.dense_op_h2 = output
+
+ def set_h4(self, module, input_, output):
+ self.dense_op_h4 = output
+
+ def set_h8(self, module, input_, output):
+ self.dense_op_h8 = output
+
+ def set_h16(self, module, input_, output):
+ self.dense_op_h16 = output
+
+ def set_dense_features(self, module, input_, output):
+ self.dense_features = output
+
+ def initialize_hooks(self):
+ self.dense_feature_extractor.features.relu0.register_forward_hook(
+ self.set_h2)
+ self.dense_feature_extractor.features.pool0.register_forward_hook(
+ self.set_h4)
+ self.dense_feature_extractor.features.transition1.register_forward_hook(
+ self.set_h8)
+ self.dense_feature_extractor.features.transition2.register_forward_hook(
+ self.set_h16)
+ self.dense_feature_extractor.features.norm5.register_forward_hook(
+ self.set_dense_features)
+
+ def forward(self, x):
+ _ = self.dense_feature_extractor(x)
+ joint_input = (self.dense_features.relu(), self.dense_op_h2,
+ self.dense_op_h4, self.dense_op_h8, self.dense_op_h16)
+ return joint_input
diff --git a/modelscope/models/cv/image_depth_estimation_bts/networks/utils.py b/modelscope/models/cv/image_depth_estimation_bts/networks/utils.py
new file mode 100644
index 00000000..6566a511
--- /dev/null
+++ b/modelscope/models/cv/image_depth_estimation_bts/networks/utils.py
@@ -0,0 +1,246 @@
+# The implementation is modified from ErenBalatkan/Bts-PyTorch
+# made publicly available under the MIT license
+# https://github.com/ErenBalatkan/Bts-PyTorch/blob/master/BTS.py
+
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+activation_fn = nn.ELU()
+MAX_DEPTH = 81
+
+
+class UpscaleLayer(nn.Module):
+
+ def __init__(self, in_channels, out_channels):
+ super(UpscaleLayer, self).__init__()
+ self.conv = nn.Conv2d(
+ in_channels, out_channels, 3, padding=1, bias=True)
+ self.bn = nn.BatchNorm2d(out_channels, momentum=0.005)
+
+ def forward(self, input):
+ input = nn.functional.interpolate(
+ input, scale_factor=2, mode='nearest')
+ input = activation_fn(self.conv(input))
+ input = self.bn(input)
+ return input
+
+
+class UpscaleBlock(nn.Module):
+
+ def __init__(self, in_channels, skip_channels, out_channels):
+ super(UpscaleBlock, self).__init__()
+ self.uplayer = UpscaleLayer(in_channels, out_channels)
+ self.conv = nn.Conv2d(
+ out_channels + skip_channels,
+ out_channels,
+ 3,
+ padding=1,
+ bias=True)
+ self.bn2 = nn.BatchNorm2d(out_channels, 0.005)
+
+ def forward(self, input_j):
+ input, skip = input_j
+ input = self.uplayer(input)
+ cat = torch.cat((input, skip), 1)
+ input = activation_fn(self.conv(cat))
+ input = self.bn2(input)
+ return input, cat
+
+
+class UpscaleNetwork(nn.Module):
+
+ def __init__(self, filters=[512, 256]):
+ super(
+ UpscaleNetwork,
+ self,
+ ).__init__()
+ self.upscale_block1 = UpscaleBlock(2208, 384, filters[0]) # H16
+ self.upscale_block2 = UpscaleBlock(filters[0], 192, filters[1]) # H8
+
+ def forward(self, raw_input):
+ input, h2, h4, h8, h16 = raw_input
+ input, _ = self.upscale_block1((input, h16))
+ input, cat = self.upscale_block2((input, h8))
+ return input, cat
+
+
+class AtrousBlock(nn.Module):
+
+ def __init__(self,
+ input_filters,
+ filters,
+ dilation,
+ apply_initial_bn=True):
+ super(AtrousBlock, self).__init__()
+
+ self.initial_bn = nn.BatchNorm2d(input_filters, 0.005)
+ self.apply_initial_bn = apply_initial_bn
+
+ self.conv1 = nn.Conv2d(input_filters, filters * 2, 1, 1, 0, bias=False)
+ self.norm1 = nn.BatchNorm2d(filters * 2, 0.005)
+
+ self.atrous_conv = nn.Conv2d(
+ filters * 2, filters, 3, 1, dilation, dilation, bias=False)
+ self.norm2 = nn.BatchNorm2d(filters, 0.005)
+
+ def forward(self, input):
+ if self.apply_initial_bn:
+ input = self.initial_bn(input)
+
+ input = self.conv1(input.relu())
+ input = self.norm1(input)
+ input = self.atrous_conv(input.relu())
+ input = self.norm2(input)
+ return input
+
+
+class ASSPBlock(nn.Module):
+
+ def __init__(self, input_filters=256, cat_filters=448, atrous_filters=128):
+ super(ASSPBlock, self).__init__()
+
+ self.atrous_conv_r3 = AtrousBlock(
+ input_filters, atrous_filters, 3, apply_initial_bn=False)
+ self.atrous_conv_r6 = AtrousBlock(cat_filters + atrous_filters,
+ atrous_filters, 6)
+ self.atrous_conv_r12 = AtrousBlock(cat_filters + atrous_filters * 2,
+ atrous_filters, 12)
+ self.atrous_conv_r18 = AtrousBlock(cat_filters + atrous_filters * 3,
+ atrous_filters, 18)
+ self.atrous_conv_r24 = AtrousBlock(cat_filters + atrous_filters * 4,
+ atrous_filters, 24)
+
+ self.conv = nn.Conv2d(
+ 5 * atrous_filters + cat_filters,
+ atrous_filters,
+ 3,
+ 1,
+ 1,
+ bias=True)
+
+ def forward(self, input):
+ input, cat = input
+ layer1_out = self.atrous_conv_r3(input)
+ concat1 = torch.cat((cat, layer1_out), 1)
+
+ layer2_out = self.atrous_conv_r6(concat1)
+ concat2 = torch.cat((concat1, layer2_out), 1)
+
+ layer3_out = self.atrous_conv_r12(concat2)
+ concat3 = torch.cat((concat2, layer3_out), 1)
+
+ layer4_out = self.atrous_conv_r18(concat3)
+ concat4 = torch.cat((concat3, layer4_out), 1)
+
+ layer5_out = self.atrous_conv_r24(concat4)
+ concat5 = torch.cat((concat4, layer5_out), 1)
+
+ features = activation_fn(self.conv(concat5))
+ return features
+
+
+class Reduction(nn.Module):
+
+ def __init__(self, scale, input_filters, is_final=False):
+ super(Reduction, self).__init__()
+ reduction_count = int(math.log(input_filters, 2)) - 2
+ self.reductions = torch.nn.Sequential()
+ for i in range(reduction_count):
+ if i != reduction_count - 1:
+ self.reductions.add_module(
+ '1x1_reduc_%d_%d' % (scale, i),
+ nn.Sequential(
+ nn.Conv2d(
+ int(input_filters / math.pow(2, i)),
+ int(input_filters / math.pow(2, i + 1)),
+ 1,
+ 1,
+ 0,
+ bias=True), activation_fn))
+ else:
+ if not is_final:
+ self.reductions.add_module(
+ '1x1_reduc_%d_%d' % (scale, i),
+ nn.Sequential(
+ nn.Conv2d(
+ int(input_filters / math.pow(2, i)),
+ int(input_filters / math.pow(2, i + 1)),
+ 1,
+ 1,
+ 0,
+ bias=True)))
+ else:
+ self.reductions.add_module(
+ '1x1_reduc_%d_%d' % (scale, i),
+ nn.Sequential(
+ nn.Conv2d(
+ int(input_filters / math.pow(2, i)),
+ 1,
+ 1,
+ 1,
+ 0,
+ bias=True), nn.Sigmoid()))
+
+ def forward(self, ip):
+ return self.reductions(ip)
+
+
+class LPGBlock(nn.Module):
+
+ def __init__(self, scale, input_filters=128):
+ super(LPGBlock, self).__init__()
+ self.scale = scale
+
+ self.reduction = Reduction(scale, input_filters)
+ self.conv = nn.Conv2d(4, 3, 1, 1, 0)
+
+ self.u = torch.arange(self.scale).reshape([1, 1, self.scale]).float()
+ self.v = torch.arange(int(self.scale)).reshape([1, self.scale,
+ 1]).float()
+
+ def forward(self, input):
+ input = self.reduction(input)
+
+ plane_parameters = torch.zeros_like(input)
+ input = self.conv(input)
+
+ theta = input[:, 0, :, :].sigmoid() * 3.1415926535 / 6
+ phi = input[:, 1, :, :].sigmoid() * 3.1415926535 * 2
+ dist = input[:, 2, :, :].sigmoid() * MAX_DEPTH
+
+ plane_parameters[:, 0, :, :] = torch.sin(theta) * torch.cos(phi)
+ plane_parameters[:, 1, :, :] = torch.sin(theta) * torch.sin(phi)
+ plane_parameters[:, 2, :, :] = torch.cos(theta)
+ plane_parameters[:, 3, :, :] = dist
+
+ plane_parameters[:, 0:3, :, :] = F.normalize(
+ plane_parameters.clone()[:, 0:3, :, :], 2, 1)
+
+ plane_eq = plane_parameters.float()
+
+ plane_eq_expanded = torch.repeat_interleave(plane_eq, int(self.scale),
+ 2)
+ plane_eq_expanded = torch.repeat_interleave(plane_eq_expanded,
+ int(self.scale), 3)
+
+ n1 = plane_eq_expanded[:, 0, :, :]
+ n2 = plane_eq_expanded[:, 1, :, :]
+ n3 = plane_eq_expanded[:, 2, :, :]
+ n4 = plane_eq_expanded[:, 3, :, :]
+
+ u = self.u.repeat(
+ plane_eq.size(0),
+ plane_eq.size(2) * int(self.scale), plane_eq.size(3)).cuda()
+ u = (u - (self.scale - 1) * 0.5) / self.scale
+
+ v = self.v.repeat(
+ plane_eq.size(0), plane_eq.size(2),
+ plane_eq.size(3) * int(self.scale)).cuda()
+ v = (v - (self.scale - 1) * 0.5) / self.scale
+
+ depth = n4 / (n1 * u + n2 * v + n3)
+ depth = depth.unsqueeze(1)
+ return depth
diff --git a/modelscope/models/cv/image_quality_assessment_man/__init__.py b/modelscope/models/cv/image_quality_assessment_man/__init__.py
new file mode 100644
index 00000000..f29b90a7
--- /dev/null
+++ b/modelscope/models/cv/image_quality_assessment_man/__init__.py
@@ -0,0 +1,22 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+from typing import TYPE_CHECKING
+
+from modelscope.utils.import_utils import LazyImportModule
+
+if TYPE_CHECKING:
+ from .image_quality_assessment_man import ImageQualityAssessmentMAN
+
+else:
+ _import_structure = {
+ 'image_quality_assessment_man': ['ImageQualityAssessmentMAN']
+ }
+
+ import sys
+
+ sys.modules[__name__] = LazyImportModule(
+ __name__,
+ globals()['__file__'],
+ _import_structure,
+ module_spec=__spec__,
+ extra_objects={},
+ )
diff --git a/modelscope/models/cv/image_quality_assessment_man/image_quality_assessment_man.py b/modelscope/models/cv/image_quality_assessment_man/image_quality_assessment_man.py
new file mode 100644
index 00000000..e290909e
--- /dev/null
+++ b/modelscope/models/cv/image_quality_assessment_man/image_quality_assessment_man.py
@@ -0,0 +1,79 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import os
+from typing import Any, Dict, Union
+
+import torch.cuda
+import torch.nn.functional as F
+
+from modelscope.metainfo import Models
+from modelscope.models.base import Tensor
+from modelscope.models.base.base_torch_model import TorchModel
+from modelscope.models.builder import MODELS
+from modelscope.models.cv.image_quality_assessment_man.maniqa import MANIQA
+from modelscope.utils.config import Config
+from modelscope.utils.constant import ModelFile, Tasks
+from modelscope.utils.logger import get_logger
+
+logger = get_logger()
+
+__all__ = ['ImageQualityAssessmentMAN']
+
+
+@MODELS.register_module(
+ Tasks.image_quality_assessment_mos,
+ module_name=Models.image_quality_assessment_man)
+class ImageQualityAssessmentMAN(TorchModel):
+
+ def __init__(self, model_dir: str, *args, **kwargs):
+ """initialize the image_quality_assessment_man model from the `model_dir` path.
+
+ Args:
+ model_dir (str): the model path.
+
+ """
+ super().__init__(model_dir, *args, **kwargs)
+ self.model_dir = model_dir
+ self.config = Config.from_file(
+ os.path.join(self.model_dir, ModelFile.CONFIGURATION))
+ model_path = os.path.join(model_dir, ModelFile.TORCH_MODEL_FILE)
+
+ self.model = MANIQA()
+ self.model = self._load_pretrained(self.model, model_path)
+ self.model.eval()
+
+ def _train_forward(self, input: Tensor,
+ target: Tensor) -> Dict[str, Tensor]:
+ losses = dict()
+ return losses
+
+ def _inference_forward(self, input: Tensor) -> Dict[str, Tensor]:
+ return {'output': self.model(input).clamp(0, 1)}
+
+ def _evaluate_postprocess(self, input: Tensor,
+ target: Tensor) -> Dict[str, list]:
+
+ torch.cuda.empty_cache()
+ with torch.no_grad():
+ preds = self.model(input)
+ preds = preds.clamp(0, 1).cpu()
+ del input
+ target = target.cpu()
+ torch.cuda.empty_cache()
+ return {'pred': preds, 'target': target}
+
+ def forward(self, inputs: Dict[str,
+ Tensor]) -> Dict[str, Union[list, Tensor]]:
+ """return the result by the model
+
+ Args:
+ inputs (Tensor): the preprocessed data
+
+ Returns:
+ Dict[str, Tensor]: results
+ """
+ if self.training:
+ return self._train_forward(**inputs)
+ elif 'target' in inputs:
+ return self._evaluate_postprocess(**inputs)
+ else:
+ return self._inference_forward(**inputs)
diff --git a/modelscope/models/cv/image_quality_assessment_man/maniqa.py b/modelscope/models/cv/image_quality_assessment_man/maniqa.py
new file mode 100644
index 00000000..8c924309
--- /dev/null
+++ b/modelscope/models/cv/image_quality_assessment_man/maniqa.py
@@ -0,0 +1,161 @@
+# This implementation is adopted from MANIQA, made pubicly available under the Apache License 2.0 at
+# https://github.com/IIGROUP/MANIQA/blob/master/models/maniqa.py
+
+import timm
+import torch
+import torch.nn as nn
+from einops import rearrange
+from timm.models.vision_transformer import Block
+
+from .swin import SwinTransformer
+
+
+class TABlock(nn.Module):
+
+ def __init__(self, dim, drop=0.1):
+ super().__init__()
+ self.c_q = nn.Linear(dim, dim)
+ self.c_k = nn.Linear(dim, dim)
+ self.c_v = nn.Linear(dim, dim)
+ self.norm_fact = dim**-0.5
+ self.softmax = nn.Softmax(dim=-1)
+ self.proj_drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ _x = x
+ B, C, N = x.shape
+ q = self.c_q(x)
+ k = self.c_k(x)
+ v = self.c_v(x)
+
+ attn = q @ k.transpose(-2, -1) * self.norm_fact
+ attn = self.softmax(attn)
+ x = (attn @ v).transpose(1, 2).reshape(B, C, N)
+ x = self.proj_drop(x)
+ x = x + _x
+ return x
+
+
+class SaveOutput:
+
+ def __init__(self):
+ self.outputs = []
+
+ def __call__(self, module, module_in, module_out):
+ self.outputs.append(module_out)
+
+ def clear(self):
+ self.outputs = []
+
+
+class MANIQA(nn.Module):
+
+ def __init__(self,
+ embed_dim=768,
+ num_outputs=1,
+ patch_size=8,
+ drop=0.1,
+ depths=[2, 2],
+ window_size=4,
+ dim_mlp=768,
+ num_heads=[4, 4],
+ img_size=224,
+ num_tab=2,
+ scale=0.13,
+ **kwargs):
+ super().__init__()
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.input_size = img_size // patch_size
+ self.patches_resolution = (img_size // patch_size,
+ img_size // patch_size)
+
+ self.vit = timm.create_model('vit_base_patch8_224', pretrained=False)
+ self.save_output = SaveOutput()
+ hook_handles = []
+ for layer in self.vit.modules():
+ if isinstance(layer, Block):
+ handle = layer.register_forward_hook(self.save_output)
+ hook_handles.append(handle)
+
+ self.tablock1 = nn.ModuleList()
+ for i in range(num_tab):
+ tab = TABlock(self.input_size**2)
+ self.tablock1.append(tab)
+
+ self.conv1 = nn.Conv2d(embed_dim * 4, embed_dim, 1, 1, 0)
+ self.swintransformer1 = SwinTransformer(
+ patches_resolution=self.patches_resolution,
+ depths=depths,
+ num_heads=num_heads,
+ embed_dim=embed_dim,
+ window_size=window_size,
+ dim_mlp=dim_mlp,
+ scale=scale)
+
+ self.tablock2 = nn.ModuleList()
+ for i in range(num_tab):
+ tab = TABlock(self.input_size**2)
+ self.tablock2.append(tab)
+
+ self.conv2 = nn.Conv2d(embed_dim, embed_dim // 2, 1, 1, 0)
+ self.swintransformer2 = SwinTransformer(
+ patches_resolution=self.patches_resolution,
+ depths=depths,
+ num_heads=num_heads,
+ embed_dim=embed_dim // 2,
+ window_size=window_size,
+ dim_mlp=dim_mlp,
+ scale=scale)
+
+ self.fc_score = nn.Sequential(
+ nn.Linear(embed_dim // 2, embed_dim // 2), nn.ReLU(),
+ nn.Dropout(drop), nn.Linear(embed_dim // 2, num_outputs),
+ nn.ReLU())
+ self.fc_weight = nn.Sequential(
+ nn.Linear(embed_dim // 2, embed_dim // 2), nn.ReLU(),
+ nn.Dropout(drop), nn.Linear(embed_dim // 2, num_outputs),
+ nn.Sigmoid())
+
+ def extract_feature(self, save_output):
+ x6 = save_output.outputs[6][:, 1:]
+ x7 = save_output.outputs[7][:, 1:]
+ x8 = save_output.outputs[8][:, 1:]
+ x9 = save_output.outputs[9][:, 1:]
+ x = torch.cat((x6, x7, x8, x9), dim=2)
+ return x
+
+ def forward(self, x):
+ self.vit(x)
+ x = self.extract_feature(self.save_output)
+ self.save_output.outputs.clear()
+
+ # stage 1
+ x = rearrange(
+ x, 'b (h w) c -> b c (h w)', h=self.input_size, w=self.input_size)
+ for tab in self.tablock1:
+ x = tab(x)
+ x = rearrange(
+ x, 'b c (h w) -> b c h w', h=self.input_size, w=self.input_size)
+ x = self.conv1(x)
+ x = self.swintransformer1(x)
+
+ # stage2
+ x = rearrange(
+ x, 'b c h w -> b c (h w)', h=self.input_size, w=self.input_size)
+ for tab in self.tablock2:
+ x = tab(x)
+ x = rearrange(
+ x, 'b c (h w) -> b c h w', h=self.input_size, w=self.input_size)
+ x = self.conv2(x)
+ x = self.swintransformer2(x)
+
+ x = rearrange(
+ x, 'b c h w -> b (h w) c', h=self.input_size, w=self.input_size)
+ score = torch.tensor([]).cuda()
+ for i in range(x.shape[0]):
+ f = self.fc_score(x[i])
+ w = self.fc_weight(x[i])
+ _s = torch.sum(f * w) / torch.sum(w)
+ score = torch.cat((score, _s.unsqueeze(0)), 0)
+ return score
diff --git a/modelscope/models/cv/image_quality_assessment_man/swin.py b/modelscope/models/cv/image_quality_assessment_man/swin.py
new file mode 100644
index 00000000..df58277f
--- /dev/null
+++ b/modelscope/models/cv/image_quality_assessment_man/swin.py
@@ -0,0 +1,632 @@
+# This implementation is adopted form SwinTransformer, made pubicly available under the MIT License at
+# https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py
+
+import collections.abc
+import math
+import warnings
+from itertools import repeat
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as checkpoint
+from einops import rearrange
+
+
+def _no_grad_trunc_normal_(tensor, mean, std, a, b):
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
+ def norm_cdf(x):
+ # Computes standard normal cumulative distribution function
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
+
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
+ warnings.warn(
+ 'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. '
+ 'The distribution of values may be incorrect.',
+ stacklevel=2)
+
+ with torch.no_grad():
+ # Values are generated by using a truncated uniform distribution and
+ # then using the inverse CDF for the normal distribution.
+ # Get upper and lower cdf values
+ lower = norm_cdf((a - mean) / std)
+ upper = norm_cdf((b - mean) / std)
+
+ # Uniformly fill tensor with values from [l, u], then translate to
+ # [2l-1, 2u-1].
+ tensor.uniform_(2 * lower - 1, 2 * upper - 1)
+
+ # Use inverse cdf transform for normal distribution to get truncated
+ # standard normal
+ tensor.erfinv_()
+
+ # Transform to proper mean, std
+ tensor.mul_(std * math.sqrt(2.))
+ tensor.add_(mean)
+
+ # Clamp to ensure it's in the proper range
+ tensor.clamp_(min=a, max=b)
+ return tensor
+
+
+def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
+ # type: (Tensor, float, float, float, float) -> Tensor
+ r"""Fills the input Tensor with values drawn from a truncated
+ normal distribution. The values are effectively drawn from the
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
+ with values outside :math:`[a, b]` redrawn until they are within
+ the bounds. The method used for generating the random values works
+ best when :math:`a \leq \text{mean} \leq b`.
+ Args:
+ tensor: an n-dimensional `torch.Tensor`
+ mean: the mean of the normal distribution
+ std: the standard deviation of the normal distribution
+ a: the minimum cutoff value
+ b: the maximum cutoff value
+ Examples:
+ >>> w = torch.empty(3, 5)
+ >>> nn.init.trunc_normal_(w)
+ """
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
+
+
+# From PyTorch internals
+def _ntuple(n):
+
+ def parse(x):
+ if isinstance(x, collections.abc.Iterable):
+ return x
+ return tuple(repeat(x, n))
+
+ return parse
+
+
+to_2tuple = _ntuple(2)
+
+
+def drop_path(x,
+ drop_prob: float = 0.,
+ training: bool = False,
+ scale_by_keep: bool = True):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
+ 'survival rate' as the argument.
+
+ """
+ if drop_prob == 0. or not training:
+ return x
+ keep_prob = 1 - drop_prob
+ shape = (x.shape[0], ) + (1, ) * (
+ x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
+ if keep_prob > 0.0 and scale_by_keep:
+ random_tensor.div_(keep_prob)
+ return x * random_tensor
+
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+ """
+
+ def __init__(self, drop_prob=None, scale_by_keep=True):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+ self.scale_by_keep = scale_by_keep
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
+
+
+class Mlp(nn.Module):
+
+ def __init__(self,
+ in_features,
+ hidden_features=None,
+ out_features=None,
+ act_layer=nn.GELU,
+ drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+def window_partition(x, window_size):
+ """
+ Args:
+ x: (B, H, W, C)
+ window_size (int): window size
+
+ Returns:
+ windows: (num_windows*B, window_size, window_size, C)
+ """
+ B, H, W, C = x.shape
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size,
+ C)
+ windows = x.permute(0, 1, 3, 2, 4,
+ 5).contiguous().view(-1, window_size, window_size, C)
+ return windows
+
+
+def window_reverse(windows, window_size, H, W):
+ """
+ Args:
+ windows: (num_windows*B, window_size, window_size, C)
+ window_size (int): Window size
+ H (int): Height of image
+ W (int): Width of image
+
+ Returns:
+ x: (B, H, W, C)
+ """
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
+ x = windows.view(B, H // window_size, W // window_size, window_size,
+ window_size, -1)
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
+ return x
+
+
+class WindowAttention(nn.Module):
+ r""" Window based multi-head self attention (W-MSA) module with relative position bias.
+ It supports both of shifted and non-shifted window.
+
+ Args:
+ dim (int): Number of input channels.
+ window_size (tuple[int]): The height and width of the window.
+ num_heads (int): Number of attention heads.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
+ """
+
+ def __init__(self,
+ dim,
+ window_size,
+ num_heads,
+ qkv_bias=True,
+ qk_scale=None,
+ attn_drop=0.,
+ proj_drop=0.):
+
+ super().__init__()
+ self.dim = dim
+ self.window_size = window_size # Wh, Ww
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = qk_scale or head_dim**-0.5
+
+ # define a parameter table of relative position bias
+ self.relative_position_bias_table = nn.Parameter(
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1),
+ num_heads)) # 2*Wh-1 * 2*Ww-1, nH
+
+ # get pair-wise relative position index for each token inside the window
+ coords_h = torch.arange(self.window_size[0])
+ coords_w = torch.arange(self.window_size[1])
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
+ relative_coords = coords_flatten[:, :,
+ None] - coords_flatten[:,
+ None, :] # 2, Wh*Ww, Wh*Ww
+ relative_coords = relative_coords.permute(
+ 1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
+ relative_coords[:, :,
+ 0] += self.window_size[0] - 1 # shift to start from 0
+ relative_coords[:, :, 1] += self.window_size[1] - 1
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
+ self.register_buffer('relative_position_index',
+ relative_position_index)
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ trunc_normal_(self.relative_position_bias_table, std=.02)
+ self.softmax = nn.Softmax(dim=-1)
+
+ def forward(self, x, mask=None):
+ """
+ Args:
+ x: input features with shape of (num_windows*B, N, C)
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
+ """
+ B_, N, C = x.shape
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads,
+ C // self.num_heads).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv[0], qkv[1], qkv[
+ 2] # make torchscript happy (cannot use tensor as tuple)
+
+ q = q * self.scale
+ attn = (q @ k.transpose(-2, -1))
+
+ relative_position_bias = self.relative_position_bias_table[
+ self.relative_position_index.view(-1)].view(
+ self.window_size[0] * self.window_size[1],
+ self.window_size[0] * self.window_size[1],
+ -1) # Wh*Ww,Wh*Ww,nH
+ relative_position_bias = relative_position_bias.permute(
+ 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
+ attn = attn + relative_position_bias.unsqueeze(0)
+
+ if mask is not None:
+ nW = mask.shape[0]
+ attn = attn.view(B_ // nW, nW, self.num_heads, N,
+ N) + mask.unsqueeze(1).unsqueeze(0)
+ attn = attn.view(-1, self.num_heads, N, N)
+ attn = self.softmax(attn)
+ else:
+ attn = self.softmax(attn)
+
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+ def extra_repr(self) -> str:
+ return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
+
+ def flops(self, N):
+ # calculate flops for 1 window with token length of N
+ flops = 0
+ # qkv = self.qkv(x)
+ flops += N * self.dim * 3 * self.dim
+ # attn = (q @ k.transpose(-2, -1))
+ flops += self.num_heads * N * (self.dim // self.num_heads) * N
+ # x = (attn @ v)
+ flops += self.num_heads * N * N * (self.dim // self.num_heads)
+ # x = self.proj(x)
+ flops += N * self.dim * self.dim
+ return flops
+
+
+class SwinBlock(nn.Module):
+ r""" Swin Transformer Block.
+
+ Args:
+ dim (int): Number of input channels.
+ input_resolution (tuple[int]): Input resulotion.
+ num_heads (int): Number of attention heads.
+ window_size (int): Window size.
+ shift_size (int): Shift size for SW-MSA.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+ drop (float, optional): Dropout rate. Default: 0.0
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ """
+
+ def __init__(self,
+ dim,
+ input_resolution,
+ num_heads,
+ window_size=7,
+ shift_size=0,
+ dim_mlp=1024.,
+ qkv_bias=True,
+ qk_scale=None,
+ drop=0.,
+ attn_drop=0.,
+ drop_path=0.,
+ act_layer=nn.GELU,
+ norm_layer=nn.LayerNorm):
+ super().__init__()
+ self.dim = dim
+ self.input_resolution = input_resolution
+ self.num_heads = num_heads
+ self.window_size = window_size
+ self.shift_size = shift_size
+ self.dim_mlp = dim_mlp
+ self.mlp_ratio = self.dim_mlp // dim
+ if min(self.input_resolution) <= self.window_size:
+ # if window size is larger than input resolution, we don't partition windows
+ self.shift_size = 0
+ self.window_size = min(self.input_resolution)
+ assert 0 <= self.shift_size < self.window_size, 'shift_size must in 0-window_size'
+
+ self.norm1 = norm_layer(dim)
+ self.attn = WindowAttention(
+ dim,
+ window_size=to_2tuple(self.window_size),
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ attn_drop=attn_drop,
+ proj_drop=drop)
+
+ self.drop_path = DropPath(
+ drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = self.dim_mlp
+ self.mlp = Mlp(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=drop)
+
+ if self.shift_size > 0:
+ # calculate attention mask for SW-MSA
+ H, W = self.input_resolution
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
+ h_slices = (slice(0, -self.window_size),
+ slice(-self.window_size,
+ -self.shift_size), slice(-self.shift_size, None))
+ w_slices = (slice(0, -self.window_size),
+ slice(-self.window_size,
+ -self.shift_size), slice(-self.shift_size, None))
+ cnt = 0
+ for h in h_slices:
+ for w in w_slices:
+ img_mask[:, h, w, :] = cnt
+ cnt += 1
+
+ mask_windows = window_partition(
+ img_mask, self.window_size) # nW, window_size, window_size, 1
+ mask_windows = mask_windows.view(
+ -1, self.window_size * self.window_size)
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
+ attn_mask = attn_mask.masked_fill(attn_mask != 0,
+ float(-100.0)).masked_fill(
+ attn_mask == 0, float(0.0))
+ else:
+ attn_mask = None
+
+ self.register_buffer('attn_mask', attn_mask)
+
+ def forward(self, x):
+ H, W = self.input_resolution
+ B, L, C = x.shape
+ assert L == H * W, 'input feature has wrong size'
+
+ shortcut = x
+ x = self.norm1(x)
+ x = x.view(B, H, W, C)
+
+ # cyclic shift
+ if self.shift_size > 0:
+ shifted_x = torch.roll(
+ x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
+ else:
+ shifted_x = x
+
+ # partition windows
+ x_windows = window_partition(
+ shifted_x, self.window_size) # nW*B, window_size, window_size, C
+ x_windows = x_windows.view(-1, self.window_size * self.window_size,
+ C) # nW*B, window_size*window_size, C
+
+ # W-MSA/SW-MSA
+ attn_windows = self.attn(
+ x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
+
+ # merge windows
+ attn_windows = attn_windows.view(-1, self.window_size,
+ self.window_size, C)
+ shifted_x = window_reverse(attn_windows, self.window_size, H,
+ W) # B H' W' C
+
+ # reverse cyclic shift
+ if self.shift_size > 0:
+ x = torch.roll(
+ shifted_x,
+ shifts=(self.shift_size, self.shift_size),
+ dims=(1, 2))
+ else:
+ x = shifted_x
+ x = x.view(B, H * W, C)
+
+ # FFN
+ x = shortcut + self.drop_path(x)
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+
+ return x
+
+ def extra_repr(self) -> str:
+ return f'dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, ' \
+ f'window_size={self.window_size}, shift_size={self.shift_size}'
+
+ def flops(self):
+ flops = 0
+ H, W = self.input_resolution
+ # norm1
+ flops += self.dim * H * W
+ # W-MSA/SW-MSA
+ nW = H * W / self.window_size / self.window_size
+ flops += nW * self.attn.flops(self.window_size * self.window_size)
+ # mlp
+ flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
+ # norm2
+ flops += self.dim * H * W
+ return flops
+
+
+class BasicLayer(nn.Module):
+ """ A basic Swin Transformer layer for one stage.
+
+ Args:
+ dim (int): Number of input channels.
+ input_resolution (tuple[int]): Input resolution.
+ depth (int): Number of blocks.
+ num_heads (int): Number of attention heads.
+ window_size (int): Local window size.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+ drop (float, optional): Dropout rate. Default: 0.0
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+ """
+
+ def __init__(self,
+ dim,
+ input_resolution,
+ depth,
+ num_heads,
+ window_size=7,
+ dim_mlp=1024,
+ qkv_bias=True,
+ qk_scale=None,
+ drop=0.,
+ attn_drop=0.,
+ drop_path=0.,
+ norm_layer=nn.LayerNorm,
+ downsample=None,
+ use_checkpoint=False):
+
+ super().__init__()
+ self.dim = dim
+ self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
+ self.input_resolution = input_resolution
+ self.depth = depth
+ self.use_checkpoint = use_checkpoint
+
+ # build blocks
+ self.blocks = nn.ModuleList([
+ SwinBlock(
+ dim=dim,
+ input_resolution=input_resolution,
+ num_heads=num_heads,
+ window_size=window_size,
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
+ dim_mlp=dim_mlp,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop,
+ attn_drop=attn_drop,
+ drop_path=drop_path[i]
+ if isinstance(drop_path, list) else drop_path,
+ norm_layer=norm_layer) for i in range(depth)
+ ])
+
+ # patch merging layer
+ if downsample is not None:
+ self.downsample = downsample(
+ input_resolution, dim=dim, norm_layer=norm_layer)
+ else:
+ self.downsample = None
+
+ def forward(self, x):
+ for blk in self.blocks:
+ if self.use_checkpoint:
+ x = checkpoint.checkpoint(blk, x)
+ else:
+ x = blk(x)
+ x = rearrange(
+ x,
+ 'b (h w) c -> b c h w',
+ h=self.input_resolution[0],
+ w=self.input_resolution[1])
+ x = F.relu(self.conv(x))
+ x = rearrange(x, 'b c h w -> b (h w) c')
+ return x
+
+ def extra_repr(self) -> str:
+ return f'dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}'
+
+ def flops(self):
+ flops = 0
+ for blk in self.blocks:
+ flops += blk.flops()
+ if self.downsample is not None:
+ flops += self.downsample.flops()
+ return flops
+
+
+class SwinTransformer(nn.Module):
+
+ def __init__(self,
+ patches_resolution,
+ depths=[2, 2, 6, 2],
+ num_heads=[3, 6, 12, 24],
+ embed_dim=256,
+ drop=0.1,
+ drop_rate=0.,
+ drop_path_rate=0.1,
+ dropout=0.,
+ window_size=7,
+ dim_mlp=1024,
+ qkv_bias=True,
+ qk_scale=None,
+ attn_drop_rate=0.,
+ norm_layer=nn.LayerNorm,
+ downsample=None,
+ use_checkpoint=False,
+ scale=0.8,
+ **kwargs):
+ super().__init__()
+ self.scale = scale
+ self.embed_dim = embed_dim
+ self.depths = depths
+ self.num_heads = num_heads
+ self.window_size = window_size
+ self.dropout = nn.Dropout(p=drop)
+ self.num_features = embed_dim
+ self.num_layers = len(depths)
+ self.patches_resolution = (patches_resolution[0],
+ patches_resolution[1])
+ self.downsample = nn.Conv2d(
+ self.embed_dim, self.embed_dim, kernel_size=3, stride=2, padding=1)
+ # stochastic depth
+ dpr = [
+ x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
+ ]
+
+ self.layers = nn.ModuleList()
+ for i_layer in range(self.num_layers):
+ layer = BasicLayer(
+ dim=self.embed_dim,
+ input_resolution=patches_resolution,
+ depth=self.depths[i_layer],
+ num_heads=self.num_heads[i_layer],
+ window_size=self.window_size,
+ dim_mlp=dim_mlp,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=dropout,
+ attn_drop=attn_drop_rate,
+ drop_path=dpr[sum(self.depths[:i_layer]
+ ):sum(self.depths[:i_layer + 1])],
+ norm_layer=norm_layer,
+ downsample=downsample,
+ use_checkpoint=use_checkpoint)
+ self.layers.append(layer)
+
+ def forward(self, x):
+ x = self.dropout(x)
+ x = rearrange(x, 'b c h w -> b (h w) c')
+ for layer in self.layers:
+ _x = x
+ x = layer(x)
+ x = self.scale * x + _x
+ x = rearrange(
+ x,
+ 'b (h w) c -> b c h w',
+ h=self.patches_resolution[0],
+ w=self.patches_resolution[1])
+ return x
diff --git a/modelscope/models/cv/ocr_detection/model.py b/modelscope/models/cv/ocr_detection/model.py
index fdb4f8a1..712973ce 100644
--- a/modelscope/models/cv/ocr_detection/model.py
+++ b/modelscope/models/cv/ocr_detection/model.py
@@ -46,7 +46,7 @@ class OCRDetection(TorchModel):
)
if model_path != '':
self.detector.load_state_dict(
- torch.load(model_path, map_location='cpu'))
+ torch.load(model_path, map_location='cpu'), strict=False)
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
"""
diff --git a/modelscope/models/cv/ocr_detection/modules/dbnet.py b/modelscope/models/cv/ocr_detection/modules/dbnet.py
index 33888324..82b0e512 100644
--- a/modelscope/models/cv/ocr_detection/modules/dbnet.py
+++ b/modelscope/models/cv/ocr_detection/modules/dbnet.py
@@ -2,10 +2,10 @@
# Part of implementation is adopted from ViLT,
# made publicly available under the Apache License 2.0 at https://github.com/dandelin/ViLT.
# ------------------------------------------------------------------------------
-
import math
import os
import sys
+from collections import OrderedDict
import torch
import torch.nn as nn
@@ -413,12 +413,48 @@ class SegDetector(nn.Module):
# this is the pred module, not binarization module;
# We do not correct the name due to the trained model.
binary = self.binarize(fuse)
- return binary
+ if self.training:
+ result = OrderedDict(binary=binary)
+ else:
+ return binary
+ if self.adaptive and self.training:
+ if self.serial:
+ fuse = torch.cat(
+ (fuse, nn.functional.interpolate(binary, fuse.shape[2:])),
+ 1)
+ thresh = self.thresh(fuse)
+ thresh_binary = self.step_function(binary, thresh)
+ result.update(thresh=thresh, thresh_binary=thresh_binary)
+ return result
def step_function(self, x, y):
return torch.reciprocal(1 + torch.exp(-self.k * (x - y)))
+class BasicModel(nn.Module):
+
+ def __init__(self, *args, **kwargs):
+ nn.Module.__init__(self)
+
+ self.backbone = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
+ self.decoder = SegDetector(
+ in_channels=[64, 128, 256, 512], adaptive=True, k=50, **kwargs)
+
+ def forward(self, data, *args, **kwargs):
+ return self.decoder(self.backbone(data), *args, **kwargs)
+
+
+def parallelize(model, distributed, local_rank):
+ if distributed:
+ return nn.parallel.DistributedDataParallel(
+ model,
+ device_ids=[local_rank],
+ output_device=[local_rank],
+ find_unused_parameters=True)
+ else:
+ return nn.DataParallel(model)
+
+
class VLPTModel(nn.Module):
def __init__(self, *args, **kwargs):
@@ -449,3 +485,44 @@ class DBModel(nn.Module):
def forward(self, x):
return self.decoder(self.backbone(x))
+
+
+class DBModel_v2(nn.Module):
+
+ def __init__(self,
+ device,
+ distributed: bool = False,
+ local_rank: int = 0,
+ *args,
+ **kwargs):
+ """
+ DBNet-resnet18 model without deformable conv,
+ paper reference: https://arxiv.org/pdf/1911.08947.pdf
+ """
+ super(DBModel_v2, self).__init__()
+ from .seg_detector_loss import L1BalanceCELoss
+
+ self.model = BasicModel(*args, **kwargs)
+ self.model = parallelize(self.model, distributed, local_rank)
+ self.criterion = L1BalanceCELoss()
+ self.criterion = parallelize(self.criterion, distributed, local_rank)
+ self.device = device
+ self.to(self.device)
+
+ def forward(self, batch, training=False):
+ if isinstance(batch, dict):
+ data = batch['image'].to(self.device)
+ else:
+ data = batch.to(self.device)
+ data = data.float()
+ pred = self.model(data, training=self.training)
+
+ if self.training:
+ for key, value in batch.items():
+ if value is not None:
+ if hasattr(value, 'to'):
+ batch[key] = value.to(self.device)
+ loss_with_metrics = self.criterion(pred, batch)
+ loss, metrics = loss_with_metrics
+ return loss, pred, metrics
+ return pred
diff --git a/modelscope/models/cv/ocr_detection/modules/seg_detector_loss.py b/modelscope/models/cv/ocr_detection/modules/seg_detector_loss.py
new file mode 100644
index 00000000..38446e35
--- /dev/null
+++ b/modelscope/models/cv/ocr_detection/modules/seg_detector_loss.py
@@ -0,0 +1,257 @@
+# ------------------------------------------------------------------------------
+# The implementation is adopted from DBNet,
+# made publicly available under the Apache License 2.0 at https://github.com/MhLiao/DB.
+# ------------------------------------------------------------------------------
+import sys
+
+import torch
+import torch.nn as nn
+
+
+class SegDetectorLossBuilder():
+ '''
+ Build loss functions for SegDetector.
+ Details about the built functions:
+ Input:
+ pred: A dict which contains predictions.
+ thresh: The threshold prediction
+ binary: The text segmentation prediction.
+ thresh_binary: Value produced by `step_function(binary - thresh)`.
+ batch:
+ gt: Text regions bitmap gt.
+ mask: Ignore mask,
+ pexels where value is 1 indicates no contribution to loss.
+ thresh_mask: Mask indicates regions cared by thresh supervision.
+ thresh_map: Threshold gt.
+ Return:
+ (loss, metrics).
+ loss: A scalar loss value.
+ metrics: A dict contraining partial loss values.
+ '''
+
+ def __init__(self, loss_class, *args, **kwargs):
+ self.loss_class = loss_class
+ self.loss_args = args
+ self.loss_kwargs = kwargs
+
+ def build(self):
+ return getattr(sys.modules[__name__],
+ self.loss_class)(*self.loss_args, **self.loss_kwargs)
+
+
+def _neg_loss(pred, gt):
+ ''' Modified focal loss. Exactly the same as CornerNet.
+ Runs faster and costs a little bit more memory
+ Arguments:
+ pred (batch x c x h x w)
+ gt_regr (batch x c x h x w)
+ '''
+ pos_inds = gt.eq(1).float()
+ neg_inds = gt.lt(1).float()
+
+ neg_weights = torch.pow(1 - gt, 4)
+
+ loss = 0
+
+ pos_loss = torch.log(pred) * torch.pow(1 - pred, 2) * pos_inds
+ neg_loss = torch.log(1 - pred) * torch.pow(pred,
+ 2) * neg_weights * neg_inds
+
+ num_pos = pos_inds.float().sum()
+ pos_loss = pos_loss.sum()
+ neg_loss = neg_loss.sum()
+
+ if num_pos == 0:
+ loss = loss - neg_loss
+ else:
+ loss = loss - (pos_loss + neg_loss) / num_pos
+
+ b = pred.shape[0]
+ loss = loss / b
+ if loss > 10:
+ print('Loss', loss)
+ loss /= 1000
+ print('HM Loss > 10\n')
+ else:
+ loss
+
+ return loss
+
+
+class FocalLoss(nn.Module):
+ '''nn.Module warpper for focal loss'''
+
+ def __init__(self):
+ super(FocalLoss, self).__init__()
+ self.neg_loss = _neg_loss
+
+ def forward(self, out, target):
+ return self.neg_loss(out, target)
+
+
+class DiceLoss(nn.Module):
+ '''
+ Loss function from https://arxiv.org/abs/1707.03237,
+ where iou computation is introduced heatmap manner to measure the
+ diversity bwtween tow heatmaps.
+ '''
+
+ def __init__(self, eps=1e-6):
+ super(DiceLoss, self).__init__()
+ self.eps = eps
+
+ def forward(self, pred: torch.Tensor, gt, mask, weights=None):
+ '''
+ pred: one or two heatmaps of shape (N, 1, H, W),
+ the losses of tow heatmaps are added together.
+ gt: (N, 1, H, W)
+ mask: (N, H, W)
+ '''
+ assert pred.dim() == 4, pred.dim()
+ return self._compute(pred, gt, mask, weights)
+
+ def _compute(self, pred, gt, mask, weights):
+ if pred.dim() == 4:
+ pred = pred[:, 0, :, :]
+ gt = gt[:, 0, :, :]
+ assert pred.shape == gt.shape
+ assert pred.shape == mask.shape
+ if weights is not None:
+ assert weights.shape == mask.shape
+ mask = weights * mask
+
+ intersection = (pred * gt * mask).sum()
+ union = (pred * mask).sum() + (gt * mask).sum() + self.eps
+ loss = 1 - 2.0 * intersection / union
+ assert loss <= 1
+ return loss
+
+
+class MaskL1Loss(nn.Module):
+
+ def __init__(self):
+ super(MaskL1Loss, self).__init__()
+
+ def forward(self, pred: torch.Tensor, gt, mask):
+ mask_sum = mask.sum()
+ if mask_sum.item() == 0:
+ return mask_sum, dict(l1_loss=mask_sum)
+ else:
+ loss = (torch.abs(pred[:, 0] - gt) * mask).sum() / mask_sum
+ return loss, dict(l1_loss=loss)
+
+
+class MaskL2Loss(nn.Module):
+
+ def __init__(self):
+ super(MaskL2Loss, self).__init__()
+
+ def forward(self, pred: torch.Tensor, gt, mask):
+ mask_sum = mask.sum()
+ if mask_sum.item() == 0:
+ return mask_sum, dict(l1_loss=mask_sum)
+ else:
+ loss = (((pred[:, 0] - gt)**2) * mask).sum() / mask_sum
+ return loss, dict(l1_loss=loss)
+
+
+class BalanceCrossEntropyLoss(nn.Module):
+ '''
+ Balanced cross entropy loss.
+ Shape:
+ - Input: :math:`(N, 1, H, W)`
+ - GT: :math:`(N, 1, H, W)`, same shape as the input
+ - Mask: :math:`(N, H, W)`, same spatial shape as the input
+ - Output: scalar.
+
+ Examples::
+
+ >>> m = nn.Sigmoid()
+ >>> loss = nn.BCELoss()
+ >>> input = torch.randn(3, requires_grad=True)
+ >>> target = torch.empty(3).random_(2)
+ >>> output = loss(m(input), target)
+ >>> output.backward()
+ '''
+
+ def __init__(self, negative_ratio=3.0, eps=1e-6):
+ super(BalanceCrossEntropyLoss, self).__init__()
+ self.negative_ratio = negative_ratio
+ self.eps = eps
+
+ def forward(self,
+ pred: torch.Tensor,
+ gt: torch.Tensor,
+ mask: torch.Tensor,
+ return_origin=False):
+ '''
+ Args:
+ pred: shape :math:`(N, 1, H, W)`, the prediction of network
+ gt: shape :math:`(N, 1, H, W)`, the target
+ mask: shape :math:`(N, H, W)`, the mask indicates positive regions
+ '''
+ positive = (gt * mask).byte()
+ negative = ((1 - gt) * mask).byte()
+ positive_count = int(positive.float().sum())
+ negative_count = min(
+ int(negative.float().sum()),
+ int(positive_count * self.negative_ratio))
+ loss = nn.functional.binary_cross_entropy(
+ pred, gt, reduction='none')[:, 0, :, :]
+ positive_loss = loss * positive.float()
+ negative_loss = loss * negative.float()
+ negative_loss, _ = torch.topk(negative_loss.view(-1), negative_count)
+
+ balance_loss = (positive_loss.sum() + negative_loss.sum()) /\
+ (positive_count + negative_count + self.eps)
+
+ if return_origin:
+ return balance_loss, loss
+ return balance_loss
+
+
+class L1BalanceCELoss(nn.Module):
+ '''
+ Balanced CrossEntropy Loss on `binary`,
+ MaskL1Loss on `thresh`,
+ DiceLoss on `thresh_binary`.
+ Note: The meaning of inputs can be figured out in `SegDetectorLossBuilder`.
+ '''
+
+ def __init__(self, eps=1e-6, l1_scale=10, bce_scale=5, hm_scale=10):
+ super(L1BalanceCELoss, self).__init__()
+ self.dice_loss = DiceLoss(eps=eps)
+ self.l1_loss = MaskL1Loss()
+ self.bce_loss = BalanceCrossEntropyLoss()
+
+ self.l2_loss = MaskL2Loss()
+ self.hm_loss = FocalLoss()
+
+ self.l1_scale = l1_scale
+ self.bce_scale = bce_scale
+ self.hm_scale = hm_scale
+
+ def forward(self, pred, batch):
+
+ bce_loss = self.bce_loss(pred['binary'], batch['gt'], batch['mask'])
+ metrics = dict(bce_loss=bce_loss)
+ if 'thresh' in pred:
+ l1_loss, l1_metric = self.l1_loss(pred['thresh'],
+ batch['thresh_map'],
+ batch['thresh_mask'])
+ dice_loss = self.dice_loss(pred['thresh_binary'], batch['gt'],
+ batch['mask'])
+ metrics['thresh_loss'] = dice_loss
+ loss = dice_loss + self.l1_scale * l1_loss + bce_loss * self.bce_scale
+ metrics.update(**l1_metric)
+ else:
+ loss = bce_loss
+
+ if 'hm' in pred:
+ hm_loss, _ = self.l2_loss(pred['hm'], batch['heatmap'],
+ batch['mask'])
+
+ metrics['hm_loss'] = hm_loss
+ loss = loss + self.hm_scale * hm_loss
+
+ return loss, metrics
diff --git a/modelscope/models/cv/ocr_detection/utils.py b/modelscope/models/cv/ocr_detection/utils.py
index 6de22b3f..81dbb076 100644
--- a/modelscope/models/cv/ocr_detection/utils.py
+++ b/modelscope/models/cv/ocr_detection/utils.py
@@ -180,7 +180,7 @@ def boxes_from_bitmap(pred, _bitmap, dest_width, dest_height):
contours, _ = cv2.findContours((bitmap * 255).astype(np.uint8),
cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
- for contour in contours[:100]:
+ for contour in contours[:1000]:
points, sside = get_mini_boxes(contour)
if sside < 3:
continue
diff --git a/modelscope/models/cv/ocr_recognition/model.py b/modelscope/models/cv/ocr_recognition/model.py
index 7d76f8e8..6eb13403 100644
--- a/modelscope/models/cv/ocr_recognition/model.py
+++ b/modelscope/models/cv/ocr_recognition/model.py
@@ -16,6 +16,53 @@ from .modules.crnn import CRNN
LOGGER = get_logger()
+def flatten_label(target):
+ label_flatten = []
+ label_length = []
+ label_dict = []
+ for i in range(0, target.size()[0]):
+ cur_label = target[i].tolist()
+ temp_label = cur_label[:cur_label.index(0)]
+ label_flatten += temp_label
+ label_dict.append(temp_label)
+ label_length.append(len(temp_label))
+ label_flatten = torch.LongTensor(label_flatten)
+ label_length = torch.IntTensor(label_length)
+ return (label_dict, label_length, label_flatten)
+
+
+class cha_encdec():
+
+ def __init__(self, charMapping, case_sensitive=True):
+ self.case_sensitive = case_sensitive
+ self.text_seq_len = 160
+ self.charMapping = charMapping
+
+ def encode(self, label_batch):
+ max_len = max([len(s) for s in label_batch])
+ out = torch.zeros(len(label_batch), max_len + 1).long()
+ for i in range(0, len(label_batch)):
+ if not self.case_sensitive:
+ cur_encoded = torch.tensor([
+ self.charMapping[char.lower()] - 1 if char.lower()
+ in self.charMapping else len(self.charMapping)
+ for char in label_batch[i]
+ ]) + 1
+ else:
+ cur_encoded = torch.tensor([
+ self.charMapping[char]
+ - 1 if char in self.charMapping else len(self.charMapping)
+ for char in label_batch[i]
+ ]) + 1
+ out[i][0:len(cur_encoded)] = cur_encoded
+ out = torch.cat(
+ (out, torch.zeros(
+ (out.size(0), self.text_seq_len - out.size(1))).type_as(out)),
+ dim=1)
+ label_dict, label_length, label_flatten = flatten_label(out)
+ return label_dict, label_length, label_flatten
+
+
@MODELS.register_module(
Tasks.ocr_recognition, module_name=Models.ocr_recognition)
class OCRRecognition(TorchModel):
@@ -27,11 +74,12 @@ class OCRRecognition(TorchModel):
model_dir (str): the model path.
"""
super().__init__(model_dir, **kwargs)
-
model_path = os.path.join(model_dir, ModelFile.TORCH_MODEL_FILE)
cfgs = Config.from_file(
os.path.join(model_dir, ModelFile.CONFIGURATION))
self.do_chunking = cfgs.model.inference_kwargs.do_chunking
+ self.target_height = cfgs.model.inference_kwargs.img_height
+ self.target_width = cfgs.model.inference_kwargs.img_width
self.recognizer = None
if cfgs.model.recognizer == 'ConvNextViT':
self.recognizer = ConvNextViT()
@@ -47,14 +95,22 @@ class OCRRecognition(TorchModel):
dict_path = os.path.join(model_dir, ModelFile.VOCAB_FILE)
self.labelMapping = dict()
+ self.charMapping = dict()
with open(dict_path, 'r', encoding='utf-8') as f:
lines = f.readlines()
cnt = 1
+ # ConvNextViT model start from index=2
+ if self.do_chunking:
+ cnt += 1
for line in lines:
line = line.strip('\n')
self.labelMapping[cnt] = line
+ self.charMapping[line] = cnt
cnt += 1
+ self.encdec = cha_encdec(self.charMapping)
+ self.criterion_CTC = torch.nn.CTCLoss(zero_infinity=True)
+
def forward(self, inputs):
"""
Args:
@@ -66,44 +122,37 @@ class OCRRecognition(TorchModel):
"""
return self.recognizer(inputs)
- def postprocess(self, inputs):
- # naive decoder
+ def do_step(self, batch):
+ inputs = batch['images']
+ labels = batch['labels']
+ bs = inputs.shape[0]
if self.do_chunking:
- preds = inputs
- batchSize, length = preds.shape
- PRED_LENTH = 75
- PRED_PAD = 6
- pred_idx = []
- if batchSize == 1:
- pred_idx = preds[0].cpu().data.tolist()
- else:
- for idx in range(batchSize):
- if idx == 0:
- pred_idx.extend(
- preds[idx].cpu().data[:PRED_LENTH
- - PRED_PAD].tolist())
- elif idx == batchSize - 1:
- pred_idx.extend(
- preds[idx].cpu().data[PRED_PAD:].tolist())
- else:
- pred_idx.extend(
- preds[idx].cpu().data[PRED_PAD:PRED_LENTH
- - PRED_PAD].tolist())
- pred_idx = [its - 1 for its in pred_idx if its > 0]
+ inputs = inputs.view(bs * 3, 1, self.target_height, 300)
else:
- outprobs = inputs
- outprobs = F.softmax(outprobs, dim=-1)
- preds = torch.argmax(outprobs, -1)
- length, batchSize = preds.shape
- assert batchSize == 1, 'only support onesample inference'
- pred_idx = preds[:, 0].cpu().data.tolist()
+ inputs = inputs.view(bs, 1, self.target_height, self.target_width)
+ output = self(inputs)
+ probs = output['probs'].permute(1, 0, 2)
+ _, label_length, label_flatten = self.encdec.encode(labels)
+ probs_sizes = torch.IntTensor([probs.size(0)] * probs.size(1))
+ loss = self.criterion_CTC(
+ probs.log_softmax(2), label_flatten, probs_sizes, label_length)
+ output = dict(loss=loss, preds=output['preds'])
+ return output
- pred_idx = pred_idx
- last_p = 0
- str_pred = []
- for p in pred_idx:
- if p != last_p and p != 0:
- str_pred.append(self.labelMapping[p])
- last_p = p
- final_str = ''.join(str_pred)
- return final_str
+ def postprocess(self, inputs):
+ outprobs = inputs
+ outprobs = F.softmax(outprobs, dim=-1)
+ preds = torch.argmax(outprobs, -1)
+ batchSize, length = preds.shape
+ final_str_list = []
+ for i in range(batchSize):
+ pred_idx = preds[i].cpu().data.tolist()
+ last_p = 0
+ str_pred = []
+ for p in pred_idx:
+ if p != last_p and p != 0:
+ str_pred.append(self.labelMapping[p])
+ last_p = p
+ final_str = ''.join(str_pred)
+ final_str_list.append(final_str)
+ return {'preds': final_str_list, 'probs': inputs}
diff --git a/modelscope/models/cv/ocr_recognition/modules/convnextvit.py b/modelscope/models/cv/ocr_recognition/modules/convnextvit.py
index aaedb697..4e7900b1 100644
--- a/modelscope/models/cv/ocr_recognition/modules/convnextvit.py
+++ b/modelscope/models/cv/ocr_recognition/modules/convnextvit.py
@@ -16,8 +16,5 @@ class ConvNextViT(nn.Module):
def forward(self, input):
""" Transformation stage """
features = self.cnn_model(input)
- prediction = self.vitstr(features)
- prediction = torch.nn.functional.softmax(prediction, dim=-1)
-
- output = torch.argmax(prediction, -1)
+ output = self.vitstr(features)
return output
diff --git a/modelscope/models/cv/ocr_recognition/modules/crnn.py b/modelscope/models/cv/ocr_recognition/modules/crnn.py
index e0e489e9..3de8304d 100644
--- a/modelscope/models/cv/ocr_recognition/modules/crnn.py
+++ b/modelscope/models/cv/ocr_recognition/modules/crnn.py
@@ -96,4 +96,5 @@ class CRNN(nn.Module):
rnnfeats = self.rnn(convfeats)
output = self.cls(rnnfeats)
+ output = output.permute(1, 0, 2) # [b, w, c]
return output
diff --git a/modelscope/models/cv/ocr_recognition/modules/vitstr.py b/modelscope/models/cv/ocr_recognition/modules/vitstr.py
index 5ce3aeca..c9fa0693 100644
--- a/modelscope/models/cv/ocr_recognition/modules/vitstr.py
+++ b/modelscope/models/cv/ocr_recognition/modules/vitstr.py
@@ -39,6 +39,13 @@ class ViTSTR(VisionTransformer):
def forward(self, x):
x = self.forward_features(x)
+ ap = x.view(x.shape[0] // 3, 3, 75, x.shape[2])
+ features_1d_concat = torch.ones(x.shape[0] // 3, 201,
+ x.shape[2]).type_as(x)
+ features_1d_concat[:, :69, :] = ap[:, 0, :69, :]
+ features_1d_concat[:, 69:69 + 63, :] = ap[:, 1, 6:-6, :]
+ features_1d_concat[:, 69 + 63:, :] = ap[:, 2, 6:, :]
+ x = features_1d_concat
b, s, e = x.size()
x = x.reshape(b * s, e)
x = self.head(x).view(b, s, self.num_classes)
diff --git a/modelscope/models/cv/ocr_recognition/preprocessor.py b/modelscope/models/cv/ocr_recognition/preprocessor.py
index 6405e3e4..d327cd3c 100644
--- a/modelscope/models/cv/ocr_recognition/preprocessor.py
+++ b/modelscope/models/cv/ocr_recognition/preprocessor.py
@@ -1,6 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
-import math
import os
import cv2
@@ -32,6 +31,8 @@ class OCRRecognitionPreprocessor(Preprocessor):
self.do_chunking = cfgs.model.inference_kwargs.do_chunking
self.target_height = cfgs.model.inference_kwargs.img_height
self.target_width = cfgs.model.inference_kwargs.img_width
+ if self.do_chunking:
+ self.target_width = 804
def keepratio_resize(self, img):
cur_ratio = img.shape[1] / float(img.shape[0])
@@ -59,49 +60,34 @@ class OCRRecognitionPreprocessor(Preprocessor):
Returns:
outputs: the preprocessed image
"""
- if isinstance(inputs, str):
- img = np.array(load_image(inputs).convert('L'))
- elif isinstance(inputs, PIL.Image.Image):
- img = np.array(inputs.convert('L'))
- elif isinstance(inputs, np.ndarray):
- if len(inputs.shape) == 3:
- img = cv2.cvtColor(inputs, cv2.COLOR_RGB2GRAY)
- else:
- raise TypeError(
- f'inputs should be either str, PIL.Image, np.array, but got {type(inputs)}'
- )
-
- if self.do_chunking:
- PRED_LENTH = 75
- PRED_PAD = 6
- data = []
- img_h, img_w = img.shape
- wh_ratio = img_w / img_h
- true_w = int(self.target_height * wh_ratio)
- split_batch_cnt = 1
- if true_w < self.target_width * 1.2:
- img = cv2.resize(
- img, (min(true_w, self.target_width), self.target_height))
+ if not isinstance(inputs, list):
+ inputs = [inputs]
+ data_batch = []
+ for item in inputs:
+ if isinstance(item, str):
+ img = np.array(load_image(item).convert('L'))
+ elif isinstance(item, PIL.Image.Image):
+ img = np.array(item.convert('L'))
+ elif isinstance(item, np.ndarray):
+ if len(item.shape) == 3:
+ img = cv2.cvtColor(item, cv2.COLOR_RGB2GRAY)
else:
- split_batch_cnt = math.ceil((true_w - 48) * 1.0 / 252)
- img = cv2.resize(img, (true_w, self.target_height))
+ raise TypeError(
+ f'inputs should be either (a list of) str, PIL.Image, np.array, but got {type(item)}'
+ )
- if split_batch_cnt == 1:
- mask = np.zeros((self.target_height, self.target_width))
- mask[:, :img.shape[1]] = img
- data.append(mask)
+ img = self.keepratio_resize(img)
+ img = torch.FloatTensor(img)
+ if self.do_chunking:
+ chunk_img = []
+ for i in range(3):
+ left = (300 - 48) * i
+ chunk_img.append(img[:, left:left + 300])
+ merge_img = torch.cat(chunk_img, 0)
+ data = merge_img.view(3, 1, self.target_height, 300) / 255.
else:
- for idx in range(split_batch_cnt):
- mask = np.zeros((self.target_height, self.target_width))
- left = (PRED_LENTH * 4 - PRED_PAD * 4) * idx
- trunk_img = img[:, left:min(left + PRED_LENTH * 4, true_w)]
- mask[:, :trunk_img.shape[1]] = trunk_img
- data.append(mask)
-
- data = torch.FloatTensor(data).view(
- len(data), 1, self.target_height, self.target_width) / 255.
- else:
- data = self.keepratio_resize(img)
- data = torch.FloatTensor(data).view(1, 1, self.target_height,
- self.target_width) / 255.
- return data
+ data = img.view(1, 1, self.target_height,
+ self.target_width) / 255.
+ data_batch.append(data)
+ data_batch = torch.cat(data_batch, 0)
+ return data_batch
diff --git a/modelscope/models/cv/salient_detection/models/__init__.py b/modelscope/models/cv/salient_detection/models/__init__.py
index 8ea7a5d3..6df5101a 100644
--- a/modelscope/models/cv/salient_detection/models/__init__.py
+++ b/modelscope/models/cv/salient_detection/models/__init__.py
@@ -1,3 +1,4 @@
# The implementation is adopted from U-2-Net, made publicly available under the Apache 2.0 License
# source code avaiable via https://github.com/xuebinqin/U-2-Net
+from .senet import SENet
from .u2net import U2NET
diff --git a/modelscope/models/cv/salient_detection/models/backbone/Res2Net_v1b.py b/modelscope/models/cv/salient_detection/models/backbone/Res2Net_v1b.py
index 46c950bf..5f92ed8b 100644
--- a/modelscope/models/cv/salient_detection/models/backbone/Res2Net_v1b.py
+++ b/modelscope/models/cv/salient_detection/models/backbone/Res2Net_v1b.py
@@ -1,6 +1,5 @@
-# Implementation in this file is modified based on Res2Net-PretrainedModels
-# Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Public License
-# publicly available at https://github.com/Res2Net/Res2Net-PretrainedModels/blob/master/res2net_v1b.py
+# Implementation in this file is modified based on SINet-V2,made publicly available under the Apache 2.0 License
+# publicly available at https://github.com/GewelsJI/SINet-V2
import math
import torch
diff --git a/modelscope/models/cv/salient_detection/models/backbone/__init__.py b/modelscope/models/cv/salient_detection/models/backbone/__init__.py
index ab4029e8..5a97ef5d 100644
--- a/modelscope/models/cv/salient_detection/models/backbone/__init__.py
+++ b/modelscope/models/cv/salient_detection/models/backbone/__init__.py
@@ -1,6 +1,5 @@
-# Implementation in this file is modified based on Res2Net-PretrainedModels
-# Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Public License
-# publicly available at https://github.com/Res2Net/Res2Net-PretrainedModels/blob/master/res2net_v1b.py
+# Implementation in this file is modified based on SINet-V2,made publicly available under the Apache 2.0 License
+# publicly available at https://github.com/GewelsJI/SINet-V2
from .Res2Net_v1b import res2net50_v1b_26w_4s
__all__ = ['res2net50_v1b_26w_4s']
diff --git a/modelscope/models/cv/salient_detection/models/modules.py b/modelscope/models/cv/salient_detection/models/modules.py
new file mode 100644
index 00000000..09796bd3
--- /dev/null
+++ b/modelscope/models/cv/salient_detection/models/modules.py
@@ -0,0 +1,178 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .utils import ConvBNReLU
+
+
+class AreaLayer(nn.Module):
+
+ def __init__(self, in_channel, out_channel):
+ super(AreaLayer, self).__init__()
+ self.lbody = nn.Sequential(
+ nn.Conv2d(out_channel, out_channel, 1),
+ nn.BatchNorm2d(out_channel), nn.ReLU(inplace=True))
+ self.hbody = nn.Sequential(
+ nn.Conv2d(in_channel, out_channel, 1), nn.BatchNorm2d(out_channel),
+ nn.ReLU(inplace=True))
+ self.body = nn.Sequential(
+ nn.Conv2d(2 * out_channel, out_channel, 3, 1, 1),
+ nn.BatchNorm2d(out_channel), nn.ReLU(inplace=True),
+ nn.Conv2d(out_channel, out_channel, 3, 1, 1),
+ nn.BatchNorm2d(out_channel), nn.ReLU(inplace=True),
+ nn.Conv2d(out_channel, 1, 1))
+
+ def forward(self, xl, xh):
+ xl1 = self.lbody(xl)
+ xl1 = F.interpolate(
+ xl1, size=xh.size()[2:], mode='bilinear', align_corners=True)
+ xh1 = self.hbody(xh)
+ x = torch.cat((xl1, xh1), dim=1)
+ x_out = self.body(x)
+ return x_out
+
+
+class EdgeLayer(nn.Module):
+
+ def __init__(self, in_channel, out_channel):
+ super(EdgeLayer, self).__init__()
+ self.lbody = nn.Sequential(
+ nn.Conv2d(out_channel, out_channel, 1),
+ nn.BatchNorm2d(out_channel), nn.ReLU(inplace=True))
+ self.hbody = nn.Sequential(
+ nn.Conv2d(in_channel, out_channel, 1), nn.BatchNorm2d(out_channel),
+ nn.ReLU(inplace=True))
+ self.bodye = nn.Sequential(
+ nn.Conv2d(2 * out_channel, out_channel, 3, 1, 1),
+ nn.BatchNorm2d(out_channel), nn.ReLU(inplace=True),
+ nn.Conv2d(out_channel, out_channel, 3, 1, 1),
+ nn.BatchNorm2d(out_channel), nn.ReLU(inplace=True),
+ nn.Conv2d(out_channel, 1, 1))
+
+ def forward(self, xl, xh):
+ xl1 = self.lbody(xl)
+ xh1 = self.hbody(xh)
+ xh1 = F.interpolate(
+ xh1, size=xl.size()[2:], mode='bilinear', align_corners=True)
+ x = torch.cat((xl1, xh1), dim=1)
+ x_out = self.bodye(x)
+ return x_out
+
+
+class EBlock(nn.Module):
+
+ def __init__(self, inchs, outchs):
+ super(EBlock, self).__init__()
+ self.elayer = nn.Sequential(
+ ConvBNReLU(inchs + 1, outchs, kernel_size=3, padding=1, stride=1),
+ ConvBNReLU(outchs, outchs, 1))
+ self.salayer = nn.Sequential(
+ nn.Conv2d(2, 1, 3, 1, 1, bias=False),
+ nn.BatchNorm2d(1, momentum=0.01), nn.Sigmoid())
+
+ def forward(self, x, edgeAtten):
+ x = torch.cat((x, edgeAtten), dim=1)
+ ex = self.elayer(x)
+ ex_max = torch.max(ex, 1, keepdim=True)[0]
+ ex_mean = torch.mean(ex, dim=1, keepdim=True)
+ xei_compress = torch.cat((ex_max, ex_mean), dim=1)
+
+ scale = self.salayer(xei_compress)
+ x_out = ex * scale
+ return x_out
+
+
+class StructureE(nn.Module):
+
+ def __init__(self, inchs, outchs, EM):
+ super(StructureE, self).__init__()
+ self.ne_modules = int(inchs / EM)
+ NM = int(outchs / self.ne_modules)
+ elayes = []
+ for i in range(self.ne_modules):
+ emblock = EBlock(EM, NM)
+ elayes.append(emblock)
+ self.emlayes = nn.ModuleList(elayes)
+ self.body = nn.Sequential(
+ ConvBNReLU(outchs, outchs, 3, 1, 1), ConvBNReLU(outchs, outchs, 1))
+
+ def forward(self, x, edgeAtten):
+ if edgeAtten.size() != x.size():
+ edgeAtten = F.interpolate(
+ edgeAtten, x.size()[2:], mode='bilinear', align_corners=False)
+ xx = torch.chunk(x, self.ne_modules, dim=1)
+ efeas = []
+ for i in range(self.ne_modules):
+ xei = self.emlayes[i](xx[i], edgeAtten)
+ efeas.append(xei)
+ efeas = torch.cat(efeas, dim=1)
+ x_out = self.body(efeas)
+ return x_out
+
+
+class ABlock(nn.Module):
+
+ def __init__(self, inchs, outchs, k):
+ super(ABlock, self).__init__()
+ self.alayer = nn.Sequential(
+ ConvBNReLU(inchs, outchs, k, 1, k // 2),
+ ConvBNReLU(outchs, outchs, 1))
+ self.arlayer = nn.Sequential(
+ ConvBNReLU(inchs, outchs, k, 1, k // 2),
+ ConvBNReLU(outchs, outchs, 1))
+ self.fusion = ConvBNReLU(2 * outchs, outchs, 1)
+
+ def forward(self, x, areaAtten):
+ xa = x * areaAtten
+ xra = x * (1 - areaAtten)
+ xout = self.fusion(torch.cat((xa, xra), dim=1))
+ return xout
+
+
+class AMFusion(nn.Module):
+
+ def __init__(self, inchs, outchs, AM):
+ super(AMFusion, self).__init__()
+ self.k = [3, 3, 5, 5]
+ self.conv_up = ConvBNReLU(inchs, outchs, 3, 1, 1)
+ self.up = nn.Upsample(
+ scale_factor=2, mode='bilinear', align_corners=True)
+ self.na_modules = int(outchs / AM)
+ alayers = []
+ for i in range(self.na_modules):
+ layer = ABlock(AM, AM, self.k[i])
+ alayers.append(layer)
+ self.alayers = nn.ModuleList(alayers)
+ self.fusion_0 = ConvBNReLU(outchs, outchs, 3, 1, 1)
+ self.fusion_e = nn.Sequential(
+ nn.Conv2d(
+ outchs, outchs, kernel_size=(3, 1), padding=(1, 0),
+ bias=False), nn.BatchNorm2d(outchs), nn.ReLU(inplace=True),
+ nn.Conv2d(
+ outchs, outchs, kernel_size=(1, 3), padding=(0, 1),
+ bias=False), nn.BatchNorm2d(outchs), nn.ReLU(inplace=True))
+ self.fusion_e1 = nn.Sequential(
+ nn.Conv2d(
+ outchs, outchs, kernel_size=(5, 1), padding=(2, 0),
+ bias=False), nn.BatchNorm2d(outchs), nn.ReLU(inplace=True),
+ nn.Conv2d(
+ outchs, outchs, kernel_size=(1, 5), padding=(0, 2),
+ bias=False), nn.BatchNorm2d(outchs), nn.ReLU(inplace=True))
+ self.fusion = ConvBNReLU(3 * outchs, outchs, 1)
+
+ def forward(self, xl, xh, xhm):
+ xh1 = self.up(self.conv_up(xh))
+ x = xh1 + xl
+ xm = self.up(torch.sigmoid(xhm))
+ xx = torch.chunk(x, self.na_modules, dim=1)
+ xxmids = []
+ for i in range(self.na_modules):
+ xi = self.alayers[i](xx[i], xm)
+ xxmids.append(xi)
+ xfea = torch.cat(xxmids, dim=1)
+ x0 = self.fusion_0(xfea)
+ x1 = self.fusion_e(xfea)
+ x2 = self.fusion_e1(xfea)
+ x_out = self.fusion(torch.cat((x0, x1, x2), dim=1))
+ return x_out
diff --git a/modelscope/models/cv/salient_detection/models/senet.py b/modelscope/models/cv/salient_detection/models/senet.py
new file mode 100644
index 00000000..37cf42be
--- /dev/null
+++ b/modelscope/models/cv/salient_detection/models/senet.py
@@ -0,0 +1,74 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .backbone import res2net50_v1b_26w_4s as res2net
+from .modules import AMFusion, AreaLayer, EdgeLayer, StructureE
+from .utils import ASPP, CBAM, ConvBNReLU
+
+
+class SENet(nn.Module):
+
+ def __init__(self, backbone_path=None, pretrained=False):
+ super(SENet, self).__init__()
+ resnet50 = res2net(backbone_path, pretrained)
+ self.layer0_1 = nn.Sequential(resnet50.conv1, resnet50.bn1,
+ resnet50.relu)
+ self.maxpool = resnet50.maxpool
+ self.layer1 = resnet50.layer1
+ self.layer2 = resnet50.layer2
+ self.layer3 = resnet50.layer3
+ self.layer4 = resnet50.layer4
+ self.aspp3 = ASPP(1024, 256)
+ self.aspp4 = ASPP(2048, 256)
+ self.cbblock3 = CBAM(inchs=256, kernel_size=5)
+ self.cbblock4 = CBAM(inchs=256, kernel_size=5)
+ self.up = nn.Upsample(
+ mode='bilinear', scale_factor=2, align_corners=False)
+ self.conv_up = ConvBNReLU(512, 512, 1)
+ self.aux_edge = EdgeLayer(512, 256)
+ self.aux_area = AreaLayer(512, 256)
+ self.layer1_enhance = StructureE(256, 128, 128)
+ self.layer2_enhance = StructureE(512, 256, 128)
+ self.layer3_decoder = AMFusion(512, 256, 128)
+ self.layer2_decoder = AMFusion(256, 128, 128)
+ self.out_conv_8 = nn.Conv2d(256, 1, 1)
+ self.out_conv_4 = nn.Conv2d(128, 1, 1)
+
+ def forward(self, x):
+ layer0 = self.layer0_1(x)
+ layer0s = self.maxpool(layer0)
+ layer1 = self.layer1(layer0s)
+ layer2 = self.layer2(layer1)
+ layer3 = self.layer3(layer2)
+ layer4 = self.layer4(layer3)
+ layer3_eh = self.cbblock3(self.aspp3(layer3))
+ layer4_eh = self.cbblock4(self.aspp4(layer4))
+ layer34 = self.conv_up(
+ torch.cat((self.up(layer4_eh), layer3_eh), dim=1))
+ edge_atten = self.aux_edge(layer1, layer34)
+ area_atten = self.aux_area(layer1, layer34)
+ edge_atten_ = torch.sigmoid(edge_atten)
+ layer1_eh = self.layer1_enhance(layer1, edge_atten_)
+ layer2_eh = self.layer2_enhance(layer2, edge_atten_)
+ layer2_fu = self.layer3_decoder(layer2_eh, layer34, area_atten)
+ out_8 = self.out_conv_8(layer2_fu)
+ layer1_fu = self.layer2_decoder(layer1_eh, layer2_fu, out_8)
+ out_4 = self.out_conv_4(layer1_fu)
+ out_16 = F.interpolate(
+ area_atten,
+ size=x.size()[2:],
+ mode='bilinear',
+ align_corners=False)
+ out_8 = F.interpolate(
+ out_8, size=x.size()[2:], mode='bilinear', align_corners=False)
+ out_4 = F.interpolate(
+ out_4, size=x.size()[2:], mode='bilinear', align_corners=False)
+ edge_out = F.interpolate(
+ edge_atten_,
+ size=x.size()[2:],
+ mode='bilinear',
+ align_corners=False)
+
+ return out_4.sigmoid(), out_8.sigmoid(), out_16.sigmoid(), edge_out
diff --git a/modelscope/models/cv/salient_detection/salient_model.py b/modelscope/models/cv/salient_detection/salient_model.py
index 73c3c3fb..e25166c8 100644
--- a/modelscope/models/cv/salient_detection/salient_model.py
+++ b/modelscope/models/cv/salient_detection/salient_model.py
@@ -2,7 +2,6 @@
import os.path as osp
import cv2
-import numpy as np
import torch
from PIL import Image
from torchvision import transforms
@@ -10,8 +9,9 @@ from torchvision import transforms
from modelscope.metainfo import Models
from modelscope.models.base.base_torch_model import TorchModel
from modelscope.models.builder import MODELS
+from modelscope.utils.config import Config
from modelscope.utils.constant import ModelFile, Tasks
-from .models import U2NET
+from .models import U2NET, SENet
@MODELS.register_module(
@@ -22,13 +22,25 @@ class SalientDetection(TorchModel):
"""str -- model file root."""
super().__init__(model_dir, *args, **kwargs)
model_path = osp.join(model_dir, ModelFile.TORCH_MODEL_FILE)
- self.model = U2NET(3, 1)
+
+ self.norm_mean = [0.485, 0.456, 0.406]
+ self.norm_std = [0.229, 0.224, 0.225]
+ self.norm_size = (320, 320)
+
+ config_path = osp.join(model_dir, 'config.py')
+ if osp.exists(config_path) is False:
+ self.model = U2NET(3, 1)
+ else:
+ self.model = SENet(backbone_path=None, pretrained=False)
+ config = Config.from_file(config_path)
+ self.norm_mean = config.norm_mean
+ self.norm_std = config.norm_std
+ self.norm_size = config.norm_size
checkpoint = torch.load(model_path, map_location='cpu')
self.transform_input = transforms.Compose([
- transforms.Resize((320, 320)),
+ transforms.Resize(self.norm_size),
transforms.ToTensor(),
- transforms.Normalize(
- mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+ transforms.Normalize(mean=self.norm_mean, std=self.norm_std)
])
self.model.load_state_dict(checkpoint)
self.model.eval()
diff --git a/modelscope/models/cv/table_recognition/__init__.py b/modelscope/models/cv/table_recognition/__init__.py
new file mode 100644
index 00000000..e88f7f67
--- /dev/null
+++ b/modelscope/models/cv/table_recognition/__init__.py
@@ -0,0 +1,20 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+from typing import TYPE_CHECKING
+
+from modelscope.utils.import_utils import LazyImportModule
+
+if TYPE_CHECKING:
+ from .model_lore import LoreModel
+
+else:
+ _import_structure = {'model_lore': ['LoreModel']}
+
+ import sys
+
+ sys.modules[__name__] = LazyImportModule(
+ __name__,
+ globals()['__file__'],
+ _import_structure,
+ module_spec=__spec__,
+ extra_objects={},
+ )
diff --git a/modelscope/models/cv/table_recognition/lineless_table_process.py b/modelscope/models/cv/table_recognition/lineless_table_process.py
new file mode 100644
index 00000000..0d7fcfb5
--- /dev/null
+++ b/modelscope/models/cv/table_recognition/lineless_table_process.py
@@ -0,0 +1,439 @@
+# ------------------------------------------------------------------------------
+# Part of implementation is adopted from CenterNet,
+# made publicly available under the MIT License at https://github.com/xingyizhou/CenterNet.git
+# ------------------------------------------------------------------------------
+
+import cv2
+import numpy as np
+import shapely
+import torch
+import torch.nn as nn
+from shapely.geometry import MultiPoint, Point, Polygon
+
+
+def _gather_feat(feat, ind, mask=None):
+ # mandatory
+ dim = feat.size(2)
+ ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim)
+ feat = feat.gather(1, ind)
+ if mask is not None:
+ mask = mask.unsqueeze(2).expand_as(feat)
+ feat = feat[mask]
+ feat = feat.view(-1, dim)
+ return feat
+
+
+def _tranpose_and_gather_feat(feat, ind):
+ # mandatory
+ feat = feat.permute(0, 2, 3, 1).contiguous()
+ feat = feat.view(feat.size(0), -1, feat.size(3))
+ feat = _gather_feat(feat, ind)
+ return feat
+
+
+def _get_4ps_feat(cc_match, output):
+ # mandatory
+ if isinstance(output, dict):
+ feat = output['cr']
+ else:
+ feat = output
+ feat = feat.permute(0, 2, 3, 1).contiguous()
+ feat = feat.contiguous().view(feat.size(0), -1, feat.size(3))
+ feat = feat.unsqueeze(3).expand(
+ feat.size(0), feat.size(1), feat.size(2), 4)
+
+ dim = feat.size(2)
+
+ cc_match = cc_match.unsqueeze(2).expand(
+ cc_match.size(0), cc_match.size(1), dim, cc_match.size(2))
+ if not (isinstance(output, dict)):
+ cc_match = torch.where(
+ cc_match < feat.shape[1], cc_match, (feat.shape[0] - 1)
+ * torch.ones(cc_match.shape).to(torch.int64).cuda())
+ cc_match = torch.where(
+ cc_match >= 0, cc_match,
+ torch.zeros(cc_match.shape).to(torch.int64).cuda())
+ feat = feat.gather(1, cc_match)
+ return feat
+
+
+def _nms(heat, name, kernel=3):
+ pad = (kernel - 1) // 2
+
+ hmax = nn.functional.max_pool2d(
+ heat, (kernel, kernel), stride=1, padding=pad)
+ # save_map(hmax.cpu().numpy()[0],name)
+ keep = (hmax == heat).float()
+ return heat * keep, keep
+
+
+def _topk(scores, K=40):
+ batch, cat, height, width = scores.size()
+
+ topk_scores, topk_inds = torch.topk(scores.view(batch, cat, -1), K)
+
+ topk_inds = topk_inds % (
+ torch.Tensor([height]).to(torch.int64).cuda()
+ * torch.Tensor([width]).to(torch.int64).cuda())
+ topk_ys = (topk_inds / torch.Tensor([width]).cuda()).int().float()
+ topk_xs = (topk_inds
+ % torch.Tensor([width]).to(torch.int64).cuda()).int().float()
+
+ topk_score, topk_ind = torch.topk(topk_scores.view(batch, -1), K)
+ topk_clses = (topk_ind // K).int()
+ topk_inds = _gather_feat(topk_inds.view(batch, -1, 1),
+ topk_ind).view(batch, K)
+ topk_ys = _gather_feat(topk_ys.view(batch, -1, 1), topk_ind).view(batch, K)
+ topk_xs = _gather_feat(topk_xs.view(batch, -1, 1), topk_ind).view(batch, K)
+
+ return topk_score, topk_inds, topk_clses, topk_ys, topk_xs
+
+
+def corner_decode(mk, st_reg, mk_reg=None, K=400):
+ batch, cat, height, width = mk.size()
+ mk, keep = _nms(mk, 'mk.0.maxpool')
+ scores, inds, clses, ys, xs = _topk(mk, K=K)
+ if mk_reg is not None:
+ reg = _tranpose_and_gather_feat(mk_reg, inds)
+ reg = reg.view(batch, K, 2)
+ xs = xs.view(batch, K, 1) + reg[:, :, 0:1]
+ ys = ys.view(batch, K, 1) + reg[:, :, 1:2]
+ else:
+ xs = xs.view(batch, K, 1) + 0.5
+ ys = ys.view(batch, K, 1) + 0.5
+ scores = scores.view(batch, K, 1)
+ st_Reg = _tranpose_and_gather_feat(st_reg, inds)
+ bboxes_vec = [
+ xs - st_Reg[..., 0:1], ys - st_Reg[..., 1:2], xs - st_Reg[..., 2:3],
+ ys - st_Reg[..., 3:4], xs - st_Reg[..., 4:5], ys - st_Reg[..., 5:6],
+ xs - st_Reg[..., 6:7], ys - st_Reg[..., 7:8]
+ ]
+ bboxes = torch.cat(bboxes_vec, dim=2)
+ corner_dict = {
+ 'scores': scores,
+ 'inds': inds,
+ 'ys': ys,
+ 'xs': xs,
+ 'gboxes': bboxes
+ }
+ return scores, inds, ys, xs, bboxes, corner_dict
+
+
+def ctdet_4ps_decode(heat,
+ wh,
+ ax,
+ cr,
+ corner_dict=None,
+ reg=None,
+ cat_spec_wh=False,
+ K=100,
+ wiz_rev=False):
+
+ batch, cat, height, width = heat.size()
+ # heat = torch.sigmoid(heat)
+ # perform nms on heatmaps
+ heat, keep = _nms(heat, 'hm.0.maxpool')
+
+ scores, inds, clses, ys, xs = _topk(heat, K=K)
+ if reg is not None:
+ reg = _tranpose_and_gather_feat(reg, inds)
+ reg = reg.view(batch, K, 2)
+ xs = xs.view(batch, K, 1) + reg[:, :, 0:1]
+ ys = ys.view(batch, K, 1) + reg[:, :, 1:2]
+ else:
+ xs = xs.view(batch, K, 1) + 0.5
+ ys = ys.view(batch, K, 1) + 0.5
+ wh = _tranpose_and_gather_feat(wh, inds)
+ ax = _tranpose_and_gather_feat(ax, inds)
+
+ if cat_spec_wh:
+ wh = wh.view(batch, K, cat, 8)
+ clses_ind = clses.view(batch, K, 1, 1).expand(batch, K, 1, 8).long()
+ wh = wh.gather(2, clses_ind).view(batch, K, 8)
+ else:
+ wh = wh.view(batch, K, 8)
+ clses = clses.view(batch, K, 1).float()
+ scores = scores.view(batch, K, 1)
+
+ bboxes_vec = [
+ xs - wh[..., 0:1], ys - wh[..., 1:2], xs - wh[..., 2:3],
+ ys - wh[..., 3:4], xs - wh[..., 4:5], ys - wh[..., 5:6],
+ xs - wh[..., 6:7], ys - wh[..., 7:8]
+ ]
+ bboxes = torch.cat(bboxes_vec, dim=2)
+
+ cc_match = torch.cat(
+ [(xs - wh[..., 0:1]) + width * torch.round(ys - wh[..., 1:2]),
+ (xs - wh[..., 2:3]) + width * torch.round(ys - wh[..., 3:4]),
+ (xs - wh[..., 4:5]) + width * torch.round(ys - wh[..., 5:6]),
+ (xs - wh[..., 6:7]) + width * torch.round(ys - wh[..., 7:8])],
+ dim=2)
+
+ cc_match = torch.round(cc_match).to(torch.int64)
+
+ cr_feat = _get_4ps_feat(cc_match, cr)
+ cr_feat = cr_feat.sum(axis=3)
+
+ detections = torch.cat([bboxes, scores, clses], dim=2)
+
+ return detections, keep, ax, cr_feat
+
+
+def get_3rd_point(a, b):
+ direct = a - b
+ return b + np.array([-direct[1], direct[0]], dtype=np.float32)
+
+
+def affine_transform(pt, t):
+ new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32).T
+ new_pt = np.dot(t, new_pt)
+ return new_pt[:2]
+
+
+def get_dir(src_point, rot_rad):
+ sn, cs = np.sin(rot_rad), np.cos(rot_rad)
+
+ src_result = [0, 0]
+ src_result[0] = src_point[0] * cs - src_point[1] * sn
+ src_result[1] = src_point[0] * sn + src_point[1] * cs
+
+ return src_result
+
+
+def get_affine_transform(center,
+ scale,
+ rot,
+ output_size,
+ shift=np.array([0, 0], dtype=np.float32),
+ inv=0):
+ if not isinstance(scale, np.ndarray) and not isinstance(scale, list):
+ scale = np.array([scale, scale], dtype=np.float32)
+
+ scale_tmp = scale
+ src_w = scale_tmp[0]
+ dst_w = output_size[0]
+ dst_h = output_size[1]
+
+ rot_rad = np.pi * rot / 180
+ src_dir = get_dir([0, src_w * -0.5], rot_rad)
+ dst_dir = np.array([0, dst_w * -0.5], np.float32)
+
+ src = np.zeros((3, 2), dtype=np.float32)
+ dst = np.zeros((3, 2), dtype=np.float32)
+ src[0, :] = center + scale_tmp * shift # [0,0] #
+ src[1, :] = center + src_dir + scale_tmp * shift # scale #
+ dst[0, :] = [dst_w * 0.5, dst_h * 0.5] # [0,0] #
+ dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5],
+ np.float32) + dst_dir # output_size #
+
+ src[2:, :] = get_3rd_point(src[0, :], src[1, :])
+ dst[2:, :] = get_3rd_point(dst[0, :], dst[1, :])
+
+ if inv:
+ trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))
+ else:
+ trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))
+
+ return trans
+
+
+def get_affine_transform_upper_left(center,
+ scale,
+ rot,
+ output_size,
+ shift=np.array([0, 0], dtype=np.float32),
+ inv=0):
+ if not isinstance(scale, np.ndarray) and not isinstance(scale, list):
+ scale = np.array([scale, scale], dtype=np.float32)
+
+ src = np.zeros((3, 2), dtype=np.float32)
+ dst = np.zeros((3, 2), dtype=np.float32)
+ src[0, :] = center
+ dst[0, :] = [0, 0]
+ if center[0] < center[1]:
+ src[1, :] = [scale[0], center[1]]
+ dst[1, :] = [output_size[0], 0]
+ else:
+ src[1, :] = [center[0], scale[0]]
+ dst[1, :] = [0, output_size[0]]
+ src[2:, :] = get_3rd_point(src[0, :], src[1, :])
+ dst[2:, :] = get_3rd_point(dst[0, :], dst[1, :])
+
+ if inv:
+ trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))
+ else:
+ trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))
+
+ return trans
+
+
+def transform_preds(coords, center, scale, output_size, rot=0):
+ target_coords = np.zeros(coords.shape)
+ trans = get_affine_transform(center, scale, rot, output_size, inv=1)
+ for p in range(coords.shape[0]):
+ target_coords[p, 0:2] = affine_transform(coords[p, 0:2], trans)
+ return target_coords
+
+
+def transform_preds_upper_left(coords, center, scale, output_size, rot=0):
+ target_coords = np.zeros(coords.shape)
+
+ trans = get_affine_transform_upper_left(
+ center, scale, rot, output_size, inv=1)
+ for p in range(coords.shape[0]):
+ target_coords[p, 0:2] = affine_transform(coords[p, 0:2], trans)
+ return target_coords
+
+
+def ctdet_4ps_post_process_upper_left(dets, c, s, h, w, num_classes, rot=0):
+ # dets: batch x max_dets x dim
+ # return 1-based class det dict
+ ret = []
+ for i in range(dets.shape[0]):
+ top_preds = {}
+ dets[i, :, 0:2] = transform_preds_upper_left(dets[i, :, 0:2], c[i],
+ s[i], (w, h), rot)
+ dets[i, :, 2:4] = transform_preds_upper_left(dets[i, :, 2:4], c[i],
+ s[i], (w, h), rot)
+ dets[i, :, 4:6] = transform_preds_upper_left(dets[i, :, 4:6], c[i],
+ s[i], (w, h), rot)
+ dets[i, :, 6:8] = transform_preds_upper_left(dets[i, :, 6:8], c[i],
+ s[i], (w, h), rot)
+ classes = dets[i, :, -1]
+ for j in range(num_classes):
+ inds = (classes == j)
+ tmp_top_pred = [
+ dets[i, inds, :8].astype(np.float32),
+ dets[i, inds, 8:9].astype(np.float32)
+ ]
+ top_preds[j + 1] = np.concatenate(tmp_top_pred, axis=1).tolist()
+ ret.append(top_preds)
+ return ret
+
+
+def ctdet_corner_post_process(corner_st_reg, c, s, h, w, num_classes):
+ for i in range(corner_st_reg.shape[0]):
+ corner_st_reg[i, :, 0:2] = transform_preds(corner_st_reg[i, :, 0:2],
+ c[i], s[i], (w, h))
+ corner_st_reg[i, :, 2:4] = transform_preds(corner_st_reg[i, :, 2:4],
+ c[i], s[i], (w, h))
+ corner_st_reg[i, :, 4:6] = transform_preds(corner_st_reg[i, :, 4:6],
+ c[i], s[i], (w, h))
+ corner_st_reg[i, :, 6:8] = transform_preds(corner_st_reg[i, :, 6:8],
+ c[i], s[i], (w, h))
+ corner_st_reg[i, :, 8:10] = transform_preds(corner_st_reg[i, :, 8:10],
+ c[i], s[i], (w, h))
+ return corner_st_reg
+
+
+def merge_outputs(detections):
+ # thresh_conf, thresh_min, thresh_max = 0.1, 0.5, 0.7
+ num_classes, max_per_image = 2, 3000
+ results = {}
+ for j in range(1, num_classes + 1):
+ results[j] = np.concatenate([detection[j] for detection in detections],
+ axis=0).astype(np.float32)
+ scores = np.hstack([results[j][:, 8] for j in range(1, num_classes + 1)])
+ if len(scores) > max_per_image:
+ kth = len(scores) - max_per_image
+ thresh = np.partition(scores, kth)[kth]
+ for j in range(1, num_classes + 1):
+ keep_inds = (results[j][:, 8] >= thresh)
+ results[j] = results[j][keep_inds]
+ return results
+
+
+def filter(results, logi, ps):
+ # this function select boxes
+ batch_size, feat_dim = logi.shape[0], logi.shape[2]
+ num_valid = sum(results[1][:, 8] >= 0.15)
+
+ slct_logi = np.zeros((batch_size, num_valid, feat_dim), dtype=np.float32)
+ slct_dets = np.zeros((batch_size, num_valid, 8), dtype=np.int32)
+ for i in range(batch_size):
+ for j in range(num_valid):
+ slct_logi[i, j, :] = logi[i, j, :].cpu()
+ slct_dets[i, j, :] = ps[i, j, :].cpu()
+
+ return torch.Tensor(slct_logi).cuda(), torch.Tensor(slct_dets).cuda()
+
+
+def process_detect_output(output, meta):
+ K, MK = 3000, 5000
+ num_classes = 2
+ scale = 1.0
+
+ hm = output['hm'].sigmoid_()
+ wh = output['wh']
+ reg = output['reg']
+ st = output['st']
+ ax = output['ax']
+ cr = output['cr']
+
+ scores, inds, ys, xs, st_reg, corner_dict = corner_decode(
+ hm[:, 1:2, :, :], st, reg, K=MK)
+ dets, keep, logi, cr = ctdet_4ps_decode(
+ hm[:, 0:1, :, :], wh, ax, cr, corner_dict, reg=reg, K=K, wiz_rev=False)
+ raw_dets = dets
+ dets = dets.detach().cpu().numpy()
+ dets = dets.reshape(1, -1, dets.shape[2])
+ dets = ctdet_4ps_post_process_upper_left(dets.copy(),
+ [meta['c'].cpu().numpy()],
+ [meta['s']], meta['out_height'],
+ meta['out_width'], 2)
+ for j in range(1, num_classes + 1):
+ dets[0][j] = np.array(dets[0][j], dtype=np.float32).reshape(-1, 9)
+ dets[0][j][:, :8] /= scale
+ dets = dets[0]
+ detections = [dets]
+
+ logi = logi + cr
+ results = merge_outputs(detections)
+ slct_logi_feat, slct_dets_feat = filter(results, logi, raw_dets[:, :, :8])
+ slct_output_dets = results[1][:slct_logi_feat.shape[1], :8]
+
+ return slct_logi_feat, slct_dets_feat, slct_output_dets
+
+
+def process_logic_output(logi):
+ logi_floor = logi.floor()
+ dev = logi - logi_floor
+ logi = torch.where(dev > 0.5, logi_floor + 1, logi_floor)
+
+ return logi
+
+
+def load_lore_model(model, checkpoint, mtype):
+ state_dict_ = checkpoint['state_dict']
+ state_dict = {}
+ # convert data_parallal to model
+ for k in state_dict_:
+ if k.startswith('module') and not k.startswith('module_list'):
+ state_dict[k[7:]] = state_dict_[k]
+ else:
+ if mtype == 'model':
+ if k.startswith('model'):
+ state_dict[k[6:]] = state_dict_[k]
+ else:
+ continue
+ else:
+ if k.startswith('processor'):
+ state_dict[k[10:]] = state_dict_[k]
+ else:
+ continue
+ model_state_dict = model.state_dict()
+ # check loaded parameters and created model parameters
+ for k in state_dict:
+ if k in model_state_dict:
+ if state_dict[k].shape != model_state_dict[k].shape:
+ print('Skip loading parameter {}, required shape{}, '
+ 'loaded shape{}.'.format(k, model_state_dict[k].shape,
+ state_dict[k].shape))
+ state_dict[k] = model_state_dict[k]
+ else:
+ print('Drop parameter {}.'.format(k))
+ for k in model_state_dict:
+ if not (k in state_dict):
+ print('No param {}.'.format(k))
+ state_dict[k] = model_state_dict[k]
+ model.load_state_dict(state_dict, strict=False)
diff --git a/modelscope/models/cv/table_recognition/model_lore.py b/modelscope/models/cv/table_recognition/model_lore.py
new file mode 100644
index 00000000..d21b3fbe
--- /dev/null
+++ b/modelscope/models/cv/table_recognition/model_lore.py
@@ -0,0 +1,88 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import copy
+import math
+from os.path import join
+from typing import Any, Dict
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from modelscope.metainfo import Models
+from modelscope.models import MODELS, TorchModel
+from modelscope.outputs import OutputKeys
+from modelscope.utils.constant import ModelFile, Tasks
+from modelscope.utils.logger import get_logger
+from .lineless_table_process import (get_affine_transform,
+ get_affine_transform_upper_left,
+ load_lore_model, process_detect_output,
+ process_logic_output)
+from .modules.lore_detector import LoreDetectModel
+from .modules.lore_processor import LoreProcessModel
+
+LOGGER = get_logger()
+
+
+@MODELS.register_module(Tasks.lineless_table_recognition,
+ Models.lineless_table_recognition)
+class LoreModel(TorchModel):
+ '''
+ The model first locates table cells in the input image by key point segmentation.
+ Then the logical locations are predicted along with the spatial locations
+ employing two cascading regressors.
+ See details in paper "LORE: Logical Location Regression Network for Table Structure Recognition"
+ (https://arxiv.org/abs/2303.03730).
+ '''
+
+ def __init__(self, model_dir: str, **kwargs):
+ '''initialize the LORE model from the `model_dir` path.
+
+ Args:
+ model_dir (str): the model path.
+ '''
+ super(LoreModel, self).__init__()
+
+ model_path = join(model_dir, ModelFile.TORCH_MODEL_FILE)
+ checkpoint = torch.load(model_path, map_location='cpu')
+ # init detect infer model
+ self.detect_infer_model = LoreDetectModel()
+ load_lore_model(self.detect_infer_model, checkpoint, 'model')
+ # init process infer model
+ self.process_infer_model = LoreProcessModel()
+ load_lore_model(self.process_infer_model, checkpoint, 'processor')
+
+ def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
+ """
+ Args:
+ img (`torch.Tensor`): image tensor,
+ shape of each tensor is [3, H, W].
+
+ Return:
+ dets (`torch.Tensor`): the locations of detected table cells,
+ shape of each tensor is [N_cell, 8].
+ dets (`torch.Tensor`): the logical coordinates of detected table cells,
+ shape of each tensor is [N_cell, 4].
+ meta (`Dict`): the meta info of original image.
+ """
+ outputs = self.detect_infer_model(input['img'])
+ output = outputs[-1]
+ meta = input['meta']
+ slct_logi_feat, slct_dets_feat, slct_output_dets = process_detect_output(
+ output, meta)
+ _, slct_logi = self.process_infer_model(
+ slct_logi_feat, dets=slct_dets_feat.to(torch.int64))
+ return {
+ 'dets': slct_output_dets,
+ 'logi': slct_logi,
+ 'meta': input['meta']
+ }
+
+ def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
+ slct_dets = inputs['dets']
+ slct_logi = process_logic_output(inputs['logi'])
+ result = {
+ OutputKeys.POLYGONS: slct_dets,
+ OutputKeys.BOXES: np.array(slct_logi[0].cpu().numpy())
+ }
+ return result
diff --git a/modelscope/models/audio/tts/kantts/preprocess/script_convertor/core/__init__.py b/modelscope/models/cv/table_recognition/modules/__init__.py
similarity index 100%
rename from modelscope/models/audio/tts/kantts/preprocess/script_convertor/core/__init__.py
rename to modelscope/models/cv/table_recognition/modules/__init__.py
diff --git a/modelscope/models/cv/table_recognition/modules/lore_detector.py b/modelscope/models/cv/table_recognition/modules/lore_detector.py
new file mode 100644
index 00000000..d7e75a4f
--- /dev/null
+++ b/modelscope/models/cv/table_recognition/modules/lore_detector.py
@@ -0,0 +1,385 @@
+# ------------------------------------------------------------------------------
+# Part of implementation is adopted from CenterNet,
+# made publicly available under the MIT License at https://github.com/xingyizhou/CenterNet.git
+# ------------------------------------------------------------------------------
+
+import copy
+import math
+from os.path import join
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+
+def conv3x3(in_planes, out_planes, stride=1):
+ """3x3 convolution with padding"""
+ return nn.Conv2d(
+ in_planes,
+ out_planes,
+ kernel_size=3,
+ stride=stride,
+ padding=0,
+ bias=False)
+
+
+class ChannelAttention(nn.Module):
+
+ def __init__(self, in_planes, ratio=16):
+ super(ChannelAttention, self).__init__()
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
+ self.max_pool = nn.AdaptiveMaxPool2d(1)
+
+ self.fc1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)
+ self.relu1 = nn.ReLU()
+ self.fc2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
+
+ self.sigmoid = nn.Sigmoid()
+
+ def forward(self, x):
+ avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
+ max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
+
+ out = avg_out + max_out
+
+ return self.sigmoid(out)
+
+
+class SpatialAttention(nn.Module):
+
+ def __init__(self):
+ super(SpatialAttention, self).__init__()
+
+ self.conv1 = nn.Conv2d(2, 1, 3, padding=1, bias=False)
+ self.sigmoid = nn.Sigmoid()
+
+ def forward(self, x):
+ avg_out = torch.mean(x, dim=1, keepdim=True)
+ max_out, _ = torch.max(x, dim=1, keepdim=True)
+ x = torch.cat([avg_out, max_out], dim=1)
+ x = self.conv1(x)
+ return self.sigmoid(x)
+
+
+class BasicBlock(nn.Module):
+ expansion = 1
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
+ super(BasicBlock, self).__init__()
+ self.BN_MOMENTUM = 0.1
+ self.conv1 = nn.Conv2d(
+ inplanes, planes, kernel_size=3, stride=stride, padding=1)
+ self.bn1 = nn.BatchNorm2d(planes, momentum=self.BN_MOMENTUM)
+ self.relu = nn.ReLU(inplace=True)
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
+ self.bn2 = nn.BatchNorm2d(planes, momentum=self.BN_MOMENTUM)
+ self.downsample = downsample
+ self.stride = stride
+ self.planes = planes
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+ out = self.conv2(out)
+ out = self.bn2(out)
+ if self.downsample is not None:
+ residual = self.downsample(residual)
+
+ out += residual
+ out = self.relu(out)
+ return out
+
+
+class Bottleneck(nn.Module):
+ expansion = 4
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
+ super(Bottleneck, self).__init__()
+ self.BN_MOMENTUM = 0.1
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(planes, momentum=self.BN_MOMENTUM)
+ self.conv2 = nn.Conv2d(
+ planes,
+ planes,
+ kernel_size=3,
+ stride=stride,
+ padding=1,
+ bias=False)
+ self.bn2 = nn.BatchNorm2d(planes, momentum=self.BN_MOMENTUM)
+ self.conv3 = nn.Conv2d(
+ planes, planes * self.expansion, kernel_size=1, bias=False)
+ self.bn3 = nn.BatchNorm2d(
+ planes * self.expansion, momentum=self.BN_MOMENTUM)
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class LoreDetectModel(nn.Module):
+ """
+ A key point-based detector with ResNet backbone. In this model, it is trained for table cell detection.
+ See details in paper "LORE: Logical Location Regression Network for Table Structure Recognition"
+ (https://arxiv.org/abs/2303.03730)
+ """
+
+ def __init__(self, **kwargs):
+ '''
+ Args:
+ '''
+ self.BN_MOMENTUM = 0.1
+ self.inplanes = 64
+ self.deconv_with_bias = False
+ self.block = BasicBlock
+ self.layers = [2, 2, 2, 2]
+ self.heads = {
+ 'hm': 2,
+ 'st': 8,
+ 'wh': 8,
+ 'ax': 256,
+ 'cr': 256,
+ 'reg': 2
+ }
+ self.head_conv = 64
+
+ super(LoreDetectModel, self).__init__()
+ self.conv1 = nn.Conv2d(
+ 3, 64, kernel_size=7, stride=2, padding=3, bias=False)
+ self.bn1 = nn.BatchNorm2d(64, momentum=self.BN_MOMENTUM)
+ self.relu = nn.ReLU(inplace=True)
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+ self.layer1 = self._make_layer(
+ self.block, 64, self.layers[0], stride=2)
+ self.layer2 = self._make_layer(
+ self.block, 128, self.layers[1], stride=2)
+ self.layer3 = self._make_layer(
+ self.block, 256, self.layers[2], stride=2)
+ self.layer4 = self._make_layer(
+ self.block, 256, self.layers[3], stride=2)
+
+ self.adaption3 = nn.Conv2d(
+ 256, 256, kernel_size=1, stride=1, padding=0, bias=False)
+ self.adaption2 = nn.Conv2d(
+ 128, 256, kernel_size=1, stride=1, padding=0, bias=False)
+ self.adaption1 = nn.Conv2d(
+ 64, 256, kernel_size=1, stride=1, padding=0, bias=False)
+ self.adaption0 = nn.Conv2d(
+ 64, 256, kernel_size=1, stride=1, padding=0, bias=False)
+
+ self.adaptionU1 = nn.Conv2d(
+ 256, 256, kernel_size=1, stride=1, padding=0, bias=False)
+
+ # used for deconv layers
+ self.deconv_layers1 = self._make_deconv_layer(
+ 1,
+ [256],
+ [4],
+ )
+ self.deconv_layers2 = self._make_deconv_layer(
+ 1,
+ [256],
+ [4],
+ )
+ self.deconv_layers3 = self._make_deconv_layer(
+ 1,
+ [256],
+ [4],
+ )
+ self.deconv_layers4 = self._make_deconv_layer(
+ 1,
+ [256],
+ [4],
+ )
+
+ self.hm_maxpool = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
+ self.hm_sigmoid = nn.Sigmoid()
+ self.mk_maxpool = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
+ self.mk_sigmoid = nn.Sigmoid()
+
+ for head in sorted(self.heads):
+ num_output = self.heads[head]
+ if self.head_conv > 0 and (head == 'reg' or head == 'mk_reg'):
+ inchannel = 256
+ fc = nn.Sequential(
+ nn.Conv2d(
+ inchannel,
+ self.head_conv,
+ kernel_size=3,
+ padding=1,
+ bias=True), nn.ReLU(inplace=True),
+ nn.Conv2d(
+ self.head_conv,
+ num_output,
+ kernel_size=1,
+ stride=1,
+ padding=0))
+ elif self.head_conv > 0:
+ inchannel = 256
+ fc = nn.Sequential(
+ nn.Conv2d(
+ inchannel,
+ self.head_conv,
+ kernel_size=3,
+ padding=1,
+ bias=True), nn.ReLU(inplace=True),
+ nn.Conv2d(
+ self.head_conv,
+ self.head_conv,
+ kernel_size=3,
+ padding=1,
+ bias=True), nn.ReLU(inplace=True),
+ nn.Conv2d(
+ self.head_conv,
+ self.head_conv,
+ kernel_size=3,
+ padding=1,
+ bias=True), nn.ReLU(inplace=True),
+ nn.Conv2d(
+ self.head_conv,
+ self.head_conv,
+ kernel_size=3,
+ padding=1,
+ bias=True), nn.ReLU(inplace=True),
+ nn.Conv2d(
+ self.head_conv,
+ num_output,
+ kernel_size=1,
+ stride=1,
+ padding=0))
+ else:
+ inchannel = 256
+ fc = nn.Conv2d(
+ in_channels=inchannel,
+ out_channels=num_output,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.__setattr__(head, fc)
+
+ def _make_layer(self, block, planes, blocks, stride=1):
+ downsample = None
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ nn.Conv2d(
+ self.inplanes,
+ planes * block.expansion,
+ kernel_size=1,
+ stride=stride,
+ bias=False),
+ nn.BatchNorm2d(
+ planes * block.expansion, momentum=self.BN_MOMENTUM),
+ )
+
+ layers = []
+ layers.append(block(self.inplanes, planes, stride, downsample))
+ self.inplanes = planes * block.expansion
+ for i in range(1, blocks):
+ layers.append(block(self.inplanes, planes))
+
+ return nn.Sequential(*layers)
+
+ def _get_deconv_cfg(self, deconv_kernel, index):
+ if deconv_kernel == 4:
+ padding = 1
+ output_padding = 0
+ elif deconv_kernel == 3:
+ padding = 1
+ output_padding = 1
+ elif deconv_kernel == 2:
+ padding = 0
+ output_padding = 0
+ elif deconv_kernel == 7:
+ padding = 3
+ output_padding = 0
+
+ return deconv_kernel, padding, output_padding
+
+ def _make_deconv_layer(self, num_layers, num_filters, num_kernels):
+ assert num_layers == len(num_filters), \
+ 'ERROR: num_deconv_layers is different len(num_deconv_filters)'
+ assert num_layers == len(num_kernels), \
+ 'ERROR: num_deconv_layers is different len(num_deconv_filters)'
+
+ layers = []
+ for i in range(num_layers):
+ kernel, padding, output_padding = \
+ self._get_deconv_cfg(num_kernels[i], i)
+
+ planes = num_filters[i]
+ layers.append(
+ nn.ConvTranspose2d(
+ in_channels=self.inplanes,
+ out_channels=planes,
+ kernel_size=kernel,
+ stride=2,
+ padding=padding,
+ output_padding=output_padding,
+ bias=self.deconv_with_bias))
+ layers.append(nn.BatchNorm2d(planes, momentum=self.BN_MOMENTUM))
+ layers.append(nn.ReLU(inplace=True))
+ self.inplanes = planes
+
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ """
+ Args:
+ x : Input image, a tensor of [batch_size, channel, w, h].
+
+ Returns:
+ ret : A dict of tensors, the keys are corresponding to the keys of head as initialized,
+ and the value is tensors of [batch_size, dim_key ,w, h],
+ where dim_key is different according to different keys. For example,
+ in this implementation, the dim_keys are 2, 8, 8, 256, 256, 2.
+ """
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+ x0 = self.maxpool(x)
+ x1 = self.layer1(x0)
+ x2 = self.layer2(x1)
+ x3 = self.layer3(x2)
+ x4 = self.layer4(x3)
+
+ x3_ = self.deconv_layers1(x4)
+ x3_ = self.adaption3(x3) + x3_
+
+ x2_ = self.deconv_layers2(x3_)
+ x2_ = self.adaption2(x2) + x2_
+
+ x1_ = self.deconv_layers3(x2_)
+ x1_ = self.adaption1(x1) + x1_
+
+ x0_ = self.deconv_layers4(x1_) + self.adaption0(x0)
+ x0_ = self.adaptionU1(x0_)
+
+ ret = {}
+
+ for head in self.heads:
+ ret[head] = self.__getattr__(head)(x0_)
+ return [ret]
diff --git a/modelscope/models/cv/table_recognition/modules/lore_processor.py b/modelscope/models/cv/table_recognition/modules/lore_processor.py
new file mode 100644
index 00000000..c643b334
--- /dev/null
+++ b/modelscope/models/cv/table_recognition/modules/lore_processor.py
@@ -0,0 +1,440 @@
+# ------------------------------------------------------------------------------
+# Part of implementation is adopted from CenterNet,
+# made publicly available under the MIT License at https://github.com/xingyizhou/CenterNet.git
+# ------------------------------------------------------------------------------
+
+import copy
+import math
+from os.path import join
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+
+def get_clones(module, N):
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
+
+
+class Encoder(nn.Module):
+
+ def __init__(self, input_size, hidden_size, N, heads, dropout):
+ super().__init__()
+ self.N = N
+ self.pe = PositionalEncoder(hidden_size, dropout=dropout)
+ self.layers = get_clones(EncoderLayer(hidden_size, heads, dropout), N)
+ self.norm = Norm(hidden_size)
+
+ def forward(self, x, mask=None, require_att=False):
+ att = None
+ for i in range(self.N):
+ if mask is None:
+ if i == (self.N - 1):
+ x, att = self.layers[i](x, require_att=True)
+ else:
+ x = self.layers[i](x)
+ else:
+ x = self.layers[i](x, mask)
+ if require_att:
+ return x, att
+ else:
+ return x
+
+
+class Decoder(nn.Module):
+
+ def __init__(self, hidden_size, output_size):
+ super(Decoder, self).__init__()
+ self.linear = nn.Sequential(
+ nn.Linear(hidden_size, hidden_size),
+ nn.ReLU(inplace=True),
+ nn.Linear(hidden_size, output_size),
+ nn.ReLU(inplace=True) # newly added
+ )
+
+ def forward(self, x):
+ out = self.linear(x)
+ return out
+
+
+class Transformer(nn.Module):
+
+ def __init__(self, input_size, hidden_size, output_size, n_layers, heads,
+ dropout):
+ super().__init__()
+ self.linear = nn.Linear(input_size, hidden_size)
+ self.encoder = Encoder(input_size, hidden_size, n_layers, heads,
+ dropout)
+ self.decoder = Decoder(hidden_size, output_size)
+
+ def forward(self, x, mask=None, require_att=False):
+ x = self.linear(x)
+ att = None
+ if mask is None:
+ # evaluation model
+ if require_att:
+ embedding, att = self.encoder(x, require_att=True)
+ else:
+ embedding = self.encoder(x)
+
+ output = self.decoder(embedding)
+
+ if require_att:
+ return output, att
+ else:
+ return output
+ else:
+ if require_att:
+ embedding, att = self.encoder(x, mask, require_att=True)
+ else:
+ embedding = self.encoder(x, mask)
+
+ output = self.decoder(embedding)
+ return output
+
+
+class Norm(nn.Module):
+
+ def __init__(self, d_model, eps=1e-6):
+ super().__init__()
+
+ self.size = d_model
+ # create two learnable parameters to calibrate normalisation
+ self.alpha = nn.Parameter(torch.ones(self.size))
+ self.bias = nn.Parameter(torch.zeros(self.size))
+ self.eps = eps
+
+ def forward(self, x):
+ norm = self.alpha * (x - x.mean(dim=-1, keepdim=True)) \
+ / (x.std(dim=-1, keepdim=True) + self.eps) + self.bias
+ return norm
+
+
+def attention(q, k, v, d_k, mask=None, dropout=None):
+ scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
+
+ if mask is not None:
+ if len(mask.shape) == 2:
+ mask = mask.unsqueeze(1)
+ mask = mask.unsqueeze(3)
+ mask = mask.to(torch.float32)
+ mask2d = torch.matmul(mask, mask.transpose(-2, -1)).expand(
+ scores.shape[0], scores.shape[1], scores.shape[2],
+ scores.shape[3])
+ elif len(mask.shape) == 3:
+ mask = mask.unsqueeze(1)
+ mask = mask.to(torch.float32)
+ mask2d = mask.expand(scores.shape[0], scores.shape[1],
+ scores.shape[2], scores.shape[3])
+
+ scores = scores.masked_fill(mask2d == 0, -1e9)
+
+ scores = F.softmax(scores, dim=-1)
+ if dropout is not None:
+ scores = dropout(scores)
+
+ output = torch.matmul(scores, v)
+ return output
+
+
+def attention_score(q, k, v, d_k):
+ scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
+ scores = F.softmax(scores, dim=-1)
+ return scores
+
+
+class MultiHeadAttention(nn.Module):
+
+ def __init__(self, heads, d_model, dropout=0.1):
+ super().__init__()
+
+ self.d_model = d_model
+ self.d_k = d_model // heads
+ self.h = heads
+
+ self.q_linear = nn.Linear(d_model, d_model)
+ self.v_linear = nn.Linear(d_model, d_model)
+ self.k_linear = nn.Linear(d_model, d_model)
+
+ self.dropout = nn.Dropout(dropout)
+ self.out = nn.Linear(d_model, d_model)
+
+ def attention_map(self, q, k, v, mask=None):
+ bs = q.size(0)
+
+ # perform linear operation and split into N heads
+ k = self.k_linear(k).view(bs, -1, self.h, self.d_k)
+ q = self.q_linear(q).view(bs, -1, self.h, self.d_k)
+ v = self.v_linear(v).view(bs, -1, self.h, self.d_k)
+
+ # transpose to get dimensions bs * N * sl * d_model
+ k = k.transpose(1, 2)
+ q = q.transpose(1, 2)
+ v = v.transpose(1, 2)
+
+ scores = attention_score(q, k, v, self.d_k)
+
+ return scores
+
+ def forward(self, q, k, v, mask=None):
+
+ bs = q.size(0)
+
+ # perform linear operation and split into N heads
+ k = self.k_linear(k).view(bs, -1, self.h, self.d_k)
+ q = self.q_linear(q).view(bs, -1, self.h, self.d_k)
+ v = self.v_linear(v).view(bs, -1, self.h, self.d_k)
+
+ # transpose to get dimensions bs * N * sl * d_model
+ k = k.transpose(1, 2)
+ q = q.transpose(1, 2)
+ v = v.transpose(1, 2)
+
+ # calculate attention using function we will define next
+ scores = attention(q, k, v, self.d_k, mask, self.dropout)
+ # concatenate heads and put through final linear layer
+
+ concat = scores.transpose(1, 2).contiguous() \
+ .view(bs, -1, self.d_model)
+ output = self.out(concat)
+
+ return output
+
+
+class FeedForward(nn.Module):
+
+ def __init__(self, d_model, d_ff=2048, dropout=0.1):
+ super().__init__()
+
+ # We set d_ff as a default to 2048
+ self.linear_1 = nn.Linear(d_model, d_ff)
+ self.dropout = nn.Dropout(dropout)
+ self.linear_2 = nn.Linear(d_ff, d_model)
+
+ def forward(self, x):
+ x = self.dropout(F.relu(self.linear_1(x)))
+ x = self.linear_2(x)
+ return x
+
+
+class Embedder(nn.Module):
+
+ def __init__(self, vocab_size, d_model):
+ super().__init__()
+ self.d_model = d_model
+ self.embed = nn.Embedding(vocab_size, d_model)
+
+ def forward(self, x):
+ return self.embed(x)
+
+
+class PositionalEncoder(nn.Module):
+
+ def __init__(self, d_model, max_seq_len=900, dropout=0.1):
+ super().__init__()
+ self.d_model = d_model
+ self.dropout = nn.Dropout(dropout)
+ # create constant 'pe' matrix with values dependant on
+ # pos and i
+ pe = torch.zeros(max_seq_len, d_model)
+ for pos in range(max_seq_len):
+ for i in range(0, d_model, 2):
+ sin_coef = 10000**((2 * i) / d_model)
+ cos_coef = 10000**((2 * (i + 1)) / d_model)
+ pe[pos, i] = math.sin(pos / sin_coef)
+ pe[pos, i + 1] = math.cos(pos / cos_coef)
+ pe = pe.unsqueeze(0)
+ self.register_buffer('pe', pe)
+
+ def forward(self, x):
+ # make embeddings relatively larger
+ x = x * math.sqrt(self.d_model)
+ # add constant to embedding
+ seq_len = x.size(1)
+ pe = Variable(self.pe[:, :seq_len], requires_grad=False)
+ if x.is_cuda:
+ pe.cuda()
+ x = x + pe
+ return self.dropout(x)
+
+
+class EncoderLayer(nn.Module):
+
+ def __init__(self, d_model, heads, dropout=0.1):
+ super().__init__()
+ self.norm_1 = Norm(d_model)
+ self.norm_2 = Norm(d_model)
+ self.attn = MultiHeadAttention(heads, d_model, dropout=dropout)
+ self.ff = FeedForward(d_model, dropout=dropout)
+ self.dropout_1 = nn.Dropout(dropout)
+ self.dropout_2 = nn.Dropout(dropout)
+
+ def forward(self, x, mask=None, require_att=False):
+ x2 = self.norm_1(x)
+ xc = x2.clone()
+
+ if mask is None:
+ x = x + self.dropout_1(self.attn(x2, x2, x2))
+ else:
+ x = x + self.dropout_1(self.attn(x2, x2, x2, mask))
+
+ x2 = self.norm_2(x)
+ x = x + self.dropout_2(self.ff(x2))
+
+ if require_att:
+ att = self.attn.attention_map(xc, xc, xc)
+ return x, att
+ else:
+ return x
+
+
+class DecoderLayer(nn.Module):
+
+ def __init__(self, d_model, heads, dropout=0.1):
+ super().__init__()
+ self.norm_1 = Norm(d_model)
+ self.norm_2 = Norm(d_model)
+ self.norm_3 = Norm(d_model)
+
+ self.dropout_1 = nn.Dropout(dropout)
+ self.dropout_2 = nn.Dropout(dropout)
+ self.dropout_3 = nn.Dropout(dropout)
+
+ self.attn_1 = MultiHeadAttention(heads, d_model, dropout=dropout)
+ self.attn_2 = MultiHeadAttention(heads, d_model, dropout=dropout)
+ self.ff = FeedForward(d_model, dropout=dropout)
+
+ def forward(self, x, e_outputs, src_mask, trg_mask):
+ x2 = self.norm_1(x)
+ x = x + self.dropout_1(self.attn_1(x2, x2, x2, trg_mask))
+ x2 = self.norm_2(x)
+ x = x + self.dropout_2(self.attn_2(x2, e_outputs, e_outputs, src_mask))
+ x2 = self.norm_3(x)
+ x = x + self.dropout_3(self.ff(x2))
+ return x
+
+
+class Stacker(nn.Module):
+ '''
+ The architecture of the stacking regressor, which takes the dense representations and
+ logical locations of table cells to make more accurate prediction of logical locations.
+ '''
+
+ def __init__(self,
+ input_size,
+ hidden_size,
+ output_size,
+ layers,
+ heads=8,
+ dropout=0.1):
+ """
+ Args:
+ input_size : The dim of logical locations which is always 4.
+ hidden_size : The dim of hidden states which is 256 by default.
+ output_size : The dim of logical locations which is always 4.
+ layers : Number of layers of self-attention mechanism, which is 4 in this implementation.
+ """
+ super(Stacker, self).__init__()
+ self.logi_encoder = nn.Sequential(
+ nn.Linear(input_size, hidden_size), nn.ReLU(inplace=True),
+ nn.Linear(hidden_size, hidden_size), nn.ReLU(inplace=True))
+ self.tsfm = Transformer(2 * hidden_size, hidden_size, output_size,
+ layers, heads, dropout)
+
+ def forward(self, outputs, logi, mask=None, require_att=False):
+ """
+ Args:
+ outputs : The dense representation of table cells, a tensor of [batch_size, number_of_objects, hidden_size].
+ logi : The logical location of table cells, a tensor of [batch_size, number_of_objects, 4].
+ mask : The mask of cells, a tensor of [batch_size, number_of_objects], not None only in training stage.
+ require_att : If True, the model will also generate the attention maps of table cells.
+
+ Returns:
+ stacked_axis : The predicted logical location of cells, a tensor of [batch_size, number_of_objects, 4].
+ att : The attention map of table cells.
+ """
+ logi_embeddings = self.logi_encoder(logi)
+
+ cat_embeddings = torch.cat((logi_embeddings, outputs), dim=2)
+
+ if mask is None:
+ if require_att:
+ stacked_axis, att = self.tsfm(cat_embeddings)
+ else:
+ stacked_axis = self.tsfm(cat_embeddings)
+ else:
+ stacked_axis = self.tsfm(cat_embeddings, mask=mask)
+
+ if require_att:
+ return stacked_axis, att
+ else:
+ return stacked_axis
+
+
+class LoreProcessModel(nn.Module):
+ '''
+ The logical location prediction head of LORE. It contains a base regressor and a stacking regressor.
+ They both consist of several self-attention blocks.
+ See details in paper "LORE: Logical Location Regression Network for Table Structure Recognition"
+ (https://arxiv.org/abs/2303.03730).
+ '''
+
+ def __init__(self, **kwargs):
+ '''
+ Args:
+ '''
+ super(LoreProcessModel, self).__init__()
+
+ self.input_size = 256
+ self.output_size = 4
+ self.hidden_size = 256
+ self.max_fmp_size = 256
+ self.stacking_layers = 4
+ self.tsfm_layers = 4
+ self.num_heads = 8
+ self.att_dropout = 0.1
+ self.stacker = Stacker(self.output_size, self.hidden_size,
+ self.output_size, self.stacking_layers)
+ self.tsfm_axis = Transformer(self.input_size, self.hidden_size,
+ self.output_size, self.tsfm_layers,
+ self.num_heads, self.att_dropout)
+ self.x_position_embeddings = nn.Embedding(self.max_fmp_size,
+ self.hidden_size)
+ self.y_position_embeddings = nn.Embedding(self.max_fmp_size,
+ self.hidden_size)
+
+ def forward(self, outputs, batch=None, cc_match=None, dets=None):
+ """
+ Args:
+ outputs : The dense representation of table cells from the detection part of LORE,
+ a tensor of [batch_size, number_of_objects, hidden_size].
+ batch : The detection results of other source, such as external OCR systems.
+ dets : The detection results of each table cells, a tensor of [batch_size, number_of_objects, 8].
+
+ Returns:
+ logi_axis : The output logical location of base regressor,
+ a tensor of [batch_size, number_of_objects, 4].
+ stacked_axis : The output logical location of stacking regressor,
+ a tensor of [batch_size, number_of_objects, 4].
+ """
+ if batch is None:
+ # evaluation mode
+ vis_feat = outputs
+
+ if batch is None:
+ if dets is None:
+ logic_axis = self.tsfm_axis(vis_feat)
+ stacked_axis = self.stacker(vis_feat, logic_axis)
+ else:
+ left_pe = self.x_position_embeddings(dets[:, :, 0])
+ upper_pe = self.y_position_embeddings(dets[:, :, 1])
+ right_pe = self.x_position_embeddings(dets[:, :, 2])
+ lower_pe = self.y_position_embeddings(dets[:, :, 5])
+ feat = vis_feat + left_pe + upper_pe + right_pe + lower_pe
+
+ logic_axis = self.tsfm_axis(feat)
+
+ stacked_axis = self.stacker(feat, logic_axis)
+
+ return logic_axis, stacked_axis
diff --git a/modelscope/models/cv/tinynas_detection/damo/apis/detector_evaluater.py b/modelscope/models/cv/tinynas_detection/damo/apis/detector_evaluater.py
index 82ffb567..6ff194f6 100644
--- a/modelscope/models/cv/tinynas_detection/damo/apis/detector_evaluater.py
+++ b/modelscope/models/cv/tinynas_detection/damo/apis/detector_evaluater.py
@@ -8,8 +8,8 @@ from modelscope.models.cv.tinynas_detection.damo.apis.detector_inference import
inference
from modelscope.models.cv.tinynas_detection.damo.detectors.detector import \
build_local_model
-from modelscope.msdatasets.task_datasets.damoyolo import (build_dataloader,
- build_dataset)
+from modelscope.msdatasets.dataset_cls.custom_datasets.damoyolo import (
+ build_dataloader, build_dataset)
def mkdir(path):
diff --git a/modelscope/models/cv/tinynas_detection/damo/apis/detector_inference.py b/modelscope/models/cv/tinynas_detection/damo/apis/detector_inference.py
index 47c1fb1b..dcd33834 100644
--- a/modelscope/models/cv/tinynas_detection/damo/apis/detector_inference.py
+++ b/modelscope/models/cv/tinynas_detection/damo/apis/detector_inference.py
@@ -5,7 +5,7 @@ import os
import torch
from tqdm import tqdm
-from modelscope.msdatasets.task_datasets.damoyolo.evaluation import evaluate
+from modelscope.msdatasets.dataset_cls.custom_datasets.damoyolo import evaluate
from modelscope.utils.logger import get_logger
from modelscope.utils.timer import Timer, get_time_str
from modelscope.utils.torch_utils import (all_gather, get_world_size,
diff --git a/modelscope/models/cv/tinynas_detection/detector.py b/modelscope/models/cv/tinynas_detection/detector.py
index 94599dcc..4ece5f91 100644
--- a/modelscope/models/cv/tinynas_detection/detector.py
+++ b/modelscope/models/cv/tinynas_detection/detector.py
@@ -49,6 +49,7 @@ class SingleStageDetector(TorchModel):
self.head = build_head(self.cfg.model.head)
self.head.nms = False
self.apply(self.init_bn)
+ self.onnx_export = False
self.load_pretrain_model(model_path)
@@ -82,6 +83,9 @@ class SingleStageDetector(TorchModel):
return prediction
def postprocess(self, preds):
+ if self.onnx_export:
+ return preds
+
bboxes, scores, labels_idx = postprocess_gfocal(
preds, self.num_classes, self.conf_thre, self.nms_thre)
bboxes = bboxes.cpu().numpy()
diff --git a/modelscope/models/cv/video_instance_segmentation/__init__.py b/modelscope/models/cv/video_instance_segmentation/__init__.py
new file mode 100644
index 00000000..294c00f7
--- /dev/null
+++ b/modelscope/models/cv/video_instance_segmentation/__init__.py
@@ -0,0 +1,25 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+from typing import TYPE_CHECKING
+
+from modelscope.utils.import_utils import LazyImportModule
+
+if TYPE_CHECKING:
+ from .video_knet import (
+ KNetTrack, )
+ from .neck import MSDeformAttnPixelDecoder
+
+else:
+ _import_structure = {
+ 'video_knet': ['KNetTrack'],
+ 'neck': ['MSDeformAttnPixelDecoder']
+ }
+
+ import sys
+
+ sys.modules[__name__] = LazyImportModule(
+ __name__,
+ globals()['__file__'],
+ _import_structure,
+ module_spec=__spec__,
+ extra_objects={},
+ )
diff --git a/modelscope/models/cv/video_instance_segmentation/head/__init__.py b/modelscope/models/cv/video_instance_segmentation/head/__init__.py
new file mode 100644
index 00000000..b937315b
--- /dev/null
+++ b/modelscope/models/cv/video_instance_segmentation/head/__init__.py
@@ -0,0 +1 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
diff --git a/modelscope/models/cv/video_instance_segmentation/head/kernel_frame_iter_head.py b/modelscope/models/cv/video_instance_segmentation/head/kernel_frame_iter_head.py
new file mode 100644
index 00000000..dcab7dba
--- /dev/null
+++ b/modelscope/models/cv/video_instance_segmentation/head/kernel_frame_iter_head.py
@@ -0,0 +1,414 @@
+# The implementation is adopted from Video-K-Net,
+# made publicly available at https://github.com/lxtGH/Video-K-Net follow the MIT license
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import build_norm_layer
+from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention
+from mmdet.core import build_assigner, build_sampler
+from mmdet.models.builder import HEADS, build_head
+from mmdet.models.roi_heads import BaseRoIHead
+from mmdet.utils import get_root_logger
+
+
+@HEADS.register_module()
+class KernelFrameIterHeadVideo(BaseRoIHead):
+
+ def __init__(self,
+ mask_head=None,
+ with_mask_init=False,
+ num_stages=3,
+ stage_loss_weights=(1, 1, 1),
+ proposal_feature_channel=256,
+ assign_stages=5,
+ num_proposals=100,
+ num_thing_classes=80,
+ num_stuff_classes=53,
+ query_merge_method='mean',
+ train_cfg=None,
+ test_cfg=None,
+ pretrained=None,
+ init_cfg=None,
+ **kwargs):
+ assert len(stage_loss_weights) == num_stages
+ self.num_stages = num_stages
+ self.stage_loss_weights = stage_loss_weights
+ self.assign_stages = assign_stages
+ self.num_proposals = num_proposals
+ self.num_thing_classes = num_thing_classes
+ self.num_stuff_classes = num_stuff_classes
+ self.query_merge_method = query_merge_method
+ self.proposal_feature_channel = proposal_feature_channel
+ super().__init__(
+ mask_head=mask_head,
+ train_cfg=train_cfg,
+ test_cfg=test_cfg,
+ init_cfg=init_cfg,
+ **kwargs)
+ if self.query_merge_method == 'attention':
+ self.init_query = nn.Embedding(self.num_proposals,
+ self.proposal_feature_channel)
+ _num_head = 8
+ _drop_out = 0.
+ self.query_merge_attn = MultiheadAttention(
+ self.proposal_feature_channel,
+ _num_head,
+ _drop_out,
+ batch_first=True)
+ self.query_merge_norm = build_norm_layer(
+ dict(type='LN'), self.proposal_feature_channel)[1]
+ self.query_merge_ffn = FFN(
+ self.proposal_feature_channel,
+ self.proposal_feature_channel * 8,
+ num_ffn_fcs=2,
+ act_cfg=dict(type='ReLU', inplace=True),
+ ffn_drop=0.)
+ self.query_merge_ffn_norm = build_norm_layer(
+ dict(type='LN'), self.proposal_feature_channel)[1]
+ elif self.query_merge_method == 'attention_pos':
+ self.init_query = nn.Embedding(self.num_proposals,
+ self.proposal_feature_channel)
+ self.query_pos = nn.Embedding(self.num_proposals,
+ self.proposal_feature_channel)
+ _num_head = 8
+ _drop_out = 0.
+ self.query_merge_attn = MultiheadAttention(
+ self.proposal_feature_channel,
+ _num_head,
+ _drop_out,
+ batch_first=True)
+ self.query_merge_norm = build_norm_layer(
+ dict(type='LN'), self.proposal_feature_channel)[1]
+ self.query_merge_ffn = FFN(
+ self.proposal_feature_channel,
+ self.proposal_feature_channel * 8,
+ num_ffn_fcs=2,
+ act_cfg=dict(type='ReLU', inplace=True),
+ ffn_drop=0.)
+ self.query_merge_ffn_norm = build_norm_layer(
+ dict(type='LN'), self.proposal_feature_channel)[1]
+
+ self.with_mask_init = with_mask_init
+ if self.with_mask_init:
+ self.fc_mask = nn.Linear(proposal_feature_channel,
+ proposal_feature_channel)
+
+ self.logger = get_root_logger()
+
+ def init_mask_head(self, bbox_roi_extractor=None, mask_head=None):
+ assert bbox_roi_extractor is None
+ self.mask_head = nn.ModuleList()
+ if not isinstance(mask_head, list):
+ mask_head = [mask_head for _ in range(self.num_stages)]
+ assert len(mask_head) == self.num_stages
+ for idx, head in enumerate(mask_head):
+ head.update(with_cls=(idx < self.assign_stages))
+ self.mask_head.append(build_head(head))
+
+ def init_assigner_sampler(self):
+ """Initialize assigner and sampler for each stage."""
+ self.mask_assigner = []
+ self.mask_sampler = []
+ if self.train_cfg is not None:
+ for i in range(self.num_stages):
+ self.mask_assigner.append(
+ build_assigner(self.train_cfg['assigner']))
+ self.current_stage = i
+ self.mask_sampler.append(
+ build_sampler(self.train_cfg['sampler'], context=self))
+
+ def init_bbox_head(self, mask_roi_extractor, mask_head):
+ """Initialize box head and box roi extractor.
+
+ Args:
+ mask_roi_extractor (dict): Config of box roi extractor.
+ mask_head (dict): Config of box in box head.
+ """
+ raise NotImplementedError
+
+ def _mask_forward(self, stage, x, object_feats, mask_preds):
+ mask_head = self.mask_head[stage]
+ cls_score, mask_preds, object_feats = mask_head(
+ x,
+ object_feats,
+ mask_preds,
+ img_metas=None,
+ pos=self.query_pos.weight
+ if self.query_merge_method == 'attention_pos' else None)
+ if mask_head.mask_upsample_stride > 1 and (stage == self.num_stages - 1
+ or self.training):
+ scaled_mask_preds = [
+ F.interpolate(
+ mask_preds[i],
+ scale_factor=mask_head.mask_upsample_stride,
+ align_corners=False,
+ mode='bilinear') for i in range(mask_preds.size(0))
+ ]
+ scaled_mask_preds = torch.stack(scaled_mask_preds)
+ else:
+ scaled_mask_preds = mask_preds
+
+ mask_results = dict(
+ cls_score=cls_score,
+ mask_preds=mask_preds,
+ scaled_mask_preds=scaled_mask_preds,
+ object_feats=object_feats)
+ return mask_results
+
+ def _query_fusion(self, obj_feats, num_imgs, num_frames):
+ if self.query_merge_method == 'mean':
+ object_feats = obj_feats.mean(1)
+ elif self.query_merge_method == 'attention':
+ assert obj_feats.size()[-2:] == (
+ 1, 1), 'Only supporting kernel size = 1'
+ obj_feats = obj_feats.reshape(
+ (num_imgs, num_frames * self.num_proposals,
+ self.proposal_feature_channel))
+ init_query = self.init_query.weight.expand(
+ num_imgs, *self.init_query.weight.size())
+ obj_feats = self.query_merge_attn(
+ query=init_query, key=obj_feats, value=obj_feats)
+ obj_feats = self.query_merge_norm(obj_feats)
+ object_feats = self.query_merge_ffn_norm(
+ self.query_merge_ffn(obj_feats))
+ object_feats = object_feats[..., None, None]
+ elif self.query_merge_method == 'attention_pos':
+ assert obj_feats.size()[-2:] == (
+ 1, 1), 'Only supporting kernel size = 1'
+ obj_feats = obj_feats.reshape(
+ (num_imgs, num_frames * self.num_proposals,
+ self.proposal_feature_channel))
+ init_query = self.init_query.weight.expand(
+ num_imgs, *self.init_query.weight.size())
+ query_pos = self.query_pos.weight.repeat(num_imgs, 1, 1)
+ key_pos = query_pos.repeat(1, num_frames, 1)
+ obj_feats = self.query_merge_attn(
+ query=init_query,
+ key=obj_feats,
+ value=obj_feats,
+ query_pos=query_pos,
+ key_pos=key_pos)
+ obj_feats = self.query_merge_norm(obj_feats)
+ object_feats = self.query_merge_ffn_norm(
+ self.query_merge_ffn(obj_feats))
+ object_feats = object_feats[..., None, None]
+
+ return object_feats
+
+ def _mask_init(self, object_feats, x_feats, num_imgs):
+ assert object_feats.size()[-2:] == (
+ 1, 1), 'Only supporting kernel size = 1'
+ object_feats = object_feats.flatten(-3, -1) # BNCKK -> BNC
+ mask_feat = self.fc_mask(object_feats)[..., None, None]
+ mask_preds = []
+ for i in range(num_imgs):
+ mask_preds.append(F.conv2d(x_feats[i], mask_feat[i], padding=0))
+
+ mask_preds = torch.stack(mask_preds, dim=0)
+
+ return mask_preds
+
+ def forward_train(self, x, ref_img_metas, cls_scores, masks, obj_feats,
+ ref_gt_masks, ref_gt_labels, ref_gt_instance_ids,
+ **kwargs):
+ num_imgs = len(ref_img_metas)
+ num_frames = len(ref_img_metas[0])
+ if len(obj_feats.size()) == 6:
+ object_feats = self._query_fusion(obj_feats, num_imgs, num_frames)
+ else:
+ object_feats = obj_feats
+
+ all_stage_loss = {}
+ if self.with_mask_init:
+ mask_preds = self._mask_init(object_feats, x, num_imgs)
+ assert self.training
+ if self.mask_head[0].mask_upsample_stride > 1:
+ scaled_mask_preds = [
+ F.interpolate(
+ mask_preds[i],
+ scale_factor=self.mask_head[0].mask_upsample_stride,
+ align_corners=False,
+ mode='bilinear') for i in range(mask_preds.size(0))
+ ]
+ scaled_mask_preds = torch.stack(scaled_mask_preds)
+ else:
+ scaled_mask_preds = mask_preds
+ _gt_masks_matches = []
+ _assign_results = []
+ _sampling_results = []
+ _pred_masks_concat = []
+ for i in range(num_imgs):
+ mask_for_assign = scaled_mask_preds[i][:self.
+ num_proposals].detach()
+ cls_for_assign = None
+ assign_result, gt_masks_match = self.mask_assigner[0].assign(
+ mask_for_assign, cls_for_assign, ref_gt_masks[i],
+ ref_gt_labels[i], ref_gt_instance_ids[i])
+ _gt_masks_matches.append(gt_masks_match)
+ _assign_results.append(assign_result)
+ num_bboxes = scaled_mask_preds.size(2)
+ h, w = scaled_mask_preds.shape[-2:]
+ pred_masks_match = torch.einsum('fqhw->qfhw',
+ scaled_mask_preds[i]).reshape(
+ (num_bboxes, -1, w))
+ sampling_result = self.mask_sampler[0].sample(
+ assign_result, pred_masks_match, gt_masks_match)
+ _sampling_results.append(sampling_result)
+ _pred_masks_concat.append(pred_masks_match)
+ pred_masks_concat = torch.stack(_pred_masks_concat)
+ mask_targets = self.mask_head[0].get_targets(
+ _sampling_results,
+ self.train_cfg,
+ True,
+ gt_sem_seg=None,
+ gt_sem_cls=None)
+
+ single_stage_loss = self.mask_head[0].loss(object_feats, None,
+ pred_masks_concat,
+ *mask_targets)
+ for key, value in single_stage_loss.items():
+ all_stage_loss[
+ f'tracker_init_{key}'] = value * self.stage_loss_weights[0]
+ else:
+ mask_preds = masks
+
+ assign_results = []
+ for stage in range(self.num_stages):
+ if stage == self.assign_stages:
+ object_feats = object_feats[:, None].repeat(
+ 1, num_frames, 1, 1, 1, 1)
+ mask_results = self._mask_forward(stage, x, object_feats,
+ mask_preds)
+ mask_preds = mask_results['mask_preds']
+ scaled_mask_preds = mask_results['scaled_mask_preds']
+ cls_score = mask_results['cls_score']
+ object_feats = mask_results['object_feats']
+
+ prev_mask_preds = scaled_mask_preds.detach()
+ prev_cls_score = cls_score.detach(
+ ) if cls_score is not None else None
+
+ sampling_results = []
+ pred_masks_concat = []
+ if stage < self.assign_stages:
+ assign_results = []
+ gt_masks_matches = []
+ for i in range(num_imgs):
+ if stage < self.assign_stages:
+ mask_for_assign = prev_mask_preds[i][:, :self.
+ num_proposals]
+ if prev_cls_score is not None:
+ cls_for_assign = prev_cls_score[
+ i][:self.num_proposals, :self.num_thing_classes]
+ else:
+ cls_for_assign = None
+ assign_result, gt_masks_match = self.mask_assigner[
+ stage].assign(mask_for_assign, cls_for_assign,
+ ref_gt_masks[i], ref_gt_labels[i],
+ ref_gt_instance_ids[i])
+ gt_masks_matches.append(gt_masks_match)
+ assign_results.append(assign_result)
+ num_bboxes = scaled_mask_preds.size(2)
+ h, w = scaled_mask_preds.shape[-2:]
+ pred_masks_match = torch.einsum('fqhw->qfhw',
+ scaled_mask_preds[i]).reshape(
+ (num_bboxes, -1, w))
+ sampling_result = self.mask_sampler[stage].sample(
+ assign_results[i], pred_masks_match, gt_masks_matches[i])
+ sampling_results.append(sampling_result)
+ pred_masks_concat.append(pred_masks_match)
+ pred_masks_concat = torch.stack(pred_masks_concat)
+ mask_targets = self.mask_head[stage].get_targets(
+ sampling_results,
+ self.train_cfg,
+ True,
+ gt_sem_seg=None,
+ gt_sem_cls=None)
+
+ single_stage_loss = self.mask_head[stage].loss(
+ object_feats, cls_score, pred_masks_concat, *mask_targets)
+ for key, value in single_stage_loss.items():
+ all_stage_loss[
+ f'tracker_s{stage}_{key}'] = value * self.stage_loss_weights[
+ stage]
+
+ features = {
+ 'obj_feats': object_feats,
+ 'x_feats': x,
+ 'cls_scores': cls_score,
+ 'masks': mask_preds,
+ }
+ return all_stage_loss, features
+
+ def simple_test(self, x, img_metas, ref_img_metas, cls_scores, masks,
+ obj_feats, **kwargs):
+ num_imgs = len(ref_img_metas)
+ num_frames = len(ref_img_metas[0])
+
+ if len(obj_feats.size()) == 6:
+ object_feats = self._query_fusion(obj_feats, num_imgs, num_frames)
+ else:
+ object_feats = obj_feats
+
+ if self.with_mask_init:
+ mask_preds = self._mask_init(object_feats, x, num_imgs)
+ else:
+ mask_preds = masks
+
+ cls_score = None
+ for stage in range(self.num_stages):
+ if stage == self.assign_stages:
+ object_feats = object_feats[:, None].repeat(
+ 1, num_frames, 1, 1, 1, 1)
+ mask_results = self._mask_forward(stage, x, object_feats,
+ mask_preds)
+ mask_preds = mask_results['mask_preds']
+ scaled_mask_preds = mask_results['scaled_mask_preds']
+ cls_score = mask_results['cls_score'] if mask_results[
+ 'cls_score'] is not None else cls_score
+ object_feats = mask_results['object_feats']
+
+ num_classes = self.mask_head[-1].num_classes
+ results = []
+ if self.mask_head[-1].loss_cls.use_sigmoid:
+ cls_score = cls_score.sigmoid()
+ else:
+ cls_score = cls_score.softmax(-1)[..., :-1]
+
+ for img_id in range(num_imgs):
+ result = []
+ cls_score_per_img = cls_score[img_id]
+ # h, quite tricky here, a bounding box can predict multiple results with different labels
+ scores_per_img, topk_indices = cls_score_per_img.flatten(
+ 0, 1).topk(
+ self.test_cfg['max_per_img'], sorted=True)
+ mask_indices = topk_indices // num_classes
+ # Use the following when torch >= 1.9.0
+ # mask_indices = torch.div(topk_indices, num_classes, rounding_mode='floor')
+ labels_per_img = topk_indices % num_classes
+ for frame_id in range(num_frames):
+ masks_per_img = scaled_mask_preds[img_id][frame_id][
+ mask_indices]
+ single_result = self.mask_head[-1].get_seg_masks_tracking(
+ masks_per_img, labels_per_img, scores_per_img,
+ torch.arange(self.test_cfg['max_per_img']), self.test_cfg,
+ img_metas[img_id])
+ result.append(single_result)
+ results.append(result)
+ features = {
+ 'obj_feats': object_feats,
+ 'x_feats': x,
+ 'cls_scores': cls_score,
+ 'masks': mask_preds,
+ }
+ return results, features
+
+ def init_weights(self):
+ if self.init_cfg is not None and self.init_cfg[
+ 'type'] == 'Pretrained' and self.init_cfg['prefix'] is not None:
+ from mmcv.cnn import initialize
+ self.logger.info('Customized loading the tracker.')
+ initialize(self, self.init_cfg)
+ else:
+ super().init_weights()
diff --git a/modelscope/models/cv/video_instance_segmentation/head/kernel_head.py b/modelscope/models/cv/video_instance_segmentation/head/kernel_head.py
new file mode 100644
index 00000000..debdc5be
--- /dev/null
+++ b/modelscope/models/cv/video_instance_segmentation/head/kernel_head.py
@@ -0,0 +1,523 @@
+# The implementation is adopted from Video-K-Net,
+# made publicly available at https://github.com/lxtGH/Video-K-Net follow the MIT license
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import ConvModule, bias_init_with_prob, normal_init
+from mmdet.core import build_assigner, build_sampler, multi_apply, reduce_mean
+from mmdet.models.builder import HEADS, build_loss, build_neck
+from mmdet.models.losses import accuracy
+from mmdet.utils import get_root_logger
+
+
+@HEADS.register_module()
+class ConvKernelHeadVideo(nn.Module):
+
+ def __init__(self,
+ num_proposals=100,
+ in_channels=256,
+ out_channels=256,
+ num_heads=8,
+ num_cls_fcs=1,
+ num_seg_convs=1,
+ num_loc_convs=1,
+ att_dropout=False,
+ localization_fpn=None,
+ conv_kernel_size=1,
+ norm_cfg=dict(type='GN', num_groups=32),
+ semantic_fpn=True,
+ train_cfg=None,
+ num_classes=80,
+ xavier_init_kernel=False,
+ kernel_init_std=0.01,
+ use_binary=False,
+ proposal_feats_with_obj=False,
+ loss_mask=None,
+ loss_seg=None,
+ loss_cls=None,
+ loss_dice=None,
+ loss_rank=None,
+ feat_downsample_stride=1,
+ feat_refine_stride=1,
+ feat_refine=True,
+ with_embed=False,
+ feat_embed_only=False,
+ conv_normal_init=False,
+ mask_out_stride=4,
+ hard_target=False,
+ num_thing_classes=80,
+ num_stuff_classes=53,
+ mask_assign_stride=4,
+ ignore_label=255,
+ thing_label_in_seg=0,
+ cat_stuff_mask=False,
+ **kwargs):
+ super().__init__()
+ self.num_proposals = num_proposals
+ self.num_cls_fcs = num_cls_fcs
+ self.train_cfg = train_cfg
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.num_classes = num_classes
+ self.proposal_feats_with_obj = proposal_feats_with_obj
+ self.sampling = False
+ self.localization_fpn = build_neck(localization_fpn)
+ self.semantic_fpn = semantic_fpn
+ self.norm_cfg = norm_cfg
+ self.num_heads = num_heads
+ self.att_dropout = att_dropout
+ self.mask_out_stride = mask_out_stride
+ self.hard_target = hard_target
+ self.conv_kernel_size = conv_kernel_size
+ self.xavier_init_kernel = xavier_init_kernel
+ self.kernel_init_std = kernel_init_std
+ self.feat_downsample_stride = feat_downsample_stride
+ self.feat_refine_stride = feat_refine_stride
+ self.conv_normal_init = conv_normal_init
+ self.feat_refine = feat_refine
+ self.with_embed = with_embed
+ self.feat_embed_only = feat_embed_only
+ self.num_loc_convs = num_loc_convs
+ self.num_seg_convs = num_seg_convs
+ self.use_binary = use_binary
+ self.num_thing_classes = num_thing_classes
+ self.num_stuff_classes = num_stuff_classes
+ self.mask_assign_stride = mask_assign_stride
+ self.ignore_label = ignore_label
+ self.thing_label_in_seg = thing_label_in_seg
+ self.cat_stuff_mask = cat_stuff_mask
+
+ if loss_mask is not None:
+ self.loss_mask = build_loss(loss_mask)
+ else:
+ self.loss_mask = loss_mask
+
+ if loss_dice is not None:
+ self.loss_dice = build_loss(loss_dice)
+ else:
+ self.loss_dice = loss_dice
+
+ if loss_seg is not None:
+ self.loss_seg = build_loss(loss_seg)
+ else:
+ self.loss_seg = loss_seg
+ if loss_cls is not None:
+ self.loss_cls = build_loss(loss_cls)
+ else:
+ self.loss_cls = loss_cls
+
+ if loss_rank is not None:
+ self.loss_rank = build_loss(loss_rank)
+ else:
+ self.loss_rank = loss_rank
+
+ if self.train_cfg:
+ self.assigner = build_assigner(self.train_cfg.assigner)
+ # use PseudoSampler when sampling is False
+ if self.sampling and hasattr(self.train_cfg, 'sampler'):
+ sampler_cfg = self.train_cfg.sampler
+ else:
+ sampler_cfg = dict(type='MaskPseudoSampler')
+ self.sampler = build_sampler(sampler_cfg, context=self)
+ self._init_layers()
+
+ def _init_layers(self):
+ """Initialize a sparse set of proposal boxes and proposal features."""
+ self.init_kernels = nn.Conv2d(
+ self.out_channels,
+ self.num_proposals,
+ self.conv_kernel_size,
+ padding=int(self.conv_kernel_size // 2),
+ bias=False)
+
+ if self.semantic_fpn:
+ if self.loss_seg.use_sigmoid:
+ self.conv_seg = nn.Conv2d(self.out_channels, self.num_classes,
+ 1)
+ else:
+ self.conv_seg = nn.Conv2d(self.out_channels,
+ self.num_classes + 1, 1)
+
+ if self.feat_downsample_stride > 1 and self.feat_refine:
+ self.ins_downsample = ConvModule(
+ self.in_channels,
+ self.out_channels,
+ 3,
+ stride=self.feat_refine_stride,
+ padding=1,
+ norm_cfg=self.norm_cfg)
+ self.seg_downsample = ConvModule(
+ self.in_channels,
+ self.out_channels,
+ 3,
+ stride=self.feat_refine_stride,
+ padding=1,
+ norm_cfg=self.norm_cfg)
+
+ self.loc_convs = nn.ModuleList()
+ for i in range(self.num_loc_convs):
+ self.loc_convs.append(
+ ConvModule(
+ self.in_channels,
+ self.out_channels,
+ 1,
+ norm_cfg=self.norm_cfg))
+
+ self.seg_convs = nn.ModuleList()
+ for i in range(self.num_seg_convs):
+ self.seg_convs.append(
+ ConvModule(
+ self.in_channels,
+ self.out_channels,
+ 1,
+ norm_cfg=self.norm_cfg))
+
+ def init_weights(self):
+ self.localization_fpn.init_weights()
+
+ if self.feat_downsample_stride > 1 and self.conv_normal_init:
+ logger = get_root_logger()
+ logger.info('Initialize convs in KPN head by normal std 0.01')
+ for conv in [self.loc_convs, self.seg_convs]:
+ for m in conv.modules():
+ if isinstance(m, nn.Conv2d):
+ normal_init(m, std=0.01)
+
+ if self.semantic_fpn:
+ bias_seg = bias_init_with_prob(0.01)
+ if self.loss_seg.use_sigmoid:
+ normal_init(self.conv_seg, std=0.01, bias=bias_seg)
+ else:
+ normal_init(self.conv_seg, mean=0, std=0.01)
+ if self.xavier_init_kernel:
+ logger = get_root_logger()
+ logger.info('Initialize kernels by xavier uniform')
+ nn.init.xavier_uniform_(self.init_kernels.weight)
+ else:
+ logger = get_root_logger()
+ logger.info(
+ f'Initialize kernels by normal std: {self.kernel_init_std}')
+ normal_init(self.init_kernels, mean=0, std=self.kernel_init_std)
+
+ def _decode_init_proposals(self, img, img_metas, ref_img_metas):
+ num_imgs = len(img_metas)
+ num_frames = len(ref_img_metas[0])
+
+ if self.localization_fpn.__class__.__name__.endswith('3D'):
+ localization_feats = self.localization_fpn(img, num_imgs,
+ num_frames)
+ else:
+ localization_feats = self.localization_fpn(img)
+ if isinstance(localization_feats, list):
+ loc_feats = localization_feats[0]
+ else:
+ loc_feats = localization_feats
+ for conv in self.loc_convs:
+ loc_feats = conv(loc_feats)
+ if self.feat_downsample_stride > 1 and self.feat_refine:
+ loc_feats = self.ins_downsample(loc_feats)
+ mask_preds = self.init_kernels(loc_feats)
+
+ if self.semantic_fpn:
+ if isinstance(localization_feats, list):
+ semantic_feats = localization_feats[1]
+ else:
+ semantic_feats = localization_feats
+ for conv in self.seg_convs:
+ semantic_feats = conv(semantic_feats)
+ if self.feat_downsample_stride > 1 and self.feat_refine:
+ semantic_feats = self.seg_downsample(semantic_feats)
+ else:
+ semantic_feats = None
+
+ if semantic_feats is not None:
+ seg_preds = self.conv_seg(semantic_feats)
+ else:
+ seg_preds = None
+
+ proposal_feats = self.init_kernels.weight.clone()
+ proposal_feats = proposal_feats[None].expand(num_imgs * num_frames,
+ *proposal_feats.size())
+
+ if semantic_feats is not None:
+ x_feats = semantic_feats + loc_feats
+ else:
+ x_feats = loc_feats
+
+ if self.proposal_feats_with_obj:
+ sigmoid_masks = mask_preds.sigmoid()
+ nonzero_inds = sigmoid_masks > 0.5
+ if self.use_binary:
+ sigmoid_masks = nonzero_inds.float()
+ else:
+ sigmoid_masks = nonzero_inds.float() * sigmoid_masks
+ obj_feats = torch.einsum('bnhw,bchw->bnc', sigmoid_masks, x_feats)
+
+ cls_scores = None
+
+ if self.proposal_feats_with_obj:
+ proposal_feats = proposal_feats + obj_feats.view(
+ num_imgs * num_frames, self.num_proposals, self.out_channels,
+ 1, 1)
+
+ if self.cat_stuff_mask and not self.training:
+ mask_preds = torch.cat(
+ [mask_preds, seg_preds[:, self.num_thing_classes:]], dim=1)
+ stuff_kernels = self.conv_seg.weight[self.
+ num_thing_classes:].clone()
+ stuff_kernels = stuff_kernels[None].expand(num_imgs * num_frames,
+ *stuff_kernels.size())
+ proposal_feats = torch.cat([proposal_feats, stuff_kernels], dim=1)
+
+ return proposal_feats, x_feats, mask_preds, cls_scores, seg_preds
+
+ def forward_train(self,
+ img,
+ img_metas,
+ ref_img_metas,
+ gt_masks,
+ gt_labels,
+ gt_instance_ids=None,
+ gt_sem_seg=None,
+ gt_sem_cls=None):
+ """Forward function in training stage."""
+ num_imgs = len(img_metas)
+ num_frames = len(ref_img_metas[0])
+ results = self._decode_init_proposals(img, img_metas, ref_img_metas)
+ (proposal_feats, x_feats, mask_preds, cls_scores, seg_preds) = results
+ if self.feat_downsample_stride > 1:
+ scaled_mask_preds = F.interpolate(
+ mask_preds,
+ scale_factor=self.feat_downsample_stride,
+ mode='bilinear',
+ align_corners=False)
+ if seg_preds is not None:
+ scaled_seg_preds = F.interpolate(
+ seg_preds,
+ scale_factor=self.feat_downsample_stride,
+ mode='bilinear',
+ align_corners=False)
+ else:
+ scaled_mask_preds = mask_preds
+ scaled_seg_preds = seg_preds
+
+ if self.hard_target:
+ gt_masks = [x.bool().float() for x in gt_masks]
+ else:
+ gt_masks = gt_masks
+
+ sampling_results = []
+ if cls_scores is None:
+ detached_cls_scores = [[None] * num_frames] * num_imgs
+ else:
+ detached_cls_scores = cls_scores.detach()
+
+ for i in range(num_imgs):
+ for j in range(num_frames):
+ assign_result = self.assigner.assign(
+ scaled_mask_preds[i * num_frames + j].detach(),
+ detached_cls_scores[i][j], gt_masks[i][j],
+ gt_labels[i][:,
+ 1][gt_labels[i][:,
+ 0] == j], ref_img_metas[i][j])
+ sampling_result = self.sampler.sample(
+ assign_result, scaled_mask_preds[i * num_frames + j],
+ gt_masks[i][j])
+ sampling_results.append(sampling_result)
+
+ mask_targets = self.get_targets(
+ sampling_results,
+ self.train_cfg,
+ True,
+ gt_sem_seg=gt_sem_seg,
+ gt_sem_cls=gt_sem_cls)
+
+ losses = self.loss(scaled_mask_preds, cls_scores, scaled_seg_preds,
+ proposal_feats, *mask_targets)
+
+ if self.cat_stuff_mask and self.training:
+ mask_preds = torch.cat(
+ [mask_preds, seg_preds[:, self.num_thing_classes:]], dim=1)
+ stuff_kernels = self.conv_seg.weight[self.
+ num_thing_classes:].clone()
+ stuff_kernels = stuff_kernels[None].expand(num_imgs * num_frames,
+ *stuff_kernels.size())
+ proposal_feats = torch.cat([proposal_feats, stuff_kernels], dim=1)
+
+ return losses, proposal_feats, x_feats, mask_preds, cls_scores
+
+ def loss(self,
+ mask_pred,
+ cls_scores,
+ seg_preds,
+ proposal_feats,
+ labels,
+ label_weights,
+ mask_targets,
+ mask_weights,
+ seg_targets,
+ reduction_override=None,
+ **kwargs):
+ losses = dict()
+ bg_class_ind = self.num_classes
+ # note in spare rcnn num_gt == num_pos
+ pos_inds = (labels >= 0) & (labels < bg_class_ind)
+ num_preds = mask_pred.shape[0] * mask_pred.shape[1]
+
+ if cls_scores is not None:
+ num_pos = pos_inds.sum().float()
+ avg_factor = reduce_mean(num_pos)
+ assert mask_pred.shape[0] == cls_scores.shape[0]
+ assert mask_pred.shape[1] == cls_scores.shape[1]
+ losses['loss_rpn_cls'] = self.loss_cls(
+ cls_scores.view(num_preds, -1),
+ labels,
+ label_weights,
+ avg_factor=avg_factor,
+ reduction_override=reduction_override)
+ losses['rpn_pos_acc'] = accuracy(
+ cls_scores.view(num_preds, -1)[pos_inds], labels[pos_inds])
+
+ bool_pos_inds = pos_inds.type(torch.bool)
+ # 0~self.num_classes-1 are FG, self.num_classes is BG
+ # do not perform bounding box regression for BG anymore.
+ H, W = mask_pred.shape[-2:]
+ if pos_inds.any():
+ pos_mask_pred = mask_pred.reshape(num_preds, H, W)[bool_pos_inds]
+ pos_mask_targets = mask_targets[bool_pos_inds]
+ losses['loss_rpn_mask'] = self.loss_mask(pos_mask_pred,
+ pos_mask_targets)
+ losses['loss_rpn_dice'] = self.loss_dice(pos_mask_pred,
+ pos_mask_targets)
+
+ if self.loss_rank is not None:
+ batch_size = mask_pred.size(0)
+ rank_target = mask_targets.new_full((batch_size, H, W),
+ self.ignore_label,
+ dtype=torch.long)
+ rank_inds = pos_inds.view(batch_size,
+ -1).nonzero(as_tuple=False)
+ batch_mask_targets = mask_targets.view(batch_size, -1, H,
+ W).bool()
+ for i in range(batch_size):
+ curr_inds = (rank_inds[:, 0] == i)
+ curr_rank = rank_inds[:, 1][curr_inds]
+ for j in curr_rank:
+ rank_target[i][batch_mask_targets[i][j]] = j
+ losses['loss_rpn_rank'] = self.loss_rank(
+ mask_pred, rank_target, ignore_index=self.ignore_label)
+
+ else:
+ losses['loss_rpn_mask'] = mask_pred.sum() * 0
+ losses['loss_rpn_dice'] = mask_pred.sum() * 0
+ if self.loss_rank is not None:
+ losses['loss_rank'] = mask_pred.sum() * 0
+
+ if seg_preds is not None:
+ if self.loss_seg.use_sigmoid:
+ cls_channel = seg_preds.shape[1]
+ flatten_seg = seg_preds.view(-1, cls_channel, H * W).permute(
+ 0, 2, 1).reshape(-1, cls_channel)
+ flatten_seg_target = seg_targets.view(-1)
+ num_dense_pos = (flatten_seg_target >= 0) & (
+ flatten_seg_target < bg_class_ind)
+ num_dense_pos = num_dense_pos.sum().float().clamp(min=1.0)
+ losses['loss_rpn_seg'] = self.loss_seg(
+ flatten_seg, flatten_seg_target, avg_factor=num_dense_pos)
+ else:
+ cls_channel = seg_preds.shape[1]
+ flatten_seg = seg_preds.view(-1, cls_channel, H * W).permute(
+ 0, 2, 1).reshape(-1, cls_channel)
+ flatten_seg_target = seg_targets.view(-1)
+ losses['loss_rpn_seg'] = self.loss_seg(flatten_seg,
+ flatten_seg_target)
+
+ return losses
+
+ def _get_target_single(self, pos_inds, neg_inds, pos_mask, neg_mask,
+ pos_gt_mask, pos_gt_labels, gt_sem_seg, gt_sem_cls,
+ cfg):
+ num_pos = pos_mask.size(0)
+ num_neg = neg_mask.size(0)
+ num_samples = num_pos + num_neg
+ H, W = pos_mask.shape[-2:]
+ # original implementation uses new_zeros since BG are set to be 0
+ # now use empty & fill because BG cat_id = num_classes,
+ # FG cat_id = [0, num_classes-1]
+ labels = pos_mask.new_full((num_samples, ),
+ self.num_classes,
+ dtype=torch.long)
+ label_weights = pos_mask.new_zeros(num_samples)
+ mask_targets = pos_mask.new_zeros(num_samples, H, W)
+ mask_weights = pos_mask.new_zeros(num_samples, H, W)
+ seg_targets = pos_mask.new_full((H, W),
+ self.num_classes,
+ dtype=torch.long)
+
+ if gt_sem_cls is not None and gt_sem_seg is not None:
+ gt_sem_seg = gt_sem_seg.bool()
+ for sem_mask, sem_cls in zip(gt_sem_seg, gt_sem_cls):
+ seg_targets[sem_mask] = sem_cls.long()
+
+ if num_pos > 0:
+ labels[pos_inds] = pos_gt_labels
+ pos_weight = 1.0 if cfg.pos_weight <= 0 else cfg.pos_weight
+ label_weights[pos_inds] = pos_weight
+ mask_targets[pos_inds, ...] = pos_gt_mask
+ mask_weights[pos_inds, ...] = 1
+ for i in range(num_pos):
+ seg_targets[pos_gt_mask[i].bool()] = pos_gt_labels[i]
+
+ if num_neg > 0:
+ label_weights[neg_inds] = 1.0
+
+ return labels, label_weights, mask_targets, mask_weights, seg_targets
+
+ def get_targets(self,
+ sampling_results,
+ rpn_train_cfg,
+ concat=True,
+ gt_sem_seg=None,
+ gt_sem_cls=None):
+ num_imgs = len(sampling_results)
+ pos_inds_list = [res.pos_inds for res in sampling_results]
+ neg_inds_list = [res.neg_inds for res in sampling_results]
+ pos_mask_list = [res.pos_masks for res in sampling_results]
+ neg_mask_list = [res.neg_masks for res in sampling_results]
+ pos_gt_mask_list = [res.pos_gt_masks for res in sampling_results]
+ pos_gt_labels_list = [res.pos_gt_labels for res in sampling_results]
+ if gt_sem_seg is None:
+ gt_sem_seg = [None] * num_imgs
+ gt_sem_cls = [None] * num_imgs
+ results = multi_apply(
+ self._get_target_single,
+ pos_inds_list,
+ neg_inds_list,
+ pos_mask_list,
+ neg_mask_list,
+ pos_gt_mask_list,
+ pos_gt_labels_list,
+ gt_sem_seg,
+ gt_sem_cls,
+ cfg=rpn_train_cfg)
+ (labels, label_weights, mask_targets, mask_weights,
+ seg_targets) = results
+ if concat:
+ labels = torch.cat(labels, 0)
+ label_weights = torch.cat(label_weights, 0)
+ mask_targets = torch.cat(mask_targets, 0)
+ mask_weights = torch.cat(mask_weights, 0)
+ seg_targets = torch.stack(seg_targets, 0)
+ return labels, label_weights, mask_targets, mask_weights, seg_targets
+
+ def simple_test_rpn(self, img, img_metas, ref_img_metas):
+ """Forward function in testing stage."""
+ return self._decode_init_proposals(img, img_metas, ref_img_metas)
+
+ def forward_dummy(self, img, img_metas, ref_img_metas):
+ """Dummy forward function.
+
+ Used in flops calculation.
+ """
+ return self._decode_init_proposals(img, img_metas, ref_img_metas)
diff --git a/modelscope/models/cv/video_instance_segmentation/head/kernel_iter_head.py b/modelscope/models/cv/video_instance_segmentation/head/kernel_iter_head.py
new file mode 100644
index 00000000..625c0b13
--- /dev/null
+++ b/modelscope/models/cv/video_instance_segmentation/head/kernel_iter_head.py
@@ -0,0 +1,412 @@
+# The implementation is adopted from Video-K-Net,
+# made publicly available at https://github.com/lxtGH/Video-K-Net follow the MIT license
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmdet.core import build_assigner, build_sampler
+from mmdet.datasets.coco_panoptic import INSTANCE_OFFSET
+from mmdet.models.builder import HEADS, build_head
+from mmdet.models.roi_heads import BaseRoIHead
+
+
+@HEADS.register_module()
+class KernelIterHeadVideo(BaseRoIHead):
+
+ def __init__(self,
+ num_stages=6,
+ recursive=False,
+ assign_stages=5,
+ stage_loss_weights=(1, 1, 1, 1, 1, 1),
+ proposal_feature_channel=256,
+ merge_cls_scores=False,
+ do_panoptic=False,
+ post_assign=False,
+ hard_target=False,
+ num_proposals=100,
+ num_thing_classes=80,
+ num_stuff_classes=53,
+ mask_assign_stride=4,
+ thing_label_in_seg=0,
+ mask_head=dict(
+ type='KernelUpdateHead',
+ num_classes=80,
+ num_fcs=2,
+ num_heads=8,
+ num_cls_fcs=1,
+ num_reg_fcs=3,
+ feedforward_channels=2048,
+ hidden_channels=256,
+ dropout=0.0,
+ roi_feat_size=7,
+ ffn_act_cfg=dict(type='ReLU', inplace=True)),
+ mask_out_stride=4,
+ train_cfg=None,
+ test_cfg=None,
+ **kwargs):
+ assert mask_head is not None
+ assert len(stage_loss_weights) == num_stages
+ self.num_stages = num_stages
+ self.stage_loss_weights = stage_loss_weights
+ self.proposal_feature_channel = proposal_feature_channel
+ self.merge_cls_scores = merge_cls_scores
+ self.recursive = recursive
+ self.post_assign = post_assign
+ self.mask_out_stride = mask_out_stride
+ self.hard_target = hard_target
+ self.assign_stages = assign_stages
+ self.do_panoptic = do_panoptic
+ self.num_thing_classes = num_thing_classes
+ self.num_stuff_classes = num_stuff_classes
+ self.num_classes = num_thing_classes + num_stuff_classes
+ self.mask_assign_stride = mask_assign_stride
+ self.thing_label_in_seg = thing_label_in_seg
+ self.num_proposals = num_proposals
+ super().__init__(
+ mask_head=mask_head,
+ train_cfg=train_cfg,
+ test_cfg=test_cfg,
+ **kwargs)
+
+ def init_bbox_head(self, mask_roi_extractor, mask_head):
+ """Initialize box head and box roi extractor.
+
+ Args:
+ mask_roi_extractor (dict): Config of box roi extractor.
+ mask_head (dict): Config of box in box head.
+ """
+ pass
+
+ def init_assigner_sampler(self):
+ """Initialize assigner and sampler for each stage."""
+ self.mask_assigner = []
+ self.mask_sampler = []
+ if self.train_cfg is not None:
+ for idx, rcnn_train_cfg in enumerate(self.train_cfg):
+ self.mask_assigner.append(
+ build_assigner(rcnn_train_cfg.assigner))
+ self.current_stage = idx
+ self.mask_sampler.append(
+ build_sampler(rcnn_train_cfg.sampler, context=self))
+
+ def init_weights(self):
+ for i in range(self.num_stages):
+ self.mask_head[i].init_weights()
+
+ def init_mask_head(self, mask_roi_extractor, mask_head):
+ """Initialize mask head and mask roi extractor.
+
+ Args:
+ mask_roi_extractor (dict): Config of mask roi extractor.
+ mask_head (dict): Config of mask in mask head.
+ """
+ self.mask_head = nn.ModuleList()
+ if not isinstance(mask_head, list):
+ mask_head = [mask_head for _ in range(self.num_stages)]
+ assert len(mask_head) == self.num_stages
+ for head in mask_head:
+ self.mask_head.append(build_head(head))
+ if self.recursive:
+ for i in range(self.num_stages):
+ self.mask_head[i] = self.mask_head[0]
+
+ def _mask_forward(self,
+ stage,
+ x,
+ object_feats,
+ mask_preds,
+ img_metas=None):
+ mask_head = self.mask_head[stage]
+ cls_score, mask_preds, object_feats = mask_head(
+ x, object_feats, mask_preds, img_metas=img_metas)
+ if mask_head.mask_upsample_stride > 1 and (stage == self.num_stages - 1
+ or self.training):
+ scaled_mask_preds = F.interpolate(
+ mask_preds,
+ scale_factor=mask_head.mask_upsample_stride,
+ align_corners=False,
+ mode='bilinear')
+ else:
+ scaled_mask_preds = mask_preds
+
+ mask_results = dict(
+ cls_score=cls_score,
+ mask_preds=mask_preds,
+ scaled_mask_preds=scaled_mask_preds,
+ object_feats=object_feats)
+ return mask_results
+
+ def forward_train(self,
+ x,
+ proposal_feats,
+ mask_preds,
+ cls_score,
+ ref_img_metas,
+ gt_masks,
+ gt_labels,
+ gt_bboxes_ignore=None,
+ imgs_whwh=None,
+ gt_bboxes=None,
+ gt_sem_seg=None,
+ gt_sem_cls=None):
+
+ num_imgs = len(ref_img_metas)
+ num_frames = len(ref_img_metas[0])
+ if self.mask_head[0].mask_upsample_stride > 1:
+ prev_mask_preds = F.interpolate(
+ mask_preds.detach(),
+ scale_factor=self.mask_head[0].mask_upsample_stride,
+ mode='bilinear',
+ align_corners=False)
+ else:
+ prev_mask_preds = mask_preds.detach()
+
+ if cls_score is not None:
+ prev_cls_score = cls_score.detach()
+ else:
+ prev_cls_score = None
+
+ if self.hard_target:
+ gt_masks = [x.bool().float() for x in gt_masks]
+ else:
+ gt_masks = gt_masks
+
+ object_feats = proposal_feats
+ all_stage_loss = {}
+ all_stage_mask_results = []
+ assign_results = []
+ for stage in range(self.num_stages):
+ mask_results = self._mask_forward(
+ stage, x, object_feats, mask_preds, img_metas=None)
+ all_stage_mask_results.append(mask_results)
+ mask_preds = mask_results['mask_preds']
+ scaled_mask_preds = mask_results['scaled_mask_preds']
+ cls_score = mask_results['cls_score']
+ object_feats = mask_results['object_feats']
+
+ if self.post_assign:
+ prev_mask_preds = scaled_mask_preds.detach()
+ prev_cls_score = cls_score.detach()
+
+ sampling_results = []
+ if stage < self.assign_stages:
+ assign_results = []
+ for i in range(num_imgs):
+ for j in range(num_frames):
+ if stage < self.assign_stages:
+ mask_for_assign = prev_mask_preds[
+ i * num_frames + j][:self.num_proposals]
+ if prev_cls_score is not None:
+ cls_for_assign = prev_cls_score[
+ i * num_frames + j][:self.num_proposals, :self.
+ num_thing_classes]
+ else:
+ cls_for_assign = None
+ assign_result = self.mask_assigner[stage].assign(
+ mask_for_assign,
+ cls_for_assign,
+ gt_masks[i][j],
+ gt_labels[i][:, 1][gt_labels[i][:, 0] == j],
+ img_meta=None)
+ assign_results.append(assign_result)
+ sampling_result = self.mask_sampler[stage].sample(
+ assign_results[i * num_frames + j],
+ scaled_mask_preds[i * num_frames + j], gt_masks[i][j])
+ sampling_results.append(sampling_result)
+ mask_targets = self.mask_head[stage].get_targets(
+ sampling_results,
+ self.train_cfg[stage],
+ True,
+ gt_sem_seg=gt_sem_seg,
+ gt_sem_cls=gt_sem_cls)
+
+ single_stage_loss = self.mask_head[stage].loss(
+ object_feats,
+ cls_score,
+ scaled_mask_preds,
+ *mask_targets,
+ imgs_whwh=imgs_whwh)
+ for key, value in single_stage_loss.items():
+ all_stage_loss[
+ f's{stage}_{key}'] = value * self.stage_loss_weights[stage]
+
+ if not self.post_assign:
+ prev_mask_preds = scaled_mask_preds.detach()
+ prev_cls_score = cls_score.detach()
+
+ bs_nf, num_query, c, ks1, ks2 = object_feats.size()
+ bs_nf2, c2, h, w = x.size()
+ assert ks1 == ks2
+ assert bs_nf == bs_nf2
+ assert bs_nf == num_frames * num_imgs
+ assert c == c2
+ features = {
+ 'obj_feats':
+ object_feats.reshape(
+ (num_imgs, num_frames, num_query, c, ks1, ks2)),
+ # "x_feats":self.mask_head[-1].feat_transform(x).reshape((num_imgs, num_frames, c, h, w)),
+ 'x_feats':
+ x.reshape((num_imgs, num_frames, c, h, w)),
+ 'cls_scores':
+ cls_score.reshape(
+ (num_imgs, num_frames, num_query, self.num_classes)),
+ 'masks':
+ mask_preds.reshape((num_imgs, num_frames, num_query, h, w)),
+ }
+ return all_stage_loss, features
+
+ def simple_test(self,
+ x,
+ proposal_feats,
+ mask_preds,
+ cls_score,
+ img_metas,
+ ref_img_metas,
+ imgs_whwh=None,
+ rescale=False):
+
+ # Decode initial proposals
+ num_imgs = len(ref_img_metas)
+ num_frames = len(ref_img_metas[0])
+ # num_proposals = proposal_feats.size(1)
+
+ object_feats = proposal_feats
+ for stage in range(self.num_stages):
+ mask_results = self._mask_forward(stage, x, object_feats,
+ mask_preds)
+ object_feats = mask_results['object_feats']
+ cls_score = mask_results['cls_score']
+ mask_preds = mask_results['mask_preds']
+ scaled_mask_preds = mask_results['scaled_mask_preds']
+
+ num_classes = self.mask_head[-1].num_classes
+ results = []
+
+ if self.mask_head[-1].loss_cls.use_sigmoid:
+ cls_score = cls_score.sigmoid()
+ else:
+ cls_score = cls_score.softmax(-1)[..., :-1]
+
+ bs_nf, num_query, c, ks1, ks2 = object_feats.size()
+ bs_nf2, c2, h, w = x.size()
+ assert ks1 == ks2
+ assert bs_nf == bs_nf2
+ assert bs_nf == num_frames * num_imgs
+ assert c == c2
+ features = {
+ 'obj_feats':
+ object_feats.reshape(
+ (num_imgs, num_frames, num_query, c, ks1, ks2)),
+ # "x_feats":self.mask_head[-1].feat_transform(x).reshape((num_imgs, num_frames, c, h, w)),
+ 'x_feats':
+ x.reshape((num_imgs, num_frames, c, h, w)),
+ 'cls_scores':
+ cls_score.reshape(
+ (num_imgs, num_frames, num_query, self.num_classes)),
+ 'masks':
+ mask_preds.reshape((num_imgs, num_frames, num_query, h, w)),
+ }
+
+ if self.do_panoptic:
+ raise NotImplementedError
+ # for img_id in range(num_imgs):
+ # single_result = self.get_panoptic(cls_score[img_id],
+ # scaled_mask_preds[img_id],
+ # self.test_cfg,
+ # ref_img_metas[img_id])
+ # results.append(single_result)
+ else:
+ for img_id in range(num_imgs):
+ for frame_id in range(num_frames):
+ cls_score_per_img = cls_score[img_id * num_frames
+ + frame_id]
+ # h, quite tricky here, a bounding box can predict multiple results with different labels
+ scores_per_img, topk_indices = cls_score_per_img.flatten(
+ 0, 1).topk(
+ self.test_cfg['max_per_img'], sorted=True)
+ mask_indices = topk_indices // num_classes
+ # Use the following when torch >= 1.9.0
+ # mask_indices = torch.div(topk_indices, num_classes, rounding_mode='floor')
+ labels_per_img = topk_indices % num_classes
+ masks_per_img = scaled_mask_preds[img_id * num_frames
+ + frame_id][mask_indices]
+ single_result = self.mask_head[-1].get_seg_masks(
+ masks_per_img, labels_per_img, scores_per_img,
+ self.test_cfg, img_metas[img_id])
+ results.append(single_result)
+ return results, features
+
+ def aug_test(self, features, proposal_list, img_metas, rescale=False):
+ raise NotImplementedError('SparseMask does not support `aug_test`')
+
+ def forward_dummy(self, x, proposal_boxes, proposal_feats, img_metas):
+ """Dummy forward function when do the flops computing."""
+ all_stage_mask_results = []
+ num_imgs = len(img_metas)
+ num_proposals = proposal_feats.size(1)
+ C, H, W = x.shape[-3:]
+ mask_preds = proposal_feats.bmm(x.view(num_imgs, C, -1)).view(
+ num_imgs, num_proposals, H, W)
+ object_feats = proposal_feats
+ for stage in range(self.num_stages):
+ mask_results = self._mask_forward(stage, x, object_feats,
+ mask_preds, img_metas)
+ all_stage_mask_results.append(mask_results)
+ return all_stage_mask_results
+
+ def get_panoptic(self, cls_scores, mask_preds, test_cfg, img_meta):
+ # resize mask predictions back
+ scores = cls_scores[:self.num_proposals][:, :self.num_thing_classes]
+ thing_scores, thing_labels = scores.max(dim=1)
+ stuff_scores = cls_scores[
+ self.num_proposals:][:, self.num_thing_classes:].diag()
+ stuff_labels = torch.arange(
+ 0, self.num_stuff_classes) + self.num_thing_classes
+ stuff_labels = stuff_labels.to(thing_labels.device)
+
+ total_masks = self.mask_head[-1].rescale_masks(mask_preds, img_meta)
+ total_scores = torch.cat([thing_scores, stuff_scores], dim=0)
+ total_labels = torch.cat([thing_labels, stuff_labels], dim=0)
+
+ panoptic_result = self.merge_stuff_thing(total_masks, total_labels,
+ total_scores,
+ test_cfg.merge_stuff_thing)
+ return dict(pan_results=panoptic_result)
+
+ def merge_stuff_thing(self,
+ total_masks,
+ total_labels,
+ total_scores,
+ merge_cfg=None):
+
+ H, W = total_masks.shape[-2:]
+ panoptic_seg = total_masks.new_full((H, W),
+ self.num_classes,
+ dtype=torch.long)
+
+ cur_prob_masks = total_scores.view(-1, 1, 1) * total_masks
+ cur_mask_ids = cur_prob_masks.argmax(0)
+
+ # sort instance outputs by scores
+ sorted_inds = torch.argsort(-total_scores)
+ current_segment_id = 0
+
+ for k in sorted_inds:
+ pred_class = total_labels[k].item()
+ isthing = pred_class < self.num_thing_classes
+ if isthing and total_scores[k] < merge_cfg.instance_score_thr:
+ continue
+
+ mask = cur_mask_ids == k
+ mask_area = mask.sum().item()
+ original_area = (total_masks[k] >= 0.5).sum().item()
+
+ if mask_area > 0 and original_area > 0:
+ if mask_area / original_area < merge_cfg.overlap_thr:
+ continue
+
+ panoptic_seg[mask] = total_labels[k] \
+ + current_segment_id * INSTANCE_OFFSET
+ current_segment_id += 1
+
+ return panoptic_seg.cpu().numpy()
diff --git a/modelscope/models/cv/video_instance_segmentation/head/kernel_update_head.py b/modelscope/models/cv/video_instance_segmentation/head/kernel_update_head.py
new file mode 100644
index 00000000..0cf5b6d9
--- /dev/null
+++ b/modelscope/models/cv/video_instance_segmentation/head/kernel_update_head.py
@@ -0,0 +1,504 @@
+# The implementation is adopted from Video-K-Net,
+# made publicly available at https://github.com/lxtGH/Video-K-Net follow the MIT license
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import (ConvModule, bias_init_with_prob, build_activation_layer,
+ build_norm_layer)
+from mmcv.cnn.bricks.transformer import (FFN, MultiheadAttention,
+ build_transformer_layer)
+from mmcv.runner import force_fp32
+from mmdet.core import multi_apply
+from mmdet.models.builder import HEADS, build_loss
+from mmdet.models.dense_heads.atss_head import reduce_mean
+from mmdet.models.losses import accuracy
+from mmdet.utils import get_root_logger
+
+from ..utils import outs2results
+
+
+@HEADS.register_module()
+class KernelUpdateHead(nn.Module):
+
+ def __init__(self,
+ num_classes=80,
+ num_ffn_fcs=2,
+ num_heads=8,
+ num_cls_fcs=1,
+ num_mask_fcs=3,
+ feedforward_channels=2048,
+ in_channels=256,
+ out_channels=256,
+ dropout=0.0,
+ mask_thr=0.5,
+ act_cfg=dict(type='ReLU', inplace=True),
+ ffn_act_cfg=dict(type='ReLU', inplace=True),
+ conv_kernel_size=3,
+ feat_transform_cfg=None,
+ hard_mask_thr=0.5,
+ kernel_init=False,
+ with_ffn=True,
+ mask_out_stride=4,
+ relative_coors=False,
+ relative_coors_off=False,
+ feat_gather_stride=1,
+ mask_transform_stride=1,
+ mask_upsample_stride=1,
+ num_thing_classes=80,
+ num_stuff_classes=53,
+ mask_assign_stride=4,
+ ignore_label=255,
+ thing_label_in_seg=0,
+ kernel_updator_cfg=dict(
+ type='DynamicConv',
+ in_channels=256,
+ feat_channels=64,
+ out_channels=256,
+ input_feat_shape=1,
+ act_cfg=dict(type='ReLU', inplace=True),
+ norm_cfg=dict(type='LN')),
+ loss_rank=None,
+ loss_mask=dict(
+ type='CrossEntropyLoss', use_mask=True, loss_weight=1.0),
+ loss_dice=dict(type='DiceLoss', loss_weight=3.0),
+ loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=2.0)):
+ super(KernelUpdateHead, self).__init__()
+ self.num_classes = num_classes
+ self.loss_cls = build_loss(loss_cls)
+ self.loss_mask = build_loss(loss_mask)
+ self.loss_dice = build_loss(loss_dice)
+ if loss_rank is not None:
+ self.loss_rank = build_loss(loss_rank)
+ else:
+ self.loss_rank = loss_rank
+
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.mask_thr = mask_thr
+ self.fp16_enabled = False
+ self.dropout = dropout
+
+ self.num_heads = num_heads
+ self.hard_mask_thr = hard_mask_thr
+ self.kernel_init = kernel_init
+ self.with_ffn = with_ffn
+ self.mask_out_stride = mask_out_stride
+ self.relative_coors = relative_coors
+ self.relative_coors_off = relative_coors_off
+ self.conv_kernel_size = conv_kernel_size
+ self.feat_gather_stride = feat_gather_stride
+ self.mask_transform_stride = mask_transform_stride
+ self.mask_upsample_stride = mask_upsample_stride
+
+ self.num_thing_classes = num_thing_classes
+ self.num_stuff_classes = num_stuff_classes
+ self.mask_assign_stride = mask_assign_stride
+ self.ignore_label = ignore_label
+ self.thing_label_in_seg = thing_label_in_seg
+
+ self.attention = MultiheadAttention(in_channels * conv_kernel_size**2,
+ num_heads, dropout)
+ self.attention_norm = build_norm_layer(
+ dict(type='LN'), in_channels * conv_kernel_size**2)[1]
+
+ self.kernel_update_conv = build_transformer_layer(kernel_updator_cfg)
+
+ if feat_transform_cfg is not None:
+ kernel_size = feat_transform_cfg.pop('kernel_size', 1)
+ self.feat_transform = ConvModule(
+ in_channels,
+ in_channels,
+ kernel_size,
+ stride=feat_gather_stride,
+ padding=int(feat_gather_stride // 2),
+ **feat_transform_cfg)
+ else:
+ self.feat_transform = None
+
+ if self.with_ffn:
+ self.ffn = FFN(
+ in_channels,
+ feedforward_channels,
+ num_ffn_fcs,
+ act_cfg=ffn_act_cfg,
+ ffn_drop=dropout)
+ self.ffn_norm = build_norm_layer(dict(type='LN'), in_channels)[1]
+
+ self.cls_fcs = nn.ModuleList()
+ for _ in range(num_cls_fcs):
+ self.cls_fcs.append(
+ nn.Linear(in_channels, in_channels, bias=False))
+ self.cls_fcs.append(
+ build_norm_layer(dict(type='LN'), in_channels)[1])
+ self.cls_fcs.append(build_activation_layer(act_cfg))
+
+ if self.loss_cls.use_sigmoid:
+ self.fc_cls = nn.Linear(in_channels, self.num_classes)
+ else:
+ self.fc_cls = nn.Linear(in_channels, self.num_classes + 1)
+
+ self.mask_fcs = nn.ModuleList()
+ for _ in range(num_mask_fcs):
+ self.mask_fcs.append(
+ nn.Linear(in_channels, in_channels, bias=False))
+ self.mask_fcs.append(
+ build_norm_layer(dict(type='LN'), in_channels)[1])
+ self.mask_fcs.append(build_activation_layer(act_cfg))
+
+ self.fc_mask = nn.Linear(in_channels, out_channels)
+
+ def init_weights(self):
+ """Use xavier initialization for all weight parameter and set
+ classification head bias as a specific value when use focal loss."""
+ for p in self.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+ else:
+ # adopt the default initialization for
+ # the weight and bias of the layer norm
+ pass
+ if self.loss_cls.use_sigmoid:
+ bias_init = bias_init_with_prob(0.01)
+ nn.init.constant_(self.fc_cls.bias, bias_init)
+ if self.kernel_init:
+ logger = get_root_logger()
+ logger.info(
+ 'mask kernel in mask head is normal initialized by std 0.01')
+ nn.init.normal_(self.fc_mask.weight, mean=0, std=0.01)
+
+ def forward(self,
+ x,
+ proposal_feat,
+ mask_preds,
+ prev_cls_score=None,
+ mask_shape=None,
+ img_metas=None):
+
+ N, num_proposals = proposal_feat.shape[:2]
+ if self.feat_transform is not None:
+ x = self.feat_transform(x)
+ C, H, W = x.shape[-3:]
+
+ mask_h, mask_w = mask_preds.shape[-2:]
+ if mask_h != H or mask_w != W:
+ gather_mask = F.interpolate(
+ mask_preds, (H, W), align_corners=False, mode='bilinear')
+ else:
+ gather_mask = mask_preds
+
+ sigmoid_masks = gather_mask.sigmoid()
+ nonzero_inds = sigmoid_masks > self.hard_mask_thr
+ sigmoid_masks = nonzero_inds.float()
+
+ # einsum is faster than bmm by 30%
+ x_feat = torch.einsum('bnhw,bchw->bnc', sigmoid_masks, x)
+
+ # obj_feat in shape [B, N, C, K, K] -> [B, N, C, K*K] -> [B, N, K*K, C]
+ proposal_feat = proposal_feat.reshape(N, num_proposals,
+ self.in_channels,
+ -1).permute(0, 1, 3, 2)
+ obj_feat = self.kernel_update_conv(x_feat, proposal_feat)
+
+ # [B, N, K*K, C] -> [B, N, K*K*C] -> [N, B, K*K*C]
+ obj_feat = obj_feat.reshape(N, num_proposals, -1).permute(1, 0, 2)
+ obj_feat = self.attention_norm(self.attention(obj_feat))
+ # [N, B, K*K*C] -> [B, N, K*K*C]
+ obj_feat = obj_feat.permute(1, 0, 2)
+
+ # obj_feat in shape [B, N, K*K*C] -> [B, N, K*K, C]
+ obj_feat = obj_feat.reshape(N, num_proposals, -1, self.in_channels)
+
+ # FFN
+ if self.with_ffn:
+ obj_feat = self.ffn_norm(self.ffn(obj_feat))
+
+ cls_feat = obj_feat.sum(-2)
+ mask_feat = obj_feat
+
+ for cls_layer in self.cls_fcs:
+ cls_feat = cls_layer(cls_feat)
+ for reg_layer in self.mask_fcs:
+ mask_feat = reg_layer(mask_feat)
+
+ cls_score = self.fc_cls(cls_feat).view(N, num_proposals, -1)
+ # [B, N, K*K, C] -> [B, N, C, K*K]
+ mask_feat = self.fc_mask(mask_feat).permute(0, 1, 3, 2)
+
+ if (self.mask_transform_stride == 2 and self.feat_gather_stride == 1):
+ mask_x = F.interpolate(
+ x, scale_factor=0.5, mode='bilinear', align_corners=False)
+ H, W = mask_x.shape[-2:]
+ raise NotImplementedError
+ else:
+ mask_x = x
+ # group conv is 5x faster than unfold and uses about 1/5 memory
+ # Group conv vs. unfold vs. concat batch, 2.9ms :13.5ms :3.8ms
+ # Group conv vs. unfold vs. concat batch, 278 : 1420 : 369
+ # fold_x = F.unfold(
+ # mask_x,
+ # self.conv_kernel_size,
+ # padding=int(self.conv_kernel_size // 2))
+ # mask_feat = mask_feat.reshape(N, num_proposals, -1)
+ # new_mask_preds = torch.einsum('bnc,bcl->bnl', mask_feat, fold_x)
+ # [B, N, C, K*K] -> [B*N, C, K, K]
+ mask_feat = mask_feat.reshape(N, num_proposals, C,
+ self.conv_kernel_size,
+ self.conv_kernel_size)
+ # [B, C, H, W] -> [1, B*C, H, W]
+ new_mask_preds = []
+ for i in range(N):
+ new_mask_preds.append(
+ F.conv2d(
+ mask_x[i:i + 1],
+ mask_feat[i],
+ padding=int(self.conv_kernel_size // 2)))
+
+ new_mask_preds = torch.cat(new_mask_preds, dim=0)
+ new_mask_preds = new_mask_preds.reshape(N, num_proposals, H, W)
+ if self.mask_transform_stride == 2:
+ new_mask_preds = F.interpolate(
+ new_mask_preds,
+ scale_factor=2,
+ mode='bilinear',
+ align_corners=False)
+
+ if mask_shape is not None and mask_shape[0] != H:
+ new_mask_preds = F.interpolate(
+ new_mask_preds,
+ mask_shape,
+ align_corners=False,
+ mode='bilinear')
+
+ return cls_score, new_mask_preds, obj_feat.permute(0, 1, 3, 2).reshape(
+ N, num_proposals, self.in_channels, self.conv_kernel_size,
+ self.conv_kernel_size)
+
+ @force_fp32(apply_to=('cls_score', 'mask_pred'))
+ def loss(self,
+ object_feats,
+ cls_score,
+ mask_pred,
+ labels,
+ label_weights,
+ mask_targets,
+ mask_weights,
+ imgs_whwh=None,
+ reduction_override=None,
+ **kwargs):
+
+ losses = dict()
+ bg_class_ind = self.num_classes
+ # note in spare rcnn num_gt == num_pos
+ pos_inds = (labels >= 0) & (labels < bg_class_ind)
+ num_pos = pos_inds.sum().float()
+ avg_factor = reduce_mean(num_pos).clamp_(min=1.0)
+
+ num_preds = mask_pred.shape[0] * mask_pred.shape[1]
+ assert mask_pred.shape[0] == cls_score.shape[0]
+ assert mask_pred.shape[1] == cls_score.shape[1]
+
+ if cls_score is not None:
+ if cls_score.numel() > 0:
+ losses['loss_cls'] = self.loss_cls(
+ cls_score.view(num_preds, -1),
+ labels,
+ label_weights,
+ avg_factor=avg_factor,
+ reduction_override=reduction_override)
+ losses['pos_acc'] = accuracy(
+ cls_score.view(num_preds, -1)[pos_inds], labels[pos_inds])
+ if mask_pred is not None:
+ bool_pos_inds = pos_inds.type(torch.bool)
+ # 0~self.num_classes-1 are FG, self.num_classes is BG
+ # do not perform bounding box regression for BG anymore.
+ H, W = mask_pred.shape[-2:]
+ if pos_inds.any():
+ pos_mask_pred = mask_pred.reshape(num_preds, H,
+ W)[bool_pos_inds]
+ pos_mask_targets = mask_targets[bool_pos_inds]
+ losses['loss_mask'] = self.loss_mask(pos_mask_pred,
+ pos_mask_targets)
+ losses['loss_dice'] = self.loss_dice(pos_mask_pred,
+ pos_mask_targets)
+
+ if self.loss_rank is not None:
+ batch_size = mask_pred.size(0)
+ rank_target = mask_targets.new_full((batch_size, H, W),
+ self.ignore_label,
+ dtype=torch.long)
+ rank_inds = pos_inds.view(batch_size,
+ -1).nonzero(as_tuple=False)
+ batch_mask_targets = mask_targets.view(
+ batch_size, -1, H, W).bool()
+ for i in range(batch_size):
+ curr_inds = (rank_inds[:, 0] == i)
+ curr_rank = rank_inds[:, 1][curr_inds]
+ for j in curr_rank:
+ rank_target[i][batch_mask_targets[i][j]] = j
+ losses['loss_rank'] = self.loss_rank(
+ mask_pred, rank_target, ignore_index=self.ignore_label)
+ else:
+ losses['loss_mask'] = mask_pred.sum() * 0
+ losses['loss_dice'] = mask_pred.sum() * 0
+ if self.loss_rank is not None:
+ losses['loss_rank'] = mask_pred.sum() * 0
+
+ return losses
+
+ def _get_target_single(self, pos_inds, neg_inds, pos_mask, neg_mask,
+ pos_gt_mask, pos_gt_labels, gt_sem_seg, gt_sem_cls,
+ cfg):
+
+ num_pos = pos_mask.size(0)
+ num_neg = neg_mask.size(0)
+ num_samples = num_pos + num_neg
+ H, W = pos_mask.shape[-2:]
+ # original implementation uses new_zeros since BG are set to be 0
+ # now use empty & fill because BG cat_id = num_classes,
+ # FG cat_id = [0, num_classes-1]
+ labels = pos_mask.new_full((num_samples, ),
+ self.num_classes,
+ dtype=torch.long)
+ label_weights = pos_mask.new_zeros((num_samples, self.num_classes))
+ mask_targets = pos_mask.new_zeros(num_samples, H, W)
+ mask_weights = pos_mask.new_zeros(num_samples, H, W)
+ if num_pos > 0:
+ labels[pos_inds] = pos_gt_labels
+ pos_weight = 1.0 if cfg.pos_weight <= 0 else cfg.pos_weight
+ label_weights[pos_inds] = pos_weight
+ pos_mask_targets = pos_gt_mask
+ mask_targets[pos_inds, ...] = pos_mask_targets
+ mask_weights[pos_inds, ...] = 1
+
+ if num_neg > 0:
+ label_weights[neg_inds] = 1.0
+
+ if gt_sem_cls is not None and gt_sem_seg is not None:
+ sem_labels = pos_mask.new_full((self.num_stuff_classes, ),
+ self.num_classes,
+ dtype=torch.long)
+ sem_targets = pos_mask.new_zeros(self.num_stuff_classes, H, W)
+ sem_weights = pos_mask.new_zeros(self.num_stuff_classes, H, W)
+ sem_stuff_weights = torch.eye(
+ self.num_stuff_classes, device=pos_mask.device)
+ sem_thing_weights = pos_mask.new_zeros(
+ (self.num_stuff_classes, self.num_thing_classes))
+ sem_label_weights = torch.cat(
+ [sem_thing_weights, sem_stuff_weights], dim=-1)
+ if len(gt_sem_cls > 0):
+ sem_inds = gt_sem_cls - self.num_thing_classes
+ sem_inds = sem_inds.long()
+ sem_labels[sem_inds] = gt_sem_cls.long()
+ sem_targets[sem_inds] = gt_sem_seg
+ sem_weights[sem_inds] = 1
+
+ label_weights[:, self.num_thing_classes:] = 0
+ labels = torch.cat([labels, sem_labels])
+ label_weights = torch.cat([label_weights, sem_label_weights])
+ mask_targets = torch.cat([mask_targets, sem_targets])
+ mask_weights = torch.cat([mask_weights, sem_weights])
+
+ return labels, label_weights, mask_targets, mask_weights
+
+ def get_targets(self,
+ sampling_results,
+ rcnn_train_cfg,
+ concat=True,
+ gt_sem_seg=None,
+ gt_sem_cls=None):
+ num_imgs = len(sampling_results)
+ pos_inds_list = [res.pos_inds for res in sampling_results]
+ neg_inds_list = [res.neg_inds for res in sampling_results]
+ pos_mask_list = [res.pos_masks for res in sampling_results]
+ neg_mask_list = [res.neg_masks for res in sampling_results]
+ pos_gt_mask_list = [res.pos_gt_masks for res in sampling_results]
+ pos_gt_labels_list = [res.pos_gt_labels for res in sampling_results]
+ if gt_sem_seg is None:
+ gt_sem_seg = [None] * num_imgs
+ gt_sem_cls = [None] * num_imgs
+
+ labels, label_weights, mask_targets, mask_weights = multi_apply(
+ self._get_target_single,
+ pos_inds_list,
+ neg_inds_list,
+ pos_mask_list,
+ neg_mask_list,
+ pos_gt_mask_list,
+ pos_gt_labels_list,
+ gt_sem_seg,
+ gt_sem_cls,
+ cfg=rcnn_train_cfg)
+ if concat:
+ labels = torch.cat(labels, 0)
+ label_weights = torch.cat(label_weights, 0)
+ mask_targets = torch.cat(mask_targets, 0)
+ mask_weights = torch.cat(mask_weights, 0)
+ return labels, label_weights, mask_targets, mask_weights
+
+ def rescale_masks(self, masks_per_img, img_meta):
+ h, w, _ = img_meta['img_shape']
+ masks_per_img = F.interpolate(
+ masks_per_img.unsqueeze(0).sigmoid(),
+ size=img_meta['batch_input_shape'],
+ mode='bilinear',
+ align_corners=False)
+
+ masks_per_img = masks_per_img[:, :, :h, :w]
+ ori_shape = img_meta['ori_shape']
+ seg_masks = F.interpolate(
+ masks_per_img,
+ size=ori_shape[:2],
+ mode='bilinear',
+ align_corners=False).squeeze(0)
+ return seg_masks
+
+ def get_seg_masks(self, masks_per_img, labels_per_img, scores_per_img,
+ test_cfg, img_meta):
+ # resize mask predictions back
+ seg_masks = self.rescale_masks(masks_per_img, img_meta)
+ seg_masks = seg_masks > test_cfg['mask_thr']
+ bbox_result, segm_result = self.segm2result(seg_masks, labels_per_img,
+ scores_per_img)
+ return bbox_result, segm_result
+
+ def segm2result(self, mask_preds, det_labels, cls_scores):
+ num_classes = self.num_classes
+ bbox_result = None
+ segm_result = [[] for _ in range(num_classes)]
+ mask_preds = mask_preds.cpu().numpy()
+ det_labels = det_labels.cpu().numpy()
+ cls_scores = cls_scores.cpu().numpy()
+ num_ins = mask_preds.shape[0]
+ # fake bboxes
+ bboxes = np.zeros((num_ins, 5), dtype=np.float32)
+ bboxes[:, -1] = cls_scores
+ bbox_result = [bboxes[det_labels == i, :] for i in range(num_classes)]
+ for idx in range(num_ins):
+ segm_result[det_labels[idx]].append(mask_preds[idx])
+ return bbox_result, segm_result
+
+ def get_seg_masks_tracking(self, masks_per_img, labels_per_img,
+ scores_per_img, ids_per_img, test_cfg,
+ img_meta):
+ num_ins = masks_per_img.shape[0]
+ # resize mask predictions back
+ seg_masks = self.rescale_masks(masks_per_img, img_meta)
+ seg_masks = seg_masks > test_cfg['mask_thr']
+ # fake bboxes
+ bboxes = torch.zeros((num_ins, 5), dtype=torch.float32)
+ bboxes[:, -1] = scores_per_img
+ tracks = outs2results(
+ bboxes=bboxes,
+ labels=labels_per_img,
+ masks=seg_masks,
+ ids=ids_per_img,
+ num_classes=self.num_classes,
+ )
+ return tracks['bbox_results'], tracks['mask_results']
diff --git a/modelscope/models/cv/video_instance_segmentation/head/kernel_updator.py b/modelscope/models/cv/video_instance_segmentation/head/kernel_updator.py
new file mode 100644
index 00000000..4d67d59f
--- /dev/null
+++ b/modelscope/models/cv/video_instance_segmentation/head/kernel_updator.py
@@ -0,0 +1,97 @@
+# The implementation is adopted from Video-K-Net,
+# made publicly available at https://github.com/lxtGH/Video-K-Net follow the MIT license
+
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import build_activation_layer, build_norm_layer
+from mmcv.cnn.bricks.transformer import TRANSFORMER_LAYER
+
+
+@TRANSFORMER_LAYER.register_module()
+class KernelUpdator(nn.Module):
+
+ def __init__(self,
+ in_channels=256,
+ feat_channels=64,
+ out_channels=None,
+ input_feat_shape=3,
+ gate_sigmoid=True,
+ gate_norm_act=False,
+ activate_out=False,
+ act_cfg=dict(type='ReLU', inplace=True),
+ norm_cfg=dict(type='LN')):
+ super(KernelUpdator, self).__init__()
+ self.in_channels = in_channels
+ self.feat_channels = feat_channels
+ self.out_channels_raw = out_channels
+ self.gate_sigmoid = gate_sigmoid
+ self.gate_norm_act = gate_norm_act
+ self.activate_out = activate_out
+ if isinstance(input_feat_shape, int):
+ input_feat_shape = [input_feat_shape] * 2
+ self.input_feat_shape = input_feat_shape
+ self.act_cfg = act_cfg
+ self.norm_cfg = norm_cfg
+ self.out_channels = out_channels if out_channels else in_channels
+
+ self.num_params_in = self.feat_channels
+ self.num_params_out = self.feat_channels
+ self.dynamic_layer = nn.Linear(
+ self.in_channels, self.num_params_in + self.num_params_out)
+ self.input_layer = nn.Linear(self.in_channels,
+ self.num_params_in + self.num_params_out,
+ 1)
+ self.input_gate = nn.Linear(self.in_channels, self.feat_channels, 1)
+ self.update_gate = nn.Linear(self.in_channels, self.feat_channels, 1)
+ if self.gate_norm_act:
+ self.gate_norm = build_norm_layer(norm_cfg, self.feat_channels)[1]
+
+ self.norm_in = build_norm_layer(norm_cfg, self.feat_channels)[1]
+ self.norm_out = build_norm_layer(norm_cfg, self.feat_channels)[1]
+ self.input_norm_in = build_norm_layer(norm_cfg, self.feat_channels)[1]
+ self.input_norm_out = build_norm_layer(norm_cfg, self.feat_channels)[1]
+
+ self.activation = build_activation_layer(act_cfg)
+
+ self.fc_layer = nn.Linear(self.feat_channels, self.out_channels, 1)
+ self.fc_norm = build_norm_layer(norm_cfg, self.out_channels)[1]
+
+ def forward(self, update_feature, input_feature):
+ update_feature = update_feature.reshape(-1, self.in_channels)
+ num_proposals = update_feature.size(0)
+ parameters = self.dynamic_layer(update_feature)
+ param_in = parameters[:, :self.num_params_in].view(
+ -1, self.feat_channels)
+ param_out = parameters[:, -self.num_params_out:].view(
+ -1, self.feat_channels)
+
+ input_feats = self.input_layer(
+ input_feature.reshape(num_proposals, -1, self.feat_channels))
+ input_in = input_feats[..., :self.num_params_in]
+ input_out = input_feats[..., -self.num_params_out:]
+
+ gate_feats = input_in * param_in.unsqueeze(-2)
+ if self.gate_norm_act:
+ gate_feats = self.activation(self.gate_norm(gate_feats))
+
+ input_gate = self.input_norm_in(self.input_gate(gate_feats))
+ update_gate = self.norm_in(self.update_gate(gate_feats))
+ if self.gate_sigmoid:
+ input_gate = input_gate.sigmoid()
+ update_gate = update_gate.sigmoid()
+ param_out = self.norm_out(param_out)
+ input_out = self.input_norm_out(input_out)
+
+ if self.activate_out:
+ param_out = self.activation(param_out)
+ input_out = self.activation(input_out)
+
+ # param_out has shape (batch_size, feat_channels, out_channels)
+ features = update_gate * param_out.unsqueeze(
+ -2) + input_gate * input_out
+
+ features = self.fc_layer(features)
+ features = self.fc_norm(features)
+ features = self.activation(features)
+
+ return features
diff --git a/modelscope/models/cv/video_instance_segmentation/neck/__init__.py b/modelscope/models/cv/video_instance_segmentation/neck/__init__.py
new file mode 100644
index 00000000..adb283e7
--- /dev/null
+++ b/modelscope/models/cv/video_instance_segmentation/neck/__init__.py
@@ -0,0 +1,21 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+from typing import TYPE_CHECKING
+
+from modelscope.utils.import_utils import LazyImportModule
+
+if TYPE_CHECKING:
+ from .msdeformattn_decoder import (
+ MSDeformAttnPixelDecoder, )
+
+else:
+ _import_structure = {'msdeformattn_decoder': ['MSDeformAttnPixelDecoder']}
+
+ import sys
+
+ sys.modules[__name__] = LazyImportModule(
+ __name__,
+ globals()['__file__'],
+ _import_structure,
+ module_spec=__spec__,
+ extra_objects={},
+ )
diff --git a/modelscope/models/cv/video_instance_segmentation/neck/msdeformattn_decoder.py b/modelscope/models/cv/video_instance_segmentation/neck/msdeformattn_decoder.py
new file mode 100644
index 00000000..f87c92fe
--- /dev/null
+++ b/modelscope/models/cv/video_instance_segmentation/neck/msdeformattn_decoder.py
@@ -0,0 +1,275 @@
+# The implementation is adopted from Video-K-Net,
+# made publicly available at https://github.com/lxtGH/Video-K-Net follow the MIT license
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import (Conv2d, ConvModule, caffe2_xavier_init, normal_init,
+ xavier_init)
+from mmcv.cnn.bricks.transformer import (build_positional_encoding,
+ build_transformer_layer_sequence)
+from mmcv.runner import BaseModule, ModuleList
+from mmdet.core.anchor import MlvlPointGenerator
+from mmdet.models.builder import NECKS
+from mmdet.models.utils.transformer import MultiScaleDeformableAttention
+
+
+@NECKS.register_module()
+class MSDeformAttnPixelDecoder(BaseModule):
+ """Pixel decoder with multi-scale deformable attention.
+
+ Args:
+ in_channels (list[int] | tuple[int]): Number of channels in the
+ input feature maps.
+ strides (list[int] | tuple[int]): Output strides of feature from
+ backbone.
+ feat_channels (int): Number of channels for feature.
+ out_channels (int): Number of channels for output.
+ num_outs (int): Number of output scales.
+ norm_cfg (:obj:`mmcv.ConfigDict` | dict): Config for normalization.
+ Defaults to dict(type='GN', num_groups=32).
+ act_cfg (:obj:`mmcv.ConfigDict` | dict): Config for activation.
+ Defaults to dict(type='ReLU').
+ encoder (:obj:`mmcv.ConfigDict` | dict): Config for transformer
+ encoder. Defaults to `DetrTransformerEncoder`.
+ positional_encoding (:obj:`mmcv.ConfigDict` | dict): Config for
+ transformer encoder position encoding. Defaults to
+ dict(type='SinePositionalEncoding', num_feats=128,
+ normalize=True).
+ init_cfg (:obj:`mmcv.ConfigDict` | dict): Initialization config dict.
+ """
+
+ def __init__(self,
+ in_channels=[256, 512, 1024, 2048],
+ strides=[4, 8, 16, 32],
+ feat_channels=256,
+ out_channels=256,
+ num_outs=3,
+ return_one_list=True,
+ norm_cfg=dict(type='GN', num_groups=32),
+ act_cfg=dict(type='ReLU'),
+ encoder=dict(
+ type='DetrTransformerEncoder',
+ num_layers=6,
+ transformerlayers=dict(
+ type='BaseTransformerLayer',
+ attn_cfgs=dict(
+ type='MultiScaleDeformableAttention',
+ embed_dims=256,
+ num_heads=8,
+ num_levels=3,
+ num_points=4,
+ im2col_step=64,
+ dropout=0.0,
+ batch_first=False,
+ norm_cfg=None,
+ init_cfg=None),
+ feedforward_channels=1024,
+ ffn_dropout=0.0,
+ operation_order=('self_attn', 'norm', 'ffn', 'norm')),
+ init_cfg=None),
+ positional_encoding=dict(
+ type='SinePositionalEncoding',
+ num_feats=128,
+ normalize=True),
+ init_cfg=None):
+ super().__init__(init_cfg=init_cfg)
+ self.strides = strides
+ self.num_input_levels = len(in_channels)
+ self.return_one_list = return_one_list
+ self.num_encoder_levels = encoder['transformerlayers']['attn_cfgs'][
+ 'num_levels']
+ assert self.num_encoder_levels >= 1, \
+ 'num_levels in attn_cfgs must be at least one'
+ input_conv_list = []
+ # from top to down (low to high resolution)
+ for i in range(self.num_input_levels - 1,
+ self.num_input_levels - self.num_encoder_levels - 1,
+ -1):
+ input_conv = ConvModule(
+ in_channels[i],
+ feat_channels,
+ kernel_size=1,
+ norm_cfg=norm_cfg,
+ act_cfg=None,
+ bias=True)
+ input_conv_list.append(input_conv)
+ self.input_convs = ModuleList(input_conv_list)
+
+ self.encoder = build_transformer_layer_sequence(encoder)
+ self.postional_encoding = build_positional_encoding(
+ positional_encoding)
+ # high resolution to low resolution
+ self.level_encoding = nn.Embedding(self.num_encoder_levels,
+ feat_channels)
+
+ # fpn-like structure
+ self.lateral_convs = ModuleList()
+ self.output_convs = ModuleList()
+ self.use_bias = norm_cfg is None
+ # from top to down (low to high resolution)
+ # fpn for the rest features that didn't pass in encoder
+ for i in range(self.num_input_levels - self.num_encoder_levels - 1, -1,
+ -1):
+ lateral_conv = ConvModule(
+ in_channels[i],
+ feat_channels,
+ kernel_size=1,
+ bias=self.use_bias,
+ norm_cfg=norm_cfg,
+ act_cfg=None)
+ output_conv = ConvModule(
+ feat_channels,
+ feat_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=self.use_bias,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+ self.lateral_convs.append(lateral_conv)
+ self.output_convs.append(output_conv)
+
+ self.mask_feature = Conv2d(
+ feat_channels, out_channels, kernel_size=1, stride=1, padding=0)
+
+ self.num_outs = num_outs
+ self.point_generator = MlvlPointGenerator(strides)
+
+ def init_weights(self):
+ """Initialize weights."""
+ for i in range(0, self.num_encoder_levels):
+ xavier_init(
+ self.input_convs[i].conv,
+ gain=1,
+ bias=0,
+ distribution='uniform')
+
+ for i in range(0, self.num_input_levels - self.num_encoder_levels):
+ caffe2_xavier_init(self.lateral_convs[i].conv, bias=0)
+ caffe2_xavier_init(self.output_convs[i].conv, bias=0)
+
+ caffe2_xavier_init(self.mask_feature, bias=0)
+
+ normal_init(self.level_encoding, mean=0, std=1)
+ for p in self.encoder.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_normal_(p)
+
+ # init_weights defined in MultiScaleDeformableAttention
+ for layer in self.encoder.layers:
+ for attn in layer.attentions:
+ if isinstance(attn, MultiScaleDeformableAttention):
+ attn.init_weights()
+
+ def forward(self, feats):
+ """
+ Args:
+ feats (list[Tensor]): Feature maps of each level. Each has
+ shape of (batch_size, c, h, w).
+
+ Returns:
+ tuple: A tuple containing the following:
+
+ - mask_feature (Tensor): shape (batch_size, c, h, w).
+ - multi_scale_features (list[Tensor]): Multi scale \
+ features, each in shape (batch_size, c, h, w).
+ """
+ # generate padding mask for each level, for each image
+ batch_size = feats[0].shape[0]
+ encoder_input_list = []
+ padding_mask_list = []
+ level_positional_encoding_list = []
+ spatial_shapes = []
+ reference_points_list = []
+ for i in range(self.num_encoder_levels):
+ level_idx = self.num_input_levels - i - 1
+ feat = feats[level_idx]
+ feat_projected = self.input_convs[i](feat)
+ h, w = feat.shape[-2:]
+
+ # no padding
+ padding_mask_resized = feat.new_zeros(
+ (batch_size, ) + feat.shape[-2:], dtype=torch.bool)
+ pos_embed = self.postional_encoding(padding_mask_resized)
+ level_embed = self.level_encoding.weight[i]
+ level_pos_embed = level_embed.view(1, -1, 1, 1) + pos_embed
+ # (h_i * w_i, 2)
+ reference_points = self.point_generator.single_level_grid_priors(
+ feat.shape[-2:], level_idx, device=feat.device)
+ # normalize
+ factor = feat.new_tensor([[w, h]]) * self.strides[level_idx]
+ reference_points = reference_points / factor
+
+ # shape (batch_size, c, h_i, w_i) -> (h_i * w_i, batch_size, c)
+ feat_projected = feat_projected.flatten(2).permute(2, 0, 1)
+ level_pos_embed = level_pos_embed.flatten(2).permute(2, 0, 1)
+ padding_mask_resized = padding_mask_resized.flatten(1)
+
+ encoder_input_list.append(feat_projected)
+ padding_mask_list.append(padding_mask_resized)
+ level_positional_encoding_list.append(level_pos_embed)
+ spatial_shapes.append(feat.shape[-2:])
+ reference_points_list.append(reference_points)
+ # shape (batch_size, total_num_query),
+ # total_num_query=sum([., h_i * w_i,.])
+ padding_masks = torch.cat(padding_mask_list, dim=1)
+ # shape (total_num_query, batch_size, c)
+ encoder_inputs = torch.cat(encoder_input_list, dim=0)
+ level_positional_encodings = torch.cat(
+ level_positional_encoding_list, dim=0)
+ device = encoder_inputs.device
+ # shape (num_encoder_levels, 2), from low
+ # resolution to high resolution
+ spatial_shapes = torch.as_tensor(
+ spatial_shapes, dtype=torch.long, device=device)
+ # shape (0, h_0*w_0, h_0*w_0+h_1*w_1, ...)
+ level_start_index = torch.cat((spatial_shapes.new_zeros(
+ (1, )), spatial_shapes.prod(1).cumsum(0)[:-1]))
+ reference_points = torch.cat(reference_points_list, dim=0)
+ reference_points = reference_points[None, :, None].repeat(
+ batch_size, 1, self.num_encoder_levels, 1)
+ valid_radios = reference_points.new_ones(
+ (batch_size, self.num_encoder_levels, 2))
+ # shape (num_total_query, batch_size, c)
+ memory = self.encoder(
+ query=encoder_inputs,
+ key=None,
+ value=None,
+ query_pos=level_positional_encodings,
+ key_pos=None,
+ attn_masks=None,
+ key_padding_mask=None,
+ query_key_padding_mask=padding_masks,
+ spatial_shapes=spatial_shapes,
+ reference_points=reference_points,
+ level_start_index=level_start_index,
+ valid_radios=valid_radios)
+ # (num_total_query, batch_size, c) -> (batch_size, c, num_total_query)
+ memory = memory.permute(1, 2, 0)
+
+ # from low resolution to high resolution
+ num_query_per_level = [e[0] * e[1] for e in spatial_shapes]
+ outs = torch.split(memory, num_query_per_level, dim=-1)
+ outs = [
+ x.reshape(batch_size, -1, spatial_shapes[i][0],
+ spatial_shapes[i][1]) for i, x in enumerate(outs)
+ ]
+
+ for i in range(self.num_input_levels - self.num_encoder_levels - 1, -1,
+ -1):
+ x = feats[i]
+ cur_feat = self.lateral_convs[i](x)
+ y = cur_feat + F.interpolate(
+ outs[-1],
+ size=cur_feat.shape[-2:],
+ mode='bilinear',
+ align_corners=False)
+ y = self.output_convs[i](y)
+ outs.append(y)
+ multi_scale_features = outs[:self.num_outs]
+
+ mask_feature = self.mask_feature(outs[-1])
+ multi_scale_features.append(mask_feature)
+ multi_scale_features.reverse()
+ return tuple(multi_scale_features)
diff --git a/modelscope/models/cv/video_instance_segmentation/track/__init__.py b/modelscope/models/cv/video_instance_segmentation/track/__init__.py
new file mode 100644
index 00000000..b937315b
--- /dev/null
+++ b/modelscope/models/cv/video_instance_segmentation/track/__init__.py
@@ -0,0 +1 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
diff --git a/modelscope/models/cv/video_instance_segmentation/track/kernel_update_head.py b/modelscope/models/cv/video_instance_segmentation/track/kernel_update_head.py
new file mode 100644
index 00000000..252fec89
--- /dev/null
+++ b/modelscope/models/cv/video_instance_segmentation/track/kernel_update_head.py
@@ -0,0 +1,634 @@
+# The implementation is adopted from Video-K-Net,
+# made publicly available at https://github.com/lxtGH/Video-K-Net follow the MIT license
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import (ConvModule, bias_init_with_prob, build_activation_layer,
+ build_norm_layer)
+from mmcv.cnn.bricks.transformer import (FFN, MultiheadAttention,
+ build_transformer_layer)
+from mmcv.runner import force_fp32
+from mmdet.core import multi_apply
+from mmdet.models.builder import HEADS, build_loss
+from mmdet.models.dense_heads.atss_head import reduce_mean
+from mmdet.models.losses import accuracy
+from mmdet.utils import get_root_logger
+
+from ..utils import outs2results
+
+
+@HEADS.register_module()
+class KernelUpdateHeadVideo(nn.Module):
+
+ def __init__(
+ self,
+ with_cls=True,
+ num_proposals=100,
+ num_classes=80,
+ num_ffn_fcs=2,
+ num_heads=8,
+ num_cls_fcs=1,
+ num_mask_fcs=3,
+ feedforward_channels=2048,
+ in_channels=256,
+ out_channels=256,
+ dropout=0.0,
+ mask_thr=0.5,
+ act_cfg=dict(type='ReLU', inplace=True),
+ ffn_act_cfg=dict(type='ReLU', inplace=True),
+ conv_kernel_size=3,
+ feat_transform_cfg=None,
+ hard_mask_thr=0.5,
+ kernel_init=False,
+ with_ffn=True,
+ mask_out_stride=4,
+ relative_coors=False,
+ relative_coors_off=False,
+ feat_gather_stride=1,
+ mask_transform_stride=1,
+ mask_upsample_stride=1,
+ num_thing_classes=80,
+ num_stuff_classes=53,
+ mask_assign_stride=4,
+ ignore_label=255,
+ thing_label_in_seg=0,
+ # query fusion
+ query_merge_method='mean',
+ kernel_updator_cfg=dict(
+ type='DynamicConv',
+ in_channels=256,
+ feat_channels=64,
+ out_channels=256,
+ input_feat_shape=1,
+ act_cfg=dict(type='ReLU', inplace=True),
+ norm_cfg=dict(type='LN')),
+ loss_rank=None,
+ loss_mask=dict(
+ type='CrossEntropyLoss', use_mask=True, loss_weight=1.0),
+ loss_dice=dict(type='DiceLoss', loss_weight=3.0),
+ loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=2.0)):
+ super().__init__()
+ self.num_proposals = num_proposals
+ self.num_classes = num_classes
+ self.loss_cls = build_loss(loss_cls)
+ self.loss_mask = build_loss(loss_mask)
+ self.loss_dice = build_loss(loss_dice)
+ if loss_rank is not None:
+ self.loss_rank = build_loss(loss_rank)
+ else:
+ self.loss_rank = loss_rank
+
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.mask_thr = mask_thr
+ self.fp16_enabled = False
+ self.dropout = dropout
+
+ self.num_heads = num_heads
+ self.hard_mask_thr = hard_mask_thr
+ self.kernel_init = kernel_init
+ self.with_ffn = with_ffn
+ self.mask_out_stride = mask_out_stride
+ self.relative_coors = relative_coors
+ self.relative_coors_off = relative_coors_off
+ self.conv_kernel_size = conv_kernel_size
+ self.feat_gather_stride = feat_gather_stride
+ self.mask_transform_stride = mask_transform_stride
+ self.mask_upsample_stride = mask_upsample_stride
+
+ self.num_thing_classes = num_thing_classes
+ self.num_stuff_classes = num_stuff_classes
+ self.mask_assign_stride = mask_assign_stride
+ self.ignore_label = ignore_label
+ self.thing_label_in_seg = thing_label_in_seg
+
+ self.attention = MultiheadAttention(in_channels * conv_kernel_size**2,
+ num_heads, dropout)
+ self.attention_norm = build_norm_layer(
+ dict(type='LN'), in_channels * conv_kernel_size**2)[1]
+
+ self.kernel_update_conv = build_transformer_layer(kernel_updator_cfg)
+
+ if feat_transform_cfg is not None:
+ kernel_size = feat_transform_cfg.pop('kernel_size', 1)
+ self.feat_transform = ConvModule(
+ in_channels,
+ in_channels,
+ kernel_size,
+ stride=feat_gather_stride,
+ padding=int(feat_gather_stride // 2),
+ **feat_transform_cfg)
+ else:
+ self.feat_transform = None
+
+ if self.with_ffn:
+ self.ffn = FFN(
+ in_channels,
+ feedforward_channels,
+ num_ffn_fcs,
+ act_cfg=ffn_act_cfg,
+ ffn_drop=dropout)
+ self.ffn_norm = build_norm_layer(dict(type='LN'), in_channels)[1]
+
+ self.with_cls = with_cls
+ if self.with_cls:
+ self.cls_fcs = nn.ModuleList()
+ for _ in range(num_cls_fcs):
+ self.cls_fcs.append(
+ nn.Linear(in_channels, in_channels, bias=False))
+ self.cls_fcs.append(
+ build_norm_layer(dict(type='LN'), in_channels)[1])
+ self.cls_fcs.append(build_activation_layer(act_cfg))
+
+ if self.loss_cls.use_sigmoid:
+ self.fc_cls = nn.Linear(in_channels, self.num_classes)
+ else:
+ self.fc_cls = nn.Linear(in_channels, self.num_classes + 1)
+
+ # query fusion
+ self.query_merge_method = query_merge_method
+ if self.query_merge_method == 'attention' and self.with_cls:
+ _num_head = 8
+ _drop_out = 0.
+ self.query_merge_attn = MultiheadAttention(
+ self.in_channels, _num_head, _drop_out, batch_first=True)
+ self.query_merge_norm = build_norm_layer(
+ dict(type='LN'), self.in_channels)[1]
+ self.query_merge_ffn = FFN(
+ self.in_channels,
+ self.in_channels * 8,
+ num_ffn_fcs=2,
+ act_cfg=dict(type='ReLU', inplace=True),
+ ffn_drop=0.)
+ self.query_merge_ffn_norm = build_norm_layer(
+ dict(type='LN'), self.in_channels)[1]
+ elif self.query_merge_method == 'attention_pos' and self.with_cls:
+ _num_head = 8
+ _drop_out = 0.
+ self.query_merge_attn = MultiheadAttention(
+ self.in_channels, _num_head, _drop_out, batch_first=True)
+ self.query_merge_norm = build_norm_layer(
+ dict(type='LN'), self.in_channels)[1]
+ self.query_merge_ffn = FFN(
+ self.in_channels,
+ self.in_channels * 8,
+ num_ffn_fcs=2,
+ act_cfg=dict(type='ReLU', inplace=True),
+ ffn_drop=0.)
+ self.query_merge_ffn_norm = build_norm_layer(
+ dict(type='LN'), self.in_channels)[1]
+
+ self.mask_fcs = nn.ModuleList()
+ for _ in range(num_mask_fcs):
+ self.mask_fcs.append(
+ nn.Linear(in_channels, in_channels, bias=False))
+ self.mask_fcs.append(
+ build_norm_layer(dict(type='LN'), in_channels)[1])
+ self.mask_fcs.append(build_activation_layer(act_cfg))
+
+ self.fc_mask = nn.Linear(in_channels, out_channels)
+
+ def init_weights(self):
+ """Use xavier initialization for all weight parameter and set
+ classification head bias as a specific value when use focal loss."""
+ for p in self.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+ else:
+ # adopt the default initialization for
+ # the weight and bias of the layer norm
+ pass
+ if self.loss_cls.use_sigmoid:
+ bias_init = bias_init_with_prob(0.01)
+ nn.init.constant_(self.fc_cls.bias, bias_init)
+ if self.kernel_init:
+ logger = get_root_logger()
+ logger.info(
+ 'mask kernel in mask head is normal initialized by std 0.01')
+ nn.init.normal_(self.fc_mask.weight, mean=0, std=0.01)
+
+ def forward(self,
+ x,
+ proposal_feat,
+ mask_preds,
+ prev_cls_score=None,
+ mask_shape=None,
+ img_metas=None,
+ pos=None):
+ if len(proposal_feat.size()) == 6:
+ assert not self.with_cls
+ is_gather_query = False
+ N, _, num_proposals = proposal_feat.shape[:3]
+ else:
+ assert self.with_cls
+ is_gather_query = True
+ N, num_proposals = proposal_feat.shape[:2]
+ assert self.num_proposals == num_proposals
+ _, num_frames, C, H, W = x.size()
+ if self.feat_transform is not None:
+ x = self.feat_transform(x.reshape(
+ (N * num_frames, C, H, W))).reshape((N, num_frames, C, H, W))
+
+ mask_h, mask_w = mask_preds.shape[-2:]
+ if mask_h != H or mask_w != W:
+ gather_mask = F.interpolate(
+ mask_preds.reshape((N * num_proposals, C, H, W)), (H, W),
+ align_corners=False,
+ mode='bilinear').reshape((N, num_frames, C, H, W))
+ else:
+ gather_mask = mask_preds
+
+ sigmoid_masks = gather_mask.sigmoid()
+ nonzero_inds = sigmoid_masks > self.hard_mask_thr
+ sigmoid_masks = nonzero_inds.float()
+
+ # einsum is faster than bmm by 30%
+ if is_gather_query:
+ # x_feat = torch.einsum('bfnhw,bfchw->bnc', sigmoid_masks, x)
+ if self.query_merge_method == 'mean':
+ x_feat = torch.einsum('bfnhw,bfchw->bfnc', sigmoid_masks,
+ x).mean(1)
+ elif self.query_merge_method == 'attention':
+ x_feat = torch.einsum('bfnhw,bfchw->bfnc', sigmoid_masks, x)
+ x_feat = x_feat.reshape(
+ (N, num_frames * num_proposals, self.in_channels))
+ assert proposal_feat.size()[-2:] == (
+ 1, 1), 'Only supporting kernel size = 1'
+ init_query = proposal_feat.reshape(N, num_proposals,
+ self.in_channels).detach()
+ x_feat = self.query_merge_attn(
+ query=init_query, key=x_feat, value=x_feat)
+ x_feat = self.query_merge_norm(x_feat)
+ x_feat = self.query_merge_ffn_norm(
+ self.query_merge_ffn(x_feat))
+ elif self.query_merge_method == 'attention_pos':
+ x_feat = torch.einsum('bfnhw,bfchw->bfnc', sigmoid_masks, x)
+ x_feat = x_feat.reshape(
+ (N, num_frames * num_proposals, self.in_channels))
+ assert proposal_feat.size()[-2:] == (
+ 1, 1), 'Only supporting kernel size = 1'
+ init_query = proposal_feat.reshape(N, num_proposals,
+ self.in_channels).detach()
+ query_pos = pos.repeat(N, 1, 1)
+ key_pos = query_pos.repeat(1, num_frames, 1)
+ x_feat = self.query_merge_attn(
+ query=init_query,
+ key=x_feat,
+ value=x_feat,
+ query_pos=query_pos,
+ key_pos=key_pos)
+ x_feat = self.query_merge_norm(x_feat)
+ x_feat = self.query_merge_ffn_norm(
+ self.query_merge_ffn(x_feat))
+ else:
+ raise NotImplementedError
+ else:
+ x_feat = torch.einsum('bfnhw,bfchw->bfnc', sigmoid_masks, x)
+
+ # obj_feat in shape [B, N, C, K, K] -> [B, N, C, K*K] -> [B, N, K*K, C]
+ if is_gather_query:
+ proposal_feat = proposal_feat.reshape(N, num_proposals,
+ self.in_channels,
+ -1).permute(0, 1, 3, 2)
+ obj_feat = self.kernel_update_conv(x_feat, proposal_feat)
+ else:
+ proposal_feat = proposal_feat.reshape(N * num_frames,
+ num_proposals,
+ self.in_channels,
+ -1).permute(0, 1, 3, 2)
+ obj_feat = self.kernel_update_conv(
+ x_feat.reshape(N * num_frames, num_proposals, C),
+ proposal_feat)
+ N *= num_frames
+
+ # [B, N, K*K, C] -> [B, N, K*K*C] -> [N, B, K*K*C]
+ obj_feat = obj_feat.reshape(N, num_proposals, -1).permute(1, 0, 2)
+ obj_feat = self.attention_norm(self.attention(obj_feat))
+ # [N, B, K*K*C] -> [B, N, K*K*C]
+ obj_feat = obj_feat.permute(1, 0, 2)
+
+ # obj_feat in shape [B, N, K*K*C] -> [B, N, K*K, C]
+ obj_feat = obj_feat.reshape(N, num_proposals, -1, self.in_channels)
+
+ # FFN
+ if self.with_ffn:
+ obj_feat = self.ffn_norm(self.ffn(obj_feat))
+
+ mask_feat = obj_feat
+
+ if is_gather_query:
+ cls_feat = obj_feat.sum(-2)
+ for cls_layer in self.cls_fcs:
+ cls_feat = cls_layer(cls_feat)
+ cls_score = self.fc_cls(cls_feat).view(N, num_proposals, -1)
+ else:
+ cls_score = None
+
+ for reg_layer in self.mask_fcs:
+ mask_feat = reg_layer(mask_feat)
+ # [B, N, K*K, C] -> [B, N, C, K*K]
+ mask_feat = self.fc_mask(mask_feat).permute(0, 1, 3, 2)
+
+ if (self.mask_transform_stride == 2 and self.feat_gather_stride == 1):
+ mask_x = F.interpolate(
+ x, scale_factor=0.5, mode='bilinear', align_corners=False)
+ H, W = mask_x.shape[-2:]
+ raise NotImplementedError
+ else:
+ mask_x = x
+ # group conv is 5x faster than unfold and uses about 1/5 memory
+ # Group conv vs. unfold vs. concat batch, 2.9ms :13.5ms :3.8ms
+ # Group conv vs. unfold vs. concat batch, 278 : 1420 : 369
+ # fold_x = F.unfold(
+ # mask_x,
+ # self.conv_kernel_size,
+ # padding=int(self.conv_kernel_size // 2))
+ # mask_feat = mask_feat.reshape(N, num_proposals, -1)
+ # new_mask_preds = torch.einsum('bnc,bcl->bnl', mask_feat, fold_x)
+ # [B, N, C, K*K] -> [B*N, C, K, K]
+ mask_feat = mask_feat.reshape(N, num_proposals, C,
+ self.conv_kernel_size,
+ self.conv_kernel_size)
+ # [B, C, H, W] -> [1, B*C, H, W]
+ if is_gather_query:
+ new_mask_preds = []
+ for i in range(N):
+ new_mask_preds.append(
+ F.conv2d(
+ mask_x[i],
+ mask_feat[i],
+ padding=int(self.conv_kernel_size // 2)))
+
+ new_mask_preds = torch.stack(new_mask_preds, dim=0)
+ assert new_mask_preds.size() == (N, num_frames, num_proposals, H,
+ W)
+ else:
+ N = N // num_frames
+ new_mask_preds = []
+ for i in range(N):
+ for j in range(num_frames):
+ new_mask_preds.append(
+ F.conv2d(
+ mask_x[i][j][None],
+ mask_feat[i * num_frames + j],
+ padding=int(self.conv_kernel_size // 2)))
+ new_mask_preds = torch.cat(new_mask_preds, dim=0)
+ new_mask_preds = new_mask_preds.reshape(N, num_frames,
+ num_proposals, H, W)
+ assert new_mask_preds.size() == (N, num_frames, num_proposals, H,
+ W)
+ if self.mask_transform_stride == 2:
+ new_mask_preds = F.interpolate(
+ new_mask_preds,
+ scale_factor=2,
+ mode='bilinear',
+ align_corners=False)
+ raise NotImplementedError
+
+ if mask_shape is not None and mask_shape[0] != H:
+ new_mask_preds = F.interpolate(
+ new_mask_preds,
+ mask_shape,
+ align_corners=False,
+ mode='bilinear')
+ raise NotImplementedError
+ if is_gather_query:
+ return cls_score, new_mask_preds, obj_feat.permute(
+ 0, 1, 3, 2).reshape(N, num_proposals, self.in_channels,
+ self.conv_kernel_size,
+ self.conv_kernel_size)
+ else:
+ return None, new_mask_preds, obj_feat.permute(0, 1, 3, 2).reshape(
+ N, num_frames, num_proposals, self.in_channels,
+ self.conv_kernel_size, self.conv_kernel_size)
+
+ @force_fp32(apply_to=('cls_score', 'mask_pred'))
+ def loss(self,
+ object_feats,
+ cls_score,
+ mask_pred,
+ labels,
+ label_weights,
+ mask_targets,
+ mask_weights,
+ imgs_whwh=None,
+ reduction_override=None,
+ **kwargs):
+
+ losses = dict()
+ bg_class_ind = self.num_classes
+ # note in spare rcnn num_gt == num_pos
+ pos_inds = (labels >= 0) & (labels < bg_class_ind)
+ num_pos = pos_inds.sum().float()
+ avg_factor = reduce_mean(num_pos).clamp_(min=1.0)
+
+ num_preds = mask_pred.shape[0] * mask_pred.shape[1]
+ if cls_score is not None:
+ assert mask_pred.shape[0] == cls_score.shape[0]
+ assert mask_pred.shape[1] == cls_score.shape[1]
+
+ if cls_score is not None:
+ if cls_score.numel() > 0:
+ losses['loss_cls'] = self.loss_cls(
+ cls_score.view(num_preds, -1),
+ labels,
+ label_weights,
+ avg_factor=avg_factor,
+ reduction_override=reduction_override)
+ losses['pos_acc'] = accuracy(
+ cls_score.view(num_preds, -1)[pos_inds], labels[pos_inds])
+ if mask_pred is not None:
+ bool_pos_inds = pos_inds.type(torch.bool)
+ # 0~self.num_classes-1 are FG, self.num_classes is BG
+ # do not perform bounding box regression for BG anymore.
+ H, W = mask_pred.shape[-2:]
+ if pos_inds.any():
+ pos_mask_pred = mask_pred.reshape(num_preds, H,
+ W)[bool_pos_inds]
+ pos_mask_targets = mask_targets[bool_pos_inds]
+ losses['loss_mask'] = self.loss_mask(pos_mask_pred,
+ pos_mask_targets)
+ losses['loss_dice'] = self.loss_dice(pos_mask_pred,
+ pos_mask_targets)
+
+ if self.loss_rank is not None:
+ batch_size = mask_pred.size(0)
+ rank_target = mask_targets.new_full((batch_size, H, W),
+ self.ignore_label,
+ dtype=torch.long)
+ rank_inds = pos_inds.view(batch_size,
+ -1).nonzero(as_tuple=False)
+ batch_mask_targets = mask_targets.view(
+ batch_size, -1, H, W).bool()
+ for i in range(batch_size):
+ curr_inds = (rank_inds[:, 0] == i)
+ curr_rank = rank_inds[:, 1][curr_inds]
+ for j in curr_rank:
+ rank_target[i][batch_mask_targets[i][j]] = j
+ losses['loss_rank'] = self.loss_rank(
+ mask_pred, rank_target, ignore_index=self.ignore_label)
+ else:
+ losses['loss_mask'] = mask_pred.sum() * 0
+ losses['loss_dice'] = mask_pred.sum() * 0
+ if self.loss_rank is not None:
+ losses['loss_rank'] = mask_pred.sum() * 0
+
+ return losses
+
+ def _get_target_single(self, pos_inds, neg_inds, pos_mask, neg_mask,
+ pos_gt_mask, pos_gt_labels, gt_sem_seg, gt_sem_cls,
+ cfg):
+
+ num_pos = pos_mask.size(0)
+ num_neg = neg_mask.size(0)
+ num_samples = num_pos + num_neg
+ H, W = pos_mask.shape[-2:]
+ # original implementation uses new_zeros since BG are set to be 0
+ # now use empty & fill because BG cat_id = num_classes,
+ # FG cat_id = [0, num_classes-1]
+ labels = pos_mask.new_full((num_samples, ),
+ self.num_classes,
+ dtype=torch.long)
+ label_weights = pos_mask.new_zeros((num_samples, self.num_classes))
+ mask_targets = pos_mask.new_zeros(num_samples, H, W)
+ mask_weights = pos_mask.new_zeros(num_samples, H, W)
+ if num_pos > 0:
+ labels[pos_inds] = pos_gt_labels
+ pos_weight = 1.0 if cfg.pos_weight <= 0 else cfg.pos_weight
+ label_weights[pos_inds] = pos_weight
+ pos_mask_targets = pos_gt_mask
+ mask_targets[pos_inds, ...] = pos_mask_targets
+ mask_weights[pos_inds, ...] = 1
+
+ if num_neg > 0:
+ label_weights[neg_inds] = 1.0
+
+ if gt_sem_cls is not None and gt_sem_seg is not None:
+ sem_labels = pos_mask.new_full((self.num_stuff_classes, ),
+ self.num_classes,
+ dtype=torch.long)
+ sem_targets = pos_mask.new_zeros(self.num_stuff_classes, H, W)
+ sem_weights = pos_mask.new_zeros(self.num_stuff_classes, H, W)
+ sem_stuff_weights = torch.eye(
+ self.num_stuff_classes, device=pos_mask.device)
+ sem_thing_weights = pos_mask.new_zeros(
+ (self.num_stuff_classes, self.num_thing_classes))
+ sem_label_weights = torch.cat(
+ [sem_thing_weights, sem_stuff_weights], dim=-1)
+ if len(gt_sem_cls > 0):
+ sem_inds = gt_sem_cls - self.num_thing_classes
+ sem_inds = sem_inds.long()
+ sem_labels[sem_inds] = gt_sem_cls.long()
+ sem_targets[sem_inds] = gt_sem_seg
+ sem_weights[sem_inds] = 1
+
+ label_weights[:, self.num_thing_classes:] = 0
+ labels = torch.cat([labels, sem_labels])
+ label_weights = torch.cat([label_weights, sem_label_weights])
+ mask_targets = torch.cat([mask_targets, sem_targets])
+ mask_weights = torch.cat([mask_weights, sem_weights])
+
+ return labels, label_weights, mask_targets, mask_weights
+
+ def get_targets(self,
+ sampling_results,
+ rcnn_train_cfg,
+ concat=True,
+ gt_sem_seg=None,
+ gt_sem_cls=None):
+ num_imgs = len(sampling_results)
+ pos_inds_list = [res.pos_inds for res in sampling_results]
+ neg_inds_list = [res.neg_inds for res in sampling_results]
+ pos_mask_list = [res.pos_masks for res in sampling_results]
+ neg_mask_list = [res.neg_masks for res in sampling_results]
+ pos_gt_mask_list = [res.pos_gt_masks for res in sampling_results]
+ pos_gt_labels_list = [res.pos_gt_labels for res in sampling_results]
+ if gt_sem_seg is None:
+ gt_sem_seg = [None] * num_imgs
+ gt_sem_cls = [None] * num_imgs
+
+ labels, label_weights, mask_targets, mask_weights = multi_apply(
+ self._get_target_single,
+ pos_inds_list,
+ neg_inds_list,
+ pos_mask_list,
+ neg_mask_list,
+ pos_gt_mask_list,
+ pos_gt_labels_list,
+ gt_sem_seg,
+ gt_sem_cls,
+ cfg=rcnn_train_cfg)
+ if concat:
+ labels = torch.cat(labels, 0)
+ label_weights = torch.cat(label_weights, 0)
+ mask_targets = torch.cat(mask_targets, 0)
+ mask_weights = torch.cat(mask_weights, 0)
+ return labels, label_weights, mask_targets, mask_weights
+
+ def rescale_masks(self, masks_per_img, img_meta):
+ h, w, _ = img_meta['img_shape']
+ masks_per_img = F.interpolate(
+ masks_per_img.unsqueeze(0).sigmoid(),
+ size=img_meta['batch_input_shape'],
+ mode='bilinear',
+ align_corners=False)
+
+ masks_per_img = masks_per_img[:, :, :h, :w]
+ ori_shape = img_meta['ori_shape']
+ seg_masks = F.interpolate(
+ masks_per_img,
+ size=ori_shape[:2],
+ mode='bilinear',
+ align_corners=False).squeeze(0)
+ return seg_masks
+
+ def get_seg_masks(self, masks_per_img, labels_per_img, scores_per_img,
+ test_cfg, img_meta):
+ # resize mask predictions back
+ seg_masks = self.rescale_masks(masks_per_img, img_meta)
+ seg_masks = seg_masks > test_cfg.mask_thr
+ bbox_result, segm_result = self.segm2result(seg_masks, labels_per_img,
+ scores_per_img)
+ return bbox_result, segm_result
+
+ def segm2result(self, mask_preds, det_labels, cls_scores):
+ num_classes = self.num_classes
+ bbox_result = None
+ segm_result = [[] for _ in range(num_classes)]
+ mask_preds = mask_preds.cpu().numpy()
+ det_labels = det_labels.cpu().numpy()
+ cls_scores = cls_scores.cpu().numpy()
+ num_ins = mask_preds.shape[0]
+ # fake bboxes
+ bboxes = np.zeros((num_ins, 5), dtype=np.float32)
+ bboxes[:, -1] = cls_scores
+ bbox_result = [bboxes[det_labels == i, :] for i in range(num_classes)]
+ for idx in range(num_ins):
+ segm_result[det_labels[idx]].append(mask_preds[idx])
+ return bbox_result, segm_result
+
+ def get_seg_masks_tracking(self, masks_per_img, labels_per_img,
+ scores_per_img, ids_per_img, test_cfg,
+ img_meta):
+ num_ins = masks_per_img.shape[0]
+ # resize mask predictions back
+ seg_masks = self.rescale_masks(masks_per_img, img_meta)
+ seg_masks = seg_masks > test_cfg['mask_thr']
+ # fake bboxes
+ bboxes = torch.zeros((num_ins, 5), dtype=torch.float32)
+ bboxes[:, -1] = scores_per_img
+ tracks = outs2results(
+ bboxes=bboxes,
+ labels=labels_per_img,
+ masks=seg_masks,
+ ids=ids_per_img,
+ num_classes=self.num_classes,
+ )
+ return tracks['bbox_results'], tracks['mask_results']
diff --git a/modelscope/models/cv/video_instance_segmentation/track/mask_hungarian_assigner.py b/modelscope/models/cv/video_instance_segmentation/track/mask_hungarian_assigner.py
new file mode 100644
index 00000000..ab7a937b
--- /dev/null
+++ b/modelscope/models/cv/video_instance_segmentation/track/mask_hungarian_assigner.py
@@ -0,0 +1,248 @@
+# The implementation is adopted from Video-K-Net,
+# made publicly available at https://github.com/lxtGH/Video-K-Net follow the MIT license
+
+import numpy as np
+import torch
+from mmdet.core import AssignResult, BaseAssigner
+from mmdet.core.bbox.builder import BBOX_ASSIGNERS
+from mmdet.core.bbox.match_costs.builder import MATCH_COST, build_match_cost
+
+try:
+ from scipy.optimize import linear_sum_assignment
+except ImportError:
+ linear_sum_assignment = None
+
+
+@MATCH_COST.register_module()
+class MaskCost(object):
+ """MaskCost.
+
+ Args:
+ weight (int | float, optional): loss_weight
+ """
+
+ def __init__(self, weight=1., pred_act=False, act_mode='sigmoid'):
+ self.weight = weight
+ self.pred_act = pred_act
+ self.act_mode = act_mode
+
+ def __call__(self, cls_pred, target):
+ """
+ Args:
+ cls_pred (Tensor): Predicted classification logits, shape
+ [num_query, num_class].
+ gt_labels (Tensor): Label of `gt_bboxes`, shape (num_gt,).
+
+ Returns:
+ torch.Tensor: cls_cost value with weight
+ """
+ if self.pred_act and self.act_mode == 'sigmoid':
+ cls_pred = cls_pred.sigmoid()
+ elif self.pred_act:
+ cls_pred = cls_pred.softmax(dim=0)
+
+ _, H, W = target.shape
+ # flatten_cls_pred = cls_pred.view(num_proposals, -1)
+ # eingum is ~10 times faster than matmul
+ pos_cost = torch.einsum('nhw,mhw->nm', cls_pred, target)
+ neg_cost = torch.einsum('nhw,mhw->nm', 1 - cls_pred, 1 - target)
+ cls_cost = -(pos_cost + neg_cost) / (H * W)
+ return cls_cost * self.weight
+
+
+@BBOX_ASSIGNERS.register_module()
+class MaskHungarianAssignerVideo(BaseAssigner):
+ """Computes one-to-one matching between predictions and ground truth.
+
+ This class computes an assignment between the targets and the predictions
+ based on the costs. The costs are weighted sum of three components:
+ classfication cost, regression L1 cost and regression iou cost. The
+ targets don't include the no_object, so generally there are more
+ predictions than targets. After the one-to-one matching, the un-matched
+ are treated as backgrounds. Thus each query prediction will be assigned
+ with `0` or a positive integer indicating the ground truth index:
+
+ - 0: negative sample, no assigned gt
+ - positive integer: positive sample, index (1-based) of assigned gt
+
+ Args:
+ cls_weight (int | float, optional): The scale factor for classification
+ cost. Default 1.0.
+ bbox_weight (int | float, optional): The scale factor for regression
+ L1 cost. Default 1.0.
+ iou_weight (int | float, optional): The scale factor for regression
+ iou cost. Default 1.0.
+ iou_calculator (dict | optional): The config for the iou calculation.
+ Default type `BboxOverlaps2D`.
+ iou_mode (str | optional): "iou" (intersection over union), "iof"
+ (intersection over foreground), or "giou" (generalized
+ intersection over union). Default "giou".
+ """
+
+ def __init__(self,
+ cls_cost=dict(type='ClassificationCost', weight=1.),
+ mask_cost=dict(type='SigmoidCost', weight=1.0),
+ dice_cost=dict(),
+ boundary_cost=None,
+ topk=1):
+ self.cls_cost = build_match_cost(cls_cost)
+ self.mask_cost = build_match_cost(mask_cost)
+ self.dice_cost = build_match_cost(dice_cost)
+ if boundary_cost is not None:
+ self.boundary_cost = build_match_cost(boundary_cost)
+ else:
+ self.boundary_cost = None
+ self.topk = topk
+
+ def assign(self,
+ bbox_pred,
+ cls_pred,
+ gt_bboxes,
+ gt_labels,
+ gt_instance_ids,
+ img_meta=None,
+ gt_bboxes_ignore=None,
+ eps=1e-7):
+ """Computes one-to-one matching based on the weighted costs.
+
+ This method assign each query prediction to a ground truth or
+ background. The `assigned_gt_inds` with -1 means don't care,
+ 0 means negative sample, and positive number is the index (1-based)
+ of assigned gt.
+ The assignment is done in the following steps, the order matters.
+
+ 1. assign every prediction to -1
+ 2. compute the weighted costs
+ 3. do Hungarian matching on CPU based on the costs
+ 4. assign all to 0 (background) first, then for each matched pair
+ between predictions and gts, treat this prediction as foreground
+ and assign the corresponding gt index (plus 1) to it.
+
+ Args:
+ bbox_pred (Tensor): Predicted boxes with normalized coordinates
+ (cx, cy, w, h), which are all in range [0, 1]. Shape
+ [num_query, 4].
+ cls_pred (Tensor): Predicted classification logits, shape
+ [num_query, num_class].
+ gt_bboxes (Tensor): Ground truth boxes with unnormalized
+ coordinates (x1, y1, x2, y2). Shape [num_gt, 4].
+ gt_labels (Tensor): Label of `gt_bboxes`, shape (num_gt,).
+ img_meta (dict): Meta information for current image.
+ gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are
+ labelled as `ignored`. Default None.
+ eps (int | float, optional): A value added to the denominator for
+ numerical stability. Default 1e-7.
+
+ Returns:
+ :obj:`AssignResult`: The assigned result.
+ """
+ assert gt_bboxes_ignore is None, \
+ 'Only case when gt_bboxes_ignore is None is supported.'
+ instances = torch.unique(gt_instance_ids[:, 1])
+ num_frames = bbox_pred.size(0)
+ h, w = bbox_pred.shape[-2:]
+ gt_masks = []
+ gt_labels_tensor = []
+ for instance_id in instances:
+ temp = gt_instance_ids[gt_instance_ids[:, 1] == instance_id, 0]
+ gt_instance_frame_ids = temp
+ instance_masks = []
+ gt_label_id = None
+ for frame_id in range(num_frames):
+ gt_frame_instance_ids = gt_instance_ids[
+ gt_instance_ids[:, 0] == frame_id, 1]
+ gt_frame_label_ids = gt_labels[gt_labels[:, 0] == frame_id, 1]
+ assert len(gt_frame_label_ids) == len(gt_frame_label_ids)
+ if not (frame_id in gt_instance_frame_ids):
+ gt_mask_frame = torch.zeros(
+ (h, w),
+ device=gt_instance_frame_ids.device,
+ dtype=torch.float)
+ else:
+ gt_index = torch.nonzero(
+ (gt_frame_instance_ids == instance_id),
+ as_tuple=True)[0].item()
+ gt_mask_frame = gt_bboxes[frame_id][gt_index]
+ gt_label_id = gt_frame_label_ids[gt_index].item(
+ ) if gt_label_id is None else gt_label_id
+ assert gt_label_id == gt_frame_label_ids[gt_index].item()
+ instance_masks.append(gt_mask_frame)
+ gt_masks.append(torch.stack(instance_masks))
+ gt_labels_tensor.append(gt_label_id)
+ gt_masks = torch.stack(gt_masks)
+ gt_labels_tensor = torch.tensor(
+ gt_labels_tensor, device=gt_masks.device, dtype=torch.long)
+
+ num_gts, num_bboxes = len(instances), bbox_pred.size(1)
+
+ # 1. assign -1 by default
+ assigned_gt_inds = bbox_pred.new_full((num_bboxes, ),
+ -1,
+ dtype=torch.long)
+ assigned_labels = bbox_pred.new_full((num_bboxes, ),
+ -1,
+ dtype=torch.long)
+ if num_gts == 0 or num_bboxes == 0:
+ # No ground truth or boxes, return empty assignment
+ if num_gts == 0:
+ # No ground truth, assign all to background
+ assigned_gt_inds[:] = 0
+ return AssignResult(
+ num_gts, assigned_gt_inds, None, labels=assigned_labels)
+
+ # 2. compute the weighted costs
+ # classification and bboxcost.
+ pred_masks_match = torch.einsum('fqhw->qfhw', bbox_pred).reshape(
+ (num_bboxes, -1, w))
+ gt_masks_match = gt_masks.reshape((num_gts, -1, w))
+ if self.cls_cost.weight != 0 and cls_pred is not None:
+ cls_cost = self.cls_cost(cls_pred, gt_labels_tensor)
+ else:
+ cls_cost = 0
+ if self.mask_cost.weight != 0:
+ reg_cost = self.mask_cost(pred_masks_match, gt_masks_match)
+ else:
+ reg_cost = 0
+ if self.dice_cost.weight != 0:
+ dice_cost = self.dice_cost(pred_masks_match, gt_masks_match)
+ else:
+ dice_cost = 0
+ if self.boundary_cost is not None and self.boundary_cost.weight != 0:
+ b_cost = self.boundary_cost(pred_masks_match, gt_masks_match)
+ else:
+ b_cost = 0
+ cost = cls_cost + reg_cost + dice_cost + b_cost
+
+ # 3. do Hungarian matching on CPU using linear_sum_assignment
+ cost = cost.detach().cpu()
+ if linear_sum_assignment is None:
+ raise ImportError('Please run "pip install scipy" '
+ 'to install scipy first.')
+ if self.topk == 1:
+ matched_row_inds, matched_col_inds = linear_sum_assignment(cost)
+ else:
+ topk_matched_row_inds = []
+ topk_matched_col_inds = []
+ for i in range(self.topk):
+ matched_row_inds, matched_col_inds = linear_sum_assignment(
+ cost)
+ topk_matched_row_inds.append(matched_row_inds)
+ topk_matched_col_inds.append(matched_col_inds)
+ cost[matched_row_inds] = 1e10
+ matched_row_inds = np.concatenate(topk_matched_row_inds)
+ matched_col_inds = np.concatenate(topk_matched_col_inds)
+
+ matched_row_inds = torch.from_numpy(matched_row_inds).to(
+ bbox_pred.device)
+ matched_col_inds = torch.from_numpy(matched_col_inds).to(
+ bbox_pred.device)
+
+ # 4. assign backgrounds and foregrounds
+ # assign all indices to backgrounds first
+ assigned_gt_inds[:] = 0
+ # assign foregrounds based on matching results
+ assigned_gt_inds[matched_row_inds] = matched_col_inds + 1
+ assigned_labels[matched_row_inds] = gt_labels_tensor[matched_col_inds]
+ return AssignResult(
+ num_gts, assigned_gt_inds, None,
+ labels=assigned_labels), gt_masks_match
diff --git a/modelscope/models/cv/video_instance_segmentation/utils.py b/modelscope/models/cv/video_instance_segmentation/utils.py
new file mode 100644
index 00000000..d91e923b
--- /dev/null
+++ b/modelscope/models/cv/video_instance_segmentation/utils.py
@@ -0,0 +1,112 @@
+# The implementation is adopted from Video-K-Net,
+# made publicly available at https://github.com/lxtGH/Video-K-Net follow the MIT license
+import numpy as np
+import torch
+from mmdet.core import bbox2result
+
+
+def sem2ins_masks(gt_sem_seg, num_thing_classes=80):
+ """Convert semantic segmentation mask to binary masks
+
+ Args:
+ gt_sem_seg (torch.Tensor): Semantic masks to be converted.
+ [0, num_thing_classes-1] is the classes of things,
+ [num_thing_classes:] is the classes of stuff.
+ num_thing_classes (int, optional): Number of thing classes.
+ Defaults to 80.
+
+ Returns:
+ tuple[torch.Tensor]: (mask_labels, bin_masks).
+ Mask labels and binary masks of stuff classes.
+ """
+ # gt_sem_seg is zero-started, where zero indicates the first class
+ # since mmdet>=2.17.0, see more discussion in
+ # https://mmdetection.readthedocs.io/en/latest/conventions.html#coco-panoptic-dataset # noqa
+ classes = torch.unique(gt_sem_seg)
+ # classes ranges from 0 - N-1, where the class IDs in
+ # [0, num_thing_classes - 1] are IDs of thing classes
+ masks = []
+ labels = []
+
+ for i in classes:
+ # skip ignore class 255 and "thing classes" in semantic seg
+ if i == 255 or i < num_thing_classes:
+ continue
+ labels.append(i)
+ masks.append(gt_sem_seg == i)
+
+ if len(labels) > 0:
+ labels = torch.stack(labels)
+ masks = torch.cat(masks)
+ else:
+ labels = gt_sem_seg.new_zeros(size=[0])
+ masks = gt_sem_seg.new_zeros(
+ size=[0, gt_sem_seg.shape[-2], gt_sem_seg.shape[-1]])
+ return labels.long(), masks.float()
+
+
+def outs2results(bboxes=None,
+ labels=None,
+ masks=None,
+ ids=None,
+ num_classes=None,
+ **kwargs):
+ """Convert tracking/detection results to a list of numpy arrays.
+ Args:
+ bboxes (torch.Tensor | np.ndarray): shape (n, 5)
+ labels (torch.Tensor | np.ndarray): shape (n, )
+ masks (torch.Tensor | np.ndarray): shape (n, h, w)
+ ids (torch.Tensor | np.ndarray): shape (n, )
+ num_classes (int): class number, not including background class
+ Returns:
+ dict[str : list(ndarray) | list[list[np.ndarray]]]: tracking/detection
+ results of each class. It may contain keys as belows:
+ - bbox_results (list[np.ndarray]): Each list denotes bboxes of one
+ category.
+ - mask_results (list[list[np.ndarray]]): Each outer list denotes masks
+ of one category. Each inner list denotes one mask belonging to
+ the category. Each mask has shape (h, w).
+ """
+ assert labels is not None
+ assert num_classes is not None
+
+ results = dict()
+
+ if ids is not None:
+ valid_inds = ids > -1
+ ids = ids[valid_inds]
+ labels = labels[valid_inds]
+
+ if bboxes is not None:
+ if ids is not None:
+ bboxes = bboxes[valid_inds]
+ if bboxes.shape[0] == 0:
+ bbox_results = [
+ np.zeros((0, 6), dtype=np.float32)
+ for i in range(num_classes)
+ ]
+ else:
+ if isinstance(bboxes, torch.Tensor):
+ bboxes = bboxes.cpu().numpy()
+ labels = labels.cpu().numpy()
+ ids = ids.cpu().numpy()
+ bbox_results = [
+ np.concatenate(
+ (ids[labels == i, None], bboxes[labels == i, :]),
+ axis=1) for i in range(num_classes)
+ ]
+ else:
+ bbox_results = bbox2result(bboxes, labels, num_classes)
+ results['bbox_results'] = bbox_results
+
+ if masks is not None:
+ if ids is not None:
+ masks = masks[valid_inds]
+ if isinstance(masks, torch.Tensor):
+ masks = masks.detach().cpu().numpy()
+ masks_results = [[] for _ in range(num_classes)]
+ for i in range(bboxes.shape[0]):
+ masks_results[labels[i]].append(masks[i])
+ results['mask_results'] = masks_results
+
+ return results
diff --git a/modelscope/models/cv/video_instance_segmentation/video_knet.py b/modelscope/models/cv/video_instance_segmentation/video_knet.py
new file mode 100644
index 00000000..a412edba
--- /dev/null
+++ b/modelscope/models/cv/video_instance_segmentation/video_knet.py
@@ -0,0 +1,441 @@
+# The implementation is adopted from Video-K-Net,
+# made publicly available at https://github.com/lxtGH/Video-K-Net follow the MIT license
+
+import torch.nn as nn
+from mmdet.models import build_head, build_neck
+
+from modelscope.metainfo import Models
+from modelscope.models.base import TorchModel
+from modelscope.models.builder import MODELS
+from modelscope.models.cv.video_panoptic_segmentation.backbone.swin_transformer import \
+ SwinTransformerDIY
+from modelscope.models.cv.video_panoptic_segmentation.head.semantic_fpn_wrapper import \
+ SemanticFPNWrapper
+from modelscope.utils.constant import Tasks
+from .head.kernel_frame_iter_head import KernelFrameIterHeadVideo
+from .head.kernel_head import ConvKernelHeadVideo
+from .head.kernel_iter_head import KernelIterHeadVideo
+from .head.kernel_update_head import KernelUpdateHead
+from .head.kernel_updator import KernelUpdator
+from .neck import MSDeformAttnPixelDecoder
+from .track.kernel_update_head import KernelUpdateHeadVideo
+from .track.mask_hungarian_assigner import MaskHungarianAssignerVideo
+
+
+@MODELS.register_module(
+ Tasks.video_instance_segmentation,
+ module_name=Models.video_instance_segmentation)
+class KNetTrack(TorchModel):
+ """
+ Video K-Net: A Simple, Strong, and Unified Baseline for Video Segmentation (https://arxiv.org/pdf/2204.04656.pdf)
+ Video K-Net is a strong and unified framework for fully end-to-end video panoptic and instance segmentation.
+ The method is built upon K-Net, a method that unifies image segmentation via a group of learnable kernels.
+ K-Net learns to simultaneously segment and track “things” and “stuff” in a video with simple kernel-based
+ appearance modeling and cross-temporal kernel interaction.
+ """
+
+ def __init__(self, model_dir: str, *args, **kwargs):
+ super().__init__(model_dir, *args, **kwargs)
+
+ self.roi_head = None
+ num_stages = 3
+ num_proposals = 100
+ conv_kernel_size = 1
+ num_thing_classes = 40
+ num_stuff_classes = 0
+ mask_assign_stride = 4
+ thing_label_in_seg = 0
+ direct_tracker = False
+ tracker_num = 1
+
+ # assert self.with_rpn, 'KNet does not support external proposals'
+ self.num_thing_classes = num_thing_classes
+ self.num_stuff_classes = num_stuff_classes
+ self.mask_assign_stride = mask_assign_stride
+ self.thing_label_in_seg = thing_label_in_seg
+ self.direct_tracker = direct_tracker
+ self.tracker_num = tracker_num
+
+ train_cfg = dict(
+ rpn=dict(
+ assigner=dict(
+ type='MaskHungarianAssigner',
+ cls_cost=dict(type='FocalLossCost', weight=2.0),
+ dice_cost=dict(type='DiceCost', weight=4.0, pred_act=True),
+ mask_cost=dict(type='MaskCost', weight=1.0,
+ pred_act=True)),
+ sampler=dict(type='MaskPseudoSampler'),
+ pos_weight=1),
+ rcnn=[
+ dict(
+ assigner=dict(
+ type='MaskHungarianAssigner',
+ cls_cost=dict(type='FocalLossCost', weight=2.0),
+ dice_cost=dict(
+ type='DiceCost', weight=4.0, pred_act=True),
+ mask_cost=dict(
+ type='MaskCost', weight=1.0, pred_act=True)),
+ sampler=dict(type='MaskPseudoSampler'),
+ pos_weight=1) for _ in range(num_stages)
+ ],
+ tracker=dict(
+ assigner=dict(
+ type='MaskHungarianAssignerVideo',
+ cls_cost=dict(type='FocalLossCost', weight=2.0),
+ dice_cost=dict(type='DiceCost', weight=4.0, pred_act=True),
+ mask_cost=dict(type='MaskCost', weight=1.0,
+ pred_act=True)),
+ sampler=dict(type='MaskPseudoSampler'),
+ pos_weight=1))
+ self.train_cfg = train_cfg
+
+ test_cfg = dict(
+ rpn=None,
+ rcnn=dict(
+ max_per_img=10,
+ mask_thr=0.5,
+ merge_stuff_thing=dict(
+ iou_thr=0.5, stuff_max_area=4096, instance_score_thr=0.3)),
+ tracker=dict(
+ max_per_img=10,
+ mask_thr=0.5,
+ merge_stuff_thing=dict(
+ iou_thr=0.5, stuff_max_area=4096, instance_score_thr=0.3),
+ ))
+ self.test_cfg = test_cfg
+
+ self.backbone = SwinTransformerDIY(
+ embed_dims=128,
+ depths=[2, 2, 18, 2],
+ num_heads=[4, 8, 16, 32],
+ window_size=7,
+ mlp_ratio=4,
+ qkv_bias=True,
+ qk_scale=None,
+ drop_rate=0.,
+ attn_drop_rate=0.,
+ drop_path_rate=0.3,
+ use_abs_pos_embed=False,
+ patch_norm=True,
+ out_indices=(0, 1, 2, 3),
+ with_cp=True)
+
+ neck = dict(
+ type='MSDeformAttnPixelDecoder',
+ in_channels=[128, 256, 512, 1024],
+ num_outs=3,
+ norm_cfg=dict(type='GN', num_groups=32),
+ act_cfg=dict(type='ReLU'),
+ return_one_list=True,
+ encoder=dict(
+ type='DetrTransformerEncoder',
+ num_layers=6,
+ transformerlayers=dict(
+ type='BaseTransformerLayer',
+ attn_cfgs=dict(
+ type='MultiScaleDeformableAttention',
+ embed_dims=256,
+ num_heads=8,
+ num_levels=3,
+ num_points=4,
+ im2col_step=64,
+ dropout=0.0,
+ batch_first=False,
+ norm_cfg=None,
+ init_cfg=None),
+ ffn_cfgs=dict(
+ type='FFN',
+ embed_dims=256,
+ feedforward_channels=1024,
+ num_fcs=2,
+ ffn_drop=0.0,
+ act_cfg=dict(type='ReLU', inplace=True)),
+ operation_order=('self_attn', 'norm', 'ffn', 'norm')),
+ init_cfg=None),
+ positional_encoding=dict(
+ type='SinePositionalEncoding', num_feats=128, normalize=True),
+ init_cfg=None)
+ self.neck = build_neck(neck)
+
+ rpn_head = dict(
+ type='ConvKernelHeadVideo',
+ conv_kernel_size=conv_kernel_size,
+ feat_downsample_stride=2,
+ feat_refine_stride=1,
+ feat_refine=False,
+ use_binary=True,
+ num_loc_convs=1,
+ num_seg_convs=1,
+ conv_normal_init=True,
+ localization_fpn=dict(
+ type='SemanticFPNWrapper',
+ in_channels=256,
+ feat_channels=256,
+ out_channels=256,
+ start_level=0,
+ end_level=3,
+ upsample_times=2,
+ positional_encoding=dict(
+ type='SinePositionalEncoding',
+ num_feats=128,
+ normalize=True),
+ cat_coors=False,
+ cat_coors_level=3,
+ fuse_by_cat=False,
+ return_list=False,
+ num_aux_convs=1,
+ norm_cfg=dict(type='GN', num_groups=32, requires_grad=True)),
+ num_proposals=num_proposals,
+ proposal_feats_with_obj=True,
+ xavier_init_kernel=False,
+ kernel_init_std=1,
+ num_cls_fcs=1,
+ in_channels=256,
+ num_classes=40,
+ feat_transform_cfg=None,
+ loss_seg=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0),
+ loss_mask=dict(
+ type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
+ loss_dice=dict(type='DiceLoss', loss_weight=4.0))
+
+ self.rpn_head = build_head(rpn_head)
+
+ roi_head = dict(
+ type='KernelIterHeadVideo',
+ num_stages=num_stages,
+ stage_loss_weights=[1] * num_stages,
+ proposal_feature_channel=256,
+ num_thing_classes=40,
+ num_stuff_classes=0,
+ mask_head=[
+ dict(
+ type='KernelUpdateHead',
+ num_classes=40,
+ num_thing_classes=40,
+ num_stuff_classes=0,
+ num_ffn_fcs=2,
+ num_heads=8,
+ num_cls_fcs=1,
+ num_mask_fcs=1,
+ feedforward_channels=2048,
+ in_channels=256,
+ out_channels=256,
+ dropout=0.0,
+ mask_thr=0.5,
+ conv_kernel_size=conv_kernel_size,
+ mask_upsample_stride=2,
+ ffn_act_cfg=dict(type='ReLU', inplace=True),
+ with_ffn=True,
+ feat_transform_cfg=dict(
+ conv_cfg=dict(type='Conv2d'), act_cfg=None),
+ kernel_updator_cfg=dict(
+ type='KernelUpdator',
+ in_channels=256,
+ feat_channels=256,
+ out_channels=256,
+ input_feat_shape=3,
+ act_cfg=dict(type='ReLU', inplace=True),
+ norm_cfg=dict(type='LN')),
+ loss_mask=dict(
+ type='CrossEntropyLoss',
+ use_sigmoid=True,
+ loss_weight=1.0),
+ loss_dice=dict(type='DiceLoss', loss_weight=4.0),
+ loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=2.0)) for _ in range(num_stages)
+ ])
+ roi_head.update(test_cfg=self.test_cfg['rcnn'])
+ self.roi_head = build_head(roi_head)
+
+ tracker = dict(
+ type='KernelFrameIterHeadVideo',
+ num_proposals=num_proposals,
+ num_stages=3,
+ assign_stages=2,
+ proposal_feature_channel=256,
+ stage_loss_weights=(1., 1., 1.),
+ num_thing_classes=40,
+ num_stuff_classes=0,
+ mask_head=dict(
+ type='KernelUpdateHeadVideo',
+ num_proposals=num_proposals,
+ num_classes=40,
+ num_thing_classes=40,
+ num_stuff_classes=0,
+ num_ffn_fcs=2,
+ num_heads=8,
+ num_cls_fcs=1,
+ num_mask_fcs=1,
+ feedforward_channels=2048,
+ in_channels=256,
+ out_channels=256,
+ dropout=0.0,
+ mask_thr=0.5,
+ conv_kernel_size=conv_kernel_size,
+ mask_upsample_stride=2,
+ ffn_act_cfg=dict(type='ReLU', inplace=True),
+ with_ffn=True,
+ feat_transform_cfg=dict(
+ conv_cfg=dict(type='Conv2d'), act_cfg=None),
+ kernel_updator_cfg=dict(
+ type='KernelUpdator',
+ in_channels=256,
+ feat_channels=256,
+ out_channels=256,
+ input_feat_shape=3,
+ act_cfg=dict(type='ReLU', inplace=True),
+ norm_cfg=dict(type='LN')),
+ loss_mask=dict(
+ type='CrossEntropyLoss', use_sigmoid=True,
+ loss_weight=1.0),
+ loss_dice=dict(type='DiceLoss', loss_weight=4.0),
+ loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=2.0)))
+
+ if tracker is not None:
+ rcnn_train_cfg = train_cfg[
+ 'tracker'] if train_cfg is not None else None
+ tracker.update(train_cfg=rcnn_train_cfg)
+ tracker.update(test_cfg=test_cfg['tracker'])
+ self.tracker = build_head(tracker)
+ if self.tracker_num > 1:
+ self.tracker_extra = nn.ModuleList(
+ [build_head(tracker) for _ in range(tracker_num - 1)])
+
+ def extract_feat(self, img):
+ """Directly extract features from the backbone+neck."""
+ x = self.backbone(img)
+ x = self.neck(x)
+ return x
+
+ def forward(self, imgs, img_metas, **kwargs):
+ """
+ Args:
+ imgs (List[Tensor]): the outer list indicates test-time
+ augmentations and inner Tensor should have a shape NxCxHxW,
+ which contains all images in the batch.
+ img_metas (List[List[dict]]): the outer list indicates test-time
+ augs (multiscale, flip, etc.) and the inner list indicates
+ images in a batch.
+ """
+ for var, name in [(imgs, 'imgs'), (img_metas, 'img_metas')]:
+ if not isinstance(var, list):
+ raise TypeError(f'{name} must be a list, but got {type(var)}')
+
+ num_augs = len(imgs)
+ if num_augs != len(img_metas):
+ raise ValueError(f'num of augmentations ({len(imgs)}) '
+ f'!= num of image meta ({len(img_metas)})')
+
+ # NOTE the batched image size information may be useful, e.g.
+ # in DETR, this is needed for the construction of masks, which is
+ # then used for the transformer_head.
+ for img, img_meta in zip(imgs, img_metas):
+ batch_size = len(img_meta)
+ for img_id in range(batch_size):
+ img_meta[img_id]['batch_input_shape'] = tuple(img.size()[-2:])
+
+ if num_augs == 1:
+ # proposals (List[List[Tensor]]): the outer list indicates
+ # test-time augs (multiscale, flip, etc.) and the inner list
+ # indicates images in a batch.
+ # The Tensor should have a shape Px4, where P is the number of
+ # proposals.
+ if 'proposals' in kwargs:
+ kwargs['proposals'] = kwargs['proposals'][0]
+ kwargs['ref_img_metas'] = kwargs['ref_img_metas'][0]
+ kwargs['ref_img'] = kwargs['ref_img'][0]
+ return self.simple_test(imgs[0], img_metas[0], **kwargs)
+ else:
+ assert imgs[0].size(0) == 1, 'aug test does not support ' \
+ 'inference with batch size ' \
+ f'{imgs[0].size(0)}'
+ # TODO: support test augmentation for predefined proposals
+ assert 'proposals' not in kwargs
+ return self.aug_test(imgs, img_metas, **kwargs)
+
+ def aug_test(self, imgs, img_metas, rescale=False):
+ """Test with augmentations.
+
+ If rescale is False, then returned bboxes and masks will fit the scale
+ of imgs[0].
+ """
+ x = self.extract_feats(imgs)
+ proposal_list = self.rpn_head.aug_test_rpn(x, img_metas)
+ return self.roi_head.aug_test(
+ x, proposal_list, img_metas, rescale=rescale)
+
+ def simple_test(self, imgs, img_metas, **kwargs):
+ ref_img = kwargs['ref_img']
+ ref_img_metas = kwargs['ref_img_metas']
+ # Step 1 extract features and get masks
+ bs, num_frame, _, h, w = ref_img.size()
+ x = self.extract_feat(ref_img.reshape(bs * num_frame, _, h, w))
+
+ proposal_feats, x_feats, mask_preds, cls_scores, seg_preds = \
+ self.rpn_head.simple_test_rpn(x, img_metas, ref_img_metas)
+
+ if self.roi_head is not None:
+ segm_results_single_frame, features = self.roi_head.simple_test(
+ x_feats,
+ proposal_feats,
+ mask_preds,
+ cls_scores,
+ img_metas,
+ ref_img_metas,
+ imgs_whwh=None,
+ rescale=True)
+
+ if self.direct_tracker:
+ proposal_feats = self.rpn_head.init_kernels.weight.clone()
+ proposal_feats = proposal_feats[None].expand(
+ bs, *proposal_feats.size())
+ if mask_preds.shape[0] == bs * num_frame:
+ mask_preds = mask_preds.reshape(
+ (bs, num_frame, *mask_preds.size()[1:]))
+ x_feats = x_feats.reshape((bs, num_frame, *x_feats.size()[1:]))
+ else:
+ assert mask_preds.size()[:2] == (bs, num_frame)
+ assert x_feats.size()[:2] == (bs, num_frame)
+ segm_results, features = self.tracker.simple_test(
+ x=x_feats,
+ img_metas=img_metas,
+ ref_img_metas=ref_img_metas,
+ cls_scores=None,
+ masks=mask_preds,
+ obj_feats=proposal_feats,
+ )
+ if self.tracker_num > 1:
+ for i in range(self.tracker_num - 1):
+ segm_results, features = self.tracker_extra[i].simple_test(
+ x=features['x_feats'],
+ img_metas=img_metas,
+ ref_img_metas=ref_img_metas,
+ cls_scores=None,
+ masks=features['masks'],
+ obj_feats=features['obj_feats'],
+ )
+ else:
+ segm_results, _ = self.tracker.simple_test(
+ x=features['x_feats'],
+ img_metas=img_metas,
+ ref_img_metas=ref_img_metas,
+ cls_scores=features['cls_scores'],
+ masks=features['masks'],
+ obj_feats=features['obj_feats'],
+ )
+
+ return segm_results
diff --git a/modelscope/models/cv/video_panoptic_segmentation/head/semantic_fpn_wrapper.py b/modelscope/models/cv/video_panoptic_segmentation/head/semantic_fpn_wrapper.py
index 0cf487b8..d772096e 100644
--- a/modelscope/models/cv/video_panoptic_segmentation/head/semantic_fpn_wrapper.py
+++ b/modelscope/models/cv/video_panoptic_segmentation/head/semantic_fpn_wrapper.py
@@ -5,8 +5,10 @@ import torch
import torch.nn as nn
from mmcv.cnn import ConvModule, normal_init
from mmcv.cnn.bricks.transformer import build_positional_encoding
+from mmdet.models.builder import NECKS
+@NECKS.register_module()
class SemanticFPNWrapper(nn.Module):
"""
Implementation of Semantic FPN used in Panoptic FPN.
diff --git a/modelscope/models/cv/video_single_object_tracking/models/layers/attn_blocks.py b/modelscope/models/cv/video_single_object_tracking/models/layers/attn_blocks.py
index 702c84f1..4eaa40e7 100644
--- a/modelscope/models/cv/video_single_object_tracking/models/layers/attn_blocks.py
+++ b/modelscope/models/cv/video_single_object_tracking/models/layers/attn_blocks.py
@@ -38,6 +38,10 @@ def candidate_elimination(attn: torch.Tensor, tokens: torch.Tensor,
attn_t = attn[:, :, :lens_t, lens_t:]
if box_mask_z is not None:
+ if not isinstance(box_mask_z, list):
+ box_mask_z = [box_mask_z]
+ box_mask_z_cat = torch.stack(box_mask_z, dim=1)
+ box_mask_z = box_mask_z_cat.flatten(1)
box_mask_z = box_mask_z.unsqueeze(1).unsqueeze(-1).expand(
-1, attn_t.shape[1], -1, attn_t.shape[-1])
attn_t = attn_t[box_mask_z]
diff --git a/modelscope/models/cv/video_single_object_tracking/models/layers/head.py b/modelscope/models/cv/video_single_object_tracking/models/layers/head.py
index e0dc7b59..7d296929 100644
--- a/modelscope/models/cv/video_single_object_tracking/models/layers/head.py
+++ b/modelscope/models/cv/video_single_object_tracking/models/layers/head.py
@@ -55,18 +55,17 @@ class CenterPredictor(
if p.dim() > 1:
nn.init.xavier_uniform_(p)
- def forward(self, x, gt_score_map=None):
+ def forward(self, x, return_score=False):
""" Forward pass with input x. """
score_map_ctr, size_map, offset_map = self.get_score_map(x)
- # assert gt_score_map is None
- if gt_score_map is None:
- bbox = self.cal_bbox(score_map_ctr, size_map, offset_map)
+ if return_score:
+ bbox, max_score = self.cal_bbox(
+ score_map_ctr, size_map, offset_map, return_score=True)
+ return score_map_ctr, bbox, size_map, offset_map, max_score
else:
- bbox = self.cal_bbox(
- gt_score_map.unsqueeze(1), size_map, offset_map)
-
- return score_map_ctr, bbox, size_map, offset_map
+ bbox = self.cal_bbox(score_map_ctr, size_map, offset_map)
+ return score_map_ctr, bbox, size_map, offset_map
def cal_bbox(self,
score_map_ctr,
diff --git a/modelscope/models/cv/video_single_object_tracking/models/ostrack/ostrack.py b/modelscope/models/cv/video_single_object_tracking/models/ostrack/ostrack.py
index 52704a6c..cd560252 100644
--- a/modelscope/models/cv/video_single_object_tracking/models/ostrack/ostrack.py
+++ b/modelscope/models/cv/video_single_object_tracking/models/ostrack/ostrack.py
@@ -49,13 +49,13 @@ class OSTrack(nn.Module):
feat_last = x
if isinstance(x, list):
feat_last = x[-1]
- out = self.forward_head(feat_last, None)
+ out = self.forward_head(feat_last)
out.update(aux_dict)
out['backbone_feat'] = x
return out
- def forward_head(self, cat_feature, gt_score_map=None):
+ def forward_head(self, cat_feature):
"""
cat_feature: output embeddings of the backbone, it can be (HW1+HW2, B, C) or (HW2, B, C)
"""
@@ -67,8 +67,7 @@ class OSTrack(nn.Module):
if self.head_type == 'CENTER':
# run the center head
- score_map_ctr, bbox, size_map, offset_map = self.box_head(
- opt_feat, gt_score_map)
+ score_map_ctr, bbox, size_map, offset_map = self.box_head(opt_feat)
outputs_coord = bbox
outputs_coord_new = outputs_coord.view(bs, Nq, 4)
out = {
diff --git a/modelscope/models/audio/tts/kantts/train/__init__.py b/modelscope/models/cv/video_single_object_tracking/models/procontext/__init__.py
similarity index 100%
rename from modelscope/models/audio/tts/kantts/train/__init__.py
rename to modelscope/models/cv/video_single_object_tracking/models/procontext/__init__.py
diff --git a/modelscope/models/cv/video_single_object_tracking/models/procontext/procontext.py b/modelscope/models/cv/video_single_object_tracking/models/procontext/procontext.py
new file mode 100644
index 00000000..adb18ae4
--- /dev/null
+++ b/modelscope/models/cv/video_single_object_tracking/models/procontext/procontext.py
@@ -0,0 +1,110 @@
+# The ProContEXT implementation is also open-sourced by the authors,
+# and available at https://github.com/jp-lan/ProContEXT
+import torch
+from torch import nn
+
+from modelscope.models.cv.video_single_object_tracking.models.layers.head import \
+ build_box_head
+from .vit_ce import vit_base_patch16_224_ce
+
+
+class ProContEXT(nn.Module):
+ """ This is the base class for ProContEXT """
+
+ def __init__(self,
+ transformer,
+ box_head,
+ aux_loss=False,
+ head_type='CORNER'):
+ """ Initializes the model.
+ Parameters:
+ transformer: torch module of the transformer architecture.
+ aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
+ """
+ super().__init__()
+ self.backbone = transformer
+ self.box_head = box_head
+
+ self.aux_loss = aux_loss
+ self.head_type = head_type
+ if head_type == 'CORNER' or head_type == 'CENTER':
+ self.feat_sz_s = int(box_head.feat_sz)
+ self.feat_len_s = int(box_head.feat_sz**2)
+
+ def forward(
+ self,
+ template: torch.Tensor,
+ search: torch.Tensor,
+ ce_template_mask=None,
+ ce_keep_rate=None,
+ ):
+ x, aux_dict = self.backbone(
+ z=template,
+ x=search,
+ ce_template_mask=ce_template_mask,
+ ce_keep_rate=ce_keep_rate,
+ )
+
+ # Forward head
+ feat_last = x
+ if isinstance(x, list):
+ feat_last = x[-1]
+ out = self.forward_head(feat_last, None)
+
+ out.update(aux_dict)
+ out['backbone_feat'] = x
+ return out
+
+ def forward_head(self, cat_feature, gt_score_map=None):
+ """
+ cat_feature: output embeddings of the backbone, it can be (HW1+HW2, B, C) or (HW2, B, C)
+ """
+ enc_opt = cat_feature[:, -self.
+ feat_len_s:] # encoder output for the search region (B, HW, C)
+ opt = (enc_opt.unsqueeze(-1)).permute((0, 3, 2, 1)).contiguous()
+ bs, Nq, C, HW = opt.size()
+ opt_feat = opt.view(-1, C, self.feat_sz_s, self.feat_sz_s)
+
+ if self.head_type == 'CENTER':
+ # run the center head
+ score_map_ctr, bbox, size_map, offset_map, score = self.box_head(
+ opt_feat, return_score=True)
+ outputs_coord = bbox
+ outputs_coord_new = outputs_coord.view(bs, Nq, 4)
+ out = {
+ 'pred_boxes': outputs_coord_new,
+ 'score_map': score_map_ctr,
+ 'size_map': size_map,
+ 'offset_map': offset_map,
+ 'score': score
+ }
+ return out
+ else:
+ raise NotImplementedError
+
+
+def build_procontext(cfg):
+ if cfg.MODEL.BACKBONE.TYPE == 'vit_base_patch16_224_ce':
+ backbone = vit_base_patch16_224_ce(
+ False,
+ drop_path_rate=cfg.MODEL.BACKBONE.DROP_PATH_RATE,
+ ce_loc=cfg.MODEL.BACKBONE.CE_LOC,
+ ce_keep_ratio=cfg.MODEL.BACKBONE.CE_KEEP_RATIO,
+ )
+ hidden_dim = backbone.embed_dim
+ patch_start_index = 1
+ else:
+ raise NotImplementedError
+
+ backbone.finetune_track(cfg=cfg, patch_start_index=patch_start_index)
+
+ box_head = build_box_head(cfg, hidden_dim)
+
+ model = ProContEXT(
+ backbone,
+ box_head,
+ aux_loss=False,
+ head_type=cfg.MODEL.HEAD.TYPE,
+ )
+
+ return model
diff --git a/modelscope/models/cv/video_single_object_tracking/models/procontext/utils.py b/modelscope/models/cv/video_single_object_tracking/models/procontext/utils.py
new file mode 100644
index 00000000..b29019cf
--- /dev/null
+++ b/modelscope/models/cv/video_single_object_tracking/models/procontext/utils.py
@@ -0,0 +1,22 @@
+# The ProContEXT implementation is also open-sourced by the authors,
+# and available at https://github.com/jp-lan/ProContEXT
+import torch
+
+
+def combine_multi_tokens(template_tokens, search_tokens, mode='direct'):
+ if mode == 'direct':
+ if not isinstance(template_tokens, list):
+ merged_feature = torch.cat((template_tokens, search_tokens), dim=1)
+ elif len(template_tokens) >= 2:
+ merged_feature = torch.cat(
+ (template_tokens[0], template_tokens[1]), dim=1)
+ for i in range(2, len(template_tokens)):
+ merged_feature = torch.cat(
+ (merged_feature, template_tokens[i]), dim=1)
+ merged_feature = torch.cat((merged_feature, search_tokens), dim=1)
+ else:
+ merged_feature = torch.cat(
+ (template_tokens[0], template_tokens[1]), dim=1)
+ else:
+ raise NotImplementedError
+ return merged_feature
diff --git a/modelscope/models/cv/video_single_object_tracking/models/procontext/vit_ce.py b/modelscope/models/cv/video_single_object_tracking/models/procontext/vit_ce.py
new file mode 100644
index 00000000..bd580228
--- /dev/null
+++ b/modelscope/models/cv/video_single_object_tracking/models/procontext/vit_ce.py
@@ -0,0 +1,128 @@
+# The ProContEXT implementation is also open-sourced by the authors,
+# and available at https://github.com/jp-lan/ProContEXT
+from functools import partial
+
+import torch
+import torch.nn as nn
+from timm.models.layers import to_2tuple
+
+from modelscope.models.cv.video_single_object_tracking.models.layers.attn_blocks import \
+ CEBlock
+from modelscope.models.cv.video_single_object_tracking.models.layers.patch_embed import \
+ PatchEmbed
+from modelscope.models.cv.video_single_object_tracking.models.ostrack.utils import (
+ combine_tokens, recover_tokens)
+from modelscope.models.cv.video_single_object_tracking.models.ostrack.vit_ce import \
+ VisionTransformerCE
+from .utils import combine_multi_tokens
+
+
+class VisionTransformerCE_ProContEXT(VisionTransformerCE):
+ """ Vision Transformer with candidate elimination (CE) module
+
+ A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
+ - https://arxiv.org/abs/2010.11929
+
+ Includes distillation token & head support for `DeiT: Data-efficient Image Transformers`
+ - https://arxiv.org/abs/2012.12877
+ """
+
+ def forward_features(
+ self,
+ z,
+ x,
+ mask_x=None,
+ ce_template_mask=None,
+ ce_keep_rate=None,
+ ):
+ B = x.shape[0]
+
+ x = self.patch_embed(x)
+ x += self.pos_embed_x
+ if not isinstance(z, list):
+ z = self.patch_embed(z)
+ z += self.pos_embed_z
+ lens_z = self.pos_embed_z.shape[1]
+ x = combine_tokens(z, x, mode=self.cat_mode)
+ else:
+ z_list = []
+ for zi in z:
+ z_list.append(self.patch_embed(zi) + self.pos_embed_z)
+ lens_z = self.pos_embed_z.shape[1] * len(z_list)
+ x = combine_multi_tokens(z_list, x, mode=self.cat_mode)
+
+ x = self.pos_drop(x)
+
+ lens_x = self.pos_embed_x.shape[1]
+
+ global_index_t = torch.linspace(0, lens_z - 1, lens_z).to(x.device)
+ global_index_t = global_index_t.repeat(B, 1)
+
+ global_index_s = torch.linspace(0, lens_x - 1, lens_x).to(x.device)
+ global_index_s = global_index_s.repeat(B, 1)
+ removed_indexes_s = []
+ for i, blk in enumerate(self.blocks):
+ x, global_index_t, global_index_s, removed_index_s, attn = \
+ blk(x, global_index_t, global_index_s, mask_x, ce_template_mask, ce_keep_rate)
+
+ if self.ce_loc is not None and i in self.ce_loc:
+ removed_indexes_s.append(removed_index_s)
+
+ x = self.norm(x)
+ lens_x_new = global_index_s.shape[1]
+ lens_z_new = global_index_t.shape[1]
+
+ z = x[:, :lens_z_new]
+ x = x[:, lens_z_new:]
+
+ if removed_indexes_s and removed_indexes_s[0] is not None:
+ removed_indexes_cat = torch.cat(removed_indexes_s, dim=1)
+
+ pruned_lens_x = lens_x - lens_x_new
+ pad_x = torch.zeros([B, pruned_lens_x, x.shape[2]],
+ device=x.device)
+ x = torch.cat([x, pad_x], dim=1)
+ index_all = torch.cat([global_index_s, removed_indexes_cat], dim=1)
+ # recover original token order
+ C = x.shape[-1]
+ x = torch.zeros_like(x).scatter_(
+ dim=1,
+ index=index_all.unsqueeze(-1).expand(B, -1, C).to(torch.int64),
+ src=x)
+
+ x = recover_tokens(x, mode=self.cat_mode)
+
+ # re-concatenate with the template, which may be further used by other modules
+ x = torch.cat([z, x], dim=1)
+
+ aux_dict = {
+ 'attn': attn,
+ 'removed_indexes_s': removed_indexes_s, # used for visualization
+ }
+
+ return x, aux_dict
+
+ def forward(self, z, x, ce_template_mask=None, ce_keep_rate=None):
+
+ x, aux_dict = self.forward_features(
+ z,
+ x,
+ ce_template_mask=ce_template_mask,
+ ce_keep_rate=ce_keep_rate,
+ )
+
+ return x, aux_dict
+
+
+def _create_vision_transformer(pretrained=False, **kwargs):
+ model = VisionTransformerCE_ProContEXT(**kwargs)
+ return model
+
+
+def vit_base_patch16_224_ce(pretrained=False, **kwargs):
+ """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
+ """
+ model_kwargs = dict(
+ patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
+ model = _create_vision_transformer(pretrained=pretrained, **model_kwargs)
+ return model
diff --git a/modelscope/models/cv/video_single_object_tracking/tracker/__init__.py b/modelscope/models/cv/video_single_object_tracking/tracker/__init__.py
index e69de29b..82cc97e0 100644
--- a/modelscope/models/cv/video_single_object_tracking/tracker/__init__.py
+++ b/modelscope/models/cv/video_single_object_tracking/tracker/__init__.py
@@ -0,0 +1,3 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+from .ostrack import OSTrack
+from .procontext import ProContEXT
diff --git a/modelscope/models/cv/video_single_object_tracking/tracker/procontext.py b/modelscope/models/cv/video_single_object_tracking/tracker/procontext.py
new file mode 100644
index 00000000..6a8fdfcc
--- /dev/null
+++ b/modelscope/models/cv/video_single_object_tracking/tracker/procontext.py
@@ -0,0 +1,174 @@
+# The ProContEXT implementation is also open-sourced by the authors,
+# and available at https://github.com/jp-lan/ProContEXT
+from copy import deepcopy
+
+import torch
+
+from modelscope.models.cv.video_single_object_tracking.models.procontext.procontext import \
+ build_procontext
+from modelscope.models.cv.video_single_object_tracking.utils.utils import (
+ Preprocessor, clip_box, generate_mask_cond, hann2d, sample_target,
+ transform_image_to_crop)
+
+
+class ProContEXT():
+
+ def __init__(self, ckpt_path, device, cfg):
+ network = build_procontext(cfg)
+ network.load_state_dict(
+ torch.load(ckpt_path, map_location='cpu')['net'], strict=True)
+ self.cfg = cfg
+ if device.type == 'cuda':
+ self.network = network.to(device)
+ else:
+ self.network = network
+ self.network.eval()
+ self.preprocessor = Preprocessor(device)
+ self.state = None
+
+ self.feat_sz = self.cfg.TEST.SEARCH_SIZE // self.cfg.MODEL.BACKBONE.STRIDE
+ # motion constrain
+ if device.type == 'cuda':
+ self.output_window = hann2d(
+ torch.tensor([self.feat_sz, self.feat_sz]).long(),
+ centered=True).to(device)
+ else:
+ self.output_window = hann2d(
+ torch.tensor([self.feat_sz, self.feat_sz]).long(),
+ centered=True)
+ self.frame_id = 0
+ # for save boxes from all queries
+ self.z_dict1 = {}
+ self.z_dict_list = []
+ self.update_intervals = [100]
+
+ def initialize(self, image, info: dict):
+ # crop templates
+ crop_resize_patches = [
+ sample_target(
+ image,
+ info['init_bbox'],
+ factor,
+ output_sz=self.cfg.TEST.TEMPLATE_SIZE)
+ for factor in self.cfg.TEST.TEMPLATE_FACTOR
+ ]
+ z_patch_arr, resize_factor, z_amask_arr = zip(*crop_resize_patches)
+ for idx in range(len(z_patch_arr)):
+ template = self.preprocessor.process(z_patch_arr[idx],
+ z_amask_arr[idx])
+ with torch.no_grad():
+ self.z_dict1 = template
+ self.z_dict_list.append(self.z_dict1)
+ self.box_mask_z = []
+ if self.cfg.MODEL.BACKBONE.CE_LOC:
+ for i in range(len(self.cfg.TEST.TEMPLATE_FACTOR) * 2):
+ template_bbox = self.transform_bbox_to_crop(
+ info['init_bbox'], resize_factor[0],
+ template.tensors.device).squeeze(1)
+ self.box_mask_z.append(
+ generate_mask_cond(self.cfg, 1, template.tensors.device,
+ template_bbox))
+
+ # init dynamic templates with static templates
+ for idx in range(len(self.cfg.TEST.TEMPLATE_FACTOR)):
+ self.z_dict_list.append(deepcopy(self.z_dict_list[idx]))
+
+ # save states
+ self.state = info['init_bbox']
+ self.frame_id = 0
+
+ def track(self, image, info: dict = None):
+ H, W, _ = image.shape
+ self.frame_id += 1
+ x_patch_arr, resize_factor, x_amask_arr = sample_target(
+ image,
+ self.state,
+ self.cfg.TEST.SEARCH_FACTOR,
+ output_sz=self.cfg.TEST.SEARCH_SIZE) # (x1, y1, w, h)
+ search = self.preprocessor.process(x_patch_arr, x_amask_arr)
+
+ with torch.no_grad():
+ x_dict = search
+ # merge the template and the search
+ # run the transformer
+ if isinstance(self.z_dict_list, (list, tuple)):
+ self.z_dict = []
+ for i in range(len(self.cfg.TEST.TEMPLATE_FACTOR) * 2):
+ self.z_dict.append(self.z_dict_list[i].tensors)
+ out_dict = self.network.forward(
+ template=self.z_dict,
+ search=x_dict.tensors,
+ ce_template_mask=self.box_mask_z)
+
+ # add hann windows
+ pred_score_map = out_dict['score_map']
+ conf_score = out_dict['score']
+ response = self.output_window * pred_score_map
+ pred_boxes = self.network.box_head.cal_bbox(response,
+ out_dict['size_map'],
+ out_dict['offset_map'])
+ pred_boxes = pred_boxes.view(-1, 4)
+ # Baseline: Take the mean of all pred boxes as the final result
+ pred_box = (pred_boxes.mean(dim=0) * self.cfg.TEST.SEARCH_SIZE
+ / resize_factor).tolist() # (cx, cy, w, h) [0,1]
+ # get the final box result
+ self.state = clip_box(
+ self.map_box_back(pred_box, resize_factor), H, W, margin=10)
+
+ for idx, update_i in enumerate(self.update_intervals):
+ if self.frame_id % update_i == 0 and conf_score > 0.7:
+ crop_resize_patches2 = [
+ sample_target(
+ image,
+ self.state,
+ factor,
+ output_sz=self.cfg.TEST.TEMPLATE_SIZE)
+ for factor in self.cfg.TEST.TEMPLATE_FACTOR
+ ]
+ z_patch_arr2, _, z_amask_arr2 = zip(*crop_resize_patches2)
+ for idx_s in range(len(z_patch_arr2)):
+ template_t = self.preprocessor.process(
+ z_patch_arr2[idx_s], z_amask_arr2[idx_s])
+ self.z_dict_list[
+ idx_s
+ + len(self.cfg.TEST.TEMPLATE_FACTOR)] = template_t
+
+ x1, y1, w, h = self.state
+ x2 = x1 + w
+ y2 = y1 + h
+ return {'target_bbox': [x1, y1, x2, y2]}
+
+ def map_box_back(self, pred_box: list, resize_factor: float):
+ cx_prev, cy_prev = self.state[0] + 0.5 * self.state[2], self.state[
+ 1] + 0.5 * self.state[3]
+ cx, cy, w, h = pred_box
+ half_side = 0.5 * self.cfg.TEST.SEARCH_SIZE / resize_factor
+ cx_real = cx + (cx_prev - half_side)
+ cy_real = cy + (cy_prev - half_side)
+ return [cx_real - 0.5 * w, cy_real - 0.5 * h, w, h]
+
+ def transform_bbox_to_crop(self,
+ box_in,
+ resize_factor,
+ device,
+ box_extract=None,
+ crop_type='template'):
+ if crop_type == 'template':
+ crop_sz = torch.Tensor(
+ [self.cfg.TEST.TEMPLATE_SIZE, self.cfg.TEST.TEMPLATE_SIZE])
+ elif crop_type == 'search':
+ crop_sz = torch.Tensor(
+ [self.cfg.TEST.SEARCH_SIZE, self.cfg.TEST.SEARCH_SIZE])
+ else:
+ raise NotImplementedError
+
+ box_in = torch.tensor(box_in)
+ if box_extract is None:
+ box_extract = box_in
+ else:
+ box_extract = torch.tensor(box_extract)
+ template_bbox = transform_image_to_crop(
+ box_in, box_extract, resize_factor, crop_sz, normalize=True)
+ template_bbox = template_bbox.view(1, 1, 4).to(device)
+
+ return template_bbox
diff --git a/modelscope/models/audio/tts/kantts/utils/__init__.py b/modelscope/models/cv/video_streaming_perception/__init__.py
similarity index 100%
rename from modelscope/models/audio/tts/kantts/utils/__init__.py
rename to modelscope/models/cv/video_streaming_perception/__init__.py
diff --git a/modelscope/msdatasets/cv/hand_2d_keypoints/__init__.py b/modelscope/models/cv/video_streaming_perception/longshortnet/__init__.py
similarity index 75%
rename from modelscope/msdatasets/cv/hand_2d_keypoints/__init__.py
rename to modelscope/models/cv/video_streaming_perception/longshortnet/__init__.py
index 5c1c72c1..b938b734 100644
--- a/modelscope/msdatasets/cv/hand_2d_keypoints/__init__.py
+++ b/modelscope/models/cv/video_streaming_perception/longshortnet/__init__.py
@@ -4,11 +4,11 @@ from typing import TYPE_CHECKING
from modelscope.utils.import_utils import LazyImportModule
if TYPE_CHECKING:
- from .hand_2d_keypoints_dataset import Hand2DKeypointDataset
+ from .longshortnet import LongShortNet
else:
_import_structure = {
- 'hand_2d_keypoints_dataset': ['Hand2DKeypointDataset']
+ 'longshortnet': ['LongShortNet'],
}
import sys
diff --git a/tests/pipelines/adaseq_pipelines/__init__.py b/modelscope/models/cv/video_streaming_perception/longshortnet/exp/__init__.py
similarity index 100%
rename from tests/pipelines/adaseq_pipelines/__init__.py
rename to modelscope/models/cv/video_streaming_perception/longshortnet/exp/__init__.py
diff --git a/modelscope/models/cv/video_streaming_perception/longshortnet/exp/longshortnet_base.py b/modelscope/models/cv/video_streaming_perception/longshortnet/exp/longshortnet_base.py
new file mode 100644
index 00000000..620cbaad
--- /dev/null
+++ b/modelscope/models/cv/video_streaming_perception/longshortnet/exp/longshortnet_base.py
@@ -0,0 +1,66 @@
+# Copyright (c) 2014-2021 Megvii Inc.
+# Copyright (c) 2022-2023 Alibaba, Inc. and its affiliates. All rights reserved.
+
+from modelscope.models.cv.stream_yolo.exp.yolox_base import Exp
+
+
+class LongShortNetExp(Exp):
+
+ def __init__(self):
+ super(Exp, self).__init__()
+ self.depth = 1.0
+ self.width = 1.0
+ self.num_classes = 8
+ self.test_size = (600, 960)
+ self.test_conf = 0.3
+ self.nmsthre = 0.65
+ self.short_cfg = dict()
+ self.long_cfg = dict()
+ self.merge_cfg = dict()
+
+ def get_model(self):
+ from ..models.longshort import LONGSHORT
+ from ..models.dfp_pafpn_long import DFPPAFPNLONG
+ from ..models.dfp_pafpn_short import DFPPAFPNSHORT
+ from ..models.longshort_backbone_neck import BACKBONENECK
+ from modelscope.models.cv.stream_yolo.models.tal_head import TALHead
+ import torch.nn as nn
+
+ if getattr(self, 'model', None) is None:
+ in_channels = [256, 512, 1024]
+ long_backbone = (
+ DFPPAFPNLONG(
+ self.depth,
+ self.width,
+ in_channels=in_channels,
+ frame_num=self.long_cfg['frame_num'],
+ with_short_cut=self.long_cfg['with_short_cut'],
+ out_channels=self.long_cfg['out_channels'])
+ if self.long_cfg['frame_num'] != 0 else None)
+ short_backbone = DFPPAFPNSHORT(
+ self.depth,
+ self.width,
+ in_channels=in_channels,
+ frame_num=self.short_cfg['frame_num'],
+ with_short_cut=self.short_cfg['with_short_cut'],
+ out_channels=self.short_cfg['out_channels'])
+ backbone_neck = BACKBONENECK(
+ self.depth, self.width, in_channels=in_channels)
+ head = TALHead(
+ self.num_classes,
+ self.width,
+ in_channels=in_channels,
+ gamma=1.0,
+ ignore_thr=0.5,
+ ignore_value=1.5)
+ self.model = LONGSHORT(
+ long_backbone,
+ short_backbone,
+ backbone_neck,
+ head,
+ merge_form=self.merge_cfg['merge_form'],
+ in_channels=in_channels,
+ width=self.width,
+ with_short_cut=self.merge_cfg['with_short_cut'],
+ long_cfg=self.long_cfg)
+ return self.model
diff --git a/modelscope/models/cv/video_streaming_perception/longshortnet/longshortnet.py b/modelscope/models/cv/video_streaming_perception/longshortnet/longshortnet.py
new file mode 100644
index 00000000..ca35c4f2
--- /dev/null
+++ b/modelscope/models/cv/video_streaming_perception/longshortnet/longshortnet.py
@@ -0,0 +1,193 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import argparse
+import logging as logger
+import os
+import os.path as osp
+import time
+
+import cv2
+import json
+import numpy as np
+import torch
+from tqdm import tqdm
+
+from modelscope.metainfo import Models
+from modelscope.models.base.base_torch_model import TorchModel
+from modelscope.models.builder import MODELS
+from modelscope.models.cv.stream_yolo.data.data_augment import ValTransform
+from modelscope.models.cv.stream_yolo.utils import (postprocess,
+ timestamp_format)
+from modelscope.preprocessors import LoadImage
+from modelscope.utils.config import Config
+from modelscope.utils.constant import ModelFile, Tasks
+from .exp.longshortnet_base import LongShortNetExp
+
+
+@MODELS.register_module(
+ group_key=Tasks.video_object_detection, module_name=Models.longshortnet)
+class LongShortNet(TorchModel):
+
+ def __init__(self, model_dir: str, *args, **kwargs):
+ super().__init__(model_dir, *args, **kwargs)
+ self.depth = kwargs.get('depth', 0.33)
+ self.width = kwargs.get('width', 0.50)
+ self.num_classes = kwargs.get('num_classes', 8)
+ self.test_size = kwargs.get('test_size', (960, 600))
+ self.test_conf = kwargs.get('test_conf', 0.3)
+ self.nmsthre = kwargs.get('nmsthre', 0.55)
+ self.label_mapping = kwargs.get('labels', [
+ 'person', 'bicycle', 'car', 'motorcycle', 'bus', 'truck',
+ 'traffic light', 'stop sign'
+ ])
+ self.model_name = kwargs.get('model_name', 'longshortnet_s.pt')
+ self.short_cfg = kwargs.get(
+ 'short_cfg',
+ dict(
+ frame_num=1,
+ delta=1,
+ with_short_cut=False,
+ out_channels=[
+ ((64, 128, 256), 1),
+ ],
+ ))
+ self.long_cfg = kwargs.get(
+ 'long_cfg',
+ dict(
+ frame_num=3,
+ delta=1,
+ with_short_cut=False,
+ include_current_frame=False,
+ out_channels=[
+ ((21, 42, 85), 3),
+ ],
+ ))
+ self.merge_cfg = kwargs.get(
+ 'merge_cfg', dict(
+ merge_form='long_fusion',
+ with_short_cut=True,
+ ))
+
+ self.exp = LongShortNetExp()
+
+ self.exp.depth = self.depth
+ self.exp.width = self.width
+ self.exp.num_classes = self.num_classes
+ self.exp.test_size = self.test_size
+ self.exp.test_conf = self.test_conf
+ self.exp.nmsthre = self.nmsthre
+ self.exp.short_cfg = self.short_cfg
+ self.exp.long_cfg = self.long_cfg
+ self.exp.merge_cfg = self.merge_cfg
+
+ # build model
+ self.model = self.exp.get_model()
+ model_path = osp.join(model_dir, self.model_name)
+ ckpt = torch.load(model_path, map_location='cpu')
+ self.model.load_state_dict(ckpt['model'])
+ self.preproc = ValTransform(legacy=False)
+
+ def forward(self, inputs):
+ return self.inference_video(inputs)
+
+ def postprocess(self, input):
+ outputs = postprocess(
+ input,
+ self.num_classes,
+ self.test_conf,
+ self.nmsthre,
+ class_agnostic=True)
+
+ if len(outputs) == 1 and (outputs[0] is not None):
+ bboxes = outputs[0][:, 0:4].cpu().numpy() / self.resize_ratio
+ scores = outputs[0][:, 5].cpu().numpy()
+ labels = outputs[0][:, 6].cpu().int().numpy()
+ pred_label_names = []
+ for lab in labels:
+ pred_label_names.append(self.label_mapping[lab])
+ else:
+ bboxes = np.asarray([])
+ scores = np.asarray([])
+ pred_label_names = np.asarray([])
+
+ return bboxes, scores, pred_label_names
+
+ def inference_video(self, v_path):
+ outputs = []
+ capture = cv2.VideoCapture(v_path)
+ self.fps = capture.get(cv2.CAP_PROP_FPS)
+ self.ori_height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
+ self.ori_width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
+ self.ori_size = (self.ori_width, self.ori_height)
+ self.resize_ratio = min(self.test_size[0] / self.ori_size[0],
+ self.test_size[1] / self.ori_size[1])
+ self.device = next(self.model.parameters()).device
+ frame_idx = 0
+
+ while capture.isOpened():
+ ret, frame = capture.read()
+ if not ret:
+ break
+ if frame_idx == 0:
+ short_imgs_queue = [
+ frame.copy() for _ in range(self.short_cfg['frame_num'])
+ ]
+ long_imgs_queue = [
+ frame.copy() for _ in range(self.long_cfg['frame_num'])
+ ]
+ short_imgs_queue = [
+ cv2.resize(
+ x, self.test_size,
+ interpolation=cv2.INTER_LINEAR).astype(np.uint8)
+ for x in short_imgs_queue
+ ]
+ long_imgs_queue = [
+ cv2.resize(
+ x, self.test_size,
+ interpolation=cv2.INTER_LINEAR).astype(np.uint8)
+ for x in long_imgs_queue
+ ]
+ short_imgs_queue = [
+ self.preproc(x, None,
+ (self.test_size[1], self.test_size[0]))[0]
+ for x in short_imgs_queue
+ ]
+ long_imgs_queue = [
+ self.preproc(x, None,
+ (self.test_size[1], self.test_size[0]))[0]
+ for x in long_imgs_queue
+ ]
+ else:
+ long_imgs_queue = long_imgs_queue[1:] + short_imgs_queue[:]
+ short_imgs_queue = [
+ frame.copy() for _ in range(self.short_cfg['frame_num'])
+ ]
+ short_imgs_queue = [
+ cv2.resize(
+ x, self.test_size,
+ interpolation=cv2.INTER_LINEAR).astype(np.uint8)
+ for x in short_imgs_queue
+ ]
+ short_imgs_queue = [
+ self.preproc(x, None,
+ (self.test_size[1], self.test_size[0]))[0]
+ for x in short_imgs_queue
+ ]
+
+ short_img = np.concatenate(short_imgs_queue, axis=0)
+ long_img = np.concatenate(long_imgs_queue, axis=0)
+ short_img = torch.from_numpy(short_img).unsqueeze(0)
+ long_img = torch.from_numpy(long_img).unsqueeze(0)
+
+ short_img = short_img.to(self.device)
+ long_img = long_img.to(self.device)
+
+ output = self.model((short_img, long_img))
+ output = self.postprocess(output)
+
+ output += (timestamp_format(seconds=frame_idx / self.fps), )
+
+ outputs.append(output)
+
+ frame_idx += 1
+
+ return outputs
diff --git a/modelscope/models/cv/video_streaming_perception/longshortnet/models/__init__.py b/modelscope/models/cv/video_streaming_perception/longshortnet/models/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/modelscope/models/cv/video_streaming_perception/longshortnet/models/dfp_pafpn_long.py b/modelscope/models/cv/video_streaming_perception/longshortnet/models/dfp_pafpn_long.py
new file mode 100644
index 00000000..a93fcba0
--- /dev/null
+++ b/modelscope/models/cv/video_streaming_perception/longshortnet/models/dfp_pafpn_long.py
@@ -0,0 +1,153 @@
+# Copyright (c) 2014-2021 Megvii Inc.
+# Copyright (c) 2022-2023 Alibaba, Inc. and its affiliates. All rights reserved.
+
+from collections import defaultdict
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from modelscope.models.cv.stream_yolo.models.darknet import CSPDarknet
+from modelscope.models.cv.stream_yolo.models.network_blocks import (BaseConv,
+ DWConv)
+
+
+class DFPPAFPNLONG(nn.Module):
+
+ def __init__(self,
+ depth=1.0,
+ width=1.0,
+ in_features=('dark3', 'dark4', 'dark5'),
+ in_channels=[256, 512, 1024],
+ depthwise=False,
+ act='silu',
+ frame_num=2,
+ with_short_cut=True,
+ merge_form='pure_concat',
+ out_channels=[
+ ((64, 128, 256), 1),
+ ]):
+ super().__init__()
+ self.in_features = in_features
+ self.in_channels = in_channels
+ self.frame_num = frame_num
+ self.with_short_cut = with_short_cut
+ self.merge_form = merge_form
+ self.out_channels = out_channels
+ self.conv_group_num = len(out_channels)
+ self.conv_group_dict = defaultdict(dict)
+ assert self.frame_num == sum([x[1] for x in out_channels])
+ Conv = DWConv if depthwise else BaseConv
+
+ for i in range(self.conv_group_num):
+ setattr(
+ self, f'group_{i}_jian2',
+ Conv(
+ in_channels=int(in_channels[0] * width),
+ out_channels=self.out_channels[i][0][0],
+ ksize=1,
+ stride=1,
+ act=act,
+ ))
+
+ setattr(
+ self, f'group_{i}_jian1',
+ Conv(
+ in_channels=int(in_channels[1] * width),
+ out_channels=self.out_channels[i][0][1],
+ ksize=1,
+ stride=1,
+ act=act,
+ ))
+
+ setattr(
+ self, f'group_{i}_jian0',
+ Conv(
+ in_channels=int(in_channels[2] * width),
+ out_channels=self.out_channels[i][0][2],
+ ksize=1,
+ stride=1,
+ act=act,
+ ))
+
+ def off_forward(self, input, backbone_neck):
+
+ rurrent_pan_out2, rurrent_pan_out1, rurrent_pan_out0 = backbone_neck(
+ torch.split(input, 3, dim=1)[0])
+
+ support_pan_out2s = []
+ support_pan_out1s = []
+ support_pan_out0s = []
+ for i in range(self.frame_num - 1):
+
+ support_pan_out2, support_pan_out1, support_pan_out0 = backbone_neck(
+ torch.split(input, 3, dim=1)[i + 1])
+
+ support_pan_out2s.append(support_pan_out2)
+ support_pan_out1s.append(support_pan_out1)
+ support_pan_out0s.append(support_pan_out0)
+
+ all_pan_out2s = [rurrent_pan_out2] + support_pan_out2s
+ all_pan_out1s = [rurrent_pan_out1] + support_pan_out1s
+ all_pan_out0s = [rurrent_pan_out0] + support_pan_out0s
+ pan_out2s = []
+ pan_out1s = []
+ pan_out0s = []
+
+ frame_start_id = 0
+ for i in range(self.conv_group_num):
+ group_frame_num = self.out_channels[i][1]
+ for j in range(group_frame_num):
+ frame_id = frame_start_id + j
+ pan_out2s.append(
+ getattr(self, f'group_{i}_jian2')(all_pan_out2s[frame_id]))
+ pan_out1s.append(
+ getattr(self, f'group_{i}_jian1')(all_pan_out1s[frame_id]))
+ pan_out0s.append(
+ getattr(self, f'group_{i}_jian0')(all_pan_out0s[frame_id]))
+ frame_start_id += group_frame_num
+
+ if self.with_short_cut:
+ if self.merge_form == 'pure_concat':
+ pan_out2 = torch.cat(pan_out2s, dim=1) + rurrent_pan_out2
+ pan_out1 = torch.cat(pan_out1s, dim=1) + rurrent_pan_out1
+ pan_out0 = torch.cat(pan_out0s, dim=1) + rurrent_pan_out0
+ elif self.merge_form == 'add':
+ pan_out2 = torch.sum(
+ torch.stack(pan_out2s), dim=0) + rurrent_pan_out2
+ pan_out1 = torch.sum(
+ torch.stack(pan_out1s), dim=0) + rurrent_pan_out1
+ pan_out0 = torch.sum(
+ torch.stack(pan_out0s), dim=0) + rurrent_pan_out0
+ else:
+ raise Exception(
+ 'merge_form must be in ["pure_concat", "add"].')
+ else:
+ if self.merge_form == 'pure_concat':
+ pan_out2 = torch.cat(pan_out2s, dim=1)
+ pan_out1 = torch.cat(pan_out1s, dim=1)
+ pan_out0 = torch.cat(pan_out0s, dim=1)
+ elif self.merge_form == 'add':
+ pan_out2 = torch.sum(torch.stack(pan_out2s), dim=0)
+ pan_out1 = torch.sum(torch.stack(pan_out1s), dim=0)
+ pan_out0 = torch.sum(torch.stack(pan_out0s), dim=0)
+ else:
+ raise Exception(
+ 'merge_form must be in ["pure_concat", "add"].')
+ outputs = (pan_out2, pan_out1, pan_out0)
+
+ return outputs
+
+ def forward(self, input, buffer=None, mode='off_pipe', backbone_neck=None):
+
+ if mode == 'off_pipe':
+ if input.size()[1] == 3:
+ input = torch.cat([input, input], dim=1)
+ output = self.off_forward(input, backbone_neck)
+ else:
+ output = self.off_forward(input, backbone_neck)
+
+ return output
+
+ else:
+ raise NotImplementedError
diff --git a/modelscope/models/cv/video_streaming_perception/longshortnet/models/dfp_pafpn_short.py b/modelscope/models/cv/video_streaming_perception/longshortnet/models/dfp_pafpn_short.py
new file mode 100644
index 00000000..44c8d418
--- /dev/null
+++ b/modelscope/models/cv/video_streaming_perception/longshortnet/models/dfp_pafpn_short.py
@@ -0,0 +1,135 @@
+# Copyright (c) 2014-2021 Megvii Inc.
+# Copyright (c) 2022-2023 Alibaba, Inc. and its affiliates. All rights reserved.
+
+from collections import defaultdict
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from modelscope.models.cv.stream_yolo.models.darknet import CSPDarknet
+from modelscope.models.cv.stream_yolo.models.network_blocks import (BaseConv,
+ DWConv)
+
+
+class DFPPAFPNSHORT(nn.Module):
+
+ def __init__(self,
+ depth=1.0,
+ width=1.0,
+ in_features=('dark3', 'dark4', 'dark5'),
+ in_channels=[256, 512, 1024],
+ depthwise=False,
+ act='silu',
+ frame_num=2,
+ with_short_cut=True,
+ out_channels=[
+ ((64, 128, 256), 1),
+ ]):
+ super().__init__()
+ self.in_features = in_features
+ self.in_channels = in_channels
+ self.frame_num = frame_num
+ self.with_short_cut = with_short_cut
+ self.out_channels = out_channels
+ self.conv_group_num = len(out_channels)
+ self.conv_group_dict = defaultdict(dict)
+ assert self.frame_num == sum([x[1] for x in out_channels])
+ Conv = DWConv if depthwise else BaseConv
+
+ for i in range(self.conv_group_num):
+ setattr(
+ self, f'group_{i}_jian2',
+ Conv(
+ in_channels=int(in_channels[0] * width),
+ out_channels=self.out_channels[i][0][0],
+ ksize=1,
+ stride=1,
+ act=act,
+ ))
+
+ setattr(
+ self, f'group_{i}_jian1',
+ Conv(
+ in_channels=int(in_channels[1] * width),
+ out_channels=self.out_channels[i][0][1],
+ ksize=1,
+ stride=1,
+ act=act,
+ ))
+
+ setattr(
+ self, f'group_{i}_jian0',
+ Conv(
+ in_channels=int(in_channels[2] * width),
+ out_channels=self.out_channels[i][0][2],
+ ksize=1,
+ stride=1,
+ act=act,
+ ))
+
+ def off_forward(self, input, backbone_neck):
+
+ rurrent_pan_out2, rurrent_pan_out1, rurrent_pan_out0 = backbone_neck(
+ torch.split(input, 3, dim=1)[0])
+
+ support_pan_out2s = []
+ support_pan_out1s = []
+ support_pan_out0s = []
+ for i in range(self.frame_num - 1):
+
+ support_pan_out2, support_pan_out1, support_pan_out0 = backbone_neck(
+ torch.split(input, 3, dim=1)[i + 1])
+
+ support_pan_out2s.append(support_pan_out2)
+ support_pan_out1s.append(support_pan_out1)
+ support_pan_out0s.append(support_pan_out0)
+
+ all_pan_out2s = [rurrent_pan_out2] + support_pan_out2s
+ all_pan_out1s = [rurrent_pan_out1] + support_pan_out1s
+ all_pan_out0s = [rurrent_pan_out0] + support_pan_out0s
+ pan_out2s = []
+ pan_out1s = []
+ pan_out0s = []
+
+ frame_start_id = 0
+ for i in range(self.conv_group_num):
+ group_frame_num = self.out_channels[i][1]
+ for j in range(group_frame_num):
+ frame_id = frame_start_id + j
+ pan_out2s.append(
+ getattr(self, f'group_{i}_jian2')(all_pan_out2s[frame_id]))
+ pan_out1s.append(
+ getattr(self, f'group_{i}_jian1')(all_pan_out1s[frame_id]))
+ pan_out0s.append(
+ getattr(self, f'group_{i}_jian0')(all_pan_out0s[frame_id]))
+ frame_start_id += group_frame_num
+
+ if self.with_short_cut:
+ pan_out2 = torch.cat(pan_out2s, dim=1) + rurrent_pan_out2
+ pan_out1 = torch.cat(pan_out1s, dim=1) + rurrent_pan_out1
+ pan_out0 = torch.cat(pan_out0s, dim=1) + rurrent_pan_out0
+ else:
+ pan_out2 = torch.cat(pan_out2s, dim=1)
+ pan_out1 = torch.cat(pan_out1s, dim=1)
+ pan_out0 = torch.cat(pan_out0s, dim=1)
+
+ outputs = (pan_out2, pan_out1, pan_out0)
+ rurrent_pan_outs = (rurrent_pan_out2, rurrent_pan_out1,
+ rurrent_pan_out0)
+
+ return outputs, rurrent_pan_outs
+
+ def forward(self, input, buffer=None, mode='off_pipe', backbone_neck=None):
+
+ if mode == 'off_pipe':
+ if input.size()[1] == 3:
+ input = torch.cat([input, input], dim=1)
+ output = self.off_forward(input, backbone_neck)
+ else:
+ output = self.off_forward(input, backbone_neck)
+
+ return output
+
+ else:
+ raise NotImplementedError
diff --git a/modelscope/models/cv/video_streaming_perception/longshortnet/models/longshort.py b/modelscope/models/cv/video_streaming_perception/longshortnet/models/longshort.py
new file mode 100644
index 00000000..9b773d7e
--- /dev/null
+++ b/modelscope/models/cv/video_streaming_perception/longshortnet/models/longshort.py
@@ -0,0 +1,232 @@
+# Copyright (c) 2014-2021 Megvii Inc.
+# Copyright (c) 2022-2023 Alibaba, Inc. and its affiliates. All rights reserved.
+
+import torch
+import torch.nn as nn
+
+from modelscope.models.cv.stream_yolo.models.network_blocks import BaseConv
+
+
+class LONGSHORT(nn.Module):
+
+ def __init__(self,
+ long_backbone=None,
+ short_backbone=None,
+ backbone_neck=None,
+ head=None,
+ merge_form='add',
+ in_channels=[256, 512, 1024],
+ width=1.0,
+ act='silu',
+ with_short_cut=False,
+ long_cfg=None,
+ jian_ratio=None):
+ super().__init__()
+
+ self.long_backbone = long_backbone
+ self.short_backbone = short_backbone
+ self.backbone = backbone_neck
+ self.head = head
+ self.merge_form = merge_form
+ self.in_channels = in_channels
+ self.with_short_cut = with_short_cut
+ if merge_form == 'concat':
+ self.jian2 = BaseConv(
+ in_channels=int(in_channels[0] * width),
+ out_channels=int(in_channels[0] * width)
+ // 2 if jian_ratio is None else int(in_channels[0] * width
+ * jian_ratio),
+ ksize=1,
+ stride=1,
+ act=act,
+ )
+
+ self.jian1 = BaseConv(
+ in_channels=int(in_channels[1] * width),
+ out_channels=int(in_channels[1] * width)
+ // 2 if jian_ratio is None else int(in_channels[1] * width
+ * jian_ratio),
+ ksize=1,
+ stride=1,
+ act=act,
+ )
+
+ self.jian0 = BaseConv(
+ in_channels=int(in_channels[2] * width),
+ out_channels=int(in_channels[2] * width)
+ // 2 if jian_ratio is None else int(in_channels[2] * width
+ * jian_ratio),
+ ksize=1,
+ stride=1,
+ act=act,
+ )
+ elif merge_form == 'long_fusion':
+ assert long_cfg is not None and 'out_channels' in long_cfg
+ self.jian2 = BaseConv(
+ in_channels=sum(
+ [x[0][0] * x[1] for x in long_cfg['out_channels']]),
+ out_channels=int(in_channels[0] * width)
+ // 2 if jian_ratio is None else int(in_channels[0] * width
+ * jian_ratio),
+ ksize=1,
+ stride=1,
+ act=act,
+ )
+
+ self.jian1 = BaseConv(
+ in_channels=sum(
+ [x[0][1] * x[1] for x in long_cfg['out_channels']]),
+ out_channels=int(in_channels[1] * width)
+ // 2 if jian_ratio is None else int(in_channels[1] * width
+ * jian_ratio),
+ ksize=1,
+ stride=1,
+ act=act,
+ )
+
+ self.jian0 = BaseConv(
+ in_channels=sum(
+ [x[0][2] * x[1] for x in long_cfg['out_channels']]),
+ out_channels=int(in_channels[2] * width)
+ // 2 if jian_ratio is None else int(in_channels[2] * width
+ * jian_ratio),
+ ksize=1,
+ stride=1,
+ act=act,
+ )
+
+ def forward(self, x, targets=None, buffer=None, mode='off_pipe'):
+ assert mode in ['off_pipe', 'on_pipe']
+
+ if mode == 'off_pipe':
+ short_fpn_outs, rurrent_pan_outs = self.short_backbone(
+ x[0],
+ buffer=buffer,
+ mode='off_pipe',
+ backbone_neck=self.backbone)
+ long_fpn_outs = self.long_backbone(
+ x[1],
+ buffer=buffer,
+ mode='off_pipe',
+ backbone_neck=self.backbone
+ ) if self.long_backbone is not None else None
+ if not self.with_short_cut:
+ if self.long_backbone is None:
+ fpn_outs = short_fpn_outs
+ else:
+ if self.merge_form == 'add':
+ fpn_outs = [
+ x + y
+ for x, y in zip(short_fpn_outs, long_fpn_outs)
+ ]
+ elif self.merge_form == 'concat':
+ jian2_outs = [
+ self.jian2(short_fpn_outs[0]),
+ self.jian2(long_fpn_outs[0])
+ ]
+ jian1_outs = [
+ self.jian1(short_fpn_outs[1]),
+ self.jian1(long_fpn_outs[1])
+ ]
+ jian0_outs = [
+ self.jian0(short_fpn_outs[2]),
+ self.jian0(long_fpn_outs[2])
+ ]
+ fpn_outs_2 = torch.cat(jian2_outs, dim=1)
+ fpn_outs_1 = torch.cat(jian1_outs, dim=1)
+ fpn_outs_0 = torch.cat(jian0_outs, dim=1)
+ fpn_outs = (fpn_outs_2, fpn_outs_1, fpn_outs_0)
+ elif self.merge_form == 'pure_concat':
+ fpn_outs_2 = torch.cat(
+ [short_fpn_outs[0], long_fpn_outs[0]], dim=1)
+ fpn_outs_1 = torch.cat(
+ [short_fpn_outs[1], long_fpn_outs[1]], dim=1)
+ fpn_outs_0 = torch.cat(
+ [short_fpn_outs[2], long_fpn_outs[2]], dim=1)
+ fpn_outs = (fpn_outs_2, fpn_outs_1, fpn_outs_0)
+ elif self.merge_form == 'long_fusion':
+ fpn_outs_2 = torch.cat(
+ [short_fpn_outs[0],
+ self.jian2(long_fpn_outs[0])],
+ dim=1)
+ fpn_outs_1 = torch.cat(
+ [short_fpn_outs[1],
+ self.jian1(long_fpn_outs[1])],
+ dim=1)
+ fpn_outs_0 = torch.cat(
+ [short_fpn_outs[2],
+ self.jian0(long_fpn_outs[2])],
+ dim=1)
+ fpn_outs = (fpn_outs_2, fpn_outs_1, fpn_outs_0)
+ else:
+ raise Exception(
+ 'merge_form must be in ["add", "concat"]')
+ else:
+ if self.long_backbone is None:
+ fpn_outs = [
+ x + y for x, y in zip(short_fpn_outs, rurrent_pan_outs)
+ ]
+ else:
+ if self.merge_form == 'add':
+ fpn_outs = [
+ x + y + z
+ for x, y, z in zip(short_fpn_outs, long_fpn_outs,
+ rurrent_pan_outs)
+ ]
+ elif self.merge_form == 'concat':
+ jian2_outs = [
+ self.jian2(short_fpn_outs[0]),
+ self.jian2(long_fpn_outs[0])
+ ]
+ jian1_outs = [
+ self.jian1(short_fpn_outs[1]),
+ self.jian1(long_fpn_outs[1])
+ ]
+ jian0_outs = [
+ self.jian0(short_fpn_outs[2]),
+ self.jian0(long_fpn_outs[2])
+ ]
+ fpn_outs_2 = torch.cat(jian2_outs, dim=1)
+ fpn_outs_1 = torch.cat(jian1_outs, dim=1)
+ fpn_outs_0 = torch.cat(jian0_outs, dim=1)
+ fpn_outs = (fpn_outs_2, fpn_outs_1, fpn_outs_0)
+ fpn_outs = [
+ x + y for x, y in zip(fpn_outs, rurrent_pan_outs)
+ ]
+ elif self.merge_form == 'pure_concat':
+ fpn_outs_2 = torch.cat(
+ [short_fpn_outs[0], long_fpn_outs[0]], dim=1)
+ fpn_outs_1 = torch.cat(
+ [short_fpn_outs[1], long_fpn_outs[1]], dim=1)
+ fpn_outs_0 = torch.cat(
+ [short_fpn_outs[2], long_fpn_outs[2]], dim=1)
+ fpn_outs = (fpn_outs_2, fpn_outs_1, fpn_outs_0)
+ fpn_outs = [
+ x + y for x, y in zip(fpn_outs, rurrent_pan_outs)
+ ]
+ elif self.merge_form == 'long_fusion':
+ fpn_outs_2 = torch.cat(
+ [short_fpn_outs[0],
+ self.jian2(long_fpn_outs[0])],
+ dim=1)
+ fpn_outs_1 = torch.cat(
+ [short_fpn_outs[1],
+ self.jian1(long_fpn_outs[1])],
+ dim=1)
+ fpn_outs_0 = torch.cat(
+ [short_fpn_outs[2],
+ self.jian0(long_fpn_outs[2])],
+ dim=1)
+ fpn_outs = (fpn_outs_2, fpn_outs_1, fpn_outs_0)
+ fpn_outs = [
+ x + y for x, y in zip(fpn_outs, rurrent_pan_outs)
+ ]
+ else:
+ raise Exception(
+ 'merge_form must be in ["add", "concat"]')
+
+ outputs = self.head(fpn_outs)
+
+ return outputs
+ else:
+ raise NotImplementedError
diff --git a/modelscope/models/cv/video_streaming_perception/longshortnet/models/longshort_backbone_neck.py b/modelscope/models/cv/video_streaming_perception/longshortnet/models/longshort_backbone_neck.py
new file mode 100644
index 00000000..4625d10a
--- /dev/null
+++ b/modelscope/models/cv/video_streaming_perception/longshortnet/models/longshort_backbone_neck.py
@@ -0,0 +1,121 @@
+# Copyright (c) 2014-2021 Megvii Inc.
+# Copyright (c) 2022-2023 Alibaba, Inc. and its affiliates. All rights reserved.
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from modelscope.models.cv.stream_yolo.models.darknet import CSPDarknet
+from modelscope.models.cv.stream_yolo.models.network_blocks import (BaseConv,
+ CSPLayer,
+ DWConv)
+
+
+class BACKBONENECK(nn.Module):
+
+ def __init__(
+ self,
+ depth=1.0,
+ width=1.0,
+ in_features=('dark3', 'dark4', 'dark5'),
+ in_channels=[256, 512, 1024],
+ depthwise=False,
+ act='silu',
+ ):
+ super().__init__()
+ self.backbone = CSPDarknet(depth, width, depthwise=depthwise, act=act)
+ self.in_features = in_features
+ self.in_channels = in_channels
+ Conv = DWConv if depthwise else BaseConv
+
+ self.lateral_conv0 = BaseConv(
+ int(in_channels[2] * width),
+ int(in_channels[1] * width),
+ 1,
+ 1,
+ act=act)
+ self.C3_p4 = CSPLayer(
+ int(2 * in_channels[1] * width),
+ int(in_channels[1] * width),
+ round(3 * depth),
+ False,
+ depthwise=depthwise,
+ act=act,
+ ) # cat
+
+ self.reduce_conv1 = BaseConv(
+ int(in_channels[1] * width),
+ int(in_channels[0] * width),
+ 1,
+ 1,
+ act=act)
+ self.C3_p3 = CSPLayer(
+ int(2 * in_channels[0] * width),
+ int(in_channels[0] * width),
+ round(3 * depth),
+ False,
+ depthwise=depthwise,
+ act=act,
+ )
+
+ # bottom-up conv
+ self.bu_conv2 = Conv(
+ int(in_channels[0] * width),
+ int(in_channels[0] * width),
+ 3,
+ 2,
+ act=act)
+ self.C3_n3 = CSPLayer(
+ int(2 * in_channels[0] * width),
+ int(in_channels[1] * width),
+ round(3 * depth),
+ False,
+ depthwise=depthwise,
+ act=act,
+ )
+
+ # bottom-up conv
+ self.bu_conv1 = Conv(
+ int(in_channels[1] * width),
+ int(in_channels[1] * width),
+ 3,
+ 2,
+ act=act)
+ self.C3_n4 = CSPLayer(
+ int(2 * in_channels[1] * width),
+ int(in_channels[2] * width),
+ round(3 * depth),
+ False,
+ depthwise=depthwise,
+ act=act,
+ )
+
+ def forward(self, input):
+
+ rurrent_out_features = self.backbone(input)
+ rurrent_features = [rurrent_out_features[f] for f in self.in_features]
+ [rurrent_x2, rurrent_x1, rurrent_x0] = rurrent_features
+
+ rurrent_fpn_out0 = self.lateral_conv0(rurrent_x0)
+ rurrent_f_out0 = F.interpolate(
+ rurrent_fpn_out0, size=rurrent_x1.shape[2:4], mode='nearest')
+ rurrent_f_out0 = torch.cat([rurrent_f_out0, rurrent_x1], 1)
+ rurrent_f_out0 = self.C3_p4(rurrent_f_out0)
+
+ rurrent_fpn_out1 = self.reduce_conv1(rurrent_f_out0)
+ rurrent_f_out1 = F.interpolate(
+ rurrent_fpn_out1, size=rurrent_x2.shape[2:4], mode='nearest')
+ rurrent_f_out1 = torch.cat([rurrent_f_out1, rurrent_x2], 1)
+ rurrent_pan_out2 = self.C3_p3(rurrent_f_out1)
+
+ rurrent_p_out1 = self.bu_conv2(rurrent_pan_out2)
+ rurrent_p_out1 = torch.cat([rurrent_p_out1, rurrent_fpn_out1], 1)
+ rurrent_pan_out1 = self.C3_n3(rurrent_p_out1)
+
+ rurrent_p_out0 = self.bu_conv1(rurrent_pan_out1)
+ rurrent_p_out0 = torch.cat([rurrent_p_out0, rurrent_fpn_out0], 1)
+ rurrent_pan_out0 = self.C3_n4(rurrent_p_out0)
+
+ outputs = (rurrent_pan_out2, rurrent_pan_out1, rurrent_pan_out0)
+
+ return outputs
diff --git a/modelscope/models/cv/vidt/__init__.py b/modelscope/models/cv/vidt/__init__.py
new file mode 100644
index 00000000..785d0274
--- /dev/null
+++ b/modelscope/models/cv/vidt/__init__.py
@@ -0,0 +1,19 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+from typing import TYPE_CHECKING
+
+from modelscope.utils.import_utils import LazyImportModule
+
+if TYPE_CHECKING:
+ from .model import VidtModel
+else:
+ _import_structure = {
+ 'model': ['VidtModel'],
+ }
+ import sys
+ sys.modules[__name__] = LazyImportModule(
+ __name__,
+ globals()['__file__'],
+ _import_structure,
+ module_spec=__spec__,
+ extra_objects={},
+ )
diff --git a/modelscope/models/cv/vidt/backbone.py b/modelscope/models/cv/vidt/backbone.py
new file mode 100644
index 00000000..198ab498
--- /dev/null
+++ b/modelscope/models/cv/vidt/backbone.py
@@ -0,0 +1,1061 @@
+# The implementation here is modified based on timm,
+# originally Apache 2.0 License and publicly available at
+# https://github.com/naver-ai/vidt/blob/vidt-plus/methods/swin_w_ram.py
+
+import math
+import os
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as checkpoint
+from timm.models.layers import DropPath, to_2tuple, trunc_normal_
+
+
+class Mlp(nn.Module):
+ """ Multilayer perceptron."""
+
+ def __init__(self,
+ in_features,
+ hidden_features=None,
+ out_features=None,
+ act_layer=nn.GELU,
+ drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+def masked_sin_pos_encoding(x,
+ mask,
+ num_pos_feats,
+ temperature=10000,
+ scale=2 * math.pi):
+ """ Masked Sinusoidal Positional Encoding
+
+ Args:
+ x: [PATCH] tokens
+ mask: the padding mask for [PATCH] tokens
+ num_pos_feats: the size of channel dimension
+ temperature: the temperature value
+ scale: the normalization scale
+
+ Returns:
+ pos: Sinusoidal positional encodings
+ """
+
+ num_pos_feats = num_pos_feats // 2
+ not_mask = ~mask
+
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
+
+ eps = 1e-6
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * scale
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * scale
+
+ dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=x.device)
+ dim_t = temperature**(2 * (dim_t // 2) / num_pos_feats)
+
+ pos_x = x_embed[:, :, :, None] / dim_t
+ pos_y = y_embed[:, :, :, None] / dim_t
+
+ pos_x = torch.stack(
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()),
+ dim=4).flatten(3)
+ pos_y = torch.stack(
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()),
+ dim=4).flatten(3)
+ pos = torch.cat((pos_y, pos_x), dim=3)
+
+ return pos
+
+
+def window_partition(x, window_size):
+ """
+ Args:
+ x: (B, H, W, C)
+ window_size (int): window size
+
+ Returns:
+ windows: (num_windows*B, window_size, window_size, C)
+ """
+ B, H, W, C = x.shape
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size,
+ C)
+ windows = x.permute(0, 1, 3, 2, 4,
+ 5).contiguous().view(-1, window_size, window_size, C)
+ return windows
+
+
+def window_reverse(windows, window_size, H, W):
+ """
+ Args:
+ windows: (num_windows*B, window_size, window_size, C)
+ window_size (int): Window size
+ H (int): Height of image
+ W (int): Width of image
+
+ Returns:
+ x: (B, H, W, C)
+ """
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
+ x = windows.view(B, H // window_size, W // window_size, window_size,
+ window_size, -1)
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
+ return x
+
+
+class ReconfiguredAttentionModule(nn.Module):
+ """ Window based multi-head self attention (W-MSA) module with relative position bias -> extended with RAM.
+ It supports both of shifted and non-shifted window.
+
+ !!!!!!!!!!! IMPORTANT !!!!!!!!!!!
+ The original attention module in Swin is replaced with the reconfigured attention module in Section 3.
+ All the Args are shared, so only the forward function is modified.
+ See https://arxiv.org/pdf/2110.03921.pdf
+ !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+
+ Args:
+ dim (int): Number of input channels.
+ window_size (tuple[int]): The height and width of the window.
+ num_heads (int): Number of attention heads.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
+ """
+
+ def __init__(self,
+ dim,
+ window_size,
+ num_heads,
+ qkv_bias=True,
+ qk_scale=None,
+ attn_drop=0.,
+ proj_drop=0.):
+
+ super().__init__()
+ self.dim = dim
+ self.window_size = window_size # Wh, Ww
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = qk_scale or head_dim**-0.5
+
+ # define a parameter table of relative position bias
+ self.relative_position_bias_table = nn.Parameter(
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1),
+ num_heads)) # 2*Wh-1 * 2*Ww-1, nH
+
+ # get pair-wise relative position index for each token inside the window
+ coords_h = torch.arange(self.window_size[0])
+ coords_w = torch.arange(self.window_size[1])
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
+ relative_coords = coords_flatten[:, :,
+ None] - coords_flatten[:,
+ None, :] # 2, Wh*Ww, Wh*Ww
+ relative_coords = relative_coords.permute(
+ 1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
+ relative_coords[:, :,
+ 0] += self.window_size[0] - 1 # shift to start from 0
+ relative_coords[:, :, 1] += self.window_size[1] - 1
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
+ self.register_buffer('relative_position_index',
+ relative_position_index)
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ trunc_normal_(self.relative_position_bias_table, std=.02)
+ self.softmax = nn.Softmax(dim=-1)
+
+ def forward(self,
+ x,
+ det,
+ mask=None,
+ cross_attn=False,
+ cross_attn_mask=None):
+ """ Forward function.
+ RAM module receives [Patch] and [DET] tokens and returns their calibrated ones
+
+ Args:
+ x: [PATCH] tokens
+ det: [DET] tokens
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None -> mask for shifted window attention
+
+ "additional inputs for RAM"
+ cross_attn: whether to use cross-attention [det x patch] (for selective cross-attention)
+ cross_attn_mask: mask for cross-attention
+
+ Returns:
+ patch_x: the calibrated [PATCH] tokens
+ det_x: the calibrated [DET] tokens
+ """
+
+ assert self.window_size[0] == self.window_size[1]
+ window_size = self.window_size[0]
+ local_map_size = window_size * window_size
+
+ # projection before window partitioning
+ if not cross_attn:
+ B, H, W, C = x.shape
+ N = H * W
+ x = x.view(B, N, C)
+ x = torch.cat([x, det], dim=1)
+ full_qkv = self.qkv(x)
+ patch_qkv, det_qkv = full_qkv[:, :N, :], full_qkv[:, N:, :]
+ else:
+ B, H, W, C = x[0].shape
+ N = H * W
+ _, ori_H, ori_W, _ = x[1].shape
+ ori_N = ori_H * ori_W
+
+ shifted_x = x[0].view(B, N, C)
+ cross_x = x[1].view(B, ori_N, C)
+ x = torch.cat([shifted_x, cross_x, det], dim=1)
+ full_qkv = self.qkv(x)
+ patch_qkv, cross_patch_qkv, det_qkv = \
+ full_qkv[:, :N, :], full_qkv[:, N:N + ori_N, :], full_qkv[:, N + ori_N:, :]
+ patch_qkv = patch_qkv.view(B, H, W, -1)
+
+ # window partitioning for [PATCH] tokens
+ patch_qkv = window_partition(
+ patch_qkv, window_size) # nW*B, window_size, window_size, C
+ B_ = patch_qkv.shape[0]
+ patch_qkv = patch_qkv.reshape(B_, window_size * window_size, 3,
+ self.num_heads, C // self.num_heads)
+ _patch_qkv = patch_qkv.permute(2, 0, 3, 1, 4)
+ patch_q, patch_k, patch_v = _patch_qkv[0], _patch_qkv[1], _patch_qkv[2]
+
+ # [PATCH x PATCH] self-attention using window partitions
+ patch_q = patch_q * self.scale
+ patch_attn = (patch_q @ patch_k.transpose(-2, -1))
+ # add relative pos bias for [patch x patch] self-attention
+ relative_position_bias = self.relative_position_bias_table[
+ self.relative_position_index.view(-1)].view(
+ self.window_size[0] * self.window_size[1],
+ self.window_size[0] * self.window_size[1],
+ -1) # Wh*Ww,Wh*Ww,nH
+ relative_position_bias = relative_position_bias.permute(
+ 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
+ patch_attn = patch_attn + relative_position_bias.unsqueeze(0)
+
+ # if shifted window is used, it needs to apply the mask
+ if mask is not None:
+ nW = mask.shape[0]
+ tmp0 = patch_attn.view(B_ // nW, nW, self.num_heads,
+ local_map_size, local_map_size)
+ tmp1 = mask.unsqueeze(1).unsqueeze(0)
+ patch_attn = tmp0 + tmp1
+ patch_attn = patch_attn.view(-1, self.num_heads, local_map_size,
+ local_map_size)
+
+ patch_attn = self.softmax(patch_attn)
+ patch_attn = self.attn_drop(patch_attn)
+ patch_x = (patch_attn @ patch_v).transpose(1, 2).reshape(
+ B_, window_size, window_size, C)
+
+ # extract qkv for [DET] tokens
+ det_qkv = det_qkv.view(B, -1, 3, self.num_heads, C // self.num_heads)
+ det_qkv = det_qkv.permute(2, 0, 3, 1, 4)
+ det_q, det_k, det_v = det_qkv[0], det_qkv[1], det_qkv[2]
+
+ # if cross-attention is activated
+ if cross_attn:
+
+ # reconstruct the spatial form of [PATCH] tokens for global [DET x PATCH] attention
+ cross_patch_qkv = cross_patch_qkv.view(B, ori_H, ori_W, 3,
+ self.num_heads,
+ C // self.num_heads)
+ patch_kv = cross_patch_qkv[:, :, :,
+ 1:, :, :].permute(3, 0, 4, 1, 2,
+ 5).contiguous()
+ patch_kv = patch_kv.view(2, B, self.num_heads, ori_H * ori_W, -1)
+
+ # extract "key and value" of [PATCH] tokens for cross-attention
+ cross_patch_k, cross_patch_v = patch_kv[0], patch_kv[1]
+
+ # bind key and value of [PATCH] and [DET] tokens for [DET X [PATCH, DET]] attention
+ det_k, det_v = torch.cat([cross_patch_k, det_k],
+ dim=2), torch.cat([cross_patch_v, det_v],
+ dim=2)
+
+ # [DET x DET] self-attention or binded [DET x [PATCH, DET]] attention
+ det_q = det_q * self.scale
+ det_attn = (det_q @ det_k.transpose(-2, -1))
+ # apply cross-attention mask if available
+ if cross_attn_mask is not None:
+ det_attn = det_attn + cross_attn_mask
+ det_attn = self.softmax(det_attn)
+ det_attn = self.attn_drop(det_attn)
+ det_x = (det_attn @ det_v).transpose(1, 2).reshape(B, -1, C)
+
+ # reverse window for [PATCH] tokens <- the output of [PATCH x PATCH] self attention
+ patch_x = window_reverse(patch_x, window_size, H, W)
+
+ # projection for outputs from multi-head
+ x = torch.cat([patch_x.view(B, H * W, C), det_x], dim=1)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+
+ # decompose after FFN into [PATCH] and [DET] tokens
+ patch_x = x[:, :H * W, :].view(B, H, W, C)
+ det_x = x[:, H * W:, :]
+
+ return patch_x, det_x
+
+
+class SwinTransformerBlock(nn.Module):
+ """ Swin Transformer Block.
+
+ Args:
+ dim (int): Number of input channels.
+ num_heads (int): Number of attention heads.
+ window_size (int): Window size.
+ shift_size (int): Shift size for SW-MSA.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+ drop (float, optional): Dropout rate. Default: 0.0
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ """
+
+ def __init__(self,
+ dim,
+ num_heads,
+ window_size=7,
+ shift_size=0,
+ mlp_ratio=4.,
+ qkv_bias=True,
+ qk_scale=None,
+ drop=0.,
+ attn_drop=0.,
+ drop_path=0.,
+ act_layer=nn.GELU,
+ norm_layer=nn.LayerNorm):
+ super().__init__()
+ self.dim = dim
+ self.num_heads = num_heads
+ self.window_size = window_size
+ self.shift_size = shift_size
+ self.mlp_ratio = mlp_ratio
+ assert 0 <= self.shift_size < self.window_size, 'shift_size must in 0-window_size'
+
+ self.norm1 = norm_layer(dim)
+ self.attn = ReconfiguredAttentionModule(
+ dim,
+ window_size=to_2tuple(self.window_size),
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ attn_drop=attn_drop,
+ proj_drop=drop)
+
+ self.drop_path = DropPath(
+ drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=drop)
+
+ self.H = None
+ self.W = None
+
+ def forward(self, x, mask_matrix, pos, cross_attn, cross_attn_mask):
+ """ Forward function.
+
+ Args:
+ x: Input feature, tensor size (B, H*W + DET, C). i.e., binded [PATCH, DET] tokens
+ H, W: Spatial resolution of the input feature.
+ mask_matrix: Attention mask for cyclic shift.
+
+ "additional inputs'
+ pos: (patch_pos, det_pos)
+ cross_attn: whether to use cross attn [det x [det + patch]]
+ cross_attn_mask: attention mask for cross-attention
+
+ Returns:
+ x: calibrated & binded [PATCH, DET] tokens
+ """
+
+ B, L, C = x.shape
+ H, W = self.H, self.W
+
+ assert L == H * W + self.det_token_num, 'input feature has wrong size'
+
+ shortcut = x
+ x = self.norm1(x)
+ x, det = x[:, :H * W, :], x[:, H * W:, :]
+ x = x.view(B, H, W, C)
+ orig_x = x
+
+ # pad feature maps to multiples of window size
+ pad_l = pad_t = 0
+ pad_r = (self.window_size - W % self.window_size) % self.window_size
+ pad_b = (self.window_size - H % self.window_size) % self.window_size
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
+ _, Hp, Wp, _ = x.shape
+
+ # projection for det positional encodings: make the channel size suitable for the current layer
+ patch_pos, det_pos = pos
+ det_pos = self.det_pos_linear(det_pos)
+
+ # cyclic shift
+ if self.shift_size > 0:
+ shifted_x = torch.roll(
+ x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
+ attn_mask = mask_matrix
+ else:
+ shifted_x = x
+ attn_mask = None
+
+ # prepare cross-attn and add positional encodings
+ if cross_attn:
+ # patch token (for cross-attention) + Sinusoidal pos encoding
+ cross_patch = orig_x + patch_pos
+ # det token + learnable pos encoding
+ det = det + det_pos
+ shifted_x = (shifted_x, cross_patch)
+ else:
+ # it cross_attn is deativated, only [PATCH] and [DET] self-attention are performed
+ det = det + det_pos
+ shifted_x = shifted_x
+
+ # W-MSA/SW-MSA
+ shifted_x, det = self.attn(
+ shifted_x,
+ mask=attn_mask,
+ # additional args
+ det=det,
+ cross_attn=cross_attn,
+ cross_attn_mask=cross_attn_mask)
+
+ # reverse cyclic shift
+ if self.shift_size > 0:
+ x = torch.roll(
+ shifted_x,
+ shifts=(self.shift_size, self.shift_size),
+ dims=(1, 2))
+ else:
+ x = shifted_x
+
+ if pad_r > 0 or pad_b > 0:
+ x = x[:, :H, :W, :].contiguous()
+
+ x = x.view(B, H * W, C)
+ x = torch.cat([x, det], dim=1)
+
+ # FFN
+ x = shortcut + self.drop_path(x)
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+
+ return x
+
+
+class PatchMerging(nn.Module):
+ """ Patch Merging Layer
+
+ Args:
+ dim (int): Number of input channels.
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ """
+
+ def __init__(self, dim, norm_layer=nn.LayerNorm, expand=True):
+ super().__init__()
+ self.dim = dim
+
+ # if expand is True, the channel size will be expanded, otherwise, return 256 size of channel
+ expand_dim = 2 * dim if expand else 256
+ self.reduction = nn.Linear(4 * dim, expand_dim, bias=False)
+ self.norm = norm_layer(4 * dim)
+
+ # added for detection token [please ignore, not used for training]
+ # not implemented yet.
+ self.expansion = nn.Linear(dim, expand_dim, bias=False)
+ self.norm2 = norm_layer(dim)
+
+ def forward(self, x, H, W):
+ """ Forward function.
+
+ Args:
+ x: Input feature, tensor size (B, H*W, C), i.e., binded [PATCH, DET] tokens
+ H, W: Spatial resolution of the input feature.
+
+ Returns:
+ x: merged [PATCH, DET] tokens;
+ only [PATCH] tokens are reduced in spatial dim, while [DET] tokens is fix-scale
+ """
+
+ B, L, C = x.shape
+ assert L == H * W + self.det_token_num, 'input feature has wrong size'
+
+ x, det = x[:, :H * W, :], x[:, H * W:, :]
+ x = x.view(B, H, W, C)
+
+ # padding
+ pad_input = (H % 2 == 1) or (W % 2 == 1)
+ if pad_input:
+ x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
+
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
+
+ # simply repeating for DET tokens
+ det = det.repeat(1, 1, 4)
+
+ x = torch.cat([x, det], dim=1)
+ x = self.norm(x)
+ x = self.reduction(x)
+
+ return x
+
+
+class BasicLayer(nn.Module):
+ """ A basic Swin Transformer layer for one stage.
+
+ Args:
+ dim (int): Number of feature channels
+ depth (int): Depths of this stage.
+ num_heads (int): Number of attention head.
+ window_size (int): Local window size. Default: 7.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+ drop (float, optional): Dropout rate. Default: 0.0
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+ """
+
+ def __init__(self,
+ dim,
+ depth,
+ num_heads,
+ window_size=7,
+ mlp_ratio=4.,
+ qkv_bias=True,
+ qk_scale=None,
+ drop=0.,
+ attn_drop=0.,
+ drop_path=0.,
+ norm_layer=nn.LayerNorm,
+ downsample=None,
+ last=False,
+ use_checkpoint=False):
+
+ super().__init__()
+ self.window_size = window_size
+ self.shift_size = window_size // 2
+ self.depth = depth
+ self.dim = dim
+ self.use_checkpoint = use_checkpoint
+
+ # build blocks
+ self.blocks = nn.ModuleList([
+ SwinTransformerBlock(
+ dim=dim,
+ num_heads=num_heads,
+ window_size=window_size,
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop,
+ attn_drop=attn_drop,
+ drop_path=drop_path[i]
+ if isinstance(drop_path, list) else drop_path,
+ norm_layer=norm_layer) for i in range(depth)
+ ])
+
+ # patch merging layer
+ if downsample is not None:
+ self.downsample = downsample(
+ dim=dim, norm_layer=norm_layer, expand=(not last))
+ else:
+ self.downsample = None
+
+ def forward(self, x, H, W, det_pos, input_mask, cross_attn=False):
+ """ Forward function.
+
+ Args:
+ x: Input feature, tensor size (B, H*W, C).
+ H, W: Spatial resolution of the input feature.
+ det_pos: pos encoding for det token
+ input_mask: padding mask for inputs
+ cross_attn: whether to use cross attn [det x [det + patch]]
+ """
+
+ B = x.shape[0]
+
+ # calculate attention mask for SW-MSA
+ Hp = int(np.ceil(H / self.window_size)) * self.window_size
+ Wp = int(np.ceil(W / self.window_size)) * self.window_size
+ img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
+ h_slices = (slice(0, -self.window_size),
+ slice(-self.window_size,
+ -self.shift_size), slice(-self.shift_size, None))
+ w_slices = (slice(0, -self.window_size),
+ slice(-self.window_size,
+ -self.shift_size), slice(-self.shift_size, None))
+ cnt = 0
+ for h in h_slices:
+ for w in w_slices:
+ img_mask[:, h, w, :] = cnt
+ cnt += 1
+
+ # mask for cyclic shift
+ mask_windows = window_partition(img_mask, self.window_size)
+ mask_windows = mask_windows.view(-1,
+ self.window_size * self.window_size)
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
+ attn_mask = attn_mask.masked_fill(attn_mask != 0,
+ float(-100.0)).masked_fill(
+ attn_mask == 0, float(0.0))
+
+ # compute sinusoidal pos encoding and cross-attn mask here to avoid redundant computation
+ if cross_attn:
+
+ _H, _W = input_mask.shape[1:]
+ if not (_H == H and _W == W):
+ input_mask = F.interpolate(
+ input_mask[None].float(), size=(H, W)).to(torch.bool)[0]
+
+ # sinusoidal pos encoding for [PATCH] tokens used in cross-attention
+ patch_pos = masked_sin_pos_encoding(x, input_mask, self.dim)
+
+ # attention padding mask due to the zero padding in inputs
+ # the zero (padded) area is masked by 1.0 in 'input_mask'
+ cross_attn_mask = input_mask.float()
+ cross_attn_mask = cross_attn_mask.masked_fill(cross_attn_mask != 0.0, float(-100.0)). \
+ masked_fill(cross_attn_mask == 0.0, float(0.0))
+
+ # pad for detection token (this padding is required to process the binded [PATCH, DET] attention
+ cross_attn_mask = cross_attn_mask.view(
+ B, H * W).unsqueeze(1).unsqueeze(2)
+ cross_attn_mask = F.pad(
+ cross_attn_mask, (0, self.det_token_num), value=0)
+
+ else:
+ patch_pos = None
+ cross_attn_mask = None
+
+ # zip pos encodings
+ pos = (patch_pos, det_pos)
+
+ for n_blk, blk in enumerate(self.blocks):
+ blk.H, blk.W = H, W
+
+ # for selective cross-attention
+ if cross_attn:
+ _cross_attn = True
+ _cross_attn_mask = cross_attn_mask
+ _pos = pos # i.e., (patch_pos, det_pos)
+ else:
+ _cross_attn = False
+ _cross_attn_mask = None
+ _pos = (None, det_pos)
+
+ if self.use_checkpoint:
+ x = checkpoint.checkpoint(
+ blk,
+ x,
+ attn_mask,
+ # additional inputs
+ pos=_pos,
+ cross_attn=_cross_attn,
+ cross_attn_mask=_cross_attn_mask)
+ else:
+ x = blk(
+ x,
+ attn_mask,
+ # additional inputs
+ pos=_pos,
+ cross_attn=_cross_attn,
+ cross_attn_mask=_cross_attn_mask)
+
+ # reduce the number of patch tokens, but maintaining a fixed-scale det tokens
+ # meanwhile, the channel dim increases by a factor of 2
+ if self.downsample is not None:
+ x_down = self.downsample(x, H, W)
+ Wh, Ww = (H + 1) // 2, (W + 1) // 2
+ return x, H, W, x_down, Wh, Ww
+ else:
+ return x, H, W, x, H, W
+
+
+class PatchEmbed(nn.Module):
+ """ Image to Patch Embedding
+
+ Args:
+ patch_size (int): Patch token size. Default: 4.
+ in_chans (int): Number of input image channels. Default: 3.
+ embed_dim (int): Number of linear projection output channels. Default: 96.
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
+ """
+
+ def __init__(self,
+ patch_size=4,
+ in_chans=3,
+ embed_dim=96,
+ norm_layer=None):
+ super().__init__()
+ patch_size = to_2tuple(patch_size)
+ self.patch_size = patch_size
+
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+
+ self.proj = nn.Conv2d(
+ in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+ if norm_layer is not None:
+ self.norm = norm_layer(embed_dim)
+ else:
+ self.norm = None
+
+ def forward(self, x):
+ """Forward function."""
+
+ # padding
+ _, _, H, W = x.size()
+ if W % self.patch_size[1] != 0:
+ x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
+ if H % self.patch_size[0] != 0:
+ x = F.pad(x,
+ (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
+
+ x = self.proj(x) # B C Wh Ww
+ if self.norm is not None:
+ Wh, Ww = x.size(2), x.size(3)
+ x = x.flatten(2).transpose(1, 2)
+ x = self.norm(x)
+ x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
+
+ return x
+
+
+class SwinTransformer(nn.Module):
+ """ Swin Transformer backbone.
+ A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
+ https://arxiv.org/pdf/2103.14030
+
+ Args:
+ pretrain_img_size (int): Input image size for training the pretrained model,
+ used in absolute postion embedding. Default 224.
+ patch_size (int | tuple(int)): Patch size. Default: 4.
+ in_chans (int): Number of input image channels. Default: 3.
+ embed_dim (int): Number of linear projection output channels. Default: 96.
+ depths (tuple[int]): Depths of each Swin Transformer stage.
+ num_heads (tuple[int]): Number of attention head of each stage.
+ window_size (int): Window size. Default: 7.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
+ drop_rate (float): Dropout rate.
+ attn_drop_rate (float): Attention dropout rate. Default: 0.
+ drop_path_rate (float): Stochastic depth rate. Default: 0.2.
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True.
+ out_indices (Sequence[int]): Output from which stages.
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
+ -1 means not freezing any args.
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+ """
+
+ def __init__(
+ self,
+ pretrain_img_size=224,
+ patch_size=4,
+ in_chans=3,
+ embed_dim=96,
+ depths=[2, 2, 6, 2],
+ num_heads=[3, 6, 12, 24],
+ window_size=7,
+ mlp_ratio=4.,
+ qkv_bias=True,
+ qk_scale=None,
+ drop_rate=0.,
+ attn_drop_rate=0.,
+ drop_path_rate=0.2,
+ norm_layer=nn.LayerNorm,
+ ape=False,
+ patch_norm=True,
+ out_indices=[1, 2,
+ 3], # not used in the current version, please ignore.
+ frozen_stages=-1,
+ use_checkpoint=False):
+ super().__init__()
+
+ self.pretrain_img_size = pretrain_img_size
+ self.num_layers = len(depths)
+ self.embed_dim = embed_dim
+ self.ape = ape
+ self.patch_norm = patch_norm
+ self.out_indices = out_indices
+ self.frozen_stages = frozen_stages
+
+ # split image into non-overlapping patches
+ self.patch_embed = PatchEmbed(
+ patch_size=patch_size,
+ in_chans=in_chans,
+ embed_dim=embed_dim,
+ norm_layer=norm_layer if self.patch_norm else None)
+
+ # absolute position embedding
+ if self.ape:
+ pretrain_img_size = to_2tuple(pretrain_img_size)
+ patch_size = to_2tuple(patch_size)
+ patches_resolution = [
+ pretrain_img_size[0] // patch_size[0],
+ pretrain_img_size[1] // patch_size[1]
+ ]
+
+ self.absolute_pos_embed = nn.Parameter(
+ torch.zeros(1, embed_dim, patches_resolution[0],
+ patches_resolution[1]))
+ trunc_normal_(self.absolute_pos_embed, std=.02)
+
+ self.pos_drop = nn.Dropout(p=drop_rate)
+
+ # stochastic depth
+ dpr = [
+ x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
+ ] # stochastic depth decay rule
+
+ # build layers
+ self.layers = nn.ModuleList()
+ for i_layer in range(self.num_layers):
+ layer = BasicLayer(
+ dim=int(embed_dim * 2**i_layer),
+ depth=depths[i_layer],
+ num_heads=num_heads[i_layer],
+ window_size=window_size,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop_rate,
+ attn_drop=attn_drop_rate,
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
+ norm_layer=norm_layer,
+ # modified by ViDT
+ downsample=PatchMerging if
+ (i_layer < self.num_layers) else None,
+ last=None if (i_layer < self.num_layers - 1) else True,
+ #
+ use_checkpoint=use_checkpoint)
+ self.layers.append(layer)
+
+ num_features = [int(embed_dim * 2**i) for i in range(self.num_layers)]
+ self.num_features = num_features
+
+ # add a norm layer for each output
+ # Not used in the current version -> please ignore. this error will be fixed later
+ # we leave this lines to load the pre-trained model ...
+ for i_layer in out_indices:
+ layer = norm_layer(num_features[i_layer])
+ layer_name = f'norm{i_layer}'
+ self.add_module(layer_name, layer)
+
+ self._freeze_stages()
+
+ def _freeze_stages(self):
+ if self.frozen_stages >= 0:
+ self.patch_embed.eval()
+ for param in self.patch_embed.parameters():
+ param.requires_grad = False
+
+ if self.frozen_stages >= 1 and self.ape:
+ self.absolute_pos_embed.requires_grad = False
+
+ if self.frozen_stages >= 2:
+ self.pos_drop.eval()
+ for i in range(0, self.frozen_stages - 1):
+ m = self.layers[i]
+ m.eval()
+ for param in m.parameters():
+ param.requires_grad = False
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {'det_pos_embed', 'det_token'}
+
+ def finetune_det(self,
+ method,
+ det_token_num=100,
+ pos_dim=256,
+ cross_indices=[3]):
+ """ A funtion to add neccessary (leanable) variables to Swin Transformer for object detection
+
+ Args:
+ method: vidt or vidt_wo_neck
+ det_token_num: the number of object to detect, i.e., number of object queries
+ pos_dim: the channel dimension of positional encodings for [DET] and [PATCH] tokens
+ cross_indices: the indices where to use the [DET X PATCH] cross-attention
+ there are four possible stages in [0, 1, 2, 3]. 3 indicates Stage 4 in the ViDT paper.
+ """
+
+ # which method?
+ self.method = method
+
+ # how many object we detect?
+ self.det_token_num = det_token_num
+ self.det_token = nn.Parameter(
+ torch.zeros(1, det_token_num, self.num_features[0]))
+ self.det_token = trunc_normal_(self.det_token, std=.02)
+
+ # dim size of pos encoding
+ self.pos_dim = pos_dim
+
+ # learnable positional encoding for detection tokens
+ det_pos_embed = torch.zeros(1, det_token_num, pos_dim)
+ det_pos_embed = trunc_normal_(det_pos_embed, std=.02)
+ self.det_pos_embed = torch.nn.Parameter(det_pos_embed)
+
+ # info for detection
+ self.num_channels = [
+ self.num_features[i + 1]
+ for i in range(len(self.num_features) - 1)
+ ]
+ if method == 'vidt':
+ self.num_channels.append(
+ self.pos_dim) # default: 256 (same to the default pos_dim)
+ self.cross_indices = cross_indices
+ # divisor to reduce the spatial size of the mask
+ self.mask_divisor = 2**(len(self.layers) - len(self.cross_indices))
+
+ # projection matrix for det pos encoding in each Swin layer (there are 4 blocks)
+ for layer in self.layers:
+ layer.det_token_num = det_token_num
+ if layer.downsample is not None:
+ layer.downsample.det_token_num = det_token_num
+ for block in layer.blocks:
+ block.det_token_num = det_token_num
+ block.det_pos_linear = nn.Linear(pos_dim, block.dim)
+
+ # neck-free model do not require downsamling at the last stage.
+ if method == 'vidt_wo_neck':
+ self.layers[-1].downsample = None
+
+ def forward(self, x, mask):
+ """ Forward function.
+
+ Args:
+ x: input rgb images
+ mask: input padding masks [0: rgb values, 1: padded values]
+
+ Returns:
+ patch_outs: multi-scale [PATCH] tokens (four scales are used)
+ these tokens are the first input of the neck decoder
+ det_tgt: final [DET] tokens obtained at the last stage
+ this tokens are the second input of the neck decoder
+ det_pos: the learnable pos encoding for [DET] tokens.
+ these encodings are used to generate reference points in deformable attention
+ """
+
+ # original input shape
+ B, _, _ = x.shape[0], x.shape[2], x.shape[3]
+
+ # patch embedding
+ x = self.patch_embed(x)
+
+ Wh, Ww = x.size(2), x.size(3)
+ x = x.flatten(2).transpose(1, 2)
+ x = self.pos_drop(x)
+
+ # expand det_token for all examples in the batch
+ det_token = self.det_token.expand(B, -1, -1)
+
+ # det pos encoding -> will be projected in each block
+ det_pos = self.det_pos_embed
+
+ # prepare a mask for cross attention
+ mask = F.interpolate(
+ mask[None].float(),
+ size=(Wh // self.mask_divisor,
+ Ww // self.mask_divisor)).to(torch.bool)[0]
+
+ patch_outs = []
+ for stage in range(self.num_layers):
+ layer = self.layers[stage]
+
+ # whether to use cross-attention
+ cross_attn = True if stage in self.cross_indices else False
+
+ # concat input
+ x = torch.cat([x, det_token], dim=1)
+
+ # inference
+ x_out, H, W, x, Wh, Ww = layer(
+ x,
+ Wh,
+ Ww,
+ # additional input for VIDT
+ input_mask=mask,
+ det_pos=det_pos,
+ cross_attn=cross_attn)
+
+ x, det_token = x[:, :-self.det_token_num, :], x[:, -self.
+ det_token_num:, :]
+
+ # Aggregate intermediate outputs
+ if stage > 0:
+ patch_out = x_out[:, :-self.det_token_num, :].view(
+ B, H, W, -1).permute(0, 3, 1, 2)
+ patch_outs.append(patch_out)
+
+ # patch token reduced from last stage output
+ patch_outs.append(x.view(B, Wh, Ww, -1).permute(0, 3, 1, 2))
+
+ # det token
+ det_tgt = x_out[:, -self.det_token_num:, :].permute(0, 2, 1)
+
+ # det token pos encoding
+ det_pos = det_pos.permute(0, 2, 1)
+
+ features_0, features_1, features_2, features_3 = patch_outs
+ return features_0, features_1, features_2, features_3, det_tgt, det_pos
+
+ def train(self, mode=True):
+ """Convert the model into training mode while keep layers freezed."""
+ super(SwinTransformer, self).train(mode)
+ self._freeze_stages()
+
+ # not working in the current version
+ def flops(self):
+ flops = 0
+ flops += self.patch_embed.flops()
+ for i, layer in enumerate(self.layers):
+ flops += layer.flops()
+ flops += self.num_features * self.patches_resolution[
+ 0] * self.patches_resolution[1] // (2**self.num_layers)
+ flops += self.num_features * self.num_classes
+ return flops
diff --git a/modelscope/models/cv/vidt/deformable_transformer.py b/modelscope/models/cv/vidt/deformable_transformer.py
new file mode 100644
index 00000000..7344ce5d
--- /dev/null
+++ b/modelscope/models/cv/vidt/deformable_transformer.py
@@ -0,0 +1,616 @@
+# The implementation here is modified based on timm,
+# originally Apache 2.0 License and publicly available at
+# https://github.com/naver-ai/vidt/blob/vidt-plus/methods/vidt/deformable_transformer.py
+
+import copy
+import math
+import warnings
+
+import torch
+import torch.nn.functional as F
+from timm.models.layers import DropPath
+from torch import nn
+from torch.nn.init import constant_, normal_, xavier_uniform_
+
+
+class DeformableTransformer(nn.Module):
+ """ A Deformable Transformer for the neck in a detector
+
+ The transformer encoder is completely removed for ViDT
+ Args:
+ d_model: the channel dimension for attention [default=256]
+ nhead: the number of heads [default=8]
+ num_decoder_layers: the number of decoding layers [default=6]
+ dim_feedforward: the channel dim of point-wise FFNs [default=1024]
+ dropout: the degree of dropout used in FFNs [default=0.1]
+ activation: An activation function to use [default='relu']
+ return_intermediate_dec: whether to return all the indermediate outputs [default=True]
+ num_feature_levels: the number of scales for extracted features [default=4]
+ dec_n_points: the number of reference points for deformable attention [default=4]
+ drop_path: the ratio of stochastic depth for decoding layers [default=0.0]
+ token_label: whether to use the token label loss for training [default=False]. This is an additional trick
+ proposed in https://openreview.net/forum?id=LhbD74dsZFL (ICLR'22) for further improvement
+ """
+
+ def __init__(self,
+ d_model=256,
+ nhead=8,
+ num_decoder_layers=6,
+ dim_feedforward=1024,
+ dropout=0.1,
+ activation='relu',
+ return_intermediate_dec=True,
+ num_feature_levels=4,
+ dec_n_points=4,
+ drop_path=0.,
+ token_label=False):
+ super().__init__()
+
+ self.d_model = d_model
+ self.nhead = nhead
+ decoder_layer = DeformableTransformerDecoderLayer(
+ d_model,
+ dim_feedforward,
+ dropout,
+ activation,
+ num_feature_levels,
+ nhead,
+ dec_n_points,
+ drop_path=drop_path)
+ self.decoder = DeformableTransformerDecoder(decoder_layer,
+ num_decoder_layers,
+ return_intermediate_dec)
+
+ self.level_embed = nn.Parameter(
+ torch.Tensor(num_feature_levels, d_model))
+ self.token_label = token_label
+
+ self.reference_points = nn.Linear(d_model, 2)
+
+ if self.token_label:
+ self.enc_output = nn.Linear(d_model, d_model)
+ self.enc_output_norm = nn.LayerNorm(d_model)
+
+ self.token_embed = nn.Linear(d_model, 91)
+ prior_prob = 0.01
+ bias_value = -math.log((1 - prior_prob) / prior_prob)
+ self.token_embed.bias.data = torch.ones(91) * bias_value
+
+ self._reset_parameters()
+
+ def _reset_parameters(self):
+ for p in self.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+ for m in self.modules():
+ if isinstance(m, MSDeformAttn):
+ m._reset_parameters()
+
+ normal_(self.level_embed)
+
+ def get_proposal_pos_embed(self, proposals):
+ num_pos_feats = 128
+ temperature = 10000
+ scale = 2 * math.pi
+
+ dim_t = torch.arange(
+ num_pos_feats, dtype=torch.float32, device=proposals.device)
+ dim_t = temperature**(2 * (dim_t // 2) / num_pos_feats)
+ # N, L, 4
+ proposals = proposals.sigmoid() * scale
+ # N, L, 4, 128
+ pos = proposals[:, :, :, None] / dim_t
+ # N, L, 4, 64, 2
+ pos = torch.stack((pos[:, :, :, 0::2].sin(), pos[:, :, :, 1::2].cos()),
+ dim=4).flatten(2)
+ return pos
+
+ def gen_encoder_output_proposals(self, memory, memory_padding_mask,
+ spatial_shapes):
+ N_, S_, C_ = memory.shape
+ proposals = []
+ _cur = 0
+ for lvl, (H_, W_) in enumerate(spatial_shapes):
+ mask_flatten_ = memory_padding_mask[:, _cur:(_cur + H_ * W_)].view(
+ N_, H_, W_, 1)
+ valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1)
+ valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1)
+
+ grid_y, grid_x = torch.meshgrid(
+ torch.linspace(
+ 0, H_ - 1, H_, dtype=torch.float32, device=memory.device),
+ torch.linspace(
+ 0, W_ - 1, W_, dtype=torch.float32, device=memory.device))
+ grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1)
+
+ scale = torch.cat([valid_W.unsqueeze(-1),
+ valid_H.unsqueeze(-1)], 1).view(N_, 1, 1, 2)
+ grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale
+ wh = torch.ones_like(grid) * 0.05 * (2.0**lvl)
+ proposal = torch.cat((grid, wh), -1).view(N_, -1, 4)
+ proposals.append(proposal)
+ _cur += (H_ * W_)
+ output_proposals = torch.cat(proposals, 1)
+ tmp = (output_proposals > 0.01) & (output_proposals < 0.99)
+ output_proposals_valid = tmp.all(-1, keepdim=True)
+ output_proposals = torch.log(output_proposals / (1 - output_proposals))
+ output_proposals = output_proposals.masked_fill(
+ memory_padding_mask.unsqueeze(-1), float('inf'))
+ output_proposals = output_proposals.masked_fill(
+ ~output_proposals_valid, float('inf'))
+
+ output_memory = memory
+ output_memory = output_memory.masked_fill(
+ memory_padding_mask.unsqueeze(-1), float(0))
+ output_memory = output_memory.masked_fill(~output_proposals_valid,
+ float(0))
+ output_memory = self.enc_output_norm(self.enc_output(output_memory))
+ return output_memory, output_proposals
+
+ def get_valid_ratio(self, mask):
+ _, H, W = mask.shape
+ valid_H = torch.sum(~mask[:, :, 0], 1)
+ valid_W = torch.sum(~mask[:, 0, :], 1)
+ valid_ratio_h = valid_H.float() / H
+ valid_ratio_w = valid_W.float() / W
+ valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
+ return valid_ratio
+
+ def forward(self, srcs, masks, tgt, query_pos):
+ """ The forward step of the decoder
+
+ Args:
+ srcs: [Patch] tokens
+ masks: input padding mask
+ tgt: [DET] tokens
+ query_pos: [DET] token pos encodings
+
+ Returns:
+ hs: calibrated [DET] tokens
+ init_reference_out: init reference points
+ inter_references_out: intermediate reference points for box refinement
+ enc_token_class_unflat: info. for token labeling
+ """
+
+ # prepare input for the Transformer decoder
+ src_flatten = []
+ mask_flatten = []
+ spatial_shapes = []
+ for lvl, (src, mask) in enumerate(zip(srcs, masks)):
+ bs, c, h, w = src.shape
+ spatial_shape = (h, w)
+ spatial_shapes.append(spatial_shape)
+ src = src.flatten(2).transpose(1, 2)
+ mask = mask.flatten(1)
+ src_flatten.append(src)
+ mask_flatten.append(mask)
+ src_flatten = torch.cat(src_flatten, 1)
+ mask_flatten = torch.cat(mask_flatten, 1)
+ spatial_shapes = torch.as_tensor(
+ spatial_shapes, dtype=torch.long, device=src_flatten.device)
+ level_start_index = torch.cat((spatial_shapes.new_zeros(
+ (1, )), spatial_shapes.prod(1).cumsum(0)[:-1]))
+ valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)
+
+ memory = src_flatten
+ bs, _, c = memory.shape
+ tgt = tgt # [DET] tokens
+ query_pos = query_pos.expand(bs, -1, -1) # [DET] token pos encodings
+
+ # prepare input for token label
+ if self.token_label:
+ output_memory, output_proposals = self.gen_encoder_output_proposals(
+ memory, mask_flatten, spatial_shapes)
+ enc_token_class_unflat = None
+ if self.token_label:
+ enc_token_class = self.token_embed(output_memory)
+ enc_token_class_unflat = []
+ for st, (h, w) in zip(level_start_index, spatial_shapes):
+ enc_token_class_unflat.append(
+ enc_token_class[:, st:st + h * w, :].view(bs, h, w, 91))
+
+ # reference points for deformable attention
+ reference_points = self.reference_points(query_pos).sigmoid()
+ init_reference_out = reference_points # query_pos -> reference point
+
+ # decoder
+ hs, inter_references = self.decoder(tgt, reference_points, memory,
+ spatial_shapes, level_start_index,
+ valid_ratios, query_pos,
+ mask_flatten)
+
+ inter_references_out = inter_references
+
+ return hs, init_reference_out, inter_references_out, enc_token_class_unflat
+
+
+class DeformableTransformerDecoderLayer(nn.Module):
+ """ A decoder layer.
+
+ Args:
+ d_model: the channel dimension for attention [default=256]
+ d_ffn: the channel dim of point-wise FFNs [default=1024]
+ dropout: the degree of dropout used in FFNs [default=0.1]
+ activation: An activation function to use [default='relu']
+ n_levels: the number of scales for extracted features [default=4]
+ n_heads: the number of heads [default=8]
+ n_points: the number of reference points for deformable attention [default=4]
+ drop_path: the ratio of stochastic depth for decoding layers [default=0.0]
+ """
+
+ def __init__(self,
+ d_model=256,
+ d_ffn=1024,
+ dropout=0.1,
+ activation='relu',
+ n_levels=4,
+ n_heads=8,
+ n_points=4,
+ drop_path=0.):
+ super().__init__()
+
+ # [DET x PATCH] deformable cross-attention
+ self.cross_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
+ self.dropout1 = nn.Dropout(dropout)
+ self.norm1 = nn.LayerNorm(d_model)
+
+ # [DET x DET] self-attention
+ self.self_attn = nn.MultiheadAttention(
+ d_model, n_heads, dropout=dropout)
+ self.dropout2 = nn.Dropout(dropout)
+ self.norm2 = nn.LayerNorm(d_model)
+
+ # ffn for multi-heaed
+ self.linear1 = nn.Linear(d_model, d_ffn)
+ self.activation = _get_activation_fn(activation)
+ self.dropout3 = nn.Dropout(dropout)
+ self.linear2 = nn.Linear(d_ffn, d_model)
+ self.dropout4 = nn.Dropout(dropout)
+ self.norm3 = nn.LayerNorm(d_model)
+
+ # stochastic depth
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else None
+
+ @staticmethod
+ def with_pos_embed(tensor, pos):
+ return tensor if pos is None else tensor + pos
+
+ def forward_ffn(self, tgt):
+ tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt))))
+ tgt = tgt + self.dropout4(tgt2)
+ tgt = self.norm3(tgt)
+ return tgt
+
+ def forward(self,
+ tgt,
+ query_pos,
+ reference_points,
+ src,
+ src_spatial_shapes,
+ level_start_index,
+ src_padding_mask=None):
+
+ # [DET] self-attention
+ q = k = self.with_pos_embed(tgt, query_pos)
+ tgt2 = self.self_attn(
+ q.transpose(0, 1), k.transpose(0, 1),
+ tgt.transpose(0, 1))[0].transpose(0, 1)
+ tgt = tgt + self.dropout2(tgt2)
+ tgt = self.norm2(tgt)
+
+ # Multi-scale deformable cross-attention in Eq. (1) in the ViDT paper
+ tgt2 = self.cross_attn(
+ self.with_pos_embed(tgt, query_pos), reference_points, src,
+ src_spatial_shapes, level_start_index, src_padding_mask)
+
+ if self.drop_path is None:
+ tgt = tgt + self.dropout1(tgt2)
+ tgt = self.norm1(tgt)
+ # ffn
+ tgt = self.forward_ffn(tgt)
+ else:
+ tgt = tgt + self.drop_path(self.dropout1(tgt2))
+ tgt2 = self.linear2(
+ self.dropout3(self.activation(self.linear1(tgt))))
+ tgt = tgt + self.drop_path(self.dropout4(tgt2))
+ tgt = self.norm3(tgt)
+
+ return tgt
+
+
+class DeformableTransformerDecoder(nn.Module):
+ """ A Decoder consisting of multiple layers
+
+ Args:
+ decoder_layer: a deformable decoding layer
+ num_layers: the number of layers
+ return_intermediate: whether to return intermediate resutls
+ """
+
+ def __init__(self, decoder_layer, num_layers, return_intermediate=False):
+ super().__init__()
+ self.layers = _get_clones(decoder_layer, num_layers)
+ self.num_layers = num_layers
+ self.return_intermediate = return_intermediate
+ # hack implementation for iterative bounding box refinement
+ self.bbox_embed = None
+ self.class_embed = None
+
+ def forward(self,
+ tgt,
+ reference_points,
+ src,
+ src_spatial_shapes,
+ src_level_start_index,
+ src_valid_ratios,
+ query_pos=None,
+ src_padding_mask=None):
+ """ The forwared step of the Deformable Decoder
+
+ Args:
+ tgt: [DET] tokens
+ reference_points: reference points for deformable attention
+ src: the [PATCH] tokens fattened into a 1-d sequence
+ src_spatial_shapes: the spatial shape of each multi-scale feature map
+ src_level_start_index: the start index to refer different scale inputs
+ src_valid_ratios: the ratio of multi-scale feature maps
+ query_pos: the pos encoding for [DET] tokens
+ src_padding_mask: the input padding mask
+
+ Returns:
+ output: [DET] tokens calibrated (i.e., object embeddings)
+ reference_points: A reference points
+
+ If return_intermediate = True, output & reference_points are returned from all decoding layers
+ """
+
+ output = tgt
+ intermediate = []
+ intermediate_reference_points = []
+
+ # iterative bounding box refinement (handling the [DET] tokens produced from Swin with RAM)
+ if self.bbox_embed is not None:
+ tmp = self.bbox_embed[0](output)
+ if reference_points.shape[-1] == 4:
+ new_reference_points = tmp + inverse_sigmoid(reference_points)
+ new_reference_points = new_reference_points.sigmoid()
+ else:
+ assert reference_points.shape[-1] == 2
+ new_reference_points = tmp
+ new_reference_points[
+ ..., :2] = tmp[..., :2] + inverse_sigmoid(reference_points)
+ new_reference_points = new_reference_points.sigmoid()
+ reference_points = new_reference_points.detach()
+ #
+
+ if self.return_intermediate:
+ intermediate.append(output)
+ intermediate_reference_points.append(reference_points)
+
+ for lid, layer in enumerate(self.layers):
+ if reference_points.shape[-1] == 4:
+ tmp0 = reference_points[:, :, None]
+ tmp1 = torch.cat([src_valid_ratios, src_valid_ratios],
+ -1)[:, None]
+ reference_points_input = tmp0 * tmp1
+ else:
+ assert reference_points.shape[-1] == 2
+ reference_points_input = reference_points[:, :,
+ None] * src_valid_ratios[:,
+ None]
+
+ # deformable operation
+ output = layer(output, query_pos, reference_points_input, src,
+ src_spatial_shapes, src_level_start_index,
+ src_padding_mask)
+
+ # hack implementation for iterative bounding box refinement
+ if self.bbox_embed is not None:
+ tmp = self.bbox_embed[lid + 1](output)
+ if reference_points.shape[-1] == 4:
+ new_reference_points = tmp + inverse_sigmoid(
+ reference_points)
+ new_reference_points = new_reference_points.sigmoid()
+ else:
+ assert reference_points.shape[-1] == 2
+ new_reference_points = tmp
+ new_reference_points[..., :2] = tmp[
+ ..., :2] + inverse_sigmoid(reference_points)
+ new_reference_points = new_reference_points.sigmoid()
+ reference_points = new_reference_points.detach()
+ #
+
+ if self.return_intermediate:
+ intermediate.append(output)
+ intermediate_reference_points.append(reference_points)
+
+ if self.return_intermediate:
+ return torch.stack(intermediate), torch.stack(
+ intermediate_reference_points)
+
+ return output, reference_points
+
+
+def _get_clones(module, N):
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
+
+
+def _get_activation_fn(activation):
+ """Return an activation function given a string"""
+
+ if activation == 'relu':
+ return F.relu
+ if activation == 'gelu':
+ return F.gelu
+ if activation == 'glu':
+ return F.glu
+ raise RuntimeError(F'activation should be relu/gelu, not {activation}.')
+
+
+def ms_deform_attn_core_pytorch(value, value_spatial_shapes,
+ sampling_locations, attention_weights):
+ # for debug and test only,
+ # need to use cuda version instead
+ N_, S_, M_, D_ = value.shape
+ _, Lq_, M_, L_, P_, _ = sampling_locations.shape
+ value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes],
+ dim=1)
+ sampling_grids = 2 * sampling_locations - 1
+ sampling_value_list = []
+ for lid_, (H_, W_) in enumerate(value_spatial_shapes):
+ # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_
+ value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(
+ N_ * M_, D_, H_, W_)
+ # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2
+ sampling_grid_l_ = sampling_grids[:, :, :,
+ lid_].transpose(1, 2).flatten(0, 1)
+ # N_*M_, D_, Lq_, P_
+ sampling_value_l_ = F.grid_sample(
+ value_l_,
+ sampling_grid_l_,
+ mode='bilinear',
+ padding_mode='zeros',
+ align_corners=False)
+ sampling_value_list.append(sampling_value_l_)
+ # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_)
+ attention_weights = attention_weights.transpose(1, 2).reshape(
+ N_ * M_, 1, Lq_, L_ * P_)
+ output = (torch.stack(sampling_value_list, dim=-2).flatten(-2)
+ * attention_weights).sum(-1).view(N_, M_ * D_, Lq_)
+ return output.transpose(1, 2).contiguous()
+
+
+def _is_power_of_2(n):
+ if (not isinstance(n, int)) or (n < 0):
+ raise ValueError(
+ 'invalid input for _is_power_of_2: {} (type: {})'.format(
+ n, type(n)))
+ return (n & (n - 1) == 0) and n != 0
+
+
+class MSDeformAttn(nn.Module):
+
+ def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4):
+ """
+ Multi-Scale Deformable Attention Module
+ :param d_model hidden dimension
+ :param n_levels number of feature levels
+ :param n_heads number of attention heads
+ :param n_points number of sampling points per attention head per feature level
+ """
+ super().__init__()
+ if d_model % n_heads != 0:
+ raise ValueError(
+ 'd_model must be divisible by n_heads, but got {} and {}'.
+ format(d_model, n_heads))
+ _d_per_head = d_model // n_heads
+ # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation
+ if not _is_power_of_2(_d_per_head):
+ warnings.warn(
+ "You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 "
+ 'which is more efficient in our CUDA implementation.')
+
+ self.im2col_step = 64
+
+ self.d_model = d_model
+ self.n_levels = n_levels
+ self.n_heads = n_heads
+ self.n_points = n_points
+
+ self.sampling_offsets = nn.Linear(d_model,
+ n_heads * n_levels * n_points * 2)
+ self.attention_weights = nn.Linear(d_model,
+ n_heads * n_levels * n_points)
+ self.value_proj = nn.Linear(d_model, d_model)
+ self.output_proj = nn.Linear(d_model, d_model)
+
+ self._reset_parameters()
+
+ def _reset_parameters(self):
+ constant_(self.sampling_offsets.weight.data, 0.)
+ thetas = torch.arange(
+ self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
+ grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
+ grid_init = (grid_init
+ / grid_init.abs().max(-1, keepdim=True)[0]).view(
+ self.n_heads, 1, 1, 2).repeat(1, self.n_levels,
+ self.n_points, 1)
+ for i in range(self.n_points):
+ grid_init[:, :, i, :] *= i + 1
+ with torch.no_grad():
+ self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
+ constant_(self.attention_weights.weight.data, 0.)
+ constant_(self.attention_weights.bias.data, 0.)
+ xavier_uniform_(self.value_proj.weight.data)
+ constant_(self.value_proj.bias.data, 0.)
+ xavier_uniform_(self.output_proj.weight.data)
+ constant_(self.output_proj.bias.data, 0.)
+
+ def forward(self,
+ query,
+ reference_points,
+ input_flatten,
+ input_spatial_shapes,
+ input_level_start_index,
+ input_padding_mask=None):
+ """
+ :param query (N, Length_{query}, C)
+ :param reference_points (N, Length_{query}, n_levels, 2)
+ :param input_flatten (N, \\sum_{l=0}^{L-1} H_l \\cdot W_l, C)
+ :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
+ :param input_level_start_index (n_levels, )
+ :param input_padding_mask (N, \\sum_{l=0}^{L-1} H_l \\cdot W_l)
+ :return output (N, Length_{query}, C)
+ """
+ N, Len_q, _ = query.shape
+ N, Len_in, _ = input_flatten.shape
+ assert (input_spatial_shapes[:, 0]
+ * input_spatial_shapes[:, 1]).sum() == Len_in
+
+ value = self.value_proj(input_flatten)
+ if input_padding_mask is not None:
+ value = value.masked_fill(input_padding_mask[..., None], float(0))
+ value = value.view(N, Len_in, self.n_heads,
+ self.d_model // self.n_heads)
+ sampling_offsets = self.sampling_offsets(query).view(
+ N, Len_q, self.n_heads, self.n_levels, self.n_points, 2)
+ # attn weights for each sampled query.
+ attention_weights = self.attention_weights(query).view(
+ N, Len_q, self.n_heads, self.n_levels * self.n_points)
+ attention_weights = F.softmax(attention_weights,
+ -1).view(N, Len_q, self.n_heads,
+ self.n_levels, self.n_points)
+ # N, Len_q, n_heads, n_levels, n_points, 2
+
+ if reference_points.shape[-1] == 2:
+ offset_normalizer = torch.stack(
+ [input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]],
+ -1)
+ tmp0 = reference_points[:, :, None, :, None, :]
+ tmp1 = sampling_offsets / offset_normalizer[None, None, None, :,
+ None, :]
+ sampling_locations = tmp0 + tmp1
+ elif reference_points.shape[-1] == 4:
+ tmp0 = reference_points[:, :, None, :, None, :2]
+ tmp1 = sampling_offsets / self.n_points * reference_points[:, :,
+ None, :,
+ None,
+ 2:] * 0.5
+ sampling_locations = tmp0 + tmp1
+ else:
+ raise ValueError(
+ 'Last dim of reference_points must be 2 or 4, but get {} instead.'
+ .format(reference_points.shape[-1]))
+ output = ms_deform_attn_core_pytorch(value, input_spatial_shapes,
+ sampling_locations,
+ attention_weights)
+ output = self.output_proj(output)
+
+ return output
+
+
+def inverse_sigmoid(x, eps=1e-5):
+ x = x.clamp(min=0, max=1)
+ x1 = x.clamp(min=eps)
+ x2 = (1 - x).clamp(min=eps)
+ return torch.log(x1 / x2)
diff --git a/modelscope/models/cv/vidt/fpn_fusion.py b/modelscope/models/cv/vidt/fpn_fusion.py
new file mode 100644
index 00000000..b48ba0fe
--- /dev/null
+++ b/modelscope/models/cv/vidt/fpn_fusion.py
@@ -0,0 +1,248 @@
+# The implementation here is modified based on timm,
+# originally Apache 2.0 License and publicly available at
+# https://github.com/naver-ai/vidt/blob/vidt-plus/methods/vidt/fpn_fusion.py
+
+import torch.nn as nn
+
+
+class FPNFusionModule(nn.Module):
+ """ This is a fpn-style cross-scale feature fusion module" """
+
+ def __init__(self, embed_dims, fuse_dim=256, n_block=4, use_bn=False):
+ super().__init__()
+ """ Initializes the model.
+ Args:
+ embed_dims: the list of channel dim for different scale feature maps (i.e., the input)
+ fuse_dim: the channel dim of the fused feature map (i.e., the output)
+ n_block: the number of multi-scale features (default=4)
+ use_bn: whether to use bn
+ """
+
+ self.embed_dims = embed_dims
+ self.fuse_dim = fuse_dim
+ self.n_block = n_block
+
+ # cross-scale fusion layers
+ self.multi_scaler = _make_multi_scale_layers(
+ embed_dims, fuse_dim, use_bn=use_bn, n_block=n_block)
+
+ def forward(self, x_blocks):
+
+ x_blocks = x_blocks
+
+ # preperation: channel reduction and normalization
+ for idx in range(self.n_block - 1, -1, -1):
+ x_blocks[idx] = getattr(self.multi_scaler, f'layer_{idx}_rn')(
+ x_blocks[idx])
+ x_blocks[idx] = getattr(self.multi_scaler, f'p_norm_{idx}')(
+ x_blocks[idx])
+
+ # cross-scale fusion
+ refined_embeds = []
+ for idx in range(self.n_block - 1, -1, -1):
+ if idx == self.n_block - 1:
+ path = getattr(self.multi_scaler,
+ f'refinenet_{idx}')([x_blocks[idx]], None)
+ else:
+ path = getattr(self.multi_scaler,
+ f'refinenet_{idx}')([path, x_blocks[idx]],
+ x_blocks[idx].size()[2:])
+ refined_embeds.append(path)
+
+ return refined_embeds
+
+
+def _make_multi_scale_layers(in_shape,
+ out_shape,
+ n_block=4,
+ groups=1,
+ use_bn=False):
+
+ out_shapes = [out_shape for _ in range(n_block)]
+ multi_scaler = nn.Module()
+
+ for idx in range(n_block - 1, -1, -1):
+ """
+ 1 x 1 conv for dim reduction -> group norm
+ """
+ layer_name = f'layer_{(idx)}_rn'
+ multi_scaler.add_module(
+ layer_name,
+ nn.Conv2d(in_shape[idx], out_shapes[idx], kernel_size=1))
+
+ layer_name = f'p_norm_{(idx)}'
+ multi_scaler.add_module(layer_name, nn.GroupNorm(32, out_shapes[idx]))
+
+ layer_name = f'refinenet_{idx}'
+ multi_scaler.add_module(layer_name,
+ _make_fusion_block(out_shape, use_bn))
+
+ # initialize for the 1x1 conv
+ nn.init.xavier_uniform_(
+ getattr(multi_scaler, f'layer_{idx}_rn').weight, gain=1)
+ nn.init.constant_(getattr(multi_scaler, f'layer_{idx}_rn').bias, 0)
+
+ return multi_scaler
+
+
+def _make_fusion_block(features, use_bn):
+ """ We use a resnet bottleneck structure for fpn """
+
+ return FeatureFusionBlock(
+ features,
+ nn.ReLU(False),
+ bn=use_bn,
+ expand=False,
+ align_corners=True,
+ )
+
+
+class FeatureFusionBlock(nn.Module):
+ """ Feature fusion block """
+
+ def __init__(self,
+ features,
+ activation,
+ bn=False,
+ expand=False,
+ align_corners=True):
+ """Init.
+ Args:
+ features (int): channel dim of the input feature
+ activation: activation function to use
+ bn: whether to use bn
+ expand: whether to exapnd feature or not
+ align_corners: wheter to use align_corners for interpolation
+ """
+
+ super(FeatureFusionBlock, self).__init__()
+ self.align_corners = align_corners
+ self.groups = 1
+ self.expand = expand
+ out_features = features
+
+ if self.expand is True:
+ out_features = features // 2
+
+ self.smoothing = nn.Conv2d(
+ features,
+ out_features,
+ kernel_size=1,
+ bias=True,
+ groups=1,
+ )
+
+ self.resConfUnit1 = ResidualConvUnit(features, activation, bn)
+ self.resConfUnit2 = ResidualConvUnit(features, activation, bn)
+ self.skip_add = nn.quantized.FloatFunctional()
+
+ def forward(self, xs, up_size):
+ """ Forward pass.
+ Args
+ xs: xs[0]: the feature refined from the previous step, xs[1]: the next scale features to fuse
+ up_size: the size for upsampling; xs[0] is upsampled before merging with xs[1]
+ Returns:
+ output: the fused feature, which is fed to the next fusion step as an input
+ """
+
+ output = xs[0]
+ if len(xs) == 2:
+ # upsampling
+ output = nn.functional.interpolate(
+ output,
+ size=up_size,
+ mode='bilinear',
+ align_corners=self.align_corners)
+ # feature smoothing since the upsampled feature is coarse-grain
+ output = self.smoothing(output)
+
+ # refine the next scale feature before fusion
+ res = self.resConfUnit1(xs[1])
+
+ # fusion
+ output = self.skip_add.add(output, res)
+
+ # post refine after fusion
+ output = self.resConfUnit2(output)
+
+ return output
+
+
+class ResidualConvUnit(nn.Module):
+ """ Residual convolution module. """
+
+ def __init__(self, features, activation, bn):
+ """Init.
+ Args:
+ features (int): channel dim of the input
+ activation: activation function
+ bn: whether to use bn
+ """
+
+ super().__init__()
+
+ self.bn = bn
+ self.groups = 1
+
+ self.conv1 = nn.Conv2d(
+ features,
+ 64,
+ kernel_size=1,
+ stride=1,
+ bias=not self.bn,
+ groups=self.groups,
+ )
+ self.conv2 = nn.Conv2d(
+ 64,
+ 64,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=not self.bn,
+ groups=self.groups,
+ )
+ self.conv3 = nn.Conv2d(
+ 64,
+ features,
+ kernel_size=1,
+ stride=1,
+ bias=not self.bn,
+ groups=self.groups,
+ )
+ if self.bn is True:
+ self.bn1 = nn.BatchNorm2d(features)
+ self.bn2 = nn.BatchNorm2d(features)
+ self.bn3 = nn.BatchNorm2d(features)
+
+ self.activation = activation
+ self.skip_add = nn.quantized.FloatFunctional()
+
+ def forward(self, x):
+ """ Forward pass
+
+ Args:
+ x (tensor): input feature
+
+ Returns:
+ tensor: output feature
+ """
+
+ out = self.activation(x)
+ out = self.conv1(out)
+ if self.bn is True:
+ out = self.bn1(out)
+
+ out = self.activation(out)
+ out = self.conv2(out)
+ if self.bn is True:
+ out = self.bn2(out)
+
+ out = self.activation(out)
+ out = self.conv3(out)
+ if self.bn is True:
+ out = self.bn3(out)
+
+ if self.groups > 1:
+ out = self.conv_merge(out)
+
+ return self.skip_add.add(out, x)
diff --git a/modelscope/models/cv/vidt/head.py b/modelscope/models/cv/vidt/head.py
new file mode 100644
index 00000000..28737e96
--- /dev/null
+++ b/modelscope/models/cv/vidt/head.py
@@ -0,0 +1,413 @@
+# The implementation here is modified based on timm,
+# originally Apache 2.0 License and publicly available at
+# https://github.com/naver-ai/vidt/blob/vidt-plus/methods/vidt/detector.py
+
+import copy
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class Detector(nn.Module):
+ """ This is a combination of "Swin with RAM" and a "Neck-free Deformable Decoder" """
+
+ def __init__(
+ self,
+ backbone,
+ transformer,
+ num_classes,
+ num_queries,
+ aux_loss=False,
+ with_box_refine=False,
+ # The three additional techniques for ViDT+
+ epff=None, # (1) Efficient Pyramid Feature Fusion Module
+ with_vector=False,
+ processor_dct=None,
+ vector_hidden_dim=256, # (2) UQR Module
+ iou_aware=False,
+ token_label=False, # (3) Additional losses
+ distil=False):
+ """ Initializes the model.
+ Args:
+ backbone: torch module of the backbone to be used. See backbone.py
+ transformer: torch module of the transformer architecture. See transformer.py
+ num_classes: number of object classes
+ num_queries: number of object queries (i.e., det tokens). This is the maximal number of objects
+ DETR can detect in a single image. For COCO, we recommend 100 queries.
+ aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
+ with_box_refine: iterative bounding box refinement
+ epff: None or fusion module available
+ iou_aware: True if iou_aware is to be used.
+ see the original paper https://arxiv.org/abs/1912.05992
+ token_label: True if token_label is to be used.
+ see the original paper https://arxiv.org/abs/2104.10858
+ distil: whether to use knowledge distillation with token matching
+ """
+
+ super().__init__()
+ self.num_queries = num_queries
+ self.transformer = transformer
+ hidden_dim = transformer.d_model
+ self.class_embed = nn.Linear(hidden_dim, num_classes)
+ self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
+
+ # two essential techniques used [default use]
+ self.aux_loss = aux_loss
+ self.with_box_refine = with_box_refine
+
+ # For UQR module for ViDT+
+ self.with_vector = with_vector
+ self.processor_dct = processor_dct
+ if self.with_vector:
+ print(
+ f'Training with vector_hidden_dim {vector_hidden_dim}.',
+ flush=True)
+ self.vector_embed = MLP(hidden_dim, vector_hidden_dim,
+ self.processor_dct.n_keep, 3)
+
+ # For two additional losses for ViDT+
+ self.iou_aware = iou_aware
+ self.token_label = token_label
+
+ # distillation
+ self.distil = distil
+
+ # For EPFF module for ViDT+
+ if epff is None:
+ num_backbone_outs = len(backbone.num_channels)
+ input_proj_list = []
+ for _ in range(num_backbone_outs):
+ in_channels = backbone.num_channels[_]
+ input_proj_list.append(
+ nn.Sequential(
+ # This is 1x1 conv -> so linear layer
+ nn.Conv2d(in_channels, hidden_dim, kernel_size=1),
+ nn.GroupNorm(32, hidden_dim),
+ ))
+ self.input_proj = nn.ModuleList(input_proj_list)
+
+ # initialize the projection layer for [PATCH] tokens
+ for proj in self.input_proj:
+ nn.init.xavier_uniform_(proj[0].weight, gain=1)
+ nn.init.constant_(proj[0].bias, 0)
+ self.fusion = None
+ else:
+ # the cross scale fusion module has its own reduction layers
+ self.fusion = epff
+
+ # channel dim reduction for [DET] tokens
+ self.tgt_proj = nn.Sequential(
+ # This is 1x1 conv -> so linear layer
+ nn.Conv2d(backbone.num_channels[-2], hidden_dim, kernel_size=1),
+ nn.GroupNorm(32, hidden_dim),
+ )
+
+ # channel dim reductionfor [DET] learnable pos encodings
+ self.query_pos_proj = nn.Sequential(
+ # This is 1x1 conv -> so linear layer
+ nn.Conv2d(hidden_dim, hidden_dim, kernel_size=1),
+ nn.GroupNorm(32, hidden_dim),
+ )
+
+ # initialize detection head: box regression and classification
+ prior_prob = 0.01
+ bias_value = -math.log((1 - prior_prob) / prior_prob)
+ self.class_embed.bias.data = torch.ones(num_classes) * bias_value
+ nn.init.constant_(self.bbox_embed.layers[-1].weight.data, 0)
+ nn.init.constant_(self.bbox_embed.layers[-1].bias.data, 0)
+
+ # initialize projection layer for [DET] tokens and encodings
+ nn.init.xavier_uniform_(self.tgt_proj[0].weight, gain=1)
+ nn.init.constant_(self.tgt_proj[0].bias, 0)
+ nn.init.xavier_uniform_(self.query_pos_proj[0].weight, gain=1)
+ nn.init.constant_(self.query_pos_proj[0].bias, 0)
+
+ if self.with_vector:
+ nn.init.constant_(self.vector_embed.layers[-1].weight.data, 0)
+ nn.init.constant_(self.vector_embed.layers[-1].bias.data, 0)
+
+ # the prediction is made for each decoding layers + the standalone detector (Swin with RAM)
+ num_pred = transformer.decoder.num_layers + 1
+
+ # set up all required nn.Module for additional techniques
+ if with_box_refine:
+ self.class_embed = _get_clones(self.class_embed, num_pred)
+ self.bbox_embed = _get_clones(self.bbox_embed, num_pred)
+ nn.init.constant_(self.bbox_embed[0].layers[-1].bias.data[2:],
+ -2.0)
+ # hack implementation for iterative bounding box refinement
+ self.transformer.decoder.bbox_embed = self.bbox_embed
+ else:
+ nn.init.constant_(self.bbox_embed.layers[-1].bias.data[2:], -2.0)
+ self.class_embed = nn.ModuleList(
+ [self.class_embed for _ in range(num_pred)])
+ self.bbox_embed = nn.ModuleList(
+ [self.bbox_embed for _ in range(num_pred)])
+ self.transformer.decoder.bbox_embed = None
+
+ if self.with_vector:
+ nn.init.constant_(self.vector_embed.layers[-1].bias.data[2:], -2.0)
+ self.vector_embed = nn.ModuleList(
+ [self.vector_embed for _ in range(num_pred)])
+
+ if self.iou_aware:
+ self.iou_embed = MLP(hidden_dim, hidden_dim, 1, 3)
+ if with_box_refine:
+ self.iou_embed = _get_clones(self.iou_embed, num_pred)
+ else:
+ self.iou_embed = nn.ModuleList(
+ [self.iou_embed for _ in range(num_pred)])
+
+ def forward(self, features_0, features_1, features_2, features_3, det_tgt,
+ det_pos, mask):
+ """ The forward step of ViDT
+
+ Args:
+ The forward expects a NestedTensor, which consists of:
+ - features_0: images feature
+ - features_1: images feature
+ - features_2: images feature
+ - features_3: images feature
+ - det_tgt: images det logits feature
+ - det_pos: images det position feature
+ - mask: images mask
+ Returns:
+ A dictionary having the key and value pairs below:
+ - "out_pred_logits": the classification logits (including no-object) for all queries.
+ Shape= [batch_size x num_queries x (num_classes + 1)]
+ - "out_pred_boxes": The normalized boxes coordinates for all queries, represented as
+ (center_x, center_y, height, width). These values are normalized in [0, 1],
+ relative to the size of each individual image (disregarding possible padding).
+ See PostProcess for information on how to retrieve the unnormalized bounding box.
+ """
+ features = [features_0, features_1, features_2, features_3]
+
+ # [DET] token and encoding projection to compact representation for the input to the Neck-free transformer
+ det_tgt = self.tgt_proj(det_tgt.unsqueeze(-1)).squeeze(-1).permute(
+ 0, 2, 1)
+ det_pos = self.query_pos_proj(
+ det_pos.unsqueeze(-1)).squeeze(-1).permute(0, 2, 1)
+
+ # [PATCH] token projection
+ shapes = []
+ for le, src in enumerate(features):
+ shapes.append(src.shape[-2:])
+
+ srcs = []
+ if self.fusion is None:
+ for le, src in enumerate(features):
+ srcs.append(self.input_proj[le](src))
+ else:
+ # EPFF (multi-scale fusion) is used if fusion is activated
+ srcs = self.fusion(features)
+
+ masks = []
+ for le, src in enumerate(srcs):
+ # resize mask
+ shapes.append(src.shape[-2:])
+ _mask = F.interpolate(
+ mask[None].float(), size=src.shape[-2:]).to(torch.bool)[0]
+ masks.append(_mask)
+ assert mask is not None
+
+ outputs_classes = []
+ outputs_coords = []
+
+ # return the output of the neck-free decoder
+ hs, init_reference, inter_references, enc_token_class_unflat = self.transformer(
+ srcs, masks, det_tgt, det_pos)
+
+ # perform predictions via the detection head
+ for lvl in range(hs.shape[0]):
+ reference = init_reference if lvl == 0 else inter_references[lvl
+ - 1]
+ reference = inverse_sigmoid(reference)
+
+ outputs_class = self.class_embed[lvl](hs[lvl])
+ # bbox output + reference
+ tmp = self.bbox_embed[lvl](hs[lvl])
+ if reference.shape[-1] == 4:
+ tmp += reference
+ else:
+ assert reference.shape[-1] == 2
+ tmp[..., :2] += reference
+
+ outputs_coord = tmp.sigmoid()
+ outputs_classes.append(outputs_class)
+ outputs_coords.append(outputs_coord)
+
+ # stack all predictions made from each decoding layers
+ outputs_class = torch.stack(outputs_classes)
+ outputs_coord = torch.stack(outputs_coords)
+
+ outputs_vector = None
+ if self.with_vector:
+ outputs_vectors = []
+ for lvl in range(hs.shape[0]):
+ outputs_vector = self.vector_embed[lvl](hs[lvl])
+ outputs_vectors.append(outputs_vector)
+ outputs_vector = torch.stack(outputs_vectors)
+
+ # final prediction is made the last decoding layer
+ out = {
+ 'pred_logits': outputs_class[-1],
+ 'pred_boxes': outputs_coord[-1]
+ }
+
+ if self.with_vector:
+ out.update({'pred_vectors': outputs_vector[-1]})
+
+ # aux loss is defined by using the rest predictions
+ if self.aux_loss and self.transformer.decoder.num_layers > 0:
+ out['aux_outputs'] = self._set_aux_loss(outputs_class,
+ outputs_coord,
+ outputs_vector)
+
+ # iou awareness loss is defined for each decoding layer similar to auxiliary decoding loss
+ if self.iou_aware:
+ outputs_ious = []
+ for lvl in range(hs.shape[0]):
+ outputs_ious.append(self.iou_embed[lvl](hs[lvl]))
+ outputs_iou = torch.stack(outputs_ious)
+ out['pred_ious'] = outputs_iou[-1]
+
+ if self.aux_loss:
+ for i, aux in enumerate(out['aux_outputs']):
+ aux['pred_ious'] = outputs_iou[i]
+
+ # token label loss
+ if self.token_label:
+ out['enc_tokens'] = {'pred_logits': enc_token_class_unflat}
+
+ if self.distil:
+ # 'patch_token': multi-scale patch tokens from each stage
+ # 'body_det_token' and 'neck_det_tgt': the input det_token for multiple detection heads
+ out['distil_tokens'] = {
+ 'patch_token': srcs,
+ 'body_det_token': det_tgt,
+ 'neck_det_token': hs
+ }
+
+ out_pred_logits = out['pred_logits']
+ out_pred_boxes = out['pred_boxes']
+ return out_pred_logits, out_pred_boxes
+
+ @torch.jit.unused
+ def _set_aux_loss(self, outputs_class, outputs_coord, outputs_vector):
+ # this is a workaround to make torchscript happy, as torchscript
+ # doesn't support dictionary with non-homogeneous values, such
+ # as a dict having both a Tensor and a list.
+
+ if outputs_vector is None:
+ return [{
+ 'pred_logits': a,
+ 'pred_boxes': b
+ } for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]
+ else:
+ return [{
+ 'pred_logits': a,
+ 'pred_boxes': b,
+ 'pred_vectors': c
+ } for a, b, c in zip(outputs_class[:-1], outputs_coord[:-1],
+ outputs_vector[:-1])]
+
+
+class MLP(nn.Module):
+ """ Very simple multi-layer perceptron (also called FFN)"""
+
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
+ super().__init__()
+ self.num_layers = num_layers
+ h = [hidden_dim] * (num_layers - 1)
+ self.layers = nn.ModuleList(
+ nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
+
+ def forward(self, x):
+ for i, layer in enumerate(self.layers):
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
+ return x
+
+
+def inverse_sigmoid(x, eps=1e-5):
+ x = x.clamp(min=0, max=1)
+ x1 = x.clamp(min=eps)
+ x2 = (1 - x).clamp(min=eps)
+ return torch.log(x1 / x2)
+
+
+def box_cxcywh_to_xyxy(x):
+ x_c, y_c, w, h = x.unbind(-1)
+ b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
+ return torch.stack(b, dim=-1)
+
+
+# process post_results
+def get_predictions(post_results, bbox_thu=0.40):
+ batch_final_res = []
+ for per_img_res in post_results:
+ per_img_final_res = []
+ for i in range(len(per_img_res['scores'])):
+ score = float(per_img_res['scores'][i].cpu())
+ label = int(per_img_res['labels'][i].cpu())
+ bbox = []
+ for it in per_img_res['boxes'][i].cpu():
+ bbox.append(int(it))
+ if score >= bbox_thu:
+ per_img_final_res.append([score, label, bbox])
+ batch_final_res.append(per_img_final_res)
+ return batch_final_res
+
+
+class PostProcess(nn.Module):
+ """ This module converts the model's output into the format expected by the coco api"""
+
+ def __init__(self, processor_dct=None):
+ super().__init__()
+ # For instance segmentation using UQR module
+ self.processor_dct = processor_dct
+
+ @torch.no_grad()
+ def forward(self, out_logits, out_bbox, target_sizes):
+ """ Perform the computation
+
+ Args:
+ out_logits: raw logits outputs of the model
+ out_bbox: raw bbox outputs of the model
+ target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch
+ For evaluation, this must be the original image size (before any data augmentation)
+ For visualization, this should be the image size after data augment, but before padding
+ """
+ assert len(out_logits) == len(target_sizes)
+ assert target_sizes.shape[1] == 2
+
+ prob = out_logits.sigmoid()
+ topk_values, topk_indexes = torch.topk(
+ prob.view(out_logits.shape[0], -1), 100, dim=1)
+ scores = topk_values
+ topk_boxes = topk_indexes // out_logits.shape[2]
+ labels = topk_indexes % out_logits.shape[2]
+ boxes = box_cxcywh_to_xyxy(out_bbox)
+ boxes = torch.gather(boxes, 1,
+ topk_boxes.unsqueeze(-1).repeat(1, 1, 4))
+
+ # and from relative [0, 1] to absolute [0, height] coordinates
+ img_h, img_w = target_sizes.unbind(1)
+ scale_fct = torch.stack([img_w, img_h, img_w, img_h],
+ dim=1).to(torch.float32)
+ boxes = boxes * scale_fct[:, None, :]
+
+ results = [{
+ 'scores': s,
+ 'labels': l,
+ 'boxes': b
+ } for s, l, b in zip(scores, labels, boxes)]
+
+ return results
+
+
+def _get_clones(module, N):
+ """ Clone a moudle N times """
+
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
diff --git a/modelscope/models/cv/vidt/model.py b/modelscope/models/cv/vidt/model.py
new file mode 100644
index 00000000..65940637
--- /dev/null
+++ b/modelscope/models/cv/vidt/model.py
@@ -0,0 +1,98 @@
+# Copyright 2022-2023 The Alibaba Fundamental Vision Team Authors. All rights reserved.
+
+import os
+
+import torch
+
+from modelscope.metainfo import Models
+from modelscope.models.base.base_torch_model import TorchModel
+from modelscope.models.builder import MODELS
+from modelscope.utils.constant import ModelFile, Tasks
+from .backbone import SwinTransformer
+from .deformable_transformer import DeformableTransformer
+from .fpn_fusion import FPNFusionModule
+from .head import Detector
+
+
+@MODELS.register_module(Tasks.image_object_detection, module_name=Models.vidt)
+class VidtModel(TorchModel):
+ """
+ The implementation of 'ViDT for joint-learning of object detection and instance segmentation'.
+ This model is dynamically initialized with the following parts:
+ - 'backbone': pre-trained backbone model with parameters.
+ - 'head': detection and segentation head with fine-tuning.
+ """
+
+ def __init__(self, model_dir: str, **kwargs):
+ """ Initialize a Vidt Model.
+ Args:
+ model_dir: model id or path, where model_dir/pytorch_model.pt contains:
+ - 'backbone_weights': parameters of backbone.
+ - 'head_weights': parameters of head.
+ """
+ super(VidtModel, self).__init__()
+
+ model_path = os.path.join(model_dir, ModelFile.TORCH_MODEL_FILE)
+ model_dict = torch.load(model_path, map_location='cpu')
+
+ # build backbone
+ backbone = SwinTransformer(
+ pretrain_img_size=[224, 224],
+ embed_dim=96,
+ depths=[2, 2, 6, 2],
+ num_heads=[3, 6, 12, 24],
+ window_size=7,
+ drop_path_rate=0.2)
+ backbone.finetune_det(
+ method='vidt', det_token_num=300, pos_dim=256, cross_indices=[3])
+ self.backbone = backbone
+ self.backbone.load_state_dict(
+ model_dict['backbone_weights'], strict=True)
+
+ # build head
+ epff = FPNFusionModule(backbone.num_channels, fuse_dim=256)
+ deform_transformers = DeformableTransformer(
+ d_model=256,
+ nhead=8,
+ num_decoder_layers=6,
+ dim_feedforward=1024,
+ dropout=0.1,
+ activation='relu',
+ return_intermediate_dec=True,
+ num_feature_levels=4,
+ dec_n_points=4,
+ token_label=False)
+ head = Detector(
+ backbone,
+ deform_transformers,
+ num_classes=2,
+ num_queries=300,
+ # two essential techniques used in ViDT
+ aux_loss=True,
+ with_box_refine=True,
+ # an epff module for ViDT+
+ epff=epff,
+ # an UQR module for ViDT+
+ with_vector=False,
+ processor_dct=None,
+ # two additional losses for VIDT+
+ iou_aware=True,
+ token_label=False,
+ vector_hidden_dim=256,
+ # distil
+ distil=False)
+ self.head = head
+ self.head.load_state_dict(model_dict['head_weights'], strict=True)
+
+ def forward(self, x, mask):
+ """ Dynamic forward function of VidtModel.
+ Args:
+ x: input images (B, 3, H, W)
+ mask: input padding masks (B, H, W)
+ """
+ features_0, features_1, features_2, features_3, det_tgt, det_pos = self.backbone(
+ x, mask)
+ out_pred_logits, out_pred_boxes = self.head(features_0, features_1,
+ features_2, features_3,
+ det_tgt, det_pos, mask)
+ return out_pred_logits, out_pred_boxes
diff --git a/modelscope/models/cv/vision_efficient_tuning/__init__.py b/modelscope/models/cv/vision_efficient_tuning/__init__.py
index 05243554..80128f62 100644
--- a/modelscope/models/cv/vision_efficient_tuning/__init__.py
+++ b/modelscope/models/cv/vision_efficient_tuning/__init__.py
@@ -5,18 +5,11 @@ from modelscope.utils.import_utils import LazyImportModule
if TYPE_CHECKING:
- from .vision_efficient_tuning_adapter import VisionEfficientTuningAdapterModel
- from .vision_efficient_tuning_prompt import VisionEfficientTuningPromptModel
- from .vision_efficient_tuning_prefix import VisionEfficientTuningPrefixModel
- from .vision_efficient_tuning_lora import VisionEfficientTuningLoRAModel
+ from .model import VisionEfficientTuningModel
else:
_import_structure = {
- 'vision_efficient_tuning_adapter':
- ['VisionEfficientTuningAdapterModel'],
- 'vision_efficient_tuning_prompt': ['VisionEfficientTuningPromptModel'],
- 'vision_efficient_tuning_prefix': ['VisionEfficientTuningPrefixModel'],
- 'vision_efficient_tuning_lora': ['VisionEfficientTuningLoRAModel'],
+ 'model': ['VisionEfficientTuningModel'],
}
import sys
diff --git a/modelscope/models/cv/vision_efficient_tuning/backbone.py b/modelscope/models/cv/vision_efficient_tuning/backbone.py
index e7556ea1..691e4440 100644
--- a/modelscope/models/cv/vision_efficient_tuning/backbone.py
+++ b/modelscope/models/cv/vision_efficient_tuning/backbone.py
@@ -7,9 +7,10 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
-from .petl import Adapter, LoRA, Prefix, Prompt
+from .petl import Adapter, LoRA, Prefix, Prompt, SideTune
from .timm_vision_transformer import (Attention, Block, DropPath, LayerScale,
- Mlp, PatchEmbed, VisionTransformer)
+ Mlp, PatchEmbed, VisionTransformer,
+ checkpoint_seq)
class AttentionPETL(nn.Module):
@@ -212,40 +213,74 @@ class VisionTransformerPETL(VisionTransformer):
The implementation of several tuning methods (prompt, prefix, adapter, and LoRA) based on ViT.
"""
- def __init__(
- self,
- img_size=224,
- patch_size=16,
- in_chans=3,
- num_classes=1000,
- global_pool='token',
- embed_dim=768,
- depth=12,
- num_heads=12,
- mlp_ratio=4.,
- qkv_bias=True,
- init_values=None,
- class_token=True,
- no_embed_class=False,
- pre_norm=False,
- fc_norm=None,
- drop_rate=0.,
- attn_drop_rate=0.,
- drop_path_rate=0.,
- weight_init='',
- embed_layer=PatchEmbed,
- norm_layer=None,
- act_layer=None,
- block_fn=Block,
- prompt_length=None,
- prompt_type=None,
- prefix_length=None,
- prefix_type=None,
- adapter_length=None,
- adapter_type=None,
- lora_length=None,
- lora_type=None,
- ):
+ def __init__(self,
+ img_size=224,
+ patch_size=16,
+ in_chans=3,
+ num_classes=1000,
+ global_pool='token',
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4.,
+ qkv_bias=True,
+ init_values=None,
+ class_token=True,
+ no_embed_class=False,
+ pre_norm=False,
+ fc_norm=None,
+ drop_rate=0.,
+ attn_drop_rate=0.,
+ drop_path_rate=0.,
+ weight_init='',
+ embed_layer=PatchEmbed,
+ norm_layer=None,
+ act_layer=None,
+ block_fn=Block,
+ prompt_length=None,
+ prompt_type=None,
+ prefix_length=None,
+ prefix_type=None,
+ adapter_length=None,
+ adapter_type=None,
+ lora_length=None,
+ lora_type=None,
+ sidetune_length=None,
+ sidetune_type=None):
+ """ Initialize a Parameter-efficient Transfer Learning Method based on Vision Transformer.
+
+ Args:
+ img_size (int, tuple): input image size
+ patch_size (int, tuple): patch size
+ in_chans (int): number of input channels
+ num_classes (int): number of classes for classification head
+ global_pool (str): type of global pooling for final sequence (default: 'token')
+ embed_dim (int): embedding dimension
+ depth (int): depth of transformer
+ num_heads (int): number of attention heads
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
+ qkv_bias (bool): enable bias for qkv if True
+ init_values: (float): layer-scale init values
+ class_token (bool): use class token
+ fc_norm (Optional[bool]): pre-fc norm after pool, set if global_pool == 'avg' if None (default: None)
+ drop_rate (float): dropout rate
+ attn_drop_rate (float): attention dropout rate
+ drop_path_rate (float): stochastic depth rate
+ weight_init (str): weight init scheme
+ embed_layer (nn.Module): patch embedding layer
+ norm_layer: (nn.Module): normalization layer
+ act_layer: (nn.Module): MLP activation layer
+ prompt_length: An integer indicating the length of prompt tuning.
+ prompt_type: A string indicating the type of prompt tuning.
+ prefix_length: An integer indicating the length of prefix tuning.
+ prefix_type: A string indicating the type of prefix tuning.
+ adapter_length: An integer indicating the length of adapter tuning.
+ adapter_type: A string indicating the type of adapter tuning.
+ lora_length: An integer indicating the length of LoRA tuning.
+ lora_type: A string indicating the type of LoRA tuning.
+ sidetune_length: An integer indicating the linear dimension.
+ sidetune_type: A string indicating the type of side network.
+ """
super().__init__()
assert global_pool in ('', 'avg', 'token')
@@ -349,3 +384,49 @@ class VisionTransformerPETL(VisionTransformer):
if weight_init != 'skip':
self.init_weights(weight_init)
+
+ if sidetune_type is not None:
+ self.sidetune = SideTune(sidetune_length, sidetune_type)
+ else:
+ self.sidetune = None
+
+ def forward_features(self, x):
+ """ feature forward function of VisionTransformer.
+
+ Args:
+ x (Tensor): the input data.
+ Returns:
+ res (Dict): the output data, contains:
+ - inputs: the original input.
+ - x: the intermediate feature.
+ """
+ res = dict(inputs=x)
+ x = self.patch_embed(x)
+ x = self._pos_embed(x)
+ x = self.norm_pre(x)
+ if self.grad_checkpointing and not torch.jit.is_scripting():
+ x = checkpoint_seq(self.blocks, x)
+ else:
+ x = self.blocks(x)
+ x = self.norm(x)
+ res['x'] = x
+ return res
+
+ def forward_head(self, res, pre_logits: bool = False):
+ """ head forward function of VisionTransformer.
+
+ Args:
+ res (Dict): the input data, contains:
+ - inputs: the original input.
+ - x: the intermediate feature.
+ Returns:
+ x (Tensor): the output data.
+ """
+ x = res['x']
+ if self.global_pool:
+ x = x[:, self.num_prefix_tokens:].mean(
+ dim=1) if self.global_pool == 'avg' else x[:, 0]
+ if self.sidetune and 'inputs' in res:
+ x = self.sidetune(res['inputs'], x)
+ x = self.fc_norm(x)
+ return x if pre_logits else self.head(x)
diff --git a/modelscope/models/cv/vision_efficient_tuning/model.py b/modelscope/models/cv/vision_efficient_tuning/model.py
new file mode 100644
index 00000000..49b50272
--- /dev/null
+++ b/modelscope/models/cv/vision_efficient_tuning/model.py
@@ -0,0 +1,49 @@
+# Copyright 2022-2023 The Alibaba Fundamental Vision Team Authors. All rights reserved.
+from typing import Any, Dict
+
+import torch
+
+from modelscope.metainfo import Models
+from modelscope.models.base.base_torch_model import TorchModel
+from modelscope.models.builder import MODELS
+from modelscope.utils.constant import Tasks
+from .vision_efficient_tuning import VisionEfficientTuning
+
+
+@MODELS.register_module(
+ Tasks.vision_efficient_tuning, module_name=Models.vision_efficient_tuning)
+class VisionEfficientTuningModel(TorchModel):
+ """ The implementation of vision efficient tuning model based on TorchModel.
+
+ This model is constructed with the following parts:
+ - 'backbone': pre-trained backbone model with parameters.
+ - 'head': classification head with fine-tuning.
+ """
+
+ def __init__(self, model_dir: str, **kwargs):
+ """ Initialize a vision efficient tuning model.
+
+ Args:
+ model_dir: model id or path, where model_dir/pytorch_model.pt contains:
+ - 'backbone_weight': parameters of backbone.
+ - 'head_weight': parameters of head.
+ """
+ super().__init__(model_dir)
+
+ self.model = VisionEfficientTuning(model_dir=model_dir, **kwargs)
+ self.CLASSES = self.model.CLASSES
+
+ self.device = torch.device(
+ 'cuda' if torch.cuda.is_available() else 'cpu')
+ self.model.to(self.device)
+
+ def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
+ """ Dynamic forward function of vision efficient tuning model.
+
+ Args:
+ input: the input data dict contanis:
+ - imgs: (B, 3, H, W).
+ - labels: (B), when training stage.
+ """
+ output = self.model(**input)
+ return output
diff --git a/modelscope/models/cv/vision_efficient_tuning/petl.py b/modelscope/models/cv/vision_efficient_tuning/petl.py
index f43ba10b..b92112b6 100644
--- a/modelscope/models/cv/vision_efficient_tuning/petl.py
+++ b/modelscope/models/cv/vision_efficient_tuning/petl.py
@@ -1,8 +1,10 @@
# Copyright 2022-2023 The Alibaba Fundamental Vision Team Authors. All rights reserved.
import math
+from collections import OrderedDict
import torch
import torch.nn as nn
+import torchvision
class Prompt(nn.Module):
@@ -172,3 +174,101 @@ class Prefix(nn.Module):
k, v = torch.cat((k, prefix_key), dim=2), torch.cat((v, prefix_value),
dim=2)
return q, k, v
+
+
+class SideTune(nn.Module):
+ """The implementation of vision side-tuning method.
+
+ Side-Tuning only needs to train one side network and
+ weights the output of pre-trained model and side network.
+ 'Side-Tuning: A Baseline for Network Adaptation via Additive Side Networks'
+ by Zhang et al.(2019)
+ See https://arxiv.org/abs/1912.13503
+
+ Attributes:
+ sidetune_length: An integer indicating the linear dimension.
+ sidetune_type: A string indicating the type of side network.
+ """
+
+ def __init__(self, sidetune_length=None, sidetune_type=None):
+ super(SideTune, self).__init__()
+ self.sidetune_length = sidetune_length
+ self.sidetune_type = sidetune_type
+ if sidetune_type.lower() == 'fcn4':
+ self.side = FCN4(out_dims=self.sidetune_length)
+ if sidetune_type.lower() == 'alexnet':
+ mm = torchvision.models.alexnet(pretrained=True)
+ self.side = nn.Sequential(
+ OrderedDict([
+ ('features', mm.features), ('avgpool', mm.avgpool),
+ ('flatten', nn.Flatten()),
+ ('fc', nn.Linear(9216, self.sidetune_length, bias=False))
+ ]))
+ self.alpha = nn.Parameter(torch.tensor(0.0))
+
+ def forward(self, x, x_base):
+ alpha_squashed = torch.sigmoid(self.alpha)
+ x_side = self.side(x)
+ x_out = alpha_squashed * x_base + (1 - alpha_squashed) * x_side
+ return x_out
+
+
+class FCN4(nn.Module):
+ """The implementation of simple FCN4 network for side network.
+ """
+
+ def __init__(self, out_dims=-1, **kwargs):
+ super(FCN4, self).__init__(**kwargs)
+
+ self.conv1 = nn.Sequential(
+ nn.Conv2d(
+ 3,
+ 16,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False,
+ dilation=1), nn.GroupNorm(2, 16), nn.ReLU())
+ self.conv2 = nn.Sequential(
+ nn.Conv2d(
+ 16,
+ 16,
+ kernel_size=3,
+ stride=2,
+ padding=0,
+ bias=False,
+ dilation=1), nn.GroupNorm(2, 16), nn.ReLU())
+ self.conv3 = nn.Sequential(
+ nn.Conv2d(
+ 16,
+ 32,
+ kernel_size=3,
+ stride=2,
+ padding=0,
+ bias=False,
+ dilation=1), nn.GroupNorm(2, 32), nn.ReLU())
+ self.conv4 = nn.Sequential(
+ nn.Conv2d(
+ 32,
+ 64,
+ kernel_size=3,
+ stride=1,
+ padding=0,
+ bias=False,
+ dilation=1), nn.GroupNorm(2, 64), nn.ReLU())
+ self.pool = nn.AdaptiveAvgPool2d((1, 1))
+ if out_dims > 0:
+ self.fc = nn.Linear(64, out_dims)
+ else:
+ self.fc = None
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = self.conv2(x)
+ x = self.conv3(x)
+ x = self.conv4(x)
+ x = self.pool(x)
+ x = x.view(x.size(0), -1)
+ if self.fc is not None:
+ x = self.fc(x)
+ return x
diff --git a/modelscope/models/cv/vision_efficient_tuning/vision_efficient_tuning.py b/modelscope/models/cv/vision_efficient_tuning/vision_efficient_tuning.py
index 629e7fac..03d1ae14 100644
--- a/modelscope/models/cv/vision_efficient_tuning/vision_efficient_tuning.py
+++ b/modelscope/models/cv/vision_efficient_tuning/vision_efficient_tuning.py
@@ -1,65 +1,154 @@
# Copyright 2022-2023 The Alibaba Fundamental Vision Team Authors. All rights reserved.
import os
+from collections import OrderedDict
import torch
+import torch.nn as nn
+import torch.nn.functional as F
-from modelscope.metainfo import Models
-from modelscope.models.base.base_torch_model import TorchModel
-from modelscope.models.builder import MODELS
-from modelscope.utils.constant import ModelFile, Tasks
+from modelscope.outputs import OutputKeys
+from modelscope.utils.constant import ModelFile
-@MODELS.register_module(
- Tasks.vision_efficient_tuning, module_name=Models.vision_efficient_tuning)
-class VisionEfficientTuningModel(TorchModel):
+class VisionEfficientTuning(nn.Module):
""" The implementation of vision efficient tuning.
This model is constructed with the following parts:
- 'backbone': pre-trained backbone model with parameters.
- 'head': classification head with fine-tuning.
+ - 'loss': loss function for training.
"""
- def __init__(self, model_dir: str, **kwargs):
+ def __init__(self,
+ backbone=None,
+ head=None,
+ loss=None,
+ pretrained=True,
+ finetune=False,
+ **kwargs):
""" Initialize a vision efficient tuning model.
Args:
- model_dir: model id or path, where model_dir/pytorch_model.pt contains:
- - 'backbone_cfg': config of backbone.
- - 'backbone_weight': parameters of backbone.
- - 'head_cfg': config of head.
- - 'head_weight': parameters of head.
- - 'CLASSES': list of label name.
+ backbone: config of backbone.
+ head: config of head.
+ loss: config of loss.
+ pretrained: whether to load the pretrained model.
+ finetune: whether to finetune the model.
"""
-
from .backbone import VisionTransformerPETL
from .head import ClassifierHead
- super().__init__(model_dir)
- model_path = os.path.join(model_dir, ModelFile.TORCH_MODEL_FILE)
- model_dict = torch.load(model_path)
+ super(VisionEfficientTuning, self).__init__()
- backbone_cfg = model_dict['backbone_cfg']
- if 'type' in backbone_cfg:
- backbone_cfg.pop('type')
- self.backbone_model = VisionTransformerPETL(**backbone_cfg)
- self.backbone_model.load_state_dict(
- model_dict['backbone_weight'], strict=True)
+ if backbone and 'type' in backbone:
+ backbone.pop('type')
+ self.backbone = VisionTransformerPETL(**backbone)
+ else:
+ self.backbone = None
- head_cfg = model_dict['head_cfg']
- if 'type' in head_cfg:
- head_cfg.pop('type')
- self.head_model = ClassifierHead(**head_cfg)
- self.head_model.load_state_dict(model_dict['head_weight'], strict=True)
+ # TODO Use a more elegant method to build the model.
+ if head and 'type' in head:
+ head.pop('type')
+ self.head = ClassifierHead(**head)
+ else:
+ self.head = None
- self.CLASSES = model_dict['CLASSES']
+ if loss and 'type' in loss:
+ self.loss = getattr(torch.nn, loss['type'])()
+ else:
+ self.loss = torch.nn.CrossEntropyLoss()
- def forward(self, inputs):
+ self.CLASSES = kwargs.pop('CLASSES', None)
+ self.pretrained_cfg = kwargs.pop('pretrained_cfg', None)
+
+ if pretrained:
+ assert 'model_dir' in kwargs, 'pretrained model dir is missing.'
+ model_path = os.path.join(kwargs['model_dir'],
+ ModelFile.TORCH_MODEL_FILE)
+ model_dict = torch.load(model_path, map_location='cpu')
+
+ if self.backbone is None and 'backbone_cfg' in model_dict:
+ model_dict['backbone_cfg'].pop('type')
+ self.backbone = VisionTransformerPETL(
+ **model_dict['backbone_cfg'])
+ if self.head is None and 'head_cfg' in model_dict:
+ model_dict['head_cfg'].pop('type')
+ self.head = ClassifierHead(**model_dict['head_cfg'])
+
+ if 'backbone_weight' in model_dict:
+ backbone_weight = model_dict['backbone_weight']
+ if finetune and self.pretrained_cfg and 'unload_part' in self.pretrained_cfg \
+ and 'backbone' in self.pretrained_cfg['unload_part']:
+ backbone_weight = self.filter_weight(
+ backbone_weight,
+ self.pretrained_cfg['unload_part']['backbone'])
+ self.backbone.load_state_dict(backbone_weight, strict=False)
+
+ if 'head_weight' in model_dict:
+ head_weight = model_dict['head_weight']
+ if finetune and self.pretrained_cfg and 'unload_part' in self.pretrained_cfg \
+ and 'head' in self.pretrained_cfg['unload_part']:
+ head_weight = self.filter_weight(
+ head_weight,
+ self.pretrained_cfg['unload_part']['head'])
+ self.head.load_state_dict(head_weight, strict=False)
+
+ self.CLASSES = model_dict[
+ 'CLASSES'] if 'CLASSES' in model_dict else self.CLASSES
+
+ def filter_weight(self, weights, unload_part=[]):
+ """ Filter parameters that the model does not need to load.
+
+ Args:
+ weights: the parameters of the model.
+ unload_part: the config of unloading parameters.
+ """
+ ret_dict = {}
+ for key, value in weights.items():
+ flag = sum([p in key for p in unload_part]) > 0
+ if not flag:
+ ret_dict[key] = value
+ return ret_dict
+
+ def forward(self, imgs, labels=None, **kwargs):
""" Dynamic forward function of vision efficient tuning.
Args:
- inputs: the input images (B, 3, H, W).
+ imgs: (B, 3, H, W).
+ labels: (B), when training stage.
"""
+ return self.forward_train(imgs, labels, **kwargs) \
+ if self.training else self.forward_test(imgs, labels, **kwargs)
- backbone_output = self.backbone_model(inputs)
- head_output = self.head_model(backbone_output)
- return head_output
+ def forward_train(self, imgs, labels=None):
+ """ Dynamic forward function of training stage.
+
+ Args:
+ imgs: (B, 3, H, W).
+ labels: (B), when training stage.
+ """
+ output = OrderedDict()
+
+ backbone_output = self.backbone(imgs)
+ head_output = self.head(backbone_output)
+ loss = self.loss(head_output, labels)
+
+ output = {OutputKeys.LOSS: loss}
+ return output
+
+ def forward_test(self, imgs, labels=None):
+ """ Dynamic forward function of testing stage.
+
+ Args:
+ imgs: (B, 3, H, W).
+ labels: (B), when training stage.
+ """
+ output = OrderedDict()
+ backbone_output = self.backbone(imgs)
+ head_output = self.head(backbone_output)
+
+ scores = F.softmax(head_output, dim=1)
+ preds = scores.topk(1, 1, True, True)[-1].squeeze(-1)
+
+ output = {OutputKeys.SCORES: scores, OutputKeys.LABELS: preds}
+ return output
diff --git a/modelscope/models/cv/vop_retrieval/__init__.py b/modelscope/models/cv/vop_retrieval/__init__.py
index 5b3e762c..e3708334 100644
--- a/modelscope/models/cv/vop_retrieval/__init__.py
+++ b/modelscope/models/cv/vop_retrieval/__init__.py
@@ -6,6 +6,7 @@ from modelscope.utils.import_utils import LazyImportModule
if TYPE_CHECKING:
from .basic_utils import set_seed, get_state_dict, load_data, init_transform_dict, load_frames_from_video
from .model import VoP
+ from .model_se import VoP_SE
from .tokenization_clip import LengthAdaptiveTokenizer
else:
_import_structure = {
@@ -14,6 +15,7 @@ else:
'load_frames_from_video'
],
'model': ['VoP'],
+ 'model_se': ['VideoTextRetrievalModelSeries'],
'tokenization_clip': ['LengthAdaptiveTokenizer']
}
diff --git a/modelscope/models/cv/vop_retrieval/model_se.py b/modelscope/models/cv/vop_retrieval/model_se.py
new file mode 100644
index 00000000..c96aa88e
--- /dev/null
+++ b/modelscope/models/cv/vop_retrieval/model_se.py
@@ -0,0 +1,156 @@
+# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.
+
+import os
+import os.path as osp
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from modelscope.metainfo import Models
+from modelscope.models.base.base_torch_model import TorchModel
+from modelscope.models.builder import MODELS
+from modelscope.utils.config import Config
+from modelscope.utils.constant import ModelFile, Tasks
+from .backbone import load_clip
+from .basic_utils import get_state_dict, set_seed
+
+
+@MODELS.register_module(
+ Tasks.vop_retrieval, module_name=Models.vop_retrieval_model_se)
+class VideoTextRetrievalModelSeries(TorchModel):
+ """
+ The implementation of 'VoP: Text-Video Co-operative Prompt Tuning for Cross-Modal Retrieval'.
+ This model is dynamically initialized with the following parts:
+ - clip: the upstream pre-trained backbone model (CLIP in this code).
+ - The pretrain param (ViT-B/32) downloads from OpenAI:
+ - "https://openaipublic.azureedge.net/clip/models/
+ - 40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"
+ - pool_frames: the frames pooling method
+ - visual_prompt_learner: visual prompt
+ - ImageEncoder: get image encoder
+ - TextPromptLearner: text prompt
+ - TextEncoder: get text encoder
+ """
+
+ def __init__(self, model_dir: str, *args, **kwargs):
+ """
+ Initialize a VoP Model
+
+ Args:
+ model_dir: model id or path,
+ """
+ super(VideoTextRetrievalModelSeries, self).__init__()
+ model_path = osp.join(model_dir, 'VoPSE_msrvtt9k.pth')
+ clip_arch = osp.join(model_dir, 'ViT-B-32.pt')
+ config_path = osp.join(model_dir, ModelFile.CONFIGURATION)
+
+ self.config = Config.from_file(config_path).hyperparam
+ self.clip = load_clip(name=clip_arch)
+
+ self.pool_frames = BaselinePooling(self.config.pooling_type)
+
+ # load param from pre-train model
+ self.load_state_dict(get_state_dict(model_path))
+
+ # eval model
+ self.eval()
+
+ def get_video_features(self, videos, return_all_frames=False):
+ """
+ Get video Features
+
+ Args:
+ videos: the dim is [1, 12, 3, 224, 224]
+ return_all_frames: default False
+ """
+ batch_size = videos.shape[0]
+ video_data = videos.reshape(-1, 3, self.config.input_res,
+ self.config.input_res)
+
+ video_features = self.clip.encode_image(video_data)
+
+ video_features = video_features / video_features.norm(
+ dim=-1, keepdim=True)
+ video_features = video_features.reshape(batch_size,
+ self.config.num_frames, -1)
+
+ video_features_pooled = self.pool_frames(video_features)
+
+ if return_all_frames:
+ return video_features, video_features_pooled
+
+ return video_features_pooled
+
+ def get_text_features(self, text_data):
+ """
+ Get Text Features
+
+ Args:
+ text_data: the dim is [1, 69]
+ """
+ text_features = self.clip.encode_text(text_data)
+
+ text_features = text_features / text_features.norm(
+ dim=-1, keepdim=True)
+ return text_features
+
+ def forward(self, data, return_all_frames=False):
+ """
+ Dynamic Forward Function of VoP
+
+ Args:
+ data: the input data
+ return_all_frames: default False
+ """
+ batch_size = data['video'].shape[0]
+ text_data = data['text']
+ video_data = data['video']
+ video_data = video_data.reshape(-1, 3, self.config.input_res,
+ self.config.input_res)
+
+ text_features = self.clip.encode_text(text_data)
+ video_features = self.clip.encode_image(video_data)
+
+ text_features = text_features / text_features.norm(
+ dim=-1, keepdim=True)
+ video_features = video_features / video_features.norm(
+ dim=-1, keepdim=True)
+ video_features = video_features.reshape(batch_size,
+ self.config.num_frames, -1)
+
+ video_features_pooled = self.pool_frames(video_features)
+
+ if return_all_frames:
+ return text_features, video_features, video_features_pooled
+
+ return text_features, video_features_pooled
+
+
+class BaselinePooling(TorchModel):
+ """
+ Redefined Pooling Function
+ """
+
+ def __init__(self, pooling_type):
+ super(BaselinePooling, self).__init__()
+ if pooling_type == 'avg':
+ self.pooling_func = self._avg_pooling
+ else:
+ raise NotImplementedError
+
+ def _avg_pooling(self, video_embeds):
+ """
+ Pooling mean of frames
+
+ Args:
+ video_embeds: the input video embedding with [1, 12, 512].
+
+ Returns:
+ video_embeds_pooled: num_vids x embed_dim
+ """
+ video_embeds_pooled = video_embeds.mean(dim=1)
+ return video_embeds_pooled
+
+ def forward(self, video_embeds):
+ return self.pooling_func(video_embeds)
diff --git a/modelscope/models/multi_modal/__init__.py b/modelscope/models/multi_modal/__init__.py
index 4edf6212..8bf9f018 100644
--- a/modelscope/models/multi_modal/__init__.py
+++ b/modelscope/models/multi_modal/__init__.py
@@ -7,6 +7,7 @@ if TYPE_CHECKING:
from .clip import CLIPForMultiModalEmbedding
from .gemm import GEMMForMultiModalEmbedding
+ from .rleg import RLEGForMultiModalEmbedding
from .team import TEAMForMultiModalSimilarity
from .diffusion import DiffusionForTextToImageSynthesis
from .mmr import VideoCLIPForMultiModalEmbedding
@@ -17,12 +18,14 @@ if TYPE_CHECKING:
from .multi_stage_diffusion import \
MultiStageDiffusionForTextToImageSynthesis
from .vldoc import VLDocForDocVLEmbedding
+ from .video_synthesis import TextToVideoSynthesis
else:
_import_structure = {
'clip': ['CLIPForMultiModalEmbedding'],
'diffusion': ['DiffusionForTextToImageSynthesis'],
'gemm': ['GEMMForMultiModalEmbedding'],
+ 'rleg': ['RLEGForMultiModalEmbedding'],
'team': ['TEAMForMultiModalSimilarity'],
'mmr': ['VideoCLIPForMultiModalEmbedding'],
'mplug_for_all_tasks': ['MPlugForAllTasks', 'HiTeAForAllTasks'],
@@ -32,6 +35,7 @@ else:
'multi_stage_diffusion':
['MultiStageDiffusionForTextToImageSynthesis'],
'vldoc': ['VLDocForDocVLEmbedding'],
+ 'video_synthesis': ['TextToVideoSynthesis'],
}
import sys
diff --git a/modelscope/models/multi_modal/guided_diffusion/__init__.py b/modelscope/models/multi_modal/guided_diffusion/__init__.py
new file mode 100644
index 00000000..93d0ca51
--- /dev/null
+++ b/modelscope/models/multi_modal/guided_diffusion/__init__.py
@@ -0,0 +1,23 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+from typing import TYPE_CHECKING
+
+from modelscope.utils.import_utils import LazyImportModule
+
+if TYPE_CHECKING:
+ from .unet import HFUNetModel
+ from .script import create_diffusion
+else:
+ _import_structure = {
+ 'unet': ['HFUNetModel'],
+ 'script': ['create_diffusion']
+ }
+
+ import sys
+
+ sys.modules[__name__] = LazyImportModule(
+ __name__,
+ globals()['__file__'],
+ _import_structure,
+ module_spec=__spec__,
+ extra_objects={},
+ )
diff --git a/modelscope/models/multi_modal/guided_diffusion/gaussian_diffusion.py b/modelscope/models/multi_modal/guided_diffusion/gaussian_diffusion.py
new file mode 100644
index 00000000..430aa378
--- /dev/null
+++ b/modelscope/models/multi_modal/guided_diffusion/gaussian_diffusion.py
@@ -0,0 +1,930 @@
+# This code is borrowed and modified from Guided Diffusion Model,
+# made publicly available under MIT license
+# at https://github.com/IDEA-CCNL/Fengshenbang-LM/tree/main/fengshen/examples/disco_project
+
+import enum
+import math
+
+import numpy as np
+import torch as th
+
+
+def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
+ """
+ Get a pre-defined beta schedule for the given name.
+
+ The beta schedule library consists of beta schedules which remain similar
+ in the limit of num_diffusion_timesteps.
+ Beta schedules may be added, but should not be removed or changed once
+ they are committed to maintain backwards compatibility.
+ """
+ if schedule_name == 'linear':
+ # Linear schedule from Ho et al, extended to work for any number of
+ # diffusion steps.
+ scale = 1000 / num_diffusion_timesteps
+ beta_start = scale * 0.0001
+ beta_end = scale * 0.02
+ return np.linspace(
+ beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
+ elif schedule_name == 'cosine':
+ return betas_for_alpha_bar(
+ num_diffusion_timesteps,
+ lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2)**2,
+ )
+ else:
+ raise NotImplementedError(f'unknown beta schedule: {schedule_name}')
+
+
+def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function,
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
+
+ :param num_diffusion_timesteps: the number of betas to produce.
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
+ produces the cumulative product of (1-beta) up to that
+ part of the diffusion process.
+ :param max_beta: the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+ """
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
+ return np.array(betas)
+
+
+class ModelMeanType(enum.Enum):
+ """
+ Which type of output the model predicts.
+ """
+
+ PREVIOUS_X = enum.auto() # the model predicts x_{t-1}
+ START_X = enum.auto() # the model predicts x_0
+ EPSILON = enum.auto() # the model predicts epsilon
+
+
+class ModelVarType(enum.Enum):
+ """
+ What is used as the model's output variance.
+
+ The LEARNED_RANGE option has been added to allow the model to predict
+ values between FIXED_SMALL and FIXED_LARGE, making its job easier.
+ """
+
+ LEARNED = enum.auto()
+ FIXED_SMALL = enum.auto()
+ FIXED_LARGE = enum.auto()
+ LEARNED_RANGE = enum.auto()
+
+
+class LossType(enum.Enum):
+ MSE = enum.auto() # use raw MSE loss (and KL when learning variances)
+ RESCALED_MSE = (
+ enum.auto()
+ ) # use raw MSE loss (with RESCALED_KL when learning variances)
+ KL = enum.auto() # use the variational lower-bound
+ RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB
+
+ def is_vb(self):
+ return self == LossType.KL or self == LossType.RESCALED_KL
+
+
+class GaussianDiffusion:
+ """
+ Utilities for training and sampling diffusion models.
+
+ Ported directly from here, and then adapted over time to further experimentation.
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
+
+ :param betas: a 1-D numpy array of betas for each diffusion timestep,
+ starting at T and going to 1.
+ :param model_mean_type: a ModelMeanType determining what the model outputs.
+ :param model_var_type: a ModelVarType determining how variance is output.
+ :param loss_type: a LossType determining the loss function to use.
+ :param rescale_timesteps: if True, pass floating point timesteps into the
+ model so that they are always scaled like in the
+ original paper (0 to 1000).
+ """
+
+ def __init__(
+ self,
+ *,
+ betas,
+ model_mean_type,
+ model_var_type,
+ loss_type,
+ rescale_timesteps=False,
+ ):
+ self.model_mean_type = model_mean_type
+ self.model_var_type = model_var_type
+ self.loss_type = loss_type
+ self.rescale_timesteps = rescale_timesteps
+
+ # Use float64 for accuracy.
+ betas = np.array(betas, dtype=np.float64)
+ self.betas = betas
+ assert len(betas.shape) == 1, 'betas must be 1-D'
+ assert (betas > 0).all() and (betas <= 1).all()
+
+ self.num_timesteps = int(betas.shape[0])
+
+ alphas = 1.0 - betas
+ self.alphas_cumprod = np.cumprod(alphas, axis=0)
+ self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
+ self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
+ assert self.alphas_cumprod_prev.shape == (self.num_timesteps, )
+
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+ self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
+ self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
+ self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
+ self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
+ self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod
+ - 1)
+
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
+ v1 = betas * (1.0 - self.alphas_cumprod_prev)
+ v2 = 1.0 - self.alphas_cumprod
+ self.posterior_variance = v1 / v2
+ # log calculation clipped because the posterior variance is 0 at the
+ # beginning of the diffusion chain.
+ self.posterior_log_variance_clipped = np.log(
+ np.append(self.posterior_variance[1], self.posterior_variance[1:]))
+
+ v1 = betas * np.sqrt(self.alphas_cumprod_prev)
+ v2 = 1.0 - self.alphas_cumprod
+ self.posterior_mean_coef1 = v1 / v2
+
+ v1 = (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas)
+ v2 = 1.0 - self.alphas_cumprod
+ self.posterior_mean_coef2 = v1 / v2
+
+ def q_mean_variance(self, x_start, t):
+ """
+ Get the distribution q(x_t | x_0).
+
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
+ """
+ mean = (
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape)
+ * x_start)
+ variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t,
+ x_start.shape)
+ log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod,
+ t, x_start.shape)
+ return mean, variance, log_variance
+
+ def q_sample(self, x_start, t, noise=None):
+ """
+ Diffuse the data for a given number of diffusion steps.
+
+ In other words, sample from q(x_t | x_0).
+
+ :param x_start: the initial data batch.
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
+ :param noise: if specified, the split-out normal noise.
+ :return: A noisy version of x_start.
+ """
+ if noise is None:
+ noise = th.randn_like(x_start)
+ assert noise.shape == x_start.shape
+ return (
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape)
+ * x_start + _extract_into_tensor(
+ self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
+
+ def q_posterior_mean_variance(self, x_start, x_t, t):
+ """
+ Compute the mean and variance of the diffusion posterior:
+
+ q(x_{t-1} | x_t, x_0)
+
+ """
+ assert x_start.shape == x_t.shape
+ posterior_mean = (
+ _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape)
+ * x_start
+ + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape)
+ * x_t)
+ posterior_variance = _extract_into_tensor(self.posterior_variance, t,
+ x_t.shape)
+ posterior_log_variance_clipped = _extract_into_tensor(
+ self.posterior_log_variance_clipped, t, x_t.shape)
+ assert posterior_mean.shape[0] == posterior_variance.shape[0]
+ assert posterior_mean.shape[0] == posterior_log_variance_clipped.shape[
+ 0]
+ assert posterior_mean.shape[0] == x_start.shape[0]
+
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
+
+ def p_mean_variance(self,
+ model,
+ x,
+ t,
+ clip_denoised=True,
+ denoised_fn=None,
+ model_kwargs=None):
+ """
+ Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
+ the initial x, x_0.
+
+ :param model: the model, which takes a signal and a batch of timesteps
+ as input.
+ :param x: the [N x C x ...] tensor at time t.
+ :param t: a 1-D Tensor of timesteps.
+ :param clip_denoised: if True, clip the denoised signal into [-1, 1].
+ :param denoised_fn: if not None, a function which applies to the
+ x_start prediction before it is used to sample. Applies before
+ clip_denoised.
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
+ pass to the model. This can be used for conditioning.
+ :return: a dict with the following keys:
+ - 'mean': the model mean output.
+ - 'variance': the model variance output.
+ - 'log_variance': the log of 'variance'.
+ - 'pred_xstart': the prediction for x_0.
+ """
+ if model_kwargs is None:
+ model_kwargs = {}
+
+ B, C = x.shape[:2]
+ assert t.shape == (B, )
+ model_output = model(x, self._scale_timesteps(t), **model_kwargs)
+
+ if self.model_var_type in [
+ ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE
+ ]:
+ assert model_output.shape == (B, C * 2, *x.shape[2:])
+ model_output, model_var_values = th.split(model_output, C, dim=1)
+ if self.model_var_type == ModelVarType.LEARNED:
+ model_log_variance = model_var_values
+ model_variance = th.exp(model_log_variance)
+ else:
+ min_log = _extract_into_tensor(
+ self.posterior_log_variance_clipped, t, x.shape)
+ max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
+ # The model_var_values is [-1, 1] for [min_var, max_var].
+ frac = (model_var_values + 1) / 2
+ model_log_variance = frac * max_log + (1 - frac) * min_log
+ model_variance = th.exp(model_log_variance)
+ else:
+ model_variance, model_log_variance = {
+ # for fixedlarge, we set the initial (log-)variance like so
+ # to get a better decoder log likelihood.
+ ModelVarType.FIXED_LARGE: (
+ np.append(self.posterior_variance[1], self.betas[1:]),
+ np.log(
+ np.append(self.posterior_variance[1], self.betas[1:])),
+ ),
+ ModelVarType.FIXED_SMALL: (
+ self.posterior_variance,
+ self.posterior_log_variance_clipped,
+ ),
+ }[self.model_var_type]
+ model_variance = _extract_into_tensor(model_variance, t, x.shape)
+ model_log_variance = _extract_into_tensor(model_log_variance, t,
+ x.shape)
+
+ def process_xstart(x):
+ if denoised_fn is not None:
+ x = denoised_fn(x)
+ if clip_denoised:
+ return x.clamp(-1, 1)
+ return x
+
+ if self.model_mean_type == ModelMeanType.PREVIOUS_X:
+ pred_xstart = process_xstart(
+ self._predict_xstart_from_xprev(
+ x_t=x, t=t, xprev=model_output))
+ model_mean = model_output
+ elif self.model_mean_type in [
+ ModelMeanType.START_X, ModelMeanType.EPSILON
+ ]:
+ if self.model_mean_type == ModelMeanType.START_X:
+ pred_xstart = process_xstart(model_output)
+ else:
+ pred_xstart = process_xstart(
+ self._predict_xstart_from_eps(
+ x_t=x, t=t, eps=model_output))
+ model_mean, _, _ = self.q_posterior_mean_variance(
+ x_start=pred_xstart, x_t=x, t=t)
+ else:
+ raise NotImplementedError(self.model_mean_type)
+
+ return {
+ 'mean': model_mean,
+ 'variance': model_variance,
+ 'log_variance': model_log_variance,
+ 'pred_xstart': pred_xstart,
+ }
+
+ def _predict_xstart_from_eps(self, x_t, t, eps):
+ assert x_t.shape == eps.shape
+ return (
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape)
+ * x_t - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t,
+ x_t.shape) * eps)
+
+ def _predict_xstart_from_xprev(self, x_t, t, xprev):
+ assert x_t.shape == xprev.shape
+ return ( # (xprev - coef2*x_t) / coef1
+ _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape)
+ * xprev - _extract_into_tensor(
+ self.posterior_mean_coef2 / self.posterior_mean_coef1, t,
+ x_t.shape) * x_t)
+
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
+ return (
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape)
+ * x_t - pred_xstart) / _extract_into_tensor(
+ self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
+
+ def _scale_timesteps(self, t):
+ if self.rescale_timesteps:
+ return t.float() * (1000.0 / self.num_timesteps)
+ return t
+
+ def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
+ """
+ Compute the mean for the previous step, given a function cond_fn that
+ computes the gradient of a conditional log probability with respect to
+ x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
+ condition on y.
+
+ This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
+ """
+ gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs)
+ new_mean = (
+ p_mean_var['mean'].float()
+ + p_mean_var['variance'] * gradient.float())
+ return new_mean
+
+ def condition_mean_with_grad(self,
+ cond_fn,
+ p_mean_var,
+ x,
+ t,
+ model_kwargs=None):
+ """
+ Compute the mean for the previous step, given a function cond_fn that
+ computes the gradient of a conditional log probability with respect to
+ x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
+ condition on y.
+
+ This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
+ """
+ gradient = cond_fn(x, t, p_mean_var, **model_kwargs)
+ new_mean = (
+ p_mean_var['mean'].float()
+ + p_mean_var['variance'] * gradient.float())
+ return new_mean
+
+ def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
+ """
+ Compute what the p_mean_variance output would have been, should the
+ model's score function be conditioned by cond_fn.
+
+ See condition_mean() for details on cond_fn.
+
+ Unlike condition_mean(), this instead uses the conditioning strategy
+ from Song et al (2020).
+ """
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
+
+ eps = self._predict_eps_from_xstart(x, t, p_mean_var['pred_xstart'])
+ eps = eps - (1 - alpha_bar).sqrt() * cond_fn(
+ x, self._scale_timesteps(t), **model_kwargs)
+
+ out = p_mean_var.copy()
+ out['pred_xstart'] = self._predict_xstart_from_eps(x, t, eps)
+ out['mean'], _, _ = self.q_posterior_mean_variance(
+ x_start=out['pred_xstart'], x_t=x, t=t)
+ return out
+
+ def condition_score_with_grad(self,
+ cond_fn,
+ p_mean_var,
+ x,
+ t,
+ model_kwargs=None):
+ """
+ Compute what the p_mean_variance output would have been, should the
+ model's score function be conditioned by cond_fn.
+
+ See condition_mean() for details on cond_fn.
+
+ Unlike condition_mean(), this instead uses the conditioning strategy
+ from Song et al (2020).
+ """
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
+
+ eps = self._predict_eps_from_xstart(x, t, p_mean_var['pred_xstart'])
+
+ grad = cond_fn(x, t, p_mean_var, **model_kwargs)
+ eps = eps - (1 - alpha_bar).sqrt() * grad
+
+ out = p_mean_var.copy()
+ out['pred_xstart'] = self._predict_xstart_from_eps(x, t, eps)
+ out['mean'], _, _ = self.q_posterior_mean_variance(
+ x_start=out['pred_xstart'], x_t=x, t=t)
+ return out
+
+ def p_sample(
+ self,
+ model,
+ x,
+ t,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ ):
+ """
+ Sample x_{t-1} from the model at the given timestep.
+
+ :param model: the model to sample from.
+ :param x: the current tensor at x_{t-1}.
+ :param t: the value of t, starting at 0 for the first diffusion step.
+ :param clip_denoised: if True, clip the x_start prediction to [-1, 1].
+ :param denoised_fn: if not None, a function which applies to the
+ x_start prediction before it is used to sample.
+ :param cond_fn: if not None, this is a gradient function that acts
+ similarly to the model.
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
+ pass to the model. This can be used for conditioning.
+ :return: a dict containing the following keys:
+ - 'sample': a random sample from the model.
+ - 'pred_xstart': a prediction of x_0.
+ """
+ out = self.p_mean_variance(
+ model,
+ x,
+ t,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ model_kwargs=model_kwargs,
+ )
+ noise = th.randn_like(x)
+ nonzero_mask = ((t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
+ ) # no noise when t == 0
+ if cond_fn is not None:
+ out['mean'] = self.condition_mean(
+ cond_fn, out, x, t, model_kwargs=model_kwargs)
+ sample = out['mean'] + nonzero_mask * th.exp(
+ 0.5 * out['log_variance']) * noise
+ return {'sample': sample, 'pred_xstart': out['pred_xstart']}
+
+ def p_sample_with_grad(
+ self,
+ model,
+ x,
+ t,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ ):
+ """
+ Sample x_{t-1} from the model at the given timestep.
+
+ :param model: the model to sample from.
+ :param x: the current tensor at x_{t-1}.
+ :param t: the value of t, starting at 0 for the first diffusion step.
+ :param clip_denoised: if True, clip the x_start prediction to [-1, 1].
+ :param denoised_fn: if not None, a function which applies to the
+ x_start prediction before it is used to sample.
+ :param cond_fn: if not None, this is a gradient function that acts
+ similarly to the model.
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
+ pass to the model. This can be used for conditioning.
+ :return: a dict containing the following keys:
+ - 'sample': a random sample from the model.
+ - 'pred_xstart': a prediction of x_0.
+ """
+ with th.enable_grad():
+ x = x.detach().requires_grad_()
+ out = self.p_mean_variance(
+ model,
+ x,
+ t,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ model_kwargs=model_kwargs,
+ )
+ noise = th.randn_like(x)
+ nonzero_mask = ((t != 0).float().view(-1,
+ *([1] * (len(x.shape) - 1)))
+ ) # no noise when t == 0
+ if cond_fn is not None:
+ out['mean'] = self.condition_mean_with_grad(
+ cond_fn, out, x, t, model_kwargs=model_kwargs)
+ sample = out['mean'] + nonzero_mask * th.exp(
+ 0.5 * out['log_variance']) * noise
+ return {'sample': sample, 'pred_xstart': out['pred_xstart'].detach()}
+
+ def p_sample_loop(
+ self,
+ model,
+ shape,
+ noise=None,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ device=None,
+ progress=False,
+ skip_timesteps=0,
+ init_image=None,
+ randomize_class=False,
+ cond_fn_with_grad=False,
+ ):
+ """
+ Generate samples from the model.
+
+ :param model: the model module.
+ :param shape: the shape of the samples, (N, C, H, W).
+ :param noise: if specified, the noise from the encoder to sample.
+ Should be of the same shape as `shape`.
+ :param clip_denoised: if True, clip x_start predictions to [-1, 1].
+ :param denoised_fn: if not None, a function which applies to the
+ x_start prediction before it is used to sample.
+ :param cond_fn: if not None, this is a gradient function that acts
+ similarly to the model.
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
+ pass to the model. This can be used for conditioning.
+ :param device: if specified, the device to create the samples on.
+ If not specified, use a model parameter's device.
+ :param progress: if True, show a tqdm progress bar.
+ :return: a non-differentiable batch of samples.
+ """
+ final = None
+ for sample in self.p_sample_loop_progressive(
+ model,
+ shape,
+ noise=noise,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ cond_fn=cond_fn,
+ model_kwargs=model_kwargs,
+ device=device,
+ progress=progress,
+ skip_timesteps=skip_timesteps,
+ init_image=init_image,
+ randomize_class=randomize_class,
+ cond_fn_with_grad=cond_fn_with_grad,
+ ):
+ final = sample
+ return final['sample']
+
+ def p_sample_loop_progressive(
+ self,
+ model,
+ shape,
+ noise=None,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ device=None,
+ progress=False,
+ skip_timesteps=0,
+ init_image=None,
+ randomize_class=False,
+ cond_fn_with_grad=False,
+ ):
+ """
+ Generate samples from the model and yield intermediate samples from
+ each timestep of diffusion.
+
+ Arguments are the same as p_sample_loop().
+ Returns a generator over dicts, where each dict is the return value of
+ p_sample().
+ """
+ if device is None:
+ device = next(model.parameters()).device
+ assert isinstance(shape, (tuple, list))
+ if noise is not None:
+ img = noise
+ else:
+ img = th.randn(*shape, device=device)
+
+ if skip_timesteps and init_image is None:
+ init_image = th.zeros_like(img)
+
+ indices = list(range(self.num_timesteps - skip_timesteps))[::-1]
+
+ if init_image is not None:
+ my_t = th.ones([shape[0]], device=device,
+ dtype=th.long) * indices[0]
+ img = self.q_sample(init_image, my_t, img)
+
+ if progress:
+ # Lazy import so that we don't depend on tqdm.
+ from tqdm.auto import tqdm
+
+ indices = tqdm(indices, desc='Steps')
+
+ for i in indices:
+ t = th.tensor([i] * shape[0], device=device)
+ if randomize_class and 'y' in model_kwargs:
+ model_kwargs['y'] = th.randint(
+ low=0,
+ high=model.num_classes,
+ size=model_kwargs['y'].shape,
+ device=model_kwargs['y'].device)
+ with th.no_grad():
+ sample_fn = self.p_sample_with_grad if cond_fn_with_grad else self.p_sample
+ out = sample_fn(
+ model,
+ img,
+ t,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ cond_fn=cond_fn,
+ model_kwargs=model_kwargs,
+ )
+ yield out
+ img = out['sample']
+
+ def ddim_sample(
+ self,
+ model,
+ x,
+ t,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ eta=0.0,
+ inpainting_mode=False,
+ orig_img=None,
+ mask_inpaint=None,
+ ):
+ """
+ Sample x_{t-1} from the model using DDIM.
+
+ Same usage as p_sample().
+ """
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
+ if inpainting_mode:
+ noised_orig_img = th.sqrt(alpha_bar) * orig_img + \
+ th.sqrt(1 - alpha_bar) * th.randn_like(x)
+ # noised_orig_img_pil = TF.to_pil_image(noised_orig_img[0].add(1).div(2).clamp(0, 1))
+ # noised_orig_img_pil.save(f'/content/drive/MyDrive/AI/Disco_Diffusion/images_out/InpaintingTest/inpainting_dump/noised_orig_{t[0].item()}.png')
+ x = (1 - mask_inpaint) * noised_orig_img + mask_inpaint * x
+ # mixed_x = TF.to_pil_image(x[0].add(1).div(2).clamp(0, 1))
+ # mixed_x.save(f'/content/drive/MyDrive/AI/Disco_Diffusion/images_out/InpaintingTest/inpainting_dump/mixed_x_{t[0].item()}.png')
+
+ out_orig = self.p_mean_variance(
+ model,
+ x,
+ t,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ model_kwargs=model_kwargs,
+ )
+ if cond_fn is not None:
+ out = self.condition_score(
+ cond_fn, out_orig, x, t, model_kwargs=model_kwargs)
+ else:
+ out = out_orig
+
+ # Usually our model outputs epsilon, but we re-derive it
+ # in case we used x_start or x_prev prediction.
+ eps = self._predict_eps_from_xstart(x, t, out['pred_xstart'])
+
+ alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t,
+ x.shape)
+
+ v1 = eta * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
+ v2 = th.sqrt(1 - alpha_bar / alpha_bar_prev)
+ sigma = v1 * v2
+
+ # Equation 12.
+ noise = th.randn_like(x)
+ mean_pred = (
+ out['pred_xstart'] * th.sqrt(alpha_bar_prev)
+ + th.sqrt(1 - alpha_bar_prev - sigma**2) * eps)
+ nonzero_mask = ((t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
+ ) # no noise when t == 0
+ sample = mean_pred + nonzero_mask * sigma * noise
+ return {'sample': sample, 'pred_xstart': out_orig['pred_xstart']}
+
+ def ddim_sample_with_grad(
+ self,
+ model,
+ x,
+ t,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ eta=0.0,
+ ):
+ """
+ Sample x_{t-1} from the model using DDIM.
+
+ Same usage as p_sample().
+ """
+ with th.enable_grad():
+ x = x.detach().requires_grad_()
+ out_orig = self.p_mean_variance(
+ model,
+ x,
+ t,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ model_kwargs=model_kwargs,
+ )
+ if cond_fn is not None:
+ out = self.condition_score_with_grad(
+ cond_fn, out_orig, x, t, model_kwargs=model_kwargs)
+ else:
+ out = out_orig
+
+ out['pred_xstart'] = out['pred_xstart'].detach()
+
+ # Usually our model outputs epsilon, but we re-derive it
+ # in case we used x_start or x_prev prediction.
+ eps = self._predict_eps_from_xstart(x, t, out['pred_xstart'])
+
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
+ alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t,
+ x.shape)
+
+ v1 = eta * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
+ v2 = th.sqrt(1 - alpha_bar / alpha_bar_prev)
+ sigma = v1 * v2
+
+ # Equation 12.
+ noise = th.randn_like(x)
+ mean_pred = (
+ out['pred_xstart'] * th.sqrt(alpha_bar_prev)
+ + th.sqrt(1 - alpha_bar_prev - sigma**2) * eps)
+ nonzero_mask = ((t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
+ ) # no noise when t == 0
+ sample = mean_pred + nonzero_mask * sigma * noise
+ return {
+ 'sample': sample,
+ 'pred_xstart': out_orig['pred_xstart'].detach()
+ }
+
+ def ddim_sample_loop(
+ self,
+ model,
+ shape,
+ noise=None,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ device=None,
+ progress=False,
+ eta=0.0,
+ skip_timesteps=0,
+ init_image=None,
+ randomize_class=False,
+ cond_fn_with_grad=False,
+ ):
+ """
+ Generate samples from the model using DDIM.
+
+ Same usage as p_sample_loop().
+ """
+ final = None
+ for sample in self.ddim_sample_loop_progressive(
+ model,
+ shape,
+ noise=noise,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ cond_fn=cond_fn,
+ model_kwargs=model_kwargs,
+ device=device,
+ progress=progress,
+ eta=eta,
+ skip_timesteps=skip_timesteps,
+ init_image=init_image,
+ randomize_class=randomize_class,
+ cond_fn_with_grad=cond_fn_with_grad,
+ ):
+ final = sample
+ return final['sample']
+
+ def ddim_sample_loop_progressive(self,
+ model,
+ shape,
+ noise=None,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ device=None,
+ progress=False,
+ eta=0.0,
+ skip_timesteps=0,
+ init_image=None,
+ randomize_class=False,
+ cond_fn_with_grad=False,
+ transformation_fn=None,
+ transformation_percent=[],
+ inpainting_mode=False,
+ mask_inpaint=None,
+ skip_timesteps_orig=None):
+ """
+ Use DDIM to sample from the model and yield intermediate samples from
+ each timestep of DDIM.
+
+ Same usage as p_sample_loop_progressive().
+ """
+
+ if device is None:
+ device = next(model.parameters()).device
+ assert isinstance(shape, (tuple, list))
+ if noise is not None:
+ img = noise
+ else:
+ img = th.randn(*shape, device=device)
+
+ if skip_timesteps and init_image is None:
+ init_image = th.zeros_like(img)
+
+ indices = list(range(self.num_timesteps - skip_timesteps))[::-1]
+ transformation_steps = [
+ int(len(indices) * (1 - i)) for i in transformation_percent
+ ]
+
+ if init_image is not None:
+ my_t = th.ones([shape[0]], device=device,
+ dtype=th.long) * indices[0]
+ img = self.q_sample(init_image, my_t, img)
+
+ if progress:
+ # Lazy import so that we don't depend on tqdm.
+ from tqdm.auto import tqdm
+ indices = tqdm(indices, desc='Steps')
+
+ if inpainting_mode and skip_timesteps_orig is None:
+ skip_timesteps_orig = self.num_timesteps
+
+ for i in indices:
+ t = th.tensor([i] * shape[0], device=device)
+ if randomize_class and 'y' in model_kwargs:
+ model_kwargs['y'] = th.randint(
+ low=0,
+ high=model.num_classes,
+ size=model_kwargs['y'].shape,
+ device=model_kwargs['y'].device)
+ with th.no_grad():
+ if i in transformation_steps and transformation_fn is not None:
+ img = transformation_fn(img)
+ sample_fn = self.ddim_sample_with_grad if cond_fn_with_grad else self.ddim_sample
+ if inpainting_mode \
+ and i >= self.num_timesteps - skip_timesteps_orig \
+ and not cond_fn_with_grad:
+ out = sample_fn(
+ model,
+ img,
+ t,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ cond_fn=cond_fn,
+ model_kwargs=model_kwargs,
+ eta=eta,
+ inpainting_mode=inpainting_mode,
+ orig_img=init_image,
+ mask_inpaint=mask_inpaint,
+ )
+ else:
+ out = sample_fn(
+ model,
+ img,
+ t,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ cond_fn=cond_fn,
+ model_kwargs=model_kwargs,
+ eta=eta,
+ )
+ yield out
+ img = out['sample']
+
+
+def _extract_into_tensor(arr, timesteps, broadcast_shape):
+ """
+ Extract values from a 1-D numpy array for a batch of indices.
+
+ :param arr: the 1-D numpy array.
+ :param timesteps: a tensor of indices into the array to extract.
+ :param broadcast_shape: a larger shape of K dimensions with the batch
+ dimension equal to the length of timesteps.
+ :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
+ """
+ res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
+ while len(res.shape) < len(broadcast_shape):
+ res = res[..., None]
+ return res.expand(broadcast_shape)
diff --git a/modelscope/models/multi_modal/guided_diffusion/respace.py b/modelscope/models/multi_modal/guided_diffusion/respace.py
new file mode 100644
index 00000000..b179aae1
--- /dev/null
+++ b/modelscope/models/multi_modal/guided_diffusion/respace.py
@@ -0,0 +1,78 @@
+# This code is borrowed and modified from Guided Diffusion Model,
+# made publicly available under MIT license
+# at https://github.com/IDEA-CCNL/Fengshenbang-LM/tree/main/fengshen/examples/disco_project
+
+import numpy as np
+import torch as th
+
+from .gaussian_diffusion import GaussianDiffusion
+
+
+class SpacedDiffusion(GaussianDiffusion):
+ """
+ A diffusion process which can skip steps in a base diffusion process.
+
+ :param use_timesteps: a collection (sequence or set) of timesteps from the
+ original diffusion process to retain.
+ :param kwargs: the kwargs to create the base diffusion process.
+ """
+
+ def __init__(self, use_timesteps, **kwargs):
+ self.use_timesteps = set(use_timesteps)
+ self.timestep_map = []
+ self.original_num_steps = len(kwargs['betas'])
+
+ base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
+ last_alpha_cumprod = 1.0
+ new_betas = []
+ for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
+ if i in self.use_timesteps:
+ new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
+ last_alpha_cumprod = alpha_cumprod
+ self.timestep_map.append(i)
+ kwargs['betas'] = np.array(new_betas)
+ super().__init__(**kwargs)
+
+ def p_mean_variance(self, model, *args, **kwargs): # pylint: disable=signature-differs
+ return super().p_mean_variance(
+ self._wrap_model(model), *args, **kwargs)
+
+ def training_losses(self, model, *args, **kwargs): # pylint: disable=signature-differs
+ return super().training_losses(
+ self._wrap_model(model), *args, **kwargs)
+
+ def condition_mean(self, cond_fn, *args, **kwargs):
+ return super().condition_mean(
+ self._wrap_model(cond_fn), *args, **kwargs)
+
+ def condition_score(self, cond_fn, *args, **kwargs):
+ return super().condition_score(
+ self._wrap_model(cond_fn), *args, **kwargs)
+
+ def _wrap_model(self, model):
+ if isinstance(model, _WrappedModel):
+ return model
+ return _WrappedModel(model, self.timestep_map, self.rescale_timesteps,
+ self.original_num_steps)
+
+ def _scale_timesteps(self, t):
+ # Scaling is done by the wrapped model.
+ return t
+
+
+class _WrappedModel:
+
+ def __init__(self, model, timestep_map, rescale_timesteps,
+ original_num_steps):
+ self.model = model
+ self.timestep_map = timestep_map
+ self.rescale_timesteps = rescale_timesteps
+ self.original_num_steps = original_num_steps
+
+ def __call__(self, x, ts, **kwargs):
+ map_tensor = th.tensor(
+ self.timestep_map, device=ts.device, dtype=ts.dtype)
+ new_ts = map_tensor[ts]
+ if self.rescale_timesteps:
+ new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
+ return self.model(x, new_ts, **kwargs)
diff --git a/modelscope/models/multi_modal/guided_diffusion/script.py b/modelscope/models/multi_modal/guided_diffusion/script.py
new file mode 100644
index 00000000..83193379
--- /dev/null
+++ b/modelscope/models/multi_modal/guided_diffusion/script.py
@@ -0,0 +1,39 @@
+# This code is borrowed and modified from Guided Diffusion Model,
+# made publicly available under MIT license
+# at https://github.com/IDEA-CCNL/Fengshenbang-LM/tree/main/fengshen/examples/disco_project
+
+from modelscope.models.cv.motion_generation.modules.respace import \
+ space_timesteps
+from . import gaussian_diffusion as gd
+from .respace import SpacedDiffusion
+
+
+def create_diffusion(diffusion_config):
+ predict_xstart = False
+ sigma_small = False
+ learn_sigma = True
+
+ steps = diffusion_config['steps']
+ timestep_respacing = f'ddim{steps}'
+ diffusion_steps = 1000
+
+ rescale_timesteps = True
+
+ betas = gd.get_named_beta_schedule('linear', diffusion_steps)
+ loss_type = gd.LossType.MSE
+
+ if not timestep_respacing:
+ timestep_respacing = [diffusion_steps]
+
+ diffusion = SpacedDiffusion(
+ use_timesteps=space_timesteps(diffusion_steps, timestep_respacing),
+ betas=betas,
+ model_mean_type=(gd.ModelMeanType.EPSILON
+ if not predict_xstart else gd.ModelMeanType.START_X),
+ model_var_type=((gd.ModelVarType.FIXED_LARGE
+ if not sigma_small else gd.ModelVarType.FIXED_SMALL)
+ if not learn_sigma else gd.ModelVarType.LEARNED_RANGE),
+ loss_type=loss_type,
+ rescale_timesteps=rescale_timesteps)
+
+ return diffusion
diff --git a/modelscope/models/multi_modal/guided_diffusion/unet.py b/modelscope/models/multi_modal/guided_diffusion/unet.py
new file mode 100644
index 00000000..946a4179
--- /dev/null
+++ b/modelscope/models/multi_modal/guided_diffusion/unet.py
@@ -0,0 +1,1046 @@
+# This code is borrowed and modified from Guided Diffusion Model,
+# made publicly available under MIT license at
+# https://github.com/IDEA-CCNL/Fengshenbang-LM/tree/main/fengshen/examples/disco_project
+
+import math
+from abc import abstractmethod
+
+import numpy as np
+import torch as th
+import torch.nn as nn
+import torch.nn.functional as F
+from transformers import PretrainedConfig, PreTrainedModel
+
+
+class GroupNorm(nn.GroupNorm):
+
+ def forward(self, x):
+ return super(GroupNorm, self).forward(x.float()).type(x.dtype)
+
+
+def timestep_embedding(timesteps, dim, max_period=10000):
+ """
+ Create sinusoidal timestep embeddings.
+
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param dim: the dimension of the output.
+ :param max_period: controls the minimum frequency of the embeddings.
+ :return: an [N x dim] Tensor of positional embeddings.
+ """
+ half = dim // 2
+ freqs = th.exp(-math.log(max_period)
+ * th.arange(start=0, end=half, dtype=th.float32)
+ / half).to(device=timesteps.device)
+ args = timesteps[:, None].float() * freqs[None]
+ embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = th.cat(
+ [embedding, th.zeros_like(embedding[:, :1])], dim=-1)
+ return embedding
+
+
+def convert_module_to_f16(ll):
+ """
+ Convert primitive modules to float16.
+ """
+ if isinstance(ll, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
+ ll.weight.data = ll.weight.data.half()
+ if ll.bias is not None:
+ ll.bias.data = ll.bias.data.half()
+
+
+def convert_module_to_f32(ll):
+ """
+ Convert primitive modules to float32, undoing convert_module_to_f16().
+ """
+ if isinstance(ll, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
+ ll.weight.data = ll.weight.data.float()
+ if ll.bias is not None:
+ ll.bias.data = ll.bias.data.float()
+
+
+def conv_nd(dims, *args, **kwargs):
+ """
+ Create a 1D, 2D, or 3D convolution module.
+ """
+ if dims == 1:
+ return nn.Conv1d(*args, **kwargs)
+ elif dims == 2:
+ return nn.Conv2d(*args, **kwargs)
+ elif dims == 3:
+ return nn.Conv3d(*args, **kwargs)
+ raise ValueError(f'unsupported dimensions: {dims}')
+
+
+def checkpoint(func, inputs, params, flag):
+ """
+ Evaluate a function without caching intermediate activations, allowing for
+ reduced memory at the expense of extra compute in the backward pass.
+ :param func: the function to evaluate.
+ :param inputs: the argument sequence to pass to `func`.
+ :param params: a sequence of parameters `func` depends on but does not
+ explicitly take as arguments.
+ :param flag: if False, disable gradient checkpointing.
+ """
+ if flag:
+ args = tuple(inputs) + tuple(params)
+ return CheckpointFunction.apply(func, len(inputs), *args)
+ else:
+ return func(*inputs)
+
+
+class AttentionPool2d(nn.Module):
+ """
+ Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
+ """
+
+ def __init__(
+ self,
+ spacial_dim: int,
+ embed_dim: int,
+ num_heads_channels: int,
+ output_dim: int = None,
+ ):
+ super().__init__()
+ self.positional_embedding = nn.Parameter(
+ th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5)
+ self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
+ self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
+ self.num_heads = embed_dim // num_heads_channels
+ self.attention = QKVAttention(self.num_heads)
+
+ def forward(self, x):
+ b, c, *_spatial = x.shape
+ x = x.reshape(b, c, -1) # NC(HW)
+ x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
+ x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
+ x = self.qkv_proj(x)
+ x = self.attention(x)
+ x = self.c_proj(x)
+ return x[:, :, 0]
+
+
+class TimestepBlock(nn.Module):
+ """
+ Any module where forward() takes timestep embeddings as a second argument.
+ """
+
+ @abstractmethod
+ def forward(self, x, emb):
+ """
+ Apply the module to `x` given `emb` timestep embeddings.
+ """
+
+
+class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
+ """
+ A sequential module that passes timestep embeddings to the children that
+ support it as an extra input.
+ """
+
+ def forward(self, x, emb):
+ for layer in self:
+ if isinstance(layer, TimestepBlock):
+ x = layer(x, emb)
+ else:
+ x = layer(x)
+ return x
+
+
+class Upsample(nn.Module):
+ """
+ An upsampling layer with an optional convolution.
+
+ :param channels: channels in the inputs and outputs.
+ :param use_conv: a bool determining if a convolution is applied.
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ upsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(self, channels, use_conv, dims=2, out_channels=None):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.dims = dims
+ if use_conv:
+ self.conv = conv_nd(
+ dims, self.channels, self.out_channels, 3, padding=1)
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ if self.dims == 3:
+ x = F.interpolate(
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2),
+ mode='nearest')
+ else:
+ x = F.interpolate(x, scale_factor=2, mode='nearest')
+ if self.use_conv:
+ x = self.conv(x)
+ return x
+
+
+class Downsample(nn.Module):
+ """
+ A downsampling layer with an optional convolution.
+
+ :param channels: channels in the inputs and outputs.
+ :param use_conv: a bool determining if a convolution is applied.
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ downsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(self, channels, use_conv, dims=2, out_channels=None):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.dims = dims
+ stride = 2 if dims != 3 else (1, 2, 2)
+ if use_conv:
+ self.op = conv_nd(
+ dims,
+ self.channels,
+ self.out_channels,
+ 3,
+ stride=stride,
+ padding=1)
+ else:
+ assert self.channels == self.out_channels
+ self.op = nn.AvgPool2d(kernel_size=stride, stride=stride)
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ return self.op(x)
+
+
+class ResBlock(TimestepBlock):
+ """
+ A residual block that can optionally change the number of channels.
+
+ :param channels: the number of input channels.
+ :param emb_channels: the number of timestep embedding channels.
+ :param dropout: the rate of dropout.
+ :param out_channels: if specified, the number of out channels.
+ :param use_conv: if True and out_channels is specified, use a spatial
+ convolution instead of a smaller 1x1 convolution to change the
+ channels in the skip connection.
+ :param dims: determines if the signal is 1D, 2D, or 3D.
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
+ :param up: if True, use this block for upsampling.
+ :param down: if True, use this block for downsampling.
+ """
+
+ def __init__(
+ self,
+ channels,
+ emb_channels,
+ dropout,
+ out_channels=None,
+ use_conv=False,
+ use_scale_shift_norm=False,
+ dims=2,
+ use_checkpoint=False,
+ up=False,
+ down=False,
+ ):
+ super().__init__()
+ self.channels = channels
+ self.emb_channels = emb_channels
+ self.dropout = dropout
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.use_checkpoint = use_checkpoint
+ self.use_scale_shift_norm = use_scale_shift_norm
+
+ self.in_layers = nn.Sequential(
+ GroupNorm(32, channels),
+ nn.SiLU(),
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
+ )
+
+ self.updown = up or down
+
+ if up:
+ self.h_upd = Upsample(channels, False, dims)
+ self.x_upd = Upsample(channels, False, dims)
+ elif down:
+ self.h_upd = Downsample(channels, False, dims)
+ self.x_upd = Downsample(channels, False, dims)
+ else:
+ self.h_upd = self.x_upd = nn.Identity()
+
+ self.emb_layers = nn.Sequential(
+ nn.SiLU(),
+ nn.Linear(
+ emb_channels,
+ 2 * self.out_channels
+ if use_scale_shift_norm else self.out_channels,
+ ),
+ )
+ self.out_layers = nn.Sequential(
+ GroupNorm(32, self.out_channels),
+ nn.SiLU(),
+ nn.Dropout(p=dropout),
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1),
+ )
+
+ nn.init.zeros_(self.out_layers[-1].weight)
+
+ if self.out_channels == channels:
+ self.skip_connection = nn.Identity()
+ elif use_conv:
+ self.skip_connection = conv_nd(
+ dims, channels, self.out_channels, 3, padding=1)
+ else:
+ self.skip_connection = conv_nd(dims, channels, self.out_channels,
+ 1)
+
+ def forward(self, x, emb):
+ """
+ Apply the block to a Tensor, conditioned on a timestep embedding.
+
+ :param x: an [N x C x ...] Tensor of features.
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
+ :return: an [N x C x ...] Tensor of outputs.
+ """
+ return checkpoint(self._forward, (x, emb), self.parameters(),
+ self.use_checkpoint)
+
+ def _forward(self, x, emb):
+ if self.updown:
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
+ h = in_rest(x)
+ h = self.h_upd(h)
+ x = self.x_upd(x)
+ h = in_conv(h)
+ else:
+ h = self.in_layers(x)
+ emb_out = self.emb_layers(emb).type(h.dtype)
+ while len(emb_out.shape) < len(h.shape):
+ emb_out = emb_out[..., None]
+ if self.use_scale_shift_norm:
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
+ scale, shift = th.chunk(emb_out, 2, dim=1)
+ h = out_norm(h) * (1 + scale) + shift
+ h = out_rest(h)
+ else:
+ h = h + emb_out
+ h = self.out_layers(h)
+ return self.skip_connection(x) + h
+
+
+class AttentionBlock(nn.Module):
+ """
+ An attention block that allows spatial positions to attend to each other.
+
+ Originally ported from here, but adapted to the N-d case.
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
+ """
+
+ def __init__(
+ self,
+ channels,
+ num_heads=1,
+ num_head_channels=-1,
+ use_checkpoint=False,
+ use_new_attention_order=False,
+ ):
+ super().__init__()
+ self.channels = channels
+ if num_head_channels == -1:
+ self.num_heads = num_heads
+ else:
+ assert (
+ channels % num_head_channels == 0
+ ), f'q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}'
+ self.num_heads = channels // num_head_channels
+ self.use_checkpoint = use_checkpoint
+ self.norm = GroupNorm(32, channels)
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
+ if use_new_attention_order:
+ # split qkv before split heads
+ self.attention = QKVAttention(self.num_heads)
+ else:
+ # split heads before split qkv
+ self.attention = QKVAttentionLegacy(self.num_heads)
+
+ self.proj_out = conv_nd(1, channels, channels, 1)
+
+ nn.init.zeros_(self.proj_out.weight)
+
+ def forward(self, x):
+ return checkpoint(self._forward, (x, ), self.parameters(),
+ self.use_checkpoint)
+
+ def _forward(self, x):
+ b, c, *spatial = x.shape
+ x = x.reshape(b, c, -1)
+ qkv = self.qkv(self.norm(x))
+ h = self.attention(qkv)
+ h = self.proj_out(h)
+ return (x + h).reshape(b, c, *spatial)
+
+
+def count_flops_attn(model, _x, y):
+ """
+ A counter for the `thop` package to count the operations in an
+ attention operation.
+ Meant to be used like:
+ macs, params = thop.profile(
+ model,
+ inputs=(inputs, timestamps),
+ custom_ops={QKVAttention: QKVAttention.count_flops},
+ )
+ """
+ b, c, *spatial = y[0].shape
+ num_spatial = int(np.prod(spatial))
+ # We perform two matmuls with the same number of ops.
+ # The first computes the weight matrix, the second computes
+ # the combination of the value vectors.
+ matmul_ops = 2 * b * (num_spatial**2) * c
+ model.total_ops += th.DoubleTensor([matmul_ops])
+
+
+class QKVAttentionLegacy(nn.Module):
+ """
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
+ """
+
+ def __init__(self, n_heads):
+ super().__init__()
+ self.n_heads = n_heads
+
+ def forward(self, qkv):
+ """
+ Apply QKV attention.
+
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
+ :return: an [N x (H * C) x T] tensor after attention.
+ """
+ bs, width, length = qkv.shape
+ assert width % (3 * self.n_heads) == 0
+ ch = width // (3 * self.n_heads)
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(
+ ch, dim=1)
+ scale = 1 / math.sqrt(math.sqrt(ch))
+ weight = th.einsum(
+ 'bct,bcs->bts', q * scale,
+ k * scale) # More stable with f16 than dividing afterwards
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
+ a = th.einsum('bts,bcs->bct', weight, v)
+ return a.reshape(bs, -1, length)
+
+ @staticmethod
+ def count_flops(model, _x, y):
+ return count_flops_attn(model, _x, y)
+
+
+class QKVAttention(nn.Module):
+ """
+ A module which performs QKV attention and splits in a different order.
+ """
+
+ def __init__(self, n_heads):
+ super().__init__()
+ self.n_heads = n_heads
+
+ def forward(self, qkv):
+ """
+ Apply QKV attention.
+
+ :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
+ :return: an [N x (H * C) x T] tensor after attention.
+ """
+ bs, width, length = qkv.shape
+ assert width % (3 * self.n_heads) == 0
+ ch = width // (3 * self.n_heads)
+ q, k, v = qkv.chunk(3, dim=1)
+ scale = 1 / math.sqrt(math.sqrt(ch))
+ weight = th.einsum(
+ 'bct,bcs->bts',
+ (q * scale).view(bs * self.n_heads, ch, length),
+ (k * scale).view(bs * self.n_heads, ch, length),
+ ) # More stable with f16 than dividing afterwards
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
+ a = th.einsum('bts,bcs->bct', weight,
+ v.reshape(bs * self.n_heads, ch, length))
+ return a.reshape(bs, -1, length)
+
+ @staticmethod
+ def count_flops(model, _x, y):
+ return count_flops_attn(model, _x, y)
+
+
+class UNetModel(nn.Module):
+ """
+ The full UNet model with attention and timestep embedding.
+
+ :param in_channels: channels in the input Tensor.
+ :param model_channels: base channel count for the model.
+ :param out_channels: channels in the output Tensor.
+ :param num_res_blocks: number of residual blocks per downsample.
+ :param attention_resolutions: a collection of downsample rates at which
+ attention will take place. May be a set, list, or tuple.
+ For example, if this contains 4, then at 4x downsampling, attention
+ will be used.
+ :param dropout: the dropout probability.
+ :param channel_mult: channel multiplier for each level of the UNet.
+ :param conv_resample: if True, use learned convolutions for upsampling and
+ downsampling.
+ :param dims: determines if the signal is 1D, 2D, or 3D.
+ :param num_classes: if specified (as an int), then this model will be
+ class-conditional with `num_classes` classes.
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
+ :param num_heads: the number of attention heads in each attention layer.
+ :param num_head_channels: if specified, ignore num_heads and instead use
+ a fixed channel width per attention head.
+ :param num_heads_upsample: works with num_heads to set a different number
+ of heads for upsampling. Deprecated.
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
+ :param resblock_updown: use residual blocks for up/downsampling.
+ :param use_new_attention_order: use a different attention pattern for potentially
+ increased efficiency.
+ """
+
+ def __init__(
+ self,
+ image_size,
+ in_channels,
+ model_channels,
+ out_channels,
+ num_res_blocks,
+ attention_resolutions,
+ dropout=0,
+ channel_mult=(1, 2, 4, 8),
+ conv_resample=True,
+ dims=2,
+ num_classes=None,
+ use_checkpoint=False,
+ use_fp16=False,
+ num_heads=1,
+ num_head_channels=-1,
+ num_heads_upsample=-1,
+ use_scale_shift_norm=False,
+ resblock_updown=False,
+ use_new_attention_order=False,
+ ):
+ super().__init__()
+
+ if num_heads_upsample == -1:
+ num_heads_upsample = num_heads
+
+ self.image_size = image_size
+ self.in_channels = in_channels
+ self.model_channels = model_channels
+ self.out_channels = out_channels
+ self.num_res_blocks = num_res_blocks
+ self.attention_resolutions = attention_resolutions
+ self.dropout = dropout
+ self.channel_mult = channel_mult
+ self.conv_resample = conv_resample
+ self.num_classes = num_classes
+ self.use_checkpoint = use_checkpoint
+ self.dtype = th.float16 if use_fp16 else th.float32
+ self.num_heads = num_heads
+ self.num_head_channels = num_head_channels
+ self.num_heads_upsample = num_heads_upsample
+
+ time_embed_dim = model_channels * 4
+ self.time_embed = nn.Sequential(
+ nn.Linear(model_channels, time_embed_dim),
+ nn.SiLU(),
+ nn.Linear(time_embed_dim, time_embed_dim),
+ )
+
+ if self.num_classes is not None:
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
+
+ ch = input_ch = int(channel_mult[0] * model_channels)
+ self.input_blocks = nn.ModuleList([
+ TimestepEmbedSequential(
+ conv_nd(dims, in_channels, ch, 3, padding=1))
+ ])
+ self._feature_size = ch
+ input_block_chans = [ch]
+ ds = 1
+ for level, mult in enumerate(channel_mult):
+ for _ in range(num_res_blocks):
+ layers = [
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=int(mult * model_channels),
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = int(mult * model_channels)
+ if ds in attention_resolutions:
+ layers.append(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=num_head_channels,
+ use_new_attention_order=use_new_attention_order,
+ ))
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+ input_block_chans.append(ch)
+ if level != len(channel_mult) - 1:
+ out_ch = ch
+ self.input_blocks.append(
+ TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ down=True,
+ ) if resblock_updown else Downsample(
+ ch, conv_resample, dims=dims, out_channels=out_ch))
+ )
+ ch = out_ch
+ input_block_chans.append(ch)
+ ds *= 2
+ self._feature_size += ch
+
+ self.middle_block = TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=num_head_channels,
+ use_new_attention_order=use_new_attention_order,
+ ),
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ )
+ self._feature_size += ch
+
+ self.output_blocks = nn.ModuleList([])
+ for level, mult in list(enumerate(channel_mult))[::-1]:
+ for i in range(num_res_blocks + 1):
+ ich = input_block_chans.pop()
+ layers = [
+ ResBlock(
+ ch + ich,
+ time_embed_dim,
+ dropout,
+ out_channels=int(model_channels * mult),
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = int(model_channels * mult)
+ if ds in attention_resolutions:
+ layers.append(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads_upsample,
+ num_head_channels=num_head_channels,
+ use_new_attention_order=use_new_attention_order,
+ ))
+ if level and i == num_res_blocks:
+ out_ch = ch
+ layers.append(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ up=True,
+ ) if resblock_updown else Upsample(
+ ch, conv_resample, dims=dims, out_channels=out_ch))
+ ds //= 2
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+
+ self.out = nn.Sequential(
+ GroupNorm(32, ch),
+ nn.SiLU(),
+ conv_nd(dims, input_ch, out_channels, 3, padding=1),
+ )
+
+ nn.init.zeros_(self.out[-1].weight)
+
+ def convert_to_fp16(self):
+ """
+ Convert the torso of the model to float16.
+ """
+ self.input_blocks.apply(convert_module_to_f16)
+ self.middle_block.apply(convert_module_to_f16)
+ self.output_blocks.apply(convert_module_to_f16)
+
+ def convert_to_fp32(self):
+ """
+ Convert the torso of the model to float32.
+ """
+ self.input_blocks.apply(convert_module_to_f32)
+ self.middle_block.apply(convert_module_to_f32)
+ self.output_blocks.apply(convert_module_to_f32)
+
+ def forward(self, x, timesteps, y=None):
+ """
+ Apply the model to an input batch.
+
+ :param x: an [N x C x ...] Tensor of inputs.
+ :param timesteps: a 1-D batch of timesteps.
+ :param y: an [N] Tensor of labels, if class-conditional.
+ :return: an [N x C x ...] Tensor of outputs.
+ """
+ assert (y is not None) == (
+ self.num_classes is not None
+ ), 'must specify y if and only if the model is class-conditional'
+
+ hs = []
+ emb = self.time_embed(
+ timestep_embedding(timesteps, self.model_channels))
+
+ if self.num_classes is not None:
+ assert y.shape == (x.shape[0], )
+ emb = emb + self.label_emb(y)
+
+ h = x.type(self.dtype)
+ for module in self.input_blocks:
+ h = module(h, emb)
+ hs.append(h)
+ h = self.middle_block(h, emb)
+ for module in self.output_blocks:
+ h = th.cat([h, hs.pop()], dim=1)
+ h = module(h, emb)
+ h = h.type(x.dtype)
+ return self.out(h)
+
+
+class SuperResModel(UNetModel):
+ """
+ A UNetModel that performs super-resolution.
+
+ Expects an extra kwarg `low_res` to condition on a low-resolution image.
+ """
+
+ def __init__(self, image_size, in_channels, *args, **kwargs):
+ super().__init__(image_size, in_channels * 2, *args, **kwargs)
+
+ def forward(self, x, timesteps, low_res=None, **kwargs):
+ _, _, new_height, new_width = x.shape
+ upsampled = F.interpolate(
+ low_res, (new_height, new_width), mode='bilinear')
+ x = th.cat([x, upsampled], dim=1)
+ return super().forward(x, timesteps, **kwargs)
+
+
+class EncoderUNetModel(nn.Module):
+ """
+ The half UNet model with attention and timestep embedding.
+
+ For usage, see UNet.
+ """
+
+ def __init__(
+ self,
+ image_size,
+ in_channels,
+ model_channels,
+ out_channels,
+ num_res_blocks,
+ attention_resolutions,
+ dropout=0,
+ channel_mult=(1, 2, 4, 8),
+ conv_resample=True,
+ dims=2,
+ use_checkpoint=False,
+ use_fp16=False,
+ num_heads=1,
+ num_head_channels=-1,
+ num_heads_upsample=-1,
+ use_scale_shift_norm=False,
+ resblock_updown=False,
+ use_new_attention_order=False,
+ pool='adaptive',
+ ):
+ super().__init__()
+
+ if num_heads_upsample == -1:
+ num_heads_upsample = num_heads
+
+ self.in_channels = in_channels
+ self.model_channels = model_channels
+ self.out_channels = out_channels
+ self.num_res_blocks = num_res_blocks
+ self.attention_resolutions = attention_resolutions
+ self.dropout = dropout
+ self.channel_mult = channel_mult
+ self.conv_resample = conv_resample
+ self.use_checkpoint = use_checkpoint
+ self.dtype = th.float16 if use_fp16 else th.float32
+ self.num_heads = num_heads
+ self.num_head_channels = num_head_channels
+ self.num_heads_upsample = num_heads_upsample
+
+ time_embed_dim = model_channels * 4
+ self.time_embed = nn.Sequential(
+ nn.Linear(model_channels, time_embed_dim),
+ nn.SiLU(),
+ nn.Linear(time_embed_dim, time_embed_dim),
+ )
+
+ ch = int(channel_mult[0] * model_channels)
+ self.input_blocks = nn.ModuleList([
+ TimestepEmbedSequential(
+ conv_nd(dims, in_channels, ch, 3, padding=1))
+ ])
+ self._feature_size = ch
+ input_block_chans = [ch]
+ ds = 1
+ for level, mult in enumerate(channel_mult):
+ for _ in range(num_res_blocks):
+ layers = [
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=int(mult * model_channels),
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = int(mult * model_channels)
+ if ds in attention_resolutions:
+ layers.append(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=num_head_channels,
+ use_new_attention_order=use_new_attention_order,
+ ))
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+ input_block_chans.append(ch)
+ if level != len(channel_mult) - 1:
+ out_ch = ch
+ self.input_blocks.append(
+ TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ down=True,
+ ) if resblock_updown else Downsample(
+ ch, conv_resample, dims=dims, out_channels=out_ch))
+ )
+ ch = out_ch
+ input_block_chans.append(ch)
+ ds *= 2
+ self._feature_size += ch
+
+ self.middle_block = TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=num_head_channels,
+ use_new_attention_order=use_new_attention_order,
+ ),
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ )
+ self._feature_size += ch
+ self.pool = pool
+ if pool == 'adaptive':
+ self.out = nn.Sequential(
+ GroupNorm(32, ch),
+ nn.SiLU(),
+ nn.AdaptiveAvgPool2d((1, 1)),
+ conv_nd(dims, ch, out_channels, 1),
+ nn.Flatten(),
+ )
+ nn.init.zeros_(self.out[-1].weight)
+ elif pool == 'attention':
+ assert num_head_channels != -1
+ self.out = nn.Sequential(
+ GroupNorm(32, ch),
+ nn.SiLU(),
+ AttentionPool2d((image_size // ds), ch, num_head_channels,
+ out_channels),
+ )
+ elif pool == 'spatial':
+ self.out = nn.Sequential(
+ nn.Linear(self._feature_size, 2048),
+ nn.ReLU(),
+ nn.Linear(2048, self.out_channels),
+ )
+ elif pool == 'spatial_v2':
+ self.out = nn.Sequential(
+ nn.Linear(self._feature_size, 2048),
+ GroupNorm(32, 2048),
+ nn.SiLU(),
+ nn.Linear(2048, self.out_channels),
+ )
+ else:
+ raise NotImplementedError(f'Unexpected {pool} pooling')
+
+ def convert_to_fp16(self):
+ """
+ Convert the torso of the model to float16.
+ """
+ self.input_blocks.apply(convert_module_to_f16)
+ self.middle_block.apply(convert_module_to_f16)
+
+ def convert_to_fp32(self):
+ """
+ Convert the torso of the model to float32.
+ """
+ self.input_blocks.apply(convert_module_to_f32)
+ self.middle_block.apply(convert_module_to_f32)
+
+ def forward(self, x, timesteps):
+ """
+ Apply the model to an input batch.
+
+ :param x: an [N x C x ...] Tensor of inputs.
+ :param timesteps: a 1-D batch of timesteps.
+ :return: an [N x K] Tensor of outputs.
+ """
+ emb = self.time_embed(
+ timestep_embedding(timesteps, self.model_channels))
+
+ results = []
+ h = x.type(self.dtype)
+ for module in self.input_blocks:
+ h = module(h, emb)
+ if self.pool.startswith('spatial'):
+ results.append(h.type(x.dtype).mean(dim=(2, 3)))
+ h = self.middle_block(h, emb)
+ if self.pool.startswith('spatial'):
+ results.append(h.type(x.dtype).mean(dim=(2, 3)))
+ h = th.cat(results, axis=-1)
+ return self.out(h)
+ else:
+ h = h.type(x.dtype)
+ return self.out(h)
+
+
+class UNetConfig(PretrainedConfig):
+
+ def __init__(self,
+ image_size=512,
+ in_channels=3,
+ model_channels=256,
+ out_channels=6,
+ num_res_blocks=2,
+ attention_resolutions=[16, 32, 64],
+ dropout=0.0,
+ channel_mult=(0.5, 1, 1, 2, 2, 4, 4),
+ num_classes=None,
+ use_checkpoint=False,
+ use_fp16=True,
+ num_heads=4,
+ num_head_channels=64,
+ num_heads_upsample=-1,
+ use_scale_shift_norm=True,
+ resblock_updown=True,
+ use_new_attention_order=False,
+ **kwargs):
+ self.image_size = image_size
+ self.in_channels = in_channels
+ self.model_channels = model_channels
+ self.out_channels = out_channels
+ self.num_res_blocks = num_res_blocks
+ self.attention_resolutions = attention_resolutions
+ self.dropout = dropout
+ self.channel_mult = channel_mult
+ self.num_classes = num_classes
+ self.use_checkpoint = use_checkpoint
+ self.use_fp16 = use_fp16
+ self.num_heads = num_heads
+ self.num_head_channels = num_head_channels
+ self.num_heads_upsample = num_heads_upsample
+ self.use_scale_shift_norm = use_scale_shift_norm
+ self.resblock_updown = resblock_updown
+ self.use_new_attention_order = use_new_attention_order
+ super().__init__(**kwargs)
+
+
+class HFUNetModel(PreTrainedModel):
+ config_class = UNetConfig
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = UNetModel(
+ image_size=config.image_size,
+ in_channels=config.in_channels,
+ model_channels=config.model_channels,
+ out_channels=config.out_channels,
+ num_res_blocks=config.num_res_blocks,
+ attention_resolutions=config.attention_resolutions,
+ dropout=config.dropout,
+ channel_mult=config.channel_mult,
+ num_classes=config.num_classes,
+ use_checkpoint=config.use_checkpoint,
+ use_fp16=config.use_fp16,
+ num_heads=config.num_heads,
+ num_head_channels=config.num_head_channels,
+ num_heads_upsample=config.num_heads_upsample,
+ use_scale_shift_norm=config.use_scale_shift_norm,
+ resblock_updown=config.resblock_updown,
+ use_new_attention_order=config.use_new_attention_order,
+ )
+
+ def forward(self, x, timesteps, y=None):
+ return self.model.forward(x, timesteps, y)
+
+ def convert_to_fp16(self):
+ """
+ Convert the torso of the model to float16.
+ """
+ self.model.input_blocks.apply(convert_module_to_f16)
+ self.model.middle_block.apply(convert_module_to_f16)
+ self.model.output_blocks.apply(convert_module_to_f16)
diff --git a/modelscope/models/multi_modal/rleg/__init__.py b/modelscope/models/multi_modal/rleg/__init__.py
new file mode 100644
index 00000000..7fec95c7
--- /dev/null
+++ b/modelscope/models/multi_modal/rleg/__init__.py
@@ -0,0 +1,23 @@
+# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.
+from typing import TYPE_CHECKING
+
+from modelscope.utils.import_utils import LazyImportModule
+
+if TYPE_CHECKING:
+
+ from .rleg import RLEGForMultiModalEmbedding
+
+else:
+ _import_structure = {
+ 'rleg': ['RLEGForMultiModalEmbedding'],
+ }
+
+ import sys
+
+ sys.modules[__name__] = LazyImportModule(
+ __name__,
+ globals()['__file__'],
+ _import_structure,
+ module_spec=__spec__,
+ extra_objects={},
+ )
diff --git a/modelscope/models/multi_modal/rleg/model.py b/modelscope/models/multi_modal/rleg/model.py
new file mode 100644
index 00000000..efabafdc
--- /dev/null
+++ b/modelscope/models/multi_modal/rleg/model.py
@@ -0,0 +1,139 @@
+# Copyright 2021 The OpenAI Team Authors.
+# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.
+#
+# The implementation here is modified based on OpenAI CLIP,
+# originally MIT License, Copyright (c) 2021 OpenAI,
+# and publicly available at https://github.com/openai/CLIP/.
+""" Generative Multimodal Model Architecture."""
+
+import os
+
+import json
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from modelscope.models.multi_modal.gemm import gemm_base, tokenizer
+
+
+class ImageEncoder(nn.Module):
+ """Image Feature Encoder
+ ViT Style Transformer
+ """
+
+ def __init__(self, configs):
+ super().__init__()
+ (embed_dim, image_resolution, vision_layers, vision_width,
+ vision_patch_size) = configs[:5]
+ self.visual = gemm_base.VisualTransformer(
+ input_resolution=image_resolution,
+ patch_size=vision_patch_size,
+ width=vision_width,
+ layers=vision_layers,
+ heads=vision_width // 64,
+ output_dim=embed_dim,
+ use_gc=False)
+
+ def forward(self, image, return_tokens=False):
+ features = self.visual(image)
+ tokens = features[:, 1:, :]
+ embedding = features[:, 0, :]
+ return (embedding, tokens) if return_tokens else embedding
+
+
+class TextEncoder(nn.Module):
+ """Text Feature Encoder
+ BERT style transformer
+ """
+
+ def __init__(self, configs):
+ super().__init__()
+ (context_length, vocab_size, model_width, model_heads,
+ model_layers) = configs[-5:]
+ # text model
+ self.transformer = gemm_base.Transformer(
+ width=model_width,
+ layers=model_layers,
+ heads=model_heads,
+ attn_mask=self.build_attention_mask(context_length),
+ )
+ # others
+ self.token_embedding = nn.Embedding(vocab_size, model_width)
+ self.positional_embedding = nn.Parameter(
+ torch.empty(context_length, model_width))
+ self.ln_final = nn.LayerNorm(model_width)
+ self.text_projection = nn.Parameter(
+ torch.empty(model_width, configs[0]))
+
+ def build_attention_mask(self, seq_length=None):
+ mask = torch.ones(seq_length, seq_length) * -1e4
+ mask.triu_(1) # zero out the lower diagonal
+ return mask
+
+ def forward(self, text, return_tokens=False):
+ x = self.token_embedding(text) # [batch_size, n_ctx, d_model]
+ x = x + self.positional_embedding
+ x = x.permute(1, 0, 2) # NLD -> LND
+ x = self.transformer(x)
+ x = x.permute(1, 0, 2) # LND -> NLD
+ x = self.ln_final(x)
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
+ embedding = x[torch.arange(x.shape[0]),
+ text.argmax(dim=-1), ...] @ self.text_projection
+ return (embedding, x) if return_tokens else embedding
+
+
+class RLEGModel(nn.Module):
+ """ Generative multi-modal model, trained with RLEG method.
+ It takes image or text or both of them as input, and produce
+ the corresponding features of inputs.
+ """
+
+ def __init__(self, model_dir):
+ super().__init__()
+ with open(
+ '{}/encoder_config.json'.format(model_dir), 'r',
+ encoding='utf-8') as f:
+ model_config = json.loads(f.read())
+ model_name = list(model_config.keys())[0]
+ config_args = model_config[model_name]
+ bpe_path = os.path.join(model_dir, 'bpe_vocab_16e6.txt.gz')
+ self.tokenizer = tokenizer.SimpleTokenizer(bpe_path)
+ # build model architecture
+ self.image_encoder = ImageEncoder(config_args)
+ self.text_encoder = TextEncoder(config_args)
+ self.logit_scale = nn.Parameter(torch.ones([]))
+
+ def tokenize(self, text_str):
+ text_tensor = tokenizer.clip_tokenize(self.tokenizer, [text_str])[0]
+ return text_tensor
+
+ def encode_text(self, text):
+ feature = self.text_encoder(text)
+ feature = F.normalize(feature, p=2, dim=-1)
+ return feature
+
+ def encode_image(self, image):
+ feature = self.image_encoder(image)
+ feature = F.normalize(feature, p=2, dim=-1)
+ return feature
+
+ def parse_feat(self, feat):
+ out = feat.cpu().numpy()
+ return out
+
+ @torch.no_grad()
+ def forward(self, image=None, text=None):
+ """ It takes image or text as input,
+ and extracts the features as output.
+ """
+ img_feature, text_feature = None, None
+ if image is not None:
+ img_feature = self.parse_feat(self.encode_image(image))
+ if text is not None:
+ text_feature = self.parse_feat(self.encode_text(text))
+ out = {
+ 'image_feature': img_feature,
+ 'text_feature': text_feature,
+ }
+ return out
diff --git a/modelscope/models/multi_modal/rleg/rleg.py b/modelscope/models/multi_modal/rleg/rleg.py
new file mode 100644
index 00000000..dd9accd7
--- /dev/null
+++ b/modelscope/models/multi_modal/rleg/rleg.py
@@ -0,0 +1,85 @@
+# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.
+""" Generative Multimodal Model Wrapper."""
+from typing import Any, Dict
+
+import torch
+from torchvision import transforms as T
+
+from modelscope.metainfo import Models
+from modelscope.models.base import TorchModel
+from modelscope.models.builder import MODELS
+from modelscope.models.multi_modal.rleg.model import RLEGModel
+from modelscope.outputs import OutputKeys
+from modelscope.preprocessors import LoadImage
+from modelscope.utils.constant import ModelFile, Tasks
+from modelscope.utils.logger import get_logger
+
+logger = get_logger()
+
+__all__ = ['RLEGForMultiModalEmbedding']
+
+
+@MODELS.register_module(
+ Tasks.generative_multi_modal_embedding, module_name=Models.rleg)
+class RLEGForMultiModalEmbedding(TorchModel):
+ """ Generative multi-modal model for multi-modal embedding.
+ The model is trained by representation learning with embedding generation.
+ Inputs could be image or text or both of them.
+ Outputs could be features of input image or text,
+ """
+
+ def __init__(self, model_dir, device_id=0, *args, **kwargs):
+ super().__init__(
+ model_dir=model_dir, device_id=device_id, *args, **kwargs)
+ self.model = RLEGModel(model_dir=model_dir)
+ pretrained_params = torch.load('{}/{}'.format(
+ model_dir, ModelFile.TORCH_MODEL_BIN_FILE))
+ self.model.load_state_dict(pretrained_params)
+ self.model.eval()
+ self.device_id = device_id
+ if self.device_id >= 0 and torch.cuda.is_available():
+ self.model.to('cuda:{}'.format(self.device_id))
+ logger.info('Use GPU: {}'.format(self.device_id))
+ else:
+ self.device_id = -1
+ logger.info('Use CPU for inference')
+ self.img_preprocessor = T.Compose([
+ T.Resize((224, 224)),
+ T.ToTensor(),
+ T.Normalize((0.48145466, 0.4578275, 0.40821073),
+ (0.26862954, 0.26130258, 0.27577711))
+ ])
+
+ def parse_image(self, input_img):
+ if input_img is None:
+ return None
+ input_img = LoadImage.convert_to_img(input_img)
+ img_tensor = self.img_preprocessor(input_img)[None, ...]
+ if self.device_id >= 0:
+ img_tensor = img_tensor.to('cuda:{}'.format(self.device_id))
+ return img_tensor
+
+ def parse_text(self, text_str):
+ if text_str is None or len(text_str) == 0:
+ return None
+ if isinstance(text_str, str):
+ text_ids_tensor = self.model.tokenize(text_str)
+ else:
+ raise TypeError(f'text should be str, but got {type(text_str)}')
+ if self.device_id >= 0:
+ text_ids_tensor = text_ids_tensor.to('cuda:{}'.format(
+ self.device_id))
+ return text_ids_tensor.view(1, -1)
+
+ def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
+ image_input = input.get('image', input.get('img', None))
+ text_input = input.get('text', input.get('txt', None))
+ image = self.parse_image(image_input)
+ text = self.parse_text(text_input)
+ out = self.model(image, text)
+ output = {
+ OutputKeys.IMG_EMBEDDING: out.get('image_feature', None),
+ OutputKeys.TEXT_EMBEDDING: out.get('text_feature', None),
+ OutputKeys.CAPTION: out.get('caption', None)
+ }
+ return output
diff --git a/modelscope/models/multi_modal/soonet/__init__.py b/modelscope/models/multi_modal/soonet/__init__.py
new file mode 100644
index 00000000..38b95d26
--- /dev/null
+++ b/modelscope/models/multi_modal/soonet/__init__.py
@@ -0,0 +1,27 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+from typing import TYPE_CHECKING
+
+from modelscope.utils.import_utils import LazyImportModule
+
+if TYPE_CHECKING:
+ from .tokenizer import SimpleTokenizer
+ from .model import SOONet
+ from .utils import decode_video
+ from .clip import load_clip
+else:
+ _import_structure = {
+ 'model': ['SOONet'],
+ 'tokenizer': ['SimpleTokenizer'],
+ 'utils': ['decode_video'],
+ 'clip': ['load_clip']
+ }
+
+ import sys
+
+ sys.modules[__name__] = LazyImportModule(
+ __name__,
+ globals()['__file__'],
+ _import_structure,
+ module_spec=__spec__,
+ extra_objects={},
+ )
diff --git a/modelscope/models/multi_modal/soonet/blocks.py b/modelscope/models/multi_modal/soonet/blocks.py
new file mode 100644
index 00000000..28f4a553
--- /dev/null
+++ b/modelscope/models/multi_modal/soonet/blocks.py
@@ -0,0 +1,287 @@
+# Copyright 2022-2023 The Alibaba Fundamental Vision Team Authors. All rights reserved.
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class Q2VRankerStage1(nn.Module):
+ """
+ Used to calculate the qv_ctx_score with query embedding and multi anchor context embeddings as input.
+ The qv_ctx_score is used to pre-rank and retain top-k related anchors.
+ """
+
+ def __init__(self, nscales, hidden_dim):
+ super().__init__()
+ self.fc = nn.Linear(hidden_dim, hidden_dim)
+ self.nscales = nscales
+
+ def forward(self, ctx_feats, qfeat):
+ qfeat = self.fc(qfeat)
+ qv_ctx_scores = list()
+ for i in range(self.nscales):
+ score = torch.einsum('bld,bd->bl',
+ F.normalize(ctx_feats[i], p=2, dim=2),
+ F.normalize(qfeat, p=2, dim=1))
+ qv_ctx_scores.append(score)
+
+ return qv_ctx_scores
+
+
+class V2QRankerStage1(nn.Module):
+ """
+ Used to calculate the vq_ctx_score with anchor context embeddings and multi query embeddings as input.
+ """
+
+ def __init__(self, nscales, hidden_dim):
+ super().__init__()
+ self.fc = nn.Linear(hidden_dim, hidden_dim)
+ self.nscales = nscales
+
+ def forward(self, ctx_feats, qfeat):
+ vq_ctx_scores = list()
+ for i in range(self.nscales):
+ score = torch.einsum(
+ 'bld,bd->bl', F.normalize(self.fc(ctx_feats[i]), p=2, dim=2),
+ F.normalize(qfeat, p=2, dim=1))
+ vq_ctx_scores.append(score)
+
+ return vq_ctx_scores
+
+
+class Q2VRankerStage2(nn.Module):
+ """
+ Used to calculate the qv_ctn_score with query embedding and video sequence embedding as input.
+ The qv_ctn_score is used to re-rank anchors.
+ """
+
+ def __init__(self, nscales, hidden_dim, snippet_length=10):
+ super().__init__()
+ self.nscales = nscales
+ self.snippet_length = snippet_length
+ self.qfc = nn.Linear(hidden_dim, hidden_dim)
+ self.encoder = V2VAttention()
+
+ def forward(self, vfeats, qfeat, hit_indices, qv_ctx_scores):
+ qfeat = self.qfc(qfeat)
+
+ qv_ctn_scores = list()
+ qv_merge_scores = list()
+
+ _, L, D = vfeats.size()
+ ctn_feats = list()
+ for i in range(self.nscales):
+ anchor_length = self.snippet_length * 2**i
+ assert L // anchor_length == qv_ctx_scores[i].size(1)
+ qv_ctx_score = torch.index_select(qv_ctx_scores[i], 1,
+ hit_indices[i])
+
+ ctn_feat = vfeats.view(L // anchor_length, anchor_length,
+ D).detach()
+ ctn_feat = torch.index_select(ctn_feat, 0, hit_indices[i])
+ ctn_feat = self.encoder(
+ ctn_feat,
+ torch.ones(ctn_feat.size()[:2], device=ctn_feat.device))
+ ctn_feats.append(ctn_feat)
+
+ qv_ctn_score = torch.einsum(
+ 'bkld,bd->bkl', F.normalize(ctn_feat.unsqueeze(0), p=2, dim=3),
+ F.normalize(qfeat, p=2, dim=1))
+ qv_ctn_score, _ = torch.max(qv_ctn_score, dim=2)
+ qv_ctn_scores.append(qv_ctn_score)
+ qv_merge_scores.append(qv_ctx_score + qv_ctn_score)
+
+ return qv_merge_scores, qv_ctn_scores, ctn_feats
+
+
+class V2QRankerStage2(nn.Module):
+ """
+ Used to calculate the vq_ctn_score with anchor content embeddings and multi query embeddings as input.
+ """
+
+ def __init__(self, nscales, hidden_dim):
+ super().__init__()
+ self.fc = nn.Linear(hidden_dim, hidden_dim)
+ self.nscales = nscales
+
+ def forward(self, ctn_feats, qfeat):
+ vq_ctn_scores = list()
+ for i in range(self.nscales):
+ score = torch.einsum(
+ 'bkld,bd->bkl',
+ F.normalize(self.fc(ctn_feats[i]).unsqueeze(0), p=2, dim=3),
+ F.normalize(qfeat, p=2, dim=1))
+ score = torch.mean(score, dim=2)
+ vq_ctn_scores.append(score)
+
+ return vq_ctn_scores
+
+
+class V2VAttention(nn.Module):
+ """
+ Self-attention encoder for anchor frame sequence to encode intra-anchor knowledge.
+ """
+
+ def __init__(self):
+ super().__init__()
+ self.posemb = PositionEncoding(max_len=400, dim=512, dropout=0.0)
+ self.encoder = MultiHeadAttention(dim=512, n_heads=8, dropout=0.1)
+ self.dropout = nn.Dropout(0.0)
+
+ def forward(self, video_feats, video_masks):
+ mask = torch.einsum('bm,bn->bmn', video_masks,
+ video_masks).unsqueeze(1)
+ residual = video_feats
+ video_feats = video_feats + self.posemb(video_feats)
+ out = self.encoder(
+ query=video_feats, key=video_feats, value=video_feats, mask=mask)
+ video_feats = self.dropout(residual
+ + out) * video_masks.unsqueeze(2).float()
+ return video_feats
+
+
+class BboxRegressor(nn.Module):
+ """
+ Predict the offset of bounding box for each candidate anchor.
+ """
+
+ def __init__(self, hidden_dim, enable_stage2=False):
+ super().__init__()
+ self.fc_ctx = nn.Linear(hidden_dim, hidden_dim)
+ self.fc_q = nn.Linear(hidden_dim, hidden_dim)
+
+ if enable_stage2:
+ self.fc_ctn = nn.Linear(hidden_dim, hidden_dim)
+ self.attn = SelfAttention(hidden_dim)
+ self.predictor = nn.Sequential(
+ nn.Linear(2 * hidden_dim, hidden_dim), nn.ReLU(),
+ nn.Linear(hidden_dim, 2))
+ else:
+ self.predictor = nn.Sequential(
+ nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
+ nn.Linear(hidden_dim, 2))
+ self.enable_stage2 = enable_stage2
+
+ def forward(self, ctx_feats, ctn_feats, qfeat):
+ qfeat = self.fc_q(qfeat)
+
+ ctx_feats = torch.cat(ctx_feats, dim=1)
+ ctx_fuse_feats = F.relu(self.fc_ctx(ctx_feats)) * F.relu(
+ qfeat.unsqueeze(1))
+
+ if self.enable_stage2 and ctn_feats:
+ ctn_fuse_feats = list()
+ for i in range(len(ctn_feats)):
+ out = F.relu(self.fc_ctn(ctn_feats[i]).unsqueeze(0)) * F.relu(
+ qfeat.unsqueeze(1).unsqueeze(1))
+ out = self.attn(out)
+ ctn_fuse_feats.append(out)
+ ctn_fuse_feats = torch.cat(ctn_fuse_feats, dim=1)
+ fuse_feats = torch.cat([ctx_fuse_feats, ctn_fuse_feats], dim=-1)
+ else:
+ fuse_feats = ctx_fuse_feats
+
+ out = self.predictor(fuse_feats)
+ return out
+
+
+class SelfAttention(nn.Module):
+ """
+ Obtain pooled features by self-attentive pooling.
+ """
+
+ def __init__(self, hidden_dim):
+ super().__init__()
+ self.fc1 = nn.Linear(hidden_dim, hidden_dim // 2)
+ self.relu = nn.ReLU()
+ self.fc2 = nn.Linear(hidden_dim // 2, 1)
+
+ def forward(self, x):
+ att = self.fc2(self.relu(self.fc1(x))).squeeze(3)
+ att = F.softmax(att, dim=2).unsqueeze(3)
+ out = torch.sum(x * att, dim=2)
+ return out
+
+
+class PositionEncoding(nn.Module):
+ """
+ An implementation of trainable positional embedding which is added to
+ sequence features to inject time/position information.
+
+ Args:
+ max_len: The max number of trainable positional embeddings.
+ dim: the dimension of positional embedding.
+ """
+
+ def __init__(self, max_len, dim, dropout=0.0):
+ super(PositionEncoding, self).__init__()
+
+ self.embed = nn.Embedding(max_len, dim)
+ self.relu = nn.ReLU()
+ self.dropout = nn.Dropout(dropout)
+
+ def forward(self, x):
+ batch_size, seq_len = x.shape[:2]
+ pos_ids = torch.arange(seq_len, dtype=torch.long, device=x.device)
+ pos_ids = pos_ids.unsqueeze(0).repeat(batch_size, 1)
+ pos_emb = self.dropout(self.relu(self.embed(pos_ids)))
+
+ return pos_emb
+
+
+class MultiHeadAttention(nn.Module):
+ """
+ An implementation of multi-head attention module, as described in
+ 'Attention Is All You Need '
+
+ Args:
+ dim: the dimension of features of hidden layers.
+ n_heads: the number of head.
+ """
+
+ def __init__(self, dim, n_heads, dropout=0.0):
+ super(MultiHeadAttention, self).__init__()
+
+ self.dim = dim
+ self.n_heads = n_heads
+ self.head_dim = dim // n_heads
+
+ self.to_q = nn.Linear(dim, dim)
+ self.to_k = nn.Linear(dim, dim)
+ self.to_v = nn.Linear(dim, dim)
+
+ self.dropout = nn.Dropout(dropout)
+ self.softmax = nn.Softmax(dim=-1)
+
+ def transpose_for_scores(self, x):
+ new_x_shape = x.size()[:-1] + (self.n_heads, self.head_dim)
+ x = x.view(*new_x_shape)
+ return x.permute(0, 2, 1, 3) # (N, nh, L, dh)
+
+ def forward(self, query, key, value, mask):
+ q = self.to_q(query)
+ k = self.to_k(key)
+ v = self.to_v(value)
+
+ q_trans = self.transpose_for_scores(q)
+ k_trans = self.transpose_for_scores(k)
+ v_trans = self.transpose_for_scores(v)
+
+ att = torch.matmul(q_trans, k_trans.transpose(-1,
+ -2)) # (N, nh, Lq, L)
+ att = att / math.sqrt(self.head_dim)
+ att = mask_logits(att, mask)
+ att = self.softmax(att)
+ att = self.dropout(att)
+
+ ctx_v = torch.matmul(att, v_trans) # (N, nh, Lq, dh)
+ ctx_v = ctx_v.permute(0, 2, 1, 3).contiguous() # (N, Lq, nh, dh)
+ shape = ctx_v.size()[:-2] + (self.dim, )
+ ctx_v = ctx_v.view(*shape) # (N, Lq, D)
+ return ctx_v
+
+
+def mask_logits(inputs, mask, mask_value=-1e30):
+ mask = mask.type(torch.float32)
+ return inputs + (1.0 - mask) * mask_value
diff --git a/modelscope/models/multi_modal/soonet/clip.py b/modelscope/models/multi_modal/soonet/clip.py
new file mode 100644
index 00000000..c43820e5
--- /dev/null
+++ b/modelscope/models/multi_modal/soonet/clip.py
@@ -0,0 +1,342 @@
+# The implementation is adopted from CLIP, made publicly available
+# under MIT License at https://github.com/openai/CLIP
+
+import warnings
+from collections import OrderedDict
+from typing import Tuple, Union
+
+import numpy as np
+import torch
+from torch import nn
+
+
+class CLIP(nn.Module):
+
+ def __init__(
+ self,
+ embed_dim: int,
+ # vision
+ image_resolution: int,
+ vision_layers: Union[Tuple[int, int, int, int], int],
+ vision_width: int,
+ vision_patch_size: int,
+ # text
+ context_length: int,
+ vocab_size: int,
+ transformer_width: int,
+ transformer_heads: int,
+ transformer_layers: int):
+ super().__init__()
+
+ self.context_length = context_length
+
+ vision_heads = vision_width // 64
+ self.visual = VisionTransformer(
+ input_resolution=image_resolution,
+ patch_size=vision_patch_size,
+ width=vision_width,
+ layers=vision_layers,
+ heads=vision_heads,
+ output_dim=embed_dim)
+
+ self.transformer = Transformer(
+ width=transformer_width,
+ layers=transformer_layers,
+ heads=transformer_heads,
+ attn_mask=self.build_attention_mask())
+
+ self.vocab_size = vocab_size
+ self.token_embedding = nn.Embedding(vocab_size, transformer_width)
+ self.positional_embedding = nn.Parameter(
+ torch.empty(self.context_length, transformer_width))
+ self.ln_final = LayerNorm(transformer_width)
+
+ self.text_projection = nn.Parameter(
+ torch.empty(transformer_width, embed_dim))
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
+
+ self.initialize_parameters()
+
+ def initialize_parameters(self):
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
+ nn.init.normal_(self.positional_embedding, std=0.01)
+
+ proj_std = (self.transformer.width**-0.5) * (
+ (2 * self.transformer.layers)**-0.5)
+ attn_std = self.transformer.width**-0.5
+ fc_std = (2 * self.transformer.width)**-0.5
+ for block in self.transformer.resblocks:
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
+
+ if self.text_projection is not None:
+ nn.init.normal_(
+ self.text_projection, std=self.transformer.width**-0.5)
+
+ def build_attention_mask(self):
+ # lazily create causal attention mask, with full attention between the vision tokens
+ # pytorch uses additive attention mask; fill with -inf
+ mask = torch.empty(self.context_length, self.context_length)
+ mask.fill_(float('-inf'))
+ mask.triu_(1) # zero out the lower diagonal
+ return mask
+
+ @property
+ def dtype(self):
+ return self.visual.conv1.weight.dtype
+
+ def encode_image(self, image):
+ return self.visual(image.type(self.dtype))
+
+ def encode_text(self, text):
+ x = self.token_embedding(text).type(
+ self.dtype) # [batch_size, n_ctx, d_model]
+
+ x = x + self.positional_embedding.type(self.dtype)
+ x = x.permute(1, 0, 2) # NLD -> LND
+ x = self.transformer(x)
+ x = x.permute(1, 0, 2) # LND -> NLD
+ x = self.ln_final(x).type(self.dtype)
+ # x.shape = [batch_size, n_ctx, transformer.width]
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
+ # x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
+ x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)]
+
+ return x
+
+ def forward(self, image, text):
+ image_features = self.encode_image(image)
+ text_features = self.encode_text(text)
+
+ # normalized features
+ image_features = image_features / image_features.norm(
+ dim=1, keepdim=True)
+ text_features = text_features / text_features.norm(dim=1, keepdim=True)
+
+ # cosine similarity as logits
+ logit_scale = self.logit_scale.exp()
+ logits_per_image = logit_scale * image_features @ text_features.t()
+ logits_per_text = logits_per_image.t()
+
+ # shape = [global_batch_size, global_batch_size]
+ return logits_per_image, logits_per_text
+
+
+class LayerNorm(nn.LayerNorm):
+ """Subclass torch's LayerNorm to handle fp16."""
+
+ def forward(self, x: torch.Tensor):
+ orig_type = x.dtype
+ ret = super().forward(x.type(torch.float32))
+ return ret.type(orig_type)
+
+
+class QuickGELU(nn.Module):
+
+ def forward(self, x: torch.Tensor):
+ return x * torch.sigmoid(1.702 * x)
+
+
+class ResidualAttentionBlock(nn.Module):
+
+ def __init__(self,
+ d_model: int,
+ n_head: int,
+ attn_mask: torch.Tensor = None):
+ super().__init__()
+
+ self.attn = nn.MultiheadAttention(d_model, n_head)
+ self.ln_1 = LayerNorm(d_model)
+ self.mlp = nn.Sequential(
+ OrderedDict([('c_fc', nn.Linear(d_model, d_model * 4)),
+ ('gelu', QuickGELU()),
+ ('c_proj', nn.Linear(d_model * 4, d_model))]))
+ self.ln_2 = LayerNorm(d_model)
+ self.attn_mask = attn_mask
+
+ def attention(self, x: torch.Tensor):
+ self.attn_mask = self.attn_mask.to(
+ dtype=x.dtype,
+ device=x.device) if self.attn_mask is not None else None
+ return self.attn(
+ x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
+
+ def forward(self, x: torch.Tensor):
+ x = x + self.attention(self.ln_1(x))
+ x = x + self.mlp(self.ln_2(x))
+ return x
+
+
+class Transformer(nn.Module):
+
+ def __init__(self,
+ width: int,
+ layers: int,
+ heads: int,
+ attn_mask: torch.Tensor = None):
+ super().__init__()
+ self.width = width
+ self.layers = layers
+ self.resblocks = nn.Sequential(*[
+ ResidualAttentionBlock(width, heads, attn_mask)
+ for _ in range(layers)
+ ])
+
+ def forward(self, x: torch.Tensor):
+ return self.resblocks(x)
+
+
+class VisionTransformer(nn.Module):
+
+ def __init__(self, input_resolution: int, patch_size: int, width: int,
+ layers: int, heads: int, output_dim: int):
+ super().__init__()
+ self.input_resolution = input_resolution
+ self.output_dim = output_dim
+ self.conv1 = nn.Conv2d(
+ in_channels=3,
+ out_channels=width,
+ kernel_size=patch_size,
+ stride=patch_size,
+ bias=False)
+
+ scale = width**-0.5
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
+ self.positional_embedding = nn.Parameter(scale * torch.randn(
+ (input_resolution // patch_size)**2 + 1, width))
+ self.ln_pre = LayerNorm(width)
+
+ self.transformer = Transformer(width, layers, heads)
+
+ self.ln_post = LayerNorm(width)
+ self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
+
+ def forward(self, x: torch.Tensor):
+ x = self.conv1(x) # shape = [*, width, grid, grid]
+ x = x.reshape(x.shape[0], x.shape[1],
+ -1) # shape = [*, width, grid ** 2]
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
+ class_token = self.class_embedding.to(x.dtype) + torch.zeros(
+ x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device)
+ x = torch.cat([class_token, x], dim=1)
+ x = x + self.positional_embedding.to(x.dtype)
+ x = self.ln_pre(x)
+
+ x = x.permute(1, 0, 2) # NLD -> LND
+ x = self.transformer(x)
+ x = x.permute(1, 0, 2) # LND -> NLD
+
+ x = self.ln_post(x[:, 0, :])
+
+ if self.proj is not None:
+ x = x @ self.proj
+
+ return x
+
+
+def build_model(state_dict: dict):
+ vision_width = state_dict['visual.conv1.weight'].shape[0]
+ vision_layers = len([
+ k for k in state_dict.keys()
+ if k.startswith('visual.') and k.endswith('.attn.in_proj_weight')
+ ])
+ vision_patch_size = state_dict['visual.conv1.weight'].shape[-1]
+ grid_size = round(
+ (state_dict['visual.positional_embedding'].shape[0] - 1)**0.5)
+ image_resolution = vision_patch_size * grid_size
+
+ embed_dim = state_dict['text_projection'].shape[1]
+ context_length = state_dict['positional_embedding'].shape[0]
+ vocab_size = state_dict['token_embedding.weight'].shape[0]
+ transformer_width = state_dict['ln_final.weight'].shape[0]
+ transformer_heads = transformer_width // 64
+ transformer_layers = len(
+ set(
+ k.split('.')[2] for k in state_dict
+ if k.startswith('transformer.resblocks')))
+
+ model = CLIP(embed_dim, image_resolution, vision_layers, vision_width,
+ vision_patch_size, context_length, vocab_size,
+ transformer_width, transformer_heads, transformer_layers)
+
+ for key in ['input_resolution', 'context_length', 'vocab_size']:
+ if key in state_dict:
+ del state_dict[key]
+
+ model.load_state_dict(state_dict)
+ return model.eval()
+
+
+def load_clip(name: str,
+ device: Union[str, torch.device] = 'cuda'
+ if torch.cuda.is_available() else 'cpu',
+ jit=True):
+ jit = False
+ model_path = name
+ try:
+ model = torch.jit.load(
+ model_path, map_location=device if jit else 'cpu').eval()
+ state_dict = None
+ except RuntimeError:
+ if jit:
+ warnings.warn(
+ f'File {model_path} is not a JIT archive. Loading as a state dict instead'
+ )
+ jit = False
+ state_dict = torch.load(model_path, map_location='cpu')
+
+ if not jit:
+ model = build_model(state_dict or model.state_dict()).to(device)
+ if str(device) == 'cpu':
+ model.float()
+ return model
+
+ device_holder = torch.jit.trace(
+ lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
+ device_node = [
+ n for n in device_holder.graph.findAllNodes('prim::Constant')
+ if 'Device' in repr(n)
+ ][-1]
+
+ def patch_device(module):
+ graphs = [module.graph] if hasattr(module, 'graph') else []
+ if hasattr(module, 'forward1'):
+ graphs.append(module.forward1.graph)
+
+ for graph in graphs:
+ for node in graph.findAllNodes('prim::Constant'):
+ if 'value' in node.attributeNames() and str(
+ node['value']).startswith('cuda'):
+ node.copyAttributes(device_node)
+
+ model.apply(patch_device)
+ patch_device(model.encode_image)
+ patch_device(model.encode_text)
+
+ if str(device) == 'cpu':
+ float_holder = torch.jit.trace(
+ lambda: torch.ones([]).float(), example_inputs=[])
+ float_input = list(float_holder.graph.findNode('aten::to').inputs())[1]
+ float_node = float_input.node()
+
+ def patch_float(module):
+ graphs = [module.graph] if hasattr(module, 'graph') else []
+ if hasattr(module, 'forward1'):
+ graphs.append(module.forward1.graph)
+
+ for graph in graphs:
+ for node in graph.findAllNodes('aten::to'):
+ inputs = list(node.inputs())
+ for i in [1, 2]:
+ if inputs[i].node()['value'] == 5:
+ inputs[i].node().copyAttributes(float_node)
+
+ model.apply(patch_float)
+ patch_float(model.encode_image)
+ patch_float(model.encode_text)
+
+ model.float()
+
+ return model
diff --git a/modelscope/models/multi_modal/soonet/model.py b/modelscope/models/multi_modal/soonet/model.py
new file mode 100644
index 00000000..32912bfc
--- /dev/null
+++ b/modelscope/models/multi_modal/soonet/model.py
@@ -0,0 +1,156 @@
+# Copyright 2022-2023 The Alibaba Fundamental Vision Team Authors. All rights reserved.
+
+import os
+
+import torch
+import torch.nn as nn
+
+from modelscope.metainfo import Models
+from modelscope.models.base.base_torch_model import TorchModel
+from modelscope.models.builder import MODELS
+from modelscope.utils.config import Config
+from modelscope.utils.constant import ModelFile, Tasks
+from .blocks import (BboxRegressor, Q2VRankerStage1, Q2VRankerStage2,
+ V2QRankerStage1, V2QRankerStage2)
+from .swin_transformer import SwinTransformerV2_1D
+
+
+@MODELS.register_module(
+ Tasks.video_temporal_grounding, module_name=Models.soonet)
+class SOONet(TorchModel):
+ """
+ The implementation of 'Scanning Only Once: An End-to-end Framework for Fast Temporal Grounding
+ in Long Videos'. The model is dynamically initialized with the following parts:
+ - q2v_stage1: calculate qv_ctx_score.
+ - v2q_stage1: calculate vq_ctx_score.
+ - q2v_stage2: calculate qv_ctn_score.
+ - v2q_stage2: calculate vq_ctn_score.
+ - regressor: predict the offset of bounding box for each candidate anchor.
+ """
+
+ def __init__(self, model_dir: str, *args, **kwargs):
+ """
+ Initialize SOONet Model
+
+ Args:
+ model_dir: model id or path
+ """
+ super().__init__()
+ config_path = os.path.join(model_dir, ModelFile.CONFIGURATION)
+ self.config = Config.from_file(config_path).hyperparams
+ nscales = self.config.nscales
+ hidden_dim = self.config.hidden_dim
+ snippet_length = self.config.snippet_length
+ self.enable_stage2 = self.config.enable_stage2
+ self.stage2_topk = self.config.stage2_topk
+ self.nscales = nscales
+
+ self.video_encoder = SwinTransformerV2_1D(
+ patch_size=snippet_length,
+ in_chans=hidden_dim,
+ embed_dim=hidden_dim,
+ depths=[2] * nscales,
+ num_heads=[8] * nscales,
+ window_size=[64] * nscales,
+ mlp_ratio=2.,
+ qkv_bias=True,
+ drop_rate=0.,
+ attn_drop_rate=0.,
+ drop_path_rate=0.1,
+ norm_layer=nn.LayerNorm,
+ patch_norm=True,
+ use_checkpoint=False,
+ pretrained_window_sizes=[0] * nscales)
+
+ self.q2v_stage1 = Q2VRankerStage1(nscales, hidden_dim)
+ self.v2q_stage1 = V2QRankerStage1(nscales, hidden_dim)
+ if self.enable_stage2:
+ self.q2v_stage2 = Q2VRankerStage2(nscales, hidden_dim,
+ snippet_length)
+ self.v2q_stage2 = V2QRankerStage2(nscales, hidden_dim)
+ self.regressor = BboxRegressor(hidden_dim, self.enable_stage2)
+
+ # Load trained weights
+ model_path = os.path.join(model_dir,
+ 'SOONet_MAD_VIT-B-32_4Scale_10C.pth')
+ state_dict = torch.load(model_path, map_location='cpu')['model']
+ self.load_state_dict(state_dict, strict=True)
+
+ def forward(self, **kwargs):
+ if self.training:
+ return self.forward_train(**kwargs)
+ else:
+ return self.forward_test(**kwargs)
+
+ def forward_train(self, **kwargs):
+ raise NotImplementedError
+
+ def forward_test(self,
+ query_feats=None,
+ video_feats=None,
+ start_ts=None,
+ end_ts=None,
+ scale_boundaries=None,
+ **kwargs):
+ """
+ Obtain matching scores and bbox bias of the top-k candidate anchors, with
+ pre-extracted query features and video features as input.
+
+ Args:
+ query_feats: the pre-extracted text features.
+ video_feats: the pre-extracted video features.
+ start_ts: the start timestamps of pre-defined multi-scale anchors.
+ end_ts: the end timestamps of pre-defined multi-scale anchors.
+ scale_boundaries: the begin and end anchor index for each scale in start_ts and end_ts.
+
+ Returns:
+ [final_scores, bbox_bias, starts, ends]
+ """
+ sent_feat = query_feats
+ ctx_feats = self.video_encoder(video_feats.permute(0, 2, 1))
+ qv_ctx_scores = self.q2v_stage1(ctx_feats, sent_feat)
+ if self.enable_stage2:
+ hit_indices = list()
+ starts = list()
+ ends = list()
+ filtered_ctx_feats = list()
+ for i in range(self.nscales):
+ _, indices = torch.sort(
+ qv_ctx_scores[i], dim=1, descending=True)
+ indices, _ = torch.sort(
+ torch.LongTensor(
+ list(
+ set(indices[:, :self.stage2_topk].flatten().cpu().
+ numpy().tolist()))))
+ indices = indices.to(video_feats.device)
+ hit_indices.append(indices)
+
+ filtered_ctx_feats.append(
+ torch.index_select(ctx_feats[i], 1, indices))
+
+ scale_first = scale_boundaries[i]
+ scale_last = scale_boundaries[i + 1]
+
+ filtered_start = torch.index_select(
+ start_ts[scale_first:scale_last], 0, indices)
+ filtered_end = torch.index_select(
+ end_ts[scale_first:scale_last], 0, indices)
+ starts.append(filtered_start)
+ ends.append(filtered_end)
+
+ starts = torch.cat(starts, dim=0)
+ ends = torch.cat(ends, dim=0)
+
+ qv_merge_scores, qv_ctn_scores, ctn_feats = self.q2v_stage2(
+ video_feats, sent_feat, hit_indices, qv_ctx_scores)
+ ctx_feats = filtered_ctx_feats
+ else:
+ ctn_feats = None
+ qv_merge_scores = qv_ctx_scores
+ starts = start_ts
+ ends = end_ts
+
+ bbox_bias = self.regressor(ctx_feats, ctn_feats, sent_feat)
+ final_scores = torch.sigmoid(torch.cat(qv_merge_scores, dim=1))
+
+ return final_scores, bbox_bias, starts, ends
diff --git a/modelscope/models/multi_modal/soonet/swin_transformer.py b/modelscope/models/multi_modal/soonet/swin_transformer.py
new file mode 100644
index 00000000..459561c0
--- /dev/null
+++ b/modelscope/models/multi_modal/soonet/swin_transformer.py
@@ -0,0 +1,623 @@
+# The implementation is adopted from Swin-Transformer-1D, made publicly available
+# at https://github.com/meraks/Swin-Transformer-1D
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as checkpoint
+from torch.nn.init import trunc_normal_
+
+
+def drop_path(x,
+ drop_prob: float = 0.,
+ training: bool = False,
+ scale_by_keep: bool = True):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
+ 'survival rate' as the argument.
+ """
+ if drop_prob == 0. or not training:
+ return x
+ keep_prob = 1 - drop_prob
+ shape = (x.shape[0], ) + (1, ) * (
+ x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
+ if keep_prob > 0.0 and scale_by_keep:
+ random_tensor.div_(keep_prob)
+ return x * random_tensor
+
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+ """
+
+ def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+ self.scale_by_keep = scale_by_keep
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
+
+ def extra_repr(self):
+ return f'drop_prob={round(self.drop_prob,3):0.3f}'
+
+
+class Mlp(nn.Module):
+
+ def __init__(self,
+ in_features,
+ hidden_features=None,
+ out_features=None,
+ act_layer=nn.GELU,
+ drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+def window_partition(x, window_size):
+ """
+ Args:
+ x: (B, L, C)
+ window_size (int): window size
+ Returns:
+ windows: (num_windows*B, window_size, C)
+ """
+ B, L, C = x.shape
+ x = x.view(B, L // window_size, window_size, C)
+ windows = x.permute(0, 1, 2, 3).contiguous().view(-1, window_size, C)
+ return windows
+
+
+def window_reverse(windows, window_size, L):
+ """
+ Args:
+ windows: (num_windows*B, window_size, window_size, C)
+ window_size (int): Window size
+ L (int): sequence length
+ Returns:
+ x: (B, L, C)
+ """
+ B = int(windows.shape[0] / (L / window_size))
+ x = windows.view(B, L // window_size, window_size, -1)
+ x = x.permute(0, 1, 2, 3).contiguous().view(B, L, -1)
+ return x
+
+
+class WindowAttention_1D(nn.Module):
+ r""" Window based multi-head self attention (W-MSA) module with relative position bias.
+ It supports both of shifted and non-shifted window.
+ Args:
+ dim (int): Number of input channels.
+ window_size (int): The height and width of the window.
+ num_heads (int): Number of attention heads.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
+ pretrained_window_size (int): The height and width of the window in pre-training.
+ """
+
+ def __init__(self,
+ dim,
+ window_size,
+ num_heads,
+ qkv_bias=True,
+ attn_drop=0.,
+ proj_drop=0.,
+ pretrained_window_size=0):
+
+ super().__init__()
+ self.dim = dim
+ self.window_size = window_size # Wl
+ self.pretrained_window_size = pretrained_window_size
+ self.num_heads = num_heads
+
+ self.logit_scale = nn.Parameter(
+ torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True)
+
+ # mlp to generate continuous relative position bias
+ self.cpb_mlp = nn.Sequential(
+ nn.Linear(1, 512, bias=True), nn.ReLU(inplace=True),
+ nn.Linear(512, num_heads, bias=False))
+
+ # get relative_coords_table
+ relative_coords_l = torch.arange(
+ -(self.window_size - 1), self.window_size, dtype=torch.float32)
+ relative_coords_table = torch.stack(
+ torch.meshgrid([relative_coords_l], indexing='ij')).permute(
+ 1, 0).contiguous().unsqueeze(0) # 1, 2*Wl-1, 1
+ if pretrained_window_size > 0:
+ relative_coords_table[:, :, :] /= (pretrained_window_size - 1)
+ else:
+ relative_coords_table[:, :, :] /= (self.window_size - 1)
+ relative_coords_table *= 8 # normalize to -8, 8
+ relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
+ torch.abs(relative_coords_table) + 1.0) / np.log2(8)
+
+ self.register_buffer('relative_coords_table', relative_coords_table)
+
+ # get pair-wise relative position index for each token inside the window
+ coords_l = torch.arange(self.window_size)
+ coords = torch.stack(torch.meshgrid([coords_l],
+ indexing='ij')) # 1, Wl
+ coords_flatten = torch.flatten(coords, 1) # 1, Wl
+ relative_coords = coords_flatten[:, :,
+ None] - coords_flatten[:,
+ None, :] # 1, Wl, Wl
+ relative_coords = relative_coords.permute(1, 2,
+ 0).contiguous() # Wl, Wl, 1
+ relative_coords[:, :,
+ 0] += self.window_size - 1 # shift to start from 0
+ relative_position_index = relative_coords.sum(-1) # Wl, Wl
+ self.register_buffer('relative_position_index',
+ relative_position_index)
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=False)
+ if qkv_bias:
+ self.q_bias = nn.Parameter(torch.zeros(dim))
+ self.v_bias = nn.Parameter(torch.zeros(dim))
+ else:
+ self.q_bias = None
+ self.v_bias = None
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+ self.softmax = nn.Softmax(dim=-1)
+
+ def forward(self, x, mask=None):
+ """
+ Args:
+ x: input features with shape of (num_windows*B, N, C)
+ mask: (0/-inf) mask with shape of (num_windows, Wl, Wl) or None
+ """
+ B_, N, C = x.shape
+ qkv_bias = None
+ if self.q_bias is not None:
+ qkv_bias = torch.cat(
+ (self.q_bias,
+ torch.zeros_like(self.v_bias,
+ requires_grad=False), self.v_bias))
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
+ qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv[0], qkv[1], qkv[
+ 2] # make torchscript happy (cannot use tensor as tuple)
+
+ # cosine attention
+ attn = (
+ F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))
+ logit_scale = torch.clamp(
+ self.logit_scale,
+ max=torch.log(torch.tensor(1. / 0.01, device=attn.device))).exp()
+ attn = attn * logit_scale
+
+ relative_position_bias_table = self.cpb_mlp(
+ self.relative_coords_table).view(-1, self.num_heads)
+ relative_position_bias = relative_position_bias_table[
+ self.relative_position_index.view(-1)].view(
+ self.window_size, self.window_size, -1) # Wl,l,nH
+ relative_position_bias = relative_position_bias.permute(
+ 2, 0, 1).contiguous() # nH, Wl, Wl
+ relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
+ attn = attn + relative_position_bias.unsqueeze(0)
+
+ if mask is not None:
+ nW = mask.shape[0]
+ attn = attn.view(B_ // nW, nW, self.num_heads, N,
+ N) + mask.unsqueeze(1).unsqueeze(0)
+ attn = attn.view(-1, self.num_heads, N, N)
+ attn = self.softmax(attn)
+ else:
+ attn = self.softmax(attn)
+
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+def compute_mask(L, window_size, shift_size):
+ Lp = int(np.ceil(L / window_size)) * window_size
+ img_mask = torch.zeros((1, Lp, 1)) # 1 Lp 1
+ pad_size = int(Lp - L)
+ if (pad_size == 0) or (pad_size + shift_size == window_size):
+ segs = (slice(-window_size), slice(-window_size, -shift_size),
+ slice(-shift_size, None))
+ elif pad_size + shift_size > window_size:
+ seg1 = int(window_size * 2 - L + shift_size)
+ segs = (slice(-seg1), slice(-seg1, -window_size),
+ slice(-window_size, -shift_size), slice(-shift_size, None))
+ elif pad_size + shift_size < window_size:
+ seg1 = int(window_size * 2 - L + shift_size)
+ segs = (slice(-window_size), slice(-window_size, -seg1),
+ slice(-seg1, -shift_size), slice(-shift_size, None))
+ cnt = 0
+ for d in segs:
+ img_mask[:, d, :] = cnt
+ cnt += 1
+ mask_windows = window_partition(img_mask, window_size) # nW, ws, 1
+ mask_windows = mask_windows.squeeze(-1) # nW, ws
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
+ attn_mask = attn_mask.masked_fill(attn_mask != 0,
+ float(-100.0)).masked_fill(
+ attn_mask == 0, float(0.0))
+ return attn_mask
+
+
+class SwinTransformerBlock_1D(nn.Module):
+ r""" Swin Transformer Block.
+ Args:
+ dim (int): Number of input channels.
+ num_heads (int): Number of attention heads.
+ window_size (int): Window size.
+ shift_size (int): Shift size for SW-MSA.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ drop (float, optional): Dropout rate. Default: 0.0
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ pretrained_window_size (int): Window size in pre-training.
+ """
+
+ def __init__(self,
+ dim,
+ num_heads,
+ window_size=7,
+ shift_size=0,
+ mlp_ratio=4.,
+ qkv_bias=True,
+ drop=0.,
+ attn_drop=0.,
+ drop_path=0.,
+ act_layer=nn.GELU,
+ norm_layer=nn.LayerNorm,
+ pretrained_window_size=0):
+ super().__init__()
+ self.dim = dim
+ self.num_heads = num_heads
+ self.window_size = window_size
+ self.shift_size = shift_size
+ self.mlp_ratio = mlp_ratio
+
+ assert 0 <= self.shift_size < self.window_size, 'shift_size must in 0-window_size'
+
+ self.norm1 = norm_layer(dim)
+ self.attn = WindowAttention_1D(
+ dim,
+ window_size=self.window_size,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ pretrained_window_size=pretrained_window_size)
+
+ self.drop_path = DropPath(
+ drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=drop)
+
+ def forward(self, x):
+ B, L, C = x.shape
+
+ attn_mask = compute_mask(L, self.window_size,
+ self.shift_size).to(x.device)
+
+ shortcut = x
+ # x = x.view(B, L, C)
+
+ # padding x
+ pad_r = (self.window_size - L % self.window_size) % self.window_size
+ x = F.pad(x, (0, 0, 0, pad_r))
+ _, Lp, _ = x.shape
+
+ # cyclic shift
+ if self.shift_size > 0:
+ shifted_x = torch.roll(x, shifts=(-self.shift_size), dims=(1))
+ else:
+ shifted_x = x
+
+ # partition windows
+ x_windows = window_partition(shifted_x,
+ self.window_size) # nW*B, window_size, C
+ x_windows = x_windows.view(-1, self.window_size,
+ C) # nW*B, window_siz, C
+
+ # W-MSA/SW-MSA
+ attn_windows = self.attn(
+ x_windows, mask=attn_mask) # nW*B, window_size, C
+
+ # merge windows
+ attn_windows = attn_windows.view(-1, self.window_size, C)
+ shifted_x = window_reverse(attn_windows, self.window_size,
+ Lp) # B L' C
+
+ # reverse cyclic shift
+ if self.shift_size > 0:
+ x = torch.roll(shifted_x, shifts=(self.shift_size), dims=(1))
+ else:
+ x = shifted_x
+ x = x.view(B, Lp, C)
+ # reverse padding x
+ x = x[:, :L, :].contiguous()
+ x = shortcut + self.drop_path(self.norm1(x))
+
+ # FFN
+ x = x + self.drop_path(self.norm2(self.mlp(x)))
+
+ return x
+
+
+class PatchMerging(nn.Module):
+ """ Patch Merging Layer
+ Args:
+ dim (int): Number of input channels.
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ """
+
+ def __init__(self, dim, norm_layer=nn.LayerNorm):
+ super().__init__()
+ self.dim = dim
+ # self.reduction = nn.Linear(2 * dim, dim, bias=False)
+ # self.norm = norm_layer(2 * dim)
+
+ def forward(self, x):
+ """ Forward function.
+ Args:
+ x: Input feature, tensor size (B, L, C).
+ """
+ B, L, C = x.shape
+ x = F.pad(x, (0, 0, 0, L % 2))
+
+ x0 = x[:, 0::2, :] # B L/2 C
+ x1 = x[:, 1::2, :] # B L/2 C
+
+ x = torch.maximum(x0, x1)
+
+ return x
+
+
+class BasicLayer(nn.Module):
+ """ A basic Swin Transformer layer for one stage.
+ Args:
+ dim (int): Number of input channels.
+ depth (int): Number of blocks.
+ num_heads (int): Number of attention heads.
+ window_size (int): Local window size.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ drop (float, optional): Dropout rate. Default: 0.0
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+ pretrained_window_size (int): Local window size in pre-training.
+ """
+
+ def __init__(self,
+ dim,
+ depth,
+ num_heads,
+ window_size,
+ mlp_ratio=4.,
+ qkv_bias=True,
+ drop=0.,
+ attn_drop=0.,
+ drop_path=0.,
+ norm_layer=nn.LayerNorm,
+ downsample=None,
+ use_checkpoint=False,
+ pretrained_window_size=0):
+
+ super().__init__()
+ self.dim = dim
+ self.depth = depth
+ self.use_checkpoint = use_checkpoint
+
+ # build blocks
+ self.blocks = nn.ModuleList([
+ SwinTransformerBlock_1D(
+ dim=dim,
+ num_heads=num_heads,
+ window_size=window_size,
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ drop=drop,
+ attn_drop=attn_drop,
+ drop_path=drop_path[i]
+ if isinstance(drop_path, list) else drop_path,
+ norm_layer=norm_layer,
+ pretrained_window_size=pretrained_window_size)
+ for i in range(depth)
+ ])
+
+ # patch merging layer
+ if downsample is not None:
+ self.downsample = downsample(dim=dim, norm_layer=norm_layer)
+ else:
+ self.downsample = None
+
+ def forward(self, x):
+ for blk in self.blocks:
+ if self.use_checkpoint:
+ x = checkpoint.checkpoint(blk, x)
+ else:
+ x = blk(x)
+
+ proposal = x
+ if self.downsample is not None:
+ x = self.downsample(x)
+ return x, proposal
+
+ def _init_respostnorm(self):
+ for blk in self.blocks:
+ nn.init.constant_(blk.norm1.bias, 0)
+ nn.init.constant_(blk.norm1.weight, 0)
+ nn.init.constant_(blk.norm2.bias, 0)
+ nn.init.constant_(blk.norm2.weight, 0)
+
+
+class PatchEmbed1D(nn.Module):
+ """ Video to Patch Embedding.
+ Args:
+ patch_size (int): Patch token size. Default: 4.
+ in_chans (int): Number of input video channels. Default: 3.
+ embed_dim (int): Number of linear projection output channels. Default: 96.
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
+ """
+
+ def __init__(self,
+ patch_size=4,
+ in_chans=32,
+ embed_dim=128,
+ norm_layer=None):
+ super().__init__()
+ self.patch_size = patch_size
+
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+
+ self.proj = nn.Conv1d(
+ in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+
+ if norm_layer is not None:
+ self.norm = norm_layer(embed_dim)
+ else:
+ self.norm = None
+
+ def forward(self, x):
+ """Forward function."""
+ # padding
+ _, _, L = x.size()
+ pad_r = (self.patch_size - L % self.patch_size) % self.patch_size
+ x = F.pad(x, (0, pad_r))
+ x = self.proj(x) # B C Wl
+ if self.norm is not None:
+ # Wl = x.size(2)
+ x = x.transpose(1, 2)
+ x = self.norm(x)
+ # x = x.transpose(1, 2).view(-1, self.embed_dim, Wl)
+
+ return x
+
+
+class SwinTransformerV2_1D(nn.Module):
+
+ def __init__(self,
+ patch_size=4,
+ in_chans=32,
+ embed_dim=96,
+ depths=[2, 2, 6, 2],
+ num_heads=[3, 6, 12, 24],
+ window_size=[7, 7, 7, 7],
+ mlp_ratio=4.,
+ qkv_bias=True,
+ drop_rate=0.,
+ attn_drop_rate=0.,
+ drop_path_rate=0.1,
+ norm_layer=nn.LayerNorm,
+ patch_norm=True,
+ use_checkpoint=False,
+ pretrained_window_sizes=[0, 0, 0, 0],
+ **kwargs):
+ super().__init__()
+
+ self.num_layers = len(depths)
+ self.embed_dim = embed_dim
+ self.patch_norm = patch_norm
+ self.num_features = int(embed_dim * 2**(self.num_layers - 1))
+ self.mlp_ratio = mlp_ratio
+
+ # split image into non-overlapping patches
+ self.patch_embed = PatchEmbed1D(
+ patch_size=patch_size,
+ in_chans=in_chans,
+ embed_dim=embed_dim,
+ norm_layer=norm_layer if self.patch_norm else None)
+
+ self.pos_drop = nn.Dropout(p=drop_rate)
+ # stochastic depth
+ dpr = [
+ x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
+ ] # stochastic depth decay rule
+
+ # build layers
+ self.layers = nn.ModuleList()
+ for i_layer in range(self.num_layers):
+ layer = BasicLayer(
+ dim=embed_dim,
+ depth=depths[i_layer],
+ num_heads=num_heads[i_layer],
+ window_size=window_size[i_layer],
+ mlp_ratio=self.mlp_ratio,
+ qkv_bias=qkv_bias,
+ drop=drop_rate,
+ attn_drop=attn_drop_rate,
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
+ norm_layer=norm_layer,
+ downsample=PatchMerging if
+ (i_layer < self.num_layers - 1) else None,
+ use_checkpoint=use_checkpoint,
+ pretrained_window_size=pretrained_window_sizes[i_layer])
+ self.layers.append(layer)
+
+ self.apply(self._init_weights)
+ for bly in self.layers:
+ bly._init_respostnorm()
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {'absolute_pos_embed'}
+
+ @torch.jit.ignore
+ def no_weight_decay_keywords(self):
+ return {'cpb_mlp', 'logit_scale', 'relative_position_bias_table'}
+
+ def forward_features(self, x):
+ x = self.patch_embed(x)
+ x = self.pos_drop(x)
+
+ proposals = list()
+ for layer in self.layers:
+ x, proposal = layer(x)
+ proposals.append(proposal)
+
+ return proposals
+
+ def forward(self, x):
+ return self.forward_features(x)
diff --git a/modelscope/models/multi_modal/soonet/tokenizer.py b/modelscope/models/multi_modal/soonet/tokenizer.py
new file mode 100644
index 00000000..ed4f40a7
--- /dev/null
+++ b/modelscope/models/multi_modal/soonet/tokenizer.py
@@ -0,0 +1,152 @@
+# The implementation is adopted from CLIP, made publicly available
+# under MIT License at https://github.com/openai/CLIP
+
+import gzip
+import html
+from functools import lru_cache
+
+import ftfy
+import regex as re
+import torch
+
+
+@lru_cache()
+def bytes_to_unicode():
+ bs = list(range(ord('!'),
+ ord('~') + 1)) + list(range(
+ ord('¡'),
+ ord('¬') + 1)) + list(range(ord('®'),
+ ord('ÿ') + 1))
+ cs = bs[:]
+ n = 0
+ for b in range(2**8):
+ if b not in bs:
+ bs.append(b)
+ cs.append(2**8 + n)
+ n += 1
+ cs = [chr(n) for n in cs]
+ return dict(zip(bs, cs))
+
+
+def get_pairs(word):
+ pairs = set()
+ prev_char = word[0]
+ for char in word[1:]:
+ pairs.add((prev_char, char))
+ prev_char = char
+ return pairs
+
+
+def basic_clean(text):
+ text = ftfy.fix_text(text)
+ text = html.unescape(html.unescape(text))
+ return text.strip()
+
+
+def whitespace_clean(text):
+ text = re.sub(r'\s+', ' ', text)
+ text = text.strip()
+ return text
+
+
+class SimpleTokenizer(object):
+
+ def __init__(self, bpe_path):
+ self.byte_encoder = bytes_to_unicode()
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
+ merges = gzip.open(bpe_path).read().decode('utf-8').split('\n')
+ merges = merges[1:49152 - 256 - 2 + 1]
+ merges = [tuple(merge.split()) for merge in merges]
+ vocab = list(bytes_to_unicode().values())
+ vocab = vocab + [v + '' for v in vocab]
+ for merge in merges:
+ vocab.append(''.join(merge))
+ vocab.extend(['<|startoftext|>', '<|endoftext|>'])
+ self.encoder = dict(zip(vocab, range(len(vocab))))
+ self.decoder = {v: k for k, v in self.encoder.items()}
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
+ self.cache = {
+ '<|startoftext|>': '<|startoftext|>',
+ '<|endoftext|>': '<|endoftext|>'
+ }
+ self.pat = re.compile(
+ r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
+ re.IGNORECASE)
+
+ def bpe(self, token):
+ if token in self.cache:
+ return self.cache[token]
+ word = tuple(token[:-1]) + (token[-1] + '', )
+ pairs = get_pairs(word)
+
+ if not pairs:
+ return token + ''
+
+ while True:
+ bigram = min(
+ pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
+ if bigram not in self.bpe_ranks:
+ break
+ first, second = bigram
+ new_word = []
+ i = 0
+ while i < len(word):
+ try:
+ j = word.index(first, i)
+ new_word.extend(word[i:j])
+ i = j
+ except ValueError:
+ new_word.extend(word[i:])
+ break
+
+ if word[i] == first and i < len(word) - 1 and word[
+ i + 1] == second:
+ new_word.append(first + second)
+ i += 2
+ else:
+ new_word.append(word[i])
+ i += 1
+ new_word = tuple(new_word)
+ word = new_word
+ if len(word) == 1:
+ break
+ else:
+ pairs = get_pairs(word)
+ word = ' '.join(word)
+ self.cache[token] = word
+ return word
+
+ def encode(self, text):
+ bpe_tokens = []
+ text = whitespace_clean(basic_clean(text)).lower()
+ for token in re.findall(self.pat, text):
+ token = ''.join(self.byte_encoder[b]
+ for b in token.encode('utf-8'))
+ bpe_tokens.extend(self.encoder[bpe_token]
+ for bpe_token in self.bpe(token).split(' '))
+ return bpe_tokens
+
+ def decode(self, tokens):
+ text = ''.join([self.decoder[token] for token in tokens])
+ text = bytearray([self.byte_decoder[c] for c in text]).decode(
+ 'utf-8', errors='replace').replace('', ' ')
+ return text
+
+ def tokenize(self, texts, context_length=77):
+ if isinstance(texts, str):
+ texts = [texts]
+
+ sot_token = self.encoder['<|startoftext|>']
+ eot_token = self.encoder['<|endoftext|>']
+ all_tokens = [[sot_token] + self.encode(text) + [eot_token]
+ for text in texts]
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)
+
+ for i, tokens in enumerate(all_tokens):
+ if len(tokens) > context_length:
+ tokens = tokens[:context_length]
+ tokens[-1] = eot_token
+
+ result[i, :len(tokens)] = torch.tensor(tokens)
+
+ return result
diff --git a/modelscope/models/multi_modal/soonet/utils.py b/modelscope/models/multi_modal/soonet/utils.py
new file mode 100644
index 00000000..8fe3960e
--- /dev/null
+++ b/modelscope/models/multi_modal/soonet/utils.py
@@ -0,0 +1,58 @@
+# Copyright 2022-2023 The Alibaba Fundamental Vision Team Authors. All rights reserved.
+
+import copy
+
+import decord
+import numpy as np
+from decord import VideoReader, cpu
+from decord._ffi.base import DECORDError
+from tqdm import tqdm
+
+
+def decode_video(video_path, target_fps=5):
+ """
+ Decode video from 'video_path' and return the sampled frames based on target_fps.
+ The default value of target_fps is 5.
+
+ Args:
+ video_path: the absolute path of video.
+ target_fps: the number of sampled video frames per second.
+
+ Returns:
+ [imgs, duration]
+ """
+ decord.bridge.set_bridge('torch')
+ vr = VideoReader(video_path, ctx=cpu(0))
+ cur_fps = vr.get_avg_fps()
+ if cur_fps > target_fps:
+ interval = float(cur_fps) / float(target_fps)
+ start = float(interval) / 2.
+ else:
+ interval = 1.0
+ start = 0.0
+
+ vid_length = len(vr)
+ duration = vid_length / cur_fps
+ sampled_idxs = np.clip(
+ np.round(np.arange(start, float(vid_length), step=interval)), 0,
+ vid_length - 1).astype(np.int32)
+
+ imgs = list()
+ for i in tqdm(sampled_idxs):
+ bias = 0
+ # avoid broken frames
+ while bias <= 10:
+ try:
+ img = vr[i - bias]
+ break
+ except DECORDError:
+ bias += 1
+ if bias > 10:
+ img = copy.deepcopy(imgs[-1])
+ imgs.append(img)
+ else:
+ img = img / 255.
+ img = img.permute(2, 0, 1)
+ imgs.append(img)
+
+ return imgs, duration
diff --git a/modelscope/models/multi_modal/video_synthesis/__init__.py b/modelscope/models/multi_modal/video_synthesis/__init__.py
new file mode 100644
index 00000000..7db72e7c
--- /dev/null
+++ b/modelscope/models/multi_modal/video_synthesis/__init__.py
@@ -0,0 +1,23 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+from typing import TYPE_CHECKING
+
+from modelscope.utils.import_utils import LazyImportModule
+
+if TYPE_CHECKING:
+
+ from .text_to_video_synthesis_model import TextToVideoSynthesis
+
+else:
+ _import_structure = {
+ 'text_to_video_synthesis_model': ['TextToVideoSynthesis'],
+ }
+
+ import sys
+
+ sys.modules[__name__] = LazyImportModule(
+ __name__,
+ globals()['__file__'],
+ _import_structure,
+ module_spec=__spec__,
+ extra_objects={},
+ )
diff --git a/modelscope/models/multi_modal/video_synthesis/autoencoder.py b/modelscope/models/multi_modal/video_synthesis/autoencoder.py
new file mode 100644
index 00000000..7885f262
--- /dev/null
+++ b/modelscope/models/multi_modal/video_synthesis/autoencoder.py
@@ -0,0 +1,569 @@
+# Part of the implementation is borrowed and modified from latent-diffusion,
+# publicly avaialbe at https://github.com/CompVis/latent-diffusion.
+# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+__all__ = ['AutoencoderKL']
+
+
+def nonlinearity(x):
+ # swish
+ return x * torch.sigmoid(x)
+
+
+def Normalize(in_channels, num_groups=32):
+ return torch.nn.GroupNorm(
+ num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
+
+
+class DiagonalGaussianDistribution(object):
+
+ def __init__(self, parameters, deterministic=False):
+ self.parameters = parameters
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
+ self.deterministic = deterministic
+ self.std = torch.exp(0.5 * self.logvar)
+ self.var = torch.exp(self.logvar)
+ if self.deterministic:
+ self.var = self.std = torch.zeros_like(
+ self.mean).to(device=self.parameters.device)
+
+ def sample(self):
+ x = self.mean + self.std * torch.randn(
+ self.mean.shape).to(device=self.parameters.device)
+ return x
+
+ def kl(self, other=None):
+ if self.deterministic:
+ return torch.Tensor([0.])
+ else:
+ if other is None:
+ return 0.5 * torch.sum(
+ torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
+ dim=[1, 2, 3])
+ else:
+ return 0.5 * torch.sum(
+ torch.pow(self.mean - other.mean, 2) / other.var
+ + self.var / other.var - 1.0 - self.logvar + other.logvar,
+ dim=[1, 2, 3])
+
+ def nll(self, sample, dims=[1, 2, 3]):
+ if self.deterministic:
+ return torch.Tensor([0.])
+ logtwopi = np.log(2.0 * np.pi)
+ return 0.5 * torch.sum(
+ logtwopi + self.logvar
+ + torch.pow(sample - self.mean, 2) / self.var,
+ dim=dims)
+
+ def mode(self):
+ return self.mean
+
+
+class ResnetBlock(nn.Module):
+
+ def __init__(self,
+ *,
+ in_channels,
+ out_channels=None,
+ conv_shortcut=False,
+ dropout,
+ temb_channels=512):
+ super().__init__()
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+ self.use_conv_shortcut = conv_shortcut
+
+ self.norm1 = Normalize(in_channels)
+ self.conv1 = torch.nn.Conv2d(
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1)
+ if temb_channels > 0:
+ self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
+ self.norm2 = Normalize(out_channels)
+ self.dropout = torch.nn.Dropout(dropout)
+ self.conv2 = torch.nn.Conv2d(
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1)
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ self.conv_shortcut = torch.nn.Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ else:
+ self.nin_shortcut = torch.nn.Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+ def forward(self, x, temb):
+ h = x
+ h = self.norm1(h)
+ h = nonlinearity(h)
+ h = self.conv1(h)
+
+ if temb is not None:
+ h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
+
+ h = self.norm2(h)
+ h = nonlinearity(h)
+ h = self.dropout(h)
+ h = self.conv2(h)
+
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ x = self.conv_shortcut(x)
+ else:
+ x = self.nin_shortcut(x)
+
+ return x + h
+
+
+class AttnBlock(nn.Module):
+
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+ self.k = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+ self.v = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+ self.proj_out = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b, c, h, w = q.shape
+ q = q.reshape(b, c, h * w)
+ q = q.permute(0, 2, 1) # b,hw,c
+ k = k.reshape(b, c, h * w) # b,c,hw
+ w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
+ w_ = w_ * (int(c)**(-0.5))
+ w_ = torch.nn.functional.softmax(w_, dim=2)
+
+ # attend to values
+ v = v.reshape(b, c, h * w)
+ w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
+ h_ = torch.bmm(
+ v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
+ h_ = h_.reshape(b, c, h, w)
+
+ h_ = self.proj_out(h_)
+
+ return x + h_
+
+
+class Upsample(nn.Module):
+
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ self.conv = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=3, stride=1, padding=1)
+
+ def forward(self, x):
+ x = torch.nn.functional.interpolate(
+ x, scale_factor=2.0, mode='nearest')
+ if self.with_conv:
+ x = self.conv(x)
+ return x
+
+
+class Downsample(nn.Module):
+
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ # no asymmetric padding in torch conv, must do it ourselves
+ self.conv = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=3, stride=2, padding=0)
+
+ def forward(self, x):
+ if self.with_conv:
+ pad = (0, 1, 0, 1)
+ x = torch.nn.functional.pad(x, pad, mode='constant', value=0)
+ x = self.conv(x)
+ else:
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
+ return x
+
+
+class Encoder(nn.Module):
+
+ def __init__(self,
+ *,
+ ch,
+ ch_mult=(1, 2, 4, 8),
+ num_res_blocks,
+ attn_resolutions,
+ dropout=0.0,
+ resamp_with_conv=True,
+ in_channels,
+ resolution,
+ z_channels,
+ double_z=True,
+ **ignore_kwargs):
+ super().__init__()
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+
+ # downsampling
+ self.conv_in = torch.nn.Conv2d(
+ in_channels, self.ch, kernel_size=3, stride=1, padding=1)
+
+ curr_res = resolution
+ in_ch_mult = (1, ) + tuple(ch_mult)
+ self.in_ch_mult = in_ch_mult
+ self.down = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_in = ch * in_ch_mult[i_level]
+ block_out = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks):
+ block.append(
+ ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(AttnBlock(block_in))
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ if i_level != self.num_resolutions - 1:
+ down.downsample = Downsample(block_in, resamp_with_conv)
+ curr_res = curr_res // 2
+ self.down.append(down)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ self.mid.attn_1 = AttnBlock(block_in)
+ self.mid.block_2 = ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(
+ block_in,
+ 2 * z_channels if double_z else z_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x):
+ # timestep embedding
+ temb = None
+
+ # downsampling
+ hs = [self.conv_in(x)]
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](hs[-1], temb)
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h)
+ hs.append(h)
+ if i_level != self.num_resolutions - 1:
+ hs.append(self.down[i_level].downsample(hs[-1]))
+
+ # middle
+ h = hs[-1]
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # end
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+
+class Decoder(nn.Module):
+
+ def __init__(self,
+ *,
+ ch,
+ out_ch,
+ ch_mult=(1, 2, 4, 8),
+ num_res_blocks,
+ attn_resolutions,
+ dropout=0.0,
+ resamp_with_conv=True,
+ in_channels,
+ resolution,
+ z_channels,
+ give_pre_end=False,
+ tanh_out=False,
+ **ignorekwargs):
+ super().__init__()
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+ self.give_pre_end = give_pre_end
+ self.tanh_out = tanh_out
+
+ # compute block_in and curr_res at lowest res
+ block_in = ch * ch_mult[self.num_resolutions - 1]
+ curr_res = resolution // 2**(self.num_resolutions - 1)
+ self.z_shape = (1, z_channels, curr_res, curr_res)
+
+ # z to block_in
+ self.conv_in = torch.nn.Conv2d(
+ z_channels, block_in, kernel_size=3, stride=1, padding=1)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ self.mid.attn_1 = AttnBlock(block_in)
+ self.mid.block_2 = ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks + 1):
+ block.append(
+ ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(AttnBlock(block_in))
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ up.upsample = Upsample(block_in, resamp_with_conv)
+ curr_res = curr_res * 2
+ self.up.insert(0, up) # prepend to get consistent order
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(
+ block_in, out_ch, kernel_size=3, stride=1, padding=1)
+
+ def forward(self, z):
+ self.last_z_shape = z.shape
+
+ # timestep embedding
+ temb = None
+
+ # z to block_in
+ h = self.conv_in(z)
+
+ # middle
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks + 1):
+ h = self.up[i_level].block[i_block](h, temb)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h)
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+
+ # end
+ if self.give_pre_end:
+ return h
+
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ if self.tanh_out:
+ h = torch.tanh(h)
+ return h
+
+
+class AutoencoderKL(nn.Module):
+
+ def __init__(self,
+ ddconfig,
+ embed_dim,
+ ckpt_path=None,
+ image_key='image',
+ colorize_nlabels=None,
+ monitor=None,
+ ema_decay=None,
+ learn_logvar=False):
+ super().__init__()
+ self.learn_logvar = learn_logvar
+ self.image_key = image_key
+ self.encoder = Encoder(**ddconfig)
+ self.decoder = Decoder(**ddconfig)
+ assert ddconfig['double_z']
+ self.quant_conv = torch.nn.Conv2d(2 * ddconfig['z_channels'],
+ 2 * embed_dim, 1)
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim,
+ ddconfig['z_channels'], 1)
+ self.embed_dim = embed_dim
+ if colorize_nlabels is not None:
+ assert type(colorize_nlabels) == int
+ self.register_buffer('colorize',
+ torch.randn(3, colorize_nlabels, 1, 1))
+ if monitor is not None:
+ self.monitor = monitor
+
+ self.use_ema = ema_decay is not None
+
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path)
+
+ def init_from_ckpt(self, path):
+ sd = torch.load(path, map_location='cpu')['state_dict']
+ keys = list(sd.keys())
+
+ import collections
+ sd_new = collections.OrderedDict()
+
+ for k in keys:
+ if k.find('first_stage_model') >= 0:
+ k_new = k.split('first_stage_model.')[-1]
+ sd_new[k_new] = sd[k]
+
+ self.load_state_dict(sd_new, strict=True)
+
+ def on_train_batch_end(self, *args, **kwargs):
+ if self.use_ema:
+ self.model_ema(self)
+
+ def encode(self, x):
+ h = self.encoder(x)
+ moments = self.quant_conv(h)
+ posterior = DiagonalGaussianDistribution(moments)
+ return posterior
+
+ def decode(self, z):
+ z = self.post_quant_conv(z)
+ dec = self.decoder(z)
+ return dec
+
+ def forward(self, input, sample_posterior=True):
+ posterior = self.encode(input)
+ if sample_posterior:
+ z = posterior.sample()
+ else:
+ z = posterior.mode()
+ dec = self.decode(z)
+ return dec, posterior
+
+ def get_input(self, batch, k):
+ x = batch[k]
+ if len(x.shape) == 3:
+ x = x[..., None]
+ x = x.permute(0, 3, 1,
+ 2).to(memory_format=torch.contiguous_format).float()
+ return x
+
+ def get_last_layer(self):
+ return self.decoder.conv_out.weight
+
+ @torch.no_grad()
+ def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs):
+ log = dict()
+ x = self.get_input(batch, self.image_key)
+ x = x.to(self.device)
+ if not only_inputs:
+ xrec, posterior = self(x)
+ if x.shape[1] > 3:
+ # colorize with random projection
+ assert xrec.shape[1] > 3
+ x = self.to_rgb(x)
+ xrec = self.to_rgb(xrec)
+ log['samples'] = self.decode(torch.randn_like(posterior.sample()))
+ log['reconstructions'] = xrec
+ if log_ema or self.use_ema:
+ with self.ema_scope():
+ xrec_ema, posterior_ema = self(x)
+ if x.shape[1] > 3:
+ # colorize with random projection
+ assert xrec_ema.shape[1] > 3
+ xrec_ema = self.to_rgb(xrec_ema)
+ log['samples_ema'] = self.decode(
+ torch.randn_like(posterior_ema.sample()))
+ log['reconstructions_ema'] = xrec_ema
+ log['inputs'] = x
+ return log
+
+ def to_rgb(self, x):
+ assert self.image_key == 'segmentation'
+ if not hasattr(self, 'colorize'):
+ self.register_buffer('colorize',
+ torch.randn(3, x.shape[1], 1, 1).to(x))
+ x = F.conv2d(x, weight=self.colorize)
+ x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
+ return x
+
+
+class IdentityFirstStage(torch.nn.Module):
+
+ def __init__(self, *args, vq_interface=False, **kwargs):
+ self.vq_interface = vq_interface
+ super().__init__()
+
+ def encode(self, x, *args, **kwargs):
+ return x
+
+ def decode(self, x, *args, **kwargs):
+ return x
+
+ def quantize(self, x, *args, **kwargs):
+ if self.vq_interface:
+ return x, None, [None, None, None]
+ return x
+
+ def forward(self, x, *args, **kwargs):
+ return x
diff --git a/modelscope/models/multi_modal/video_synthesis/diffusion.py b/modelscope/models/multi_modal/video_synthesis/diffusion.py
new file mode 100644
index 00000000..4eba1f13
--- /dev/null
+++ b/modelscope/models/multi_modal/video_synthesis/diffusion.py
@@ -0,0 +1,227 @@
+# Part of the implementation is borrowed and modified from latent-diffusion,
+# publicly avaialbe at https://github.com/CompVis/latent-diffusion.
+# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.
+
+import torch
+
+__all__ = ['GaussianDiffusion', 'beta_schedule']
+
+
+def _i(tensor, t, x):
+ r"""Index tensor using t and format the output according to x.
+ """
+ shape = (x.size(0), ) + (1, ) * (x.ndim - 1)
+ return tensor[t].view(shape).to(x)
+
+
+def beta_schedule(schedule,
+ num_timesteps=1000,
+ init_beta=None,
+ last_beta=None):
+ if schedule == 'linear_sd':
+ return torch.linspace(
+ init_beta**0.5, last_beta**0.5, num_timesteps,
+ dtype=torch.float64)**2
+ else:
+ raise ValueError(f'Unsupported schedule: {schedule}')
+
+
+class GaussianDiffusion(object):
+ r""" Diffusion Model for DDIM.
+ "Denoising diffusion implicit models." by Song, Jiaming, Chenlin Meng, and Stefano Ermon.
+ See https://arxiv.org/abs/2010.02502
+ """
+
+ def __init__(self,
+ betas,
+ mean_type='eps',
+ var_type='learned_range',
+ loss_type='mse',
+ epsilon=1e-12,
+ rescale_timesteps=False):
+ # check input
+ if not isinstance(betas, torch.DoubleTensor):
+ betas = torch.tensor(betas, dtype=torch.float64)
+ assert min(betas) > 0 and max(betas) <= 1
+ assert mean_type in ['x0', 'x_{t-1}', 'eps']
+ assert var_type in [
+ 'learned', 'learned_range', 'fixed_large', 'fixed_small'
+ ]
+ assert loss_type in [
+ 'mse', 'rescaled_mse', 'kl', 'rescaled_kl', 'l1', 'rescaled_l1',
+ 'charbonnier'
+ ]
+ self.betas = betas
+ self.num_timesteps = len(betas)
+ self.mean_type = mean_type
+ self.var_type = var_type
+ self.loss_type = loss_type
+ self.epsilon = epsilon
+ self.rescale_timesteps = rescale_timesteps
+
+ # alphas
+ alphas = 1 - self.betas
+ self.alphas_cumprod = torch.cumprod(alphas, dim=0)
+ self.alphas_cumprod_prev = torch.cat(
+ [alphas.new_ones([1]), self.alphas_cumprod[:-1]])
+ self.alphas_cumprod_next = torch.cat(
+ [self.alphas_cumprod[1:],
+ alphas.new_zeros([1])])
+
+ # q(x_t | x_{t-1})
+ self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
+ self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0
+ - self.alphas_cumprod)
+ self.log_one_minus_alphas_cumprod = torch.log(1.0
+ - self.alphas_cumprod)
+ self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod)
+ self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod
+ - 1)
+
+ # q(x_{t-1} | x_t, x_0)
+ self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / (
+ 1.0 - self.alphas_cumprod)
+ self.posterior_log_variance_clipped = torch.log(
+ self.posterior_variance.clamp(1e-20))
+ self.posterior_mean_coef1 = betas * torch.sqrt(
+ self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
+ self.posterior_mean_coef2 = (
+ 1.0 - self.alphas_cumprod_prev) * torch.sqrt(alphas) / (
+ 1.0 - self.alphas_cumprod)
+
+ def p_mean_variance(self,
+ xt,
+ t,
+ model,
+ model_kwargs={},
+ clamp=None,
+ percentile=None,
+ guide_scale=None):
+ r"""Distribution of p(x_{t-1} | x_t).
+ """
+ # predict distribution
+ if guide_scale is None:
+ out = model(xt, self._scale_timesteps(t), **model_kwargs)
+ else:
+ # classifier-free guidance
+ # (model_kwargs[0]: conditional kwargs; model_kwargs[1]: non-conditional kwargs)
+ assert isinstance(model_kwargs, list) and len(model_kwargs) == 2
+ y_out = model(xt, self._scale_timesteps(t), **model_kwargs[0])
+ u_out = model(xt, self._scale_timesteps(t), **model_kwargs[1])
+ dim = y_out.size(1) if self.var_type.startswith(
+ 'fixed') else y_out.size(1) // 2
+ a = u_out[:, :dim]
+ b = guide_scale * (y_out[:, :dim] - u_out[:, :dim])
+ c = y_out[:, dim:]
+ out = torch.cat([a + b, c], dim=1)
+
+ # compute variance
+ if self.var_type == 'fixed_small':
+ var = _i(self.posterior_variance, t, xt)
+ log_var = _i(self.posterior_log_variance_clipped, t, xt)
+
+ # compute mean and x0
+ if self.mean_type == 'eps':
+ x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - _i(
+ self.sqrt_recipm1_alphas_cumprod, t, xt) * out
+ mu, _, _ = self.q_posterior_mean_variance(x0, xt, t)
+
+ # restrict the range of x0
+ if percentile is not None:
+ assert percentile > 0 and percentile <= 1 # e.g., 0.995
+ s = torch.quantile(
+ x0.flatten(1).abs(), percentile,
+ dim=1).clamp_(1.0).view(-1, 1, 1, 1)
+ x0 = torch.min(s, torch.max(-s, x0)) / s
+ elif clamp is not None:
+ x0 = x0.clamp(-clamp, clamp)
+ return mu, var, log_var, x0
+
+ def q_posterior_mean_variance(self, x0, xt, t):
+ r"""Distribution of q(x_{t-1} | x_t, x_0).
+ """
+ mu = _i(self.posterior_mean_coef1, t, xt) * x0 + _i(
+ self.posterior_mean_coef2, t, xt) * xt
+ var = _i(self.posterior_variance, t, xt)
+ log_var = _i(self.posterior_log_variance_clipped, t, xt)
+ return mu, var, log_var
+
+ @torch.no_grad()
+ def ddim_sample(self,
+ xt,
+ t,
+ model,
+ model_kwargs={},
+ clamp=None,
+ percentile=None,
+ condition_fn=None,
+ guide_scale=None,
+ ddim_timesteps=20,
+ eta=0.0):
+ r"""Sample from p(x_{t-1} | x_t) using DDIM.
+ - condition_fn: for classifier-based guidance (guided-diffusion).
+ - guide_scale: for classifier-free guidance (glide/dalle-2).
+ """
+ stride = self.num_timesteps // ddim_timesteps
+
+ # predict distribution of p(x_{t-1} | x_t)
+ _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp,
+ percentile, guide_scale)
+ if condition_fn is not None:
+ # x0 -> eps
+ alpha = _i(self.alphas_cumprod, t, xt)
+ eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / _i(
+ self.sqrt_recipm1_alphas_cumprod, t, xt)
+ eps = eps - (1 - alpha).sqrt() * condition_fn(
+ xt, self._scale_timesteps(t), **model_kwargs)
+
+ # eps -> x0
+ x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - _i(
+ self.sqrt_recipm1_alphas_cumprod, t, xt) * eps
+
+ # derive variables
+ eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / _i(
+ self.sqrt_recipm1_alphas_cumprod, t, xt)
+ alphas = _i(self.alphas_cumprod, t, xt)
+ alphas_prev = _i(self.alphas_cumprod, (t - stride).clamp(0), xt)
+ a = (1 - alphas_prev) / (1 - alphas)
+ b = (1 - alphas / alphas_prev)
+ sigmas = eta * torch.sqrt(a * b)
+
+ # random sample
+ noise = torch.randn_like(xt)
+ direction = torch.sqrt(1 - alphas_prev - sigmas**2) * eps
+ mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1)))
+ xt_1 = torch.sqrt(alphas_prev) * x0 + direction + mask * sigmas * noise
+ return xt_1, x0
+
+ @torch.no_grad()
+ def ddim_sample_loop(self,
+ noise,
+ model,
+ model_kwargs={},
+ clamp=None,
+ percentile=None,
+ condition_fn=None,
+ guide_scale=None,
+ ddim_timesteps=20,
+ eta=0.0):
+ # prepare input
+ b = noise.size(0)
+ xt = noise
+
+ # diffusion process (TODO: clamp is inaccurate! Consider replacing the stride by explicit prev/next steps)
+ steps = (1 + torch.arange(0, self.num_timesteps,
+ self.num_timesteps // ddim_timesteps)).clamp(
+ 0, self.num_timesteps - 1).flip(0)
+ for step in steps:
+ t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
+ xt, _ = self.ddim_sample(xt, t, model, model_kwargs, clamp,
+ percentile, condition_fn, guide_scale,
+ ddim_timesteps, eta)
+ return xt
+
+ def _scale_timesteps(self, t):
+ if self.rescale_timesteps:
+ return t.float() * 1000.0 / self.num_timesteps
+ return t
diff --git a/modelscope/models/multi_modal/video_synthesis/text_to_video_synthesis_model.py b/modelscope/models/multi_modal/video_synthesis/text_to_video_synthesis_model.py
new file mode 100644
index 00000000..1bcd6eda
--- /dev/null
+++ b/modelscope/models/multi_modal/video_synthesis/text_to_video_synthesis_model.py
@@ -0,0 +1,241 @@
+# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.
+
+import os
+from os import path as osp
+from typing import Any, Dict
+
+import open_clip
+import torch
+import torch.cuda.amp as amp
+from einops import rearrange
+
+from modelscope.metainfo import Models
+from modelscope.models.base import Model
+from modelscope.models.builder import MODELS
+from modelscope.models.multi_modal.video_synthesis.autoencoder import \
+ AutoencoderKL
+from modelscope.models.multi_modal.video_synthesis.diffusion import (
+ GaussianDiffusion, beta_schedule)
+from modelscope.models.multi_modal.video_synthesis.unet_sd import UNetSD
+from modelscope.utils.config import Config
+from modelscope.utils.constant import ModelFile, Tasks
+
+__all__ = ['TextToVideoSynthesis']
+
+
+@MODELS.register_module(
+ Tasks.text_to_video_synthesis, module_name=Models.video_synthesis)
+class TextToVideoSynthesis(Model):
+ r"""
+ task for text to video synthesis.
+
+ Attributes:
+ sd_model: denosing model using in this task.
+ diffusion: diffusion model for DDIM.
+ autoencoder: decode the latent representation into visual space with VQGAN.
+ clip_encoder: encode the text into text embedding.
+ """
+
+ def __init__(self, model_dir, *args, **kwargs):
+ r"""
+ Args:
+ model_dir (`str` or `os.PathLike`)
+ Can be either:
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co
+ or modelscope.cn. Valid model ids can be located at the root-level, like `bert-base-uncased`,
+ or namespaced under a user or organization name, like `dbmdz/bert-base-german-cased`.
+ - A path to a *directory* containing model weights saved using
+ [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
+ - A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In
+ this case, `from_tf` should be set to `True` and a configuration object should be provided as
+ `config` argument. This loading path is slower than converting the TensorFlow checkpoint in a
+ PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
+ - A path or url to a model folder containing a *flax checkpoint file* in *.msgpack* format (e.g,
+ `./flax_model/` containing `flax_model.msgpack`). In this case, `from_flax` should be set to
+ `True`.
+ """
+ super().__init__(model_dir=model_dir, *args, **kwargs)
+ self.device = torch.device('cuda') if torch.cuda.is_available() \
+ else torch.device('cpu')
+ self.config = Config.from_file(
+ osp.join(model_dir, ModelFile.CONFIGURATION))
+ cfg = self.config.model.model_cfg
+ cfg['temporal_attention'] = True if cfg[
+ 'temporal_attention'] == 'True' else False
+
+ # Initialize unet
+ self.sd_model = UNetSD(
+ in_dim=cfg['unet_in_dim'],
+ dim=cfg['unet_dim'],
+ y_dim=cfg['unet_y_dim'],
+ context_dim=cfg['unet_context_dim'],
+ out_dim=cfg['unet_out_dim'],
+ dim_mult=cfg['unet_dim_mult'],
+ num_heads=cfg['unet_num_heads'],
+ head_dim=cfg['unet_head_dim'],
+ num_res_blocks=cfg['unet_res_blocks'],
+ attn_scales=cfg['unet_attn_scales'],
+ dropout=cfg['unet_dropout'],
+ temporal_attention=cfg['temporal_attention'])
+ self.sd_model.load_state_dict(
+ torch.load(
+ osp.join(model_dir, self.config.model.model_args.ckpt_unet)),
+ strict=True)
+ self.sd_model.eval()
+ self.sd_model.to(self.device)
+
+ # Initialize diffusion
+ betas = beta_schedule(
+ 'linear_sd',
+ cfg['num_timesteps'],
+ init_beta=0.00085,
+ last_beta=0.0120)
+ self.diffusion = GaussianDiffusion(
+ betas=betas,
+ mean_type=cfg['mean_type'],
+ var_type=cfg['var_type'],
+ loss_type=cfg['loss_type'],
+ rescale_timesteps=False)
+
+ # Initialize autoencoder
+ ddconfig = {
+ 'double_z': True,
+ 'z_channels': 4,
+ 'resolution': 256,
+ 'in_channels': 3,
+ 'out_ch': 3,
+ 'ch': 128,
+ 'ch_mult': [1, 2, 4, 4],
+ 'num_res_blocks': 2,
+ 'attn_resolutions': [],
+ 'dropout': 0.0
+ }
+ self.autoencoder = AutoencoderKL(
+ ddconfig, 4,
+ osp.join(model_dir, self.config.model.model_args.ckpt_autoencoder))
+ if self.config.model.model_args.tiny_gpu == 1:
+ self.autoencoder.to('cpu')
+ else:
+ self.autoencoder.to(self.device)
+ self.autoencoder.eval()
+
+ # Initialize Open clip
+ self.clip_encoder = FrozenOpenCLIPEmbedder(
+ version=osp.join(model_dir,
+ self.config.model.model_args.ckpt_clip),
+ layer='penultimate')
+ if self.config.model.model_args.tiny_gpu == 1:
+ self.clip_encoder.to('cpu')
+ else:
+ self.clip_encoder.to(self.device)
+
+ def forward(self, input: Dict[str, Any]):
+ r"""
+ The entry function of text to image synthesis task.
+ 1. Using diffusion model to generate the video's latent representation.
+ 2. Using vqgan model (autoencoder) to decode the video's latent representation to visual space.
+
+ Args:
+ input (`Dict[Str, Any]`):
+ The input of the task
+ Returns:
+ A generated video (as pytorch tensor).
+ """
+ y = input['text_emb']
+ zero_y = input['text_emb_zero']
+ context = torch.cat([zero_y, y], dim=0).to(self.device)
+ # synthesis
+ with torch.no_grad():
+ num_sample = 1 # here let b = 1
+ max_frames = self.config.model.model_args.max_frames
+ latent_h, latent_w = 32, 32
+ with amp.autocast(enabled=True):
+ x0 = self.diffusion.ddim_sample_loop(
+ noise=torch.randn(num_sample, 4, max_frames, latent_h,
+ latent_w).to(
+ self.device), # shape: b c f h w
+ model=self.sd_model,
+ model_kwargs=[{
+ 'y':
+ context[1].unsqueeze(0).repeat(num_sample, 1, 1)
+ }, {
+ 'y':
+ context[0].unsqueeze(0).repeat(num_sample, 1, 1)
+ }],
+ guide_scale=9.0,
+ ddim_timesteps=50,
+ eta=0.0)
+
+ scale_factor = 0.18215
+ video_data = 1. / scale_factor * x0
+ bs_vd = video_data.shape[0]
+ video_data = rearrange(video_data, 'b c f h w -> (b f) c h w')
+ self.autoencoder.to(self.device)
+ video_data = self.autoencoder.decode(video_data)
+ if self.config.model.model_args.tiny_gpu == 1:
+ self.autoencoder.to('cpu')
+ video_data = rearrange(
+ video_data, '(b f) c h w -> b c f h w', b=bs_vd)
+ return video_data.type(torch.float32).cpu()
+
+
+class FrozenOpenCLIPEmbedder(torch.nn.Module):
+ """
+ Uses the OpenCLIP transformer encoder for text
+ """
+ LAYERS = ['last', 'penultimate']
+
+ def __init__(self,
+ arch='ViT-H-14',
+ version='open_clip_pytorch_model.bin',
+ device='cuda',
+ max_length=77,
+ freeze=True,
+ layer='last'):
+ super().__init__()
+ assert layer in self.LAYERS
+ model, _, _ = open_clip.create_model_and_transforms(
+ arch, device=torch.device('cpu'), pretrained=version)
+ del model.visual
+ self.model = model
+
+ self.device = device
+ self.max_length = max_length
+ if freeze:
+ self.freeze()
+ self.layer = layer
+ if self.layer == 'last':
+ self.layer_idx = 0
+ elif self.layer == 'penultimate':
+ self.layer_idx = 1
+ else:
+ raise NotImplementedError()
+
+ def freeze(self):
+ self.model = self.model.eval()
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, text):
+ tokens = open_clip.tokenize(text)
+ z = self.encode_with_transformer(tokens.to(self.device))
+ return z
+
+ def encode_with_transformer(self, text):
+ x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
+ x = x + self.model.positional_embedding
+ x = x.permute(1, 0, 2) # NLD -> LND
+ x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
+ x = x.permute(1, 0, 2) # LND -> NLD
+ x = self.model.ln_final(x)
+ return x
+
+ def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
+ for i, r in enumerate(self.model.transformer.resblocks):
+ if i == len(self.model.transformer.resblocks) - self.layer_idx:
+ break
+ x = r(x, attn_mask=attn_mask)
+ return x
+
+ def encode(self, text):
+ return self(text)
diff --git a/modelscope/models/multi_modal/video_synthesis/unet_sd.py b/modelscope/models/multi_modal/video_synthesis/unet_sd.py
new file mode 100644
index 00000000..f3c764eb
--- /dev/null
+++ b/modelscope/models/multi_modal/video_synthesis/unet_sd.py
@@ -0,0 +1,1098 @@
+# Part of the implementation is borrowed and modified from stable-diffusion,
+# publicly avaialbe at https://github.com/Stability-AI/stablediffusion.
+# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.
+
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange, repeat
+
+__all__ = ['UNetSD']
+
+
+def exists(x):
+ return x is not None
+
+
+def default(val, d):
+ if exists(val):
+ return val
+ return d() if callable(d) else d
+
+
+class UNetSD(nn.Module):
+
+ def __init__(self,
+ in_dim=7,
+ dim=512,
+ y_dim=512,
+ context_dim=512,
+ out_dim=6,
+ dim_mult=[1, 2, 3, 4],
+ num_heads=None,
+ head_dim=64,
+ num_res_blocks=3,
+ attn_scales=[1 / 2, 1 / 4, 1 / 8],
+ use_scale_shift_norm=True,
+ dropout=0.1,
+ temporal_attn_times=2,
+ temporal_attention=True,
+ use_checkpoint=False,
+ use_image_dataset=False,
+ use_fps_condition=False,
+ use_sim_mask=False):
+ embed_dim = dim * 4
+ num_heads = num_heads if num_heads else dim // 32
+ super(UNetSD, self).__init__()
+ self.in_dim = in_dim
+ self.dim = dim
+ self.y_dim = y_dim
+ self.context_dim = context_dim
+ self.embed_dim = embed_dim
+ self.out_dim = out_dim
+ self.dim_mult = dim_mult
+ self.num_heads = num_heads
+ # parameters for spatial/temporal attention
+ self.head_dim = head_dim
+ self.num_res_blocks = num_res_blocks
+ self.attn_scales = attn_scales
+ self.use_scale_shift_norm = use_scale_shift_norm
+ self.temporal_attn_times = temporal_attn_times
+ self.temporal_attention = temporal_attention
+ self.use_checkpoint = use_checkpoint
+ self.use_image_dataset = use_image_dataset
+ self.use_fps_condition = use_fps_condition
+ self.use_sim_mask = use_sim_mask
+ use_linear_in_temporal = False
+ transformer_depth = 1
+ disabled_sa = False
+ # params
+ enc_dims = [dim * u for u in [1] + dim_mult]
+ dec_dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
+ shortcut_dims = []
+ scale = 1.0
+
+ # embeddings
+ self.time_embed = nn.Sequential(
+ nn.Linear(dim, embed_dim), nn.SiLU(),
+ nn.Linear(embed_dim, embed_dim))
+
+ if self.use_fps_condition:
+ self.fps_embedding = nn.Sequential(
+ nn.Linear(dim, embed_dim), nn.SiLU(),
+ nn.Linear(embed_dim, embed_dim))
+ nn.init.zeros_(self.fps_embedding[-1].weight)
+ nn.init.zeros_(self.fps_embedding[-1].bias)
+
+ # encoder
+ self.input_blocks = nn.ModuleList()
+ init_block = nn.ModuleList([nn.Conv2d(self.in_dim, dim, 3, padding=1)])
+
+ if temporal_attention:
+ init_block.append(
+ TemporalTransformer(
+ dim,
+ num_heads,
+ head_dim,
+ depth=transformer_depth,
+ context_dim=context_dim,
+ disable_self_attn=disabled_sa,
+ use_linear=use_linear_in_temporal,
+ multiply_zero=use_image_dataset))
+
+ self.input_blocks.append(init_block)
+ shortcut_dims.append(dim)
+ for i, (in_dim,
+ out_dim) in enumerate(zip(enc_dims[:-1], enc_dims[1:])):
+ for j in range(num_res_blocks):
+ # residual (+attention) blocks
+ block = nn.ModuleList([
+ ResBlock(
+ in_dim,
+ embed_dim,
+ dropout,
+ out_channels=out_dim,
+ use_scale_shift_norm=False,
+ use_image_dataset=use_image_dataset,
+ )
+ ])
+ if scale in attn_scales:
+ block.append(
+ SpatialTransformer(
+ out_dim,
+ out_dim // head_dim,
+ head_dim,
+ depth=1,
+ context_dim=self.context_dim,
+ disable_self_attn=False,
+ use_linear=True))
+ if self.temporal_attention:
+ block.append(
+ TemporalTransformer(
+ out_dim,
+ out_dim // head_dim,
+ head_dim,
+ depth=transformer_depth,
+ context_dim=context_dim,
+ disable_self_attn=disabled_sa,
+ use_linear=use_linear_in_temporal,
+ multiply_zero=use_image_dataset))
+
+ in_dim = out_dim
+ self.input_blocks.append(block)
+ shortcut_dims.append(out_dim)
+
+ # downsample
+ if i != len(dim_mult) - 1 and j == num_res_blocks - 1:
+ downsample = Downsample(
+ out_dim, True, dims=2, out_channels=out_dim)
+ shortcut_dims.append(out_dim)
+ scale /= 2.0
+ self.input_blocks.append(downsample)
+
+ # middle
+ self.middle_block = nn.ModuleList([
+ ResBlock(
+ out_dim,
+ embed_dim,
+ dropout,
+ use_scale_shift_norm=False,
+ use_image_dataset=use_image_dataset,
+ ),
+ SpatialTransformer(
+ out_dim,
+ out_dim // head_dim,
+ head_dim,
+ depth=1,
+ context_dim=self.context_dim,
+ disable_self_attn=False,
+ use_linear=True)
+ ])
+
+ if self.temporal_attention:
+ self.middle_block.append(
+ TemporalTransformer(
+ out_dim,
+ out_dim // head_dim,
+ head_dim,
+ depth=transformer_depth,
+ context_dim=context_dim,
+ disable_self_attn=disabled_sa,
+ use_linear=use_linear_in_temporal,
+ multiply_zero=use_image_dataset,
+ ))
+
+ self.middle_block.append(
+ ResBlock(
+ out_dim,
+ embed_dim,
+ dropout,
+ use_scale_shift_norm=False,
+ use_image_dataset=use_image_dataset,
+ ))
+
+ # decoder
+ self.output_blocks = nn.ModuleList()
+ for i, (in_dim,
+ out_dim) in enumerate(zip(dec_dims[:-1], dec_dims[1:])):
+ for j in range(num_res_blocks + 1):
+ # residual (+attention) blocks
+ block = nn.ModuleList([
+ ResBlock(
+ in_dim + shortcut_dims.pop(),
+ embed_dim,
+ dropout,
+ out_dim,
+ use_scale_shift_norm=False,
+ use_image_dataset=use_image_dataset,
+ )
+ ])
+ if scale in attn_scales:
+ block.append(
+ SpatialTransformer(
+ out_dim,
+ out_dim // head_dim,
+ head_dim,
+ depth=1,
+ context_dim=1024,
+ disable_self_attn=False,
+ use_linear=True))
+
+ if self.temporal_attention:
+ block.append(
+ TemporalTransformer(
+ out_dim,
+ out_dim // head_dim,
+ head_dim,
+ depth=transformer_depth,
+ context_dim=context_dim,
+ disable_self_attn=disabled_sa,
+ use_linear=use_linear_in_temporal,
+ multiply_zero=use_image_dataset))
+ in_dim = out_dim
+
+ # upsample
+ if i != len(dim_mult) - 1 and j == num_res_blocks:
+ upsample = Upsample(
+ out_dim, True, dims=2.0, out_channels=out_dim)
+ scale *= 2.0
+ block.append(upsample)
+ self.output_blocks.append(block)
+
+ # head
+ self.out = nn.Sequential(
+ nn.GroupNorm(32, out_dim), nn.SiLU(),
+ nn.Conv2d(out_dim, self.out_dim, 3, padding=1))
+
+ # zero out the last layer params
+ nn.init.zeros_(self.out[-1].weight)
+
+ def forward(
+ self,
+ x,
+ t,
+ y,
+ fps=None,
+ video_mask=None,
+ focus_present_mask=None,
+ prob_focus_present=0.,
+ mask_last_frame_num=0 # mask last frame num
+ ):
+ """
+ prob_focus_present: probability at which a given batch sample will focus on the present
+ (0. is all off, 1. is completely arrested attention across time)
+ """
+ batch, device = x.shape[0], x.device
+ self.batch = batch
+
+ # image and video joint training, if mask_last_frame_num is set, prob_focus_present will be ignored
+ if mask_last_frame_num > 0:
+ focus_present_mask = None
+ video_mask[-mask_last_frame_num:] = False
+ else:
+ focus_present_mask = default(
+ focus_present_mask, lambda: prob_mask_like(
+ (batch, ), prob_focus_present, device=device))
+
+ time_rel_pos_bias = None
+ # embeddings
+ if self.use_fps_condition and fps is not None:
+ e = self.time_embed(sinusoidal_embedding(
+ t, self.dim)) + self.fps_embedding(
+ sinusoidal_embedding(fps, self.dim))
+ else:
+ e = self.time_embed(sinusoidal_embedding(t, self.dim))
+ context = y
+
+ # repeat f times for spatial e and context
+ f = x.shape[2]
+ e = e.repeat_interleave(repeats=f, dim=0)
+ context = context.repeat_interleave(repeats=f, dim=0)
+
+ # always in shape (b f) c h w, except for temporal layer
+ x = rearrange(x, 'b c f h w -> (b f) c h w')
+ # encoder
+ xs = []
+ for block in self.input_blocks:
+ x = self._forward_single(block, x, e, context, time_rel_pos_bias,
+ focus_present_mask, video_mask)
+ xs.append(x)
+
+ # middle
+ for block in self.middle_block:
+ x = self._forward_single(block, x, e, context, time_rel_pos_bias,
+ focus_present_mask, video_mask)
+
+ # decoder
+ for block in self.output_blocks:
+ x = torch.cat([x, xs.pop()], dim=1)
+ x = self._forward_single(
+ block,
+ x,
+ e,
+ context,
+ time_rel_pos_bias,
+ focus_present_mask,
+ video_mask,
+ reference=xs[-1] if len(xs) > 0 else None)
+
+ # head
+ x = self.out(x)
+ # reshape back to (b c f h w)
+ x = rearrange(x, '(b f) c h w -> b c f h w', b=batch)
+ return x
+
+ def _forward_single(self,
+ module,
+ x,
+ e,
+ context,
+ time_rel_pos_bias,
+ focus_present_mask,
+ video_mask,
+ reference=None):
+ if isinstance(module, ResidualBlock):
+ x = x.contiguous()
+ x = module(x, e, reference)
+ elif isinstance(module, ResBlock):
+ x = x.contiguous()
+ x = module(x, e, self.batch)
+ elif isinstance(module, SpatialTransformer):
+ x = module(x, context)
+ elif isinstance(module, TemporalTransformer):
+ x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch)
+ x = module(x, context)
+ x = rearrange(x, 'b c f h w -> (b f) c h w')
+ elif isinstance(module, CrossAttention):
+ x = module(x, context)
+ elif isinstance(module, BasicTransformerBlock):
+ x = module(x, context)
+ elif isinstance(module, FeedForward):
+ x = module(x, context)
+ elif isinstance(module, Upsample):
+ x = module(x)
+ elif isinstance(module, Downsample):
+ x = module(x)
+ elif isinstance(module, Resample):
+ x = module(x, reference)
+ elif isinstance(module, nn.ModuleList):
+ for block in module:
+ x = self._forward_single(block, x, e, context,
+ time_rel_pos_bias, focus_present_mask,
+ video_mask, reference)
+ else:
+ x = module(x)
+ return x
+
+
+def sinusoidal_embedding(timesteps, dim):
+ # check input
+ half = dim // 2
+ timesteps = timesteps.float()
+ # compute sinusoidal embedding
+ sinusoid = torch.outer(
+ timesteps, torch.pow(10000,
+ -torch.arange(half).to(timesteps).div(half)))
+ x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
+ if dim % 2 != 0:
+ x = torch.cat([x, torch.zeros_like(x[:, :1])], dim=1)
+ return x
+
+
+class CrossAttention(nn.Module):
+
+ def __init__(self,
+ query_dim,
+ context_dim=None,
+ heads=8,
+ dim_head=64,
+ dropout=0.):
+ super().__init__()
+ inner_dim = dim_head * heads
+ context_dim = default(context_dim, query_dim)
+
+ self.scale = dim_head**-0.5
+ self.heads = heads
+
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
+
+ self.to_out = nn.Sequential(
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
+
+ def forward(self, x, context=None, mask=None):
+ h = self.heads
+
+ q = self.to_q(x)
+ context = default(context, x)
+ k = self.to_k(context)
+ v = self.to_v(context)
+
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
+ (q, k, v))
+ sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale
+ del q, k
+
+ if exists(mask):
+ mask = rearrange(mask, 'b ... -> b (...)')
+ max_neg_value = -torch.finfo(sim.dtype).max
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
+ sim.masked_fill_(~mask, max_neg_value)
+
+ # attention, what we cannot get enough of
+ sim = sim.softmax(dim=-1)
+
+ out = torch.einsum('b i j, b j d -> b i d', sim, v)
+ out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
+ return self.to_out(out)
+
+
+class SpatialTransformer(nn.Module):
+ """
+ Transformer block for image-like data in spatial axis.
+ First, project the input (aka embedding)
+ and reshape to b, t, d.
+ Then apply standard transformer action.
+ Finally, reshape to image
+ NEW: use_linear for more efficiency instead of the 1x1 convs
+ """
+
+ def __init__(self,
+ in_channels,
+ n_heads,
+ d_head,
+ depth=1,
+ dropout=0.,
+ context_dim=None,
+ disable_self_attn=False,
+ use_linear=False,
+ use_checkpoint=True):
+ super().__init__()
+ if exists(context_dim) and not isinstance(context_dim, list):
+ context_dim = [context_dim]
+ self.in_channels = in_channels
+ inner_dim = n_heads * d_head
+ self.norm = torch.nn.GroupNorm(
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
+ if not use_linear:
+ self.proj_in = nn.Conv2d(
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
+ else:
+ self.proj_in = nn.Linear(in_channels, inner_dim)
+
+ self.transformer_blocks = nn.ModuleList([
+ BasicTransformerBlock(
+ inner_dim,
+ n_heads,
+ d_head,
+ dropout=dropout,
+ context_dim=context_dim[d],
+ disable_self_attn=disable_self_attn,
+ checkpoint=use_checkpoint) for d in range(depth)
+ ])
+ if not use_linear:
+ self.proj_out = zero_module(
+ nn.Conv2d(
+ inner_dim, in_channels, kernel_size=1, stride=1,
+ padding=0))
+ else:
+ self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
+ self.use_linear = use_linear
+
+ def forward(self, x, context=None):
+ # note: if no context is given, cross-attention defaults to self-attention
+ if not isinstance(context, list):
+ context = [context]
+ b, c, h, w = x.shape
+ x_in = x
+ x = self.norm(x)
+ if not self.use_linear:
+ x = self.proj_in(x)
+ x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
+ if self.use_linear:
+ x = self.proj_in(x)
+ for i, block in enumerate(self.transformer_blocks):
+ x = block(x, context=context[i])
+ if self.use_linear:
+ x = self.proj_out(x)
+ x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
+ if not self.use_linear:
+ x = self.proj_out(x)
+ return x + x_in
+
+
+class TemporalTransformer(nn.Module):
+ """
+ Transformer block for image-like data in temporal axis.
+ First, reshape to b, t, d.
+ Then apply standard transformer action.
+ Finally, reshape to image
+ """
+
+ def __init__(self,
+ in_channels,
+ n_heads,
+ d_head,
+ depth=1,
+ dropout=0.,
+ context_dim=None,
+ disable_self_attn=False,
+ use_linear=False,
+ use_checkpoint=True,
+ only_self_att=True,
+ multiply_zero=False):
+ super().__init__()
+ self.multiply_zero = multiply_zero
+ self.only_self_att = only_self_att
+ if self.only_self_att:
+ context_dim = None
+ if not isinstance(context_dim, list):
+ context_dim = [context_dim]
+ self.in_channels = in_channels
+ inner_dim = n_heads * d_head
+ self.norm = torch.nn.GroupNorm(
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
+ if not use_linear:
+ self.proj_in = nn.Conv1d(
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
+ else:
+ self.proj_in = nn.Linear(in_channels, inner_dim)
+
+ self.transformer_blocks = nn.ModuleList([
+ BasicTransformerBlock(
+ inner_dim,
+ n_heads,
+ d_head,
+ dropout=dropout,
+ context_dim=context_dim[d],
+ checkpoint=use_checkpoint) for d in range(depth)
+ ])
+ if not use_linear:
+ self.proj_out = zero_module(
+ nn.Conv1d(
+ inner_dim, in_channels, kernel_size=1, stride=1,
+ padding=0))
+ else:
+ self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
+ self.use_linear = use_linear
+
+ def forward(self, x, context=None):
+ # note: if no context is given, cross-attention defaults to self-attention
+ if self.only_self_att:
+ context = None
+ if not isinstance(context, list):
+ context = [context]
+ b, c, f, h, w = x.shape
+ x_in = x
+ x = self.norm(x)
+
+ if not self.use_linear:
+ x = rearrange(x, 'b c f h w -> (b h w) c f').contiguous()
+ x = self.proj_in(x)
+ if self.use_linear:
+ x = rearrange(
+ x, '(b f) c h w -> b (h w) f c', f=self.frames).contiguous()
+ x = self.proj_in(x)
+
+ if self.only_self_att:
+ x = rearrange(x, 'bhw c f -> bhw f c').contiguous()
+ for i, block in enumerate(self.transformer_blocks):
+ x = block(x)
+ x = rearrange(x, '(b hw) f c -> b hw f c', b=b).contiguous()
+ else:
+ x = rearrange(x, '(b hw) c f -> b hw f c', b=b).contiguous()
+ for i, block in enumerate(self.transformer_blocks):
+ context[i] = rearrange(
+ context[i], '(b f) l con -> b f l con',
+ f=self.frames).contiguous()
+ # calculate each batch one by one (since number in shape could not greater then 65,535 for some package)
+ for j in range(b):
+ context_i_j = repeat(
+ context[i][j],
+ 'f l con -> (f r) l con',
+ r=(h * w) // self.frames,
+ f=self.frames).contiguous()
+ x[j] = block(x[j], context=context_i_j)
+
+ if self.use_linear:
+ x = self.proj_out(x)
+ x = rearrange(x, 'b (h w) f c -> b f c h w', h=h, w=w).contiguous()
+ if not self.use_linear:
+ x = rearrange(x, 'b hw f c -> (b hw) c f').contiguous()
+ x = self.proj_out(x)
+ x = rearrange(
+ x, '(b h w) c f -> b c f h w', b=b, h=h, w=w).contiguous()
+
+ if self.multiply_zero:
+ x = 0.0 * x + x_in
+ else:
+ x = x + x_in
+ return x
+
+
+class BasicTransformerBlock(nn.Module):
+
+ def __init__(self,
+ dim,
+ n_heads,
+ d_head,
+ dropout=0.,
+ context_dim=None,
+ gated_ff=True,
+ checkpoint=True,
+ disable_self_attn=False):
+ super().__init__()
+ attn_cls = CrossAttention
+ self.disable_self_attn = disable_self_attn
+ self.attn1 = attn_cls(
+ query_dim=dim,
+ heads=n_heads,
+ dim_head=d_head,
+ dropout=dropout,
+ context_dim=context_dim if self.disable_self_attn else
+ None) # is a self-attention if not self.disable_self_attn
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
+ self.attn2 = attn_cls(
+ query_dim=dim,
+ context_dim=context_dim,
+ heads=n_heads,
+ dim_head=d_head,
+ dropout=dropout) # is self-attn if context is none
+ self.norm1 = nn.LayerNorm(dim)
+ self.norm2 = nn.LayerNorm(dim)
+ self.norm3 = nn.LayerNorm(dim)
+ self.checkpoint = checkpoint
+
+ def forward(self, x, context=None):
+ x = self.attn1(
+ self.norm1(x),
+ context=context if self.disable_self_attn else None) + x
+ x = self.attn2(self.norm2(x), context=context) + x
+ x = self.ff(self.norm3(x)) + x
+ return x
+
+
+# feedforward
+class GEGLU(nn.Module):
+
+ def __init__(self, dim_in, dim_out):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out * 2)
+
+ def forward(self, x):
+ x, gate = self.proj(x).chunk(2, dim=-1)
+ return x * F.gelu(gate)
+
+
+def zero_module(module):
+ """
+ Zero out the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+
+class FeedForward(nn.Module):
+
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
+ super().__init__()
+ inner_dim = int(dim * mult)
+ dim_out = default(dim_out, dim)
+ project_in = nn.Sequential(nn.Linear(
+ dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)
+
+ self.net = nn.Sequential(project_in, nn.Dropout(dropout),
+ nn.Linear(inner_dim, dim_out))
+
+ def forward(self, x):
+ return self.net(x)
+
+
+class Upsample(nn.Module):
+ """
+ An upsampling layer with an optional convolution.
+ :param channels: channels in the inputs and outputs.
+ :param use_conv: a bool determining if a convolution is applied.
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ upsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(self,
+ channels,
+ use_conv,
+ dims=2,
+ out_channels=None,
+ padding=1):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.dims = dims
+ if use_conv:
+ self.conv = nn.Conv2d(
+ self.channels, self.out_channels, 3, padding=padding)
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ if self.dims == 3:
+ x = F.interpolate(
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2),
+ mode='nearest')
+ else:
+ x = F.interpolate(x, scale_factor=2, mode='nearest')
+ if self.use_conv:
+ x = self.conv(x)
+ return x
+
+
+class ResBlock(nn.Module):
+ """
+ A residual block that can optionally change the number of channels.
+ :param channels: the number of input channels.
+ :param emb_channels: the number of timestep embedding channels.
+ :param dropout: the rate of dropout.
+ :param out_channels: if specified, the number of out channels.
+ :param use_conv: if True and out_channels is specified, use a spatial
+ convolution instead of a smaller 1x1 convolution to change the
+ channels in the skip connection.
+ :param dims: determines if the signal is 1D, 2D, or 3D.
+ :param up: if True, use this block for upsampling.
+ :param down: if True, use this block for downsampling.
+ :param use_temporal_conv: if True, use the temporal convolution.
+ :param use_image_dataset: if True, the temporal parameters will not be optimized.
+ """
+
+ def __init__(
+ self,
+ channels,
+ emb_channels,
+ dropout,
+ out_channels=None,
+ use_conv=False,
+ use_scale_shift_norm=False,
+ dims=2,
+ up=False,
+ down=False,
+ use_temporal_conv=True,
+ use_image_dataset=False,
+ ):
+ super().__init__()
+ self.channels = channels
+ self.emb_channels = emb_channels
+ self.dropout = dropout
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.use_scale_shift_norm = use_scale_shift_norm
+ self.use_temporal_conv = use_temporal_conv
+
+ self.in_layers = nn.Sequential(
+ nn.GroupNorm(32, channels),
+ nn.SiLU(),
+ nn.Conv2d(channels, self.out_channels, 3, padding=1),
+ )
+
+ self.updown = up or down
+
+ if up:
+ self.h_upd = Upsample(channels, False, dims)
+ self.x_upd = Upsample(channels, False, dims)
+ elif down:
+ self.h_upd = Downsample(channels, False, dims)
+ self.x_upd = Downsample(channels, False, dims)
+ else:
+ self.h_upd = self.x_upd = nn.Identity()
+
+ self.emb_layers = nn.Sequential(
+ nn.SiLU(),
+ nn.Linear(
+ emb_channels,
+ 2 * self.out_channels
+ if use_scale_shift_norm else self.out_channels,
+ ),
+ )
+ self.out_layers = nn.Sequential(
+ nn.GroupNorm(32, self.out_channels),
+ nn.SiLU(),
+ nn.Dropout(p=dropout),
+ zero_module(
+ nn.Conv2d(self.out_channels, self.out_channels, 3, padding=1)),
+ )
+
+ if self.out_channels == channels:
+ self.skip_connection = nn.Identity()
+ elif use_conv:
+ self.skip_connection = conv_nd(
+ dims, channels, self.out_channels, 3, padding=1)
+ else:
+ self.skip_connection = nn.Conv2d(channels, self.out_channels, 1)
+
+ if self.use_temporal_conv:
+ self.temopral_conv = TemporalConvBlock_v2(
+ self.out_channels,
+ self.out_channels,
+ dropout=0.1,
+ use_image_dataset=use_image_dataset)
+
+ def forward(self, x, emb, batch_size):
+ """
+ Apply the block to a Tensor, conditioned on a timestep embedding.
+ :param x: an [N x C x ...] Tensor of features.
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
+ :return: an [N x C x ...] Tensor of outputs.
+ """
+ return self._forward(x, emb, batch_size)
+
+ def _forward(self, x, emb, batch_size):
+ if self.updown:
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
+ h = in_rest(x)
+ h = self.h_upd(h)
+ x = self.x_upd(x)
+ h = in_conv(h)
+ else:
+ h = self.in_layers(x)
+ emb_out = self.emb_layers(emb).type(h.dtype)
+ while len(emb_out.shape) < len(h.shape):
+ emb_out = emb_out[..., None]
+ if self.use_scale_shift_norm:
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
+ scale, shift = torch.chunk(emb_out, 2, dim=1)
+ h = out_norm(h) * (1 + scale) + shift
+ h = out_rest(h)
+ else:
+ h = h + emb_out
+ h = self.out_layers(h)
+ h = self.skip_connection(x) + h
+
+ if self.use_temporal_conv:
+ h = rearrange(h, '(b f) c h w -> b c f h w', b=batch_size)
+ h = self.temopral_conv(h)
+ h = rearrange(h, 'b c f h w -> (b f) c h w')
+ return h
+
+
+class Downsample(nn.Module):
+ """
+ A downsampling layer with an optional convolution.
+ :param channels: channels in the inputs and outputs.
+ :param use_conv: a bool determining if a convolution is applied.
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ downsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(self,
+ channels,
+ use_conv,
+ dims=2,
+ out_channels=None,
+ padding=1):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.dims = dims
+ stride = 2 if dims != 3 else (1, 2, 2)
+ if self.use_conv:
+ self.op = nn.Conv2d(
+ self.channels,
+ self.out_channels,
+ 3,
+ stride=stride,
+ padding=padding)
+ else:
+ assert self.channels == self.out_channels
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ return self.op(x)
+
+
+class Resample(nn.Module):
+
+ def __init__(self, in_dim, out_dim, mode):
+ assert mode in ['none', 'upsample', 'downsample']
+ super(Resample, self).__init__()
+ self.in_dim = in_dim
+ self.out_dim = out_dim
+ self.mode = mode
+
+ def forward(self, x, reference=None):
+ if self.mode == 'upsample':
+ assert reference is not None
+ x = F.interpolate(x, size=reference.shape[-2:], mode='nearest')
+ elif self.mode == 'downsample':
+ x = F.adaptive_avg_pool2d(
+ x, output_size=tuple(u // 2 for u in x.shape[-2:]))
+ return x
+
+
+class ResidualBlock(nn.Module):
+
+ def __init__(self,
+ in_dim,
+ embed_dim,
+ out_dim,
+ use_scale_shift_norm=True,
+ mode='none',
+ dropout=0.0):
+ super(ResidualBlock, self).__init__()
+ self.in_dim = in_dim
+ self.embed_dim = embed_dim
+ self.out_dim = out_dim
+ self.use_scale_shift_norm = use_scale_shift_norm
+ self.mode = mode
+
+ # layers
+ self.layer1 = nn.Sequential(
+ nn.GroupNorm(32, in_dim), nn.SiLU(),
+ nn.Conv2d(in_dim, out_dim, 3, padding=1))
+ self.resample = Resample(in_dim, in_dim, mode)
+ self.embedding = nn.Sequential(
+ nn.SiLU(),
+ nn.Linear(embed_dim,
+ out_dim * 2 if use_scale_shift_norm else out_dim))
+ self.layer2 = nn.Sequential(
+ nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout),
+ nn.Conv2d(out_dim, out_dim, 3, padding=1))
+ self.shortcut = nn.Identity() if in_dim == out_dim else nn.Conv2d(
+ in_dim, out_dim, 1)
+ # zero out the last layer params
+ nn.init.zeros_(self.layer2[-1].weight)
+
+ def forward(self, x, e, reference=None):
+ identity = self.resample(x, reference)
+ x = self.layer1[-1](self.resample(self.layer1[:-1](x), reference))
+ e = self.embedding(e).unsqueeze(-1).unsqueeze(-1).type(x.dtype)
+ if self.use_scale_shift_norm:
+ scale, shift = e.chunk(2, dim=1)
+ x = self.layer2[0](x) * (1 + scale) + shift
+ x = self.layer2[1:](x)
+ else:
+ x = x + e
+ x = self.layer2(x)
+ x = x + self.shortcut(identity)
+ return x
+
+
+class AttentionBlock(nn.Module):
+
+ def __init__(self, dim, context_dim=None, num_heads=None, head_dim=None):
+ # consider head_dim first, then num_heads
+ num_heads = dim // head_dim if head_dim else num_heads
+ head_dim = dim // num_heads
+ assert num_heads * head_dim == dim
+ super(AttentionBlock, self).__init__()
+ self.dim = dim
+ self.context_dim = context_dim
+ self.num_heads = num_heads
+ self.head_dim = head_dim
+ self.scale = math.pow(head_dim, -0.25)
+
+ # layers
+ self.norm = nn.GroupNorm(32, dim)
+ self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
+ if context_dim is not None:
+ self.context_kv = nn.Linear(context_dim, dim * 2)
+ self.proj = nn.Conv2d(dim, dim, 1)
+
+ # zero out the last layer params
+ nn.init.zeros_(self.proj.weight)
+
+ def forward(self, x, context=None):
+ r"""x: [B, C, H, W].
+ context: [B, L, C] or None.
+ """
+ identity = x
+ b, c, h, w, n, d = *x.size(), self.num_heads, self.head_dim
+
+ # compute query, key, value
+ x = self.norm(x)
+ q, k, v = self.to_qkv(x).view(b, n * 3, d, h * w).chunk(3, dim=1)
+ if context is not None:
+ ck, cv = self.context_kv(context).reshape(b, -1, n * 2,
+ d).permute(0, 2, 3,
+ 1).chunk(
+ 2, dim=1)
+ k = torch.cat([ck, k], dim=-1)
+ v = torch.cat([cv, v], dim=-1)
+
+ # compute attention
+ attn = torch.matmul(q.transpose(-1, -2) * self.scale, k * self.scale)
+ attn = F.softmax(attn, dim=-1)
+
+ # gather context
+ x = torch.matmul(v, attn.transpose(-1, -2))
+ x = x.reshape(b, c, h, w)
+ # output
+ x = self.proj(x)
+ return x + identity
+
+
+class TemporalConvBlock_v2(nn.Module):
+
+ def __init__(self,
+ in_dim,
+ out_dim=None,
+ dropout=0.0,
+ use_image_dataset=False):
+ super(TemporalConvBlock_v2, self).__init__()
+ if out_dim is None:
+ out_dim = in_dim # int(1.5*in_dim)
+ self.in_dim = in_dim
+ self.out_dim = out_dim
+ self.use_image_dataset = use_image_dataset
+
+ # conv layers
+ self.conv1 = nn.Sequential(
+ nn.GroupNorm(32, in_dim), nn.SiLU(),
+ nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0)))
+ self.conv2 = nn.Sequential(
+ nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout),
+ nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)))
+ self.conv3 = nn.Sequential(
+ nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout),
+ nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)))
+ self.conv4 = nn.Sequential(
+ nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout),
+ nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)))
+
+ # zero out the last layer params,so the conv block is identity
+ nn.init.zeros_(self.conv4[-1].weight)
+ nn.init.zeros_(self.conv4[-1].bias)
+
+ def forward(self, x):
+ identity = x
+ x = self.conv1(x)
+ x = self.conv2(x)
+ x = self.conv3(x)
+ x = self.conv4(x)
+
+ if self.use_image_dataset:
+ x = identity + 0.0 * x
+ else:
+ x = identity + x
+ return x
+
+
+def prob_mask_like(shape, prob, device):
+ if prob == 1:
+ return torch.ones(shape, device=device, dtype=torch.bool)
+ elif prob == 0:
+ return torch.zeros(shape, device=device, dtype=torch.bool)
+ else:
+ mask = torch.zeros(shape, device=device).float().uniform_(0, 1) < prob
+ # aviod mask all, which will cause find_unused_parameters error
+ if mask.all():
+ mask[0] = False
+ return mask
+
+
+def conv_nd(dims, *args, **kwargs):
+ """
+ Create a 1D, 2D, or 3D convolution module.
+ """
+ if dims == 1:
+ return nn.Conv1d(*args, **kwargs)
+ elif dims == 2:
+ return nn.Conv2d(*args, **kwargs)
+ elif dims == 3:
+ return nn.Conv3d(*args, **kwargs)
+ raise ValueError(f'unsupported dimensions: {dims}')
+
+
+def avg_pool_nd(dims, *args, **kwargs):
+ """
+ Create a 1D, 2D, or 3D average pooling module.
+ """
+ if dims == 1:
+ return nn.AvgPool1d(*args, **kwargs)
+ elif dims == 2:
+ return nn.AvgPool2d(*args, **kwargs)
+ elif dims == 3:
+ return nn.AvgPool3d(*args, **kwargs)
+ raise ValueError(f'unsupported dimensions: {dims}')
diff --git a/modelscope/models/nlp/bert/siamese_uie.py b/modelscope/models/nlp/bert/siamese_uie.py
index 10b4b478..33ec925e 100644
--- a/modelscope/models/nlp/bert/siamese_uie.py
+++ b/modelscope/models/nlp/bert/siamese_uie.py
@@ -44,6 +44,20 @@ class SiameseUieModel(BertPreTrainedModel):
self.plm.encoder.layer = self.plm.encoder.layer[:self.config.
num_hidden_layers]
+ def circle_loss(self, y_pred, y_true):
+ batch_size = y_true.size(0)
+ y_true = y_true.view(batch_size, -1)
+ y_pred = y_pred.view(batch_size, -1)
+ y_pred = (1 - 2 * y_true) * y_pred
+ y_pred_neg = y_pred - y_true * 1e12
+ y_pred_pos = y_pred - (1 - y_true) * 1e12
+ zeros = torch.zeros_like(y_pred[:, :1])
+ y_pred_neg = torch.cat([y_pred_neg, zeros], dim=-1)
+ y_pred_pos = torch.cat([y_pred_pos, zeros], dim=-1)
+ neg_loss = torch.logsumexp(y_pred_neg, dim=-1)
+ pos_loss = torch.logsumexp(y_pred_pos, dim=-1)
+ return (neg_loss + pos_loss).mean()
+
def get_cross_attention_output(self, hidden_states, attention_mask,
encoder_hidden_states,
encoder_attention_mask):
@@ -72,8 +86,44 @@ class SiameseUieModel(BertPreTrainedModel):
position_ids=position_ids)[0]
return sequence_output
- def forward(self, sequence_output, attention_masks, hint_ids,
- cross_attention_masks):
+ def forward(self, input_ids, attention_masks, hint_ids,
+ cross_attention_masks, head_labels, tail_labels):
+ """train forward
+
+ Args:
+ input_ids (Tensor): input token ids of text.
+ attention_masks (Tensor): attention_masks of text.
+ hint_ids (Tensor): input token ids of prompt.
+ cross_attention_masks (Tensor): attention_masks of prompt.
+ head_labels (Tensor): labels of start position.
+ tail_labels (Tensor): labels of end position.
+
+ Returns:
+ Dict[str, float]: the loss
+ Example:
+ {"loss": 0.5091743}
+ """
+ sequence_output = self.get_plm_sequence_output(input_ids,
+ attention_masks)
+ assert hint_ids.size(1) + input_ids.size(1) <= 512
+ position_ids = torch.arange(hint_ids.size(1)).expand(
+ (1, -1)) + input_ids.size(1)
+ position_ids = position_ids.to(sequence_output.device)
+ hint_sequence_output = self.get_plm_sequence_output(
+ hint_ids, cross_attention_masks, position_ids, is_hint=True)
+ sequence_output = self.get_cross_attention_output(
+ sequence_output, attention_masks, hint_sequence_output,
+ cross_attention_masks)
+ # (b, l, n)
+ head_logits = self.head_clsf(sequence_output).squeeze(-1)
+ tail_logits = self.tail_clsf(sequence_output).squeeze(-1)
+ loss_func = self.circle_loss
+ head_loss = loss_func(head_logits, head_labels)
+ tail_loss = loss_func(tail_logits, tail_labels)
+ return {'loss': head_loss + tail_loss}
+
+ def fast_inference(self, sequence_output, attention_masks, hint_ids,
+ cross_attention_masks):
"""
Args:
diff --git a/modelscope/models/nlp/gpt3/backbone.py b/modelscope/models/nlp/gpt3/backbone.py
index a86f01e4..2f8e4699 100644
--- a/modelscope/models/nlp/gpt3/backbone.py
+++ b/modelscope/models/nlp/gpt3/backbone.py
@@ -354,6 +354,9 @@ class GPT3Model(PreTrainedModel):
return model
def generate(self, tokens, temperature=1.0, **kwargs):
+ top_k = kwargs.pop('top_k', self.config.top_k)
+ top_p = kwargs.pop('top_p', self.config.top_p)
+ max_length = kwargs.pop('max_length', tokens.size(1) + 100)
batch_size = tokens.size(0)
lengths = kwargs.pop(
@@ -361,13 +364,18 @@ class GPT3Model(PreTrainedModel):
torch.tensor([tokens.size(1)], device=tokens.device))
min_prompt_length = lengths.min().item()
- max_sequence_length = tokens.size(1)
- max_sequence_length = min(max_sequence_length,
+ max_sequence_length = min(max_length,
self.config.max_position_embeddings)
# If the context is too big, this happens
if min_prompt_length >= max_sequence_length:
- raise ValueError('context length + tokens_to_generate too large')
+ raise ValueError('context length too large')
+
+ pad_length = max_sequence_length - tokens.size(1)
+ if pad_length > 0:
+ pads = torch.zeros(
+ batch_size, pad_length, device=tokens.device).long()
+ tokens = torch.cat((tokens, pads), dim=-1)
# Added termination_id to support the case that we want to terminate the
# generation once that id is generated.
@@ -391,8 +399,8 @@ class GPT3Model(PreTrainedModel):
last_token_logits = logits[:, -1, :]
new_sample = sample(
last_token_logits,
- top_k=self.config.top_k,
- top_p=self.config.top_p,
+ top_k=top_k,
+ top_p=top_p,
temperature=temperature,
vocab_size=self.config.vocab_size)
diff --git a/modelscope/models/nlp/gpt3/distributed_gpt3.py b/modelscope/models/nlp/gpt3/distributed_gpt3.py
index 1c4505a0..d0da9659 100644
--- a/modelscope/models/nlp/gpt3/distributed_gpt3.py
+++ b/modelscope/models/nlp/gpt3/distributed_gpt3.py
@@ -952,10 +952,11 @@ class DistributedGPT3(TorchModel):
rank,
path_load_tag='model',
*args,
+ megatron_cfg=None,
**kwargs):
super().__init__(model_dir, *args, **kwargs)
- init_megatron_util(model_dir=model_dir, rank=rank)
+ init_megatron_util(megatron_cfg, model_dir, rank=rank)
self.config = GPT3Config.from_pretrained(model_dir)
# Build model.
@@ -974,8 +975,8 @@ class DistributedGPT3(TorchModel):
self.dist_model = model
tensor_ws = mpu.get_tensor_model_parallel_world_size()
- ckpt_ws = get_args().get('checkpoint_tensor_model_parallel_size',
- tensor_ws)
+ ckpt_ws = get_args().get('checkpoint_tensor_model_parallel_size', None)
+ ckpt_ws = tensor_ws if ckpt_ws is None else ckpt_ws
ckpt_rank = mpu.get_tensor_model_parallel_rank() * ckpt_ws // tensor_ws
load_model = pre_load(ckpt_rank, model_dir, tag=path_load_tag)
load_model = split_state_dict(load_model, model, tensor_ws // ckpt_ws)
@@ -1032,24 +1033,32 @@ class DistributedGPT3(TorchModel):
stop_on_double_eol=False,
stop_on_eol=False,
**kwargs):
+ top_k = kwargs.pop('top_k', self.config.top_k)
+ top_p = kwargs.pop('top_p', self.config.top_p)
+ temperature = kwargs.pop('temperature', self.config.temperature)
+ max_length = kwargs.pop(
+ 'max_length',
+ tokens.size(1) + self.config.tokens_to_generate)
+
batch_size = tokens.size(0)
lengths = prompts_len
if lengths is None:
lengths = torch.tensor([tokens.size(1)], device=tokens.device)
- pads = torch.ones(
- batch_size, self.config.tokens_to_generate,
- device=tokens.device).long() * self.config.eod_id
- tokens = torch.cat((tokens, pads), dim=-1)
min_prompt_length = lengths.min().item()
- max_sequence_length = tokens.size(1)
- max_sequence_length = min(max_sequence_length,
+ max_sequence_length = min(max_length,
self.config.max_position_embeddings)
# If the context is too big, this happens
if min_prompt_length >= max_sequence_length:
raise ValueError('context length + tokens_to_generate too large')
+ pad_length = max_sequence_length - tokens.size(1)
+ if pad_length > 0:
+ pads = torch.zeros(
+ batch_size, pad_length, device=tokens.device).long()
+ tokens = torch.cat((tokens, pads), dim=-1)
+
# Initialize inference parameters.
self.inference_params = InferenceParams(batch_size,
max_sequence_length)
@@ -1084,9 +1093,9 @@ class DistributedGPT3(TorchModel):
last_token_logits = logits[:, -1, :]
new_sample = sample(
last_token_logits,
- top_k=kwargs.pop('top_k', self.config.top_k),
- top_p=kwargs.pop('top_p', self.config.top_p),
- temperature=kwargs.pop('temperature', self.config.temperature),
+ top_k=top_k,
+ top_p=top_p,
+ temperature=temperature,
vocab_size=self.config.vocab_size)
# If a prompt length is smaller or equal th current context
diff --git a/modelscope/models/nlp/gpt3/text_generation.py b/modelscope/models/nlp/gpt3/text_generation.py
index 368cd2b5..fbc82b8a 100644
--- a/modelscope/models/nlp/gpt3/text_generation.py
+++ b/modelscope/models/nlp/gpt3/text_generation.py
@@ -52,13 +52,14 @@ class GPT3ForTextGeneration(TorchModel):
"""
return self.model(**input)
- def generate(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]:
+ def generate(self, inputs: Dict[str, Tensor],
+ **kwargs) -> Dict[str, Tensor]:
if not isinstance(self.model, GPT3Model):
- return self.model.generate(**inputs)
+ return self.model.generate(**inputs, **kwargs)
tokens = inputs['input_ids']
lengths = self._get_length(inputs['attention_mask'])
- return self.model.generate(tokens, prompt_length=lengths)
+ return self.model.generate(tokens, prompt_length=lengths, **kwargs)
@staticmethod
def _get_length(attention_mask: torch.Tensor) -> Tensor:
diff --git a/modelscope/models/nlp/heads/crf_head.py b/modelscope/models/nlp/heads/crf_head.py
index 1454ed36..edccd7de 100644
--- a/modelscope/models/nlp/heads/crf_head.py
+++ b/modelscope/models/nlp/heads/crf_head.py
@@ -97,7 +97,7 @@ class TransformersCRFHead(TorchHead):
mask = label_mask
masked_lengths = mask.sum(-1).long()
masked_logits = torch.zeros_like(logits)
- for i in range(len(mask)):
+ for i in range(mask.shape[0]):
masked_logits[
i, :masked_lengths[i], :] = logits[i].masked_select(
mask[i].unsqueeze(-1)).view(masked_lengths[i], -1)
diff --git a/modelscope/models/nlp/palm_v2/text_generation.py b/modelscope/models/nlp/palm_v2/text_generation.py
index a87b5cdd..cd3ecdaf 100644
--- a/modelscope/models/nlp/palm_v2/text_generation.py
+++ b/modelscope/models/nlp/palm_v2/text_generation.py
@@ -779,8 +779,6 @@ class Translator(object):
self.end_token = self.symbols['EOS']
self.alpha = self.args.alpha
self.beam_size = self.args.beam_size
- self.min_length = self.args.min_length
- self.max_length = self.args.max_length
def from_batch(self, translation_batch):
batch = translation_batch['batch']
@@ -1065,8 +1063,7 @@ class Translator(object):
"""
self.model.eval()
with torch.no_grad():
- return self._fast_translate_batch(
- batch, self.max_length, min_length=self.min_length)
+ return self._fast_translate_batch(batch)
def _tile(self, x, count, dim=0):
perm = list(range(len(x.size())))
@@ -1121,13 +1118,13 @@ class Translator(object):
logits[indices_to_remove] = filter_value
return logits
- def _fast_translate_batch(self,
- batch: 'Batch',
- max_length: int,
- min_length: int = 0):
+ def _fast_translate_batch(self, batch: 'Batch'):
# TODO: faster code path for beam_size == 1.
# TODO: support these blacklisted features.
+ max_length = self.args.max_length
+ min_length = self.args.min_length
+
beam_size = self.beam_size
batch_size = batch.batch_size
src = batch.src
@@ -1366,7 +1363,10 @@ class PalmForTextGeneration(PalmPreTrainedModel):
logits=output[0],
)
- def generate(self, input: Dict[str, Tensor]) -> TokenGeneratorOutput:
+ def generate(self, input: Dict[str, Tensor],
+ **kwargs) -> TokenGeneratorOutput:
+ for k, v in kwargs.items():
+ setattr(self.generator.args, k, v)
outputs = self.generator(**input)
preds = outputs['predictions']
return TokenGeneratorOutput(sequences=[pred[0] for pred in preds])
diff --git a/modelscope/models/nlp/peer/__init__.py b/modelscope/models/nlp/peer/__init__.py
new file mode 100644
index 00000000..4d51a617
--- /dev/null
+++ b/modelscope/models/nlp/peer/__init__.py
@@ -0,0 +1,37 @@
+# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import TYPE_CHECKING
+
+from modelscope.utils.import_utils import LazyImportModule
+
+if TYPE_CHECKING:
+ from .configuration import PeerConfig
+ from .text_classification import PeerForSequenceClassification
+else:
+ _import_structure = {
+ 'configuration': ['PeerConfig'],
+ 'text_classification': ['PeerForSequenceClassification'],
+ }
+
+ import sys
+
+ sys.modules[__name__] = LazyImportModule(
+ __name__,
+ globals()['__file__'],
+ _import_structure,
+ module_spec=__spec__,
+ extra_objects={},
+ )
diff --git a/modelscope/models/nlp/peer/backbone.py b/modelscope/models/nlp/peer/backbone.py
new file mode 100644
index 00000000..2dca8dda
--- /dev/null
+++ b/modelscope/models/nlp/peer/backbone.py
@@ -0,0 +1,1256 @@
+# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors.
+# Copyright 2019 The Google AI Language Team Authors and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch PEER model. """
+
+import math
+from dataclasses import dataclass
+from typing import Optional, Tuple
+
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint
+from transformers.activations import ACT2FN, get_activation
+from transformers.file_utils import ModelOutput, add_start_docstrings
+from transformers.modeling_outputs import \
+ BaseModelOutputWithPastAndCrossAttentions
+from transformers.modeling_utils import (PreTrainedModel,
+ apply_chunking_to_forward,
+ find_pruneable_heads_and_indices,
+ prune_linear_layer)
+
+from modelscope.models import Model, TorchModel
+from modelscope.utils import logger as logging
+from modelscope.utils.nlp.utils import parse_labels_in_order
+from .configuration import PeerConfig
+from .sas_utils import SequenceSideInfo
+
+logger = logging.get_logger(__name__)
+
+PEER_PRETRAINED_MODEL_ARCHIVE_LIST = [
+ 'google/peer-small-generator',
+ 'google/peer-base-generator',
+ 'google/peer-large-generator',
+ 'google/peer-small-discriminator',
+ 'google/peer-base-discriminator',
+ 'google/peer-large-discriminator',
+ # See all PEER models at https://huggingface.co/models?filter=peer
+]
+
+
+class PeerEmbeddings(nn.Module):
+ """Construct the embeddings from word, position and token_type embeddings."""
+
+ def __init__(self, config):
+ super().__init__()
+ self.word_embeddings = nn.Embedding(
+ config.vocab_size,
+ config.embedding_size,
+ padding_idx=config.pad_token_id)
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings,
+ config.embedding_size)
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size,
+ config.embedding_size)
+
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
+ # any TensorFlow checkpoint file
+ self.LayerNorm = nn.LayerNorm(
+ config.embedding_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
+ self.register_buffer(
+ 'position_ids',
+ torch.arange(config.max_position_embeddings).expand((1, -1)))
+ self.position_embedding_type = getattr(config,
+ 'position_embedding_type',
+ ['absolute'])
+ if 'absolute_token_position_in_sentence' in self.position_embedding_type:
+ self.side_info_size = 16
+ self.position_embeddings__token_position_in_sentence = nn.Embedding(
+ self.side_info_size, config.embedding_size)
+
+ # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward
+ def forward(
+ self,
+ input_ids=None,
+ token_type_ids=None,
+ position_ids=None,
+ inputs_embeds=None,
+ past_key_values_length=0,
+ side_info_sets=dict(),
+ ):
+ if input_ids is not None:
+ input_shape = input_ids.size()
+ else:
+ input_shape = inputs_embeds.size()[:-1]
+
+ seq_length = input_shape[1]
+
+ if position_ids is None:
+ position_ids = self.position_ids[:,
+ past_key_values_length:seq_length
+ + past_key_values_length]
+
+ if token_type_ids is None:
+ token_type_ids = torch.zeros(
+ input_shape, dtype=torch.long, device=self.position_ids.device)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.word_embeddings(input_ids)
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
+
+ embeddings = inputs_embeds + token_type_embeddings
+ if 'absolute' in self.position_embedding_type:
+ position_embeddings = self.position_embeddings(position_ids)
+ embeddings += position_embeddings
+
+ if 'absolute_token_position_in_sentence' in self.position_embedding_type:
+ position_idx = torch.clamp(
+ side_info_sets['ss_token_position_in_sentence'],
+ min=0,
+ max=self.side_info_size - 1)
+ position_embeddings__token_position_in_sentence = self.position_embeddings__token_position_in_sentence(
+ position_idx)
+ embeddings += position_embeddings__token_position_in_sentence
+
+ # Pass to attention layers to calcualte position-2-position attention scores
+ if 'absolute_self_only' in self.position_embedding_type:
+ if 'embeddings' not in side_info_sets:
+ side_info_sets['embeddings'] = dict()
+ side_info_sets['embeddings'][
+ 'ss_token_position_in_sequence'] = self.position_embeddings(
+ position_ids)
+
+ embeddings = self.LayerNorm(embeddings)
+ embeddings = self.dropout(embeddings)
+ return embeddings
+
+
+class PeerSelfAttention(nn.Module):
+
+ def __init__(self, config):
+ super().__init__()
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
+ config, 'embedding_size'):
+ raise ValueError(
+ 'The hidden size (%d) is not a multiple of the number of attention '
+ 'heads (%d)' %
+ (config.hidden_size, config.num_attention_heads))
+
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(config.hidden_size
+ / config.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
+
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+ self.position_embedding_type = getattr(config,
+ 'position_embedding_type',
+ ['absolute'])
+
+ if 'relative_scalar_bias' in self.position_embedding_type:
+ self.max_relative_position_embeddings = config.max_position_embeddings // 4
+ self.distance_embedding = nn.Embedding(
+ 2 * self.max_relative_position_embeddings,
+ self.num_attention_heads)
+
+ elif 'relative_scalar_bias_with_side_info_token' in self.position_embedding_type:
+ self.max_relative_position_embeddings = config.max_position_embeddings // 4
+ self.side_info_size = 16 # leverage the information of token_position_in_sentence
+ self.distance_embedding = nn.Embedding(
+ (2 * self.max_relative_position_embeddings)
+ * self.side_info_size, self.num_attention_heads)
+
+ elif 'relative_scalar_bias_token_plus_sentence' in self.position_embedding_type:
+ self.max_relative_position_embeddings = config.max_position_embeddings // 4
+ self.max_sen_relative_position_embeddings = self.max_relative_position_embeddings // 4
+
+ self.distance_embedding = nn.Embedding(
+ 2 * self.max_relative_position_embeddings,
+ self.num_attention_heads)
+ self.distance_embedding_sentence = nn.Embedding(
+ 2 * self.max_sen_relative_position_embeddings,
+ self.num_attention_heads)
+
+ elif 'relative_scalar_bias_with_side_info_sentence' in self.position_embedding_type:
+ self.max_relative_position_embeddings = config.max_position_embeddings // 4
+ self.max_sen_relative_position_embeddings = self.max_relative_position_embeddings // 4
+
+ vocab = (2 * self.max_relative_position_embeddings) * (
+ 2 * self.max_sen_relative_position_embeddings)
+ self.distance_embedding = nn.Embedding(vocab,
+ self.num_attention_heads)
+
+ elif 'relative_key' in self.position_embedding_type or 'relative_key_query' in self.position_embedding_type:
+ self.max_relative_position_embeddings = config.max_position_embeddings // 4
+ self.distance_embedding = nn.Embedding(
+ 2 * self.max_relative_position_embeddings,
+ self.attention_head_size)
+
+ self.is_decoder = config.is_decoder
+
+ def transpose_for_scores(self, x):
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads,
+ self.attention_head_size)
+ x = x.view(*new_x_shape)
+ return x.permute(0, 2, 1, 3)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_value=None,
+ output_attentions=False,
+ side_info_sets=dict(),
+ ):
+ mixed_query_layer = self.query(hidden_states)
+
+ # If this is instantiated as a cross-attention module, the keys
+ # and values come from an encoder; the attention mask needs to be
+ # such that the encoder's padding tokens are not attended to.
+ is_cross_attention = encoder_hidden_states is not None
+
+ if is_cross_attention and past_key_value is not None:
+ # reuse k,v, cross_attentions
+ key_layer = past_key_value[0]
+ value_layer = past_key_value[1]
+ attention_mask = encoder_attention_mask
+ elif is_cross_attention:
+ key_layer = self.transpose_for_scores(
+ self.key(encoder_hidden_states))
+ value_layer = self.transpose_for_scores(
+ self.value(encoder_hidden_states))
+ attention_mask = encoder_attention_mask
+ elif past_key_value is not None:
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
+ else:
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+
+ query_layer = self.transpose_for_scores(mixed_query_layer)
+
+ if self.is_decoder:
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
+ # Further calls to cross_attention layer can then reuse all cross-attention
+ # key/value_states (first "if" case)
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
+ past_key_value = (key_layer, value_layer)
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(
+ -1, -2)) / math.sqrt(self.attention_head_size)
+ attention_scores_terms = 1
+
+ if 'absolute_self_only' in self.position_embedding_type:
+ attention_scores += side_info_sets[
+ 'side_info_attention_scores'] # already normalized by sqrt(attention_head_size)
+ attention_scores_terms += 1
+
+ if 'relative_key' in self.position_embedding_type or 'relative_key_query' in self.position_embedding_type \
+ or 'relative_scalar_bias' in self.position_embedding_type \
+ or 'relative_scalar_bias_with_side_info_token' in self.position_embedding_type \
+ or 'relative_scalar_bias_token_plus_sentence' in self.position_embedding_type \
+ or 'relative_scalar_bias_with_side_info_sentence' in self.position_embedding_type:
+
+ distance_idx = side_info_sets['distance_idx']
+
+ positional_embedding = self.distance_embedding(distance_idx)
+ positional_embedding = positional_embedding.to(
+ dtype=query_layer.dtype) # fp16 compatibility
+
+ if 'relative_scalar_bias' in self.position_embedding_type:
+ relative_scalar_bias = positional_embedding.permute(
+ [2, 0, 1]).unsqueeze(0)
+ attention_scores = attention_scores / math.sqrt(
+ attention_scores_terms) + relative_scalar_bias
+
+ elif ('relative_scalar_bias_with_side_info_token'
+ in self.position_embedding_type
+ or 'relative_scalar_bias_with_side_info_sentence'
+ in self.position_embedding_type):
+ relative_scalar_bias = positional_embedding.permute(
+ [0, 3, 1, 2])
+ attention_scores = attention_scores / math.sqrt(
+ attention_scores_terms) + relative_scalar_bias
+
+ elif 'relative_scalar_bias_token_plus_sentence' in self.position_embedding_type:
+ relative_scalar_bias = positional_embedding.permute(
+ [2, 0, 1]).unsqueeze(0)
+
+ distance_idx_sentence = side_info_sets['distance_idx_sentence']
+ positional_embedding_sentence = self.distance_embedding_sentence(
+ distance_idx_sentence)
+ positional_embedding_sentence = positional_embedding_sentence.to(
+ dtype=query_layer.dtype) # fp16 compatibility
+ relative_scalar_bias_sentence = positional_embedding_sentence.permute(
+ [0, 3, 1, 2])
+
+ attention_scores = attention_scores / math.sqrt(
+ attention_scores_terms
+ ) + relative_scalar_bias + relative_scalar_bias_sentence
+
+ elif 'relative_key' in self.position_embedding_type:
+ relative_position_scores = torch.einsum(
+ 'bhld,lrd->bhlr', query_layer,
+ positional_embedding) / math.sqrt(self.attention_head_size)
+ attention_scores_terms += 1
+ attention_scores = (attention_scores + relative_position_scores
+ ) / math.sqrt(attention_scores_terms)
+ elif 'relative_key_query' in self.position_embedding_type:
+ relative_position_scores_query = torch.einsum(
+ 'bhld,lrd->bhlr', query_layer, positional_embedding)
+ relative_position_scores_key = torch.einsum(
+ 'bhrd,lrd->bhlr', key_layer, positional_embedding)
+ relative_position_scores = (
+ relative_position_scores_query
+ + relative_position_scores_key) / math.sqrt(
+ self.attention_head_size)
+ attention_scores_terms += 2
+ attention_scores = (attention_scores + relative_position_scores
+ ) / math.sqrt(attention_scores_terms)
+
+ else:
+ attention_scores = attention_scores / math.sqrt(
+ attention_scores_terms)
+
+ if attention_mask is not None:
+ # Apply the attention mask is (precomputed for all layers in PeerModel forward() function)
+ attention_scores = attention_scores + attention_mask
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs = self.dropout(attention_probs)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attention_probs = attention_probs * head_mask
+
+ context_layer = torch.matmul(attention_probs, value_layer)
+
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (
+ self.all_head_size, )
+ context_layer = context_layer.view(*new_context_layer_shape)
+
+ outputs = (context_layer,
+ attention_probs) if output_attentions else (context_layer, )
+
+ if self.is_decoder:
+ outputs = outputs + (past_key_value, )
+ return outputs
+
+
+class PeerSelfOutput(nn.Module):
+
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.LayerNorm = nn.LayerNorm(
+ config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states, input_tensor):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+class PeerAttention(nn.Module):
+
+ def __init__(self, config):
+ super().__init__()
+ self.self = PeerSelfAttention(config)
+ self.output = PeerSelfOutput(config)
+ self.pruned_heads = set()
+
+ def prune_heads(self, heads):
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(
+ heads, self.self.num_attention_heads,
+ self.self.attention_head_size, self.pruned_heads)
+
+ # Prune linear layers
+ self.self.query = prune_linear_layer(self.self.query, index)
+ self.self.key = prune_linear_layer(self.self.key, index)
+ self.self.value = prune_linear_layer(self.self.value, index)
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+ # Update hyper params and store pruned heads
+ self.self.num_attention_heads = self.self.num_attention_heads - len(
+ heads)
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_value=None,
+ output_attentions=False,
+ side_info_sets=dict(),
+ ):
+ self_outputs = self.self(
+ hidden_states,
+ attention_mask,
+ head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ past_key_value,
+ output_attentions,
+ side_info_sets,
+ )
+ attention_output = self.output(self_outputs[0], hidden_states)
+ outputs = (attention_output,
+ ) + self_outputs[1:] # add attentions if we output them
+ return outputs
+
+
+class PeerIntermediate(nn.Module):
+
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.intermediate_act_fn = config.hidden_act
+
+ def forward(self, hidden_states):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+ return hidden_states
+
+
+class PeerOutput(nn.Module):
+
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
+ self.LayerNorm = nn.LayerNorm(
+ config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states, input_tensor):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+class PeerLayer(nn.Module):
+
+ def __init__(self, config):
+ super().__init__()
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
+ self.seq_len_dim = 1
+ self.attention = PeerAttention(config)
+ self.is_decoder = config.is_decoder
+ self.add_cross_attention = config.add_cross_attention
+ if self.add_cross_attention:
+ assert self.is_decoder, f'{self} should be used as a decoder model if cross attention is added'
+ self.crossattention = PeerAttention(config)
+ self.intermediate = PeerIntermediate(config)
+ self.output = PeerOutput(config)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_value=None,
+ output_attentions=False,
+ side_info_sets=dict(),
+ ):
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
+ self_attn_past_key_value = past_key_value[:
+ 2] if past_key_value is not None else None
+ self_attention_outputs = self.attention(
+ hidden_states,
+ attention_mask,
+ head_mask,
+ output_attentions=output_attentions,
+ past_key_value=self_attn_past_key_value,
+ side_info_sets=side_info_sets,
+ )
+ attention_output = self_attention_outputs[0]
+
+ # if decoder, the last output is tuple of self-attn cache
+ if self.is_decoder:
+ outputs = self_attention_outputs[1:-1]
+ present_key_value = self_attention_outputs[-1]
+ else:
+ outputs = self_attention_outputs[
+ 1:] # add self attentions if we output attention weights
+
+ cross_attn_present_key_value = None
+ if self.is_decoder and encoder_hidden_states is not None:
+ assert hasattr(
+ self, 'crossattention'
+ ), f'If `encoder_hidden_states` are passed, {self} has to be instantiated \
+ with cross-attention layers by setting `config.add_cross_attention=True`'
+
+ # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
+ cross_attn_past_key_value = past_key_value[
+ -2:] if past_key_value is not None else None
+ cross_attention_outputs = self.crossattention(
+ attention_output,
+ attention_mask,
+ head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ cross_attn_past_key_value,
+ output_attentions,
+ )
+ attention_output = cross_attention_outputs[0]
+ outputs = outputs + cross_attention_outputs[
+ 1:-1] # add cross attentions if we output attention weights
+
+ # add cross-attn cache to positions 3,4 of present_key_value tuple
+ cross_attn_present_key_value = cross_attention_outputs[-1]
+ present_key_value = present_key_value + cross_attn_present_key_value
+
+ layer_output = apply_chunking_to_forward(self.feed_forward_chunk,
+ self.chunk_size_feed_forward,
+ self.seq_len_dim,
+ attention_output)
+ outputs = (layer_output, ) + outputs
+
+ # if decoder, return the attn key/values as the last output
+ if self.is_decoder:
+ outputs = outputs + (present_key_value, )
+
+ return outputs
+
+ def feed_forward_chunk(self, attention_output):
+ intermediate_output = self.intermediate(attention_output)
+ layer_output = self.output(intermediate_output, attention_output)
+ return layer_output
+
+
+class PeerEncoder(nn.Module):
+
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.layer = nn.ModuleList(
+ [PeerLayer(config) for _ in range(config.num_hidden_layers)])
+
+ self.position_embedding_type = getattr(config,
+ 'position_embedding_type',
+ ['absolute'])
+ if 'absolute_self_only' in self.position_embedding_type:
+ # To be used/shared in all self-attention layers. Copy their dimensions here to be consistent.
+ self.self_attention = self.layer[0].attention.self
+
+ self.num_attention_heads = self.self_attention.num_attention_heads
+ self.attention_head_size = self.self_attention.attention_head_size
+ self.all_head_size = self.self_attention.all_head_size
+
+ self.pos_query = nn.Linear(self.self_attention.query.in_features,
+ self.self_attention.query.out_features)
+ self.pos_key = nn.Linear(self.self_attention.key.in_features,
+ self.self_attention.key.out_features)
+
+ def get_position_attention_score(self, hidden_states):
+ query_layer = self.self_attention.transpose_for_scores(
+ self.pos_query(hidden_states))
+ key_layer = self.self_attention.transpose_for_scores(
+ self.pos_key(hidden_states))
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ attention_scores = torch.matmul(query_layer,
+ key_layer.transpose(-1, -2))
+ attention_scores = attention_scores / math.sqrt(
+ self.attention_head_size)
+ return attention_scores
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_values=None,
+ use_cache=None,
+ output_attentions=False,
+ output_hidden_states=False,
+ side_info_sets=dict(),
+ return_dict=True,
+ ):
+
+ if 'absolute_self_only' in self.position_embedding_type:
+ side_info_attention_scores = self.get_position_attention_score(
+ hidden_states=side_info_sets['embeddings']
+ ['ss_token_position_in_sequence'])
+ side_info_sets[
+ 'side_info_attention_scores'] = side_info_attention_scores
+
+ if 'relative_key' in self.position_embedding_type or 'relative_key_query' in self.position_embedding_type \
+ or 'relative_scalar_bias' in self.position_embedding_type \
+ or 'relative_scalar_bias_with_side_info_token' in self.position_embedding_type \
+ or 'relative_scalar_bias_token_plus_sentence' in self.position_embedding_type \
+ or 'relative_scalar_bias_with_side_info_sentence' in self.position_embedding_type:
+ seq_length = hidden_states.shape[1]
+ batch_size = hidden_states.shape[0]
+
+ position_ids_l = torch.arange(
+ seq_length, dtype=torch.long,
+ device=hidden_states.device).view(-1, 1)
+ position_ids_r = torch.arange(
+ seq_length, dtype=torch.long,
+ device=hidden_states.device).view(1, -1)
+ max_relative_position_embeddings = self.layer[
+ 0].attention.self.max_relative_position_embeddings
+ distance_idx = torch.clamp(
+ position_ids_l - position_ids_r
+ + max_relative_position_embeddings - 2,
+ min=0,
+ max=2 * max_relative_position_embeddings - 4)
+ distance_idx[
+ 0, :] = 2 * max_relative_position_embeddings - 3 # CLS-to-others
+ distance_idx[:,
+ 0] = 2 * max_relative_position_embeddings - 2 # others-to-CLS
+ distance_idx[
+ 0, 0] = 2 * max_relative_position_embeddings - 1 # CLS-to-CLS
+ distance_idx_max = 2 * max_relative_position_embeddings
+
+ # token position-aware relative position
+ if 'relative_scalar_bias_with_side_info_token' in self.position_embedding_type:
+ idx1 = torch.clamp(
+ side_info_sets['ss_token_position_in_sentence'],
+ min=0,
+ max=self.layer[0].attention.self.side_info_size
+ - 1).unsqueeze(2).repeat(1, 1, seq_length)
+ idx2 = distance_idx.unsqueeze(0).repeat(batch_size, 1, 1)
+ distance_idx = idx1 * distance_idx_max + idx2
+ # relative token position + relative sentence position
+ elif 'relative_scalar_bias_with_side_info_sentence' in self.position_embedding_type:
+ sen_position_ids_l = side_info_sets[
+ 'ss_sentence_position_in_sequence'].view(
+ batch_size, -1, 1)
+ sen_position_ids_r = side_info_sets[
+ 'ss_sentence_position_in_sequence'].view(
+ batch_size, 1, -1)
+ max_sen_relative_position_embeddings = self.layer[
+ 0].attention.self.max_sen_relative_position_embeddings
+ idx1 = torch.clamp(
+ sen_position_ids_l - sen_position_ids_r
+ + max_sen_relative_position_embeddings,
+ min=0,
+ max=2 * max_sen_relative_position_embeddings - 1)
+ idx2 = distance_idx.unsqueeze(0).repeat(batch_size, 1, 1)
+ distance_idx = idx1 * distance_idx_max + idx2
+ elif 'relative_scalar_bias_token_plus_sentence' in self.position_embedding_type:
+ sen_position_ids_l = side_info_sets[
+ 'ss_sentence_position_in_sequence'].view(
+ batch_size, -1, 1)
+ sen_position_ids_r = side_info_sets[
+ 'ss_sentence_position_in_sequence'].view(
+ batch_size, 1, -1)
+ max_sen_relative_position_embeddings = self.layer[
+ 0].attention.self.max_sen_relative_position_embeddings
+ idx1 = torch.clamp(
+ sen_position_ids_l - sen_position_ids_r
+ + max_sen_relative_position_embeddings,
+ min=0,
+ max=2 * max_sen_relative_position_embeddings - 1)
+ side_info_sets['distance_idx_sentence'] = idx1
+
+ side_info_sets['distance_idx'] = distance_idx
+
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+ all_cross_attentions = (
+ ) if output_attentions and self.config.add_cross_attention else None
+
+ next_decoder_cache = () if use_cache else None
+ for i, layer_module in enumerate(self.layer):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states, )
+
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+ past_key_value = past_key_values[
+ i] if past_key_values is not None else None
+ if getattr(self.config, 'gradient_checkpointing', False):
+
+ def create_custom_forward(module):
+
+ def custom_forward(*inputs):
+ return module(*inputs, past_key_value,
+ output_attentions)
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(layer_module),
+ hidden_states,
+ attention_mask,
+ layer_head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ side_info_sets,
+ )
+ else:
+ layer_outputs = layer_module(
+ hidden_states,
+ attention_mask,
+ layer_head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ past_key_value,
+ output_attentions,
+ side_info_sets,
+ )
+
+ hidden_states = layer_outputs[0]
+ if use_cache:
+ next_decoder_cache += (layer_outputs[-1], )
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (
+ layer_outputs[1], )
+ if self.config.add_cross_attention:
+ all_cross_attentions = all_cross_attentions + (
+ layer_outputs[2], )
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states, )
+
+ if not return_dict:
+ return tuple(v for v in [
+ hidden_states,
+ next_decoder_cache,
+ all_hidden_states,
+ all_self_attentions,
+ all_cross_attentions,
+ ] if v is not None)
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=next_decoder_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ cross_attentions=all_cross_attentions,
+ )
+
+
+class PeerDiscriminatorPredictions(nn.Module):
+ """Prediction module for the discriminator, made up of two dense layers."""
+
+ def __init__(self, config):
+ super().__init__()
+
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.dense_prediction = nn.Linear(config.hidden_size, 1)
+ self.config = config
+
+ def forward(self, discriminator_hidden_states):
+ hidden_states = self.dense(discriminator_hidden_states)
+ hidden_states = get_activation(self.config.hidden_act)(hidden_states)
+ logits = self.dense_prediction(hidden_states).squeeze(-1)
+
+ return logits
+
+
+class PeerGeneratorPredictions(nn.Module):
+ """Prediction module for the generator, made up of two dense layers."""
+
+ def __init__(self, config):
+ super().__init__()
+
+ self.LayerNorm = nn.LayerNorm(config.embedding_size)
+ self.dense = nn.Linear(config.hidden_size, config.embedding_size)
+
+ def forward(self, generator_hidden_states):
+ hidden_states = self.dense(generator_hidden_states)
+ hidden_states = get_activation('gelu')(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states)
+
+ return hidden_states
+
+
+class PeerPreTrainedModel(TorchModel, PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = PeerConfig
+ base_model_prefix = 'teams1_shared_bottom'
+ _keys_to_ignore_on_load_missing = [r'position_ids']
+ _keys_to_ignore_on_load_unexpected = [
+ r'peer\.embeddings_project\.weight', r'peer\.embeddings_project\.bias'
+ ]
+
+ def _init_weights(self, module):
+ """ Initialize the weights """
+ if isinstance(module, (nn.Linear, nn.Embedding)):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(
+ mean=0.0, std=self.config.initializer_range)
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ if isinstance(module, nn.Linear) and module.bias is not None:
+ module.bias.data.zero_()
+
+ @classmethod
+ def _instantiate(cls, **kwargs):
+ """Instantiate the model.
+
+ Args:
+ kwargs: Input args.
+ model_dir: The model dir used to load the checkpoint and the label information.
+ num_labels: An optional arg to tell the model how many classes to initialize.
+ Method will call utils.parse_label_mapping if num_labels is not input.
+ label2id: An optional label2id mapping, which will cover the label2id in configuration (if exists).
+
+ Returns:
+ The loaded model, which is initialized by transformers.PreTrainedModel.from_pretrained
+ """
+
+ model_dir = kwargs.pop('model_dir', None)
+ cfg = kwargs.pop('cfg', None)
+ model_args = parse_labels_in_order(model_dir, cfg, **kwargs)
+
+ if model_dir is None:
+ config = PeerConfig(**model_args)
+ model = cls(config)
+ else:
+ model = super(Model, cls).from_pretrained(
+ pretrained_model_name_or_path=model_dir, **model_args)
+ return model
+
+
+@dataclass
+class PeerForRTDOutput(ModelOutput):
+ """
+ Output type of :class:`~transformers.PeerForRTD`.
+
+ Args:
+ loss (`optional`, returned when ``labels`` is provided, ``torch.FloatTensor`` of shape :obj:`(1,)`):
+ Total loss of the PEER objective.
+ logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`):
+ Prediction scores of the head (scores for each token before SoftMax).
+ hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`,
+ returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
+ Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
+ of shape :obj:`(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (:obj:`tuple(torch.FloatTensor)`, `optional`,
+ returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
+ Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
+ sequence_length, sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: torch.FloatTensor = None
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
+
+
+@dataclass
+class PeerForPreTrainingOutput(ModelOutput):
+ """
+ Output type of :class:`~transformers.PeerForPreTraining`.
+
+ Args:
+ loss (`optional`, returned when ``labels`` is provided, ``torch.FloatTensor`` of shape :obj:`(1,)`):
+ Total loss of the PEER objective.
+ logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`):
+ Prediction scores of the head (scores for each token before SoftMax).
+ hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`,
+ returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
+ Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
+ of shape :obj:`(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (:obj:`tuple(torch.FloatTensor)`, `optional`,
+ returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
+ Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
+ sequence_length, sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ mlm_loss: Optional[torch.FloatTensor] = None
+ rtd_loss: Optional[torch.FloatTensor] = None
+ mlm_logits: torch.FloatTensor = None
+ rtd_logits: torch.FloatTensor = None
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
+
+
+PEER_START_DOCSTRING = r"""
+
+ This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
+ methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,
+ pruning heads etc.)
+
+ This model is also a PyTorch `torch.nn.Module `__
+ subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to
+ general usage and behavior.
+
+ Parameters:
+ config (:class:`~transformers.PeerConfig`): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model
+ weights.
+"""
+
+PEER_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`):
+ Indices of input sequence tokens in the vocabulary.
+
+ Indices can be obtained using :class:`~transformers.PeerTokenizer`. See
+ :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
+ details.
+
+ `What are input IDs? <../glossary.html#input-ids>`__
+ attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`):
+ Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ `What are attention masks? <../glossary.html#attention-mask>`__
+ token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0,
+ 1]``:
+
+ - 0 corresponds to a `sentence A` token,
+ - 1 corresponds to a `sentence B` token.
+
+ `What are token type IDs? <../glossary.html#token-type-ids>`_
+ position_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
+ config.max_position_embeddings - 1]``.
+
+ `What are position IDs? <../glossary.html#position-ids>`_
+ head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`({0}, hidden_size)`, `optional`):
+ Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
+ This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
+ vectors than the model's internal embedding lookup matrix.
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`({0}, hidden_size)`, `optional`):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
+ the model is configured as a decoder.
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`):
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ output_attentions (:obj:`bool`, `optional`):
+ Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
+ tensors for more detail.
+ output_hidden_states (:obj:`bool`, `optional`):
+ Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
+ more detail.
+ return_dict (:obj:`bool`, `optional`):
+ Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+ 'The bare Peer Model transformer outputting raw hidden-states without any specific head on top. Identical to '
+ 'the BERT model except that it uses an additional linear layer between the embedding layer and the encoder if the '
+ 'hidden size and embedding size are different.'
+ ''
+ 'Both the generator and discriminator checkpoints may be loaded into this model.',
+ PEER_START_DOCSTRING,
+)
+class PeerModel(PeerPreTrainedModel):
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.embeddings = PeerEmbeddings(config)
+
+ if config.embedding_size != config.hidden_size:
+ self.embeddings_project = nn.Linear(config.embedding_size,
+ config.hidden_size)
+
+ self.encoder = PeerEncoder(config)
+ self.config = config
+ self.init_weights()
+
+ if self.config.seq_side_info_embeddings:
+ self.input_sequence_side_info = dict()
+ self.sequence_side_info = SequenceSideInfo()
+
+ def get_input_embeddings(self):
+ return self.embeddings.word_embeddings
+
+ def set_input_embeddings(self, value):
+ self.embeddings.word_embeddings = value
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+ def update_seq_side_info(self, side_info_sets, input_ids):
+
+ device = input_ids.device
+ if 'input_sequence_side_info' not in side_info_sets or len(
+ side_info_sets['input_sequence_side_info']) == 0:
+ input_sequence_side_info = self.sequence_side_info.generate_seq_side_info(
+ self.config.seq_side_info_embeddings, input_ids)
+
+ else:
+ # Save compute in PEER pre-training
+ # (Save the extra side info into cpu in the first epoch; Directly retrieve it from cpu in later epochs)
+ input_sequence_side_info = side_info_sets[
+ 'input_sequence_side_info']
+
+ for ss in input_sequence_side_info.keys():
+ input_sequence_side_info[ss] = input_sequence_side_info[ss].to(
+ device=device).long()
+ side_info_sets = {**side_info_sets, **input_sequence_side_info}
+ return side_info_sets
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ side_info_sets=dict(),
+ return_dict=None,
+ ):
+ if self.config.seq_side_info_embeddings:
+ side_info_sets = self.update_seq_side_info(side_info_sets,
+ input_ids)
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else
+ self.config.output_hidden_states)
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError(
+ 'You cannot specify both input_ids and inputs_embeds at the same time'
+ )
+ elif input_ids is not None:
+ input_shape = input_ids.size()
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ else:
+ raise ValueError(
+ 'You have to specify either input_ids or inputs_embeds')
+
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+
+ if attention_mask is None:
+ attention_mask = torch.ones(input_shape, device=device)
+ if token_type_ids is None:
+ token_type_ids = torch.zeros(
+ input_shape, dtype=torch.long, device=device)
+
+ extended_attention_mask = self.get_extended_attention_mask(
+ attention_mask, input_shape, device)
+ head_mask = self.get_head_mask(head_mask,
+ self.config.num_hidden_layers)
+
+ hidden_states = self.embeddings(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ token_type_ids=token_type_ids,
+ inputs_embeds=inputs_embeds,
+ side_info_sets=side_info_sets,
+ )
+
+ if hasattr(self, 'embeddings_project'):
+ hidden_states = self.embeddings_project(hidden_states)
+
+ hidden_states = self.encoder(
+ hidden_states,
+ attention_mask=extended_attention_mask,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ side_info_sets=side_info_sets,
+ return_dict=return_dict,
+ )
+
+ return hidden_states
+
+
+class PeerTopModel(PeerPreTrainedModel):
+
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.encoder = PeerEncoder(config)
+ self.config = config
+ self.init_weights()
+
+ if self.config.seq_side_info_embeddings:
+ self.input_sequence_side_info = dict()
+ self.sequence_side_info = SequenceSideInfo()
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+ def update_seq_side_info(self, side_info_sets, input_ids):
+
+ device = input_ids.device
+ if 'input_sequence_side_info' not in side_info_sets or len(
+ side_info_sets['input_sequence_side_info']) == 0:
+ input_sequence_side_info = self.sequence_side_info.generate_seq_side_info(
+ self.config.seq_side_info_embeddings, input_ids)
+
+ else:
+ # Save compute in PEER pre-training
+ # (Save the extra side info into cpu in the first epoch; Directly retrieve it from cpu in later epochs)
+ input_sequence_side_info = side_info_sets[
+ 'input_sequence_side_info']
+
+ for ss in input_sequence_side_info.keys():
+ input_sequence_side_info[ss] = input_sequence_side_info[ss].to(
+ device=device).long()
+ side_info_sets = {**side_info_sets, **input_sequence_side_info}
+ return side_info_sets
+
+ def forward(
+ self,
+ hidden_states,
+ input_ids=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ side_info_sets=dict(),
+ return_dict=None,
+ ):
+
+ if self.config.seq_side_info_embeddings:
+ side_info_sets = self.update_seq_side_info(side_info_sets,
+ input_ids)
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else
+ self.config.output_hidden_states)
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError(
+ 'You cannot specify both input_ids and inputs_embeds at the same time'
+ )
+ elif input_ids is not None:
+ input_shape = input_ids.size()
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ else:
+ raise ValueError(
+ 'You have to specify either input_ids or inputs_embeds')
+
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+
+ if attention_mask is None:
+ attention_mask = torch.ones(input_shape, device=device)
+ if token_type_ids is None:
+ token_type_ids = torch.zeros(
+ input_shape, dtype=torch.long, device=device)
+
+ extended_attention_mask = self.get_extended_attention_mask(
+ attention_mask, input_shape, device)
+ head_mask = self.get_head_mask(head_mask,
+ self.config.num_hidden_layers)
+
+ hidden_states = self.encoder(
+ hidden_states,
+ attention_mask=extended_attention_mask,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ side_info_sets=side_info_sets,
+ return_dict=return_dict,
+ )
+
+ return hidden_states
+
+
+class PeerClassificationHead(nn.Module):
+ """Head for sentence-level classification tasks."""
+
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
+
+ def forward(self, features, **kwargs):
+ x = features[:, 0, :] # take token (equiv. to [CLS])
+ x = self.dropout(x)
+ x = self.dense(x)
+ x = get_activation('gelu')(
+ x
+ ) # although BERT uses tanh here, it seems Peer authors used gelu here
+ x = self.dropout(x)
+ x = self.out_proj(x)
+ return x
diff --git a/modelscope/models/nlp/peer/configuration.py b/modelscope/models/nlp/peer/configuration.py
new file mode 100644
index 00000000..da8b0a74
--- /dev/null
+++ b/modelscope/models/nlp/peer/configuration.py
@@ -0,0 +1,224 @@
+# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors.
+# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
+# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" PEER model configuration """
+
+# modified the path according to the structure in my directory csssl_4_15/cssl/ and its env
+from transformers.configuration_utils import PretrainedConfig
+
+from modelscope.utils import logger as logging
+
+logger = logging.get_logger(__name__)
+
+
+class PeerConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a :class:`~transformers.PeerModel` or a
+ :class:`~transformers.TFPeerModel`. It is used to instantiate a PEER model according to the specified
+ arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar
+ configuration to that of the PEER `google/peer-small-discriminator
+ `__ architecture.
+
+ Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model
+ outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information.
+
+
+ Args:
+ vocab_size (:obj:`int`, `onal`, defaults to 30522)
+ Vocabulary size of the PEER model. Defines the number of different tokens that can be represented by the
+ :obj:`inputs_ids` passed when calling :class:`~transformers.PeerModel` or
+ :class:`~transformers.TFPeerModel`.
+ embedding_size (:obj:`int`, `onal`, defaults to 128)
+ Dimensionality of the encoder layers and the pooler layer.
+ hidden_size (:obj:`int`, `onal`, defaults to 256)
+ Dimensionality of the encoder layers and the pooler layer.
+ num_hidden_layers (:obj:`int`, `onal`, defaults to 12)
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (:obj:`int`, `onal`, defaults to 4)
+ Number of attention heads for each attention layer in the Transformer encoder.
+ intermediate_size (:obj:`int`, `onal`, defaults to 1024)
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+ hidden_act (:obj:`str` or :obj:`Callable`, `onal`, defaults to :obj:`"gelu"`)
+ The non-linear activation function (function or string) in the encoder and pooler. If string,
+ :obj:`"gelu"`, :obj:`"relu"`, :obj:`"silu"` and :obj:`"gelu_new"` are supported.
+ hidden_dropout_prob (:obj:`float`, `onal`, defaults to 0.1)
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_probs_dropout_prob (:obj:`float`, `onal`, defaults to 0.1)
+ The dropout ratio for the attention probabilities.
+ max_position_embeddings (:obj:`int`, `onal`, defaults to 512)
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
+ just in case (e.g., 512 or 1024 or 2048).
+ type_vocab_size (:obj:`int`, `onal`, defaults to 2)
+ The vocabulary size of the :obj:`token_type_ids` passed when calling :class:`~transformers.PeerModel` or
+ :class:`~transformers.TFPeerModel`.
+ initializer_range (:obj:`float`, `onal`, defaults to 0.02)
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (:obj:`float`, `onal`, defaults to 1e-12)
+ The epsilon used by the layer normalization layers.
+ summary_type (:obj:`str`, `onal`, defaults to :obj:`"first"`)
+ Argument used when doing sequence summary. Used in the sequence classification and multiple choice models.
+
+ Has to be one of the following ons
+
+ - :obj:`"last"`: Take the last token hidden state (like XLNet).
+ - :obj:`"first"`: Take the first token hidden state (like BERT).
+ - :obj:`"mean"`: Take the mean of all tokens hidden states.
+ - :obj:`"cls_index"`: Supply a Tensor of classification token position (like GPT/GPT-2).
+ - :obj:`"attn"`: Not implemented now, use multi-head attention.
+ summary_use_proj (:obj:`bool`, `onal`, defaults to :obj:`True`)
+ Argument used when doing sequence summary. Used in the sequence classification and multiple choice models.
+
+ Whether or not to add a projection after the vector extraction.
+ summary_activation (:obj:`str`, `onal`)
+ Argument used when doing sequence summary. Used in the sequence classification and multiple choice models.
+
+ Pass :obj:`"gelu"` for a gelu activation to the output, any other value will result in no activation.
+ summary_last_dropout (:obj:`float`, `onal`, defaults to 0.0)
+ Argument used when doing sequence summary. Used in the sequence classification and multiple choice models.
+
+ The dropout ratio to be used after the projection and activation.
+ position_embedding_type (:obj:`str`, `onal`, defaults to :obj:`"absolute"`)
+ Type of position embedding. Choose one of :obj:`"absolute"`, :obj:`"relative_key"`,
+ :obj:`"relative_key_query"`. For positional embeddings use :obj:`"absolute"`. For more information on
+ :obj:`"relative_key"`, please refer to `Self-Attention with Relative Position Representations (Shaw et al.)
+ `__. For more information on :obj:`"relative_key_query"`, please refer to
+ `Method 4` in `Improve Transformer Models with Better Relative Position Embeddings (Huang et al.)
+ `__.
+
+ Examples::
+
+ >>> from transformers import PeerModel, PeerConfig
+
+ >>> # Initializing a PEER peer-base-uncased style configuration
+ >>> configuration = PeerConfig()
+
+ >>> # Initializing a model from the peer-base-uncased style configuration
+ >>> model = PeerModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ """
+ model_type = 'peer'
+
+ def __init__(self,
+ vocab_size=30522,
+ embedding_size=128,
+ hidden_size=256,
+ num_hidden_layers=12,
+ num_hidden_layers_shared=3,
+ num_hidden_layers_gen=6,
+ num_attention_heads=4,
+ intermediate_size=1024,
+ hidden_act='gelu',
+ hidden_dropout_prob=0.1,
+ attention_probs_dropout_prob=0.1,
+ max_position_embeddings=512,
+ type_vocab_size=2,
+ initializer_range=0.02,
+ layer_norm_eps=1e-12,
+ summary_type='first',
+ summary_use_proj=True,
+ summary_activation='gelu',
+ summary_last_dropout=0.1,
+ pad_token_id=0,
+ position_embedding_type='absolute',
+ gen_weight=1,
+ dis_weight=50,
+ dis_weight_scheduler=1,
+ augmentation_copies=1,
+ augmentation_temperature=1,
+ absolute_position_embedding=1,
+ relative_position_embedding=32,
+ seq_side_info_embeddings=0,
+ cold_start_epochs=1.25,
+ debug_config=dict(),
+ rtd_levels=2,
+ rtd_level_thresholds='',
+ ranking_start_epoch=1.0,
+ real_token_rank_for_good_estimate=5,
+ rank_sampl_prop=0.3,
+ rank_sampl_range=100,
+ rank_delta_factor=0.0,
+ rank_level_compare_method=0,
+ weight_loss_low_levels=1.0,
+ weight_loss_low_levels_setting='1.0-1.0',
+ weight_loss_low_levels_scheduler=0,
+ weight_loss_level_compos=1,
+ mask_da=0,
+ mask_da_start_epoch=0.0,
+ mask_da_mlm_topk_val=0,
+ mask_ratio_setting='0.15-0.15',
+ mask_ratio_scheduler=0,
+ mask_ratio_stage1_epochs=0.0,
+ **kwargs):
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
+
+ self.vocab_size = vocab_size
+ self.embedding_size = embedding_size
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_hidden_layers_shared = num_hidden_layers_shared
+ self.num_hidden_layers_gen = num_hidden_layers_gen
+ self.num_attention_heads = num_attention_heads
+ self.intermediate_size = intermediate_size
+ self.hidden_act = hidden_act
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.max_position_embeddings = max_position_embeddings
+ self.type_vocab_size = type_vocab_size
+ self.initializer_range = initializer_range
+ self.layer_norm_eps = layer_norm_eps
+
+ self.summary_type = summary_type
+ self.summary_use_proj = summary_use_proj
+ self.summary_activation = summary_activation
+ self.summary_last_dropout = summary_last_dropout
+ if type(position_embedding_type) == str:
+ position_embedding_type = position_embedding_type.split('+')
+ self.position_embedding_type = position_embedding_type
+ self.augmentation_temperature = augmentation_temperature
+
+ self.gen_weight = gen_weight
+ self.dis_weight = dis_weight
+ self.dis_weight_scheduler = dis_weight_scheduler
+ self.augmentation_copies = augmentation_copies
+
+ self.absolute_position_embedding = absolute_position_embedding
+ self.relative_position_embedding = relative_position_embedding
+ self.seq_side_info_embeddings = seq_side_info_embeddings
+
+ self.cold_start_epochs = cold_start_epochs
+ self.debug_config = debug_config
+
+ self.rtd_levels = rtd_levels
+ self.rtd_level_thresholds = rtd_level_thresholds
+ self.ranking_start_epoch = ranking_start_epoch
+ self.real_token_rank_for_good_estimate = real_token_rank_for_good_estimate
+ self.rank_sampl_prop = rank_sampl_prop
+ self.rank_sampl_range = rank_sampl_range
+ self.rank_delta_factor = rank_delta_factor
+ self.rank_level_compare_method = rank_level_compare_method
+ self.weight_loss_low_levels = weight_loss_low_levels
+ self.weight_loss_low_levels_setting = weight_loss_low_levels_setting
+ self.weight_loss_low_levels_scheduler = weight_loss_low_levels_scheduler
+ self.weight_loss_level_compos = weight_loss_level_compos
+
+ self.mask_da = mask_da
+ self.mask_da_start_epoch = mask_da_start_epoch
+ self.mask_da_mlm_topk_val = mask_da_mlm_topk_val
+
+ self.mask_ratio_setting = mask_ratio_setting
+ self.mask_ratio_scheduler = mask_ratio_scheduler
+ self.mask_ratio_stage1_epochs = mask_ratio_stage1_epochs
diff --git a/modelscope/models/nlp/peer/sas_utils.py b/modelscope/models/nlp/peer/sas_utils.py
new file mode 100644
index 00000000..da947e4d
--- /dev/null
+++ b/modelscope/models/nlp/peer/sas_utils.py
@@ -0,0 +1,173 @@
+import random
+
+import nltk
+import numpy as np
+import torch
+
+
+def get_random_states(device=None):
+ random_states = {}
+
+ random_states['rng_state_torch'] = torch.get_rng_state()
+ random_states['rng_state_np'] = np.random.get_state()
+ random_states['rng_state_rnd'] = random.getstate()
+ if device is not None and device.type == 'cuda':
+ random_states['rng_state_torch_cuda'] = torch.cuda.get_rng_state(
+ device)
+
+ return random_states
+
+
+def set_random_states(random_states, device=None):
+
+ torch.set_rng_state(random_states['rng_state_torch'])
+ np.random.set_state(random_states['rng_state_np'])
+ random.setstate(random_states['rng_state_rnd'])
+ if device is not None and device.type == 'cuda':
+ torch.cuda.set_rng_state(random_states['rng_state_torch_cuda'])
+
+
+# Check any nan or inf in the data. Return an array of two elements for nan and inf, respectively.
+# Inputs
+# data: a tensor or a tuple of multiple tensors
+# Outputs:
+# results: Each element shows the # of tensors that includes nan or inf.
+# If data is a "tuple" (instead of a single tensor),
+# we add 10 to the count if any nan or inf is detected.
+def check_nan_inf(data):
+ if data is None:
+ return None
+
+ result = [0, 0]
+ if torch.is_tensor(data):
+ if torch.isnan(data).any():
+ result[0] = 1
+ if torch.isinf(data).any():
+ result[1] = 1
+
+ elif type(data) is tuple:
+ for i in range(len(data)):
+ if torch.is_tensor(data[i]):
+ if torch.isnan(data[i]).any():
+ result[0] += 1
+ if torch.isinf(data[i]).any():
+ result[1] += 1
+
+ if result[0] > 0:
+ result[0] += 10
+ if result[1] > 0:
+ result[1] += 10
+
+ return result if sum(result) > 0 else None
+
+
+class SequenceSideInfo():
+
+ def __init__(self, tokenizer=None):
+ if tokenizer is not None:
+ self.tokenizer = tokenizer
+ else:
+ from transformers import ElectraTokenizer
+ self.tokenizer = ElectraTokenizer.from_pretrained(
+ 'google/electra-small-generator')
+
+ self.sen_tokenizer = nltk.tokenize.punkt.PunktSentenceTokenizer()
+
+ tokens = [
+ self.tokenizer.decode([i])
+ for i in range(self.tokenizer.vocab_size)
+ ]
+ self.ind_subtokens = set(
+ [i for i in range(len(tokens)) if tokens[i][0:2] == '##'])
+ tmp = [
+ 0 if t[0] == '[' and t[-1] == ']' else
+ (10 + min(5,
+ len(t) - 2) if t[0:2] == '##' else min(10, len(t)))
+ for t in tokens
+ ]
+ self.len_tokens = torch.tensor(tmp, dtype=torch.int8)
+
+ def getSenTokIdx(self, sentence_position_embedding, inputs_str,
+ seq_len_total):
+ sentences = self.sen_tokenizer.tokenize(inputs_str)
+ sen_lengths = np.array([
+ len(x) - 2
+ for x in self.tokenizer.batch_encode_plus(sentences)['input_ids']
+ ]) # -2: to drop the extra [CLS] and [SEP] added by sen_tokenizer
+
+ sen_lengths[0] = seq_len_total - sen_lengths[1:].sum()
+
+ idx_sen = np.concatenate([
+ i * np.ones(sen_lengths[i], dtype=np.int8)
+ for i in range(len(sen_lengths))
+ ])
+ idx_tok = np.concatenate([
+ np.arange(sen_lengths[i], dtype=np.int8)
+ for i in range(len(sen_lengths))
+ ])
+
+ return np.concatenate((idx_sen, idx_tok))
+
+ def generate_seq_side_info(self, sentence_position_embedding, inputs_id):
+ is_np_array = False
+ if isinstance(inputs_id[0], (list, np.ndarray)):
+ is_np_array = True
+ inputs_id = torch.tensor(inputs_id)
+
+ if hasattr(self.tokenizer, 'batch_decode'):
+ inputs_str = self.tokenizer.batch_decode(inputs_id)
+ sen_tok_idx = torch.tensor(
+ np.array([
+ self.getSenTokIdx(sentence_position_embedding, input_str,
+ inputs_id.shape[1])
+ for input_str in inputs_str
+ ]),
+ device=inputs_id.device)
+ else:
+ sen_tok_idx = torch.tensor(
+ np.array([
+ self.getSenTokIdx(sentence_position_embedding,
+ self.tokenizer.decode(input_ori),
+ inputs_id.shape[1])
+ for input_ori in inputs_id.numpy()
+ ]),
+ device=inputs_id.device)
+
+ side_info_dict = dict()
+ seq_length = inputs_id.shape[1]
+ side_info_dict[
+ 'ss_sentence_position_in_sequence'] = sen_tok_idx[:, 0:seq_length]
+ side_info_dict[
+ 'ss_token_position_in_sentence'] = sen_tok_idx[:, 1 * seq_length:2
+ * seq_length]
+
+ if sentence_position_embedding >= 2:
+ # consider sub-word tokens
+ unique, _ = np.unique(inputs_id, return_inverse=True)
+ ind_subtokens = self.ind_subtokens.intersection(set(unique))
+
+ if len(ind_subtokens) > 0:
+ idx_tok_ww = torch.stack([
+ inputs_id == st for st in ind_subtokens
+ ]).any(axis=0).char()
+ else:
+ idx_tok_ww = torch.zeros(inputs_id.shape, dtype=torch.int8)
+
+ idx_tok_ww[:, 0] = 0
+ idx_tok_ww_1 = idx_tok_ww[:, 1:]
+ for i in range(1, 11):
+ pos = torch.logical_and(idx_tok_ww_1 == i,
+ idx_tok_ww[:, 0:-1] == i)
+ if len(pos) == 0:
+ break
+ idx_tok_ww_1[pos] = i + 1
+ side_info_dict['ss_token_position_in_whole_word'] = idx_tok_ww
+
+ inputs_str_len = self.len_tokens[inputs_id.long()]
+ side_info_dict['ss_token_string_length'] = inputs_str_len
+
+ if is_np_array:
+ for key in side_info_dict.keys():
+ side_info_dict[key] = side_info_dict[key].numpy()
+
+ return side_info_dict
diff --git a/modelscope/models/nlp/peer/text_classification.py b/modelscope/models/nlp/peer/text_classification.py
new file mode 100644
index 00000000..de55652c
--- /dev/null
+++ b/modelscope/models/nlp/peer/text_classification.py
@@ -0,0 +1,121 @@
+# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors.
+# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
+# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import copy
+
+from torch.nn import CrossEntropyLoss, MSELoss
+
+from modelscope.metainfo import Models
+from modelscope.models.builder import MODELS
+from modelscope.outputs import AttentionTextClassificationModelOutput
+from modelscope.utils import logger as logging
+from modelscope.utils.constant import Tasks
+from .backbone import (PeerClassificationHead, PeerModel, PeerPreTrainedModel,
+ PeerTopModel)
+
+logger = logging.get_logger()
+
+
+@MODELS.register_module(Tasks.text_classification, module_name=Models.peer)
+@MODELS.register_module(Tasks.nli, module_name=Models.peer)
+@MODELS.register_module(
+ Tasks.sentiment_classification, module_name=Models.peer)
+@MODELS.register_module(Tasks.sentence_similarity, module_name=Models.peer)
+@MODELS.register_module(
+ Tasks.zero_shot_classification, module_name=Models.peer)
+class PeerForSequenceClassification(PeerPreTrainedModel):
+
+ def __init__(self, config, **kwargs):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.config = config
+
+ config_discr_top = copy.deepcopy(config)
+ config_shared_bottom = copy.deepcopy(config)
+
+ assert config.num_hidden_layers_shared > 0, 'config.num_hidden_layers_shared should be greater than 0!'
+
+ config_shared_bottom.num_hidden_layers = config.num_hidden_layers_shared
+ config_discr_top.num_hidden_layers = config_discr_top.num_hidden_layers \
+ - config_discr_top.num_hidden_layers_shared
+
+ self.teams1_shared_bottom = PeerModel(config_shared_bottom)
+ self.teams1_discr_top = PeerTopModel(config_discr_top)
+
+ self.classifier = PeerClassificationHead(config)
+
+ self.init_weights()
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ labels=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ side_info_sets=dict(),
+ return_dict=None,
+ ):
+ r"""
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
+ Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
+ config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
+ If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ hidden_states_discr_bottom = self.teams1_shared_bottom(
+ input_ids, attention_mask, token_type_ids, position_ids, head_mask,
+ inputs_embeds, output_attentions, output_hidden_states,
+ side_info_sets, return_dict)
+
+ hidden_states_discr_top = self.teams1_discr_top(
+ hidden_states_discr_bottom[0], input_ids, attention_mask,
+ token_type_ids, position_ids, head_mask, inputs_embeds,
+ output_attentions, output_hidden_states, side_info_sets,
+ return_dict)
+
+ discriminator_hidden_states = hidden_states_discr_top
+
+ sequence_output = discriminator_hidden_states[0]
+
+ logits = self.classifier(sequence_output)
+ loss = None
+ if labels is not None:
+ if self.num_labels == 1:
+ # We are doing regression
+ loss_fct = MSELoss()
+ loss = loss_fct(logits.view(-1), labels.view(-1))
+ else:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(
+ logits.view(-1, self.num_labels), labels.view(-1))
+
+ if not return_dict:
+ output = (logits, ) + discriminator_hidden_states[1:]
+ return ((loss, ) + output) if loss is not None else output
+
+ return AttentionTextClassificationModelOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=discriminator_hidden_states.hidden_states,
+ attentions=discriminator_hidden_states.attentions,
+ )
diff --git a/modelscope/msdatasets/__init__.py b/modelscope/msdatasets/__init__.py
index 073f9396..70200e44 100644
--- a/modelscope/msdatasets/__init__.py
+++ b/modelscope/msdatasets/__init__.py
@@ -1,3 +1,2 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
-from . import cv
from .ms_dataset import MsDataset
diff --git a/modelscope/msdatasets/audio/__init__.py b/modelscope/msdatasets/audio/__init__.py
index e69de29b..b937315b 100644
--- a/modelscope/msdatasets/audio/__init__.py
+++ b/modelscope/msdatasets/audio/__init__.py
@@ -0,0 +1 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
diff --git a/modelscope/msdatasets/audio/asr_dataset.py b/modelscope/msdatasets/audio/asr_dataset.py
index c0696615..a7a344e9 100644
--- a/modelscope/msdatasets/audio/asr_dataset.py
+++ b/modelscope/msdatasets/audio/asr_dataset.py
@@ -1,48 +1,10 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
-import os
+from modelscope.msdatasets.dataset_cls.custom_datasets import ASRDataset
+from modelscope.utils.logger import get_logger
-from modelscope.msdatasets.ms_dataset import MsDataset
-
-
-class ASRDataset(MsDataset):
- """ASR dataset for speech recognition.
- support load dataset from msdataset hub or local data_dir (including wav.scp and text)
- For more details, please refer to
- https://github.com/alibaba-damo-academy/FunASR/blob/main/funasr/datasets/ms_dataset.py.
- """
-
- @classmethod
- def load_core(cls, data_dir, data_set):
- wav_file = os.path.join(data_dir, data_set, 'wav.scp')
- text_file = os.path.join(data_dir, data_set, 'text')
- with open(wav_file) as f:
- wav_lines = f.readlines()
- with open(text_file) as f:
- text_lines = f.readlines()
- data_list = []
- for wav_line, text_line in zip(wav_lines, text_lines):
- item = {}
- item['Audio:FILE'] = wav_line.strip().split()[-1]
- item['Text:LABEL'] = ' '.join(text_line.strip().split()[1:])
- data_list.append(item)
- return data_list
-
- @classmethod
- def load(cls,
- dataset_name,
- namespace='speech_asr',
- train_set='train',
- dev_set='validation'):
- if os.path.exists(dataset_name):
- data_dir = dataset_name
- ds_dict = {}
- ds_dict['train'] = cls.load_core(data_dir, train_set)
- ds_dict['validation'] = cls.load_core(data_dir, dev_set)
- ds_dict['raw_data_dir'] = data_dir
- return ds_dict
- else:
- from modelscope.msdatasets import MsDataset
- ds_dict = MsDataset.load(
- dataset_name=dataset_name, namespace=namespace)
- return ds_dict
+logger = get_logger()
+logger.warning(
+ 'The reference has been Deprecated, '
+ 'please use `from modelscope.msdatasets.dataset_cls.custom_datasets import ASRDataset`'
+)
diff --git a/modelscope/msdatasets/cv/__init__.py b/modelscope/msdatasets/cv/__init__.py
deleted file mode 100644
index fad91bcf..00000000
--- a/modelscope/msdatasets/cv/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-# Copyright (c) Alibaba, Inc. and its affiliates.
-from . import (image_classification, image_semantic_segmentation,
- object_detection)
diff --git a/modelscope/msdatasets/data_loader/data_loader.py b/modelscope/msdatasets/data_loader/data_loader.py
index c97151b0..1ef92372 100644
--- a/modelscope/msdatasets/data_loader/data_loader.py
+++ b/modelscope/msdatasets/data_loader/data_loader.py
@@ -13,6 +13,7 @@ from modelscope.msdatasets.context.dataset_context_config import \
DatasetContextConfig
from modelscope.msdatasets.data_files.data_files_manager import \
DataFilesManager
+from modelscope.msdatasets.dataset_cls.dataset import ExternalDataset
from modelscope.msdatasets.meta.data_meta_manager import DataMetaManager
from modelscope.utils.constant import DatasetFormations
@@ -62,7 +63,8 @@ class OssDataLoader(BaseDataLoader):
self.data_files_builder: Optional[DataFilesManager] = None
self.dataset: Optional[Union[Dataset, IterableDataset, DatasetDict,
- IterableDatasetDict]] = None
+ IterableDatasetDict,
+ ExternalDataset]] = None
self.builder: Optional[DatasetBuilder] = None
self.data_files_manager: Optional[DataFilesManager] = None
@@ -141,7 +143,8 @@ class OssDataLoader(BaseDataLoader):
self.builder)
def _post_process(self) -> None:
- ...
+ if isinstance(self.dataset, ExternalDataset):
+ self.dataset.custom_map = self.dataset_context_config.data_meta_config.meta_type_map
class MaxComputeDataLoader(BaseDataLoader):
diff --git a/modelscope/msdatasets/dataset_cls/__init__.py b/modelscope/msdatasets/dataset_cls/__init__.py
index b937315b..a5b2e73d 100644
--- a/modelscope/msdatasets/dataset_cls/__init__.py
+++ b/modelscope/msdatasets/dataset_cls/__init__.py
@@ -1 +1,3 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
+
+from .dataset import ExternalDataset, NativeIterableDataset
diff --git a/modelscope/msdatasets/dataset_cls/custom_datasets/__init__.py b/modelscope/msdatasets/dataset_cls/custom_datasets/__init__.py
new file mode 100644
index 00000000..9eb62168
--- /dev/null
+++ b/modelscope/msdatasets/dataset_cls/custom_datasets/__init__.py
@@ -0,0 +1,87 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+from typing import TYPE_CHECKING
+
+from modelscope.utils.import_utils import LazyImportModule
+
+if TYPE_CHECKING:
+ from .easycv_base import EasyCVBaseDataset
+ from .builder import CUSTOM_DATASETS, build_custom_dataset
+ from .torch_custom_dataset import TorchCustomDataset
+ from .movie_scene_segmentation.movie_scene_segmentation_dataset import MovieSceneSegmentationDataset
+ from .image_instance_segmentation_coco_dataset import ImageInstanceSegmentationCocoDataset
+ from .gopro_image_deblurring_dataset import GoproImageDeblurringDataset
+ from .language_guided_video_summarization_dataset import LanguageGuidedVideoSummarizationDataset
+ from .mgeo_ranking_dataset import MGeoRankingDataset
+ from .reds_image_deblurring_dataset import RedsImageDeblurringDataset
+ from .text_ranking_dataset import TextRankingDataset
+ from .veco_dataset import VecoDataset
+ from .video_summarization_dataset import VideoSummarizationDataset
+ from .audio import KWSDataset, KWSDataLoader, kws_nearfield_dataset, ASRDataset
+ from .bad_image_detecting import BadImageDetectingDataset
+ from .image_inpainting import ImageInpaintingDataset
+ from .image_portrait_enhancement import ImagePortraitEnhancementDataset
+ from .image_quality_assessment_degradation import ImageQualityAssessmentDegradationDataset
+ from .image_quality_assmessment_mos import ImageQualityAssessmentMosDataset
+ from .referring_video_object_segmentation import ReferringVideoObjectSegmentationDataset
+ from .sidd_image_denoising import SiddImageDenoisingDataset
+ from .video_frame_interpolation import VideoFrameInterpolationDataset
+ from .video_stabilization import VideoStabilizationDataset
+ from .video_super_resolution import VideoSuperResolutionDataset
+ from .image_semantic_segmentation import SegDataset
+ from .face_2d_keypoins import FaceKeypointDataset
+ from .hand_2d_keypoints import HandCocoWholeBodyDataset
+ from .human_wholebody_keypoint import WholeBodyCocoTopDownDataset
+ from .image_classification import ClsDataset
+ from .object_detection import DetDataset, DetImagesMixDataset
+ from .ocr_detection import DataLoader, ImageDataset, QuadMeasurer
+ from .ocr_recognition_dataset import OCRRecognitionDataset
+ from .image_colorization import ImageColorizationDataset
+else:
+ _import_structure = {
+ 'easycv_base': ['EasyCVBaseDataset'],
+ 'builder': ['CUSTOM_DATASETS', 'build_custom_dataset'],
+ 'torch_custom_dataset': ['TorchCustomDataset'],
+ 'movie_scene_segmentation_dataset': ['MovieSceneSegmentationDataset'],
+ 'image_instance_segmentation_coco_dataset':
+ ['ImageInstanceSegmentationCocoDataset'],
+ 'gopro_image_deblurring_dataset': ['GoproImageDeblurringDataset'],
+ 'language_guided_video_summarization_dataset':
+ ['LanguageGuidedVideoSummarizationDataset'],
+ 'mgeo_ranking_dataset': ['MGeoRankingDataset'],
+ 'reds_image_deblurring_dataset': ['RedsImageDeblurringDataset'],
+ 'text_ranking_dataset': ['TextRankingDataset'],
+ 'veco_dataset': ['VecoDataset'],
+ 'video_summarization_dataset': ['VideoSummarizationDataset'],
+ 'audio':
+ ['KWSDataset', 'KWSDataLoader', 'kws_nearfield_dataset', 'ASRDataset'],
+ 'bad_image_detecting': ['BadImageDetectingDataset'],
+ 'image_inpainting': ['ImageInpaintingDataset'],
+ 'image_portrait_enhancement': ['ImagePortraitEnhancementDataset'],
+ 'image_quality_assessment_degradation':
+ ['ImageQualityAssessmentDegradationDataset'],
+ 'image_quality_assmessment_mos': ['ImageQualityAssessmentMosDataset'],
+ 'referring_video_object_segmentation':
+ ['ReferringVideoObjectSegmentationDataset'],
+ 'sidd_image_denoising': ['SiddImageDenoisingDataset'],
+ 'video_frame_interpolation': ['VideoFrameInterpolationDataset'],
+ 'video_stabilization': ['VideoStabilizationDataset'],
+ 'video_super_resolution': ['VideoSuperResolutionDataset'],
+ 'image_semantic_segmentation': ['SegDataset'],
+ 'face_2d_keypoins': ['FaceKeypointDataset'],
+ 'hand_2d_keypoints': ['HandCocoWholeBodyDataset'],
+ 'human_wholebody_keypoint': ['WholeBodyCocoTopDownDataset'],
+ 'image_classification': ['ClsDataset'],
+ 'object_detection': ['DetDataset', 'DetImagesMixDataset'],
+ 'ocr_detection': ['DataLoader', 'ImageDataset', 'QuadMeasurer'],
+ 'ocr_recognition_dataset': ['OCRRecognitionDataset'],
+ 'image_colorization': ['ImageColorizationDataset'],
+ }
+
+ import sys
+ sys.modules[__name__] = LazyImportModule(
+ __name__,
+ globals()['__file__'],
+ _import_structure,
+ module_spec=__spec__,
+ extra_objects={},
+ )
diff --git a/modelscope/msdatasets/task_datasets/audio/__init__.py b/modelscope/msdatasets/dataset_cls/custom_datasets/audio/__init__.py
similarity index 89%
rename from modelscope/msdatasets/task_datasets/audio/__init__.py
rename to modelscope/msdatasets/dataset_cls/custom_datasets/audio/__init__.py
index dc66bd8d..7291bb7f 100644
--- a/modelscope/msdatasets/task_datasets/audio/__init__.py
+++ b/modelscope/msdatasets/dataset_cls/custom_datasets/audio/__init__.py
@@ -6,11 +6,13 @@ from modelscope.utils.import_utils import LazyImportModule
if TYPE_CHECKING:
from .kws_farfield_dataset import KWSDataset, KWSDataLoader
from .kws_nearfield_dataset import kws_nearfield_dataset
+ from .asr_dataset import ASRDataset
else:
_import_structure = {
'kws_farfield_dataset': ['KWSDataset', 'KWSDataLoader'],
'kws_nearfield_dataset': ['kws_nearfield_dataset'],
+ 'asr_dataset': ['ASRDataset'],
}
import sys
diff --git a/modelscope/msdatasets/dataset_cls/custom_datasets/audio/asr_dataset.py b/modelscope/msdatasets/dataset_cls/custom_datasets/audio/asr_dataset.py
new file mode 100644
index 00000000..c0696615
--- /dev/null
+++ b/modelscope/msdatasets/dataset_cls/custom_datasets/audio/asr_dataset.py
@@ -0,0 +1,48 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+
+import os
+
+from modelscope.msdatasets.ms_dataset import MsDataset
+
+
+class ASRDataset(MsDataset):
+ """ASR dataset for speech recognition.
+ support load dataset from msdataset hub or local data_dir (including wav.scp and text)
+ For more details, please refer to
+ https://github.com/alibaba-damo-academy/FunASR/blob/main/funasr/datasets/ms_dataset.py.
+ """
+
+ @classmethod
+ def load_core(cls, data_dir, data_set):
+ wav_file = os.path.join(data_dir, data_set, 'wav.scp')
+ text_file = os.path.join(data_dir, data_set, 'text')
+ with open(wav_file) as f:
+ wav_lines = f.readlines()
+ with open(text_file) as f:
+ text_lines = f.readlines()
+ data_list = []
+ for wav_line, text_line in zip(wav_lines, text_lines):
+ item = {}
+ item['Audio:FILE'] = wav_line.strip().split()[-1]
+ item['Text:LABEL'] = ' '.join(text_line.strip().split()[1:])
+ data_list.append(item)
+ return data_list
+
+ @classmethod
+ def load(cls,
+ dataset_name,
+ namespace='speech_asr',
+ train_set='train',
+ dev_set='validation'):
+ if os.path.exists(dataset_name):
+ data_dir = dataset_name
+ ds_dict = {}
+ ds_dict['train'] = cls.load_core(data_dir, train_set)
+ ds_dict['validation'] = cls.load_core(data_dir, dev_set)
+ ds_dict['raw_data_dir'] = data_dir
+ return ds_dict
+ else:
+ from modelscope.msdatasets import MsDataset
+ ds_dict = MsDataset.load(
+ dataset_name=dataset_name, namespace=namespace)
+ return ds_dict
diff --git a/modelscope/msdatasets/task_datasets/audio/kws_farfield_dataset.py b/modelscope/msdatasets/dataset_cls/custom_datasets/audio/kws_farfield_dataset.py
similarity index 99%
rename from modelscope/msdatasets/task_datasets/audio/kws_farfield_dataset.py
rename to modelscope/msdatasets/dataset_cls/custom_datasets/audio/kws_farfield_dataset.py
index d4866204..69c95bbd 100644
--- a/modelscope/msdatasets/task_datasets/audio/kws_farfield_dataset.py
+++ b/modelscope/msdatasets/dataset_cls/custom_datasets/audio/kws_farfield_dataset.py
@@ -5,7 +5,6 @@ import math
import os.path
import queue
import threading
-import time
import numpy as np
import torch
diff --git a/modelscope/msdatasets/task_datasets/audio/kws_nearfield_dataset.py b/modelscope/msdatasets/dataset_cls/custom_datasets/audio/kws_nearfield_dataset.py
similarity index 98%
rename from modelscope/msdatasets/task_datasets/audio/kws_nearfield_dataset.py
rename to modelscope/msdatasets/dataset_cls/custom_datasets/audio/kws_nearfield_dataset.py
index 43f28e01..1b784410 100644
--- a/modelscope/msdatasets/task_datasets/audio/kws_nearfield_dataset.py
+++ b/modelscope/msdatasets/dataset_cls/custom_datasets/audio/kws_nearfield_dataset.py
@@ -18,7 +18,7 @@ import torch
import torch.distributed as dist
from torch.utils.data import IterableDataset
-import modelscope.msdatasets.task_datasets.audio.kws_nearfield_processor as processor
+import modelscope.msdatasets.dataset_cls.custom_datasets.audio.kws_nearfield_processor as processor
from modelscope.trainers.audio.kws_utils.file_utils import (make_pair,
read_lists)
from modelscope.utils.logger import get_logger
diff --git a/modelscope/msdatasets/task_datasets/audio/kws_nearfield_processor.py b/modelscope/msdatasets/dataset_cls/custom_datasets/audio/kws_nearfield_processor.py
similarity index 100%
rename from modelscope/msdatasets/task_datasets/audio/kws_nearfield_processor.py
rename to modelscope/msdatasets/dataset_cls/custom_datasets/audio/kws_nearfield_processor.py
diff --git a/modelscope/msdatasets/task_datasets/bad_image_detecting/__init__.py b/modelscope/msdatasets/dataset_cls/custom_datasets/bad_image_detecting/__init__.py
similarity index 100%
rename from modelscope/msdatasets/task_datasets/bad_image_detecting/__init__.py
rename to modelscope/msdatasets/dataset_cls/custom_datasets/bad_image_detecting/__init__.py
diff --git a/modelscope/msdatasets/task_datasets/bad_image_detecting/bad_image_detecting_dataset.py b/modelscope/msdatasets/dataset_cls/custom_datasets/bad_image_detecting/bad_image_detecting_dataset.py
similarity index 79%
rename from modelscope/msdatasets/task_datasets/bad_image_detecting/bad_image_detecting_dataset.py
rename to modelscope/msdatasets/dataset_cls/custom_datasets/bad_image_detecting/bad_image_detecting_dataset.py
index f3cd9a2f..539b7b25 100644
--- a/modelscope/msdatasets/task_datasets/bad_image_detecting/bad_image_detecting_dataset.py
+++ b/modelscope/msdatasets/dataset_cls/custom_datasets/bad_image_detecting/bad_image_detecting_dataset.py
@@ -1,12 +1,8 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
-import cv2
-import numpy as np
-
from modelscope.metainfo import Models
-from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS
-from modelscope.msdatasets.task_datasets.torch_base_dataset import \
- TorchTaskDataset
+from modelscope.msdatasets.dataset_cls.custom_datasets import (
+ CUSTOM_DATASETS, TorchCustomDataset)
from modelscope.outputs import OutputKeys
from modelscope.preprocessors import LoadImage
from modelscope.preprocessors.cv.bad_image_detecting_preprocessor import \
@@ -14,9 +10,9 @@ from modelscope.preprocessors.cv.bad_image_detecting_preprocessor import \
from modelscope.utils.constant import Tasks
-@TASK_DATASETS.register_module(
+@CUSTOM_DATASETS.register_module(
Tasks.bad_image_detecting, module_name=Models.bad_image_detecting)
-class BadImageDetectingDataset(TorchTaskDataset):
+class BadImageDetectingDataset(TorchCustomDataset):
"""Paired image dataset for bad image detecting.
"""
diff --git a/modelscope/msdatasets/task_datasets/builder.py b/modelscope/msdatasets/dataset_cls/custom_datasets/builder.py
similarity index 56%
rename from modelscope/msdatasets/task_datasets/builder.py
rename to modelscope/msdatasets/dataset_cls/custom_datasets/builder.py
index 683bec8f..a793ea27 100644
--- a/modelscope/msdatasets/task_datasets/builder.py
+++ b/modelscope/msdatasets/dataset_cls/custom_datasets/builder.py
@@ -3,13 +3,13 @@
from modelscope.utils.config import ConfigDict
from modelscope.utils.registry import Registry, build_from_cfg
-TASK_DATASETS = Registry('task_datasets')
+CUSTOM_DATASETS = Registry('custom_datasets')
-def build_task_dataset(cfg: ConfigDict,
- task_name: str = None,
- default_args: dict = None):
- """ Build task specific dataset processor given model config dict and the task name.
+def build_custom_dataset(cfg: ConfigDict,
+ task_name: str,
+ default_args: dict = None):
+ """ Build custom dataset for user-define dataset given model config and task name.
Args:
cfg (:obj:`ConfigDict`): config dict for model object.
@@ -18,4 +18,4 @@ def build_task_dataset(cfg: ConfigDict,
default_args (dict, optional): Default initialization arguments.
"""
return build_from_cfg(
- cfg, TASK_DATASETS, group_key=task_name, default_args=default_args)
+ cfg, CUSTOM_DATASETS, group_key=task_name, default_args=default_args)
diff --git a/modelscope/msdatasets/task_datasets/damoyolo/__init__.py b/modelscope/msdatasets/dataset_cls/custom_datasets/damoyolo/__init__.py
similarity index 75%
rename from modelscope/msdatasets/task_datasets/damoyolo/__init__.py
rename to modelscope/msdatasets/dataset_cls/custom_datasets/damoyolo/__init__.py
index 2a3bccdb..dabde7a4 100644
--- a/modelscope/msdatasets/task_datasets/damoyolo/__init__.py
+++ b/modelscope/msdatasets/dataset_cls/custom_datasets/damoyolo/__init__.py
@@ -1,2 +1,3 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from .build import build_dataloader, build_dataset
+from .evaluation import evaluate
diff --git a/modelscope/msdatasets/task_datasets/damoyolo/build.py b/modelscope/msdatasets/dataset_cls/custom_datasets/damoyolo/build.py
similarity index 100%
rename from modelscope/msdatasets/task_datasets/damoyolo/build.py
rename to modelscope/msdatasets/dataset_cls/custom_datasets/damoyolo/build.py
diff --git a/modelscope/msdatasets/task_datasets/damoyolo/collate_batch.py b/modelscope/msdatasets/dataset_cls/custom_datasets/damoyolo/collate_batch.py
similarity index 100%
rename from modelscope/msdatasets/task_datasets/damoyolo/collate_batch.py
rename to modelscope/msdatasets/dataset_cls/custom_datasets/damoyolo/collate_batch.py
diff --git a/modelscope/msdatasets/task_datasets/damoyolo/datasets/__init__.py b/modelscope/msdatasets/dataset_cls/custom_datasets/damoyolo/datasets/__init__.py
similarity index 100%
rename from modelscope/msdatasets/task_datasets/damoyolo/datasets/__init__.py
rename to modelscope/msdatasets/dataset_cls/custom_datasets/damoyolo/datasets/__init__.py
diff --git a/modelscope/msdatasets/task_datasets/damoyolo/datasets/coco.py b/modelscope/msdatasets/dataset_cls/custom_datasets/damoyolo/datasets/coco.py
similarity index 100%
rename from modelscope/msdatasets/task_datasets/damoyolo/datasets/coco.py
rename to modelscope/msdatasets/dataset_cls/custom_datasets/damoyolo/datasets/coco.py
diff --git a/modelscope/msdatasets/task_datasets/damoyolo/datasets/mosaic_wrapper.py b/modelscope/msdatasets/dataset_cls/custom_datasets/damoyolo/datasets/mosaic_wrapper.py
similarity index 100%
rename from modelscope/msdatasets/task_datasets/damoyolo/datasets/mosaic_wrapper.py
rename to modelscope/msdatasets/dataset_cls/custom_datasets/damoyolo/datasets/mosaic_wrapper.py
diff --git a/modelscope/msdatasets/task_datasets/damoyolo/evaluation/__init__.py b/modelscope/msdatasets/dataset_cls/custom_datasets/damoyolo/evaluation/__init__.py
similarity index 93%
rename from modelscope/msdatasets/task_datasets/damoyolo/evaluation/__init__.py
rename to modelscope/msdatasets/dataset_cls/custom_datasets/damoyolo/evaluation/__init__.py
index b121b80b..b12fbf69 100644
--- a/modelscope/msdatasets/task_datasets/damoyolo/evaluation/__init__.py
+++ b/modelscope/msdatasets/dataset_cls/custom_datasets/damoyolo/evaluation/__init__.py
@@ -1,6 +1,6 @@
# Copyright © Alibaba, Inc. and its affiliates.
-from modelscope.msdatasets.task_datasets.damoyolo import datasets
+from .. import datasets
from .coco import coco_evaluation
diff --git a/modelscope/msdatasets/task_datasets/damoyolo/evaluation/coco/__init__.py b/modelscope/msdatasets/dataset_cls/custom_datasets/damoyolo/evaluation/coco/__init__.py
similarity index 100%
rename from modelscope/msdatasets/task_datasets/damoyolo/evaluation/coco/__init__.py
rename to modelscope/msdatasets/dataset_cls/custom_datasets/damoyolo/evaluation/coco/__init__.py
diff --git a/modelscope/msdatasets/task_datasets/damoyolo/evaluation/coco/coco_eval.py b/modelscope/msdatasets/dataset_cls/custom_datasets/damoyolo/evaluation/coco/coco_eval.py
similarity index 100%
rename from modelscope/msdatasets/task_datasets/damoyolo/evaluation/coco/coco_eval.py
rename to modelscope/msdatasets/dataset_cls/custom_datasets/damoyolo/evaluation/coco/coco_eval.py
diff --git a/modelscope/msdatasets/task_datasets/damoyolo/samplers/__init__.py b/modelscope/msdatasets/dataset_cls/custom_datasets/damoyolo/samplers/__init__.py
similarity index 100%
rename from modelscope/msdatasets/task_datasets/damoyolo/samplers/__init__.py
rename to modelscope/msdatasets/dataset_cls/custom_datasets/damoyolo/samplers/__init__.py
diff --git a/modelscope/msdatasets/task_datasets/damoyolo/samplers/distributed.py b/modelscope/msdatasets/dataset_cls/custom_datasets/damoyolo/samplers/distributed.py
similarity index 100%
rename from modelscope/msdatasets/task_datasets/damoyolo/samplers/distributed.py
rename to modelscope/msdatasets/dataset_cls/custom_datasets/damoyolo/samplers/distributed.py
diff --git a/modelscope/msdatasets/task_datasets/damoyolo/samplers/grouped_batch_sampler.py b/modelscope/msdatasets/dataset_cls/custom_datasets/damoyolo/samplers/grouped_batch_sampler.py
similarity index 100%
rename from modelscope/msdatasets/task_datasets/damoyolo/samplers/grouped_batch_sampler.py
rename to modelscope/msdatasets/dataset_cls/custom_datasets/damoyolo/samplers/grouped_batch_sampler.py
diff --git a/modelscope/msdatasets/task_datasets/damoyolo/samplers/iteration_based_batch_sampler.py b/modelscope/msdatasets/dataset_cls/custom_datasets/damoyolo/samplers/iteration_based_batch_sampler.py
similarity index 100%
rename from modelscope/msdatasets/task_datasets/damoyolo/samplers/iteration_based_batch_sampler.py
rename to modelscope/msdatasets/dataset_cls/custom_datasets/damoyolo/samplers/iteration_based_batch_sampler.py
diff --git a/modelscope/msdatasets/task_datasets/damoyolo/transforms/__init__.py b/modelscope/msdatasets/dataset_cls/custom_datasets/damoyolo/transforms/__init__.py
similarity index 100%
rename from modelscope/msdatasets/task_datasets/damoyolo/transforms/__init__.py
rename to modelscope/msdatasets/dataset_cls/custom_datasets/damoyolo/transforms/__init__.py
diff --git a/modelscope/msdatasets/task_datasets/damoyolo/transforms/build.py b/modelscope/msdatasets/dataset_cls/custom_datasets/damoyolo/transforms/build.py
similarity index 100%
rename from modelscope/msdatasets/task_datasets/damoyolo/transforms/build.py
rename to modelscope/msdatasets/dataset_cls/custom_datasets/damoyolo/transforms/build.py
diff --git a/modelscope/msdatasets/task_datasets/damoyolo/transforms/transforms.py b/modelscope/msdatasets/dataset_cls/custom_datasets/damoyolo/transforms/transforms.py
similarity index 100%
rename from modelscope/msdatasets/task_datasets/damoyolo/transforms/transforms.py
rename to modelscope/msdatasets/dataset_cls/custom_datasets/damoyolo/transforms/transforms.py
diff --git a/modelscope/msdatasets/cv/easycv_base.py b/modelscope/msdatasets/dataset_cls/custom_datasets/easycv_base.py
similarity index 100%
rename from modelscope/msdatasets/cv/easycv_base.py
rename to modelscope/msdatasets/dataset_cls/custom_datasets/easycv_base.py
diff --git a/modelscope/msdatasets/cv/face_2d_keypoins/__init__.py b/modelscope/msdatasets/dataset_cls/custom_datasets/face_2d_keypoins/__init__.py
similarity index 100%
rename from modelscope/msdatasets/cv/face_2d_keypoins/__init__.py
rename to modelscope/msdatasets/dataset_cls/custom_datasets/face_2d_keypoins/__init__.py
diff --git a/modelscope/msdatasets/cv/face_2d_keypoins/face_2d_keypoints_dataset.py b/modelscope/msdatasets/dataset_cls/custom_datasets/face_2d_keypoins/face_2d_keypoints_dataset.py
similarity index 78%
rename from modelscope/msdatasets/cv/face_2d_keypoins/face_2d_keypoints_dataset.py
rename to modelscope/msdatasets/dataset_cls/custom_datasets/face_2d_keypoins/face_2d_keypoints_dataset.py
index 2f2e03ef..9f55901f 100644
--- a/modelscope/msdatasets/cv/face_2d_keypoins/face_2d_keypoints_dataset.py
+++ b/modelscope/msdatasets/dataset_cls/custom_datasets/face_2d_keypoins/face_2d_keypoints_dataset.py
@@ -1,15 +1,16 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from easycv.datasets.face import FaceKeypointDataset as _FaceKeypointDataset
-from modelscope.metainfo import Datasets
-from modelscope.msdatasets.cv.easycv_base import EasyCVBaseDataset
-from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS
+from modelscope.metainfo import CustomDatasets
+from modelscope.msdatasets.dataset_cls.custom_datasets import CUSTOM_DATASETS
+from modelscope.msdatasets.dataset_cls.custom_datasets.easycv_base import \
+ EasyCVBaseDataset
from modelscope.utils.constant import Tasks
-@TASK_DATASETS.register_module(
+@CUSTOM_DATASETS.register_module(
group_key=Tasks.face_2d_keypoints,
- module_name=Datasets.Face2dKeypointsDataset)
+ module_name=CustomDatasets.Face2dKeypointsDataset)
class FaceKeypointDataset(EasyCVBaseDataset, _FaceKeypointDataset):
"""EasyCV dataset for face 2d keypoints.
diff --git a/modelscope/msdatasets/dataset_cls/custom_datasets/gopro_image_deblurring_dataset.py b/modelscope/msdatasets/dataset_cls/custom_datasets/gopro_image_deblurring_dataset.py
new file mode 100644
index 00000000..47943885
--- /dev/null
+++ b/modelscope/msdatasets/dataset_cls/custom_datasets/gopro_image_deblurring_dataset.py
@@ -0,0 +1,63 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+
+import cv2
+import numpy as np
+
+from modelscope.metainfo import CustomDatasets
+from modelscope.msdatasets.dataset_cls.custom_datasets import (
+ CUSTOM_DATASETS, TorchCustomDataset)
+from modelscope.msdatasets.dataset_cls.custom_datasets.sidd_image_denoising.data_utils import (
+ img2tensor, padding)
+from modelscope.msdatasets.dataset_cls.custom_datasets.sidd_image_denoising.transforms import (
+ augment, paired_random_crop)
+from modelscope.utils.constant import Tasks
+
+
+def default_loader(path):
+ return cv2.imread(path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.0
+
+
+@CUSTOM_DATASETS.register_module(
+ Tasks.image_deblurring, module_name=CustomDatasets.GoproDataset)
+class GoproImageDeblurringDataset(TorchCustomDataset):
+ """Paired image dataset for image restoration.
+ """
+
+ def __init__(self, dataset, opt, is_train):
+ self.dataset = dataset
+ self.opt = opt
+ self.is_train = is_train
+
+ def __len__(self):
+ return len(self.dataset)
+
+ def __getitem__(self, index):
+
+ # Load gt and lq images. Dimension order: HWC; channel order: BGR;
+ # image range: [0, 1], float32.
+ item_dict = self.dataset[index]
+ gt_path = item_dict['Sharp Image:FILE']
+ img_gt = default_loader(gt_path)
+ lq_path = item_dict['Blur Image:FILE']
+ img_lq = default_loader(lq_path)
+
+ # augmentation for training
+ if self.is_train:
+ gt_size = self.opt.gt_size
+ # padding
+ img_gt, img_lq = padding(img_gt, img_lq, gt_size)
+
+ # random crop
+ img_gt, img_lq = paired_random_crop(
+ img_gt, img_lq, gt_size, scale=1)
+
+ # flip, rotation
+ img_gt, img_lq = augment([img_gt, img_lq], self.opt.use_flip,
+ self.opt.use_rot)
+
+ # BGR to RGB, HWC to CHW, numpy to tensor
+ img_gt, img_lq = img2tensor([img_gt, img_lq],
+ bgr2rgb=True,
+ float32=True)
+
+ return {'input': img_lq, 'target': img_gt}
diff --git a/modelscope/msdatasets/dataset_cls/custom_datasets/hand_2d_keypoints/__init__.py b/modelscope/msdatasets/dataset_cls/custom_datasets/hand_2d_keypoints/__init__.py
new file mode 100644
index 00000000..3af670e3
--- /dev/null
+++ b/modelscope/msdatasets/dataset_cls/custom_datasets/hand_2d_keypoints/__init__.py
@@ -0,0 +1,22 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+from typing import TYPE_CHECKING
+
+from modelscope.utils.import_utils import LazyImportModule
+
+if TYPE_CHECKING:
+ from .hand_2d_keypoints_dataset import HandCocoWholeBodyDataset
+
+else:
+ _import_structure = {
+ 'hand_2d_keypoints_dataset': ['HandCocoWholeBodyDataset']
+ }
+
+ import sys
+
+ sys.modules[__name__] = LazyImportModule(
+ __name__,
+ globals()['__file__'],
+ _import_structure,
+ module_spec=__spec__,
+ extra_objects={},
+ )
diff --git a/modelscope/msdatasets/cv/hand_2d_keypoints/hand_2d_keypoints_dataset.py b/modelscope/msdatasets/dataset_cls/custom_datasets/hand_2d_keypoints/hand_2d_keypoints_dataset.py
similarity index 79%
rename from modelscope/msdatasets/cv/hand_2d_keypoints/hand_2d_keypoints_dataset.py
rename to modelscope/msdatasets/dataset_cls/custom_datasets/hand_2d_keypoints/hand_2d_keypoints_dataset.py
index 89ee0bb8..c6163715 100644
--- a/modelscope/msdatasets/cv/hand_2d_keypoints/hand_2d_keypoints_dataset.py
+++ b/modelscope/msdatasets/dataset_cls/custom_datasets/hand_2d_keypoints/hand_2d_keypoints_dataset.py
@@ -2,15 +2,16 @@
from easycv.datasets.pose import \
HandCocoWholeBodyDataset as _HandCocoWholeBodyDataset
-from modelscope.metainfo import Datasets
-from modelscope.msdatasets.cv.easycv_base import EasyCVBaseDataset
-from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS
+from modelscope.metainfo import CustomDatasets
+from modelscope.msdatasets.dataset_cls.custom_datasets import CUSTOM_DATASETS
+from modelscope.msdatasets.dataset_cls.custom_datasets.easycv_base import \
+ EasyCVBaseDataset
from modelscope.utils.constant import Tasks
-@TASK_DATASETS.register_module(
+@CUSTOM_DATASETS.register_module(
group_key=Tasks.hand_2d_keypoints,
- module_name=Datasets.HandCocoWholeBodyDataset)
+ module_name=CustomDatasets.HandCocoWholeBodyDataset)
class HandCocoWholeBodyDataset(EasyCVBaseDataset, _HandCocoWholeBodyDataset):
"""EasyCV dataset for human hand 2d keypoints.
diff --git a/modelscope/msdatasets/cv/human_wholebody_keypoint/__init__.py b/modelscope/msdatasets/dataset_cls/custom_datasets/human_wholebody_keypoint/__init__.py
similarity index 100%
rename from modelscope/msdatasets/cv/human_wholebody_keypoint/__init__.py
rename to modelscope/msdatasets/dataset_cls/custom_datasets/human_wholebody_keypoint/__init__.py
diff --git a/modelscope/msdatasets/cv/human_wholebody_keypoint/human_wholebody_keypoint_dataset.py b/modelscope/msdatasets/dataset_cls/custom_datasets/human_wholebody_keypoint/human_wholebody_keypoint_dataset.py
similarity index 79%
rename from modelscope/msdatasets/cv/human_wholebody_keypoint/human_wholebody_keypoint_dataset.py
rename to modelscope/msdatasets/dataset_cls/custom_datasets/human_wholebody_keypoint/human_wholebody_keypoint_dataset.py
index fc9469f2..59c97af8 100644
--- a/modelscope/msdatasets/cv/human_wholebody_keypoint/human_wholebody_keypoint_dataset.py
+++ b/modelscope/msdatasets/dataset_cls/custom_datasets/human_wholebody_keypoint/human_wholebody_keypoint_dataset.py
@@ -2,15 +2,16 @@
from easycv.datasets.pose import \
WholeBodyCocoTopDownDataset as _WholeBodyCocoTopDownDataset
-from modelscope.metainfo import Datasets
-from modelscope.msdatasets.cv.easycv_base import EasyCVBaseDataset
-from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS
+from modelscope.metainfo import CustomDatasets
+from modelscope.msdatasets.dataset_cls.custom_datasets import CUSTOM_DATASETS
+from modelscope.msdatasets.dataset_cls.custom_datasets.easycv_base import \
+ EasyCVBaseDataset
from modelscope.utils.constant import Tasks
-@TASK_DATASETS.register_module(
+@CUSTOM_DATASETS.register_module(
group_key=Tasks.human_wholebody_keypoint,
- module_name=Datasets.HumanWholeBodyKeypointDataset)
+ module_name=CustomDatasets.HumanWholeBodyKeypointDataset)
class WholeBodyCocoTopDownDataset(EasyCVBaseDataset,
_WholeBodyCocoTopDownDataset):
"""EasyCV dataset for human whole body 2d keypoints.
diff --git a/modelscope/msdatasets/cv/image_classification/__init__.py b/modelscope/msdatasets/dataset_cls/custom_datasets/image_classification/__init__.py
similarity index 100%
rename from modelscope/msdatasets/cv/image_classification/__init__.py
rename to modelscope/msdatasets/dataset_cls/custom_datasets/image_classification/__init__.py
diff --git a/modelscope/msdatasets/cv/image_classification/classification_dataset.py b/modelscope/msdatasets/dataset_cls/custom_datasets/image_classification/classification_dataset.py
similarity index 75%
rename from modelscope/msdatasets/cv/image_classification/classification_dataset.py
rename to modelscope/msdatasets/dataset_cls/custom_datasets/image_classification/classification_dataset.py
index ba73e472..386810c7 100644
--- a/modelscope/msdatasets/cv/image_classification/classification_dataset.py
+++ b/modelscope/msdatasets/dataset_cls/custom_datasets/image_classification/classification_dataset.py
@@ -1,14 +1,16 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from easycv.datasets.classification import ClsDataset as _ClsDataset
-from modelscope.metainfo import Datasets
-from modelscope.msdatasets.cv.easycv_base import EasyCVBaseDataset
-from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS
+from modelscope.metainfo import CustomDatasets
+from modelscope.msdatasets.dataset_cls.custom_datasets import CUSTOM_DATASETS
+from modelscope.msdatasets.dataset_cls.custom_datasets.easycv_base import \
+ EasyCVBaseDataset
from modelscope.utils.constant import Tasks
-@TASK_DATASETS.register_module(
- group_key=Tasks.image_classification, module_name=Datasets.ClsDataset)
+@CUSTOM_DATASETS.register_module(
+ group_key=Tasks.image_classification,
+ module_name=CustomDatasets.ClsDataset)
class ClsDataset(_ClsDataset):
"""EasyCV dataset for classification.
diff --git a/modelscope/msdatasets/dataset_cls/custom_datasets/image_colorization/__init__.py b/modelscope/msdatasets/dataset_cls/custom_datasets/image_colorization/__init__.py
new file mode 100644
index 00000000..3ab45a2e
--- /dev/null
+++ b/modelscope/msdatasets/dataset_cls/custom_datasets/image_colorization/__init__.py
@@ -0,0 +1,22 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+from typing import TYPE_CHECKING
+
+from modelscope.utils.import_utils import LazyImportModule
+
+if TYPE_CHECKING:
+ from .image_colorization_dataset import ImageColorizationDataset
+
+else:
+ _import_structure = {
+ 'image_colorization_dataset': ['ImageColorizationDataset'],
+ }
+
+ import sys
+
+ sys.modules[__name__] = LazyImportModule(
+ __name__,
+ globals()['__file__'],
+ _import_structure,
+ module_spec=__spec__,
+ extra_objects={},
+ )
diff --git a/modelscope/msdatasets/dataset_cls/custom_datasets/image_colorization/image_colorization_dataset.py b/modelscope/msdatasets/dataset_cls/custom_datasets/image_colorization/image_colorization_dataset.py
new file mode 100644
index 00000000..06132473
--- /dev/null
+++ b/modelscope/msdatasets/dataset_cls/custom_datasets/image_colorization/image_colorization_dataset.py
@@ -0,0 +1,67 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+
+import cv2
+import numpy as np
+import torch
+
+from modelscope.metainfo import Models
+from modelscope.msdatasets.dataset_cls.custom_datasets import (
+ CUSTOM_DATASETS, TorchCustomDataset)
+from modelscope.utils.constant import Tasks
+
+
+def default_loader(path):
+ return cv2.imread(path).astype(np.float32) / 255.0
+
+
+@CUSTOM_DATASETS.register_module(
+ Tasks.image_colorization, module_name=Models.ddcolor)
+class ImageColorizationDataset(TorchCustomDataset):
+ """Image dataset for image colorization.
+ """
+
+ def __init__(self, dataset, opt, is_train):
+ self.dataset = dataset
+ self.opt = opt
+ self.input_size = 256
+ self.is_train = is_train
+
+ def __len__(self):
+ return len(self.dataset)
+
+ def __getitem__(self, index):
+ # Load gt images. Dimension order: HWC; channel order: BGR;
+ # image range: [0, 1], float32.
+ item_dict = self.dataset[index]
+ gt_path = item_dict['Image:FILE']
+ img_gt = default_loader(gt_path)
+
+ # rezise to 256
+ img_gt = cv2.resize(img_gt, (self.input_size, self.input_size))
+
+ # get lq
+ img_l = cv2.cvtColor(img_gt, cv2.COLOR_BGR2Lab)[:, :, :1]
+ img_gray_lab = np.concatenate(
+ (img_l, np.zeros_like(img_l), np.zeros_like(img_l)), axis=-1)
+ img_gray_rgb = cv2.cvtColor(img_gray_lab, cv2.COLOR_LAB2RGB)
+ tensor_lq_rgb = torch.from_numpy(img_gray_rgb.transpose(
+ (2, 0, 1))).float()
+ tensor_lq = torch.from_numpy(img_l.transpose((2, 0, 1))).float()
+
+ # get ab
+ img_ab = cv2.cvtColor(img_gt, cv2.COLOR_BGR2Lab)[:, :, 1:]
+ tensor_gt_ab = torch.from_numpy(img_ab.transpose((2, 0, 1))).float()
+
+ # gt_bgr
+ img_gt_rgb = cv2.cvtColor(img_gt, cv2.COLOR_BGR2RGB)
+ gt_rgb = torch.from_numpy(img_gt_rgb.transpose((2, 0, 1))).float()
+
+ if self.is_train:
+ return {'input': tensor_lq_rgb, 'target': tensor_gt_ab}
+ else:
+ return {
+ 'input': tensor_lq_rgb,
+ 'target': tensor_gt_ab,
+ 'img_l': tensor_lq,
+ 'gt_rgb': gt_rgb
+ }
diff --git a/modelscope/msdatasets/dataset_cls/custom_datasets/image_inpainting/__init__.py b/modelscope/msdatasets/dataset_cls/custom_datasets/image_inpainting/__init__.py
new file mode 100644
index 00000000..0c9552bd
--- /dev/null
+++ b/modelscope/msdatasets/dataset_cls/custom_datasets/image_inpainting/__init__.py
@@ -0,0 +1,20 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+from typing import TYPE_CHECKING
+
+from modelscope.utils.import_utils import LazyImportModule
+
+if TYPE_CHECKING:
+ from .image_inpainting_dataset import ImageInpaintingDataset
+else:
+ _import_structure = {
+ 'image_inpainting_dataset': ['ImageInpaintingDataset'],
+ }
+ import sys
+
+ sys.modules[__name__] = LazyImportModule(
+ __name__,
+ globals()['__file__'],
+ _import_structure,
+ module_spec=__spec__,
+ extra_objects={},
+ )
diff --git a/modelscope/msdatasets/task_datasets/image_inpainting/aug.py b/modelscope/msdatasets/dataset_cls/custom_datasets/image_inpainting/aug.py
similarity index 100%
rename from modelscope/msdatasets/task_datasets/image_inpainting/aug.py
rename to modelscope/msdatasets/dataset_cls/custom_datasets/image_inpainting/aug.py
diff --git a/modelscope/msdatasets/task_datasets/image_inpainting/image_inpainting_dataset.py b/modelscope/msdatasets/dataset_cls/custom_datasets/image_inpainting/image_inpainting_dataset.py
similarity index 97%
rename from modelscope/msdatasets/task_datasets/image_inpainting/image_inpainting_dataset.py
rename to modelscope/msdatasets/dataset_cls/custom_datasets/image_inpainting/image_inpainting_dataset.py
index 057b8f88..c7040c86 100644
--- a/modelscope/msdatasets/task_datasets/image_inpainting/image_inpainting_dataset.py
+++ b/modelscope/msdatasets/dataset_cls/custom_datasets/image_inpainting/image_inpainting_dataset.py
@@ -3,20 +3,16 @@ Part of the implementation is borrowed and modified from LaMa,
publicly available at https://github.com/saic-mdal/lama
"""
import glob
-import os
import os.path as osp
from enum import Enum
import albumentations as A
import cv2
-import json
import numpy as np
-import torch
from modelscope.metainfo import Models
-from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS
-from modelscope.msdatasets.task_datasets.torch_base_dataset import \
- TorchTaskDataset
+from modelscope.msdatasets.dataset_cls.custom_datasets import (
+ CUSTOM_DATASETS, TorchCustomDataset)
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger
from .aug import IAAAffine2, IAAPerspective2
@@ -296,9 +292,9 @@ def get_transforms(test_mode, out_size):
return transform
-@TASK_DATASETS.register_module(
+@CUSTOM_DATASETS.register_module(
Tasks.image_inpainting, module_name=Models.image_inpainting)
-class ImageInpaintingDataset(TorchTaskDataset):
+class ImageInpaintingDataset(TorchCustomDataset):
def __init__(self, **kwargs):
split_config = kwargs['split_config']
diff --git a/modelscope/msdatasets/task_datasets/image_instance_segmentation_coco_dataset.py b/modelscope/msdatasets/dataset_cls/custom_datasets/image_instance_segmentation_coco_dataset.py
similarity index 98%
rename from modelscope/msdatasets/task_datasets/image_instance_segmentation_coco_dataset.py
rename to modelscope/msdatasets/dataset_cls/custom_datasets/image_instance_segmentation_coco_dataset.py
index 1c7bc249..4dd1af5a 100644
--- a/modelscope/msdatasets/task_datasets/image_instance_segmentation_coco_dataset.py
+++ b/modelscope/msdatasets/dataset_cls/custom_datasets/image_instance_segmentation_coco_dataset.py
@@ -6,9 +6,9 @@ import numpy as np
from pycocotools.coco import COCO
from modelscope.metainfo import Models
+from modelscope.msdatasets.dataset_cls.custom_datasets import (
+ CUSTOM_DATASETS, TorchCustomDataset)
from modelscope.utils.constant import Tasks
-from .builder import TASK_DATASETS
-from .torch_base_dataset import TorchTaskDataset
DATASET_STRUCTURE = {
'train': {
@@ -22,10 +22,10 @@ DATASET_STRUCTURE = {
}
-@TASK_DATASETS.register_module(
+@CUSTOM_DATASETS.register_module(
module_name=Models.cascade_mask_rcnn_swin,
group_key=Tasks.image_segmentation)
-class ImageInstanceSegmentationCocoDataset(TorchTaskDataset):
+class ImageInstanceSegmentationCocoDataset(TorchCustomDataset):
"""Coco-style dataset for image instance segmentation.
Args:
diff --git a/modelscope/msdatasets/task_datasets/image_portrait_enhancement/__init__.py b/modelscope/msdatasets/dataset_cls/custom_datasets/image_portrait_enhancement/__init__.py
similarity index 100%
rename from modelscope/msdatasets/task_datasets/image_portrait_enhancement/__init__.py
rename to modelscope/msdatasets/dataset_cls/custom_datasets/image_portrait_enhancement/__init__.py
diff --git a/modelscope/msdatasets/task_datasets/image_portrait_enhancement/data_utils.py b/modelscope/msdatasets/dataset_cls/custom_datasets/image_portrait_enhancement/data_utils.py
similarity index 100%
rename from modelscope/msdatasets/task_datasets/image_portrait_enhancement/data_utils.py
rename to modelscope/msdatasets/dataset_cls/custom_datasets/image_portrait_enhancement/data_utils.py
diff --git a/modelscope/msdatasets/task_datasets/image_portrait_enhancement/image_portrait_enhancement_dataset.py b/modelscope/msdatasets/dataset_cls/custom_datasets/image_portrait_enhancement/image_portrait_enhancement_dataset.py
similarity index 77%
rename from modelscope/msdatasets/task_datasets/image_portrait_enhancement/image_portrait_enhancement_dataset.py
rename to modelscope/msdatasets/dataset_cls/custom_datasets/image_portrait_enhancement/image_portrait_enhancement_dataset.py
index 58d40778..d2c03408 100644
--- a/modelscope/msdatasets/task_datasets/image_portrait_enhancement/image_portrait_enhancement_dataset.py
+++ b/modelscope/msdatasets/dataset_cls/custom_datasets/image_portrait_enhancement/image_portrait_enhancement_dataset.py
@@ -3,10 +3,9 @@
import cv2
import numpy as np
-from modelscope.metainfo import Datasets, Models
-from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS
-from modelscope.msdatasets.task_datasets.torch_base_dataset import \
- TorchTaskDataset
+from modelscope.metainfo import CustomDatasets
+from modelscope.msdatasets.dataset_cls.custom_datasets import (
+ CUSTOM_DATASETS, TorchCustomDataset)
from modelscope.utils.constant import Tasks
from .data_utils import img2tensor
@@ -15,9 +14,9 @@ def default_loader(path):
return cv2.imread(path, cv2.IMREAD_COLOR).astype(np.float32) / 255.0
-@TASK_DATASETS.register_module(
- Tasks.image_portrait_enhancement, module_name=Datasets.PairedDataset)
-class ImagePortraitEnhancementDataset(TorchTaskDataset):
+@CUSTOM_DATASETS.register_module(
+ Tasks.image_portrait_enhancement, module_name=CustomDatasets.PairedDataset)
+class ImagePortraitEnhancementDataset(TorchCustomDataset):
"""Paired image dataset for image portrait enhancement.
"""
diff --git a/modelscope/msdatasets/task_datasets/image_quality_assessment_degradation/__init__.py b/modelscope/msdatasets/dataset_cls/custom_datasets/image_quality_assessment_degradation/__init__.py
similarity index 100%
rename from modelscope/msdatasets/task_datasets/image_quality_assessment_degradation/__init__.py
rename to modelscope/msdatasets/dataset_cls/custom_datasets/image_quality_assessment_degradation/__init__.py
diff --git a/modelscope/msdatasets/task_datasets/image_quality_assessment_degradation/image_quality_assessment_degradation_dataset.py b/modelscope/msdatasets/dataset_cls/custom_datasets/image_quality_assessment_degradation/image_quality_assessment_degradation_dataset.py
similarity index 81%
rename from modelscope/msdatasets/task_datasets/image_quality_assessment_degradation/image_quality_assessment_degradation_dataset.py
rename to modelscope/msdatasets/dataset_cls/custom_datasets/image_quality_assessment_degradation/image_quality_assessment_degradation_dataset.py
index 75826065..06f0453e 100644
--- a/modelscope/msdatasets/task_datasets/image_quality_assessment_degradation/image_quality_assessment_degradation_dataset.py
+++ b/modelscope/msdatasets/dataset_cls/custom_datasets/image_quality_assessment_degradation/image_quality_assessment_degradation_dataset.py
@@ -1,21 +1,18 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
-import cv2
-import numpy as np
from torchvision import transforms
from modelscope.metainfo import Models
-from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS
-from modelscope.msdatasets.task_datasets.torch_base_dataset import \
- TorchTaskDataset
+from modelscope.msdatasets.dataset_cls.custom_datasets import (
+ CUSTOM_DATASETS, TorchCustomDataset)
from modelscope.preprocessors import LoadImage
from modelscope.utils.constant import Tasks
-@TASK_DATASETS.register_module(
+@CUSTOM_DATASETS.register_module(
Tasks.image_quality_assessment_degradation,
module_name=Models.image_quality_assessment_degradation)
-class ImageQualityAssessmentDegradationDataset(TorchTaskDataset):
+class ImageQualityAssessmentDegradationDataset(TorchCustomDataset):
"""Paired image dataset for image quality assessment degradation.
"""
diff --git a/modelscope/msdatasets/task_datasets/image_quality_assmessment_mos/__init__.py b/modelscope/msdatasets/dataset_cls/custom_datasets/image_quality_assmessment_mos/__init__.py
similarity index 100%
rename from modelscope/msdatasets/task_datasets/image_quality_assmessment_mos/__init__.py
rename to modelscope/msdatasets/dataset_cls/custom_datasets/image_quality_assmessment_mos/__init__.py
diff --git a/modelscope/msdatasets/task_datasets/image_quality_assmessment_mos/image_quality_assessment_mos_dataset.py b/modelscope/msdatasets/dataset_cls/custom_datasets/image_quality_assmessment_mos/image_quality_assessment_mos_dataset.py
similarity index 77%
rename from modelscope/msdatasets/task_datasets/image_quality_assmessment_mos/image_quality_assessment_mos_dataset.py
rename to modelscope/msdatasets/dataset_cls/custom_datasets/image_quality_assmessment_mos/image_quality_assessment_mos_dataset.py
index 3d8ed297..28c163eb 100644
--- a/modelscope/msdatasets/task_datasets/image_quality_assmessment_mos/image_quality_assessment_mos_dataset.py
+++ b/modelscope/msdatasets/dataset_cls/custom_datasets/image_quality_assmessment_mos/image_quality_assessment_mos_dataset.py
@@ -1,20 +1,16 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
-import cv2
-import numpy as np
-
from modelscope.metainfo import Models
-from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS
-from modelscope.msdatasets.task_datasets.torch_base_dataset import \
- TorchTaskDataset
+from modelscope.msdatasets.dataset_cls.custom_datasets import (
+ CUSTOM_DATASETS, TorchCustomDataset)
from modelscope.preprocessors.cv import ImageQualityAssessmentMosPreprocessor
from modelscope.utils.constant import Tasks
-@TASK_DATASETS.register_module(
+@CUSTOM_DATASETS.register_module(
Tasks.image_quality_assessment_mos,
module_name=Models.image_quality_assessment_mos)
-class ImageQualityAssessmentMosDataset(TorchTaskDataset):
+class ImageQualityAssessmentMosDataset(TorchCustomDataset):
"""Paired image dataset for image quality assessment mos.
"""
diff --git a/modelscope/msdatasets/cv/image_semantic_segmentation/__init__.py b/modelscope/msdatasets/dataset_cls/custom_datasets/image_semantic_segmentation/__init__.py
similarity index 100%
rename from modelscope/msdatasets/cv/image_semantic_segmentation/__init__.py
rename to modelscope/msdatasets/dataset_cls/custom_datasets/image_semantic_segmentation/__init__.py
diff --git a/modelscope/msdatasets/cv/image_semantic_segmentation/segmentation_dataset.py b/modelscope/msdatasets/dataset_cls/custom_datasets/image_semantic_segmentation/segmentation_dataset.py
similarity index 81%
rename from modelscope/msdatasets/cv/image_semantic_segmentation/segmentation_dataset.py
rename to modelscope/msdatasets/dataset_cls/custom_datasets/image_semantic_segmentation/segmentation_dataset.py
index b1316e2e..71e7c42b 100644
--- a/modelscope/msdatasets/cv/image_semantic_segmentation/segmentation_dataset.py
+++ b/modelscope/msdatasets/dataset_cls/custom_datasets/image_semantic_segmentation/segmentation_dataset.py
@@ -1,14 +1,15 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from easycv.datasets.segmentation import SegDataset as _SegDataset
-from modelscope.metainfo import Datasets
-from modelscope.msdatasets.cv.easycv_base import EasyCVBaseDataset
-from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS
+from modelscope.metainfo import CustomDatasets
+from modelscope.msdatasets.dataset_cls.custom_datasets import CUSTOM_DATASETS
+from modelscope.msdatasets.dataset_cls.custom_datasets.easycv_base import \
+ EasyCVBaseDataset
from modelscope.utils.constant import Tasks
-@TASK_DATASETS.register_module(
- group_key=Tasks.image_segmentation, module_name=Datasets.SegDataset)
+@CUSTOM_DATASETS.register_module(
+ group_key=Tasks.image_segmentation, module_name=CustomDatasets.SegDataset)
class SegDataset(EasyCVBaseDataset, _SegDataset):
"""EasyCV dataset for Sementic segmentation.
For more details, please refer to :
diff --git a/modelscope/msdatasets/task_datasets/language_guided_video_summarization_dataset.py b/modelscope/msdatasets/dataset_cls/custom_datasets/language_guided_video_summarization_dataset.py
similarity index 94%
rename from modelscope/msdatasets/task_datasets/language_guided_video_summarization_dataset.py
rename to modelscope/msdatasets/dataset_cls/custom_datasets/language_guided_video_summarization_dataset.py
index 94313e15..756d0050 100644
--- a/modelscope/msdatasets/task_datasets/language_guided_video_summarization_dataset.py
+++ b/modelscope/msdatasets/dataset_cls/custom_datasets/language_guided_video_summarization_dataset.py
@@ -25,16 +25,15 @@ import numpy as np
import torch
from modelscope.metainfo import Models
-from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS
-from modelscope.msdatasets.task_datasets.torch_base_dataset import \
- TorchTaskDataset
+from modelscope.msdatasets.dataset_cls.custom_datasets import (
+ CUSTOM_DATASETS, TorchCustomDataset)
from modelscope.utils.constant import Tasks
-@TASK_DATASETS.register_module(
+@CUSTOM_DATASETS.register_module(
Tasks.language_guided_video_summarization,
module_name=Models.language_guided_video_summarization)
-class LanguageGuidedVideoSummarizationDataset(TorchTaskDataset):
+class LanguageGuidedVideoSummarizationDataset(TorchCustomDataset):
def __init__(self, mode, opt, root_dir):
self.mode = mode
diff --git a/modelscope/msdatasets/task_datasets/mgeo_ranking_dataset.py b/modelscope/msdatasets/dataset_cls/custom_datasets/mgeo_ranking_dataset.py
similarity index 93%
rename from modelscope/msdatasets/task_datasets/mgeo_ranking_dataset.py
rename to modelscope/msdatasets/dataset_cls/custom_datasets/mgeo_ranking_dataset.py
index 9adccd7c..536451ae 100644
--- a/modelscope/msdatasets/task_datasets/mgeo_ranking_dataset.py
+++ b/modelscope/msdatasets/dataset_cls/custom_datasets/mgeo_ranking_dataset.py
@@ -1,24 +1,20 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import random
-from dataclasses import dataclass
-from typing import Any, Dict, List, Tuple, Union
+from typing import Any, List, Union
import json
import torch
-from datasets import Dataset, IterableDataset, concatenate_datasets
from torch.utils.data import ConcatDataset
-from transformers import DataCollatorWithPadding
from modelscope.metainfo import Models
+from modelscope.msdatasets.dataset_cls.custom_datasets import (
+ CUSTOM_DATASETS, TorchCustomDataset)
from modelscope.utils.constant import ModeKeys, Tasks
-from .base import TaskDataset
-from .builder import TASK_DATASETS
-from .torch_base_dataset import TorchTaskDataset
-@TASK_DATASETS.register_module(
+@CUSTOM_DATASETS.register_module(
group_key=Tasks.text_ranking, module_name=Models.mgeo)
-class MGeoRankingDataset(TorchTaskDataset):
+class MGeoRankingDataset(TorchCustomDataset):
def __init__(self,
datasets: Union[Any, List[Any]],
diff --git a/modelscope/msdatasets/dataset_cls/custom_datasets/movie_scene_segmentation/__init__.py b/modelscope/msdatasets/dataset_cls/custom_datasets/movie_scene_segmentation/__init__.py
new file mode 100644
index 00000000..6157e9e8
--- /dev/null
+++ b/modelscope/msdatasets/dataset_cls/custom_datasets/movie_scene_segmentation/__init__.py
@@ -0,0 +1,20 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+from typing import TYPE_CHECKING
+
+from modelscope.utils.import_utils import LazyImportModule
+
+if TYPE_CHECKING:
+ from .movie_scene_segmentation_dataset import MovieSceneSegmentationDataset
+else:
+ _import_structure = {
+ 'movie_scene_segmentation_dataset': ['MovieSceneSegmentationDataset'],
+ }
+ import sys
+
+ sys.modules[__name__] = LazyImportModule(
+ __name__,
+ globals()['__file__'],
+ _import_structure,
+ module_spec=__spec__,
+ extra_objects={},
+ )
diff --git a/modelscope/msdatasets/task_datasets/movie_scene_segmentation/movie_scene_segmentation_dataset.py b/modelscope/msdatasets/dataset_cls/custom_datasets/movie_scene_segmentation/movie_scene_segmentation_dataset.py
similarity index 94%
rename from modelscope/msdatasets/task_datasets/movie_scene_segmentation/movie_scene_segmentation_dataset.py
rename to modelscope/msdatasets/dataset_cls/custom_datasets/movie_scene_segmentation/movie_scene_segmentation_dataset.py
index 49991b11..041976dd 100644
--- a/modelscope/msdatasets/task_datasets/movie_scene_segmentation/movie_scene_segmentation_dataset.py
+++ b/modelscope/msdatasets/dataset_cls/custom_datasets/movie_scene_segmentation/movie_scene_segmentation_dataset.py
@@ -10,9 +10,8 @@ import torch
from torchvision.datasets.folder import pil_loader
from modelscope.metainfo import Models
-from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS
-from modelscope.msdatasets.task_datasets.torch_base_dataset import \
- TorchTaskDataset
+from modelscope.msdatasets.dataset_cls.custom_datasets.builder import \
+ CUSTOM_DATASETS
from modelscope.utils.constant import Tasks
from . import sampler
@@ -30,9 +29,9 @@ DATASET_STRUCTURE = {
}
-@TASK_DATASETS.register_module(
- Tasks.movie_scene_segmentation, module_name=Models.resnet50_bert)
-class MovieSceneSegmentationDataset(TorchTaskDataset):
+@CUSTOM_DATASETS.register_module(
+ group_key=Tasks.movie_scene_segmentation, module_name=Models.resnet50_bert)
+class MovieSceneSegmentationDataset(torch.utils.data.Dataset):
"""dataset for movie scene segmentation.
Args:
diff --git a/modelscope/msdatasets/task_datasets/movie_scene_segmentation/sampler.py b/modelscope/msdatasets/dataset_cls/custom_datasets/movie_scene_segmentation/sampler.py
similarity index 100%
rename from modelscope/msdatasets/task_datasets/movie_scene_segmentation/sampler.py
rename to modelscope/msdatasets/dataset_cls/custom_datasets/movie_scene_segmentation/sampler.py
diff --git a/modelscope/msdatasets/cv/object_detection/__init__.py b/modelscope/msdatasets/dataset_cls/custom_datasets/object_detection/__init__.py
similarity index 100%
rename from modelscope/msdatasets/cv/object_detection/__init__.py
rename to modelscope/msdatasets/dataset_cls/custom_datasets/object_detection/__init__.py
diff --git a/modelscope/msdatasets/cv/object_detection/detection_dataset.py b/modelscope/msdatasets/dataset_cls/custom_datasets/object_detection/detection_dataset.py
similarity index 85%
rename from modelscope/msdatasets/cv/object_detection/detection_dataset.py
rename to modelscope/msdatasets/dataset_cls/custom_datasets/object_detection/detection_dataset.py
index c7e45eea..66c11f64 100644
--- a/modelscope/msdatasets/cv/object_detection/detection_dataset.py
+++ b/modelscope/msdatasets/dataset_cls/custom_datasets/object_detection/detection_dataset.py
@@ -1,20 +1,21 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
-import os.path as osp
from easycv.datasets.detection import DetDataset as _DetDataset
from easycv.datasets.detection import \
DetImagesMixDataset as _DetImagesMixDataset
-from modelscope.metainfo import Datasets
-from modelscope.msdatasets.cv.easycv_base import EasyCVBaseDataset
-from modelscope.msdatasets.task_datasets import TASK_DATASETS
+from modelscope.metainfo import CustomDatasets
+from modelscope.msdatasets.dataset_cls.custom_datasets import CUSTOM_DATASETS
+from modelscope.msdatasets.dataset_cls.custom_datasets.easycv_base import \
+ EasyCVBaseDataset
from modelscope.utils.constant import Tasks
-@TASK_DATASETS.register_module(
- group_key=Tasks.image_object_detection, module_name=Datasets.DetDataset)
-@TASK_DATASETS.register_module(
- group_key=Tasks.image_segmentation, module_name=Datasets.DetDataset)
+@CUSTOM_DATASETS.register_module(
+ group_key=Tasks.image_object_detection,
+ module_name=CustomDatasets.DetDataset)
+@CUSTOM_DATASETS.register_module(
+ group_key=Tasks.image_segmentation, module_name=CustomDatasets.DetDataset)
class DetDataset(EasyCVBaseDataset, _DetDataset):
"""EasyCV dataset for object detection.
For more details, please refer to https://github.com/alibaba/EasyCV/blob/master/easycv/datasets/detection/raw.py .
@@ -47,12 +48,12 @@ class DetDataset(EasyCVBaseDataset, _DetDataset):
_DetDataset.__init__(self, *args, **kwargs)
-@TASK_DATASETS.register_module(
+@CUSTOM_DATASETS.register_module(
group_key=Tasks.image_object_detection,
- module_name=Datasets.DetImagesMixDataset)
-@TASK_DATASETS.register_module(
+ module_name=CustomDatasets.DetImagesMixDataset)
+@CUSTOM_DATASETS.register_module(
group_key=Tasks.domain_specific_object_detection,
- module_name=Datasets.DetImagesMixDataset)
+ module_name=CustomDatasets.DetImagesMixDataset)
class DetImagesMixDataset(EasyCVBaseDataset, _DetImagesMixDataset):
"""EasyCV dataset for object detection, a wrapper of multiple images mixed dataset.
Suitable for training on multiple images mixed data augmentation like
diff --git a/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/__init__.py b/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/__init__.py
new file mode 100644
index 00000000..6a3847b9
--- /dev/null
+++ b/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+from .data_loader import DataLoader
+from .image_dataset import ImageDataset
+from .measures import QuadMeasurer
diff --git a/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/augmenter.py b/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/augmenter.py
new file mode 100644
index 00000000..42f2fff3
--- /dev/null
+++ b/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/augmenter.py
@@ -0,0 +1,46 @@
+# ------------------------------------------------------------------------------
+# The implementation is adopted from DBNet,
+# made publicly available under the Apache License 2.0 at https://github.com/MhLiao/DB.
+# ------------------------------------------------------------------------------
+import imgaug
+import imgaug.augmenters as iaa
+
+
+class AugmenterBuilder(object):
+
+ def __init__(self):
+ pass
+
+ def build(self, args, root=True):
+ if args is None:
+ return None
+ elif isinstance(args, (int, float, str)):
+ return args
+ elif isinstance(args, list):
+ if root:
+ sequence = [self.build(value, root=False) for value in args]
+ return iaa.Sequential(sequence)
+ else:
+ return getattr(
+ iaa,
+ args[0])(*[self.to_tuple_if_list(a) for a in args[1:]])
+ elif isinstance(args, dict):
+ if 'cls' in args:
+ cls = getattr(iaa, args['cls'])
+ return cls(
+ **{
+ k: self.to_tuple_if_list(v)
+ for k, v in args.items() if not k == 'cls'
+ })
+ else:
+ return {
+ key: self.build(value, root=False)
+ for key, value in args.items()
+ }
+ else:
+ raise RuntimeError('unknown augmenter arg: ' + str(args))
+
+ def to_tuple_if_list(self, obj):
+ if isinstance(obj, list):
+ return tuple(obj)
+ return obj
diff --git a/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/data_loader.py b/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/data_loader.py
new file mode 100644
index 00000000..a13ad196
--- /dev/null
+++ b/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/data_loader.py
@@ -0,0 +1,135 @@
+# ------------------------------------------------------------------------------
+# Part of implementation is adopted from DBNet,
+# made publicly available under the Apache License 2.0 at https://github.com/MhLiao/DB.
+# ------------------------------------------------------------------------------
+import bisect
+import math
+
+import imgaug
+import numpy as np
+import torch
+import torch.distributed as dist
+from torch.utils.data import BatchSampler, ConcatDataset, Sampler
+
+from .processes import ICDARCollectFN
+
+
+def default_worker_init_fn(worker_id):
+ np.random.seed(worker_id)
+ imgaug.seed(worker_id)
+
+
+class DataLoader(torch.utils.data.DataLoader):
+
+ def __init__(self,
+ dataset,
+ cfg_dataloader,
+ is_train,
+ distributed,
+ drop_last=False,
+ shuffle=None):
+ self.dataset = dataset
+ self.batch_size = cfg_dataloader.batch_size
+ self.num_workers = cfg_dataloader.num_workers
+ self.num_gpus = cfg_dataloader.num_gpus
+ self.is_train = is_train
+ self.drop_last = drop_last
+ self.shuffle = shuffle
+
+ if hasattr(cfg_dataloader, 'collect_fn'
+ ) and cfg_dataloader.collect_fn == 'ICDARCollectFN':
+ self.collect_fn = ICDARCollectFN()
+ else:
+ self.collect_fn = torch.utils.data.dataloader.default_collate
+ if self.shuffle is None:
+ self.shuffle = self.is_train
+
+ if distributed:
+ sampler = DistributedSampler(
+ self.dataset, shuffle=self.shuffle, num_replicas=self.num_gpus)
+ batch_sampler = BatchSampler(sampler,
+ self.batch_size // self.num_gpus,
+ False)
+ torch.utils.data.DataLoader.__init__(
+ self,
+ self.dataset,
+ batch_sampler=batch_sampler,
+ num_workers=self.num_workers,
+ pin_memory=False,
+ drop_last=self.drop_last,
+ collate_fn=self.collect_fn,
+ worker_init_fn=default_worker_init_fn)
+ else:
+ torch.utils.data.DataLoader.__init__(
+ self,
+ self.dataset,
+ batch_size=self.batch_size,
+ num_workers=self.num_workers,
+ drop_last=self.drop_last,
+ shuffle=self.shuffle,
+ pin_memory=True,
+ collate_fn=self.collect_fn,
+ worker_init_fn=default_worker_init_fn)
+ self.collect_fn = str(self.collect_fn)
+
+
+class DistributedSampler(Sampler):
+ """Sampler that restricts data loading to a subset of the dataset.
+ It is especially useful in conjunction with
+ :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
+ process can pass a DistributedSampler instance as a DataLoader sampler,
+ and load a subset of the original dataset that is exclusive to it.
+ .. note::
+ Dataset is assumed to be of constant size.
+ Arguments:
+ dataset: Dataset used for sampling.
+ num_replicas (optional): Number of processes participating in
+ distributed training.
+ rank (optional): Rank of the current process within num_replicas.
+ """
+
+ def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
+ if num_replicas is None:
+ if not dist.is_available():
+ raise RuntimeError(
+ 'Requires distributed package to be available')
+ num_replicas = dist.get_world_size()
+ if rank is None:
+ if not dist.is_available():
+ raise RuntimeError(
+ 'Requires distributed package to be available')
+ rank = dist.get_rank()
+ self.dataset = dataset
+ self.num_replicas = num_replicas
+ self.rank = rank
+ self.epoch = 0
+ self.num_samples = int(
+ math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
+ self.total_size = self.num_samples * self.num_replicas
+ self.shuffle = shuffle
+
+ def __iter__(self):
+ if self.shuffle:
+ # deterministically shuffle based on epoch
+ g = torch.Generator()
+ g.manual_seed(self.epoch)
+ indices = torch.randperm(len(self.dataset)).tolist()
+ else:
+ indices = torch.arange(len(self.dataset)).tolist()
+
+ # add extra samples to make it evenly divisible
+ indices += indices[:(self.total_size - len(indices))]
+ assert len(indices) == self.total_size
+
+ # subsample
+ offset = self.num_samples * self.rank
+ indices = indices[offset:offset + self.num_samples]
+ assert len(indices) == self.num_samples
+
+ return iter(indices)
+
+ def __len__(self):
+ return self.num_samples
+
+ def set_epoch(self, epoch):
+ self.epoch = epoch
diff --git a/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/image_dataset.py b/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/image_dataset.py
new file mode 100644
index 00000000..f5ea2f45
--- /dev/null
+++ b/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/image_dataset.py
@@ -0,0 +1,150 @@
+# ------------------------------------------------------------------------------
+# Part of implementation is adopted from DBNet,
+# made publicly available under the Apache License 2.0 at https://github.com/MhLiao/DB.
+# ------------------------------------------------------------------------------
+import bisect
+import functools
+import glob
+import logging
+import math
+import os
+
+import cv2
+import numpy as np
+import torch.utils.data as data
+
+from .processes import (AugmentDetectionData, MakeBorderMap, MakeICDARData,
+ MakeSegDetectionData, NormalizeImage, RandomCropData)
+
+
+class ImageDataset(data.Dataset):
+ r'''Dataset reading from images.
+ '''
+
+ def __init__(self, cfg, data_dir=None, data_list=None, **kwargs):
+ self.data_dir = data_dir
+ self.data_list = data_list
+ if 'train' in self.data_list[0]:
+ self.is_training = True
+ else:
+ self.is_training = False
+ self.image_paths = []
+ self.gt_paths = []
+ self.get_all_samples()
+ self.processes = None
+ if self.is_training and hasattr(cfg.train, 'transform'):
+ self.processes = cfg.train.transform
+ elif not self.is_training and hasattr(cfg.test, 'transform'):
+ self.processes = cfg.test.transform
+
+ def get_all_samples(self):
+ for i in range(len(self.data_dir)):
+ with open(self.data_list[i], 'r') as fid:
+ image_list = fid.readlines()
+ fid.close()
+ if self.is_training:
+ image_path = [
+ self.data_dir[i] + '/train_images/' + timg.strip()
+ for timg in image_list
+ ]
+ gt_path = [
+ self.data_dir[i] + '/train_gts/'
+ + timg.strip().split('.')[0] + '.txt'
+ for timg in image_list
+ ]
+ else:
+ image_path = [
+ self.data_dir[i] + '/test_images/' + timg.strip()
+ for timg in image_list
+ ]
+ gt_path = [
+ self.data_dir[i] + '/test_gts/'
+ + timg.strip().split('.')[0] + '.txt'
+ for timg in image_list
+ ]
+ self.image_paths += image_path
+ self.gt_paths += gt_path
+ self.num_samples = len(self.image_paths)
+ self.targets = self.load_ann()
+ if self.is_training:
+ assert len(self.image_paths) == len(self.targets)
+
+ def load_ann(self):
+ res = []
+ for gt in self.gt_paths:
+ lines = []
+ reader = open(gt, 'r')
+ for line in reader.readlines():
+ item = {}
+ line = line.strip().split(',')
+ label = line[-1]
+ poly = np.array(list(map(float, line[:8]))).reshape(
+ (-1, 2)).tolist()
+ item['poly'] = poly
+ item['text'] = label
+ lines.append(item)
+ reader.close()
+ res.append(lines)
+ return res
+
+ def __getitem__(self, index, retry=0):
+ if index >= self.num_samples:
+ index = index % self.num_samples
+ data = {}
+ image_path = self.image_paths[index]
+ img = cv2.imread(image_path, cv2.IMREAD_COLOR).astype('float32')
+ if self.is_training:
+ data['filename'] = image_path
+ data['data_id'] = image_path
+ else:
+ data['filename'] = image_path.split('/')[-1]
+ data['data_id'] = image_path.split('/')[-1]
+ data['image'] = img
+ target = self.targets[index]
+ data['lines'] = target
+
+ # processes in line-up way, defined in configuration.json
+ if self.processes is not None:
+ # normal detection augment
+ if hasattr(self.processes, 'detection_augment'):
+ data_process0 = AugmentDetectionData(
+ self.processes.detection_augment)
+ data = data_process0(data)
+
+ # random crop augment
+ if hasattr(self.processes, 'random_crop'):
+ data_process1 = RandomCropData(self.processes.random_crop)
+ data = data_process1(data)
+
+ # data build in ICDAR format
+ if hasattr(self.processes, 'MakeICDARData'):
+ data_process2 = MakeICDARData()
+ data = data_process2(data)
+
+ # Making binary mask from detection data with ICDAR format
+ if hasattr(self.processes, 'MakeSegDetectionData'):
+ data_process3 = MakeSegDetectionData()
+ data = data_process3(data)
+
+ # Making the border map from detection data with ICDAR format
+ if hasattr(self.processes, 'MakeBorderMap'):
+ data_process4 = MakeBorderMap()
+ data = data_process4(data)
+
+ # Image Normalization
+ if hasattr(self.processes, 'NormalizeImage'):
+ data_process5 = NormalizeImage()
+ data = data_process5(data)
+
+ if self.is_training:
+ # remove redundant data key for training
+ for key in [
+ 'polygons', 'filename', 'shape', 'ignore_tags',
+ 'is_training'
+ ]:
+ del data[key]
+
+ return data
+
+ def __len__(self):
+ return len(self.image_paths)
diff --git a/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/measures/__init__.py b/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/measures/__init__.py
new file mode 100644
index 00000000..c4546f1a
--- /dev/null
+++ b/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/measures/__init__.py
@@ -0,0 +1 @@
+from .quad_measurer import QuadMeasurer
diff --git a/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/measures/iou_evaluator.py b/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/measures/iou_evaluator.py
new file mode 100644
index 00000000..86b76b81
--- /dev/null
+++ b/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/measures/iou_evaluator.py
@@ -0,0 +1,220 @@
+#!/usr/bin/env python
+from collections import namedtuple
+
+import numpy as np
+from shapely.geometry import Polygon
+
+
+class DetectionIoUEvaluator(object):
+
+ def __init__(self, iou_constraint=0.5, area_precision_constraint=0.5):
+ self.iou_constraint = iou_constraint
+ self.area_precision_constraint = area_precision_constraint
+
+ def evaluate_image(self, gt, pred):
+
+ def get_union(pD, pG):
+ return Polygon(pD).union(Polygon(pG)).area
+
+ def get_intersection_over_union(pD, pG):
+ return get_intersection(pD, pG) / get_union(pD, pG)
+
+ def get_intersection(pD, pG):
+ return Polygon(pD).intersection(Polygon(pG)).area
+
+ def compute_ap(confList, matchList, numGtCare):
+ correct = 0
+ AP = 0
+ if len(confList) > 0:
+ confList = np.array(confList)
+ matchList = np.array(matchList)
+ sorted_ind = np.argsort(-confList)
+ confList = confList[sorted_ind]
+ matchList = matchList[sorted_ind]
+ for n in range(len(confList)):
+ match = matchList[n]
+ if match:
+ correct += 1
+ AP += float(correct) / (n + 1)
+
+ if numGtCare > 0:
+ AP /= numGtCare
+
+ return AP
+
+ perSampleMetrics = {}
+
+ matchedSum = 0
+
+ numGlobalCareGt = 0
+ numGlobalCareDet = 0
+
+ recall = 0
+ precision = 0
+ hmean = 0
+
+ detMatched = 0
+
+ iouMat = np.empty([1, 1])
+
+ gtPols = []
+ detPols = []
+
+ gtPolPoints = []
+ detPolPoints = []
+
+ # Array of Ground Truth Polygons' keys marked as don't Care
+ gtDontCarePolsNum = []
+ # Array of Detected Polygons' matched with a don't Care GT
+ detDontCarePolsNum = []
+
+ pairs = []
+ detMatchedNums = []
+
+ evaluationLog = ''
+
+ for n in range(len(gt)):
+ points = gt[n]['points']
+ dontCare = gt[n]['ignore']
+
+ if not Polygon(points).is_valid or not Polygon(points).is_simple:
+ continue
+
+ gtPol = points
+ gtPols.append(gtPol)
+ gtPolPoints.append(points)
+ if dontCare:
+ gtDontCarePolsNum.append(len(gtPols) - 1)
+
+ evaluationLog += 'GT polygons: ' + str(len(gtPols)) + (
+ ' (' + str(len(gtDontCarePolsNum))
+ + " don't care)\n" if len(gtDontCarePolsNum) > 0 else '\n')
+
+ for n in range(len(pred)):
+ points = pred[n]['points']
+ if not Polygon(points).is_valid or not Polygon(points).is_simple:
+ continue
+
+ detPol = points
+ detPols.append(detPol)
+ detPolPoints.append(points)
+ if len(gtDontCarePolsNum) > 0:
+ for dontCarePol in gtDontCarePolsNum:
+ dontCarePol = gtPols[dontCarePol]
+ intersected_area = get_intersection(dontCarePol, detPol)
+ pdDimensions = Polygon(detPol).area
+ precision = 0 if pdDimensions == 0 else intersected_area / pdDimensions
+ if (precision > self.area_precision_constraint):
+ detDontCarePolsNum.append(len(detPols) - 1)
+ break
+
+ evaluationLog += 'DET polygons: ' + str(len(detPols)) + (
+ ' (' + str(len(detDontCarePolsNum))
+ + " don't care)\n" if len(detDontCarePolsNum) > 0 else '\n')
+
+ if len(gtPols) > 0 and len(detPols) > 0:
+ # Calculate IoU and precision matrixs
+ outputShape = [len(gtPols), len(detPols)]
+ iouMat = np.empty(outputShape)
+ gtRectMat = np.zeros(len(gtPols), np.int8)
+ detRectMat = np.zeros(len(detPols), np.int8)
+ for gtNum in range(len(gtPols)):
+ for detNum in range(len(detPols)):
+ pG = gtPols[gtNum]
+ pD = detPols[detNum]
+ iouMat[gtNum, detNum] = get_intersection_over_union(pD, pG)
+
+ for gtNum in range(len(gtPols)):
+ for detNum in range(len(detPols)):
+ if gtRectMat[gtNum] == 0 and detRectMat[
+ detNum] == 0 and gtNum not in gtDontCarePolsNum and detNum not in detDontCarePolsNum:
+ if iouMat[gtNum, detNum] > self.iou_constraint:
+ gtRectMat[gtNum] = 1
+ detRectMat[detNum] = 1
+ detMatched += 1
+ pairs.append({'gt': gtNum, 'det': detNum})
+ detMatchedNums.append(detNum)
+ evaluationLog += 'Match GT #' + \
+ str(gtNum) + ' with Det #' + str(detNum) + '\n'
+
+ numGtCare = (len(gtPols) - len(gtDontCarePolsNum))
+ numDetCare = (len(detPols) - len(detDontCarePolsNum))
+ if numGtCare == 0:
+ recall = float(1)
+ precision = float(0) if numDetCare > 0 else float(1)
+ else:
+ recall = float(detMatched) / numGtCare
+ precision = 0 if numDetCare == 0 else float(
+ detMatched) / numDetCare
+
+ hmean = 0 if (precision + recall) == 0 else 2.0 * \
+ precision * recall / (precision + recall)
+
+ matchedSum += detMatched
+ numGlobalCareGt += numGtCare
+ numGlobalCareDet += numDetCare
+
+ perSampleMetrics = {
+ 'precision': precision,
+ 'recall': recall,
+ 'hmean': hmean,
+ 'pairs': pairs,
+ 'iouMat': [] if len(detPols) > 100 else iouMat.tolist(),
+ 'gtPolPoints': gtPolPoints,
+ 'detPolPoints': detPolPoints,
+ 'gtCare': numGtCare,
+ 'detCare': numDetCare,
+ 'gtDontCare': gtDontCarePolsNum,
+ 'detDontCare': detDontCarePolsNum,
+ 'detMatched': detMatched,
+ 'evaluationLog': evaluationLog
+ }
+
+ return perSampleMetrics
+
+ def combine_results(self, results):
+ numGlobalCareGt = 0
+ numGlobalCareDet = 0
+ matchedSum = 0
+ for result in results:
+ numGlobalCareGt += result['gtCare']
+ numGlobalCareDet += result['detCare']
+ matchedSum += result['detMatched']
+
+ methodRecall = 0 if numGlobalCareGt == 0 else float(
+ matchedSum) / numGlobalCareGt
+ methodPrecision = 0 if numGlobalCareDet == 0 else float(
+ matchedSum) / numGlobalCareDet
+ methodHmean = 0 if methodRecall + methodPrecision == 0 else 2 * \
+ methodRecall * methodPrecision / (methodRecall + methodPrecision)
+
+ methodMetrics = {
+ 'precision': methodPrecision,
+ 'recall': methodRecall,
+ 'hmean': methodHmean
+ }
+
+ return methodMetrics
+
+
+if __name__ == '__main__':
+ evaluator = DetectionIoUEvaluator()
+ gts = [[{
+ 'points': [(0, 0), (1, 0), (1, 1), (0, 1)],
+ 'text': 1234,
+ 'ignore': False,
+ }, {
+ 'points': [(2, 2), (3, 2), (3, 3), (2, 3)],
+ 'text': 5678,
+ 'ignore': False,
+ }]]
+ preds = [[{
+ 'points': [(0.1, 0.1), (1, 0), (1, 1), (0, 1)],
+ 'text': 123,
+ 'ignore': False,
+ }]]
+ results = []
+ for gt, pred in zip(gts, preds):
+ results.append(evaluator.evaluate_image(gt, pred))
+ metrics = evaluator.combine_results(results)
+ print(metrics)
diff --git a/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/measures/quad_measurer.py b/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/measures/quad_measurer.py
new file mode 100644
index 00000000..0d662305
--- /dev/null
+++ b/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/measures/quad_measurer.py
@@ -0,0 +1,98 @@
+import numpy as np
+
+from .iou_evaluator import DetectionIoUEvaluator
+
+
+class AverageMeter(object):
+ """Computes and stores the average and current value"""
+
+ def __init__(self):
+ self.reset()
+
+ def reset(self):
+ self.val = 0
+ self.avg = 0
+ self.sum = 0
+ self.count = 0
+
+ def update(self, val, n=1):
+ self.val = val
+ self.sum += val * n
+ self.count += n
+ self.avg = self.sum / self.count
+ return
+
+
+class QuadMeasurer():
+
+ def __init__(self, **kwargs):
+ self.evaluator = DetectionIoUEvaluator()
+
+ def measure(self, batch, output, is_output_polygon=False, box_thresh=0.6):
+ '''
+ batch: (image, polygons, ignore_tags
+ batch: a dict produced by dataloaders.
+ image: tensor of shape (N, C, H, W).
+ polygons: tensor of shape (N, K, 4, 2), the polygons of objective regions.
+ ignore_tags: tensor of shape (N, K), indicates whether a region is ignorable or not.
+ shape: the original shape of images.
+ filename: the original filenames of images.
+ output: (polygons, ...)
+ '''
+ results = []
+ gt_polyons_batch = batch['polygons']
+ ignore_tags_batch = batch['ignore_tags']
+ pred_polygons_batch = np.array(output[0])
+ pred_scores_batch = np.array(output[1])
+ for polygons, pred_polygons, pred_scores, ignore_tags in\
+ zip(gt_polyons_batch, pred_polygons_batch, pred_scores_batch, ignore_tags_batch):
+ gt = [
+ dict(points=polygons[i], ignore=ignore_tags[i])
+ for i in range(len(polygons))
+ ]
+ if is_output_polygon:
+ pred = [
+ dict(points=pred_polygons[i])
+ for i in range(len(pred_polygons))
+ ]
+ else:
+ pred = []
+ for i in range(pred_polygons.shape[0]):
+ if pred_scores[i] >= box_thresh:
+ pred.append(
+ dict(
+ points=pred_polygons.reshape(-1, 4, 2)[
+ i, :, :].tolist()))
+ results.append(self.evaluator.evaluate_image(gt, pred))
+ return results
+
+ def validate_measure(self,
+ batch,
+ output,
+ is_output_polygon=False,
+ box_thresh=0.6):
+ return self.measure(batch, output, is_output_polygon, box_thresh)
+
+ def evaluate_measure(self, batch, output):
+ return self.measure(batch, output),\
+ np.linspace(0, batch['image'].shape[0]).tolist()
+
+ def gather_measure(self, raw_metrics):
+ raw_metrics = [
+ image_metrics for batch_metrics in raw_metrics
+ for image_metrics in batch_metrics
+ ]
+
+ result = self.evaluator.combine_results(raw_metrics)
+
+ precision = AverageMeter()
+ recall = AverageMeter()
+ fmeasure = AverageMeter()
+
+ precision.update(result['precision'], n=len(raw_metrics))
+ recall.update(result['recall'], n=len(raw_metrics))
+ fmeasure_score = 2 * precision.val * recall.val /\
+ (precision.val + recall.val + 1e-8)
+ fmeasure.update(fmeasure_score)
+
+ return {'precision': precision, 'recall': recall, 'fmeasure': fmeasure}
diff --git a/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/processes/__init__.py b/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/processes/__init__.py
new file mode 100644
index 00000000..92a3ad7e
--- /dev/null
+++ b/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/processes/__init__.py
@@ -0,0 +1,6 @@
+from .augment_data import AugmentData, AugmentDetectionData
+from .make_border_map import MakeBorderMap
+from .make_icdar_data import ICDARCollectFN, MakeICDARData
+from .make_seg_detection_data import MakeSegDetectionData
+from .normalize_image import NormalizeImage
+from .random_crop_data import RandomCropData
diff --git a/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/processes/augment_data.py b/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/processes/augment_data.py
new file mode 100644
index 00000000..316bf84e
--- /dev/null
+++ b/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/processes/augment_data.py
@@ -0,0 +1,99 @@
+import math
+
+import cv2
+import imgaug
+import numpy as np
+
+from ..augmenter import AugmenterBuilder
+from .data_process import DataProcess
+
+
+class AugmentData(DataProcess):
+
+ def __init__(self, cfg):
+ self.augmenter_args = cfg.augmenter_args
+ self.keep_ratio = cfg.keep_ratio
+ self.only_resize = cfg.only_resize
+ self.augmenter = AugmenterBuilder().build(self.augmenter_args)
+
+ def may_augment_annotation(self, aug, data):
+ pass
+
+ def resize_image(self, image):
+ origin_height, origin_width, c = image.shape
+ resize_shape = self.augmenter_args[0][1]
+
+ new_height_pad = resize_shape['height']
+ new_width_pad = resize_shape['width']
+ if self.keep_ratio:
+ if origin_height > origin_width:
+ new_height = new_height_pad
+ new_width = int(
+ math.ceil(new_height / origin_height * origin_width / 32)
+ * 32)
+ else:
+ new_width = new_width_pad
+ new_height = int(
+ math.ceil(new_width / origin_width * origin_height / 32)
+ * 32)
+ image = cv2.resize(image, (new_width, new_height))
+
+ else:
+ image = cv2.resize(image, (new_width_pad, new_height_pad))
+
+ return image
+
+ def process(self, data):
+ image = data['image']
+ aug = None
+ shape = image.shape
+ if self.augmenter:
+ aug = self.augmenter.to_deterministic()
+ if self.only_resize:
+ data['image'] = self.resize_image(image)
+ else:
+ data['image'] = aug.augment_image(image)
+ self.may_augment_annotation(aug, data, shape)
+
+ filename = data.get('filename', data.get('data_id', ''))
+ data.update(filename=filename, shape=shape[:2])
+ if not self.only_resize:
+ data['is_training'] = True
+ else:
+ data['is_training'] = False
+ return data
+
+
+class AugmentDetectionData(AugmentData):
+
+ def may_augment_annotation(self, aug: imgaug.augmenters.Augmenter, data,
+ shape):
+ if aug is None:
+ return data
+
+ line_polys = []
+ keypoints = []
+ texts = []
+ new_polys = []
+ for line in data['lines']:
+ texts.append(line['text'])
+ new_poly = []
+ for p in line['poly']:
+ new_poly.append((p[0], p[1]))
+ keypoints.append(imgaug.Keypoint(p[0], p[1]))
+ new_polys.append(new_poly)
+ if not self.only_resize:
+ keypoints = aug.augment_keypoints(
+ [imgaug.KeypointsOnImage(keypoints=keypoints,
+ shape=shape)])[0].keypoints
+ new_polys = np.array([[p.x, p.y]
+ for p in keypoints]).reshape([-1, 4, 2])
+ for i in range(len(texts)):
+ poly = new_polys[i]
+ line_polys.append({
+ 'points': poly,
+ 'ignore': texts[i] == '###',
+ 'text': texts[i]
+ })
+
+ data['polys'] = line_polys
diff --git a/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/processes/data_process.py b/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/processes/data_process.py
new file mode 100644
index 00000000..8ef7b0f1
--- /dev/null
+++ b/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/processes/data_process.py
@@ -0,0 +1,32 @@
+class DataProcess:
+ r'''Processes of data dict.
+ '''
+
+ def __call__(self, data, **kwargs):
+ return self.process(data, **kwargs)
+
+ def process(self, data, **kwargs):
+ raise NotImplementedError
+
+ def render_constant(self,
+ canvas,
+ xmin,
+ xmax,
+ ymin,
+ ymax,
+ value=1,
+ shrink=0):
+
+ def shrink_rect(xmin, xmax, ratio):
+ center = (xmin + xmax) / 2
+ width = center - xmin
+ return int(center - width * ratio
+ + 0.5), int(center + width * ratio + 0.5)
+
+ if shrink > 0:
+ xmin, xmax = shrink_rect(xmin, xmax, shrink)
+ ymin, ymax = shrink_rect(ymin, ymax, shrink)
+
+ canvas[int(ymin + 0.5):int(ymax + 0.5) + 1,
+ int(xmin + 0.5):int(xmax + 0.5) + 1] = value
+ return canvas
diff --git a/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/processes/make_border_map.py b/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/processes/make_border_map.py
new file mode 100644
index 00000000..bb2466f7
--- /dev/null
+++ b/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/processes/make_border_map.py
@@ -0,0 +1,152 @@
+import cv2
+import numpy as np
+import pyclipper
+from shapely.geometry import Polygon
+
+from .data_process import DataProcess
+
+
+class MakeBorderMap(DataProcess):
+ r'''
+ Making the border map from detection data with ICDAR format.
+ Typically following the process of class `MakeICDARData`.
+ '''
+
+ def __init__(self, *args, **kwargs):
+ self.shrink_ratio = 0.4
+ self.thresh_min = 0.3
+ self.thresh_max = 0.7
+
+ def process(self, data, *args, **kwargs):
+ r'''
+ required keys:
+ image, polygons, ignore_tags
+ adding keys:
+ thresh_map, thresh_mask
+ '''
+ image = data['image']
+ polygons = data['polygons']
+ ignore_tags = data['ignore_tags']
+ canvas = np.zeros(image.shape[:2], dtype=np.float32)
+ mask = np.zeros(image.shape[:2], dtype=np.float32)
+
+ for i in range(len(polygons)):
+ if ignore_tags[i]:
+ continue
+ self.draw_border_map(polygons[i], canvas, mask=mask)
+ canvas = canvas * (self.thresh_max - self.thresh_min) + self.thresh_min
+ data['thresh_map'] = canvas
+ data['thresh_mask'] = mask
+ return data
+
+ def draw_border_map(self, polygon, canvas, mask):
+ polygon = np.array(polygon)
+ assert polygon.ndim == 2
+ assert polygon.shape[1] == 2
+
+ polygon_shape = Polygon(polygon)
+ distance = polygon_shape.area * \
+ (1 - np.power(self.shrink_ratio, 2)) / polygon_shape.length
+ subject = [tuple(lp) for lp in polygon]
+ padding = pyclipper.PyclipperOffset()
+ padding.AddPath(subject, pyclipper.JT_ROUND,
+ pyclipper.ET_CLOSEDPOLYGON)
+ padded_polygon = np.array(padding.Execute(distance)[0])
+ cv2.fillPoly(mask, [padded_polygon.astype(np.int32)], 1.0)
+
+ xmin = padded_polygon[:, 0].min()
+ xmax = padded_polygon[:, 0].max()
+ ymin = padded_polygon[:, 1].min()
+ ymax = padded_polygon[:, 1].max()
+ width = xmax - xmin + 1
+ height = ymax - ymin + 1
+
+ polygon[:, 0] = polygon[:, 0] - xmin
+ polygon[:, 1] = polygon[:, 1] - ymin
+
+ xs = np.broadcast_to(
+ np.linspace(0, width - 1, num=width).reshape(1, width),
+ (height, width))
+ ys = np.broadcast_to(
+ np.linspace(0, height - 1, num=height).reshape(height, 1),
+ (height, width))
+
+ distance_map = np.zeros((polygon.shape[0], height, width),
+ dtype=np.float32)
+ for i in range(polygon.shape[0]):
+ j = (i + 1) % polygon.shape[0]
+ absolute_distance = self.distance(xs, ys, polygon[i], polygon[j])
+ distance_map[i] = np.clip(absolute_distance / distance, 0, 1)
+ distance_map = distance_map.min(axis=0)
+
+ xmin_valid = min(max(0, xmin), canvas.shape[1] - 1)
+ xmax_valid = min(max(0, xmax), canvas.shape[1] - 1)
+ ymin_valid = min(max(0, ymin), canvas.shape[0] - 1)
+ ymax_valid = min(max(0, ymax), canvas.shape[0] - 1)
+ canvas[ymin_valid:ymax_valid + 1, xmin_valid:xmax_valid + 1] = np.fmax(
+ 1 - distance_map[ymin_valid - ymin:ymax_valid - ymax + height,
+ xmin_valid - xmin:xmax_valid - xmax + width],
+ canvas[ymin_valid:ymax_valid + 1, xmin_valid:xmax_valid + 1])
+
+ def distance(self, xs, ys, point_1, point_2):
+ '''
+ compute the distance from point to a line
+ ys: coordinates in the first axis
+ xs: coordinates in the second axis
+ point_1, point_2: (x, y), the end of the line
+ '''
+ height, width = xs.shape[:2]
+ square_distance_1 = np.square(xs
+ - point_1[0]) + np.square(ys
+ - point_1[1])
+ square_distance_2 = np.square(xs
+ - point_2[0]) + np.square(ys
+ - point_2[1])
+ square_distance = np.square(point_1[0]
+ - point_2[0]) + np.square(point_1[1]
+ - point_2[1])
+
+ cosin = (square_distance - square_distance_1 - square_distance_2) / \
+ (2 * np.sqrt(square_distance_1 * square_distance_2))
+ square_sin = 1 - np.square(cosin)
+ square_sin = np.nan_to_num(square_sin)
+ result = np.sqrt(square_distance_1 * square_distance_2
+ * np.abs(square_sin) / (square_distance + 1e-6))
+
+ result[cosin < 0] = np.sqrt(
+ np.fmin(square_distance_1, square_distance_2))[cosin < 0]
+ # self.extend_line(point_1, point_2, result)
+ return result
+
+ def extend_line(self, point_1, point_2, result):
+ ex_point_1 = (
+ int(
+ round(point_1[0]
+ + (point_1[0] - point_2[0]) * (1 + self.shrink_ratio))),
+ int(
+ round(point_1[1]
+ + (point_1[1] - point_2[1]) * (1 + self.shrink_ratio))))
+ cv2.line(
+ result,
+ tuple(ex_point_1),
+ tuple(point_1),
+ 4096.0,
+ 1,
+ lineType=cv2.LINE_AA,
+ shift=0)
+ ex_point_2 = (
+ int(
+ round(point_2[0]
+ + (point_2[0] - point_1[0]) * (1 + self.shrink_ratio))),
+ int(
+ round(point_2[1]
+ + (point_2[1] - point_1[1]) * (1 + self.shrink_ratio))))
+ cv2.line(
+ result,
+ tuple(ex_point_2),
+ tuple(point_2),
+ 4096.0,
+ 1,
+ lineType=cv2.LINE_AA,
+ shift=0)
+ return ex_point_1, ex_point_2
diff --git a/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/processes/make_icdar_data.py b/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/processes/make_icdar_data.py
new file mode 100644
index 00000000..0bed212d
--- /dev/null
+++ b/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/processes/make_icdar_data.py
@@ -0,0 +1,65 @@
+from collections import OrderedDict
+
+import cv2
+import numpy as np
+import torch
+
+from .data_process import DataProcess
+
+
+class MakeICDARData(DataProcess):
+
+ def __init__(self, debug=False, **kwargs):
+ self.shrink_ratio = 0.4
+ self.debug = debug
+
+ def process(self, data):
+ polygons = []
+ ignore_tags = []
+ annotations = data['polys']
+ for annotation in annotations:
+ polygons.append(np.array(annotation['points']))
+ ignore_tags.append(annotation['ignore'])
+ ignore_tags = np.array(ignore_tags, dtype=np.uint8)
+ filename = data.get('filename', data['data_id'])
+ if self.debug:
+ self.draw_polygons(data['image'], polygons, ignore_tags)
+ shape = np.array(data['shape'])
+ return OrderedDict(
+ image=data['image'],
+ polygons=polygons,
+ ignore_tags=ignore_tags,
+ shape=shape,
+ filename=filename,
+ is_training=data['is_training'])
+
+ def draw_polygons(self, image, polygons, ignore_tags):
+ for i in range(len(polygons)):
+ polygon = polygons[i].reshape(-1, 2).astype(np.int32)
+ ignore = ignore_tags[i]
+ if ignore:
+ color = (255, 0, 0) # depict ignorable polygons in blue
+ else:
+ color = (0, 0, 255) # depict polygons in red
+
+ cv2.polylines(image, [polygon], True, color, 1)
+
+ polylines = staticmethod(draw_polygons)
+
+
+class ICDARCollectFN():
+
+ def __init__(self, *args, **kwargs):
+ pass
+
+ def __call__(self, batch):
+ data_dict = OrderedDict()
+ for sample in batch:
+ for k, v in sample.items():
+ if k not in data_dict:
+ data_dict[k] = []
+ if isinstance(v, np.ndarray):
+ v = torch.from_numpy(v)
+ data_dict[k].append(v)
+ data_dict['image'] = torch.stack(data_dict['image'], 0)
+ return data_dict
diff --git a/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/processes/make_seg_detection_data.py b/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/processes/make_seg_detection_data.py
new file mode 100644
index 00000000..73b6b415
--- /dev/null
+++ b/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/processes/make_seg_detection_data.py
@@ -0,0 +1,100 @@
+import cv2
+import numpy as np
+import pyclipper
+from shapely.geometry import Polygon
+
+from .data_process import DataProcess
+
+
+class MakeSegDetectionData(DataProcess):
+ r'''
+ Making binary mask from detection data with ICDAR format.
+ Typically following the process of class `MakeICDARData`.
+ '''
+
+ def __init__(self, **kwargs):
+ self.min_text_size = 6
+ self.shrink_ratio = 0.4
+
+ def process(self, data):
+ '''
+ requied keys:
+ image, polygons, ignore_tags, filename
+ adding keys:
+ mask
+ '''
+ image = data['image']
+ polygons = data['polygons']
+ ignore_tags = data['ignore_tags']
+ image = data['image']
+ filename = data['filename']
+
+ h, w = image.shape[:2]
+ if data['is_training']:
+ polygons, ignore_tags = self.validate_polygons(
+ polygons, ignore_tags, h, w)
+ gt = np.zeros((1, h, w), dtype=np.float32)
+ mask = np.ones((h, w), dtype=np.float32)
+
+ for i in range(len(polygons)):
+ polygon = polygons[i]
+ height = max(polygon[:, 1]) - min(polygon[:, 1])
+ width = max(polygon[:, 0]) - min(polygon[:, 0])
+ if ignore_tags[i] or min(height, width) < self.min_text_size:
+ cv2.fillPoly(mask,
+ polygon.astype(np.int32)[np.newaxis, :, :], 0)
+ ignore_tags[i] = True
+ else:
+ polygon_shape = Polygon(polygon)
+ distance = polygon_shape.area * \
+ (1 - np.power(self.shrink_ratio, 2)) / polygon_shape.length
+ subject = [tuple(lp) for lp in polygons[i]]
+ padding = pyclipper.PyclipperOffset()
+ padding.AddPath(subject, pyclipper.JT_ROUND,
+ pyclipper.ET_CLOSEDPOLYGON)
+ shrinked = padding.Execute(-distance)
+ if shrinked == []:
+ cv2.fillPoly(mask,
+ polygon.astype(np.int32)[np.newaxis, :, :], 0)
+ ignore_tags[i] = True
+ continue
+ shrinked = np.array(shrinked[0]).reshape(-1, 2)
+ cv2.fillPoly(gt[0], [shrinked.astype(np.int32)], 1)
+
+ if filename is None:
+ filename = ''
+ data.update(
+ image=image,
+ polygons=polygons,
+ gt=gt,
+ mask=mask,
+ filename=filename)
+ return data
+
+ def validate_polygons(self, polygons, ignore_tags, h, w):
+ '''
+ polygons (numpy.array, required): of shape (num_instances, num_points, 2)
+ '''
+ if len(polygons) == 0:
+ return polygons, ignore_tags
+ assert len(polygons) == len(ignore_tags)
+ for polygon in polygons:
+ polygon[:, 0] = np.clip(polygon[:, 0], 0, w - 1)
+ polygon[:, 1] = np.clip(polygon[:, 1], 0, h - 1)
+
+ for i in range(len(polygons)):
+ area = self.polygon_area(polygons[i])
+ if abs(area) < 1:
+ ignore_tags[i] = True
+ if area > 0:
+ polygons[i] = polygons[i][::-1, :]
+ return polygons, ignore_tags
+
+ def polygon_area(self, polygon):
+ edge = 0
+ for i in range(polygon.shape[0]):
+ next_index = (i + 1) % polygon.shape[0]
+ edge += (polygon[next_index, 0] - polygon[i, 0]) * (
+ polygon[next_index, 1] + polygon[i, 1])
+
+ return edge / 2.
diff --git a/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/processes/normalize_image.py b/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/processes/normalize_image.py
new file mode 100644
index 00000000..904467fe
--- /dev/null
+++ b/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/processes/normalize_image.py
@@ -0,0 +1,25 @@
+import numpy as np
+import torch
+
+from .data_process import DataProcess
+
+
+class NormalizeImage(DataProcess):
+ RGB_MEAN = np.array([122.67891434, 116.66876762, 104.00698793])
+
+ def process(self, data):
+ assert 'image' in data, '`image` in data is required by this process'
+ image = data['image']
+ image -= self.RGB_MEAN
+ image /= 255.
+ image = torch.from_numpy(image).permute(2, 0, 1).float()
+ data['image'] = image
+ return data
+
+ @classmethod
+ def restore(self, image):
+ image = image.permute(1, 2, 0).to('cpu').numpy()
+ image = image * 255.
+ image += self.RGB_MEAN
+ image = image.astype(np.uint8)
+ return image
diff --git a/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/processes/random_crop_data.py b/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/processes/random_crop_data.py
new file mode 100644
index 00000000..93d7aed0
--- /dev/null
+++ b/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_detection/processes/random_crop_data.py
@@ -0,0 +1,146 @@
+import cv2
+import numpy as np
+
+from .data_process import DataProcess
+
+
+# random crop algorithm similar to https://github.com/argman/EAST
+class RandomCropData(DataProcess):
+
+ def __init__(self, cfg):
+ self.size = cfg.size
+ self.max_tries = cfg.max_tries
+ self.min_crop_side_ratio = 0.1
+ self.require_original_image = False
+
+ def process(self, data):
+
+ size = self.size
+
+ img = data['image']
+ ori_img = img
+ ori_lines = data['polys']
+
+ all_care_polys = [
+ line['points'] for line in data['polys'] if not line['ignore']
+ ]
+ crop_x, crop_y, crop_w, crop_h = self.crop_area(img, all_care_polys)
+ scale_w = size[0] / crop_w
+ scale_h = size[1] / crop_h
+ scale = min(scale_w, scale_h)
+ h = int(crop_h * scale)
+ w = int(crop_w * scale)
+ padimg = np.zeros((size[1], size[0], img.shape[2]), img.dtype)
+ padimg[:h, :w] = cv2.resize(
+ img[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w], (w, h))
+ img = padimg
+
+ lines = []
+ for line in data['polys']:
+ poly_ori = np.array(line['points']) - (crop_x, crop_y)
+ poly = (poly_ori * scale).tolist()
+ if not self.is_poly_outside_rect(poly, 0, 0, w, h):
+ lines.append({**line, 'points': poly})
+ data['polys'] = lines
+
+ if self.require_original_image:
+ data['image'] = ori_img
+ else:
+ data['image'] = img
+ data['lines'] = ori_lines
+ data['scale_w'] = scale
+ data['scale_h'] = scale
+
+ return data
+
+ def is_poly_in_rect(self, poly, x, y, w, h):
+ poly = np.array(poly)
+ if poly[:, 0].min() < x or poly[:, 0].max() > x + w:
+ return False
+ if poly[:, 1].min() < y or poly[:, 1].max() > y + h:
+ return False
+ return True
+
+ def is_poly_outside_rect(self, poly, x, y, w, h):
+ poly = np.array(poly)
+ if poly[:, 0].max() < x or poly[:, 0].min() > x + w:
+ return True
+ if poly[:, 1].max() < y or poly[:, 1].min() > y + h:
+ return True
+ return False
+
+ def split_regions(self, axis):
+ regions = []
+ min_axis = 0
+ for i in range(1, axis.shape[0]):
+ if axis[i] != axis[i - 1] + 1:
+ region = axis[min_axis:i]
+ min_axis = i
+ regions.append(region)
+ return regions
+
+ def random_select(self, axis, max_size):
+ xx = np.random.choice(axis, size=2)
+ xmin = np.min(xx)
+ xmax = np.max(xx)
+ xmin = np.clip(xmin, 0, max_size - 1)
+ xmax = np.clip(xmax, 0, max_size - 1)
+ return xmin, xmax
+
+ def region_wise_random_select(self, regions, max_size):
+ selected_index = list(np.random.choice(len(regions), 2))
+ selected_values = []
+ for index in selected_index:
+ axis = regions[index]
+ xx = int(np.random.choice(axis, size=1))
+ selected_values.append(xx)
+ xmin = min(selected_values)
+ xmax = max(selected_values)
+ return xmin, xmax
+
+ def crop_area(self, img, polys):
+ h, w, _ = img.shape
+ h_array = np.zeros(h, dtype=np.int32)
+ w_array = np.zeros(w, dtype=np.int32)
+ for points in polys:
+ points = np.round(points, decimals=0).astype(np.int32)
+ minx = np.min(points[:, 0])
+ maxx = np.max(points[:, 0])
+ w_array[minx:maxx] = 1
+ miny = np.min(points[:, 1])
+ maxy = np.max(points[:, 1])
+ h_array[miny:maxy] = 1
+ # ensure the cropped area not across a text
+ h_axis = np.where(h_array == 0)[0]
+ w_axis = np.where(w_array == 0)[0]
+
+ if len(h_axis) == 0 or len(w_axis) == 0:
+ return 0, 0, w, h
+
+ h_regions = self.split_regions(h_axis)
+ w_regions = self.split_regions(w_axis)
+
+ for i in range(self.max_tries):
+ if len(w_regions) > 1:
+ xmin, xmax = self.region_wise_random_select(w_regions, w)
+ else:
+ xmin, xmax = self.random_select(w_axis, w)
+ if len(h_regions) > 1:
+ ymin, ymax = self.region_wise_random_select(h_regions, h)
+ else:
+ ymin, ymax = self.random_select(h_axis, h)
+
+ if xmax - xmin < self.min_crop_side_ratio * w or ymax - ymin < self.min_crop_side_ratio * h:
+ # area too small
+ continue
+ num_poly_in_rect = 0
+ for poly in polys:
+ if not self.is_poly_outside_rect(poly, xmin, ymin, xmax - xmin,
+ ymax - ymin):
+ num_poly_in_rect += 1
+ break
+
+ if num_poly_in_rect > 0:
+ return xmin, ymin, xmax - xmin, ymax - ymin
+
+ return 0, 0, w, h
diff --git a/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_recognition_dataset.py b/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_recognition_dataset.py
new file mode 100644
index 00000000..bc9cd3ca
--- /dev/null
+++ b/modelscope/msdatasets/dataset_cls/custom_datasets/ocr_recognition_dataset.py
@@ -0,0 +1,75 @@
+import os
+
+import cv2
+import json
+import lmdb
+import numpy as np
+import six
+import torch
+from PIL import Image
+
+from modelscope.metainfo import Models
+from modelscope.msdatasets.dataset_cls.custom_datasets.builder import \
+ CUSTOM_DATASETS
+from modelscope.msdatasets.dataset_cls.custom_datasets.torch_custom_dataset import \
+ TorchCustomDataset
+from modelscope.utils.constant import Tasks
+from modelscope.utils.logger import get_logger
+
+DATASET_STRUCTURE = {'image': 'image', 'label': 'label.txt', 'lmdb': 'lmdb'}
+
+
+def Q2B(uchar):
+ inside_code = ord(uchar)
+ if inside_code == 0x3000:
+ inside_code = 0x0020
+ else:
+ inside_code -= 0xfee0
+ if inside_code < 0x0020 or inside_code > 0x7e:
+ return uchar
+ return chr(inside_code)
+
+
+@CUSTOM_DATASETS.register_module(
+ Tasks.ocr_recognition, module_name=Models.ocr_recognition)
+class OCRRecognitionDataset(TorchCustomDataset):
+
+ def __init__(self, **kwargs):
+ split_config = kwargs['split_config']
+ cache_root = next(iter(split_config.values()))
+ lmdb_path = os.path.join(cache_root, DATASET_STRUCTURE['lmdb'])
+ self.env = lmdb.open(
+ lmdb_path,
+ max_readers=1,
+ readonly=True,
+ lock=False,
+ readahead=False,
+ meminit=False)
+ if not self.env:
+ print('cannot creat lmdb from %s' % (lmdb_path))
+ sys.exit(0)
+ self.nSamples = 0
+ with self.env.begin(write=False) as txn:
+ self.nSamples = int(txn.get('num-samples'.encode()))
+ self.reco_preprocess = kwargs['preprocessor']
+
+ def __len__(self):
+ return self.nSamples
+
+ def __getitem__(self, index):
+ index += 1
+ img_key = 'image-%09d' % index
+ with self.env.begin(write=False) as txn:
+ imgbuf = txn.get(img_key.encode())
+ buf = six.BytesIO()
+ buf.write(imgbuf)
+ buf.seek(0)
+ img = Image.open(buf).convert('L')
+ if self.reco_preprocess is not None:
+ img = self.reco_preprocess(img)
+
+ label_key = 'label-%09d' % index
+ label = txn.get(label_key.encode()).decode('utf-8')
+ label = ''.join([Q2B(c) for c in label])
+
+ return {'images': img, 'labels': label}
diff --git a/modelscope/msdatasets/dataset_cls/custom_datasets/reds_image_deblurring_dataset.py b/modelscope/msdatasets/dataset_cls/custom_datasets/reds_image_deblurring_dataset.py
new file mode 100644
index 00000000..826f5e78
--- /dev/null
+++ b/modelscope/msdatasets/dataset_cls/custom_datasets/reds_image_deblurring_dataset.py
@@ -0,0 +1,59 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+
+import cv2
+import numpy as np
+
+from modelscope.metainfo import CustomDatasets
+from modelscope.msdatasets.dataset_cls.custom_datasets import (
+ CUSTOM_DATASETS, TorchCustomDataset)
+from modelscope.msdatasets.dataset_cls.custom_datasets.sidd_image_denoising.data_utils import (
+ img2tensor, padding)
+from modelscope.msdatasets.dataset_cls.custom_datasets.sidd_image_denoising.transforms import (
+ augment, paired_random_crop)
+from modelscope.utils.constant import Tasks
+
+
+def default_loader(path):
+ return cv2.imread(path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.0
+
+
+@CUSTOM_DATASETS.register_module(
+ Tasks.image_deblurring, module_name=CustomDatasets.RedsDataset)
+class RedsImageDeblurringDataset(TorchCustomDataset):
+ """Paired image dataset for image restoration.
+ """
+
+ def __init__(self, dataset, opt, is_train):
+ self.dataset = dataset
+ self.opt = opt
+ self.is_train = is_train
+
+ def __len__(self):
+ return len(self.dataset)
+
+ def __getitem__(self, index):
+ item_dict = self.dataset[index]
+ hq_path = item_dict['LQ Frame:FILE']
+ img_hq = default_loader(hq_path)
+ lq_path = item_dict['HQ Frame:FILE']
+ img_lq = default_loader(lq_path)
+
+ # augmentation for training
+ if self.is_train:
+ gt_size = self.opt.gt_size
+ # padding
+ img_hq, img_lq = padding(img_hq, img_lq, gt_size)
+
+ # random crop
+ img_hq, img_lq = paired_random_crop(
+ img_hq, img_lq, gt_size, scale=1)
+
+ # flip, rotation
+ img_hq, img_lq = augment([img_hq, img_lq], self.opt.use_flip,
+ self.opt.use_rot)
+
+ # BGR to RGB, HWC to CHW, numpy to tensor
+ img_hq, img_lq = img2tensor([img_hq, img_lq],
+ bgr2rgb=True,
+ float32=True)
+ return {'input': img_lq, 'target': img_hq}
diff --git a/modelscope/msdatasets/dataset_cls/custom_datasets/referring_video_object_segmentation/__init__.py b/modelscope/msdatasets/dataset_cls/custom_datasets/referring_video_object_segmentation/__init__.py
new file mode 100644
index 00000000..7349e494
--- /dev/null
+++ b/modelscope/msdatasets/dataset_cls/custom_datasets/referring_video_object_segmentation/__init__.py
@@ -0,0 +1,21 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+from typing import TYPE_CHECKING
+
+from modelscope.utils.import_utils import LazyImportModule
+
+if TYPE_CHECKING:
+ from .referring_video_object_segmentation_dataset import ReferringVideoObjectSegmentationDataset
+else:
+ _import_structure = {
+ 'referring_video_object_segmentation_dataset':
+ ['MovieSceneSegmentationDataset'],
+ }
+ import sys
+
+ sys.modules[__name__] = LazyImportModule(
+ __name__,
+ globals()['__file__'],
+ _import_structure,
+ module_spec=__spec__,
+ extra_objects={},
+ )
diff --git a/modelscope/msdatasets/task_datasets/referring_video_object_segmentation/referring_video_object_segmentation_dataset.py b/modelscope/msdatasets/dataset_cls/custom_datasets/referring_video_object_segmentation/referring_video_object_segmentation_dataset.py
similarity index 98%
rename from modelscope/msdatasets/task_datasets/referring_video_object_segmentation/referring_video_object_segmentation_dataset.py
rename to modelscope/msdatasets/dataset_cls/custom_datasets/referring_video_object_segmentation/referring_video_object_segmentation_dataset.py
index 8b6d22a4..4493fd96 100644
--- a/modelscope/msdatasets/task_datasets/referring_video_object_segmentation/referring_video_object_segmentation_dataset.py
+++ b/modelscope/msdatasets/dataset_cls/custom_datasets/referring_video_object_segmentation/referring_video_object_segmentation_dataset.py
@@ -18,9 +18,8 @@ from tqdm import tqdm
from modelscope.metainfo import Models
from modelscope.models.cv.referring_video_object_segmentation.utils import \
nested_tensor_from_videos_list
-from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS
-from modelscope.msdatasets.task_datasets.torch_base_dataset import \
- TorchTaskDataset
+from modelscope.msdatasets.dataset_cls.custom_datasets import (
+ CUSTOM_DATASETS, TorchCustomDataset)
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger
from . import transformers as T
@@ -33,10 +32,10 @@ def get_image_id(video_id, frame_idx, ref_instance_a2d_id):
return image_id
-@TASK_DATASETS.register_module(
+@CUSTOM_DATASETS.register_module(
Tasks.referring_video_object_segmentation,
module_name=Models.referring_video_object_segmentation)
-class ReferringVideoObjectSegmentationDataset(TorchTaskDataset):
+class ReferringVideoObjectSegmentationDataset(TorchCustomDataset):
def __init__(self, **kwargs):
split_config = kwargs['split_config']
diff --git a/modelscope/msdatasets/task_datasets/referring_video_object_segmentation/transformers.py b/modelscope/msdatasets/dataset_cls/custom_datasets/referring_video_object_segmentation/transformers.py
similarity index 100%
rename from modelscope/msdatasets/task_datasets/referring_video_object_segmentation/transformers.py
rename to modelscope/msdatasets/dataset_cls/custom_datasets/referring_video_object_segmentation/transformers.py
diff --git a/modelscope/msdatasets/task_datasets/sidd_image_denoising/__init__.py b/modelscope/msdatasets/dataset_cls/custom_datasets/sidd_image_denoising/__init__.py
similarity index 100%
rename from modelscope/msdatasets/task_datasets/sidd_image_denoising/__init__.py
rename to modelscope/msdatasets/dataset_cls/custom_datasets/sidd_image_denoising/__init__.py
diff --git a/modelscope/msdatasets/task_datasets/sidd_image_denoising/data_utils.py b/modelscope/msdatasets/dataset_cls/custom_datasets/sidd_image_denoising/data_utils.py
similarity index 100%
rename from modelscope/msdatasets/task_datasets/sidd_image_denoising/data_utils.py
rename to modelscope/msdatasets/dataset_cls/custom_datasets/sidd_image_denoising/data_utils.py
diff --git a/modelscope/msdatasets/task_datasets/sidd_image_denoising/sidd_image_denoising_dataset.py b/modelscope/msdatasets/dataset_cls/custom_datasets/sidd_image_denoising/sidd_image_denoising_dataset.py
similarity index 83%
rename from modelscope/msdatasets/task_datasets/sidd_image_denoising/sidd_image_denoising_dataset.py
rename to modelscope/msdatasets/dataset_cls/custom_datasets/sidd_image_denoising/sidd_image_denoising_dataset.py
index 3f0cdae0..64fb8cb3 100644
--- a/modelscope/msdatasets/task_datasets/sidd_image_denoising/sidd_image_denoising_dataset.py
+++ b/modelscope/msdatasets/dataset_cls/custom_datasets/sidd_image_denoising/sidd_image_denoising_dataset.py
@@ -3,10 +3,9 @@
import cv2
import numpy as np
-from modelscope.metainfo import Models
-from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS
-from modelscope.msdatasets.task_datasets.torch_base_dataset import \
- TorchTaskDataset
+from modelscope.metainfo import CustomDatasets
+from modelscope.msdatasets.dataset_cls.custom_datasets import (
+ CUSTOM_DATASETS, TorchCustomDataset)
from modelscope.utils.constant import Tasks
from .data_utils import img2tensor, padding
from .transforms import augment, paired_random_crop
@@ -16,9 +15,9 @@ def default_loader(path):
return cv2.imread(path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.0
-@TASK_DATASETS.register_module(
- Tasks.image_denoising, module_name=Models.nafnet)
-class SiddImageDenoisingDataset(TorchTaskDataset):
+@CUSTOM_DATASETS.register_module(
+ Tasks.image_denoising, module_name=CustomDatasets.SiddDataset)
+class SiddImageDenoisingDataset(TorchCustomDataset):
"""Paired image dataset for image restoration.
"""
diff --git a/modelscope/msdatasets/task_datasets/sidd_image_denoising/transforms.py b/modelscope/msdatasets/dataset_cls/custom_datasets/sidd_image_denoising/transforms.py
similarity index 100%
rename from modelscope/msdatasets/task_datasets/sidd_image_denoising/transforms.py
rename to modelscope/msdatasets/dataset_cls/custom_datasets/sidd_image_denoising/transforms.py
diff --git a/modelscope/msdatasets/task_datasets/text_ranking_dataset.py b/modelscope/msdatasets/dataset_cls/custom_datasets/text_ranking_dataset.py
similarity index 92%
rename from modelscope/msdatasets/task_datasets/text_ranking_dataset.py
rename to modelscope/msdatasets/dataset_cls/custom_datasets/text_ranking_dataset.py
index 19f07110..46c64bbf 100644
--- a/modelscope/msdatasets/task_datasets/text_ranking_dataset.py
+++ b/modelscope/msdatasets/dataset_cls/custom_datasets/text_ranking_dataset.py
@@ -1,25 +1,21 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import random
-from dataclasses import dataclass
-from typing import Any, Dict, List, Tuple, Union
+from typing import Any, List, Union
import torch
-from datasets import Dataset, IterableDataset, concatenate_datasets
from torch.utils.data import ConcatDataset
-from transformers import DataCollatorWithPadding
from modelscope.metainfo import Models
+from modelscope.msdatasets.dataset_cls.custom_datasets import (
+ CUSTOM_DATASETS, TorchCustomDataset)
from modelscope.utils.constant import ModeKeys, Tasks
-from .base import TaskDataset
-from .builder import TASK_DATASETS
-from .torch_base_dataset import TorchTaskDataset
-@TASK_DATASETS.register_module(
+@CUSTOM_DATASETS.register_module(
group_key=Tasks.text_ranking, module_name=Models.bert)
-@TASK_DATASETS.register_module(
+@CUSTOM_DATASETS.register_module(
group_key=Tasks.sentence_embedding, module_name=Models.bert)
-class TextRankingDataset(TorchTaskDataset):
+class TextRankingDataset(TorchCustomDataset):
def __init__(self,
datasets: Union[Any, List[Any]],
diff --git a/modelscope/msdatasets/dataset_cls/custom_datasets/torch_custom_dataset.py b/modelscope/msdatasets/dataset_cls/custom_datasets/torch_custom_dataset.py
new file mode 100644
index 00000000..54ad55b7
--- /dev/null
+++ b/modelscope/msdatasets/dataset_cls/custom_datasets/torch_custom_dataset.py
@@ -0,0 +1,51 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+from typing import Any, List, Union
+
+import torch.utils.data
+from torch.utils.data import ConcatDataset as TorchConcatDataset
+
+from modelscope.utils.constant import ModeKeys
+
+
+class TorchCustomDataset(torch.utils.data.Dataset):
+ """The custom dataset base class for all the torch-based task processors.
+ """
+
+ def __init__(self,
+ datasets: Union[Any, List[Any]],
+ mode=ModeKeys.TRAIN,
+ preprocessor=None,
+ **kwargs):
+ self.trainer = None
+ self.mode = mode
+ self.preprocessor = preprocessor
+ self._inner_dataset = self.prepare_dataset(datasets)
+
+ def __getitem__(self, index) -> Any:
+ return self.preprocessor(
+ self._inner_dataset[index]
+ ) if self.preprocessor else self._inner_dataset[index]
+
+ def __len__(self):
+ return len(self._inner_dataset)
+
+ def prepare_dataset(self, datasets: Union[Any, List[Any]]) -> Any:
+ """Prepare a dataset.
+
+ User can process the input datasets in a whole dataset perspective.
+ This method gives a default implementation of datasets merging, user can override this
+ method to write custom logics.
+
+ Args:
+ datasets: The original dataset(s)
+
+ Returns: A single dataset, which may be created after merging.
+
+ """
+ if isinstance(datasets, List):
+ if len(datasets) == 1:
+ return datasets[0]
+ elif len(datasets) > 1:
+ return TorchConcatDataset(datasets)
+ else:
+ return datasets
diff --git a/modelscope/msdatasets/task_datasets/veco_dataset.py b/modelscope/msdatasets/dataset_cls/custom_datasets/veco_dataset.py
similarity index 91%
rename from modelscope/msdatasets/task_datasets/veco_dataset.py
rename to modelscope/msdatasets/dataset_cls/custom_datasets/veco_dataset.py
index df7c6483..047849bc 100644
--- a/modelscope/msdatasets/task_datasets/veco_dataset.py
+++ b/modelscope/msdatasets/dataset_cls/custom_datasets/veco_dataset.py
@@ -5,13 +5,13 @@ import numpy as np
from datasets import Dataset, IterableDataset, concatenate_datasets
from modelscope.metainfo import Models
+from modelscope.msdatasets.dataset_cls.custom_datasets import (
+ CUSTOM_DATASETS, TorchCustomDataset)
from modelscope.utils.constant import Tasks
-from .builder import TASK_DATASETS
-from .torch_base_dataset import TorchTaskDataset
-@TASK_DATASETS.register_module(module_name=Models.veco, group_key=Tasks.nli)
-class VecoDataset(TorchTaskDataset):
+@CUSTOM_DATASETS.register_module(module_name=Models.veco, group_key=Tasks.nli)
+class VecoDataset(TorchCustomDataset):
def __init__(self,
datasets: Union[Any, List[Any]],
diff --git a/modelscope/msdatasets/task_datasets/video_frame_interpolation/__init__.py b/modelscope/msdatasets/dataset_cls/custom_datasets/video_frame_interpolation/__init__.py
similarity index 100%
rename from modelscope/msdatasets/task_datasets/video_frame_interpolation/__init__.py
rename to modelscope/msdatasets/dataset_cls/custom_datasets/video_frame_interpolation/__init__.py
diff --git a/modelscope/msdatasets/task_datasets/video_frame_interpolation/data_utils.py b/modelscope/msdatasets/dataset_cls/custom_datasets/video_frame_interpolation/data_utils.py
similarity index 100%
rename from modelscope/msdatasets/task_datasets/video_frame_interpolation/data_utils.py
rename to modelscope/msdatasets/dataset_cls/custom_datasets/video_frame_interpolation/data_utils.py
diff --git a/modelscope/msdatasets/task_datasets/video_frame_interpolation/video_frame_interpolation_dataset.py b/modelscope/msdatasets/dataset_cls/custom_datasets/video_frame_interpolation/video_frame_interpolation_dataset.py
similarity index 79%
rename from modelscope/msdatasets/task_datasets/video_frame_interpolation/video_frame_interpolation_dataset.py
rename to modelscope/msdatasets/dataset_cls/custom_datasets/video_frame_interpolation/video_frame_interpolation_dataset.py
index 44b965a7..6f47906d 100644
--- a/modelscope/msdatasets/task_datasets/video_frame_interpolation/video_frame_interpolation_dataset.py
+++ b/modelscope/msdatasets/dataset_cls/custom_datasets/video_frame_interpolation/video_frame_interpolation_dataset.py
@@ -1,16 +1,13 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
-from collections import defaultdict
-
import cv2
import numpy as np
import torch
from modelscope.metainfo import Models
-from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS
-from modelscope.msdatasets.task_datasets.torch_base_dataset import \
- TorchTaskDataset
-from modelscope.msdatasets.task_datasets.video_frame_interpolation.data_utils import (
+from modelscope.msdatasets.dataset_cls.custom_datasets import (
+ CUSTOM_DATASETS, TorchCustomDataset)
+from modelscope.msdatasets.dataset_cls.custom_datasets.video_frame_interpolation.data_utils import (
img2tensor, img_padding)
from modelscope.utils.constant import Tasks
@@ -19,10 +16,10 @@ def default_loader(path):
return cv2.imread(path, cv2.IMREAD_UNCHANGED).astype(np.float32)
-@TASK_DATASETS.register_module(
+@CUSTOM_DATASETS.register_module(
Tasks.video_frame_interpolation,
module_name=Models.video_frame_interpolation)
-class VideoFrameInterpolationDataset(TorchTaskDataset):
+class VideoFrameInterpolationDataset(TorchCustomDataset):
"""Dataset for video frame-interpolation.
"""
diff --git a/modelscope/msdatasets/task_datasets/video_stabilization/__init__.py b/modelscope/msdatasets/dataset_cls/custom_datasets/video_stabilization/__init__.py
similarity index 100%
rename from modelscope/msdatasets/task_datasets/video_stabilization/__init__.py
rename to modelscope/msdatasets/dataset_cls/custom_datasets/video_stabilization/__init__.py
diff --git a/modelscope/msdatasets/task_datasets/video_stabilization/video_stabilization_dataset.py b/modelscope/msdatasets/dataset_cls/custom_datasets/video_stabilization/video_stabilization_dataset.py
similarity index 71%
rename from modelscope/msdatasets/task_datasets/video_stabilization/video_stabilization_dataset.py
rename to modelscope/msdatasets/dataset_cls/custom_datasets/video_stabilization/video_stabilization_dataset.py
index b0e6bdef..a0e0604c 100644
--- a/modelscope/msdatasets/task_datasets/video_stabilization/video_stabilization_dataset.py
+++ b/modelscope/msdatasets/dataset_cls/custom_datasets/video_stabilization/video_stabilization_dataset.py
@@ -1,15 +1,14 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from modelscope.metainfo import Models
-from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS
-from modelscope.msdatasets.task_datasets.torch_base_dataset import \
- TorchTaskDataset
+from modelscope.msdatasets.dataset_cls.custom_datasets import (
+ CUSTOM_DATASETS, TorchCustomDataset)
from modelscope.utils.constant import Tasks
-@TASK_DATASETS.register_module(
+@CUSTOM_DATASETS.register_module(
Tasks.video_stabilization, module_name=Models.video_stabilization)
-class VideoStabilizationDataset(TorchTaskDataset):
+class VideoStabilizationDataset(TorchCustomDataset):
"""Paired video dataset for video stabilization.
"""
diff --git a/modelscope/msdatasets/dataset_cls/custom_datasets/video_summarization_dataset.py b/modelscope/msdatasets/dataset_cls/custom_datasets/video_summarization_dataset.py
new file mode 100644
index 00000000..4d6e0155
--- /dev/null
+++ b/modelscope/msdatasets/dataset_cls/custom_datasets/video_summarization_dataset.py
@@ -0,0 +1,72 @@
+# Part of the implementation is borrowed and modified from PGL-SUM,
+# publicly available at https://github.com/e-apostolidis/PGL-SUM
+
+import os
+
+import h5py
+import json
+import numpy as np
+import torch
+
+from modelscope.msdatasets.dataset_cls.custom_datasets import \
+ TorchCustomDataset
+
+
+class VideoSummarizationDataset(TorchCustomDataset):
+
+ def __init__(self, mode, opt, root_dir):
+ self.mode = mode
+ self.data_filename = os.path.join(root_dir, opt.dataset_file)
+ self.split_filename = os.path.join(root_dir, opt.split_file)
+ self.split_index = opt.split_index
+ hdf = h5py.File(self.data_filename, 'r')
+ self.list_frame_features, self.list_gtscores = [], []
+ self.list_user_summary = []
+ self.list_change_points = []
+ self.list_n_frames = []
+ self.list_positions = []
+
+ with open(self.split_filename, encoding='utf-8') as f:
+ data = json.loads(f.read())
+ for i, split in enumerate(data):
+ if i == self.split_index:
+ self.split = split
+ break
+
+ for video_name in self.split[self.mode + '_keys']:
+ frame_features = torch.Tensor(
+ np.array(hdf[video_name + '/features']))
+ gtscore = torch.Tensor(np.array(hdf[video_name + '/gtscore']))
+ user_summary = np.array(hdf[f'{video_name}/user_summary'])
+ change_points = np.array(hdf[f'{video_name}/change_points'])
+ n_frames = np.array(hdf[f'{video_name}/n_frames'])
+ positions = np.array(hdf[f'{video_name}/picks'])
+
+ self.list_frame_features.append(frame_features)
+ self.list_gtscores.append(gtscore)
+ self.list_user_summary.append(user_summary)
+ self.list_change_points.append(change_points)
+ self.list_n_frames.append(n_frames)
+ self.list_positions.append(positions)
+
+ hdf.close()
+
+ def __len__(self):
+ self.len = len(self.split[self.mode + '_keys'])
+ return self.len
+
+ def __getitem__(self, index):
+ frame_features = self.list_frame_features[index]
+ gtscore = self.list_gtscores[index]
+ user_summary = self.list_user_summary[index]
+ change_points = self.list_change_points[index]
+ n_frames = self.list_n_frames[index]
+ positions = self.list_positions[index]
+
+ return dict(
+ frame_features=frame_features,
+ gtscore=gtscore,
+ user_summary=user_summary,
+ change_points=change_points,
+ n_frames=n_frames,
+ positions=positions)
diff --git a/modelscope/msdatasets/task_datasets/video_super_resolution/__init__.py b/modelscope/msdatasets/dataset_cls/custom_datasets/video_super_resolution/__init__.py
similarity index 100%
rename from modelscope/msdatasets/task_datasets/video_super_resolution/__init__.py
rename to modelscope/msdatasets/dataset_cls/custom_datasets/video_super_resolution/__init__.py
diff --git a/modelscope/msdatasets/task_datasets/video_super_resolution/video_super_resolution_dataset.py b/modelscope/msdatasets/dataset_cls/custom_datasets/video_super_resolution/video_super_resolution_dataset.py
similarity index 89%
rename from modelscope/msdatasets/task_datasets/video_super_resolution/video_super_resolution_dataset.py
rename to modelscope/msdatasets/dataset_cls/custom_datasets/video_super_resolution/video_super_resolution_dataset.py
index 69faa527..86e07db1 100644
--- a/modelscope/msdatasets/task_datasets/video_super_resolution/video_super_resolution_dataset.py
+++ b/modelscope/msdatasets/dataset_cls/custom_datasets/video_super_resolution/video_super_resolution_dataset.py
@@ -7,9 +7,8 @@ import numpy as np
import torch
from modelscope.metainfo import Models
-from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS
-from modelscope.msdatasets.task_datasets.torch_base_dataset import \
- TorchTaskDataset
+from modelscope.msdatasets.dataset_cls.custom_datasets import (
+ CUSTOM_DATASETS, TorchCustomDataset)
from modelscope.utils.constant import Tasks
@@ -42,9 +41,9 @@ def img2tensor(imgs, bgr2rgb=True, float32=True):
return _totensor(imgs, bgr2rgb, float32)
-@TASK_DATASETS.register_module(
+@CUSTOM_DATASETS.register_module(
Tasks.video_super_resolution, module_name=Models.real_basicvsr)
-class VideoSuperResolutionDataset(TorchTaskDataset):
+class VideoSuperResolutionDataset(TorchCustomDataset):
"""single video dataset for video super-resolution.
"""
diff --git a/modelscope/msdatasets/dataset_cls/dataset.py b/modelscope/msdatasets/dataset_cls/dataset.py
index 57ee8150..4acf51b1 100644
--- a/modelscope/msdatasets/dataset_cls/dataset.py
+++ b/modelscope/msdatasets/dataset_cls/dataset.py
@@ -14,15 +14,19 @@ logger = get_logger()
class ExternalDataset(object):
+ """Dataset class for custom datasets."""
def __init__(self, split_path_dict, config_kwargs):
self.split_path_dict = split_path_dict
self.config_kwargs = copy.deepcopy(config_kwargs)
self.config_kwargs.update({'split_config': split_path_dict})
- self.ext_dataset = None
+ # dataset for specific extensions
+ self.spec_extension_dataset = None
self.split_data_files = {k: [] for k, _ in split_path_dict.items()}
- file_ext = ''
+ self.custom_map = {}
+ # the extension of file
+ file_ext = ''
for split_name, split_dir in split_path_dict.items():
if isinstance(split_dir, str) and os.path.isdir(split_dir):
split_file_names = os.listdir(split_dir)
@@ -52,25 +56,27 @@ class ExternalDataset(object):
if file_ext:
file_ext = EXTENSIONS_TO_LOAD.get(file_ext)
- self.ext_dataset = datasets.load_dataset(
+ self.spec_extension_dataset = datasets.load_dataset(
file_ext, data_files=self.split_data_files, **config_kwargs)
def __len__(self):
- return len(self.split_path_dict
- ) if not self.ext_dataset else self.ext_dataset.__len__()
+ return len(
+ self.split_path_dict
+ ) if not self.spec_extension_dataset else self.spec_extension_dataset.__len__(
+ )
def __getitem__(self, item):
- if not self.ext_dataset:
+ if not self.spec_extension_dataset:
return self.split_path_dict.get(item)
else:
- return self.ext_dataset.__getitem__(item)
+ return self.spec_extension_dataset.__getitem__(item)
def __iter__(self):
- if not self.ext_dataset:
+ if not self.spec_extension_dataset:
for k, v in self.split_path_dict.items():
yield k, v
else:
- for k, v in self.ext_dataset.items():
+ for k, v in self.spec_extension_dataset.items():
yield k, v
@@ -99,3 +105,6 @@ class NativeIterableDataset(IterableDataset):
entity = ret
yield entity
+
+ def __len__(self):
+ return 1
diff --git a/modelscope/msdatasets/meta/data_meta_config.py b/modelscope/msdatasets/meta/data_meta_config.py
index 401a8e14..7f97108b 100644
--- a/modelscope/msdatasets/meta/data_meta_config.py
+++ b/modelscope/msdatasets/meta/data_meta_config.py
@@ -2,7 +2,35 @@
class DataMetaConfig(object):
- """Modelscope data-meta config class."""
+ """Modelscope data-meta config class.
+
+ Attributes:
+ dataset_scripts(str): The local path of dataset scripts.
+ dataset_formation(:obj:`enum.Enum`): Dataset formation, refer to modelscope.utils.constant.DatasetFormations.
+ meta_cache_dir(str): Meta cache path.
+ meta_data_files(dict): Meta data mapping, Example: {'test': 'https://xxx/mytest.csv'}
+ zip_data_files(dict): Data files mapping, Example: {'test': 'pictures.zip'}
+ meta_args_map(dict): Meta arguments mapping, Example: {'test': {'file': 'pictures.zip'}, ...}
+ target_dataset_structure(dict): Dataset Structure, like
+ {
+ "default":{
+ "train":{
+ "meta":"my_train.csv",
+ "file":"pictures.zip"
+ }
+ },
+ "subsetA":{
+ "test":{
+ "meta":"mytest.csv",
+ "file":"pictures.zip"
+ }
+ }
+ }
+ dataset_py_script(str): The python script path of dataset.
+ meta_type_map(dict): The custom dataset mapping in meta data,
+ Example: {"type": "MovieSceneSegmentationCustomDataset",
+ "preprocessor": "movie-scene-segmentation-preprocessor"}
+ """
def __init__(self):
self.dataset_scripts = None
@@ -13,3 +41,4 @@ class DataMetaConfig(object):
self.meta_args_map = None
self.target_dataset_structure = None
self.dataset_py_script = None
+ self.meta_type_map = {}
diff --git a/modelscope/msdatasets/meta/data_meta_manager.py b/modelscope/msdatasets/meta/data_meta_manager.py
index bba46e84..d90b8d5e 100644
--- a/modelscope/msdatasets/meta/data_meta_manager.py
+++ b/modelscope/msdatasets/meta/data_meta_manager.py
@@ -75,7 +75,7 @@ class DataMetaManager(object):
elif download_mode == DownloadMode.FORCE_REDOWNLOAD:
# Clean meta-files
if os.path.exists(meta_cache_dir) and os.listdir(meta_cache_dir):
- shutil.rmtree(meta_cache_dir)
+ shutil.rmtree(meta_cache_dir, ignore_errors=True)
# Re-download meta-files
with FileLock(lock_file=lock_file_path):
os.makedirs(meta_cache_dir, exist_ok=True)
@@ -129,12 +129,13 @@ class DataMetaManager(object):
else:
target_subset_name, target_dataset_structure = get_target_dataset_structure(
dataset_json, subset_name, split)
- meta_map, file_map, args_map = get_dataset_files(
+ meta_map, file_map, args_map, type_map = get_dataset_files(
target_dataset_structure, dataset_name, namespace, version)
data_meta_config.meta_data_files = meta_map
data_meta_config.zip_data_files = file_map
data_meta_config.meta_args_map = args_map
+ data_meta_config.meta_type_map = type_map
data_meta_config.target_dataset_structure = target_dataset_structure
self.dataset_context_config.data_meta_config = data_meta_config
diff --git a/modelscope/msdatasets/ms_dataset.py b/modelscope/msdatasets/ms_dataset.py
index e4948310..06f47874 100644
--- a/modelscope/msdatasets/ms_dataset.py
+++ b/modelscope/msdatasets/ms_dataset.py
@@ -16,19 +16,27 @@ from modelscope.msdatasets.context.dataset_context_config import \
from modelscope.msdatasets.data_loader.data_loader_manager import (
LocalDataLoaderManager, LocalDataLoaderType, RemoteDataLoaderManager,
RemoteDataLoaderType)
+from modelscope.msdatasets.dataset_cls.custom_datasets.builder import \
+ build_custom_dataset
from modelscope.msdatasets.dataset_cls.dataset import (ExternalDataset,
NativeIterableDataset)
-from modelscope.msdatasets.task_datasets.builder import build_task_dataset
from modelscope.msdatasets.utils.delete_utils import DatasetDeleteManager
from modelscope.msdatasets.utils.upload_utils import DatasetUploadManager
-from modelscope.utils.config import ConfigDict
+from modelscope.preprocessors import build_preprocessor
+from modelscope.utils.config import Config, ConfigDict
from modelscope.utils.config_ds import MS_DATASETS_CACHE
from modelscope.utils.constant import (DEFAULT_DATASET_NAMESPACE,
- DEFAULT_DATASET_REVISION, DownloadMode,
- Hubs, UploadMode)
+ DEFAULT_DATASET_REVISION, ConfigFields,
+ DownloadMode, Hubs, ModeKeys, Tasks,
+ UploadMode)
from modelscope.utils.import_utils import is_tf_available, is_torch_available
from modelscope.utils.logger import get_logger
+try:
+ from tensorflow.data import Dataset as TfDataset
+except Exception as e:
+ print(e)
+
logger = get_logger()
@@ -53,6 +61,7 @@ class MsDataset:
"""
# the underlying huggingface Dataset
_hf_ds = None
+ _dataset_context_config: DatasetContextConfig = None
def __init__(self,
ds_instance: Union[Dataset, IterableDataset, ExternalDataset],
@@ -63,6 +72,7 @@ class MsDataset:
f'"target" must be a column of the dataset({list(self._hf_ds.features.keys())}, but got {target}'
)
self.target = target
+ self.is_custom = False
def __iter__(self):
for item in self._hf_ds:
@@ -77,10 +87,10 @@ class MsDataset:
def __len__(self):
if isinstance(self._hf_ds, IterableDataset) or isinstance(
self._hf_ds, NativeIterableDataset):
- logger.error(
- f'object of type `{self._hf_ds.__class__.__name__}` has no __len__()'
+ logger.warning(
+ f'object of type `{self._hf_ds.__class__.__name__}` has default length 1'
)
- return None
+ return 1
return len(self._hf_ds)
@property
@@ -163,6 +173,7 @@ class MsDataset:
REUSE_DATASET_IF_EXISTS,
cache_dir: Optional[str] = MS_DATASETS_CACHE,
use_streaming: Optional[bool] = False,
+ custom_cfg: Optional[Config] = Config(),
**config_kwargs,
) -> Union[dict, 'MsDataset', NativeIterableDataset]:
"""Load a MsDataset from the ModelScope Hub, Hugging Face Hub, urls, or a local dataset.
@@ -191,6 +202,8 @@ class MsDataset:
use_streaming (bool, Optional): If set to True, no need to download all data files.
Instead, it streams the data progressively, and returns
NativeIterableDataset or a dict of NativeIterableDataset.
+ custom_cfg (str, Optional): Model configuration, this can be used for custom datasets.
+ see https://modelscope.cn/docs/Configuration%E8%AF%A6%E8%A7%A3
**config_kwargs (additional keyword arguments): Keyword arguments to be passed
Returns:
@@ -245,305 +258,44 @@ class MsDataset:
dataset_inst = LocalDataLoaderManager(
dataset_context_config).load_dataset(
LocalDataLoaderType.HF_DATA_LOADER)
- return MsDataset.to_ms_dataset(dataset_inst, target=target)
+ dataset_inst = MsDataset.to_ms_dataset(dataset_inst, target=target)
+ if isinstance(dataset_inst, MsDataset):
+ dataset_inst._dataset_context_config = dataset_context_config
+ if custom_cfg:
+ dataset_inst.to_custom_dataset(
+ custom_cfg=custom_cfg, **config_kwargs)
+ dataset_inst.is_custom = True
+ return dataset_inst
# Load from the huggingface hub
elif hub == Hubs.huggingface:
dataset_inst = RemoteDataLoaderManager(
dataset_context_config).load_dataset(
RemoteDataLoaderType.HF_DATA_LOADER)
- return MsDataset.to_ms_dataset(dataset_inst, target=target)
+ dataset_inst = MsDataset.to_ms_dataset(dataset_inst, target=target)
+ dataset_inst._dataset_context_config = dataset_context_config
+ if custom_cfg:
+ dataset_inst.to_custom_dataset(
+ custom_cfg=custom_cfg, **config_kwargs)
+ dataset_inst.is_custom = True
+ return dataset_inst
# Load from the modelscope hub
elif hub == Hubs.modelscope:
- dataset_inst = RemoteDataLoaderManager(
- dataset_context_config).load_dataset(
- RemoteDataLoaderType.MS_DATA_LOADER)
- return MsDataset.to_ms_dataset(dataset_inst, target=target)
+ remote_dataloader_manager = RemoteDataLoaderManager(
+ dataset_context_config)
+ dataset_inst = remote_dataloader_manager.load_dataset(
+ RemoteDataLoaderType.MS_DATA_LOADER)
+ dataset_inst = MsDataset.to_ms_dataset(dataset_inst, target=target)
+ if isinstance(dataset_inst, MsDataset):
+ dataset_inst._dataset_context_config = remote_dataloader_manager.dataset_context_config
+ if custom_cfg:
+ dataset_inst.to_custom_dataset(
+ custom_cfg=custom_cfg, **config_kwargs)
+ dataset_inst.is_custom = True
+ return dataset_inst
else:
raise 'Please adjust input args to specify a loading mode, we support following scenes: ' \
'loading from local disk, huggingface hub and modelscope hub.'
- def to_torch_dataset_with_processors(
- self,
- preprocessors: Union[Callable, List[Callable]],
- columns: Union[str, List[str]] = None,
- to_tensor: bool = True,
- ):
- import torch
- preprocessor_list = preprocessors if isinstance(
- preprocessors, list) else [preprocessors]
-
- columns = format_list(columns)
-
- columns = [
- key for key in self._hf_ds.features.keys() if key in columns
- ]
- retained_columns = []
- if to_tensor:
- sample = next(iter(self._hf_ds))
-
- sample_res = {k: np.array(sample[k]) for k in columns}
- for processor in preprocessor_list:
- sample_res.update(
- {k: np.array(v)
- for k, v in processor(sample).items()})
-
- def is_numpy_number(value):
- return np.issubdtype(value.dtype, np.integer) or np.issubdtype(
- value.dtype, np.floating)
-
- for k in sample_res.keys():
- if not is_numpy_number(sample_res[k]):
- logger.warning(
- f'Data of column {k} is non-numeric, will be removed')
- continue
- retained_columns.append(k)
-
- class MsMapDataset(torch.utils.data.Dataset):
-
- def __init__(self, dataset: Iterable, preprocessor_list,
- retained_columns, columns, to_tensor):
- super(MsDataset).__init__()
- self.dataset = dataset
- self.preprocessor_list = preprocessor_list
- self.to_tensor = to_tensor
- self.retained_columns = retained_columns
- self.columns = columns
-
- def __len__(self):
- return len(self.dataset)
-
- def type_converter(self, x):
- import torch
- if self.to_tensor:
- return torch.tensor(x)
- else:
- return x
-
- def __getitem__(self, index):
- item_dict = self.dataset[index]
- res = {
- k: self.type_converter(item_dict[k])
- for k in self.columns
- if (not self.to_tensor) or k in self.retained_columns
- }
- for preprocessor in self.preprocessor_list:
- res.update({
- k: self.type_converter(v)
- for k, v in preprocessor(item_dict).items()
- if (not self.to_tensor) or k in self.retained_columns
- })
- return res
-
- return MsMapDataset(self._hf_ds, preprocessor_list, retained_columns,
- columns, to_tensor)
-
- def to_torch_dataset(
- self,
- columns: Union[str, List[str]] = None,
- preprocessors: Union[Callable, List[Callable]] = None,
- task_name: str = None,
- task_data_config: ConfigDict = None,
- to_tensor: bool = True,
- **format_kwargs,
- ):
- """Create a torch.utils.data.Dataset from the MS Dataset. The torch.utils.data.Dataset can be passed to
- torch.utils.data.DataLoader.
-
- Args:
- preprocessors (Callable or List[Callable], default None): (list of) Preprocessor object used to process
- every sample of the dataset. The output type of processors is dict, and each (numeric) field of the dict
- will be used as a field of torch.utils.data.Dataset.
- columns (str or List[str], default None): Dataset column(s) to be loaded (numeric data only if
- `to_tensor` is True). If the preprocessor is None, the arg columns must have at least one column.
- If the `preprocessors` is not None, the output fields of processors will also be added.
- task_name (str, default None): task name, refer to :obj:`Tasks` for more details
- task_data_config (ConfigDict, default None): config dict for model object.
- to_tensor (bool, default None): whether convert the data types of dataset column(s) to torch.tensor or not.
- format_kwargs: A `dict` of arguments to be passed to the `torch.tensor`.
-
- Returns:
- :class:`tf.data.Dataset`
-
- """
- if not is_torch_available():
- raise ImportError(
- 'The function to_torch_dataset requires pytorch to be installed'
- )
- if isinstance(self._hf_ds, ExternalDataset):
- task_data_config.update({'preprocessor': preprocessors})
- task_data_config.update(self._hf_ds.config_kwargs)
- return build_task_dataset(task_data_config, task_name)
- if preprocessors is not None:
- return self.to_torch_dataset_with_processors(
- preprocessors, columns=columns, to_tensor=to_tensor)
- else:
- self._hf_ds.reset_format()
- self._hf_ds.set_format(
- type='torch', columns=columns, format_kwargs=format_kwargs)
- return self._hf_ds
-
- def to_tf_dataset_with_processors(
- self,
- batch_size: int,
- shuffle: bool,
- preprocessors: Union[Callable, List[Callable]],
- drop_remainder: bool = None,
- prefetch: bool = True,
- label_cols: Union[str, List[str]] = None,
- columns: Union[str, List[str]] = None,
- ):
- preprocessor_list = preprocessors if isinstance(
- preprocessors, list) else [preprocessors]
-
- label_cols = format_list(label_cols)
- columns = format_list(columns)
- cols_to_retain = list(set(label_cols + columns))
- retained_columns = [
- key for key in self._hf_ds.features.keys() if key in cols_to_retain
- ]
- import tensorflow as tf
- tf_dataset = tf.data.Dataset.from_tensor_slices(
- np.arange(len(self._hf_ds), dtype=np.int64))
- if shuffle:
- tf_dataset = tf_dataset.shuffle(buffer_size=len(self._hf_ds))
-
- def func(i, return_dict=False):
- i = int(i)
- res = {k: np.array(self._hf_ds[i][k]) for k in retained_columns}
- for preprocessor in preprocessor_list:
- # TODO preprocessor output may have the same key
- res.update({
- k: np.array(v)
- for k, v in preprocessor(self._hf_ds[i]).items()
- })
- if return_dict:
- return res
- return tuple(list(res.values()))
-
- sample_res = func(0, True)
-
- @tf.function(input_signature=[tf.TensorSpec(None, tf.int64)])
- def fetch_function(i):
- output = tf.numpy_function(
- func,
- inp=[i],
- Tout=[
- tf.dtypes.as_dtype(val.dtype)
- for val in sample_res.values()
- ],
- )
- return {key: output[i] for i, key in enumerate(sample_res)}
-
- from tensorflow.data.experimental import AUTOTUNE
- tf_dataset = tf_dataset.map(
- fetch_function, num_parallel_calls=AUTOTUNE)
- if label_cols:
-
- def split_features_and_labels(input_batch):
- labels = {
- key: tensor
- for key, tensor in input_batch.items() if key in label_cols
- }
- if len(input_batch) == 1:
- input_batch = next(iter(input_batch.values()))
- if len(labels) == 1:
- labels = next(iter(labels.values()))
- return input_batch, labels
-
- tf_dataset = tf_dataset.map(split_features_and_labels)
-
- elif len(columns) == 1:
- tf_dataset = tf_dataset.map(lambda x: next(iter(x.values())))
- if batch_size > 1:
- tf_dataset = tf_dataset.batch(
- batch_size, drop_remainder=drop_remainder)
-
- if prefetch:
- tf_dataset = tf_dataset.prefetch(AUTOTUNE)
- return tf_dataset
-
- def to_tf_dataset(
- self,
- batch_size: int,
- shuffle: bool,
- preprocessors: Union[Callable, List[Callable]] = None,
- columns: Union[str, List[str]] = None,
- collate_fn: Callable = None,
- drop_remainder: bool = None,
- collate_fn_args: Dict[str, Any] = None,
- label_cols: Union[str, List[str]] = None,
- prefetch: bool = True,
- ):
- """Create a tf.data.Dataset from the MS Dataset. This tf.data.Dataset can be passed to tf methods like
- model.fit() or model.predict().
-
- Args:
- batch_size (int): Number of samples in a single batch.
- shuffle(bool): Shuffle the dataset order.
- preprocessors (Callable or List[Callable], default None): (list of) Preprocessor object used to process
- every sample of the dataset. The output type of processors is dict, and each field of the dict will be
- used as a field of the tf.data. Dataset. If the `preprocessors` is None, the `collate_fn`
- shouldn't be None.
- columns (str or List[str], default None): Dataset column(s) to be loaded. If the preprocessor is None,
- the arg columns must have at least one column. If the `preprocessors` is not None, the output fields of
- processors will also be added.
- collate_fn(Callable, default None): A callable object used to collect lists of samples into a batch. If
- the `preprocessors` is None, the `collate_fn` shouldn't be None.
- drop_remainder(bool, default None): Drop the last incomplete batch when loading.
- collate_fn_args (Dict, optional): A `dict` of arguments to be passed to the`collate_fn`.
- label_cols (str or List[str], defalut None): Dataset column(s) to load as labels.
- prefetch (bool, default True): Prefetch data.
-
- Returns:
- :class:`tf.data.Dataset`
-
- """
- if not is_tf_available():
- raise ImportError(
- 'The function to_tf_dataset requires Tensorflow to be installed.'
- )
- if preprocessors is not None:
- return self.to_tf_dataset_with_processors(
- batch_size,
- shuffle,
- preprocessors,
- drop_remainder=drop_remainder,
- prefetch=prefetch,
- label_cols=label_cols,
- columns=columns)
-
- if collate_fn is None:
- logger.error(
- 'The `preprocessors` and the `collate_fn` should`t be both None.'
- )
- return None
- self._hf_ds.reset_format()
- return self._hf_ds.to_tf_dataset(
- columns,
- batch_size,
- shuffle,
- collate_fn,
- drop_remainder=drop_remainder,
- collate_fn_args=collate_fn_args,
- label_cols=label_cols,
- prefetch=prefetch)
-
- def to_hf_dataset(self) -> Dataset:
- self._hf_ds.reset_format()
- return self._hf_ds
-
- def remap_columns(self, column_mapping: Dict[str, str]) -> Dataset:
- """
- Rename columns and return the underlying hf dataset directly
- TODO: support native MsDataset column rename.
- Args:
- column_mapping: the mapping of the original and new column names
- Returns:
- underlying hf dataset
- """
- self._hf_ds.reset_format()
- return self._hf_ds.rename_columns(column_mapping)
-
@staticmethod
def upload(
object_name: str,
@@ -695,3 +447,358 @@ class MsDataset:
resp_msg = _delete_manager.delete(object_name=object_name)
logger.info(f'Object {object_name} successfully removed!')
return resp_msg
+
+ def to_torch_dataset(
+ self,
+ columns: Union[str, List[str]] = None,
+ preprocessors: Union[Callable, List[Callable]] = None,
+ task_name: str = None,
+ data_config: ConfigDict = None,
+ to_tensor: bool = True,
+ **format_kwargs,
+ ):
+ """Create a torch.utils.data.Dataset from the MS Dataset. The torch.utils.data.Dataset can be passed to
+ torch.utils.data.DataLoader.
+
+ Args:
+ preprocessors (Callable or List[Callable], default None): (list of) Preprocessor object used to process
+ every sample of the dataset. The output type of processors is dict, and each (numeric) field of the dict
+ will be used as a field of torch.utils.data.Dataset.
+ columns (str or List[str], default None): Dataset column(s) to be loaded (numeric data only if
+ `to_tensor` is True). If the preprocessor is None, the arg columns must have at least one column.
+ If the `preprocessors` is not None, the output fields of processors will also be added.
+ task_name (str, default None): task name, refer to :obj:`Tasks` for more details
+ data_config (ConfigDict, default None): config dict for model object.
+ Attributes of ConfigDict:
+ `preprocessor` (Callable, List[Callable], optional): preprocessors to deal with dataset
+ `type` (str): the type of task
+ `split_config` (dict, optional): get the split config for ExternalDataset
+ `test_mode` (bool, optional): is test mode or not
+ to_tensor (bool, default None): whether convert the data types of dataset column(s) to torch.tensor or not.
+ format_kwargs: A `dict` of arguments to be passed to the `torch.tensor`.
+
+ Returns:
+ :class:`torch.utils.data.Dataset`
+
+ """
+ if not is_torch_available():
+ raise ImportError(
+ 'The function to_torch_dataset requires pytorch to be installed'
+ )
+ if isinstance(self._hf_ds, ExternalDataset):
+ data_config.update({'preprocessor': preprocessors})
+ data_config.update(self._hf_ds.config_kwargs)
+ return build_custom_dataset(data_config, task_name)
+ if preprocessors is not None:
+ return self._to_torch_dataset_with_processors(
+ preprocessors, columns=columns, to_tensor=to_tensor)
+ else:
+ self._hf_ds.reset_format()
+ self._hf_ds.set_format(
+ type='torch', columns=columns, format_kwargs=format_kwargs)
+ return self._hf_ds
+
+ def to_tf_dataset(
+ self,
+ batch_size: int,
+ shuffle: bool,
+ preprocessors: Union[Callable, List[Callable]] = None,
+ columns: Union[str, List[str]] = None,
+ collate_fn: Callable = None,
+ drop_remainder: bool = None,
+ collate_fn_args: Dict[str, Any] = None,
+ label_cols: Union[str, List[str]] = None,
+ prefetch: bool = True,
+ ):
+ """Create a tf.data.Dataset from the MS Dataset. This tf.data.Dataset can be passed to tf methods like
+ model.fit() or model.predict().
+
+ Args:
+ batch_size (int): Number of samples in a single batch.
+ shuffle(bool): Shuffle the dataset order.
+ preprocessors (Callable or List[Callable], default None): (list of) Preprocessor object used to process
+ every sample of the dataset. The output type of processors is dict, and each field of the dict will be
+ used as a field of the tf.data. Dataset. If the `preprocessors` is None, the `collate_fn`
+ shouldn't be None.
+ columns (str or List[str], default None): Dataset column(s) to be loaded. If the preprocessor is None,
+ the arg columns must have at least one column. If the `preprocessors` is not None, the output fields of
+ processors will also be added.
+ collate_fn(Callable, default None): A callable object used to collect lists of samples into a batch. If
+ the `preprocessors` is None, the `collate_fn` shouldn't be None.
+ drop_remainder(bool, default None): Drop the last incomplete batch when loading.
+ collate_fn_args (Dict, optional): A `dict` of arguments to be passed to the`collate_fn`.
+ label_cols (str or List[str], defalut None): Dataset column(s) to load as labels.
+ prefetch (bool, default True): Prefetch data.
+
+ Returns:
+ :class:`tf.data.Dataset`
+
+ """
+ if not is_tf_available():
+ raise ImportError(
+ 'The function to_tf_dataset requires Tensorflow to be installed.'
+ )
+ if preprocessors is not None:
+ return self._to_tf_dataset_with_processors(
+ batch_size,
+ shuffle,
+ preprocessors,
+ drop_remainder=drop_remainder,
+ prefetch=prefetch,
+ label_cols=label_cols,
+ columns=columns)
+
+ if collate_fn is None:
+ logger.error(
+ 'The `preprocessors` and the `collate_fn` should`t be both None.'
+ )
+ return None
+ self._hf_ds.reset_format()
+ return self._hf_ds.to_tf_dataset(
+ columns,
+ batch_size,
+ shuffle,
+ collate_fn,
+ drop_remainder=drop_remainder,
+ collate_fn_args=collate_fn_args,
+ label_cols=label_cols,
+ prefetch=prefetch)
+
+ def to_hf_dataset(self) -> Dataset:
+ self._hf_ds.reset_format()
+ return self._hf_ds
+
+ def remap_columns(self, column_mapping: Dict[str, str]) -> Dataset:
+ """
+ Rename columns and return the underlying hf dataset directly
+ TODO: support native MsDataset column rename.
+ Args:
+ column_mapping: the mapping of the original and new column names
+ Returns:
+ underlying hf dataset
+ """
+ self._hf_ds.reset_format()
+ return self._hf_ds.rename_columns(column_mapping)
+
+ def _to_torch_dataset_with_processors(
+ self,
+ preprocessors: Union[Callable, List[Callable]],
+ columns: Union[str, List[str]] = None,
+ to_tensor: bool = True,
+ ):
+ preprocessor_list = preprocessors if isinstance(
+ preprocessors, list) else [preprocessors]
+
+ columns = format_list(columns)
+
+ columns = [
+ key for key in self._hf_ds.features.keys() if key in columns
+ ]
+ retained_columns = []
+ if to_tensor:
+ sample = next(iter(self._hf_ds))
+
+ sample_res = {k: np.array(sample[k]) for k in columns}
+ for processor in preprocessor_list:
+ sample_res.update(
+ {k: np.array(v)
+ for k, v in processor(sample).items()})
+
+ def is_numpy_number(value):
+ return np.issubdtype(value.dtype, np.integer) or np.issubdtype(
+ value.dtype, np.floating)
+
+ for k in sample_res.keys():
+ if not is_numpy_number(sample_res[k]):
+ logger.warning(
+ f'Data of column {k} is non-numeric, will be removed')
+ continue
+ retained_columns.append(k)
+
+ import torch
+
+ class MsMapDataset(torch.utils.data.Dataset):
+
+ def __init__(self, dataset: Iterable, preprocessor_list,
+ retained_columns, columns, to_tensor):
+ super(MsDataset).__init__()
+ self.dataset = dataset
+ self.preprocessor_list = preprocessor_list
+ self.to_tensor = to_tensor
+ self.retained_columns = retained_columns
+ self.columns = columns
+
+ def __len__(self):
+ return len(self.dataset)
+
+ def type_converter(self, x):
+ if self.to_tensor:
+ return torch.tensor(x)
+ else:
+ return x
+
+ def __getitem__(self, index):
+ item_dict = self.dataset[index]
+ res = {
+ k: self.type_converter(item_dict[k])
+ for k in self.columns
+ if (not self.to_tensor) or k in self.retained_columns
+ }
+ for preprocessor in self.preprocessor_list:
+ res.update({
+ k: self.type_converter(v)
+ for k, v in preprocessor(item_dict).items()
+ if (not self.to_tensor) or k in self.retained_columns
+ })
+ return res
+
+ return MsMapDataset(self._hf_ds, preprocessor_list, retained_columns,
+ columns, to_tensor)
+
+ def _to_tf_dataset_with_processors(
+ self,
+ batch_size: int,
+ shuffle: bool,
+ preprocessors: Union[Callable, List[Callable]],
+ drop_remainder: bool = None,
+ prefetch: bool = True,
+ label_cols: Union[str, List[str]] = None,
+ columns: Union[str, List[str]] = None,
+ ):
+ preprocessor_list = preprocessors if isinstance(
+ preprocessors, list) else [preprocessors]
+
+ label_cols = format_list(label_cols)
+ columns = format_list(columns)
+ cols_to_retain = list(set(label_cols + columns))
+ retained_columns = [
+ key for key in self._hf_ds.features.keys() if key in cols_to_retain
+ ]
+ import tensorflow as tf
+ tf_dataset = tf.data.Dataset.from_tensor_slices(
+ np.arange(len(self._hf_ds), dtype=np.int64))
+ if shuffle:
+ tf_dataset = tf_dataset.shuffle(buffer_size=len(self._hf_ds))
+
+ def func(i, return_dict=False):
+ i = int(i)
+ res = {k: np.array(self._hf_ds[i][k]) for k in retained_columns}
+ for preprocessor in preprocessor_list:
+ # TODO preprocessor output may have the same key
+ res.update({
+ k: np.array(v)
+ for k, v in preprocessor(self._hf_ds[i]).items()
+ })
+ if return_dict:
+ return res
+ return tuple(list(res.values()))
+
+ sample_res = func(0, True)
+
+ @tf.function(input_signature=[tf.TensorSpec(None, tf.int64)])
+ def fetch_function(i):
+ output = tf.numpy_function(
+ func,
+ inp=[i],
+ Tout=[
+ tf.dtypes.as_dtype(val.dtype)
+ for val in sample_res.values()
+ ],
+ )
+ return {key: output[i] for i, key in enumerate(sample_res)}
+
+ from tensorflow.data.experimental import AUTOTUNE
+ tf_dataset = tf_dataset.map(
+ fetch_function, num_parallel_calls=AUTOTUNE)
+ if label_cols:
+
+ def split_features_and_labels(input_batch):
+ labels = {
+ key: tensor
+ for key, tensor in input_batch.items() if key in label_cols
+ }
+ if len(input_batch) == 1:
+ input_batch = next(iter(input_batch.values()))
+ if len(labels) == 1:
+ labels = next(iter(labels.values()))
+ return input_batch, labels
+
+ tf_dataset = tf_dataset.map(split_features_and_labels)
+
+ elif len(columns) == 1:
+ tf_dataset = tf_dataset.map(lambda x: next(iter(x.values())))
+ if batch_size > 1:
+ tf_dataset = tf_dataset.batch(
+ batch_size, drop_remainder=drop_remainder)
+
+ if prefetch:
+ tf_dataset = tf_dataset.prefetch(AUTOTUNE)
+ return tf_dataset
+
+ def to_custom_dataset(self,
+ custom_cfg: Config,
+ preprocessor=None,
+ mode=None,
+ **kwargs):
+ """Convert the input datasets to specific custom datasets by given model configuration and preprocessor.
+
+ Args:
+ custom_cfg (Config): The model configuration for custom datasets.
+ preprocessor (Preprocessor, Optional): Preprocessor for data samples.
+ mode (str, Optional): See modelscope.utils.constant.ModeKeys
+
+ Returns:
+ `MsDataset`
+ """
+
+ if not is_torch_available():
+ raise ImportError(
+ 'The function to_custom_dataset requires pytorch to be installed'
+ )
+ if not custom_cfg:
+ return
+
+ # Set the flag that it has been converted to custom dataset
+ self.is_custom = True
+
+ # Check mode
+ if mode is None:
+ if 'mode' in kwargs:
+ mode = kwargs.get('mode')
+
+ # Parse cfg
+ ds_cfg_key = 'train' if mode == ModeKeys.TRAIN else 'val'
+ data_cfg = custom_cfg.safe_get(f'dataset.{ds_cfg_key}')
+ if data_cfg is None:
+ data_cfg = ConfigDict(type=custom_cfg.model.type) if hasattr(
+ custom_cfg, ConfigFields.model) else ConfigDict(type=None)
+ data_cfg.update(dict(mode=mode))
+
+ # Get preprocessors from custom_cfg
+ task_name = custom_cfg.task
+ if 'task' in kwargs:
+ task_name = kwargs.pop('task')
+ field_name = Tasks.find_field_by_task(task_name)
+ if 'field' in kwargs:
+ field_name = kwargs.pop('field')
+ if preprocessor is None and hasattr(custom_cfg, 'preprocessor'):
+ preprocessor_cfg = custom_cfg.preprocessor
+ if preprocessor_cfg:
+ preprocessor = build_preprocessor(preprocessor_cfg, field_name)
+
+ # Build custom dataset
+ if isinstance(self._hf_ds, ExternalDataset):
+ data_cfg.update(dict(preprocessor=preprocessor))
+ data_cfg.update(self._hf_ds.config_kwargs)
+ self._hf_ds = build_custom_dataset(
+ cfg=data_cfg, task_name=custom_cfg.task)
+ return
+
+ if preprocessor is not None:
+ to_tensor = kwargs.get('to_tensor', True)
+ self._hf_ds = self._to_torch_dataset_with_processors(
+ preprocessors=preprocessor, to_tensor=to_tensor)
+ else:
+ self._hf_ds.reset_format()
+ self._hf_ds.set_format(type='torch')
+ return
diff --git a/modelscope/msdatasets/task_datasets/__init__.py b/modelscope/msdatasets/task_datasets/__init__.py
index 167af6db..28c00b07 100644
--- a/modelscope/msdatasets/task_datasets/__init__.py
+++ b/modelscope/msdatasets/task_datasets/__init__.py
@@ -1,45 +1,25 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
+
from typing import TYPE_CHECKING
-from modelscope.utils.import_utils import LazyImportModule, is_torch_available
+from modelscope.utils.import_utils import LazyImportModule
if TYPE_CHECKING:
- from .base import TaskDataset
- from .builder import TASK_DATASETS, build_task_dataset
from .torch_base_dataset import TorchTaskDataset
- from .veco_dataset import VecoDataset
- from .image_instance_segmentation_coco_dataset import ImageInstanceSegmentationCocoDataset
- from .movie_scene_segmentation import MovieSceneSegmentationDataset
+ from .gopro_image_deblurring_dataset import GoproImageDeblurringDataset
+ from .reds_image_deblurring_dataset import RedsImageDeblurringDataset
+ from .sidd_image_denoising import SiddImageDenoisingDataset
from .video_summarization_dataset import VideoSummarizationDataset
- from .language_guided_video_summarization_dataset import LanguageGuidedVideoSummarizationDataset
- from .image_inpainting import ImageInpaintingDataset
- from .text_ranking_dataset import TextRankingDataset
- from .referring_video_object_segmentation import ReferringVideoObjectSegmentationDataset
- from .bad_image_detecting import BadImageDetectingDataset
-
else:
_import_structure = {
- 'base': ['TaskDataset'],
- 'builder': ['TASK_DATASETS', 'build_task_dataset'],
'torch_base_dataset': ['TorchTaskDataset'],
- 'text_ranking_dataset': ['TextRankingDataset'],
- 'veco_dataset': ['VecoDataset'],
- 'image_instance_segmentation_coco_dataset':
- ['ImageInstanceSegmentationCocoDataset'],
+ 'gopro_image_deblurring_dataset': ['GoproImageDeblurringDataset'],
+ 'reds_image_deblurring_dataset': ['RedsImageDeblurringDataset'],
+ 'sidd_image_denoising': ['SiddImageDenoisingDataset'],
'video_summarization_dataset': ['VideoSummarizationDataset'],
- 'language_guided_video_summarization_dataset':
- ['LanguageGuidedVideoSummarizationDataset'],
- 'movie_scene_segmentation': ['MovieSceneSegmentationDataset'],
- 'image_inpainting': ['ImageInpaintingDataset'],
- 'sidd_image_denoising_dataset': ['SiddImageDenoisingDataset'],
- 'image_portrait_enhancement_dataset':
- ['ImagePortraitEnhancementDataset'],
- 'referring_video_object_segmentation':
- ['ReferringVideoObjectSegmentationDataset'],
- 'bad_image_detecting': ['BadImageDetectingDataset'],
}
- import sys
+ import sys
sys.modules[__name__] = LazyImportModule(
__name__,
globals()['__file__'],
diff --git a/modelscope/msdatasets/task_datasets/base.py b/modelscope/msdatasets/task_datasets/base.py
deleted file mode 100644
index 39b791b1..00000000
--- a/modelscope/msdatasets/task_datasets/base.py
+++ /dev/null
@@ -1,48 +0,0 @@
-# Copyright (c) Alibaba, Inc. and its affiliates.
-from abc import ABC, abstractmethod
-from typing import Any, List, Tuple, Union
-
-
-class TaskDataset(ABC):
- """The task dataset base class for all the task specific dataset processors.
- """
-
- def __init__(self,
- datasets: Union[Any, List[Any]],
- mode,
- preprocessor=None,
- **kwargs):
- super().__init__()
- self.mode = mode
- self.preprocessor = preprocessor
- self._inner_dataset = self.prepare_dataset(datasets)
-
- @abstractmethod
- def prepare_dataset(self, datasets: Union[Any, List[Any]]) -> Any:
- """Prepare a dataset.
-
- User can process the input datasets in a whole dataset perspective.
- This method also helps to merge several datasets to one.
-
- Args:
- datasets: The original dataset(s)
-
- Returns: A single dataset, which may be created after merging.
-
- """
- pass
-
- @abstractmethod
- def prepare_sample(self, data):
- """Preprocess the data fetched from the inner_dataset.
-
- If the preprocessor is None, the original data will be returned, else the preprocessor will be called.
- User can override this method to implement custom logics.
-
- Args:
- data: The data fetched from the dataset.
-
- Returns: The processed data.
-
- """
- pass
diff --git a/modelscope/msdatasets/task_datasets/gopro_image_deblurring_dataset.py b/modelscope/msdatasets/task_datasets/gopro_image_deblurring_dataset.py
index fb621551..0fa94487 100644
--- a/modelscope/msdatasets/task_datasets/gopro_image_deblurring_dataset.py
+++ b/modelscope/msdatasets/task_datasets/gopro_image_deblurring_dataset.py
@@ -1,64 +1,11 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
-import cv2
-import numpy as np
+from modelscope.msdatasets.dataset_cls.custom_datasets import \
+ GoproImageDeblurringDataset
+from modelscope.utils.logger import get_logger
-from modelscope.metainfo import Datasets
-from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS
-from modelscope.msdatasets.task_datasets.sidd_image_denoising.data_utils import (
- img2tensor, padding)
-from modelscope.msdatasets.task_datasets.sidd_image_denoising.transforms import (
- augment, paired_random_crop)
-from modelscope.msdatasets.task_datasets.torch_base_dataset import \
- TorchTaskDataset
-from modelscope.utils.constant import Tasks
-
-
-def default_loader(path):
- return cv2.imread(path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.0
-
-
-@TASK_DATASETS.register_module(
- Tasks.image_deblurring, module_name=Datasets.PairedDataset)
-class GoproImageDeblurringDataset(TorchTaskDataset):
- """Paired image dataset for image restoration.
- """
-
- def __init__(self, dataset, opt, is_train):
- self.dataset = dataset
- self.opt = opt
- self.is_train = is_train
-
- def __len__(self):
- return len(self.dataset)
-
- def __getitem__(self, index):
-
- # Load gt and lq images. Dimension order: HWC; channel order: BGR;
- # image range: [0, 1], float32.
- item_dict = self.dataset[index]
- gt_path = item_dict['Sharp Image:FILE']
- img_gt = default_loader(gt_path)
- lq_path = item_dict['Blur Image:FILE']
- img_lq = default_loader(lq_path)
-
- # augmentation for training
- if self.is_train:
- gt_size = self.opt.gt_size
- # padding
- img_gt, img_lq = padding(img_gt, img_lq, gt_size)
-
- # random crop
- img_gt, img_lq = paired_random_crop(
- img_gt, img_lq, gt_size, scale=1)
-
- # flip, rotation
- img_gt, img_lq = augment([img_gt, img_lq], self.opt.use_flip,
- self.opt.use_rot)
-
- # BGR to RGB, HWC to CHW, numpy to tensor
- img_gt, img_lq = img2tensor([img_gt, img_lq],
- bgr2rgb=True,
- float32=True)
-
- return {'input': img_lq, 'target': img_gt}
+logger = get_logger()
+logger.warning(
+ 'The reference has been Deprecated in modelscope v1.4.0+, '
+ 'please use `from modelscope.msdatasets.dataset_cls.custom_datasets import GoproImageDeblurringDataset`'
+)
diff --git a/modelscope/msdatasets/task_datasets/image_inpainting/__init__.py b/modelscope/msdatasets/task_datasets/image_inpainting/__init__.py
deleted file mode 100644
index 732a1bd7..00000000
--- a/modelscope/msdatasets/task_datasets/image_inpainting/__init__.py
+++ /dev/null
@@ -1,2 +0,0 @@
-# Copyright (c) Alibaba, Inc. and its affiliates.
-from .image_inpainting_dataset import ImageInpaintingDataset
diff --git a/modelscope/msdatasets/task_datasets/movie_scene_segmentation/__init__.py b/modelscope/msdatasets/task_datasets/movie_scene_segmentation/__init__.py
deleted file mode 100644
index b1bc40f8..00000000
--- a/modelscope/msdatasets/task_datasets/movie_scene_segmentation/__init__.py
+++ /dev/null
@@ -1,2 +0,0 @@
-# Copyright (c) Alibaba, Inc. and its affiliates.
-from .movie_scene_segmentation_dataset import MovieSceneSegmentationDataset
diff --git a/modelscope/msdatasets/task_datasets/reds_image_deblurring_dataset.py b/modelscope/msdatasets/task_datasets/reds_image_deblurring_dataset.py
index 17b731bc..c129a4d0 100644
--- a/modelscope/msdatasets/task_datasets/reds_image_deblurring_dataset.py
+++ b/modelscope/msdatasets/task_datasets/reds_image_deblurring_dataset.py
@@ -1,60 +1,11 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
-import cv2
-import numpy as np
+from modelscope.msdatasets.dataset_cls.custom_datasets import \
+ RedsImageDeblurringDataset
+from modelscope.utils.logger import get_logger
-from modelscope.metainfo import Datasets
-from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS
-from modelscope.msdatasets.task_datasets.sidd_image_denoising.data_utils import (
- img2tensor, padding)
-from modelscope.msdatasets.task_datasets.sidd_image_denoising.transforms import (
- augment, paired_random_crop)
-from modelscope.msdatasets.task_datasets.torch_base_dataset import \
- TorchTaskDataset
-from modelscope.utils.constant import Tasks
-
-
-def default_loader(path):
- return cv2.imread(path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.0
-
-
-@TASK_DATASETS.register_module(
- Tasks.image_deblurring, module_name=Datasets.PairedDataset)
-class RedsImageDeblurringDataset(TorchTaskDataset):
- """Paired image dataset for image restoration.
- """
-
- def __init__(self, dataset, opt, is_train):
- self.dataset = dataset
- self.opt = opt
- self.is_train = is_train
-
- def __len__(self):
- return len(self.dataset)
-
- def __getitem__(self, index):
- item_dict = self.dataset[index]
- hq_path = item_dict['LQ Frame:FILE']
- img_hq = default_loader(hq_path)
- lq_path = item_dict['HQ Frame:FILE']
- img_lq = default_loader(lq_path)
-
- # augmentation for training
- if self.is_train:
- gt_size = self.opt.gt_size
- # padding
- img_hq, img_lq = padding(img_hq, img_lq, gt_size)
-
- # random crop
- img_hq, img_lq = paired_random_crop(
- img_hq, img_lq, gt_size, scale=1)
-
- # flip, rotation
- img_hq, img_lq = augment([img_hq, img_lq], self.opt.use_flip,
- self.opt.use_rot)
-
- # BGR to RGB, HWC to CHW, numpy to tensor
- img_hq, img_lq = img2tensor([img_hq, img_lq],
- bgr2rgb=True,
- float32=True)
- return {'input': img_lq, 'target': img_hq}
+logger = get_logger()
+logger.warning(
+ 'The reference has been Deprecated in modelscope v1.4.0+, '
+ 'please use `modelscope.msdatasets.dataset_cls.custom_datasets import RedsImageDeblurringDataset`'
+)
diff --git a/modelscope/msdatasets/task_datasets/referring_video_object_segmentation/__init__.py b/modelscope/msdatasets/task_datasets/referring_video_object_segmentation/__init__.py
deleted file mode 100644
index 7c1b724e..00000000
--- a/modelscope/msdatasets/task_datasets/referring_video_object_segmentation/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-# Copyright (c) Alibaba, Inc. and its affiliates.
-from .referring_video_object_segmentation_dataset import \
- ReferringVideoObjectSegmentationDataset
diff --git a/modelscope/msdatasets/task_datasets/sidd_image_denoising.py b/modelscope/msdatasets/task_datasets/sidd_image_denoising.py
new file mode 100644
index 00000000..da8dbf44
--- /dev/null
+++ b/modelscope/msdatasets/task_datasets/sidd_image_denoising.py
@@ -0,0 +1,11 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+
+from modelscope.msdatasets.dataset_cls.custom_datasets import \
+ SiddImageDenoisingDataset
+from modelscope.utils.logger import get_logger
+
+logger = get_logger()
+logger.warning(
+ 'The reference has been Deprecated in modelscope v1.4.0+, '
+ 'please use `from modelscope.msdatasets.dataset_cls.custom_datasets import SiddImageDenoisingDataset`'
+)
diff --git a/modelscope/msdatasets/task_datasets/torch_base_dataset.py b/modelscope/msdatasets/task_datasets/torch_base_dataset.py
index 4d82b741..314b9d1c 100644
--- a/modelscope/msdatasets/task_datasets/torch_base_dataset.py
+++ b/modelscope/msdatasets/task_datasets/torch_base_dataset.py
@@ -1,64 +1,11 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
-from typing import Any, List, Tuple, Union
-from torch.utils.data import ConcatDataset, Dataset
+from modelscope.msdatasets.dataset_cls.custom_datasets import \
+ TorchCustomDataset as TorchTaskDataset
+from modelscope.utils.logger import get_logger
-from .base import TaskDataset
-
-
-class TorchTaskDataset(TaskDataset, Dataset):
- """The task dataset base class for all the torch-based task processors.
-
- This base class is enough for most cases, except there are procedures which can not be executed in
- preprocessors and Datasets like dataset merging.
- """
-
- def __init__(self,
- datasets: Union[Any, List[Any]],
- mode,
- preprocessor=None,
- **kwargs):
- TaskDataset.__init__(self, datasets, mode, preprocessor, **kwargs)
- self.trainer = None
-
- def __getitem__(self, index) -> Any:
- return self.prepare_sample(self._inner_dataset[index])
-
- def __len__(self):
- return len(self._inner_dataset)
-
- def prepare_dataset(self, datasets: Union[Any, List[Any]]) -> Any:
- """Prepare a dataset.
-
- User can process the input datasets in a whole dataset perspective.
- This method gives a default implementation of datasets merging, user can override this
- method to write custom logics.
-
- Args:
- datasets: The original dataset(s)
-
- Returns: A single dataset, which may be created after merging.
-
- """
- if isinstance(datasets, List):
- if len(datasets) == 1:
- return datasets[0]
- elif len(datasets) > 1:
- return ConcatDataset(datasets)
- else:
- return datasets
-
- def prepare_sample(self, data):
- """Preprocess the data fetched from the inner_dataset.
-
- If the preprocessor is None, the original data will be returned, else the preprocessor will be called.
- User can override this method to implement custom logics.
-
- Args:
- data: The data fetched from the dataset.
-
- Returns: The processed data.
-
- """
- return self.preprocessor(
- data) if self.preprocessor is not None else data
+logger = get_logger()
+logger.warning(
+ 'The reference has been Deprecated in modelscope v1.4.0+, '
+ 'please use `from modelscope.msdatasets.dataset_cls.custom_datasets import TorchCustomDataset`'
+)
diff --git a/modelscope/msdatasets/task_datasets/video_summarization_dataset.py b/modelscope/msdatasets/task_datasets/video_summarization_dataset.py
index 02639be8..24a29352 100644
--- a/modelscope/msdatasets/task_datasets/video_summarization_dataset.py
+++ b/modelscope/msdatasets/task_datasets/video_summarization_dataset.py
@@ -1,72 +1,11 @@
-# Part of the implementation is borrowed and modified from PGL-SUM,
-# publicly available at https://github.com/e-apostolidis/PGL-SUM
+# Copyright (c) Alibaba, Inc. and its affiliates.
-import os
+from modelscope.msdatasets.dataset_cls.custom_datasets import \
+ VideoSummarizationDataset
+from modelscope.utils.logger import get_logger
-import h5py
-import json
-import numpy as np
-import torch
-
-from modelscope.msdatasets.task_datasets.torch_base_dataset import \
- TorchTaskDataset
-
-
-class VideoSummarizationDataset(TorchTaskDataset):
-
- def __init__(self, mode, opt, root_dir):
- self.mode = mode
- self.data_filename = os.path.join(root_dir, opt.dataset_file)
- self.split_filename = os.path.join(root_dir, opt.split_file)
- self.split_index = opt.split_index
- hdf = h5py.File(self.data_filename, 'r')
- self.list_frame_features, self.list_gtscores = [], []
- self.list_user_summary = []
- self.list_change_points = []
- self.list_n_frames = []
- self.list_positions = []
-
- with open(self.split_filename, encoding='utf-8') as f:
- data = json.loads(f.read())
- for i, split in enumerate(data):
- if i == self.split_index:
- self.split = split
- break
-
- for video_name in self.split[self.mode + '_keys']:
- frame_features = torch.Tensor(
- np.array(hdf[video_name + '/features']))
- gtscore = torch.Tensor(np.array(hdf[video_name + '/gtscore']))
- user_summary = np.array(hdf[f'{video_name}/user_summary'])
- change_points = np.array(hdf[f'{video_name}/change_points'])
- n_frames = np.array(hdf[f'{video_name}/n_frames'])
- positions = np.array(hdf[f'{video_name}/picks'])
-
- self.list_frame_features.append(frame_features)
- self.list_gtscores.append(gtscore)
- self.list_user_summary.append(user_summary)
- self.list_change_points.append(change_points)
- self.list_n_frames.append(n_frames)
- self.list_positions.append(positions)
-
- hdf.close()
-
- def __len__(self):
- self.len = len(self.split[self.mode + '_keys'])
- return self.len
-
- def __getitem__(self, index):
- frame_features = self.list_frame_features[index]
- gtscore = self.list_gtscores[index]
- user_summary = self.list_user_summary[index]
- change_points = self.list_change_points[index]
- n_frames = self.list_n_frames[index]
- positions = self.list_positions[index]
-
- return dict(
- frame_features=frame_features,
- gtscore=gtscore,
- user_summary=user_summary,
- change_points=change_points,
- n_frames=n_frames,
- positions=positions)
+logger = get_logger()
+logger.warning(
+ 'The reference has been Deprecated in modelscope v1.4.0+, '
+ 'please use `from modelscope.msdatasets.dataset_cls.custom_datasets import VideoSummarizationDataset`'
+)
diff --git a/modelscope/msdatasets/utils/dataset_utils.py b/modelscope/msdatasets/utils/dataset_utils.py
index 4c80af7d..dde044d5 100644
--- a/modelscope/msdatasets/utils/dataset_utils.py
+++ b/modelscope/msdatasets/utils/dataset_utils.py
@@ -184,9 +184,11 @@ def get_dataset_files(subset_split_into: dict,
meta_map = defaultdict(dict)
file_map = defaultdict(dict)
args_map = defaultdict(dict)
+ custom_type_map = defaultdict(dict)
modelscope_api = HubApi()
for split, info in subset_split_into.items():
+ custom_type_map[split] = info.get('custom', '')
meta_map[split] = modelscope_api.get_dataset_file_url(
info.get('meta', ''), dataset_name, namespace, revision)
if info.get('file'):
@@ -221,4 +223,4 @@ def get_dataset_files(subset_split_into: dict,
if contains_dir(file_map):
file_map = get_split_objects_map(file_map, objects)
- return meta_map, file_map, args_map
+ return meta_map, file_map, args_map, custom_type_map
diff --git a/modelscope/outputs/outputs.py b/modelscope/outputs/outputs.py
index c3d66ff5..f4e8fbf7 100644
--- a/modelscope/outputs/outputs.py
+++ b/modelscope/outputs/outputs.py
@@ -57,6 +57,7 @@ class OutputKeys(object):
MATCHES = 'matches'
PCD12 = 'pcd12'
PCD12_ALIGN = 'pcd12_align'
+ TBOUNDS = 'tbounds'
TASK_OUTPUTS = {
@@ -70,6 +71,7 @@ TASK_OUTPUTS = {
# }
Tasks.ocr_detection: [OutputKeys.POLYGONS],
Tasks.table_recognition: [OutputKeys.POLYGONS],
+ Tasks.lineless_table_recognition: [OutputKeys.POLYGONS, OutputKeys.BOXES],
Tasks.license_plate_detection: [OutputKeys.POLYGONS, OutputKeys.TEXT],
# ocr recognition result for single sample
@@ -456,6 +458,16 @@ TASK_OUTPUTS = {
# }
Tasks.face_reconstruction: [OutputKeys.OUTPUT],
+ # 3D human reconstruction result for single sample
+ # {
+ # "output": {
+ # "vertices": np.array with shape(n, 3),
+ # "faces": np.array with shape(n, 3),
+ # "colors": np.array with shape(n, 3),
+ # }
+ # }
+ Tasks.human_reconstruction: [OutputKeys.OUTPUT],
+
# 2D hand keypoints result for single sample
# {
# "keypoints": [
@@ -851,6 +863,7 @@ TASK_OUTPUTS = {
# punctuation result for single sample
# { "text": "你好,明天!"}
Tasks.punctuation: [OutputKeys.TEXT],
+
# language model result for single sample
# { "text": " hel@@ lo 大 家 好 呀
# p( hel@@ | ) = 0.00057767 [ -7.45650959 ]
@@ -864,6 +877,22 @@ TASK_OUTPUTS = {
# "}
Tasks.language_score_prediction: [OutputKeys.TEXT],
+ # speech timestamp result for single sample
+ # {
+ # 'text': ' 0.000 0.376;一 0.376 0.556;个 0.556 0.796;东 0.796 0.976;
+ # 太 0.976 1.136;平 1.136 1.256;洋 1.256 1.436;国 1.436 1.676;
+ # 1.676 1.676;家 1.676 1.916; 1.916 2.036;为 2.036 2.196;
+ # 什 2.196 2.316;么 2.316 2.496;跑 2.496 2.676;到 2.676 2.856;
+ # 西 2.856 3.036;太 3.036 3.196;平 3.196 3.376;洋 3.376 3.496;
+ # 来 3.496 3.636;了 3.636 3.796;呢 3.796 4.148; 4.148 4.440;',
+ # 'timestamp': [[0, 376], [376, 556], [556, 795], [795, 976],
+ # [976, 1136], [1136, 1256], [1256, 1436], [1436, 1676],
+ # [1676, 1676], [1676, 1916], [1916, 2036], [2036, 2196],
+ # [2196, 2316], [2316, 2496], [2496, 2676], [2676, 2856],
+ # [2856, 3036], [3036, 3196], [3196, 3376], [3376, 3496]]
+ # }
+ Tasks.speech_timestamp: [OutputKeys.TEXT],
+
# audio processed for single file in PCM format
# {
# "output_pcm": pcm encoded audio bytes
@@ -1077,6 +1106,7 @@ TASK_OUTPUTS = {
Tasks.document_grounded_dialog_generate: [OutputKeys.TEXT],
Tasks.document_grounded_dialog_rerank: [OutputKeys.OUTPUT],
Tasks.document_grounded_dialog_retrieval: [OutputKeys.OUTPUT],
+ Tasks.video_temporal_grounding: [OutputKeys.SCORES, OutputKeys.TBOUNDS],
}
diff --git a/modelscope/pipeline_inputs.py b/modelscope/pipeline_inputs.py
index 0756ffb4..381b5eaa 100644
--- a/modelscope/pipeline_inputs.py
+++ b/modelscope/pipeline_inputs.py
@@ -316,6 +316,8 @@ TASK_INPUTS = {
},
Tasks.action_detection:
InputType.VIDEO,
+ Tasks.human_reconstruction:
+ InputType.IMAGE,
Tasks.image_reid_person:
InputType.IMAGE,
Tasks.video_inpainting: {
diff --git a/modelscope/pipelines/audio/__init__.py b/modelscope/pipelines/audio/__init__.py
index c38c9762..18e8b8b3 100644
--- a/modelscope/pipelines/audio/__init__.py
+++ b/modelscope/pipelines/audio/__init__.py
@@ -14,6 +14,7 @@ if TYPE_CHECKING:
from .speaker_verification_pipeline import SpeakerVerificationPipeline
else:
_import_structure = {
+ 'ans_dfsmn_pipeline': ['ANSDFSMNPipeline'],
'ans_pipeline': ['ANSPipeline'],
'asr_inference_pipeline': ['AutomaticSpeechRecognitionPipeline'],
'kws_farfield_pipeline': ['KWSFarfieldPipeline'],
diff --git a/modelscope/pipelines/audio/ans_dfsmn_pipeline.py b/modelscope/pipelines/audio/ans_dfsmn_pipeline.py
new file mode 100644
index 00000000..fad77091
--- /dev/null
+++ b/modelscope/pipelines/audio/ans_dfsmn_pipeline.py
@@ -0,0 +1,187 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import collections
+import io
+import os
+import sys
+from typing import Any, Dict
+
+import librosa
+import numpy as np
+import soundfile as sf
+import torch
+
+from modelscope.fileio import File
+from modelscope.metainfo import Pipelines
+from modelscope.outputs import OutputKeys
+from modelscope.pipelines.base import Input, Pipeline
+from modelscope.pipelines.builder import PIPELINES
+from modelscope.utils.constant import ModelFile, Tasks
+
+HOP_LENGTH = 960
+N_FFT = 1920
+WINDOW_NAME_HAM = 'hamming'
+STFT_WIN_LEN = 1920
+WINLEN = 3840
+STRIDE = 1920
+
+
+@PIPELINES.register_module(
+ Tasks.acoustic_noise_suppression,
+ module_name=Pipelines.speech_dfsmn_ans_psm_48k_causal)
+class ANSDFSMNPipeline(Pipeline):
+ """ANS (Acoustic Noise Suppression) inference pipeline based on DFSMN model.
+
+ Args:
+ stream_mode: set its work mode, default False
+ In stream model, it accepts bytes as pipeline input that should be the audio data in PCM format.
+ In normal model, it accepts str and treat it as the path of local wav file or the http link of remote wav file.
+ """
+ SAMPLE_RATE = 48000
+
+ def __init__(self, model, **kwargs):
+ super().__init__(model=model, **kwargs)
+ model_bin_file = os.path.join(self.model.model_dir,
+ ModelFile.TORCH_MODEL_BIN_FILE)
+ if os.path.exists(model_bin_file):
+ checkpoint = torch.load(model_bin_file, map_location=self.device)
+ self.model.load_state_dict(checkpoint)
+ self.model.eval()
+ self.stream_mode = kwargs.get('stream_mode', False)
+ if self.stream_mode:
+ # the unit of WINLEN and STRIDE is frame, 1 frame of 16bit = 2 bytes
+ byte_buffer_length = \
+ (WINLEN + STRIDE * (self.model.lorder - 1)) * 2
+ self.buffer = collections.deque(maxlen=byte_buffer_length)
+ # padding head
+ for i in range(STRIDE * 2):
+ self.buffer.append(b'\0')
+ # it processes WINLEN frames at the first time, then STRIDE frames
+ self.byte_length_remain = (STRIDE * 2 - WINLEN) * 2
+ self.first_forward = True
+ self.tensor_give_up_length = (WINLEN - STRIDE) // 2
+
+ window = torch.hamming_window(
+ STFT_WIN_LEN, periodic=False, device=self.device)
+
+ def stft(x):
+ return torch.stft(
+ x,
+ N_FFT,
+ HOP_LENGTH,
+ STFT_WIN_LEN,
+ center=False,
+ window=window)
+
+ def istft(x, slen):
+ return librosa.istft(
+ x,
+ hop_length=HOP_LENGTH,
+ win_length=STFT_WIN_LEN,
+ window=WINDOW_NAME_HAM,
+ center=False,
+ length=slen)
+
+ self.stft = stft
+ self.istft = istft
+
+ def preprocess(self, inputs: Input, **preprocess_params) -> Dict[str, Any]:
+ if self.stream_mode:
+ if not isinstance(inputs, bytes):
+ raise TypeError('Only support bytes in stream mode.')
+ if len(inputs) > self.buffer.maxlen:
+ raise ValueError(
+ f'inputs length too large: {len(inputs)} > {self.buffer.maxlen}'
+ )
+ tensor_list = []
+ current_index = 0
+ while self.byte_length_remain + len(
+ inputs) - current_index >= STRIDE * 2:
+ byte_length_to_add = STRIDE * 2 - self.byte_length_remain
+ for i in range(current_index,
+ current_index + byte_length_to_add):
+ self.buffer.append(inputs[i].to_bytes(
+ 1, byteorder=sys.byteorder, signed=False))
+ bytes_io = io.BytesIO()
+ for b in self.buffer:
+ bytes_io.write(b)
+ data = np.frombuffer(bytes_io.getbuffer(), dtype=np.int16)
+ data_tensor = torch.from_numpy(data).type(torch.FloatTensor)
+ tensor_list.append(data_tensor)
+ self.byte_length_remain = 0
+ current_index += byte_length_to_add
+ for i in range(current_index, len(inputs)):
+ self.buffer.append(inputs[i].to_bytes(
+ 1, byteorder=sys.byteorder, signed=False))
+ self.byte_length_remain += 1
+ return {'audio': tensor_list}
+ else:
+ if isinstance(inputs, str):
+ data_bytes = File.read(inputs)
+ elif isinstance(inputs, bytes):
+ data_bytes = inputs
+ else:
+ raise TypeError(f'Unsupported type {type(inputs)}.')
+ data_tensor = self.bytes2tensor(data_bytes)
+ return {'audio': data_tensor}
+
+ def bytes2tensor(self, file_bytes):
+ data1, fs = sf.read(io.BytesIO(file_bytes))
+ data1 = data1.astype(np.float32)
+ if len(data1.shape) > 1:
+ data1 = data1[:, 0]
+ if fs != self.SAMPLE_RATE:
+ data1 = librosa.resample(data1, fs, self.SAMPLE_RATE)
+ data = data1 * 32768
+ data_tensor = torch.from_numpy(data).type(torch.FloatTensor)
+ return data_tensor
+
+ def forward(self, inputs: Dict[str, Any],
+ **forward_params) -> Dict[str, Any]:
+ if self.stream_mode:
+ bytes_io = io.BytesIO()
+ for origin_audio in inputs['audio']:
+ masked_sig = self._forward(origin_audio)
+ if self.first_forward:
+ masked_sig = masked_sig[:-self.tensor_give_up_length]
+ self.first_forward = False
+ else:
+ masked_sig = masked_sig[-WINLEN:]
+ masked_sig = masked_sig[self.tensor_give_up_length:-self.
+ tensor_give_up_length]
+ bytes_io.write(masked_sig.astype(np.int16).tobytes())
+ outputs = bytes_io.getvalue()
+ else:
+ origin_audio = inputs['audio']
+ masked_sig = self._forward(origin_audio)
+ outputs = masked_sig.astype(np.int16).tobytes()
+ return {OutputKeys.OUTPUT_PCM: outputs}
+
+ def _forward(self, origin_audio):
+ with torch.no_grad():
+ audio_in = origin_audio.unsqueeze(0)
+ import torchaudio
+ fbanks = torchaudio.compliance.kaldi.fbank(
+ audio_in,
+ dither=1.0,
+ frame_length=40.0,
+ frame_shift=20.0,
+ num_mel_bins=120,
+ sample_frequency=self.SAMPLE_RATE,
+ window_type=WINDOW_NAME_HAM)
+ fbanks = fbanks.unsqueeze(0)
+ masks = self.model(fbanks)
+ spectrum = self.stft(origin_audio)
+ masks = masks.permute(2, 1, 0)
+ masked_spec = (spectrum * masks).cpu()
+ masked_spec = masked_spec.detach().numpy()
+ masked_spec_complex = masked_spec[:, :, 0] + 1j * masked_spec[:, :, 1]
+ masked_sig = self.istft(masked_spec_complex, len(origin_audio))
+ return masked_sig
+
+ def postprocess(self, inputs: Dict[str, Any], **kwargs) -> Dict[str, Any]:
+ if not self.stream_mode and 'output_path' in kwargs.keys():
+ sf.write(
+ kwargs['output_path'],
+ np.frombuffer(inputs[OutputKeys.OUTPUT_PCM], dtype=np.int16),
+ self.SAMPLE_RATE)
+ return inputs
diff --git a/modelscope/pipelines/audio/ans_pipeline.py b/modelscope/pipelines/audio/ans_pipeline.py
index c12c9817..3719689c 100644
--- a/modelscope/pipelines/audio/ans_pipeline.py
+++ b/modelscope/pipelines/audio/ans_pipeline.py
@@ -36,8 +36,11 @@ class ANSPipeline(Pipeline):
"""
super().__init__(model=model, **kwargs)
self.model.eval()
+ self.stream_mode = kwargs.get('stream_mode', False)
def preprocess(self, inputs: Input, **preprocess_params) -> Dict[str, Any]:
+ if self.stream_mode:
+ raise TypeError('This model does not support stream mode!')
if isinstance(inputs, bytes):
data1, fs = sf.read(io.BytesIO(inputs))
elif isinstance(inputs, str):
diff --git a/modelscope/pipelines/audio/asr_inference_pipeline.py b/modelscope/pipelines/audio/asr_inference_pipeline.py
index 80f5387a..f86c92a7 100644
--- a/modelscope/pipelines/audio/asr_inference_pipeline.py
+++ b/modelscope/pipelines/audio/asr_inference_pipeline.py
@@ -51,6 +51,8 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
punc_model_revision: Optional[str] = None,
lm_model: Optional[Union[Model, str]] = None,
lm_model_revision: Optional[str] = None,
+ timestamp_model: Optional[Union[Model, str]] = None,
+ timestamp_model_revision: Optional[str] = None,
**kwargs):
"""
Use `model` and `preprocessor` to create an asr pipeline for prediction
@@ -72,6 +74,9 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
lm_model (Optional: 'Model' or 'str'):
language model from model hub or local
example: 'damo/speech_transformer_lm_zh-cn-common-vocab8404-pytorch'
+ timestamp_model (Optional: 'Model' or 'str'):
+ timestamp model from model hub or local
+ example: 'damo/speech_timestamp_predictor-v1-16k-offline'
output_dir('str'):
output dir path
batch_size('int'):
@@ -108,6 +113,8 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
self.punc_model_revision = punc_model_revision
self.lm_model = lm_model
self.lm_model_revision = lm_model_revision
+ self.timestamp_model = timestamp_model
+ self.timestamp_model_revision = timestamp_model_revision
self.model_cfg = self.model.forward()
self.cmd = self.get_cmd(kwargs)
@@ -144,6 +151,8 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
vad_cmvn_file=self.cmd['vad_cmvn_file'],
punc_model_file=self.cmd['punc_model_file'],
punc_infer_config=self.cmd['punc_infer_config'],
+ timestamp_model_file=self.cmd['timestamp_model_file'],
+ timestamp_infer_config=self.cmd['timestamp_infer_config'],
outputs_dict=self.cmd['outputs_dict'],
param_dict=self.cmd['param_dict'],
token_num_relax=self.cmd['token_num_relax'],
@@ -288,6 +297,8 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
'time_stamp_writer': True,
'punc_infer_config': None,
'punc_model_file': None,
+ 'timestamp_infer_config': None,
+ 'timestamp_model_file': None,
'outputs_dict': True,
'param_dict': None,
'model_type': outputs['model_type'],
@@ -355,9 +366,16 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
self.punc_model = model_config['punc_model']
if model_config.__contains__('punc_model_revision'):
self.punc_model_revision = model_config['punc_model_revision']
+ if model_config.__contains__(
+ 'timestamp_model') and self.timestamp_model != '':
+ self.timestamp_model = model_config['timestamp_model']
+ if model_config.__contains__('timestamp_model_revision'):
+ self.timestamp_model_revision = model_config[
+ 'timestamp_model_revision']
self.load_vad_model(cmd)
self.load_punc_model(cmd)
self.load_lm_model(cmd)
+ self.load_timestamp_model(cmd)
user_args_dict = [
'output_dir',
@@ -460,6 +478,28 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
model_dir,
model_cfg['model']['model_config']['lm_model_config'])
+ # FIXME
+ def load_timestamp_model(self, cmd):
+ if self.timestamp_model is not None and self.timestamp_model != '':
+ if os.path.exists(self.timestamp_model):
+ timestamp_model = self.timestamp_model
+ else:
+ timestamp_model = snapshot_download(
+ self.timestamp_model,
+ revision=self.timestamp_model_revision)
+ logger.info(
+ 'loading timestamp model from {0} ...'.format(timestamp_model))
+ config_path = os.path.join(timestamp_model,
+ ModelFile.CONFIGURATION)
+ model_cfg = json.loads(open(config_path).read())
+ model_dir = os.path.dirname(config_path)
+ cmd['timestamp_model_file'] = os.path.join(
+ model_dir,
+ model_cfg['model']['model_config']['timestamp_model_name'])
+ cmd['timestamp_infer_config'] = os.path.join(
+ model_dir,
+ model_cfg['model']['model_config']['timestamp_model_config'])
+
def forward(self, inputs: Dict[str, Any], **kwargs) -> Dict[str, Any]:
"""Decoding
"""
diff --git a/modelscope/pipelines/audio/kws_farfield_pipeline.py b/modelscope/pipelines/audio/kws_farfield_pipeline.py
index 5bfc31e9..fe5cb537 100644
--- a/modelscope/pipelines/audio/kws_farfield_pipeline.py
+++ b/modelscope/pipelines/audio/kws_farfield_pipeline.py
@@ -45,6 +45,9 @@ class KWSFarfieldPipeline(Pipeline):
else:
self._keyword_map = {}
+ def _sanitize_parameters(self, **pipeline_parameters):
+ return pipeline_parameters, pipeline_parameters, pipeline_parameters
+
def preprocess(self, inputs: Input, **preprocess_params) -> Dict[str, Any]:
if isinstance(inputs, bytes):
return dict(input_file=inputs)
@@ -65,8 +68,8 @@ class KWSFarfieldPipeline(Pipeline):
frames = numpy.stack((frames, frames, numpy.zeros_like(frames)), 1)
kws_list = []
- if 'output_file' in inputs:
- with wave.open(inputs['output_file'], 'wb') as fout:
+ if 'output_file' in forward_params:
+ with wave.open(forward_params['output_file'], 'wb') as fout:
fout.setframerate(self.SAMPLE_RATE)
fout.setnchannels(self.OUTPUT_CHANNELS)
fout.setsampwidth(self.SAMPLE_WIDTH)
diff --git a/modelscope/pipelines/audio/punctuation_processing_pipeline.py b/modelscope/pipelines/audio/punctuation_processing_pipeline.py
index ec1532ea..90daa421 100644
--- a/modelscope/pipelines/audio/punctuation_processing_pipeline.py
+++ b/modelscope/pipelines/audio/punctuation_processing_pipeline.py
@@ -86,9 +86,12 @@ class PunctuationProcessingPipeline(Pipeline):
rst = {}
for i in range(len(inputs)):
if i == 0:
- text = inputs[0]['value']
- if len(text) > 0:
- rst[OutputKeys.TEXT] = text
+ for key, value in inputs[0].items():
+ if key == 'value':
+ if len(value) > 0:
+ rst[OutputKeys.TEXT] = value
+ elif key != 'key':
+ rst[key] = value
else:
rst[inputs[i]['key']] = inputs[i]['value']
return rst
diff --git a/modelscope/pipelines/audio/speaker_diarization_pipeline.py b/modelscope/pipelines/audio/speaker_diarization_pipeline.py
index ed34dfb9..f800e2a5 100644
--- a/modelscope/pipelines/audio/speaker_diarization_pipeline.py
+++ b/modelscope/pipelines/audio/speaker_diarization_pipeline.py
@@ -189,6 +189,18 @@ class SpeakerDiarizationPipeline(Pipeline):
self.sv_model_revision = model_config['sv_model_revision']
self.load_sv_model(cmd)
+ # re-write the config with configure.json
+ for user_args in user_args_dict:
+ if (user_args in self.model_cfg['model_config']
+ and self.model_cfg['model_config'][user_args] is not None):
+ if isinstance(cmd[user_args], dict) and isinstance(
+ self.model_cfg['model_config'][user_args], dict):
+ cmd[user_args].update(
+ self.model_cfg['model_config'][user_args])
+ else:
+ cmd[user_args] = self.model_cfg['model_config'][user_args]
+
+ # rewrite the config with user args
for user_args in user_args_dict:
if user_args in extra_args and extra_args[user_args] is not None:
if isinstance(cmd[user_args], dict) and isinstance(
diff --git a/modelscope/pipelines/audio/speaker_verification_pipeline.py b/modelscope/pipelines/audio/speaker_verification_pipeline.py
index 55ea95bf..1ee6b0b2 100644
--- a/modelscope/pipelines/audio/speaker_verification_pipeline.py
+++ b/modelscope/pipelines/audio/speaker_verification_pipeline.py
@@ -34,8 +34,9 @@ class SpeakerVerificationPipeline(Pipeline):
>>> from modelscope.pipelines import pipeline
>>> pipeline_sv = pipeline(
>>> task=Tasks.speaker_verification, model='damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch')
- >>> audio_in=('','')
+ >>> audio_in=('sv_example_enroll.wav', 'sv_example_same.wav')
>>> print(pipeline_sv(audio_in))
+ >>> # {'label': ['Same', 'Different'], 'scores': [0.8540488358969999, 0.14595116410300013]}
"""
@@ -88,12 +89,11 @@ class SpeakerVerificationPipeline(Pipeline):
"""
rst = {}
for i in range(len(inputs)):
- # for demo service(environ is 'eas'), only show the first result
+ # for single input, re-formate the output
# audio_in:
# list/tuple: return speaker verification scores
# single wav/bytes: return speaker embedding
- if 'MODELSCOPE_ENVIRONMENT' in os.environ and \
- os.environ['MODELSCOPE_ENVIRONMENT'] == 'eas':
+ if len(inputs) == 1 and i == 0:
if isinstance(self.audio_in, tuple) or isinstance(
self.audio_in, list):
score = inputs[0]['value']
@@ -103,7 +103,7 @@ class SpeakerVerificationPipeline(Pipeline):
embedding = inputs[0]['value']
rst[OutputKeys.SPK_EMBEDDING] = embedding
else:
- # for notebook/local jobs, copy results
+ # for multiple inputs
rst[inputs[i]['key']] = inputs[i]['value']
return rst
@@ -146,9 +146,25 @@ class SpeakerVerificationPipeline(Pipeline):
'param_dict',
]
+ # re-write the config with configure.json
+ for user_args in user_args_dict:
+ if (user_args in self.model_cfg['model_config']
+ and self.model_cfg['model_config'][user_args] is not None):
+ if isinstance(cmd[user_args], dict) and isinstance(
+ self.model_cfg['model_config'][user_args], dict):
+ cmd[user_args].update(
+ self.model_cfg['model_config'][user_args])
+ else:
+ cmd[user_args] = self.model_cfg['model_config'][user_args]
+
+ # rewrite the config with user args
for user_args in user_args_dict:
if user_args in extra_args and extra_args[user_args] is not None:
- cmd[user_args] = extra_args[user_args]
+ if isinstance(cmd[user_args], dict) and isinstance(
+ extra_args[user_args], dict):
+ cmd[user_args].update(extra_args[user_args])
+ else:
+ cmd[user_args] = extra_args[user_args]
return cmd
diff --git a/modelscope/pipelines/audio/timestamp_pipeline.py b/modelscope/pipelines/audio/timestamp_pipeline.py
new file mode 100644
index 00000000..63471b1c
--- /dev/null
+++ b/modelscope/pipelines/audio/timestamp_pipeline.py
@@ -0,0 +1,307 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import os
+from typing import Any, Dict, List, Sequence, Tuple, Union
+
+import json
+import yaml
+from funasr.utils import asr_utils
+
+from modelscope.metainfo import Pipelines
+from modelscope.models import Model
+from modelscope.outputs import OutputKeys
+from modelscope.pipelines.base import Pipeline
+from modelscope.pipelines.builder import PIPELINES
+from modelscope.utils.audio.audio_utils import generate_scp_from_url
+from modelscope.utils.constant import Frameworks, ModelFile, Tasks
+from modelscope.utils.logger import get_logger
+
+logger = get_logger()
+
+__all__ = ['TimestampPipeline']
+
+
+@PIPELINES.register_module(
+ Tasks.speech_timestamp, module_name=Pipelines.speech_timestamp_inference)
+class TimestampPipeline(Pipeline):
+ """Timestamp Inference Pipeline
+ Example:
+
+ >>> from modelscope.pipelines import pipeline
+ >>> from modelscope.utils.constant import Tasks
+
+ >>> pipeline_infer = pipeline(
+ >>> task=Tasks.speech_timestamp,
+ >>> model='damo/speech_timestamp_predictor-v1-16k-offline')
+
+ >>> audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_timestamps.wav'
+ >>> text_in='一 个 东 太 平 洋 国 家 为 什 么 跑 到 西 太 平 洋 来 了 呢'
+ >>> print(pipeline_infer(audio_in, text_in))
+
+ """
+
+ def __init__(self, model: Union[Model, str] = None, **kwargs):
+ """
+ Use `model` and `preprocessor` to create an asr pipeline for prediction
+ Args:
+ model ('Model' or 'str'):
+ The pipeline handles three types of model:
+
+ - A model instance
+ - A model local dir
+ - A model id in the model hub
+ output_dir('str'):
+ output dir path
+ batch_size('int'):
+ the batch size for inference
+ ngpu('int'):
+ the number of gpus, 0 indicates CPU mode
+ split_with_space('bool'):
+ split the input sentence by space
+ seg_dict_file('str'):
+ seg dict file
+ param_dict('dict'):
+ extra kwargs
+ """
+ super().__init__(model=model, **kwargs)
+ config_path = os.path.join(model, ModelFile.CONFIGURATION)
+ self.cmd = self.get_cmd(config_path, kwargs)
+
+ from funasr.bin import tp_inference_launch
+ self.funasr_infer_modelscope = tp_inference_launch.inference_launch(
+ mode=self.cmd['mode'],
+ batch_size=self.cmd['batch_size'],
+ dtype=self.cmd['dtype'],
+ ngpu=self.cmd['ngpu'],
+ seed=self.cmd['seed'],
+ num_workers=self.cmd['num_workers'],
+ log_level=self.cmd['log_level'],
+ key_file=self.cmd['key_file'],
+ timestamp_infer_config=self.cmd['timestamp_infer_config'],
+ timestamp_model_file=self.cmd['timestamp_model_file'],
+ timestamp_cmvn_file=self.cmd['timestamp_cmvn_file'],
+ output_dir=self.cmd['output_dir'],
+ allow_variable_data_keys=self.cmd['allow_variable_data_keys'],
+ split_with_space=self.cmd['split_with_space'],
+ seg_dict_file=self.cmd['seg_dict_file'],
+ param_dict=self.cmd['param_dict'])
+
+ def __call__(self,
+ audio_in: Union[str, bytes],
+ text_in: str = None,
+ audio_fs: int = None,
+ recog_type: str = None,
+ audio_format: str = None,
+ output_dir: str = None,
+ param_dict: dict = None,
+ **kwargs) -> Dict[str, Any]:
+ """
+ Decoding the input audios
+ Args:
+ audio_in('str' or 'bytes'):
+ - A string containing a local path to a wav file
+ - A string containing a local path to a scp
+ - A string containing a wav url
+ text_in('str'):
+ - A text str input
+ - A local text file input endswith .txt or .scp
+ audio_fs('int'):
+ frequency of sample
+ recog_type('str'):
+ recog type for wav file or datasets file ('wav', 'test', 'dev', 'train')
+ audio_format('str'):
+ audio format ('pcm', 'scp', 'kaldi_ark', 'tfrecord')
+ output_dir('str'):
+ output dir
+ param_dict('dict'):
+ extra kwargs
+ Return:
+ A dictionary of result or a list of dictionary of result.
+
+ The dictionary contain the following keys:
+ - **text** ('str') --The timestamp result.
+ """
+ self.audio_in = None
+ self.text_in = None
+ self.raw_inputs = None
+ self.recog_type = recog_type
+ self.audio_format = audio_format
+ self.audio_fs = None
+ checking_audio_fs = None
+ if output_dir is not None:
+ self.cmd['output_dir'] = output_dir
+ if param_dict is not None:
+ self.cmd['param_dict'] = param_dict
+
+ # audio
+ if isinstance(audio_in, str):
+ # for funasr code, generate wav.scp from url or local path
+ self.audio_in, self.raw_inputs = generate_scp_from_url(audio_in)
+ elif isinstance(audio_in, bytes):
+ self.audio_in = audio_in
+ self.raw_inputs = None
+ else:
+ import numpy
+ import torch
+ if isinstance(audio_in, torch.Tensor):
+ self.audio_in = None
+ self.raw_inputs = audio_in
+ elif isinstance(audio_in, numpy.ndarray):
+ self.audio_in = None
+ self.raw_inputs = audio_in
+ # text
+ if text_in.startswith('http'):
+ self.text_in, _ = generate_text_from_url(text_in)
+ else:
+ self.text_in = text_in
+
+ # set the sample_rate of audio_in if checking_audio_fs is valid
+ if checking_audio_fs is not None:
+ self.audio_fs = checking_audio_fs
+
+ if recog_type is None or audio_format is None:
+ self.recog_type, self.audio_format, self.audio_in = asr_utils.type_checking(
+ audio_in=self.audio_in,
+ recog_type=recog_type,
+ audio_format=audio_format)
+
+ if hasattr(asr_utils,
+ 'sample_rate_checking') and self.audio_in is not None:
+ checking_audio_fs = asr_utils.sample_rate_checking(
+ self.audio_in, self.audio_format)
+ if checking_audio_fs is not None:
+ self.audio_fs = checking_audio_fs
+ if audio_fs is not None:
+ self.cmd['fs']['audio_fs'] = audio_fs
+ else:
+ self.cmd['fs']['audio_fs'] = self.audio_fs
+
+ output = self.forward(self.audio_in, self.text_in, **kwargs)
+ result = self.postprocess(output)
+ return result
+
+ def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
+ """Postprocessing
+ """
+ rst = {}
+ for i in range(len(inputs)):
+ if i == 0:
+ for key, value in inputs[0].items():
+ if key == 'value':
+ if len(value) > 0:
+ rst[OutputKeys.TEXT] = value
+ elif key != 'key':
+ rst[key] = value
+ else:
+ rst[inputs[i]['key']] = inputs[i]['value']
+ return rst
+
+ def get_cmd(self, config_path, extra_args) -> Dict[str, Any]:
+ model_cfg = json.loads(open(config_path).read())
+ model_dir = os.path.dirname(config_path)
+ # generate inference command
+ timestamp_model_file = os.path.join(
+ model_dir,
+ model_cfg['model']['model_config']['timestamp_model_file'])
+ timestamp_infer_config = os.path.join(
+ model_dir,
+ model_cfg['model']['model_config']['timestamp_infer_config'])
+ timestamp_cmvn_file = os.path.join(
+ model_dir,
+ model_cfg['model']['model_config']['timestamp_cmvn_file'])
+ mode = model_cfg['model']['model_config']['mode']
+ frontend_conf = None
+ if os.path.exists(timestamp_infer_config):
+ config_file = open(timestamp_infer_config, encoding='utf-8')
+ root = yaml.full_load(config_file)
+ config_file.close()
+ if 'frontend_conf' in root:
+ frontend_conf = root['frontend_conf']
+ seg_dict_file = None
+ if 'seg_dict_file' in model_cfg['model']['model_config']:
+ seg_dict_file = os.path.join(
+ model_dir, model_cfg['model']['model_config']['seg_dict_file'])
+
+ cmd = {
+ 'mode': mode,
+ 'batch_size': 1,
+ 'dtype': 'float32',
+ 'ngpu': 0, # 0: only CPU, ngpu>=1: gpu number if cuda is available
+ 'seed': 0,
+ 'num_workers': 0,
+ 'log_level': 'ERROR',
+ 'key_file': None,
+ 'allow_variable_data_keys': False,
+ 'split_with_space': True,
+ 'seg_dict_file': seg_dict_file,
+ 'timestamp_infer_config': timestamp_infer_config,
+ 'timestamp_model_file': timestamp_model_file,
+ 'timestamp_cmvn_file': timestamp_cmvn_file,
+ 'output_dir': None,
+ 'param_dict': None,
+ 'fs': {
+ 'model_fs': None,
+ 'audio_fs': None
+ }
+ }
+ if frontend_conf is not None and 'fs' in frontend_conf:
+ cmd['fs']['model_fs'] = frontend_conf['fs']
+
+ user_args_dict = [
+ 'output_dir',
+ 'batch_size',
+ 'mode',
+ 'ngpu',
+ 'param_dict',
+ 'num_workers',
+ 'log_level',
+ 'split_with_space',
+ 'seg_dict_file',
+ ]
+
+ for user_args in user_args_dict:
+ if user_args in extra_args and extra_args[user_args] is not None:
+ cmd[user_args] = extra_args[user_args]
+
+ return cmd
+
+ def forward(self, audio_in: Dict[str, Any], text_in: Dict[str, Any],
+ **kwargs) -> Dict[str, Any]:
+ """Decoding
+ """
+ logger.info('Timestamp Processing ...')
+ # generate inputs
+ data_cmd: Sequence[Tuple[str, str, str]]
+ if isinstance(self.audio_in, bytes):
+ data_cmd = [(self.audio_in, 'speech', 'bytes')]
+ data_cmd.append((text_in, 'text', 'text'))
+ elif isinstance(self.audio_in, str):
+ data_cmd = [(self.audio_in, 'speech', 'sound')]
+ data_cmd.append((text_in, 'text', 'text'))
+ elif self.raw_inputs is not None:
+ data_cmd = None
+
+ if self.raw_inputs is None and data_cmd is None:
+ raise ValueError('please check audio_in')
+
+ self.cmd['name_and_type'] = data_cmd
+ self.cmd['raw_inputs'] = self.raw_inputs
+ self.cmd['audio_in'] = self.audio_in
+
+ tp_result = self.run_inference(self.cmd, **kwargs)
+
+ return tp_result
+
+ def run_inference(self, cmd, **kwargs):
+ tp_result = []
+ if self.framework == Frameworks.torch:
+ tp_result = self.funasr_infer_modelscope(
+ data_path_and_name_and_type=cmd['name_and_type'],
+ raw_inputs=cmd['raw_inputs'],
+ output_dir_v2=cmd['output_dir'],
+ fs=cmd['fs'],
+ param_dict=cmd['param_dict'],
+ **kwargs)
+ else:
+ raise ValueError('model type is mismatching')
+
+ return tp_result
diff --git a/modelscope/pipelines/audio/voice_activity_detection_pipeline.py b/modelscope/pipelines/audio/voice_activity_detection_pipeline.py
index d80591f3..da46dd3e 100644
--- a/modelscope/pipelines/audio/voice_activity_detection_pipeline.py
+++ b/modelscope/pipelines/audio/voice_activity_detection_pipeline.py
@@ -67,7 +67,8 @@ class VoiceActivityDetectionPipeline(Pipeline):
recog_type: str = None,
audio_format: str = None,
output_dir: str = None,
- param_dict: dict = None) -> Dict[str, Any]:
+ param_dict: dict = None,
+ **kwargs) -> Dict[str, Any]:
"""
Decoding the input audios
Args:
@@ -92,15 +93,16 @@ class VoiceActivityDetectionPipeline(Pipeline):
The dictionary contain the following keys:
- **text** ('str') --The vad result.
"""
+ self.audio_in = None
+ self.raw_inputs = None
self.recog_type = recog_type
self.audio_format = audio_format
- self.audio_fs = audio_fs
+ self.audio_fs = None
checking_audio_fs = None
- self.raw_inputs = None
if output_dir is not None:
self.cmd['output_dir'] = output_dir
- if audio_fs is not None:
- self.cmd['fs']['audio_fs'] = audio_fs
+ if param_dict is not None:
+ self.cmd['param_dict'] = param_dict
if isinstance(audio_in, str):
# for funasr code, generate wav.scp from url or local path
self.audio_in, self.raw_inputs = generate_scp_from_url(audio_in)
@@ -116,10 +118,6 @@ class VoiceActivityDetectionPipeline(Pipeline):
elif isinstance(audio_in, numpy.ndarray):
self.audio_in = None
self.raw_inputs = audio_in
- if output_dir is not None:
- self.cmd['output_dir'] = output_dir
- if param_dict is not None:
- self.cmd['param_dict'] = param_dict
# set the sample_rate of audio_in if checking_audio_fs is valid
if checking_audio_fs is not None:
@@ -137,7 +135,12 @@ class VoiceActivityDetectionPipeline(Pipeline):
self.audio_in, self.audio_format)
if checking_audio_fs is not None:
self.audio_fs = checking_audio_fs
- output = self.forward(self.audio_in)
+ if audio_fs is not None:
+ self.cmd['fs']['audio_fs'] = audio_fs
+ else:
+ self.cmd['fs']['audio_fs'] = self.audio_fs
+
+ output = self.forward(self.audio_in, **kwargs)
result = self.postprocess(output)
return result
@@ -205,7 +208,7 @@ class VoiceActivityDetectionPipeline(Pipeline):
return cmd
- def forward(self, audio_in: Dict[str, Any]) -> Dict[str, Any]:
+ def forward(self, audio_in: Dict[str, Any], **kwargs) -> Dict[str, Any]:
"""Decoding
"""
logger.info('VAD Processing ...')
@@ -221,11 +224,11 @@ class VoiceActivityDetectionPipeline(Pipeline):
self.cmd['raw_inputs'] = self.raw_inputs
self.cmd['audio_in'] = self.audio_in
- vad_result = self.run_inference(self.cmd)
+ vad_result = self.run_inference(self.cmd, **kwargs)
return vad_result
- def run_inference(self, cmd):
+ def run_inference(self, cmd, **kwargs):
vad_result = []
if self.framework == Frameworks.torch:
vad_result = self.funasr_infer_modelscope(
@@ -233,7 +236,8 @@ class VoiceActivityDetectionPipeline(Pipeline):
raw_inputs=cmd['raw_inputs'],
output_dir_v2=cmd['output_dir'],
fs=cmd['fs'],
- param_dict=cmd['param_dict'])
+ param_dict=cmd['param_dict'],
+ **kwargs)
else:
raise ValueError('model type is mismatching')
diff --git a/modelscope/pipelines/base.py b/modelscope/pipelines/base.py
index a3c15695..5479fe59 100644
--- a/modelscope/pipelines/base.py
+++ b/modelscope/pipelines/base.py
@@ -2,6 +2,7 @@
import os
import os.path as osp
+import random
from abc import ABC, abstractmethod
from functools import partial
from multiprocessing import Pool
@@ -9,6 +10,7 @@ from threading import Lock
from typing import Any, Dict, Generator, List, Mapping, Union
import numpy as np
+from packaging import version
from modelscope.models.base import Model
from modelscope.msdatasets import MsDataset
@@ -22,6 +24,7 @@ from modelscope.utils.device import (create_device, device_placement,
from modelscope.utils.hub import read_config, snapshot_download
from modelscope.utils.import_utils import is_tf_available, is_torch_available
from modelscope.utils.logger import get_logger
+from modelscope.utils.torch_utils import compile_model
from .util import is_model, is_official_hub_path
if is_torch_available():
@@ -80,6 +83,9 @@ class Pipeline(ABC):
preprocessor: (list of) Preprocessor object
device (str): device str, should be either cpu, cuda, gpu, gpu:X or cuda:X
auto_collate (bool): automatically to convert data to tensor or not.
+ compile (bool, optional): Compile the model with torch 2.0, default False
+ compile_options (dict, optional): The compile options if compile=True,
+ default None to use the default params of 'TorchModel.compile'.
"""
verify_device(device)
self.device_name = device
@@ -118,6 +124,8 @@ class Pipeline(ABC):
self._model_prepare = False
self._model_prepare_lock = Lock()
self._auto_collate = auto_collate
+ self._compile = kwargs.get('compile', False)
+ self._compile_options = kwargs.get('compile_options', {})
def prepare_model(self):
""" Place model on certain device for pytorch models before first inference
@@ -139,8 +147,16 @@ class Pipeline(ABC):
if self.has_multiple_models:
for m in self.models:
_prepare_single(m)
+ if self._compile:
+ self.models = [
+ compile_model(m, **self._compile_options)
+ for m in self.models
+ ]
else:
_prepare_single(self.model)
+ if self._compile:
+ self.model = compile_model(self.model,
+ **self._compile_options)
self._model_prepare = True
self._model_prepare_lock.release()
@@ -421,15 +437,20 @@ class DistributedPipeline(Pipeline):
ranks = list(range(self.world_size))
self.model_pool = Pool(self.world_size)
- master_ip = '127.0.0.1' if 'master_ip' not in kwargs else kwargs[
- 'master_ip']
- os.environ['MASTER_ADDR'] = master_ip
- master_port = '29500' if 'master_port' not in kwargs else kwargs[
- 'master_port']
+
+ if 'master_ip' not in kwargs:
+ kwargs['master_ip'] = '127.0.0.1'
+ master_port = int(kwargs['master_port']
+ ) if 'master_port' in kwargs else random.randint(
+ 29500, 39500)
from modelscope.utils.torch_utils import _find_free_port, _is_free_port
- if not _is_free_port(int(master_port)):
- master_port = str(_find_free_port())
- os.environ['MASTER_PORT'] = master_port
+ if not _is_free_port(master_port):
+ master_port = _find_free_port()
+ kwargs['master_port'] = str(master_port)
+ # TODO: Pass ip and port to megatron_util for initialization
+ os.environ['MASTER_ADDR'] = kwargs['master_ip']
+ os.environ['MASTER_PORT'] = kwargs['master_port']
+
self.model_pool.map(
partial(
self.__class__._instantiate_one,
diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py
index 4987a3e0..dd39453c 100644
--- a/modelscope/pipelines/builder.py
+++ b/modelscope/pipelines/builder.py
@@ -9,6 +9,8 @@ from modelscope.models.base import Model
from modelscope.utils.config import ConfigDict, check_config
from modelscope.utils.constant import DEFAULT_MODEL_REVISION, Invoke
from modelscope.utils.hub import read_config
+from modelscope.utils.plugins import (register_modelhub_repo,
+ register_plugins_repo)
from modelscope.utils.registry import Registry, build_from_cfg
from .base import Pipeline
from .util import is_official_hub_path
@@ -63,7 +65,6 @@ def pipeline(task: str = None,
framework: str = None,
device: str = 'gpu',
model_revision: Optional[str] = DEFAULT_MODEL_REVISION,
- plugins: List[str] = None,
**kwargs) -> Pipeline:
""" Factory method to build an obj:`Pipeline`.
@@ -96,8 +97,6 @@ def pipeline(task: str = None,
if task is None and pipeline_name is None:
raise ValueError('task or pipeline_name is required')
- try_import_plugins(plugins)
-
model = normalize_model_input(model, model_revision)
pipeline_props = {'type': pipeline_name}
if pipeline_name is None:
@@ -111,7 +110,8 @@ def pipeline(task: str = None,
model, str) else read_config(
model[0], revision=model_revision)
check_config(cfg)
- try_import_plugins(cfg.safe_get('plugins'))
+ register_plugins_repo(cfg.safe_get('plugins'))
+ register_modelhub_repo(model, cfg.get('allow_remote', False))
pipeline_props = cfg.pipeline
elif model is not None:
# get pipeline info from Model object
@@ -120,7 +120,6 @@ def pipeline(task: str = None,
# model is instantiated by user, we should parse config again
cfg = read_config(first_model.model_dir)
check_config(cfg)
- try_import_plugins(cfg.safe_get('plugins'))
first_model.pipeline = cfg.pipeline
pipeline_props = first_model.pipeline
else:
@@ -178,10 +177,3 @@ def get_default_pipeline_info(task):
else:
pipeline_name, default_model = DEFAULT_MODEL_FOR_PIPELINE[task]
return pipeline_name, default_model
-
-
-def try_import_plugins(plugins: List[str]) -> None:
- """ Try to import plugins """
- if plugins is not None:
- from modelscope.utils.plugins import import_plugins
- import_plugins(plugins)
diff --git a/modelscope/pipelines/cv/__init__.py b/modelscope/pipelines/cv/__init__.py
index 025f088b..f1c027a0 100644
--- a/modelscope/pipelines/cv/__init__.py
+++ b/modelscope/pipelines/cv/__init__.py
@@ -52,6 +52,7 @@ if TYPE_CHECKING:
from .ocr_recognition_pipeline import OCRRecognitionPipeline
from .license_plate_detection_pipeline import LicensePlateDetectionPipeline
from .table_recognition_pipeline import TableRecognitionPipeline
+ from .lineless_table_recognition_pipeline import LinelessTableRecognitionPipeline
from .skin_retouching_pipeline import SkinRetouchingPipeline
from .face_reconstruction_pipeline import FaceReconstructionPipeline
from .tinynas_classification_pipeline import TinynasClassificationPipeline
@@ -80,10 +81,12 @@ if TYPE_CHECKING:
from .vision_efficient_tuning_prefix_pipeline import VisionEfficientTuningPrefixPipeline
from .vision_efficient_tuning_lora_pipeline import VisionEfficientTuningLoRAPipeline
from .vision_middleware_pipeline import VisionMiddlewarePipeline
+ from .vidt_pipeline import VidtPipeline
from .video_frame_interpolation_pipeline import VideoFrameInterpolationPipeline
from .image_skychange_pipeline import ImageSkychangePipeline
from .image_driving_perception_pipeline import ImageDrivingPerceptionPipeline
from .vop_retrieval_pipeline import VopRetrievalPipeline
+ from .vop_retrieval_se_pipeline import VopRetrievalSEPipeline
from .video_object_segmentation_pipeline import VideoObjectSegmentationPipeline
from .video_deinterlace_pipeline import VideoDeinterlacePipeline
from .image_matching_pipeline import ImageMatchingPipeline
@@ -104,11 +107,13 @@ if TYPE_CHECKING:
from .ddpm_semantic_segmentation_pipeline import DDPMImageSemanticSegmentationPipeline
from .image_inpainting_sdv2_pipeline import ImageInpaintingSDV2Pipeline
from .image_quality_assessment_mos_pipeline import ImageQualityAssessmentMosPipeline
+ from .image_quality_assessment_man_pipeline import ImageQualityAssessmentMANPipeline
from .bad_image_detecting_pipeline import BadImageDetecingPipeline
from .mobile_image_super_resolution_pipeline import MobileImageSuperResolutionPipeline
from .image_human_parsing_pipeline import ImageHumanParsingPipeline
from .nerf_recon_acc_pipeline import NeRFReconAccPipeline
from .controllable_image_generation_pipeline import ControllableImageGenerationPipeline
+ from .image_bts_depth_estimation_pipeline import ImageBTSDepthEstimationPipeline
else:
_import_structure = {
@@ -215,6 +220,7 @@ else:
'VisionEfficientTuningLoRAPipeline'
],
'vision_middleware_pipeline': ['VisionMiddlewarePipeline'],
+ 'vidt_pipeline': ['VidtPipeline'],
'video_frame_interpolation_pipeline': [
'VideoFrameInterpolationPipeline'
],
@@ -223,6 +229,7 @@ else:
'ImageDrivingPerceptionPipeline'
],
'vop_retrieval_pipeline': ['VopRetrievalPipeline'],
+ 'vop_retrieval_se_pipeline': ['VopRetrievalSEPipeline'],
'video_object_segmentation_pipeline': [
'VideoObjectSegmentationPipeline'
],
@@ -259,6 +266,9 @@ else:
'image_quality_assessment_mos_pipeline': [
'ImageQualityAssessmentMosPipeline'
],
+ 'image_quality_assessment_man_pipeline': [
+ 'ImageQualityAssessmentMANPipeline'
+ ],
'mobile_image_super_resolution_pipeline': [
'MobileImageSuperResolutionPipeline'
],
@@ -268,6 +278,9 @@ else:
'controllable_image_generation_pipeline': [
'ControllableImageGenerationPipeline'
],
+ 'image_bts_depth_estimation_pipeline': [
+ 'ImageBTSDepthEstimationPipeline'
+ ]
}
import sys
diff --git a/modelscope/pipelines/cv/human_reconstruction_pipeline.py b/modelscope/pipelines/cv/human_reconstruction_pipeline.py
new file mode 100644
index 00000000..87186f73
--- /dev/null
+++ b/modelscope/pipelines/cv/human_reconstruction_pipeline.py
@@ -0,0 +1,109 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import os
+import shutil
+from typing import Any, Dict
+
+import numpy as np
+import torch
+import trimesh
+
+from modelscope.metainfo import Pipelines
+from modelscope.models.cv.human_reconstruction.utils import (
+ keep_largest, reconstruction, save_obj_mesh, save_obj_mesh_with_color,
+ to_tensor)
+from modelscope.outputs import OutputKeys
+from modelscope.pipelines import pipeline
+from modelscope.pipelines.base import Input, Model, Pipeline
+from modelscope.pipelines.builder import PIPELINES
+from modelscope.utils.constant import ModelFile, Tasks
+from modelscope.utils.logger import get_logger
+
+logger = get_logger()
+
+
+@PIPELINES.register_module(
+ Tasks.human_reconstruction, module_name=Pipelines.human_reconstruction)
+class HumanReconstructionPipeline(Pipeline):
+
+ def __init__(self, model: str, **kwargs):
+ """The inference pipeline for human reconstruction task.
+ Human Reconstruction Pipeline. Given one image generate a human mesh.
+
+ Args:
+ model (`str` or `Model` or module instance): A model instance or a model local dir
+ or a model id in the model hub.
+
+ Example:
+ >>> from modelscope.pipelines import pipeline
+ >>> test_input = 'human_reconstruction.jpg' # input image path
+ >>> pipeline_humanRecon = pipeline('human-reconstruction',
+ model='damo/cv_hrnet_image-human-reconstruction')
+ >>> result = pipeline_humanRecon(test_input)
+ >>> output = result[OutputKeys.OUTPUT]
+ """
+ super().__init__(model=model, **kwargs)
+ if not isinstance(self.model, Model):
+ logger.error('model object is not initialized.')
+ raise Exception('model object is not initialized.')
+
+ def preprocess(self, input: Input) -> Dict[str, Any]:
+ img_crop = self.model.crop_img(input)
+ img, mask = self.model.get_mask(img_crop)
+ normal_f, normal_b = self.model.generation_normal(img, mask)
+ image = to_tensor(img_crop) * 2 - 1
+ normal_b = to_tensor(normal_b) * 2 - 1
+ normal_f = to_tensor(normal_f) * 2 - 1
+ mask = to_tensor(mask)
+ result = {
+ 'img': image,
+ 'mask': mask,
+ 'normal_F': normal_f,
+ 'normal_B': normal_b
+ }
+ return result
+
+ def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
+ image = input['img']
+ mask = input['mask']
+ normF = input['normal_F']
+ normB = input['normal_B']
+ normF[1, ...] = -normF[1, ...]
+ normB[0, ...] = -normB[0, ...]
+ img = image * mask
+ normal_b = normB * mask
+ normal_f = normF * mask
+ img = torch.cat([img, normal_f, normal_b], dim=0).float()
+ image_tensor = img.unsqueeze(0).to(self.model.device)
+ calib_tensor = self.model.calib
+ net = self.model.meshmodel
+ net.extract_features(image_tensor)
+ verts, faces = reconstruction(net, calib_tensor, self.model.coords,
+ self.model.mat)
+ pre_mesh = trimesh.Trimesh(
+ verts, faces, process=False, maintain_order=True)
+ final_mesh = keep_largest(pre_mesh)
+ verts = final_mesh.vertices
+ faces = final_mesh.faces
+ verts_tensor = torch.from_numpy(verts.T).unsqueeze(0).to(
+ self.model.device).float()
+ color = torch.zeros(verts.shape)
+ interval = 20000
+ for i in range(len(color) // interval):
+ left = i * interval
+ right = i * interval + interval
+ if i == len(color) // interval - 1:
+ right = -1
+ pred_color = net.query_rgb(verts_tensor[:, :, left:right],
+ calib_tensor)
+ rgb = pred_color[0].detach().cpu() * 0.5 + 0.5
+ color[left:right] = rgb.T
+ vert_min = np.min(verts[:, 1])
+ verts[:, 1] = verts[:, 1] - vert_min
+ save_obj_mesh('human_reconstruction.obj', verts, faces)
+ save_obj_mesh_with_color('human_color.obj', verts, faces,
+ color.numpy())
+ results = {'vertices': verts, 'faces': faces, 'colors': color.numpy()}
+ return {OutputKeys.OUTPUT: results}
+
+ def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
+ return inputs
diff --git a/modelscope/pipelines/cv/image_bts_depth_estimation_pipeline.py b/modelscope/pipelines/cv/image_bts_depth_estimation_pipeline.py
new file mode 100644
index 00000000..f635f566
--- /dev/null
+++ b/modelscope/pipelines/cv/image_bts_depth_estimation_pipeline.py
@@ -0,0 +1,86 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+from typing import Any, Dict
+
+import albumentations as A
+import cv2
+import numpy as np
+import torch
+
+from modelscope.metainfo import Pipelines
+from modelscope.outputs import OutputKeys
+from modelscope.pipelines.base import Input, Pipeline
+from modelscope.pipelines.builder import PIPELINES
+from modelscope.preprocessors import LoadImage
+from modelscope.utils.constant import Tasks
+from modelscope.utils.cv.image_utils import depth_to_color
+from modelscope.utils.logger import get_logger
+
+logger = get_logger()
+
+
+@PIPELINES.register_module(
+ Tasks.image_depth_estimation,
+ module_name=Pipelines.image_bts_depth_estimation)
+class ImageBTSDepthEstimationPipeline(Pipeline):
+ r""" Image depth estimation pipeline of BTS model.
+
+ Examples:
+
+ >>> import cv2
+ >>> from modelscope.outputs import OutputKeys
+ >>> from modelscope.pipelines import pipeline
+ >>> from modelscope.utils.constant import Tasks
+
+ >>> estimator = pipeline(Tasks.image_depth_estimation, 'damo/cv_densenet161_image-depth-estimation_bts')
+ >>> result = estimator(
+ "https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/image_depth_estimation_kitti_007517.png")
+ >>> cv2.imwrite('result_depth_color.jpg', result[OutputKeys.DEPTHS_COLOR])
+ >>> cv2.imwrite('result_depth.jpg', result[OutputKeys.DEPTHS])
+ >>>
+ """
+
+ def __init__(self, model: str, **kwargs):
+ """
+ use `model` to create a image depth estimation pipeline for prediction
+ Args:
+ model: model id on modelscope hub.
+ """
+ super().__init__(model=model, **kwargs)
+ self.transform = A.Compose([A.Normalize(always_apply=True)])
+
+ logger.info('BTS depth estimation model, pipeline init')
+
+ def preprocess(self, input: Input) -> Dict[str, Any]:
+ img = LoadImage.convert_to_ndarray(input)
+
+ h, w, _ = img.shape
+ top, left = int(h - 352), int((w - 1216) / 2)
+ img = img[top:top + 352, left:left + 1216]
+
+ img = self.transform(image=img)['image']
+ img = torch.tensor(img).float().transpose(0, 2).transpose(1, 2)
+
+ imgs = img[None, ...]
+ data = {'imgs': imgs}
+
+ return data
+
+ def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
+ results = self.model.inference(input)
+ return results
+
+ def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
+ results = self.model.postprocess(inputs)
+ depths = results[OutputKeys.DEPTHS].detach().cpu()
+ depths = np.asarray(
+ np.squeeze(
+ (255 - torch.clamp_max(depths * 4, 250)).byte().numpy()),
+ np.uint8)
+ depths_color = depth_to_color(depths)
+
+ outputs = {
+ OutputKeys.DEPTHS: depths,
+ OutputKeys.DEPTHS_COLOR: depths_color
+ }
+
+ return outputs
diff --git a/modelscope/pipelines/cv/image_quality_assessment_man_pipeline.py b/modelscope/pipelines/cv/image_quality_assessment_man_pipeline.py
new file mode 100644
index 00000000..8e82e615
--- /dev/null
+++ b/modelscope/pipelines/cv/image_quality_assessment_man_pipeline.py
@@ -0,0 +1,80 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import math
+import tempfile
+from typing import Any, Dict, Optional, Union
+
+import cv2
+import numpy as np
+import torch
+from torchvision import transforms
+
+from modelscope.metainfo import Pipelines
+from modelscope.models.cv.image_quality_assessment_man import \
+ ImageQualityAssessmentMAN
+from modelscope.outputs import OutputKeys
+from modelscope.pipelines.base import Input, Pipeline
+from modelscope.pipelines.builder import PIPELINES
+from modelscope.preprocessors import LoadImage
+from modelscope.preprocessors.cv import ImageQualityAssessmentMANPreprocessor
+from modelscope.utils.constant import Tasks
+from modelscope.utils.logger import get_logger
+
+logger = get_logger()
+
+
+@PIPELINES.register_module(
+ Tasks.image_quality_assessment_mos,
+ module_name=Pipelines.image_quality_assessment_man)
+class ImageQualityAssessmentMANPipeline(Pipeline):
+ """ Image Quality Assessment MAN Pipeline which will use Multi-dimension Attention Network
+ to return Mean Opinion Score (MOS) for the input image.
+
+ Example:
+
+ ```python
+ >>> from modelscope.pipelines import pipeline
+ >>> from modelscope.outputs import OutputKeys
+ >>> from modelscope.utils.constant import Tasks
+
+ >>> test_image = 'https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/dogs.jpg'
+ >>> assessment_predictor = pipeline(Tasks.image_quality_assessment_man, \
+ model='damo/cv_man_image-quality-assessment')
+ >>> out_mos = assessment_predictor(test_image)[OutputKeys.SCORE]
+ >>> print('Pipeline: the output mos is {}'.format(out_mos))
+
+ ```
+ """
+
+ def __init__(self,
+ model: Union[ImageQualityAssessmentMAN, str],
+ preprocessor=ImageQualityAssessmentMANPreprocessor(),
+ **kwargs):
+ """
+ use `model` to create image quality assessment man pipeline for prediction
+ Args:
+ model: model id on modelscope hub or `ImageQualityAssessmentMAN` Model.
+ preprocessor: preprocessor for input image
+
+ """
+ super().__init__(model=model, preprocessor=preprocessor, **kwargs)
+
+ if torch.cuda.is_available():
+ self._device = torch.device('cuda')
+ else:
+ self._device = torch.device('cpu')
+
+ logger.info('load MANIQA model done')
+
+ @torch.no_grad()
+ def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
+ """
+ inference for image quality assessment prediction
+ Args:
+ input: dict including torch tensor.
+
+ """
+ outputs = self.model.forward({'input': input['input']})['output'].cpu()
+ return {OutputKeys.SCORE: outputs.item()}
+
+ def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
+ return inputs
diff --git a/modelscope/pipelines/cv/image_salient_detection_pipeline.py b/modelscope/pipelines/cv/image_salient_detection_pipeline.py
index 4a3eaa65..4b4df52c 100644
--- a/modelscope/pipelines/cv/image_salient_detection_pipeline.py
+++ b/modelscope/pipelines/cv/image_salient_detection_pipeline.py
@@ -12,6 +12,11 @@ from modelscope.utils.constant import Tasks
@PIPELINES.register_module(
Tasks.semantic_segmentation, module_name=Pipelines.salient_detection)
+@PIPELINES.register_module(
+ Tasks.semantic_segmentation,
+ module_name=Pipelines.salient_boudary_detection)
+@PIPELINES.register_module(
+ Tasks.semantic_segmentation, module_name=Pipelines.camouflaged_detection)
class ImageSalientDetectionPipeline(Pipeline):
def __init__(self, model: str, **kwargs):
diff --git a/modelscope/pipelines/cv/lineless_table_recognition_pipeline.py b/modelscope/pipelines/cv/lineless_table_recognition_pipeline.py
new file mode 100644
index 00000000..f0938a99
--- /dev/null
+++ b/modelscope/pipelines/cv/lineless_table_recognition_pipeline.py
@@ -0,0 +1,133 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import math
+import os.path as osp
+from typing import Any, Dict, Optional, Union
+
+import cv2
+import numpy as np
+import PIL
+import torch
+
+from modelscope.metainfo import Pipelines
+from modelscope.models.cv.table_recognition import LoreModel
+from modelscope.models.cv.table_recognition.lineless_table_process import \
+ get_affine_transform_upper_left
+from modelscope.outputs import OutputKeys
+from modelscope.pipelines.base import Input, Model, Pipeline
+from modelscope.pipelines.builder import PIPELINES
+from modelscope.preprocessors import load_image
+from modelscope.preprocessors.image import LoadImage
+from modelscope.utils.constant import ModelFile, Tasks
+from modelscope.utils.logger import get_logger
+
+logger = get_logger()
+
+
+@PIPELINES.register_module(
+ Tasks.lineless_table_recognition,
+ module_name=Pipelines.lineless_table_recognition)
+class LinelessTableRecognitionPipeline(Pipeline):
+ r""" Lineless Table Recognition Pipeline.
+
+ Examples:
+
+ >>> from modelscope.pipelines import pipeline
+
+ >>> detector = pipeline('lineless-table-recognition', 'damo/cv_resnet-transformer_table-structure-recognition_lore')
+ >>> detector("data/test/images/lineless_table_recognition.jpg")
+ >>> {
+ >>> "polygons": [
+ >>> [
+ >>> 159.65718,
+ >>> 161.14981,
+ >>> 170.9718,
+ >>> 161.1621,
+ >>> 170.97322,
+ >>> 175.4334,
+ >>> 159.65717,
+ >>> 175.43259
+ >>> ],
+ >>> [
+ >>> 153.24953,
+ >>> 230.49915,
+ >>> 176.26964,
+ >>> 230.50377,
+ >>> 176.26273,
+ >>> 246.08868,
+ >>> 153.24817,
+ >>> 246.10458
+ >>> ],
+ >>> ......
+ >>> ],
+ >>> "boxes": [
+ >>> [
+ >>> 4.,
+ >>> 4.,
+ >>> 1.,
+ >>> 1.
+ >>> ],
+ >>> [
+ >>> 6.,
+ >>> 6.,
+ >>> 1.,
+ >>> 1.
+ >>> ],
+ >>> ......
+ >>> ]
+ >>> }
+ >>>
+ """
+
+ def __init__(self, model: Union[Model, str], **kwargs):
+ """
+ Args:
+ model: model id on modelscope hub.
+ """
+ assert isinstance(model, str), 'model must be a single str'
+ super().__init__(model=model, **kwargs)
+ logger.info(f'loading model from dir {model}')
+ self.model.eval()
+
+ def preprocess(self, input: Input) -> Dict[str, Any]:
+ img = LoadImage.convert_to_ndarray(input)[:, :, ::-1]
+
+ mean = np.array([0.408, 0.447, 0.470],
+ dtype=np.float32).reshape(1, 1, 3)
+ std = np.array([0.289, 0.274, 0.278],
+ dtype=np.float32).reshape(1, 1, 3)
+ height, width = img.shape[0:2]
+ inp_height, inp_width = 768, 768
+ c = np.array([0, 0], dtype=np.float32)
+ s = max(height, width) * 1.0
+ trans_input = get_affine_transform_upper_left(c, s, 0,
+ [inp_width, inp_height])
+
+ resized_image = cv2.resize(img, (width, height))
+ inp_image = cv2.warpAffine(
+ resized_image,
+ trans_input, (inp_width, inp_height),
+ flags=cv2.INTER_LINEAR)
+ inp_image = ((inp_image / 255. - mean) / std).astype(np.float32)
+
+ images = inp_image.transpose(2, 0, 1).reshape(1, 3, inp_height,
+ inp_width)
+ images = torch.from_numpy(images).to(self.device)
+ meta = {
+ 'c': c,
+ 's': s,
+ 'input_height': inp_height,
+ 'input_width': inp_width,
+ 'out_height': inp_height // 4,
+ 'out_width': inp_width // 4
+ }
+
+ result = {'img': images, 'meta': meta}
+
+ return result
+
+ def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
+ results = self.model(input)
+ return results
+
+ def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
+ return inputs
diff --git a/modelscope/pipelines/cv/ocr_recognition_pipeline.py b/modelscope/pipelines/cv/ocr_recognition_pipeline.py
index f5b2f667..e81e1ff6 100644
--- a/modelscope/pipelines/cv/ocr_recognition_pipeline.py
+++ b/modelscope/pipelines/cv/ocr_recognition_pipeline.py
@@ -70,5 +70,5 @@ class OCRRecognitionPipeline(Pipeline):
return outputs
def postprocess(self, inputs):
- outputs = {OutputKeys.TEXT: inputs}
+ outputs = {OutputKeys.TEXT: inputs['preds']}
return outputs
diff --git a/modelscope/pipelines/cv/realtime_video_object_detection_pipeline.py b/modelscope/pipelines/cv/realtime_video_object_detection_pipeline.py
index ed2c0d35..6cf9379a 100644
--- a/modelscope/pipelines/cv/realtime_video_object_detection_pipeline.py
+++ b/modelscope/pipelines/cv/realtime_video_object_detection_pipeline.py
@@ -29,7 +29,6 @@ class RealtimeVideoObjectDetectionPipeline(Pipeline):
def __init__(self, model: str, **kwargs):
super().__init__(model=model, **kwargs)
- self.model = RealtimeVideoDetector(model)
def preprocess(self, input: Input) -> Dict[Tensor, Union[str, np.ndarray]]:
return input
diff --git a/modelscope/pipelines/cv/video_instance_segmentation_pipeline.py b/modelscope/pipelines/cv/video_instance_segmentation_pipeline.py
new file mode 100644
index 00000000..8b6fde35
--- /dev/null
+++ b/modelscope/pipelines/cv/video_instance_segmentation_pipeline.py
@@ -0,0 +1,271 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+
+import os
+import os.path as osp
+from typing import Any, Dict
+
+import cv2
+import mmcv
+import numpy as np
+import torch
+from tqdm import tqdm
+
+from modelscope.metainfo import Pipelines
+from modelscope.models.cv.video_instance_segmentation.video_knet import \
+ KNetTrack
+from modelscope.outputs import OutputKeys
+from modelscope.pipelines.base import Input, Pipeline
+from modelscope.pipelines.builder import PIPELINES
+from modelscope.utils.config import Config
+from modelscope.utils.constant import ModelFile, Tasks
+from modelscope.utils.logger import get_logger
+
+logger = get_logger()
+
+
+@PIPELINES.register_module(
+ Tasks.video_instance_segmentation,
+ module_name=Pipelines.video_instance_segmentation)
+class VideoInstanceSegmentationPipeline(Pipeline):
+ r""" Video Instance Segmentation Pipeline.
+
+ Examples:
+
+ >>> from modelscope.pipelines import pipeline
+
+ >>> detector = pipeline('video-instance-segmentation', 'damo/cv_swinb_video-instance-segmentation')
+ >>> detector("http://www.modelscope.cn/api/v1/models/damo/cv_swinb_video-instance-segmentation/repo?Revision=master"
+ >>> "&FilePath=resources/kitti-step_testing_image_02_0000.mp4")
+ >>> {
+ >>> "boxes": [
+ >>> [
+ >>> [
+ >>> 0,
+ >>> 446.9007568359375,
+ >>> 36.374977111816406,
+ >>> 907.0919189453125,
+ >>> 337.439208984375,
+ >>> 0.333
+ >>> ],
+ >>> [
+ >>> 1,
+ >>> 454.3310241699219,
+ >>> 336.08477783203125,
+ >>> 921.26904296875,
+ >>> 641.7871704101562,
+ >>> 0.792
+ >>> ]
+ >>> ],
+ >>> [
+ >>> [
+ >>> 0,
+ >>> 446.9007568359375,
+ >>> 36.374977111816406,
+ >>> 907.0919189453125,
+ >>> 337.439208984375,
+ >>> 0.333
+ >>> ],
+ >>> [
+ >>> 1,
+ >>> 454.3310241699219,
+ >>> 336.08477783203125,
+ >>> 921.26904296875,
+ >>> 641.7871704101562,
+ >>> 0.792
+ >>> ]
+ >>> ]
+ >>> ],
+ >>> "masks": [
+ >>> [
+ >>> [
+ >>> [False, False, False, ..., False, False, False],
+ >>> [False, False, False, ..., False, False, False],
+ >>> [False, False, False, ..., False, False, False],
+ >>> ...,
+ >>> [False, False, False, ..., False, False, False],
+ >>> [False, False, False, ..., False, False, False],
+ >>> [False, False, False, ..., False, False, False]
+ >>> ],
+ >>> [
+ >>> [False, False, False, ..., False, False, False],
+ >>> [False, False, False, ..., False, False, False],
+ >>> [False, False, False, ..., False, False, False],
+ >>> ...,
+ >>> [False, False, False, ..., False, False, False],
+ >>> [False, False, False, ..., False, False, False],
+ >>> [False, False, False, ..., False, False, False]
+ >>> ]
+ >>> ],
+ >>> [
+ >>> [
+ >>> [False, False, False, ..., False, False, False],
+ >>> [False, False, False, ..., False, False, False],
+ >>> [False, False, False, ..., False, False, False],
+ >>> ...,
+ >>> [False, False, False, ..., False, False, False],
+ >>> [False, False, False, ..., False, False, False],
+ >>> [False, False, False, ..., False, False, False]
+ >>> ],
+ >>> [
+ >>> [False, False, False, ..., False, False, False],
+ >>> [False, False, False, ..., False, False, False],
+ >>> [False, False, False, ..., False, False, False],
+ >>> ...,
+ >>> [False, False, False, ..., False, False, False],
+ >>> [False, False, False, ..., False, False, False],
+ >>> [False, False, False, ..., False, False, False]
+ >>> ]
+ >>> ]
+ >>> ]
+ >>> }
+ >>>
+ """
+
+ def __init__(self, model: str, **kwargs):
+ """
+ use `model` to create a video panoptic segmentation pipeline for prediction
+ Args:
+ model: model id on modelscope hub.
+ """
+ super().__init__(model=model, auto_collate=False, **kwargs)
+ logger.info(f'loading model from {model}')
+ model_path = osp.join(model, ModelFile.TORCH_MODEL_FILE)
+ config_path = osp.join(model, ModelFile.CONFIGURATION)
+ logger.info(f'loading config from {config_path}')
+ self.cfg = Config.from_file(config_path)
+ self.max_video_frames = kwargs.get('max_video_frames', 1000)
+
+ self.model = KNetTrack(model)
+ checkpoint = torch.load(
+ model_path, map_location=torch.device(self.device))
+ self.model.load_state_dict(checkpoint['state_dict'])
+ self.model = self.model.to(self.device).eval()
+ logger.info('load model done')
+
+ self.pad_size_divisor = 32
+ self.mean = np.array([123.675, 116.28, 103.53], np.float32)
+ self.std = np.array([58.395, 57.12, 57.375], np.float32)
+ self.to_rgb = False
+
+ def preprocess(self, input: Input) -> Dict[str, Any]:
+ """
+ Read video and process into 'imgs', 'img_metas', 'ref_img', 'ref_img_metas'
+ """
+
+ if not isinstance(input, str):
+ raise TypeError(f'input should be a str,'
+ f' but got {type(input)}')
+ imgs = []
+ img_metas = []
+
+ ref_imgs = []
+ ref_img_metas = []
+
+ cap = cv2.VideoCapture(input)
+ self.fps = cap.get(cv2.CAP_PROP_FPS)
+ self.frame_count = cap.get(cv2.CAP_PROP_FRAME_COUNT)
+ frame_idx = 0
+ while (cap.isOpened()):
+ ret, frame = cap.read()
+ if not ret:
+ break
+
+ if frame_idx > self.max_video_frames:
+ break
+
+ resize_frame = mmcv.imresize(frame, (640, 360))
+ norm_frame = mmcv.imnormalize(resize_frame, self.mean, self.std,
+ self.to_rgb)
+ pad_frame = mmcv.impad_to_multiple(
+ norm_frame, self.pad_size_divisor, pad_val=0)
+
+ ref_img_meta = {
+ 'flip': False,
+ 'flip_direction': None,
+ 'img_norm_cfg': {
+ 'mean': np.array([123.675, 116.28, 103.53],
+ dtype=np.float32),
+ 'std': np.array([58.395, 57.12, 57.375], dtype=np.float32),
+ 'to_rgb': True
+ },
+ 'video_id': 0,
+ 'is_video_data': True
+ }
+ ref_img_meta['ori_shape'] = frame.shape
+ ref_img_meta['img_shape'] = resize_frame.shape
+ ref_img_meta['pad_shape'] = pad_frame.shape
+ ref_img_meta['frame_id'] = frame_idx
+
+ if frame_idx == 0:
+ imgs = [
+ torch.from_numpy(
+ np.array([np.transpose(pad_frame,
+ [2, 0, 1])])).to(self.device)
+ ]
+ img_metas = [[ref_img_meta]]
+
+ ref_imgs.append(np.transpose(pad_frame, [2, 0, 1]))
+ ref_img_metas.append(ref_img_meta)
+
+ frame_idx += 1
+
+ ref_imgs = np.array([[ref_imgs]])
+ ref_img_metas = [[ref_img_metas]]
+
+ result = {
+ 'video_name': input,
+ 'imgs': imgs,
+ 'img_metas': img_metas,
+ 'ref_img': torch.from_numpy(ref_imgs).to(self.device),
+ 'ref_img_metas': ref_img_metas,
+ }
+ return result
+
+ def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
+ """
+ Segmentation Instance (bounding boxes or masks) in the video passed as inputs.
+
+ Args:
+ input (`Video`):
+ The pipeline handles two types of images:
+
+ - A string containing an HTTP(S) link pointing to a video
+ - A string containing a local path to a video
+
+ The pipeline accepts a single video as input.
+
+
+ Return:
+ A dictionary of result. If the input is a video, a dictionary
+ is returned.
+
+ The dictionary contain the following keys:
+
+ - **boxes** (`List[float]) -- The bounding boxes [index, x1, y1, x2, y2, score] of instance in each frame.
+ - **masks** (`List[List[bool]]`, optional) -- The instance mask [[False,...,False],...,[False,...,False]]
+ """
+
+ bbox_results = []
+ mask_results = []
+
+ with torch.no_grad():
+ imgs = input['imgs']
+ img_metas = input['img_metas']
+ ref_img = input['ref_img']
+ ref_img_metas = input['ref_img_metas']
+
+ segm_results = self.model(
+ imgs, img_metas, ref_img=ref_img, ref_img_metas=ref_img_metas)
+
+ for ii in range(len(segm_results[0])):
+ bbox_results.append(segm_results[0][ii][0])
+ mask_results.append(segm_results[0][ii][1])
+
+ output = {
+ 'boxes': bbox_results,
+ 'masks': mask_results,
+ }
+ return output
+
+ def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
+ return inputs
diff --git a/modelscope/pipelines/cv/video_single_object_tracking_pipeline.py b/modelscope/pipelines/cv/video_single_object_tracking_pipeline.py
index 4169def7..89955a53 100644
--- a/modelscope/pipelines/cv/video_single_object_tracking_pipeline.py
+++ b/modelscope/pipelines/cv/video_single_object_tracking_pipeline.py
@@ -7,8 +7,8 @@ import cv2
from modelscope.metainfo import Pipelines
from modelscope.models.cv.video_single_object_tracking.config.ostrack import \
cfg
-from modelscope.models.cv.video_single_object_tracking.tracker.ostrack import \
- OSTrack
+from modelscope.models.cv.video_single_object_tracking.tracker import (
+ OSTrack, ProContEXT)
from modelscope.models.cv.video_single_object_tracking.utils.utils import (
check_box, timestamp_format)
from modelscope.outputs import OutputKeys
@@ -20,6 +20,9 @@ from modelscope.utils.logger import get_logger
logger = get_logger()
+@PIPELINES.register_module(
+ Tasks.video_single_object_tracking,
+ module_name=Pipelines.video_single_object_tracking_procontext)
@PIPELINES.register_module(
Tasks.video_single_object_tracking,
module_name=Pipelines.video_single_object_tracking)
@@ -32,10 +35,14 @@ class VideoSingleObjectTrackingPipeline(Pipeline):
model: model id on modelscope hub.
"""
super().__init__(model=model, **kwargs)
- self.cfg = cfg
ckpt_path = osp.join(model, ModelFile.TORCH_MODEL_BIN_FILE)
logger.info(f'loading model from {ckpt_path}')
- self.tracker = OSTrack(ckpt_path, self.device)
+ if self.cfg.get('tracker', None) == 'ProContEXT':
+ self.tracker = ProContEXT(ckpt_path, self.device, self.cfg)
+ else:
+ self.cfg = cfg
+ self.tracker = OSTrack(ckpt_path, self.device)
+
logger.info('init tracker done')
def preprocess(self, input) -> Input:
diff --git a/modelscope/pipelines/cv/vidt_pipeline.py b/modelscope/pipelines/cv/vidt_pipeline.py
new file mode 100644
index 00000000..5c16c35e
--- /dev/null
+++ b/modelscope/pipelines/cv/vidt_pipeline.py
@@ -0,0 +1,207 @@
+# Copyright 2022-2023 The Alibaba Fundamental Vision Team Authors. All rights reserved.
+from typing import Any, Dict
+
+import torch
+import torchvision.transforms as transforms
+from torch import nn
+
+from modelscope.metainfo import Pipelines
+from modelscope.pipelines.base import Input, Pipeline
+from modelscope.pipelines.builder import PIPELINES
+from modelscope.preprocessors import LoadImage
+from modelscope.utils.constant import Tasks
+from modelscope.utils.logger import get_logger
+
+logger = get_logger()
+
+
+@PIPELINES.register_module(
+ Tasks.image_object_detection, module_name=Pipelines.vidt)
+class VidtPipeline(Pipeline):
+
+ def __init__(self, model: str, **kwargs):
+ """
+ use `model` to create a vidt pipeline for prediction
+ Args:
+ model: model id on modelscope hub.
+ Example:
+ >>> from modelscope.pipelines import pipeline
+ >>> vidt_pipeline = pipeline('image-object-detection', 'damo/ViDT-logo-detection')
+ >>> result = vidt_pipeline(
+ 'data/test/images/vidt_test1.png')
+ >>> print(f'Output: {result}.')
+ """
+ super().__init__(model=model, **kwargs)
+
+ self.model.eval()
+ self.transform = transforms.Compose([
+ transforms.Resize([640, 640]),
+ transforms.ToTensor(),
+ transforms.Normalize(
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+ ])
+ self.postprocessors = PostProcess()
+ self.label_dic = {0: 'negative', 1: 'positive'}
+
+ def preprocess(self, inputs: Input, **preprocess_params):
+ img = LoadImage.convert_to_img(inputs)
+ ori_size = [img.size[1], img.size[0]]
+ image = self.transform(img)
+ tensor_list = [image]
+ orig_target_sizes = [ori_size]
+ orig_target_sizes = torch.tensor(orig_target_sizes).to(self.device)
+ samples = nested_tensor_from_tensor_list(tensor_list)
+ samples = samples.to(self.device)
+ res = {}
+ res['tensors'] = samples.tensors
+ res['mask'] = samples.mask
+ res['orig_target_sizes'] = orig_target_sizes
+ return res
+
+ def forward(self, inputs: Dict[str, Any], **forward_params):
+ tensors = inputs['tensors']
+ mask = inputs['mask']
+ orig_target_sizes = inputs['orig_target_sizes']
+ with torch.no_grad():
+ out_pred_logits, out_pred_boxes = self.model(tensors, mask)
+ res = {}
+ res['out_pred_logits'] = out_pred_logits
+ res['out_pred_boxes'] = out_pred_boxes
+ res['orig_target_sizes'] = orig_target_sizes
+ return res
+
+ def postprocess(self, inputs: Dict[str, Any], **post_params):
+ results = self.postprocessors(inputs['out_pred_logits'],
+ inputs['out_pred_boxes'],
+ inputs['orig_target_sizes'])
+ batch_predictions = get_predictions(results)[0] # 仅支持单张图推理
+ scores = []
+ labels = []
+ boxes = []
+ for sub_pre in batch_predictions:
+ scores.append(sub_pre[0])
+ labels.append(self.label_dic[sub_pre[1]])
+ boxes.append(sub_pre[2]) # [xmin, ymin, xmax, ymax]
+ outputs = {}
+ outputs['scores'] = scores
+ outputs['labels'] = labels
+ outputs['boxes'] = boxes
+ return outputs
+
+
+def nested_tensor_from_tensor_list(tensor_list):
+ # TODO make it support different-sized images
+ max_size = _max_by_axis([list(img.shape) for img in tensor_list])
+ batch_shape = [len(tensor_list)] + max_size
+ b, c, h, w = batch_shape
+ dtype = tensor_list[0].dtype
+ device = tensor_list[0].device
+ tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
+ mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
+ for img, pad_img, m in zip(tensor_list, tensor, mask):
+ pad_img[:img.shape[0], :img.shape[1], :img.shape[2]].copy_(img)
+ m[:img.shape[1], :img.shape[2]] = False
+ return NestedTensor(tensor, mask)
+
+
+def _max_by_axis(the_list):
+ # type: (List[List[int]]) -> List[int]
+ maxes = the_list[0]
+ for sublist in the_list[1:]:
+ for index, item in enumerate(sublist):
+ maxes[index] = max(maxes[index], item)
+ return maxes
+
+
+class NestedTensor(object):
+
+ def __init__(self, tensors, mask):
+ self.tensors = tensors
+ self.mask = mask
+
+ def to(self, device):
+ # type: (Device) -> NestedTensor # noqa
+ cast_tensor = self.tensors.to(device)
+ mask = self.mask
+ if mask is not None:
+ assert mask is not None
+ cast_mask = mask.to(device)
+ else:
+ cast_mask = None
+ return NestedTensor(cast_tensor, cast_mask)
+
+ def decompose(self):
+ return self.tensors, self.mask
+
+ def __repr__(self):
+ return str(self.tensors)
+
+
+def box_cxcywh_to_xyxy(x):
+ x_c, y_c, w, h = x.unbind(-1)
+ b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
+ return torch.stack(b, dim=-1)
+
+
+# process post_results
+def get_predictions(post_results, bbox_thu=0.40):
+ batch_final_res = []
+ for per_img_res in post_results:
+ per_img_final_res = []
+ for i in range(len(per_img_res['scores'])):
+ score = float(per_img_res['scores'][i].cpu())
+ label = int(per_img_res['labels'][i].cpu())
+ bbox = []
+ for it in per_img_res['boxes'][i].cpu():
+ bbox.append(int(it))
+ if score >= bbox_thu:
+ per_img_final_res.append([score, label, bbox])
+ batch_final_res.append(per_img_final_res)
+ return batch_final_res
+
+
+class PostProcess(nn.Module):
+ """ This module converts the model's output into the format expected by the coco api"""
+
+ def __init__(self, processor_dct=None):
+ super().__init__()
+ # For instance segmentation using UQR module
+ self.processor_dct = processor_dct
+
+ @torch.no_grad()
+ def forward(self, out_logits, out_bbox, target_sizes):
+ """ Perform the computation
+
+ Parameters:
+ out_logits: raw logits outputs of the model
+ out_bbox: raw bbox outputs of the model
+ target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch
+ For evaluation, this must be the original image size (before any data augmentation)
+ For visualization, this should be the image size after data augment, but before padding
+ """
+ assert len(out_logits) == len(target_sizes)
+ assert target_sizes.shape[1] == 2
+
+ prob = out_logits.sigmoid()
+ topk_values, topk_indexes = torch.topk(
+ prob.view(out_logits.shape[0], -1), 100, dim=1)
+ scores = topk_values
+ topk_boxes = topk_indexes // out_logits.shape[2]
+ labels = topk_indexes % out_logits.shape[2]
+ boxes = box_cxcywh_to_xyxy(out_bbox)
+ boxes = torch.gather(boxes, 1,
+ topk_boxes.unsqueeze(-1).repeat(1, 1, 4))
+
+ # and from relative [0, 1] to absolute [0, height] coordinates
+ img_h, img_w = target_sizes.unbind(1)
+ scale_fct = torch.stack([img_w, img_h, img_w, img_h],
+ dim=1).to(torch.float32)
+ boxes = boxes * scale_fct[:, None, :]
+
+ results = [{
+ 'scores': s,
+ 'labels': l,
+ 'boxes': b
+ } for s, l, b in zip(scores, labels, boxes)]
+
+ return results
diff --git a/modelscope/pipelines/cv/vision_efficient_tuning_pipeline.py b/modelscope/pipelines/cv/vision_efficient_tuning_pipeline.py
index 2e3c45cc..50289168 100644
--- a/modelscope/pipelines/cv/vision_efficient_tuning_pipeline.py
+++ b/modelscope/pipelines/cv/vision_efficient_tuning_pipeline.py
@@ -10,7 +10,7 @@ from modelscope.metainfo import Pipelines
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Input, Pipeline
from modelscope.pipelines.builder import PIPELINES
-from modelscope.preprocessors import LoadImage
+from modelscope.preprocessors import LoadImage, Preprocessor
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger
@@ -40,25 +40,55 @@ class VisionEfficientTuningPipeline(Pipeline):
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.model = self.model.to(self.device)
self.model.eval()
- self.transform = transforms.Compose([
- transforms.Resize(224),
- transforms.ToTensor(),
- transforms.Normalize(
- mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
- ])
- def preprocess(self, input: Input) -> Dict[str, Any]:
- img = LoadImage.convert_to_img(input)
- data = self.transform(img).unsqueeze(0).to(self.device)
- return data
+ self.preprocessor = Preprocessor.from_pretrained(
+ self.model.model_dir, **kwargs)
- def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
+ if self.preprocessor is None:
+ self.preprocessor = transforms.Compose([
+ transforms.Resize(256),
+ transforms.CenterCrop((224, 224)),
+ transforms.ToTensor(),
+ transforms.Normalize(
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+ ])
+
+ def preprocess(self, inputs: Input, **preprocess_params) -> Dict[str, Any]:
+ """ Preprocess method build from transforms or Preprocessor """
+ in_key = 'img_path:FILE'
+ other_in_keys = ['image']
+ out_key = 'imgs'
+ if isinstance(self.preprocessor, Preprocessor):
+ if not isinstance(inputs, dict):
+ inputs = {in_key: inputs}
+ elif in_key not in inputs:
+ for ik in other_in_keys:
+ if ik in inputs and isinstance(inputs[ik], str):
+ inputs = {in_key: inputs[ik]}
+ break
+ data = self.preprocessor(inputs)
+ result = {out_key: data[out_key].unsqueeze(0).to(self.device)}
+ else:
+ if isinstance(inputs, dict):
+ for ik in [in_key] + other_in_keys:
+ if ik in inputs:
+ inputs = inputs[ik]
+ break
+ img = LoadImage.convert_to_img(inputs)
+ data = self.preprocessor(img)
+ result = {out_key: data.unsqueeze(0).to(self.device)}
+ return result
+
+ def forward(self, inputs: Dict[str, Any],
+ **forward_params) -> Dict[str, Any]:
with torch.no_grad():
- results = self.model(input)
+ results = self.model(inputs)
return results
- def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
- scores = F.softmax(inputs, dim=1).cpu().numpy()
+ def postprocess(self, inputs: Dict[str, Any],
+ **post_params) -> Dict[str, Any]:
+ """ Postprocess for classification """
+ scores = inputs[OutputKeys.SCORES].cpu().numpy()
pred_scores = np.sort(scores, axis=1)[0][::-1][:5]
pred_labels = np.argsort(scores, axis=1)[0][::-1][:5]
diff --git a/modelscope/pipelines/cv/vop_retrieval_se_pipeline.py b/modelscope/pipelines/cv/vop_retrieval_se_pipeline.py
new file mode 100644
index 00000000..779957c5
--- /dev/null
+++ b/modelscope/pipelines/cv/vop_retrieval_se_pipeline.py
@@ -0,0 +1,142 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+
+import gzip
+import os.path as osp
+from typing import Any, Dict
+
+import numpy as np
+import torch
+
+from modelscope.metainfo import Pipelines
+from modelscope.models import Model
+from modelscope.models.cv.vop_retrieval import (LengthAdaptiveTokenizer,
+ init_transform_dict, load_data,
+ load_frames_from_video)
+from modelscope.outputs import OutputKeys
+from modelscope.pipelines.base import Input, Pipeline
+from modelscope.pipelines.builder import PIPELINES
+from modelscope.utils.config import Config
+from modelscope.utils.constant import ModelFile, Tasks
+from modelscope.utils.logger import get_logger
+
+logger = get_logger()
+
+
+@PIPELINES.register_module(
+ Tasks.vop_retrieval, module_name=Pipelines.vop_retrieval_se)
+class VopRetrievalSEPipeline(Pipeline):
+
+ def __init__(self, model: str, **kwargs):
+ r""" Card VopRetrievalSE Pipeline.
+
+ Examples:
+ >>>
+ >>> from modelscope.pipelines import pipeline
+ >>> vop_pipeline = pipeline(Tasks.vop_retrieval,
+ >>> model='damo/cv_vit-b32_retrieval_vop_bias')
+ >>>
+ >>> # IF DO TEXT-TO-VIDEO:
+ >>> input_text = 'a squid is talking'
+ >>> result = vop_pipeline(input_text)
+ >>> result:
+ >>> {'output_data': array([['video8916']], dtype='>>
+ >>> # IF DO VIDEO-TO-TEXT:
+ >>> input_video = 'video10.mp4'
+ >>> result = vop_pipeline(input_video)
+ >>> result:
+ >>> {'output_data': array([['assorted people are shown holding cute pets']], dtype='>>
+ """
+ super().__init__(model=model, **kwargs)
+
+ # [from pretrain] load model
+ self.model = Model.from_pretrained(model).to(self.device)
+ logger.info('load model done')
+
+ # others: load transform
+ self.local_pth = model
+ self.cfg = Config.from_file(osp.join(model, ModelFile.CONFIGURATION))
+ self.img_transform = init_transform_dict(
+ self.cfg.hyperparam.input_res)['clip_test']
+ logger.info('load transform done')
+
+ # others: load tokenizer
+ bpe_path = gzip.open(osp.join(
+ model,
+ 'bpe_simple_vocab_16e6.txt.gz')).read().decode('utf-8').split('\n')
+ self.tokenizer = LengthAdaptiveTokenizer(self.cfg.hyperparam, bpe_path)
+ logger.info('load tokenizer done')
+
+ # others: load dataset
+ if 'vop_bias' in model:
+ self.database = load_data(
+ osp.join(model, 'Bias_msrvtt9k_features.pkl'), self.device)
+ elif 'vop_partial' in model:
+ self.database = load_data(
+ osp.join(model, 'Partial_msrvtt9k_features.pkl'), self.device)
+ elif 'vop_proj' in model:
+ self.database = load_data(
+ osp.join(model, 'Proj_msrvtt9k_features.pkl'), self.device)
+ else:
+ self.database = load_data(
+ osp.join(model, 'VoP_msrvtt9k_features.pkl'), self.device)
+ logger.info('load database done')
+
+ def preprocess(self, input: Input, **preprocess_params) -> Dict[str, Any]:
+ if isinstance(input, str):
+ if '.mp4' in input:
+ query = []
+ for video_path in [input]:
+ video_path = osp.join(self.local_pth, video_path)
+ imgs, idxs = load_frames_from_video(
+ video_path, self.cfg.hyperparam.num_frames,
+ self.cfg.hyperparam.video_sample_type)
+ imgs = self.img_transform(imgs)
+ query.append(imgs)
+ query = torch.stack(
+ query, dim=0).to(
+ self.device, non_blocking=True)
+ mode = 'v2t'
+ else:
+ query = self.tokenizer(
+ input, return_tensors='pt', padding=True, truncation=True)
+ if isinstance(query, torch.Tensor):
+ query = query.to(self.device, non_blocking=True)
+ else:
+ query = {
+ key: val.to(self.device, non_blocking=True)
+ for key, val in query.items()
+ }
+ mode = 't2v'
+ else:
+ raise TypeError(f'input should be a str,'
+ f' but got {type(input)}')
+ result = {'input_data': query, 'mode': mode}
+ return result
+
+ def forward(self, input: Dict[str, Any],
+ **forward_params) -> Dict[str, Any]:
+ text_embeds, vid_embeds_pooled, vid_ids, texts = self.database
+ with torch.no_grad():
+ if input['mode'] == 't2v':
+ query_feats = self.model.get_text_features(input['input_data'])
+ score = query_feats @ vid_embeds_pooled.T
+ retrieval_idxs = torch.topk(
+ score, k=self.cfg.hyperparam.topk,
+ dim=-1)[1].cpu().numpy()
+ res = np.array(vid_ids)[retrieval_idxs]
+ elif input['mode'] == 'v2t':
+ query_feats = self.model.get_video_features(
+ input['input_data'])
+ score = query_feats @ text_embeds.T
+ retrieval_idxs = torch.topk(
+ score, k=self.cfg.hyperparam.topk,
+ dim=-1)[1].cpu().numpy()
+ res = np.array(texts)[retrieval_idxs]
+ results = {'output_data': res, 'mode': input['mode']}
+ return results
+
+ def postprocess(self, inputs: Dict[str, Any],
+ **post_params) -> Dict[str, Any]:
+ return inputs
diff --git a/modelscope/pipelines/multi_modal/__init__.py b/modelscope/pipelines/multi_modal/__init__.py
index e8ca1a3c..2e496952 100644
--- a/modelscope/pipelines/multi_modal/__init__.py
+++ b/modelscope/pipelines/multi_modal/__init__.py
@@ -19,6 +19,8 @@ if TYPE_CHECKING:
from .video_captioning_pipeline import VideoCaptioningPipeline
from .video_question_answering_pipeline import VideoQuestionAnsweringPipeline
from .diffusers_wrapped import StableDiffusionWrapperPipeline, ChineseStableDiffusionPipeline
+ from .soonet_video_temporal_grounding_pipeline import SOONetVideoTemporalGroundingPipeline
+ from .text_to_video_synthesis_pipeline import TextToVideoSynthesisPipeline
else:
_import_structure = {
'image_captioning_pipeline': ['ImageCaptioningPipeline'],
@@ -39,7 +41,10 @@ else:
'video_question_answering_pipeline':
['VideoQuestionAnsweringPipeline'],
'diffusers_wrapped':
- ['StableDiffusionWrapperPipeline', 'ChineseStableDiffusionPipeline']
+ ['StableDiffusionWrapperPipeline', 'ChineseStableDiffusionPipeline'],
+ 'soonet_video_temporal_grounding_pipeline':
+ ['SOONetVideoTemporalGroundingPipeline'],
+ 'text_to_video_synthesis_pipeline': ['TextToVideoSynthesisPipeline'],
}
import sys
diff --git a/modelscope/pipelines/multi_modal/disco_guided_diffusion_pipeline/__init__.py b/modelscope/pipelines/multi_modal/disco_guided_diffusion_pipeline/__init__.py
new file mode 100644
index 00000000..41ee2ad4
--- /dev/null
+++ b/modelscope/pipelines/multi_modal/disco_guided_diffusion_pipeline/__init__.py
@@ -0,0 +1,23 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+from typing import TYPE_CHECKING
+
+from modelscope.utils.import_utils import LazyImportModule
+
+if TYPE_CHECKING:
+ from .disco_guided_diffusion import DiscoDiffusionPipeline
+ from .utils import resize
+else:
+ _import_structure = {
+ 'disco_guided_diffusion': ['DiscoDiffusionPipeline'],
+ 'utils': ['resize'],
+ }
+
+ import sys
+
+ sys.modules[__name__] = LazyImportModule(
+ __name__,
+ globals()['__file__'],
+ _import_structure,
+ module_spec=__spec__,
+ extra_objects={},
+ )
diff --git a/modelscope/pipelines/multi_modal/disco_guided_diffusion_pipeline/disco_guided_diffusion.py b/modelscope/pipelines/multi_modal/disco_guided_diffusion_pipeline/disco_guided_diffusion.py
new file mode 100644
index 00000000..59ab67f8
--- /dev/null
+++ b/modelscope/pipelines/multi_modal/disco_guided_diffusion_pipeline/disco_guided_diffusion.py
@@ -0,0 +1,430 @@
+# This code is borrowed and modified from Guided Diffusion Model,
+# made publicly available under MIT license at
+# https://github.com/IDEA-CCNL/Fengshenbang-LM/tree/main/fengshen/examples/disco_project
+
+import gc
+import importlib
+import math
+import os
+
+import clip
+import cv2
+import json
+import numpy as np
+import torch
+import torch.nn as nn
+import torchvision.transforms as T
+import torchvision.transforms.functional as TF
+from PIL import Image
+from torch.nn import functional as F
+
+from modelscope.metainfo import Pipelines
+from modelscope.models.multi_modal.guided_diffusion.script import \
+ create_diffusion
+from modelscope.models.multi_modal.guided_diffusion.unet import HFUNetModel
+from modelscope.outputs import OutputKeys
+from modelscope.pipelines.builder import PIPELINES
+from modelscope.pipelines.multi_modal.diffusers_wrapped.diffusers_pipeline import \
+ DiffusersPipeline
+from modelscope.utils.constant import Tasks
+from .utils import resize
+
+
+def parse_prompt(prompt):
+ if prompt.startswith('http://') or prompt.startswith('https://'):
+ vals = prompt.rsplit(':', 2)
+ vals = [vals[0] + ':' + vals[1], *vals[2:]]
+ else:
+ vals = prompt.rsplit(':', 1)
+ vals = vals + ['', '1'][len(vals):]
+ return vals[0], float(vals[1])
+
+
+def sinc(x):
+ return torch.where(x != 0,
+ torch.sin(math.pi * x) / (math.pi * x), x.new_ones([]))
+
+
+def lanczos(x, a):
+ cond = torch.logical_and(-a < x, x < a)
+ out = torch.where(cond, sinc(x) * sinc(x / a), x.new_zeros([]))
+ return out / out.sum()
+
+
+class MakeCutoutsDango(nn.Module):
+
+ def __init__(
+ self,
+ cut_size,
+ Overview=4,
+ InnerCrop=0,
+ IC_Size_Pow=0.5,
+ IC_Grey_P=0.2,
+ ):
+ super().__init__()
+ self.padargs = {}
+ self.cutout_debug = False
+ self.cut_size = cut_size
+ self.Overview = Overview
+ self.InnerCrop = InnerCrop
+ self.IC_Size_Pow = IC_Size_Pow
+ self.IC_Grey_P = IC_Grey_P
+ self.augs = T.Compose([
+ T.RandomHorizontalFlip(p=0.5),
+ T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
+ T.RandomAffine(
+ degrees=10,
+ translate=(0.05, 0.05),
+ interpolation=T.InterpolationMode.BILINEAR),
+ T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
+ T.RandomGrayscale(p=0.1),
+ T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
+ T.ColorJitter(
+ brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
+ ])
+
+ def forward(self, input):
+ cutouts = []
+ gray = T.Grayscale(3)
+ sideY, sideX = input.shape[2:4]
+ max_size = min(sideX, sideY)
+ min_size = min(sideX, sideY, self.cut_size)
+ output_shape = [1, 3, self.cut_size, self.cut_size]
+ pad_input = F.pad(input,
+ ((sideY - max_size) // 2, (sideY - max_size) // 2,
+ (sideX - max_size) // 2, (sideX - max_size) // 2),
+ **self.padargs)
+ cutout = resize(pad_input, out_shape=output_shape)
+
+ if self.Overview > 0:
+ if self.Overview <= 4:
+ if self.Overview >= 1:
+ cutouts.append(cutout)
+ if self.Overview >= 2:
+ cutouts.append(gray(cutout))
+ if self.Overview >= 3:
+ cutouts.append(TF.hflip(cutout))
+ if self.Overview == 4:
+ cutouts.append(gray(TF.hflip(cutout)))
+ else:
+ cutout = resize(pad_input, out_shape=output_shape)
+ for _ in range(self.Overview):
+ cutouts.append(cutout)
+
+ if self.cutout_debug:
+ TF.to_pil_image(cutouts[0].clamp(0, 1).squeeze(0)).save(
+ 'cutout_overview0.jpg', quality=99)
+
+ if self.InnerCrop > 0:
+ for i in range(self.InnerCrop):
+ size = int(
+ torch.rand([])**self.IC_Size_Pow * (max_size - min_size)
+ + min_size)
+ offsetx = torch.randint(0, sideX - size + 1, ())
+ offsety = torch.randint(0, sideY - size + 1, ())
+ cutout = input[:, :, offsety:offsety + size,
+ offsetx:offsetx + size]
+ if i <= int(self.IC_Grey_P * self.InnerCrop):
+ cutout = gray(cutout)
+ cutout = resize(cutout, out_shape=output_shape)
+ cutouts.append(cutout)
+ if self.cutout_debug:
+ TF.to_pil_image(cutouts[-1].clamp(0, 1).squeeze(0)).save(
+ 'cutout_InnerCrop.jpg', quality=99)
+ cutouts = torch.cat(cutouts)
+
+ cutouts = self.augs(cutouts)
+ return cutouts
+
+
+def spherical_dist_loss(x, y):
+ x = F.normalize(x, dim=-1)
+ y = F.normalize(y, dim=-1)
+ return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)
+
+
+def tv_loss(input):
+ """L2 total variation loss, as in Mahendran et al."""
+ input = F.pad(input, (0, 1, 0, 1), 'replicate')
+ x_diff = input[..., :-1, 1:] - input[..., :-1, :-1]
+ y_diff = input[..., 1:, :-1] - input[..., :-1, :-1]
+ return (x_diff**2 + y_diff**2).mean([1, 2, 3])
+
+
+def range_loss(input):
+ return (input - input.clamp(-1, 1)).pow(2).mean([1, 2, 3])
+
+
+normalize = T.Normalize(
+ mean=[0.48145466, 0.4578275, 0.40821073],
+ std=[0.26862954, 0.26130258, 0.27577711])
+
+
+@PIPELINES.register_module(
+ Tasks.text_to_image_synthesis,
+ module_name=Pipelines.disco_guided_diffusion)
+class DiscoDiffusionPipeline(DiffusersPipeline):
+
+ def __init__(self, model: str, device: str = 'gpu', **kwargs):
+ """ Chinese Disco Diffusion Pipeline.
+
+ Examples:
+
+ >>> import cv2
+ >>> from modelscope.pipelines import pipeline
+ >>> from modelscope.utils.constant import Tasks
+
+ >>> prompt = '赛博朋克,城市'
+ >>> output_image_path = './result.png'
+ >>> input = {
+ >>> 'text': prompt
+ >>> }
+ >>> pipe = pipeline(
+ >>> Tasks.text_to_image_synthesis,
+ >>> model='yyqoni/yinyueqin_cyberpunk',
+ >>> model_revision='v1.0')
+ >>> output = pipe(input)['output_imgs'][0]
+ >>> cv2.imwrite(output_image_path, output)
+ >>> print('pipeline: the output image path is {}'.format(output_image_path))
+ """
+
+ super().__init__(model, device, **kwargs)
+
+ model_path = model
+
+ model_config = {'steps': 100, 'use_fp16': True}
+ self.diffusion = create_diffusion(model_config)
+
+ self.unet = HFUNetModel.from_pretrained(f'{model_path}/unet')
+
+ self.unet.requires_grad_(False).eval().to(self.device)
+ for name, param in self.unet.named_parameters():
+ if 'qkv' in name or 'norm' in name or 'proj' in name:
+ param.requires_grad_()
+ if model_config['use_fp16']:
+ self.unet.convert_to_fp16()
+
+ with open(
+ os.path.join(model_path, 'model_index.json'),
+ 'r',
+ encoding='utf-8') as reader:
+ text = reader.read()
+ config_dict = json.loads(text)
+
+ library = importlib.import_module(config_dict['tokenizer'][0])
+ class_name = config_dict['tokenizer'][1]
+
+ self.taiyi_tokenizer = getattr(
+ library, class_name).from_pretrained(f'{model_path}/tokenizer')
+
+ library = importlib.import_module(config_dict['text_encoder'][0])
+ class_name = config_dict['text_encoder'][1]
+
+ self.taiyi_transformer = getattr(library, class_name).from_pretrained(
+ f'{model_path}/text_encoder').eval().to(self.device)
+
+ self.clip_models = []
+ self.clip_models.append(
+ clip.load('ViT-L/14',
+ jit=False)[0].eval().requires_grad_(False).to(
+ self.device))
+
+ def forward(self,
+ inputs,
+ init=None,
+ init_scale=2000,
+ skip_steps=10,
+ randomize_class=True,
+ eta=0.8,
+ output_type='pil',
+ return_dict=True,
+ clip_guidance_scale=7500):
+ if not isinstance(inputs, dict):
+ raise ValueError(
+ f'Expected the input to be a dictionary, but got {type(input)}'
+ )
+ if 'text' not in inputs:
+ raise ValueError('input should contain "text", but not found')
+
+ batch_size = 1
+ cutn_batches = 1
+
+ tv_scale = 0
+ range_scale = 150
+ sat_scale = 0
+
+ cut_overview = eval('[12]*400+[4]*600')
+ cut_innercut = eval('[4]*400+[12]*600')
+ cut_ic_pow = eval('[1]*1000')
+ cut_icgray_p = eval('[0.2]*400+[0]*600')
+
+ side_x = 512
+ side_y = 512
+
+ if 'width' in inputs:
+ side_x = inputs['width']
+ if 'height' in inputs:
+ side_y = inputs['height']
+ frame_prompt = [inputs.get('text')]
+ loss_values = []
+
+ model_stats = []
+ for clip_model in self.clip_models:
+ # cutn = 16
+ model_stat = {
+ 'clip_model': None,
+ 'target_embeds': [],
+ 'make_cutouts': None,
+ 'weights': []
+ }
+ model_stat['clip_model'] = clip_model
+
+ for prompt in frame_prompt:
+ txt, weight = parse_prompt(prompt)
+ # NOTE use chinese CLIP
+ txt = self.taiyi_transformer(
+ self.taiyi_tokenizer(txt,
+ return_tensors='pt')['input_ids'].to(
+ self.device)).logits
+
+ model_stat['target_embeds'].append(txt)
+ model_stat['weights'].append(weight)
+
+ model_stat['target_embeds'] = torch.cat(
+ model_stat['target_embeds'])
+ model_stat['weights'] = torch.tensor(
+ model_stat['weights'], device=self.device)
+ if model_stat['weights'].sum().abs() < 1e-3:
+ raise RuntimeError('The weights must not sum to 0.')
+ model_stat['weights'] /= model_stat['weights'].sum().abs()
+ model_stats.append(model_stat)
+
+ init = None
+ cur_t = None
+
+ def cond_fn(x, t, y=None):
+ with torch.enable_grad():
+ x_is_NaN = False
+ x = x.detach().requires_grad_()
+ n = x.shape[0]
+
+ my_t = torch.ones([n], device=self.device,
+ dtype=torch.long) * cur_t
+ out = self.diffusion.p_mean_variance(
+ self.unet,
+ x,
+ my_t,
+ clip_denoised=False,
+ model_kwargs={'y': y})
+ fac = self.diffusion.sqrt_one_minus_alphas_cumprod[cur_t]
+ x_in = out['pred_xstart'] * fac + x * (1 - fac)
+ x_in_grad = torch.zeros_like(x_in)
+
+ for model_stat in model_stats:
+ for i in range(cutn_batches):
+ t_int = int(t.item()) + 1
+ input_resolution = model_stat[
+ 'clip_model'].visual.input_resolution
+
+ cuts = MakeCutoutsDango(
+ input_resolution,
+ Overview=cut_overview[1000 - t_int],
+ InnerCrop=cut_innercut[1000 - t_int],
+ IC_Size_Pow=cut_ic_pow[1000 - t_int],
+ IC_Grey_P=cut_icgray_p[1000 - t_int],
+ )
+ clip_in = normalize(cuts(x_in.add(1).div(2)))
+ image_embeds = model_stat['clip_model'].encode_image(
+ clip_in).float()
+ dists = spherical_dist_loss(
+ image_embeds.unsqueeze(1),
+ model_stat['target_embeds'].unsqueeze(0))
+ dists = dists.view([
+ cut_overview[1000 - t_int]
+ + cut_innercut[1000 - t_int], n, -1
+ ])
+ losses = dists.mul(
+ model_stat['weights']).sum(2).mean(0)
+ loss_values.append(losses.sum().item(
+ )) # log loss, probably shouldn't do per cutn_batch
+ x_in_grad += torch.autograd.grad(
+ losses.sum() * clip_guidance_scale,
+ x_in)[0] / cutn_batches
+ tv_losses = tv_loss(x_in)
+ range_losses = range_loss(out['pred_xstart'])
+ sat_losses = torch.abs(x_in - x_in.clamp(min=-1, max=1)).mean()
+ loss = tv_losses.sum() * tv_scale + range_losses.sum(
+ ) * range_scale + sat_losses.sum() * sat_scale
+ if init is not None and init_scale:
+ init_losses = self.lpips_model(x_in, init)
+ loss = loss + init_losses.sum() * init_scale
+ x_in_grad += torch.autograd.grad(loss, x_in)[0]
+ if not torch.isnan(x_in_grad).any():
+ grad = -torch.autograd.grad(x_in, x, x_in_grad)[0]
+ else:
+ x_is_NaN = True
+ grad = torch.zeros_like(x)
+ if not x_is_NaN:
+ magnitude = grad.square().mean().sqrt()
+ return grad * magnitude.clamp(max=0.05) / magnitude
+ return grad
+
+ sample_fn = self.diffusion.ddim_sample_loop_progressive
+
+ n_batches = 1
+
+ for i in range(n_batches):
+ gc.collect()
+ torch.cuda.empty_cache()
+ cur_t = self.diffusion.num_timesteps - skip_steps - 1
+
+ samples = sample_fn(
+ self.unet,
+ (batch_size, 3, side_y, side_x),
+ clip_denoised=False,
+ model_kwargs={},
+ cond_fn=cond_fn,
+ progress=True,
+ skip_timesteps=skip_steps,
+ init_image=init,
+ randomize_class=randomize_class,
+ eta=eta,
+ )
+
+ for j, sample in enumerate(samples):
+ image = sample['pred_xstart']
+ image = (image / 2 + 0.5).clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
+
+ if output_type == 'pil':
+ image = self.numpy_to_pil(image)
+ return image
+
+ if not return_dict:
+ return (image, None)
+
+ @staticmethod
+ def numpy_to_pil(images):
+ """
+ Convert a numpy image or a batch of images to a PIL image.
+ """
+ if images.ndim == 3:
+ images = images[None, ...]
+ images = (images * 255).round().astype('uint8')
+ if images.shape[-1] == 1:
+ # special case for grayscale (single channel) images
+ pil_images = [
+ Image.fromarray(image.squeeze(), mode='L') for image in images
+ ]
+ else:
+ pil_images = [Image.fromarray(image) for image in images]
+
+ return pil_images
+
+ def postprocess(self, inputs):
+ images = []
+ for img in inputs:
+ if isinstance(img, Image.Image):
+ img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
+ images.append(img)
+ return {OutputKeys.OUTPUT_IMGS: images}
diff --git a/modelscope/pipelines/multi_modal/disco_guided_diffusion_pipeline/utils.py b/modelscope/pipelines/multi_modal/disco_guided_diffusion_pipeline/utils.py
new file mode 100644
index 00000000..09772ccc
--- /dev/null
+++ b/modelscope/pipelines/multi_modal/disco_guided_diffusion_pipeline/utils.py
@@ -0,0 +1,468 @@
+# The implementation is adopted from https://github.com/assafshocher/ResizeRight
+import warnings
+from fractions import Fraction
+from math import ceil
+
+
+class NoneClass:
+ pass
+
+
+try:
+ import torch
+ from torch import nn
+ nnModuleWrapped = nn.Module
+except ImportError:
+ warnings.warn('No PyTorch found, will work only with Numpy')
+ torch = None
+ nnModuleWrapped = NoneClass
+
+try:
+ import numpy
+except ImportError:
+ warnings.warn('No Numpy found, will work only with PyTorch')
+ numpy = None
+
+if numpy is None and torch is None:
+ raise ImportError('Must have either Numpy or PyTorch but both not found')
+
+
+def set_framework_dependencies(x):
+ if type(x) is numpy.ndarray:
+
+ def to_dtype(a):
+ return a
+
+ fw = numpy
+ else:
+
+ def to_dtype(a):
+ return a.to(x.dtype)
+
+ fw = torch
+ eps = fw.finfo(fw.float32).eps
+ return fw, to_dtype, eps
+
+
+def support_sz(sz):
+
+ def wrapper(f):
+ f.support_sz = sz
+ return f
+
+ return wrapper
+
+
+@support_sz(4)
+def cubic(x):
+ fw, to_dtype, eps = set_framework_dependencies(x)
+ absx = fw.abs(x)
+ absx2 = absx**2
+ absx3 = absx**3
+ v1 = (1.5 * absx3 - 2.5 * absx2 + 1.) * to_dtype(absx <= 1.)
+ v2 = (-0.5 * absx3 + 2.5 * absx2 - 4. * absx
+ + 2.) * to_dtype((1. < absx) & (absx <= 2.))
+ return v1 + v2
+
+
+def resize(input,
+ scale_factors=None,
+ out_shape=None,
+ interp_method=cubic,
+ support_sz=None,
+ antialiasing=True,
+ by_convs=False,
+ scale_tolerance=None,
+ max_numerator=10,
+ pad_mode='constant'):
+ # get properties of the input tensor
+ in_shape, n_dims = input.shape, input.ndim
+
+ # fw stands for framework that can be either numpy or torch,
+ # determined by the input type
+ fw = numpy if type(input) is numpy.ndarray else torch
+ eps = fw.finfo(fw.float32).eps
+ device = input.device if fw is torch else None
+
+ # set missing scale factors or output shapem one according to another,
+ # scream if both missing. this is also where all the defults policies
+ # take place. also handling the by_convs attribute carefully.
+ scale_factors, out_shape, by_convs = set_scale_and_out_sz(
+ in_shape, out_shape, scale_factors, by_convs, scale_tolerance,
+ max_numerator, eps, fw)
+
+ # sort indices of dimensions according to scale of each dimension.
+ # since we are going dim by dim this is efficient
+ sorted_filtered_dims_and_scales = [
+ (dim, scale_factors[dim], by_convs[dim], in_shape[dim], out_shape[dim])
+ for dim in sorted(range(n_dims), key=lambda ind: scale_factors[ind])
+ if scale_factors[dim] != 1.
+ ]
+
+ # unless support size is specified by the user, it is an attribute
+ # of the interpolation method
+ if support_sz is None:
+ support_sz = interp_method.support_sz
+
+ # output begins identical to input and changes with each iteration
+ output = input
+
+ # iterate over dims
+ for (dim, scale_factor, dim_by_convs, in_sz,
+ out_sz) in sorted_filtered_dims_and_scales:
+ # STEP 1- PROJECTED GRID: The non-integer locations of the projection
+ # of output pixel locations to the input tensor
+ projected_grid = get_projected_grid(in_sz, out_sz, scale_factor, fw,
+ dim_by_convs, device)
+
+ # STEP 1.5: ANTIALIASING- If antialiasing is taking place, we modify
+ # the window size and the interpolation method (see inside function)
+ cur_interp_method, cur_support_sz = apply_antialiasing_if_needed(
+ interp_method, support_sz, scale_factor, antialiasing)
+
+ # STEP 2- FIELDS OF VIEW: for each output pixels, map the input pixels
+ # that influence it. Also calculate needed padding and update grid
+ # accoedingly
+ field_of_view = get_field_of_view(projected_grid, cur_support_sz, fw,
+ eps, device)
+
+ # STEP 2.5- CALCULATE PAD AND UPDATE: according to the field of view,
+ # the input should be padded to handle the boundaries, coordinates
+ # should be updated. actual padding only occurs when weights are
+ # aplied (step 4). if using by_convs for this dim, then we need to
+ # calc right and left boundaries for each filter instead.
+ pad_sz, projected_grid, field_of_view = calc_pad_sz(
+ in_sz, out_sz, field_of_view, projected_grid, scale_factor,
+ dim_by_convs, fw, device)
+
+ # STEP 3- CALCULATE WEIGHTS: Match a set of weights to the pixels in
+ # the field of view for each output pixel
+ weights = get_weights(cur_interp_method, projected_grid, field_of_view)
+
+ # STEP 4- APPLY WEIGHTS: Each output pixel is calculated by multiplying
+ # its set of weights with the pixel values in its field of view.
+ # We now multiply the fields of view with their matching weights.
+ # We do this by tensor multiplication and broadcasting.
+ # if by_convs is true for this dim, then we do this action by
+ # convolutions. this is equivalent but faster.
+ if not dim_by_convs:
+ output = apply_weights(output, field_of_view, weights, dim, n_dims,
+ pad_sz, pad_mode, fw)
+ else:
+ output = apply_convs(output, scale_factor, in_sz, out_sz, weights,
+ dim, pad_sz, pad_mode, fw)
+ return output
+
+
+def get_projected_grid(in_sz, out_sz, scale_factor, fw, by_convs, device=None):
+ # we start by having the ouput coordinates which are just integer locations
+ # in the special case when usin by_convs, we only need two cycles of grid
+ # points. the first and last.
+ grid_sz = out_sz if not by_convs else scale_factor.numerator
+ out_coordinates = fw_arange(grid_sz, fw, device)
+
+ # This is projecting the ouput pixel locations in 1d to the input tensor,
+ # as non-integer locations.
+ # the following fomrula is derived in the paper
+ # "From Discrete to Continuous Convolutions" by Shocher et al.
+ v1 = out_coordinates / float(scale_factor) + (in_sz - 1) / 2
+ v2 = (out_sz - 1) / (2 * float(scale_factor))
+ return v1 - v2
+
+
+def get_field_of_view(projected_grid, cur_support_sz, fw, eps, device):
+ # for each output pixel, map which input pixels influence it, in 1d.
+ # we start by calculating the leftmost neighbor, using half of the window
+ # size (eps is for when boundary is exact int)
+ left_boundaries = fw_ceil(projected_grid - cur_support_sz / 2 - eps, fw)
+
+ # then we simply take all the pixel centers in the field by counting
+ # window size pixels from the left boundary
+ ordinal_numbers = fw_arange(ceil(cur_support_sz - eps), fw, device)
+ return left_boundaries[:, None] + ordinal_numbers
+
+
+def calc_pad_sz(in_sz, out_sz, field_of_view, projected_grid, scale_factor,
+ dim_by_convs, fw, device):
+ if not dim_by_convs:
+ # determine padding according to neighbor coords out of bound.
+ # this is a generalized notion of padding, when pad<0 it means crop
+ pad_sz = [
+ -field_of_view[0, 0].item(),
+ field_of_view[-1, -1].item() - in_sz + 1
+ ]
+
+ # since input image will be changed by padding, coordinates of both
+ # field_of_view and projected_grid need to be updated
+ field_of_view += pad_sz[0]
+ projected_grid += pad_sz[0]
+
+ else:
+ # only used for by_convs, to calc the boundaries of each filter the
+ # number of distinct convolutions is the numerator of the scale factor
+ num_convs, stride = scale_factor.numerator, scale_factor.denominator
+
+ # calculate left and right boundaries for each conv. left can also be
+ # negative right can be bigger than in_sz. such cases imply padding if
+ # needed. however if# both are in-bounds, it means we need to crop,
+ # practically apply the conv only on part of the image.
+ left_pads = -field_of_view[:, 0]
+
+ # next calc is tricky, explanation by rows:
+ # 1) counting output pixels between the first position of each filter
+ # to the right boundary of the input
+ # 2) dividing it by number of filters to count how many 'jumps'
+ # each filter does
+ # 3) multiplying by the stride gives us the distance over the input
+ # coords done by all these jumps for each filter
+ # 4) to this distance we add the right boundary of the filter when
+ # placed in its leftmost position. so now we get the right boundary
+ # of that filter in input coord.
+ # 5) the padding size needed is obtained by subtracting the rightmost
+ # input coordinate. if the result is positive padding is needed. if
+ # negative then negative padding means shaving off pixel columns.
+ right_pads = (((out_sz - fw_arange(num_convs, fw, device) - 1) # (1)
+ // num_convs) # (2)
+ * stride # (3)
+ + field_of_view[:, -1] # (4)
+ - in_sz + 1) # (5)
+
+ # in the by_convs case pad_sz is a list of left-right pairs. one per
+ # each filter
+
+ pad_sz = list(zip(left_pads, right_pads))
+
+ return pad_sz, projected_grid, field_of_view
+
+
+def get_weights(interp_method, projected_grid, field_of_view):
+ # the set of weights per each output pixels is the result of the chosen
+ # interpolation method applied to the distances between projected grid
+ # locations and the pixel-centers in the field of view (distances are
+ # directed, can be positive or negative)
+ weights = interp_method(projected_grid[:, None] - field_of_view)
+
+ # we now carefully normalize the weights to sum to 1 per each output pixel
+ sum_weights = weights.sum(1, keepdims=True)
+ sum_weights[sum_weights == 0] = 1
+ return weights / sum_weights
+
+
+def apply_weights(input, field_of_view, weights, dim, n_dims, pad_sz, pad_mode,
+ fw):
+ # for this operation we assume the resized dim is the first one.
+ # so we transpose and will transpose back after multiplying
+ tmp_input = fw_swapaxes(input, dim, 0, fw)
+
+ # apply padding
+ tmp_input = fw_pad(tmp_input, fw, pad_sz, pad_mode)
+
+ # field_of_view is a tensor of order 2: for each output (1d location
+ # along cur dim)- a list of 1d neighbors locations.
+ # note that this whole operations is applied to each dim separately,
+ # this is why it is all in 1d.
+ # neighbors = tmp_input[field_of_view] is a tensor of order image_dims+1:
+ # for each output pixel (this time indicated in all dims), these are the
+ # values of the neighbors in the 1d field of view. note that we only
+ # consider neighbors along the current dim, but such set exists for every
+ # multi-dim location, hence the final tensor order is image_dims+1.
+ neighbors = tmp_input[field_of_view]
+
+ # weights is an order 2 tensor: for each output location along 1d- a list
+ # of weights matching the field of view. we augment it with ones, for
+ # broadcasting, so that when multiplies some tensor the weights affect
+ # only its first dim.
+ tmp_weights = fw.reshape(weights, (*weights.shape, *[1] * (n_dims - 1)))
+
+ # now we simply multiply the weights with the neighbors, and then sum
+ # along the field of view, to get a single value per out pixel
+ tmp_output = (neighbors * tmp_weights).sum(1)
+
+ # we transpose back the resized dim to its original position
+ return fw_swapaxes(tmp_output, 0, dim, fw)
+
+
+def apply_convs(input, scale_factor, in_sz, out_sz, weights, dim, pad_sz,
+ pad_mode, fw):
+ # for this operations we assume the resized dim is the last one.
+ # so we transpose and will transpose back after multiplying
+ input = fw_swapaxes(input, dim, -1, fw)
+
+ # the stride for all convs is the denominator of the scale factor
+ stride, num_convs = scale_factor.denominator, scale_factor.numerator
+
+ # prepare an empty tensor for the output
+ tmp_out_shape = list(input.shape)
+ tmp_out_shape[-1] = out_sz
+ tmp_output = fw_empty(tuple(tmp_out_shape), fw, input.device)
+
+ # iterate over the conv operations. we have as many as the numerator
+ # of the scale-factor. for each we need boundaries and a filter.
+ for conv_ind, (pad_sz, filt) in enumerate(zip(pad_sz, weights)):
+ # apply padding (we pad last dim, padding can be negative)
+ pad_dim = input.ndim - 1
+ tmp_input = fw_pad(input, fw, pad_sz, pad_mode, dim=pad_dim)
+
+ # apply convolution over last dim. store in the output tensor with
+ # positional strides so that when the loop is comlete conv results are
+ # interwind
+ tmp_output[..., conv_ind::num_convs] = fw_conv(tmp_input, filt, stride)
+
+ return fw_swapaxes(tmp_output, -1, dim, fw)
+
+
+def set_scale_and_out_sz(in_shape, out_shape, scale_factors, by_convs,
+ scale_tolerance, max_numerator, eps, fw):
+ # eventually we must have both scale-factors and out-sizes for all in/out
+ # dims. however, we support many possible partial arguments
+ if scale_factors is None and out_shape is None:
+ raise ValueError('either scale_factors or out_shape should be '
+ 'provided')
+ if out_shape is not None:
+ # if out_shape has less dims than in_shape, we defaultly resize the
+ # first dims for numpy and last dims for torch
+ out_shape = (
+ list(out_shape) + list(in_shape[len(out_shape):]) if fw is numpy
+ else list(in_shape[:-len(out_shape)]) + list(out_shape))
+ if scale_factors is None:
+ # if no scale given, we calculate it as the out to in ratio
+ # (not recomended)
+ scale_factors = [
+ out_sz / in_sz for out_sz, in_sz in zip(out_shape, in_shape)
+ ]
+
+ if scale_factors is not None:
+ # by default, if a single number is given as scale, we assume resizing
+ # two dims (most common are images with 2 spatial dims)
+ scale_factors = (
+ scale_factors if isinstance(scale_factors, (list, tuple)) else
+ [scale_factors, scale_factors])
+ # if less scale_factors than in_shape dims, we defaultly resize the
+ # first dims for numpy and last dims for torch
+ if fw is numpy:
+ scale_factors = list(scale_factors) + [1] * (
+ len(in_shape) - len(scale_factors))
+ else:
+ scale_factors = [1] * (len(in_shape)
+ - len(scale_factors)) + list(scale_factors)
+ if out_shape is None:
+ # when no out_shape given, it is calculated by multiplying the
+ # scale by the in_shape (not recomended)
+ out_shape = [
+ ceil(scale_factor * in_sz)
+ for scale_factor, in_sz in zip(scale_factors, in_shape)
+ ]
+ # next part intentionally after out_shape determined for stability
+ # we fix by_convs to be a list of truth values in case it is not
+ if not isinstance(by_convs, (list, tuple)):
+ by_convs = [by_convs] * len(out_shape)
+
+ # next loop fixes the scale for each dim to be either frac or float.
+ # this is determined by by_convs and by tolerance for scale accuracy.
+ for ind, (sf, dim_by_convs) in enumerate(zip(scale_factors, by_convs)):
+ # first we fractionaize
+ if dim_by_convs:
+ frac = Fraction(1 / sf).limit_denominator(max_numerator)
+ frac = Fraction(
+ numerator=frac.denominator, denominator=frac.numerator)
+
+ # if accuracy is within tolerance scale will be frac. if not, then
+ # it will be float and the by_convs attr will be set false for
+ # this dim
+ if scale_tolerance is None:
+ scale_tolerance = eps
+ if dim_by_convs and abs(frac - sf) < scale_tolerance:
+ scale_factors[ind] = frac
+ else:
+ scale_factors[ind] = float(sf)
+ by_convs[ind] = False
+
+ return scale_factors, out_shape, by_convs
+
+
+def apply_antialiasing_if_needed(interp_method, support_sz, scale_factor,
+ antialiasing):
+ # antialiasing is "stretching" the field of view according to the scale
+ # factor (only for downscaling). this is low-pass filtering. this
+ # requires modifying both the interpolation (stretching the 1d
+ # function and multiplying by the scale-factor) and the window size.
+ scale_factor = float(scale_factor)
+ if scale_factor >= 1.0 or not antialiasing:
+ return interp_method, support_sz
+ cur_interp_method = (
+ lambda arg: scale_factor * interp_method(scale_factor * arg))
+ cur_support_sz = support_sz / scale_factor
+ return cur_interp_method, cur_support_sz
+
+
+def fw_ceil(x, fw):
+ if fw is numpy:
+ return fw.int_(fw.ceil(x))
+ else:
+ return x.ceil().long()
+
+
+def fw_floor(x, fw):
+ if fw is numpy:
+ return fw.int_(fw.floor(x))
+ else:
+ return x.floor().long()
+
+
+def fw_cat(x, fw):
+ if fw is numpy:
+ return fw.concatenate(x)
+ else:
+ return fw.cat(x)
+
+
+def fw_swapaxes(x, ax_1, ax_2, fw):
+ if fw is numpy:
+ return fw.swapaxes(x, ax_1, ax_2)
+ else:
+ return x.transpose(ax_1, ax_2)
+
+
+def fw_pad(x, fw, pad_sz, pad_mode, dim=0):
+ if pad_sz == (0, 0):
+ return x
+ if fw is numpy:
+ pad_vec = [(0, 0)] * x.ndim
+ pad_vec[dim] = pad_sz
+ return fw.pad(x, pad_width=pad_vec, mode=pad_mode)
+ else:
+ if x.ndim < 3:
+ x = x[None, None, ...]
+
+ pad_vec = [0] * ((x.ndim - 2) * 2)
+ pad_vec[0:2] = pad_sz
+ return fw.nn.functional.pad(
+ x.transpose(dim, -1), pad=pad_vec,
+ mode=pad_mode).transpose(dim, -1)
+
+
+def fw_conv(input, filter, stride):
+ # we want to apply 1d conv to any nd array. the way to do it is to reshape
+ # the input to a 4D tensor. first two dims are singeletons, 3rd dim stores
+ # all the spatial dims that we are not convolving along now. then we can
+ # apply conv2d with a 1xK filter. This convolves the same way all the other
+ # dims stored in the 3d dim. like depthwise conv over these.
+ # TODO: numpy support
+ reshaped_input = input.reshape(1, 1, -1, input.shape[-1])
+ reshaped_output = torch.nn.functional.conv2d(
+ reshaped_input, filter.view(1, 1, 1, -1), stride=(1, stride))
+ return reshaped_output.reshape(*input.shape[:-1], -1)
+
+
+def fw_arange(upper_bound, fw, device):
+ if fw is numpy:
+ return fw.arange(upper_bound)
+ else:
+ return fw.arange(upper_bound, device=device)
+
+
+def fw_empty(shape, fw, device):
+ if fw is numpy:
+ return fw.empty(shape)
+ else:
+ return fw.empty(size=(*shape, ), device=device)
diff --git a/modelscope/pipelines/multi_modal/soonet_video_temporal_grounding_pipeline.py b/modelscope/pipelines/multi_modal/soonet_video_temporal_grounding_pipeline.py
new file mode 100644
index 00000000..0251e745
--- /dev/null
+++ b/modelscope/pipelines/multi_modal/soonet_video_temporal_grounding_pipeline.py
@@ -0,0 +1,222 @@
+# Copyright 2022-2023 The Alibaba Fundamental Vision Team Authors. All rights reserved.
+
+import os
+from typing import Any, Dict
+
+import numpy as np
+import torch
+from torchvision import transforms
+
+from modelscope.metainfo import Pipelines
+from modelscope.models.multi_modal.soonet import (SimpleTokenizer,
+ decode_video, load_clip)
+from modelscope.outputs import OutputKeys
+from modelscope.pipelines.base import Input, Pipeline
+from modelscope.pipelines.builder import PIPELINES
+from modelscope.utils.config import Config
+from modelscope.utils.constant import ModelFile, Tasks
+from modelscope.utils.logger import get_logger
+
+logger = get_logger()
+
+
+@PIPELINES.register_module(
+ Tasks.video_temporal_grounding,
+ module_name=Pipelines.soonet_video_temporal_grounding)
+class SOONetVideoTemporalGroundingPipeline(Pipeline):
+
+ def __init__(self, model: str, **kwargs):
+ """
+ SOONet pipeline for video temporal groundinng
+
+ Examples:
+
+ >>> from modelscope.pipelines import pipeline
+
+ >>> soonet_pipeline = pipeline("video-temporal-grounding", "damo/multi-modal_soonet_video-temporal-grounding")
+ >>> soonet_pipeline(
+ ('a man takes food out of the refrigerator.',
+ 'soonet_video_temporal_grounding_test_video.mp4'))
+
+ >>> {
+ >>> "scores": [
+ >>> 0.80661213,
+ >>> 0.8060084,
+ >>> 0.8018835,
+ >>> 0.79837507,
+ >>> 0.7963626,
+ >>> 0.7949013,
+ >>> 0.79353744,
+ >>> 0.79287416,
+ >>> 0.79066336,
+ >>> 0.79027915
+ >>> ],
+ >>> "tbounds": [
+ >>> [
+ >>> 0,
+ >>> 2.9329566955566406
+ >>> ],
+ >>> [
+ >>> 1.0630402565002441,
+ >>> 4.9339457750320435
+ >>> ],
+ >>> [
+ >>> 300.96919429302216,
+ >>> 304.8546848297119
+ >>> ],
+ >>> [
+ >>> 302.96981167793274,
+ >>> 306.7714672088623
+ >>> ],
+ >>> [
+ >>> 0,
+ >>> 5.0421366691589355
+ >>> ],
+ >>> [
+ >>> 304.9119266271591,
+ >>> 308.7636929154396
+ >>> ],
+ >>> [
+ >>> 258.96133184432983,
+ >>> 262.805901825428
+ >>> ],
+ >>> [
+ >>> 122.9599289894104,
+ >>> 126.86622190475464
+ >>> ],
+ >>> [
+ >>> 126.94010400772095,
+ >>> 130.8090701699257
+ >>> ],
+ >>> [
+ >>> 121.04773849248886,
+ >>> 124.79261875152588
+ >>> ]
+ >>> ]
+ >>> }
+ """
+ super().__init__(model=model, **kwargs)
+
+ self.model_dir = model
+ self.clip = load_clip(os.path.join(self.model_dir,
+ 'ViT-B-32.pt')).to(self.device)
+ self.model = self.model.float().to(self.device)
+ self.model.eval()
+
+ # Load Configuration from File
+ config_path = os.path.join(self.model_dir, ModelFile.CONFIGURATION)
+ self.config = Config.from_file(config_path).hyperparams
+ self.nscales = self.config.nscales
+ self.snippet_length = self.config.snippet_length
+ self.max_anchor_length = self.snippet_length * 2**(self.nscales - 1)
+ self.topk = 10
+ self.fps = 5
+ # Define image transform
+ self.img_transform = transforms.Compose([
+ transforms.Resize(224),
+ transforms.CenterCrop(224),
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
+ (0.26862954, 0.26130258, 0.27577711)),
+ ])
+ logger.info('Init transform done')
+
+ # Init tokenizer
+ bpe_path = os.path.join(self.model_dir, 'bpe_simple_vocab_16e6.txt.gz')
+ self.tokenizer = SimpleTokenizer(bpe_path)
+ logger.info('Init tokenizer done')
+
+ def pad(self, arr, pad_len):
+ new_arr = np.zeros((pad_len, ), dtype=float)
+ new_arr[:len(arr)] = arr
+ return new_arr
+
+ def preprocess(self, input: Input, **preprocess_params) -> Dict[str, Any]:
+ text, video_name = input
+ video_path = os.path.join(self.model_dir, video_name)
+ imgs, duration = decode_video(video_path, self.fps)
+ trans_imgs = list()
+ for i, img in enumerate(imgs):
+ trans_imgs.append(self.img_transform(img))
+ imgs = trans_imgs
+ token_ids = self.tokenizer.tokenize(text).to(
+ self.device, non_blocking=True)
+ # get the start and end timestamps of anchors
+ start_ts, end_ts, scale_boundaries = list(), list(), [0]
+ ori_video_length = len(imgs)
+ pad_video_length = int(
+ np.math.ceil(ori_video_length / self.max_anchor_length)
+ * self.max_anchor_length)
+ for i in range(self.config.nscales):
+ anchor_length = self.config.snippet_length * (2**i)
+ pad_feat_length = pad_video_length // anchor_length
+ nfeats = np.math.ceil(ori_video_length / anchor_length)
+ s_times = np.arange(0, nfeats).astype(np.float32) * (
+ anchor_length // self.fps)
+ e_times = np.arange(1, nfeats + 1).astype(np.float32) * (
+ anchor_length // self.fps)
+ e_times = np.minimum(duration, e_times)
+ start_ts.append(self.pad(s_times, pad_feat_length))
+ end_ts.append(self.pad(e_times, pad_feat_length))
+ scale_boundaries.append(scale_boundaries[-1] + pad_feat_length)
+
+ start_ts = torch.from_numpy(np.concatenate(start_ts, axis=0))
+ end_ts = torch.from_numpy(np.concatenate(end_ts, axis=0))
+ scale_boundaries = torch.LongTensor(scale_boundaries)
+ result = {
+ 'token_ids': token_ids,
+ 'imgs': torch.stack(imgs, dim=0),
+ 'start_ts': start_ts,
+ 'end_ts': end_ts,
+ 'scale_boundaries': scale_boundaries
+ }
+ return result
+
+ def forward(self, input: Dict[str, Any],
+ **forward_params) -> Dict[str, Any]:
+ with torch.no_grad():
+ video_feats = self.clip.encode_image(input['imgs'].to(self.device))
+ query_feats = self.clip.encode_text(input['token_ids'].to(
+ self.device))
+ #
+ ori_video_length, feat_dim = video_feats.shape
+ pad_video_length = int(
+ np.math.ceil(ori_video_length / self.max_anchor_length)
+ * self.max_anchor_length)
+ pad_video_feats = torch.zeros((pad_video_length, feat_dim),
+ dtype=float)
+ pad_video_feats[:ori_video_length, :] = video_feats
+ final_scores, bbox_bias, starts, ends = self.model(
+ query_feats=query_feats.float().to(self.device),
+ video_feats=pad_video_feats.unsqueeze(0).float().to(
+ self.device),
+ start_ts=input['start_ts'].float().to(self.device),
+ end_ts=input['end_ts'].float().to(self.device),
+ scale_boundaries=input['scale_boundaries'])
+ #
+ final_scores = final_scores.cpu().numpy()
+ bbox_bias = bbox_bias.cpu().numpy()
+ starts = starts.cpu().numpy()
+ ends = ends.cpu().numpy()
+ pred_scores, pred_bboxes = list(), list()
+ rank_id = np.argsort(final_scores[0])[::-1]
+ for j in range(self.topk):
+ if j >= len(rank_id):
+ break
+ pred_scores.append(final_scores[0, rank_id[j]])
+ ori_end = float(ends[rank_id[j]])
+ ori_start = float(starts[rank_id[j]])
+ duration = ori_end - ori_start
+ sbias = bbox_bias[0, rank_id[j], 0]
+ ebias = bbox_bias[0, rank_id[j], 1]
+ pred_start = max(0, ori_start + sbias * duration)
+ pred_end = ori_end + ebias * duration
+ pred_bboxes.append([pred_start, pred_end])
+
+ return {
+ OutputKeys.SCORES: pred_scores,
+ OutputKeys.TBOUNDS: pred_bboxes
+ }
+
+ def postprocess(self, inputs: Dict[str, Any],
+ **post_params) -> Dict[str, Any]:
+ return inputs
diff --git a/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py b/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py
new file mode 100644
index 00000000..ee6635a6
--- /dev/null
+++ b/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py
@@ -0,0 +1,89 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+
+import tempfile
+from typing import Any, Dict, Optional
+
+import cv2
+import torch
+from einops import rearrange
+
+from modelscope.metainfo import Pipelines
+from modelscope.outputs import OutputKeys
+from modelscope.pipelines.base import Input, Model, Pipeline
+from modelscope.pipelines.builder import PIPELINES
+from modelscope.utils.constant import Tasks
+from modelscope.utils.logger import get_logger
+
+logger = get_logger()
+
+
+@PIPELINES.register_module(
+ Tasks.text_to_video_synthesis,
+ module_name=Pipelines.text_to_video_synthesis)
+class TextToVideoSynthesisPipeline(Pipeline):
+ r""" Text To Video Synthesis Pipeline.
+
+ Examples:
+ >>> from modelscope.pipelines import pipeline
+ >>> from modelscope.outputs import OutputKeys
+
+ >>> p = pipeline('text-to-video-synthesis', 'damo/text-to-video-synthesis')
+ >>> test_text = {
+ >>> 'text': 'A panda eating bamboo on a rock.',
+ >>> }
+ >>> p(test_text,)
+
+ >>> {OutputKeys.OUTPUT_VIDEO: path-to-the-generated-video}
+ >>>
+ """
+
+ def __init__(self, model: str, **kwargs):
+ """
+ use `model` to create a kws pipeline for prediction
+ Args:
+ model: model id on modelscope hub.
+ """
+ super().__init__(model=model, **kwargs)
+
+ def preprocess(self, input: Input, **preprocess_params) -> Dict[str, Any]:
+ self.model.clip_encoder.to(self.model.device)
+ text_emb = self.model.clip_encoder(input['text'])
+ text_emb_zero = self.model.clip_encoder('')
+ if self.model.config.model.model_args.tiny_gpu == 1:
+ self.model.clip_encoder.to('cpu')
+ return {'text_emb': text_emb, 'text_emb_zero': text_emb_zero}
+
+ def forward(self, input: Dict[str, Any],
+ **forward_params) -> Dict[str, Any]:
+ video = self.model(input)
+ return {'video': video}
+
+ def postprocess(self, inputs: Dict[str, Any],
+ **post_params) -> Dict[str, Any]:
+ video = tensor2vid(inputs['video'])
+ output_video_path = post_params.get('output_video', None)
+ if output_video_path is None:
+ output_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name
+
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
+ h, w, c = video[0].shape
+ video_writer = cv2.VideoWriter(
+ output_video_path, fourcc, fps=8, frameSize=(w, h))
+ for i in range(len(video)):
+ img = cv2.cvtColor(video[i], cv2.COLOR_RGB2BGR)
+ video_writer.write(img)
+ return {OutputKeys.OUTPUT_VIDEO: output_video_path}
+
+
+def tensor2vid(video, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]):
+ mean = torch.tensor(
+ mean, device=video.device).reshape(1, -1, 1, 1, 1) # ncfhw
+ std = torch.tensor(
+ std, device=video.device).reshape(1, -1, 1, 1, 1) # ncfhw
+ video = video.mul_(std).add_(mean) # unnormalize back to [0,1]
+ video.clamp_(0, 1)
+ images = rearrange(video, 'i c f h w -> f h (i w) c')
+ images = images.unbind(dim=0)
+ images = [(image.numpy() * 255).astype('uint8')
+ for image in images] # f h w c
+ return images
diff --git a/modelscope/pipelines/nlp/distributed_gpt3_pipeline.py b/modelscope/pipelines/nlp/distributed_gpt3_pipeline.py
index e098823b..1738f2da 100644
--- a/modelscope/pipelines/nlp/distributed_gpt3_pipeline.py
+++ b/modelscope/pipelines/nlp/distributed_gpt3_pipeline.py
@@ -43,7 +43,7 @@ class DistributedGPT3Pipeline(DistributedPipeline):
def _forward_one(cls, inputs: Dict[str, Any]) -> Dict[str, Any]:
tokens = inputs['inputs']['input_ids'].cuda(
torch.cuda.current_device())
- return cls.model.generate(tokens)
+ return cls.model.generate(tokens, **inputs['forward_params'])
def postprocess(self, inputs: Dict[str, Any],
**postprocess_params) -> Dict[str, str]:
@@ -61,3 +61,6 @@ class DistributedGPT3Pipeline(DistributedPipeline):
self.preprocessor.tokenizer.detokenize(
inputs.sequences[0].tolist())
}
+
+ def _sanitize_parameters(self, **pipeline_parameters):
+ return {}, pipeline_parameters, {}
diff --git a/modelscope/pipelines/nlp/named_entity_recognition_pipeline.py b/modelscope/pipelines/nlp/named_entity_recognition_pipeline.py
index 8a25c415..ba174bae 100644
--- a/modelscope/pipelines/nlp/named_entity_recognition_pipeline.py
+++ b/modelscope/pipelines/nlp/named_entity_recognition_pipeline.py
@@ -48,7 +48,7 @@ class NamedEntityRecognitionPipeline(TokenClassificationPipeline):
>>> input = '这与温岭市新河镇的一个神秘的传说有关。'
>>> print(pipeline_ins(input))
- To view other examples plese check the tests/pipelines/test_named_entity_recognition.py.
+ To view other examples plese check the tests/pipelines/test_plugin_model.py.
"""
super().__init__(
model=model,
diff --git a/modelscope/pipelines/nlp/siamese_uie_pipeline.py b/modelscope/pipelines/nlp/siamese_uie_pipeline.py
index c9f86893..21582900 100644
--- a/modelscope/pipelines/nlp/siamese_uie_pipeline.py
+++ b/modelscope/pipelines/nlp/siamese_uie_pipeline.py
@@ -245,7 +245,7 @@ class SiameseUiePipeline(Pipeline):
with torch.no_grad():
with autocast():
for batch_data in zip(*all_tensor_data):
- batch_head_probs, batch_tail_probs = self.model(
+ batch_head_probs, batch_tail_probs = self.model.fast_inference(
*batch_data)
batch_head_probs, batch_tail_probs = batch_head_probs.tolist(
), batch_tail_probs.tolist() # (b, n, l)
diff --git a/modelscope/pipelines/nlp/text_classification_pipeline.py b/modelscope/pipelines/nlp/text_classification_pipeline.py
index 5b76a571..a300b008 100644
--- a/modelscope/pipelines/nlp/text_classification_pipeline.py
+++ b/modelscope/pipelines/nlp/text_classification_pipeline.py
@@ -61,7 +61,9 @@ class TextClassificationPipeline(Pipeline):
preprocessor=preprocessor,
config_file=config_file,
device=device,
- auto_collate=auto_collate)
+ auto_collate=auto_collate,
+ compile=kwargs.pop('compile', False),
+ compile_options=kwargs.pop('compile_options', {}))
assert isinstance(self.model, Model), \
f'please check whether model config exists in {ModelFile.CONFIGURATION}'
diff --git a/modelscope/preprocessors/common.py b/modelscope/preprocessors/common.py
index aa1db84c..68aaae36 100644
--- a/modelscope/preprocessors/common.py
+++ b/modelscope/preprocessors/common.py
@@ -7,6 +7,7 @@ from typing import Mapping
import numpy as np
import torch
+from modelscope.utils.registry import default_group
from .builder import PREPROCESSORS, build_preprocessor
@@ -28,13 +29,14 @@ class Compose(object):
for transform in transforms:
if isinstance(transform, dict):
if self.field_name is None:
- transform = build_preprocessor(transform, field_name)
+ transform = build_preprocessor(transform, default_group)
else:
# if not found key in field_name, try field_name=None(default_group)
try:
transform = build_preprocessor(transform, field_name)
except KeyError:
- transform = build_preprocessor(transform, None)
+ transform = build_preprocessor(transform,
+ default_group)
elif callable(transform):
pass
else:
@@ -108,7 +110,8 @@ class ToTensor(object):
self.keys = list(data.keys())
for key in self.keys:
- data[key] = to_tensor(data[key])
+ if key in data:
+ data[key] = to_tensor(data[key])
else:
data = to_tensor(data)
@@ -135,9 +138,93 @@ class Filter(object):
reserved_data = {}
for key in self.reserved_keys:
- reserved_data[key] = data[key]
+ if key in data:
+ reserved_data[key] = data[key]
return reserved_data
def __repr__(self):
return self.__class__.__name__ + f'(keys={self.reserved_keys})'
+
+
+def to_numpy(data):
+ """Convert objects of various python types to `numpy.ndarray`.
+
+ Args:
+ data (torch.Tensor | numpy.ndarray | Sequence | int | float): Data to
+ be converted.
+ """
+
+ if isinstance(data, torch.Tensor):
+ return data.numpy()
+ elif isinstance(data, np.ndarray):
+ return data
+ elif isinstance(data, Sequence) and not isinstance(data, str):
+ return np.asarray(data)
+ elif isinstance(data, int):
+ return np.asarray(data, dtype=np.int64)
+ elif isinstance(data, float):
+ return np.asarray(data, dtype=np.float64)
+ else:
+ raise TypeError(f'type {type(data)} cannot be converted to tensor.')
+
+
+@PREPROCESSORS.register_module()
+class ToNumpy(object):
+ """Convert target object to numpy.ndarray.
+
+ Args:
+ keys (Sequence[str]): Key of data to be converted to numpy.ndarray.
+ Only valid when data is type of `Mapping`. If `keys` is None,
+ all values of keys will be converted to numpy.ndarray by default.
+ """
+
+ def __init__(self, keys=None):
+ self.keys = keys
+
+ def __call__(self, data):
+ if isinstance(data, Mapping):
+ if self.keys is None:
+ self.keys = list(data.keys())
+
+ for key in self.keys:
+ if key in data:
+ data[key] = to_numpy(data[key])
+ else:
+ data = to_numpy(data)
+
+ return data
+
+ def __repr__(self):
+ return self.__class__.__name__ + f'(keys={self.keys})'
+
+
+@PREPROCESSORS.register_module()
+class Rename(object):
+ """Change the name of the input keys to output keys, respectively.
+ """
+
+ def __init__(self, input_keys=[], output_keys=[]):
+ self.input_keys = input_keys
+ self.output_keys = output_keys
+
+ def __call__(self, data):
+ if isinstance(data, Mapping):
+ for in_key, out_key in zip(self.input_keys, self.output_keys):
+ if in_key in data and out_key not in data:
+ data[out_key] = data[in_key]
+ data.pop(in_key)
+ return data
+
+ def __repr__(self):
+ return self.__class__.__name__ + f'(keys={self.keys})'
+
+
+@PREPROCESSORS.register_module()
+class Identity(object):
+
+ def __init__(self):
+ pass
+
+ def __call__(self, item):
+ return item
diff --git a/modelscope/preprocessors/cv/__init__.py b/modelscope/preprocessors/cv/__init__.py
index b9165a9d..b832f1e6 100644
--- a/modelscope/preprocessors/cv/__init__.py
+++ b/modelscope/preprocessors/cv/__init__.py
@@ -9,9 +9,11 @@ if TYPE_CHECKING:
from .mmcls_preprocessor import ImageClassificationMmcvPreprocessor
from .image_quality_assessment_mos import ImageQualityAssessmentMosPreprocessor
+ from .image_quality_assessment_man import ImageQualityAssessmentMANPreprocessor
from .image_restoration_preprocessor import ImageRestorationPreprocessor
from .bad_image_detecting_preprocessor import BadImageDetectingPreprocessor
from .controllable_image_generation import ControllableImageGenerationPreprocessor
+ from .image_classification_preprocessor import ImageClassificationPreprocessor
else:
_import_structure = {
@@ -20,10 +22,14 @@ else:
'mmcls_preprocessor': ['ImageClassificationMmcvPreprocessor'],
'image_quality_assessment_mos':
['ImageQualityAssessmentMosPreprocessor'],
+ 'image_quality_assessment_man':
+ ['ImageQualityAssessmentMANPreprocessor'],
'image_restoration_preprocessor': ['ImageRestorationPreprocessor'],
'bad_image_detecting_preprocessor': ['BadImageDetectingPreprocessor'],
'controllable_image_generation':
['ControllableImageGenerationPreprocessor'],
+ 'image_classification_preprocessor':
+ ['ImageClassificationPreprocessor']
}
import sys
diff --git a/modelscope/preprocessors/cv/action_detection_mapper.py b/modelscope/preprocessors/cv/action_detection_mapper.py
new file mode 100644
index 00000000..9bb6d422
--- /dev/null
+++ b/modelscope/preprocessors/cv/action_detection_mapper.py
@@ -0,0 +1,185 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+
+import copy
+import random
+
+import decord
+import numpy as np
+import torch
+from detectron2.data.transforms import (ExtentTransform, RandomBrightness,
+ RandomFlip, ResizeShortestEdge)
+from detectron2.structures import Boxes, Instances
+from scipy.interpolate import interp1d
+
+
+def inp_boxes(boxes: dict, start, end):
+ idxs = sorted([int(i) for i in boxes.keys()])
+ bbox = [boxes[str(i)] for i in idxs]
+ new_bboxes = []
+ for i in range(4):
+ f = interp1d(idxs, [b[i] for b in bbox])
+ new_b = f(list(range(start, end + 1)))
+ new_bboxes.append(new_b)
+ new_bboxes = np.stack(new_bboxes, axis=1)
+ return new_bboxes
+
+
+def assign_label(start, end, data_dict):
+ """
+ 根据视频起始位置,以及标注的label,给这小段视频安排bbox检测标签
+ 方法,取交集,交集占到样本的一半或者标签的一半,即将该label赋给样本
+ :param start: 起始帧号(含)
+ :param end: 结束帧号(含)
+ :param labels: 标注的label, 字符串形式
+ :return:[[行为,x1,y1,x2,y2],]
+ """
+ if 'actions' not in data_dict:
+ return []
+ scale = data_dict['scale']
+ gt_labels = []
+ for action in data_dict['actions']:
+ low = max(int(action['start']), start)
+ high = min(int(action['end']), end)
+ inter = 0 if low > high else high - low
+ if inter > (end - start) * 0.7 or inter > (action['end']
+ - action['start']) * 0.7:
+ boxes = inp_boxes(action['boxes'], low, high)
+ box = boxes.mean(axis=0) / scale
+ label = [action['label']] + box.tolist()
+ gt_labels.append(label)
+ return gt_labels
+
+
+class VideoDetMapper:
+
+ def __init__(self,
+ classes_id_map,
+ used_seconds=2,
+ input_frames=4,
+ is_train=True,
+ tile=False):
+ self.classes_id = classes_id_map
+ self.is_train = is_train
+ self.used_seconds = used_seconds
+ self.input_frames = input_frames
+ self.tile = tile
+ self.trans = [RandomBrightness(0.5, 1.5)]
+ self.tfm_gens = [
+ ResizeShortestEdge((480, 512, 544, 576, 608, 640, 672, 704, 736,
+ 768) if is_train else 512,
+ 1280 if is_train else 896, 'choice')
+ ]
+ if is_train:
+ self.tfm_gens.append(RandomFlip())
+
+ def __call__(self, data_dict):
+ data_dict = copy.deepcopy(data_dict)
+ try:
+ data_dict = self._call(data_dict)
+ except Exception as e:
+ print(data_dict['path:FILE'], e)
+ data_dict = None
+ return data_dict
+
+ def _call(self, data_dict):
+ video_name = data_dict['path:FILE']
+ if data_dict['actions'] is not None:
+ data_dict['actions'] = eval(data_dict['actions'])
+ else:
+ data_dict['actions'] = []
+
+ v = decord.VideoReader(video_name, ctx=decord.cpu(0))
+ num_frames = len(v)
+ used_frames = max(int((1 + random.random()) * v.get_avg_fps()), 1)
+ if self.is_train:
+ start_idx = random.randint(0, max(0, num_frames - used_frames))
+ else:
+ start_idx = max(0, num_frames - used_frames) // 2
+ idxs = np.linspace(start_idx, min(start_idx + used_frames, num_frames) - 1, self.input_frames) \
+ .round().astype('int32').tolist()
+ imgs = v.get_batch(idxs).asnumpy()
+ del v
+ labels = assign_label(idxs[0], idxs[-1] + 1, data_dict)
+ bboxes = np.array([label[-4:] for label in labels])
+
+ if self.is_train:
+ if self.tile:
+ imgs, labels, bboxes = self.random_tile(
+ video_name, imgs, labels, bboxes, pos_choices=[1, 1, 2, 4])
+ else:
+ imgs, labels, bboxes = self.random_tile(
+ video_name, imgs, labels, bboxes, pos_choices=[1])
+
+ for g in self.trans:
+ tfm = g.get_transform(imgs)
+ imgs = tfm.apply_image(imgs)
+ imgs, bboxes = self.random_extent(imgs, bboxes)
+
+ for trans in self.tfm_gens:
+ tfm = trans.get_transform(imgs[0])
+ imgs = np.stack([tfm.apply_image(img) for img in imgs])
+ bboxes = tfm.apply_box(bboxes)
+
+ _, h, w, c = imgs.shape
+ data_dict['height'] = h
+ data_dict['width'] = w
+ gt_boxes = Boxes(torch.from_numpy(bboxes)) # XYXY_ABS
+ gt_classes = [self.classes_id[label[0]]
+ for label in labels] # N is background
+ instances = Instances((data_dict['height'], data_dict['width']))
+ instances.set('gt_boxes', gt_boxes)
+ instances.set('gt_classes',
+ torch.as_tensor(gt_classes, dtype=torch.int64))
+ data_dict['instances'] = instances
+ data_dict['frames'] = torch.as_tensor(
+ np.ascontiguousarray(imgs.transpose([3, 0, 1, 2])))
+ return data_dict
+
+ def random_tile(self, name, imgs, labels, bboxes,
+ pos_choices=(1, 1, 2, 4)):
+ _, h, w, c = imgs.shape
+ bboxes = bboxes.tolist()
+ if len(labels) == 0: # 负样本 1/2, 1, 2, 4
+ ratio = random.choice([0, 1, 2, 4])
+ if ratio == 0: # 随机取部分区域
+ h0, w0 = random.randint(0, h // 2), random.randint(0, w // 2)
+ imgs = imgs[:, h0:h0 + h // 2, w0:w0 + h // 2]
+ elif ratio == 2:
+ imgs = np.tile(imgs,
+ (1, 1, 2,
+ 1)) if h > w else np.tile(imgs, (1, 2, 1, 1))
+ elif ratio == 4:
+ imgs = np.tile(imgs, (1, 2, 2, 1))
+ else: # 正样本 1, 2, 4
+ ratio = random.choice(pos_choices)
+ if ratio == 2:
+ labels = labels * 2
+ if h >= w: # 左右拼接
+ imgs = np.tile(imgs, (1, 1, 2, 1))
+ bbox2 = [[x1 + w, y1, x2 + w, y2]
+ for x1, y1, x2, y2 in bboxes]
+ else: # 上下拼接
+ imgs = np.tile(imgs, (1, 2, 1, 1))
+ bbox2 = [[x1, y1 + h, x2, y2 + h]
+ for x1, y1, x2, y2 in bboxes]
+ bboxes = bboxes + bbox2
+ elif ratio == 4:
+ labels = labels * 4
+ imgs = np.tile(imgs, (1, 2, 2, 1))
+ bbox2 = [[x1 + w, y1, x2 + w, y2] for x1, y1, x2, y2 in bboxes] + \
+ [[x1, y1 + h, x2, y2 + h] for x1, y1, x2, y2 in bboxes] + \
+ [[x1 + w, y1 + h, x2 + w, y2 + h] for x1, y1, x2, y2 in bboxes]
+ bboxes = bboxes + bbox2
+ bboxes = np.array(bboxes)
+ return imgs.copy(), labels, bboxes
+
+ def random_extent(self, imgs, bboxes):
+ t, h, w, c = imgs.shape
+ r_h, r_w = int(h * 0.1), int(w * 0.1)
+ x0, y0 = random.randint(-r_w, r_w), random.randint(-r_h, r_h)
+ x1, y1 = random.randint(w - r_w,
+ w + r_w), random.randint(h - r_h, h + r_h)
+ tfm = ExtentTransform((x0, y0, x1, y1), output_size=(y1 - y0, x1 - x0))
+ imgs = np.stack([tfm.apply_image(img) for img in imgs])
+ bboxes = tfm.apply_box(bboxes)
+ return imgs, bboxes
diff --git a/modelscope/preprocessors/cv/cv2_transforms.py b/modelscope/preprocessors/cv/cv2_transforms.py
new file mode 100644
index 00000000..cb8b8b1f
--- /dev/null
+++ b/modelscope/preprocessors/cv/cv2_transforms.py
@@ -0,0 +1,559 @@
+# The implementation is adopted from opencv_transforms,
+# made publicly available under the MIT license at
+# https://github.com/jbohnslav/opencv_transforms/blob/master/opencv_transforms/functional.py
+# https://github.com/jbohnslav/opencv_transforms/blob/master/opencv_transforms/transforms.py
+
+import collections
+import math
+import numbers
+import random
+
+import cv2
+import numpy as np
+import torch
+
+_cv2_pad_to_str = {
+ 'constant': cv2.BORDER_CONSTANT,
+ 'edge': cv2.BORDER_REPLICATE,
+ 'reflect': cv2.BORDER_REFLECT_101,
+ 'symmetric': cv2.BORDER_REFLECT
+}
+_cv2_interpolation_to_str = {
+ 'nearest': cv2.INTER_NEAREST,
+ 'bilinear': cv2.INTER_LINEAR,
+ 'area': cv2.INTER_AREA,
+ 'bicubic': cv2.INTER_CUBIC,
+ 'lanczos': cv2.INTER_LANCZOS4
+}
+_cv2_interpolation_from_str = {
+ v: k
+ for k, v in _cv2_interpolation_to_str.items()
+}
+
+
+def _is_tensor_image(img):
+ return torch.is_tensor(img) and img.ndimension() == 3
+
+
+def _is_numpy_image(img):
+ return isinstance(img, np.ndarray) and (img.ndim in {2, 3})
+
+
+def to_tensor(pic):
+ """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
+ See ``ToTensor`` for more details.
+ Args:
+ pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
+ Returns:
+ Tensor: Converted image.
+ """
+ if not (_is_numpy_image(pic)):
+ raise TypeError('pic should be ndarray. Got {}'.format(type(pic)))
+
+ # handle numpy array
+ img = torch.from_numpy(pic.transpose((2, 0, 1)))
+ # backward compatibility
+ if isinstance(img, torch.ByteTensor) or img.dtype == torch.uint8:
+ return img.float().div(255)
+ else:
+ return img
+
+
+def normalize(tensor, mean, std):
+ """Normalize a tensor image with mean and standard deviation.
+ .. note::
+ This transform acts in-place, i.e., it mutates the input tensor.
+ See :class:`~torchvision.transforms.Normalize` for more details.
+ Args:
+ tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
+ mean (sequence): Sequence of means for each channel.
+ std (sequence): Sequence of standard deviations for each channely.
+ Returns:
+ Tensor: Normalized Tensor image.
+ """
+ if not _is_tensor_image(tensor):
+ raise TypeError('tensor is not a torch image.')
+
+ # This is faster than using broadcasting, don't change without benchmarking
+ for t, m, s in zip(tensor, mean, std):
+ t.sub_(m).div_(s)
+ return tensor
+
+
+def resize(img, size, interpolation=cv2.INTER_LINEAR):
+ r"""Resize the input numpy ndarray to the given size.
+ Args:
+ img (numpy ndarray): Image to be resized.
+ size (sequence or int): Desired output size. If size is a sequence like
+ (h, w), the output size will be matched to this. If size is an int,
+ the smaller edge of the image will be matched to this number maintaing
+ the aspect ratio. i.e, if height > width, then image will be rescaled to
+ :math:`\left(\text{size} \times \frac{\text{height}}{\text{width}}, \text{size}\right)`
+ interpolation (int, optional): Desired interpolation. Default is
+ ``cv2.INTER_LINEAR``
+ Returns:
+ PIL Image: Resized image.
+ """
+ if not _is_numpy_image(img):
+ raise TypeError('img should be numpy image. Got {}'.format(type(img)))
+ if not (isinstance(size, int) or # noqa: W504
+ (isinstance(size, collections.abc.Iterable) and len(size) == 2)):
+ raise TypeError('Got inappropriate size arg: {}'.format(size))
+ h, w = img.shape[0], img.shape[1]
+
+ if isinstance(size, int):
+ if (w <= h and w == size) or (h <= w and h == size):
+ return img
+ if w < h:
+ ow = size
+ oh = int(size * h / w)
+ else:
+ oh = size
+ ow = int(size * w / h)
+ else:
+ ow, oh = size[1], size[0]
+ output = cv2.resize(img, dsize=(ow, oh), interpolation=interpolation)
+ if img.shape[2] == 1:
+ return output[:, :, np.newaxis]
+ else:
+ return output
+
+
+def pad(img, padding, fill=0, padding_mode='constant'):
+ r"""Pad the given numpy ndarray on all sides with specified padding mode and fill value.
+ Args:
+ img (numpy ndarray): image to be padded.
+ padding (int or tuple): Padding on each border. If a single int is provided this
+ is used to pad all borders. If tuple of length 2 is provided this is the padding
+ on left/right and top/bottom respectively. If a tuple of length 4 is provided
+ this is the padding for the left, top, right and bottom borders
+ respectively.
+ fill: Pixel fill value for constant fill. Default is 0. If a tuple of
+ length 3, it is used to fill R, G, B channels respectively.
+ This value is only used when the padding_mode is constant
+ padding_mode: Type of padding. Should be: constant, edge, reflect or symmetric. Default is constant.
+ - constant: pads with a constant value, this value is specified with fill
+ - edge: pads with the last value on the edge of the image
+ - reflect: pads with reflection of image (without repeating the last value on the edge)
+ padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
+ will result in [3, 2, 1, 2, 3, 4, 3, 2]
+ - symmetric: pads with reflection of image (repeating the last value on the edge)
+ padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
+ will result in [2, 1, 1, 2, 3, 4, 4, 3]
+ Returns:
+ Numpy image: padded image.
+ """
+ if not _is_numpy_image(img):
+ raise TypeError('img should be numpy ndarray. Got {}'.format(
+ type(img)))
+ if not isinstance(padding, (numbers.Number, tuple, list)):
+ raise TypeError('Got inappropriate padding arg')
+ if not isinstance(fill, (numbers.Number, str, tuple)):
+ raise TypeError('Got inappropriate fill arg')
+ if not isinstance(padding_mode, str):
+ raise TypeError('Got inappropriate padding_mode arg')
+ if isinstance(padding,
+ collections.Sequence) and len(padding) not in [2, 4]:
+ raise ValueError(
+ 'Padding must be an int or a 2, or 4 element tuple, not a '
+ + '{} element tuple'.format(len(padding)))
+
+ assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric'], \
+ 'Padding mode should be either constant, edge, reflect or symmetric'
+
+ if isinstance(padding, int):
+ pad_left = pad_right = pad_top = pad_bottom = padding
+ if isinstance(padding, collections.Sequence) and len(padding) == 2:
+ pad_left = pad_right = padding[0]
+ pad_top = pad_bottom = padding[1]
+ if isinstance(padding, collections.Sequence) and len(padding) == 4:
+ pad_left = padding[0]
+ pad_top = padding[1]
+ pad_right = padding[2]
+ pad_bottom = padding[3]
+ if img.shape[2] == 1:
+ return cv2.copyMakeBorder(
+ img,
+ top=pad_top,
+ bottom=pad_bottom,
+ left=pad_left,
+ right=pad_right,
+ borderType=_cv2_pad_to_str[padding_mode],
+ value=fill)[:, :, np.newaxis]
+ else:
+ return cv2.copyMakeBorder(
+ img,
+ top=pad_top,
+ bottom=pad_bottom,
+ left=pad_left,
+ right=pad_right,
+ borderType=_cv2_pad_to_str[padding_mode],
+ value=fill)
+
+
+def crop(img, i, j, h, w):
+ """Crop the given PIL Image.
+ Args:
+ img (numpy ndarray): Image to be cropped.
+ i: Upper pixel coordinate.
+ j: Left pixel coordinate.
+ h: Height of the cropped image.
+ w: Width of the cropped image.
+ Returns:
+ numpy ndarray: Cropped image.
+ """
+ if not _is_numpy_image(img):
+ raise TypeError('img should be numpy image. Got {}'.format(type(img)))
+
+ return img[i:i + h, j:j + w, :]
+
+
+def center_crop(img, output_size):
+ if isinstance(output_size, numbers.Number):
+ output_size = (int(output_size), int(output_size))
+ h, w = img.shape[0:2]
+ th, tw = output_size
+ i = int(round((h - th) / 2.))
+ j = int(round((w - tw) / 2.))
+ return crop(img, i, j, th, tw)
+
+
+def resized_crop(img, i, j, h, w, size, interpolation=cv2.INTER_LINEAR):
+ """Crop the given numpy ndarray and resize it to desired size.
+ Notably used in :class:`~torchvision.transforms.RandomResizedCrop`.
+ Args:
+ img (numpy ndarray): Image to be cropped.
+ i: Upper pixel coordinate.
+ j: Left pixel coordinate.
+ h: Height of the cropped image.
+ w: Width of the cropped image.
+ size (sequence or int): Desired output size. Same semantics as ``scale``.
+ interpolation (int, optional): Desired interpolation. Default is
+ ``cv2.INTER_CUBIC``.
+ Returns:
+ PIL Image: Cropped image.
+ """
+ assert _is_numpy_image(img), 'img should be numpy image'
+ img = crop(img, i, j, h, w)
+ img = resize(img, size, interpolation=interpolation)
+ return img
+
+
+def hflip(img):
+ """Horizontally flip the given numpy ndarray.
+ Args:
+ img (numpy ndarray): image to be flipped.
+ Returns:
+ numpy ndarray: Horizontally flipped image.
+ """
+ if not _is_numpy_image(img):
+ raise TypeError('img should be numpy image. Got {}'.format(type(img)))
+ # img[:,::-1] is much faster, but doesn't work with torch.from_numpy()!
+ if img.shape[2] == 1:
+ return cv2.flip(img, 1)[:, :, np.newaxis]
+ else:
+ return cv2.flip(img, 1)
+
+
+class ToTensor(object):
+ """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
+ Converts a PIL Image or numpy.ndarray (H x W x C) in the range
+ [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
+ """
+
+ def __call__(self, pic):
+ """
+ Args:
+ pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
+ Returns:
+ Tensor: Converted image.
+ """
+ return to_tensor(pic)
+
+ def __repr__(self):
+ return self.__class__.__name__ + '()'
+
+
+class Normalize(object):
+ """Normalize a tensor image with mean and standard deviation.
+ Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels, this transform
+ will normalize each channel of the input ``torch.*Tensor`` i.e.
+ ``input[channel] = (input[channel] - mean[channel]) / std[channel]``
+ .. note::
+ This transform acts in-place, i.e., it mutates the input tensor.
+ Args:
+ mean (sequence): Sequence of means for each channel.
+ std (sequence): Sequence of standard deviations for each channel.
+ """
+
+ def __init__(self, mean, std):
+ self.mean = mean
+ self.std = std
+
+ def __call__(self, tensor):
+ """
+ Args:
+ tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
+ Returns:
+ Tensor: Normalized Tensor image.
+ """
+ return normalize(tensor, self.mean, self.std)
+
+ def __repr__(self):
+ return self.__class__.__name__ + '(mean={0}, std={1})'.format(
+ self.mean, self.std)
+
+
+class Resize(object):
+ """Resize the input numpy ndarray to the given size.
+ Args:
+ size (sequence or int): Desired output size. If size is a sequence like
+ (h, w), output size will be matched to this. If size is an int,
+ smaller edge of the image will be matched to this number.
+ i.e, if height > width, then image will be rescaled to
+ (size * height / width, size)
+ interpolation (int, optional): Desired interpolation. Default is
+ ``cv2.INTER_CUBIC``, bicubic interpolation
+ """
+
+ def __init__(self, size, interpolation=cv2.INTER_LINEAR):
+ # assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)
+ if isinstance(size, int):
+ self.size = size
+ elif isinstance(size, collections.abc.Iterable) and len(size) == 2:
+ if type(size) == list:
+ size = tuple(size)
+ self.size = size
+ else:
+ raise ValueError('Unknown inputs for size: {}'.format(size))
+ self.interpolation = interpolation
+
+ def __call__(self, img):
+ """
+ Args:
+ img (numpy ndarray): Image to be scaled.
+ Returns:
+ numpy ndarray: Rescaled image.
+ """
+ return resize(img, self.size, self.interpolation)
+
+ def __repr__(self):
+ interpolate_str = _cv2_interpolation_from_str[self.interpolation]
+ return self.__class__.__name__ + '(size={0}, interpolation={1})'.format(
+ self.size, interpolate_str)
+
+
+class CenterCrop(object):
+ """Crops the given numpy ndarray at the center.
+ Args:
+ size (sequence or int): Desired output size of the crop. If size is an
+ int instead of sequence like (h, w), a square crop (size, size) is
+ made.
+ """
+
+ def __init__(self, size):
+ if isinstance(size, numbers.Number):
+ self.size = (int(size), int(size))
+ else:
+ self.size = size
+
+ def __call__(self, img):
+ """
+ Args:
+ img (numpy ndarray): Image to be cropped.
+ Returns:
+ numpy ndarray: Cropped image.
+ """
+ return center_crop(img, self.size)
+
+ def __repr__(self):
+ return self.__class__.__name__ + '(size={0})'.format(self.size)
+
+
+class RandomCrop(object):
+ """Crop the given numpy ndarray at a random location.
+ Args:
+ size (sequence or int): Desired output size of the crop. If size is an
+ int instead of sequence like (h, w), a square crop (size, size) is
+ made.
+ padding (int or sequence, optional): Optional padding on each border
+ of the image. Default is None, i.e no padding. If a sequence of length
+ 4 is provided, it is used to pad left, top, right, bottom borders
+ respectively. If a sequence of length 2 is provided, it is used to
+ pad left/right, top/bottom borders, respectively.
+ pad_if_needed (boolean): It will pad the image if smaller than the
+ desired size to avoid raising an exception.
+ fill: Pixel fill value for constant fill. Default is 0. If a tuple of
+ length 3, it is used to fill R, G, B channels respectively.
+ This value is only used when the padding_mode is constant
+ padding_mode: Type of padding. Should be: constant, edge, reflect or symmetric. Default is constant.
+ - constant: pads with a constant value, this value is specified with fill
+ - edge: pads with the last value on the edge of the image
+ - reflect: pads with reflection of image (without repeating the last value on the edge)
+ padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
+ will result in [3, 2, 1, 2, 3, 4, 3, 2]
+ - symmetric: pads with reflection of image (repeating the last value on the edge)
+ padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
+ will result in [2, 1, 1, 2, 3, 4, 4, 3]
+ """
+
+ def __init__(self,
+ size,
+ padding=None,
+ pad_if_needed=False,
+ fill=0,
+ padding_mode='constant'):
+ if isinstance(size, numbers.Number):
+ self.size = (int(size), int(size))
+ else:
+ self.size = size
+ self.padding = padding
+ self.pad_if_needed = pad_if_needed
+ self.fill = fill
+ self.padding_mode = padding_mode
+
+ @staticmethod
+ def get_params(img, output_size):
+ """Get parameters for ``crop`` for a random crop.
+ Args:
+ img (numpy ndarray): Image to be cropped.
+ output_size (tuple): Expected output size of the crop.
+ Returns:
+ tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
+ """
+ h, w = img.shape[0:2]
+ th, tw = output_size
+ if w == tw and h == th:
+ return 0, 0, h, w
+
+ i = random.randint(0, h - th)
+ j = random.randint(0, w - tw)
+ return i, j, th, tw
+
+ def __call__(self, img):
+ """
+ Args:
+ img (numpy ndarray): Image to be cropped.
+ Returns:
+ numpy ndarray: Cropped image.
+ """
+ if self.padding is not None:
+ img = pad(img, self.padding, self.fill, self.padding_mode)
+
+ # pad the width if needed
+ if self.pad_if_needed and img.shape[1] < self.size[1]:
+ img = pad(img, (self.size[1] - img.shape[1], 0), self.fill,
+ self.padding_mode)
+ # pad the height if needed
+ if self.pad_if_needed and img.shape[0] < self.size[0]:
+ img = pad(img, (0, self.size[0] - img.shape[0]), self.fill,
+ self.padding_mode)
+
+ i, j, h, w = self.get_params(img, self.size)
+
+ return crop(img, i, j, h, w)
+
+ def __repr__(self):
+ return self.__class__.__name__ + '(size={0}, padding={1})'.format(
+ self.size, self.padding)
+
+
+class RandomResizedCrop(object):
+ """Crop the given numpy ndarray to random size and aspect ratio.
+ A crop of random size (default: of 0.08 to 1.0) of the original size and a random
+ aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop
+ is finally resized to given size.
+ This is popularly used to train the Inception networks.
+ Args:
+ size: expected output size of each edge
+ scale: range of size of the origin size cropped
+ ratio: range of aspect ratio of the origin aspect ratio cropped
+ interpolation: Default: cv2.INTER_CUBIC
+ """
+
+ def __init__(self,
+ size,
+ scale=(0.08, 1.0),
+ ratio=(3. / 4., 4. / 3.),
+ interpolation=cv2.INTER_LINEAR):
+ self.size = (size, size)
+ self.interpolation = interpolation
+ self.scale = scale
+ self.ratio = ratio
+
+ @staticmethod
+ def get_params(img, scale, ratio):
+ """Get parameters for ``crop`` for a random sized crop.
+ Args:
+ img (numpy ndarray): Image to be cropped.
+ scale (tuple): range of size of the origin size cropped
+ ratio (tuple): range of aspect ratio of the origin aspect ratio cropped
+ Returns:
+ tuple: params (i, j, h, w) to be passed to ``crop`` for a random
+ sized crop.
+ """
+ for attempt in range(10):
+ area = img.shape[0] * img.shape[1]
+ target_area = random.uniform(*scale) * area
+ aspect_ratio = random.uniform(*ratio)
+
+ w = int(round(math.sqrt(target_area * aspect_ratio)))
+ h = int(round(math.sqrt(target_area / aspect_ratio)))
+
+ if random.random() < 0.5:
+ w, h = h, w
+
+ if w <= img.shape[1] and h <= img.shape[0]:
+ i = random.randint(0, img.shape[0] - h)
+ j = random.randint(0, img.shape[1] - w)
+ return i, j, h, w
+
+ # Fallback
+ w = min(img.shape[0], img.shape[1])
+ i = (img.shape[0] - w) // 2
+ j = (img.shape[1] - w) // 2
+ return i, j, w, w
+
+ def __call__(self, img):
+ """
+ Args:
+ img (numpy ndarray): Image to be cropped and resized.
+ Returns:
+ numpy ndarray: Randomly cropped and resized image.
+ """
+ i, j, h, w = self.get_params(img, self.scale, self.ratio)
+ return resized_crop(img, i, j, h, w, self.size, self.interpolation)
+
+ def __repr__(self):
+ interpolate_str = _cv2_interpolation_from_str[self.interpolation]
+ format_string = self.__class__.__name__ + '(size={0}'.format(self.size)
+ format_string += ', scale={0}'.format(
+ tuple(round(s, 4) for s in self.scale))
+ format_string += ', ratio={0}'.format(
+ tuple(round(r, 4) for r in self.ratio))
+ format_string += ', interpolation={0})'.format(interpolate_str)
+ return format_string
+
+
+class RandomHorizontalFlip(object):
+ """Horizontally flip the given PIL Image randomly with a given probability.
+ Args:
+ p (float): probability of the image being flipped. Default value is 0.5
+ """
+
+ def __init__(self, p=0.5):
+ self.p = p
+
+ def __call__(self, img):
+ """random
+ Args:
+ img (numpy ndarray): Image to be flipped.
+ Returns:
+ numpy ndarray: Randomly flipped image.
+ """
+ if random.random() < self.p:
+ return hflip(img)
+ return img
+
+ def __repr__(self):
+ return self.__class__.__name__ + '(p={})'.format(self.p)
diff --git a/modelscope/preprocessors/cv/image_classification_preprocessor.py b/modelscope/preprocessors/cv/image_classification_preprocessor.py
new file mode 100644
index 00000000..fa98315b
--- /dev/null
+++ b/modelscope/preprocessors/cv/image_classification_preprocessor.py
@@ -0,0 +1,340 @@
+# Copyright 2022-2023 The Alibaba Fundamental Vision Team Authors. All rights reserved.
+# The part implementation is also open-sourced by the authors,
+# and available at https://github.com/alibaba/EssentialMC2
+import os
+from typing import Any, Dict
+
+import cv2
+import numpy as np
+import torch
+import torchvision.transforms as transforms
+from PIL import Image
+from torchvision.transforms.functional import InterpolationMode
+
+import modelscope.preprocessors.cv.cv2_transforms as cv2_transforms
+from modelscope.fileio import File
+from modelscope.metainfo import Preprocessors
+from modelscope.preprocessors.base import Preprocessor
+from modelscope.preprocessors.builder import PREPROCESSORS, build_preprocessor
+from modelscope.utils.constant import Fields, ModeKeys
+from modelscope.utils.registry import default_group
+
+BACKEND_TORCHVISION = 'torchvision'
+BACKEND_PILLOW = 'pillow'
+BACKEND_CV2 = 'cv2'
+BACKENDS = (BACKEND_PILLOW, BACKEND_CV2, BACKEND_TORCHVISION)
+
+INTERPOLATION_STYLE = {
+ 'bilinear': InterpolationMode('bilinear'),
+ 'nearest': InterpolationMode('nearest'),
+ 'bicubic': InterpolationMode('bicubic'),
+}
+INTERPOLATION_STYLE_CV2 = {
+ 'bilinear': cv2.INTER_LINEAR,
+ 'nearest': cv2.INTER_NEAREST,
+ 'bicubic': cv2.INTER_CUBIC,
+}
+
+
+def is_pil_image(img):
+ return isinstance(img, Image.Image)
+
+
+def is_cv2_image(img):
+ return isinstance(img, np.ndarray) and img.dtype == np.uint8
+
+
+def is_tensor(t):
+ return isinstance(t, torch.Tensor)
+
+
+class ImageTransform(object):
+
+ def __init__(self,
+ backend=BACKEND_PILLOW,
+ input_key=None,
+ output_key=None):
+ self.input_key = input_key or 'img'
+ self.output_key = output_key or 'img'
+ self.backend = backend
+
+ def check_image_type(self, input_img):
+ if self.backend == BACKEND_PILLOW:
+ assert is_pil_image(input_img), 'input should be PIL Image'
+ elif self.backend == BACKEND_CV2:
+ assert is_cv2_image(
+ input_img), 'input should be cv2 image(uint8 np.ndarray)'
+
+
+@PREPROCESSORS.register_module(Fields.cv)
+class RandomCrop(ImageTransform):
+ """ Crop a random portion of image.
+ If the image is torch Tensor, it is expected to have [..., H, W] shape.
+
+ Args:
+ size (sequence or int): Desired output size.
+ If size is a sequence like (h, w), the output size will be matched to this.
+ If size is an int, the output size will be matched to (size, size).
+ padding (sequence or int): Optional padding on each border of the image. Default is None.
+ pad_if_needed (bool): It will pad the image if smaller than the desired size to avoid raising an exception.
+ fill (number or str or tuple): Pixel fill value for constant fill. Default is 0.
+ padding_mode (str): Type of padding. Should be: constant, edge, reflect or symmetric.
+ Default is constant.
+ """
+
+ def __init__(self,
+ size,
+ padding=None,
+ pad_if_needed=False,
+ fill=0,
+ padding_mode='constant',
+ **kwargs):
+
+ super(RandomCrop, self).__init__(**kwargs)
+ assert self.backend in BACKENDS
+ if self.backend in (BACKEND_PILLOW, BACKEND_TORCHVISION):
+ self.callable = transforms.RandomCrop(
+ size,
+ padding=padding,
+ pad_if_needed=pad_if_needed,
+ fill=fill,
+ padding_mode=padding_mode)
+ else:
+ self.callable = cv2_transforms.RandomCrop(
+ size,
+ padding=padding,
+ pad_if_needed=pad_if_needed,
+ fill=fill,
+ padding_mode=padding_mode)
+
+ def __call__(self, item):
+ self.check_image_type(item[self.input_key])
+ item[self.output_key] = self.callable(item[self.input_key])
+ return item
+
+
+@PREPROCESSORS.register_module(Fields.cv)
+class RandomResizedCrop(ImageTransform):
+ """Crop a random portion of image and resize it to a given size.
+
+ If the image is torch Tensor, it is expected to have [..., H, W] shape.
+
+ Args:
+ size (int or sequence): Desired output size.
+ If size is a sequence like (h, w), the output size will be matched to this.
+ If size is an int, the output size will be matched to (size, size).
+ scale (tuple of float): Specifies the lower and upper bounds for the random area of the crop,
+ before resizing. The scale is defined with respect to the area of the original image.
+ ratio (tuple of float): lower and upper bounds for the random aspect ratio of the crop, before
+ resizing.
+ interpolation (str): Desired interpolation string, 'bilinear', 'nearest', 'bicubic' are supported.
+ """
+
+ def __init__(self,
+ size,
+ scale=(0.08, 1.0),
+ ratio=(3. / 4., 4. / 3.),
+ interpolation='bilinear',
+ **kwargs):
+ super(RandomResizedCrop, self).__init__(**kwargs)
+ assert self.backend in BACKENDS
+ self.interpolation = interpolation
+ if self.backend in (BACKEND_PILLOW, BACKEND_TORCHVISION):
+ assert interpolation in INTERPOLATION_STYLE
+ else:
+ assert interpolation in INTERPOLATION_STYLE_CV2
+ self.callable = transforms.RandomResizedCrop(size, scale, ratio, INTERPOLATION_STYLE[interpolation]) \
+ if self.backend in (BACKEND_PILLOW, BACKEND_TORCHVISION) \
+ else cv2_transforms.RandomResizedCrop(size, scale, ratio, INTERPOLATION_STYLE_CV2[interpolation])
+
+ def __call__(self, item):
+ self.check_image_type(item[self.input_key])
+ item[self.output_key] = self.callable(item[self.input_key])
+ return item
+
+
+@PREPROCESSORS.register_module(Fields.cv)
+class Resize(ImageTransform):
+ """Resize image to a given size.
+
+ If the image is torch Tensor, it is expected to have [..., H, W] shape.
+
+ Args:
+ size (int or sequence): Desired output size.
+ If size is a sequence like (h, w), the output size will be matched to this.
+ If size is an int, the smaller edge of the image will be matched to this
+ number maintaining the aspect ratio.
+ interpolation (str): Desired interpolation string, 'bilinear', 'nearest', 'bicubic' are supported.
+ """
+
+ def __init__(self, size, interpolation='bilinear', **kwargs):
+ super(Resize, self).__init__(**kwargs)
+ assert self.backend in BACKENDS
+ self.size = size
+ self.interpolation = interpolation
+ if self.backend in (BACKEND_PILLOW, BACKEND_TORCHVISION):
+ assert interpolation in INTERPOLATION_STYLE
+ else:
+ assert interpolation in INTERPOLATION_STYLE_CV2
+ self.callable = transforms.Resize(size, INTERPOLATION_STYLE[interpolation]) \
+ if self.backend in (BACKEND_PILLOW, BACKEND_TORCHVISION) \
+ else cv2_transforms.Resize(size, INTERPOLATION_STYLE_CV2[interpolation])
+
+ def __call__(self, item):
+ self.check_image_type(item[self.input_key])
+ item[self.output_key] = self.callable(item[self.input_key])
+ return item
+
+
+@PREPROCESSORS.register_module(Fields.cv)
+class CenterCrop(ImageTransform):
+ """ Crops the given image at the center.
+
+ If the image is torch Tensor, it is expected to have [..., H, W] shape.
+
+ Args:
+ size (sequence or int): Desired output size.
+ If size is a sequence like (h, w), the output size will be matched to this.
+ If size is an int, the output size will be matched to (size, size).
+ """
+
+ def __init__(self, size, **kwargs):
+ super(CenterCrop, self).__init__(**kwargs)
+ assert self.backend in BACKENDS
+ self.size = size
+ self.callable = transforms.CenterCrop(size) \
+ if self.backend in (BACKEND_PILLOW, BACKEND_TORCHVISION) else cv2_transforms.CenterCrop(size)
+
+ def __call__(self, item):
+ self.check_image_type(item[self.input_key])
+ item[self.output_key] = self.callable(item[self.input_key])
+ return item
+
+
+@PREPROCESSORS.register_module(Fields.cv)
+class RandomHorizontalFlip(ImageTransform):
+ """ Horizontally flip the given image randomly with a given probability.
+
+ If the image is torch Tensor, it is expected to have [..., H, W] shape.
+
+ Args:
+ p (float): probability of the image being flipped. Default value is 0.5
+ """
+
+ def __init__(self, p=0.5, **kwargs):
+ super(RandomHorizontalFlip, self).__init__(**kwargs)
+ assert self.backend in BACKENDS
+ self.callable = transforms.RandomHorizontalFlip(p) \
+ if self.backend in (BACKEND_PILLOW, BACKEND_TORCHVISION) else cv2_transforms.RandomHorizontalFlip(p)
+
+ def __call__(self, item):
+ self.check_image_type(item[self.input_key])
+ item[self.output_key] = self.callable(item[self.input_key])
+ return item
+
+
+@PREPROCESSORS.register_module(Fields.cv)
+class Normalize(ImageTransform):
+ """ Normalize a tensor image with mean and standard deviation.
+ This transform only support tensor image.
+
+ Args:
+ mean (sequence): Sequence of means for each channel.
+ std (sequence): Sequence of standard deviations for each channel.
+ """
+
+ def __init__(self, mean, std, **kwargs):
+ super(Normalize, self).__init__(**kwargs)
+ assert self.backend in BACKENDS
+ self.mean = np.array(mean, dtype=np.float32)
+ self.std = np.array(std, dtype=np.float32)
+ self.callable = transforms.Normalize(self.mean, self.std) \
+ if self.backend in (BACKEND_PILLOW, BACKEND_TORCHVISION) else cv2_transforms.Normalize(self.mean, self.std)
+
+ def __call__(self, item):
+ item[self.output_key] = self.callable(item[self.input_key])
+ return item
+
+
+@PREPROCESSORS.register_module(Fields.cv)
+class ImageToTensor(ImageTransform):
+ """ Convert a ``PIL Image`` or ``numpy.ndarray`` or uint8 type tensor to a float32 tensor,
+ and scale output to [0.0, 1.0].
+ """
+
+ def __init__(self, **kwargs):
+ super(ImageToTensor, self).__init__(**kwargs)
+ assert self.backend in BACKENDS
+
+ if self.backend == BACKEND_PILLOW:
+ self.callable = transforms.ToTensor()
+ elif self.backend == BACKEND_CV2:
+ self.callable = cv2_transforms.ToTensor()
+ else:
+ self.callable = transforms.ConvertImageDtype(torch.float)
+
+ def __call__(self, item):
+ item[self.output_key] = self.callable(item[self.input_key])
+ return item
+
+
+def build_preprocess_pipeline(pipeline, group_name=Fields.cv):
+ if isinstance(pipeline, list):
+ if len(pipeline) == 0:
+ return build_preprocessor(
+ dict(type='Identity'), field_name=default_group)
+ elif len(pipeline) == 1:
+ return build_preprocess_pipeline(pipeline[0])
+ else:
+ return build_preprocessor(
+ dict(
+ type='Compose', transforms=pipeline,
+ field_name=group_name),
+ field_name=default_group)
+ elif isinstance(pipeline, dict):
+ return build_preprocessor(pipeline, field_name=group_name)
+ elif pipeline is None:
+ return build_preprocessor(
+ dict(type='Identity'), field_name=default_group)
+ else:
+ raise TypeError(
+ f'Expect pipeline_cfg to be dict or list or None, got {type(pipeline)}'
+ )
+
+
+@PREPROCESSORS.register_module(
+ Fields.cv, module_name=Preprocessors.image_classification_preprocessor)
+class ImageClassificationPreprocessor(Preprocessor):
+
+ def __init__(self, *args, **kwargs):
+ """image classification preprocessor in the fine-tune scenario
+ """
+ super().__init__(*args, **kwargs)
+
+ self.training = kwargs.pop('training', True)
+ self.preprocessor_train_cfg = kwargs.pop('train', None)
+ self.preprocessor_test_cfg = kwargs.pop('val', None)
+
+ if self.preprocessor_train_cfg is not None:
+ self.train_preprocess_pipeline = build_preprocess_pipeline(
+ self.preprocessor_train_cfg)
+
+ if self.preprocessor_test_cfg is not None:
+ self.test_preprocess_pipeline = build_preprocess_pipeline(
+ self.preprocessor_test_cfg)
+
+ def __call__(self, results: Dict[str, Any]):
+ """process the raw input data
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ Dict[str, Any] | None: the preprocessed data
+ """
+ if self.mode == ModeKeys.TRAIN:
+ pipline = self.train_preprocess_pipeline
+ else:
+ pipline = self.test_preprocess_pipeline
+
+ return pipline(results)
diff --git a/modelscope/preprocessors/cv/image_quality_assessment_man.py b/modelscope/preprocessors/cv/image_quality_assessment_man.py
new file mode 100644
index 00000000..0f34dca8
--- /dev/null
+++ b/modelscope/preprocessors/cv/image_quality_assessment_man.py
@@ -0,0 +1,38 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import math
+from typing import Any, Dict
+
+import torch
+import torch.nn.functional as F
+from numpy import ndarray
+from PIL import Image
+from torchvision import transforms
+
+from modelscope.metainfo import Preprocessors
+from modelscope.preprocessors import load_image
+from modelscope.preprocessors.base import Preprocessor
+from modelscope.preprocessors.builder import PREPROCESSORS
+from modelscope.utils.constant import Fields
+from modelscope.utils.type_assert import type_assert
+
+
+@PREPROCESSORS.register_module(
+ Fields.cv,
+ module_name=Preprocessors.image_quality_assessment_man_preprocessor)
+class ImageQualityAssessmentMANPreprocessor(Preprocessor):
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ self.transform_input = transforms.Compose([
+ transforms.Resize(224),
+ transforms.CenterCrop(224),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
+ ])
+
+ @type_assert(object, object)
+ def __call__(self, data) -> Dict[str, Any]:
+ image = load_image(data)
+ data = self.transform_input(image)
+ data = data.unsqueeze(0)
+ return {'input': data.float()}
diff --git a/modelscope/preprocessors/image.py b/modelscope/preprocessors/image.py
index 666d2b29..36ab2f2f 100644
--- a/modelscope/preprocessors/image.py
+++ b/modelscope/preprocessors/image.py
@@ -24,10 +24,12 @@ class LoadImage:
"scale_factor" (1.0) and "img_norm_cfg" (means=0 and stds=1).
Args:
mode (str): See :ref:`PIL.Mode`.
+ backend (str): Type of loading image. Should be: cv2 or pillow. Default is pillow.
"""
- def __init__(self, mode='rgb'):
+ def __init__(self, mode='rgb', backend='pillow'):
self.mode = mode.upper()
+ self.backend = backend
def __call__(self, input: Union[str, Dict[str, str]]):
"""Call functions to load image and get image meta information.
@@ -42,21 +44,38 @@ class LoadImage:
else:
image_path_or_url = input
- bytes = File.read(image_path_or_url)
- # TODO @wenmeng.zwm add opencv decode as optional
- # we should also look at the input format which is the most commonly
- # used in Mind' image related models
- with io.BytesIO(bytes) as infile:
- img = Image.open(infile)
- img = ImageOps.exif_transpose(img)
- img = img.convert(self.mode)
+ if self.backend == 'cv2':
+ storage = File._get_storage(image_path_or_url)
+ with storage.as_local_path(image_path_or_url) as img_path:
+ img = cv2.imread(img_path, cv2.IMREAD_COLOR)
+ if self.mode == 'RGB':
+ cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img)
+ img_h, img_w, img_c = img.shape[0], img.shape[1], img.shape[2]
+ img_shape = (img_h, img_w, img_c)
+ elif self.backend == 'pillow':
+ bytes = File.read(image_path_or_url)
+ # TODO @wenmeng.zwm add opencv decode as optional
+ # we should also look at the input format which is the most commonly
+ # used in Mind' image related models
+ with io.BytesIO(bytes) as infile:
+ img = Image.open(infile)
+ img = ImageOps.exif_transpose(img)
+ img = img.convert(self.mode)
+ img_shape = (img.size[1], img.size[0], 3)
+ else:
+ raise TypeError(f'backend should be either cv2 or pillow,'
+ f'but got {self.backend}')
results = {
'filename': image_path_or_url,
'img': img,
- 'img_shape': (img.size[1], img.size[0], 3),
+ 'img_shape': img_shape,
'img_field': 'img',
}
+ if isinstance(input, dict):
+ input_ret = input.copy()
+ input_ret.update(results)
+ results = input_ret
return results
def __repr__(self):
diff --git a/modelscope/preprocessors/nlp/siamese_uie_preprocessor.py b/modelscope/preprocessors/nlp/siamese_uie_preprocessor.py
index a224cd67..d77a9cd3 100644
--- a/modelscope/preprocessors/nlp/siamese_uie_preprocessor.py
+++ b/modelscope/preprocessors/nlp/siamese_uie_preprocessor.py
@@ -25,7 +25,6 @@ class SiameseUiePreprocessor(Preprocessor):
**kwargs,
):
"""preprocess the data
-`
Args:
model_dir (str): model path
"""
diff --git a/modelscope/preprocessors/nlp/text_generation_preprocessor.py b/modelscope/preprocessors/nlp/text_generation_preprocessor.py
index c3b91485..734ddbc2 100644
--- a/modelscope/preprocessors/nlp/text_generation_preprocessor.py
+++ b/modelscope/preprocessors/nlp/text_generation_preprocessor.py
@@ -107,7 +107,7 @@ class TextGenerationTransformersPreprocessor(TextGenerationPreprocessorBase):
mode: str = ModeKeys.INFERENCE,
src_txt='src_txt',
tgt_txt='tgt_txt',
- max_length: int = None,
+ sequence_length: int = None,
use_fast: bool = None,
keep_original_columns=None,
**kwargs):
@@ -118,7 +118,7 @@ class TextGenerationTransformersPreprocessor(TextGenerationPreprocessorBase):
mode: The mode for the preprocessor.
src_txt: The key of the source sentence.
tgt_txt: The key of the generated sentence.
- max_length: The max sequence length which the model supported,
+ sequence_length: The max sequence length which the model supported,
will be passed into tokenizer as the 'max_length' param.
use_fast: Whether to use the fast tokenizer or not.
**kwargs: Extra args input into the tokenizer's __call__ method.
@@ -130,10 +130,10 @@ class TextGenerationTransformersPreprocessor(TextGenerationPreprocessorBase):
kwargs['padding'] = kwargs.get('padding', 'max_length')
kwargs['return_token_type_ids'] = kwargs.get('return_token_type_ids',
False)
- kwargs[
- 'max_length'] = max_length if max_length is not None else kwargs.get(
- 'sequence_length', 128)
- kwargs.pop('sequence_length', None)
+ # sequence_length > max_length
+ kwargs['max_length'] = sequence_length if sequence_length is not None \
+ else kwargs.get('max_length', 128)
+
self.src_length = kwargs['max_length']
self.tgt_length = kwargs.pop('target_max_length', kwargs['max_length'])
model_type = None
diff --git a/modelscope/preprocessors/nlp/token_classification_preprocessor.py b/modelscope/preprocessors/nlp/token_classification_preprocessor.py
index c07012e0..66e57cc8 100644
--- a/modelscope/preprocessors/nlp/token_classification_preprocessor.py
+++ b/modelscope/preprocessors/nlp/token_classification_preprocessor.py
@@ -57,16 +57,15 @@ class WordSegmentationBlankSetToLabelPreprocessor(Preprocessor):
class TokenClassificationPreprocessorBase(Preprocessor):
- def __init__(
- self,
- model_dir: str = None,
- first_sequence: str = None,
- label: str = 'label',
- label2id: Dict = None,
- label_all_tokens: bool = False,
- mode: str = ModeKeys.INFERENCE,
- keep_original_columns: List[str] = None,
- ):
+ def __init__(self,
+ model_dir: str = None,
+ first_sequence: str = None,
+ label: str = 'label',
+ label2id: Dict = None,
+ label_all_tokens: bool = False,
+ mode: str = ModeKeys.INFERENCE,
+ keep_original_columns: List[str] = None,
+ return_text: bool = True):
"""The base class for all the token-classification tasks.
Args:
@@ -82,6 +81,7 @@ class TokenClassificationPreprocessorBase(Preprocessor):
mode: The preprocessor mode.
keep_original_columns(List[str], `optional`): The original columns to keep,
only available when the input is a `dict`, default None
+ return_text: Whether to return `text` field in inference mode, default: True.
"""
super().__init__(mode)
self.model_dir = model_dir
@@ -90,6 +90,7 @@ class TokenClassificationPreprocessorBase(Preprocessor):
self.label2id = label2id
self.label_all_tokens = label_all_tokens
self.keep_original_columns = keep_original_columns
+ self.return_text = return_text
if self.label2id is None and self.model_dir is not None:
self.label2id = parse_label_mapping(self.model_dir)
@@ -164,7 +165,7 @@ class TokenClassificationPreprocessorBase(Preprocessor):
if self.keep_original_columns and isinstance(data, dict):
for column in self.keep_original_columns:
outputs[column] = data[column]
- if self.mode == ModeKeys.INFERENCE:
+ if self.mode == ModeKeys.INFERENCE and self.return_text:
outputs['text'] = text
return outputs
@@ -208,6 +209,7 @@ class TokenClassificationTransformersPreprocessor(
max_length=None,
use_fast=None,
keep_original_columns=None,
+ return_text=True,
**kwargs):
"""
@@ -218,7 +220,8 @@ class TokenClassificationTransformersPreprocessor(
**kwargs: Extra args input into the tokenizer's __call__ method.
"""
super().__init__(model_dir, first_sequence, label, label2id,
- label_all_tokens, mode, keep_original_columns)
+ label_all_tokens, mode, keep_original_columns,
+ return_text)
self.is_lstm_model = 'lstm' in model_dir
model_type = None
if self.is_lstm_model:
diff --git a/modelscope/preprocessors/tts.py b/modelscope/preprocessors/tts.py
index 7c7d8005..4357f54f 100644
--- a/modelscope/preprocessors/tts.py
+++ b/modelscope/preprocessors/tts.py
@@ -3,9 +3,9 @@
import os
from typing import Any, Dict, List, Union
+from kantts.preprocess.data_process import process_data
+
from modelscope.metainfo import Preprocessors
-from modelscope.models.audio.tts.kantts.preprocess.data_process import \
- process_data
from modelscope.models.base import Model
from modelscope.utils.audio.tts_exceptions import (
TtsDataPreprocessorAudioConfigNotExistsException,
@@ -28,22 +28,22 @@ class KanttsDataPreprocessor(Preprocessor):
def __call__(self,
data_dir,
output_dir,
- lang_dir,
audio_config_path,
speaker_name='F7',
target_lang='PinYin',
- skip_script=False):
- self.do_data_process(data_dir, output_dir, lang_dir, audio_config_path,
- speaker_name, target_lang, skip_script)
+ skip_script=False,
+ se_model=None):
+ self.do_data_process(data_dir, output_dir, audio_config_path,
+ speaker_name, target_lang, skip_script, se_model)
def do_data_process(self,
datadir,
outputdir,
- langdir,
audio_config,
speaker_name='F7',
targetLang='PinYin',
- skip_script=False):
+ skip_script=False,
+ se_model=None):
if not os.path.exists(datadir):
raise TtsDataPreprocessorDirNotExistsException(
'Preprocessor: dataset dir not exists')
@@ -53,8 +53,5 @@ class KanttsDataPreprocessor(Preprocessor):
if not os.path.exists(audio_config):
raise TtsDataPreprocessorAudioConfigNotExistsException(
'Preprocessor: audio config not exists')
- if not os.path.exists(langdir):
- raise TtsDataPreprocessorDirNotExistsException(
- 'Preprocessor: language dir not exists')
- process_data(datadir, outputdir, langdir, audio_config, speaker_name,
- targetLang, skip_script)
+ process_data(datadir, outputdir, audio_config, speaker_name,
+ targetLang, skip_script, se_model)
diff --git a/modelscope/tools/__init__.py b/modelscope/tools/__init__.py
new file mode 100644
index 00000000..59a1c5eb
--- /dev/null
+++ b/modelscope/tools/__init__.py
@@ -0,0 +1 @@
+from .speech_tts_autolabel import run_auto_label
diff --git a/modelscope/tools/speech_tts_autolabel.py b/modelscope/tools/speech_tts_autolabel.py
new file mode 100644
index 00000000..0fcd41fd
--- /dev/null
+++ b/modelscope/tools/speech_tts_autolabel.py
@@ -0,0 +1,141 @@
+import argparse
+import os
+import sys
+import zipfile
+
+from modelscope.hub.snapshot_download import snapshot_download
+from modelscope.utils.constant import ThirdParty
+from modelscope.utils.logger import get_logger
+
+try:
+ from tts_autolabel import AutoLabeling
+except ImportError:
+ raise ImportError('pls install tts-autolabel with \
+ "pip install tts-autolabel -f \
+ https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html"'
+ )
+
+DEFAULT_RESOURCE_MODEL_ID = 'damo/speech_ptts_autolabel_16k'
+logger = get_logger()
+
+
+# Suggest params:
+# --para_ids all --resource_revision v1.0.2 --input_wav data/test/audios/autolabel
+# --work_dir ../ptts/test/diff2 --develop_mode 1 --stage 1 --process_num 2 --no_para --disable_enh
+def run_auto_label(input_wav,
+ work_dir,
+ para_ids='all',
+ resource_model_id=DEFAULT_RESOURCE_MODEL_ID,
+ resource_revision=None,
+ gender='female',
+ stage=1,
+ process_num=4,
+ develop_mode=0,
+ has_para=False,
+ enable_enh=False):
+ if not os.path.exists(input_wav):
+ raise ValueError(f'input_wav: {input_wav} not exists')
+ if not os.path.exists(work_dir):
+ raise ValueError(f'work_dir: {work_dir} not exists')
+
+ def _download_and_unzip_resousrce(model, model_revision=None):
+ if os.path.exists(model):
+ model_cache_dir = model if os.path.isdir(
+ model) else os.path.dirname(model)
+ check_local_model_is_latest(
+ model_cache_dir,
+ user_agent={ThirdParty.KEY: 'speech_tts_autolabel'})
+ else:
+ model_cache_dir = snapshot_download(
+ model,
+ revision=model_revision,
+ user_agent={ThirdParty.KEY: 'speech_tts_autolabel'})
+ if not os.path.exists(model_cache_dir):
+ raise ValueError(f'mdoel_cache_dir: {model_cache_dir} not exists')
+ zip_file = os.path.join(model_cache_dir, 'model.zip')
+ if not os.path.exists(zip_file):
+ raise ValueError(f'zip_file: {zip_file} not exists')
+ z = zipfile.ZipFile(zip_file)
+ z.extractall(model_cache_dir)
+ target_resource = os.path.join(model_cache_dir, 'model')
+ return target_resource
+
+ model_resource = _download_and_unzip_resousrce(resource_model_id,
+ resource_revision)
+ auto_labeling = AutoLabeling(
+ os.path.abspath(input_wav),
+ model_resource,
+ False,
+ os.path.abspath(work_dir),
+ gender,
+ develop_mode,
+ has_para,
+ para_ids,
+ stage,
+ process_num,
+ enable_enh=enable_enh)
+ ret_code, report = auto_labeling.run()
+ return ret_code, report
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ '--para_ids',
+ default='all',
+ help=
+ 'you can use this variable to config your auto labeling paragraph ids, \
+ all means all in the dir, none means no paragraph 1 means 1 para only, \
+ 1 2 means 1 and 2, transcipt/prosody/wav should be named exactly the same!!!'
+ )
+ parser.add_argument(
+ '--resource', type=str, default=DEFAULT_RESOURCE_MODEL_ID)
+ parser.add_argument(
+ '--resource_revision',
+ type=str,
+ default=None,
+ help='resource directory')
+ parser.add_argument('--input_wav', help='personal user input wav dir')
+ parser.add_argument('--work_dir', help='autolabel work dir')
+ parser.add_argument(
+ '--gender', default='female', help='personal user gender')
+ parser.add_argument('--develop_mode', type=int, default=1)
+ parser.add_argument(
+ '--stage',
+ type=int,
+ default=1,
+ help='auto labeling stage, 0 means qualification and 1 means labeling')
+ parser.add_argument(
+ '--process_num',
+ type=int,
+ default=4,
+ help='kaldi bin parallel execution process number')
+ parser.add_argument(
+ '--has_para', dest='has_para', action='store_true', help='paragraph')
+ parser.add_argument(
+ '--no_para',
+ dest='has_para',
+ action='store_false',
+ help='no paragraph')
+ parser.add_argument(
+ '--enable_enh',
+ dest='enable_enh',
+ action='store_true',
+ help='enable audio enhancement')
+ parser.add_argument(
+ '--disable_enh',
+ dest='enable_enh',
+ action='store_false',
+ help='disable audio enhancement')
+ parser.set_defaults(has_para=True)
+ parser.set_defaults(enable_enh=False)
+ args = parser.parse_args()
+ logger.info(args.enable_enh)
+ ret_code, report = run_auto_label(args.input_wav, args.work_dir,
+ args.para_ids, args.resource,
+ args.resource_revision, args.gender,
+ args.stage, args.process_num,
+ args.develop_mode, args.has_para,
+ args.enable_enh)
+ logger.info(f'ret_code={ret_code}')
+ logger.info(f'report={report}')
diff --git a/modelscope/trainers/__init__.py b/modelscope/trainers/__init__.py
index cb635a91..90f73a7f 100644
--- a/modelscope/trainers/__init__.py
+++ b/modelscope/trainers/__init__.py
@@ -12,7 +12,7 @@ if TYPE_CHECKING:
MovieSceneSegmentationTrainer, ImageInpaintingTrainer,
ReferringVideoObjectSegmentationTrainer)
from .multi_modal import CLIPTrainer
- from .nlp import SequenceClassificationTrainer, TextRankingTrainer
+ from .nlp import SequenceClassificationTrainer, TextRankingTrainer, SiameseUIETrainer
from .nlp_trainer import NlpEpochBasedTrainer, VecoTrainer
from .trainer import EpochBasedTrainer
@@ -27,7 +27,10 @@ else:
'ImageInpaintingTrainer'
],
'multi_modal': ['CLIPTrainer'],
- 'nlp': ['SequenceClassificationTrainer', 'TextRankingTrainer'],
+ 'nlp': [
+ 'SequenceClassificationTrainer', 'TextRankingTrainer',
+ 'SiameseUIETrainer'
+ ],
'nlp_trainer': ['NlpEpochBasedTrainer', 'VecoTrainer'],
'trainer': ['EpochBasedTrainer']
}
diff --git a/modelscope/trainers/audio/kws_farfield_trainer.py b/modelscope/trainers/audio/kws_farfield_trainer.py
index 276bf85f..205947b7 100644
--- a/modelscope/trainers/audio/kws_farfield_trainer.py
+++ b/modelscope/trainers/audio/kws_farfield_trainer.py
@@ -1,6 +1,8 @@
import datetime
+import glob
import math
import os
+import pickle
from typing import Callable, Dict, Optional
import numpy as np
@@ -10,7 +12,8 @@ from torch import optim as optim
from modelscope.metainfo import Trainers
from modelscope.models import Model, TorchModel
-from modelscope.msdatasets.task_datasets.audio import KWSDataLoader, KWSDataset
+from modelscope.msdatasets.dataset_cls.custom_datasets.audio import (
+ KWSDataLoader, KWSDataset)
from modelscope.trainers.base import BaseTrainer
from modelscope.trainers.builder import TRAINERS
from modelscope.utils.audio.audio_utils import update_conf
@@ -29,6 +32,7 @@ BASETRAIN_CONF_HARD = 'basetrain_hard'
FINETUNE_CONF_EASY = 'finetune_easy'
FINETUNE_CONF_NORMAL = 'finetune_normal'
FINETUNE_CONF_HARD = 'finetune_hard'
+CKPT_PREFIX = 'checkpoint'
EASY_RATIO = 0.1
NORMAL_RATIO = 0.6
@@ -110,9 +114,27 @@ class KWSFarfieldTrainer(BaseTrainer):
if 'single_rate' in kwargs:
self._single_rate = kwargs['single_rate']
self._batch_size = dataloader_config.batch_size_per_gpu
+ next_epoch = kwargs.get('next_epoch', 1)
+ self._current_epoch = next_epoch - 1
if 'model_bin' in kwargs:
model_bin_file = os.path.join(self.model_dir, kwargs['model_bin'])
self.model = torch.load(model_bin_file)
+ elif self._current_epoch > 0:
+ # load checkpoint
+ ckpt_file_pattern = os.path.join(
+ self.work_dir, f'{CKPT_PREFIX}_{self._current_epoch:04d}*.pth')
+ ckpt_files = glob.glob(ckpt_file_pattern)
+ if len(ckpt_files) == 1:
+ logger.info('Loading model from checkpoint: %s', ckpt_files[0])
+ self.model = torch.load(ckpt_files[0])
+ elif len(ckpt_files) == 0:
+ raise FileNotFoundError(
+ f'Failed to load checkpoint file like '
+ f'{ckpt_file_pattern}. File not found!')
+ else:
+ raise AssertionError(f'Expecting one but multiple checkpoint'
+ f' files are found: {ckpt_files}')
+
# build corresponding optimizer and loss function
lr = self.cfg.train.optimizer.lr
self.optimizer = optim.Adam(self.model.parameters(), lr)
@@ -123,10 +145,9 @@ class KWSFarfieldTrainer(BaseTrainer):
self.conf_files = []
for conf_key in self.conf_keys:
template_file = os.path.join(self.model_dir, conf_key)
- conf_file = os.path.join(self.model_dir, f'{conf_key}.conf')
+ conf_file = os.path.join(self.work_dir, f'{conf_key}.conf')
update_conf(template_file, conf_file, custom_conf[conf_key])
self.conf_files.append(conf_file)
- self._current_epoch = 0
self.stages = (math.floor(self._max_epochs * EASY_RATIO),
math.floor(self._max_epochs * NORMAL_RATIO),
math.floor(self._max_epochs * HARD_RATIO))
@@ -151,30 +172,33 @@ class KWSFarfieldTrainer(BaseTrainer):
logger.info('Start training...')
totaltime = datetime.datetime.now()
+ next_stage_head_epoch = 0
for stage, num_epoch in enumerate(self.stages):
- self.run_stage(stage, num_epoch)
+ next_stage_head_epoch += num_epoch
+ epochs_to_run = next_stage_head_epoch - self._current_epoch
+ self.run_stage(stage, epochs_to_run)
# total time spent
totaltime = datetime.datetime.now() - totaltime
logger.info('Total time spent: {:.2f} hours\n'.format(
totaltime.total_seconds() / 3600.0))
- def run_stage(self, stage, num_epoch):
+ def run_stage(self, stage, epochs_to_run):
"""
Run training stages with correspond data
Args:
stage: id of stage
- num_epoch: the number of epoch to run in this stage
+ epochs_to_run: the number of epoch to run in this stage
"""
- if num_epoch <= 0:
+ if epochs_to_run <= 0:
logger.warning(f'Invalid epoch number, stage {stage} exit!')
return
logger.info(f'Starting stage {stage}...')
dataset, dataloader = self.create_dataloader(
self.conf_files[stage * 2], self.conf_files[stage * 2 + 1])
it = iter(dataloader)
- for _ in range(num_epoch):
+ for _ in range(epochs_to_run):
self._current_epoch += 1
epochtime = datetime.datetime.now()
logger.info('Start epoch %d...', self._current_epoch)
@@ -211,8 +235,9 @@ class KWSFarfieldTrainer(BaseTrainer):
logger.info(val_result)
self._dump_log(val_result)
# check point
- ckpt_name = 'checkpoint_{:04d}_loss_train_{:.4f}_loss_val_{:.4f}.pth'.format(
- self._current_epoch, loss_train_epoch, loss_val_epoch)
+ ckpt_name = '{}_{:04d}_loss_train_{:.4f}_loss_val_{:.4f}.pth'.format(
+ CKPT_PREFIX, self._current_epoch, loss_train_epoch,
+ loss_val_epoch)
save_path = os.path.join(self.work_dir, ckpt_name)
logger.info(f'Save model to {save_path}')
torch.save(self.model, save_path)
@@ -229,6 +254,14 @@ class KWSFarfieldTrainer(BaseTrainer):
"""
generate validation set
"""
+ val_dump_file = os.path.join(self.work_dir, 'val_dataset.bin')
+ if self._current_epoch > 0:
+ logger.info('Start loading validation set...')
+ with open(val_dump_file, 'rb') as f:
+ self.data_val = pickle.load(f)
+ logger.info('Finish loading validation set!')
+ return
+
logger.info('Start generating validation set...')
dataset, dataloader = self.create_dataloader(self.conf_files[2],
self.conf_files[3])
@@ -243,6 +276,9 @@ class KWSFarfieldTrainer(BaseTrainer):
dataloader.stop()
dataset.release()
+
+ with open(val_dump_file, 'wb') as f:
+ pickle.dump(self.data_val, f)
logger.info('Finish generating validation set!')
def create_dataloader(self, base_path, finetune_path):
diff --git a/modelscope/trainers/audio/kws_nearfield_trainer.py b/modelscope/trainers/audio/kws_nearfield_trainer.py
index bf00c435..5e63e87e 100644
--- a/modelscope/trainers/audio/kws_nearfield_trainer.py
+++ b/modelscope/trainers/audio/kws_nearfield_trainer.py
@@ -1,42 +1,30 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import copy
import datetime
-import math
import os
-import random
import re
-import sys
-from shutil import copyfile
from typing import Callable, Dict, Optional
-import numpy as np
import torch
-import torch.distributed as dist
-import torch.nn.functional as F
import yaml
from tensorboardX import SummaryWriter
from torch import nn as nn
from torch import optim as optim
-from torch.distributed import ReduceOp
-from torch.nn.utils import clip_grad_norm_
from torch.utils.data import DataLoader
from modelscope.metainfo import Trainers
from modelscope.models import Model, TorchModel
-from modelscope.msdatasets.task_datasets.audio.kws_nearfield_dataset import \
+from modelscope.msdatasets.dataset_cls.custom_datasets.audio.kws_nearfield_dataset import \
kws_nearfield_dataset
from modelscope.trainers.base import BaseTrainer
from modelscope.trainers.builder import TRAINERS
-from modelscope.utils.audio.audio_utils import update_conf
from modelscope.utils.checkpoint import load_checkpoint, save_checkpoint
from modelscope.utils.config import Config
from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile
-from modelscope.utils.data_utils import to_device
from modelscope.utils.device import create_device
from modelscope.utils.logger import get_logger
from modelscope.utils.torch_utils import (get_dist_info, get_local_rank,
- init_dist, is_master,
- set_random_seed)
+ init_dist, set_random_seed)
from .kws_utils.batch_utils import executor_cv, executor_test, executor_train
from .kws_utils.det_utils import compute_det
from .kws_utils.file_utils import query_tokens_id, read_lexicon, read_token
diff --git a/modelscope/trainers/audio/tts_trainer.py b/modelscope/trainers/audio/tts_trainer.py
index e835f24e..e6964729 100644
--- a/modelscope/trainers/audio/tts_trainer.py
+++ b/modelscope/trainers/audio/tts_trainer.py
@@ -2,11 +2,13 @@
import os
import shutil
import tempfile
+import zipfile
from typing import Callable, Dict, List, Optional, Tuple, Union
import json
from modelscope.metainfo import Preprocessors, Trainers
+from modelscope.models import Model
from modelscope.models.audio.tts import SambertHifigan
from modelscope.msdatasets import MsDataset
from modelscope.preprocessors.builder import build_preprocessor
@@ -17,6 +19,7 @@ from modelscope.utils.audio.tts_exceptions import (
TtsTrainingCfgNotExistsException, TtsTrainingDatasetInvalidException,
TtsTrainingHparamsInvalidException, TtsTrainingInvalidModelException,
TtsTrainingWorkDirNotExistsException)
+from modelscope.utils.config import Config
from modelscope.utils.constant import (DEFAULT_DATASET_NAMESPACE,
DEFAULT_DATASET_REVISION,
DEFAULT_MODEL_REVISION, ModelFile,
@@ -35,12 +38,12 @@ class KanttsTrainer(BaseTrainer):
ORIG_MODEL_DIR = 'orig_model'
def __init__(self,
- model: str,
+ model: Union[Model, str],
work_dir: str = None,
speaker: str = 'F7',
lang_type: str = 'PinYin',
- cfg_file: Optional[str] = None,
- train_dataset: Optional[Union[MsDataset, str]] = None,
+ cfg_file: str = None,
+ train_dataset: Union[MsDataset, str] = None,
train_dataset_namespace: str = DEFAULT_DATASET_NAMESPACE,
train_dataset_revision: str = DEFAULT_DATASET_REVISION,
train_type: dict = {
@@ -76,7 +79,6 @@ class KanttsTrainer(BaseTrainer):
self.train_type[TtsTrainType.TRAIN_TYPE_VOC] = {}
logger.info(f'Set workdir to {self.work_dir}')
-
self.data_dir = os.path.join(self.work_dir, self.DATA_DIR)
self.am_tmp_dir = os.path.join(self.work_dir, self.AM_TMP_DIR)
self.voc_tmp_dir = os.path.join(self.work_dir, self.VOC_TMP_DIR)
@@ -84,7 +86,6 @@ class KanttsTrainer(BaseTrainer):
self.raw_dataset_path = ''
self.skip_script = preprocess_skip_script
self.audio_config_path = ''
- self.lang_path = ''
self.am_config_path = ''
self.voc_config_path = ''
@@ -99,15 +100,29 @@ class KanttsTrainer(BaseTrainer):
if train_dataset:
if isinstance(train_dataset, str):
- logger.info(f'load {train_dataset_namespace}/{train_dataset}')
- train_dataset = MsDataset.load(
- dataset_name=train_dataset,
- namespace=train_dataset_namespace,
- version=train_dataset_revision)
- logger.info(f'train dataset:{train_dataset.config_kwargs}')
- self.raw_dataset_path = self.load_dataset_raw_path(train_dataset)
+ if os.path.exists(train_dataset):
+ logger.info(f'load {train_dataset}')
+ self.raw_dataset_path = train_dataset
+ else:
+ logger.info(
+ f'load {train_dataset_namespace}/{train_dataset}')
+ train_dataset = MsDataset.load(
+ dataset_name=train_dataset,
+ namespace=train_dataset_namespace,
+ version=train_dataset_revision)
+ logger.info(f'train dataset:{train_dataset.config_kwargs}')
+ self.raw_dataset_path = self.load_dataset_raw_path(
+ train_dataset)
+ else:
+ self.raw_dataset_path = self.load_dataset_raw_path(
+ train_dataset)
- model_dir = self.get_or_download_model_dir(model, model_revision)
+ if not model:
+ raise TtsTrainingInvalidModelException('model param is none')
+ if isinstance(model, str):
+ model_dir = self.get_or_download_model_dir(model, model_revision)
+ else:
+ model_dir = model.model_dir
shutil.copytree(model_dir, self.orig_model_dir)
self.model_dir = self.orig_model_dir
@@ -121,11 +136,10 @@ class KanttsTrainer(BaseTrainer):
self.finetune_from_pretrain = False
self.speaker = speaker
- self.lang_type = lang_type
self.model = None
self.device = kwargs.get('device', 'gpu')
- self.model = self.get_model(self.model_dir, self.speaker,
- self.lang_type)
+ self.model = self.get_model(self.model_dir, self.speaker)
+ self.lang_type = self.model.lang_type
if TtsTrainType.TRAIN_TYPE_SAMBERT in self.train_type or TtsTrainType.TRAIN_TYPE_VOC in self.train_type:
self.audio_data_preprocessor = build_preprocessor(
dict(type=Preprocessors.kantts_data_preprocessor),
@@ -152,26 +166,25 @@ class KanttsTrainer(BaseTrainer):
config['train']['voc_config'])
if os.path.exists(voc_config):
self.voc_config_path = voc_config
- if 'language_path' in config['train']:
- lang_path = os.path.join(cur_dir,
- config['train']['language_path'])
- if os.path.exists(lang_path):
- self.lang_path = lang_path
if not self.raw_dataset_path:
if 'train_dataset' in config['train']:
dataset = config['train']['train_dataset']
- if 'id' in dataset:
- namespace = dataset.get('namespace',
- DEFAULT_DATASET_NAMESPACE)
- revision = dataset.get('revision',
- DEFAULT_DATASET_REVISION)
- ms = MsDataset.load(
- dataset_name=dataset['id'],
- namespace=namespace,
- version=revision)
- self.raw_dataset_path = self.load_dataset_raw_path(ms)
- elif 'path' in dataset:
- self.raw_dataset_path = dataset['path']
+ if os.path.exists(dataset):
+ self.raw_dataset_path = dataset
+ else:
+ if 'id' in dataset:
+ namespace = dataset.get('namespace',
+ DEFAULT_DATASET_NAMESPACE)
+ revision = dataset.get('revision',
+ DEFAULT_DATASET_REVISION)
+ ms = MsDataset.load(
+ dataset_name=dataset['id'],
+ namespace=namespace,
+ version=revision)
+ self.raw_dataset_path = self.load_dataset_raw_path(
+ ms)
+ elif 'path' in dataset:
+ self.raw_dataset_path = dataset['path']
def load_dataset_raw_path(self, dataset: MsDataset):
if 'split_config' not in dataset.config_kwargs:
@@ -188,19 +201,21 @@ class KanttsTrainer(BaseTrainer):
if not audio_config or not os.path.exists(audio_config):
audio_config = self.model.get_voice_audio_config_path(
self.speaker)
- lang_path = self.lang_path
- if not lang_path or not os.path.exists(lang_path):
- lang_path = self.model.get_voice_lang_path(self.speaker)
+ se_model = self.model.get_voice_se_model_path(self.speaker)
self.audio_data_preprocessor(self.raw_dataset_path, self.data_dir,
- lang_path, audio_config, self.speaker,
- self.lang_type, self.skip_script)
+ audio_config, self.speaker,
+ self.lang_type, self.skip_script,
+ se_model)
def prepare_text(self):
pass
- def get_model(self, model_dir, speaker, lang_type):
+ def get_model(self, model_dir, speaker):
+ cfg = Config.from_file(
+ os.path.join(self.model_dir, ModelFile.CONFIGURATION))
+ model_cfg = cfg.get('model', {})
model = SambertHifigan(
- model_dir=self.model_dir, lang_type=self.lang_type, is_train=True)
+ model_dir=self.model_dir, is_train=True, **model_cfg)
return model
def train(self, *args, **kwargs):
diff --git a/modelscope/trainers/base.py b/modelscope/trainers/base.py
index 29fb3d2e..e5d708c0 100644
--- a/modelscope/trainers/base.py
+++ b/modelscope/trainers/base.py
@@ -3,7 +3,7 @@
import os
import time
from abc import ABC, abstractmethod
-from typing import Callable, Dict, List, Optional, Tuple, Union
+from typing import Callable, Dict, Optional
from modelscope.hub.check_model import check_local_model_is_latest
from modelscope.hub.snapshot_download import snapshot_download
diff --git a/modelscope/trainers/cv/__init__.py b/modelscope/trainers/cv/__init__.py
index c31342ae..d6aa6c30 100644
--- a/modelscope/trainers/cv/__init__.py
+++ b/modelscope/trainers/cv/__init__.py
@@ -12,7 +12,9 @@ if TYPE_CHECKING:
from .referring_video_object_segmentation_trainer import ReferringVideoObjectSegmentationTrainer
from .image_defrcn_fewshot_detection_trainer import ImageDefrcnFewshotTrainer
from .cartoon_translation_trainer import CartoonTranslationTrainer
+ from .ocr_recognition_trainer import OCRRecognitionTrainer
from .nerf_recon_acc_trainer import NeRFReconAccTrainer
+ from .vision_efficient_tuning_trainer import VisionEfficientTuningTrainer
else:
_import_structure = {
@@ -27,7 +29,9 @@ else:
'image_defrcn_fewshot_detection_trainer':
['ImageDefrcnFewshotTrainer'],
'cartoon_translation_trainer': ['CartoonTranslationTrainer'],
+ 'ocr_recognition_trainer': ['OCRRecognitionTrainer'],
'nerf_recon_acc_trainer': ['NeRFReconAccTrainer'],
+ 'vision_efficient_tuning_trainer': ['VisionEfficientTuningTrainer'],
}
import sys
diff --git a/modelscope/trainers/cv/action_detection_trainer.py b/modelscope/trainers/cv/action_detection_trainer.py
new file mode 100644
index 00000000..b89b604c
--- /dev/null
+++ b/modelscope/trainers/cv/action_detection_trainer.py
@@ -0,0 +1,184 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+
+import os
+import os.path as osp
+from typing import Callable, Dict, Optional
+
+import torch
+from detectron2.checkpoint import DetectionCheckpointer
+from detectron2.data import (build_detection_test_loader,
+ build_detection_train_loader)
+from detectron2.engine import SimpleTrainer, hooks, launch
+from detectron2.engine.defaults import create_ddp_model, default_writers
+from detectron2.evaluation import inference_on_dataset, print_csv_format
+from detectron2.solver import LRMultiplier, WarmupParamScheduler
+from detectron2.solver.build import get_default_optimizer_params
+from detectron2.utils import comm
+from detectron2.utils.file_io import PathManager
+from detectron2.utils.logger import setup_logger
+from fvcore.common.param_scheduler import CosineParamScheduler
+
+from modelscope.hub.check_model import check_local_model_is_latest
+from modelscope.hub.snapshot_download import snapshot_download
+from modelscope.metainfo import Trainers
+from modelscope.metrics.action_detection_evaluator import DetEvaluator
+from modelscope.models.cv.action_detection.modules.action_detection_pytorch import \
+ build_action_detection_model
+from modelscope.preprocessors.cv.action_detection_mapper import VideoDetMapper
+from modelscope.trainers.base import BaseTrainer
+from modelscope.trainers.builder import TRAINERS
+from modelscope.utils.constant import Invoke, ModelFile, Tasks
+
+
+@TRAINERS.register_module(module_name=Trainers.action_detection)
+class ActionDetectionTrainer(BaseTrainer):
+
+ def __init__(self,
+ model_id,
+ train_dataset,
+ test_dataset,
+ cfg_file: str = None,
+ cfg_modify_fn: Optional[Callable] = None,
+ *args,
+ **kwargs):
+ model_cache_dir = self.get_or_download_model_dir(model_id)
+ if cfg_file is None:
+ cfg_file = os.path.join(model_cache_dir, ModelFile.CONFIGURATION)
+
+ super().__init__(cfg_file)
+ if cfg_modify_fn is not None:
+ self.cfg = cfg_modify_fn(self.cfg)
+ self.total_step = self.cfg.train.max_iter
+ self.warmup_step = self.cfg.train.lr_scheduler['warmup_step']
+ self.lr = self.cfg.train.optimizer.lr
+ self.total_batch_size = max(
+ 1, self.cfg.train.num_gpus
+ ) * self.cfg.train.dataloader['batch_size_per_gpu']
+ self.num_classes = len(self.cfg.train.classes_id_map)
+ self.resume = kwargs.get('resume', False)
+ self.train_dataset = train_dataset
+ self.test_dataset = test_dataset
+ self.pretrained_model = kwargs.get(
+ 'pretrained_model',
+ osp.join(model_cache_dir, ModelFile.TORCH_MODEL_FILE))
+
+ def start(self, output_dir):
+ if comm.is_main_process() and output_dir:
+ PathManager.mkdirs(output_dir)
+ self.cfg.dump(osp.join(output_dir, 'config.py'))
+ rank = comm.get_rank()
+ setup_logger(output_dir, distributed_rank=rank, name='fvcore')
+ logger = setup_logger(output_dir, distributed_rank=rank)
+ logger.info('Rank of current process: {}. World size: {}'.format(
+ rank, comm.get_world_size()))
+
+ def train(self, *args, **kwargs):
+ if self.cfg.train.num_gpus <= 1:
+ self.do_train()
+ else:
+ launch(
+ self.do_train,
+ self.cfg.train.num_gpus,
+ 1,
+ machine_rank=0,
+ dist_url='auto',
+ args=())
+
+ def evaluate(self, checkpoint_path: str, *args,
+ **kwargs) -> Dict[str, float]:
+ if self.cfg.train.num_gpus <= 1:
+ self.do_train(just_eval=True, checkpoint_path=checkpoint_path)
+ else:
+ launch(
+ self.do_train,
+ self.cfg.train.num_gpus,
+ 1,
+ machine_rank=0,
+ dist_url='auto',
+ args=(True, checkpoint_path))
+
+ def do_train(
+ self,
+ just_eval=False,
+ checkpoint_path=None,
+ ):
+ self.start(self.cfg.train.work_dir)
+ model = build_action_detection_model(num_classes=self.num_classes)
+ if self.cfg.train.num_gpus > 0:
+ model.cuda()
+ model = create_ddp_model(model, broadcast_buffers=False)
+ if just_eval:
+ checkpoint = DetectionCheckpointer(model)
+ checkpoint.load(checkpoint_path)
+ result = self.do_test(model)
+ return result
+ optim = torch.optim.AdamW(
+ params=get_default_optimizer_params(model, base_lr=self.lr),
+ lr=self.lr,
+ weight_decay=0.1)
+ lr_scheduler = LRMultiplier(
+ optim,
+ WarmupParamScheduler(
+ CosineParamScheduler(1, 1e-3),
+ warmup_factor=0,
+ warmup_length=self.warmup_step / self.total_step),
+ max_iter=self.total_step,
+ )
+ train_loader = build_detection_train_loader(
+ self.train_dataset,
+ mapper=VideoDetMapper(
+ self.cfg.train.classes_id_map, is_train=True),
+ total_batch_size=self.total_batch_size,
+ num_workers=self.cfg.train.dataloader.workers_per_gpu)
+ trainer = SimpleTrainer(model, train_loader, optim)
+ checkpointer = DetectionCheckpointer(
+ model, self.cfg.train.work_dir, trainer=trainer)
+
+ trainer.register_hooks([
+ hooks.IterationTimer(),
+ hooks.LRScheduler(scheduler=lr_scheduler),
+ hooks.PeriodicCheckpointer(
+ checkpointer, period=self.cfg.train.checkpoint_interval)
+ if comm.is_main_process() else None,
+ hooks.EvalHook(
+ eval_period=self.cfg.evaluation.interval,
+ eval_function=lambda: self.do_test(model)),
+ hooks.PeriodicWriter(
+ default_writers(checkpointer.save_dir, self.total_step),
+ period=20) if comm.is_main_process() else None,
+ ])
+ checkpointer.resume_or_load(self.pretrained_model, resume=False)
+ if self.resume:
+ checkpointer.resume_or_load(resume=self.resume)
+ start_iter = trainer.iter + 1
+ else:
+ start_iter = 0
+ trainer.train(start_iter, self.total_step)
+
+ def do_test(self, model):
+ evaluator = DetEvaluator(
+ list(self.cfg.train.classes_id_map.keys()),
+ self.cfg.train.work_dir,
+ distributed=self.cfg.train.num_gpus > 1)
+ test_loader = build_detection_test_loader(
+ self.test_dataset,
+ mapper=VideoDetMapper(
+ self.cfg.train.classes_id_map, is_train=False),
+ num_workers=self.cfg.evaluation.dataloader.workers_per_gpu)
+
+ result = inference_on_dataset(model, test_loader, evaluator)
+ print_csv_format(result)
+ return result
+
+ def get_or_download_model_dir(self, model, model_revision=None):
+ if os.path.exists(model):
+ model_cache_dir = model if os.path.isdir(
+ model) else os.path.dirname(model)
+ check_local_model_is_latest(
+ model_cache_dir, user_agent={Invoke.KEY: Invoke.LOCAL_TRAINER})
+ else:
+ model_cache_dir = snapshot_download(
+ model,
+ revision=model_revision,
+ user_agent={Invoke.KEY: Invoke.TRAINER})
+ return model_cache_dir
diff --git a/modelscope/trainers/cv/image_detection_damoyolo_trainer.py b/modelscope/trainers/cv/image_detection_damoyolo_trainer.py
index 734c8915..8d8b32ae 100644
--- a/modelscope/trainers/cv/image_detection_damoyolo_trainer.py
+++ b/modelscope/trainers/cv/image_detection_damoyolo_trainer.py
@@ -1,11 +1,9 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
-import copy
import datetime
import math
import os
-import os.path as osp
import time
-from typing import Callable, Dict, Optional
+from typing import Dict
import torch
import torch.distributed as dist
@@ -25,12 +23,13 @@ from modelscope.models.cv.tinynas_detection.damo.detectors.detector import (
build_ddp_model, build_local_model)
from modelscope.models.cv.tinynas_detection.damo.utils import (
cosine_scheduler, ema_model)
-from modelscope.msdatasets.task_datasets.damoyolo import (build_dataloader,
- build_dataset)
+from modelscope.msdatasets.dataset_cls.custom_datasets.damoyolo import (
+ build_dataloader, build_dataset)
from modelscope.trainers.base import BaseTrainer
from modelscope.trainers.builder import TRAINERS
from modelscope.utils.checkpoint import save_checkpoint
-from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile
+from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, ModelFile,
+ ThirdParty)
from modelscope.utils.logger import get_logger
from modelscope.utils.metric import MeterBuffer
from modelscope.utils.torch_utils import get_rank, synchronize
@@ -64,14 +63,19 @@ class ImageDetectionDamoyoloTrainer(BaseTrainer):
train_ann: the path of train set annotation file.
val_ann: the path of val set annotation file.
num_classes: class number.
- base_lr_per_img: learning rate per image. The final learning rate is base_lr_per_img*batch_size.
+ base_lr_per_img: learning rate per image.
+ The final learning rate is base_lr_per_img*batch_size.
pretrain_model: the path of pretrained model.
work_dir: the directory of work folder.
exp_name: the name of experiment.
+ third_party: in which third party library this function is called.
"""
if model is not None:
+ third_party = kwargs.get(ThirdParty.KEY)
+ if third_party is not None:
+ kwargs.pop(ThirdParty.KEY)
self.cache_path = self.get_or_download_model_dir(
- model, model_revision)
+ model, model_revision, third_party)
if cfg_file is None:
self.cfg_file = os.path.join(self.cache_path,
ModelFile.CONFIGURATION)
diff --git a/modelscope/trainers/cv/ocr_detection_db_trainer.py b/modelscope/trainers/cv/ocr_detection_db_trainer.py
new file mode 100644
index 00000000..3a9d51aa
--- /dev/null
+++ b/modelscope/trainers/cv/ocr_detection_db_trainer.py
@@ -0,0 +1,433 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import copy
+import datetime
+import math
+import os
+import time
+from typing import Callable, Dict, Optional
+
+import numpy as np
+import torch
+import torch.distributed as dist
+import torch.multiprocessing as mp
+import torch.nn as nn
+from easydict import EasyDict as easydict
+from tqdm import tqdm
+
+from modelscope.metainfo import Trainers
+from modelscope.models.cv.ocr_detection.modules.dbnet import (DBModel,
+ DBModel_v2)
+from modelscope.models.cv.ocr_detection.utils import (boxes_from_bitmap,
+ polygons_from_bitmap)
+from modelscope.msdatasets.dataset_cls.custom_datasets.ocr_detection import (
+ DataLoader, ImageDataset, QuadMeasurer)
+from modelscope.trainers.base import BaseTrainer
+from modelscope.trainers.builder import TRAINERS
+from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile
+from modelscope.utils.logger import get_logger
+from modelscope.utils.torch_utils import get_rank, synchronize
+
+
+@TRAINERS.register_module(module_name=Trainers.ocr_detection_db)
+class OCRDetectionDBTrainer(BaseTrainer):
+
+ def __init__(self,
+ model: str = None,
+ cfg_file: str = None,
+ load_pretrain: bool = True,
+ cache_path: str = None,
+ model_revision: str = DEFAULT_MODEL_REVISION,
+ *args,
+ **kwargs):
+ """ High-level finetune api for dbnet.
+
+ Args:
+ model: Model id of modelscope models.
+ cfg_file: Path to configuration file.
+ load_pretrain: Whether load pretrain model for finetune.
+ if False, means training from scratch.
+ cache_path: cache path of model files.
+ """
+ if model is not None:
+ self.cache_path = self.get_or_download_model_dir(
+ model, model_revision)
+ if cfg_file is None:
+ self.cfg_file = os.path.join(self.cache_path,
+ ModelFile.CONFIGURATION)
+ else:
+ assert cfg_file is not None and cache_path is not None, \
+ 'cfg_file and cache_path is needed, if model is not provided'
+
+ if cfg_file is not None:
+ self.cfg_file = cfg_file
+ if cache_path is not None:
+ self.cache_path = cache_path
+ super().__init__(self.cfg_file)
+ cfg = self.cfg
+ if load_pretrain:
+ if 'pretrain_model' in kwargs:
+ cfg.train.finetune_path = kwargs['pretrain_model']
+ else:
+ cfg.train.finetune_path = os.path.join(self.cache_path,
+ self.cfg.model.weights)
+
+ if 'framework' in self.cfg:
+ cfg = self._config_transform(cfg)
+
+ if 'gpu_ids' in kwargs:
+ cfg.train.gpu_ids = kwargs['gpu_ids']
+ if 'batch_size' in kwargs:
+ cfg.train.batch_size = kwargs['batch_size']
+ if 'max_epochs' in kwargs:
+ cfg.train.total_epochs = kwargs['max_epochs']
+ if 'base_lr' in kwargs:
+ cfg.train.base_lr = kwargs['base_lr']
+ if 'train_data_dir' in kwargs:
+ cfg.dataset.train_data_dir = kwargs['train_data_dir']
+ if 'val_data_dir' in kwargs:
+ cfg.dataset.val_data_dir = kwargs['val_data_dir']
+ if 'train_data_list' in kwargs:
+ cfg.dataset.train_data_list = kwargs['train_data_list']
+ if 'val_data_list' in kwargs:
+ cfg.dataset.val_data_list = kwargs['val_data_list']
+
+ self.gpu_ids = cfg.train.gpu_ids
+ self.world_size = len(self.gpu_ids)
+
+ self.cfg = cfg
+
+ def train(self):
+ trainer = DBTrainer(self.cfg)
+ trainer.train(local_rank=0)
+
+ def evaluate(self,
+ checkpoint_path: str = None,
+ *args,
+ **kwargs) -> Dict[str, float]:
+ if checkpoint_path is not None:
+ self.cfg.test.checkpoint_path = checkpoint_path
+ evaluater = DBTrainer(self.cfg)
+ evaluater.evaluate(local_rank=0)
+
+ def _config_transform(self, config):
+ new_config = easydict({})
+ new_config.miscs = config.train.miscs
+ new_config.miscs.output_dir = config.train.work_dir
+ new_config.model = config.model
+ new_config.dataset = config.dataset
+ new_config.train = config.train
+ new_config.test = config.evaluation
+
+ new_config.train.dataloader.num_gpus = len(config.train.gpu_ids)
+ new_config.train.dataloader.batch_size = len(
+ config.train.gpu_ids) * config.train.dataloader.batch_size_per_gpu
+ new_config.train.dataloader.num_workers = len(
+ config.train.gpu_ids) * config.train.dataloader.workers_per_gpu
+ new_config.train.total_epochs = config.train.max_epochs
+
+ new_config.test.dataloader.num_gpus = 1
+ new_config.test.dataloader.num_workers = 4
+ new_config.test.dataloader.collect_fn = config.evaluation.transform.collect_fn
+
+ return new_config
+
+
+class DBTrainer:
+
+ def __init__(self, cfg):
+ self.init_device()
+
+ self.cfg = cfg
+ self.dir_path = cfg.miscs.output_dir
+ self.lr = cfg.train.base_lr
+ self.current_lr = 0
+
+ self.total = 0
+
+ if len(cfg.train.gpu_ids) > 1:
+ self.distributed = True
+ else:
+ self.distributed = False
+
+ self.file_name = os.path.join(cfg.miscs.output_dir, cfg.miscs.exp_name)
+
+ # setup logger
+ if get_rank() == 0:
+ os.makedirs(self.file_name, exist_ok=True)
+
+ self.logger = get_logger(os.path.join(self.file_name, 'train_log.txt'))
+
+ # logger
+ self.logger.info('cfg value:\n{}'.format(self.cfg))
+
+ def init_device(self):
+ if torch.cuda.is_available():
+ self.device = torch.device('cuda')
+ else:
+ self.device = torch.device('cpu')
+
+ def init_model(self, local_rank):
+ model = DBModel_v2(self.device, self.distributed, local_rank)
+ return model
+
+ def get_learning_rate(self, epoch, step=None):
+ # DecayLearningRate
+ factor = 0.9
+ rate = np.power(1.0 - epoch / float(self.cfg.train.total_epochs + 1),
+ factor)
+ return rate * self.lr
+
+ def update_learning_rate(self, optimizer, epoch, step):
+ lr = self.get_learning_rate(epoch, step)
+
+ for group in optimizer.param_groups:
+ group['lr'] = lr
+ self.current_lr = lr
+
+ def restore_model(self, model, model_path, device):
+ state_dict = torch.load(model_path, map_location=device)
+ model.load_state_dict(state_dict, strict=False)
+
+ def create_optimizer(self, lr=0.007, momentum=0.9, weight_decay=0.0001):
+ bn_group, weight_group, bias_group = [], [], []
+
+ for k, v in self.model.named_modules():
+ if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter):
+ bias_group.append(v.bias)
+ if isinstance(v, nn.BatchNorm2d) or 'bn' in k:
+ bn_group.append(v.weight)
+ elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter):
+ weight_group.append(v.weight)
+
+ optimizer = torch.optim.SGD(
+ bn_group, lr=lr, momentum=momentum, nesterov=True)
+ optimizer.add_param_group({
+ 'params': weight_group,
+ 'weight_decay': weight_decay
+ })
+ optimizer.add_param_group({'params': bias_group})
+ return optimizer
+
+ def maybe_save_model(self, model, epoch, step):
+ if step % self.cfg.miscs.save_interval == 0:
+ self.logger.info('save interval model for step ' + str(step))
+ self.save_model(model, epoch, step)
+
+ def save_model(self, model, epoch=None, step=None):
+ if isinstance(model, dict):
+ for name, net in model.items():
+ checkpoint_name = self.make_checkpoint_name(name, epoch, step)
+ self.save_checkpoint(net, checkpoint_name)
+ else:
+ checkpoint_name = self.make_checkpoint_name('model', epoch, step)
+ self.save_checkpoint(model, checkpoint_name)
+
+ def save_checkpoint(self, net, name):
+ os.makedirs(self.dir_path, exist_ok=True)
+ torch.save(net.state_dict(), os.path.join(self.dir_path, name))
+ self.logger.info('save_checkpoint to: '
+ + os.path.join(self.dir_path, name))
+
+ def convert_model_for_inference(self, finetune_model_name,
+ infer_model_name):
+ # Convert finetuned model to model for inference,
+ # remove some param for training.
+ infer_model = DBModel().to(self.device)
+ model_state_dict = infer_model.state_dict()
+ model_keys = list(model_state_dict.keys())
+ saved_dict = torch.load(
+ os.path.join(self.dir_path, finetune_model_name),
+ map_location=self.device)
+ saved_keys = set(saved_dict.keys())
+ prefix = 'model.module.'
+ for i in range(len(model_keys)):
+ if prefix + model_keys[i] in saved_keys:
+ model_state_dict[model_keys[i]] = (
+ saved_dict[prefix + model_keys[i]].cpu().float())
+ infer_model.load_state_dict(model_state_dict)
+ torch.save(infer_model.state_dict(),
+ os.path.join(self.dir_path, infer_model_name))
+
+ def make_checkpoint_name(self, name, epoch=None, step=None):
+ if epoch is None or step is None:
+ c_name = name + '_latest.pt'
+ else:
+ c_name = '{}_epoch_{}_minibatch_{}.pt'.format(name, epoch, step)
+ return c_name
+
+ def get_data_loader(self, cfg, distributed=False):
+ train_dataset = ImageDataset(cfg, cfg.dataset.train_data_dir,
+ cfg.dataset.train_data_list)
+ train_dataloader = DataLoader(
+ train_dataset,
+ cfg.train.dataloader,
+ is_train=True,
+ distributed=distributed)
+ test_dataset = ImageDataset(cfg, cfg.dataset.val_data_dir,
+ cfg.dataset.val_data_list)
+ test_dataloader = DataLoader(
+ test_dataset,
+ cfg.test.dataloader,
+ is_train=False,
+ distributed=distributed)
+ return train_dataloader, test_dataloader
+
+ def train(self, local_rank):
+ # Build model for training
+ self.model = self.init_model(local_rank)
+
+ # Build dataloader
+ self.train_data_loader, self.validation_loaders = self.get_data_loader(
+ self.cfg, self.distributed)
+ # Resume model from finetune_path
+ self.steps = 0
+ if self.cfg.train.finetune_path is not None:
+ self.logger.info(f'finetune from {self.cfg.train.finetune_path}')
+ self.restore_model(self.model, self.cfg.train.finetune_path,
+ self.device)
+ epoch = 0
+
+ # Build optimizer
+ optimizer = self.create_optimizer(self.lr)
+
+ self.logger.info('Start Training...')
+
+ self.model.train()
+
+ # Training loop
+ while True:
+ self.logger.info('Training epoch ' + str(epoch))
+ self.total = len(self.train_data_loader)
+ for batch in self.train_data_loader:
+ self.update_learning_rate(optimizer, epoch, self.steps)
+
+ self.train_step(
+ self.model, optimizer, batch, epoch=epoch, step=self.steps)
+
+ # Save interval model
+ self.maybe_save_model(self.model, epoch, self.steps)
+
+ self.steps += 1
+
+ epoch += 1
+ if epoch > self.cfg.train.total_epochs:
+ self.save_checkpoint(self.model, 'final.pt')
+ self.convert_model_for_inference('final.pt',
+ 'pytorch_model.pt')
+ self.logger.info('Training done')
+ break
+
+ def train_step(self, model, optimizer, batch, epoch, step):
+ optimizer.zero_grad()
+
+ results = model.forward(batch, training=True)
+ if len(results) == 2:
+ l, pred = results
+ metrics = {}
+ elif len(results) == 3:
+ l, pred, metrics = results
+
+ if isinstance(l, dict):
+ line = []
+ loss = torch.tensor(0.).cuda()
+ for key, l_val in l.items():
+ loss += l_val.mean()
+ line.append('loss_{0}:{1:.4f}'.format(key, l_val.mean()))
+ else:
+ loss = l.mean()
+ loss.backward()
+ optimizer.step()
+
+ if step % self.cfg.train.miscs.print_interval_iters == 0:
+ if isinstance(l, dict):
+ line = '\t'.join(line)
+ log_info = '\t'.join(
+ ['step:{:6d}', 'epoch:{:3d}', '{}',
+ 'lr:{:.4f}']).format(step, epoch, line, self.current_lr)
+ self.logger.info(log_info)
+ else:
+ self.logger.info('step: %6d, epoch: %3d, loss: %.6f, lr: %f' %
+ (step, epoch, loss.item(), self.current_lr))
+ for name, metric in metrics.items():
+ self.logger.info('%s: %6f' % (name, metric.mean()))
+
+ def init_torch_tensor(self):
+ # Use gpu or not
+ torch.set_default_tensor_type('torch.FloatTensor')
+ if torch.cuda.is_available():
+ self.device = torch.device('cuda')
+ torch.set_default_tensor_type('torch.cuda.FloatTensor')
+ else:
+ self.device = torch.device('cpu')
+
+ def represent(self, batch, _pred, is_output_polygon=False):
+ '''
+ batch: (image, polygons, ignore_tags
+ batch: a dict produced by dataloaders.
+ image: tensor of shape (N, C, H, W).
+ polygons: tensor of shape (N, K, 4, 2), the polygons of objective regions.
+ ignore_tags: tensor of shape (N, K), indicates whether a region is ignorable or not.
+ shape: the original shape of images.
+ filename: the original filenames of images.
+ pred:
+ binary: text region segmentation map, with shape (N, 1, H, W)
+ thresh: [if exists] thresh hold prediction with shape (N, 1, H, W)
+ thresh_binary: [if exists] binarized with threshhold, (N, 1, H, W)
+ '''
+ images = batch['image']
+ if isinstance(_pred, dict):
+ pred = _pred['binary']
+ else:
+ pred = _pred
+ segmentation = pred > self.cfg.test.thresh
+ boxes_batch = []
+ scores_batch = []
+ for batch_index in range(images.size(0)):
+ height, width = batch['shape'][batch_index]
+ if is_output_polygon:
+ boxes, scores = polygons_from_bitmap(pred[batch_index],
+ segmentation[batch_index],
+ width, height)
+ else:
+ boxes, scores = boxes_from_bitmap(pred[batch_index],
+ segmentation[batch_index],
+ width, height)
+ boxes_batch.append(boxes)
+ scores_batch.append(scores)
+ return boxes_batch, scores_batch
+
+ def evaluate(self, local_rank):
+ self.init_torch_tensor()
+ # Build model for evaluation
+ model = self.init_model(local_rank)
+
+ # Restore model from checkpoint_path
+ self.restore_model(model, self.cfg.test.checkpoint_path, self.device)
+
+ # Build dataloader for evaluation
+ self.train_data_loader, self.validation_loaders = self.get_data_loader(
+ self.cfg, self.distributed)
+ # Build evaluation metric
+ quad_measurer = QuadMeasurer()
+ model.eval()
+
+ with torch.no_grad():
+ raw_metrics = []
+ for i, batch in tqdm(
+ enumerate(self.validation_loaders),
+ total=len(self.validation_loaders)):
+ pred = model.forward(batch, training=False)
+ output = self.represent(batch, pred,
+ self.cfg.test.return_polygon)
+ raw_metric = quad_measurer.validate_measure(
+ batch,
+ output,
+ is_output_polygon=self.cfg.test.return_polygon,
+ box_thresh=0.3)
+ raw_metrics.append(raw_metric)
+ metrics = quad_measurer.gather_measure(raw_metrics)
+ for key, metric in metrics.items():
+ self.logger.info('%s : %f (%d)' %
+ (key, metric.avg, metric.count))
+
+ self.logger.info('Evaluation done')
diff --git a/modelscope/trainers/cv/ocr_recognition_trainer.py b/modelscope/trainers/cv/ocr_recognition_trainer.py
new file mode 100644
index 00000000..fe7b6ef7
--- /dev/null
+++ b/modelscope/trainers/cv/ocr_recognition_trainer.py
@@ -0,0 +1,84 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import time
+from collections.abc import Mapping
+
+import torch
+from torch import distributed as dist
+
+from modelscope.metainfo import Trainers
+from modelscope.trainers.builder import TRAINERS
+from modelscope.trainers.trainer import EpochBasedTrainer
+from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, ConfigFields,
+ ConfigKeys, Hubs, ModeKeys, ModelFile,
+ Tasks, TrainerStages)
+from modelscope.utils.data_utils import to_device
+from modelscope.utils.file_utils import func_receive_dict_inputs
+
+
+@TRAINERS.register_module(module_name=Trainers.ocr_recognition)
+class OCRRecognitionTrainer(EpochBasedTrainer):
+
+ def evaluate(self, *args, **kwargs):
+ metric_values = super().evaluate(*args, **kwargs)
+ return metric_values
+
+ def prediction_step(self, model, inputs):
+ pass
+
+ def train_step(self, model, inputs):
+ """ Perform a training step on a batch of inputs.
+
+ Subclass and override to inject custom behavior.
+
+ Args:
+ model (`TorchModel`): The model to train.
+ inputs (`Dict[str, Union[torch.Tensor, Any]]`):
+ The inputs and targets of the model.
+
+ The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
+ argument `labels`. Check your model's documentation for all accepted arguments.
+
+ Return:
+ `torch.Tensor`: The tensor with training loss on this batch.
+ """
+ # EvaluationHook will do evaluate and change mode to val, return to train mode
+ # TODO: find more pretty way to change mode
+ model.train()
+ self._mode = ModeKeys.TRAIN
+ train_outputs = model.do_step(inputs)
+
+ if not isinstance(train_outputs, dict):
+ raise TypeError('"model.forward()" must return a dict')
+
+ # add model output info to log
+ if 'log_vars' not in train_outputs:
+ default_keys_pattern = ['loss']
+ match_keys = set([])
+ for key_p in default_keys_pattern:
+ match_keys.update(
+ [key for key in train_outputs.keys() if key_p in key])
+
+ log_vars = {}
+ for key in match_keys:
+ value = train_outputs.get(key, None)
+ if value is not None:
+ if dist.is_available() and dist.is_initialized():
+ value = value.data.clone()
+ dist.all_reduce(value.div_(dist.get_world_size()))
+ log_vars.update({key: value.item()})
+ self.log_buffer.update(log_vars)
+ else:
+ self.log_buffer.update(train_outputs['log_vars'])
+
+ self.train_outputs = train_outputs
+
+ def evaluation_step(self, data):
+ """Perform a evaluation step on a batch of inputs.
+
+ Subclass and override to inject custom behavior.
+
+ """
+ model = self.model.module if self._dist else self.model
+ model.eval()
+ result = model.do_step(data)
+ return result
diff --git a/modelscope/trainers/cv/vision_efficient_tuning_trainer.py b/modelscope/trainers/cv/vision_efficient_tuning_trainer.py
new file mode 100644
index 00000000..4c7dca73
--- /dev/null
+++ b/modelscope/trainers/cv/vision_efficient_tuning_trainer.py
@@ -0,0 +1,114 @@
+# Copyright 2022-2023 The Alibaba Fundamental Vision Team Authors. All rights reserved.
+from typing import Union
+
+from torch import nn
+
+from modelscope.metainfo import Trainers
+from modelscope.models.base import Model, TorchModel
+from modelscope.trainers.builder import TRAINERS
+from modelscope.trainers.default_config import merge_hooks
+from modelscope.trainers.trainer import EpochBasedTrainer
+from modelscope.utils.constant import ModeKeys
+
+
+@TRAINERS.register_module(module_name=Trainers.vision_efficient_tuning)
+class VisionEfficientTuningTrainer(EpochBasedTrainer):
+ """ Vision Efficient Tuning Trainer based on EpochBasedTrainer
+
+ The trainer freezes the parameters of the pre-trained model and
+ tunes the extra parameters of the different parameter-efficient
+ transfer learning (PETL) method.
+
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def build_model(self) -> Union[nn.Module, TorchModel]:
+ """ Instantiate a pytorch model and return.
+
+ By default, we will create a model using config from configuration file. You can
+ override this method in a subclass.
+
+ """
+ model = Model.from_pretrained(self.model_dir, cfg_dict=self.cfg)
+ if 'freeze_cfg' in self.cfg['model']:
+ model = self.freeze(model, **self.cfg['model']['freeze_cfg'])
+ if not isinstance(model, nn.Module) and hasattr(model, 'model'):
+ return model.model
+ elif isinstance(model, nn.Module):
+ return model
+
+ def train(self, *args, **kwargs):
+ self.print_model_params_status()
+ super().train(*args, **kwargs)
+
+ def evaluate(self, *args, **kwargs):
+ metric_values = super().evaluate(*args, **kwargs)
+ return metric_values
+
+ def freeze(self, model, freeze_part=[], train_part=[]):
+ """ Freeze or train the model based on the config.
+
+ Args:
+ model: the current model.
+ freeze_part: the config of frozen parameters.
+ train_part: the config of trainable parameters.
+ """
+ if hasattr(model, 'module'):
+ freeze_model = model.module
+ else:
+ freeze_model = model
+
+ if freeze_part and len(freeze_part) > 0:
+ if 'backbone' in freeze_part:
+ part = freeze_part['backbone']
+ for name, param in freeze_model.model.backbone.named_parameters(
+ ):
+ freeze_flag = sum([p in name for p in part]) > 0
+ if freeze_flag:
+ param.requires_grad = False
+ elif 'head' in freeze_part:
+ part = freeze_part['head']
+ for name, param in freeze_model.model.head.named_parameters():
+ freeze_flag = sum([p in name for p in part]) > 0
+ if freeze_flag:
+ param.requires_grad = False
+
+ if train_part and len(train_part) > 0:
+ if 'backbone' in train_part:
+ part = train_part['backbone']
+ for name, param in freeze_model.model.backbone.named_parameters(
+ ):
+ freeze_flag = sum([p in name for p in part]) > 0
+ if freeze_flag:
+ param.requires_grad = True
+ elif 'head' in train_part:
+ part = train_part['head']
+ for name, param in freeze_model.model.head.named_parameters():
+ freeze_flag = sum([p in name for p in part]) > 0
+ if freeze_flag:
+ param.requires_grad = True
+ return model
+
+ def print_model_params_status(self, model=None, logger=None):
+ """Print the status and parameters of the model"""
+ if model is None:
+ model = self.model
+ if logger is None:
+ logger = self.logger
+ train_param_dict = {}
+ all_param_numel = 0
+ for key, val in model.named_parameters():
+ if val.requires_grad:
+ sub_key = '.'.join(key.split('.', 1)[-1].split('.', 2)[:2])
+ if sub_key in train_param_dict:
+ train_param_dict[sub_key] += val.numel()
+ else:
+ train_param_dict[sub_key] = val.numel()
+ all_param_numel += val.numel()
+ train_param_numel = sum(train_param_dict.values())
+ logger.info(
+ f'Load trainable params {train_param_numel} / {all_param_numel} = '
+ f'{train_param_numel/all_param_numel:.2%}, '
+ f'train part: {train_param_dict}.')
diff --git a/modelscope/trainers/default_config.py b/modelscope/trainers/default_config.py
index 7619633f..5f9aa625 100644
--- a/modelscope/trainers/default_config.py
+++ b/modelscope/trainers/default_config.py
@@ -1,6 +1,6 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
-from typing import Dict, List, Optional
+from typing import Dict, List, Optional, Tuple
from modelscope.utils.config import Config
@@ -21,10 +21,11 @@ DEFAULT_CONFIG = Config({
'type': 'StepLR',
'step_size': 2
},
- 'hooks': [{
- 'type': 'CheckpointHook',
- 'interval': 1
- }]
+ 'checkpoint': {
+ 'period': {
+ 'interval': 1
+ }
+ }
},
'evaluation': {
'dataloader': {
@@ -49,6 +50,13 @@ DEFAULT_HOOKS_CONFIG = {
}
}
+_HOOK_KEY_CHAIN_MAP = {
+ 'TextLoggerHook': 'train.logging',
+ 'CheckpointHook': 'train.checkpoint.period',
+ 'BestCkptSaverHook': 'train.checkpoint.best',
+ 'EvaluationHook': 'evaluation.period',
+}
+
def merge_cfg(cfg: Config):
"""Merge the default config into the input cfg.
@@ -62,20 +70,31 @@ def merge_cfg(cfg: Config):
def merge_hooks(cfg: Config) -> List[Dict]:
- key_chain_hook_map = {
- 'train.logging': 'TextLoggerHook',
- 'train.checkpoint.period': 'CheckpointHook',
- 'train.checkpoint.best': 'BestCkptSaverHook',
- 'evaluation.period': 'EvaluationHook'
- }
hooks = cfg.train.hooks.copy()
- for key_chain, hook_type in key_chain_hook_map.items():
+ for hook_type, key_chain in _HOOK_KEY_CHAIN_MAP.items():
hook = _key_chain_to_hook(cfg, key_chain, hook_type)
if hook is not None:
hooks.append(hook)
return hooks
+def update_cfg(cfg: Config) -> Config:
+ if 'hooks' not in cfg.train:
+ return cfg
+ key_chain_map = {}
+ for hook in cfg.train.hooks:
+ if not hook:
+ continue
+ key, value = _hook_split(hook)
+ if key not in _HOOK_KEY_CHAIN_MAP:
+ continue
+ key_chain_map[_HOOK_KEY_CHAIN_MAP[key]] = value
+ hook.clear()
+ cfg.train.hooks = list(filter(bool, cfg.train.hooks))
+ cfg.merge_from_dict(key_chain_map)
+ return cfg
+
+
def _key_chain_to_hook(cfg: Config, key_chain: str,
hook_type: str) -> Optional[Dict]:
if not _check_basic_hook(cfg, key_chain, hook_type):
@@ -95,3 +114,8 @@ def _check_basic_hook(cfg: Config, key_chain: str, hook_type: str) -> bool:
f'cannot exist at the same time, ' \
f'please delete {hook_type} in the configuration file.'
return True
+
+
+def _hook_split(hook: Dict) -> Tuple[str, Dict]:
+ hook = hook.copy()
+ return hook.pop('type'), hook
diff --git a/modelscope/trainers/easycv/trainer.py b/modelscope/trainers/easycv/trainer.py
index a1ad0649..58d6a440 100644
--- a/modelscope/trainers/easycv/trainer.py
+++ b/modelscope/trainers/easycv/trainer.py
@@ -88,22 +88,6 @@ class EasyCVEpochBasedTrainer(EpochBasedTrainer):
collate,
samples_per_gpu=self.cfg.evaluation.dataloader.batch_size_per_gpu)
- # Register easycv hooks dynamicly. If the hook already exists in modelscope,
- # the hook in modelscope will be used, otherwise register easycv hook into ms.
- # We must manually trigger lazy import to detect whether the hook is in modelscope.
- # TODO: use ast index to detect whether the hook is in modelscope
- for h_i in self.cfg.train.get('hooks', []):
- sig = ('HOOKS', default_group, h_i['type'])
- LazyImportModule.import_module(sig)
- if h_i['type'] not in HOOKS._modules[default_group]:
- if h_i['type'] in [
- 'TensorboardLoggerHookV2', 'WandbLoggerHookV2'
- ]:
- raise ValueError(
- 'Not support hook %s now, we will support it in the future!'
- % h_i['type'])
- register_util.register_hook_to_ms(h_i['type'], self.logger)
-
# load pretrained model
load_from = self.cfg.get('load_from', None)
if load_from is not None:
@@ -125,6 +109,25 @@ class EasyCVEpochBasedTrainer(EpochBasedTrainer):
device_ids=[torch.cuda.current_device()])
self.model = build_parallel(dp_cfg)
+ def rebuild_config(self, cfg: Config):
+ cfg = super().rebuild_config(cfg)
+ # Register easycv hooks dynamicly. If the hook already exists in modelscope,
+ # the hook in modelscope will be used, otherwise register easycv hook into ms.
+ # We must manually trigger lazy import to detect whether the hook is in modelscope.
+ # TODO: use ast index to detect whether the hook is in modelscope
+ for h_i in cfg.train.get('hooks', []):
+ sig = ('HOOKS', default_group, h_i['type'])
+ LazyImportModule.import_module(sig)
+ if h_i['type'] not in HOOKS._modules[default_group]:
+ if h_i['type'] in [
+ 'TensorboardLoggerHookV2', 'WandbLoggerHookV2'
+ ]:
+ raise ValueError(
+ 'Not support hook %s now, we will support it in the future!'
+ % h_i['type'])
+ register_util.register_hook_to_ms(h_i['type'])
+ return cfg
+
def create_optimizer_and_scheduler(self):
""" Create optimizer and lr scheduler
"""
diff --git a/modelscope/trainers/hooks/checkpoint_hook.py b/modelscope/trainers/hooks/checkpoint_hook.py
index 57698f3a..b7ff2bc5 100644
--- a/modelscope/trainers/hooks/checkpoint_hook.py
+++ b/modelscope/trainers/hooks/checkpoint_hook.py
@@ -2,19 +2,16 @@
import os
import random
import re
-from shutil import rmtree
import numpy as np
import torch
from packaging import version
-from modelscope import __version__
from modelscope.metainfo import Hooks, Pipelines
from modelscope.utils.checkpoint import (load_checkpoint, save_checkpoint,
save_configuration)
from modelscope.utils.constant import LogKeys, ModelFile
from modelscope.utils.logger import get_logger
-from modelscope.utils.megatron_utils import is_megatron_initialized
from modelscope.utils.torch_utils import is_master
from .builder import HOOKS
from .hook import Hook
diff --git a/modelscope/trainers/hooks/ddp_hook.py b/modelscope/trainers/hooks/ddp_hook.py
new file mode 100644
index 00000000..eaae2d89
--- /dev/null
+++ b/modelscope/trainers/hooks/ddp_hook.py
@@ -0,0 +1,43 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+from modelscope.metainfo import Hooks
+from modelscope.utils.constant import DistributedParallelType
+from modelscope.utils.device import create_device
+from modelscope.utils.torch_utils import get_local_rank, init_dist
+from .builder import HOOKS
+from .hook import Hook
+from .priority import Priority
+
+
+@HOOKS.register_module(module_name=Hooks.DDPHook)
+class DDPHook(Hook):
+
+ PRIORITY = Priority.LOW
+
+ def __init__(self, launcher):
+ """The DDP Hook for data parallel
+
+ Args:
+ launcher(str, required): The launcher info, can be 'pytorch' or 'mpi' or 'slurm'
+ """
+ assert launcher is not None
+ self.launcher = launcher
+ self.wrapped = False
+ # TODO support single GPU evaluate & multi GPU train
+
+ def after_init(self, trainer):
+ init_dist(self.launcher)
+ local_rank = get_local_rank()
+ trainer.device = create_device(f'cuda:{local_rank}')
+ trainer.model.to(trainer.device)
+ trainer.parallel_groups[DistributedParallelType.DP] = None
+
+ def before_run(self, trainer):
+ self.wrap_module(trainer)
+
+ def before_val(self, trainer):
+ self.wrap_module(trainer)
+
+ def wrap_module(self, trainer):
+ if not self.wrapped:
+ trainer.model = trainer.to_parallel(trainer.model)
+ self.wrapped = True
diff --git a/modelscope/trainers/hooks/deepspeed_hook.py b/modelscope/trainers/hooks/deepspeed_hook.py
index 3f423059..a34b3f6f 100644
--- a/modelscope/trainers/hooks/deepspeed_hook.py
+++ b/modelscope/trainers/hooks/deepspeed_hook.py
@@ -138,6 +138,9 @@ class DeepspeedHook(MegatronHook):
checkpoint, strict=strict)
return meta
+ def before_val(self, trainer):
+ pass
+
def before_run(self, trainer):
if not hasattr(trainer, 'logger'):
self.logger = get_logger()
diff --git a/modelscope/trainers/hooks/hook.py b/modelscope/trainers/hooks/hook.py
index 02ab249d..70e06fbd 100644
--- a/modelscope/trainers/hooks/hook.py
+++ b/modelscope/trainers/hooks/hook.py
@@ -12,20 +12,28 @@ class Hook:
The Hook base class of any modelscope trainer. You can build your own hook inherited from this class.
"""
- stages = (TrainerStages.before_run, TrainerStages.before_train_epoch,
+ stages = (TrainerStages.after_init, TrainerStages.before_run,
+ TrainerStages.before_val, TrainerStages.before_train_epoch,
TrainerStages.before_train_iter, TrainerStages.after_train_iter,
TrainerStages.after_train_epoch, TrainerStages.before_val_epoch,
TrainerStages.before_val_iter, TrainerStages.after_val_iter,
- TrainerStages.after_val_epoch, TrainerStages.after_run)
+ TrainerStages.after_val_epoch, TrainerStages.after_run,
+ TrainerStages.after_val)
PRIORITY = Priority.NORMAL
# The strategic function dict.
_strategies = dict()
+ def after_init(self, trainer):
+ """
+ Will be called at the end of the trainer's `__init__` method
+ """
+ pass
+
def before_run(self, trainer):
"""
- Will be called before any loop begins.
+ Will be called before trainer loop begins.
Args:
trainer: The trainer instance.
@@ -36,7 +44,29 @@ class Hook:
def after_run(self, trainer):
"""
- Will be called after all loops end.
+ Will be called after trainer loop end.
+ Args:
+ trainer: The trainer instance.
+
+ Returns: None
+
+ """
+ pass
+
+ def before_val(self, trainer):
+ """
+ Will be called before eval loop begins.
+ Args:
+ trainer: The trainer instance.
+
+ Returns: None
+
+ """
+ pass
+
+ def after_val(self, trainer):
+ """
+ Will be called after eval loop end.
Args:
trainer: The trainer instance.
diff --git a/modelscope/trainers/hooks/megatron_hook.py b/modelscope/trainers/hooks/megatron_hook.py
index fbb77e1c..601c1cae 100644
--- a/modelscope/trainers/hooks/megatron_hook.py
+++ b/modelscope/trainers/hooks/megatron_hook.py
@@ -1,4 +1,5 @@
import os
+from copy import deepcopy
import torch
from megatron_util import mpu
@@ -6,7 +7,12 @@ from megatron_util import mpu
from modelscope.metainfo import Hooks
from modelscope.trainers.hooks.builder import HOOKS
from modelscope.trainers.hooks.hook import Hook
+from modelscope.trainers.parallel.builder import build_parallel
from modelscope.utils.checkpoint import load_checkpoint, save_checkpoint
+from modelscope.utils.constant import DistributedParallelType
+from modelscope.utils.device import create_device
+from modelscope.utils.megatron_utils import is_megatron_initialized
+from modelscope.utils.torch_utils import get_local_rank
from .checkpoint_hook import CheckpointHook, LoadCheckpointHook
@@ -15,6 +21,9 @@ class MegatronHook(Hook):
_BIN_FILE_DIR = 'model'
+ def __init__(self):
+ self.wrapped = False
+
def register_strategy(self):
Hook.overload(
name='CheckpointHook.should_save_on_rank',
@@ -31,6 +40,30 @@ class MegatronHook(Hook):
Hook.overload(
name='CheckpointHook.prepare_output', function=self.prepare_output)
+ def after_init(self, trainer):
+ assert is_megatron_initialized()
+ local_rank = get_local_rank()
+ trainer.device = create_device(f'cuda:{local_rank}')
+ trainer.model.to(trainer.device)
+ trainer.parallel_groups[
+ DistributedParallelType.DP] = mpu.get_data_parallel_group()
+ trainer.parallel_groups[DistributedParallelType.
+ TP] = mpu.get_tensor_model_parallel_group()
+ trainer.parallel_groups[DistributedParallelType.
+ PP] = mpu.get_pipeline_model_parallel_group()
+
+ def before_run(self, trainer):
+ self.wrap_module(trainer)
+
+ def before_val(self, trainer):
+ self.wrap_module(trainer)
+
+ def wrap_module(self, trainer):
+ if trainer._dist:
+ if not self.wrapped:
+ trainer.model = trainer.to_parallel(trainer.model)
+ self.wrapped = True
+
def should_save_on_rank(self, trainer):
# TODO
return (not torch.distributed.is_initialized()
diff --git a/modelscope/trainers/lrscheduler/builder.py b/modelscope/trainers/lrscheduler/builder.py
index 3a892001..e1827383 100644
--- a/modelscope/trainers/lrscheduler/builder.py
+++ b/modelscope/trainers/lrscheduler/builder.py
@@ -1,6 +1,9 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import inspect
+import torch
+from packaging import version
+
from modelscope.utils.config import ConfigDict
from modelscope.utils.registry import Registry, build_from_cfg, default_group
@@ -35,7 +38,10 @@ def build_lr_scheduler(cfg: ConfigDict, default_args: dict = None):
def register_torch_lr_scheduler():
from torch.optim import lr_scheduler
- from torch.optim.lr_scheduler import _LRScheduler
+ if version.parse(torch.__version__) < version.parse('2.0.0.dev'):
+ from torch.optim.lr_scheduler import _LRScheduler
+ else:
+ from torch.optim.lr_scheduler import LRScheduler as _LRScheduler
members = inspect.getmembers(lr_scheduler)
diff --git a/modelscope/trainers/nlp/__init__.py b/modelscope/trainers/nlp/__init__.py
index 125e82c6..755e5387 100644
--- a/modelscope/trainers/nlp/__init__.py
+++ b/modelscope/trainers/nlp/__init__.py
@@ -9,13 +9,15 @@ if TYPE_CHECKING:
from .text_ranking_trainer import TextRankingTrainer
from .text_generation_trainer import TextGenerationTrainer
from .sentence_embedding_trainer import SentenceEmbeddingTrainer
+ from .siamese_uie_trainer import SiameseUIETrainer
else:
_import_structure = {
'sequence_classification_trainer': ['SequenceClassificationTrainer'],
'csanmt_translation_trainer': ['CsanmtTranslationTrainer'],
'text_ranking_trainer': ['TextRankingTrainer'],
'text_generation_trainer': ['TextGenerationTrainer'],
- 'sentence_emebedding_trainer': ['SentenceEmbeddingTrainer']
+ 'sentence_emebedding_trainer': ['SentenceEmbeddingTrainer'],
+ 'siamese_uie_trainer': ['SiameseUIETrainer']
}
import sys
diff --git a/modelscope/trainers/nlp/gpt3_trainer.py b/modelscope/trainers/nlp/gpt3_trainer.py
index 22f244eb..ee5fbfae 100644
--- a/modelscope/trainers/nlp/gpt3_trainer.py
+++ b/modelscope/trainers/nlp/gpt3_trainer.py
@@ -1,13 +1,20 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
-from typing import Any, Dict, List
+from copy import deepcopy
+from typing import Any, Dict, List, Union
+
+import torch
+from torch import nn
from modelscope.metainfo import Trainers
+from modelscope.models.base import Model, TorchModel
from modelscope.models.nlp import GPT3ForTextGeneration
from modelscope.trainers.builder import TRAINERS
from modelscope.trainers.nlp_trainer import NlpEpochBasedTrainer
+from modelscope.trainers.parallel.builder import build_parallel
from modelscope.utils.config import Config
+from modelscope.utils.megatron_utils import is_megatron_initialized
@TRAINERS.register_module(module_name=Trainers.gpt3_trainer)
@@ -18,6 +25,29 @@ class GPT3Trainer(NlpEpochBasedTrainer):
cfg.model.rank = int(os.environ.get('RANK', 0))
return cfg
+ def to_parallel(self, model) -> Union[nn.Module, TorchModel]:
+ # config format to reserve custom ddp
+ if self.cfg.get('parallel', None) is not None:
+ dp_cfg = deepcopy(self.cfg['parallel'])
+ dp_cfg.update(
+ dict(module=model, device_ids=[torch.cuda.current_device()]))
+ return build_parallel(dp_cfg)
+
+ dp_cfg = dict(
+ type='DistributedDataParallel',
+ module=model,
+ find_unused_parameters=True,
+ device_ids=[torch.cuda.current_device()])
+
+ if is_megatron_initialized():
+ from megatron_util import mpu
+ dp_cfg.update({
+ 'output_device': torch.cuda.current_device(),
+ 'process_group': mpu.get_data_parallel_group()
+ })
+
+ return build_parallel(dp_cfg)
+
def _decode(self, tokens):
tokenizer = self.eval_preprocessor.tokenizer
return tokenizer.detokenize(tokens.tolist())
@@ -51,3 +81,7 @@ class GPT3Trainer(NlpEpochBasedTrainer):
def _forward_eval(self, model: GPT3ForTextGeneration,
data: Dict[str, Any]) -> Dict[str, Any]:
return model.forward(data)
+
+ def build_model(self) -> TorchModel:
+ return Model.from_pretrained(
+ self.model_dir, cfg_dict=self.cfg, megatron_cfg=self.cfg.megatron)
diff --git a/modelscope/trainers/nlp/siamese_uie_trainer.py b/modelscope/trainers/nlp/siamese_uie_trainer.py
new file mode 100644
index 00000000..e3289976
--- /dev/null
+++ b/modelscope/trainers/nlp/siamese_uie_trainer.py
@@ -0,0 +1,377 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+
+import os
+import random
+import time
+from collections import defaultdict
+from math import ceil
+from typing import Callable, Dict, List, Optional, Tuple, Union
+
+import json
+import numpy as np
+import torch
+from torch import distributed as dist
+from torch import nn
+from torch.utils.data import Dataset
+
+from modelscope.metainfo import Trainers
+from modelscope.models.base import TorchModel
+from modelscope.msdatasets import MsDataset
+from modelscope.pipelines import pipeline
+from modelscope.preprocessors.base import Preprocessor
+from modelscope.trainers import EpochBasedTrainer, NlpEpochBasedTrainer
+from modelscope.trainers.builder import TRAINERS
+from modelscope.trainers.optimizer.builder import build_optimizer
+from modelscope.utils.config import Config
+from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModeKeys, Tasks
+from modelscope.utils.file_utils import func_receive_dict_inputs
+from modelscope.utils.logger import get_logger
+from ..parallel.utils import is_parallel
+
+PATH = None
+logger = get_logger(PATH)
+
+os.environ['TOKENIZERS_PARALLELISM'] = 'true'
+
+
+@TRAINERS.register_module(module_name=Trainers.siamese_uie_trainer)
+class SiameseUIETrainer(EpochBasedTrainer):
+
+ def __init__(
+ self,
+ model: Optional[Union[TorchModel, nn.Module, str]] = None,
+ cfg_file: Optional[str] = None,
+ cfg_modify_fn: Optional[Callable] = None,
+ train_dataset: Optional[Union[MsDataset, Dataset]] = None,
+ eval_dataset: Optional[Union[MsDataset, Dataset]] = None,
+ preprocessor: Optional[Union[Preprocessor,
+ Dict[str, Preprocessor]]] = None,
+ optimizers: Tuple[torch.optim.Optimizer,
+ torch.optim.lr_scheduler._LRScheduler] = (None,
+ None),
+ model_revision: Optional[str] = DEFAULT_MODEL_REVISION,
+ seed: int = 42,
+ negative_sampling_rate=1,
+ slide_len=352,
+ max_len=384,
+ hint_max_len=128,
+ **kwargs):
+ """Epoch based Trainer, a training helper for PyTorch.
+
+ Args:
+ model (:obj:`torch.nn.Module` or :obj:`TorchModel` or `str`): The model to be run, or a valid model dir
+ or a model id. If model is None, build_model method will be called.
+ cfg_file(str): The local config file.
+ cfg_modify_fn (function): Optional[Callable] = None, config function
+ train_dataset (`MsDataset` or `torch.utils.data.Dataset`, *optional*):
+ The dataset to use for training.
+
+ Note that if it's a `torch.utils.data.IterableDataset` with some randomization and you are training in a
+ distributed fashion, your iterable dataset should either use a internal attribute `generator` that is a
+ `torch.Generator` for the randomization that must be identical on all processes (and the Trainer will
+ manually set the seed of this `generator` at each epoch) or have a `set_epoch()` method that internally
+ sets the seed of the RNGs used.
+ eval_dataset (`MsDataset` or `torch.utils.data.Dataset`, *optional*): The dataset to use for evaluation.
+ preprocessor (:obj:`Preprocessor`, *optional*): The optional preprocessor.
+ NOTE: If the preprocessor has been called before the dataset fed into this
+ trainer by user's custom code,
+ this parameter should be None, meanwhile remove the 'preprocessor' key from the cfg_file.
+ Else the preprocessor will be instantiated from the cfg_file or assigned from this parameter and
+ this preprocessing action will be executed every time the dataset's __getitem__ is called.
+ optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler._LRScheduler]`, *optional*): A tuple
+ containing the optimizer and the scheduler to use.
+ model_revision (str): The model version to use in modelhub.
+ negative_sampling_rate (float): The rate to do negative sampling.
+ slide_len (int): The length to slide.
+ max_len (int): The max length of prompt + text.
+ hint_max_len (int): The max length of prompt.
+ seed (int): The optional random seed for torch, cuda, numpy and random.
+ """
+ print('*******************')
+ self.slide_len = slide_len
+ self.max_len = max_len
+ self.hint_max_len = hint_max_len
+ self.negative_sampling_rate = negative_sampling_rate
+
+ super().__init__(
+ model=model,
+ cfg_file=cfg_file,
+ cfg_modify_fn=cfg_modify_fn,
+ data_collator=self._nn_collate_fn,
+ train_dataset=train_dataset,
+ eval_dataset=eval_dataset,
+ preprocessor=preprocessor,
+ optimizers=optimizers,
+ model_revision=model_revision,
+ seed=seed,
+ **kwargs)
+
+ def build_dataset(self,
+ datasets: Union[torch.utils.data.Dataset, MsDataset,
+ List[torch.utils.data.Dataset]],
+ model_cfg: Config,
+ mode: str,
+ preprocessor: Optional[Preprocessor] = None,
+ **kwargs):
+ if mode == ModeKeys.TRAIN:
+ datasets = self.load_dataset(datasets)
+ return super(SiameseUIETrainer, self).build_dataset(
+ datasets=datasets,
+ model_cfg=self.cfg,
+ mode=mode,
+ preprocessor=preprocessor,
+ **kwargs)
+
+ def get_train_dataloader(self):
+ """ Builder torch dataloader for training.
+
+ We provide a reasonable default that works well. If you want to use something else, you can change
+ the config for data.train in configuration file, or subclass and override this method
+ (or `get_train_dataloader` in a subclass.
+ """
+ self.train_dataset.preprocessor = None
+ data_loader = self._build_dataloader_with_dataset(
+ self.train_dataset,
+ dist=self._dist,
+ seed=self._seed,
+ collate_fn=self.train_data_collator,
+ **self.cfg.train.get('dataloader', {}))
+ return data_loader
+
+ def get_brother_type_map(self, schema, brother_type_map, prefix_types):
+ if not schema:
+ return
+ for k in schema:
+ brother_type_map[tuple(prefix_types
+ + [k])] += [v for v in schema if v != k]
+ self.get_brother_type_map(schema[k], brother_type_map,
+ prefix_types + [k])
+
+ def load_dataset(self, raw_dataset):
+ data = []
+ for num_line, raw_sample in enumerate(raw_dataset):
+ raw_sample['info_list'] = json.loads(raw_sample['info_list'])
+ raw_sample['schema'] = json.loads(raw_sample['schema'])
+ hint_spans_map = defaultdict(list)
+ # positive sampling
+ for info in raw_sample['info_list']:
+ hint = ''
+ for item in info:
+ hint += f'{item["type"]}: '
+ span = {'span': item['span'], 'offset': item['offset']}
+ if span not in hint_spans_map[hint]:
+ hint_spans_map[hint].append(span)
+ hint += f'{item["span"]}, '
+ # negative sampling
+ brother_type_map = defaultdict(list)
+ self.get_brother_type_map(raw_sample['schema'], brother_type_map,
+ [])
+
+ for info in raw_sample['info_list']:
+ hint = ''
+ for i, item in enumerate(info):
+ key = tuple([info[j]['type'] for j in range(i + 1)])
+ for st in brother_type_map.get(key, []):
+ neg_hint = hint + f'{st}: '
+ if neg_hint not in hint_spans_map and random.random(
+ ) < self.negative_sampling_rate:
+ hint_spans_map[neg_hint] = []
+ hint += f'{item["type"]}: '
+ hint += f'{item["span"]}, '
+ # info list为空
+ for k in raw_sample['schema']:
+ neg_hint = f'{k}: '
+ if neg_hint not in hint_spans_map and random.random(
+ ) < self.negative_sampling_rate:
+ hint_spans_map[neg_hint] = []
+
+ for i, hint in enumerate(hint_spans_map):
+ sample = {
+ 'id': f'{raw_sample["id"]}-{i}',
+ 'hint': hint,
+ 'text': raw_sample['text'],
+ 'spans': hint_spans_map[hint]
+ }
+ uuid = sample['id']
+ text = sample['text']
+ tokenized_input = self.train_preprocessor([text])[0]
+ tokenized_hint = self.train_preprocessor(
+ [hint], max_length=self.hint_max_len, truncation=True)[0]
+ sample['offsets'] = tokenized_input.offsets
+ entities = sample.get('spans', [])
+ head_labels, tail_labels = self._get_labels(
+ text, tokenized_input, sample['offsets'], entities)
+
+ split_num = ceil(
+ (len(tokenized_input) - self.max_len) / self.slide_len
+ ) + 1 if len(tokenized_input) > self.max_len else 1
+ for j in range(split_num):
+ a, b = j * self.slide_len, j * self.slide_len + self.max_len
+ item = {
+ 'id': uuid,
+ 'shift': a,
+ 'tokens': tokenized_input.tokens[a:b],
+ 'token_ids': tokenized_input.ids[a:b],
+ 'hint_tokens': tokenized_hint.tokens,
+ 'hint_token_ids': tokenized_hint.ids,
+ 'attention_masks': tokenized_input.attention_mask[a:b],
+ 'cross_attention_masks': tokenized_hint.attention_mask,
+ 'head_labels': head_labels[a:b],
+ 'tail_labels': tail_labels[a:b]
+ }
+ data.append(item)
+
+ from datasets import Dataset
+ train_dataset = Dataset.from_list(data)
+ for index in random.sample(range(len(train_dataset)), 3):
+ logger.info(
+ f'Sample {index} of the training set: {train_dataset[index]}.')
+ return train_dataset
+
+ def _get_labels(self, text, tokenized_input, offsets, entities):
+ num_tokens = len(tokenized_input)
+ head_labels = [0] * num_tokens
+ tail_labels = [0] * num_tokens
+ char_index_to_token_index_map = {}
+ for i in range(len(offsets)):
+ offset = offsets[i]
+ for j in range(offset[0], offset[1]):
+ char_index_to_token_index_map[j] = i
+ for e in entities:
+ h, t = e['offset']
+ t -= 1
+ while h not in char_index_to_token_index_map:
+ h += 1
+ if h > len(text):
+ print('h', e['offset'], e['span'],
+ text[e['offset'][0]:e['offset'][1]])
+ break
+ while t not in char_index_to_token_index_map:
+ t -= 1
+ if t < 0:
+ print('t', e['offset'], e['span'],
+ text[e['offset'][0]:e['offset'][1]])
+ break
+ if h > len(text) or t < 0:
+ continue
+ token_head = char_index_to_token_index_map[h]
+ token_tail = char_index_to_token_index_map[t]
+ head_labels[token_head] = 1
+ tail_labels[token_tail] = 1
+ return head_labels, tail_labels
+
+ def _padding(self, data, val=0):
+ res = []
+ for seq in data:
+ res.append(seq + [val] * (self.max_len - len(seq)))
+ return res
+
+ def _nn_collate_fn(self, batch):
+ token_ids = torch.tensor(
+ self._padding([item['token_ids'] for item in batch]),
+ dtype=torch.long)
+ hint_token_ids = torch.tensor(
+ self._padding([item['hint_token_ids'] for item in batch]),
+ dtype=torch.long)
+ attention_masks = torch.tensor(
+ self._padding([item['attention_masks'] for item in batch]),
+ dtype=torch.long)
+ cross_attention_masks = torch.tensor(
+ self._padding([item['cross_attention_masks'] for item in batch]),
+ dtype=torch.long)
+ head_labels = torch.tensor(
+ self._padding([item['head_labels'] for item in batch]),
+ dtype=torch.float)
+ tail_labels = torch.tensor(
+ self._padding([item['tail_labels'] for item in batch]),
+ dtype=torch.float)
+ # the content of `batch` is like batch_size * [token_ids, head_labels, tail_labels]
+ # for fp16 acceleration, truncate seq_len to multiples of 8
+ batch_max_len = token_ids.gt(0).sum(dim=-1).max().item()
+ batch_max_len += (8 - batch_max_len % 8) % 8
+ truncate_len = min(self.max_len, batch_max_len)
+ token_ids = token_ids[:, :truncate_len]
+ attention_masks = attention_masks[:, :truncate_len]
+ head_labels = head_labels[:, :truncate_len]
+ tail_labels = tail_labels[:, :truncate_len]
+
+ # for fp16 acceleration, truncate seq_len to multiples of 8
+ batch_max_len = hint_token_ids.gt(0).sum(dim=-1).max().item()
+ batch_max_len += (8 - batch_max_len % 8) % 8
+ hint_truncate_len = min(self.hint_max_len, batch_max_len)
+ hint_token_ids = hint_token_ids[:, :hint_truncate_len]
+ cross_attention_masks = cross_attention_masks[:, :hint_truncate_len]
+
+ return {
+ 'input_ids': token_ids,
+ 'attention_masks': attention_masks,
+ 'hint_ids': hint_token_ids,
+ 'cross_attention_masks': cross_attention_masks,
+ 'head_labels': head_labels,
+ 'tail_labels': tail_labels
+ }
+
+ def evaluate(self,
+ checkpoint_path: Optional[str] = None,
+ *args,
+ **kwargs) -> Dict[str, float]:
+ """evaluate a dataset
+
+ evaluate a dataset via a specific model from the `checkpoint_path` path, if the `checkpoint_path`
+ does not exist, read from the config file.
+
+ Args:
+ checkpoint_path (Optional[str], optional): the model path. Defaults to None.
+
+ Returns:
+ Dict[str, float]: the results about the evaluation
+ Example:
+ {"accuracy": 0.5091743119266054, "f1": 0.673780487804878}
+ """
+ pipeline_uie = pipeline(Tasks.siamese_uie, self.model)
+ if checkpoint_path is not None and os.path.isfile(checkpoint_path):
+ from modelscope.trainers.hooks import LoadCheckpointHook
+ LoadCheckpointHook.load_checkpoint(checkpoint_path, self)
+ self.model.eval()
+ self._mode = ModeKeys.EVAL
+ self.eval_dataloader = self.train_dataloader
+ num_pred = num_recall = num_correct = 1e-10
+ self.eval_dataset.preprocessor = None
+ for sample in self.eval_dataset:
+ text = sample['text']
+ schema = json.loads(sample['schema'])
+ gold_info_list = json.loads(sample['info_list'])
+ pred_info_list = pipeline_uie(input=text, schema=schema)['output']
+ pred_info_list_set = set([str(item) for item in pred_info_list])
+ gold_info_list_set = set([str(item) for item in gold_info_list])
+ a, b, c = len(pred_info_list_set), len(gold_info_list_set), len(
+ pred_info_list_set.intersection(gold_info_list_set))
+ num_pred += a
+ num_recall += b
+ num_correct += c
+ precision, recall, f1 = self.compute_metrics(num_pred, num_recall,
+ num_correct)
+ return {'precision': precision, 'recall': recall, 'f1': f1}
+
+ def get_metrics(self) -> List[Union[str, Dict]]:
+ """Get the metric class types.
+
+ The first choice will be the metrics configured in the config file, if not found, the default metrics will be
+ used.
+ If no metrics is found and the eval dataset exists, the method will raise an error.
+
+ Returns: The metric types.
+
+ """
+ return self.compute_metrics
+
+ def compute_metrics(self, num_pred, num_recall, num_correct):
+ if num_pred == num_recall == 1e-10:
+ return 1, 1, 1
+ precision = num_correct / float(num_pred)
+ recall = num_correct / float(num_recall)
+ f1 = 2 * precision * recall / (precision + recall)
+ # print(num_pred, num_recall, num_correct)
+ if num_correct == 1e-10:
+ return 0, 0, 0
+ return precision, recall, f1
diff --git a/modelscope/trainers/nlp/text_generation_trainer.py b/modelscope/trainers/nlp/text_generation_trainer.py
index fa6a448f..0021f7fc 100644
--- a/modelscope/trainers/nlp/text_generation_trainer.py
+++ b/modelscope/trainers/nlp/text_generation_trainer.py
@@ -22,12 +22,7 @@ class TextGenerationTrainer(NlpEpochBasedTrainer):
model.eval()
with torch.no_grad():
- if isinstance(
- data,
- Mapping) and not func_receive_dict_inputs(model.generate):
- result = model.generate(**data)
- else:
- result = model.generate(data)
+ result = model.generate(data)
result['preds'] = [self._decode(seq) for seq in result['sequences']]
data['tgts'] = [self._decode(seq) for seq in data['labels']]
diff --git a/modelscope/trainers/nlp_trainer.py b/modelscope/trainers/nlp_trainer.py
index bbdd080f..455fc907 100644
--- a/modelscope/trainers/nlp_trainer.py
+++ b/modelscope/trainers/nlp_trainer.py
@@ -150,7 +150,7 @@ class VecoTrainer(NlpEpochBasedTrainer):
"""Veco evaluates the datasets one by one.
"""
- from modelscope.msdatasets.task_datasets import VecoDataset
+ from modelscope.msdatasets.dataset_cls.custom_datasets import VecoDataset
if checkpoint_path is not None:
from modelscope.trainers.hooks import LoadCheckpointHook
LoadCheckpointHook.load_checkpoint(checkpoint_path, self)
@@ -159,9 +159,10 @@ class VecoTrainer(NlpEpochBasedTrainer):
metric_values = {}
if self.eval_dataset is None:
- val_data = self.cfg.dataset.val
- self.eval_dataset = self.build_dataset(
- val_data, mode=ModeKeys.EVAL)
+ self.eval_dataset = self.build_dataset_from_cfg(
+ model_cfg=self.cfg,
+ mode=self._mode,
+ preprocessor=self.eval_preprocessor)
idx = 0
dataset_cnt = 1
diff --git a/modelscope/trainers/trainer.py b/modelscope/trainers/trainer.py
index 843b1c2f..a21e154c 100644
--- a/modelscope/trainers/trainer.py
+++ b/modelscope/trainers/trainer.py
@@ -20,10 +20,11 @@ from modelscope.metrics import build_metric, task_default_metrics
from modelscope.metrics.prediction_saving_wrapper import \
PredictionSavingWrapper
from modelscope.models.base import Model, TorchModel
+from modelscope.msdatasets.dataset_cls.custom_datasets import \
+ TorchCustomDataset
+from modelscope.msdatasets.dataset_cls.custom_datasets.builder import \
+ build_custom_dataset
from modelscope.msdatasets.ms_dataset import MsDataset
-from modelscope.msdatasets.task_datasets.builder import build_task_dataset
-from modelscope.msdatasets.task_datasets.torch_base_dataset import \
- TorchTaskDataset
from modelscope.outputs import ModelOutputBase
from modelscope.preprocessors.base import Preprocessor
from modelscope.trainers.hooks.builder import HOOKS
@@ -32,20 +33,20 @@ from modelscope.trainers.lrscheduler.builder import build_lr_scheduler
from modelscope.trainers.optimizer.builder import build_optimizer
from modelscope.utils.config import Config, ConfigDict, JSONIteratorEncoder
from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, ConfigFields,
- ConfigKeys, ModeKeys, ModelFile,
- ThirdParty, TrainerStages)
+ ConfigKeys, DistributedParallelType,
+ ModeKeys, ModelFile, ThirdParty,
+ TrainerStages)
from modelscope.utils.data_utils import to_device
from modelscope.utils.device import create_device
from modelscope.utils.file_utils import func_receive_dict_inputs
from modelscope.utils.logger import get_logger
-from modelscope.utils.megatron_utils import is_megatron_initialized
from modelscope.utils.registry import build_from_cfg
-from modelscope.utils.torch_utils import (get_dist_info, get_local_rank,
- init_dist, is_dist, is_master,
- set_random_seed)
+from modelscope.utils.torch_utils import (compile_model, get_dist_info,
+ get_local_rank, init_dist, is_dist,
+ is_master, set_random_seed)
from .base import BaseTrainer
from .builder import TRAINERS
-from .default_config import merge_cfg, merge_hooks
+from .default_config import merge_cfg, merge_hooks, update_cfg
from .hooks.hook import Hook
from .parallel.builder import build_parallel
from .parallel.utils import is_parallel
@@ -83,6 +84,9 @@ class EpochBasedTrainer(BaseTrainer):
remove_unused_data: Automatically remove unused data keys in mini-batches.
The remove action based on the `inspect` on the model's forward method, the removed columns will be
moved to the mini-batch's attributes.
+ compile (bool, optional): Compile the model with torch 2.0, default False
+ compile_options (dict, optional): The compile options if compile=True,
+ default None to use the default params of 'TorchModel.compile'.
Examples of cfg_modify_fn:
>>> def cfg_modify_fn(cfg):
@@ -108,6 +112,7 @@ class EpochBasedTrainer(BaseTrainer):
None),
model_revision: Optional[str] = DEFAULT_MODEL_REVISION,
seed: int = 42,
+ callbacks: Optional[List[Hook]] = None,
**kwargs):
self._seed = seed
@@ -120,6 +125,11 @@ class EpochBasedTrainer(BaseTrainer):
self._iter = 0
self._inner_iter = 0
self._stop_training = False
+ self._compile = kwargs.get('compile', False)
+
+ self.train_dataloader = None
+ self.eval_dataloader = None
+ self.data_loader = None
if isinstance(model, str):
third_party = kwargs.get(ThirdParty.KEY)
@@ -142,12 +152,20 @@ class EpochBasedTrainer(BaseTrainer):
self.cfg = self.rebuild_config(self.cfg)
if 'cfg_options' in kwargs:
self.cfg.merge_from_dict(kwargs['cfg_options'])
+ self.cfg = update_cfg(self.cfg)
if isinstance(model, (TorchModel, nn.Module)):
self.model = model
else:
self.model = self.build_model()
+ if self._compile:
+ # Compile the model with torch 2.0
+ compile_options = kwargs.get('compile_options')
+ if compile_options is None:
+ compile_options = {}
+ self.model = compile_model(self.model, **compile_options)
+
if 'work_dir' in kwargs:
self.work_dir = kwargs['work_dir']
else:
@@ -156,46 +174,33 @@ class EpochBasedTrainer(BaseTrainer):
self.train_preprocessor, self.eval_preprocessor = self.get_preprocessors(
preprocessor)
- self._dist = self.init_dist(kwargs.get('launcher'))
-
- if is_master() and not os.path.exists(self.work_dir):
- os.makedirs(self.work_dir)
-
- self.device = self.get_device(kwargs.get('device'))
+ if not os.path.exists(self.work_dir):
+ # TODO duplicate makedirs may cause errors in dlc envs.
+ os.makedirs(self.work_dir, exist_ok=True)
# init logger after distribution init
log_file = os.path.join(self.work_dir, '{}.log'.format(self.timestamp))
self.logger = get_logger(
log_file=log_file, log_level=self.cfg.get('log_level', 'INFO'))
- if is_master():
- self.logger.info(
- '==========================Training Config Start=========================='
- )
- self.logger.info(
- json.dumps(
- self.cfg._cfg_dict, indent=4, cls=JSONIteratorEncoder))
- self.logger.info(
- '===========================Training Config End==========================='
- )
-
- self.train_dataset = self.to_task_dataset(
- train_dataset,
+ # Get train datasets
+ self.train_dataset = self.build_dataset(
+ datasets=train_dataset,
+ model_cfg=self.cfg,
mode=ModeKeys.TRAIN,
- task_data_config=self.cfg.safe_get('dataset.train'),
preprocessor=self.train_preprocessor,
**kwargs)
- self.eval_dataset = self.to_task_dataset(
- eval_dataset,
+ # Get evaluation datasets
+ self.eval_dataset = self.build_dataset(
+ datasets=eval_dataset,
+ model_cfg=self.cfg,
mode=ModeKeys.EVAL,
- task_data_config=self.cfg.safe_get('dataset.val'),
preprocessor=self.eval_preprocessor,
**kwargs)
self.train_data_collator, self.eval_data_collator = self.get_data_collator(
data_collator,
remove_unused_data=kwargs.get('remove_unused_data', False))
- self.metrics = self.get_metrics()
self._max_epochs = kwargs.get('max_epochs',
self.cfg.safe_get('train.max_epochs'))
assert self._max_epochs is not None, 'max_epochs should be provided by the init arguments or configured ' \
@@ -207,11 +212,51 @@ class EpochBasedTrainer(BaseTrainer):
'val_iters_per_epoch',
self.cfg.safe_get('evaluation.val_iters_per_epoch'))
self.use_fp16 = kwargs.get('use_fp16', False)
- # model placement
- self.place_model()
+ self.launcher = kwargs.get('launcher')
+ self.device = kwargs.get('device')
+ # The parallel_groups field will be initialized in the hooks' after_init stage.
+ # Please check the DDPHook and MegatronHook for details.
+ self.parallel_groups = {}
+
+ # Clear the Hook overload functions to avoid duplication.
Hook.clear_strategies()
+ if self.launcher is not None and not self.cfg.safe_get(
+ 'train.hooks.DDPHook'):
+ # A logic to fit the current code
+ # Put a DDPHook in if launcher is provided.
+ if 'hooks' not in self.cfg.train:
+ self.cfg.train['hooks'] = ConfigDict([])
+ self.cfg.train['hooks'].append({
+ 'type': 'DDPHook',
+ 'launcher': self.launcher
+ })
+
+ hooks = merge_hooks(self.cfg)
+ self.register_hook_from_cfg(hooks)
+ # Add user callback to hooks
+ if callable(callbacks):
+ callbacks = [callbacks]
+ for callback in callbacks or []:
+ self.register_hook(callback)
+ self.invoke_hook(TrainerStages.after_init)
+
+ # _dist represents for if dp is initialized and its world_size > 1
+ self._dist = self.is_dp_group_available() and dist.get_world_size(
+ self.dp_group) > 1
+
+ self.metrics = self.get_metrics()
+
+ if not self.parallel_groups:
+ # If not working in parallel scenario, put model to device as a default logic.
+ device_name = self.device if self.device is not None else 'gpu'
+ self.device = create_device(device_name)
+ if self.device.type == 'cuda':
+ self.model.to(self.device)
+
+ self.print_cfg()
+
def place_model(self):
"""Place model to device, or to DDP
"""
@@ -330,6 +375,45 @@ class EpochBasedTrainer(BaseTrainer):
cfg = self.cfg_modify_fn(cfg)
return cfg
+ @property
+ def dp_group(self):
+ """
+ Get the data parallel group.
+ """
+ return self.parallel_groups[DistributedParallelType.DP]
+
+ @property
+ def tp_group(self):
+ """
+ Get the tensor parallel group.
+ """
+ return self.parallel_groups[DistributedParallelType.TP]
+
+ @property
+ def pp_group(self):
+ """
+ Get the pipeline parallel group.
+ """
+ return self.parallel_groups[DistributedParallelType.PP]
+
+ def is_dp_group_available(self):
+ """
+ Get whether the data parallel group is initialized.
+ """
+ return DistributedParallelType.DP in self.parallel_groups
+
+ def is_tp_group_available(self):
+ """
+ Get whether the tensor parallel group is initialized.
+ """
+ return DistributedParallelType.TP in self.parallel_groups
+
+ def is_pp_group_available(self):
+ """
+ Get whether the pipeline parallel group is initialized.
+ """
+ return DistributedParallelType.PP in self.parallel_groups
+
@property
def mode(self):
return self._mode
@@ -389,85 +473,125 @@ class EpochBasedTrainer(BaseTrainer):
else:
return _get_data_len(self.eval_dataloader)
- def to_task_dataset(self,
- datasets: Union[Dataset, List[Dataset]],
- mode: str,
- task_data_config: Config = None,
- preprocessor: Optional[Preprocessor] = None,
- **kwargs):
- """Build the task specific dataset processor for this trainer.
+ def build_dataset(self,
+ datasets: Union[Dataset, MsDataset, List[Dataset]],
+ model_cfg: Config,
+ mode: str,
+ preprocessor: Optional[Preprocessor] = None,
+ **kwargs):
+ """Build input datasets by given model configuration and preprocessor.
- Returns: The task dataset processor for the task. If no result for the very model-type and task,
- the default TaskDataset will be returned.
+ Args:
+ datasets (Union[Dataset, MsDataset, List[Dataset]]): The input datasets.
+ model_cfg (Config): The model configuration.
+ mode (str): `train`, `eval` or `inference`. See modelscope.utils.constant.ModeKeys
+ preprocessor (Preprocessor, Optional): The preprocessor for input data samples.
+
+ Returns:
+ Preprocessed datasets.
"""
try:
- to_tensor = kwargs.get('to_tensor', True)
if not datasets:
- return datasets
- if isinstance(datasets, TorchTaskDataset):
+ return EpochBasedTrainer.build_dataset_from_cfg(
+ model_cfg=model_cfg, mode=mode, preprocessor=preprocessor)
+
+ if isinstance(datasets, TorchCustomDataset):
return datasets
elif isinstance(datasets, MsDataset):
- if task_data_config is None:
- # adapt to some special models
- task_data_config = ConfigDict(
- type=self.cfg.model.type) if hasattr(
- self.cfg, ConfigFields.model) else ConfigDict(
- type=None)
- task_data_config.update(dict(mode=mode))
- return datasets.to_torch_dataset(
- task_data_config=task_data_config,
- task_name=self.cfg.task,
- preprocessors=preprocessor,
- to_tensor=to_tensor)
+ if not datasets.is_custom:
+ datasets.to_custom_dataset(
+ custom_cfg=model_cfg,
+ preprocessor=preprocessor,
+ mode=mode,
+ **kwargs)
+ return datasets.ds_instance
elif isinstance(datasets, List) and isinstance(
datasets[0], MsDataset):
- if task_data_config is None:
- # adapt to some special models
- task_data_config = ConfigDict(
- type=self.cfg.model.type) if hasattr(
- self.cfg, ConfigFields.model) else ConfigDict(
- type=None)
- task_data_config.update(dict(mode=mode))
- datasets = [
- d.to_torch_dataset(
- task_data_config=task_data_config,
- task_name=self.cfg.task,
- preprocessors=preprocessor,
- to_tensor=to_tensor) for d in datasets
- ]
- cfg = ConfigDict(
- type=self.cfg.model.type, mode=mode, datasets=datasets)
- task_dataset = build_task_dataset(cfg, self.cfg.task)
- task_dataset.trainer = self
- return task_dataset
+ custom_datasets = []
+ for dataset in datasets:
+ if not dataset.is_custom:
+ dataset.to_custom_dataset(
+ custom_cfg=model_cfg,
+ preprocessor=preprocessor,
+ mode=mode,
+ **kwargs)
+ custom_datasets.append(dataset.ds_instance)
+ torch_custom_dataset = TorchCustomDataset(
+ datasets=custom_datasets,
+ mode=mode,
+ preprocessor=None,
+ **kwargs)
+ torch_custom_dataset.trainer = self
+ return torch_custom_dataset
else:
- if task_data_config is None:
+ dataset_mode_key = 'train' if mode == ModeKeys.TRAIN else 'val'
+ data_config = model_cfg.safe_get(f'dataset.{dataset_mode_key}')
+ if data_config is None:
# adapt to some special models
- task_data_config = {}
+ data_config = {}
# avoid add no str value datasets, preprocessors in cfg
- task_data_build_config = ConfigDict(
- type=self.cfg.model.type,
+ data_build_config = ConfigDict(
+ type=model_cfg.model.type,
mode=mode,
datasets=datasets,
preprocessor=preprocessor)
- task_data_build_config.update(task_data_config)
- task_dataset = build_task_dataset(task_data_build_config,
- self.cfg.task)
- task_dataset.trainer = self
- return task_dataset
- except Exception:
+ data_build_config.update(data_config)
+ custom_dataset = build_custom_dataset(data_build_config,
+ model_cfg.task)
+ custom_dataset.trainer = self
+ return custom_dataset
+ except Exception as e:
+ print('** build_dataset error log:', e)
if isinstance(datasets, (List, Tuple)) or preprocessor is not None:
- task_dataset = TorchTaskDataset(
+ custom_dataset = TorchCustomDataset(
datasets,
mode=mode,
preprocessor=preprocessor,
- **(dict(type=self.cfg.model.type) if hasattr(
- self.cfg, 'model') else {}))
- task_dataset.trainer = self
- return task_dataset
+ **(dict(type=model_cfg.model.type) if hasattr(
+ model_cfg, 'model') else {}))
+ custom_dataset.trainer = self
+ return custom_dataset
else:
return datasets
+ def to_task_dataset(self, dataset: Dataset, mode: str,
+ preprocessor: Preprocessor,
+ **kwargs) -> TorchCustomDataset:
+ r"""
+ @deprecated
+ This method is deprecated and may be removed in future releases, please use `build_dataset()` instead. Could be
+ compatible with methods that override the to_task_dataset in other classes.
+ """
+ self.logger.warning(
+ 'This to_task_dataset method is deprecated, please use build_dataset instead.'
+ )
+
+ task_dataset = TorchCustomDataset(
+ dataset, mode=mode, preprocessor=preprocessor, **kwargs)
+ task_dataset.trainer = self
+ return task_dataset
+
+ @staticmethod
+ def build_dataset_from_cfg(model_cfg: Config,
+ mode: str,
+ preprocessor: Preprocessor = None):
+ dataset = None
+ dataset_name = model_cfg.safe_get('dataset.name')
+ subset_name = model_cfg.safe_get('dataset.subset', default='default')
+ split_name = model_cfg.safe_get(f'dataset.split_{mode}')
+ if not dataset_name or not split_name:
+ return dataset
+ dataset = MsDataset.load(
+ dataset_name=dataset_name,
+ subset_name=subset_name,
+ split=split_name,
+ custom_cfg=model_cfg)
+ if not dataset.is_custom:
+ dataset.to_custom_dataset(
+ custom_cfg=model_cfg, preprocessor=preprocessor, mode=mode)
+
+ return dataset.ds_instance
+
def build_preprocessor(self) -> Tuple[Preprocessor, Preprocessor]:
"""Build train and eval preprocessor.
@@ -544,10 +668,7 @@ class EpochBasedTrainer(BaseTrainer):
self.train_dataloader = self.get_train_dataloader()
self.data_loader = self.train_dataloader
self.register_optimizers_hook()
- hooks = merge_hooks(self.cfg)
- self.register_hook_from_cfg(hooks)
- if is_master():
- self.logger.info(self.get_hook_info())
+ self.print_hook_info()
self.set_checkpoint_file_to_hook(checkpoint_path, load_all_state,
kwargs.get('strict', False))
self.model.train()
@@ -586,18 +707,14 @@ class EpochBasedTrainer(BaseTrainer):
strict(`boolean`): If strict, any unmatched keys will cause an error.
"""
- if not self._hooks:
- hooks = merge_hooks(self.cfg)
- self.register_hook_from_cfg(hooks)
- if is_master():
- self.logger.info(self.get_hook_info())
+ self.print_hook_info()
if checkpoint_path is not None:
from modelscope.trainers.hooks import LoadCheckpointHook
LoadCheckpointHook.load_checkpoint(
checkpoint_path, self, strict=strict)
self.model.eval()
self._mode = ModeKeys.EVAL
- predict_dataloader = self.get_predict_data_loader(predict_datasets)
+ predict_dataloader = self.get_predict_dataloader(predict_datasets)
metric_classes = [PredictionSavingWrapper(saving_fn=saving_fn)]
for m in metric_classes:
@@ -628,11 +745,7 @@ class EpochBasedTrainer(BaseTrainer):
kwargs:
strict(`boolean`): If strict, any unmatched keys will cause an error.
"""
- if not self._hooks:
- hooks = merge_hooks(self.cfg)
- self.register_hook_from_cfg(hooks)
- if is_master():
- self.logger.info(self.get_hook_info())
+ self.print_hook_info()
if checkpoint_path is not None:
from modelscope.trainers.hooks import LoadCheckpointHook
LoadCheckpointHook.load_checkpoint(
@@ -682,14 +795,8 @@ class EpochBasedTrainer(BaseTrainer):
type='DistributedDataParallel',
module=model,
find_unused_parameters=True,
- device_ids=[torch.cuda.current_device()])
-
- if is_megatron_initialized():
- from megatron_util import mpu
- dp_cfg.update({
- 'output_device': torch.cuda.current_device(),
- 'process_group': mpu.get_data_parallel_group()
- })
+ device_ids=[torch.cuda.current_device()],
+ process_group=self.dp_group)
return build_parallel(dp_cfg)
@@ -776,12 +883,7 @@ class EpochBasedTrainer(BaseTrainer):
(or `get_train_dataloader` in a subclass.
"""
if self.train_dataset is None:
- train_data = self.cfg.dataset.train
- self.train_dataset = self.build_dataset(
- train_data,
- mode=ModeKeys.TRAIN,
- preprocessor=self.train_preprocessor)
-
+ raise 'The train_dataset cannot be None.'
data_loader = self._build_dataloader_with_dataset(
self.train_dataset,
dist=self._dist,
@@ -798,11 +900,7 @@ class EpochBasedTrainer(BaseTrainer):
pass
"""
if self.eval_dataset is None:
- val_data = self.cfg.dataset.val
- self.eval_dataset = self.build_dataset(
- val_data,
- mode=ModeKeys.EVAL,
- preprocessor=self.eval_preprocessor)
+ raise 'The eval_dataset cannot be None.'
default_config = {'shuffle': False}
default_config.update(self.cfg.evaluation.get('dataloader', {}))
@@ -814,15 +912,16 @@ class EpochBasedTrainer(BaseTrainer):
**default_config)
return data_loader
- def get_predict_data_loader(self, predict_datasets: Union[Dataset,
- List[Dataset]]):
+ def get_predict_dataloader(self, predict_datasets: Union[Dataset,
+ List[Dataset]]):
""" Builder torch dataloader for prediction with the config of evaluation.
Args:
predict_datasets(Union[Dataset, List[Dataset]]): The datasets used to predict ground truth.
"""
- dataset = self.to_task_dataset(
- predict_datasets,
+ dataset = self.build_dataset(
+ datasets=predict_datasets,
+ model_cfg=self.cfg,
mode=ModeKeys.EVAL,
preprocessor=self.eval_preprocessor)
@@ -836,26 +935,6 @@ class EpochBasedTrainer(BaseTrainer):
**default_config)
return data_loader
- def build_dataset(self, data_cfg, mode, preprocessor=None):
- """ Build torch dataset object using data config
- """
- # TODO: support MsDataset load for cv
- if hasattr(data_cfg, 'name'):
- dataset_name = data_cfg.pop('name')
- dataset = MsDataset.load(
- dataset_name=dataset_name,
- **data_cfg,
- )
- cfg = ConfigDict(type=self.cfg.model.type, mode=mode)
- torch_dataset = dataset.to_torch_dataset(
- task_data_config=cfg,
- task_name=self.cfg.task,
- preprocessors=preprocessor)
- else:
- torch_dataset = build_task_dataset(data_cfg, self.cfg.task)
- dataset = self.to_task_dataset(torch_dataset, mode)
- return dataset
-
def build_optimizer(self, cfg: ConfigDict, default_args: dict = None):
try:
return build_optimizer(
@@ -1024,7 +1103,11 @@ class EpochBasedTrainer(BaseTrainer):
Returns:
DataLoader: A PyTorch dataloader.
"""
- rank, world_size = get_dist_info()
+ rank = 0
+ world_size = 1
+ if self.is_dp_group_available():
+ rank = torch.distributed.get_rank(self.dp_group)
+ world_size = torch.distributed.get_world_size(self.dp_group)
if dist:
# When model is :obj:`DistributedDataParallel`,
@@ -1132,6 +1215,7 @@ class EpochBasedTrainer(BaseTrainer):
vis_closure = partial(
self.visualization, dataset=self.eval_dataset, **vis_cfg)
+ self.invoke_hook(TrainerStages.before_val)
if self._dist:
from modelscope.trainers.utils.inference import multi_gpu_test
# list of batched result and data samples
@@ -1154,6 +1238,7 @@ class EpochBasedTrainer(BaseTrainer):
vis_closure=vis_closure,
data_loader_iters=self._eval_iters_per_epoch)
+ self.invoke_hook(TrainerStages.after_val)
return metric_values
def visualization(self, batch_result, dataset, **kwargs):
@@ -1239,7 +1324,26 @@ class EpochBasedTrainer(BaseTrainer):
"before_train_epoch".
"""
for hook in self._hooks:
- getattr(hook, fn_name)(self)
+ if hasattr(hook, fn_name):
+ getattr(hook, fn_name)(self)
+
+ def print_cfg(self):
+ if is_master():
+ cfg = deepcopy(self.cfg)
+ cfg.train.work_dir = self.work_dir
+ self.logger.info(
+ '==========================Training Config Start=========================='
+ )
+ self.logger.info(
+ json.dumps(cfg._cfg_dict, indent=4, cls=JSONIteratorEncoder))
+ self.logger.info(
+ '===========================Training Config End==========================='
+ )
+
+ def print_hook_info(self):
+ if is_master() and not getattr(self, '_hook_info_printed', False):
+ self.logger.info(self.get_hook_info())
+ self._hook_info_printed = True
def get_hook_info(self) -> str:
# Get hooks info in each stage
@@ -1251,8 +1355,9 @@ class EpochBasedTrainer(BaseTrainer):
priority = Priority.NORMAL # type: ignore
classname = hook.__class__.__name__
hook_info = f'({priority:<12}) {classname:<35}'
- for trigger_stage in hook.get_triggered_stages():
- stage_hook_map[trigger_stage].append(hook_info)
+ if hasattr(hook, 'get_triggered_stages'):
+ for trigger_stage in hook.get_triggered_stages():
+ stage_hook_map[trigger_stage].append(hook_info)
stage_hook_infos = []
for stage in Hook.stages:
diff --git a/modelscope/trainers/utils/inference.py b/modelscope/trainers/utils/inference.py
index b0dbe4bf..e922b430 100644
--- a/modelscope/trainers/utils/inference.py
+++ b/modelscope/trainers/utils/inference.py
@@ -10,11 +10,9 @@ import torch
from torch import distributed as dist
from tqdm import tqdm
-from modelscope.utils.constant import DistributedParallelType
from modelscope.utils.data_utils import to_device
-from modelscope.utils.megatron_utils import is_megatron_initialized
-from modelscope.utils.torch_utils import (broadcast, get_dist_info, is_master,
- make_tmp_dir)
+from modelscope.utils.torch_utils import (broadcast, get_dist_info, is_dist,
+ is_master, make_tmp_dir)
def single_gpu_test(trainer,
@@ -107,7 +105,7 @@ def multi_gpu_test(trainer,
list: The prediction results.
"""
dataset = data_loader.dataset
- rank, world_size = get_dist_info()
+ rank, world_size = get_dist_info(trainer.dp_group)
progress_with_iters = False
if data_loader_iters_per_gpu is None:
@@ -164,12 +162,13 @@ def multi_gpu_test(trainer,
# collect results and data from all ranks
if gpu_collect:
- metric_classes_list = collect_results_gpu(metric_classes)
+ metric_classes_list = collect_results_gpu(metric_classes,
+ trainer.dp_group)
else:
if tmpdir is None:
tmpdir = make_tmp_dir()
metric_classes_list = collect_results_cpu(
- metric_classes, os.path.join(tmpdir, 'metrics'))
+ metric_classes, trainer, os.path.join(tmpdir, 'metrics'))
metric_classes = merge_metrics(metric_classes_list)
@@ -189,18 +188,16 @@ def evaluate_batch(trainer, data, metric_classes, vis_closure):
def get_metric_values(metric_classes):
- _, world_size = get_dist_info()
metric_values = {}
- if is_master(
- DistributedParallelType.DP if is_megatron_initialized() else None):
+ if is_master():
for metric_cls in metric_classes:
metric_values.update(metric_cls.evaluate())
- if world_size > 1:
+ if is_dist():
metric_values = broadcast(metric_values, 0)
return metric_values
-def collect_results_cpu(result_part, tmpdir=None):
+def collect_results_cpu(result_part, trainer, tmpdir=None):
"""Collect results under cpu mode.
On cpu mode, this function will save the results on different gpus to
@@ -209,8 +206,7 @@ def collect_results_cpu(result_part, tmpdir=None):
Args:
result_part (list): Result list containing result parts
to be collected.
- size (int): Size of the results, commonly equal to length of
- the results.
+ trainer(`EpochBasedTrainer`): The trainer instance to get the parallel groups.
tmpdir (str | None): temporal directory for collected results to
store. If set to None, it will create a random temporal directory
for it.
@@ -218,15 +214,16 @@ def collect_results_cpu(result_part, tmpdir=None):
Returns:
list: The collected results.
"""
- rank, world_size = get_dist_info()
+ rank, world_size = get_dist_info(trainer.dp_group)
if tmpdir is None:
tmpdir = make_tmp_dir()
- if not os.path.exists(tmpdir) and is_master(DistributedParallelType.TP):
- os.makedirs(tmpdir)
+ if not os.path.exists(tmpdir):
+ os.makedirs(tmpdir, exist_ok=True)
dist.barrier()
# dump the part result to the dir
- if is_master(DistributedParallelType.TP):
+ if (not trainer.is_tp_group_available() or is_master(trainer.tp_group)) \
+ and (not trainer.is_pp_group_available() or is_master(trainer.pp_group)):
with open(os.path.join(tmpdir, f'part_{rank}.pkl'), 'wb') as f:
pickle.dump(result_part, f)
dist.barrier()
@@ -250,7 +247,7 @@ def collect_results_cpu(result_part, tmpdir=None):
return part_list
-def collect_results_gpu(result_part):
+def collect_results_gpu(result_part, dp_group=None):
"""Collect results under gpu mode.
On gpu mode, this function will encode results to gpu tensors and use gpu
@@ -259,17 +256,12 @@ def collect_results_gpu(result_part):
Args:
result_part (list): Result list containing result parts
to be collected.
- size (int): Size of the results, commonly equal to length of
- the results.
+ dp_group(`ProcessGroup` or None): The data parallel group, default None for global group.
Returns:
list: The collected results.
"""
- _, world_size = get_dist_info()
- group = None
- if is_megatron_initialized():
- from megatron_util import mpu
- group = mpu.get_data_parallel_group()
+ _, world_size = get_dist_info(dp_group)
# dump result part to tensor with pickle
part_tensor = torch.tensor(
@@ -277,7 +269,7 @@ def collect_results_gpu(result_part):
# gather all result part tensor shape
shape_tensor = torch.tensor(part_tensor.shape, device='cuda')
shape_list = [shape_tensor.clone() for _ in range(world_size)]
- dist.all_gather(shape_list, shape_tensor, group)
+ dist.all_gather(shape_list, shape_tensor, dp_group)
# padding result part tensor to max length
shape_max = torch.tensor(shape_list).max()
part_send = torch.zeros(shape_max, dtype=torch.uint8, device='cuda')
@@ -286,7 +278,7 @@ def collect_results_gpu(result_part):
part_tensor.new_zeros(shape_max) for _ in range(world_size)
]
# gather all result part
- dist.all_gather(part_recv_list, part_send, group)
+ dist.all_gather(part_recv_list, part_send, dp_group)
if is_master():
part_list = []
diff --git a/modelscope/utils/ast_utils.py b/modelscope/utils/ast_utils.py
index 4b73ed26..76f15e56 100644
--- a/modelscope/utils/ast_utils.py
+++ b/modelscope/utils/ast_utils.py
@@ -16,7 +16,7 @@ import json
from modelscope import __version__
from modelscope.fileio.file import LocalStorage
-from modelscope.metainfo import (Datasets, Heads, Hooks, LR_Schedulers,
+from modelscope.metainfo import (CustomDatasets, Heads, Hooks, LR_Schedulers,
Metrics, Models, Optimizers, Pipelines,
Preprocessors, TaskModels, Trainers)
from modelscope.utils.constant import Fields, Tasks
@@ -35,7 +35,8 @@ INDEXER_FILE_DIR = get_default_cache_dir()
REGISTER_MODULE = 'register_module'
IGNORED_PACKAGES = ['modelscope', '.']
SCAN_SUB_FOLDERS = [
- 'models', 'metrics', 'pipelines', 'preprocessors', 'trainers', 'msdatasets'
+ 'models', 'metrics', 'pipelines', 'preprocessors', 'trainers',
+ 'msdatasets', 'exporters'
]
INDEXER_FILE = 'ast_indexer'
DECORATOR_KEY = 'decorators'
@@ -82,11 +83,11 @@ class AstScanning(object):
else:
return True
- def _skip_function(self, node: ast.AST) -> bool:
- if type(node).__name__ == 'FunctionDef' and SKIP_FUNCTION_SCANNING:
- return True
- else:
- return False
+ def _skip_function(self, node: Union[ast.AST, 'str']) -> bool:
+ if SKIP_FUNCTION_SCANNING:
+ if type(node).__name__ == 'FunctionDef' or node == 'FunctionDef':
+ return True
+ return False
def _fields(self, n: ast.AST, show_offsets: bool = True) -> tuple:
if show_offsets:
@@ -120,9 +121,7 @@ class AstScanning(object):
def scan_import(
self,
node: Union[ast.AST, None, str],
- indent: Union[str, int] = ' ',
show_offsets: bool = True,
- _indent: int = 0,
parent_node_name: str = '',
) -> tuple:
if node is None:
@@ -131,23 +130,11 @@ class AstScanning(object):
return self._leaf(node, show_offsets=show_offsets)
else:
- class state:
- indent = _indent
-
- @contextlib.contextmanager
- def indented() -> Generator[None, None, None]:
- state.indent += 1
- yield
- state.indent -= 1
-
def _scan_import(el: Union[ast.AST, None, str],
- _indent: int = 0,
parent_node_name: str = '') -> str:
return self.scan_import(
el,
- indent=indent,
show_offsets=show_offsets,
- _indent=_indent,
parent_node_name=parent_node_name)
outputs = dict()
@@ -162,80 +149,73 @@ class AstScanning(object):
setattr(node, 'module', path_level)
else:
setattr(node, 'module', path_level + module_name)
- with indented():
- for field in self._fields(node, show_offsets=show_offsets):
- attr = getattr(node, field)
- if attr == []:
- outputs[field] = []
- elif (isinstance(attr, list) and len(attr) == 1
- and isinstance(attr[0], ast.AST)
- and self._skip_function(attr[0])):
- continue
- elif (isinstance(attr, list) and len(attr) == 1
- and isinstance(attr[0], ast.AST)
- and self._is_leaf(attr[0])):
- local_out = _scan_import(attr[0])
- outputs[field] = local_out
- elif isinstance(attr, list):
- el_dict = dict()
- with indented():
- for el in attr:
- local_out = _scan_import(
- el, state.indent,
- type(el).__name__)
- name = type(el).__name__
- if (name == 'Import' or name == 'ImportFrom'
- or parent_node_name == 'ImportFrom'
- or parent_node_name == 'Import'):
- if name not in el_dict:
- el_dict[name] = []
- el_dict[name].append(local_out)
- outputs[field] = el_dict
- elif isinstance(attr, ast.AST):
- output = _scan_import(attr, state.indent)
- outputs[field] = output
- else:
- outputs[field] = attr
+ for field in self._fields(node, show_offsets=show_offsets):
+ attr = getattr(node, field)
+ if attr == []:
+ outputs[field] = []
+ elif self._skip_function(parent_node_name):
+ continue
+ elif (isinstance(attr, list) and len(attr) == 1
+ and isinstance(attr[0], ast.AST)
+ and self._is_leaf(attr[0])):
+ local_out = _scan_import(attr[0])
+ outputs[field] = local_out
+ elif isinstance(attr, list):
+ el_dict = dict()
+ for el in attr:
+ local_out = _scan_import(el, type(el).__name__)
+ name = type(el).__name__
+ if (name == 'Import' or name == 'ImportFrom'
+ or parent_node_name == 'ImportFrom'
+ or parent_node_name == 'Import'):
+ if name not in el_dict:
+ el_dict[name] = []
+ el_dict[name].append(local_out)
+ outputs[field] = el_dict
+ elif isinstance(attr, ast.AST):
+ output = _scan_import(attr)
+ outputs[field] = output
+ else:
+ outputs[field] = attr
- if (type(node).__name__ == 'Import'
- or type(node).__name__ == 'ImportFrom'):
- if type(node).__name__ == 'ImportFrom':
- if field == 'module':
+ if (type(node).__name__ == 'Import'
+ or type(node).__name__ == 'ImportFrom'):
+ if type(node).__name__ == 'ImportFrom':
+ if field == 'module':
+ self.result_from_import[outputs[field]] = dict()
+ if field == 'names':
+ if isinstance(outputs[field]['alias'], list):
+ item_name = []
+ for item in outputs[field]['alias']:
+ local_name = item['alias']['name']
+ item_name.append(local_name)
self.result_from_import[
- outputs[field]] = dict()
- if field == 'names':
- if isinstance(outputs[field]['alias'], list):
- item_name = []
- for item in outputs[field]['alias']:
- local_name = item['alias']['name']
- item_name.append(local_name)
- self.result_from_import[
- outputs['module']] = item_name
- else:
- local_name = outputs[field]['alias'][
- 'name']
- self.result_from_import[
- outputs['module']] = [local_name]
-
- if type(node).__name__ == 'Import':
- final_dict = outputs[field]['alias']
- if isinstance(final_dict, list):
- for item in final_dict:
- self.result_import[
- item['alias']['name']] = item['alias']
+ outputs['module']] = item_name
else:
- self.result_import[outputs[field]['alias']
- ['name']] = final_dict
+ local_name = outputs[field]['alias']['name']
+ self.result_from_import[outputs['module']] = [
+ local_name
+ ]
- if 'decorator_list' == field and attr != []:
- for item in attr:
- setattr(item, CLASS_NAME, node.name)
- self.result_decorator.extend(attr)
+ if type(node).__name__ == 'Import':
+ final_dict = outputs[field]['alias']
+ if isinstance(final_dict, list):
+ for item in final_dict:
+ self.result_import[item['alias']
+ ['name']] = item['alias']
+ else:
+ self.result_import[outputs[field]['alias']
+ ['name']] = final_dict
- if attr != [] and type(
- attr
- ).__name__ == 'Call' and parent_node_name == 'Expr':
- self.result_express.append(attr)
+ if 'decorator_list' == field and attr != []:
+ for item in attr:
+ setattr(item, CLASS_NAME, node.name)
+ self.result_decorator.extend(attr)
+
+ if attr != [] and type(
+ attr
+ ).__name__ == 'Call' and parent_node_name == 'Expr':
+ self.result_express.append(attr)
return {
IMPORT_KEY: self.result_import,
@@ -384,7 +364,7 @@ class AstScanning(object):
data = ''.join(data)
node = gast.parse(data)
- output = self.scan_import(node, indent=' ', show_offsets=False)
+ output = self.scan_import(node, show_offsets=False)
output[DECORATOR_KEY] = self.parse_decorators(output[DECORATOR_KEY])
output[EXPRESS_KEY] = self.parse_decorators(output[EXPRESS_KEY])
output[DECORATOR_KEY].extend(output[EXPRESS_KEY])
@@ -396,6 +376,7 @@ class FilesAstScanning(object):
def __init__(self) -> None:
self.astScaner = AstScanning()
self.file_dirs = []
+ self.requirement_dirs = []
def _parse_import_path(self,
import_package: str,
@@ -456,15 +437,15 @@ class FilesAstScanning(object):
ignored.add(item)
return list(set(output) - set(ignored))
- def traversal_files(self, path, check_sub_dir):
+ def traversal_files(self, path, check_sub_dir=None):
self.file_dirs = []
if check_sub_dir is None or len(check_sub_dir) == 0:
self._traversal_files(path)
-
- for item in check_sub_dir:
- sub_dir = os.path.join(path, item)
- if os.path.isdir(sub_dir):
- self._traversal_files(sub_dir)
+ else:
+ for item in check_sub_dir:
+ sub_dir = os.path.join(path, item)
+ if os.path.isdir(sub_dir):
+ self._traversal_files(sub_dir)
def _traversal_files(self, path):
dir_list = os.scandir(path)
@@ -475,6 +456,8 @@ class FilesAstScanning(object):
self._traversal_files(item.path)
elif item.is_file() and item.name.endswith('.py'):
self.file_dirs.append(item.path)
+ elif item.is_file() and 'requirement' in item.name:
+ self.requirement_dirs.append(item.path)
def _get_single_file_scan_result(self, file):
try:
diff --git a/modelscope/utils/audio/audio_utils.py b/modelscope/utils/audio/audio_utils.py
index 9be97016..3f79e8b0 100644
--- a/modelscope/utils/audio/audio_utils.py
+++ b/modelscope/utils/audio/audio_utils.py
@@ -21,6 +21,18 @@ class TtsTrainType(object):
TRAIN_TYPE_VOC = 'train-type-voc'
+class TtsCustomParams(object):
+ VOICE_NAME = 'voice_name'
+ AM_CKPT = 'am_ckpt'
+ VOC_CKPT = 'voc_ckpt'
+ AM_CONFIG = 'am_config'
+ VOC_CONFIG = 'voc_config'
+ AUIDO_CONFIG = 'audio_config'
+ SE_FILE = 'se_file'
+ SE_MODEL = 'se_model'
+ MVN_FILE = 'mvn_file'
+
+
def to_segment(batch, segment_length=SEGMENT_LENGTH_TRAIN):
"""
Dataset mapping function to split one audio into segments.
@@ -105,6 +117,28 @@ def extract_pcm_from_wav(wav: bytes) -> bytes:
return data, sample_rate
+def expect_token_number(instr, token):
+ first_token = re.match(r'^\s*' + token, instr)
+ if first_token is None:
+ return None
+ instr = instr[first_token.end():]
+ lr = re.match(r'^\s*(-?\d+\.?\d*e?-?\d*?)', instr)
+ if lr is None:
+ return None
+ return instr[lr.end():], lr.groups()[0]
+
+
+def expect_kaldi_matrix(instr):
+ pos2 = instr.find('[', 0)
+ pos3 = instr.find(']', pos2)
+ mat = []
+ for stt in instr[pos2 + 1:pos3].split('\n'):
+ tmp_mat = np.fromstring(stt, dtype=np.float32, sep=' ')
+ if tmp_mat.size > 0:
+ mat.append(tmp_mat)
+ return instr[pos3 + 1:], np.array(mat)
+
+
# This implementation is adopted from scipy.io.wavfile.write,
# made publicly available under the BSD-3-Clause license at
# https://github.com/scipy/scipy/blob/v1.9.3/scipy/io/wavfile.py
@@ -172,22 +206,24 @@ def load_bytes_from_url(url: str) -> Union[bytes, str]:
def generate_scp_from_url(url: str, key: str = None):
wav_scp_path = None
raw_inputs = None
- # for local wav.scp inputs
- if os.path.exists(url) and url.lower().endswith('.scp'):
- wav_scp_path = url
- return wav_scp_path, raw_inputs
- # for local wav file inputs
- if os.path.exists(url) and (url.lower().endswith(SUPPORT_AUDIO_TYPE_SETS)):
+ # for local inputs
+ if os.path.exists(url):
wav_scp_path = url
return wav_scp_path, raw_inputs
# for wav url, download bytes data
- result = urlparse(url)
- if result.scheme is not None and len(result.scheme) > 0:
- storage = HTTPStorage()
- # bytes
- wav_scp_path = storage.read(url)
-
- return wav_scp_path, raw_inputs
+ if url.startswith('http'):
+ result = urlparse(url)
+ if result.scheme is not None and len(result.scheme) > 0:
+ storage = HTTPStorage()
+ # bytes
+ data = storage.read(url)
+ work_dir = tempfile.TemporaryDirectory().name
+ if not os.path.exists(work_dir):
+ os.makedirs(work_dir)
+ wav_path = os.path.join(work_dir, os.path.basename(url))
+ with open(wav_path, 'wb') as fb:
+ fb.write(data)
+ return wav_path, raw_inputs
return wav_scp_path, raw_inputs
diff --git a/modelscope/utils/chinese_utils.py b/modelscope/utils/chinese_utils.py
index 86cf91a2..77ea34ce 100644
--- a/modelscope/utils/chinese_utils.py
+++ b/modelscope/utils/chinese_utils.py
@@ -3,8 +3,6 @@
import re
import string
-from zhconv import convert
-
CHINESE_PUNCTUATION = '"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、\u3000、〃〈〉《》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏﹑﹔·!?。。'
ENGLISH_PUNCTUATION = string.punctuation
@@ -58,6 +56,8 @@ def _is_chinese_char(cp: str) -> bool:
def normalize_chinese_number(text):
+ from zhconv import convert
+
chinese_number = ['零', '一', '二', '三', '四', '五', '六', '七', '八', '九']
new_text = ''
for x in text:
diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py
index 667575ff..68630c81 100644
--- a/modelscope/utils/constant.py
+++ b/modelscope/utils/constant.py
@@ -17,6 +17,7 @@ class CVTasks(object):
ocr_detection = 'ocr-detection'
ocr_recognition = 'ocr-recognition'
table_recognition = 'table-recognition'
+ lineless_table_recognition = 'lineless-table-recognition'
license_plate_detection = 'license-plate-detection'
# human face body related
@@ -111,6 +112,7 @@ class CVTasks(object):
referring_video_object_segmentation = 'referring-video-object-segmentation'
video_human_matting = 'video-human-matting'
video_panoptic_segmentation = 'video-panoptic-segmentation'
+ video_instance_segmentation = 'video-instance-segmentation'
# video editing
video_inpainting = 'video-inpainting'
@@ -140,6 +142,9 @@ class CVTasks(object):
# 3d face reconstruction
face_reconstruction = 'face-reconstruction'
+ # 3d human reconstruction
+ human_reconstruction = 'human-reconstruction'
+
# image quality assessment mos
image_quality_assessment_mos = 'image-quality-assessment-mos'
# motion generation
@@ -217,6 +222,7 @@ class AudioTasks(object):
speaker_diarization = 'speaker-diarization'
voice_activity_detection = 'voice-activity-detection'
language_score_prediction = 'language-score-prediction'
+ speech_timestamp = 'speech-timestamp'
class MultiModalTasks(object):
@@ -234,6 +240,8 @@ class MultiModalTasks(object):
document_vl_embedding = 'document-vl-embedding'
video_captioning = 'video-captioning'
video_question_answering = 'video-question-answering'
+ video_temporal_grounding = 'video-temporal-grounding'
+ text_to_video_synthesis = 'text-to-video-synthesis'
class ScienceTasks(object):
@@ -391,6 +399,7 @@ class ThirdParty(object):
KEY = 'third_party'
EASYCV = 'easycv'
ADASEQ = 'adaseq'
+ ADADET = 'adadet'
class ConfigFields(object):
@@ -460,7 +469,9 @@ class LogKeys:
class TrainerStages:
+ after_init = 'after_init'
before_run = 'before_run'
+ before_val = 'before_val'
before_train_epoch = 'before_train_epoch'
before_train_iter = 'before_train_iter'
after_train_iter = 'after_train_iter'
@@ -470,6 +481,7 @@ class TrainerStages:
after_val_iter = 'after_val_iter'
after_val_epoch = 'after_val_epoch'
after_run = 'after_run'
+ after_val = 'after_val'
class ColorCodes:
@@ -516,3 +528,8 @@ class DistributedParallelType(object):
DP = 'data_parallel'
TP = 'tensor_model_parallel'
PP = 'pipeline_model_parallel'
+
+
+class DatasetTensorflowConfig:
+ BATCH_SIZE = 'batch_size'
+ DEFAULT_BATCH_SIZE_VALUE = 5
diff --git a/modelscope/utils/demo_utils.py b/modelscope/utils/demo_utils.py
index 82bf1ada..99e61d45 100644
--- a/modelscope/utils/demo_utils.py
+++ b/modelscope/utils/demo_utils.py
@@ -30,6 +30,7 @@ TASKS_INPUT_TEMPLATES = {
Tasks.ocr_detection: TasksIODescriptions.image_to_text,
Tasks.ocr_recognition: TasksIODescriptions.image_to_text,
Tasks.body_2d_keypoints: TasksIODescriptions.image_to_text,
+ Tasks.vision_efficient_tuning: TasksIODescriptions.image_to_text,
# nlp tasks
Tasks.text_classification: TasksIODescriptions.text_to_text,
diff --git a/modelscope/utils/error.py b/modelscope/utils/error.py
index 44e6b238..d3000130 100644
--- a/modelscope/utils/error.py
+++ b/modelscope/utils/error.py
@@ -153,3 +153,12 @@ MPI4PY_IMPORT_ERROR = """
`pip install mpi4py' and with following the instruction to install openmpi,
https://docs.open-mpi.org/en/v5.0.x/installing-open-mpi/quickstart.html`
"""
+
+# docstyle-ignore
+OPENCLIP_IMPORT_ERROR = """
+{0} requires the fasttext library but it was not found in your environment.
+You can install it with pip on linux or mac:
+`pip install open_clip_torch`
+Or you can checkout the instructions on the
+installation page: https://github.com/mlfoundations/open_clip and follow the ones that match your environment.
+"""
diff --git a/modelscope/utils/import_utils.py b/modelscope/utils/import_utils.py
index 3517ea3d..ea123ed7 100644
--- a/modelscope/utils/import_utils.py
+++ b/modelscope/utils/import_utils.py
@@ -304,6 +304,7 @@ REQUIREMENTS_MAAPING = OrderedDict([
('text2sql_lgesql', (is_package_available('text2sql_lgesql'),
TEXT2SQL_LGESQL_IMPORT_ERROR)),
('mpi4py', (is_package_available('mpi4py'), MPI4PY_IMPORT_ERROR)),
+ ('open_clip', (is_package_available('open_clip'), OPENCLIP_IMPORT_ERROR)),
])
SYSTEM_PACKAGE = set(['os', 'sys', 'typing'])
diff --git a/modelscope/utils/megatron_utils.py b/modelscope/utils/megatron_utils.py
index 9f2b2c09..11e79831 100644
--- a/modelscope/utils/megatron_utils.py
+++ b/modelscope/utils/megatron_utils.py
@@ -20,24 +20,24 @@ _DEFAULT_CFG_WITH_MODEL_TYPE = {
_IS_MEGATRON_INITIALIZED = False
-def init_megatron_util(cfg=None, model_dir=None, **kwargs):
+def init_megatron_util(megatron_cfg=None, model_dir=None, **kwargs):
from modelscope.utils.hub import read_config
from megatron_util import initialize_megatron
- assert not (cfg is None and model_dir is None), \
+ assert not (megatron_cfg is None and model_dir is None), \
'cfg and model_dir cannot both be None when initializing megatron_util'
- if cfg is None:
+ if megatron_cfg is None:
cfg = read_config(model_dir)
- try:
- megatron_cfg = cfg.megatron
- except AttributeError:
try:
- model_type = cfg.model.type
+ megatron_cfg = cfg.megatron
except AttributeError:
- # Fit models without model type, such as mglm
- model_type = cfg.pipeline.type
- megatron_cfg = _DEFAULT_CFG_WITH_MODEL_TYPE[model_type] \
- if model_type in _DEFAULT_CFG_WITH_MODEL_TYPE else {}
+ try:
+ model_type = cfg.model.type
+ except AttributeError:
+ # Fit models without model type, such as mglm
+ model_type = cfg.pipeline.type
+ megatron_cfg = _DEFAULT_CFG_WITH_MODEL_TYPE[model_type] \
+ if model_type in _DEFAULT_CFG_WITH_MODEL_TYPE else {}
megatron_cfg.update(kwargs)
initialize_megatron(megatron_cfg)
global _IS_MEGATRON_INITIALIZED
diff --git a/modelscope/utils/plugins.py b/modelscope/utils/plugins.py
index 6c2f2975..e62f775d 100644
--- a/modelscope/utils/plugins.py
+++ b/modelscope/utils/plugins.py
@@ -1,17 +1,40 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
+import copy
import importlib
import os
import pkgutil
import sys
+import venv
from contextlib import contextmanager
from fnmatch import fnmatch
from pathlib import Path
-from typing import Iterable, List, Optional, Set
+from typing import Any, Iterable, List, Optional, Set, Union
+import json
+import pkg_resources
+
+from modelscope.fileio.file import LocalStorage
+from modelscope.utils.ast_utils import FilesAstScanning
+from modelscope.utils.constant import DEFAULT_MODEL_REVISION
+from modelscope.utils.file_utils import get_default_cache_dir
+from modelscope.utils.hub import read_config, snapshot_download
from modelscope.utils.logger import get_logger
logger = get_logger()
+storage = LocalStorage()
+
+MODELSCOPE_FILE_DIR = get_default_cache_dir()
+PLUGINS_FILENAME = '.modelscope_plugins'
+OFFICIAL_PLUGINS = [
+ {
+ 'name': 'adaseq',
+ 'desc':
+ 'Provide hundreds of additions NERs algorithms, check: https://github.com/modelscope/AdaSeq',
+ 'version': '',
+ 'url': ''
+ },
+]
LOCAL_PLUGINS_FILENAME = '.modelscope_plugins'
GLOBAL_PLUGINS_FILENAME = os.path.join(Path.home(), '.modelscope', 'plugins')
@@ -65,24 +88,36 @@ def discover_file_plugins(
yield module_name
-def discover_plugins() -> Iterable[str]:
+def discover_plugins(requirement_path=None) -> Iterable[str]:
"""
Discover plugins
+
+ Args:
+ requirement_path: The file path of requirement
+
"""
plugins: Set[str] = set()
- if os.path.isfile(LOCAL_PLUGINS_FILENAME):
- with push_python_path('.'):
- for plugin in discover_file_plugins(LOCAL_PLUGINS_FILENAME):
+ if requirement_path is None:
+ if os.path.isfile(LOCAL_PLUGINS_FILENAME):
+ with push_python_path('.'):
+ for plugin in discover_file_plugins(LOCAL_PLUGINS_FILENAME):
+ if plugin in plugins:
+ continue
+ yield plugin
+ plugins.add(plugin)
+ if os.path.isfile(GLOBAL_PLUGINS_FILENAME):
+ for plugin in discover_file_plugins(GLOBAL_PLUGINS_FILENAME):
+ if plugin in plugins:
+ continue
+ yield plugin
+ plugins.add(plugin)
+ else:
+ if os.path.isfile(requirement_path):
+ for plugin in discover_file_plugins(requirement_path):
if plugin in plugins:
continue
yield plugin
plugins.add(plugin)
- if os.path.isfile(GLOBAL_PLUGINS_FILENAME):
- for plugin in discover_file_plugins(GLOBAL_PLUGINS_FILENAME):
- if plugin in plugins:
- continue
- yield plugin
- plugins.add(plugin)
def import_all_plugins(plugins: List[str] = None) -> List[str]:
@@ -142,9 +177,13 @@ def import_plugins(plugins: List[str] = None) -> List[str]:
return imported_plugins
-def import_file_plugins() -> List[str]:
+def import_file_plugins(requirement_path=None) -> List[str]:
"""
Imports the plugins found with `discover_plugins()`.
+
+ Args:
+ requirement_path: The file path of requirement
+
"""
imported_plugins: List[str] = []
@@ -153,7 +192,7 @@ def import_file_plugins() -> List[str]:
if cwd not in sys.path:
sys.path.append(cwd)
- for module_name in discover_plugins():
+ for module_name in discover_plugins(requirement_path):
try:
importlib.import_module(module_name)
logger.info('Plugin %s available', module_name)
@@ -174,7 +213,7 @@ def import_module_and_submodules(package_name: str,
include = include if include else set()
exclude = exclude if exclude else set()
- def fn_in(packge_name: str, pattern_set: Set[str]) -> bool:
+ def fn_in(package_name: str, pattern_set: Set[str]) -> bool:
for pattern in pattern_set:
if fnmatch(package_name, pattern):
return True
@@ -213,3 +252,473 @@ def import_module_and_submodules(package_name: str,
logger.warning(f'{package_name} not imported: {str(e)}')
if len(package_name.split('.')) == 1:
raise ModuleNotFoundError('Package not installed')
+
+
+def install_module_from_requirements(requirement_path, ):
+ """
+ Args:
+ requirement_path: The path of requirement file
+
+ Returns:
+
+ """
+
+ install_args = ['-r', requirement_path]
+ status_code, _, args = PluginsManager.pip_command(
+ 'install',
+ install_args,
+ )
+ if status_code != 0:
+ raise ImportError(
+ f'Failed to install requirements from {requirement_path}')
+
+
+def import_module_from_file(module_name, file_path):
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ module = importlib.util.module_from_spec(spec)
+ spec.loader.exec_module(module)
+ return module
+
+
+def import_module_from_model_dir(model_dir):
+ from pathlib import Path
+ file_scanner = FilesAstScanning()
+ file_scanner.traversal_files(model_dir)
+ file_dirs = file_scanner.file_dirs
+ requirements = file_scanner.requirement_dirs
+
+ # install the requirements firstly
+ install_requirements_by_files(requirements)
+
+ # then import the modules
+ import sys
+ sys.path.insert(0, model_dir)
+ for file in file_dirs:
+ module_name = Path(file).stem
+ import_module_from_file(module_name, file)
+
+
+def install_modelscope_if_need():
+ plugin_installed, version = PluginsManager.check_plugin_installed(
+ 'modelscope')
+ if not plugin_installed:
+ status_code, _, args = PluginsManager.pip_command(
+ 'install',
+ ['modelscope'],
+ )
+ if status_code != 0:
+ raise ImportError('Failed to install package modelscope')
+
+
+def install_requirements_by_names(plugins: List[str]):
+ plugins_manager = PluginsManager()
+ uninstalled_plugins = []
+ for plugin in plugins:
+ plugin_installed, version = plugins_manager.check_plugin_installed(
+ plugin)
+ if not plugin_installed:
+ uninstalled_plugins.append(plugin)
+ status, _ = plugins_manager.install_plugins(uninstalled_plugins)
+ if status != 0:
+ raise EnvironmentError(
+ f'The required packages {",".join(uninstalled_plugins)} are not installed.',
+ f'Please run the command `modelscope plugin install {" ".join(uninstalled_plugins)}` to install them.'
+ )
+ install_modelscope_if_need()
+
+
+def install_requirements_by_files(requirements: List[str]):
+ for requirement in requirements:
+ install_module_from_requirements(requirement)
+ install_modelscope_if_need()
+
+
+def register_plugins_repo(plugins: List[str]) -> None:
+ """ Try to install and import plugins from repo"""
+ if plugins is not None:
+ install_requirements_by_names(plugins)
+ import_plugins(plugins)
+
+
+def register_modelhub_repo(model_dir, allow_remote=False) -> None:
+ """ Try to install and import remote model from modelhub"""
+ if allow_remote:
+ try:
+ import_module_from_model_dir(model_dir)
+ except KeyError:
+ logger.warning(
+ 'Multi component keys in the hub are registered in same file')
+ pass
+
+
+class PluginsManager(object):
+
+ def __init__(self,
+ cache_dir=MODELSCOPE_FILE_DIR,
+ plugins_file=PLUGINS_FILENAME):
+ cache_dir = os.getenv('MODELSCOPE_CACHE', cache_dir)
+ plugins_file = os.getenv('MODELSCOPE_PLUGINS_FILE', plugins_file)
+ self._file_path = os.path.join(cache_dir, plugins_file)
+
+ @property
+ def file_path(self):
+ return self._file_path
+
+ @file_path.setter
+ def file_path(self, value):
+ self._file_path = value
+
+ @staticmethod
+ def check_plugin_installed(package):
+ try:
+ importlib.reload(pkg_resources)
+ package_meta_info = pkg_resources.working_set.by_key[package]
+ version = package_meta_info.version
+ installed = True
+ except KeyError:
+ version = ''
+ installed = False
+
+ return installed, version
+
+ @staticmethod
+ def pip_command(
+ command,
+ command_args: List[str],
+ ):
+ """
+
+ Args:
+ command: install, uninstall command
+ command_args: the args to be used with command, should be in list
+ such as ['-r', 'requirements']
+
+ Returns:
+
+ """
+ from pip._internal.commands import create_command
+ importlib.reload(pkg_resources)
+ command = create_command(command)
+ options, args = command.parse_args(command_args)
+
+ status_code = command.main(command_args)
+ return status_code, options, args
+
+ def install_plugins(self,
+ install_args: List[str],
+ index_url: Optional[str] = None,
+ force_update=False) -> Any:
+ """Install packages via pip
+ Args:
+ install_args (list): List of arguments passed to `pip install`.
+ index_url (str, optional): The pypi index url.
+ """
+
+ if len(install_args) == 0:
+ return 0, []
+
+ if index_url is not None:
+ install_args += ['-i', index_url]
+
+ if force_update is not False:
+ install_args += ['-f']
+
+ status_code, options, args = PluginsManager.pip_command(
+ 'install',
+ install_args,
+ )
+
+ if status_code == 0:
+ logger.info(f'The plugins {",".join(args)} is installed')
+
+ # TODO Add Ast index for ast update record
+
+ # Add the plugins info to the local record
+ installed_package = self.parse_args_info(args, options)
+ self.update_plugins_file(installed_package)
+
+ return status_code, install_args
+
+ def parse_args_info(self, args: List[str], options):
+ installed_package = []
+
+ # the case of install with requirements
+ if len(args) == 0:
+ src_dir = options.src_dir
+ requirements = options.requirments
+ for requirement in requirements:
+ package_info = {
+ 'name': requirement,
+ 'url': os.path.join(src_dir, requirement),
+ 'desc': '',
+ 'version': ''
+ }
+
+ installed_package.append(package_info)
+
+ def get_package_info(package_name):
+ from pathlib import Path
+ package_info = {
+ 'name': package_name,
+ 'url': options.index_url,
+ 'desc': ''
+ }
+
+ # the case with git + http
+ if package_name.split('.')[-1] == 'git':
+ package_name = Path(package_name).stem
+
+ plugin_installed, version = self.check_plugin_installed(
+ package_name)
+ if plugin_installed:
+ package_info['version'] = version
+ package_info['name'] = package_name
+ else:
+ logger.warning(
+ f'The package {package_name} is not in the lib, this might be happened'
+ f' when installing the package with git+https method, should be ignored'
+ )
+ package_info['version'] = ''
+
+ return package_info
+
+ for package in args:
+ package_info = get_package_info(package)
+ installed_package.append(package_info)
+
+ return installed_package
+
+ def uninstall_plugins(self,
+ uninstall_args: Union[str, List],
+ is_yes=False):
+ if is_yes is not None:
+ uninstall_args += ['-y']
+
+ status_code, options, args = PluginsManager.pip_command(
+ 'uninstall',
+ uninstall_args,
+ )
+
+ if status_code == 0:
+ logger.info(f'The plugins {",".join(args)} is uninstalled')
+
+ # TODO Add Ast index for ast update record
+
+ # Add to the local record
+ self.remove_plugins_from_file(args)
+
+ return status_code, uninstall_args
+
+ def _get_plugins_from_file(self):
+ """ get plugins from file
+
+ """
+ logger.info(f'Loading plugins information from {self.file_path}')
+ if os.path.exists(self.file_path):
+ local_plugins_info_bytes = storage.read(self.file_path)
+ local_plugins_info = json.loads(local_plugins_info_bytes)
+ else:
+ local_plugins_info = {}
+ return local_plugins_info
+
+ def _update_plugins(
+ self,
+ new_plugins_list,
+ local_plugins_info,
+ override=False,
+ ):
+ for item in new_plugins_list:
+ package_name = item.pop('name')
+
+ # update package information if existed
+ if package_name in local_plugins_info and not override:
+ original_item = local_plugins_info[package_name]
+ from pkg_resources import parse_version
+ item_version = parse_version(
+ item['version'] if item['version'] != '' else '0.0.0')
+ origin_version = parse_version(
+ original_item['version']
+ if original_item['version'] != '' else '0.0.0')
+ desc = item['desc']
+ if original_item['desc'] != '' and desc == '':
+ desc = original_item['desc']
+ item = item if item_version > origin_version else original_item
+ item['desc'] = desc
+
+ # Double-check if the item is installed with the version number
+ if item['version'] == '':
+ plugin_installed, version = self.check_plugin_installed(
+ package_name)
+ item['version'] = version
+
+ local_plugins_info[package_name] = item
+
+ return local_plugins_info
+
+ def _print_plugins_info(self, local_plugins_info):
+ print('{:<15} |{:<10} |{:<100}'.format('NAME', 'VERSION',
+ 'DESCRIPTION'))
+ print('')
+ for k, v in local_plugins_info.items():
+ print('{:<15} |{:<10} |{:<100}'.format(k, v['version'], v['desc']))
+
+ def list_plugins(
+ self,
+ show_all=False,
+ ):
+ """
+
+ Args:
+ show_all: show installed and official supported if True, else only those installed
+
+ Returns:
+
+ """
+ local_plugins_info = self._get_plugins_from_file()
+
+ # update plugins with default
+
+ local_official_plugins = copy.deepcopy(OFFICIAL_PLUGINS)
+ local_plugins_info = self._update_plugins(local_official_plugins,
+ local_plugins_info)
+
+ if show_all is True:
+ self._print_plugins_info(local_plugins_info)
+ return local_plugins_info
+
+ # Consider those package with version is installed
+ not_installed_list = []
+ for item in local_plugins_info:
+ if local_plugins_info[item]['version'] == '':
+ not_installed_list.append(item)
+
+ for item in not_installed_list:
+ local_plugins_info.pop(item)
+
+ self._print_plugins_info(local_plugins_info)
+ return local_plugins_info
+
+ def update_plugins_file(
+ self,
+ plugins_list,
+ override=False,
+ ):
+ """update the plugins file in order to maintain the latest plugins information
+
+ Args:
+ plugins_list: The plugins list contain the information of plugins
+ name, version, introduction, install url and the status of delete or update
+ override: Override the file by the list if True, else only update.
+
+ Returns:
+
+ """
+ local_plugins_info = self._get_plugins_from_file()
+
+ # local_plugins_info is empty if first time loading, should add OFFICIAL_PLUGINS information
+ if local_plugins_info == {}:
+ plugins_list.extend(copy.deepcopy(OFFICIAL_PLUGINS))
+
+ local_plugins_info = self._update_plugins(plugins_list,
+ local_plugins_info, override)
+
+ local_plugins_info_json = json.dumps(local_plugins_info)
+ storage.write(local_plugins_info_json.encode(), self.file_path)
+
+ return local_plugins_info_json
+
+ def remove_plugins_from_file(
+ self,
+ package_names: Union[str, list],
+ ):
+ """
+
+ Args:
+ package_names: package name
+
+ Returns:
+
+ """
+ local_plugins_info = self._get_plugins_from_file()
+
+ if type(package_names) is str:
+ package_names = list(package_names)
+
+ for item in package_names:
+ if item in local_plugins_info:
+ local_plugins_info.pop(item)
+
+ local_plugins_info_json = json.dumps(local_plugins_info)
+ storage.write(local_plugins_info_json.encode(), self.file_path)
+
+ return local_plugins_info_json
+
+
+class EnvsManager(object):
+ name = 'envs'
+
+ def __init__(self,
+ model_id,
+ model_revision=DEFAULT_MODEL_REVISION,
+ cache_dir=MODELSCOPE_FILE_DIR):
+ """
+
+ Args:
+ model_id: id of the model, not dir
+ model_revision: revision of the model, default as master
+ cache_dir: the system modelscope cache dir
+ """
+ cache_dir = os.getenv('MODELSCOPE_CACHE', cache_dir)
+ self.env_dir = os.path.join(cache_dir, EnvsManager.name, model_id)
+ model_dir = snapshot_download(model_id, revision=model_revision)
+ cfg = read_config(model_dir)
+ self.plugins = cfg.get('plugins', [])
+ self.allow_remote = cfg.get('allow_remote', False)
+ self.env_builder = venv.EnvBuilder(
+ system_site_packages=True,
+ clear=False,
+ symlinks=True,
+ with_pip=False)
+
+ def get_env_dir(self):
+ return self.env_dir
+
+ def get_activate_dir(self):
+ return os.path.join(self.env_dir, 'bin', 'activate')
+
+ def check_if_need_env(self):
+ if len(self.plugins) or self.allow_remote:
+ return True
+ else:
+ return False
+
+ def create_env(self):
+ if not os.path.exists(self.env_dir):
+ os.makedirs(self.env_dir)
+ try:
+ self.env_builder.create(self.env_dir)
+ except Exception as e:
+ self.clean_env()
+ raise EnvironmentError(
+ f'Failed to create virtual env at {self.env_dir} with error: {e}'
+ )
+
+ def clean_env(self):
+ if os.path.exists(self.env_dir):
+ self.env_builder.clear_directory(self.env_dir)
+
+ @staticmethod
+ def run_process(cmd):
+ import subprocess
+ status, result = subprocess.getstatusoutput(cmd)
+ logger.debug('The status and the results are: {}, {}'.format(
+ status, result))
+ if status != 0:
+ raise Exception(
+ 'running the cmd: {} failed, with message: {}'.format(
+ cmd, result))
+ return result
+
+
+if __name__ == '__main__':
+ install_requirements_by_files(['adaseq'])
diff --git a/modelscope/utils/torch_utils.py b/modelscope/utils/torch_utils.py
index 3d315716..8dedf022 100644
--- a/modelscope/utils/torch_utils.py
+++ b/modelscope/utils/torch_utils.py
@@ -12,9 +12,9 @@ from typing import Callable, List, Optional, Tuple
import numpy as np
import torch
import torch.multiprocessing as mp
+from packaging import version
from torch import distributed as dist
-from modelscope.utils.constant import DistributedParallelType
from modelscope.utils.megatron_utils import is_megatron_initialized
@@ -36,6 +36,20 @@ def _is_free_port(port: int) -> bool:
return all(s.connect_ex((ip, port)) != 0 for ip in ips)
+def compile_model(model, **compile_options):
+ # Compile the model with torch 2.0
+ if hasattr(model, 'compile'):
+ model = model.compile(**compile_options)
+ elif version.parse(torch.__version__) >= version.parse('2.0.0.dev'):
+ model = torch.compile(model, **compile_options)
+ else:
+ print(
+ 'Compiling model needs torch version > 2.0.0, '
+ f'your torch version is: {torch.__version__}, origin model will be returned.'
+ )
+ return model
+
+
def init_dist(launcher: str, backend: str = 'nccl', **kwargs) -> None:
if mp.get_start_method(allow_none=True) is None:
mp.set_start_method('spawn')
@@ -108,10 +122,17 @@ def _init_dist_slurm(backend: str, port: Optional[int] = None) -> None:
dist.init_process_group(backend=backend)
-def get_dist_info() -> Tuple[int, int]:
+def get_dist_info(group=None) -> Tuple[int, int]:
+ """Get dist info of a specified group
+
+ Args:
+ group: The parallel group, default None, for the global group
+
+ Returns:
+ A tuple of the current rank and world_size of the group
+ """
if is_dist():
- group = None
- if is_megatron_initialized():
+ if group is None and is_megatron_initialized():
from megatron_util import mpu
group = mpu.get_data_parallel_group()
rank = dist.get_rank(group)
@@ -162,24 +183,9 @@ def is_dist():
def is_master(group=None):
- if isinstance(group, str):
- group = _parse_parallel_group(group)
return dist.get_rank(group) == 0 if is_dist() else True
-def _parse_parallel_group(group: str):
- from megatron_util import mpu
- if group == DistributedParallelType.DP:
- return mpu.get_data_parallel_group()
- if group == DistributedParallelType.TP:
- return mpu.get_tensor_model_parallel_group()
- if group == DistributedParallelType.PP:
- return mpu.get_pipeline_model_parallel_group()
- raise ValueError(
- f"Wrong group '{group}'. Supported groups are '{DistributedParallelType.DP}', "
- f"'{DistributedParallelType.TP}' or '{DistributedParallelType.PP}'")
-
-
def master_only(group=None):
def decorate(func: Callable) -> Callable:
diff --git a/modelscope/version.py b/modelscope/version.py
index 4fa90b93..7a913906 100644
--- a/modelscope/version.py
+++ b/modelscope/version.py
@@ -1,5 +1,5 @@
# Make sure to modify __release_datetime__ to release time when making official release.
-__version__ = '1.3.0'
+__version__ = '1.4.1'
# default release datetime for branches under active development is set
# to be a time far-far-away-into-the-future
__release_datetime__ = '2099-10-13 08:56:12'
diff --git a/requirements/audio/audio_asr.txt b/requirements/audio/audio_asr.txt
index 2c9a201f..2709d960 100644
--- a/requirements/audio/audio_asr.txt
+++ b/requirements/audio/audio_asr.txt
@@ -1,2 +1,2 @@
easyasr>=0.0.2
-funasr>=0.2.2
+funasr>=0.3.0
diff --git a/requirements/audio/audio_kws.txt b/requirements/audio/audio_kws.txt
index 12b73bea..4118f3ed 100644
--- a/requirements/audio/audio_kws.txt
+++ b/requirements/audio/audio_kws.txt
@@ -1,5 +1,5 @@
kaldiio
-kwsbp>=0.0.2
+kwsbp>=0.0.6
matplotlib
numpy
py_sound_connect>=0.1
diff --git a/requirements/audio/audio_signal.txt b/requirements/audio/audio_signal.txt
index 6082a2e1..61e688f3 100644
--- a/requirements/audio/audio_signal.txt
+++ b/requirements/audio/audio_signal.txt
@@ -1,5 +1,5 @@
hyperpyyaml
-librosa
+librosa<=0.9.2
MinDAEC
mir_eval>=0.7
numpy
diff --git a/requirements/audio/audio_tts.txt b/requirements/audio/audio_tts.txt
index b9974294..b1a85faf 100644
--- a/requirements/audio/audio_tts.txt
+++ b/requirements/audio/audio_tts.txt
@@ -2,7 +2,8 @@ bitstring
greenlet>=1.1.2
inflect
jedi>=0.18.1
-librosa
+kantts
+librosa<=0.9.2
lxml
matplotlib
msgpack>=1.0.4
@@ -22,6 +23,6 @@ sox
tensorboardx
tqdm
traitlets>=5.3.0
-ttsfrd>=0.1.1
+ttsfrd>=0.1.2
unidecode
wcwidth>=0.2.5
diff --git a/requirements/cv.txt b/requirements/cv.txt
index d505bff4..c33d5b96 100644
--- a/requirements/cv.txt
+++ b/requirements/cv.txt
@@ -9,6 +9,7 @@ ddpm_guided_diffusion
diffusers
easydict
easyrobust
+edit_distance
face_alignment>=1.3.5
fairscale>=0.4.1
fastai>=1.0.51
@@ -34,11 +35,11 @@ networkx
numba
omegaconf
onnx
-onnx-simplifier
onnxruntime>=1.10
+onnxsim
open-clip-torch>=2.7.0
opencv-python
-pai-easycv>=0.8
+pai-easycv>=0.8,<0.10.0
paint_ldm
pandas
panopticapi
@@ -59,6 +60,7 @@ torchmetrics>=0.6.2
torchsummary>=1.5.1
torchvision
transformers>=4.26.0
+trimesh
ujson
utils
videofeatures_clipit>=1.0
diff --git a/requirements/framework.txt b/requirements/framework.txt
index 9a6a8998..d701b860 100644
--- a/requirements/framework.txt
+++ b/requirements/framework.txt
@@ -5,7 +5,7 @@ einops
filelock>=3.3.0
gast>=0.2.2
jsonplus
-numpy
+numpy<1.24.0
oss2
Pillow>=6.2.0
# pyarrow 9.0.0 introduced event_loop core dump
diff --git a/requirements/multi-modal.txt b/requirements/multi-modal.txt
index 8a86be8e..49b79f2c 100644
--- a/requirements/multi-modal.txt
+++ b/requirements/multi-modal.txt
@@ -1,7 +1,7 @@
accelerate
diffusers>=0.11.1
ftfy>=6.0.3
-librosa
+librosa<=0.9.2
opencv-python
pycocoevalcap>=1.2
pycocotools>=2.0.4
@@ -12,11 +12,13 @@ rapidfuzz
# which introduced compatability issues that are being investigated
rouge_score<=0.0.4
sacrebleu
+# scikit-video
soundfile
taming-transformers-rom1504
timm
tokenizers
torchvision
transformers>=4.12.0
+# triton==2.0.0.dev20221120
unicodedata2
zhconv
diff --git a/setup.py b/setup.py
index 011a4796..9affe028 100644
--- a/setup.py
+++ b/setup.py
@@ -204,7 +204,7 @@ if __name__ == '__main__':
author_email='modelscope@list.alibaba-inc.com',
keywords='python,nlp,science,cv,speech,multi-modal',
url='https://github.com/modelscope/modelscope',
- packages=find_packages(exclude=('configs', 'tools', 'demo')),
+ packages=find_packages(exclude=('configs', 'demo')),
include_package_data=True,
package_data={
'': ['*.h', '*.cpp', '*.cu'],
diff --git a/tests/cli/test_custom_pipeline_cmd.py b/tests/cli/test_custom_pipeline_cmd.py
new file mode 100644
index 00000000..5682c123
--- /dev/null
+++ b/tests/cli/test_custom_pipeline_cmd.py
@@ -0,0 +1,23 @@
+import os
+import shutil
+import subprocess
+import tempfile
+import unittest
+import uuid
+
+
+class ModelUploadCMDTest(unittest.TestCase):
+
+ def setUp(self):
+ self.task_name = 'task-%s' % (uuid.uuid4().hex)
+ print(self.task_name)
+
+ def test_upload_modelcard(self):
+ cmd = f'python -m modelscope.cli.cli pipeline --action create --task_name {self.task_name} '
+ stat, output = subprocess.getstatusoutput(cmd)
+ if stat != 0:
+ print(output)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/tests/cli/test_modelcard_cmd.py b/tests/cli/test_modelcard_cmd.py
new file mode 100644
index 00000000..3484895b
--- /dev/null
+++ b/tests/cli/test_modelcard_cmd.py
@@ -0,0 +1,54 @@
+import os
+import os.path as osp
+import shutil
+import subprocess
+import tempfile
+import unittest
+import uuid
+
+from modelscope.hub.api import HubApi
+from modelscope.utils.test_utils import TEST_ACCESS_TOKEN1, TEST_MODEL_ORG
+
+
+class ModelUploadCMDTest(unittest.TestCase):
+
+ def setUp(self):
+ print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
+ self.tmp_dir = tempfile.TemporaryDirectory().name
+ if not os.path.exists(self.tmp_dir):
+ os.makedirs(self.tmp_dir)
+ print(self.tmp_dir)
+ self.api = HubApi()
+ self.api.login(TEST_ACCESS_TOKEN1)
+ self.task_name = 'task-%s' % (uuid.uuid4().hex)
+ self.model_name = 'op-%s' % (uuid.uuid4().hex)
+ self.model_id = '%s/%s' % (TEST_MODEL_ORG, self.model_name)
+ print(self.tmp_dir, self.task_name, self.model_name)
+
+ def tearDown(self):
+ self.api.delete_model(model_id=self.model_id)
+ shutil.rmtree(self.tmp_dir)
+ super().tearDown()
+
+ def test_upload_modelcard(self):
+ cmd = f'python -m modelscope.cli.cli pipeline --action create --task_name {self.task_name} ' \
+ f'--save_file_path {self.tmp_dir} --configuration_path {self.tmp_dir}'
+ stat, output = subprocess.getstatusoutput(cmd)
+ if stat != 0:
+ print(output)
+
+ cmd = f'python {self.tmp_dir}/ms_wrapper.py'
+ stat, output = subprocess.getstatusoutput(cmd)
+ if stat != 0:
+ print(output)
+ self.assertEqual(stat, 0)
+
+ cmd = f'python -m modelscope.cli.cli modelcard --action upload -tk {TEST_ACCESS_TOKEN1} ' \
+ f'--model_id {self.model_id} --model_dir {self.tmp_dir}'
+ stat, output = subprocess.getstatusoutput(cmd)
+ if stat != 0:
+ print(output)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/tests/cli/test_plugins_cmd.py b/tests/cli/test_plugins_cmd.py
new file mode 100644
index 00000000..b11c67ab
--- /dev/null
+++ b/tests/cli/test_plugins_cmd.py
@@ -0,0 +1,50 @@
+import subprocess
+import unittest
+
+from modelscope.utils.plugins import PluginsManager
+
+
+class PluginsCMDTest(unittest.TestCase):
+
+ def setUp(self):
+ print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
+ self.package = 'adaseq'
+ self.plugins_manager = PluginsManager()
+
+ def tearDown(self):
+ super().tearDown()
+
+ def test_plugins_install(self):
+ cmd = f'python -m modelscope.cli.cli plugin install {self.package}'
+ stat, output = subprocess.getstatusoutput(cmd)
+ self.assertEqual(stat, 0)
+
+ # move this from tear down to avoid unexpected uninstall
+ uninstall_args = [self.package, '-y']
+ self.plugins_manager.uninstall_plugins(uninstall_args)
+
+ def test_plugins_uninstall(self):
+ # move this from tear down to avoid unexpected uninstall
+ uninstall_args = [self.package, '-y']
+ self.plugins_manager.uninstall_plugins(uninstall_args)
+
+ cmd = f'python -m modelscope.cli.cli plugin install {self.package}'
+ stat, output = subprocess.getstatusoutput(cmd)
+ self.assertEqual(stat, 0)
+
+ cmd = f'python -m modelscope.cli.cli plugin uninstall {self.package}'
+ stat, output = subprocess.getstatusoutput(cmd)
+ self.assertEqual(stat, 0)
+
+ # move this from tear down to avoid unexpected uninstall
+ uninstall_args = [self.package, '-y']
+ self.plugins_manager.uninstall_plugins(uninstall_args)
+
+ def test_plugins_list(self):
+ cmd = 'python -m modelscope.cli.cli plugin list'
+ stat, output = subprocess.getstatusoutput(cmd)
+ self.assertEqual(stat, 0)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/tests/export/test_export_object_detection_damoyolo.py b/tests/export/test_export_object_detection_damoyolo.py
new file mode 100644
index 00000000..d7e51165
--- /dev/null
+++ b/tests/export/test_export_object_detection_damoyolo.py
@@ -0,0 +1,32 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import os
+import shutil
+import tempfile
+import unittest
+from collections import OrderedDict
+
+from modelscope.exporters import Exporter
+from modelscope.models import Model
+from modelscope.utils.constant import Tasks
+from modelscope.utils.test_utils import test_level
+
+
+class TestExportObjectDetectionDamoyolo(unittest.TestCase):
+
+ def setUp(self):
+ print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
+ self.tmp_dir = tempfile.TemporaryDirectory().name
+ if not os.path.exists(self.tmp_dir):
+ os.makedirs(self.tmp_dir)
+ self.model_id = 'damo/cv_tinynas_object-detection_damoyolo'
+
+ @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
+ def test_export_object_detection_damoyolo(self):
+
+ model = Model.from_pretrained(self.model_id)
+ Exporter.from_model(model).export_onnx(
+ input_shape=(1, 3, 640, 640), output_dir=self.tmp_dir)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/tests/export/test_export_token_classification.py b/tests/export/test_export_token_classification.py
new file mode 100644
index 00000000..951f3616
--- /dev/null
+++ b/tests/export/test_export_token_classification.py
@@ -0,0 +1,41 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import os
+import shutil
+import tempfile
+import unittest
+from collections import OrderedDict
+
+from modelscope.exporters import Exporter, TorchModelExporter
+from modelscope.models import Model
+from modelscope.utils.constant import Tasks
+from modelscope.utils.test_utils import test_level
+
+
+class TestExportTokenClassification(unittest.TestCase):
+
+ def setUp(self):
+ print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
+ self.tmp_dir = tempfile.TemporaryDirectory().name
+ if not os.path.exists(self.tmp_dir):
+ os.makedirs(self.tmp_dir)
+ self.model_id = 'damo/nlp_raner_named-entity-recognition_chinese-base-news'
+
+ def tearDown(self):
+ shutil.rmtree(self.tmp_dir)
+ super().tearDown()
+
+ @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
+ def test_export_token_classification(self):
+ model = Model.from_pretrained(self.model_id)
+ with self.subTest(format='onnx'):
+ print(
+ Exporter.from_model(model).export_onnx(
+ output_dir=self.tmp_dir))
+ with self.subTest(format='torchscript'):
+ print(
+ Exporter.from_model(model).export_torch_script(
+ output_dir=self.tmp_dir))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/tests/msdatasets/test_custom_datasets_compatibility.py b/tests/msdatasets/test_custom_datasets_compatibility.py
new file mode 100644
index 00000000..f9cd7fa1
--- /dev/null
+++ b/tests/msdatasets/test_custom_datasets_compatibility.py
@@ -0,0 +1,76 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+
+import os
+import unittest
+
+from datasets import Dataset
+
+from modelscope.hub.snapshot_download import snapshot_download
+from modelscope.msdatasets.dataset_cls.custom_datasets import \
+ TorchCustomDataset
+from modelscope.preprocessors import Preprocessor
+from modelscope.trainers.trainer import EpochBasedTrainer
+from modelscope.utils import logger as logging
+from modelscope.utils.config import Config
+from modelscope.utils.constant import ModeKeys, ModelFile, Tasks
+from modelscope.utils.test_utils import test_level
+
+logger = logging.get_logger()
+
+
+class TestDummyEpochBasedTrainer(EpochBasedTrainer):
+
+ def __init__(self,
+ dataset: Dataset = None,
+ mode: str = ModeKeys.TRAIN,
+ preprocessor: Preprocessor = None,
+ **kwargs):
+ super(TestDummyEpochBasedTrainer, self).__init__(**kwargs)
+ self.train_dataset = self.to_task_dataset(dataset, mode, preprocessor)
+
+ def to_task_dataset(self, dataset: Dataset, mode: str,
+ preprocessor: Preprocessor,
+ **kwargs) -> TorchCustomDataset:
+ src_dataset_dict = {
+ 'src_txt': [
+ 'This is test sentence1-1', 'This is test sentence2-1',
+ 'This is test sentence3-1'
+ ]
+ }
+ dataset = Dataset.from_dict(src_dataset_dict)
+ dataset_res = TorchCustomDataset(
+ datasets=dataset, mode=mode, preprocessor=preprocessor)
+ dataset_res.trainer = self
+ return dataset_res
+
+
+class TestCustomDatasetsCompatibility(unittest.TestCase):
+
+ def setUp(self):
+ self.task = Tasks.movie_scene_segmentation
+ self.model_id = 'damo/cv_resnet50-bert_video-scene-segmentation_movienet'
+
+ cache_path = snapshot_download(self.model_id)
+ self.config_path = os.path.join(cache_path, ModelFile.CONFIGURATION)
+ self.cfg = Config.from_file(self.config_path)
+
+ @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
+ def test_adaseq_import_task_datasets(self):
+ from modelscope.msdatasets.task_datasets.torch_base_dataset import TorchTaskDataset
+ from modelscope.msdatasets.task_datasets import GoproImageDeblurringDataset
+ from modelscope.msdatasets.task_datasets import RedsImageDeblurringDataset
+ from modelscope.msdatasets.task_datasets import SiddImageDenoisingDataset
+ from modelscope.msdatasets.task_datasets import VideoSummarizationDataset
+
+ @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
+ def test_adaseq_trainer_overwrite(self):
+ test_trainer = TestDummyEpochBasedTrainer(cfg_file=self.config_path)
+
+ assert isinstance(test_trainer.train_dataset.trainer,
+ TestDummyEpochBasedTrainer)
+ assert test_trainer.train_dataset.mode == ModeKeys.TRAIN
+ assert isinstance(test_trainer.train_dataset._inner_dataset, Dataset)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/tests/msdatasets/test_ms_dataset.py b/tests/msdatasets/test_ms_dataset.py
index 51074bca..8ded9a46 100644
--- a/tests/msdatasets/test_ms_dataset.py
+++ b/tests/msdatasets/test_ms_dataset.py
@@ -3,12 +3,16 @@ import hashlib
import os
import unittest
+from modelscope.hub.snapshot_download import snapshot_download
from modelscope.models import Model
from modelscope.msdatasets import MsDataset
-from modelscope.msdatasets.audio.asr_dataset import ASRDataset
+from modelscope.msdatasets.dataset_cls.custom_datasets.audio.asr_dataset import \
+ ASRDataset
from modelscope.preprocessors import TextClassificationTransformersPreprocessor
from modelscope.preprocessors.base import Preprocessor
-from modelscope.utils.constant import DEFAULT_DATASET_NAMESPACE, DownloadMode
+from modelscope.utils.config import Config
+from modelscope.utils.constant import (DEFAULT_DATASET_NAMESPACE, DownloadMode,
+ ModelFile)
from modelscope.utils.test_utils import require_tf, require_torch, test_level
@@ -68,6 +72,7 @@ class MsDatasetTest(unittest.TestCase):
ms_ds_train = MsDataset.load('movie_scene_seg_toydata', split='train')
print(ms_ds_train._hf_ds.config_kwargs)
assert next(iter(ms_ds_train.config_kwargs['split_config'].values()))
+ assert next(iter(ms_ds_train))
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_coco(self):
@@ -260,6 +265,34 @@ class MsDatasetTest(unittest.TestCase):
print(data_example)
assert data_example.values()
+ @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
+ def test_to_custom_dataset_movie_scene_toydata(self):
+ from modelscope.msdatasets.dataset_cls.custom_datasets.movie_scene_segmentation import \
+ MovieSceneSegmentationDataset
+ from modelscope.msdatasets.dataset_cls.dataset import ExternalDataset
+
+ model_id = 'damo/cv_resnet50-bert_video-scene-segmentation_movienet'
+ cache_path = snapshot_download(model_id)
+ config_path = os.path.join(cache_path, ModelFile.CONFIGURATION)
+ cfg = Config.from_file(config_path)
+
+ # ds_test.ds_instance got object 'MovieSceneSegmentationDataset' when the custom_cfg is not none.
+ ds_test_1 = MsDataset.load(
+ 'modelscope/movie_scene_seg_toydata',
+ split='test',
+ custom_cfg=cfg,
+ test_mode=True)
+ assert ds_test_1.is_custom
+ assert isinstance(ds_test_1.ds_instance, MovieSceneSegmentationDataset)
+
+ # ds_test.ds_instance got object 'ExternalDataset' when the custom_cfg is none. (by default)
+ ds_test_2 = MsDataset.load(
+ 'modelscope/movie_scene_seg_toydata',
+ split='test',
+ custom_cfg=None)
+ assert not ds_test_2.is_custom
+ assert isinstance(ds_test_2.ds_instance, ExternalDataset)
+
if __name__ == '__main__':
unittest.main()
diff --git a/tests/pipelines/adaseq_pipelines/test_named_entity_recognition.py b/tests/pipelines/adaseq_pipelines/test_named_entity_recognition.py
deleted file mode 100644
index 4a0af955..00000000
--- a/tests/pipelines/adaseq_pipelines/test_named_entity_recognition.py
+++ /dev/null
@@ -1,25 +0,0 @@
-# Copyright (c) Alibaba, Inc. and its affiliates.
-import os
-import unittest
-
-from modelscope.pipelines import pipeline
-from modelscope.utils.constant import Tasks
-from modelscope.utils.demo_utils import DemoCompatibilityCheck
-from modelscope.utils.test_utils import test_level
-
-
-class NamedEntityRecognitionTest(unittest.TestCase, DemoCompatibilityCheck):
-
- def setUp(self):
- os.system('pip install adaseq>=0.5.0')
-
- @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
- def test_run_span_based_ner_pipeline(self):
- pipeline_ins = pipeline(
- Tasks.named_entity_recognition,
- 'damo/nlp_nested-ner_named-entity-recognition_chinese-base-med')
- print(
- pipeline_ins(
- '1、可测量目标: 1周内胸闷缓解。2、下一步诊疗措施:1.心内科护理常规,一级护理,低盐低脂饮食,留陪客。'
- '2.予“阿司匹林肠溶片”抗血小板聚集,“呋塞米、螺内酯”利尿减轻心前负荷,“瑞舒伐他汀”调脂稳定斑块,“厄贝沙坦片片”降血压抗心机重构'
- ))
diff --git a/tests/pipelines/plugin_remote_pipelines/__init__.py b/tests/pipelines/plugin_remote_pipelines/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/tests/pipelines/plugin_remote_pipelines/test_allow_remote_model.py b/tests/pipelines/plugin_remote_pipelines/test_allow_remote_model.py
new file mode 100644
index 00000000..0453cf64
--- /dev/null
+++ b/tests/pipelines/plugin_remote_pipelines/test_allow_remote_model.py
@@ -0,0 +1,37 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import os
+import unittest
+
+from modelscope.hub.snapshot_download import snapshot_download
+from modelscope.pipelines import pipeline
+from modelscope.utils.plugins import PluginsManager
+from modelscope.utils.test_utils import test_level
+
+
+class AllowRemoteModelTest(unittest.TestCase):
+
+ def setUp(self):
+ self.package = 'moviepy'
+
+ def tearDown(self):
+ # make sure uninstalled after installing
+ uninstall_args = [self.package, '-y']
+ PluginsManager.pip_command('uninstall', uninstall_args)
+ super().tearDown()
+
+ @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
+ def test_run_bilibili_image(self):
+
+ model_path = snapshot_download(
+ 'bilibili/cv_bilibili_image-super-resolution', revision='v1.0.5')
+ file_path = f'{model_path}/demos/title-compare1.png'
+ weight_path = f'{model_path}/weights_v3/up2x-latest-denoise3x.pth'
+ inference = pipeline(
+ 'image-super-resolution',
+ model='bilibili/cv_bilibili_image-super-resolution',
+ weight_path=weight_path,
+ device='cpu',
+ half=False) # GPU环境可以设置为True
+
+ output = inference(file_path, tile_mode=0, cache_mode=1, alpha=1)
+ print(output)
diff --git a/tests/pipelines/plugin_remote_pipelines/test_plugin_model.py b/tests/pipelines/plugin_remote_pipelines/test_plugin_model.py
new file mode 100644
index 00000000..40124dac
--- /dev/null
+++ b/tests/pipelines/plugin_remote_pipelines/test_plugin_model.py
@@ -0,0 +1,45 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import unittest
+
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import Tasks
+from modelscope.utils.demo_utils import DemoCompatibilityCheck
+from modelscope.utils.plugins import PluginsManager
+from modelscope.utils.test_utils import test_level
+
+
+class PluginModelTest(unittest.TestCase, DemoCompatibilityCheck):
+
+ def setUp(self):
+ self.package = 'adaseq'
+
+ def tearDown(self):
+ # make sure uninstalled after installing
+ uninstall_args = [self.package, '-y']
+ PluginsManager.pip_command('uninstall', uninstall_args)
+ super().tearDown()
+ import subprocess
+ result = subprocess.run(
+ ['pip', 'install', 'adaseq>=0.6.2', '--no-deps'],
+ stdout=subprocess.PIPE)
+ print(result.stdout.decode('utf-8'))
+
+ @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
+ def test_run_span_based_ner_pipeline(self):
+ pipeline_ins = pipeline(
+ Tasks.named_entity_recognition,
+ 'damo/nlp_nested-ner_named-entity-recognition_chinese-base-med')
+ print(
+ pipeline_ins(
+ '1、可测量目标: 1周内胸闷缓解。2、下一步诊疗措施:1.心内科护理常规,一级护理,低盐低脂饮食,留陪客。'
+ '2.予“阿司匹林肠溶片”抗血小板聚集,“呋塞米、螺内酯”利尿减轻心前负荷,“瑞舒伐他汀”调脂稳定斑块,“厄贝沙坦片片”降血压抗心机重构'
+ ))
+
+ def test_maoe_pipelines(self):
+ pipeline_ins = pipeline(
+ Tasks.named_entity_recognition,
+ 'damo/nlp_maoe_named-entity-recognition_chinese-base-general')
+ print(
+ pipeline_ins(
+ '刘培强,男,生理年龄40岁(因为在太空中进入休眠状态),实际年龄52岁,领航员国际空间站中的中国航天员,机械工程专家,军人,军衔中校。'
+ ))
diff --git a/tests/pipelines/test_base.py b/tests/pipelines/test_base.py
index 75aa8afc..32bf8877 100644
--- a/tests/pipelines/test_base.py
+++ b/tests/pipelines/test_base.py
@@ -179,13 +179,13 @@ class CustomPipelineTest(unittest.TestCase):
img_url = 'data/test/images/dogs.jpg'
output = pipe(img_url)
self.assertEqual(output['filename'], img_url)
- self.assertEqual(output[OutputKeys.OUTPUT_IMG].shape, (318, 512, 3))
+ self.assertEqual(output[OutputKeys.OUTPUT_IMG].shape, (598, 600, 3))
outputs = pipe([img_url for i in range(4)])
self.assertEqual(len(outputs), 4)
for out in outputs:
self.assertEqual(out['filename'], img_url)
- self.assertEqual(out[OutputKeys.OUTPUT_IMG].shape, (318, 512, 3))
+ self.assertEqual(out[OutputKeys.OUTPUT_IMG].shape, (598, 600, 3))
if __name__ == '__main__':
diff --git a/tests/pipelines/test_disco_guided_diffusion.py b/tests/pipelines/test_disco_guided_diffusion.py
new file mode 100644
index 00000000..d7be7292
--- /dev/null
+++ b/tests/pipelines/test_disco_guided_diffusion.py
@@ -0,0 +1,46 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import unittest
+
+import cv2
+
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import Tasks
+from modelscope.utils.demo_utils import DemoCompatibilityCheck
+from modelscope.utils.test_utils import test_level
+
+
+class DiscoGuidedDiffusionTest(unittest.TestCase, DemoCompatibilityCheck):
+
+ def setUp(self) -> None:
+ self.task = Tasks.text_to_image_synthesis
+ self.model_id1 = 'yyqoni/yinyueqin_test'
+ self.model_id2 = 'yyqoni/yinyueqin_cyberpunk'
+
+ test_input1 = '夕阳西下'
+ test_input2 = '城市,赛博朋克'
+
+ @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
+ def test_run(self):
+ diffusers_pipeline = pipeline(
+ task=self.task, model=self.model_id1, model_revision='v1.0')
+ output = diffusers_pipeline({
+ 'text': self.test_input1,
+ 'height': 256,
+ 'width': 256
+ })
+ cv2.imwrite('output1.png', output['output_imgs'][0])
+ print('Image saved to output1.png')
+
+ diffusers_pipeline = pipeline(
+ task=self.task, model=self.model_id2, model_revision='v1.0')
+ output = diffusers_pipeline({
+ 'text': self.test_input2,
+ 'height': 256,
+ 'width': 256
+ })
+ cv2.imwrite('output2.png', output['output_imgs'][0])
+ print('Image saved to output2.png')
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/tests/pipelines/test_generative_multi_modal_embedding.py b/tests/pipelines/test_generative_multi_modal_embedding.py
index 7061d736..18b96f65 100644
--- a/tests/pipelines/test_generative_multi_modal_embedding.py
+++ b/tests/pipelines/test_generative_multi_modal_embedding.py
@@ -13,7 +13,7 @@ class GEMMMultiModalEmbeddingTest(unittest.TestCase, DemoCompatibilityCheck):
def setUp(self) -> None:
self.task = Tasks.generative_multi_modal_embedding
- self.model_id = 'damo/multi-modal_gemm-vit-large-patch14_generative-multi-modal-embedding'
+ self.model_id = 'damo/multi-modal_rleg-vit-large-patch14'
test_input = {
'image': 'data/test/images/generative_multimodal.jpg',
diff --git a/tests/pipelines/test_gpt3_text_generation.py b/tests/pipelines/test_gpt3_text_generation.py
index 7f7722b5..1d938384 100644
--- a/tests/pipelines/test_gpt3_text_generation.py
+++ b/tests/pipelines/test_gpt3_text_generation.py
@@ -27,6 +27,11 @@ class TextGPT3GenerationTest(unittest.TestCase):
pipe = pipeline(Tasks.text_generation, model=self.model_id_2_7B)
print(pipe(self.input))
+ @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
+ def test_gpt3_1_3B_with_args(self):
+ pipe = pipeline(Tasks.text_generation, model=self.model_id_1_3B)
+ print(pipe(self.input, top_p=0.9, temperature=0.9, max_length=32))
+
@unittest.skip('distributed gpt3 13B, skipped')
def test_gpt3_13B(self):
""" The model can be downloaded from the link on
diff --git a/tests/pipelines/test_human_reconstruction.py b/tests/pipelines/test_human_reconstruction.py
new file mode 100644
index 00000000..9b856958
--- /dev/null
+++ b/tests/pipelines/test_human_reconstruction.py
@@ -0,0 +1,46 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import os.path as osp
+import sys
+import unittest
+
+from modelscope.hub.snapshot_download import snapshot_download
+from modelscope.outputs import OutputKeys
+from modelscope.pipelines import pipeline
+from modelscope.pipelines.base import Pipeline
+from modelscope.utils.constant import Tasks
+from modelscope.utils.test_utils import test_level
+
+sys.path.append('.')
+
+
+class HumanReconstructionTest(unittest.TestCase):
+
+ def setUp(self) -> None:
+ self.task = Tasks.human_reconstruction
+ self.model_id = 'damo/cv_hrnet_image-human-reconstruction'
+ self.test_image = 'data/test/images/human_reconstruction.jpg'
+
+ def pipeline_inference(self, pipeline: Pipeline, input_location: str):
+ result = pipeline(input_location)
+ mesh = result[OutputKeys.OUTPUT]
+ print(
+ f'Output to {osp.abspath("human_reconstruction.obj")}, vertices num: {mesh["vertices"].shape}'
+ )
+
+ @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
+ def test_run_by_direct_model_download(self):
+ model_dir = snapshot_download(self.model_id)
+ human_reconstruction = pipeline(
+ Tasks.human_reconstruction, model=model_dir)
+ print('running')
+ self.pipeline_inference(human_reconstruction, self.test_image)
+
+ @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
+ def test_run_modelhub(self):
+ human_reconstruction = pipeline(
+ Tasks.human_reconstruction, model=self.model_id)
+ self.pipeline_inference(human_reconstruction, self.test_image)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/tests/pipelines/test_image_depth_estimation_bts.py b/tests/pipelines/test_image_depth_estimation_bts.py
new file mode 100644
index 00000000..bda7a41f
--- /dev/null
+++ b/tests/pipelines/test_image_depth_estimation_bts.py
@@ -0,0 +1,54 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import unittest
+
+import cv2
+
+from modelscope.hub.snapshot_download import snapshot_download
+from modelscope.models import Model
+from modelscope.outputs import OutputKeys
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import Tasks
+from modelscope.utils.demo_utils import DemoCompatibilityCheck
+from modelscope.utils.test_utils import test_level
+
+
+class ImageDepthEstimationBtsTest(unittest.TestCase, DemoCompatibilityCheck):
+
+ def setUp(self) -> None:
+ self.task = Tasks.image_depth_estimation
+ self.model_id = 'damo/cv_densenet161_image-depth-estimation_bts'
+ self.image = 'data/test/images/image_depth_estimation_kitti_007517.png'
+
+ @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
+ def test_run_with_model_from_modelhub(self):
+ model = Model.from_pretrained(self.model_id)
+ pipeline_bts = pipeline(task=self.task, model=model)
+ result = pipeline_bts(input=self.image)
+ depth_vis = result[OutputKeys.DEPTHS_COLOR]
+ cv2.imwrite('result_modelhub.jpg', depth_vis)
+ print('Test run with model from modelhub ok.')
+
+ @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
+ def test_run_with_model_name(self):
+ pipeline_bts = pipeline(task=self.task, model=self.model_id)
+ result = pipeline_bts(input=self.image)
+ depth_vis = result[OutputKeys.DEPTHS_COLOR]
+ cv2.imwrite('result_modelname.jpg', depth_vis)
+ print('Test run with model name ok.')
+
+ @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
+ def test_run_by_direct_model_download(self):
+ cache_path = snapshot_download(self.model_id)
+ pipeline_bts = pipeline(self.task, model=cache_path)
+ result = pipeline_bts(input=self.image)
+ depth_vis = result[OutputKeys.DEPTHS_COLOR]
+ cv2.imwrite('result_snapshot.jpg', depth_vis)
+ print('Test run with snapshot ok.')
+
+ @unittest.skip('demo compatibility test is only enabled on a needed-basis')
+ def test_demo_compatibility(self):
+ self.compatibility_check()
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/tests/pipelines/test_image_open_vocabulary_detection.py b/tests/pipelines/test_image_open_vocabulary_detection.py
index 28ae636e..52dc1d11 100644
--- a/tests/pipelines/test_image_open_vocabulary_detection.py
+++ b/tests/pipelines/test_image_open_vocabulary_detection.py
@@ -45,7 +45,7 @@ class ImageOpenVocabularyDetectionTest(unittest.TestCase,
logger.info('degrade tensorflow finished')
return super().tearDown()
- @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
+ @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_model_from_modelhub(self):
model = Model.from_pretrained(self.model_id)
vild_pipeline = pipeline(task=self.task, model=model)
diff --git a/tests/pipelines/test_image_quality_assessment_man.py b/tests/pipelines/test_image_quality_assessment_man.py
new file mode 100644
index 00000000..2668d45d
--- /dev/null
+++ b/tests/pipelines/test_image_quality_assessment_man.py
@@ -0,0 +1,56 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import unittest
+
+from modelscope.hub.snapshot_download import snapshot_download
+from modelscope.models import Model
+from modelscope.outputs import OutputKeys
+from modelscope.pipelines import pipeline
+from modelscope.pipelines.cv import ImageQualityAssessmentMANPipeline
+from modelscope.utils.constant import Tasks
+from modelscope.utils.demo_utils import DemoCompatibilityCheck
+from modelscope.utils.test_utils import test_level
+
+
+class ImageQualityAssessmentMANTest(unittest.TestCase, DemoCompatibilityCheck):
+
+ def setUp(self) -> None:
+ self.task = Tasks.image_quality_assessment_mos
+ self.model_id = 'damo/cv_man_image-quality-assessment'
+ self.test_img = 'data/test/images/dogs.jpg'
+
+ @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
+ def test_run_by_direct_model_download(self):
+ cache_path = snapshot_download(self.model_id)
+ pipeline = ImageQualityAssessmentMANPipeline(cache_path)
+ pipeline.group_key = self.task
+ out_path = pipeline(input=self.test_img)[OutputKeys.SCORE]
+ print('pipeline: the out_path is {}'.format(out_path))
+
+ @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
+ def test_run_with_model_from_modelhub(self):
+ model = Model.from_pretrained(self.model_id)
+ pipeline_ins = pipeline(
+ task=Tasks.image_quality_assessment_mos, model=model)
+ out_path = pipeline_ins(input=self.test_img)[OutputKeys.SCORE]
+ print('pipeline: the out_path is {}'.format(out_path))
+
+ @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
+ def test_run_with_model_name(self):
+ pipeline_ins = pipeline(
+ task=Tasks.image_quality_assessment_mos, model=self.model_id)
+ out_path = pipeline_ins(input=self.test_img)[OutputKeys.SCORE]
+ print('pipeline: the out_path is {}'.format(out_path))
+
+ @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
+ def test_run_with_default_model(self):
+ pipeline_ins = pipeline(task=Tasks.image_quality_assessment_mos)
+ out_path = pipeline_ins(input=self.test_img)[OutputKeys.SCORE]
+ print('pipeline: the out_path is {}'.format(out_path))
+
+ @unittest.skip('demo compatibility test is only enabled on a needed-basis')
+ def test_demo_compatibility(self):
+ self.compatibility_check()
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/tests/pipelines/test_key_word_spotting.py b/tests/pipelines/test_key_word_spotting.py
index 85f3370f..13f7a308 100644
--- a/tests/pipelines/test_key_word_spotting.py
+++ b/tests/pipelines/test_key_word_spotting.py
@@ -180,6 +180,14 @@ class KeyWordSpottingTest(unittest.TestCase, DemoCompatibilityCheck):
'model_id': 'damo/speech_charctc_kws_phone-xiaoyun',
'wav_path': 'data/test/audios/kws_xiaoyunxiaoyun.wav',
'keywords': '小云小云'
+ }, {
+ 'model_id': 'damo/speech_charctc_kws_phone-speechcommands',
+ 'wav_path': 'data/test/audios/kws_xiaoyunxiaoyun.wav',
+ 'keywords': '小云小云'
+ }, {
+ 'model_id': 'damo/speech_charctc_kws_phone-wenwen',
+ 'wav_path': 'data/test/audios/kws_xiaoyunxiaoyun.wav',
+ 'keywords': '小云小云'
}]
def setUp(self) -> None:
@@ -330,10 +338,11 @@ class KeyWordSpottingTest(unittest.TestCase, DemoCompatibilityCheck):
wav_path = item['wav_path']
keywords = item['keywords']
- logger.info('run with model_id:' + model_id)
+ logger.info('run with model_id:' + model_id + ' with keywords:'
+ + keywords)
kws_result = self.run_pipeline(
model_id=model_id, audio_in=wav_path, keywords=keywords)
- self.check_result('test_run_with_all_models', kws_result)
+ logger.info(ColorCodes.YELLOW + str(kws_result) + ColorCodes.END)
@unittest.skip('demo compatibility test is only enabled on a needed-basis')
def test_demo_compatibility(self):
diff --git a/tests/pipelines/test_key_word_spotting_farfield.py b/tests/pipelines/test_key_word_spotting_farfield.py
index 69d6a953..e736f48b 100644
--- a/tests/pipelines/test_key_word_spotting_farfield.py
+++ b/tests/pipelines/test_key_word_spotting_farfield.py
@@ -7,6 +7,8 @@ from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level
+OUTPUT_WAV = 'output.wav'
+
TEST_SPEECH_FILE = 'data/test/audios/3ch_nihaomiya.wav'
TEST_SPEECH_FILE_MONO = 'data/test/audios/1ch_nihaomiya.wav'
TEST_SPEECH_URL = 'https://modelscope.oss-cn-beijing.aliyuncs.com/' \
@@ -17,6 +19,8 @@ class KWSFarfieldTest(unittest.TestCase):
def setUp(self) -> None:
self.model_id = 'damo/speech_dfsmn_kws_char_farfield_16k_nihaomiya'
+ if os.path.isfile(OUTPUT_WAV):
+ os.remove(OUTPUT_WAV)
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_normal(self):
@@ -25,6 +29,16 @@ class KWSFarfieldTest(unittest.TestCase):
self.assertEqual(len(result['kws_list']), 5)
print(result['kws_list'][-1])
+ @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
+ def test_output(self):
+ kws = pipeline(Tasks.keyword_spotting, model=self.model_id)
+ result = kws(
+ os.path.join(os.getcwd(), TEST_SPEECH_FILE),
+ output_file=OUTPUT_WAV)
+ self.assertEqual(len(result['kws_list']), 5)
+ self.assertTrue(os.path.exists(OUTPUT_WAV))
+ print(result['kws_list'][-1])
+
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_mono(self):
kws = pipeline(Tasks.keyword_spotting, model=self.model_id)
diff --git a/tests/pipelines/test_language_identification.py b/tests/pipelines/test_language_identification.py
index a17cd439..ddd91e69 100644
--- a/tests/pipelines/test_language_identification.py
+++ b/tests/pipelines/test_language_identification.py
@@ -13,7 +13,8 @@ class LanguageIdentificationTest(unittest.TestCase, DemoCompatibilityCheck):
self.task = Tasks.text_classification
self.model_id = 'damo/nlp_language_identification-classification-base'
- @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
+ @unittest.skipUnless(test_level() >= 0,
+ 'skip test case in current test level')
def test_run_with_model_name_for_en2de(self):
inputs = 'Elon Musk, co-founder and chief executive officer of Tesla Motors.\n' \
'Gleichzeitig nahm die Legion an der Befriedung Algeriens teil, die von.\n' \
@@ -21,7 +22,8 @@ class LanguageIdentificationTest(unittest.TestCase, DemoCompatibilityCheck):
pipeline_ins = pipeline(self.task, model=self.model_id)
print(pipeline_ins(input=inputs))
- @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
+ @unittest.skipUnless(test_level() >= 0,
+ 'skip test case in current test level')
def test_demo_compatibility(self):
self.compatibility_check()
diff --git a/tests/pipelines/test_lineless_table_recognition.py b/tests/pipelines/test_lineless_table_recognition.py
new file mode 100644
index 00000000..53fde8a1
--- /dev/null
+++ b/tests/pipelines/test_lineless_table_recognition.py
@@ -0,0 +1,44 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+
+import unittest
+
+import cv2
+import numpy as np
+
+from modelscope.pipelines import pipeline
+from modelscope.pipelines.base import Pipeline
+from modelscope.utils.constant import Tasks
+from modelscope.utils.demo_utils import DemoCompatibilityCheck
+from modelscope.utils.test_utils import test_level
+
+
+class TableRecognitionTest(unittest.TestCase, DemoCompatibilityCheck):
+
+ def setUp(self) -> None:
+ self.model_id = 'damo/cv_resnet-transformer_table-structure-recognition_lore'
+ self.test_image = 'data/test/images/lineless_table_recognition.jpg'
+ self.task = Tasks.lineless_table_recognition
+
+ def pipeline_inference(self, pipe: Pipeline, input_location: str):
+ result = pipe(input_location)
+ print('lineless table recognition results: ')
+ print(result)
+
+ @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
+ def test_run_with_model_from_modelhub(self):
+ lineless_table_recognition = pipeline(
+ Tasks.lineless_table_recognition, model=self.model_id)
+ self.pipeline_inference(lineless_table_recognition, self.test_image)
+
+ @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
+ def test_run_modelhub_default_model(self):
+ lineless_table_recognition = pipeline(Tasks.lineless_table_recognition)
+ self.pipeline_inference(lineless_table_recognition, self.test_image)
+
+ @unittest.skip('demo compatibility test is only enabled on a needed-basis')
+ def test_demo_compatibility(self):
+ self.compatibility_check()
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/tests/pipelines/test_movie_scene_segmentation.py b/tests/pipelines/test_movie_scene_segmentation.py
index affd5140..0ac8b716 100644
--- a/tests/pipelines/test_movie_scene_segmentation.py
+++ b/tests/pipelines/test_movie_scene_segmentation.py
@@ -1,8 +1,15 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
+import os
+import tempfile
import unittest
+from modelscope.hub.snapshot_download import snapshot_download
+from modelscope.metainfo import Trainers
+from modelscope.msdatasets import MsDataset
from modelscope.pipelines import pipeline
-from modelscope.utils.constant import Tasks
+from modelscope.trainers import build_trainer
+from modelscope.utils.config import Config, ConfigDict
+from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.demo_utils import DemoCompatibilityCheck
from modelscope.utils.test_utils import test_level
@@ -13,6 +20,12 @@ class MovieSceneSegmentationTest(unittest.TestCase, DemoCompatibilityCheck):
self.task = Tasks.movie_scene_segmentation
self.model_id = 'damo/cv_resnet50-bert_video-scene-segmentation_movienet'
+ cache_path = snapshot_download(self.model_id)
+ config_path = os.path.join(cache_path, ModelFile.CONFIGURATION)
+ self.cfg = Config.from_file(config_path)
+
+ self.tmp_dir = tempfile.TemporaryDirectory().name
+
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_movie_scene_segmentation(self):
input_location = 'data/test/videos/movie_scene_segmentation_test_video.mp4'
@@ -24,6 +37,81 @@ class MovieSceneSegmentationTest(unittest.TestCase, DemoCompatibilityCheck):
else:
raise ValueError('process error')
+ @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
+ def test_movie_scene_segmentation_finetune(self):
+
+ train_data_cfg = ConfigDict(
+ name='movie_scene_seg_toydata',
+ split='train',
+ cfg=self.cfg.preprocessor,
+ test_mode=False)
+
+ train_dataset = MsDataset.load(
+ dataset_name=train_data_cfg.name,
+ split=train_data_cfg.split,
+ cfg=train_data_cfg.cfg,
+ test_mode=train_data_cfg.test_mode)
+
+ test_data_cfg = ConfigDict(
+ name='movie_scene_seg_toydata',
+ split='test',
+ cfg=self.cfg.preprocessor,
+ test_mode=True)
+
+ test_dataset = MsDataset.load(
+ dataset_name=test_data_cfg.name,
+ split=test_data_cfg.split,
+ cfg=test_data_cfg.cfg,
+ test_mode=test_data_cfg.test_mode)
+
+ kwargs = dict(
+ model=self.model_id,
+ train_dataset=train_dataset,
+ eval_dataset=test_dataset,
+ work_dir=self.tmp_dir)
+
+ trainer = build_trainer(
+ name=Trainers.movie_scene_segmentation, default_args=kwargs)
+ trainer.train()
+ results_files = os.listdir(trainer.work_dir)
+ print(results_files)
+
+ @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
+ def test_movie_scene_segmentation_finetune_with_custom_dataset(self):
+
+ data_cfg = ConfigDict(
+ dataset_name='movie_scene_seg_toydata',
+ namespace='modelscope',
+ train_split='train',
+ test_split='test',
+ model_cfg=self.cfg)
+
+ train_dataset = MsDataset.load(
+ dataset_name=data_cfg.dataset_name,
+ namespace=data_cfg.namespace,
+ split=data_cfg.train_split,
+ custom_cfg=data_cfg.model_cfg,
+ test_mode=False)
+
+ test_dataset = MsDataset.load(
+ dataset_name=data_cfg.dataset_name,
+ namespace=data_cfg.namespace,
+ split=data_cfg.test_split,
+ custom_cfg=data_cfg.model_cfg,
+ test_mode=True)
+
+ kwargs = dict(
+ model=self.model_id,
+ train_dataset=train_dataset,
+ eval_dataset=test_dataset,
+ work_dir=self.tmp_dir)
+
+ trainer = build_trainer(
+ name=Trainers.movie_scene_segmentation, default_args=kwargs)
+ trainer.train()
+ results_files = os.listdir(trainer.work_dir)
+ print(results_files)
+
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_movie_scene_segmentation_with_default_task(self):
input_location = 'data/test/videos/movie_scene_segmentation_test_video.mp4'
diff --git a/tests/pipelines/test_nerf_recon_acc.py b/tests/pipelines/test_nerf_recon_acc.py
index bc5ad1b2..95d879fb 100644
--- a/tests/pipelines/test_nerf_recon_acc.py
+++ b/tests/pipelines/test_nerf_recon_acc.py
@@ -8,7 +8,7 @@ from modelscope.hub.snapshot_download import snapshot_download
from modelscope.msdatasets import MsDataset
from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline
-from modelscope.utils.constant import Tasks
+from modelscope.utils.constant import DownloadMode, Tasks
from modelscope.utils.demo_utils import DemoCompatibilityCheck
from modelscope.utils.test_utils import test_level
@@ -18,8 +18,11 @@ class NeRFReconAccTest(unittest.TestCase, DemoCompatibilityCheck):
def setUp(self) -> None:
self.model_id = 'damo/cv_nerf-3d-reconstruction-accelerate_damo'
data_dir = MsDataset.load(
- 'nerf_recon_dataset', namespace='damo',
- split='train').config_kwargs['split_config']['train']
+ 'nerf_recon_dataset',
+ namespace='damo',
+ split='train',
+ download_mode=DownloadMode.FORCE_REDOWNLOAD
+ ).config_kwargs['split_config']['train']
nerf_synthetic_dataset = os.path.join(data_dir, 'nerf_synthetic')
blender_scene = 'lego'
self.data_dir = os.path.join(nerf_synthetic_dataset, blender_scene)
diff --git a/tests/pipelines/test_nli.py b/tests/pipelines/test_nli.py
index 9d985d25..a7d2a236 100644
--- a/tests/pipelines/test_nli.py
+++ b/tests/pipelines/test_nli.py
@@ -18,9 +18,12 @@ class NLITest(unittest.TestCase, DemoCompatibilityCheck):
self.task = Tasks.nli
self.model_id = 'damo/nlp_structbert_nli_chinese-base'
self.model_id_fact_checking = 'damo/nlp_structbert_fact-checking_chinese-base'
+ self.model_id_peer = 'damo/nlp_peer_mnli_english-base'
sentence1 = '四川商务职业学院和四川财经职业学院哪个好?'
sentence2 = '四川商务职业学院商务管理在哪个校区?'
+ en_sentence1 = 'Conceptually cream skimming has two basic dimensions - product and geography.'
+ en_sentence2 = 'Product and geography are what make cream skimming work.'
regress_tool = MsRegressTool(baseline=False)
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
@@ -61,6 +64,15 @@ class NLITest(unittest.TestCase, DemoCompatibilityCheck):
model_revision='v1.0.1')
print(pipeline_ins(input=(self.sentence1, self.sentence2)))
+ @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
+ def test_peer_model(self):
+ pipeline_ins = pipeline(
+ task=Tasks.nli,
+ model=self.model_id_peer,
+ model_revision='v1.0.0',
+ )
+ print(pipeline_ins(input=(self.en_sentence1, self.en_sentence2)))
+
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_default_model(self):
pipeline_ins = pipeline(task=Tasks.nli)
diff --git a/tests/pipelines/test_ocr_recognition.py b/tests/pipelines/test_ocr_recognition.py
index 372a4bc4..145ae22a 100644
--- a/tests/pipelines/test_ocr_recognition.py
+++ b/tests/pipelines/test_ocr_recognition.py
@@ -26,7 +26,7 @@ class OCRRecognitionTest(unittest.TestCase, DemoCompatibilityCheck):
ocr_recognition = pipeline(
Tasks.ocr_recognition,
model=self.model_id,
- model_revision='v1.0.0')
+ model_revision='v2.2.1')
self.pipeline_inference(ocr_recognition, self.test_image)
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
@@ -34,14 +34,14 @@ class OCRRecognitionTest(unittest.TestCase, DemoCompatibilityCheck):
ocr_recognition = pipeline(
Tasks.ocr_recognition,
model=self.model_id,
- model_revision='v1.0.0')
+ model_revision='v2.2.1')
imagePIL = PIL.Image.open(self.test_image)
self.pipeline_inference(ocr_recognition, imagePIL)
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_modelhub_default_model(self):
ocr_recognition = pipeline(
- Tasks.ocr_recognition, model_revision='v2.0.0')
+ Tasks.ocr_recognition, model_revision='v2.3.0')
self.pipeline_inference(ocr_recognition, self.test_image)
@unittest.skip('demo compatibility test is only enabled on a needed-basis')
diff --git a/tests/pipelines/test_realtime_video_object_detection.py b/tests/pipelines/test_realtime_video_object_detection.py
index d65313a3..716c9260 100644
--- a/tests/pipelines/test_realtime_video_object_detection.py
+++ b/tests/pipelines/test_realtime_video_object_detection.py
@@ -37,6 +37,22 @@ class RealtimeVideoObjectDetectionTest(unittest.TestCase,
else:
raise ValueError('process error')
+ @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
+ def test_longshortnet(self):
+ model_id = 'damo/cv_cspnet_video-object-detection_longshortnet'
+ test_video = 'data/test/videos/test_realtime_vod.mp4'
+ realtime_video_object_detection = pipeline(
+ Tasks.video_object_detection, model=model_id)
+ result = realtime_video_object_detection(test_video)
+ if result:
+ logger.info('Video output to test_vod_results.avi')
+ show_video_object_detection_result(test_video,
+ result[OutputKeys.BOXES],
+ result[OutputKeys.LABELS],
+ 'test_vod_results.avi')
+ else:
+ raise ValueError('process error')
+
@unittest.skip('demo compatibility test is only enabled on a needed-basis')
def test_demo_compatibility(self):
self.compatibility_check()
diff --git a/tests/pipelines/test_salient_detection.py b/tests/pipelines/test_salient_detection.py
index bcb904e6..3101213c 100644
--- a/tests/pipelines/test_salient_detection.py
+++ b/tests/pipelines/test_salient_detection.py
@@ -23,6 +23,27 @@ class SalientDetectionTest(unittest.TestCase, DemoCompatibilityCheck):
import cv2
cv2.imwrite(input_location + '_salient.jpg', result[OutputKeys.MASKS])
+ @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
+ def test_salient_boudary_detection(self):
+ input_location = 'data/test/images/image_salient_detection.jpg'
+ model_id = 'damo/cv_res2net_salient-detection'
+ salient_detect = pipeline(Tasks.semantic_segmentation, model=model_id)
+ result = salient_detect(input_location)
+ import cv2
+ cv2.imwrite(input_location + '_boudary_salient.jpg',
+ result[OutputKeys.MASKS])
+
+ @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
+ def test_camouflag_detection(self):
+ input_location = 'data/test/images/image_camouflag_detection.jpg'
+ model_id = 'damo/cv_res2net_camouflaged-detection'
+ camouflag_detect = pipeline(
+ Tasks.semantic_segmentation, model=model_id)
+ result = camouflag_detect(input_location)
+ import cv2
+ cv2.imwrite(input_location + '_camouflag.jpg',
+ result[OutputKeys.MASKS])
+
@unittest.skip('demo compatibility test is only enabled on a needed-basis')
def test_demo_compatibility(self):
self.compatibility_check()
diff --git a/tests/pipelines/test_sentence_similarity.py b/tests/pipelines/test_sentence_similarity.py
index 846b72c3..233bd3a1 100644
--- a/tests/pipelines/test_sentence_similarity.py
+++ b/tests/pipelines/test_sentence_similarity.py
@@ -1,8 +1,11 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import unittest
+import torch
+from packaging import version
+
from modelscope.hub.snapshot_download import snapshot_download
-from modelscope.models import Model
+from modelscope.models import Model, TorchModel
from modelscope.models.nlp import SbertForSequenceClassification
from modelscope.pipelines import pipeline
from modelscope.pipelines.nlp import TextClassificationPipeline
@@ -90,6 +93,18 @@ class SentenceSimilarityTest(unittest.TestCase, DemoCompatibilityCheck):
model_revision='v1.0.0')
print(pipeline_ins(input=(self.sentence1, self.sentence2)))
+ @unittest.skipIf(
+ version.parse(torch.__version__) < version.parse('2.0.0.dev'),
+ 'skip when torch version < 2.0')
+ def test_compile(self):
+ pipeline_ins = pipeline(
+ task=Tasks.sentence_similarity,
+ model=self.model_id_retail,
+ model_revision='v1.0.0',
+ compile=True)
+ print(pipeline_ins(input=(self.sentence1, self.sentence2)))
+ self.assertTrue(isinstance(pipeline_ins.model._orig_mod, TorchModel))
+
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_default_model(self):
pipeline_ins = pipeline(task=Tasks.sentence_similarity)
diff --git a/tests/pipelines/test_siamese_uie.py b/tests/pipelines/test_siamese_uie.py
index 9097813c..30b38d2e 100644
--- a/tests/pipelines/test_siamese_uie.py
+++ b/tests/pipelines/test_siamese_uie.py
@@ -31,12 +31,12 @@ class ZeroShotClassificationTest(unittest.TestCase, DemoCompatibilityCheck):
tokenizer = SiameseUiePreprocessor(cache_path)
model = SiameseUieModel.from_pretrained(cache_path)
pipeline1 = SiameseUiePipeline(
- model, preprocessor=tokenizer, model_revision='v1.0')
+ model, preprocessor=tokenizer, model_revision='v1.1')
pipeline2 = pipeline(
Tasks.siamese_uie,
model=model,
preprocessor=tokenizer,
- model_revision='v1.0')
+ model_revision='v1.1')
print(
f'sentence: {self.sentence}\n'
@@ -53,18 +53,18 @@ class ZeroShotClassificationTest(unittest.TestCase, DemoCompatibilityCheck):
task=Tasks.siamese_uie,
model=model,
preprocessor=tokenizer,
- model_revision='v1.0')
+ model_revision='v1.1')
print(pipeline_ins(input=self.sentence, schema=self.schema))
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_model_name(self):
pipeline_ins = pipeline(
- task=Tasks.siamese_uie, model=self.model_id, model_revision='v1.0')
+ task=Tasks.siamese_uie, model=self.model_id, model_revision='v1.1')
print(pipeline_ins(input=self.sentence, schema=self.schema))
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_default_model(self):
- pipeline_ins = pipeline(task=Tasks.siamese_uie, model_revision='v1.0')
+ pipeline_ins = pipeline(task=Tasks.siamese_uie, model_revision='v1.1')
print(pipeline_ins(input=self.sentence, schema=self.schema))
@unittest.skip('demo compatibility test is only enabled on a needed-basis')
diff --git a/tests/pipelines/test_soonet_video_temporal_grounding.py b/tests/pipelines/test_soonet_video_temporal_grounding.py
new file mode 100644
index 00000000..21f8027c
--- /dev/null
+++ b/tests/pipelines/test_soonet_video_temporal_grounding.py
@@ -0,0 +1,34 @@
+# Copyright 2022-2023 The Alibaba Fundamental Vision Team Authors. All rights reserved.
+import unittest
+
+from modelscope.models import Model
+from modelscope.models.multi_modal.soonet import SOONet
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import Tasks
+from modelscope.utils.demo_utils import DemoCompatibilityCheck
+from modelscope.utils.test_utils import test_level
+
+
+class SOONetVideoTemporalGroundingTest(unittest.TestCase,
+ DemoCompatibilityCheck):
+
+ def setUp(self) -> None:
+ self.task = Tasks.video_temporal_grounding
+ self.model_id = 'damo/multi-modal_soonet_video-temporal-grounding'
+
+ @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
+ def test_run_modelhub(self):
+ soonet_pipeline = pipeline(self.task, self.model_id)
+ result = soonet_pipeline(
+ ('a man takes food out of the refrigerator.',
+ 'soonet_video_temporal_grounding_test_video.mp4'))
+ print(f'soonet output: {result}.')
+
+ @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
+ def test_load_model_from_pretrained(self):
+ model = Model.from_pretrained(self.model_id)
+ self.assertTrue(model.__class__ == SOONet)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/tests/pipelines/test_speech_signal_process.py b/tests/pipelines/test_speech_signal_process.py
index 2916d31a..2c26cee6 100644
--- a/tests/pipelines/test_speech_signal_process.py
+++ b/tests/pipelines/test_speech_signal_process.py
@@ -4,6 +4,7 @@ import os.path
import unittest
from modelscope.metainfo import Pipelines
+from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.demo_utils import DemoCompatibilityCheck
@@ -17,6 +18,8 @@ FAREND_SPEECH_URL = 'https://modelscope.oss-cn-beijing.aliyuncs.com/' \
'test/audios/farend_speech.wav'
NOISE_SPEECH_FILE = 'data/test/audios/speech_with_noise.wav'
+NOISE_SPEECH_FILE_48K = 'data/test/audios/speech_with_noise_48k.wav'
+NOISE_SPEECH_FILE_48K_PCM = 'data/test/audios/speech_with_noise_48k.PCM'
NOISE_SPEECH_URL = 'https://modelscope.oss-cn-beijing.aliyuncs.com/' \
'test/audios/speech_with_noise.wav'
@@ -83,7 +86,7 @@ class SpeechSignalProcessTest(unittest.TestCase, DemoCompatibilityCheck):
print(f'Processed audio saved to {output_path}')
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
- def test_ans(self):
+ def test_frcrn_ans(self):
model_id = 'damo/speech_frcrn_ans_cirm_16k'
ans = pipeline(Tasks.acoustic_noise_suppression, model=model_id)
output_path = os.path.abspath('output.wav')
@@ -112,6 +115,41 @@ class SpeechSignalProcessTest(unittest.TestCase, DemoCompatibilityCheck):
ans(data, output_path=output_path)
print(f'Processed audio saved to {output_path}')
+ @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
+ def test_dfsmn_ans(self):
+ model_id = 'damo/speech_dfsmn_ans_psm_48k_causal'
+ ans = pipeline(Tasks.acoustic_noise_suppression, model=model_id)
+ output_path = os.path.abspath('output.wav')
+ ans(os.path.join(os.getcwd(), NOISE_SPEECH_FILE_48K),
+ output_path=output_path)
+ print(f'Processed audio saved to {output_path}')
+
+ @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
+ def test_dfsmn_ans_bytes(self):
+ model_id = 'damo/speech_dfsmn_ans_psm_48k_causal'
+ ans = pipeline(Tasks.acoustic_noise_suppression, model=model_id)
+ output_path = os.path.abspath('output.wav')
+ with open(os.path.join(os.getcwd(), NOISE_SPEECH_FILE_48K), 'rb') as f:
+ data = f.read()
+ ans(data, output_path=output_path)
+ print(f'Processed audio saved to {output_path}')
+
+ @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
+ def test_dfsmn_ans_stream(self):
+ model_id = 'damo/speech_dfsmn_ans_psm_48k_causal'
+ ans = pipeline(
+ Tasks.acoustic_noise_suppression, model=model_id, stream_mode=True)
+ with open(os.path.join(os.getcwd(), NOISE_SPEECH_FILE_48K_PCM),
+ 'rb') as f:
+ block_size = 3840
+ audio = f.read(block_size)
+ with open('output.pcm', 'wb') as w:
+ while len(audio) >= block_size:
+ result = ans(audio)
+ pcm = result[OutputKeys.OUTPUT_PCM]
+ w.write(pcm)
+ audio = f.read(block_size)
+
@unittest.skip('demo compatibility test is only enabled on a needed-basis')
def test_demo_compatibility(self):
self.compatibility_check()
diff --git a/tests/pipelines/test_text_generation.py b/tests/pipelines/test_text_generation.py
index cbb1b29b..a729d4da 100644
--- a/tests/pipelines/test_text_generation.py
+++ b/tests/pipelines/test_text_generation.py
@@ -67,6 +67,17 @@ class TextGenerationTest(unittest.TestCase, DemoCompatibilityCheck):
self.run_pipeline_with_model_id(self.palm_model_id_zh_base,
self.palm_input_zh)
+ @unittest.skipUnless(test_level() >= -1, 'skip test in current test level')
+ def test_palm_zh_base_with_model_name_with_args(self):
+ self.run_pipeline_with_model_id(
+ self.palm_model_id_zh_base,
+ self.palm_input_zh,
+ run_kwargs={
+ 'top_p': 0.9,
+ 'temperature': 0.9,
+ 'max_length': 64
+ })
+
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_palm_zh_base_with_model_name_batch(self):
self.run_pipeline_with_model_id(
@@ -95,6 +106,17 @@ class TextGenerationTest(unittest.TestCase, DemoCompatibilityCheck):
self.run_pipeline_with_model_id(self.gpt3_base_model_id,
self.gpt3_input)
+ @unittest.skipUnless(test_level() >= -1, 'skip test in current test level')
+ def test_gpt_base_with_model_name_with_args(self):
+ self.run_pipeline_with_model_id(
+ self.gpt3_base_model_id,
+ self.gpt3_input,
+ run_kwargs={
+ 'top_p': 0.9,
+ 'temperature': 0.9,
+ 'max_length': 64
+ })
+
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_gpt_base_with_model_name_batch(self):
self.run_pipeline_with_model_id(
diff --git a/tests/pipelines/test_text_to_video_synthesis.py b/tests/pipelines/test_text_to_video_synthesis.py
new file mode 100644
index 00000000..6463c155
--- /dev/null
+++ b/tests/pipelines/test_text_to_video_synthesis.py
@@ -0,0 +1,36 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+
+import unittest
+
+from modelscope.outputs import OutputKeys
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import Tasks
+from modelscope.utils.demo_utils import DemoCompatibilityCheck
+from modelscope.utils.test_utils import test_level
+
+
+class TextToVideoSynthesisTest(unittest.TestCase, DemoCompatibilityCheck):
+
+ def setUp(self) -> None:
+ self.task = Tasks.text_to_video_synthesis
+ self.model_id = 'damo/text-to-video-synthesis'
+
+ test_text = {
+ 'text': 'A panda eating bamboo on a rock.',
+ }
+
+ @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
+ def test_run_with_model_from_modelhub(self):
+ pipe_line_text_to_video_synthesis = pipeline(
+ task=self.task, model=self.model_id)
+ output_video_path = pipe_line_text_to_video_synthesis(
+ self.test_text)[OutputKeys.OUTPUT_VIDEO]
+ print(output_video_path)
+
+ @unittest.skip('demo compatibility test is only enabled on a needed-basis')
+ def test_demo_compatibility(self):
+ self.compatibility_check()
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/tests/pipelines/test_tinynas_detection.py b/tests/pipelines/test_tinynas_detection.py
index 4c3735dc..f7c513ff 100644
--- a/tests/pipelines/test_tinynas_detection.py
+++ b/tests/pipelines/test_tinynas_detection.py
@@ -196,6 +196,28 @@ class TinynasObjectDetectionTest(unittest.TestCase, DemoCompatibilityCheck):
OutputKeys.LABELS in result) and (OutputKeys.BOXES in result)
print('results: ', result)
+ @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
+ def test_smokefire_detection_damoyolo(self):
+ tinynas_object_detection = pipeline(
+ Tasks.domain_specific_object_detection,
+ model='damo/cv_tinynas_object-detection_damoyolo_smokefire')
+ result = tinynas_object_detection(
+ 'data/test/images/image_smokefire_detection.jpg')
+ assert result and (OutputKeys.SCORES in result) and (
+ OutputKeys.LABELS in result) and (OutputKeys.BOXES in result)
+ print('results: ', result)
+
+ @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
+ def test_smokefire_detection_damoyolo_with_image(self):
+ tinynas_object_detection = pipeline(
+ Tasks.domain_specific_object_detection,
+ model='damo/cv_tinynas_object-detection_damoyolo_smokefire')
+ img = Image.open('data/test/images/image_smokefire_detection.jpg')
+ result = tinynas_object_detection(img)
+ assert result and (OutputKeys.SCORES in result) and (
+ OutputKeys.LABELS in result) and (OutputKeys.BOXES in result)
+ print('results: ', result)
+
if __name__ == '__main__':
unittest.main()
diff --git a/tests/pipelines/test_video_instance_segmentation.py b/tests/pipelines/test_video_instance_segmentation.py
new file mode 100644
index 00000000..0a76d260
--- /dev/null
+++ b/tests/pipelines/test_video_instance_segmentation.py
@@ -0,0 +1,42 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import unittest
+
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import Tasks
+from modelscope.utils.demo_utils import DemoCompatibilityCheck
+from modelscope.utils.test_utils import test_level
+
+
+class VideoInstanceSegmentationTest(unittest.TestCase, DemoCompatibilityCheck):
+
+ def setUp(self) -> None:
+ self.task = Tasks.video_panoptic_segmentation
+ self.model_id = 'damo/cv_swinb_video-instance-segmentation'
+
+ @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
+ def test_run_modelhub(self):
+ video_path = 'data/test/videos/kitti-step_testing_image_02_0000.mp4'
+ seg_pipeline = pipeline(
+ Tasks.video_instance_segmentation,
+ model=self.model_id,
+ max_video_frames=20)
+ result = seg_pipeline(video_path)
+
+ print(f'video instance segmentation output: \n{result}.')
+
+ @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
+ def test_run_modelhub_default_model(self):
+ video_path = 'data/test/videos/kitti-step_testing_image_02_0000.mp4'
+ seg_pipeline = pipeline(
+ Tasks.video_instance_segmentation, max_video_frames=20)
+ result = seg_pipeline(video_path)
+
+ print(f'video instance segmentation output:\n {result}.')
+
+ @unittest.skip('demo compatibility test is only enabled on a needed-basis')
+ def test_demo_compatibility(self):
+ self.compatibility_check()
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/tests/pipelines/test_video_single_object_tracking.py b/tests/pipelines/test_video_single_object_tracking.py
index 7f3a9226..e75ccbb0 100644
--- a/tests/pipelines/test_video_single_object_tracking.py
+++ b/tests/pipelines/test_video_single_object_tracking.py
@@ -14,6 +14,7 @@ class SingleObjectTracking(unittest.TestCase, DemoCompatibilityCheck):
def setUp(self) -> None:
self.task = Tasks.video_single_object_tracking
self.model_id = 'damo/cv_vitb_video-single-object-tracking_ostrack'
+ self.model_id_procontext = 'damo/cv_vitb_video-single-object-tracking_procontext'
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_end2end(self):
@@ -26,6 +27,16 @@ class SingleObjectTracking(unittest.TestCase, DemoCompatibilityCheck):
show_video_tracking_result(video_path, result[OutputKeys.BOXES],
'./tracking_result.avi')
+ @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
+ def test_run_end2end_procontext(self):
+ video_single_object_tracking = pipeline(
+ Tasks.video_single_object_tracking, model=self.model_id_procontext)
+ video_path = 'data/test/videos/dog.avi'
+ init_bbox = [414, 343, 514, 449] # [x1, y1, x2, y2]
+ result = video_single_object_tracking((video_path, init_bbox))
+ assert OutputKeys.BOXES in result.keys() and len(
+ result[OutputKeys.BOXES]) == 139
+
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_modelhub_default_model(self):
video_single_object_tracking = pipeline(
diff --git a/tests/pipelines/test_vidt_face.py b/tests/pipelines/test_vidt_face.py
new file mode 100644
index 00000000..8640d128
--- /dev/null
+++ b/tests/pipelines/test_vidt_face.py
@@ -0,0 +1,31 @@
+# Copyright 2022-2023 The Alibaba Fundamental Vision Team Authors. All rights reserved.
+import unittest
+
+from modelscope.models import Model
+from modelscope.models.cv.vidt import VidtModel
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import Tasks
+from modelscope.utils.demo_utils import DemoCompatibilityCheck
+from modelscope.utils.test_utils import test_level
+
+
+class VidtTest(unittest.TestCase, DemoCompatibilityCheck):
+
+ def setUp(self) -> None:
+ self.task = Tasks.image_object_detection
+ self.model_id = 'damo/ViDT-face-detection'
+
+ @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
+ def test_run_pipeline(self):
+ vidt_pipeline = pipeline(self.task, self.model_id)
+ result = vidt_pipeline('data/test/images/vidt_test1.jpg')
+ print(f'Vidt output: {result}.')
+
+ @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
+ def test_load_model_from_pretrained(self):
+ model = Model.from_pretrained('damo/ViDT-face-detection')
+ self.assertTrue(model.__class__ == VidtModel)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/tests/pipelines/test_vidt_logo.py b/tests/pipelines/test_vidt_logo.py
new file mode 100644
index 00000000..143eb205
--- /dev/null
+++ b/tests/pipelines/test_vidt_logo.py
@@ -0,0 +1,31 @@
+# Copyright 2022-2023 The Alibaba Fundamental Vision Team Authors. All rights reserved.
+import unittest
+
+from modelscope.models import Model
+from modelscope.models.cv.vidt import VidtModel
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import Tasks
+from modelscope.utils.demo_utils import DemoCompatibilityCheck
+from modelscope.utils.test_utils import test_level
+
+
+class VidtTest(unittest.TestCase, DemoCompatibilityCheck):
+
+ def setUp(self) -> None:
+ self.task = Tasks.image_object_detection
+ self.model_id = 'damo/ViDT-logo-detection'
+
+ @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
+ def test_run_pipeline(self):
+ vidt_pipeline = pipeline(self.task, self.model_id)
+ result = vidt_pipeline('data/test/images/vidt_test1.jpg')
+ print(f'Vidt output: {result}.')
+
+ @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
+ def test_load_model_from_pretrained(self):
+ model = Model.from_pretrained('damo/ViDT-logo-detection')
+ self.assertTrue(model.__class__ == VidtModel)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/tests/pipelines/test_vision_efficient_tuning.py b/tests/pipelines/test_vision_efficient_tuning.py
new file mode 100644
index 00000000..c88ed478
--- /dev/null
+++ b/tests/pipelines/test_vision_efficient_tuning.py
@@ -0,0 +1,154 @@
+# Copyright 2022-2023 The Alibaba Fundamental Vision Team Authors. All rights reserved.
+import unittest
+
+from modelscope.models import Model
+from modelscope.models.cv.vision_efficient_tuning.model import \
+ VisionEfficientTuningModel
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import Tasks
+from modelscope.utils.demo_utils import DemoCompatibilityCheck
+from modelscope.utils.test_utils import test_level
+
+
+class VisionEfficientTuningTest(unittest.TestCase, DemoCompatibilityCheck):
+
+ def setUp(self) -> None:
+ self.task = Tasks.vision_efficient_tuning
+
+ @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
+ def test_vision_efficient_tuning_adapter_run_pipeline(self):
+ model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-adapter'
+ img_path = 'data/test/images/vision_efficient_tuning_test_1.png'
+ petl_pipeline = pipeline(self.task, model_id)
+ result = petl_pipeline(img_path)
+ print(f'Vision-efficient-tuning-adapter output: {result}.')
+
+ @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
+ def test_vision_efficient_tuning_adapter_load_model_from_pretrained(self):
+ model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-adapter'
+ model = Model.from_pretrained(model_id)
+ self.assertTrue(model.__class__ == VisionEfficientTuningModel)
+
+ @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
+ def test_vision_efficient_tuning_adapter_demo_compatibility(self):
+ self.model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-adapter'
+ self.compatibility_check()
+
+ @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
+ def test_vision_efficient_tuning_lora_run_pipeline(self):
+ model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-lora'
+ img_path = 'data/test/images/vision_efficient_tuning_test_1.png'
+ petl_pipeline = pipeline(self.task, model_id)
+ result = petl_pipeline(img_path)
+ print(f'Vision-efficient-tuning-lora output: {result}.')
+
+ @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
+ def test_vision_efficient_tuning_lora_load_model_from_pretrained(self):
+ model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-lora'
+ model = Model.from_pretrained(model_id)
+ self.assertTrue(model.__class__ == VisionEfficientTuningModel)
+
+ @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
+ def test_vision_efficient_tuning_lora_demo_compatibility(self):
+ self.model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-lora'
+ self.compatibility_check()
+
+ @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
+ def test_vision_efficient_tuning_prefix_run_pipeline(self):
+ model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-prefix'
+ img_path = 'data/test/images/vision_efficient_tuning_test_1.png'
+ petl_pipeline = pipeline(self.task, model_id)
+ result = petl_pipeline(img_path)
+ print(f'Vision-efficient-tuning-prefix output: {result}.')
+
+ @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
+ def test_vision_efficient_tuning_prefix_load_model_from_pretrained(self):
+ model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-prefix'
+ model = Model.from_pretrained(model_id)
+ self.assertTrue(model.__class__ == VisionEfficientTuningModel)
+
+ @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
+ def test_vision_efficient_tuning_prefix_demo_compatibility(self):
+ self.model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-prefix'
+ self.compatibility_check()
+
+ @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
+ def test_vision_efficient_tuning_prompt_run_pipeline(self):
+ model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-prompt'
+ img_path = 'data/test/images/vision_efficient_tuning_test_1.png'
+ petl_pipeline = pipeline(self.task, model_id)
+ result = petl_pipeline(img_path)
+ print(f'Vision-efficient-tuning-prompt output: {result}.')
+
+ @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
+ def test_vision_efficient_tuning_prompt_load_model_from_pretrained(self):
+ model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-prompt'
+ model = Model.from_pretrained(model_id)
+ self.assertTrue(model.__class__ == VisionEfficientTuningModel)
+
+ @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
+ def test_vision_efficient_tuning_prompt_demo_compatibility(self):
+ self.model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-prompt'
+ self.compatibility_check()
+
+ @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
+ def test_vision_efficient_tuning_bitfit_run_pipeline(self):
+ model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-bitfit'
+ img_path = 'data/test/images/vision_efficient_tuning_test_1.png'
+ petl_pipeline = pipeline(self.task, model_id)
+ result = petl_pipeline(img_path)
+ print(f'Vision-efficient-tuning-bitfit output: {result}.')
+
+ @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
+ def test_vision_efficient_tuning_bitfit_load_model_from_pretrained(self):
+ model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-bitfit'
+ model = Model.from_pretrained(model_id)
+ self.assertTrue(model.__class__ == VisionEfficientTuningModel)
+
+ @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
+ def test_vision_efficient_tuning_bitfit_demo_compatibility(self):
+ self.model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-bitfit'
+ self.compatibility_check()
+
+ @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
+ def test_vision_efficient_tuning_sidetuning_run_pipeline(self):
+ model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-sidetuning'
+ img_path = 'data/test/images/vision_efficient_tuning_test_1.png'
+ petl_pipeline = pipeline(self.task, model_id)
+ result = petl_pipeline(img_path)
+ print(f'Vision-efficient-tuning-sidetuning output: {result}.')
+
+ @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
+ def test_vision_efficient_tuning_sidetuning_load_model_from_pretrained(
+ self):
+ model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-sidetuning'
+ model = Model.from_pretrained(model_id)
+ self.assertTrue(model.__class__ == VisionEfficientTuningModel)
+
+ @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
+ def test_vision_efficient_tuning_sidetuning_demo_compatibility(self):
+ self.model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-sidetuning'
+ self.compatibility_check()
+
+ @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
+ def test_vision_efficient_tuning_utuning_run_pipeline(self):
+ model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-utuning'
+ img_path = 'data/test/images/vision_efficient_tuning_test_1.png'
+ petl_pipeline = pipeline(self.task, model_id)
+ result = petl_pipeline(img_path)
+ print(f'Vision-efficient-tuning-utuning output: {result}.')
+
+ @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
+ def test_vision_efficient_tuning_utuning_load_model_from_pretrained(self):
+ model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-utuning'
+ model = Model.from_pretrained(model_id)
+ self.assertTrue(model.__class__ == VisionEfficientTuningModel)
+
+ @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
+ def test_vision_efficient_tuning_utuning_demo_compatibility(self):
+ self.model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-utuning'
+ self.compatibility_check()
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/tests/pipelines/test_vision_efficient_tuning_adapter.py b/tests/pipelines/test_vision_efficient_tuning_adapter.py
deleted file mode 100644
index 4a06a40a..00000000
--- a/tests/pipelines/test_vision_efficient_tuning_adapter.py
+++ /dev/null
@@ -1,37 +0,0 @@
-# Copyright 2022-2023 The Alibaba Fundamental Vision Team Authors. All rights reserved.
-import unittest
-
-from modelscope.models import Model
-from modelscope.models.cv.vision_efficient_tuning.vision_efficient_tuning import \
- VisionEfficientTuningModel
-from modelscope.pipelines import pipeline
-from modelscope.utils.constant import Tasks
-from modelscope.utils.demo_utils import DemoCompatibilityCheck
-from modelscope.utils.test_utils import test_level
-
-
-class VisionEfficientTuningAdapterTest(unittest.TestCase,
- DemoCompatibilityCheck):
-
- def setUp(self) -> None:
- self.task = Tasks.vision_efficient_tuning
- self.model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-adapter'
-
- @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
- def test_run_pipeline(self):
-
- petl_pipeline = pipeline(self.task, self.model_id)
- result = petl_pipeline(
- 'data/test/images/vision_efficient_tuning_test_1.png')
-
- print(f'Vision-efficient-tuning-adapter output: {result}.')
-
- @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
- def test_load_model_from_pretrained(self):
- model = Model.from_pretrained(
- 'damo/cv_vitb16_classification_vision-efficient-tuning-adapter')
- self.assertTrue(model.__class__ == VisionEfficientTuningModel)
-
-
-if __name__ == '__main__':
- unittest.main()
diff --git a/tests/pipelines/test_vision_efficient_tuning_lora.py b/tests/pipelines/test_vision_efficient_tuning_lora.py
deleted file mode 100644
index 6c49453a..00000000
--- a/tests/pipelines/test_vision_efficient_tuning_lora.py
+++ /dev/null
@@ -1,36 +0,0 @@
-# Copyright 2022-2023 The Alibaba Fundamental Vision Team Authors. All rights reserved.
-import unittest
-
-from modelscope.models import Model
-from modelscope.models.cv.vision_efficient_tuning.vision_efficient_tuning import \
- VisionEfficientTuningModel
-from modelscope.pipelines import pipeline
-from modelscope.utils.constant import Tasks
-from modelscope.utils.demo_utils import DemoCompatibilityCheck
-from modelscope.utils.test_utils import test_level
-
-
-class VisionEfficientTuningLoRATest(unittest.TestCase, DemoCompatibilityCheck):
-
- def setUp(self) -> None:
- self.task = Tasks.vision_efficient_tuning
- self.model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-lora'
-
- @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
- def test_run_pipeline(self):
-
- petl_pipeline = pipeline(self.task, self.model_id)
- result = petl_pipeline(
- 'data/test/images/vision_efficient_tuning_test_1.png')
-
- print(f'Vision-efficient-tuning-lora output: {result}.')
-
- @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
- def test_load_model_from_pretrained(self):
- model = Model.from_pretrained(
- 'damo/cv_vitb16_classification_vision-efficient-tuning-lora')
- self.assertTrue(model.__class__ == VisionEfficientTuningModel)
-
-
-if __name__ == '__main__':
- unittest.main()
diff --git a/tests/pipelines/test_vision_efficient_tuning_prefix.py b/tests/pipelines/test_vision_efficient_tuning_prefix.py
deleted file mode 100644
index 0eca5819..00000000
--- a/tests/pipelines/test_vision_efficient_tuning_prefix.py
+++ /dev/null
@@ -1,37 +0,0 @@
-# Copyright 2022-2023 The Alibaba Fundamental Vision Team Authors. All rights reserved.
-import unittest
-
-from modelscope.models import Model
-from modelscope.models.cv.vision_efficient_tuning.vision_efficient_tuning import \
- VisionEfficientTuningModel
-from modelscope.pipelines import pipeline
-from modelscope.utils.constant import Tasks
-from modelscope.utils.demo_utils import DemoCompatibilityCheck
-from modelscope.utils.test_utils import test_level
-
-
-class VisionEfficientTuningPrefixTest(unittest.TestCase,
- DemoCompatibilityCheck):
-
- def setUp(self) -> None:
- self.task = Tasks.vision_efficient_tuning
- self.model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-prefix'
-
- @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
- def test_run_pipeline(self):
-
- petl_pipeline = pipeline(self.task, self.model_id)
- result = petl_pipeline(
- 'data/test/images/vision_efficient_tuning_test_1.png')
-
- print(f'Vision-efficient-tuning-prefix output: {result}.')
-
- @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
- def test_load_model_from_pretrained(self):
- model = Model.from_pretrained(
- 'damo/cv_vitb16_classification_vision-efficient-tuning-prefix')
- self.assertTrue(model.__class__ == VisionEfficientTuningModel)
-
-
-if __name__ == '__main__':
- unittest.main()
diff --git a/tests/pipelines/test_vision_efficient_tuning_prompt.py b/tests/pipelines/test_vision_efficient_tuning_prompt.py
deleted file mode 100644
index 97d97811..00000000
--- a/tests/pipelines/test_vision_efficient_tuning_prompt.py
+++ /dev/null
@@ -1,37 +0,0 @@
-# Copyright 2022-2023 The Alibaba Fundamental Vision Team Authors. All rights reserved.
-import unittest
-
-from modelscope.models import Model
-from modelscope.models.cv.vision_efficient_tuning.vision_efficient_tuning import \
- VisionEfficientTuningModel
-from modelscope.pipelines import pipeline
-from modelscope.utils.constant import Tasks
-from modelscope.utils.demo_utils import DemoCompatibilityCheck
-from modelscope.utils.test_utils import test_level
-
-
-class VisionEfficientTuningPromptTest(unittest.TestCase,
- DemoCompatibilityCheck):
-
- def setUp(self) -> None:
- self.task = Tasks.vision_efficient_tuning
- self.model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-prompt'
-
- @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
- def test_run_pipeline(self):
-
- petl_pipeline = pipeline(self.task, self.model_id)
- result = petl_pipeline(
- 'data/test/images/vision_efficient_tuning_test_1.png')
-
- print(f'Vision-efficient-tuning-prompt output: {result}.')
-
- @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
- def test_load_model_from_pretrained(self):
- model = Model.from_pretrained(
- 'damo/cv_vitb16_classification_vision-efficient-tuning-prompt')
- self.assertTrue(model.__class__ == VisionEfficientTuningModel)
-
-
-if __name__ == '__main__':
- unittest.main()
diff --git a/tests/pipelines/test_vop_retrieval_sebias.py b/tests/pipelines/test_vop_retrieval_sebias.py
new file mode 100644
index 00000000..bea1bc45
--- /dev/null
+++ b/tests/pipelines/test_vop_retrieval_sebias.py
@@ -0,0 +1,36 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import unittest
+
+from modelscope.models import Model
+from modelscope.models.cv.vop_retrieval import VideoTextRetrievalModelSeries
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import Tasks
+from modelscope.utils.demo_utils import DemoCompatibilityCheck
+from modelscope.utils.test_utils import test_level
+
+
+class VopRetrievalTest(unittest.TestCase, DemoCompatibilityCheck):
+
+ def setUp(self) -> None:
+ self.task = Tasks.vop_retrieval
+ # self.model_id = '../cv_vit-b32_retrieval_vop_bias'
+ self.model_id = 'damo/cv_vit-b32_retrieval_vop_bias'
+
+ @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
+ def test_run_modelhub(self):
+ vop_pipeline = pipeline(self.task, self.model_id)
+ # t2v
+ result = vop_pipeline('a squid is talking')
+ # v2t
+ # result = vop_pipeline('video10.mp4')
+ print(f'vop output: {result}.')
+
+ @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
+ def test_load_model_from_pretrained(self):
+ # model = Model.from_pretrained('../cv_vit-b32_retrieval_vop_bias')
+ model = Model.from_pretrained('damo/cv_vit-b32_retrieval_vop_bias')
+ self.assertTrue(model.__class__ == VideoTextRetrievalModelSeries)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/tests/pipelines/test_vop_retrieval_separtial.py b/tests/pipelines/test_vop_retrieval_separtial.py
new file mode 100644
index 00000000..942fbd3b
--- /dev/null
+++ b/tests/pipelines/test_vop_retrieval_separtial.py
@@ -0,0 +1,36 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import unittest
+
+from modelscope.models import Model
+from modelscope.models.cv.vop_retrieval import VideoTextRetrievalModelSeries
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import Tasks
+from modelscope.utils.demo_utils import DemoCompatibilityCheck
+from modelscope.utils.test_utils import test_level
+
+
+class VopRetrievalTest(unittest.TestCase, DemoCompatibilityCheck):
+
+ def setUp(self) -> None:
+ self.task = Tasks.vop_retrieval
+ # self.model_id = '../cv_vit-b32_retrieval_vop'
+ self.model_id = 'damo/cv_vit-b32_retrieval_vop_partial'
+
+ @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
+ def test_run_modelhub(self):
+ vop_pipeline = pipeline(self.task, self.model_id)
+ # t2v
+ result = vop_pipeline('a squid is talking')
+ # v2t
+ # result = vop_pipeline('video10.mp4')
+ print(f'vop output: {result}.')
+
+ @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
+ def test_load_model_from_pretrained(self):
+ # model = Model.from_pretrained('../cv_vit-b32_retrieval_vop')
+ model = Model.from_pretrained('damo/cv_vit-b32_retrieval_vop_partial')
+ self.assertTrue(model.__class__ == VideoTextRetrievalModelSeries)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/tests/pipelines/test_vop_retrieval_seproj.py b/tests/pipelines/test_vop_retrieval_seproj.py
new file mode 100644
index 00000000..a371ac36
--- /dev/null
+++ b/tests/pipelines/test_vop_retrieval_seproj.py
@@ -0,0 +1,36 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import unittest
+
+from modelscope.models import Model
+from modelscope.models.cv.vop_retrieval import VideoTextRetrievalModelSeries
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import Tasks
+from modelscope.utils.demo_utils import DemoCompatibilityCheck
+from modelscope.utils.test_utils import test_level
+
+
+class VopRetrievalTest(unittest.TestCase, DemoCompatibilityCheck):
+
+ def setUp(self) -> None:
+ self.task = Tasks.vop_retrieval
+ # self.model_id = '../cv_vit-b32_retrieval_vop'
+ self.model_id = 'damo/cv_vit-b32_retrieval_vop_proj'
+
+ @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
+ def test_run_modelhub(self):
+ vop_pipeline = pipeline(self.task, self.model_id)
+ # t2v
+ result = vop_pipeline('a squid is talking')
+ # v2t
+ # result = vop_pipeline('video10.mp4')
+ print(f'vop output: {result}.')
+
+ @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
+ def test_load_model_from_pretrained(self):
+ # model = Model.from_pretrained('../cv_vit-b32_retrieval_vop')
+ model = Model.from_pretrained('damo/cv_vit-b32_retrieval_vop_proj')
+ self.assertTrue(model.__class__ == VideoTextRetrievalModelSeries)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/tests/run_analysis.py b/tests/run_analysis.py
index d6a526ac..ca0a0018 100644
--- a/tests/run_analysis.py
+++ b/tests/run_analysis.py
@@ -259,7 +259,7 @@ def get_test_suites_to_run():
affected_trainer_cases.extend(
model_trainer_map[model_id])
elif (affected_register_module[0] == 'HOOKS'
- or affected_register_module[0] == 'TASK_DATASETS'):
+ or affected_register_module[0] == 'CUSTOM_DATASETS'):
# ["HOOKS", "", "CheckpointHook", "CheckpointHook"]
# ["HOOKS", "", hook_name, class_name]
# HOOKS, DATASETS modify run all trainer cases
diff --git a/tests/run_config.yaml b/tests/run_config.yaml
index e7466af6..773c6397 100644
--- a/tests/run_config.yaml
+++ b/tests/run_config.yaml
@@ -61,11 +61,13 @@ isolated: # test cases that may require excessive anmount of GPU memory or run
- test_bad_image_detecting.py
- test_image_portrait_stylization_trainer.py
- test_controllable_image_generation.py
+ - test_image_colorization_trainer.py
envs:
default: # default env, case not in other env will in default, pytorch.
dependencies: # requirement packages,pip install before test case run.
- numpy>=1.20
+ - protobuf<4,>=3.20.2
tensorflow1x: # cases excuted tensorflow1.x framework.
requirements: # requirements files run before test case run.
- tensorflow1x.txt
@@ -82,3 +84,4 @@ envs:
- test_skin_retouching.py
- test_image_style_transfer.py
- test_image_portrait_stylization_trainer.py
+ - test_language_identification.py
diff --git a/tests/taskdataset/test_veco_dataset.py b/tests/taskdataset/test_veco_dataset.py
index 76da1681..c220c363 100644
--- a/tests/taskdataset/test_veco_dataset.py
+++ b/tests/taskdataset/test_veco_dataset.py
@@ -2,7 +2,8 @@
import unittest
-from modelscope.msdatasets.task_datasets.veco_dataset import VecoDataset
+from modelscope.msdatasets.dataset_cls.custom_datasets.veco_dataset import \
+ VecoDataset
from modelscope.utils.test_utils import test_level
diff --git a/tests/trainers/audio/test_kws_farfield_trainer.py b/tests/trainers/audio/test_kws_farfield_trainer.py
index 70b68a11..cc2b38f6 100644
--- a/tests/trainers/audio/test_kws_farfield_trainer.py
+++ b/tests/trainers/audio/test_kws_farfield_trainer.py
@@ -81,3 +81,5 @@ class TestKwsFarfieldTrainer(unittest.TestCase):
results_files = os.listdir(self.tmp_dir)
self.assertIn(f'{trainer.timestamp}.log.json', results_files,
f'work_dir:{self.tmp_dir}')
+ self.assertIn('val_dataset.bin', results_files,
+ f'work_dir:{self.tmp_dir}')
diff --git a/tests/trainers/hooks/compression/test_sparsity_hook.py b/tests/trainers/hooks/compression/test_sparsity_hook.py
index d8dcc879..499d6cc1 100644
--- a/tests/trainers/hooks/compression/test_sparsity_hook.py
+++ b/tests/trainers/hooks/compression/test_sparsity_hook.py
@@ -92,7 +92,6 @@ class SparsityHookTest(unittest.TestCase):
train_dataloader = trainer._build_dataloader_with_dataset(
trainer.train_dataset, **trainer.cfg.train.get('dataloader', {}))
trainer.register_optimizers_hook()
- trainer.register_hook_from_cfg(trainer.cfg.train.hooks)
trainer.train_dataloader = train_dataloader
trainer.data_loader = train_dataloader
trainer.invoke_hook(TrainerStages.before_run)
diff --git a/tests/trainers/hooks/test_lr_scheduler_hook.py b/tests/trainers/hooks/test_lr_scheduler_hook.py
index 9e1865d5..cd28b055 100644
--- a/tests/trainers/hooks/test_lr_scheduler_hook.py
+++ b/tests/trainers/hooks/test_lr_scheduler_hook.py
@@ -15,6 +15,7 @@ from modelscope.metainfo import Trainers
from modelscope.metrics.builder import METRICS, MetricKeys
from modelscope.models.base import TorchModel
from modelscope.trainers import build_trainer
+from modelscope.trainers.default_config import merge_hooks
from modelscope.utils.constant import LogKeys, ModelFile, TrainerStages
from modelscope.utils.registry import default_group
from modelscope.utils.test_utils import create_dummy_test_dataset
@@ -104,7 +105,10 @@ class LrSchedulerHookTest(unittest.TestCase):
train_dataloader = trainer._build_dataloader_with_dataset(
trainer.train_dataset, **trainer.cfg.train.get('dataloader', {}))
trainer.register_optimizers_hook()
-
+ trainer._hooks = [
+ hook for hook in trainer._hooks if hook.__class__.__name__ not in
+ ['CheckpointHook', 'TextLoggerHook', 'IterTimerHook']
+ ]
trainer.invoke_hook(TrainerStages.before_run)
log_lrs = []
optim_lrs = []
@@ -173,7 +177,10 @@ class LrSchedulerHookTest(unittest.TestCase):
train_dataloader = trainer._build_dataloader_with_dataset(
trainer.train_dataset, **trainer.cfg.train.get('dataloader', {}))
trainer.register_optimizers_hook()
-
+ trainer._hooks = [
+ hook for hook in trainer._hooks if hook.__class__.__name__ not in
+ ['CheckpointHook', 'TextLoggerHook', 'IterTimerHook']
+ ]
trainer.invoke_hook(TrainerStages.before_run)
log_lrs = []
optim_lrs = []
@@ -254,7 +261,10 @@ class LrSchedulerHookTest(unittest.TestCase):
train_dataloader = trainer._build_dataloader_with_dataset(
trainer.train_dataset, **trainer.cfg.train.get('dataloader', {}))
trainer.register_optimizers_hook()
-
+ trainer._hooks = [
+ hook for hook in trainer._hooks if hook.__class__.__name__ not in
+ ['CheckpointHook', 'TextLoggerHook', 'IterTimerHook']
+ ]
trainer.invoke_hook(TrainerStages.before_run)
log_lrs = []
optim_lrs = []
@@ -355,8 +365,10 @@ class PlateauLrSchedulerHookTest(unittest.TestCase):
trainer.train_dataloader = train_dataloader
trainer.data_loader = train_dataloader
trainer.register_optimizers_hook()
- trainer.register_hook_from_cfg(trainer.cfg.train.hooks)
-
+ trainer._hooks = [
+ hook for hook in trainer._hooks if hook.__class__.__name__ not in
+ ['CheckpointHook', 'TextLoggerHook', 'IterTimerHook']
+ ]
trainer.invoke_hook(TrainerStages.before_run)
log_lrs = []
optim_lrs = []
diff --git a/tests/trainers/hooks/test_optimizer_hook.py b/tests/trainers/hooks/test_optimizer_hook.py
index 1bf9d292..b9899c36 100644
--- a/tests/trainers/hooks/test_optimizer_hook.py
+++ b/tests/trainers/hooks/test_optimizer_hook.py
@@ -80,7 +80,10 @@ class OptimizerHookTest(unittest.TestCase):
train_dataloader = trainer._build_dataloader_with_dataset(
trainer.train_dataset, **trainer.cfg.train.get('dataloader', {}))
trainer.register_optimizers_hook()
-
+ trainer._hooks = [
+ hook for hook in trainer._hooks if hook.__class__.__name__ not in
+ ['CheckpointHook', 'TextLoggerHook', 'IterTimerHook']
+ ]
trainer.invoke_hook(TrainerStages.before_run)
for _ in range(trainer._epoch, trainer._max_epochs):
@@ -147,7 +150,10 @@ class TorchAMPOptimizerHookTest(unittest.TestCase):
train_dataloader = trainer._build_dataloader_with_dataset(
trainer.train_dataset, **trainer.cfg.train.get('dataloader', {}))
trainer.register_optimizers_hook()
-
+ trainer._hooks = [
+ hook for hook in trainer._hooks if hook.__class__.__name__ not in
+ ['CheckpointHook', 'TextLoggerHook', 'IterTimerHook']
+ ]
trainer.invoke_hook(TrainerStages.before_run)
for _ in range(trainer._epoch, trainer._max_epochs):
@@ -223,6 +229,10 @@ class TorchApexOptimizerHookTest(unittest.TestCase):
trainer.train_dataset, **trainer.cfg.train.get('dataloader', {}))
trainer.register_optimizers_hook()
trainer.register_hook_from_cfg([{'type': 'ApexAMPOptimizerHook'}])
+ trainer._hooks = [
+ hook for hook in trainer._hooks if hook.__class__.__name__ not in
+ ['CheckpointHook', 'TextLoggerHook', 'IterTimerHook']
+ ]
trainer.invoke_hook(TrainerStages.before_run)
for _ in range(trainer._epoch, trainer._max_epochs):
diff --git a/tests/trainers/hooks/test_timer_hook.py b/tests/trainers/hooks/test_timer_hook.py
index 62ed1262..9755becb 100644
--- a/tests/trainers/hooks/test_timer_hook.py
+++ b/tests/trainers/hooks/test_timer_hook.py
@@ -14,6 +14,7 @@ from torch.optim.lr_scheduler import MultiStepLR
from modelscope.metainfo import Trainers
from modelscope.models.base import TorchModel
from modelscope.trainers import build_trainer
+from modelscope.trainers.default_config import merge_hooks
from modelscope.utils.constant import LogKeys, ModelFile, TrainerStages
from modelscope.utils.test_utils import create_dummy_test_dataset
@@ -83,7 +84,6 @@ class IterTimerHookTest(unittest.TestCase):
train_dataloader = trainer._build_dataloader_with_dataset(
trainer.train_dataset, **trainer.cfg.train.get('dataloader', {}))
trainer.register_optimizers_hook()
- trainer.register_hook_from_cfg(trainer.cfg.train.hooks)
trainer.train_dataloader = train_dataloader
trainer.data_loader = train_dataloader
trainer.invoke_hook(TrainerStages.before_run)
diff --git a/tests/trainers/test_action_detection_trainer.py b/tests/trainers/test_action_detection_trainer.py
new file mode 100644
index 00000000..96f02cf9
--- /dev/null
+++ b/tests/trainers/test_action_detection_trainer.py
@@ -0,0 +1,79 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+
+import os
+import shutil
+import subprocess
+import sys
+import tempfile
+import unittest
+
+from modelscope.metainfo import Trainers
+from modelscope.msdatasets import MsDataset
+from modelscope.trainers import build_trainer
+from modelscope.utils.constant import DownloadMode
+from modelscope.utils.test_utils import test_level
+
+
+class TestActionDetectionTrainer(unittest.TestCase):
+
+ def setUp(self):
+ os.environ['OMP_NUM_THREADS'] = '1'
+ print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
+ cmd_uninstall = ['pip', 'uninstall', '-y', 'detectron2']
+ cmd = [
+ 'pip', 'install', '--upgrade',
+ 'git+https://gitee.com/lllcho/detectron2.git'
+ ]
+ subprocess.run(cmd_uninstall)
+ subprocess.run(cmd)
+ import detectron2
+ print(f'Install detectron2 done, version {detectron2.__version__}')
+ self.model_id = 'damo/cv_ResNetC3D_action-detection_detection2d'
+
+ self.train_dataset = MsDataset.load(
+ 'lllcho/ivi_action',
+ subset_name='default',
+ split='train',
+ download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS)
+ self.tmp_dir = tempfile.TemporaryDirectory().name
+ if not os.path.exists(self.tmp_dir):
+ os.makedirs(self.tmp_dir)
+
+ def tearDown(self):
+ shutil.rmtree(self.tmp_dir)
+ super().tearDown()
+
+ @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
+ def test_trainer(self):
+
+ def cfg_modify_fn(cfg):
+ cfg.train.max_iter = 5
+ cfg.train.dataloader.batch_size_per_gpu = 1
+ cfg.train.dataloader.workers_per_gpu = 1
+ cfg.train.optimizer.lr = 1e-4
+ cfg.train.lr_scheduler.warmup_step = 1
+ cfg.train.checkpoint_interval = 5000
+
+ cfg.evaluation.interval = 5000
+ cfg.evaluation.dataloader.batch_size_per_gpu = 1
+ cfg.evaluation.dataloader.workers_per_gpu = 1
+
+ cfg.train.work_dir = self.tmp_dir
+ cfg.train.num_gpus = 0
+ return cfg
+
+ trainer = build_trainer(
+ Trainers.action_detection,
+ dict(
+ model_id=self.model_id,
+ train_dataset=self.train_dataset,
+ test_dataset=self.train_dataset,
+ cfg_modify_fn=cfg_modify_fn))
+ trainer.train()
+ results_files = os.listdir(self.tmp_dir)
+ self.assertIn('config.py', results_files)
+ self.assertIn('model_final.pth', results_files)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/tests/trainers/test_finetune_vision_efficient_tuning.py b/tests/trainers/test_finetune_vision_efficient_tuning.py
new file mode 100644
index 00000000..8719c64f
--- /dev/null
+++ b/tests/trainers/test_finetune_vision_efficient_tuning.py
@@ -0,0 +1,355 @@
+# Copyright 2022-2023 The Alibaba Fundamental Vision Team Authors. All rights reserved.
+import os
+import shutil
+import tempfile
+import unittest
+
+from modelscope.metainfo import Trainers
+from modelscope.msdatasets import MsDataset
+from modelscope.trainers import build_trainer
+from modelscope.utils.test_utils import test_level
+
+
+class TestVisionEfficientTuningTrainer(unittest.TestCase):
+
+ def setUp(self):
+ print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
+
+ self.train_dataset = MsDataset.load(
+ 'foundation_model_evaluation_benchmark',
+ namespace='damo',
+ subset_name='OxfordFlowers',
+ split='train')
+
+ self.eval_dataset = MsDataset.load(
+ 'foundation_model_evaluation_benchmark',
+ namespace='damo',
+ subset_name='OxfordFlowers',
+ split='eval')
+
+ self.max_epochs = 1
+ self.num_classes = 102
+ self.tune_length = 10
+
+ self.tmp_dir = tempfile.TemporaryDirectory().name
+ if not os.path.exists(self.tmp_dir):
+ os.makedirs(self.tmp_dir)
+
+ def tearDown(self):
+ shutil.rmtree(self.tmp_dir)
+ super().tearDown()
+
+ @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
+ def test_vision_efficient_tuning_adapter_train(self):
+ model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-adapter'
+
+ def cfg_modify_fn(cfg):
+ cfg.model.head.num_classes = self.num_classes
+ cfg.model.finetune = True
+ cfg.train.max_epochs = self.max_epochs
+ cfg.train.lr_scheduler.T_max = self.max_epochs
+ cfg.model.backbone.adapter_length = self.tune_length
+ return cfg
+
+ kwargs = dict(
+ model=model_id,
+ work_dir=self.tmp_dir,
+ train_dataset=self.train_dataset,
+ eval_dataset=self.eval_dataset,
+ cfg_modify_fn=cfg_modify_fn)
+
+ trainer = build_trainer(
+ name=Trainers.vision_efficient_tuning, default_args=kwargs)
+ trainer.train()
+ result = trainer.evaluate()
+ print(f'Vision-efficient-tuning-adapter train output: {result}.')
+
+ results_files = os.listdir(self.tmp_dir)
+ self.assertIn(f'{trainer.timestamp}.log.json', results_files)
+ for i in range(self.max_epochs):
+ self.assertIn(f'epoch_{i+1}.pth', results_files)
+
+ @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
+ def test_vision_efficient_tuning_adapter_eval(self):
+ model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-adapter'
+
+ kwargs = dict(
+ model=model_id,
+ work_dir=self.tmp_dir,
+ train_dataset=None,
+ eval_dataset=self.eval_dataset)
+
+ trainer = build_trainer(
+ name=Trainers.vision_efficient_tuning, default_args=kwargs)
+ result = trainer.evaluate()
+ print(f'Vision-efficient-tuning-adapter eval output: {result}.')
+
+ @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
+ def test_vision_efficient_tuning_lora_train(self):
+ model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-lora'
+
+ def cfg_modify_fn(cfg):
+ cfg.model.head.num_classes = self.num_classes
+ cfg.model.finetune = True
+ cfg.train.max_epochs = self.max_epochs
+ cfg.train.lr_scheduler.T_max = self.max_epochs
+ cfg.model.backbone.lora_length = self.tune_length
+ return cfg
+
+ kwargs = dict(
+ model=model_id,
+ work_dir=self.tmp_dir,
+ train_dataset=self.train_dataset,
+ eval_dataset=self.eval_dataset,
+ cfg_modify_fn=cfg_modify_fn)
+
+ trainer = build_trainer(
+ name=Trainers.vision_efficient_tuning, default_args=kwargs)
+ trainer.train()
+ result = trainer.evaluate()
+ print(f'Vision-efficient-tuning-lora train output: {result}.')
+
+ results_files = os.listdir(self.tmp_dir)
+ self.assertIn(f'{trainer.timestamp}.log.json', results_files)
+ for i in range(self.max_epochs):
+ self.assertIn(f'epoch_{i+1}.pth', results_files)
+
+ @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
+ def test_vision_efficient_tuning_lora_eval(self):
+ model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-lora'
+
+ kwargs = dict(
+ model=model_id,
+ work_dir=self.tmp_dir,
+ train_dataset=None,
+ eval_dataset=self.eval_dataset)
+
+ trainer = build_trainer(
+ name=Trainers.vision_efficient_tuning, default_args=kwargs)
+ result = trainer.evaluate()
+ print(f'Vision-efficient-tuning-lora eval output: {result}.')
+
+ @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
+ def test_vision_efficient_tuning_prefix_train(self):
+ model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-prefix'
+
+ def cfg_modify_fn(cfg):
+ cfg.model.head.num_classes = self.num_classes
+ cfg.model.finetune = True
+ cfg.train.max_epochs = self.max_epochs
+ cfg.train.lr_scheduler.T_max = self.max_epochs
+ cfg.model.backbone.prefix_length = self.tune_length
+ return cfg
+
+ kwargs = dict(
+ model=model_id,
+ work_dir=self.tmp_dir,
+ train_dataset=self.train_dataset,
+ eval_dataset=self.eval_dataset,
+ cfg_modify_fn=cfg_modify_fn)
+
+ trainer = build_trainer(
+ name=Trainers.vision_efficient_tuning, default_args=kwargs)
+ trainer.train()
+ result = trainer.evaluate()
+ print(f'Vision-efficient-tuning-prefix train output: {result}.')
+
+ results_files = os.listdir(self.tmp_dir)
+ self.assertIn(f'{trainer.timestamp}.log.json', results_files)
+ for i in range(self.max_epochs):
+ self.assertIn(f'epoch_{i+1}.pth', results_files)
+
+ @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
+ def test_vision_efficient_tuning_prefix_eval(self):
+ model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-prefix'
+
+ kwargs = dict(
+ model=model_id,
+ work_dir=self.tmp_dir,
+ train_dataset=None,
+ eval_dataset=self.eval_dataset)
+
+ trainer = build_trainer(
+ name=Trainers.vision_efficient_tuning, default_args=kwargs)
+ result = trainer.evaluate()
+ print(f'Vision-efficient-tuning-prefix eval output: {result}.')
+
+ @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
+ def test_vision_efficient_tuning_prompt_train(self):
+ model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-prompt'
+
+ def cfg_modify_fn(cfg):
+ cfg.model.head.num_classes = self.num_classes
+ cfg.model.finetune = True
+ cfg.train.max_epochs = self.max_epochs
+ cfg.train.lr_scheduler.T_max = self.max_epochs
+ cfg.model.backbone.prompt_length = self.tune_length
+ return cfg
+
+ kwargs = dict(
+ model=model_id,
+ work_dir=self.tmp_dir,
+ train_dataset=self.train_dataset,
+ eval_dataset=self.eval_dataset,
+ cfg_modify_fn=cfg_modify_fn)
+
+ trainer = build_trainer(
+ name=Trainers.vision_efficient_tuning, default_args=kwargs)
+ trainer.train()
+ result = trainer.evaluate()
+ print(f'Vision-efficient-tuning-prompt train output: {result}.')
+
+ results_files = os.listdir(self.tmp_dir)
+ self.assertIn(f'{trainer.timestamp}.log.json', results_files)
+ for i in range(self.max_epochs):
+ self.assertIn(f'epoch_{i+1}.pth', results_files)
+
+ @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
+ def test_vision_efficient_tuning_prompt_eval(self):
+ model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-prompt'
+
+ kwargs = dict(
+ model=model_id,
+ work_dir=self.tmp_dir,
+ train_dataset=None,
+ eval_dataset=self.eval_dataset)
+
+ trainer = build_trainer(
+ name=Trainers.vision_efficient_tuning, default_args=kwargs)
+ result = trainer.evaluate()
+ print(f'Vision-efficient-tuning-prompt eval output: {result}.')
+
+ @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
+ def test_vision_efficient_tuning_bitfit_train(self):
+ model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-bitfit'
+
+ # model_id = '../modelcard/cv_vitb16_classification_vision-efficient-tuning-bitfit'
+ def cfg_modify_fn(cfg):
+ cfg.model.head.num_classes = self.num_classes
+ cfg.model.finetune = True
+ cfg.train.max_epochs = self.max_epochs
+ cfg.train.lr_scheduler.T_max = self.max_epochs
+ return cfg
+
+ kwargs = dict(
+ model=model_id,
+ work_dir=self.tmp_dir,
+ train_dataset=self.train_dataset,
+ eval_dataset=self.eval_dataset,
+ cfg_modify_fn=cfg_modify_fn)
+
+ trainer = build_trainer(
+ name=Trainers.vision_efficient_tuning, default_args=kwargs)
+ trainer.train()
+ result = trainer.evaluate()
+ print(f'Vision-efficient-tuning-bitfit train output: {result}.')
+
+ results_files = os.listdir(self.tmp_dir)
+ self.assertIn(f'{trainer.timestamp}.log.json', results_files)
+ for i in range(self.max_epochs):
+ self.assertIn(f'epoch_{i+1}.pth', results_files)
+
+ @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
+ def test_vision_efficient_tuning_bitfit_eval(self):
+ model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-bitfit'
+ # model_id = '../modelcard/cv_vitb16_classification_vision-efficient-tuning-bitfit'
+ kwargs = dict(
+ model=model_id,
+ work_dir=self.tmp_dir,
+ train_dataset=None,
+ eval_dataset=self.eval_dataset)
+
+ trainer = build_trainer(
+ name=Trainers.vision_efficient_tuning, default_args=kwargs)
+ result = trainer.evaluate()
+ print(f'Vision-efficient-tuning-bitfit eval output: {result}.')
+
+ @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
+ def test_vision_efficient_tuning_sidetuning_train(self):
+ model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-sidetuning'
+
+ def cfg_modify_fn(cfg):
+ cfg.model.head.num_classes = self.num_classes
+ cfg.model.finetune = True
+ cfg.train.max_epochs = self.max_epochs
+ cfg.train.lr_scheduler.T_max = self.max_epochs
+ return cfg
+
+ kwargs = dict(
+ model=model_id,
+ work_dir=self.tmp_dir,
+ train_dataset=self.train_dataset,
+ eval_dataset=self.eval_dataset,
+ cfg_modify_fn=cfg_modify_fn)
+
+ trainer = build_trainer(
+ name=Trainers.vision_efficient_tuning, default_args=kwargs)
+ trainer.train()
+ result = trainer.evaluate()
+ print(f'Vision-efficient-tuning-sidetuning train output: {result}.')
+
+ results_files = os.listdir(self.tmp_dir)
+ self.assertIn(f'{trainer.timestamp}.log.json', results_files)
+ for i in range(self.max_epochs):
+ self.assertIn(f'epoch_{i+1}.pth', results_files)
+
+ @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
+ def test_vision_efficient_tuning_sidetuning_eval(self):
+ model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-sidetuning'
+ kwargs = dict(
+ model=model_id,
+ work_dir=self.tmp_dir,
+ train_dataset=None,
+ eval_dataset=self.eval_dataset)
+
+ trainer = build_trainer(
+ name=Trainers.vision_efficient_tuning, default_args=kwargs)
+ result = trainer.evaluate()
+ print(f'Vision-efficient-tuning-sidetuning eval output: {result}.')
+
+ @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
+ def test_vision_efficient_tuning_utuning_train(self):
+ model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-utuning'
+
+ def cfg_modify_fn(cfg):
+ cfg.model.head.num_classes = self.num_classes
+ cfg.model.finetune = True
+ cfg.train.max_epochs = self.max_epochs
+ cfg.train.lr_scheduler.T_max = self.max_epochs
+ return cfg
+
+ kwargs = dict(
+ model=model_id,
+ work_dir=self.tmp_dir,
+ train_dataset=self.train_dataset,
+ eval_dataset=self.eval_dataset,
+ cfg_modify_fn=cfg_modify_fn)
+
+ trainer = build_trainer(
+ name=Trainers.vision_efficient_tuning, default_args=kwargs)
+ trainer.train()
+ result = trainer.evaluate()
+ print(f'Vision-efficient-tuning-utuning train output: {result}.')
+
+ results_files = os.listdir(self.tmp_dir)
+ self.assertIn(f'{trainer.timestamp}.log.json', results_files)
+ for i in range(self.max_epochs):
+ self.assertIn(f'epoch_{i+1}.pth', results_files)
+
+ @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
+ def test_vision_efficient_tuning_utuning_eval(self):
+ model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-utuning'
+ kwargs = dict(
+ model=model_id,
+ work_dir=self.tmp_dir,
+ train_dataset=None,
+ eval_dataset=self.eval_dataset)
+
+ trainer = build_trainer(
+ name=Trainers.vision_efficient_tuning, default_args=kwargs)
+ result = trainer.evaluate()
+ print(f'Vision-efficient-tuning-utuning eval output: {result}.')
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/tests/trainers/test_image_colorization_trainer.py b/tests/trainers/test_image_colorization_trainer.py
new file mode 100644
index 00000000..0c736c4b
--- /dev/null
+++ b/tests/trainers/test_image_colorization_trainer.py
@@ -0,0 +1,94 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import os
+import shutil
+import tempfile
+import unittest
+
+from modelscope.hub.snapshot_download import snapshot_download
+from modelscope.models.cv.image_colorization import DDColorForImageColorization
+from modelscope.msdatasets import MsDataset
+from modelscope.msdatasets.dataset_cls.custom_datasets.image_colorization import \
+ ImageColorizationDataset
+from modelscope.pipelines import pipeline
+from modelscope.trainers import build_trainer
+from modelscope.utils.config import Config
+from modelscope.utils.constant import DownloadMode, ModelFile, Tasks
+from modelscope.utils.logger import get_logger
+from modelscope.utils.test_utils import test_level
+
+logger = get_logger()
+
+
+class ImageColorizationTrainerTest(unittest.TestCase):
+
+ def setUp(self):
+ print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
+ self.tmp_dir = tempfile.TemporaryDirectory().name
+ if not os.path.exists(self.tmp_dir):
+ os.makedirs(self.tmp_dir)
+
+ self.model_id = 'damo/cv_ddcolor_image-colorization'
+ self.cache_path = snapshot_download(self.model_id)
+ self.config = Config.from_file(
+ os.path.join(self.cache_path, ModelFile.CONFIGURATION))
+ dataset_train = MsDataset.load(
+ 'imagenet-val5k-image',
+ namespace='damo',
+ subset_name='default',
+ split='validation',
+ download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS)._hf_ds
+ dataset_val = MsDataset.load(
+ 'imagenet-val5k-image',
+ namespace='damo',
+ subset_name='default',
+ split='validation',
+ download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS)._hf_ds
+ self.dataset_train = ImageColorizationDataset(
+ dataset_train, self.config.dataset, is_train=True)
+ self.dataset_val = ImageColorizationDataset(
+ dataset_val, self.config.dataset, is_train=False)
+
+ def tearDown(self):
+ shutil.rmtree(self.tmp_dir, ignore_errors=True)
+ super().tearDown()
+
+ @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
+ def test_trainer(self):
+ kwargs = dict(
+ model=self.model_id,
+ train_dataset=self.dataset_train,
+ eval_dataset=self.dataset_val,
+ work_dir=self.tmp_dir)
+ trainer = build_trainer(default_args=kwargs)
+ trainer.train()
+ results_files = os.listdir(self.tmp_dir)
+ self.assertIn(f'{trainer.timestamp}.log.json', results_files)
+ for i in range(1):
+ self.assertIn(f'epoch_{i+1}.pth', results_files)
+ pipeline_colorization = pipeline(
+ task=Tasks.image_colorization, model=f'{self.tmp_dir}/output')
+ pipeline_colorization('data/test/images/marilyn_monroe_4.jpg')
+
+ @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
+ def test_trainer_with_model_and_args(self):
+ model = DDColorForImageColorization.from_pretrained(self.cache_path)
+ kwargs = dict(
+ cfg_file=os.path.join(self.cache_path, ModelFile.CONFIGURATION),
+ model=model,
+ train_dataset=self.dataset_train,
+ eval_dataset=self.dataset_val,
+ max_epochs=1,
+ work_dir=self.tmp_dir)
+ trainer = build_trainer(default_args=kwargs)
+ trainer.train()
+ results_files = os.listdir(self.tmp_dir)
+ self.assertIn(f'{trainer.timestamp}.log.json', results_files)
+ for i in range(1):
+ self.assertIn(f'epoch_{i+1}.pth', results_files)
+ pipeline_colorization = pipeline(
+ task=Tasks.image_colorization, model=f'{self.tmp_dir}/output')
+ pipeline_colorization('data/test/images/marilyn_monroe_4.jpg')
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/tests/trainers/test_image_deblur_trainer.py b/tests/trainers/test_image_deblur_trainer.py
index 6ae88726..f07db1bb 100644
--- a/tests/trainers/test_image_deblur_trainer.py
+++ b/tests/trainers/test_image_deblur_trainer.py
@@ -7,7 +7,7 @@ import unittest
from modelscope.hub.snapshot_download import snapshot_download
from modelscope.models.cv.image_deblur import NAFNetForImageDeblur
from modelscope.msdatasets import MsDataset
-from modelscope.msdatasets.task_datasets.gopro_image_deblurring_dataset import \
+from modelscope.msdatasets.dataset_cls.custom_datasets.gopro_image_deblurring_dataset import \
GoproImageDeblurringDataset
from modelscope.trainers import build_trainer
from modelscope.utils.config import Config
diff --git a/tests/trainers/test_image_denoise_trainer.py b/tests/trainers/test_image_denoise_trainer.py
index 3b5882bd..e2b65b32 100644
--- a/tests/trainers/test_image_denoise_trainer.py
+++ b/tests/trainers/test_image_denoise_trainer.py
@@ -7,7 +7,7 @@ import unittest
from modelscope.hub.snapshot_download import snapshot_download
from modelscope.models.cv.image_denoise import NAFNetForImageDenoise
from modelscope.msdatasets import MsDataset
-from modelscope.msdatasets.task_datasets.sidd_image_denoising import \
+from modelscope.msdatasets.dataset_cls.custom_datasets.sidd_image_denoising import \
SiddImageDenoisingDataset
from modelscope.trainers import build_trainer
from modelscope.utils.config import Config
diff --git a/tests/trainers/test_image_instance_segmentation_trainer.py b/tests/trainers/test_image_instance_segmentation_trainer.py
index 03f7eea3..923eca2c 100644
--- a/tests/trainers/test_image_instance_segmentation_trainer.py
+++ b/tests/trainers/test_image_instance_segmentation_trainer.py
@@ -11,8 +11,6 @@ from modelscope.metainfo import Trainers
from modelscope.models.cv.image_instance_segmentation import \
CascadeMaskRCNNSwinModel
from modelscope.msdatasets import MsDataset
-from modelscope.msdatasets.task_datasets import \
- ImageInstanceSegmentationCocoDataset
from modelscope.trainers import build_trainer
from modelscope.utils.config import Config, ConfigDict
from modelscope.utils.constant import DownloadMode, ModelFile
diff --git a/tests/trainers/test_image_portrait_enhancement_trainer.py b/tests/trainers/test_image_portrait_enhancement_trainer.py
index a9fc74cb..b556a13b 100644
--- a/tests/trainers/test_image_portrait_enhancement_trainer.py
+++ b/tests/trainers/test_image_portrait_enhancement_trainer.py
@@ -1,21 +1,15 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
-import os.path as osp
import shutil
import tempfile
import unittest
-from typing import Callable, List, Optional, Tuple, Union
-
-import cv2
-import torch
-from torch.utils import data as data
from modelscope.hub.snapshot_download import snapshot_download
from modelscope.metainfo import Trainers
from modelscope.models.cv.image_portrait_enhancement import \
ImagePortraitEnhancement
from modelscope.msdatasets import MsDataset
-from modelscope.msdatasets.task_datasets.image_portrait_enhancement import \
+from modelscope.msdatasets.dataset_cls.custom_datasets.image_portrait_enhancement import \
ImagePortraitEnhancementDataset
from modelscope.trainers import build_trainer
from modelscope.utils.constant import DownloadMode, ModelFile
diff --git a/tests/trainers/test_language_guided_video_summarization_trainer.py b/tests/trainers/test_language_guided_video_summarization_trainer.py
index 3ff0e102..2673e4b9 100644
--- a/tests/trainers/test_language_guided_video_summarization_trainer.py
+++ b/tests/trainers/test_language_guided_video_summarization_trainer.py
@@ -7,7 +7,7 @@ import unittest
from modelscope.hub.snapshot_download import snapshot_download
from modelscope.models.cv.language_guided_video_summarization import \
ClipItVideoSummarization
-from modelscope.msdatasets.task_datasets import \
+from modelscope.msdatasets.dataset_cls.custom_datasets import \
LanguageGuidedVideoSummarizationDataset
from modelscope.trainers import build_trainer
from modelscope.utils.config import Config
diff --git a/tests/trainers/test_nerf_recon_acc_trainer.py b/tests/trainers/test_nerf_recon_acc_trainer.py
index 514aa262..4b6c8091 100644
--- a/tests/trainers/test_nerf_recon_acc_trainer.py
+++ b/tests/trainers/test_nerf_recon_acc_trainer.py
@@ -4,6 +4,7 @@ import unittest
from modelscope.msdatasets import MsDataset
from modelscope.trainers.cv import NeRFReconAccTrainer
+from modelscope.utils.constant import DownloadMode
from modelscope.utils.test_utils import test_level
@@ -14,8 +15,11 @@ class TestNeRFReconAccTrainer(unittest.TestCase):
model_id = 'damo/cv_nerf-3d-reconstruction-accelerate_damo'
data_dir = MsDataset.load(
- 'nerf_recon_dataset', namespace='damo',
- split='train').config_kwargs['split_config']['train']
+ 'nerf_recon_dataset',
+ namespace='damo',
+ split='train',
+ download_mode=DownloadMode.FORCE_REDOWNLOAD
+ ).config_kwargs['split_config']['train']
trainer = NeRFReconAccTrainer(
model=model_id,
diff --git a/tests/trainers/test_ocr_detection_db_trainer.py b/tests/trainers/test_ocr_detection_db_trainer.py
new file mode 100644
index 00000000..10097fea
--- /dev/null
+++ b/tests/trainers/test_ocr_detection_db_trainer.py
@@ -0,0 +1,74 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import glob
+import os
+import shutil
+import tempfile
+import unittest
+
+import torch
+
+from modelscope.hub.snapshot_download import snapshot_download
+from modelscope.metainfo import Trainers
+from modelscope.pipelines import pipeline
+from modelscope.trainers import build_trainer
+from modelscope.utils.config import Config
+from modelscope.utils.constant import ModelFile, Tasks
+from modelscope.utils.test_utils import DistributedTestCase, test_level
+
+
+def _setup():
+ model_id = 'damo/cv_resnet18_ocr-detection-db-line-level_damo'
+ cache_path = snapshot_download(model_id)
+ return cache_path
+
+
+class TestOCRDetectionDBTrainerSingleGPU(unittest.TestCase):
+
+ def setUp(self):
+ self.model_id = 'damo/cv_resnet18_ocr-detection-db-line-level_damo'
+ self.test_image = 'data/test/images/ocr_detection/test_images/X51007339105.jpg'
+ self.cache_path = _setup()
+ self.config_file = os.path.join(self.cache_path, 'configuration.json')
+ self.pretrained_model = os.path.join(
+ self.cache_path, 'db_resnet18_public_line_640x640.pt')
+ self.saved_dir = './workdirs'
+ self.saved_finetune_model = os.path.join(self.saved_dir, 'final.pt')
+ self.saved_infer_model = os.path.join(self.saved_dir,
+ 'pytorch_model.pt')
+
+ @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
+ def test_trainer_finetune_singleGPU(self):
+
+ kwargs = dict(
+ cfg_file=self.config_file,
+ gpu_ids=[
+ 0,
+ ],
+ batch_size=8,
+ max_epochs=5,
+ base_lr=0.007,
+ load_pretrain=True,
+ pretrain_model=self.pretrained_model,
+ cache_path=self.cache_path,
+ train_data_dir=['./data/test/images/ocr_detection/'],
+ train_data_list=[
+ './data/test/images/ocr_detection/train_list.txt'
+ ],
+ val_data_dir=['./data/test/images/ocr_detection/'],
+ val_data_list=['./data/test/images/ocr_detection/test_list.txt'])
+ trainer = build_trainer(
+ name=Trainers.ocr_detection_db, default_args=kwargs)
+ trainer.train()
+ trainer.evaluate(checkpoint_path=self.saved_finetune_model)
+
+ # inference with pipeline using saved inference model
+ cmd = 'cp {} {}'.format(self.config_file, self.saved_dir)
+ os.system(cmd)
+ ocr_detection = pipeline(Tasks.ocr_detection, model=self.saved_dir)
+ result = ocr_detection(self.test_image)
+ print('ocr detection results: ')
+ print(result)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/tests/trainers/test_ocr_recognition_trainer.py b/tests/trainers/test_ocr_recognition_trainer.py
new file mode 100644
index 00000000..ddebc3fe
--- /dev/null
+++ b/tests/trainers/test_ocr_recognition_trainer.py
@@ -0,0 +1,95 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import os
+import shutil
+import tempfile
+import unittest
+
+from modelscope.hub.snapshot_download import snapshot_download
+from modelscope.metainfo import Trainers
+from modelscope.models.cv.ocr_recognition import OCRRecognition
+from modelscope.msdatasets import MsDataset
+from modelscope.trainers import build_trainer
+from modelscope.utils.config import Config, ConfigDict
+from modelscope.utils.constant import DownloadMode, ModelFile
+from modelscope.utils.test_utils import test_level
+
+
+class TestOCRRecognitionTrainer(unittest.TestCase):
+
+ model_id = 'damo/cv_crnn_ocr-recognition-general_damo'
+
+ def setUp(self):
+ print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
+
+ cache_path = snapshot_download(self.model_id, revision='v2.2.2')
+ config_path = os.path.join(cache_path, ModelFile.CONFIGURATION)
+ cfg = Config.from_file(config_path)
+
+ max_epochs = cfg.train.max_epochs
+
+ train_data_cfg = ConfigDict(
+ name='ICDAR13_HCTR_Dataset', split='test', namespace='damo')
+
+ test_data_cfg = ConfigDict(
+ name='ICDAR13_HCTR_Dataset', split='test', namespace='damo')
+
+ self.train_dataset = MsDataset.load(
+ dataset_name=train_data_cfg.name,
+ split=train_data_cfg.split,
+ namespace=train_data_cfg.namespace,
+ download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS)
+ assert next(
+ iter(self.train_dataset.config_kwargs['split_config'].values()))
+
+ self.test_dataset = MsDataset.load(
+ dataset_name=test_data_cfg.name,
+ split=test_data_cfg.split,
+ namespace=train_data_cfg.namespace,
+ download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS)
+ assert next(
+ iter(self.test_dataset.config_kwargs['split_config'].values()))
+
+ self.max_epochs = max_epochs
+
+ self.tmp_dir = tempfile.TemporaryDirectory().name
+ if not os.path.exists(self.tmp_dir):
+ os.makedirs(self.tmp_dir)
+
+ def tearDown(self):
+ shutil.rmtree(self.tmp_dir)
+ super().tearDown()
+
+ @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
+ def test_trainer(self):
+ kwargs = dict(
+ model=self.model_id,
+ train_dataset=self.train_dataset,
+ eval_dataset=self.test_dataset,
+ work_dir=self.tmp_dir)
+
+ trainer = build_trainer(
+ name=Trainers.ocr_recognition, default_args=kwargs)
+ trainer.train()
+
+ @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
+ def test_trainer_with_model_and_args(self):
+ tmp_dir = tempfile.TemporaryDirectory().name
+ if not os.path.exists(tmp_dir):
+ os.makedirs(tmp_dir)
+
+ cache_path = snapshot_download(self.model_id, revision='v2.2.2')
+ model = OCRRecognition.from_pretrained(cache_path)
+ kwargs = dict(
+ cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION),
+ model=model,
+ train_dataset=self.train_dataset,
+ eval_dataset=self.test_dataset,
+ work_dir=tmp_dir)
+
+ trainer = build_trainer(
+ name=Trainers.ocr_recognition, default_args=kwargs)
+ trainer.train()
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/tests/trainers/test_siamese_uie_trainer.py b/tests/trainers/test_siamese_uie_trainer.py
new file mode 100644
index 00000000..bf21ece9
--- /dev/null
+++ b/tests/trainers/test_siamese_uie_trainer.py
@@ -0,0 +1,66 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import os
+import shutil
+import tempfile
+import unittest
+
+from modelscope.metainfo import Trainers
+from modelscope.msdatasets import MsDataset
+from modelscope.pipelines import pipeline
+from modelscope.trainers import build_trainer
+from modelscope.utils.constant import DownloadMode, Tasks
+
+
+class TestFinetuneSiameseUIE(unittest.TestCase):
+
+ def setUp(self):
+ print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
+ self.tmp_dir = tempfile.TemporaryDirectory().name
+ os.makedirs(self.tmp_dir, exist_ok=True)
+
+ def tearDown(self):
+ shutil.rmtree(self.tmp_dir)
+ super().tearDown()
+
+ @unittest.skip(
+ 'skip since the test requires multiple GPU and takes a long time to run'
+ )
+ def test_finetune_people_daily(self):
+ model_id = 'damo/nlp_structbert_siamese-uie_chinese-base'
+ WORK_DIR = '/tmp'
+ train_dataset = MsDataset.load(
+ 'people_daily_ner_1998_tiny',
+ namespace='damo',
+ split='train',
+ download_mode=DownloadMode.FORCE_REDOWNLOAD)
+ eval_dataset = MsDataset.load(
+ 'people_daily_ner_1998_tiny',
+ namespace='damo',
+ split='validation',
+ download_mode=DownloadMode.FORCE_REDOWNLOAD)
+ max_epochs = 3
+ kwargs = dict(
+ model=model_id,
+ train_dataset=train_dataset,
+ eval_dataset=eval_dataset,
+ max_epochs=max_epochs,
+ work_dir=WORK_DIR)
+ trainer = build_trainer('siamese-uie-trainer', default_args=kwargs)
+ trainer.train()
+ for i in range(max_epochs):
+ eval_results = trainer.evaluate(f'{WORK_DIR}/epoch_{i+1}.pth')
+ print(f'epoch {i} evaluation result:')
+ print(eval_results)
+ pipeline_uie = pipeline(
+ task=Tasks.siamese_uie, model=f'{WORK_DIR}/output')
+ pipeline_uie(
+ input='1944年毕业于北大的名古屋铁道会长谷口清太郎等人在日本积极筹资',
+ schema={
+ '人物': None,
+ '地理位置': None,
+ '组织机构': None
+ })
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/tests/trainers/test_tinynas_damoyolo_trainer.py b/tests/trainers/test_tinynas_damoyolo_trainer.py
index d08980da..5dd9e928 100644
--- a/tests/trainers/test_tinynas_damoyolo_trainer.py
+++ b/tests/trainers/test_tinynas_damoyolo_trainer.py
@@ -1,18 +1,12 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
-import glob
-import os
-import shutil
-import tempfile
-import unittest
-import torch
+import os
+import unittest
from modelscope.hub.snapshot_download import snapshot_download
from modelscope.metainfo import Trainers
from modelscope.trainers import build_trainer
-from modelscope.utils.config import Config
-from modelscope.utils.constant import ModelFile
-from modelscope.utils.test_utils import DistributedTestCase, test_level
+from modelscope.utils.test_utils import test_level
def _setup():
diff --git a/tests/trainers/test_trainer.py b/tests/trainers/test_trainer.py
index 1fb915c6..2cf4b2e9 100644
--- a/tests/trainers/test_trainer.py
+++ b/tests/trainers/test_trainer.py
@@ -22,6 +22,7 @@ from modelscope.trainers.base import DummyTrainer
from modelscope.trainers.builder import TRAINERS
from modelscope.trainers.trainer import EpochBasedTrainer
from modelscope.utils.constant import LogKeys, ModeKeys, ModelFile, Tasks
+from modelscope.utils.hub import read_config
from modelscope.utils.test_utils import create_dummy_test_dataset, test_level
@@ -549,6 +550,146 @@ class TrainerTest(unittest.TestCase):
for i in [2, 5, 8]:
self.assertIn(MetricKeys.ACCURACY, lines[i])
+ @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
+ def test_train_with_old_and_new_cfg(self):
+ old_cfg = {
+ 'task': Tasks.image_classification,
+ 'train': {
+ 'work_dir':
+ self.tmp_dir,
+ 'dataloader': {
+ 'batch_size_per_gpu': 2,
+ 'workers_per_gpu': 1
+ },
+ 'optimizer': {
+ 'type': 'SGD',
+ 'lr': 0.01,
+ 'options': {
+ 'grad_clip': {
+ 'max_norm': 2.0
+ }
+ }
+ },
+ 'lr_scheduler': {
+ 'type': 'StepLR',
+ 'step_size': 2,
+ 'options': {
+ 'warmup': {
+ 'type': 'LinearWarmup',
+ 'warmup_iters': 2
+ }
+ }
+ },
+ 'hooks': [{
+ 'type': 'CheckpointHook',
+ 'interval': 1
+ }, {
+ 'type': 'TextLoggerHook',
+ 'interval': 1
+ }, {
+ 'type': 'IterTimerHook'
+ }, {
+ 'type': 'EvaluationHook',
+ 'interval': 1
+ }, {
+ 'type': 'TensorboardHook',
+ 'interval': 1
+ }]
+ },
+ 'evaluation': {
+ 'dataloader': {
+ 'batch_size_per_gpu': 2,
+ 'workers_per_gpu': 1,
+ 'shuffle': False
+ },
+ 'metrics': [Metrics.seq_cls_metric],
+ }
+ }
+
+ new_cfg = {
+ 'task': Tasks.image_classification,
+ 'train': {
+ 'work_dir':
+ self.tmp_dir,
+ 'dataloader': {
+ 'batch_size_per_gpu': 2,
+ 'workers_per_gpu': 1
+ },
+ 'optimizer': {
+ 'type': 'SGD',
+ 'lr': 0.01,
+ 'options': {
+ 'grad_clip': {
+ 'max_norm': 2.0
+ }
+ }
+ },
+ 'lr_scheduler': {
+ 'type': 'StepLR',
+ 'step_size': 2,
+ 'options': {
+ 'warmup': {
+ 'type': 'LinearWarmup',
+ 'warmup_iters': 2
+ }
+ }
+ },
+ 'checkpoint': {
+ 'period': {
+ 'interval': 1
+ }
+ },
+ 'logging': {
+ 'interval': 1
+ },
+ 'hooks': [{
+ 'type': 'IterTimerHook'
+ }, {
+ 'type': 'TensorboardHook',
+ 'interval': 1
+ }]
+ },
+ 'evaluation': {
+ 'dataloader': {
+ 'batch_size_per_gpu': 2,
+ 'workers_per_gpu': 1,
+ 'shuffle': False
+ },
+ 'metrics': [Metrics.seq_cls_metric],
+ 'period': {
+ 'interval': 1
+ }
+ }
+ }
+
+ def assert_new_cfg(cfg):
+ self.assertNotIn('CheckpointHook', cfg.train.hooks)
+ self.assertNotIn('TextLoggerHook', cfg.train.hooks)
+ self.assertNotIn('EvaluationHook', cfg.train.hooks)
+ self.assertIn('checkpoint', cfg.train)
+ self.assertIn('logging', cfg.train)
+ self.assertIn('period', cfg.evaluation)
+
+ for json_cfg in (new_cfg, old_cfg):
+ config_path = os.path.join(self.tmp_dir, ModelFile.CONFIGURATION)
+ with open(config_path, 'w') as f:
+ json.dump(json_cfg, f)
+ trainer_name = Trainers.default
+ kwargs = dict(
+ cfg_file=config_path,
+ model=DummyModel(),
+ data_collator=None,
+ train_dataset=dummy_dataset_small,
+ eval_dataset=dummy_dataset_small,
+ max_epochs=3,
+ device='cpu')
+
+ trainer = build_trainer(trainer_name, kwargs)
+ assert_new_cfg(trainer.cfg)
+ trainer.train()
+ cfg = read_config(os.path.join(self.tmp_dir, 'output'))
+ assert_new_cfg(cfg)
+
class DummyTrainerTest(unittest.TestCase):
diff --git a/tests/trainers/test_trainer_gpu.py b/tests/trainers/test_trainer_gpu.py
index 1d3df533..c4173c78 100644
--- a/tests/trainers/test_trainer_gpu.py
+++ b/tests/trainers/test_trainer_gpu.py
@@ -8,15 +8,17 @@ import unittest
import json
import numpy as np
import torch
+from packaging import version
from torch import nn
+from torch.nn.parallel import DistributedDataParallel
from torch.optim import SGD
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import IterableDataset
from modelscope.metainfo import Metrics, Trainers
from modelscope.metrics.builder import MetricKeys
-from modelscope.models.base import Model, TorchModel
-from modelscope.trainers import EpochBasedTrainer, build_trainer
+from modelscope.models.base import TorchModel
+from modelscope.trainers import build_trainer
from modelscope.utils.constant import LogKeys, ModeKeys, ModelFile, Tasks
from modelscope.utils.test_utils import (DistributedTestCase,
create_dummy_test_dataset, test_level)
@@ -229,7 +231,7 @@ class TrainerTestMultiGpus(DistributedTestCase):
super().tearDown()
shutil.rmtree(self.tmp_dir)
- @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
+ @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_multi_gpus(self):
self.start(train_func, num_gpus=2, work_dir=self.tmp_dir, dist=True)
@@ -288,7 +290,7 @@ class TrainerTestMultiGpus(DistributedTestCase):
for i in [1, 3, 5]:
self.assertIn(MetricKeys.ACCURACY, lines[i])
- @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
+ @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_multi_gpus_forward_inputs(self):
self.start(
train_func,
@@ -327,5 +329,186 @@ class TrainerTestMultiGpus(DistributedTestCase):
print(results_files, lines)
+def train_func_2(work_dir,
+ dist=False,
+ iterable_dataset=False,
+ forward_inputs=False,
+ **kwargs):
+ json_cfg = {
+ 'task': Tasks.image_classification,
+ 'model': {},
+ 'train': {
+ 'work_dir': work_dir,
+ 'dataloader': {
+ 'batch_size_per_gpu': 2,
+ 'workers_per_gpu': 1
+ },
+ 'hooks': [{
+ 'type': 'EvaluationHook',
+ 'interval': 1
+ }]
+ },
+ 'evaluation': {
+ 'dataloader': {
+ 'batch_size_per_gpu': 1,
+ 'workers_per_gpu': 1,
+ 'shuffle': False
+ },
+ 'metrics': [Metrics.seq_cls_metric]
+ }
+ }
+
+ extra_hooks = [{'type': 'ApexAMPOptimizerHook'}]
+ json_cfg['train']['hooks'].extend(extra_hooks)
+ config_path = os.path.join(work_dir, ModelFile.CONFIGURATION)
+ with open(config_path, 'w') as f:
+ json.dump(json_cfg, f)
+
+ if forward_inputs:
+ model = DummyModelForwardInputs()
+ else:
+ model = DummyModel()
+ optimmizer = SGD(model.parameters(), lr=0.01)
+ lr_scheduler = StepLR(optimmizer, 2)
+ trainer_name = Trainers.default
+ if iterable_dataset:
+ train_dataset = DummyIterableDataset()
+ eval_dataset = DummyIterableDataset()
+ else:
+ train_dataset = dummy_dataset_big
+ eval_dataset = dummy_dataset_small
+ _kwargs = dict(
+ cfg_file=config_path,
+ model=model,
+ data_collator=None,
+ train_dataset=train_dataset,
+ eval_dataset=eval_dataset,
+ optimizers=(optimmizer, lr_scheduler),
+ max_epochs=3,
+ device='gpu',
+ launcher='pytorch' if dist else None,
+ **kwargs)
+
+ trainer = build_trainer(trainer_name, _kwargs)
+ trainer.train()
+ assert isinstance(trainer.model, DistributedDataParallel)
+ assert isinstance(trainer.model.module, DummyModel)
+ assert trainer.train_outputs['logits'].dtype == torch.float16
+
+
+@unittest.skipIf(not torch.cuda.is_available()
+ or torch.cuda.device_count() <= 1
+ or version.parse(torch.__version__) >= version.parse('1.9.0'),
+ 'skip on torch 1.9 or above')
+class TrainerTestDDPAndApex(DistributedTestCase):
+
+ def setUp(self):
+ print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
+ self.tmp_dir = tempfile.TemporaryDirectory().name
+ if not os.path.exists(self.tmp_dir):
+ os.makedirs(self.tmp_dir)
+
+ def tearDown(self):
+ super().tearDown()
+ shutil.rmtree(self.tmp_dir)
+
+ @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
+ def test_multi_gpus_apex(self):
+ self.start(train_func_2, num_gpus=2, work_dir=self.tmp_dir, dist=True)
+
+
+def test_func(work_dir,
+ dist=False,
+ iterable_dataset=False,
+ forward_inputs=False,
+ **kwargs):
+ json_cfg = {
+ 'task': Tasks.image_classification,
+ 'model': {},
+ 'train': {
+ 'work_dir': work_dir,
+ 'dataloader': {
+ 'batch_size_per_gpu': 2,
+ 'workers_per_gpu': 1
+ },
+ 'hooks': [{
+ 'type': 'EvaluationHook',
+ 'interval': 1
+ }]
+ },
+ 'evaluation': {
+ 'dataloader': {
+ 'batch_size_per_gpu': 1,
+ 'workers_per_gpu': 1,
+ 'shuffle': False
+ },
+ 'metrics': [Metrics.seq_cls_metric]
+ }
+ }
+
+ config_path = os.path.join(work_dir, ModelFile.CONFIGURATION)
+ with open(config_path, 'w') as f:
+ json.dump(json_cfg, f)
+
+ if forward_inputs:
+ model = DummyModelForwardInputs()
+ else:
+ model = DummyModel()
+ torch.save(model.state_dict(), os.path.join(work_dir, 'pytorch_model.bin'))
+ optimmizer = SGD(model.parameters(), lr=0.01)
+ lr_scheduler = StepLR(optimmizer, 2)
+ trainer_name = Trainers.default
+ if iterable_dataset:
+ train_dataset = DummyIterableDataset()
+ eval_dataset = DummyIterableDataset()
+ else:
+ train_dataset = dummy_dataset_big
+ eval_dataset = dummy_dataset_small
+ _kwargs = dict(
+ cfg_file=config_path,
+ model=model,
+ data_collator=None,
+ train_dataset=train_dataset,
+ eval_dataset=eval_dataset,
+ optimizers=(optimmizer, lr_scheduler),
+ max_epochs=3,
+ device='gpu',
+ launcher='pytorch' if dist else None,
+ **kwargs)
+
+ trainer = build_trainer(trainer_name, _kwargs)
+ trainer.evaluate()
+ assert isinstance(trainer.model, DistributedDataParallel)
+ assert isinstance(trainer.model.module, DummyModel)
+ metric_values = trainer.metric_values
+ trainer.evaluate(os.path.join(work_dir, 'pytorch_model.bin'))
+ assert isinstance(trainer.model, DistributedDataParallel)
+ assert isinstance(trainer.model.module, DummyModel)
+ print(metric_values)
+ print(trainer.metric_values)
+ for key in metric_values:
+ assert np.isclose(metric_values[key], trainer.metric_values[key])
+
+
+@unittest.skipIf(not torch.cuda.is_available()
+ or torch.cuda.device_count() <= 1,
+ 'skip on torch 1.9 or above')
+class TrainerTestDDPTest(DistributedTestCase):
+
+ def setUp(self):
+ print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
+ self.tmp_dir = tempfile.TemporaryDirectory().name
+ if not os.path.exists(self.tmp_dir):
+ os.makedirs(self.tmp_dir)
+
+ def tearDown(self):
+ super().tearDown()
+ shutil.rmtree(self.tmp_dir)
+
+ @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
+ def test_multi_gpus_apex_test(self):
+ self.start(test_func, num_gpus=2, work_dir=self.tmp_dir, dist=True)
+
+
if __name__ == '__main__':
unittest.main()
diff --git a/tests/trainers/test_trainer_with_nlp.py b/tests/trainers/test_trainer_with_nlp.py
index ac4e67a3..ceb04e15 100644
--- a/tests/trainers/test_trainer_with_nlp.py
+++ b/tests/trainers/test_trainer_with_nlp.py
@@ -7,10 +7,12 @@ import tempfile
import unittest
import numpy as np
+import torch
+from packaging import version
from modelscope.hub.snapshot_download import snapshot_download
from modelscope.metainfo import Metrics
-from modelscope.models.base import Model
+from modelscope.models.base import Model, TorchModel
from modelscope.models.nlp import SbertForSequenceClassification
from modelscope.msdatasets import MsDataset
from modelscope.pipelines import pipeline
@@ -76,6 +78,52 @@ class TestTrainerWithNlp(unittest.TestCase):
output_dir = os.path.join(self.tmp_dir, ModelFile.TRAIN_OUTPUT_DIR)
pipeline_sentence_similarity(output_dir)
+ @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
+ def test_trainer_callback(self):
+ model_id = 'damo/nlp_structbert_sentence-similarity_chinese-tiny'
+
+ class CustomCallback:
+
+ def after_train_iter(self, trainer):
+ if trainer.iter == 2:
+ trainer._stop_training = True
+
+ kwargs = dict(
+ model=model_id,
+ train_dataset=self.dataset,
+ eval_dataset=self.dataset,
+ work_dir=self.tmp_dir,
+ callbacks=[CustomCallback()])
+
+ trainer = build_trainer(default_args=kwargs)
+ trainer.train()
+
+ self.assertEqual(trainer.iter, 3)
+
+ @unittest.skipIf(
+ version.parse(torch.__version__) < version.parse('2.0.0.dev'),
+ 'skip test when torch version < 2.0')
+ def test_trainer_compile(self):
+ model_id = 'damo/nlp_structbert_sentence-similarity_chinese-tiny'
+
+ class CustomCallback:
+
+ def after_train_iter(self, trainer):
+ if trainer.iter == 5:
+ trainer._stop_training = True
+
+ kwargs = dict(
+ model=model_id,
+ train_dataset=self.dataset,
+ eval_dataset=self.dataset,
+ work_dir=self.tmp_dir,
+ callbacks=[CustomCallback()],
+ compile=True)
+
+ trainer = build_trainer(default_args=kwargs)
+ self.assertTrue(isinstance(trainer.model._orig_mod, TorchModel))
+ trainer.train()
+
@unittest.skip
def test_trainer_with_backbone_head(self):
model_id = 'damo/nlp_structbert_sentiment-classification_chinese-base'
diff --git a/tests/trainers/test_video_summarization_trainer.py b/tests/trainers/test_video_summarization_trainer.py
index 1cea1eea..35eee2bc 100644
--- a/tests/trainers/test_video_summarization_trainer.py
+++ b/tests/trainers/test_video_summarization_trainer.py
@@ -6,7 +6,8 @@ import unittest
from modelscope.hub.snapshot_download import snapshot_download
from modelscope.models.cv.video_summarization import PGLVideoSummarization
-from modelscope.msdatasets.task_datasets import VideoSummarizationDataset
+from modelscope.msdatasets.dataset_cls.custom_datasets import \
+ VideoSummarizationDataset
from modelscope.trainers import build_trainer
from modelscope.utils.config import Config
from modelscope.utils.constant import ModelFile
@@ -17,6 +18,7 @@ logger = get_logger()
class VideoSummarizationTrainerTest(unittest.TestCase):
+ # TODO: To be added to CUSTOM_DATASETS register
def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
diff --git a/tests/utils/test_envs.py b/tests/utils/test_envs.py
new file mode 100644
index 00000000..e87297ac
--- /dev/null
+++ b/tests/utils/test_envs.py
@@ -0,0 +1,25 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+
+import unittest
+
+from modelscope.utils.plugins import EnvsManager
+
+
+class PluginTest(unittest.TestCase):
+
+ def setUp(self):
+ self.model_id = 'damo/nlp_nested-ner_named-entity-recognition_chinese-base-med'
+ self.env_manager = EnvsManager(self.model_id)
+
+ def tearDown(self):
+ self.env_manager.clean_env()
+ super().tearDown()
+
+ def test_create_env(self):
+ need_env = self.env_manager.check_if_need_env()
+ self.assertEqual(need_env, True)
+ activate_dir = self.env_manager.create_env()
+ remote = 'source {}'.format(activate_dir)
+ cmd = f'{remote};'
+ print(cmd)
+ # EnvsManager.run_process(cmd) no sh in ci env, so skip
diff --git a/tests/utils/test_plugin.py b/tests/utils/test_plugin.py
index 40d86f9d..447ce1c9 100644
--- a/tests/utils/test_plugin.py
+++ b/tests/utils/test_plugin.py
@@ -1,16 +1,29 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
+import os
+import shutil
+import tempfile
import unittest
from modelscope.models.builder import MODELS
-from modelscope.utils.plugins import (discover_plugins, import_all_plugins,
- import_file_plugins, import_plugins,
- pushd)
+from modelscope.utils.plugins import (PluginsManager, discover_plugins,
+ import_all_plugins, import_file_plugins,
+ import_plugins, pushd)
+from modelscope.utils.test_utils import test_level
class PluginTest(unittest.TestCase):
def setUp(self):
self.plugins_root = 'tests/utils/plugins/'
+ self.tmp_dir = tempfile.TemporaryDirectory().name
+ if not os.path.exists(self.tmp_dir):
+ os.makedirs(self.tmp_dir)
+ self.package = 'adaseq'
+ self.plugins_manager = PluginsManager()
+
+ def tearDown(self):
+ shutil.rmtree(self.tmp_dir)
+ super().tearDown()
def test_no_plugins(self):
available_plugins = set(discover_plugins())
@@ -39,3 +52,75 @@ class PluginTest(unittest.TestCase):
import_all_plugins()
assert MODELS.get('dummy-model', 'dummy-group') is not None
+
+ def test_install_plugins(self):
+ """
+ examples for the modelscope install method
+ > modelscope install adaseq ofasys
+ > modelscope install git+https://github.com/modelscope/AdaSeq.git
+ > modelscope install adaseq -i -f
+ > modelscope install adaseq --extra-index-url --trusted-host
+ """
+ install_args = [self.package]
+ status_code, install_args = self.plugins_manager.install_plugins(
+ install_args)
+ self.assertEqual(status_code, 0)
+
+ install_args = ['random_blabla']
+ status_code, install_args = self.plugins_manager.install_plugins(
+ install_args)
+ self.assertEqual(status_code, 1)
+
+ install_args = [self.package, 'random_blabla']
+ status_code, install_args = self.plugins_manager.install_plugins(
+ install_args)
+ self.assertEqual(status_code, 1)
+
+ # move this from tear down to avoid unexpected uninstall
+ uninstall_args = [self.package, '-y']
+ self.plugins_manager.uninstall_plugins(uninstall_args)
+
+ @unittest.skip
+ def test_install_plugins_with_git(self):
+
+ install_args = ['git+https://github.com/modelscope/AdaSeq.git']
+ status_code, install_args = self.plugins_manager.install_plugins(
+ install_args)
+ self.assertEqual(status_code, 0)
+
+ # move this from tear down to avoid unexpected uninstall
+ uninstall_args = ['git+https://github.com/modelscope/AdaSeq.git', '-y']
+ self.plugins_manager.uninstall_plugins(uninstall_args)
+
+ def test_uninstall_plugins(self):
+ """
+ examples for the modelscope uninstall method
+ > modelscope uninstall adaseq
+ > modelscope uninstall -y adaseq
+ """
+ install_args = [self.package]
+ status_code, install_args = self.plugins_manager.install_plugins(
+ install_args)
+ self.assertEqual(status_code, 0)
+
+ uninstall_args = [self.package, '-y']
+ status_code, uninstall_args = self.plugins_manager.uninstall_plugins(
+ uninstall_args)
+ self.assertEqual(status_code, 0)
+
+ def test_list_plugins(self):
+ """
+ examples for the modelscope list method
+ > modelscope list
+ > modelscope list --all
+ > modelscope list -a
+ # """
+ modelscope_plugin = os.path.join(self.tmp_dir, 'modelscope_plugin')
+ self.plugins_manager.file_path = modelscope_plugin
+ result = self.plugins_manager.list_plugins()
+ self.assertEqual(len(result.items()), 0)
+
+ from modelscope.utils.plugins import OFFICIAL_PLUGINS
+
+ result = self.plugins_manager.list_plugins(show_all=True)
+ self.assertEqual(len(result.items()), len(OFFICIAL_PLUGINS))